7. Example using IPUEstimator

This example shows how to use the IPUEstimator to train a simple CNN on the CIFAR-10 dataset. The XLA compilation is already handled while using the IPUEstimator, so the model_fn should not be manually compiled with ipu_compiler.

  1import argparse
  2import time
  3
  4import tensorflow.compat.v1 as tf
  5
  6from tensorflow.keras import Sequential
  7from tensorflow.keras.datasets import cifar10
  8from tensorflow.keras.layers import Conv2D, MaxPooling2D
  9from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten
 10from tensorflow.python import ipu
 11
 12NUM_CLASSES = 10
 13
 14
 15def model_fn(features, labels, mode, params):
 16  """A simple CNN based on https://keras.io/examples/cifar10_cnn/"""
 17
 18  model = Sequential()
 19  model.add(Conv2D(16, (3, 3), padding="same"))
 20  model.add(Activation("relu"))
 21  model.add(Conv2D(16, (3, 3)))
 22  model.add(Activation("relu"))
 23  model.add(MaxPooling2D(pool_size=(2, 2)))
 24  model.add(Dropout(0.25))
 25
 26  model.add(Conv2D(32, (3, 3), padding="same"))
 27  model.add(Activation("relu"))
 28  model.add(Conv2D(32, (3, 3)))
 29  model.add(Activation("relu"))
 30  model.add(MaxPooling2D(pool_size=(2, 2)))
 31  model.add(Dropout(0.25))
 32
 33  model.add(Flatten())
 34  model.add(Dense(256))
 35  model.add(Activation("relu"))
 36  model.add(Dropout(0.5))
 37  model.add(Dense(NUM_CLASSES))
 38
 39  logits = model(features, training=mode == tf.estimator.ModeKeys.TRAIN)
 40
 41  loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
 42
 43  if mode == tf.estimator.ModeKeys.EVAL:
 44    predictions = tf.argmax(input=logits, axis=-1)
 45    eval_metric_ops = {
 46        "accuracy": tf.metrics.accuracy(labels=labels,
 47                                        predictions=predictions),
 48    }
 49    return tf.estimator.EstimatorSpec(mode,
 50                                      loss=loss,
 51                                      eval_metric_ops=eval_metric_ops)
 52
 53  if mode == tf.estimator.ModeKeys.TRAIN:
 54    optimizer = tf.train.GradientDescentOptimizer(params["learning_rate"])
 55    if params["replicas"] > 1:
 56      optimizer = ipu.cross_replica_optimizer.CrossReplicaOptimizer(optimizer)
 57    train_op = optimizer.minimize(loss=loss)
 58    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
 59
 60  raise NotImplementedError(mode)
 61
 62
 63def parse_args():
 64  parser = argparse.ArgumentParser()
 65
 66  parser.add_argument(
 67      "--test-only",
 68      action="store_true",
 69      help="Skip training and test using latest checkpoint from model_dir.")
 70
 71  parser.add_argument("--batch-size",
 72                      type=int,
 73                      default=32,
 74                      help="The batch size.")
 75
 76  parser.add_argument(
 77      "--iterations-per-loop",
 78      type=int,
 79      default=100,
 80      help="The number of iterations (batches) per loop on IPU.")
 81
 82  parser.add_argument("--log-interval",
 83                      type=int,
 84                      default=10,
 85                      help="Interval at which to log progress.")
 86
 87  parser.add_argument("--summary-interval",
 88                      type=int,
 89                      default=1,
 90                      help="Interval at which to write summaries.")
 91
 92  parser.add_argument("--training-steps",
 93                      type=int,
 94                      default=200000,
 95                      help="Total number of training steps.")
 96
 97  parser.add_argument(
 98      "--learning-rate",
 99      type=float,
100      default=0.01,
101      help="The learning rate used with stochastic gradient descent.")
102
103  parser.add_argument(
104      "--replicas",
105      type=int,
106      default=1,
107      help="The replication factor. Increases the number of IPUs "
108      "used and the effective batch size by this factor.")
109
110  parser.add_argument(
111      "--model-dir",
112      help="Directory where checkpoints and summaries are stored.")
113
114  return parser.parse_args()
115
116
117def create_ipu_estimator(args):
118  ipu_options = ipu.config.IPUConfig()
119  ipu_options.auto_select_ipus = args.replicas
120
121  ipu_run_config = ipu.ipu_run_config.IPURunConfig(
122      iterations_per_loop=args.iterations_per_loop,
123      num_replicas=args.replicas,
124      ipu_options=ipu_options,
125  )
126
127  config = ipu.ipu_run_config.RunConfig(
128      ipu_run_config=ipu_run_config,
129      log_step_count_steps=args.log_interval,
130      save_summary_steps=args.summary_interval,
131      model_dir=args.model_dir,
132  )
133
134  return ipu.ipu_estimator.IPUEstimator(
135      config=config,
136      model_fn=model_fn,
137      params={
138          "learning_rate": args.learning_rate,
139          "replicas": args.replicas
140      },
141  )
142
143
144def train(ipu_estimator, args, x_train, y_train):
145  """Train a model on IPU and save checkpoints to the given `args.model_dir`."""
146  def input_fn():
147    # If using Dataset.from_tensor_slices(), the data will be embedded
148    # into the graph as constants, which makes the training graph very
149    # large and impractical. So use Dataset.from_generator() here instead,
150    # but add prefetching and caching to improve performance.
151
152    def generator():
153      return zip(x_train, y_train)
154
155    types = (x_train.dtype, y_train.dtype)
156    shapes = (x_train.shape[1:], y_train.shape[1:])
157
158    dataset = tf.data.Dataset.from_generator(generator, types, shapes)
159    dataset = dataset.prefetch(len(x_train)).cache()
160    dataset = dataset.repeat()
161    dataset = dataset.shuffle(len(x_train))
162    dataset = dataset.batch(args.batch_size, drop_remainder=True)
163
164    return dataset
165
166  # Training progress is logged as INFO, so enable that logging level
167  tf.logging.set_verbosity(tf.logging.INFO)
168
169  t0 = time.time()
170  ipu_estimator.train(input_fn=input_fn, steps=args.training_steps)
171  t1 = time.time()
172
173  duration_seconds = t1 - t0
174  images_per_step = args.batch_size * args.replicas
175  images_per_second = args.training_steps * images_per_step / duration_seconds
176  print("Took {:.2f} minutes, i.e. {:.0f} images per second".format(
177      duration_seconds / 60, images_per_second))
178
179
180def calc_batch_size(num_examples, batches_per_loop, batch_size):
181  """Reduce the batch size if needed to cover all examples without a remainder."""
182  assert batch_size > 0
183  assert num_examples % batches_per_loop == 0
184  while num_examples % (batch_size * batches_per_loop) != 0:
185    batch_size -= 1
186  return batch_size
187
188
189def test(ipu_estimator, args, x_test, y_test):
190  """Test the model on IPU by loading weights from the final checkpoint in the
191  given `args.model_dir`."""
192
193  num_test_examples = len(x_test)
194
195  batches_per_loop = args.replicas * args.iterations_per_loop
196  test_batch_size = calc_batch_size(num_test_examples, batches_per_loop,
197                                    args.batch_size)
198
199  if test_batch_size != args.batch_size:
200    print("Test batch size changed to {}.".format(test_batch_size))
201
202  def input_fn():
203    dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
204    dataset = dataset.batch(test_batch_size, drop_remainder=True)
205    return dataset
206
207  num_steps = num_test_examples // (test_batch_size * args.replicas)
208  metrics = ipu_estimator.evaluate(input_fn=input_fn, steps=num_steps)
209  test_loss = metrics["loss"]
210  test_accuracy = metrics["accuracy"]
211
212  print("Test loss: {:g}".format(test_loss))
213  print("Test accuracy: {:.2f}%".format(100 * test_accuracy))
214
215
216def main():
217  args = parse_args()
218  train_data, test_data = cifar10.load_data()
219
220  num_test_examples = len(test_data[0])
221  batches_per_loop = args.replicas * args.iterations_per_loop
222  if num_test_examples % batches_per_loop != 0:
223    raise ValueError(("replicas * iterations_per_loop ({} * {}) must evenly " +
224                      "divide the number of test examples ({})").format(
225                          args.replicas, args.iterations_per_loop,
226                          num_test_examples))
227
228  ipu_estimator = create_ipu_estimator(args)
229
230  def normalise(x, y):
231    return x.astype("float32") / 255.0, y.astype("int32")
232
233  if not args.test_only:
234    print("Training...")
235    x_train, y_train = normalise(*train_data)
236    train(ipu_estimator, args, x_train, y_train)
237
238  print("Testing...")
239  x_test, y_test = normalise(*test_data)
240  test(ipu_estimator, args, x_test, y_test)
241
242
243if __name__ == "__main__":
244  main()