Copyright (c) 2022 Graphcore Ltd. All rights reserved.

# Data parallelism

Data parallelism techniques are essential to achieve a better throughput during execution.
There are mainly two kind of data parallelism:

- **intra-device data parallelism**. This is controlled by the **micro batch size**. When we use a ```micro_batch_size``` greater than one, the IPU executes the same program in parallel on the different samples of the micro batch. 
- **across-devices data parallelism**. We can also choose to run the same program on different devices, feeding the copies with different data. 
An efficient way to implement this is to use [**replication**](https://docs.graphcore.ai/projects/ipu-programmers-guide/en/3.2.0/algorithmic_techniques.html?highlight=micro%20batch#replication). This executes the same program across multiple devices and provides collective operations for the replicas to communicate. The ```session.ir.replication_factor``` controls the total number of replicas, whatever they are used for (we will see other uses of replication). When using replicas only to implement data parallelism, this factor will be set equal to your ```data_parallel``` factor.

Unless otherwise specified, with the term data parallelism we will always imply data parallelism across different replicas.

## Replication

Each program has an inner execution strategy, which can include multiple devices. 
For example, we can design a program that uses 2 IPUs for its execution.

It's important to understand that replication is always orthogonal to the inner structure of the program: if we choose to run a 2-IPUs program with 4 replicas, we will use ``` 2 * 4 = 8 ``` IPUs, since each replica will need 2 IPUs for its execution.

With data parallelism we can process more data simultaneously, but this comes with an additional **communication** cost.
This communication is achieved by means of **replicated collective operations**. Replicated collective operations perform a tensor operation across replicas, while standard **collective operations** perform tensor operations across devices in a single replica for multi-IPUs programs.

Available collective operations can be found in ```popxl.ops.collectives```. For example, a ```replicated_all_reduce``` op takes a tensor on each replica, calculates the sum (other reduction options are available) across the replicas and then creates a new tensor on each replica holding that sum. An ```all_reduce``` op will do the same, but across IPUs within a single replica.

Without data parallelism, a training step has the form

```shell
repeat:
  # fwd and bwd can happen on multiple IPUs
  calculate forward pass of model on micro_batch to compute loss 
  calculate backward pass of model to compute weight gradients
  update weight using weight gradients
```

With data parallelism the training program has extra collectives operations:

```shell
repeat:
  # fwd and bwd can happen on multiple IPUs
  calculate forward pass of model on micro_batch to compute loss
  calculate backward pass of model to compute weight gradients

  # collectives: extra communication step, extra cost
  obtain sum of gradients across all replicas

  update weight using weight gradients
```

Each model replica will run this program on different `micro-batches`, achieving data-parallelism.

![Figure 1: Replication during training ](images/data_parallel_training.png)
<figcaption> <b>Fig 1: </b> Replication during training 
 </figcaption>

Sometimes the communication cost can be significant, in which case we may want to amortize it with larger batches. [Gradient accumulation](https://docs.graphcore.ai/projects/ipu-programmers-guide/en/3.2.0/algorithmic_techniques.html?highlight=micro%20batch#gradient-accumulation) can be help in this regard.

## Gradient Accumulation

Up to now we have been updating the weights of the network after each micro batch.
With [Gradient accumulation](https://docs.graphcore.ai/projects/ipu-programmers-guide/en/3.2.0/algorithmic_techniques.html#gradient-accumulation) gradients are instead accumulated for ```N = gradient_accumulation``` micro batches before updating the weights. Accumulation may be summation, mean or running mean.

Without gradient accumulation, a training step has the form
```shell
repeat: #train steps
    load a micro batch 
    # process a micro batch
    calculate forward pass of model on micro batch to get loss 
    calculate backward pass of model to get current weight gradients 
    # if replication
    (obtain sum of gradients across all replicas)
    
    update weight using accumulated weight gradients
```
Each model replica processes ```micro_batch_size ``` samples per weight update.

With gradient accumulation you instead have:

```shell
repeat: #train steps
  zero accumulated weight gradients
  repeat: # gradient accumulation
    load a micro batch 
    # process a micro batch
    calculate forward pass of model on micro batch to get loss 
    calculate backward pass of model to get current weight gradients
    # accumulation step
    add current weight gradients to accumulated weight gradients
  #if replication
  (obtain sum of gradients across all replicas)
  update weight using accumulated weight gradients
```
In this case, each model replica processes ```micro_batch_size * gradient_accumulation``` samples per weight update, but samples are not loaded all at once. Only a ```micro_batch``` is loaded for each gradient accumulation step. This way, we can have larger batches still fitting in the device memory.
For a given number of processed samples, the time cost of the training with or without gradient accumulation is the same: ``` device_time = time_per_sample * micro_batch_size * gradient_accumulation``` (equal if we have smaller ```micro_batches``` with ```gradient_accumulation > 1``` or a single large ```micro_batch``` with ```gradient_accumulation = 1```)

However, from a memory perspective we have an advantage: since we don't have to load the full batch on the device, we can use larger batches and have a better throughput ( large batches are not always a good strategy from a training perspective though).
For this same reason gradient accumulation can be used with data parallelism to amortize the communication cost: if more samples are processed, the communication cost is less relevant.
Another use case for gradient accumulation is pipelining, which will be investigated in another tutorial.

Taking into account both data parallelism and gradient accumulation, the total number of samples that contribute to an optimizer step is

```global_batch_size = micro_batch_size * gradient_accumulation * data_parallel```

## Batch Terminology
- ```micro_batch_size``` size of the micro batch. Determines the data parallelism on a single device and is the number of samples that contribute to a gradient accumulation step.
- ```gradient_accumulation``` number of micro batches processed from a single replica before updating the weights
- ```data_parallel``` number of replicas used to implement data parallelism
- ```global_batch_size``` total number of samples that contributes to a weight update ```global_batch_size = micro_batch_size * gradient_accumulation * data_parallel ```

## Mnist with Gradient Accumulation & Data Parallelism

To add data parallelism to the program we first need to specify a ```ir.replication_factor ``` > 1. Before updating the weights we also have to add communication between replicas via **collective** ops. We can use ``` ops.collectives.replicated_all_reduce_```. Note that if we perform the operation in place, we need to use ```in_sequence(True)``` context.

```python
def train_program(opts):
    # total number of replicas used in the program, regardless of their use
    # here, we are using them to implement data parallelism, and no other 
    # use of replication is involved.
    ir = popxl.Ir(replication_factor = opts.train.data_parallel)

    with ir.main_graph:
        ...
            with popxl.in_sequence(True):
            # fwd
            # bwd
            # reduce gradients across replicas
                for g in grads:
                    g = ops.collectives.replicated_all_reduce_(g, op = 'mean')

            # optimizer step
    ...
```

To add also gradient accumulation the first thing to do is using the ```addons.transforms.autodiff_with_accumulation``` transform instead of ```autodiff``` to generate the backward graph.
```python
def autodiff_with_accumulation(
        graph: GraphWithNamedArgs,
        tensors_to_accumulate_grads: Iterable[popxl.Tensor],
        grads_required: Optional[Iterable[popxl.Tensor]] = None) -> Tuple[NamedVariableFactories, GraphWithNamedArgs]
```
While the standard ```autodiff``` transform produces a graph without state, i.e, without variables, the ```autodiff_with_accumulation``` transform generates a graph with state, hence returning both the graph and the variable factories for the ```NamedArgs```, which are the accumulators for the ```tensors_to_accumulate_grads``` and a ```mean_accum_counter``` which is incremented with each call of the gradient graph.  
Each tensor in ```tensors_to_accumulate_grads``` is automatically added as a required grad. You can provide another list of tensors in ```grads_required``` for non-accumulated gradients. 

<figure>
    <img src="images/autodiff.png" />
    <img src="images/backward_auto_accum.png"/>
    <figcaption> <b>Fig 2: </b> Differences between <code>autodiff</code> and <code>autodiff_with_accumulation</code>. On the left, <code>autodiff</code> logic. The backward takes as inputs forward activations y, the derivative y', x and w and produces the derivatives x' and w'. On the right,  <code>autodiff_with_accumulation</code>  logic. The backward takes as inputs forward activations y, the derivative y', x, w and the accumulators X' and W'. No output is produced since the result is accumulated in X' and W'. I are intermediate tensors.
</figcaption>
</figure>


<figure>
    <img src="images/autodiff_transf.png"/>
    <figcaption> <b>Fig 3: </b> <code>autodiff_with_accumulation</code> calls <code>autodiff</code> and then for each tensor in <code>tensors_to_accumulate_grads</code> adds an operation to the output gradient
    graph which takes a running mean of the tensor and the result stored in an accumulator tensor. The accumulators are added as NamedArgs TensorByRef inputs to the grad graph and the corresponding output of the original tensor removed.
 </figcaption>
</figure>


After each weight update the running mean needs to be reset. To this aim it's enough to reset the counter with  `ops.var_updates.accumulator_zero_(grad_args.mean_accum_counter)` after each weight update:

```python
def optimizer_step(...):
    # update variables
        ...
    # Reset accumulators.
    ops.var_updates.accumulator_scale_(grads.mean_accum_counter, 0.0)
```

Finally, note that in a gradient accumulation loop we typically have ```host loads``` in the program:

```shell
for _ in range(gradient_accumulation):
      ...
      input = ops.host_load(input_stream)
      ...
```
Remember from [session user guide]() that

> For each host_load (per tensor) operation in your model to run, you will need to increment the num_host_transfers by one. 

Hence, we will need to set
```
session.ir.num_host_transfers = gradient_accumulation
```
We will also need to provide training data in the appropriate data format when running the session: it should be of shape 
```
(num_host_transfers, replication_factor, *device_shape)
```

In this example we present both techniques. When gradient accumulation is 1, standard autodiff is used to avoid the extra memory from the accumulator variables.

In [None]:
import sys

In [None]:
import argparse
from functools import partial
from typing import Mapping, Optional
import torch
import numpy as np
from time import time
from dataclasses import dataclass
import popxl
import popxl_addons as addons
import popxl.ops as ops
from typing import Union, Dict
from popxl_addons.graph import GraphWithNamedArgs
from popxl_addons.named_tensors import NamedTensors
from popxl_addons.variable_factory import NamedVariableFactories
import torch
import torchvision
from tqdm import tqdm

In [None]:
np.random.seed(42)

In [None]:
def get_mnist_data(test_batch_size: int, batch_size: int):
    training_data = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            "~/.torch/datasets",
            train=True,
            download=True,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(
                        (0.1307,), (0.3081,)
                    ),  # mean and std computed on the training set.
                ]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
    )

    validation_data = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            "~/.torch/datasets",
            train=False,
            download=True,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        ),
        batch_size=test_batch_size,
        shuffle=True,
        drop_last=True,
    )
    return training_data, validation_data


