Copyright (c) 2022 Graphcore Ltd. All rights reserved.

# Phased Execution

Phased execution is an execution strategy that can be applied when the whole model is too big to fit in memory.

When you design a phased execution strategy, you partition the graph execution into **phases**. Variables and required activations for the phases are stored remotely in the streaming memory.
When the phase needs to be executed, variables and needed activations are loaded. This way tensors are only required to be alive during the phase execution.

<figure>
<img src="images/phase_diagram.jpg" style="width:400px;"/>
<figcaption> <b>Fig 1: </b> Execution diagram for a phase, consisting of <code> load </code> -> <code>compute</code> -> <code>store</code> steps. Load and store operations can be towards the streaming memory (remote load/store) or towards the host.
</figcaption>
</figure>

<figure>
<img src="images/phased.jpg" style="width:700px;"/>
<figcaption> <b>Fig 2: </b> On the right, non phased model: all variables and activations are alive for the whole duration of the program.  On the left, an example of phased execution with three phases, corresponding to the three layers of the model. Variables and activations for each phase are stored remotely and only loaded when the layer needs to execute. The lower section of the image shows the difference in memory occupation (just an example) due to reduced liveness of variables and activations. Backward is not shown.
</figcaption>
</figure>

## Batch serialisation

The building block of phased execution is the phase graph, made of its `load` -> `compute` -> `store` steps. 
Typically phased execution is implemented together with gradient accumulation technique and the phase graph needs to be repeated in the gradient accumulation loop. 
The naive implementation of the gradient accumulation loop would look like:

```python
for _ in range(gradient_accumulation_steps):
    ...
    # ---- Layer A ----
    
    a_vs = load() # load layer specific variables 
    a_xs = load() # load inputs
    a_ys = a.bind(vs).call(xs) 
    store(a_ys) # store activations

    # ---- Layer B ----
    
    b_vs = load() # load layer specific variables 
    b_xs = load() # load inputs, that are the previous layer activations a_ys
    b_ys = b.bind(vs).call(b_xs)
    store(b_ys) # store activations
    ...
optimiser
```

However, variables are updated only after the gradient accumulation loop. Hence, they can be loaded once for all to reduce the number of load operations and the communication cost.
A better loop for phased execution is 

```python
...
# ---- Layer A ----

a_vs = load() # load layer specific variables 
for _ in range(gradient_accumulation_steps):
    a_xs = load() # load inputs
    a_ys = a.bind(vs).call(a_xs)
    store(a_ys) # store activations. Note that we now need to store activations for each step!

# Layer B
b_vs = load() # load layer specific variables 
for _ in range(gradient_accumulation_steps):
    b_xs = load() # load inputs, that are previous layer activations for the same GA step, a_ys[step]
    b_ys = b.bind(vs).call(b_xs)
    store(b_ys) # store activations
...
optimiser
```
Batch serialisation is a transform that takes a graph and build a `repeat` loop with this kind of structure:
```
load variables for the phase
repeat(load activations - compute - store activations)
... additional things (for example, optimizer)
```

Since phased execution makes use of remote buffers, you should use RTS to amortise the communication overhead, in which case you need to include the appropriate collective operations in the picture.

Images below illustrate batch serialised phased execution.
<figure>
<img src="images/batch_serialisation.jpg" style="width:700px;"/>
<figcaption> <b>Fig 3: </b> A batch serialised forward phase. On the right, RTS is included.
</figcaption>
</figure>

<figure>
<img src="images/batch_serialisation_grad.jpg" style="width:700px;"/>
<figcaption> <b>Fig 4: </b> A batch serialised backward phase. On the right, RTS is included.
</figcaption>
</figure>


### Batch serialisation in popxl.addons
It should be clear from the above code snippets that in batch serialisation we need to store activations for all steps.
This is reflected in the remote buffers structure. In the context of `popxl.addons` batch serialisation, it's useful to consider remote buffers as if they were matrices, with rows corresponding to different phases that can share the same buffer (they need to have the same graph) and `steps` columns corresponding to the batch indices. Underneath this logical layout, each element at a position defined by row and column has its own entry in the underlying `popxl` remote buffer, consisting of `rows * steps` entries.

<figure>
<img src="images/remote_buffer.jpg" style="width:700px;"/>
<figcaption> <b>Fig 5: </b> Remote buffer logical layout. Rows correspond to different phases that can share the same buffer (for example, when the phases are identical layers). Columns are the different batches, once for each gradient accumulation step. The entry index in the buffer is given by the <code> flat_index </code>.
</figcaption>
</figure>

You can create this kind of buffers with `batch_serial_buffer(t: popxl.Tensor, steps: int, rows: int = 1)`.

To build a batch serialised graph in `popxl.addons` you can use `batch_serialise(...)` and `batch_serialise_fwd_and_grad(...)`

```python
batch_serialise(
                    graph: GraphWithNamedArgs,
                    steps: int,
                    load_handles: Dict[popxl.Tensor, Union[popxl.HostToDeviceStream, RemoteBufferAndOffset]],
                    store_streams: Dict[popxl.Tensor, popxl.DeviceToHostStream],
                    store_buffers: Dict[popxl.Tensor, RemoteBufferAndOffset],
                    seed_input: Optional[popxl.Tensor] = None,
                    rows: int = 1,
                    io_mode: Literal['compute', 'io', 'io_overlapped'] = 'io'
                ) -> BatchSerialResult
```
You need to provide the `graph` which you want to repeat for `steps` times.

