11. Legacy tracing frontend

Warning

Tracing has been deprecated since PopTorch 3.0. We suggest you use the dispatcher frontend, which is enabled by default on supported platforms and brings many benefits, including greatly simplified handling of float16 operations. However, if you need to use tracing for legacy reasons, this section explains the limitations imposed and the workarounds available.

11.1. Dispatcher support

Up to version 2.6 PopTorch used torch.jit.trace to build a static graph representation of a torch.nn.Module.

However, this approach suffered from several limitations:

  • Only tensors could be passed as arguments.

  • The traced model ran on the CPU as part of the tracing process.

    • It was expensive for large batch sizes.

    • It meant we needed to add workarounds to trace types which were not supported on the CPU, for example float16 (See 16-bit float operations when using tracing for more details).

  • Source code location was not supported: most of the instructions pointed at torch.nn.module.py rather than at user code.

To address these issues the default is now to use the PyTorch dispatcher to build the PopTorch graph ourselves.

The dispatcher frontend is supported on most PopTorch platforms (See poptorch.hasMLIRSupportOnPlatform()) but if you run into any issue you can revert back to tracing by using traceModel(). The only current PopTorch platform without dispatcher support is CentOS 7.

11.2. Constraints when using tracing

When tracing, PopTorch uses PyTorch’s torch.jit.trace API. This means that the tracing frontend inherits the constraints of that API. These include:

  • Inputs must be PyTorch tensors or tuples containing PyTorch tensors.

  • None can be used as a default value for a parameter but cannot be explicitly passed as an input value.

See also Section 1.3, Constraints for general PopTorch constraints.

11.3. 16-bit float operations when using tracing

Note

Handling of float16 operations is greatly simplified with the dispatcher frontend. For help on migrating float16 code from tracing to the dispatcher frontend, see 16-bit float migration.

Due to the limitation of PyTorch’s float16 support on the CPU (used for tracing the model), certain operations may result in the use of float32 where float16 would be expected, or float16 where float32 would be expected. This is because the model must always be traced with float16 inputs converted to float32.

This limitation is much less noticeable when opts.Precision.halfFloatCasting(poptorch.HalfFloatCastingBehavior.HalfUpcastToFloat) has not been set because PopTorch’s default casting functionality is to output a float16 if any input of the op is float16. In such situations, any data type which incorrectly resolves to a float16 would have been cast to a float16 in any case.

11.3.1. Casting

The dtype argument in tensor.to(dtype) will be ignored if it is torch.float32 because it may refer to one or more float16 tensors which were converted to float32 to allow tracing to happen, for example a.to(b.dtype) where b may be a float16 tensor converted to a float32 tensor. Once the output of the op or one of its descendants encounters a known float16 or float32 input, the type will be resolved to this type.

The following examples show cases where the casting functionality is resolved based on context, correctly or incorrectly:

Listing 11.1 Cases where casting resolves to the correct type
 1class Model(torch.nn.Module):
 2    def forward(self, x, y):
 3        # In spite of "y.dtype" being ignored if it is float32, the dtype used
 4        # for the cast resolves to be the type of y because of the "+ y"
 5        return x.to(y.dtype) + y
 6
 7
 8native_model = Model()
 9
10float16_tensor = torch.tensor([1.0], dtype=torch.float16)
11float32_tensor = torch.tensor([1.0], dtype=torch.float32)
12
13assert native_model(float16_tensor, float16_tensor).dtype == torch.float16
14assert native_model(float16_tensor, float32_tensor).dtype == torch.float32
15assert native_model(float32_tensor, float16_tensor).dtype == torch.float16
16assert native_model(float32_tensor, float32_tensor).dtype == torch.float32
17
18poptorch_model = poptorch.inferenceModel(native_model)
19assert poptorch_model(float16_tensor, float16_tensor).dtype == torch.float16
20
21poptorch_model = poptorch.inferenceModel(native_model)
22assert poptorch_model(float16_tensor, float32_tensor).dtype == torch.float32
23
24poptorch_model = poptorch.inferenceModel(native_model)
25assert poptorch_model(float32_tensor, float16_tensor).dtype == torch.float16
26
27poptorch_model = poptorch.inferenceModel(native_model)
28assert poptorch_model(float32_tensor, float32_tensor).dtype == torch.float32
29
Listing 11.2 Cases where casting resolves to an incorrect type
 1class Model(torch.nn.Module):
 2    def forward(self, x, y):
 3        # torch.float32 is ignored and the type is resolved to be the type of y
 4        return x.to(torch.float32) + y
 5
 6
 7native_model = Model()
 8
 9float16_tensor = torch.tensor([1.0], dtype=torch.float16)