def accuracy(predictions: np.ndarray, labels: np.ndarray):
    ind = np.argmax(predictions, axis=-1).flatten()
    labels = labels.detach().numpy().flatten()
    return np.mean(ind == labels) * 100.0


class Linear(addons.Module):
    def __init__(self, out_features: int, bias: bool = True):
        super().__init__()
        self.out_features = out_features
        self.bias = bias

    def build(self, x: popxl.Tensor) -> popxl.Tensor:
        # add a state variable to the module
        w = self.add_variable_input(
            "weight",
            partial(np.random.normal, 0, 0.02, (x.shape[-1], self.out_features)),
            x.dtype,
        )
        y = x @ w
        if self.bias:
            # add a state variable to the module
            b = self.add_variable_input("bias", partial(np.zeros, y.shape[-1]), x.dtype)
            y = y + b
        return y


class Net(addons.Module):
    def __init__(self, cache: Optional[addons.GraphCache] = None):
        super().__init__(cache=cache)
        self.fc1 = Linear(512)
        self.fc2 = Linear(512)
        self.fc3 = Linear(512)
        self.fc4 = Linear(10)

    def build(self, x: popxl.Tensor):
        x = x.reshape((-1, 28 * 28))
        x = ops.gelu(self.fc1(x))
        x = ops.gelu(self.fc2(x))
        x = ops.gelu(self.fc3(x))
        x = self.fc4(x)
        return x


