10. Distributed training
We support distributed training for two different types of systems:
IPU-POD systems: An IPU-M2000-based system where IPUs in a rack are interconnected by IPU-Links, and IPUs in different racks are interconnected by GW-Links. Distributed training on IPU-PODs uses these links to perform collective operations without host involvement. When using multiple instances (host processes), there may however still be a need for communication over the host network, for example for broadcasting the initial values of variables from the first instance to the others.
IPU-Server systems: A Mk1 PCIe card-based system with IPUs interconnected by IPU-Links. IPUs in distinct IPU-Servers are not directly interconnected. Distributed training on IPU-Servers therefore uses the host network for communication. A collective operation is typically performed in a hierarchical fashion where the IPU-Links are used first for intra-server communication, and then the host network is used for inter-server communication.
We provide distribution strategies that are designed for these two types of systems, and with different implementations of the host communication:
IPUMultiReplicaStrategy
IPUHorovodStrategy
Their main differences can be summarized like this:
Distribution strategy |
System |
Host communication |
---|---|---|
|
IPU-POD |
Horovod (OpenMPI) |
|
IPU-Server |
Horovod (OpenMPI) |
|
IPU-Server |
gRPC |
There are some things they have in common:
They all perform data-parallel synchronous training using multiple host processes. In this sense they are all similar to the MultiWorkerMirroredStrategy provided in standard TensorFlow.
They all broadcast the initial values of variables over the host network (using either Horovod or gRPC as described above).
And these are the main differences:
With the
IPUMultiReplicaStrategy
designed for IPU-POD systems, a collective operation (performed either explicitly by calling a member function likereduce()
or implicitly by using an optimizer under the strategy scope) will be performed directly on the IPU by using compiled communications with the GCL library over the IPU-Links and GW-Links. TheIPUMultiReplicaStrategy
is designed for use with PopDist and PopRun. Please refer to the PopDist and PopRun User Guide for more details.With the two distribution strategies designed for IPU-Server systems, an equivalent collective operation will involve a transfer of the tensor from the IPU to the host for performing the collective communication over the host network (using either Horovod with OpenMPI or gRPC). A local (cross-replica) collective operation can be performed by using the
cross_replica_ops
.
A distinction should be made between these distribution strategies and
the IPUStrategy
provided in TensorFlow 2. The IPUStrategy
targets
a single system with one or more IPUs attached, while the distribution
strategies we discuss here target distributed systems like those described
above (IPU-PODs or multiple IPU-Servers). Also, unlike the IPUStrategy
,
these distribution strategies do not currently support the Keras
Model.fit()
family of APIs, and the use of ipu_compiler.compile()
is still required to ensure a single XLA graph is compiled, except when
using IPUEstimator
or IPUPipelineEstimator
which already use it
internally.
10.1. Example using IPUMultiWorkerStrategy
This example shows how to use the IPUEstimator
with the
IPUMultiWorkerStrategyV1
to perform distributed training of
a model on the MNIST dataset.
The example is based on the following official tutorial with some modifications for use with the IPU: https://www.tensorflow.org/tutorials/distribute/multi_worker_with_estimator
We highlight the changes needed to convert code using IPUEstimator
to support distributed training below.
10.1.1. The input function
In multi-worker training, it is necessary to shard the dataset such that each worker processes distinct portions of the dataset.
When used in a distributed context, the input function is passed an
additional argument input_context
that can be used to get the
current worker index and the total number of workers. We pass this
information to the Dataset.shard()
function to perform the
sharding.
Note that the batch size provided by the input function is the per-worker batch size. The global batch size will be this multiplied by the number of workers.
10.1.2. The model function
The optimiser will automatically divide the loss by the number of workers, so in the model function we should only divide the loss by the local batch size.
We will do some changes to how we update the weights of the model.
Instead of using the high-level Optimizer.minimize()
function,
we will use the Optimizer.compute_gradients()
and
Optimizer.apply_gradients()
separately in order to control
their placement. The Optimizer.compute_gradients()
call (the
backward pass) is placed on the IPU, while the
Optimizer.apply_gradients()
call (the allreduce of gradients and
weight updates) is placed on the host. This is done by using the
host_call
parameter in IPUEstimatorSpec
.
In practice this means that the gradients will be streamed from the IPU to the host as soon as they are computed. The workers will then start reducing the gradients amongst themselves, allowing overlap between the backward pass on the IPUs with the reductions on the hosts. After a gradient is reduced across the workers, the corresponding weight update is also done on the host.
The reduction is done using a ring-based collectives implementation with gRPC as the cross-host communication layer.
One benefit of this approach is that any additional optimiser state (such as momentum) is only needed in host memory, so there is no additional IPU memory consumption when using stateful optimisers with this approach.
10.1.3. Cluster definition
We use the TFConfigClusterResolver
which reads the TF_CONFIG
environment variable to determine the cluster definition.
There are two components of TF_CONFIG
: cluster
and task
.
cluster
provides information about the entire cluster, namely the workers and parameter servers in the cluster.task
provides information about the current task.
In this example, the task type
is worker
and the task index
is 0.
You could run this example with two workers on the same machine
(in different terminals) like this:
$ TF_CONFIG='{"cluster":{"worker":["localhost:3737","localhost:3738"]},"task":{"type":"worker","index":0}}' python distributed_training_example.py
$ TF_CONFIG='{"cluster":{"worker":["localhost:3737","localhost:3738"]},"task":{"type":"worker","index":1}}' python distributed_training_example.py
10.1.4. Complete example
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15import argparse
16import numpy as np
17
18import tensorflow.compat.v1 as tf
19
20from tensorflow.python import ipu
21
22BATCH_SIZE = 64
23
24
25def input_fn(mode, input_context=None): # pylint: disable=unused-argument
26 train_data, _ = tf.keras.datasets.mnist.load_data()
27
28 def normalise(image, label):
29 image = image.astype(np.float32) / 255.0
30 image = np.expand_dims(image, axis=-1)
31 label = label.astype(np.int32)
32 return image, label
33
34 x_train, y_train = normalise(*train_data)
35
36 def generator():
37 return zip(x_train, y_train)
38
39 types = (x_train.dtype, y_train.dtype)
40 shapes = (x_train.shape[1:], y_train.shape[1:])
41 mnist_dataset = tf.data.Dataset.from_generator(generator, types, shapes)
42
43 if input_context:
44 mnist_dataset = mnist_dataset.shard(input_context.num_input_pipelines,
45 input_context.input_pipeline_id)
46
47 mnist_dataset = mnist_dataset.shuffle(len(y_train)) \
48 .cache().batch(BATCH_SIZE, drop_remainder=True).repeat()
49 return mnist_dataset
50
51
52def model_fn(features, labels, mode):
53 model = tf.keras.Sequential([
54 tf.keras.layers.Conv2D(8, 3, activation="relu"),
55 tf.keras.layers.MaxPooling2D(),
56 tf.keras.layers.Flatten(),
57 tf.keras.layers.Dense(8, activation="relu"),
58 tf.keras.layers.Dense(10)
59 ])
60 logits = model(features, training=mode == tf.estimator.ModeKeys.TRAIN)
61
62 if mode == tf.estimator.ModeKeys.PREDICT:
63 predictions = {"logits": logits}
64 return tf.estimator.EstimatorSpec(labels=labels, predictions=predictions)
65
66 optimizer = tf.compat.v1.train.AdamOptimizer()
67 loss = tf.keras.losses.SparseCategoricalCrossentropy(
68 from_logits=True, reduction=tf.compat.v1.losses.Reduction.NONE)(labels,
69 logits)
70 loss = tf.reduce_sum(loss) * (1. / BATCH_SIZE)
71 if mode == tf.estimator.ModeKeys.EVAL:
72 predictions = tf.argmax(input=logits, axis=-1)
73 eval_metric_ops = {
74 "accuracy":
75 tf.compat.v1.metrics.accuracy(labels=labels, predictions=predictions),
76 }
77 return tf.estimator.EstimatorSpec(mode,
78 loss=loss,
79 eval_metric_ops=eval_metric_ops)
80
81 variables = model.trainable_variables
82
83 def host_model_fn(*host_gradients):
84 # This will allreduce the gradients and update the weights on the host.
85 return optimizer.apply_gradients(zip(host_gradients, variables))
86
87 train_op = tf.identity(loss)
88 grads_and_vars = optimizer.compute_gradients(loss, var_list=variables)
89 gradients = [g for (g, _) in grads_and_vars]
90 host_call = (host_model_fn, gradients)
91
92 return ipu.ipu_estimator.IPUEstimatorSpec(mode=mode,
93 loss=loss,
94 train_op=train_op,
95 host_call=host_call)
96
97
98# Get the cluster configuration from the TF_CONFIG environment variable.
99cluster = tf.distribute.cluster_resolver.TFConfigClusterResolver()
100# Create strategy that places variables (including momentums) on the host.
101strategy = ipu.ipu_multi_worker_strategy.IPUMultiWorkerStrategyV1(
102 cluster, variables_on_host=True)
103
104ipu_options = ipu.config.IPUConfig()
105ipu_options.auto_select_ipus = 1
106ipu_run_config = ipu.ipu_run_config.IPURunConfig(ipu_options=ipu_options)
107
108config = ipu.ipu_run_config.RunConfig(
109 session_config=tf.ConfigProto(allow_soft_placement=False),
110 ipu_run_config=ipu_run_config,
111 train_distribute=strategy,
112)
113
114parser = argparse.ArgumentParser()
115parser.add_argument("--num-steps", type=int, default=10000)
116parser.add_argument("--model-dir")
117args = parser.parse_args()
118
119classifier = ipu.ipu_estimator.IPUEstimator(
120 config=config,
121 model_fn=model_fn,
122 model_dir=args.model_dir,
123)
124
125# Training progress is logged as INFO, so enable that logging level.
126tf.logging.set_verbosity(tf.logging.INFO)
127
128tf.estimator.train_and_evaluate(
129 classifier,
130 train_spec=tf.estimator.TrainSpec(input_fn=input_fn,
131 max_steps=args.num_steps),
132 eval_spec=tf.estimator.EvalSpec(input_fn=input_fn))
10.2. Distributed training with Horovod
Distributed training can also be performed using Horovod which is included in the TensorFlow wheel provided by Graphcore.
The class
IPUHorovodStrategyV1
can be used in the same manner as the
IPUMultiWorkerStrategyV1
.
While the IPUMultiWorkerStrategyV1
uses collective operations over gRPC, the
IPUHorovodStrategyV1
uses the collective operations provided by Horovod, based on
MPI. Horovod also has built-in cluster discovery, so there is no cluster resolver
argument that must be provided like there is for the IPUMultiWorkerStrategyV1
,
and there is no need for starting a tf.distribute.Server
.
Apart from these differences, the API and semantics should be the same for the
IPUHorovodStrategyV1
and IPUMultiWorkerStrategyV1
. In other words, they
both provide data parallel distributed training that keeps the variables in sync
on the different workers. During variable initialisation the values are broadcast
from the root rank to the other ranks, and during training the gradients are
all-reduced as a part of the Optimizer.apply_gradients
call.
10.3. Launching Horovod training
The mpirun
tool can be used to run the distributed training across a cluster.
For instance, running distributed training across two processes on the same machine
can be done with the following command:
$ mpirun -np 2 -H localhost:2 python distributed_training_horovod_example.py
10.4. Complete Horovod example
Below is a complete example using Horovod, adapted from the example above.
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15import argparse
16import numpy as np
17import tensorflow.compat.v1 as tf
18from tensorflow.python import ipu
19from tensorflow.python.ipu import horovod as hvd
20from tensorflow.python.ipu.horovod import ipu_horovod_strategy
21
22BATCH_SIZE = 64
23
24
25def input_fn(mode): # pylint: disable=unused-argument
26 train_data, _ = tf.keras.datasets.mnist.load_data()
27
28 def normalise(image, label):
29 image = image.astype(np.float32) / 255.0
30 image = np.expand_dims(image, axis=-1)
31 label = label.astype(np.int32)
32 return image, label
33
34 x_train, y_train = normalise(*train_data)
35
36 def generator():
37 return zip(x_train, y_train)
38
39 types = (x_train.dtype, y_train.dtype)
40 shapes = (x_train.shape[1:], y_train.shape[1:])
41 mnist_dataset = tf.data.Dataset.from_generator(generator, types, shapes)
42 mnist_dataset = mnist_dataset.shard(hvd.size(), hvd.rank())
43 mnist_dataset = mnist_dataset.shuffle(len(y_train)) \
44 .cache().batch(BATCH_SIZE, drop_remainder=True).repeat()
45 return mnist_dataset
46
47
48def model_fn(features, labels, mode):
49 model = tf.keras.Sequential([
50 tf.keras.layers.Conv2D(8, 3, activation="relu"),
51 tf.keras.layers.MaxPooling2D(),
52 tf.keras.layers.Flatten(),
53 tf.keras.layers.Dense(8, activation="relu"),
54 tf.keras.layers.Dense(10)
55 ])
56 logits = model(features, training=mode == tf.estimator.ModeKeys.TRAIN)
57
58 if mode == tf.estimator.ModeKeys.PREDICT:
59 predictions = {"logits": logits}
60 return tf.estimator.EstimatorSpec(labels=labels, predictions=predictions)
61
62 optimizer = tf.compat.v1.train.AdamOptimizer()
63 loss = tf.keras.losses.SparseCategoricalCrossentropy(
64 from_logits=True, reduction=tf.compat.v1.losses.Reduction.NONE)(labels,
65 logits)
66 loss = tf.reduce_sum(loss) * (1. / BATCH_SIZE)
67
68 variables = model.trainable_variables
69
70 def host_model_fn(*host_gradients):
71 # This will allreduce the gradients and update the weights on the host.
72 return optimizer.apply_gradients(zip(host_gradients, variables))
73
74 train_op = tf.identity(loss)
75 grads_and_vars = optimizer.compute_gradients(loss, var_list=variables)
76 gradients = [g for (g, _) in grads_and_vars]
77 host_call = (host_model_fn, gradients)
78
79 return ipu.ipu_estimator.IPUEstimatorSpec(mode=mode,
80 loss=loss,
81 train_op=train_op,
82 host_call=host_call)
83
84
85# Initialise the Horovod runtime.
86hvd.init()
87
88# Create a Horovod strategy that places variables on the host.
89strategy = ipu_horovod_strategy.IPUHorovodStrategyV1(variables_on_host=True)
90
91ipu_options = ipu.config.IPUConfig()
92ipu_options.auto_select_ipus = 1
93ipu_run_config = ipu.ipu_run_config.IPURunConfig(ipu_options=ipu_options)
94
95config = ipu.ipu_run_config.RunConfig(
96 session_config=tf.ConfigProto(allow_soft_placement=False),
97 ipu_run_config=ipu_run_config,
98 train_distribute=strategy,
99)
100
101parser = argparse.ArgumentParser()
102parser.add_argument("--num-steps", type=int, default=10000)
103parser.add_argument("--model-dir")
104args = parser.parse_args()
105
106classifier = ipu.ipu_estimator.IPUEstimator(
107 config=config,
108 model_fn=model_fn,
109 model_dir=args.model_dir,
110)
111
112# Training progress is logged as INFO, so enable that logging level.
113tf.logging.set_verbosity(tf.logging.INFO)
114classifier.train(input_fn=input_fn, max_steps=args.num_steps)