10float32_tensor = torch.tensor([1.0], dtype=torch.float32)
11
12assert native_model(float16_tensor, float16_tensor).dtype == torch.float32
13assert native_model(float32_tensor, float16_tensor).dtype == torch.float32
14
15opts = poptorch.Options()
16# Important: this is only needed for traceModel(True)
17opts.Jit.traceModel(True)
18opts.Precision.halfFloatCasting(
19    poptorch.HalfFloatCastingBehavior.HalfUpcastToFloat)
20
21# This incorrectly results in a float16 tensor
22poptorch_model = poptorch.inferenceModel(native_model, opts)
23assert poptorch_model(float16_tensor, float16_tensor).dtype == torch.float16
24
25# This incorrectly results in a float16 tensor
26poptorch_model = poptorch.inferenceModel(native_model, opts)
27assert poptorch_model(float32_tensor, float16_tensor).dtype == torch.float16
28
29# UPDATE: with the new default of traceModel(False) PopTorch now matches the native behaviour
30poptorch_model = poptorch.inferenceModel(native_model)
31assert poptorch_model(float16_tensor, float16_tensor).dtype == native_model(
32    float16_tensor, float16_tensor).dtype
33
34poptorch_model = poptorch.inferenceModel(native_model)
35assert poptorch_model(float32_tensor, float16_tensor).dtype == native_model(
36    float32_tensor, float16_tensor).dtype

11.3.2. Creation functions

The following functions are affected:

  • torch.ones

  • torch.rand

  • torch.zeros

  • torch.distributions.uniform.Uniform

The dtype arguments will be ignored because they may refer to float16 tensors which were converted to float32 tensors to allow tracing to succeed. Once the output of the op, or its descendant, encounters a known float16 or float32 input, the dtype values are resolved to this type.

The following examples show cases where the type output differs from PyTorch:

Listing 11.3 Type resolution when using torch.zeros
 1## torch.ones and zeros
 2class Model(torch.nn.Module):
 3    def forward(self, x):
 4        # dtype is ignored, however the type is resolved to be the type of x
 5        return torch.zeros((2, 3, 4), dtype=torch.float32) + x
 6
 7
 8native_model = Model()
 9
10float16_tensor = torch.tensor([1.0], dtype=torch.float16)
11float32_tensor = torch.tensor([1.0], dtype=torch.float32)
12
13# The native model always yields a float32 tensor
14assert native_model(float16_tensor).dtype == torch.float32
15assert native_model(float32_tensor).dtype == torch.float32
16
17opts = poptorch.Options()
18# Important: this is only needed for traceModel(True)
19opts.Jit.traceModel(True)
20opts.Precision.halfFloatCasting(
21    poptorch.HalfFloatCastingBehavior.HalfUpcastToFloat)
22
23# The poptorch model will resolve to the type of x
24poptorch_model = poptorch.inferenceModel(native_model, opts)
25assert poptorch_model(float16_tensor).dtype == torch.float16
26
27poptorch_model = poptorch.inferenceModel(native_model, opts)
28assert poptorch_model(float32_tensor).dtype == torch.float32
29
30# UPDATE: with the new default of traceModel(False) PopTorch now matches the native behaviour
31poptorch_model = poptorch.inferenceModel(native_model)
32assert poptorch_model(float16_tensor).dtype == native_model(
33    float16_tensor).dtype
34
35poptorch_model = poptorch.inferenceModel(native_model)
36assert poptorch_model(float32_tensor).dtype == native_model(
37    float32_tensor).dtype
38
Listing 11.4 Type resolution when using torch.rand
 1## torch.rand
 2class Model(torch.nn.Module):
 3    def forward(self, x):
 4        # dtype is ignored, however the type is resolved to be the type of x
 5        return torch.rand((2, 3, 4), dtype=torch.float32) + x
 6
 7
 8native_model = Model()
 9
