8. Examples

8.1. MNIST example

Listing 8.1 MNIST example
  1    import torch
  2    import torch.nn as nn
  3    import torchvision
  4    import poptorch
  5
  6    # Normal pytorch batch size
  7    training_batch_size = 20
  8    validation_batch_size = 100
  9
 10    opts = poptorch.Options()
 11    # Device "step"
 12    opts.deviceIterations(20)
 13
 14    # How many IPUs to replicate over.
 15    opts.replicationFactor(4)
 16
 17    opts.randomSeed(42)
 18
 19    # Load MNIST normally.
 20    training_data = poptorch.DataLoader(
 21        opts,
 22        torchvision.datasets.MNIST('mnist_data/',
 23                                   train=True,
 24                                   download=True,
 25                                   transform=torchvision.transforms.Compose([
 26                                       torchvision.transforms.ToTensor(),
 27                                       torchvision.transforms.Normalize(
 28                                           (0.1307, ), (0.3081, ))
 29                                   ])),
 30        batch_size=training_batch_size,
 31        shuffle=True)
 32
 33    # Load MNIST normally.
 34    val_options = poptorch.Options()
 35    validation_data = poptorch.DataLoader(
 36        val_options,
 37        torchvision.datasets.MNIST('mnist_data/',
 38                                   train=True,
 39                                   download=True,
 40                                   transform=torchvision.transforms.Compose([
 41                                       torchvision.transforms.ToTensor(),
 42                                       torchvision.transforms.Normalize(
 43                                           (0.1307, ), (0.3081, ))
 44                                   ])),
 45        batch_size=validation_batch_size,
 46        shuffle=True,
 47        drop_last=True)
 48
 49    # A helper block to build convolution-pool-relu blocks.
 50    class Block(nn.Module):
 51        def __init__(self, in_channels, num_filters, kernel_size, pool_size):
 52            super(Block, self).__init__()
 53            self.conv = nn.Conv2d(in_channels,
 54                                  num_filters,
 55                                  kernel_size=kernel_size)
 56            self.pool = nn.MaxPool2d(kernel_size=pool_size)
 57            self.relu = nn.ReLU()
 58
 59        def forward(self, x):
 60            x = self.conv(x)
 61            x = self.pool(x)
 62            x = self.relu(x)
 63            return x
 64
 65    # Define the network using the above blocks.
 66    class Network(nn.Module):
 67        def __init__(self):
 68            super().__init__()
 69            self.layer1 = Block(1, 10, 5, 2)
 70            self.layer2 = Block(10, 20, 5, 2)
 71            self.layer3 = nn.Linear(320, 256)
 72            self.layer3_act = nn.ReLU()
 73            self.layer4 = nn.Linear(256, 10)
 74
 75            self.softmax = nn.LogSoftmax(1)
 76            self.loss = nn.NLLLoss(reduction="mean")
 77
 78        def forward(self, x, target=None):
 79            x = self.layer1(x)
 80            x = self.layer2(x)
 81            x = x.view(-1, 320)
 82
 83            x = self.layer3_act(self.layer3(x))
 84            x = self.layer4(x)
 85            x = self.softmax(x)
 86
 87            if target is not None:
 88                loss = self.loss(x, target)
 89                return x, loss
 90            return x
 91
 92    # Create our model.
 93    model = Network()
 94
 95    # Create model for training which will run on IPU.
 96    training_model = poptorch.trainingModel(model, training_data.options)
 97
 98    # Same model as above, they will share weights (in 'model') which once training is finished can be copied back.
 99    inference_model = poptorch.inferenceModel(model, validation_data.options)
100
101    def train():
102        for batch_number, (data, labels) in enumerate(training_data):
103            output, losses = training_model(data, labels)
104
105            if batch_number % 10 == 0:
106                print(f"PoptorchIPU loss at batch: {batch_number} is {losses}")
107
108                # Pick the highest probability.
109                _, ind = torch.max(output, 1)
110                assert training_data.options.anchor_mode in (
111                    poptorch.AnchorMode.All, poptorch.AnchorMode.Final
112                ), "Only 'Final' and 'All' AnchorMode supported"
113                # If we're using Final: only keep the last labels, no-op if using All
114                num_labels = ind.shape[0]
115                labels = labels[-num_labels:]
116                eq = torch.eq(ind, labels)
117                elms, counts = torch.unique(eq,
118                                            sorted=False,
119                                            return_counts=True)
120
121                acc = 0.0
122                if len(elms) == 2:
123                    if elms[0]:
124                        acc = (counts[0].item() / num_labels) * 100.0
125                    else:
126                        acc = (counts[1].item() / num_labels) * 100.0
127
128                print(
129                    f"Training accuracy: {acc}% from batch of size {num_labels}"
130                )
131        print("Done training")
132
133    def test():
134        correct = 0
135        total = 0
136        with torch.no_grad():
137            for (data, labels) in validation_data:
138                output = inference_model(data)
139
140                # Argmax the probabilities to get the highest.
141                _, ind = torch.max(output, 1)
142
143                # Compare it against the ground truth for this batch.
144                eq = torch.eq(ind, labels)
145
146                # Count the number which are True and the number which are False.
147                elms, counts = torch.unique(eq,
148                                            sorted=False,
149                                            return_counts=True)
150
151                if len(elms) == 2 or elms[0]:
152                    if elms[0]:
153                        correct += counts[0].item()
154                    else:
155                        correct += counts[1].item()
156
157                total += validation_batch_size
158        print("Validation: of " + str(total) + " samples we got: " +
159              str((correct / total) * 100.0) + "% correct")
160
161    # Train on IPU.
162    train()
163
164    test()