17. Application example: MNIST with replication and RTS

In this section, we use RTS variables (Section 13.3, Variable tensors for replicated tensor sharding) based on the previous MNIST application example (Section 16, Application example: MNIST). Recall that RTS is based on replication, so first of all we need to change the code to support replication. Then we need to change the variable tensors into RTS variable tensors.

17.1. Add support for replications

The replication, a data parallelism, is achieved by running the same program in parallel on multiple sets of IPUs. PopXL currently support local replication. In local replication, replications are handled by a single host instance. We’ve added the command line option --replication-factor to the code, indicating the number of replications we need. Then we assign this parameter to each IR, as shown in the code below, in build_train_ir and build_test_ir.

ir.replication_factor = opts.replication_factor

When replication is used, the total number of data processed in each batch equals to the batch size in one replica multiplied by the replication factor. In this case, we need to provide enough input data to run in multiple replicas each step. To have the same training results from different replication factors, we need to guarantee that the batch_size * replication_factor stays the same. We can achieve this when preparing the dataset by simply keeping the code for batch size unchanged (when replication_factor equal to 1), and multiplying the replication factor as shown in the code below.

training_data, test_data = get_mnist_data(
    opts.test_batch_size * opts.replication_factor,
    opts.batch_size * opts.replication_factor)

We also need to change the data passed to each session to match the required dimension of its inputs in train (Section 12.5, Data input shape):

315        for data, labels in bar:
316            if opts.replication_factor > 1:
317                data = data.reshape((opts.replication_factor, opts.batch_size, 28, 28))
318                labels = labels.reshape((opts.replication_factor, opts.batch_size))
319                inputs: Mapping[popxl.HostToDeviceStream, np.ndarray] = dict(
320                    zip(input_streams, [data.float().numpy(), labels.int().numpy()])
321                )
322            else:
323                inputs: Mapping[popxl.HostToDeviceStream, np.ndarray] = dict(
324                    zip(
325                        input_streams,
326                        [data.squeeze().float().numpy(), labels.int().numpy()],
327                    )
328                )

After making similar changes of data shape for the test session, replication is also supported in the testing of the trained model. You can check whether the replication works by running:

python mnist_rts.py --replication-factor 2 --batch-size 4

It should give similar test accuracy to the following command

python mnist_rts.py --replication-factor 1 --batch-size 8

17.2. Change variable tensors to RTS variable tensors

We can create RTS variables for training and testing in two different ways. One which exposes the remote buffer (using remote_replica_sharded_variable()) and one which does not expose the remote buffer (using replica_sharded_variable()). In the code we:

To collect all the info needed by an RTS variable, we’ve used the named tuple

Trainable = namedtuple('Trainable', ['var', 'shards', 'full', 'remote_buffer'])

The var is the remote variable tensor, the shards is the shards after remote load operation, the full is the tensor after all gather operation, and the remote_buffer is the corresponding remote buffer that handles the variable if known. If a variable tensor is not an RTS variable, then shards, full, and remote_buffer will be None.

The Trainable for W0, trainable_w0, is created as shown in the code below:

83    if ir.replication_factor > 1 and opts.rts:
84        W0_remote, W0_shards = popxl.replica_sharded_variable(
85            W0_data, dtype=popxl.float32, name="W0"
86        )
87        W0 = ops.collectives.replicated_all_gather(W0_shards)
88        trainable_w0 = Trainable(W0_remote, W0_shards, W0, None)
89    else:
90        W0 = popxl.variable(W0_data, name="W0")
91        trainable_w0 = Trainable(W0, None, None, None)

The Trainable for W1, trainable_w1, is created as shown in the code below:

101    if ir.replication_factor > 1 and opts.rts:
102        # create remote buffer that match shard shape and dtype
103        var_shard_shape: Tuple[int, ...] = (W1_data.size // ir.replication_factor,)
104        buffer = popxl.remote_buffer(
105            var_shard_shape, popxl.float32, entries=ir.replication_factor
106        )
107        # create remote rts variable
108        W1_remote = popxl.remote_replica_sharded_variable(W1_data, buffer, 0)
109        # load the remote rts variable from each shard
110        W1_shards = ops.remote_load(buffer, 0)
111        # gather all the shards to get the full weight
112        W1 = ops.collectives.replicated_all_gather(W1_shards)
113        trainable_w1 = Trainable(W1_remote, W1_shards, W1, buffer)
114    else:
115        W1 = popxl.variable(W1_data, name="W1")
116        trainable_w1 = Trainable(W1, None, None, None)

Notice that when we get the gradient for an input RTS variable tensor, the tensor is the (non-sharded) “full” tensor which has been gathered from the shards by using replicated_all_gather(). After obtaining gradients for the “full” tensors, the gradients are then sliced by using replicated_reduce_scatter() to update each shard of the RTS variable tensor.

if params["W1"].shards is not None:
    grad_w_1 = ops.collectives.replica_sharded_slice(grad_w_1)

if params["W0"].shards is not None:
    grad_w_0 = ops.collectives.replica_sharded_slice(grad_w_0)

When you update a variable tensor, and if a remote buffer is used, you also need to restore the updated value to the right place as well.

212def update_weights_bias(opts, grads, params) -> None:
213    """
214    Update weights and bias by W += - lr * grads_w, b += - lr * grads_b.
215    """
216    for k, v in params.items():
217        if v.shards is not None:
218            # rts variable
219            ops.scaled_add_(v.shards, grads[k], b=-opts.lr)
220            if v.remote_buffer is not None:
221                ops.remote_store(v.remote_buffer, 0, v.shards)
222        else:
223            # not rts variable
224            ops.scaled_add_(v.var, grads[k], b=-opts.lr)