1#!/usr/bin/env python3
2# Copyright (c) 2020 Graphcore Ltd. All rights reserved.
3
4import sys
5import torch
6import torch.nn as nn
7import torchvision
8import poptorch
9
10if not poptorch.ipuHardwareIsAvailable():
11 poptorch.logger.warn("This examples requires IPU hardware to run")
12 sys.exit(0)
13
14# Normal pytorch batch size
15training_batch_size = 20
16validation_batch_size = 100
17
18opts = poptorch.Options()
19# Device "step"
20opts.deviceIterations(20)
21
22# How many IPUs to replicate over.
23opts.replicationFactor(4)
24
25opts.randomSeed(42)
26
27# Load MNIST normally.
28training_data = poptorch.DataLoader(
29 opts,
30 torchvision.datasets.MNIST('mnist_data/',
31 train=True,
32 download=True,
33 transform=torchvision.transforms.Compose([
34 torchvision.transforms.ToTensor(),
35 torchvision.transforms.Normalize((0.1307, ),
36 (0.3081, ))
37 ])),
38 batch_size=training_batch_size,
39 shuffle=True)
40
41# Load MNIST normally.
42val_options = poptorch.Options()
43validation_data = poptorch.DataLoader(
44 val_options,
45 torchvision.datasets.MNIST('mnist_data/',
46 train=True,
47 download=True,
48 transform=torchvision.transforms.Compose([
49 torchvision.transforms.ToTensor(),
50 torchvision.transforms.Normalize((0.1307, ),
51 (0.3081, ))
52 ])),
53 batch_size=validation_batch_size,
54 shuffle=True,
55 drop_last=True)
56
57
58# A helper block to build convolution-pool-relu blocks.
59class Block(nn.Module):
60 def __init__(self, in_channels, num_filters, kernel_size, pool_size):
61 super(Block, self).__init__()
62 self.conv = nn.Conv2d(in_channels,
63 num_filters,
64 kernel_size=kernel_size)
65 self.pool = nn.MaxPool2d(kernel_size=pool_size)
66 self.relu = nn.ReLU()
67
68 def forward(self, x):
69 x = self.conv(x)
70 x = self.pool(x)
71 x = self.relu(x)
72 return x
73
74
75# Define the network using the above blocks.
76class Network(nn.Module):
77 def __init__(self):
78 super().__init__()
79 self.layer1 = Block(1, 10, 5, 2)
80 self.layer2 = Block(10, 20, 5, 2)
81 self.layer3 = nn.Linear(320, 256)
82 self.layer3_act = nn.ReLU()
83 self.layer4 = nn.Linear(256, 10)
84
85 self.softmax = nn.LogSoftmax(1)
86 self.loss = nn.NLLLoss(reduction="mean")
87
88 def forward(self, x, target=None):
89 x = self.layer1(x)
90 x = self.layer2(x)
91 x = x.view(-1, 320)
92
93 x = self.layer3_act(self.layer3(x))
94 x = self.layer4(x)
95 x = self.softmax(x)
96
97 if target is not None:
98 loss = self.loss(x, target)
99 return x, loss
100 return x
101
102
103# Create our model.
104model = Network()
105
106# Create model for training which will run on IPU.
107training_model = poptorch.trainingModel(model, training_data.options)
108
109# Same model as above, they will share weights (in 'model') which once training is finished can be copied back.
110inference_model = poptorch.inferenceModel(model, validation_data.options)
111
112
113def train():
114 for batch_number, (data, labels) in enumerate(training_data):
115 output, losses = training_model(data, labels)
116
117 if batch_number % 10 == 0:
118 print(f"PoptorchIPU loss at batch: {batch_number} is {losses}")
119
120 # Pick the highest probability.
121 _, ind = torch.max(output, 1)
122 assert training_data.options.anchor_mode in (
123 poptorch.AnchorMode.All, poptorch.AnchorMode.Final
124 ), "Only 'Final' and 'All' AnchorMode supported"
125 # If we're using Final: only keep the last labels, no-op if using All
126 num_labels = ind.shape[0]
127 labels = labels[-num_labels:]
128 eq = torch.eq(ind, labels)
129 elms, counts = torch.unique(eq, sorted=False, return_counts=True)
130
131 acc = 0.0
132 if len(elms) == 2:
133 if elms[0]:
134 acc = (counts[0].item() / num_labels) * 100.0
135 else:
136 acc = (counts[1].item() / num_labels) * 100.0
137
138 print(
139 f"Training accuracy: {acc}% from batch of size {num_labels}")
140 print("Done training")
141
142
143def test():
144 correct = 0
145 total = 0
146 with torch.no_grad():
147 for (data, labels) in validation_data:
148 output = inference_model(data)
149
150 # Argmax the probabilities to get the highest.
151 _, ind = torch.max(output, 1)
152
153 # Compare it against the ground truth for this batch.
154 eq = torch.eq(ind, labels)
155
156 # Count the number which are True and the number which are False.
157 elms, counts = torch.unique(eq, sorted=False, return_counts=True)
158
159 if len(elms) == 2 or elms[0]:
160 if elms[0]:
161 correct += counts[0].item()
162 else:
163 correct += counts[1].item()
164
165 total += validation_batch_size
166 print("Validation: of " + str(total) + " samples we got: " +
167 str((correct / total) * 100.0) + "% correct")
168
169
170# Train on IPU.
171train()
172
173test()