9. Examples

You can find PyTorch examples and tutorials in the Graphcore GitHub examples repository. This contains

9.1. MNIST example

The example in Listing 9.1 shows how a MNIST model can be run on the IPU. The highlighted lines show the PopTorch-specific code required to run the example on multiple IPUs.

You can download the full source code from GitHub: mnist.py.

To run this example you will need to install the Poplar SDK (see the Getting Started Guide for your IPU system) and the appropriate version of torchvision:

$ python3 -m pip install torchvision==0.11.1
Listing 9.1 MNIST example
 16import torch
 17import torch.nn as nn
 18import torchvision
 19import poptorch
 20
 21# Normal pytorch batch size
 22training_batch_size = 20
 23validation_batch_size = 100
 24
 25opts = poptorch.Options()
 26# Device "step"
 27opts.deviceIterations(20)
 28
 29# How many IPUs to replicate over.
 30opts.replicationFactor(4)
 31
 32opts.randomSeed(42)
 33
 34# Load MNIST normally.
 35training_data = poptorch.DataLoader(
 36    opts,
 37    torchvision.datasets.MNIST('mnist_data/',
 38                               train=True,
 39                               download=True,
 40                               transform=torchvision.transforms.Compose([
 41                                   torchvision.transforms.ToTensor(),
 42                                   torchvision.transforms.Normalize(
 43                                       (0.1307, ), (0.3081, ))
 44                               ])),
 45    batch_size=training_batch_size,
 46    shuffle=True)
 47
 48# Load MNIST normally.
 49val_options = poptorch.Options()
 50validation_data = poptorch.DataLoader(
 51    val_options,
 52    torchvision.datasets.MNIST('mnist_data/',
 53                               train=True,
 54                               download=True,
 55                               transform=torchvision.transforms.Compose([
 56                                   torchvision.transforms.ToTensor(),
 57                                   torchvision.transforms.Normalize(
 58                                       (0.1307, ), (0.3081, ))
 59                               ])),
 60    batch_size=validation_batch_size,
 61    shuffle=True,
 62    drop_last=True)
 63
 64# A helper block to build convolution-pool-relu blocks.
 65class Block(nn.Module):
 66    def __init__(self, in_channels, num_filters, kernel_size, pool_size):
 67        super(Block, self).__init__()
 68        self.conv = nn.Conv2d(in_channels,
 69                              num_filters,
 70                              kernel_size=kernel_size)
 71        self.pool = nn.MaxPool2d(kernel_size=pool_size)
 72        self.relu = nn.ReLU()
 73
 74    def forward(self, x):
 75        x = self.conv(x)
 76        x = self.pool(x)
 77        x = self.relu(x)
 78        return x
 79
 80# Define the network using the above blocks.
 81class Network(nn.Module):
 82    def __init__(self):
 83        super().__init__()
 84        self.layer1 = Block(1, 10, 5, 2)
 85        self.layer2 = Block(10, 20, 5, 2)
 86        self.layer3 = nn.Linear(320, 256)
 87        self.layer3_act = nn.ReLU()
 88        self.layer4 = nn.Linear(256, 10)
 89
 90        self.softmax = nn.LogSoftmax(1)
 91        self.loss = nn.NLLLoss(reduction="mean")
 92
 93    def forward(self, x, target=None):
 94        x = self.layer1(x)
 95        x = self.layer2(x)
 96        x = x.view(-1, 320)
 97
 98        x = self.layer3_act(self.layer3(x))
 99        x = self.layer4(x)
100        x = self.softmax(x)
101
102        if target is not None:
103            loss = self.loss(x, target)
104            return x, loss
105        return x
106
107# Create our model.
108model = Network()
109
110# Create model for training which will run on IPU.
111training_model = poptorch.trainingModel(model, training_data.options)
112
113# Same model as above, they will share weights (in 'model') which once training is finished can be copied back.
114inference_model = poptorch.inferenceModel(model, validation_data.options)
115
116def train():
117    for batch_number, (data, labels) in enumerate(training_data):
118        output, losses = training_model(data, labels)
119
120        if batch_number % 10 == 0:
121            print(f"PoptorchIPU loss at batch: {batch_number} is {losses}")
122
123            # Pick the highest probability.
124            _, ind = torch.max(output, 1)
125            assert training_model.options.output_mode in (
126                poptorch.OutputMode.All, poptorch.OutputMode.Final
127            ), "Only 'Final' and 'All' OutputMode supported"
128            # If we're using Final: only keep the last labels, no-op if using All
129            num_labels = ind.shape[0]
130            labels = labels[-num_labels:]
131            eq = torch.eq(ind, labels)
132            elms, counts = torch.unique(eq,
133                                        sorted=False,
134                                        return_counts=True)
135
136            acc = 0.0
137            if len(elms) == 2:
138                if elms[0]:
139                    acc = (counts[0].item() / num_labels) * 100.0
140                else:
141                    acc = (counts[1].item() / num_labels) * 100.0
142
143            print(
144                f"Training accuracy: {acc}% from batch of size {num_labels}"
145            )
146    print("Done training")
147
148def test():
149    correct = 0
150    total = 0
151    with torch.no_grad():
152        for (data, labels) in validation_data:
153            output = inference_model(data)
154
155            # Argmax the probabilities to get the highest.
156            _, ind = torch.max(output, 1)
157
158            # Compare it against the ground truth for this batch.
159            eq = torch.eq(ind, labels)
160
161            # Count the number which are True and the number which are False.
162            elms, counts = torch.unique(eq,
163                                        sorted=False,
164                                        return_counts=True)
165
166            if len(elms) == 2 or elms[0]:
167                if elms[0]:
168                    correct += counts[0].item()
169                else:
170                    correct += counts[1].item()
171
172            total += validation_batch_size
173    print("Validation: of " + str(total) + " samples we got: " +
174          str((correct / total) * 100.0) + "% correct")
175
176# Train on IPU.
177train()
178
179test()