5. Data types

Currently, PopXL supports the data types listed in Table 5.1. These data types are defined in popxl directly and will be converted to their IPU-compatible data type. Note that if the session option popart.SessionOptions.enableSupportedDataTypeCasting is set to True, then int64 and uint64 will be downcast to int32 and uint32, respectively.

Table 5.1 Data types in PopXL

PopXL dtype

int

floating point

signed

NumPy dtype

Python dtype

alias

bool

False

False

False

bool

builtins.bool

N/A

int8

True

False

True

int8

None

N/A

int32

True

False

True

int32

None

N/A

uint8

True

False

False

uint8

None

N/A

uint32

True

False

False

uint32

None

N/A

float16

False

True

True

float16

None

half

float32

False

True

True

float32

builtins.float

float

float64

False

True

True

float64

None

double

float8_143

False

True

True

uint8

None

N/A

float8_152

False

True

True

uint8

None

N/A

5.1. 8-bit floating point datatypes

There are two 8-bit float datatypes in PopXL, namely popxl.float8_143 and popxl.float8_152. The numbers in the names of these types refer to the format: the number of bits used to represent the sign, exponent and mantissa. As with other floating point representations, the exponent is subject to a bias. This bias is different for each of the two formats:

Table 5.2 Float8 formats in PopXL

PopXL dtype

Number of sign bits

Number of exponent bits

Number of mantissa bits

Exponent bias

Smallest positive value

Largest positive value

float8_143

1

4

3

-8

\(2^-10\)

\(240.0\)

float8_152

1

5

2

-16

\(2^-17\)

\(57344.0\)

More details about the numerical properties of these two 8-bit floating point data types can be found in arXiv paper 8-Bit Numerical Formats for Deep Neural Networks.

Because of the limited numerical range of 8-bit floating point numbers, operations that consume or produce tensors of these types are usually fused with a pow2 scaling operation. These operations have a log2_scale parameter. Internally, these operations multiply your 8-bit floating point data with a factor of pow2(log2_scale). Note that you can use a positive log2_scale to accommodate large numerical ranges or you can use negative values for smaller numerical ranges. Currently, we support log2_scale parameter values in the range \([-32,32)\).

Table 5.3 lists a number of utility functions and operations for 8-bit floats.

Table 5.3 8-Bit floating point API

API function

Description

host_pow2scale_cast_to_fp8()

Host-based conversion from 16/32/64-bit floating point data to a 8-bit floating point representation.

host_pow2scale_cast_from_fp8()

Host-based conversion from a 8-bit floating point representation back to 16/32/64-bit floating point data.

pow2scale_cast_to_fp8()

Operation to convert from 16-bit floating point to 8-bit floating point.

pow2scale_cast_from_fp8()

Operation to convert from 8-bit floating point to 16-bit floating point.

matmul_pow2scaled()

Operation to perform a matmul on 8-bit floating point data resulting in 16-bit floating point output.

Note that for device-based operations that support 8-bit float operands the log2_scale operand is also a tensor parameter in its own right. This means you can change this scaling at runtime if you so desire.

5.2. 8-bit floating point inference model example

An example of using float8 tensors in an inference graph is shown in the example float8_inference.py. The float16 input data is loaded onto the device as-is, then cast to float8 on the device with a pow2scale_cast_to_fp8() operator. After this we do the cast on the host of the trained weight data (in this example the weights are randomly generated), then creating the popxl.variable() for the float8 weights.

Note that in both cases we do not scale the values, as this is done within the conv_pow2scaled() operator.

Listing 5.1 Example of host-based casting to float8
104        # Cast to fp8 on device before conv layer
105        # Note we not not scale here, as scaling is done within the conv op.
106        a_fp8 = ops.pow2scale_cast_to_fp8(
107            a, data_type=popxl.float8_143, log2_scale=popxl.constant(0)
108        )
109
110        conv_ = ConvFloat8(opts_)
111        # Convert the weight data on the host.
112        # Note we not not scale here, as scaling is done within the conv op.
113        weight_fp8 = host_pow2scale_cast_to_fp8(weight, popxl.float8_143, 0, False)
114
115        W_t = popxl.variable(weight_fp8, popxl.float8_143)
116        conv_graph_0 = ir.create_graph(conv_, a_fp8, log2_scale_t)
117

Download float8_inference.py

In the PopXL Module you can see the conv_pow2scaled() operator which takes a log2_scale tensor, in addition to our float8 input and weight tensors, as well as all of the usual parameters used in a conv() operator.

Listing 5.2 Example of using float8 tensors
14class ConvFloat8(popxl.Module):
15    """
16    Define a float8 convolution layer in PopXL.
17    """
18
19    def __init__(self, opts_: argparse.Namespace) -> None:
20
21        self.in_channel = opts_.in_channel
22        self.out_channel = opts_.out_channel
23        self.h_kernel = opts_.h_kernel
24        self.w_kernel = opts_.w_kernel
25        self.strides = opts_.strides
26        self.group = opts_.group
27
28        self.W: popxl.Tensor = None
29
30    def build(self, x: popxl.Tensor, log2_scale: popxl.Tensor) -> popxl.Tensor:
31        """
32        Override the `build` method to build a graph.
33        Note:
34            x is a popxl.float_143 tensor, and log2_scale is an popxl.int32 tensor,
35            in the range [-32,32)
36        """
37        self.W = popxl.graph_input(
38            (
39                self.out_channel,
40                self.in_channel // self.group,
41                self.h_kernel,
42                self.w_kernel,
43            ),
44            popxl.float8_143,
45            "W",
46        )
47
48        # Note this is a pow2scaled convolution that needs a log2_scale tensor.
49        y = ops.conv_pow2scaled(x, self.W, log2_scale, stride=self.strides)
50
51        y = ops.gelu(y)
52        return y
53
54

Download float8_inference.py

See conv_pow2scaled() for more details on this operator.