1. Introduction

PopTorch is a set of extensions for PyTorch to enable PyTorch models to run directly on Graphcore IPU hardware. PopTorch has been designed to require as few changes as possible to your models in order to run on the IPU. However, it does have some differences from native PyTorch execution, to get the most out of IPU hardware.

PopTorch supports executing native PyTorch models for both inference and training. To run a PyTorch model on the IPU, you must wrap your model with either:

Both of these functions accept a PyTorch model (torch.nn.Module) and create a representation of the model that can be executed on the IPU hardware.

In training mode, PopTorch uses its own automatic differentiation engine (autograd) that differs from native PyTorch. The input model (torch.nn.Module) is required to have at least one loss built into the forward pass. PopTorch backpropagates the gradients from the loss value(s) to update the model parameters. This is all taken care of automatically so your training loop does not need to call .backward() on the loss value(s) or .step() on the optimiser.

The following example shows a typical native PyTorch training loop. The model incorporates a loss criterion within the .forward() method, and returns the loss value as a second output (along with the prediction). This native PyTorch training loop manually invokes the .backward() method to backpropagate the gradients. The loop also manually updates the optimiser by calling the .step() method.

Listing 1.1 A simple example of training using PyTorch on the CPU
 1    training_data = torch.utils.data.DataLoader(ExampleDataset(shape=[1],
 2                                                               length=20000),
 3                                                batch_size=10,
 4                                                shuffle=True,
 5                                                drop_last=True)
 6
 7    model = ExampleModelWithLoss()
 8    model.train()
 9
10    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
11
12    momentum_loss = None
13
14    for batch, target in training_data:
15        # Zero gradients
16        optimizer.zero_grad()
17
18        # Run model.
19        _, loss = model(batch, target)
20
21        # Back propagate the gradients.
22        loss.backward()
23
24        # Update the weights.
25        optimizer.step()
26
27        if momentum_loss is None:
28            momentum_loss = loss
29        else:
30            momentum_loss = momentum_loss * 0.95 + loss * 0.05
31
32        if momentum_loss < 0.1:
33            optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)

1.1. Data batching

An equivalent training loop executing the model on the IPU with PopTorch is shown below. The poptorch.DataLoader is used to efficiently load data batches on the IPU. PopTorch follows the data batching semantics of PopART. By default, this means you will just pass in data of the normal batch size. However, there are a number of options provided in PopTorch which will enable more efficient data loading. See Efficient data batching for more information.

Notice that the torch.optim.AdamW optimiser is passed as an input argument to the poptorch.trainingModel() wrapper which applies the optimiser algorithm during training on the IPU. The optimiser state is automatically managed by the PopART framework so there is no need to call the .step() method. Another significant change from the native training loop is there is no loss.backward(). As mentioned above, PopTorch uses its own automatic differentiation engine and will detect the loss value to backpropagate the gradients from.

Listing 1.2 Equivalent code using PopTorch to train on the IPU
 1    # Set up the PyTorch DataLoader to load that much data at each iteration
 2    opts = poptorch.Options()
 3    opts.deviceIterations(10)
 4    training_data = poptorch.DataLoader(options=opts,
 5                                        dataset=ExampleDataset(shape=[1],
 6                                                               length=20000),
 7                                        batch_size=10,
 8                                        shuffle=True,
 9                                        drop_last=True)
10
11    model = ExampleModelWithLoss()
12    model.train()
13
14    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
15
16    # Wrap the model in a PopTorch training wrapper
17    poptorch_model = poptorch.trainingModel(model,
18                                            options=opts,
19                                            optimizer=optimizer)
20
21    momentum_loss = None
22
23    for batch, target in training_data:
24        # Performs forward pass, loss function evaluation,
25        # backward pass and weight update in one go on the device.
26        _, loss = poptorch_model(batch, target)
27
28        if momentum_loss is None:
29            momentum_loss = loss
30        else:
31            momentum_loss = momentum_loss * 0.95 + loss * 0.05
32
33        # Optimizer can be updated via setOptimizer.
34        if momentum_loss < 0.1:
35            poptorch_model.setOptimizer(
36                torch.optim.AdamW(model.parameters(), lr=0.0001))

1.2. Distributed execution

For additional scalability, you can wrap individual layers in an IPU helper to designate the IPU to execute the layer. Using the user-provided annotations, PopTorch will use PopART to parallelise the model over the given number of IPUs. Additional parallelism can be expressed via a replication factor which enables you to data-parallelise the model over additional IPUs. See Distributed execution for addtional information.

1.3. Constraints

PopTorch uses PyTorch’s torch.jit.trace API. That means it inherits the constraints of that API. These include:

  • Inputs must be PyTorch tensors or tuples containing PyTorch tensors.

  • None can be used as a default value for a parameter but cannot be explicitly passed as an input value.

  • torch.jit.trace cannot handle control flow or shape variations within the model. That is, the inputs passed at run-time cannot vary the control flow of the model or the shapes/sizes of results. If you attempt this, the graph will be frozen to whichever control flow path was traced as a result of the first inputs given to the wrapped model.

Note

All tensor data types and shapes must be constant for the entire dataset.

Not all PyTorch operations have been implemented by the PopTorch compiler yet. See IPU supported operations for a list of operators that are supported on the IPU. Please also report any unsupported operators to support@graphcore.ai so that these ops may be incorporated into a future release.

1.4. Other resources

Please see Graphcore’s website for How-to Videos and Graphcore’s code examples GitHub repository for PopTorch applications, code examples and tutorials. Further developer resources can be found on Graphcore’s developer portal.