TPU unit with 4 cores

Last week, we talked about training an image classifier on the CIFAR-10 dataset using Google Colab on a Tesla K80 GPU in the cloud. This time, we will instead carry out the classifier training on a Tensor Processing Unit (TPU).

Because training and running deep learning models can be computationally demanding, we built the Tensor Processing Unit (TPU), an ASIC designed from the ground up for machine learning that powers several of our major products, including Translate, Photos, Search, Assistant, and Gmail.

TPU’s have been recently added to the Google Colab portfolio making it even more attractive for quick-and-dirty machine learning projects when your own local processing units are just not fast enough. While the Tesla K80 available in Google Colab delivers respectable 1.87 TFlops and has 12GB RAM, the TPUv2 available from within Google Colab comes with a whopping 180 TFlops, give or take. It also comes with 64 GB High Bandwidth Memory (HBM).

Enabling TPU support in the notebook

In order to try out the TPU on a concrete project, we will work with a Colab notebook, in which a Keras model is trained on classifying the CIFAR-10 dataset. It can be found HERE.

If you would just like to execute the TPU-compatible notebook, you can find it HERE. Otherwise, just follow the next simple steps to add TPU support to an existing notebook.

Enabling TPU support for the notebook is really straightforward. First, let’s change the runtime settings:


And choose TPU as the hardware accelerator:


Code adjustments

We also have to make minor adjustments to the Python code in the notebook. We add a new cell anywhere in the notebook in which we check that the TPU devices are properly recognized in the environment:

import os
import pprint
import tensorflow as tf

if 'COLAB_TPU_ADDR' not in os.environ:
  print('ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!')
  tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
  print ('TPU address is', tpu_address)

  with tf.Session(tpu_address) as session:
    devices = session.list_devices()
  print('TPU devices:')

This should output a list of 8 TPU devices available in our Colab environment. In order to run the tf.keras model on a TPU, we have to convert it to a TPU-model using the tf.contrib.tpu.keras_to_tpu module. Luckily, the module takes care of everything for us, leaving us with a couple of lines of boilerplate code.

# This address identifies the TPU we'll use when configuring TensorFlow.
TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']

resnet_model = tf.contrib.tpu.keras_to_tpu_model(

In case your model is defined using the recently presented TensorFlow Estimator API, you only have to make some minor adjustments to your Estimator’s model_fn like so:

# .... body of model_fn

optimizer = tf.train.AdamOptimizer()
  if FLAGS.use_tpu:
    optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

  train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
#   return tf.estimator.EstimatorSpec(   # CPU or GPU estimator 
#     mode=mode,
#     loss=loss,
#     train_op=train_op,
#     predictions=predictions)

  return tf.contrib.tpu.TPUEstimatorSpec(  # TPU estimator 

You can find an example of a TPUEstimator in the TensorFlow GitHub repository.

You should also consider increasing the batch size for training and validation of your model. Since we have 8 TPU units available, a batch size of \(8 \times 128\) might be reasonable — depending on your model’s size. Generally speaking, a batch size of \(8 \times 8^n\), \(n\) being \(1, 2, ...\) is advised. Due to the increased batch size, you can experiment with increasing the learning rate as well, making training even faster.

Performance gains

Compiling the TPU model and initializing the optimizer shards takes time. Depending on the Colab environment workload, it might take a couple of minutes until the first epoch and all the necessary previous initializations have been completed. However, once the TPU model is up and running, it is lightning fast.

Using the Resnet model discussed in the previous post, one epoch takes approximately 25 secs compared to the approx. 7 minutes on the Tesla K80 GPU, resulting in a speedup of almost 17.