10float16_tensor = torch.tensor([1.0], dtype=torch.float16)
11float32_tensor = torch.tensor([1.0], dtype=torch.float32)
12
13opts = poptorch.Options()
14# Important: this is only needed for traceModel(True)
15opts.Jit.traceModel(True)
16opts.Precision.halfFloatCasting(
17    poptorch.HalfFloatCastingBehavior.HalfUpcastToFloat)
18
19# The native model always yields a float32 tensor
20assert native_model(float16_tensor).dtype == torch.float32
21assert native_model(float32_tensor).dtype == torch.float32
22
23# The poptorch model will resolve to the type of x
24poptorch_model = poptorch.inferenceModel(native_model, opts)
25assert poptorch_model(float16_tensor).dtype == torch.float16
26
27poptorch_model = poptorch.inferenceModel(native_model, opts)
28assert poptorch_model(float32_tensor).dtype == torch.float32
29
30# UPDATE: with the new default of traceModel(False) PopTorch now matches the native behaviour
31poptorch_model = poptorch.inferenceModel(native_model)
32assert poptorch_model(float16_tensor).dtype == native_model(
33    float16_tensor).dtype
34
35poptorch_model = poptorch.inferenceModel(native_model)
36assert poptorch_model(float32_tensor).dtype == native_model(
37    float32_tensor).dtype
38
Listing 11.5 Type resolution when using torch.distributions.uniform.Uniform
 1## torch.distributions.uniform.Uniform
 2class Model(torch.nn.Module):
 3    def forward(self, x):
 4        # dtype is ignored, however the type is resolved to be the type of x
 5        ud = torch.distributions.uniform.Uniform(
 6            torch.tensor([0.0], dtype=torch.float16),
 7            torch.tensor([1.0], dtype=torch.float32))
 8        return ud.sample((10, 10, 1000)) + x
 9
10
11native_model = Model()
12
13float16_tensor = torch.tensor([1.0], dtype=torch.float16)
14float32_tensor = torch.tensor([1.0], dtype=torch.float32)
15
16# The native model always yields a float32 tensor
17assert native_model(float16_tensor).dtype == torch.float32
18assert native_model(float32_tensor).dtype == torch.float32
19
20opts = poptorch.Options()
21# Important: this is only needed for traceModel(True)
22opts.Jit.traceModel(True)
23opts.Precision.halfFloatCasting(
24    poptorch.HalfFloatCastingBehavior.HalfUpcastToFloat)
25
26# The poptorch model will resolve to the type of x
27poptorch_model = poptorch.inferenceModel(native_model, opts)
28assert poptorch_model(float16_tensor).dtype == torch.float16
29
30poptorch_model = poptorch.inferenceModel(native_model, opts)
31assert poptorch_model(float32_tensor).dtype == torch.float32
32
33# UPDATE: with the new default of traceModel(False) PopTorch now matches the native behaviour
34poptorch_model = poptorch.inferenceModel(native_model)
35assert poptorch_model(float16_tensor).dtype == native_model(
36    float16_tensor).dtype
37
38poptorch_model = poptorch.inferenceModel(native_model)
39assert poptorch_model(float32_tensor).dtype == native_model(
40    float32_tensor).dtype
41

11.3.3. Normalization

Some normalization layers require the computation of running statistics - mean and variance. These tensors will be computed as float32 even though the inputs to the operator can be float16. This behaviour has been chosen to strike a balance between performance and numerical accuracy.

