Copyright (c) 2022 Graphcore Ltd. All rights reserved.

# IPU Hardware
In this tutorial we are going to show two techniques that can be used to address memory issues.
Before entering into the topic, we suggest you read the [IPU hardware overview](https://docs.graphcore.ai/projects/ipu-programmers-guide/en/3.0.0/about_ipu.html). Here we just give a very short overiew 

The building blocks for IPU pods are IPU-machines, which are systems of 4 connected IPUs.

Each IPU has 10 **IPU-links** to connect to other IPUs. Multiple links can connect two IPUs, in which case the data transfer is faster between these IPU with respect to others. 
In the M2000 machine, for example, the basic IPU machine has the following structure

![Fig1: IPU links](images/ipu_links.png)
**Figure 1:** *Ipu links*

When you use replication you don't select the IPUs to be used for replicas. Given the required number of IPUs, the best option is chosen automatically. 

The memory in an IPU-based machine is made up of **In-Processor-Memory** (SRAM) and **Streaming Memory**.
![Fig2: IPU memory architecture](images/mem_arch.png )
**Figure 2:** *IPU memory architecture*

The In-Processor-Memory is the IPUâ€™s local **SRAM** and is split between tiles. Each tile only has direct access to code and data living in its local memory.

The Streaming Memory is made up of several **DDR memory chips** and is not directly accessible to the tiles. Reading/Storing data to the Streaming Memory happens through PCle links and Gateway links and is slower than chip-to-chip communication via IPU-links.

Knowing that difference is important to understand how to efficiently apply the following techniques. 


## Replicated tensor sharding 

When we are using replication, some tensor may be the same across replicas. 
In order to save memory, we can choose to slice these tensors so that each replica only has ` num_elements/num_replicas` elements of the tensor, with extra padding if needed (read also [popART User Guide](https://docs.graphcore.ai/projects/popart-user-guide/en/3.0.0/tensor_locations.html?highlight=replicated%20tensor%20sharding#replicated-tensor-sharding-rts)).

This technique is called [Replicated tensor sharding](https://docs.graphcore.ai/projects/ipu-programmers-guide/en/3.0.0/algorithmic_techniques.html#replicated-tensor-sharding).

When the full tensors are required by the program you need to make use of `gather` collectives, hence sharding comes with extra communication cost. This communication will happen through IPU-links. 

However, the full tensor is no longer in memory for the entire lifetime of the program, so the memory gain can be significant.

When we implement data parallelism with replication, a lot of variables are duplicated: each replica is using exactly the same weights, and also the optimizer state is the same: these are typically good candidates for RTS.

As we will see in the next section, RTS is commonly used together with remote variables because it can amortise its cost.

![Fig3: Replicated Tensor Sharding](images/rts.png)
**Figure 3:** *Replicated Tensor Sharding*

### RTS in popxl.addons
Sharded tensors can be identified since they have a `meta_shape`, representing the shape of the original tensor. Hence, you can know if a tensor is sharded by checking if `t.meta_shape` is not `None`.
The `shape` of a sharded tensor is always flattened, `(num_elements/num_replicas, )`.

When you create a layer with the ```addons.Module``` class, you can specify sharded inputs with the
`module.add_replica_sharded_variable_input` method. This works just like `module.add_variable_input` but expects sharded variables. Even if the input is sharded, the full tensor will be initialised.

An analogue function `add_replica_sharded_variable_input`  is available in `addons.variable_factory` if you are using `VariableFactories` outside the module context.

To shard a tensor either `remote_replica_sharded_variable` or `ops.collectives.replicated_reduce_scatter(..., configure_output_for_replicated_tensor_sharding=True)` can be used. We will see both examples in this tutorial.

Another useful function when dealing with sharded tensors is ```addons.rts.replica_sharded_spec(t: popxl.Tensor, threshold: int)```. If the tensor has more elements than threshold, it returns the appropriate `TensorSpec` for the shard.

## Remote Variables
When the IPU memory is insufficient the Streaming Memory can be used to store data remotely.
The ideal use case for the streaming memory is data which does not require frequent access, so that the communication cost can be amortised. 
Moreover, we can **shard** remote variables to make the transfer faster: in a way sharding a remote variable is equivalent to perform part of the transfer using the IPU-Links instead of remote transfer, increasing the effective bandwidth. 

To store data in the Streaming Memory we make use of **remote buffers** and **remote variables** (see also [remote variables in popxl](https://docs.graphcore.ai/projects/popxl/en/3.0.0/api.html#remote-buffers))

A remote buffer represents a data array in the Streaming Memory, and a remote variable is a tensor whose data is located at a certain position in a remote buffer (hence, in the Streaming Memory). 

A remote variable is always linked to a specific **entry** in a buffer, which tells where the variable starts.

The image below shows two different buffers, both with 2 entries but created with different tensor shapes, the first with a flat count of 4 elements and the second with 6. 
![Fig4: Remote buffers, entries and remote variables](images/remote_buffers.jpg)
**Figure 4:** *Remote buffers, entries and remote variables*

When you want to actually use a remote variable in your program, you need to load the data from the buffer to the IPU.

In `popxl`, remote buffers are created via ```remote_buffer(tensor_shape, tensor_dtype, entries)```.
The remote buffer can be thought as an array of the given `tensor_dtype`  with `tensor_total_elements * entries` slots. 

Each entry corresponds to a different variable: indeed, when you create a remote variable you have to specify the buffer and the entry number: ```popxl.remote_variable(data, buffer, entry_number)```.

Likewise, when you load or store a variable from a buffer you have to tell the entry number, because in this way we know where data for the variable is located
```python
loaded_x = ops.remote_load(buffer, entry_number)
ops.remote_store(buffer, entry_number, loaded_x)
```
### Remote variables in popxl.addons
In `popxl.addons` you typically manage `NamedTensors`. 

You can create `NamedRemoteBuffers` for a set of `NamedTensors` or `NamedVariableFactories` with 

```python
addons.named_buffers(named_tensors, entries, sharded_threshold)
addons.named_variable_buffers(named_factories, entries, sharded_threshold)
```
respectively.
These functions create a buffer for each VariableFactory or NamedTensors, with the specified entries.

If the shape is so that `nelms >= sharded_threshold`, a replica sharded RemoteBuffer will be created instead.
`entries > 1` can be useful if you have multiple instances of the same layer. You can create buffers for all the layer variables specifying `entries = layers_copies` and access variables for the different copies by changing the entry number. 

Once you have the `NamedRemoteBuffers` you can create a graph to load variables with  
```python
load_graph, names = load_remote_graph(buffers: NamedRemoteBuffers, entries: int = 1)
```
This function returns a `GraphWithNamedArgs` and a list of names.
The resulting graph has no named inputs and it requires the `entry_index` as input: `load_graph.call(0)` returns the remote variable stored at index 0 in the remote buffer.
The `entries` argument can be provided to resize each buffer as needed, in case you want to enlarge your buffer.

Likewise, you can create a graph to store variables into buffer after you have updated them with 
```python
store_graph = store_remote_graph(buffers: NamedRemoteBuffers, entries: int = 1)
```
The graph has a named input for each buffer provided, with the tensor shape of the buffer.  It needs to be bound before calling it.
```python
store_graph.bind(w0).call(0)
```

In the example below we illustrate these concepts.

In [None]:
import sys

In [None]:
import argparse
from functools import partial
from typing import Mapping, Optional
import torch
import numpy as np
from time import time
from dataclasses import dataclass
import popxl
import popxl_addons as addons
import popxl.ops as ops
from typing import Union, Dict
from popxl_addons.graph import GraphWithNamedArgs
from popxl_addons.named_tensors import NamedTensors
from popxl_addons.variable_factory import NamedVariableFactories
from popxl_addons.rts import replica_sharded_spec, reduce_replica_sharded_graph

In [None]:
class Add(addons.Module):
    def build(self, x: popxl.Tensor):
        w = self.add_variable_input("weight", partial(np.ones, x.shape), x.dtype)
        x = popxl.ops.add(w, x)
        return x


ir = popxl.Ir()
with ir.main_graph, popxl.in_sequence(True):
    # ----- Create a single add graph -----
    data = np.ones((5,), dtype=np.float32)
    x = popxl.constant(data)
    facts, graph = Add().create_graph(x.spec)

    # ----- Create NamedRemoteBuffers -----
    # We use two entries since we want to have two copies of the layer,
    # referencing different remote variables
    entries = 2
    buffers = addons.named_variable_buffers(
        facts, entries=entries, sharded_threshold=10
    )  # no shards

    # ----- Create create load and store graphs -----
    load_graph, names = addons.load_remote_graph(buffers, entries=entries)
    store_graph = addons.store_remote_graph(buffers, entries=entries)

    print(load_graph.print_schedule())
    print(store_graph.print_schedule())

    # ----- Initialise remote variables, both entries -----
    remote_vars_0 = facts.init_remote(buffers, 0)
    remote_vars_1 = facts.init_remote(buffers, 1)

    # ----- Load remote variables  -----
    (w_loaded_0,) = load_graph.call(0)
    (w_loaded_1,) = load_graph.call(1)

    # out streams to get the value back
    w0_initial_d2h = addons.host_store(
        w_loaded_0
    )  # expected output (all elements equal): 1
    w1_initial_d2h = addons.host_store(
        w_loaded_1
    )  # expected output (all elements equal): 1

    # ----- Bind and call  -----
    w_named_tensors_0 = NamedTensors.pack(names, (w_loaded_0,))
    w_named_tensors_1 = NamedTensors.pack(names, (w_loaded_1,))

    bound_add_0 = graph.bind(w_named_tensors_0)  # bind to first set of variables
    bound_add_1 = graph.bind(w_named_tensors_1)  # bind to second set of variables

    (x0,) = bound_add_0.call(x)
    (x1,) = bound_add_1.call(x)

    # out stream to get the value back
    x0_d2h = addons.host_store(x0)  # expected output (all elements equal): 1+1 = 2
    x1_d2h = addons.host_store(x1)  # expected output (all elements equal): 1+1 = 2

    # ----- Modify weights  -----
    update_val = np.full((5,), 1.0, dtype=np.float32)

    updater0 = popxl.constant(update_val)  # now first layer has w = 2
    updater1 = popxl.constant(2 * update_val)  # now second layer has w = 3

    w_loaded_0 += updater0
    w_loaded_1 += updater1

    # store the new weights into buffer
    store_graph.bind(w_named_tensors_0).call(0)
    store_graph.bind(w_named_tensors_1).call(1)

    # load again and check that the buffer contains the updated value
    (w_loaded_0,) = load_graph.call(0)
    (x0_new,) = bound_add_0.call(x)
    (w_loaded_1,) = load_graph.call(1)
    (x1_new,) = bound_add_1.call(x)

    # out stream to get the value back
    w0_after_d2h = addons.host_store(
        w_loaded_0
    )  # expected output (all elements equal): 2
    x0_new_d2h = addons.host_store(
        x0_new
    )  # expected output (all elements equal): 1+2 = 3
    w1_after_d2h = addons.host_store(
        w_loaded_1
    )  # expected output (all elements equal): 3
    x1_new_d2h = addons.host_store(
        x1_new
    )  # expected output (all elements equal): 1+3 = 4

with popxl.Session(ir, "ipu_hw") as session:
    outputs = session.run()

print("initial values")
print("\t w_0: ", outputs[w0_initial_d2h])
print("\t output0: ", outputs[x0_d2h])
print("\t w_1: ", outputs[w1_initial_d2h])
print("\t output1: ", outputs[x1_d2h])

print("updated values")
print("\t w_0: ", outputs[w0_after_d2h])
print("\t output0: ", outputs[x0_new_d2h])
print("\t w_1: ", outputs[w1_after_d2h])
print("\t output1: ", outputs[x1_new_d2h])

## Mnist with off-chip optimizer state

The starting point for this tutorial will be the [Data parallelism tutorial](../3_data_parallelism).

We are going to modify the optimizer so that its state is stored remotely in the Streaming Memory:
- The Adam optimizer now uses `add_replica_sharded_variable_input` when the variable is sharded.
- We introduce a `remote_step` function that creates remote buffers and all the graphs (load/store + optimizer graph), load the state, call the optimizer and store the new state to the buffer. 

Since the optimizer state is sharded, also its inputs must be sharded: ```def build(self, var: popxl.TensorByRef, grad: popxl.Tensor, ...)``` `var` and `grad` must be sharded too.
This implies few things:

- When the optimizer is created, we need to specify sharded specs: 
```python
optimizer.create_graph(
                          replica_sharded_spec(var, threshold=sharded_threshold),
                          replica_sharded_spec(grad, threshold=sharded_threshold), 
                          ...
                          )
```
- When the optimizer is called, `var` and `grad` must be sharded tensors. To achieve this, we need to:
    - Reduce gradients across replicas with `reduce_replica_sharded_graph`. This becomes a `replicated_reduce_scatter(..., configure_output_for_replicated_tensor_sharding=True)` for tensors exceeding the provided threshold
    - Shard the variables using `ops.collectives.replica_sharded_slice`. This just means splitting the tensor. 
    ``` python
    for name, v in variables.named_tensors.items():
        ir = popxl.gcg().ir
        if v.nelms >= opts.sharded_threshold and v.nelms % ir.replication_factor == 0:
            shard = ops.collectives.replica_sharded_slice(v)
     ```
- The optimizer is called with the shards as inputs: `opt.call(sharded_var, sharded_grad)`. After the call, `sharded_var` is updated with the new value since is a `TensorByRef` input. However, we need to collect all the updated shards with `ops.collectives.replicated_all_gather` and copy the new value in the original full tensor with `ops.var_updates.copy_var_update_`.
```python
for name, v in enumerate(sharded_vars.tensors):
    if v.meta_shape:
        # we need to gather the updated shards
        v_full = ops.collectives.replicated_all_gather(v)
        # and copy the updated value in the original full tensor
        ops.var_updates.copy_var_update_(variables.tensors[name], v_full)
```

### Memory and performance
You can play around with the `sharded_threshold` option to see how the throughput, execution time and memory usage change. 
With `data_parallelism = 4`, `gradient_accumulation = 8`, `micro_batch_size = 8`, 
we evaluated the throghput for `100` iterations in the cases of no-sharding, sharding only the optimizer state related to fc1.weight, sharding all optimizer variables related to the weights but not those related to biases, and sharding all the state. 
The result is that best throughput and execution time is reached with all tensors sharded, in agreement with the observation that the IPU-links are faster and hence sharding is convenient when using remote variables. With respect to the case with on-chip optimizer (you can run the [Data parallelism tutorial](../3_data_parallelism) with the same parameters) we have lost something in throughput and speed, but this was expected.

Case | Threshold | Throughput (samples/s) | Execution time (s)
 --- | --------- | ---------------------- | ----------------- 
no sharding | 1e10 (inf) | 43418.9 | 0.0059
only state related to fc1.weight | 262145 | 50909.2 | 0.0050
all state related to weights | 1024 | 67403.4 | 0.0038
all state | 512 | 72570.0 | 0.0035
**on-chip optimizer** | | 107418.9 | 0.0024

To analyze the memory usage we can use [Popvision Graph Analyser](https://docs.graphcore.ai/projects/graph-analyser-userguide/en/3.11.2/user-guide.html?highlight=Rx%2FTx#execution-view-options).

Sharding or not sharding does not make a big difference in our case:

![](images/shard_non_shard_nal.png)
![](images/shard_non_shard_al.png)
**Figure 4:** *Target: off chip optimizer state, non sharded, Source: off chip optimizer state, all sharded*

Hence, we compare the sharded remote optimizer program, which is the one with higher throughput, with the on-chip program.

![](images/nal_remote_onChip.png)
![](images/al_remote_onChip.png)
**Figure 5:** *Target: on chip optimizer state, Source: off chip optimizer state, sharded*

The off-chip version of the program requires more always live memory: this reflects a bigger code size. However, the not always live memory usage, where we can find variables, is lower. 
A lot more information is available in the profile, and you can compare specific variables and operations in the two cases.

To generate your own profiles, comment out the code related to validation in the `mnist.py` script and run it with
```
POPLAR_ENGINE_OPTIONS='{"autoReport.all":"true"}' python3 mnist.py
```

You can try to generate your own and see how everything scales with the model size, the number of replicas, ... 

In [None]:
import torch
import torchvision
from tqdm import tqdm

In [None]:
np.random.seed(42)

In [None]:
def get_mnist_data(test_batch_size: int, batch_size: int):
    training_data = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            "~/.torch/datasets",
            train=True,
            download=True,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(
                        (0.1307,), (0.3081,)
                    ),  # mean and std computed on the training set.
                ]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
    )

    validation_data = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            "~/.torch/datasets",
            train=False,
            download=True,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        ),
        batch_size=test_batch_size,
        shuffle=True,
        drop_last=True,
    )
    return training_data, validation_data


def accuracy(predictions: np.ndarray, labels: np.ndarray):
    ind = np.argmax(predictions, axis=-1).flatten()
    labels = labels.detach().numpy().flatten()
    return np.mean(ind == labels) * 100.0


class Linear(addons.Module):
    def __init__(self, out_features: int, bias: bool = True):
        super().__init__()
        self.out_features = out_features
        self.bias = bias

    def build(self, x: popxl.Tensor) -> popxl.Tensor:
        # add a state variable to the module
        w = self.add_variable_input(
            "weight",
            partial(np.random.normal, 0, 0.02, (x.shape[-1], self.out_features)),
            x.dtype,
        )
        y = x @ w
        if self.bias:
            # add a state variable to the module
            b = self.add_variable_input("bias", partial(np.zeros, y.shape[-1]), x.dtype)
            y = y + b
        return y


class Net(addons.Module):
    def __init__(self, cache: Optional[addons.GraphCache] = None):
        super().__init__(cache=cache)
        self.fc1 = Linear(512)
        self.fc2 = Linear(512)
        self.fc3 = Linear(512)
        self.fc4 = Linear(10)

    def build(self, x: popxl.Tensor):
        x = x.reshape((-1, 28 * 28))
        x = ops.gelu(self.fc1(x))
        x = ops.gelu(self.fc2(x))
        x = ops.gelu(self.fc3(x))
        x = self.fc4(x)
        return x


def evaluate_throughput(session, samples_per_step, epochs: int = 5):
    inputs = {
        stream: np.ones(
            session._full_input_shape(stream.shape), stream.dtype.as_numpy()
        )
        for stream in session.expected_inputs()
    }

    durations = []
    with session:
        for i in range(epochs):
            start = time()
            session.run(inputs)
            dur = time() - start
            durations.append(dur)

    duration = np.mean(durations)

    result_str = (
        f"Mean duration: {duration} s "
        f"Throughput: {samples_per_step/duration:6.1f} samples/s "
    )
    print(result_str)

### Configs

In [None]:
@dataclass
class Train:
    micro_batch_size: int = 8
    lr: Union[float, popxl.Tensor] = 1e-3
    epochs: int = 1
    gradient_accumulation: int = 1
    data_parallel: int = 1
    device = "ipu_hw"
    sharded_threshold: int = 512


@dataclass
class Test:
    micro_batch_size: int = 80
    device = "ipu_hw"


class Options:
    train = Train()
    test = Test()


opts = Options()
opts.train.micro_batch_size = 8
opts.train.lr = 1e-3
opts.train.epochs = 1
opts.train.gradient_accumulation = 4
opts.train.data_parallel = 4
opts.train.sharded_threshold = 512

opts.test.micro_batch_size = 80

### Train

In [None]:
"""
Adam optimizer.
Defines adam update step for a single variable
"""


class Adam(addons.Module):
    # we need to specify in_sequence because a lot of operations are in place and their order
    # shouldn't be rearranged
    @popxl.in_sequence()
    def build(
        self,
        var: popxl.TensorByRef,
        grad: popxl.Tensor,
        *,
        lr: Union[float, popxl.Tensor],
        beta1: Union[float, popxl.Tensor] = 0.9,
        beta2: Union[float, popxl.Tensor] = 0.999,
        eps: Union[float, popxl.Tensor] = 1e-5,
        weight_decay: Union[float, popxl.Tensor] = 0.0,
        first_order_dtype: popxl.dtype = popxl.float16,
        bias_correction: bool = True
    ):

        # gradient estimators for the variable var - same shape as the variable

        # Sharded inputs must be added with add_replica_sharded_variable_input
        if var.meta_shape:
            first_order = self.add_replica_sharded_variable_input(
                "first_order",
                partial(np.zeros, var.meta_shape),
                first_order_dtype,
                by_ref=True,
            )
            second_order = self.add_replica_sharded_variable_input(
                "second_order",
                partial(np.zeros, var.meta_shape),
                popxl.float32,
                by_ref=True,
            )

        else:
            first_order = self.add_variable_input(
                "first_order",
                partial(np.zeros, var.shape),
                first_order_dtype,
                by_ref=True,
            )
            second_order = self.add_variable_input(
                "second_order", partial(np.zeros, var.shape), popxl.float32, by_ref=True
            )

        ops.var_updates.accumulate_moving_average_(first_order, grad, f=beta1)
        ops.var_updates.accumulate_moving_average_square_(second_order, grad, f=beta2)

        # adam is a biased estimator: provide the step to correct bias
        step = None
        if bias_correction:
            step = self.add_variable_input(
                "step", partial(np.zeros, ()), popxl.float32, by_ref=True
            )

        # calculate the weight increment with adam heuristic
        updater = ops.var_updates.adam_updater(
            first_order,
            second_order,
            weight=var,
            weight_decay=weight_decay,
            time_step=step,
            beta1=beta1,
            beta2=beta2,
            epsilon=eps,
        )

        # in place weight update: w += (-lr)*dw
        ops.scaled_add_(var, updater, b=-lr)


"""
Optimizer step  with off-chip state. Needs to be called in the main context.
A step consists in:
    - load state from buffer
    - call optimizer
    - store new state into buffer
"""


def remote_step(var: popxl.Tensor, grad: popxl.Tensor, optimizer: addons.Module, opts):
    facts, opt_graph = optimizer.create_graph(
        replica_sharded_spec(var, threshold=opts.train.sharded_threshold),
        replica_sharded_spec(grad, threshold=opts.train.sharded_threshold),
        lr=opts.train.lr,
    )

    # keep the state of the optimizer in remote buffers
    buffers = addons.named_variable_buffers(
        facts, entries=1, sharded_threshold=opts.train.sharded_threshold
    )
    # create graph for loading the state
    opt_load, names = addons.load_remote_graph(buffers, entries=1)
    # create graph for storing the state after it is updated
    opt_store = addons.store_remote_graph(buffers, entries=1)

    # init the buffer
    state = facts.init_remote(buffers)
    # load remote variables: remote buffer -> device memory
    loaded_state = opt_load.call(0)
    state = NamedTensors.from_dict(dict(zip(names, loaded_state)))

    # bind optimizer to loaded vars and call optimizer
    opt_graph.bind(state).call(var, grad)

    # bind store graph to loaded vars and store remote variables: device memory -> remote buffer
    opt_store.bind(state).call(0)


"""
Update all variables creating per-variable optimizers. 
"""


def optimizer_step(
    variables: NamedTensors,
    grads: Dict[popxl.Tensor, popxl.Tensor],
    optimizer: addons.Module,
    accum_counter: popxl.Tensor,
    opts,
):

    for name, var in variables.named_tensors.items():
        remote_step(var, grads[var], optimizer, opts)

    if accum_counter is not None:
        # Reset accumulators.
        ops.var_updates.accumulator_scale_(accum_counter, 0.0)

In [None]:
def train_program(opts):
    ir = popxl.Ir()
    ir.replication_factor = opts.train.data_parallel

    with ir.main_graph:
        # ----- Streams  -----

        img_spec = popxl.TensorSpec(
            (opts.train.micro_batch_size, 28, 28), popxl.float32
        )

        img_stream = popxl.h2d_stream(img_spec.shape, popxl.float32, "image")
        label_stream = popxl.h2d_stream(
            (opts.train.micro_batch_size,), popxl.int32, "labels"
        )
        loss_stream = popxl.d2h_stream((), popxl.float32, "loss")

        # ----- Create graphs  -----

        facts, fwd_graph = Net().create_graph(img_spec)
        variables = facts.init()
        bound_fwd = fwd_graph.bind(variables)

        counter = None
        required_grads = fwd_graph.args.tensors

        if opts.train.gradient_accumulation > 1:
            bwd_facts, bwd_graph = addons.autodiff_with_accumulation(
                fwd_graph, required_grads
            )
            accumulated_grads = bwd_facts.init()
            counter = accumulated_grads.mean_accum_counter
            bound_bwd = bwd_graph.bind(accumulated_grads)
        else:
            bwd_graph = addons.autodiff(fwd_graph, grads_required=required_grads)

        # ----- Gradient accumulation loop  -----
        with popxl.in_sequence(True):
            for ga_step in range(opts.train.gradient_accumulation):
                # ----- Load data  -----

                img_t = ops.host_load(img_stream)
                labels = ops.host_load(label_stream, "labels")

                # ----- Fwd  -----

                # full weights are used, we are not sharding the network weights
                fwd_info = bound_fwd.call_with_info(img_t)
                x = fwd_info.outputs[0]

                # ----- Loss  -----

                loss, dx = addons.ops.cross_entropy_with_grad(x, labels)
                ops.host_store(loss_stream, loss)

                # ----- Bwd  -----

                activations = bwd_graph.grad_graph_info.inputs_dict(fwd_info)
                if opts.train.gradient_accumulation > 1:
                    # full weights, we are not sharding the backward accumulators
                    bound_bwd.call(dx, args=activations)
                    grads = accumulated_grads.tensors[:-1]  # exclude the counter

                else:
                    grads = bwd_graph.call(dx, args=activations)

            if opts.train.data_parallel > 1:
                # ----- Reduce and shard gradients  -----
                keys = [
                    n
                    for n, g in accumulated_grads.named_tensors.items()
                    if n != "mean_accum_counter"
                ]
                grads = NamedTensors.pack(keys, grads)
                # tensors whose elements exceed threshold will be reduce_scattered -> sharded
                grad_reduce, names = reduce_replica_sharded_graph(
                    grads, "mean", threshold=opts.train.sharded_threshold
                )
                grads = grad_reduce.bind(grads).call()

                # ----- Shard forward variables  -----
                sharded_vars = []
                names = []
                for name, v in variables.named_tensors.items():
                    ir = popxl.gcg().ir
                    if (
                        v.nelms >= opts.train.sharded_threshold
                        and v.nelms % ir.replication_factor == 0
                    ):
                        shard = ops.collectives.replica_sharded_slice(v)
                    else:
                        shard = v

                    sharded_vars.append(shard)
                    names.append(name)

                sharded_vars = NamedTensors.pack(names, sharded_vars)
            else:
                sharded_vars = variables

            # ----- Optimizer  -----

            grads_dict = dict(zip(sharded_vars.tensors, grads))
            optimizer = Adam(cache=True)
            # the optimizer step will update the shards in place (sharded vars are TensorByRef inputs)
            optimizer_step(sharded_vars, grads_dict, optimizer, counter, opts)

            # gather shards and copy into full tensor
            if opts.train.data_parallel > 1:
                for name, v in enumerate(sharded_vars.tensors):
                    if v.meta_shape:
                        # we need to gather the updated shards
                        v_full = ops.collectives.replicated_all_gather(v)
                        # and copy the updated value in the original full tensor
                        ops.var_updates.copy_var_update_(
                            variables.tensors[name], v_full
                        )

    ir.num_host_transfers = opts.train.gradient_accumulation
    return (
        popxl.Session(ir, "ipu_hw"),
        [img_stream, label_stream],
        variables,
        loss_stream,
    )

In [None]:
global_batch_size = (
    opts.train.micro_batch_size
    * opts.train.gradient_accumulation
    * opts.train.data_parallel
)
training_data, test_data = get_mnist_data(opts.test.micro_batch_size, global_batch_size)
train_session, train_input_streams, train_variables, loss_stream = train_program(opts)

In [None]:
with train_session:
    nr_batches = len(training_data)
    for epoch in range(1, opts.train.epochs + 1):
        nr_batches = len(training_data)
        for epoch in range(1, opts.train.epochs + 1):
            print("Epoch {0}/{1}".format(opts.train.epochs, opts.train.epochs))
            bar = tqdm(training_data, total=nr_batches)
            for data, labels in bar:
                # reshape data accounting for replication and num hosts transfers
                data = data.reshape(
                    train_session.ir.num_host_transfers,
                    train_session.ir.replication_factor,
                    opts.train.micro_batch_size,
                    28,
                    28,
                ).squeeze()
                labels = labels.reshape(
                    train_session.ir.num_host_transfers,
                    train_session.ir.replication_factor,
                    opts.train.micro_batch_size,
                ).squeeze()

                inputs: Mapping[popxl.HostToDeviceStream, np.ndarray] = dict(
                    zip(train_input_streams, [data.squeeze().float(), labels.int()])
                )
                loss = train_session.run(inputs)
                losses_np = loss[
                    loss_stream
                ]  # shape(ir.num_host_transfers, ir.replication_factor, )
                avg_loss = np.mean(losses_np)
                bar.set_description("Loss:{:0.4f}".format(avg_loss))

In [None]:
# get weights data : dictionary { train_session variables : tensor data (numpy) }
train_vars_to_data = train_session.get_tensors_data(train_variables.tensors)

### Throughput and Testing

In [None]:
def test_program(test_batch_size, device):
    ir = popxl.Ir(replication=1)
    with ir.main_graph:
        # Inputs
        in_stream = popxl.h2d_stream((test_batch_size, 28, 28), popxl.float32, "image")
        in_t = ops.host_load(in_stream)

        # Create graphs
        facts, graph = Net().create_graph(in_t)

        # Initialise variables
        variables = facts.init()

        # Forward
        (outputs,) = graph.bind(variables).call(in_t)
        out_stream = popxl.d2h_stream(outputs.shape, outputs.dtype, "outputs")
        ops.host_store(out_stream, outputs)

    ir.num_host_transfers = 1
    return popxl.Session(ir, device), [in_stream], variables, out_stream

In [None]:
# Create test program and test session
test_session, test_input_streams, test_variables, out_stream = test_program(
    opts.test.micro_batch_size, opts.test.device
)

# dictionary { train_session variables : test_session variables }
train_vars_to_test_vars = train_variables.to_mapping(test_variables)

# Create a dictionary { test_session variables : tensor data (numpy) }
# We want to copy the values before evaluating throughput on synthetic data, otherwise weights are changed
test_vars_to_data = {
    test_var: train_vars_to_data[train_var].copy()
    for train_var, test_var in train_vars_to_test_vars.items()
}

# Copy trained weights to the program, with a single host to device transfer at the end
test_session.write_variables_data(test_vars_to_data)

# evaluate the ratio samples per step / time for train session
print("train session")
evaluate_throughput(train_session, global_batch_size)

In [None]:
nr_batches = len(test_data)
sum_acc = 0.0
with test_session:
    for data, labels in tqdm(test_data, total=nr_batches):
        inputs: Mapping[popxl.HostToDeviceStream, np.ndarray] = dict(
            zip(test_input_streams, [data.squeeze().float(), labels.int()])
        )
        output = test_session.run(inputs)
        sum_acc += accuracy(output[out_stream], labels)
print("Accuracy on test set: {:0.2f}%".format(sum_acc / len(test_data)))

In [None]:
samples_per_step = (
    opts.test.micro_batch_size
)  # no data parallelism or gradient accumulation for inference in this program
# evaluate the ratio samples per step / time for train session
print("test session")
evaluate_throughput(test_session, samples_per_step)