def evaluate_throughput(session, samples_per_step, epochs: int = 5):
    inputs = {
        stream: np.ones(
            session._full_input_shape(stream.shape), stream.dtype.as_numpy()
        )
        for stream in session.expected_inputs()
    }

    durations = []
    with session:
        for i in range(epochs):
            start = time()
            session.run(inputs)
            dur = time() - start
            durations.append(dur)

    duration = np.mean(durations)

    result_str = (
        f"Mean duration: {duration} s "
        f"Throughput: {samples_per_step/duration:6.1f} samples/s "
    )
    print(result_str)

## Configs

In [None]:
@dataclass
class Train:
    micro_batch_size: int = 8
    lr: Union[float, popxl.Tensor] = 1e-3
    epochs: int = 1
    gradient_accumulation: int = 1
    data_parallel: int = 1
    device = "ipu_hw"


@dataclass
class Test:
    micro_batch_size: int = 80
    device = "ipu_hw"


class Options:
    train = Train()
    test = Test()


opts = Options()
opts.train.micro_batch_size = 8
opts.train.lr = 1e-3
opts.train.epochs = 1
opts.train.gradient_accumulation = 4
opts.train.data_parallel = 2

opts.test.micro_batch_size = 80

## Train

In [None]:
"""
Adam optimizer.
Defines adam update step for a single variable
"""