Also, you need to tell how the inputs for the graph are loaded, providing a dictionary between each input and a host to device stream (if loaded from the host) or a `RemoteBufferAndOffset` which is just a `tuple` of a buffer and a row offset to access it.
For example, if you created a buffer with 
```python
x_buffer = batch_serial_buffer(first_layer_output, steps=opts.gradient_accumulation, rows=3)
```
you can then use it as a `load handle` for next layers:
```python
layer2_bs = batch_serialise(layer2, steps, {layer2.graph.inputs[0] : (x_buffer,0)}, ...)
layer3_bs = batch_serialise(layer2, steps, {layer3.graph.inputs[0] : (x_buffer,1)}, ...)
layer4_bs = batch_serialise(layer2, steps, {layer4.graph.inputs[0] : (x_buffer,2)}, ...)
````
`(x_buffer,0)` is a `RemoteBufferWithOffset`, telling that `layer2` input should be loaded from the `x_buffer` first row. Likewise, `layer3.graph.inputs[0] : (x_buffer,1)` specifies that `layer3` input should be loaded from row one. 
As you can see, you don't have to worry about the batch dimension, you always just need to think about the row you want to access. Internally, the graph will access the correct column at each step.
If you don't provide a handle for a certain input, this will be an input of the batch serialised graph.

After specifying the `load_handles`, you can provide `store_streams` and `store_buffers`for the layer outputs. They are kept separate because sometimes you may want to use both: stream an output to the host and store it in a buffer.
Outputs that are not specified in `store_streams` or `store_buffers` are not outputs of the batch serialised graph.

If your layer requires a different seed each time it's executed (for example, if you are using dropout), you should provide that input as `seed_input` parameter. That way a new random seed will be generated for each iteration step and fed to the layer graph.

The `rows` parameter allows you to specify multiple rows.

When you call a batch serialised graph, the first input is the row offset to access remote buffers, as in
```python
bs_graph = batch_serialise(graph, ..., rows=2)
bs_graph.call(0) # access first row in the buffers
bs_graph.call(1) # access second row in the buffers
```

Finally, the `io_mode` parameter manages how to load/store tensors during the loop.
- `compute` uses the Compute tiles.
- `io` uses the IO tiles.
- `io_overlapped` uses the io tiles and builds the loop such that Compute and IO can execute at the same time.

The `io` and `io_overlapped` modes require some tiles to be reserved as IO tiles (read also [popart user guide](https://docs.graphcore.ai/projects/popart-user-guide/en/3.1.0/overlap_io.html?highlight=io%20tiles#configuring-io-tiles)).
You can do that by specifying session options after creating the `ir`.
```python
session_opts = ir._pb_ir.getSessionOptions()
session_opts.numIOTiles = 32 
```

The `batch_serialise_fwd_and_grad` transform is very similar to the `batch_serialise` transform
```python
def batch_serialise_fwd_and_grad(
        forward_graph: GraphWithNamedArgs,
        gradient_graph: GraphWithNamedArgs,
        named_inputs_for_grad_graph: NamedTensors,
        steps: int,
        load_handles: Dict[popxl.Tensor, Union[popxl.HostToDeviceStream, RemoteBufferAndOffset]],
        store_streams: Dict[popxl.Tensor, popxl.DeviceToHostStream],
        store_buffers: Dict[popxl.Tensor, RemoteBufferAndOffset],
        seed_input: Optional[popxl.Tensor] = None,
        rows: int = 1,
        io_mode: Literal['compute', 'io', 'io_overlapped'] = 'io') -> Tuple[BatchSerialResult, BatchSerialResult]:
```
It applies batch serialisation to the both forward and backward graphs, ensuring that all inputs in the backward graph that derive from the forward are properly managed.
This means that:
- If there is a `store_buffer` for the tensor, this same buffer is used as a `load_handle` in the backward.
- If there is a `load_handle` for the tensor, this `load_handle` is used
- If the tensor is provided in `named_inputs_for_grad_graph`, the returned gradient graph will have this tensor as a named input and you will need to bind the graph to it. Typically, you want to use this parameter to provide the forward variables:
```python
fwd_vars = fwd_facts.init()
bwd_vars = bwd_facts.init() # gradient accumulators in autodiff_with_accum
bwd_vars.update(fwd_vars.copy())
batch_serialised_bwd.graph.bind(bwd_vars).call(0) 
```
- If the tensor is not provided in any of these ways, a new buffer is created for it, where the forward tensor will be stored.

The result of a batch serialisation transform is a `BatchSerialResult`, gathering the batch serialised `graph`, `store_buffers` and a  dictionary to remap tensors from the original graph to the transformed one.


## Mnist with phased execution and batch serialisation
In this tutorial we are going to implement a phased execution mnist example illustrating all these concepts.
We will use data parallelism, remote buffers and replicated tensor sharding. Hence, check out the previous tutorials on these topics: 
- [Data parallelism and gradient accumulation](../3_data_parallelism)
- [Remote variables and RTS](../5_remote_variables_and_rts)

You may also want to have another look at outlining in the very [first tutorial](../1_basic_concepts).

Our network has 4 layers. We define 7 phases for training:
- fc1 forward
- fc2 forward
- fc3 forward
- output layer fwd + loss + output layer bwd 
- fc3 backward
- fc2 backward
- fc1 backward

Forward variables and optimizer state are stored remotely. 
We use a `Graphs` class to keep together the forward, backward, and optimizer graphs for the same module and easily deal with loading/storing/updating variables. 
We create three `Graphs` objects:
- one for `fc1`, corresponding to a `Linear` module. In `input_layer_batch_serialise`, we batch serialise the forward and backward graphs for this layer, using `batch_serialise_fwd_and_grad`.
- one for the inner layers,`fc2` and `fc3`, which are identical, corresponding to a `Linear` module with different input shape from the first one. Since they are identical, they share the same graph and they can use the same remote buffer for forward variables and optimizer state, but we need to specify `2` entries (the `Graphs` class has an `entries` parameter for this aim). In `layer_batch_serialise`, we batch serialise the forward and backward graphs, using `batch_serialise_fwd_and_grad`. We need to specify `rows=2` when using the transform, one row for each layer. 
- one corresponding to the `OutputLayerWithBwd` module. In this case, we don't have a backward graph: the forward graph already includes the backward. Hence, in `output_layer_batch_serialise`, we use the `batch_serialise` transform.

We also need two buffers to connect the phases, a `x_buffer` and a` dx_buffer` where each phase can read its inputs and store its outputs.
```python
x_buffer = batch_serial_buffer(
                               fc1.fwd.graph.outputs[0],
                               steps=opts.train.gradient_accumulation,
                               rows=num_inner_layers + 1
                               )
dx_buffer = batch_serial_buffer(
                                fc1.bwd.graph.inputs[0],
                                steps=opts.train.gradient_accumulation,
                                rows=num_inner_layers + 1
                                )
