4. Keras with IPUs
The Graphcore implementation of TensorFlow includes Keras support for IPUs.
Keras model creation is no different than what you would use if you were
training on other devices. To target the Poplar XLA device, Keras model creation
must be inside the strategy.scope
of an IPUStrategy
.
For a more practical walkthrough, see this tutorial about using Keras on the IPU from the Graphcore tutorials repository.
4.1. Single IPU models
You can train, evaluate or run inference on single-IPU models through the Keras
APIs as you would with other accelerators, as long as you create the model
inside the scope of an IPUStrategy
:
1import tensorflow as tf
2from tensorflow.python import ipu
3
4# Configure the IPU device.
5config = ipu.config.IPUConfig()
6config.auto_select_ipus = 1
7config.configure_ipu_system()
8
9
10# Create a simple model.
11def create_model():
12 return tf.keras.Sequential([
13 tf.keras.layers.Flatten(),
14 tf.keras.layers.Dense(256, activation='relu'),
15 tf.keras.layers.Dense(128, activation='relu'),
16 tf.keras.layers.Dense(10)
17 ])
18
19
20# Create a dataset for the model.
21def create_dataset():
22 mnist = tf.keras.datasets.mnist
23
24 (x_train, y_train), (_, _) = mnist.load_data()
25 x_train = x_train / 255.0
26
27 train_ds = tf.data.Dataset.from_tensor_slices(
28 (x_train, y_train)).shuffle(10000).batch(32, drop_remainder=True)
29 train_ds = train_ds.map(lambda d, l:
30 (tf.cast(d, tf.float32), tf.cast(l, tf.int32)))
31
32 return train_ds.repeat().prefetch(16)
33
34
35dataset = create_dataset()
36
37# Create a strategy for execution on the IPU.
38strategy = ipu.ipu_strategy.IPUStrategy()
39with strategy.scope():
40 # Create a Keras model inside the strategy.
41 model = create_model()
42
43 # Compile the model for training.
44 model.compile(
45 loss=tf.keras.losses.SparseCategoricalCrossentropy(),
46 optimizer=tf.keras.optimizers.RMSprop(),
47 metrics=["accuracy"],
48 )
49
50 model.fit(dataset, epochs=2, steps_per_epoch=100)
4.2. Using steps_per_execution
To reduce Python overhead and maximize the performance of your model, pass in
the steps_per_execution
argument to the compile method. This argument sets
the number of batches to process sequentially in a single execution. You should
increase this number to improve accelerator utilization.
Note
In order to achieve best performance, steps_per_execution
needs to be set
before using fit()
, evaluate()
and predict()
, even if no training
is performed.
See the documentation for the compile method for full details.
The example below highlights the usage of steps_per_execution
:
1import tensorflow as tf
2from tensorflow.python import ipu
3
4# Configure the IPU device.
5config = ipu.config.IPUConfig()
6config.auto_select_ipus = 1
7config.configure_ipu_system()
8
9
10# Create a simple model.
11def create_model():
12 return tf.keras.Sequential([
13 tf.keras.layers.Flatten(),
14 tf.keras.layers.Dense(256, activation='relu'),
15 tf.keras.layers.Dense(128, activation='relu'),
16 tf.keras.layers.Dense(10)
17 ])
18
19
20# Create a dataset for the model.
21def create_dataset():
22 mnist = tf.keras.datasets.mnist
23
24 (x_train, y_train), (_, _) = mnist.load_data()
25 x_train = x_train / 255.0
26
27 train_ds = tf.data.Dataset.from_tensor_slices(
28 (x_train, y_train)).shuffle(10000).batch(32, drop_remainder=True)
29 train_ds = train_ds.map(lambda d, l:
30 (tf.cast(d, tf.float32), tf.cast(l, tf.int32)))
31
32 return train_ds.repeat().prefetch(16)
33
34
35dataset = create_dataset()
36
37# Create a strategy for execution on the IPU.
38strategy = ipu.ipu_strategy.IPUStrategy()
39with strategy.scope():
40 # Create a Keras model inside the strategy.
41 model = create_model()
42
43 # Compile the model for training.
44 model.compile(
45 loss=tf.keras.losses.SparseCategoricalCrossentropy(),
46 optimizer=tf.keras.optimizers.RMSprop(),
47 metrics=["accuracy"],
48 # Anything between 2 and `steps_per_epoch` could help here.
49 steps_per_execution=50,
50 )
51
52 model.fit(dataset, epochs=2, steps_per_epoch=100)
4.3. Gradient accumulation
When training, gradient accumulation allows us to simulate bigger batch sizes. This is achieved by accumulating the gradients across multiple batches together then performing the weight update.
For example, if we have a model where each step is of batch size 16 and we use a gradient accumulation factor of 4 then this simulates an input batch of size 64.
Gradient accumulation can be easily enabled for Keras models created inside of
an IPUStrategy
by calling the
set_gradient_accumulation_options()
method for Functional Keras models and the
set_gradient_accumulation_options()
method for Sequential Keras models. See the respective method documentation
for more details.
Note
When using data-parallelism, the steps_per_execution
value the model was
compiled with must be an integer multiple of
gradient_accumulation_steps_per_replica
multiplied by the number of
replicas in the model. Data parallelism is discussed in the
keras-data-parallelism
section below.
Note
Not all operations are compatible with gradient accumulation.
The example below highlights the usage of set_gradient_accumulation_options
:
1import tensorflow as tf
2from tensorflow.python import ipu
3
4# Configure the IPU device.
5config = ipu.config.IPUConfig()
6config.auto_select_ipus = 1
7config.configure_ipu_system()
8
9
10# Create a simple model.
11def create_model():
12 return tf.keras.Sequential([
13 tf.keras.layers.Flatten(),
14 tf.keras.layers.Dense(256, activation='relu'),
15 tf.keras.layers.Dense(128, activation='relu'),
16 tf.keras.layers.Dense(10)
17 ])
18
19
20# Create a dataset for the model.
21def create_dataset():
22 mnist = tf.keras.datasets.mnist
23
24 (x_train, y_train), (_, _) = mnist.load_data()
25 x_train = x_train / 255.0
26
27 train_ds = tf.data.Dataset.from_tensor_slices(
28 (x_train, y_train)).shuffle(10000).batch(32, drop_remainder=True)
29 train_ds = train_ds.map(lambda d, l:
30 (tf.cast(d, tf.float32), tf.cast(l, tf.int32)))
31
32 return train_ds.repeat().prefetch(16)
33
34
35dataset = create_dataset()
36
37# Create a strategy for execution on the IPU.
38strategy = ipu.ipu_strategy.IPUStrategy()
39with strategy.scope():
40 # Create a Keras model inside the strategy.
41 model = create_model()
42
43 # Compile the model for training.
44 model.compile(
45 loss=tf.keras.losses.SparseCategoricalCrossentropy(),
46 optimizer=tf.keras.optimizers.RMSprop(),
47 metrics=["accuracy"],
48 steps_per_execution=50,
49 )
50
51 model.set_gradient_accumulation_options(
52 gradient_accumulation_steps_per_replica=10)
53
54 model.fit(dataset, epochs=2, steps_per_epoch=100)
4.4. Model parallelism
The models described so far occupy a single IPU device, however some models might require the model layers to be split across multiple IPU devices to achieve high compute efficiency.
One method to achieve model parallelism is called pipelining, where the model layers are assigned to pipeline stages. Each pipeline stage can be assigned to a different device and different devices can execute in parallel.
The method to pipeline your model depends on whether your model is a
Sequential
or a Functional
model.
4.4.1. Sequential model
To enable IPU pipelining for a Sequential
model (an instance of
tensorflow.keras.Sequential
), a list of per-layer pipeline stage
assignments should be passed to the
set_pipeline_stage_assignment()
method of the model.
For example, a simple four layer Sequential
model could be assigned to two
different pipeline stages as follows:
1 model = tf.keras.Sequential([
2 tf.keras.layers.Dense(8), # Pipeline stage 0.
3 tf.keras.layers.Dense(16), # Pipeline stage 0.
4 tf.keras.layers.Dense(16), # Pipeline stage 1.
5 tf.keras.layers.Dense(1), # Pipeline stage 1.
6 ])
7
8 model.set_pipeline_stage_assignment([0, 0, 1, 1])
You can confirm which layers are assigned to which stages using the
print_pipeline_stage_assignment_summary()
method of the model.
4.4.2. Functional model
There are two ways to enable IPU pipelining for a Functional
model (an
instance of tensorflow.keras.Model
) depending on if you’re pipelining a model
you are writing yourself or an existing model.
Pipelining a model you are writing yourself
To pipeline a Functional
model you are writing yourself, each layer call
must happen within the scope of an ipu.keras.PipelineStage
context.
For example, a simple four layer Functional
model could be assigned to two
different pipeline stages as follows:
1 input_layer = tf.keras.layers.Input((28, 28))
2
3 with ipu.keras.PipelineStage(0):
4 x = tf.keras.layers.Dense(8)(input_layer)
5 x = tf.keras.layers.Dense(16)(x)
6
7 with ipu.keras.PipelineStage(1):
8 x = tf.keras.layers.Dense(16)(x)
9 x = tf.keras.layers.Dense(1)(x)
10
11 model = tf.keras.Model(inputs=input_layer, outputs=x)
Pipelining an existing functional model
To pipeline an existing Functional
model, you can use
get_pipeline_stage_assignment()
.
Each layer invocation in the model has an associated
FunctionalLayerPipelineStageAssignment
object, which indicates what pipeline stage that invocation is assigned to.
get_pipeline_stage_assignment
returns a list of these stage assignments,
which you can inspect and modify. Note that the list is in post-order, which
means the assignments are returned in the order they will be executed.
Once you are done modifying the stage assignments, you should use
set_pipeline_stage_assignment()
to set them on the model.
For example, a naive way of pipelining ResNet50 would be to assign everything up until the “conv4_block2_add” layer invocation to the first stage, then everything else to the second stage, as follows:
1strategy = ipu.ipu_strategy.IPUStrategy()
2with strategy.scope():
3
4 from tensorflow.keras.applications.resnet50 import ResNet50
5 model = ResNet50(weights='imagenet')
6
7 # Get the individual assignments - note that they are returned in post-order.
8 assignments = model.get_pipeline_stage_assignment()
9
10 # Iterate over them and set their pipeline stages.
11 stage_id = 0
12 for assignment in assignments:
13 assignment.pipeline_stage = stage_id
14 # Split the model on the `conv4_block2_add` layer.
15 if assignment.layer.name.startswith("conv4_block2_add"):
16 stage_id = 1
17
18 # Set the assignments to the model.
19 model.set_pipeline_stage_assignment(assignments)
20
21 model.print_pipeline_stage_assignment_summary()
Note
You can use print_pipeline_stage_assignment_summary()
to print the pipeline stage assignments of the model’s layer invocations.
Note
This method of pipelining can also be used with Functional
models you are
writing yourself, as well as Sequential
models using the
SequentialExtension
equivalents.
4.5. Automatic data parallelism
IPU TensorFlow supports automatic data parallelism when multiple IPU devices are configured with the system. Automatic data parallelism is achieved by model replication across available IPU devices. The number of times the model is replicated is called the replication factor; higher replication factors allow higher data throughput.
When replicating, gradients are reduced across replicas during training, which has implications for gradient accumulation. For a non replicated model, the effective batch size is the product of the dataset batch size and the number of gradient accumulation steps. In the case of a replication factor greater than one, the effective batch size is additionally scaled by the replication factor according to the following formula:
effective_batch_size = dataset_batch_size * gradient_accumulation_steps_per_replica * num_replicas
4.6. Asynchronous callbacks
IPU TensorFlow supports the use of Callback
objects with the Keras APIs,
however there is an important difference to note when specifying
steps_per_execution
. In IPU TensorFlow, if steps_per_execution
is specified
for your model, then per-batch callback functions will only be invoked every
steps_per_execution
steps, which can have the effect of delaying access to
results.
However, IPU TensorFlow also supports asynchronous callbacks by providing a
polling mechanism which allows results to be accessed at the earliest possible
instance. Asynchronous callbacks can be enabled by invoking
set_asynchronous_callbacks()
with True
on your Sequential
or Functional
Keras model.
4.7. Configuring Infeeds and Outfeed
Keras models created inside of an IPUStrategy
scope automatically create
IPUInfeedQueue
and IPUOutfeedQueue
data queues for efficiently feeding
data to and from the IPU devices when using fit()
, evaluate()
and
predict()
.
Instances of IPUInfeedQueue
and IPUOutfeedQueue
can be created with
optional arguments which can affect performance of the model.
For configuring the IPUInfeedQueue
use
set_infeed_queue_options()
on your Sequential
or Functional
Keras model.
For configuring the IPUOutfeedQueue
use
set_outfeed_queue_options()
on your Sequential
or Functional
Keras model.
For example the prefetch_depth
parameter of the IPUInfeedQueue
and the
buffer_depth
parameter of the IPUOutfeedQueue
can be configured as
follows:
1import tensorflow as tf
2from tensorflow.python import ipu
3
4# Configure the IPU device.
5config = ipu.config.IPUConfig()
6config.auto_select_ipus = 1
7config.configure_ipu_system()
8
9
10# Create a simple model.
11def create_model():
12 return tf.keras.Sequential([
13 tf.keras.layers.Flatten(),
14 tf.keras.layers.Dense(256, activation='relu'),
15 tf.keras.layers.Dense(128, activation='relu'),
16 tf.keras.layers.Dense(10)
17 ])
18
19
20# Create a strategy for execution on the IPU.
21strategy = ipu.ipu_strategy.IPUStrategy()
22with strategy.scope():
23
24 model = create_model()
25
26 # Set the infeed and outfeed options.
27 model.set_infeed_queue_options(prefetch_depth=2)
28 model.set_outfeed_queue_options(buffer_depth=2)
4.8. Porting models from TensorFlow 2.1
Previously, IPU TensorFlow included IPU-specific Keras model classes for
Functional
and Sequential
models. These classes no longer exist and must be
replaced with their standard Keras counterparts.
Specifically, use of the old IPUSequential
(or tensorflow.python.ipu.keras.Sequential
)
class should be changed to tensorflow.keras.Sequential
and use of the old
IPUModel
(or tensorflow.python.ipu.keras.Model
) class should be changed to
tensorflow.keras.Model
.
Any IPU-specific arguments to the old IPU-specific classes (such as
gradient_accumulation_count
) should also be removed and the behaviour they
specify achieved by the means outlined in this document.
For reference, the following table details APIs that have been removed and their replacements:
TF2.1 |
TF2.4 |
|
Removed, use |
|
Removed, use |
|
Removed, use |
|
Removed, use |
|
Removed, set via |
|
Removed, set via |
|
Removed |
|
Set via |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
As an example, the following snippets show equivalent TF2.1 and TF2.4 code for creating and fitting a pipelined sequential keras model.
4.8.1. TF2.1
strategy = ipu.ipu_strategy.IPUStrategy()
with strategy.scope():
# Using IPU-specific PipelineSequential model.
# IPU-specific arguments passed into model constructor.
model = ipu.keras.PipelineSequential(
[tf.keras.layers.Flatten(),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)],
gradient_accumulation_count=16,
device_mapping=[0, 0, 1, 1])
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
optimizer=tf.keras.optimizers.RMSprop()
)
model.fit(dataset, epochs=2, steps_per_epoch=128)
4.8.2. TF2.4
1strategy = ipu.ipu_strategy.IPUStrategy()
2with strategy.scope():
3 # Using standard keras Sequential model.
4 model = tf.keras.Sequential([
5 tf.keras.layers.Flatten(),
6 tf.keras.layers.Dense(256, activation='relu'),
7 tf.keras.layers.Dense(128, activation='relu'),
8 tf.keras.layers.Dense(10)
9 ])
10
11 # IPU-specific arguments passed into separate configuration methods.
12 model.set_pipeline_stage_assignment([0, 0, 1, 1])
13
14 # Replication factor is 1 in this example.
15 model.set_pipelining_options(gradient_accumulation_steps_per_replica=16)
16
17 # steps_per_execution specified to improve performance.
18 model.compile(steps_per_execution=256,
19 loss=tf.keras.losses.SparseCategoricalCrossentropy(),
20 optimizer=tf.keras.optimizers.RMSprop())
21
22 model.fit(dataset, epochs=2, steps_per_epoch=128)
4.9. Implementation details
When instantiating a standard TensorFlow Keras model inside the scope of
an IPUStrategy
instance, it is dynamically injected with additional,
IPU-specific, functions.
This is done through the relevant IPU Keras extension classes.
For tensorflow.keras.Sequential
, IPU-specific extensions exist in
SequentialExtension
and for
tensorflow.keras.Model
in
FunctionalExtension
.