class Adam(addons.Module):
    # we need to specify in_sequence because a lot of operations are in place and their order
    # shouldn't be rearranged
    @popxl.in_sequence()
    def build(
        self,
        var: popxl.TensorByRef,
        grad: popxl.Tensor,
        *,
        lr: Union[float, popxl.Tensor],
        beta1: Union[float, popxl.Tensor] = 0.9,
        beta2: Union[float, popxl.Tensor] = 0.999,
        eps: Union[float, popxl.Tensor] = 1e-5,
        weight_decay: Union[float, popxl.Tensor] = 1e-2,
        first_order_dtype: popxl.dtype = popxl.float16,
        bias_correction: bool = True
    ):

        # gradient estimators for the variable var - same shape as the variable
        first_order = self.add_variable_input(
            "first_order", partial(np.zeros, var.shape), first_order_dtype, by_ref=True
        )
        ops.var_updates.accumulate_moving_average_(first_order, grad, f=beta1)

        # variance estimators for the variable var - same shape as the variable
        second_order = self.add_variable_input(
            "second_order", partial(np.zeros, var.shape), popxl.float32, by_ref=True
        )
        ops.var_updates.accumulate_moving_average_square_(second_order, grad, f=beta2)

        # adam is a biased estimator: provide the step to correct bias
        step = None
        if bias_correction:
            step = self.add_variable_input(
                "step", partial(np.zeros, ()), popxl.float32, by_ref=True
            )

        # calculate the weight increment with adam heuristic
        updater = ops.var_updates.adam_updater(
            first_order,
            second_order,
            weight=var,
            weight_decay=weight_decay,
            time_step=step,
            beta1=beta1,
            beta2=beta2,
            epsilon=eps,
        )

        # in place weight update: w += (-lr)*dw
        ops.scaled_add_(var, updater, b=-lr)


"""
Update all variables creating per-variable optimizers. 
"""