The following operators are affected:

  • torch.nn.BatchNorm1d

  • torch.nn.BatchNorm2d

  • torch.nn.BatchNorm3d

The type of running statistics computations may be controlled via opts.Precision.runningStatisticsAlwaysFloat(bool). For example, in the script below, mean and variance computations will be performed in half-precision:

Listing 11.6 Controlling type of running mean and variance computations
1model = torch.nn.Sequential()
2model.add_module('lin', torch.nn.Linear(16, 16))
3model.add_module('bn', torch.nn.BatchNorm1d(16))
4model.float()
5
6opts = poptorch.Options()
7opts.Precision.runningStatisticsAlwaysFloat(False)
8poptorch_model = poptorch.inferenceModel(model, opts)

11.4. Automatic mixed-precision casting

Warning

The autocasting API is only available when using the legacy tracing frontend. When using the dispatcher frontend, which is the default frontend, simply use PyTorch casting.

PopTorch supports converting your model automatically between float16 and float32. This functionality is not active by default - you must enable it explicitly by calling the autocast(enabled=True) method at model level.

Listing 11.7 Enabling automatic casting at model level
model = MyModel()
model.autocast()
poptorch_model = poptorch.inferenceModel(model)

During compilation, selected layers and operators will have their types adjusted aiming to strike a good compromise between compute efficiency, memory requirements and numerical precision.

You can also set automatic casting at the layer level. In this situation, its effect is hierarchical: changing the setting for a layer affects it and all layers it contains.

In the following example, automatic casting is enabled for all layers of the model, except for the first activation and second convolution.

Listing 11.8 Controlling automatic casting at layer level
model = torch.nn.Sequential()
model.add_module('conv1', torch.nn.Conv2d(1, 20, 5))
model.add_module('relu1', torch.nn.ReLU())
model.add_module('conv2', torch.nn.Conv2d(20, 64, 5))
model.add_module('relu2', torch.nn.ReLU())
model.autocast()
model.relu1.autocast(False)
model.conv2.autocast(False)

You can also set automatic casting with the function decorator @poptorch.autocast(enabled=True). Its effect is to apply automatic casting to the body of the function. Setting its parameter to False has the opposite effect. A typical use-case is applying it to the forward function of custom modules.

Listing 11.9 Controlling automatic casting via decorator
class MyModel(torch.nn.Module):
    @poptorch.autocast()
    def forward(self, x, y):
        return torch.bmm(x, y)


In addition, you can apply poptorch.autocast(enabled=True) to a code-block, with similar effect.

Listing 11.10 Applying automatic casting to a code-block
x = torch.randn(1, 10, 10)
y = torch.randn(1, 10, 10)
with poptorch.autocast():
    z = torch.bmm(x, y)

You can disable this feature for the whole application via the autocastEnabled(bool) method of _PrecisionOptions.

Listing 11.11 Disabling automatic casting
opts = poptorch.Options()
opts.Precision.autocastEnabled(False)
poptorch_model = poptorch.inferenceModel(model, opts)

11.4.1. Custom casting policies

PopTorch provides a mechanism to customize automatic casting behaviour in the form of casting policy classes. A casting policy is defined by four sets of PyTorch modules and/or PyTorch operators:

  1. fp16 - set of operations to be typed as float16

  2. fp32 - set of operations to be typed as float32

  3. promote - set of operations to be promoted to float32 should they take mixed-precision inputs

  4. demote - set of operations to be demoted to float16 should they take mixed-precision inputs

The following example describes a policy where convolution and ReLU operations are to be performed using float16, whilst batch matrix multiplication is to be performed using float32. Dot product computations will be promoted to float32 when operands have mixed precision.

Listing 11.12 Custom casting policies
fp16 = [torch.nn.Conv2d, torch.relu]
fp32 = [torch.bmm]
promote = [torch.dot]
demote = []
policy = poptorch.autocasting.Policy(fp16, fp32, promote, demote)

opts = poptorch.Options()
opts.Precision.autocastPolicy(policy)
poptorch.model = poptorch.inferenceModel(model, opts)