15. Application example: MNIST with replication and RTS

In this section, we use RTS variables (Section 12.3, Variable tensors for replicated tensor sharding) based on the previous MNIST application example (Section 14, Application example: MNIST). Recall that RTS is based on replication, so first of all we need to change the code to support replication. Then we need to change the variable tensors into RTS variable tensors.

15.1. Add support for replications

The replication, a data parallelism, is achieved by running the same program in parallel on multiple sets of IPUs. PopXL currently support local replication. In local replication, replications are handled by a single host instance. We’ve added the command line option --replication-factor to the code, indicating the number of replications we need. Then we assign this parameter to each IR, as shown in the code below, in build_train_ir and build_test_ir.

ir.replication_factor = opts.replication_factor

When replication is used, the total number of data processed in each batch equals to the batch size in one replica multiplied by the replication factor. In this case, we need to provide enough input data to run in multiple replicas each step. To have the same training results from different replication factors, we need to guarantee that the batch_size * replication_factor stays the same. We can achieve this when preparing the dataset by simply keeping the code for batch size unchanged (when replication_factor equal to 1), and multiplying the replication factor as shown in the code below.

training_data, test_data = get_mnist_data(
    opts.test_batch_size * opts.replication_factor,
    opts.batch_size * opts.replication_factor)

We also need to change the data passed to each session to match the required dimension of its inputs in train (Section 11.5, Data input shape):

369        for data, labels in bar:
370            if opts.replication_factor > 1:
371                data = data.reshape((opts.replication_factor, opts.batch_size, 28, 28))
372                labels = labels.reshape((opts.replication_factor, opts.batch_size))
373                inputs: Mapping[popxl.HostToDeviceStream, np.ndarray] = dict(
374                    zip(input_streams, [data.float().numpy(), labels.int().numpy()])
375                )
376            else:
377                inputs: Mapping[popxl.HostToDeviceStream, np.ndarray] = dict(
378                    zip(
379                        input_streams,
380                        [data.squeeze().float().numpy(), labels.int().numpy()],
381                    )
382                )

After making similar changes of data shape for the test session, replication is also supported in the testing of the trained model. You can check whether the replication works by running:

python mnist_rts.py --replication-factor 2 --batch-size 4

It should give similar test accuracy to the following command

python mnist_rts.py --replication-factor 1 --batch-size 8

15.2. Change variable tensors to RTS variable tensors

We can create RTS variables for training and testing in two different ways. One which exposes the remote buffer (using remote_replica_sharded_variable()) and one which does not expose the remote buffer (using replica_sharded_variable()). In the code we:

To collect all the info needed by an RTS variable, we’ve used the named tuple

Trainable = namedtuple('Trainable', ['var', 'shards', 'full', 'remote_buffer'])

The var is the remote variable tensor, the shards is the shards after remote load operation, the full is the tensor after all gather operation, and the remote_buffer is the corresponding remote buffer that handles the variable if known. If a variable tensor is not an RTS variable, then shards, full, and remote_buffer will be None.

The Trainable for W0, trainable_w0, is created as shown in the code below:

137    if ir.replication_factor > 1 and opts.rts:
138        W0_remote, W0_shards = popxl.replica_sharded_variable(
139            W0_data, dtype=popxl.float32, name="W0"
140        )
141        W0 = ops.collectives.replicated_all_gather(W0_shards)
142        trainable_w0 = Trainable(W0_remote, W0_shards, W0, None)
143    else:
144        W0 = popxl.variable(W0_data, name="W0")
145        trainable_w0 = Trainable(W0, None, None, None)

The Trainable for W1, trainable_w1, is created as shown in the code below:

155    if ir.replication_factor > 1 and opts.rts:
156        # create remote buffer that match shard shape and dtype
157        var_shard_shape: Tuple[int, ...] = (W1_data.size // ir.replication_factor,)
158        buffer = popxl.remote_buffer(
159            var_shard_shape, popxl.float32, entries=ir.replication_factor
160        )
161        # create remote rts variable
162        W1_remote = popxl.remote_replica_sharded_variable(W1_data, buffer, 0)
163        # load the remote rts variable from each shard
164        W1_shards = ops.remote_load(buffer, 0)
165        # gather all the shards to get the full weight
166        W1 = ops.collectives.replicated_all_gather(W1_shards)
167        trainable_w1 = Trainable(W1_remote, W1_shards, W1, buffer)
168    else:
169        W1 = popxl.variable(W1_data, name="W1")
170        trainable_w1 = Trainable(W1, None, None, None)

Notice that when we get the gradient for an input RTS variable tensor, the tensor is the (non-sharded) “full” tensor which has been gathered from the shards by using replicated_all_gather(). After obtaining gradients for the “full” tensors, the gradients are then sliced by using replicated_reduce_scatter() to update each shard of the RTS variable tensor.

......
if params["W1"].shards is not None:
    grad_w_1 = ops.collectives.replica_sharded_slice(grad_w_1)

......
if params["W0"].shards is not None:
    grad_w_0 = ops.collectives.replica_sharded_slice(grad_w_0)

When you update a variable tensor, and if a remote buffer is used, you also need to restore the updated value to the right place as well.

266def update_weights_bias(opts, grads, params) -> None:
267    """
268    Update weights and bias by W += - lr * grads_w, b += - lr * grads_b.
269    """
270    for k, v in params.items():
271        if v.shards is not None:
272            # rts variable
273            ops.scaled_add_(v.shards, grads[k], b=-opts.lr)
274            if v.remote_buffer is not None:
275                ops.remote_store(v.remote_buffer, 0, v.shards)
276        else:
277            # not rts variable
278            ops.scaled_add_(v.var, grads[k], b=-opts.lr)
279
280