def optimizer_step(
    variables: NamedTensors,
    grads: Dict[popxl.Tensor, popxl.Tensor],
    optimizer: addons.Module,
    accum_counter: popxl.Tensor,
    lr: popxl.float32 = 1e-3,
):
    for name, var in variables.named_tensors.items():
        # create optimizer and state factories for the variable
        opt_facts, opt_graph = optimizer.create_graph(
            var, var.spec, lr=lr, weight_decay=0.0, bias_correction=False
        )
        state = opt_facts.init()
        # bind the graph to its state and call it.
        opt_graph.bind(state).call(var, grads[var])

    if accum_counter is not None:
        # Reset accumulators.
        # Resetting the counter for mean gradient accumulation is sufficient to zero the accumulators
        # in the next call to ops.accumulate_mean_
        ops.var_updates.accumulator_scale_(accum_counter, 0.0)

In [None]:
def train_program(opts):
    ir = popxl.Ir(replication=opts.train.data_parallel)
    # total number of replicas used in the program, regardeless their use
    # here, we are using them to implement data parallelism, and no other
    # use of replication is involved.

    with ir.main_graph:
        # Create input streams from host to device
        img_spec = popxl.TensorSpec(
            (opts.train.micro_batch_size, 28, 28), popxl.float32
        )
        img_stream = popxl.h2d_stream(img_spec.shape, popxl.float32, "image")
        label_stream = popxl.h2d_stream(
            (opts.train.micro_batch_size,), popxl.int32, "labels"
        )
        loss_stream = popxl.d2h_stream((), popxl.float32, "loss")

        # Create forward graph
        facts, fwd_graph = Net().create_graph(img_spec)
        variables = facts.init()
        bound_fwd = fwd_graph.bind(variables)

        counter = None
        required_grads = fwd_graph.args.tensors

        if opts.train.gradient_accumulation > 1:
            bwd_facts, bwd_graph = addons.autodiff_with_accumulation(
                fwd_graph, required_grads
            )
            accumulated_grads = bwd_facts.init()
            counter = accumulated_grads.mean_accum_counter
            bound_bwd = bwd_graph.bind(accumulated_grads)
        else:
            # standard autodiff, avoid extra memory
            bwd_graph = addons.autodiff(fwd_graph, grads_required=required_grads)

        # in sequence needed for in place ops
        with popxl.in_sequence(True):
            # gradient accumulation loop
            for ga_step in range(opts.train.gradient_accumulation):
                # load data
                img_t = ops.host_load(img_stream)
                labels = ops.host_load(label_stream, "labels")

                # fwd
                fwd_info = bound_fwd.call_with_info(img_t)
                x = fwd_info.outputs[0]

                # loss
                loss, dx = addons.ops.cross_entropy_with_grad(x, labels)
                ops.host_store(loss_stream, loss)

                # bwd
                activations = bwd_graph.grad_graph_info.inputs_dict(fwd_info)

                if opts.train.gradient_accumulation > 1:
                    bound_bwd.call(dx, args=activations)
                    grads = accumulated_grads.tensors[:-1]  # exclude the counter

                else:
                    grads = bwd_graph.call(dx, args=activations)

            # reduce gradients across replicas with add and divide by the number of replicas
            if opts.train.data_parallel > 1:
                for g in grads:
                    g = ops.collectives.replicated_all_reduce_(g, op="mean")

            # optimizer step: the optimizer resets the accumulators
            grads_dict = dict(zip(variables.tensors, grads))
            optimizer = Adam(cache=True)
            optimizer_step(variables, grads_dict, optimizer, counter, opts.train.lr)

    # we have a for loop, the number of host loads is equal to gradient_accumulation
    ir.num_host_transfers = opts.train.gradient_accumulation

    return (
        popxl.Session(ir, "ipu_hw"),
        [img_stream, label_stream],
        variables,
        loss_stream,
    )

