9. Examples

You can find PyTorch examples and tutorials in the Graphcore GitHub repositories:

9.1. MNIST example

The example in Listing 9.1 shows how a MNIST model can be run on the IPU. The highlighted lines show the PopTorch-specific code required to run the example on multiple IPUs.

You can download the full source code from GitHub: mnist.py.

To run this example you will need to install the Poplar SDK (see the Getting Started Guide for your IPU system) and the appropriate version of torchvision:

$ python3 -m pip install torchvision==0.11.1
Listing 9.1 MNIST example
 16 import torch
 17 import torch.nn as nn
 18 import torchvision
 19 import poptorch
 20
 21 # Normal pytorch batch size
 22 training_batch_size = 20
 23 validation_batch_size = 100
 24
 25 opts = poptorch.Options()
 26 # Device "step"
 27 opts.deviceIterations(20)
 28
 29 # How many IPUs to replicate over.
 30 opts.replicationFactor(4)
 31
 32 opts.randomSeed(42)
 33
 34 # Load MNIST normally.
 35 training_data = poptorch.DataLoader(
 36     opts,
 37     torchvision.datasets.MNIST('mnist_data/',
 38                                train=True,
 39                                download=True,
 40                                transform=torchvision.transforms.Compose([
 41                                    torchvision.transforms.ToTensor(),
 42                                    torchvision.transforms.Normalize(
 43                                        (0.1307, ), (0.3081, ))
 44                                ])),
 45     batch_size=training_batch_size,
 46     shuffle=True)
 47
 48 # Load MNIST normally.
 49 val_options = poptorch.Options()
 50 validation_data = poptorch.DataLoader(
 51     val_options,
 52     torchvision.datasets.MNIST('mnist_data/',
 53                                train=True,
 54                                download=True,
 55                                transform=torchvision.transforms.Compose([
 56                                    torchvision.transforms.ToTensor(),
 57                                    torchvision.transforms.Normalize(
 58                                        (0.1307, ), (0.3081, ))
 59                                ])),
 60     batch_size=validation_batch_size,
 61     shuffle=True,
 62     drop_last=True)
 63
 64 # A helper block to build convolution-pool-relu blocks.
 65 class Block(nn.Module):
 66     def __init__(self, in_channels, num_filters, kernel_size, pool_size):
 67         super(Block, self).__init__()
 68         self.conv = nn.Conv2d(in_channels,
 69                               num_filters,
 70                               kernel_size=kernel_size)
 71         self.pool = nn.MaxPool2d(kernel_size=pool_size)
 72         self.relu = nn.ReLU()
 73
 74     def forward(self, x):
 75         x = self.conv(x)
 76         x = self.pool(x)
 77         x = self.relu(x)
 78         return x
 79
 80 # Define the network using the above blocks.
 81 class Network(nn.Module):
 82     def __init__(self):
 83         super().__init__()
 84         self.layer1 = Block(1, 10, 5, 2)
 85         self.layer2 = Block(10, 20, 5, 2)
 86         self.layer3 = nn.Linear(320, 256)
 87         self.layer3_act = nn.ReLU()
 88         self.layer4 = nn.Linear(256, 10)
 89
 90         self.softmax = nn.LogSoftmax(1)
 91         self.loss = nn.NLLLoss(reduction="mean")
 92
 93     def forward(self, x, target=None):
 94         x = self.layer1(x)
 95         x = self.layer2(x)
 96         x = x.view(-1, 320)
 97
 98         x = self.layer3_act(self.layer3(x))
 99         x = self.layer4(x)
100         x = self.softmax(x)
101
102         if target is not None:
103             loss = self.loss(x, target)
104             return x, loss
105         return x
106
107 # Create our model.
108 model = Network()
109
110 # Create model for training which will run on IPU.
111 training_model = poptorch.trainingModel(model, training_data.options)
112
113 # Same model as above, they will share weights (in 'model') which once training is finished can be copied back.
114 inference_model = poptorch.inferenceModel(model, validation_data.options)
115
116 def train():
117     for batch_number, (data, labels) in enumerate(training_data):
118         output, losses = training_model(data, labels)
119
120         if batch_number % 10 == 0:
121             print(f"PoptorchIPU loss at batch: {batch_number} is {losses}")
122
123             # Pick the highest probability.
124             _, ind = torch.max(output, 1)
125             assert training_model.options.output_mode in (
126                 poptorch.OutputMode.All, poptorch.OutputMode.Final
127             ), "Only 'Final' and 'All' OutputMode supported"
128             # If we're using Final: only keep the last labels, no-op if using All
129             num_labels = ind.shape[0]
130             labels = labels[-num_labels:]
131             eq = torch.eq(ind, labels)
132             elms, counts = torch.unique(eq,
133                                         sorted=False,
134                                         return_counts=True)
135
136             acc = 0.0
137             if len(elms) == 2:
138                 if elms[0]:
139                     acc = (counts[0].item() / num_labels) * 100.0
140                 else:
141                     acc = (counts[1].item() / num_labels) * 100.0
142
143             print(
144                 f"Training accuracy: {acc}% from batch of size {num_labels}"
145             )
146     print("Done training")
147
148 def test():
149     correct = 0
150     total = 0
151     with torch.no_grad():
152         for (data, labels) in validation_data:
153             output = inference_model(data)
154
155             # Argmax the probabilities to get the highest.
156             _, ind = torch.max(output, 1)
157
158             # Compare it against the ground truth for this batch.
159             eq = torch.eq(ind, labels)
160
161             # Count the number which are True and the number which are False.
162             elms, counts = torch.unique(eq,
163                                         sorted=False,
164                                         return_counts=True)
165
166             if len(elms) == 2 or elms[0]:
167                 if elms[0]:
168                     correct += counts[0].item()
169                 else:
170                     correct += counts[1].item()
171
172             total += validation_batch_size
173     print("Validation: of " + str(total) + " samples we got: " +
174           str((correct / total) * 100.0) + "% correct")
175
176 # Train on IPU.
177 train()
178
179 test()