```
Each row in `x_buffer` corresponds to a forward phase excluding the last forward phase (which just reads from the buffer), and each row in `dx_buffer` correspond to a backward phase, excluding the last backward phase (which just reads from the buffer).

The first forward phase loads inputs from host, and store its output in the first row of the `x_buffer`.
Next forward phases load their inputs from the previous phase row, and store their output in their own row.
The output layer only reads from the buffer.

During backward, the order is reversed: the output layer stores its `dx` output into the last row, and next backward phases read their `dx` input from the next row  and store their `dx` output into their own row. The input layer only reads from the buffer.

The image below illustrates the concept
<figure>
<img src="images/x_dx_buffers.png"/>
<figcaption> <b>Fig 6: </b> Connecting phases together: x and dx buffers
</figcaption>
</figure>

### Imports

In [None]:
import argparse
from doctest import OutputChecker
from functools import partial
from typing import Mapping, Optional
from typing_extensions import Literal
import torch
import torchvision
from tqdm import tqdm
import numpy as np
from time import time
from dataclasses import dataclass, field

import popxl
import popxl_addons as addons
import popxl.ops as ops
from typing import Union, Dict
from popxl_addons.graph import GraphWithNamedArgs, BoundGraph
from popxl_addons.named_tensors import NamedTensors
from popxl_addons.variable_factory import NamedVariableFactories
from popxl.transforms import GradGraphInfo
import logging
from popxl_addons import (
    batch_serialise,
    batch_serialise_fwd_and_grad,
    batch_serial_buffer,
)
from popxl_addons.rts import (
    reduce_replica_sharded_graph,
    all_gather_replica_sharded_graph,
    replica_sharded_spec,
)

from popxl_addons.remote import (
    named_variable_buffers,
    load_remote_graph,
    store_remote_graph,
)

np.random.seed(42)

### Layers

In [None]:
# includes gelu
class Linear(addons.Module):
    def __init__(self, out_features: int, bias: bool = True, gelu: bool = True):
        super().__init__()
        self.out_features = out_features
        self.bias = bias
        self.gelu = gelu

    def build(self, x: popxl.Tensor) -> popxl.Tensor:
        # add a state variable to the module
        w = self.add_variable_input(
            "weight",
            partial(np.random.normal, 0, 0.02, (x.shape[-1], self.out_features)),
            x.dtype,
        )
        y = x @ w
        if self.bias:
            # add a state variable to the module
            b = self.add_variable_input("bias", partial(np.zeros, y.shape[-1]), x.dtype)
            y = y + b
        if self.gelu:
            y = ops.gelu(y)
        return y


class OutputLayerWithBwd(addons.Module):
    def __init__(self, out_features: int, bias: bool = True, gelu: bool = True):
        super().__init__()
        self.linear = Linear(out_features=out_features, bias=bias, gelu=gelu)

    def build(self, x: popxl.Tensor, labels=popxl.Tensor) -> popxl.Tensor:

        fwd_facts, fwd_graph = self.linear.create_graph(x.spec)
        bwd_facts, bwd_graph = addons.transforms.autodiff_with_accumulation(
            fwd_graph,
            tensors_to_accumulate_grads=fwd_graph.args.tensors,
            grads_required=[fwd_graph.graph.inputs[0]],
        )

        # outline forward
        vars = self.add_variable_inputs("fwd", fwd_facts)
        fwd_info = fwd_graph.bind(vars).call_with_info(x)
        x = fwd_info.parent_output(0)

        loss, dx = addons.ops.cross_entropy_with_grad(x, labels)

        # outline backward
        bwd_vars = self.add_variable_inputs("bwd", bwd_facts)
        (dx,) = bwd_graph.bind(bwd_vars).call(
            dx, args=bwd_graph.grad_graph_info.inputs_dict(fwd_info)
        )

        return dx, loss


# gelu included in the linear layer
class Net(addons.Module):
    def __init__(self, cache: Optional[addons.GraphCache] = None):
        super().__init__(cache=cache)
        self.fc1 = Linear(512)
        self.fc2 = Linear(512)
        self.fc3 = Linear(512)
        self.fc4 = Linear(10, gelu=False)

    def build(self, x: popxl.Tensor):
        x = x.reshape((-1, 28 * 28))
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        return x

### Optimizer

In [None]:
"""
Adam optimizer.
Defines adam update step for a single variable
"""


class Adam(addons.Module):
    # we need to specify in_sequence because a lot of operations are in place and their order
    # shouldn't be rearranged
    @popxl.in_sequence()
    def build(
        self,
        var: popxl.TensorByRef,
        grad: popxl.Tensor,
        *,
        lr: Union[float, popxl.Tensor],
        beta1: Union[float, popxl.Tensor] = 0.9,
        beta2: Union[float, popxl.Tensor] = 0.999,
        eps: Union[float, popxl.Tensor] = 1e-5,
        weight_decay: Union[float, popxl.Tensor] = 0.0,
        first_order_dtype: popxl.dtype = popxl.float16,
        bias_correction: bool = True
    ):

        # gradient estimators for the variable var - same shape as the variable

        # Sharded inputs must be added with add_replica_sharded_variable_input
        if var.meta_shape:
            first_order = self.add_replica_sharded_variable_input(
                "first_order",
                partial(np.zeros, var.meta_shape),
                first_order_dtype,
                by_ref=True,
            )
            second_order = self.add_replica_sharded_variable_input(
                "second_order",
                partial(np.zeros, var.meta_shape),
                popxl.float32,
                by_ref=True,
            )

        else:
            first_order = self.add_variable_input(
                "first_order",
                partial(np.zeros, var.shape),
                first_order_dtype,
                by_ref=True,
            )
            second_order = self.add_variable_input(
                "second_order", partial(np.zeros, var.shape), popxl.float32, by_ref=True
            )

        ops.var_updates.accumulate_moving_average_(first_order, grad, f=beta1)
        ops.var_updates.accumulate_moving_average_square_(second_order, grad, f=beta2)

        # adam is a biased estimator: provide the step to correct bias
        step = None
        if bias_correction:
            step = self.add_variable_input(
                "step", partial(np.zeros, ()), popxl.float32, by_ref=True
            )

        # calculate the weight increment with adam heuristic
        updater = ops.var_updates.adam_updater(
            first_order,
            second_order,
            weight=var,
            weight_decay=weight_decay,
            time_step=step,
            beta1=beta1,
            beta2=beta2,
            epsilon=eps,
        )

        # in place weight update: w += (-lr)*dw
        ops.scaled_add_(var, updater, b=-lr)

### Graphs

In [None]:
"""
Groups together the forward, backward and optimizers graphs of a layer for easy access and handling.
"""


class Graphs:
    def __init__(
        self,
        opts,
        layer: addons.Module,
        optimizer: addons.Module,
        entries: int,
        require_dx_0: bool,
        *args,
        **kwargs
    ):
        # Create Graphs for computing forward, gradient and optimizer
        fwd_facts, self.fwd = layer.create_graph(*args, **kwargs)
        required_grads = (self.fwd.graph.inputs[0],) if require_dx_0 else ()
        grad_facts, self.bwd = addons.autodiff_with_accumulation(
            self.fwd,
            tensors_to_accumulate_grads=self.fwd.args.tensors,
            grads_required=required_grads,
        )

        optim_facts = self._setup_optim(optimizer, self.fwd.args, opts)
        self._set_factories(fwd_facts, optim_facts, grad_facts)
        self._setup_graphs(opts, entries)

    @classmethod
    def empty(cls):
        return super().__new__(cls)

    def from_fwd_and_bwd(
        opts,
        fwd_and_bwd: addons.Module,
        optimizer: addons.Module,
        entries: int,
        *args,
        **kwargs
    ):
        graphs = Graphs.empty()
        graphs.bwd = None
        facts, graphs.fwd = fwd_and_bwd.create_graph(*args, **kwargs)
        optim_facts = graphs._setup_optim(optimizer, graphs.fwd.args.fwd, opts)
        graphs._set_factories(facts.fwd, optim_facts, facts.pop("bwd"))
        graphs._setup_graphs(opts, entries)
        return graphs

    def _setup_optim(self, optimizer: addons.Module, fwd_vars: NamedTensors, opts):
        optim_facts = {}
        self.optim = {}
        for name, var in fwd_vars.to_dict().items():
            optim_facts[name], self.optim[name] = optimizer.create_graph(
                replica_sharded_spec(var, threshold=opts.train.sharded_threshold),
                replica_sharded_spec(var, threshold=opts.train.sharded_threshold),
                lr=opts.train.lr,
                bias_correction=False,
            )
        return optim_facts

    def _set_factories(self, fwd_facts, optim_facts, grad_facts):
        self.facts = NamedVariableFactories(
            fwd=fwd_facts, optim=NamedVariableFactories.from_dict(optim_facts)
        )
        self.grad_facts = grad_facts

    def _setup_graphs(self, opts, entries: int):
        # Create remote buffers for fwd vars and optimiser state.
        self.buffers = named_variable_buffers(
            self.facts, entries, sharded_threshold=opts.train.sharded_threshold
        )
        # Create Graphs for loading/gathering/storing/reducing
        self._fwd_load, self._fwd_load_names = load_remote_graph(
            self.buffers.fwd, entries
        )
        self._optim_fwd_load, self._optim_fwd_load_names = load_remote_graph(
            self.buffers, entries
        )
        self._optim_store = store_remote_graph(self.buffers, entries)

        (
            self._fwd_all_gather,
            self._fwd_all_gather_names,
        ) = all_gather_replica_sharded_graph(
            NamedTensors.pack(self._fwd_load_names, self._fwd_load.graph.outputs)
        )
        grad_accums = self.bwd.args.copy() if self.bwd else self.fwd.args.bwd.copy()
        grad_accums.pop("mean_accum_counter")
        self._grad_reduce, self._grad_reduce_names = reduce_replica_sharded_graph(
            grad_accums, "mean", threshold=opts.train.sharded_threshold
        )

    # load forward variables
    def fwd_load(self, i: Union[int, popxl.Tensor]):
        return NamedTensors.pack(self._fwd_load_names, self._fwd_load.call(i))

    # load forward variables and optimizer state
    def optim_fwd_load(self, i: Union[int, popxl.Tensor]):
        return NamedTensors.pack(
            self._optim_fwd_load_names, self._optim_fwd_load.call(i)
        )

    # store forward variables and optimizer state
    def optim_store(self, args: NamedTensors, i: Union[int, popxl.Tensor]):
        return self._optim_store.bind(args).call(i)

    # gathers replica sharded forward variables
    def fwd_all_gather(self, args: NamedTensors):
        return NamedTensors.pack(
            self._fwd_all_gather_names, self._fwd_all_gather.bind(args).call()
        )

    # reduce scatter gradients
    def grad_reduce(self, args: NamedTensors):
        return NamedTensors.pack(
            self._grad_reduce_names, self._grad_reduce.bind(args).call()
        )

    # update forward variables
    def optimizer_remote_step(
        self,
        i: int,
        vars_and_state: NamedTensors,
        grads: NamedTensors,
        accum_counter: popxl.Tensor,
    ):
        _variables = vars_and_state.fwd.to_dict()
        _state = vars_and_state.optim
        _grads = grads.accum.to_dict()
        for name, graph in self.optim.items():
            state_clean_names = self._get_optimizer_state(name, _state)
            self.optim[name].bind(state_clean_names).call(
                _variables[name], _grads[name]
            )
        ops.var_updates.accumulator_scale_(accum_counter, 0.0)

    def _get_optimizer_state(self, name: str, state: NamedTensors) -> NamedTensors:
        attrs = name.split(".")
        for attr in attrs:
            state = getattr(state, attr)
        return state

### Batch serialisation utils

In [None]:
def input_layer_batch_serialise(
    opts,
    layer_graphs: Graphs,
    x_buffer: popxl.RemoteBuffer,
    dx_buffer: popxl.RemoteBuffer,
    input_stream: popxl.HostToDeviceStream,
):
    fwd_bs, bwd_bs = batch_serialise_fwd_and_grad(
        layer_graphs.fwd,
        layer_graphs.bwd,
        layer_graphs.fwd.args,
        opts.train.gradient_accumulation,
        load_handles={
            layer_graphs.fwd.graph.inputs[0]: input_stream,
            layer_graphs.bwd.graph.inputs[0]: (dx_buffer, 0),
        },
        store_streams={},
        store_buffers={
            layer_graphs.fwd.graph.outputs[0]: (x_buffer, 0),
        },
        rows=1,
        io_mode=opts.train.io_mode,
    )
    layer_graphs.fwd = fwd_bs.graph
    layer_graphs.bwd = bwd_bs.graph


def layer_batch_serialise(
    opts,
    layer_graphs: Graphs,
    x_buffer: popxl.RemoteBuffer,
    dx_buffer: popxl.RemoteBuffer,
    rows: int,
):
    fwd_bs, bwd_bs = batch_serialise_fwd_and_grad(
        layer_graphs.fwd,
        layer_graphs.bwd,
        layer_graphs.fwd.args,
        opts.train.gradient_accumulation,
        load_handles={
            layer_graphs.fwd.graph.inputs[0]: (
                x_buffer,
                0,
            ),  # load x from previous layer row
            layer_graphs.bwd.graph.inputs[0]: (
                dx_buffer,
                1,
            ),  # load dx from next layer row
        },
        store_streams={},
        store_buffers={
            layer_graphs.fwd.graph.outputs[0]: (
                x_buffer,
                1,
            ),  # store x in next layer row
            layer_graphs.bwd.graph.outputs[0]: (
                dx_buffer,
                0,
            ),  # store dx in previous layer row
        },
        rows=2,
        io_mode=opts.train.io_mode,
    )
    layer_graphs.fwd = fwd_bs.graph
    layer_graphs.bwd = bwd_bs.graph


def output_layer_batch_serialise(
    opts,
    layer_graphs: Graphs,
    x_buffer: popxl.RemoteBuffer,
    dx_buffer: popxl.RemoteBuffer,
    label_stream: popxl.h2d_stream,
    output_stream: popxl.d2h_stream,
):
    fwd_bs = batch_serialise(
        layer_graphs.fwd,
        opts.train.gradient_accumulation,
        load_handles={
            layer_graphs.fwd.graph.inputs[0]: (x_buffer, 2),
            layer_graphs.fwd.graph.inputs[1]: label_stream,
        },
        store_streams={layer_graphs.fwd.graph.outputs[1]: output_stream},
        store_buffers={layer_graphs.fwd.graph.outputs[0]: (dx_buffer, 2)},
        rows=1,
        io_mode=opts.train.io_mode,
    )
    layer_graphs.fwd = fwd_bs.graph

### Utils

In [None]:
def evaluate_throughput(session, samples_per_step, epochs: int = 5):
    inputs = {
        stream: np.ones(
            session._full_input_shape(stream.shape), stream.dtype.as_numpy()
        )
        for stream in session.expected_inputs()
    }

    durations = []
    with session:
        for i in range(epochs):
            start = time()
            session.run(inputs)
            dur = time() - start
            durations.append(dur)

    duration = np.mean(durations)

    result_str = (
        f"Mean duration: {duration} s "
        f"Throughput: {samples_per_step/duration:6.1f} samples/s "
    )
    print(result_str)


def get_mnist_data(test_batch_size: int, batch_size: int):
    training_data = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            "~/.torch/datasets",
            train=True,
            download=True,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(
                        (0.1307,), (0.3081,)
                    ),  # mean and std computed on the training set.
                ]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
    )

    validation_data = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            "~/.torch/datasets",
            train=False,
            download=True,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        ),
        batch_size=test_batch_size,
        shuffle=True,
        drop_last=True,
    )
    return training_data, validation_data


def accuracy(predictions: np.ndarray, labels: np.ndarray):
    ind = np.argmax(predictions, axis=-1).flatten()
    labels = labels.detach().numpy().flatten()
    return np.mean(ind == labels) * 100.0

### Configs

In [None]:
@dataclass
class Train:
    micro_batch_size: int = 8
    lr: Union[float, popxl.Tensor] = 1e-3
    epochs: int = 1
    gradient_accumulation: int = 1
    data_parallel: int = 1
    device = "ipu_hw"
    sharded_threshold: int = 512
    io_mode: str = "io"


@dataclass
class Test:
    micro_batch_size: int = 80
    device = "ipu_hw"


class Options:
    train = Train()
    test = Test()


opts = Options()
opts.train.micro_batch_size = 8
opts.train.lr = 1e-3
opts.train.epochs = 1
opts.train.gradient_accumulation = 4
opts.train.data_parallel = 2
opts.train.sharded_threshold = 512
opts.train.io_mod = "io"

opts.test.micro_batch_size = 80

### Train

In [None]:
def train_program(opts):
    assert opts.train.gradient_accumulation > 1
    assert opts.train.data_parallel > 1

    ir = popxl.Ir()
    ir.replication_factor = opts.train.data_parallel
    num_inner_layers = 2
    if opts.train.io_mode != "compute":
        session_opts = ir._pb_ir.getSessionOptions()
        session_opts.numIOTiles = 32

    with ir.main_graph:
        # -----  Define input and output streams -----
        img_spec = popxl.TensorSpec(
            (opts.train.micro_batch_size, 28 * 28), popxl.float32
        )
        inner_spec = popxl.TensorSpec((opts.train.micro_batch_size, 512), popxl.float32)

        img_stream = popxl.h2d_stream(img_spec.shape, popxl.float32, "image")
        label_stream = popxl.h2d_stream(
            (opts.train.micro_batch_size,), popxl.int32, "labels"
        )
        loss_stream = popxl.d2h_stream((), popxl.float32, "loss")
        optimizer = Adam(cache=True)

        # ----- Create graphs -----
        fc1 = Graphs(opts, Linear(512), optimizer, 1, False, img_spec)
        inner_layer = Graphs(
            opts, Linear(512), optimizer, num_inner_layers, True, inner_spec
        )
        fc4_fwd_bwd = Graphs.from_fwd_and_bwd(
            opts,
            OutputLayerWithBwd(10, gelu=False),
            optimizer,
            1,
            inner_spec,
            label_stream.spec,
        )

        x_buffer = batch_serial_buffer(
            fc1.fwd.graph.outputs[0],
            steps=opts.train.gradient_accumulation,
            rows=num_inner_layers + 1,
        )
        dx_buffer = batch_serial_buffer(
            fc1.bwd.graph.inputs[0],
            steps=opts.train.gradient_accumulation,
            rows=num_inner_layers + 1,
        )

        # ----- Transform graphs -----

        # apply batch serialisation
        input_layer_batch_serialise(opts, fc1, x_buffer, dx_buffer, img_stream)
        layer_batch_serialise(
            opts, inner_layer, x_buffer, dx_buffer, num_inner_layers
        )  # use a buffer with two rows because the inner layer is duplicated two times. row 0 is for fc2 and row 1 for fc3
        output_layer_batch_serialise(
            opts, fc4_fwd_bwd, x_buffer, dx_buffer, label_stream, loss_stream
        )

        # ----- Create Variables -----
        variables = NamedTensors()
        variables.insert("fc1", fc1.facts.init_remote(fc1.buffers, 0, "fc1"))
        variables.insert(
            "fc2", inner_layer.facts.init_remote(inner_layer.buffers, 0, "fc2")
        )
        variables.insert(
            "fc3", inner_layer.facts.init_remote(inner_layer.buffers, 1, "fc3")
        )
        variables.insert(
            "fc4", fc4_fwd_bwd.facts.init_remote(fc4_fwd_bwd.buffers, 0, "fc4")
        )

        # ----- Construct Execution Scheme -----

        # Phased Execution (with remote fwd variables and optimizer state). N layers executing separately
        # phase 1: fwd for layer 1:
        #   load fwd variables.
        #   for i in range(gradient_accumulation_steps):
        #        load inputs (xs)
        #        execute fwd
        #        store outputs & activations
        #
        # phase 2: fwd for layer 2:
        # ...
        # phase N: fwd + bwd + optimizer for layer N:
        #   load fwd variables and optimizer state.
        #   for i in range(gradient_accumulation_steps):
        #       load fwd and bwd inputs (xs)
        #       execute fwd
        #       compute loss
        #       execute bwd
        #       store outputs & activations
        #   call optimizer
        #   store updated fwd variables and optimizer state
        #
        # phase N+1: bwd for layer N-1 + optimizer
        #   load fwd variables and optimizer state. both needed, fwd vars are needed from the backward.
        #   for i in range(gradient_accumulation_steps):
        #       load bwd inputs (xs)
        #       execute bwd
        #       store outputs & activations
        #   call optimizer
        #   store updated fwd variables and optimizer state
        # ...
        # phase 2N-1: bwd for layer 1 + optimizer
        #

        # ----- Phased Execution -----
        with popxl.in_sequence(True):

            def forward_phase(graphs: Graphs, row_offset: int):
                vars = graphs.fwd_load(
                    row_offset
                )  # load forward remote variables, which are sharded
                vars = graphs.fwd_all_gather(
                    vars
                )  # gathered variables: graph must be bound to gathered vars
                # calling the graph executes the GA loop for the phase: repeat ( load xs - execute - store )
                graphs.fwd.bind(vars).call(row_offset)

            def backward_phase(graphs: Graphs, row_offset: int):
                is_joint_fwd_bwd = graphs.bwd is None
                # forward vars and optimizer state are needed in the backward.
                # loading them together is convenient
                fwd_vars_and_state = graphs.optim_fwd_load(row_offset)  # sharded
                vars: NamedTensors  # gathered variables comprising forward and backward named inputs
                reduced_grads: NamedTensors  # scattered gradient accumulators
                mean_accum_counter: popxl.Tensor
                if is_joint_fwd_bwd:
                    vars = NamedTensors(
                        fwd=graphs.fwd_all_gather(
                            fwd_vars_and_state.fwd
                        ),  # gathered forward variables
                        bwd=graphs.grad_facts.init_zero(),  # gradient accumulators
                    )
                    # calling the graph executes the GA loop for the phase: repeat ( load xs - execute fwd compute loss execute bwd - store )
                    graphs.fwd.bind(vars).call(
                        row_offset
                    )  # the fwd graph includes everything
                    reduced_grads = graphs.grad_reduce(vars.bwd)
                    mean_accum_counter = vars.bwd.mean_accum_counter
                else:
                    vars = graphs.fwd_all_gather(fwd_vars_and_state.fwd)
                    grad_accums = graphs.grad_facts.init_zero()  # gradient accumulators
                    vars.update(grad_accums.copy())
                    # calling the graph executes the GA loop for the phase: repeat ( load xs - execute bwd - store )
                    graphs.bwd.bind(vars).call(
                        row_offset
                    )  # just call the batch serialised bwd
                    reduced_grads = graphs.grad_reduce(grad_accums)
                    mean_accum_counter = vars.mean_accum_counter
                # optimizer
                graphs.optimizer_remote_step(
                    row_offset, fwd_vars_and_state, reduced_grads, mean_accum_counter
                )
                graphs.optim_store(fwd_vars_and_state, row_offset)  # store updated vars

            # ----- Phase 1 (fwd): fc1 Fwd -----
            forward_phase(fc1, 0)
            # ----- Phase 2 (fwd): fc2 Fwd-----
            forward_phase(inner_layer, 0)
            # ----- Phase 3 (fwd): fc2 Fwd-----
            forward_phase(inner_layer, 1)
            # ----- Phase 4 (merged fwd-bwd): Fwd for output layer, loss,  Bwd for output layer - Optimizer for output layer -----
            backward_phase(fc4_fwd_bwd, 0)
            # ----- Phase 5 (bwd): fc3 bwd - Optimizer for fc3 -----
            backward_phase(inner_layer, 1)
            # ----- Phase 6 (bwd): fc2 bwd - Optimizer for fc2 -----
            backward_phase(inner_layer, 0)
            # ----- Phase 7 (bwd): fc1 bwd - Optimizer for fc1 -----
            backward_phase(fc1, 0)

    # we have a for loop, the number of host loads is equal to gradient_accumulation
    ir.num_host_transfers = opts.train.gradient_accumulation
    # weights we need to retrieve and copy to the test session. They need to be in the same order as the full model (fc1-fc2-fc4-fc4).
    vars = NamedTensors(
        fc1=variables.fc1.fwd,
        fc2=variables.fc2.fwd,
        fc3=variables.fc3.fwd,
        fc4=variables.fc4.fwd,
    )

    return popxl.Session(ir, "ipu_hw"), [img_stream, label_stream], vars, loss_stream

In [None]:
global_batch_size = (
    opts.train.micro_batch_size
    * opts.train.gradient_accumulation
    * opts.train.data_parallel
)
training_data, test_data = get_mnist_data(opts.test.micro_batch_size, global_batch_size)
train_session, train_input_streams, train_variables, loss_stream = train_program(opts)

In [None]:
nr_batches = len(training_data)
for epoch in range(1, opts.train.epochs + 1):
    nr_batches = len(training_data)
    with train_session:
        for epoch in range(1, opts.train.epochs + 1):
            print("Epoch {0}/{1}".format(opts.train.epochs, opts.train.epochs))
            bar = tqdm(training_data, total=nr_batches)
            for data, labels in bar:
                # reshape data accounting for replication and num hosts transfers
                data = data.reshape(
                    train_session.ir.num_host_transfers,
                    train_session.ir.replication_factor,
                    opts.train.micro_batch_size,
                    28 * 28,
                ).squeeze()
                labels = labels.reshape(
                    train_session.ir.num_host_transfers,
                    train_session.ir.replication_factor,
                    opts.train.micro_batch_size,
                ).squeeze()

                inputs: Mapping[popxl.HostToDeviceStream, np.ndarray] = dict(
                    zip(train_input_streams, [data.squeeze().float(), labels.int()])
                )
                loss = train_session.run(inputs)
                losses_np = loss[
                    loss_stream
                ]  # shape(ir.num_host_transfers, ir.replication_factor, )
                avg_loss = np.mean(losses_np)
                bar.set_description("Loss:{:0.4f}".format(avg_loss))

In [None]:
# get weights data : dictionary { train_session variables : tensor data (numpy) }
train_vars_to_data = train_session.get_tensors_data(train_variables.tensors)

### Throughput and Testing

In [None]:
def test_program(test_batch_size, device):
    ir = popxl.Ir(replication=1)
    with ir.main_graph:
        # Inputs
        in_stream = popxl.h2d_stream((test_batch_size, 28, 28), popxl.float32, "image")
        in_t = ops.host_load(in_stream)

        # Create graphs
        facts, graph = Net().create_graph(in_t)

        # Initialise variables
        variables = facts.init()

        # Forward
        (outputs,) = graph.bind(variables).call(in_t)
        out_stream = popxl.d2h_stream(outputs.shape, outputs.dtype, "outputs")
        ops.host_store(out_stream, outputs)

    ir.num_host_transfers = 1
    return popxl.Session(ir, device), [in_stream], variables, out_stream

In [None]:
# Create test program and test session
test_session, test_input_streams, test_variables, out_stream = test_program(
    opts.test.micro_batch_size, opts.test.device
)

# dictionary { train_session variables : test_session variables }
train_vars_to_test_vars = train_variables.to_mapping(test_variables)

# Create a dictionary { test_session variables : tensor data (numpy) }
# We want to copy the values before evaluating throughput on synthetic data, otherwise weights are changed
test_vars_to_data = {
    test_var: train_vars_to_data[train_var].copy()
    for train_var, test_var in train_vars_to_test_vars.items()
}

# Copy trained weights to the program, with a single host to device transfer at the end
test_session.write_variables_data(test_vars_to_data)

# evaluate the ratio samples per step / time for train session
print("train session")
evaluate_throughput(train_session, global_batch_size)

In [None]:
nr_batches = len(test_data)
sum_acc = 0.0
with test_session:
    for data, labels in tqdm(test_data, total=nr_batches):
        inputs: Mapping[popxl.HostToDeviceStream, np.ndarray] = dict(
            zip(test_input_streams, [data.squeeze().float(), labels.int()])
        )
        output = test_session.run(inputs)
        sum_acc += accuracy(output[out_stream], labels)
print("Accuracy on test set: {:0.2f}%".format(sum_acc / len(test_data)))

In [None]:
samples_per_step = (
    opts.test.micro_batch_size
)  # no data parallelism or gradient accumulation for inference in this program
# evaluate the ratio samples per step / time for train session
print("test session")
evaluate_throughput(test_session, samples_per_step)