1 import torch
2 import torch.nn as nn
3 import torchvision
4 import poptorch
5
6 # Normal pytorch batch size
7 training_batch_size = 20
8 validation_batch_size = 100
9
10 opts = poptorch.Options()
11 # Device "step"
12 opts.deviceIterations(20)
13
14 # How many IPUs to replicate over.
15 opts.replicationFactor(4)
16
17 opts.randomSeed(42)
18
19 # Load MNIST normally.
20 training_data = poptorch.DataLoader(
21 opts,
22 torchvision.datasets.MNIST('mnist_data/',
23 train=True,
24 download=True,
25 transform=torchvision.transforms.Compose([
26 torchvision.transforms.ToTensor(),
27 torchvision.transforms.Normalize(
28 (0.1307, ), (0.3081, ))
29 ])),
30 batch_size=training_batch_size,
31 shuffle=True)
32
33 # Load MNIST normally.
34 val_options = poptorch.Options()
35 validation_data = poptorch.DataLoader(
36 val_options,
37 torchvision.datasets.MNIST('mnist_data/',
38 train=True,
39 download=True,
40 transform=torchvision.transforms.Compose([
41 torchvision.transforms.ToTensor(),
42 torchvision.transforms.Normalize(
43 (0.1307, ), (0.3081, ))
44 ])),
45 batch_size=validation_batch_size,
46 shuffle=True,
47 drop_last=True)
48
49 # A helper block to build convolution-pool-relu blocks.
50 class Block(nn.Module):
51 def __init__(self, in_channels, num_filters, kernel_size, pool_size):
52 super(Block, self).__init__()
53 self.conv = nn.Conv2d(in_channels,
54 num_filters,
55 kernel_size=kernel_size)
56 self.pool = nn.MaxPool2d(kernel_size=pool_size)
57 self.relu = nn.ReLU()
58
59 def forward(self, x):
60 x = self.conv(x)
61 x = self.pool(x)
62 x = self.relu(x)
63 return x
64
65 # Define the network using the above blocks.
66 class Network(nn.Module):
67 def __init__(self):
68 super().__init__()
69 self.layer1 = Block(1, 10, 5, 2)
70 self.layer2 = Block(10, 20, 5, 2)
71 self.layer3 = nn.Linear(320, 256)
72 self.layer3_act = nn.ReLU()
73 self.layer4 = nn.Linear(256, 10)
74
75 self.softmax = nn.LogSoftmax(1)
76 self.loss = nn.NLLLoss(reduction="mean")
77
78 def forward(self, x, target=None):
79 x = self.layer1(x)
80 x = self.layer2(x)
81 x = x.view(-1, 320)
82
83 x = self.layer3_act(self.layer3(x))
84 x = self.layer4(x)
85 x = self.softmax(x)
86
87 if target is not None:
88 loss = self.loss(x, target)
89 return x, loss
90 return x
91
92 # Create our model.
93 model = Network()
94
95 # Create model for training which will run on IPU.
96 training_model = poptorch.trainingModel(model, training_data.options)
97
98 # Same model as above, they will share weights (in 'model') which once training is finished can be copied back.
99 inference_model = poptorch.inferenceModel(model, validation_data.options)
100
101 def train():
102 for batch_number, (data, labels) in enumerate(training_data):
103 output, losses = training_model(data, labels)
104
105 if batch_number % 10 == 0:
106 print(f"PoptorchIPU loss at batch: {batch_number} is {losses}")
107
108 # Pick the highest probability.
109 _, ind = torch.max(output, 1)
110 assert training_data.options.anchor_mode in (
111 poptorch.AnchorMode.All, poptorch.AnchorMode.Final
112 ), "Only 'Final' and 'All' AnchorMode supported"
113 # If we're using Final: only keep the last labels, no-op if using All
114 num_labels = ind.shape[0]
115 labels = labels[-num_labels:]
116 eq = torch.eq(ind, labels)
117 elms, counts = torch.unique(eq,
118 sorted=False,
119 return_counts=True)
120
121 acc = 0.0
122 if len(elms) == 2:
123 if elms[0]:
124 acc = (counts[0].item() / num_labels) * 100.0
125 else:
126 acc = (counts[1].item() / num_labels) * 100.0
127
128 print(
129 f"Training accuracy: {acc}% from batch of size {num_labels}"
130 )
131 print("Done training")
132
133 def test():
134 correct = 0
135 total = 0
136 with torch.no_grad():
137 for (data, labels) in validation_data:
138 output = inference_model(data)
139
140 # Argmax the probabilities to get the highest.
141 _, ind = torch.max(output, 1)
142
143 # Compare it against the ground truth for this batch.
144 eq = torch.eq(ind, labels)
145
146 # Count the number which are True and the number which are False.
147 elms, counts = torch.unique(eq,
148 sorted=False,
149 return_counts=True)
150
151 if len(elms) == 2 or elms[0]:
152 if elms[0]:
153 correct += counts[0].item()
154 else:
155 correct += counts[1].item()
156
157 total += validation_batch_size
158 print("Validation: of " + str(total) + " samples we got: " +
159 str((correct / total) * 100.0) + "% correct")
160
161 # Train on IPU.
162 train()
163
164 test()