16. Application example: MNIST with replication and RTS
In this section, we use RTS variables (Section 12.3, Variable tensors for replicated tensor sharding) based on the previous MNIST application example (Section 15, Application example: MNIST). Recall that RTS is based on replication, so first of all we need to change the code to support replication. Then we need to change the variable tensors into RTS variable tensors.
16.1. Add support for replications
The replication,
a data parallelism, is achieved by running the same program in parallel on multiple sets of IPUs. PopXL currently support
local replication.
In local replication, replications are handled by a single host instance. We’ve added the command line option --replication-factor
to the code, indicating the number of replications we need. Then we assign this parameter to each IR, as
shown in the code below, in build_train_ir
and build_test_ir
.
ir.replication_factor = opts.replication_factor
When replication is used, the total number of data processed in each batch equals to the batch size in one replica multiplied by
the replication factor. In this case, we need to provide enough input data to run in multiple replicas each step. To have
the same training results from different replication factors, we need to guarantee that the batch_size * replication_factor
stays the same. We can achieve this when preparing the dataset by simply keeping the code for batch size unchanged (when
replication_factor
equal to 1), and multiplying the replication factor as shown in the code below.
training_data, test_data = get_mnist_data(
opts.test_batch_size * opts.replication_factor,
opts.batch_size * opts.replication_factor)
We also need to change the data passed to each session to match the required dimension of its inputs in train
(Section 11.5, Data input shape):
315 for data, labels in bar:
316 if opts.replication_factor > 1:
317 data = data.reshape((opts.replication_factor, opts.batch_size, 28, 28))
318 labels = labels.reshape((opts.replication_factor, opts.batch_size))
319 inputs: Mapping[popxl.HostToDeviceStream, np.ndarray] = dict(
320 zip(input_streams, [data.float().numpy(), labels.int().numpy()])
321 )
322 else:
323 inputs: Mapping[popxl.HostToDeviceStream, np.ndarray] = dict(
324 zip(
325 input_streams,
326 [data.squeeze().float().numpy(), labels.int().numpy()],
327 )
328 )
After making similar changes of data shape for the test session, replication is also supported in the testing of the trained model. You can check whether the replication works by running:
python mnist_rts.py --replication-factor 2 --batch-size 4
It should give similar test accuracy to the following command
python mnist_rts.py --replication-factor 1 --batch-size 8
16.2. Change variable tensors to RTS variable tensors
We can create RTS variables for training and testing in two different ways. One which exposes the remote buffer
(using remote_replica_sharded_variable()
) and one which does not expose the remote buffer (using
replica_sharded_variable()
). In the code we:
Create variable tensor
W0
by usingreplica_sharded_variable()
;Create variable tensor
W1
by usingremote_replica_sharded_variable()
.
To collect all the info needed by an RTS variable, we’ve used the named tuple
Trainable = namedtuple('Trainable', ['var', 'shards', 'full', 'remote_buffer'])
The var
is the remote variable tensor, the shards
is the shards after remote load operation,
the full
is the tensor after all gather operation, and the remote_buffer
is the corresponding
remote buffer that handles the variable if known. If a variable tensor is not an RTS variable, then
shards
, full
, and remote_buffer
will be None
.
The Trainable
for W0
, trainable_w0
, is created as shown in the code below:
83 if ir.replication_factor > 1 and opts.rts:
84 W0_remote, W0_shards = popxl.replica_sharded_variable(
85 W0_data, dtype=popxl.float32, name="W0"
86 )
87 W0 = ops.collectives.replicated_all_gather(W0_shards)
88 trainable_w0 = Trainable(W0_remote, W0_shards, W0, None)
89 else:
90 W0 = popxl.variable(W0_data, name="W0")
91 trainable_w0 = Trainable(W0, None, None, None)
The Trainable
for W1
, trainable_w1
, is created as shown in the code below:
101 if ir.replication_factor > 1 and opts.rts:
102 # create remote buffer that match shard shape and dtype
103 var_shard_shape: Tuple[int, ...] = (W1_data.size // ir.replication_factor,)
104 buffer = popxl.remote_buffer(
105 var_shard_shape, popxl.float32, entries=ir.replication_factor
106 )
107 # create remote rts variable
108 W1_remote = popxl.remote_replica_sharded_variable(W1_data, buffer, 0)
109 # load the remote rts variable from each shard
110 W1_shards = ops.remote_load(buffer, 0)
111 # gather all the shards to get the full weight
112 W1 = ops.collectives.replicated_all_gather(W1_shards)
113 trainable_w1 = Trainable(W1_remote, W1_shards, W1, buffer)
114 else:
115 W1 = popxl.variable(W1_data, name="W1")
116 trainable_w1 = Trainable(W1, None, None, None)
Notice that when we get the gradient for an input RTS variable tensor, the tensor is the (non-sharded) “full”
tensor which has been gathered from the shards by using replicated_all_gather()
.
After obtaining gradients for the “full” tensors, the gradients are then sliced by using replicated_reduce_scatter()
to update each shard of the RTS variable tensor.
......
if params["W1"].shards is not None:
grad_w_1 = ops.collectives.replica_sharded_slice(grad_w_1)
......
if params["W0"].shards is not None:
grad_w_0 = ops.collectives.replica_sharded_slice(grad_w_0)
When you update a variable tensor, and if a remote buffer is used, you also need to restore the updated value to the right place as well.
212def update_weights_bias(opts, grads, params) -> None:
213 """
214 Update weights and bias by W += - lr * grads_w, b += - lr * grads_b.
215 """
216 for k, v in params.items():
217 if v.shards is not None:
218 # rts variable
219 ops.scaled_add_(v.shards, grads[k], b=-opts.lr)
220 if v.remote_buffer is not None:
221 ops.remote_store(v.remote_buffer, 0, v.shards)
222 else:
223 # not rts variable
224 ops.scaled_add_(v.var, grads[k], b=-opts.lr)
225
226