In [None]:
global_batch_size = (
    opts.train.micro_batch_size
    * opts.train.gradient_accumulation
    * opts.train.data_parallel
)
training_data, test_data = get_mnist_data(opts.test.micro_batch_size, global_batch_size)
train_session, train_input_streams, train_variables, loss_stream = train_program(opts)

In [None]:
nr_batches = len(training_data)
with train_session:
    for epoch in range(1, opts.train.epochs + 1):
        nr_batches = len(training_data)
        for epoch in range(1, opts.train.epochs + 1):
            print("Epoch {0}/{1}".format(opts.train.epochs, opts.train.epochs))
            bar = tqdm(training_data, total=nr_batches)
            for data, labels in bar:
                # reshape data accounting for replication and num hosts transfers
                data = data.reshape(
                    train_session.ir.num_host_transfers,
                    train_session.ir.replication_factor,
                    opts.train.micro_batch_size,
                    28,
                    28,
                ).squeeze()
                labels = labels.reshape(
                    train_session.ir.num_host_transfers,
                    train_session.ir.replication_factor,
                    opts.train.micro_batch_size,
                ).squeeze()

                inputs: Mapping[popxl.HostToDeviceStream, np.ndarray] = dict(
                    zip(train_input_streams, [data.squeeze().float(), labels.int()])
                )
                loss = train_session.run(inputs)
                losses_np = loss[
                    loss_stream
                ]  # shape(ir.num_host_transfers, ir.replication_factor, )
                avg_loss = np.mean(losses_np)
                bar.set_description("Loss:{:0.4f}".format(avg_loss))

In [None]:
# get weights data : dictionary { train_session variables : tensor data (numpy) }
train_vars_to_data = train_session.get_tensors_data(train_variables.tensors)

### Throughput and Testing

In [None]:
def test_program(test_batch_size, device):
    ir = popxl.Ir(replication=1)

    with ir.main_graph:
        # Inputs
        in_stream = popxl.h2d_stream((test_batch_size, 28, 28), popxl.float32, "image")
        in_t = ops.host_load(in_stream)

        # Create graphs
        facts, graph = Net().create_graph(in_t)

        # Initialise variables
        variables = facts.init()

        # Forward
        (outputs,) = graph.bind(variables).call(in_t)
        out_stream = popxl.d2h_stream(outputs.shape, outputs.dtype, "outputs")
        ops.host_store(out_stream, outputs)

    ir.num_host_transfers = 1
    return popxl.Session(ir, device), [in_stream], variables, out_stream

In [None]:
test_session, test_input_streams, test_variables, out_stream = test_program(
    opts.test.micro_batch_size, opts.test.device
)

# dictionary { train_session variables : test_session variables }
train_vars_to_test_vars = train_variables.to_mapping(test_variables)

# Create a dictionary { test_session variables : tensor data (numpy) }
# We want to copy the values before evaluating throughput on synthetic data, otherwise weights are changed
test_vars_to_data = {
    test_var: train_vars_to_data[train_var].copy()
    for train_var, test_var in train_vars_to_test_vars.items()
}

# Copy trained weights to the program, with a single host to device transfer at the end
test_session.write_variables_data(test_vars_to_data)

# evaluate the ratio samples per step / time for train session
print("train session")
evaluate_throughput(train_session, global_batch_size)

In [None]:
nr_batches = len(test_data)
sum_acc = 0.0
with test_session:
    for data, labels in tqdm(test_data, total=nr_batches):
        inputs: Mapping[popxl.HostToDeviceStream, np.ndarray] = dict(
            zip(test_input_streams, [data.squeeze().float(), labels.int()])
        )
        output = test_session.run(inputs)
        sum_acc += accuracy(output[out_stream], labels)
print("Accuracy on test set: {:0.2f}%".format(sum_acc / len(test_data)))

In [None]:
samples_per_step = (
    opts.test.micro_batch_size
)  # no data parallelism or gradient accumulation for inference in this program
# evaluate the ratio samples per step / time for test session
print("test session")
evaluate_throughput(test_session, samples_per_step)