5. IPU supported operations
Below is a list of currently supported operations that can be executed on IPU hardware. This list will be expanded over time as we add more support. Some overloads and modes of operation for ops are not supported and we’ve tried to list all the caveats but some may have been missed.
5.1. Torch operations
5.1.1. Tensor operations
Many of the tensor operations will be executed before even reaching the IPU
so we can consider them supported anyway. Some, like contiguous()
, make
no sense on a distributed memory system like the IPU so are ignored. There
are no constraints on the memory format of how operations should be called
other than the constraint that all graph inputs should be contiguous.
We will also create tensor views. However, the aliasing property of views with respect to in-place operations should not be relied on as we may have slightly different view behaviour.
Additionally some PyTorch operations may be implemented by composition of the listed ops but may not be explicitly listed but are in fact supported.
Creation ops
torch.arange
tensor.fill
torch.full
torch.full_like
torch.ones
torch.zeros
Indexing, Slicing, Joining, Mutating Ops
PyTorch functions
torch.cat
torch.chunk
torch.reshape
torch.stack
torch.split
torch.squeeze
torch.t
torch.transpose
torch.unsqueeze
torch.where
Tensor methods
tensor.expand
tensor.expand_as
tensor.masked_fill
Random Samplers
To set the random state use poptorch.Options.randomSeed
torch.distributions.Uniform
torch.normal
torch.rand
torch.randn
torch.uniform
5.1.2. Math operations
Pointwise Ops
torch.abs
torch.add
torch.asin
torch.atan
torch.ceil
torch.clamp
torch.cos
torch.cosh
torch.div
torch.exp
torch.expm1
torch.floor
torch.floor_divide
torch.frac
torch.log
torch.log10
torch.log1p
torch.log2
torch.mul
torch.norm
torch.neg
torch.pow
torch.reciprocal
torch.round
torch.rsqrt
torch.sigmoid
torch.sign
torch.sin
torch.sinh
torch.sqrt
torch.square
torch.sub
torch.tan
torch.tanh
torch.true_divide
torch.trunc
Reduction Ops
torch.argmax
torch.argmin
torch.mean
torch.prod
torch.logsumexp
torch.sum
Comparison Ops
torch.eq
torch.ge
torch.gt
torch.le
torch.lt
torch.min and torch.max only support (tensor, tensor) and (tensor) overloads. They do not support the (tensor, dim=.*, keepdim=.*) overload.
torch.max
torch.min
torch.ne
torch.isnan
torch.topk only supports sorted=True and Largest=True arguments.
torch.topk
Other Ops
torch.cumsum
torch.meshgrid
torch.cartesian_prod
torch.tensordot
BLAS and LAPACK Operations
torch.addmm
torch.matmul
torch.bmm
5.2. Torch.nn operations
5.2.1. Containers
torch.nn.Module
and torch.nn.Sequential
can be passed into our
compiler wrappers and just work.
5.2.2. Convolution layers
Conv transpose operations do not yet support dilations.
torch.nn.Conv1d
torch.nn.Conv2d
torch.nn.Conv3d
torch.nn.ConvTranspose1d
torch.nn.ConvTranspose2d
torch.nn.ConvTranspose3d
5.2.3. Pooling layers
Currently the max pool layers do not return the indices
so only the variants with return_indices=False
are supported.
torch.nn.MaxPool1d
torch.nn.MaxPool2d
torch.nn.MaxPool3d
torch.nn.AvgPool1d
torch.nn.AvgPool2d
torch.nn.AvgPool3d
torch.nn.AdaptiveAvgPool2d
5.2.4. Padding layers
All padding layers are supported.
torch.nn.ReflectionPad1d
torch.nn.ReflectionPad2d
torch.nn.ReplicationPad1d
torch.nn.ReplicationPad2d
torch.nn.ReplicationPad3d
torch.nn.ZeroPad2d
torch.nn.ConstantPad1d
torch.nn.ConstantPad2d
torch.nn.ConstantPad3d
5.2.5. Activations
torch.nn.ELU
torch.nn.GELU
torch.nn.LeakyReLU
torch.nn.LogSoftmax
torch.nn.ReLU
torch.nn.SELU
torch.nn.Sigmoid
torch.nn.Softmax
torch.nn.Softsign
torch.nn.Tanh
torch.nn.PReLU
torch.nn.Hardtanh
torch.nn.functional.glu
5.2.6. Normalization layers
Currently only affine=True
is supported as a parameter. That is to say, only the variants with trainable parameters are supported.
torch.nn.BatchNorm1d
torch.nn.BatchNorm2d
torch.nn.BatchNorm3d
torch.nn.LayerNorm
torch.nn.GroupNorm
torch.nn.InstanceNorm1d
torch.nn.InstanceNorm2d
torch.nn.InstanceNorm3d
5.2.7. Recurrent layers
torch.nn.LSTM
5.2.8. Linear layers
torch.nn.Identity
torch.nn.Linear
torch.nn.Bilinear
5.2.9. Dropout
torch.nn.dropout
5.2.10. Sparse layers
Embedding is supported with the exception of padding_idx
being ignored.
torch.nn.Embedding
5.2.11. Loss functions
This version supports a limited subset of loss functions. However, we support
poptorch.identity_loss()
which gives you the ability to implement any arbitrary
loss function.
See also
One caveat for the following loss functions is if they are used they will always be included in the back propagation and will always receive a gradient, which is a slight deviation from normal PyTorch operations, where they have to opt in to the gradient pass.
torch.nn.L1Loss
torch.nn.MSELoss
torch.nn.CrossEntropyLoss
torch.nn.NLLLoss
torch.nn.BCELoss
torch.nn.KLDivLoss
torch.nn.PoissonNLLLoss
torch.nn.HingeEmbeddingLoss
torch.nn.BCEWithLogitsLoss
torch.nn.SmoothL1Loss
torch.nn.SoftMarginLoss
torch.nn.CosineEmbeddingLoss
torch.nn.MarginRankingLoss
torch.nn.TripletMarginLoss
5.2.12. Vision Layers
Only nearest is supported.
torch.nn.Upsample
5.3. Float 16 operations
Due to the limitation of PyTorch’s float 16 support on the CPU (used for tracing the model), certain operations may result in the use of float 32 where float 16 would be expected, or float 16 where float 32 would be expected. This is because the model must always be traced with float 16 inputs converted to float 32.
5.3.1. Casting
The tensor.to(dtype)
argument will be ignored because it may refer to one or more float 16 tensors which were converted to float 32 to allow tracing to happen, for example a.to(b.dtype)
where b
may be a float 16 tensor converted to a float 32 tensor.
Once the output of the op or one of its descendants encounters a known float 16 or float 32 input, the type will be resolved to this type.
The following examples show cases where the casting functionality is resolved based on context, correctly or incorrect:
1class Model(torch.nn.Module):
2 def forward(self, x, y):
3 # y.dtype is ignored, however the type is resolved to be the type of y
4 return x.to(y.dtype) + y
5
6
7native_model = Model()
8
9float16_tensor = torch.tensor([1.0], dtype=torch.float16)
10float32_tensor = torch.tensor([1.0], dtype=torch.float32)
11
12assert native_model(float16_tensor, float16_tensor).dtype == torch.float16
13assert native_model(float16_tensor, float32_tensor).dtype == torch.float32
14assert native_model(float32_tensor, float16_tensor).dtype == torch.float16
15assert native_model(float32_tensor, float32_tensor).dtype == torch.float32
16
17poptorch_model = poptorch.inferenceModel(native_model)
18assert poptorch_model(float16_tensor, float16_tensor).dtype == torch.float16
19
20poptorch_model = poptorch.inferenceModel(native_model)
21assert poptorch_model(float16_tensor, float32_tensor).dtype == torch.float32
22
23poptorch_model = poptorch.inferenceModel(native_model)
24assert poptorch_model(float32_tensor, float16_tensor).dtype == torch.float16
25
26poptorch_model = poptorch.inferenceModel(native_model)
27assert poptorch_model(float32_tensor, float32_tensor).dtype == torch.float32
28
1class Model(torch.nn.Module):
2 def forward(self, x, y):
3 # torch.float32 is ignored and the type is resolved to be the type of y
4 return x.to(torch.float32) + y
5
6
7native_model = Model()
8
9float16_tensor = torch.tensor([1.0], dtype=torch.float16)
10float32_tensor = torch.tensor([1.0], dtype=torch.float32)
11
12assert native_model(float16_tensor, float16_tensor).dtype == torch.float32
13assert native_model(float32_tensor, float16_tensor).dtype == torch.float32
14
15# This incorrectly results in a float 16 tensor
16poptorch_model = poptorch.inferenceModel(native_model)
17assert poptorch_model(float16_tensor, float16_tensor).dtype == torch.float16
18
19# This incorrectly results in a float 16 tensor
20poptorch_model = poptorch.inferenceModel(native_model)
21assert poptorch_model(float32_tensor, float16_tensor).dtype == torch.float16
5.3.2. Creation functions
The following functions are affected: * torch.ones * torch.rand * torch.zeros * torch.distributions.uniform.Uniform
The dtype
arguments will be ignored because they may refer to float 16 tensors which were converted to float 32 tensors to allow tracing to succeed.
Once the output of the op, or its descendant, encounters a known float 16 or float 32 input, the dtypes
are resolved to this type.
The following examples show cases where the type output differs from PyTorch:
1## torch.ones and zeros
2class Model(torch.nn.Module):
3 def forward(self, x):
4 # dtype is ignored, however the type is resolved to be the type of x
5 return torch.zeros((2, 3, 4), dtype=torch.float32) + x
6
7
8native_model = Model()
9
10float16_tensor = torch.tensor([1.0], dtype=torch.float16)
11float32_tensor = torch.tensor([1.0], dtype=torch.float32)
12
13# The native model always yields a float32 tensor
14assert native_model(float16_tensor).dtype == torch.float32
15assert native_model(float32_tensor).dtype == torch.float32
16
17# The poptorch model will resolve to the type of x
18poptorch_model = poptorch.inferenceModel(native_model)
19assert poptorch_model(float16_tensor).dtype == torch.float16
20
21poptorch_model = poptorch.inferenceModel(native_model)
22assert poptorch_model(float32_tensor).dtype == torch.float32
23
1## torch.rand
2class Model(torch.nn.Module):
3 def forward(self, x):
4 # dtype is ignored, however the type is resolved to be the type of x
5 return torch.rand((2, 3, 4), dtype=torch.float32) + x
6
7
8native_model = Model()
9
10float16_tensor = torch.tensor([1.0], dtype=torch.float16)
11float32_tensor = torch.tensor([1.0], dtype=torch.float32)
12
13# The native model always yields a float32 tensor
14assert native_model(float16_tensor).dtype == torch.float32
15assert native_model(float32_tensor).dtype == torch.float32
16
17# The poptorch model will resolve to the type of x
18poptorch_model = poptorch.inferenceModel(native_model)
19assert poptorch_model(float16_tensor).dtype == torch.float16
20
21poptorch_model = poptorch.inferenceModel(native_model)
22assert poptorch_model(float32_tensor).dtype == torch.float32
23
1## torch.distributions.uniform.Uniform
2class Model(torch.nn.Module):
3 def forward(self, x):
4 # dtype is ignored, however the type is resolved to be the type of x
5 ud = torch.distributions.uniform.Uniform(
6 torch.tensor([10.0], dtype=torch.float16),
7 torch.tensor([10.0], dtype=torch.float32))
8 return ud.sample((10, 10, 1000)) + x
9
10
11native_model = Model()
12
13float16_tensor = torch.tensor([1.0], dtype=torch.float16)
14float32_tensor = torch.tensor([1.0], dtype=torch.float32)
15
16# The native model always yields a float32 tensor
17assert native_model(float16_tensor).dtype == torch.float32
18assert native_model(float32_tensor).dtype == torch.float32
19
20# The poptorch model will resolve to the type of x
21poptorch_model = poptorch.inferenceModel(native_model)
22assert poptorch_model(float16_tensor).dtype == torch.float16
23
24poptorch_model = poptorch.inferenceModel(native_model)
25assert poptorch_model(float32_tensor).dtype == torch.float16