4. Data types
Currently, PopXL supports the data types listed in Table 4.1.
These data types are defined in popxl
directly and
will be converted to their IPU-compatible data type. Note that if the session option popart.SessionOptions.enableSupportedDataTypeCasting
is set to True
, then int64
and uint64
will be downcast to int32
and uint32
, respectively.
PopXL dtype |
int |
floating point |
signed |
NumPy dtype |
Python dtype |
alias |
---|---|---|---|---|---|---|
|
False |
False |
False |
bool |
builtins.bool |
N/A |
|
True |
False |
True |
int8 |
None |
N/A |
|
True |
False |
True |
int32 |
None |
N/A |
|
True |
False |
False |
uint8 |
None |
N/A |
|
True |
False |
False |
uint32 |
None |
N/A |
|
False |
True |
True |
float16 |
None |
|
|
False |
True |
True |
float32 |
builtins.float |
|
|
False |
True |
True |
float64 |
None |
|
|
False |
True |
True |
uint8 |
None |
N/A |
|
False |
True |
True |
uint8 |
None |
N/A |
4.1. 8-bit floating point datatypes
There are two 8-bit float datatypes in PopXL, namely popxl.float8_143
and
popxl.float8_152
. The numbers in the names of these types refer to the format:
the number of bits used to represent the sign, exponent and mantissa. As with
other floating point representations, the exponent is subject to a bias. This
bias is different for each of the two formats:
PopXL dtype |
Number of sign bits |
Number of exponent bits |
Number of mantissa bits |
Exponent bias |
Smallest positive value |
Largest positive value |
---|---|---|---|---|---|---|
|
1 |
4 |
3 |
-8 |
\(2^-10\) |
\(240.0\) |
|
1 |
5 |
2 |
-16 |
\(2^-17\) |
\(57344.0\) |
More details about the numerical properties of these two 8-bit floating point data types can be found in arXiv paper 8-Bit Numerical Formats for Deep Neural Networks.
Because of the limited numerical range of 8-bit floating point numbers,
operations that consume or produce tensors of these types are usually fused with
a pow2
scaling operation. These operations have a log2_scale
parameter.
Internally, these operations multiply your 8-bit floating point data with a
factor of pow2(log2_scale)
. Note that you can use a positive log2_scale
to
accommodate large numerical ranges or you can use negative values for smaller
numerical ranges. Currently, we support log2_scale
parameter values in the
range \([-32,32)\).
Table 4.3 lists a number of utility functions and operations for 8-bit floats.
API function |
Description |
---|---|
Host-based conversion from 16/32/64-bit floating point data to a 8-bit floating point representation. |
|
Host-based conversion from a 8-bit floating point representation back to 16/32/64-bit floating point data. |
|
|
Operation to convert from 16-bit floating point to 8-bit floating point. |
|
Operation to convert from 8-bit floating point to 16-bit floating point. |
|
Operation to perform a matmul on 8-bit floating point data resulting in 16-bit floating point output. |
Note that for device-based operations that support 8-bit float operands the
log2_scale
operand is also a tensor parameter in its own right. This means you
can change this scaling at runtime if you so desire.
4.2. 8-bit floating point inference model example
An example of using float8 tensors in an inference graph is shown in the example
float8_inference.py
.
The float16 input data is loaded onto the device as-is, then cast to float8 on the
device with a pow2scale_then_cast()
operator.
After this we do the cast on the host of the trained weight data (in this example
the weights are randomly generated), then creating the popxl.variable()
for the float8 weights.
Note that in both cases we do not scale the values, as this is done within the conv_pow2scaled()
operator.
104 # Cast to fp8 on device before conv layer 105 # Note we not not scale here, as scaling is done within the conv op. 106 a_fp8 = ops.pow2scale_then_cast( 107 a, data_type=popxl.float8_143, log2_scale=popxl.constant(0) 108 ) 109 110 conv_ = ConvFloat8(opts_) 111 # Convert the weight data on the host. 112 # Note we not not scale here, as scaling is done within the conv op. 113 weight_fp8 = host_pow2scale_then_cast(weight, popxl.float8_143, 0, False) 114 115 W_t = popxl.variable(weight_fp8, popxl.float8_143) 116 conv_graph_0 = ir.create_graph(conv_, a_fp8, log2_scale_t) 117
In the PopXL Module
you can see the conv_pow2scaled()
operator which takes a
log2_scale
tensor, in addition to our float8 input and weight tensors, as well as all
of the usual parameters used in a conv()
operator.
14class ConvFloat8(popxl.Module): 15 """ 16 Define a float8 convolution layer in PopXL. 17 """ 18 19 def __init__(self, opts_: argparse.Namespace) -> None: 20 21 self.in_channel = opts_.in_channel 22 self.out_channel = opts_.out_channel 23 self.h_kernel = opts_.h_kernel 24 self.w_kernel = opts_.w_kernel 25 self.strides = opts_.strides 26 self.group = opts_.group 27 28 self.W: popxl.Tensor = None 29 30 def build(self, x: popxl.Tensor, log2_scale: popxl.Tensor) -> popxl.Tensor: 31 """ 32 Override the `build` method to build a graph. 33 Note: 34 x is a popxl.float_143 tensor, and log2_scale is an popxl.int32 tensor, 35 in the range [-32,32) 36 """ 37 self.W = popxl.graph_input( 38 ( 39 self.out_channel, 40 self.in_channel // self.group, 41 self.h_kernel, 42 self.w_kernel, 43 ), 44 popxl.float8_143, 45 "W", 46 ) 47 48 # Note this is a pow2scaled convolution that needs a log2_scale tensor. 49 y = ops.conv_pow2scaled(x, self.W, log2_scale, stride=self.strides) 50 51 y = ops.gelu(y) 52 return y 53 54
See conv_pow2scaled()
for more details on this operator.