Our Summer of Code Project on TF-GAN

Posted by Nived P A, Margaret Maynard-Reid, Joel Shor

Google Summer of Code is a program that brings student developers into open-source projects each summer. This article describes enhancements made to the TensorFlow GAN library (TF-GAN) last summer that were proposed by Nived PA, an undergraduate student of Amrita School of Engineering. The goal of Nived’s project was to improve the TF-GAN library by adding new tutorials, and adding new functionality to the library itself.

This article provides an overview of TF-GAN and our accomplishments from last summer. We will share our experience from the perspective of both the student and the mentors, and walk through one of the new tutorials Nived created, an ESRGAN TensorFlow implementation, and show you how easy it is to use TF-GAN to help with training and evaluation.

What is TF-GAN?

TF-GAN provides common building blocks and infrastructure support for training GANs, and offers easy-to-use, standard techniques for evaluating them. Using TF-GAN helps developers and researchers save time with common GAN tools, and avoids common pitfalls in implementations. In addition, TF-GAN offers a collection of famous examples that include GANs from the image and audio space, as well as GPU and TPU support.

Since its launch in 2017, the team has updated the infrastructure to work with TensorFlow 2.0, released a self-study GAN course viewed by over 150K people in 2020, and an ML Tech talk on GANs. The project itself has been downloaded over millions of times. Papers using TF-GAN have thousands of citations (e.g. 1, 2, 3, 4, 5).

The TF-GAN library can be divided into a number of independent parts, namely Core, Features, Losses, Evaluation and Examples. Each of these different parts can be used to simplify the training or evaluation process of GANs.

Project Scope

The Google Summer of Code 2021 project on TF-GAN was aimed at adding more recent GAN models as examples to the library and additionally add more tutorial notebooks that explored various functionalities of TF-GAN while training and evaluating state-of-the-art GAN models such as ESRGAN. Through this project new loss functions were also added to the library that can improve the training process of GANs. Next, we will walk through the ESRGAN code and demonstrate how to use TF-GAN to help with training and evaluation.

If you are new to GANs, a good start is to read this Intro to GANs post written by Margaret (who mentored this project), these GANs tutorials on tensorflow.org and the self-study GAN course on Machine Learning Crash Course as mentioned above.


Image super resolution is an important use case of GANs. Super resolution is the process of reconstructing a high resolution (HR) image from a given low resolution (LR) image. Super resolution can be applied to solve real world problems such as photo editing.

The SRGAN paper (Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network) introduced the concept of single-image super resolution and used residual blocks and perception loss to achieve that. The ESRGAN (Enhanced Super-Resolution Generative Adversarial Networks) paper enhanced SRGAN by introducing the Residual-in-Residual Dense Block (RRDB) without batch normalization as the basic building block, using relativistic loss and improved the perceptual loss.

Now let’s walk through how to implement ESRGAN with TensorFlow 2 and evaluate its performance with TF-GAN. There are two versions of Colab notebook: one using GPU and the other one using TPU. We will be going over the Colab notebook TPU version.


First let’s make sure that we are set up with Colab TPU and Google Cloud Storage bucket.

  1. Colab TPU
  2. To enable TPU runtime in Colab, go to Edit → Notebook Settings or Runtime→ change runtime type, and then select “TPU” from the Hardware Accelerator drop-down menu.

  3. Google Cloud Storage Bucket

In order to train with TPU, we need to first set up a Google Cloud Storage bucket to store dataset and model weights during training. Please refer to the Google Cloud documentation on Creating Storage buckets. After you create a storage bucket, let’s authenticate from Colab so that you can grant Google Cloud SDK access to the bucket:

bucket = 'enter-your-bucket-name-here'
tpu_address = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])

from google.colab import auth


You will be prompted to follow a link in your browser to authenticate the connection to the bucket. Click on the link will take you to a new browser tab. Follow the instructions there to get the verification code then go back to the Colab notebook to enter the code. Now you should be able to access the bucket for the rest of the notebook.

Training parameters

Now that we have enabled TPU for Colab and set up GCS cloud bucket to store training data and model weights, we first define some parameters that will be used from data loading to model training, such as the batch size, HR image resolution and the scale by which to downscale the image into LR etc.

Params = {
'batch_size' : 32, # Number of image samples used in each training step
'hr_dimension' : 256, # Dimension of a High Resolution (HR) Image
'scale' : 4, # Factor by which Low Resolution (LR) Images to be downscaled.
'data_name': 'div2k/bicubic_x4', # Dataset name - loaded using tfds.
'trunk_size' : 11, # Number of Residual blocks used in Generator


We are using the DIV2K dataset: DIVerse 2k resolution high quality images. We will load the data into our cloud bucket with TensorFlow Datasets (tfds) API.

We need both high resolution (HR) and low resolution (LR) data for training. So we will download the original images and scale them down to 96×96 for HR and 28×28 for LR.

Note: the data downloading and rescaling to store in the cloud bucket could take over 30 minutes.

Visualize the dataset

Let’s visualize the dataset downloaded and scaled:

img_lr, img_hr = next(iter(train_ds))

lr = Image.fromarray(np.array(img_lr)[0].astype(np.uint8))
lr = lr.resize([256, 256])

hr = Image.fromarray(np.array(img_hr)[0].astype(np.uint8))
hr = hr.resize([256, 256])
pic name pic name

Model architecture

We will first define the generator architecture, the discriminator architecture and the loss functions; and then put everything together to form the ESRGAN model.

Generator – as with most GAN generators, the ESRGAN generator upsamples the input a few times. What makes it different is the Residual-in-Residual Block (RRDB) without batch normalization.

In the generator we define the function for creating the Conv block, Dense block, RRDB block for upsampling. Then we define a function to create the generator network as follows with Keras Functional API:

def generator_network(filter=32,
lr_input = layers.Input(shape=(None, None, 3))

x = layers.Conv2D(filter, kernel_size=[3,3], strides=[1,1],
padding='same', use_bias=True)(lr_input)
x = layers.LeakyReLU(0.2)(x)
ref = x
for i in range(trunk_size):
x = rrdb(x)

x = layers.Conv2D(filter, kernel_size=[3,3], strides=[1,1],
padding='same', use_bias = True)(x)
x = layers.Add()([x, ref])

x = upsample(x, filter)
x = upsample(x, filter)
x = layers.Conv2D(filter, kernel_size=3, strides=1,
padding='same', use_bias=True)(x)
x = layers.LeakyReLU(0.2)(x)
hr_output = layers.Conv2D(out_channels, kernel_size=3, strides=1,
padding='same', use_bias=True)(x)

model = tf.keras.models.Model(inputs=lr_input, outputs=hr_output)
return model


The discriminator is a fairly straightforward CNN with Conv2D, BatchNormalization, LeakyReLU and Dense layers. Again, with the Keras Functional API.

def discriminator_network(filters = 64, training=True):
img = layers.Input(shape = (Params['hr_dimension'], Params['hr_dimension'], 3))

x = layers.Conv2D(filters, [3,3], 1, padding='same', use_bias=False)(img)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(alpha=0.2)(x)

x = layers.Conv2D(filters, [3,3], 2, padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(alpha=0.2)(x)

x = _conv_block_d(x, filters *2)
x = _conv_block_d(x, filters *4)
x = _conv_block_d(x, filters *8)
x = layers.Flatten()(x)
x = layers.Dense(100)(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Dense(1)(x)

model = tf.keras.models.Model(inputs = img, outputs = x)
return model

Loss Functions

The ESRGAN model makes use of three loss functions to ensure the balance between visual quality and metrics such as Peak Signal-to- Noise Ratio (PSNR) and encourages the generator to produce more realistic images with natural textures:

  1. Pixel loss – the pixel loss between the generated and ground truth.
  2. Adversarial loss (used RelativisticGAN) – calculated for both G and D.
  3. Perceptual loss – calculated using the pre-trained VGG-19 network.

Let’s dive deeper into the adversarial loss here since this is the most complex one and it’s a function added to the TF-GAN library as part of the project.

In GANs the discriminator network classifies the input data as real or fake. The generator is trained to generate fake data and fool the discriminator into mistakenly classifying it as real. As the generator increases the probability of fake data being real, the probability of real data being real should also decrease. This was a missing property of standard GANs as pointed out in this paper, and the relativistic discriminator was introduced to overcome this issue. The relativistic average discriminator estimates the probability that the given real data is more realistic than fake data, on average. This improves the quality of generated data and the stability of the model while training. In the TF-GAN library, see relativistic_generator_loss and relativistic_discriminator_loss for the implementation of this loss function.

def ragan_generator_loss(d_real, d_fake):
real_logits = d_real - tf.reduce_mean(d_fake)
fake_logits = d_fake - tf.reduce_mean(d_real)
real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.zeros_like(real_logits), logits=real_logits))
fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.ones_like(fake_logits), logits=fake_logits))

return real_loss + fake_loss

def ragan_discriminator_loss(d_real, d_fake):
def get_logits(x, y):
return x - tf.reduce_mean(y)
real_logits = get_logits(d_real, d_fake)
fake_logits = get_logits(d_fake, d_real)

real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.ones_like(real_logits), logits=real_logits))
fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.zeros_like(fake_logits), logits=fake_logits))

return real_loss + fake_loss


The ESRGAN model is trained in two phases:

  • Phase 1: train the generator network individually and is aimed at improving the PSNR values of generated images by reducing the L1 loss.
  • Phase 2: continue training of the same generator model along with the discriminator network. In the second phase, the generator reduces the L1 Loss, Relativistic average GAN (RaGAN) loss which indicates how realistic the generated image looks and the improved Perceptual loss proposed in the paper.

If starting from scratch, phase-1 training can be completed within an hour on a free colab TPU, whereas phase-2 can take around 2-3 hours to get good results. As a result saving the weights/checkpoints are important steps during training.

Phase 1 training

Here are the steps of phase 1 training:

  • Define the generator and its optimizer
  • Take LR, HR image pairs from the training dataset
  • Input the LR image to the generator network
  • Calculate the L1 loss using the generated image and HR image
  • Calculate gradient value and apply it to the optimizer
  • Update the learning rate of optimizer after every decay steps for better performance

Phase 2 training

In this phase of training:

  • Load the generator network trained in phase 1
  • Define checkpoints that can be useful during training
  • Use VGG-19 pretrained network for calculating perceptual loss

Then we define the training step as follows:

  • Input the LR image to the generator network
  • Calculate L1 loss, perceptual loss and adversarial loss for both the generator and the discriminator.
  • Update the optimizers for both networks using the obtained gradient values
  • Update the learning rate of optimizers after every decay steps for better performance
  • TF-GAN’s image grid function is used to display the generated images in the validation steps

Please refer to the Colab notebook for the complete code implementation.

During training we visualize the 3 images: LR image, HR image (generated), HR image (training data), and these metrics: generator loss, discriminator loss and PSNR.

step 0

Generator Loss = 0.636057436466217

Disc Loss = 0.0191921629011631

PSNR : 20.95576286315918

Here are some more results at the end of the training which look pretty good.


Now that training has completed, we will evaluate the ESRGAN model with 3 metrics: Fréchet Inception Distance (FID), Inception Scores and Peak signal-to-noise ratio (PSNR).

FID and Inception Scores are two common metrics used to evaluate the performance of a GAN model. Peak Signal-to- Noise Ratio (PSNR) is used to quantify the similarity between two images and is used for benchmarking super resolution models.

Instead of writing the code from scratch to calculate each of the metrics, we are using the TF-GAN library to evaluate our GAN implementation with ease for FID and Inception Scores. Then we make use of the `tf.image` module to calculate PSNR values for evaluating the super resolution algorithm.

Why do we need the TF-GAN library for evaluation?

Standard evaluation metrics for GANs such as Inception Scores, Frechet Distance or Kernel Distance are available inside TF-GAN Evaluation. Various implementations of such metrics can be prone to errors and this can result in unreliable evaluation scores. By using TF-GAN, such errors can be avoided and GAN evaluations can be made easy. For evaluating the ESRGAN model we have made use of the Inception Score (tfgan.eval.inception_score) and Frechet Distance Score (tfgan.eval.frechet_inception_distance) from the TF-GAN library.

Here is how we use tf-gan for evaluation in code.

First we need to install the tf-gan library which should have been part of the imports at the beginning of the notebook. Then we import the library.

!pip install tensorflow-gan
import tensorflow_gan as tfgan

Now we are ready to use the library for the ESRGAN evaluation!

Fréchet inception distance (FID)

def get_fid_score(real_image, gen_image):

resized_real_images = tf.image.resize(real_image, [size, size], method=tf.image.ResizeMethod.BILINEAR)
resized_generated_images = tf.image.resize(gen_image, [size, size], method=tf.image.ResizeMethod.BILINEAR)
num_inception_images = 1
num_batches = Params['batch_size'] // num_inception_images
fid = tfgan.eval.frechet_inception_distance(resized_real_images, resized_generated_images, num_batches=num_batches)
return fid

Inception Scores

def get_inception_score(images, gen, num_inception_images = 8):
resized_images = tf.image.resize(images, [size, size], method=tf.image.ResizeMethod.BILINEAR)

num_batches = Params['batch_size'] // num_inception_images
inc_score = tfgan.eval.inception_score(resized_images, num_batches=num_batches)

return inc_score

Peak Signal-to- Noise Ratio (PSNR)

def get_psnr(real, generated):
psnr_value = tf.reduce_mean(tf.image.psnr(generated, real, max_val=256.0))
return psnr_value

GSoC experience

Here is the Google Summer of Code 2021 experience in our own words:


As a student, Google Summer of Code gave me an opportunity to participate in exciting open source projects for TensorFlow and the mentorship that I got during this period was invaluable. I got to learn a lot about implementing various GAN models, writing tutorial notebooks, using Cloud TPUs for training models and using tools such as Google Cloud Platform. I received a lot of support from Margaret and Joel throughout the program which kept the project on track. From the beginning their suggestions helped define the project scope and during the coding period, Margaret and I met on a weekly basis to clear all my doubts and solve various issues that I was facing. Joel also helped in reviewing all the PRs made to the TF-GAN library. GSoC is indeed a great way of getting involved with various interesting TensorFlow libraries and I look forward to continuing making valuable contributions to the community.


As the project mentor, I have been involved since the project selection phase. Mentoring Nived and collaborating with Joel on TF-GAN has been a fulfilling experience. Nived has done an excellent job implementing the ESRGAN paper with TensorFlow 2 and TF-GAN. Nived and I spent a lot of time looking at the various text-to-image GANs to choose one that can potentially be implemented during the GSoC timeframe. Aside from writing the ESRGAN tutorial, he made great progress on ControlGAN for text-to-image generation. I hope this project helps others to learn how to use the TF-GAN library and contribute to TF-GAN and other open source TensorFlow projects.


As an unofficial technical mentor, I was impressed how independently and effectivly Nived worked. I felt more like I was working with a junior colleague than an intern, in that I helped give technical and project pointers, but ultimately Nived made the decisions. I think the impressive results reflect this: Nived owned the project, and I think as a result the example and Colab are more well-written and cohesive than they otherwise might have been. Furthermore, Nived successfully navigated the multi-timezone reality that is working-from-home!

What’s next

During the GSoC coding period the implementation of the ESRGAN model was completed and the Python code and Colab notebooks were merged to the TF-GAN repo. The implementation of the ControlGAN model for text-to-image generation is still in progress. Once the implementation of ControlGAN is completed, we plan to extend it to serve some real-world applications in areas such as art generation or image editing. We are also planning to write tutorials to explore different models that solve the task of text-to-image translation.

If you want to contribute to TF-GAN, you can reach out to `tfgan-users@google.com` to propose a project or addition. Unless you’ve contributed to OSS Google projects before, it’s usually a good idea to check with someone before submitting a large pull request. We look forward to seeing your contributions and working with you!


We would like to thank the GSoC program committee and their support, in particular Josh Gordon from the TensorFlow team.

Many thanks to the support of the Machine Learning (ML) Google Developer Expert (GDE) program, Google Cloud Platform and TensorFlow Research Cloud.

Read More

Continuous Adaptation for Machine Learning System to Data Changes

A guest post by Chansung Park, Sayak Paul (ML-GDEs)

Continuous integration and delivery (CI/CD) is a much sought-after topic in the DevOps domain. In the MLOps (Machine Learning + Operations) domain, we have another form of continuity — continuous evaluation and retraining. MLOps systems evolve according to the changes of the world, and that is usually caused by data/concept drift. So, to cater to the data changes we need to continuously evaluate our deployed ML models and retrain and re-deploy them as necessary.

In this blog post, we present a project that implements a workflow combining batch prediction and model evaluation for continuous evaluation retraining In order to capture changes in the data. We will first discuss the general setup of the project. Then we will move on to key components (batch prediction, new data spans, retraining, etc.) that are important for continuously evaluating an ML model and then re-training it if needed. Rather than discussing the technical implementation details of the project, we will keep it high-level so that we will focus on understanding the underlying concepts.

The project is implemented with TensorFlow Extended (TFX), Keras, and various services offered from Google Cloud Platform. You can find the project on GitHub.


This project shows how to build two separate pipelines working together to create a CI/CD workflow which responds to changes in the data. The first pipeline is for model training, and the second pipeline is for model evaluation based on the result of a batch prediction as shown in Figure 1.

Figure 1. Overview of the project structure (original)

The model training pipeline is built by combining standard TFX components such as ImportExampleGen and Trainer with custom TFX components such as VertexUploader and VertexDeployer. Since the Pusher standard component had an issue when we were doing this project, we have brought custom components from our previous project, Dual Deployments.

There is one significant implementation detail on how ImportExampleGen handles the dataset to be fed into the model. We have designed our project to hold datasets from different distributions in separate folders with filesystem paths which indicate the span number. For instance, the initial training and test dataset can be stored in SPAN-1/train and SPAN-2/test while the drifted dataset can be stored in SPAN-2/train and SPAN-2/test respectively as shown in Figure 2.

For the sake of the versioning feature in Google Cloud Storage (GCS), you might think we don’t need to manage datasets in this manner. However, we thought our way makes datasets much more manageable. For example, you might want to pick data from SPAN-1 and SPAN-2 or SPAN-1 and SPAN-3 to train the model depending on situations. Also, datasets belonging to the same distribution can still benefit from the versioning feature in GCS.

Figure 2. How datasets are managed (original)

The batch evaluation pipeline does not leverage any standard TFX components. Rather it consists of five custom TFX components which are FileListGen, BatchPredictionGen, PerformanceEvaluator, SpanPreparator, and PipelineTrigger. These components are available as standalone modules here.

Figure 3. Custom TFX components in batch evaluation pipeline (original)

FileListGen generates a text file to be looked up by the currently deployed model on Vertex AI to perform batch prediction according to the format required by Vertex Prediction. Then BatchPredictionGen will simply perform Vertex Prediction based on the prepared text file from the FileListGen and output a set of files containing the batch prediction results. PerformanceEvaluator calculates the average accuracy based on the batch prediction results and outputs False if it is less than the threshold. If the output is True, the pipeline will be terminated. Or if the output is False, SpanPreparator prepares TFRecord files by compressing the list of raw data, and then puts those TFRecords into a new folder whose name contains the successive span number such as span-2. Finally, PipelineTrigger triggers the model training pipeline by passing the span numbers for the data which should be included for training the model through RuntimeParameter.

General setup

In this section, we walk through the key components of the project and also leave some notes on the tools we used to implement them.

Getting the initial model ready

We focus on the concepts and consider implementing them in a minimal manner so that our implementations are as reproducible and as accessible as possible. Keeping that in mind, we use the CIFAR-10 training set as our training data and we fine-tune a ResNet50 model to fit the data. Our training pipeline is demonstrated in this notebook.

Simulating data drift and labeling new data

To simulate a data drift scenario, we then collect a bunch of images from the internet matching CIFAR-10 classes. To make it easy to follow we implement this workflow inside a Colab Notebook which is available here. This workflow also includes uploading and deploying the trained model as a service on the Vertex AI platform.

Continuous evaluation with batch inference

We then perform inference on these images with the trained model from the above step. We perform batch inference rather than online inference to get the results. We use Vertex AI’s batch prediction service to realize this. In practice, usually after this step, the model test images and model predictions are sent to domain experts for audit purposes. They also provide the expected ground-truth labels on the test images. Only after that, we can validate the prediction results. But for the purpose of this project, we eliminate this step and pretend that the ground-truth labels are already available. So, as soon as the batch prediction results are available we evaluate them. This entire workflow is covered in this notebook.

We deploy a Cloud Function to monitor a specific location inside a Google Cloud Storage (GCS) bucket. If there is a sufficient number of new test images available inside that location, we trigger the batch prediction pipeline. We cover this workflow in this notebook. This is how we achieve the “continuous evaluation” aspect of our project.

There are other ways to capture drift in data, though. For example, using JS-Divergence, we can compare the distributions between the newly available data and training data. You can follow this Coursera lecture from Robert Crowe which dives deep into these techniques.

Model retraining

After the batch predictions are evaluated, the next step is to determine if we need to re-train the model based on a predefined performance threshold that generally depends on the business context and a lot of other factors. We set this threshold to 0.9 in the project. If we need to re-train then we trigger the same model training pipeline (as shown in this notebook) but with the newly available data added to the CIFAR-10 training set. We can either warm-start our model from a previous checkpoint or we can train the model from scratch using all the available training data. For this project, we do the latter.

In the following section, we will go over a few non-trivial components from our implementation and discuss their motivation and technicalities. As a reminder, our implementation is fully open-sourced here.

Implementation details on managing datasets with span numbers

In this section, we walk through the implementation details on some key aspects of the project. Please go through the project repository and review all notebooks for further information.

The initial CIFAR-10 datasets are stored in {bucket-name}/span-1/train and {bucket-name}/span-1/test GCS locations respectively. This step is done through the first notebook. Then, we download more images of the same categories as in CIFAR-10 by using Bing Image Downloader. Those images are resized by 32×32 to make them compatible with CIFAR-10 datasets, and they are stored in a separate GCS bucket such as {bucket-batch-prediction}/2021-10/.

Note we used the YYYY-MM for the name where the images are stored. This is because Cloud Function which is fired by Cloud Scheduler will look for the latest GCS location to launch the batch evaluation pipeline as shown below.

def get_latest_directory(storage_client, bucket):
blobs = storage_client.list_blobs(bucket)

folders = list(
for blob in blobs
if bool(
"[1-9][0-9][0-9][0-9]-[0-1][0-9]", os.path.dirname(blob.name)
is True

folders.sort(key=lambda date: datetime.strptime(date, "%Y-%m"))
return folders[0]

As you see, it only looks for the GCS location that exactly matches the YYYY-MM format. The Cloud Function launches the batch evaluation pipeline by passing which GCS location to look up for batch prediction via RuntimeParameter. The code snippet below shows how it is passed to the pipeline with the name data_gcs_prefix on the Cloud Function side.

from kfp.v2.google.client import AIPlatformClient

api_client = AIPlatformClient(project_id=project, region=region)

response = api_client.create_run_from_job_spec(
parameter_values={"data_gcs_prefix": latest_directory},

The pipeline recognizes data_gcs_prefix is a type of RuntimeParameter, and it is used in the FileListGen component which prepares a text file in the required format to perform Vertex AI Batch Prediction.

def _create_pipeline(
data_gcs_prefix: data_types.RuntimeParameter,
) -> Pipeline:

filelist_gen = FileListGen(


Let’s skip the batch prediction performed by the BatchPredictionGen component.

When the PerformanceEvaluator component determines that retraining should be performed based on the result from the BatchPredictionGen component, the SpanPreparator prepares a TFRecord file with the newly collected images, moves it to {bucket-name}/span-1/train and {bucket-name}/span-2/test where the training pipeline is ingesting data for model training, and renames the GCS location where the newly collected images are to {bucket-batch-prediction}/YYYY-MM_old/.

We add the _old suffix so that Cloud Function will ignore the renamed GCS location. If the retrained model doesn’t show a good enough performance metric, then you can have a chance to collect more data and merge them with the images in the _old GCS location.

The PipelineTrigger component at the end of the batch evaluation pipeline will trigger the training pipeline by passing which span numbers to look for in order to do model training. The data will be consumed by ImportExampleGen, based on the glob pattern matching feature. For instance, if data from span-1 and span-2 should be used for model training, then the glob pattern for the training dataset might be span-[12]/train/*.tfrecord. The code snippet below clearly shows the generalized version of the idea.

response = api_client.create_run_from_job_spec(
"input-config": json.dumps(
"splits": [
"name": "train",
"pattern": f"span-[{int(latest_span)-1}{latest_span}]/train/*.tfrecord",
"name": "val",
"pattern": f"span-[{int(latest_span)-1}{latest_span}]/test/*.tfrecord",
"output-config": json.dumps({}),

The reason we formed the RuntimeParameter in the parameter_values in this way is that the pattern matching feature of the ImportExampleGen component should be specified in the input-config and output-config parameters. We do not need the output-config parameter for our purpose, but it is required when passing the input-config parameter as a RuntimeParameter. That’s why the output-config parameter is left empty. Note that you have to form the parameter in protocol buffer format when using RuntimeParameter for standard TFX components. The code below shows how the passed input-config and output-config can be consumed by the ImportExampleGen component.

example_gen = tfx.components.ImportExampleGen(
input_base=data_root, input_config=input_config, output_config=output_config

It is worth noting that you can leverage the rolling window feature supported by TFX with the standard components if the backend environment is Kubeflow Pipeline v1. The code snippet below shows how to achieve this with the CsvExampleGen component and a Resolver node.

examplegen_range_config = proto.RangeConfig(
start_span_number=2, end_span_number=2))

example_gen = tfx.components.CsvExampleGen(

resolver_range_config = proto.RangeConfig(

examples_resolver = tfx.dsl.Resolver(
'range_config': resolver_range_config

This is a much better way since it reuses the artifacts generated by the previous ExampleGens, and the current pipeline run only takes care of the data in the new span. Unfortunately however this feature is not supported by Vertex AI Pipeline which is based on Kubeflow Pipeline v2. We had an extensive discussion with the TFX team about this, which is why we came up with a different approach from the standard way.


Vertex AI Training is a separate service from Pipeline. We need to pay for the Vertex AI Pipeline individually, and at the time of writing this article, it costs about $0.03 USD per pipeline run. The type of compute instance for each TFX component was e2-standard-4, and it costs about $0.134 per hour. Since the whole pipeline took less than an hour to be finished, we can estimate that the total cost was about $0.164 for a Vertex AI Pipeline run.

The cost of custom model training depends on the type of machine and the number of hours. Also, you have to consider that you pay for the server and the accelerator separately. For this project, we chose n1-standard-4 machine type whose price is $0.19 per hour and NVIDIA_TESLA_K80 accelerator type whose price is $0.45 per hour. The training for each model was done in less than an hour, so it cost about $1.28 in total. So, as per our estimates, the upper bound of the costs incurred should not be more than $5.

The cost only stems from Vertex AI because the rest of the components like Pub/Sub, Cloud Functions, etc., have very minimal usage. So even if we add a small estimate for those costs, the upper bound of the total cost for this project should not be more than $5. Please refer to the official documents on the price: Vertex AI price reference, Cloud Build price reference.

In any case, you should use this GCP Price Calculator to get a better understanding of how your cost for the GCP services might differ.


In this blog post, we touched upon the idea of continuous evaluation and re-training for machine learning systems as well as the tooling needed to implement them. There is also a more traditional form of CI/CD for ML systems in response to code changes including changes in hyperparameters, model architecture, etc. We have a separate project demonstrating that use case. You are encouraged to check them here: Part I and Part II.


We are grateful to the ML-GDE program that provided GCP credits for supporting our experiments. We sincerely thank Robert Crowe and Jiayi Zhao of Google for their help with the review.

Read More

Recognizing the 2021 TensorFlow Contributor Awardees

Posted by the TensorFlow team

TensorFlow wouldn’t be where it is today without its bustling global community of contributors. There are many ways these developers contribute. They write code to improve TensorFlow, teach classes, answer questions on forums, and organize and host events.

We are thankful to every person that’s helped the TensorFlow community over the years. And at this year’s TensorFlow Contributor Summit, we wanted to show thanks by recognizing individuals who went above and beyond on their TensorFlow contributions in 2021.

So without further ado, we are pleased to introduce the TensorFlow Contributor Awardees of 2021!

SIG Leadership Award

Awarded to a highly active SIG

Jason Zaman, SIG Build

Active SIG Award

Awarded to an impactful Special Interest Group (SIG) leader

Sean Morgan, SIG Add-ons

TF Forum Award

Awarded to a helpful TF Forum user with many liked posts and responses

Ekaterina Dranitsyna

Diversity and Inclusion Award

Awarded to the person who made a significant effort to bring diversity into the TensorFlow ecosystem

Merve Noyan

Education Outreach Awards

Awarded to the people who made significant contributions to educational outreach

Gant Laborde

Sandeep Mistry

Community Management Award

Awarded to highly active community leaders

TensorFlow User Group Pune (TFUG Pune)

Yogesh Kulkarni, Shashank Sane, and Aditya Kane

Regional Awards

Awarded to top contributors by geographic region

Margaret Maynard-Reid, Americas

Sayak Paul, South Asia / Oceania

Chansung Park, East Asia

Ruqiya Bin Safi, Middle East / Africa

M. Yusuf Sarıgöz, Europe

Art by Margaret Maynard-Reid
Art by Margaret Maynard-Reid

Thank you again to all the TensorFlow contributors! We look forward to recognizing even more of you next year.

Read More

An Introduction to Keras Preprocessing Layers

Posted by Matthew Watson, Keras Developer

Determining the right feature representation for your data can be one of the trickiest parts of building a model. Imagine you are working with categorical input features such as names of colors. You could one-hot encode the feature so each color gets a 1 in a specific index ('red' = [0, 0, 1, 0, 0]), or you could embed the feature so each color maps to a unique trainable vector ('red' = [0.1, 0.2, 0.5, -0.2]). Larger category spaces might do better with an embedding, and smaller spaces as a one-hot encoding, but the answer is not clear cut. It will require experimentation on your specific dataset.

Ideally, we would like updates to our feature representation and updates to our model architecture to happen in a tight iterative loop, applying new transformations to our data while changing our model architecture. In practice, feature preprocessing and model building are usually handled by entirely different libraries, frameworks, or languages. This can slow the process of experimentation.

On the Keras team, we recently released Keras Preprocessing Layers, a set of Keras layers aimed at making preprocessing data fit more naturally into model development workflows. In this post we are going to use the layers to build a simple sentiment classification model with the imdb movie review dataset. The goal will be to show how preprocessing can be flexibly developed and applied. To start, we can import tensorflow and download the training data.

import tensorflow as tf
import tensorflow_datasets as tfds

train_ds = tfds.load('imdb_reviews', split='train', as_supervised=True).batch(32)

Keras preprocessing layers can handle a wide range of input, including structured data, images, and text. In this case, we will be working with raw text, so we will use the TextVectorization layer.

By default, the TextVectorization layer will process text in three phases:

  • First, remove punctuation and lower cases the input.
  • Next, split text into lists of individual string words.
  • Finally, map strings to numeric outputs using a vocabulary of known words.

A simple approach we can try here is a multi-hot encoding, where we only consider the presence or absence of terms in the review. For example, say a layer vocabulary is ['movie', 'good', 'bad'], and a review read 'This movie was bad.'. We would encode this as [1, 0, 1], where movie (the first vocab term) and bad (the last vocab term) are present.

text_vectorizer = tf.keras.layers.TextVectorization(
output_mode='multi_hot', max_tokens=2500)
features = train_ds.map(lambda x, y: x)

Above, we create a TextVectorization layer with multi-hot output, and do two things to set the layer’s state. First, we map over our training dataset and discard the integer label indicating a positive or negative review. This gives us a dataset containing only the review text. Next, we adapt() the layer over this dataset, which causes the layer to learn a vocabulary of the most frequent terms in all documents, capped at a max of 2500.

Adapt is a utility function on all stateful preprocessing layers, which allows layers to set their internal state from input data. Calling adapt is always optional. For TextVectorization, we could instead supply a precomputed vocabulary on layer construction, and skip the adapt step.

We can now train a simple linear model on top of this multi-hot encoding. We will define two functions: preprocess, which converts raw input data to the representation we want for our model, and forward_pass, which applies the trainable layers.

def preprocess(x):
return text_vectorizer(x)

def forward_pass(x):
return tf.keras.layers.Dense(1)(x) # Linear model

inputs = tf.keras.Input(shape=(1,), dtype='string')
outputs = forward_pass(preprocess(inputs))
model = tf.keras.Model(inputs, outputs)
model.fit(train_ds, epochs=5)

That’s it for an end-to-end training example, and already enough for 85% accuracy. You can find complete code for this example at the bottom of this post.

Let’s experiment with a new feature. Our multi-hot encoding does not contain any notion of review length, so we can try adding a feature for normalized string length. Preprocessing layers can be mixed with TensorFlow ops and custom layers as desired. Here we can combine the tf.strings.length function with the Normalization layer, which will scale the input to have 0 mean and 1 variance. We have only updated code up to the preprocess function below, but we will show the rest of training for clarity.

# This layer will scale our review length feature to mean 0 variance 1.
normalizer = tf.keras.layers.Normalization(axis=None)
normalizer.adapt(features.map(lambda x: tf.strings.length(x)))

def preprocess(x):
multi_hot_terms = text_vectorizer(x)
normalized_length = normalizer(tf.strings.length(x))
# Combine the multi-hot encoding with review length.
return tf.keras.layers.concatenate((multi_hot_terms, normalized_length))

def forward_pass(x):
return tf.keras.layers.Dense(1)(x) # Linear model.

inputs = tf.keras.Input(shape=(1,), dtype='string')
outputs = forward_pass(preprocess(inputs))
model = tf.keras.Model(inputs, outputs)
model.fit(train_ds, epochs=5)

Above, we create the normalization layer and adapt it to our input. Within the preprocess function, we simply concatenate our multi-hot encoding and length features together. We learn a linear model over the union of the two feature representations.

The last change we can make is to speed up training. We have one major opportunity to improve our training throughput. Right now, every training step, we spend some time on the CPU performing string operations (which cannot run on an accelerator), followed by calculating a loss function and gradients on a GPU.

With all computation in a single model, we will first preprocess each batch on the CPU and then update parameter weights on the GPU. This leaves gaps in our GPU usage.
With all computation in a single model, we will first preprocess each batch on the CPU and then update parameter weights on the GPU. This leaves gaps in our GPU usage.

This gap in accelerator usage is totally unnecessary! Preprocessing is distinct from the actual forward pass of our model. The preprocessing doesn’t use any of the parameters being trained. It’s a static transformation that we could precompute.

To speed things up, we would like to prefetch our preprocessed batches, so that each time we are training on one batch we are preprocessing the next. This is easy to do with the tf.data library, which was built for uses like this. The only major change we need to make is to split our monolithic keras.Model into two: one for preprocessing and one for training. This is easy with Keras’ functional API.

inputs = tf.keras.Input(shape=(1,), dtype="string")
preprocessed_inputs = preprocess(inputs)
outputs = forward_pass(preprocessed_inputs)

# The first model will only apply preprocessing.
preprocessing_model = tf.keras.Model(inputs, preprocessed_inputs)
# The second model will only apply the forward pass.
training_model = tf.keras.Model(preprocessed_inputs, outputs)

# Apply preprocessing asynchronously with tf.data.
# It is important to call prefetch and remember the AUTOTUNE options.
preprocessed_ds = train_ds.map(
lambda x, y: (preprocessing_model(x), y),

# Now the GPU can focus on the training part of the model.
training_model.fit(preprocessed_ds, epochs=5)

In the above example, we pass a single keras.Input through our preprocess and forward_pass functions, but define two separate models over the transformed inputs. This slices our single graph of operations into two. Another valid option would be to only make a training model, and call the preprocess function directly when we map over our dataset. In this case, the keras.Input would need to reflect the type and shape of the preprocessed features rather than the raw strings.

Using tf.data to prefetch batches cuts our train step time by over 30%! Our compute time now looks more like the following:

With tf.data, we are now precomputing each preprocessed batch before the GPU needs it. This significantly speeds up training.
With tf.data, we are now precomputing each preprocessed batch before the GPU needs it. This significantly speeds up training.

We could even go a step further than this, and use tf.data to cache our preprocessed dataset in memory or on disk. We would simply add a .cache() call directly before the call to prefetch. In this way, we could entirely skip computing our preprocessing batches after the first epoch of training.

After training, we can rejoin our split model into a single model during inference. This allows us to save a model that can directly handle raw input data.

inputs = preprocessing_model.input
outputs = training_model(preprocessing_model(inputs))
inference_model = tf.keras.Model(inputs, outputs)
tf.constant(["Terrible, no good, trash.", "I loved this movie!"]))

Keras preprocessing layers aim to provide a flexible and expressive way to build data preprocessing pipelines. Prebuilt layers can be mixed and matched with custom layers and other tensorflow functions. Preprocessing can be split from training and applied efficiently with tf.data, and joined later for inference. We hope they allow for more natural and efficient iterations on feature representation in your models.

To play around with the code from this post in a Colab, you can follow this link. To see a wide range of tasks you can do with preprocessing layers, see the Quick Recipes section of our preprocessing guide. You can also check out our complete tutorials for basic text classification, image data augmentation, and structured data classification.

Read More

Announcing TensorFlow’s Kaggle Challenge to Help Protect Coral Reefs

Posted by Megha Malpani & Tim Davis, Google Product Managers

We are excited to announce a TensorFlow-sponsored Kaggle challenge to locate and identify harmful crown-of-thorns starfish (COTS), as part of a broader partnership between the Commonwealth Scientific and Industrial Research Organization (CSIRO) and Google, to help protect coral reefs everywhere.

Coral reefs are some of the most diverse and important ecosystems in the world – both for marine life and society more broadly. Not only are healthy reefs critical to fisheries and food security, they provide countless additional benefits: protecting coastlines from storm surge, supporting tourism-based economies and sustainable livelihoods, and pushing forward drug discovery research.

Reefs around the world face a number of rising threats, most notably climate change, pollution, and overfishing. In the past 30 years alone, there have been dramatic losses in coral cover and habitat in the Great Barrier Reef (GBR), with other reefs experiencing similar declines. In Australia, outbreaks of the coral-eating COTS have been shown to cause major coral loss. These outbreaks can strip a reef of 90% of its coral tissue. While COTS naturally exist in the Indo-Pacific ocean, overfishing and excess run-off nutrients have led to massive outbreaks that are devastating already vulnerable coral communities.

Controlling COTS populations is critical to reducing coral mortality from outbreaks. Google has teamed up with CSIRO to supercharge efforts in monitoring COTS using artificial intelligence. This is just the beginning of a much deeper collaboration and we, along with the Great Barrier Reef Foundation, are extremely excited to invite you, our global ML community, to help protect the world’s reefs.

We are challenging the Kaggle community to build the most accurate and performant (in terms of runtime and memory usage) crown-of-thorns starfish object detection models for image sequences. For this challenge, we are offering $150,000 in prizes to the best solutions.

We have two tiers of prizes – the first, in standard Kaggle fashion, for the most accurate models. Since we will be deploying these models on the edge, we are offering an additional prize for the most performant models (that fall in the top 10% of the accuracy leaderboard). We are looking for creative ideas on how to maximize performance while working effectively with underwater image sequences. We intend to ultimately bring the most innovative ideas together in a single model that we deploy on the Great Barrier Reef. We plan to open-source the winning model for other scientific organizations and agencies around the world to use.

This is an amazing opportunity to have a real impact protecting coral reefs everywhere! The competition is now live, so please join the challenge today and get started with this notebook. We look forward to seeing what you come up with, good luck!

Acknowledgements: Thanks to everyone whose hard work made this collaboration possible!

Google: Martin Wicke, Kemal El Moujahid, Sarah Sirajudddin, Scott Riddle, Glenn Cameron, Addison Howard, Will Cukierski, Sohier Dane, Ryan Holbrook, Khanh LeViet, Sachin Joglekar, Tei Jeong, Rachel Stiegler, Daniel Formoso, Tom Small, Ana Nieto, Arun Venkatesan

CSIRO: Jiajun Liu, Brano Kusy, Ross Marchant, David Ahmedt, Lachlan Tychsen-Smith, Joey Crosswell, Geoffrey Carlin, Russ Babcock

Read More

Introducing TensorFlow Graph Neural Networks

Posted by Sibon Li, Jan Pfeifer and Bryan Perozzi and Douglas Yarrington

Today, we are excited to release TensorFlow Graph Neural Networks (GNNs), a library designed to make it easy to work with graph structured data using TensorFlow. We have used an earlier version of this library in production at Google in a variety of contexts (for example, spam and anomaly detection, traffic estimation, YouTube content labeling) and as a component in our scalable graph mining pipelines. In particular, given the myriad types of data at Google, our library was designed with heterogeneous graphs in mind. We are releasing this library with the intention to encourage collaborations with researchers in industry.

Why use GNNs?

Graphs are all around us, in the real world and in our engineered systems. A set of objects, places, or people and the connections between them is generally describable as a graph. More often than not, the data we see in machine learning problems is structured or relational, and thus can also be described with a graph. And while fundamental research on GNNs is perhaps decades old, recent advances in the capabilities of modern GNNs have led to advances in domains as varied as traffic prediction, rumor and fake news detection, modeling disease spread, physics simulations, and understanding why molecules smell.

Graphs can model the relationships between many different types of data, including web pages (left), social connections (center), or molecules (right).
Graphs can model the relationships between many different types of data, including web pages (left), social connections (center), or molecules (right).

A graph represents the relations (edges) between a collection of entities (nodes or vertices). We can characterize each node, edge, or the entire graph, and thereby store information in each of these pieces of the graph. Additionally, we can ascribe directionality to edges to describe information or traffic flow, for example.

GNNs can be used to answer questions about multiple characteristics of these graphs. By working at the graph level, we try to predict characteristics of the entire graph. We can identify the presence of certain “shapes,” like circles in a graph that might represent sub-molecules or perhaps close social relationships. GNNs can be used on node-level tasks, to classify the nodes of a graph, and predict partitions and affinity in a graph similar to image classification or segmentation. Finally, we can use GNNs at the edge level to discover connections between entities, perhaps using GNNs to “prune” edges to identify the state of objects in a scene.


TF-GNN provides building blocks for implementing GNN models in TensorFlow. Beyond the modeling APIs, our library also provides extensive tooling around the difficult task of working with graph data: a Tensor-based graph data structure, a data handling pipeline, and some example models for users to quickly onboard.

The various components of TF-GNN that make up the workflow.
The various components of TF-GNN that make up the workflow.

The initial release of the TF-GNN library contains a number of utilities and features for use by beginners and experienced users alike, including:

  • A high-level Keras-style API to create GNN models that can easily be composed with other types of models. GNNs are often used in combination with ranking, deep-retrieval (dual-encoders) or mixed with other types of models (image, text, etc.)
    • GNN API for heterogeneous graphs. Many of the graph problems we approach at Google and in the real world contain different types of nodes and edges. Hence we chose to provide an easy way to model this.
  • A well-defined schema to declare the topology of a graph, and tools to validate it. This schema describes the shape of its training data and serves to guide other tools.
  • A GraphTensor composite tensor type which holds graph data, can be batched, and has graph manipulation routines available.
  • A library of operations on the GraphTensor structure:
    • Various efficient broadcast and pooling operations on nodes and edges, and related tools.
    • A library of standard baked convolutions, that can be easily extended by ML engineers/researchers.
    • A high-level API for product engineers to quickly build GNN models without necessarily worrying about its details.
  • An encoding of graph-shaped training data on disk, as well as a library used to parse this data into a data structure from which your model can extract the various features.

Example usage

In the example below, we build a model using the TF-GNN Keras API to recommend movies to a user based on what they watched and genres that they liked.

We use the ConvGNNBuilder method to specify the type of edge and node configuration, namely to use WeightedSumConvolution (defined below) for edges. And for each pass through the GNN, we will update the node values through a Dense interconnected layer:

    import tensorflow as tf
import tensorflow_gnn as tfgnn

# Model hyper-parameters:
h_dims = {'user': 256, 'movie': 64, 'genre': 128}

# Model builder initialization:
gnn = tfgnn.keras.ConvGNNBuilder(
lambda edge_set_name: WeightedSumConvolution(),
lambda node_set_name: tfgnn.keras.layers.NextStateFromConcat(

# Two rounds of message passing to target node sets:
model = tf.keras.models.Sequential([
gnn.Convolve({'genre'}), # sends messages from movie to genre
gnn.Convolve({'user'}), # sends messages from movie and genre to users

The code above works great, but sometimes we may want to use a more powerful custom model architecture for our GNNs. For example, in our previous use case, we might want to specify that certain movies or genres hold more weight when we give our recommendation. In the following snippet, we define a more advanced GNN with custom graph convolutions, in this case with weighted edges. We define the WeightedSumConvolution class to pool edge values as a sum of weights across all edges:

class WeightedSumConvolution(tf.keras.layers.Layer):
"""Weighted sum of source nodes states."""

def call(self, graph: tfgnn.GraphTensor,
edge_set_name: tfgnn.EdgeSetName) -> tfgnn.Field:
messages = tfgnn.broadcast_node_to_edges(
weights = graph.edge_sets[edge_set_name]['weight']
weighted_messages = tf.expand_dims(weights, -1) * messages
pooled_messages = tfgnn.pool_edges_to_node(
return pooled_messages

Note that even though the convolution was written with only the source and target nodes in mind, TF-GNN makes sure it’s applicable and works on heterogeneous graphs (with various types of nodes and edges) seamlessly.

Next steps

You can check out the TF-GNN GitHub repo for more information. To stay up to date, you can read the TensorFlow blog, join the TensorFlow Forum at discuss.tensorflow.org, follow twitter.com/tensorflow, or subscribe to youtube.com/tensorflow. If you’ve built something you’d like to share, please submit it for our Community Spotlight at goo.gle/TFCS. For feedback, please file an issue on GitHub. Thank you!


The work described here was a research collaboration between Oleksandr Ferludin‎, Martin Blais, Jan Pfeifer‎, Arno Eigenwillig, Dustin Zelle, Bryan Perozzi and Da-Cheng Juan of Google, and Sibon Li, Alvaro Sanchez-Gonzalez, Peter Battaglia, Kevin Villela, Jennifer She and David Wong of DeepMind.

Read More

ML Community Day 2021 Recap

Posted by the TensorFlow Team

Thanks to everyone who joined our inaugural virtual ML Community Day! It was so great to get the community together and hear incredible talks like how JAX and TPUs make AlphaFold possible from the DeepMind team, and how Edge Impulse makes it easy for developers to work with TinyML using TensorFlow.

We also celebrated TensorFlow’s 6th birthday! The TensorFlow ecosystem has come a long way in 6 years, and we love seeing what you all achieve with our tools. From using machine learning to help advance access to human rights information, to creating a custom, TensorFlow-powered drumming arm.

In this article are a few of the updates and topics we shared during the event. You can watch the keynote below, and you can find recordings of every talk on the TensorFlow YouTube channel.

Model building

TensorFlow 2.7 is here! This release offers performance and usability improvements, including TFLite use of XNNPack for mobile inference performance boosts, training improvements on GPUs, and a dramatic improvement in debugging efficiency in Keras and TF.

Keras has been modularized as a separate pip package on top of TensorFlow (installed by default) and now lives in a separate GitHub repository. This will make it much easier for the community to contribute to the development of Keras. We welcome your PRs!

Responsible AI

The Responsible AI team also announced v0.4 of our Language Interpretability Tool (LIT). LIT is an open-source platform for visualization and understanding of NLP models. This new release includes new interpretability techniques like TCAV, Targeted Concept activation Vector. TCAV is an interpretability method for ML models that shows the importance of high level conceptsfor a predicted class.


We recently launched on-device training in TensorFlow Lite. When deploying TensorFlow Lite machine learning model to a mobile app, you may want to enable the model to be improved or personalized based on input from the device or end user. Using on-device training techniques allows you to update a model without data leaving your users’ devices, improving user privacy, and without requiring users to update the device software. It’s currently available on Android.

And we continue to work on making performance better on TensorFlow Lite. As mentioned above, XNNPACK, a library for faster floating point ops, is now turned on by default in TensorFlow Lite. This allows your models to run on an average 2.3x faster on the CPU.

Find all the talks here

You can find all of the content in this playlist, and for your convenience here are direct links to each of the sessions also:

Read More

3D Hand Pose with MediaPipe and TensorFlow.js

Posted by Valentin Bazarevsky, Ivan Grishchenko, Eduard Gabriel Bazavan, Andrei Zanfir, Mihai Zanfir, Jiuqiang Tang, Jason Mayes, Ahmed Sabie, Google

Today, we’re excited to share a new version of our model for hand pose detection, with improved accuracy for 2D, novel support for 3D, and the new ability to predict keypoints on both hands simultaneously. Support for multi-hand tracking was one of the most common requests from the developer community, and we’re pleased to support it in this release.

You can try a live demo of the new model here. This work improves on our previous model which predicted 21 keypoints, but could only detect a single hand at a time. In this article, we’ll describe the new model, and how you can get started.

The new hand pose detection model in action.
The new hand pose detection model in action.

Try out the live demo!

How to use it

1. The first step is to import the library. You can either use the <script> tag in your html file or use NPM:

Through script tag:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/hand-pose-detection">>/script>
<!-- Optional: Include below scripts if you want to use MediaPipe runtime. -->
<script src="https://cdn.jsdelivr.net/npm/@mediapipe/hands"> </script >

Through NPM:

yarn add @tensorflow-models/hand-pose-detection

# Run below commands if you want to use TF.js runtime.
yarn add @tensorflow/tfjs-core @tensorflow/tfjs-converter
yarn add @tensorflow/tfjs-backend-webgl

# Run below commands if you want to use MediaPipe runtime.
yarn add @mediapipe/hands

If installed through NPM, you need to import the libraries first:

import * as handPoseDetection from '@tensorflow-models/hand-pose-detection';

Next create an instance of the detector:

const model = handPoseDetection.SupportedModels.MediaPipeHands;
const detectorConfig = {
runtime: 'mediapipe', // or 'tfjs'
modelType: 'full'
detector = await handPoseDetection.createDetector(model, detectorConfig);

Choose a modelType that fits your application needs, there are two options for you to choose from: lite, and full. From lite to full, the accuracy increases while the inference speed decreases.

2. Once you have a detector, you can pass in a video stream or static image to detect poses:

const video = document.getElementById('video');
const hands = await detector.estimateHands(video);

The output format is as follows: hands represent an array of detected hand predictions in the image frame. For each hand, the structure contains a prediction of the handedness (left or right) as well as a confidence score of this prediction. An array of 2D keypoints is also returned, where each keypoint contains x, y, and name. The x, y denotes the horizontal and vertical position of the hand keypoint in the image pixel space, and name denotes the joint label. In addition to 2D keypoints, we also return 3D keypoints (x, y, z values) in a metric scale, with the origin in auxiliary keypoint formed as an average between the first knuckles of index, middle, ring and pinky fingers.

score: 0.8,
Handedness: 'Right',
keypoints: [
{x: 105, y: 107, name: "wrist"},
{x: 108, y: 160, name: "pinky_finger_tip"},
keypoints3D: [
{x: 0.00388, y: -0.0205, z: 0.0217, name: "wrist"},
{x: -0.025138, y: -0.0255, z: -0.0051, name: "pinky_finger_tip"},

You can refer to our README for more details about the API.

Model deep dive

The updated version of our hand pose detection API improves the quality for 2D keypoint prediction, handedness (classification output whether it is left or right hand), and minimizes the number of false positive detections. More details about the updated model can be found in our recent paper: On-device Real-time Hand Gesture Recognition.

Following our recently released BlazePose GHUM 3D in TensorFlow.js, we also added metric-scale 3D keypoint prediction to hand pose detection in this release, with the origin being represented by an auxiliary keypoint, formed as a mean of first knuckles for index, middle, ring and pinky fingers. Our 3D ground truth is based on a statistical 3D human body model called GHUM, which is built using a large corpus of human shapes and motions.

To obtain hand pose ground truth, we fitted the GHUM hand model to our existing 2D hand dataset and recovered real world 3D keypoint coordinates. The shape and the hand pose variables of the GHUM hand model were optimized such that the reconstructed model aligns with the image evidence. This includes 2D keypoint alignment, shape, and pose regularization terms as well as anthropometric joint angle limits and model self contact penalties.

Sample GHUM hand fittings for hand images with 2D keypoint annotations overlaid. The data was used to train and test a variety of poses leading to better results for more extreme poses.
Sample GHUM hand fittings for hand images with 2D keypoint annotations overlaid. The data was used to train and test a variety of poses leading to better results for more extreme poses.

Model quality

In this new release, we substantially improved the quality of models, and evaluated them on a dataset of American Sign Language (ASL) gestures. As evaluation metric for 2D screen coordinates, we used Mean Average Precision (mAP) suggested by the COCO keypoint challenge methodology.

Hand model evaluation on American Sign Language dataset
Hand model evaluation on American Sign Language dataset

For 3D evaluation we used Mean Absolute Error in Euclidean 3D metric space, with the average error measured in centimeters.

Model Name

2D, mAP, %

3D, mean 3D error, cm

HandPose GHUM Lite



HandPose GHUM Full



Previous TensorFlow.js HandPose



Quality metrics for newly released HandPose GHUM models vs. previously released TensorFlow.js HandPose model in for 2D and 3D predictions

Browser performance

We benchmark the model across multiple devices. All the benchmarks are tested with two hands presented.

MacBook Pro 15” 2019. 

Intel core i9. 

AMD Radeon Pro Vega 20 Graphics.


iPhone 11


Pixel 5



Intel i9-10900K. Nvidia GTX 1070 GPU.


MediaPipe Runtime

With WASM & GPU Accel.

62 | 48

8 | 5

19 | 15 

  136 | 120

TensorFlow.js Runtime
With WebGL backend

36 | 31

15 | 12

11 | 8 

 42 | 35 

Inference speed of HandPose across different devices and runtimes. The first number in each cell is for the lite model, and the second number is for the full model.

To see the model’s FPS on your device, try our demo. You can switch the model type and runtime live in the demo UI to see what works best for your device.

Cross platform availability

In addition to the JavaScript hand pose detection API, these updated hand models are also available in MediaPipe Hands as a ready-to-use Android Solution API and Python Solution API, with prebuilt packages in Android Maven Repository and Python PyPI respectively.

For instance, for Android developers the Maven package can be easily integrated into an Android Studio project by adding the following into the project’s Gradle dependencies:

dependencies {
implementation 'com.google.mediapipe:solution-core:latest.release'
implementation 'com.google.mediapipe:hands:latest.release'

The MediaPipe Android Solution is designed to handle different use scenarios such as processing live camera feeds, video files, as well as static images. It also comes with utilities to facilitate overlaying the output landmarks onto either CPU images (with Canvas) or GPU (using OpenGL). For instance, the following code snippet demonstrates how it can be used to process a live camera feed and render the output on screen in real-time:

// Creates MediaPipe Hands.
HandsOptions handsOptions =
Hands hands = new Hands(activity, handsOptions);

// Connects MediaPipe Hands to camera.
CameraInput cameraInput = new CameraInput(activity);
cameraInput.setNewFrameListener(textureFrame -> hands.send(textureFrame));

// Registers a result listener.
handsResult -> {

// Starts the camera to feed data to MediaPipe Hands.

To learn more about MediaPipe Android Solutions, please refer to our documentation and try them out with the example Android Studio project. Also visit MediaPipe Solutions for more cross-platform solutions.


We would like to acknowledge our colleagues who participated in or sponsored creating HandPose GHUM 3D and building the APIs: Cristian Sminchisescu, Michael Hays, Na Li, Ping Yu, George Sung, Jonathan Baccash‎, Esha Uboweja, David Tian, Kanstantsin Sokal‎, Gregory Karpiak, Tyler Mullen, Chuo-Ling Chang, Matthias Grundmann.

Read More

Women in Machine Learning Symposium – Event Recap

Posted by Joana Carrasqueira, Program Manager, TensorFlow.

Thank you to everyone who joined us at the first Women in Machine Learning Symposium!

Hundreds of practitioners joined from all over the world to share tips and insights for careers in ML, how to be involved in the community, contribute to open source, and much more. It was very inspiring to learn from each other’s experiences. Following is a quick recap, and an overview of the resources we discussed at the event. Thanks again.

Online education

Get involved in the community

Build your portfolio

Connect with (or become) a GDE

Read More

What’s new in TensorFlow 2.7?

Posted by Goldie Gadde and Josh Gordon for the TensorFlow team

TensorFlow 2.7 is here! This release improves usability with clearer error messages, simplified stack traces, and adds new tools and documentation for users migrating to TF2.

Improved Debugging Experience

The process of debugging your code is a fundamental part of the user experience of a machine learning framework. In this release, we’ve considerably improved the TensorFlow debugging experience to make it more productive and more enjoyable, via three major changes: simplified stack traces, displaying additional context information in errors that originate from custom Keras layers, and a wide-ranging audit of all error messages in Keras and TensorFlow.

Simplified stack traces

TensorFlow is now filtering by default the stack traces displayed upon error to hide any frame that originates from TensorFlow-internal code, and keep the information focused on what matters to you: your own code. This makes stack traces simpler and shorter, and it makes it easier to understand and fix the problems in your code.

If you’re actually debugging the TensorFlow codebase itself (for instance, because you’re preparing a PR for TensorFlow), you can turn off the filtering mechanism by calling tf.debugging.disable_traceback_filtering().

Automatic context injection for Keras layer exceptions

One of the most common use cases for writing low-level code is creating custom Keras layers, so we wanted to make debugging your layers as easy and productive as possible. The first thing you do when you’re debugging a layer is to print the shapes and dtypes of its inputs, as well the value of its training and mask arguments. We now add this information automatically to all stack traces that originate from custom Keras layers.

See the effect of stack trace filtering and call context information display in practice in the image below:

Simplified stack traces in TensorFlow 2.7
Simplified stack traces in TensorFlow 2.7

Audit and improve all error messages in the TensorFlow and Keras codebases

Lastly, we’ve audited every error message in the Keras and TensorFlow codebases (thousands of error locations!) and improved them to make sure they follow UX best practices. A good error message should tell you what the framework expected, what you did that didn’t match the framework’s expectations, and should provide tips to fix the problem.

Improve tf.function error messages

We have improved two common types of tf.function error messages: runtime error messages and “Graph” tensor error messages, by including tracebacks pointing to the error source in the user code. For other vague and inaccurate tf.function error messages, we also updated them to be more clear and accurate.

For the runtime error message caused by the user code

def f():
l = tf.range(tf.random.uniform((), minval=1, maxval=10, dtype=tf.int32))
return l[20]

A summary of the old error message looks like

# … Python stack trace of the function call …

InvalidArgumentError: slice index 20 of dimension 0 out of bounds.
[[node strided_slice (defined at <'ipython-input-8-250c76a76c0e'>:5) ]] [Op:__inference_f_75]

Errors may have originated from an input operation.
Input Source operations connected to node strided_slice:
range (defined at <ipython-input-8-250c76a76c0e >':4)

Function call stack:

A summary of the new error message looks like

# … Python stack trace of the function call …

InvalidArgumentError: slice index 20 of dimension 0 out of bounds.
[[node strided_slice
(defined at <ipython-input-3-250c76a76c0e>:5)
]] [Op:__inference_f_15]

Errors may have originated from an input operation.
Input Source operations connected to node strided_slice:
In[0] range (defined at <ipython-input-3-250c76a76c0e>:4)
In[1] strided_slice/stack:
In[2] strided_slice/stack_1:
In[3] strided_slice/stack_2:

Operation defined at: (most recent call last)
# … Stack trace of the error within the function …
>>> File "<ipython-input-3-250c76a76c0e>", line 7, in <module>
>>> f()
>>> File "<ipython-input-3-250c76a76c0e>", line 5, in f
>>> return l[20]

The main difference is runtime errors raised while executing a tf.function now include a stack trace which shows the source of the error, in the user’s code.

# … Original error message and information …
# … More stack frames …
>>> File "<ipython-input-3-250c76a76c0e>", line 7, in <module>
>>> f()
>>> File "<ipython-input-3-250c76a76c0e>", line 5, in f
>>> return l[20]

For the “Graph” tensor error messages caused by the following user code

x = None

def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return a + 2

def captures_leaked_tensor(b):
b += x
return b


A summary of the old error message looks like

# … Python stack trace of the function call …

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
def has_init_scope():
my_constant = tf.constant(1.)
with tf.init_scope():
added = my_constant * 2
The graph tensor has name: add:0

A summary of the new error message looks like

# … Python stack trace of the function call …

TypeError: Originated from a graph execution error.

The graph execution error is detected at a node built at (most recent call last):
# … Stack trace of the error within the function …
>>> File <ipython-input-5-95ca3a98778f>, line 6, in leaky_function
# … More stack trace of the error within the function …

Error detected in node 'add' defined at: File "<ipython-input-5-95ca3a98778f>", line 6, in leaky_function

TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

The main difference is errors for attempting to capture a tensor that was leaked from an unreachable graph now include a stack trace which shows where the tensor was created in the user’s code:

# … Original error message and information …
# … More stack frames …
>>> File <ipython-input-5-95ca3a98778f>, line 6, in leaky_function

Error detected in node 'add' defined at: File "<ipython-input-5-95ca3a98778f>", line 6, in leaky_function

TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

Introducing tf.experimental.ExtensionType

User-defined types can make your projects more readable, modular, maintainable. TensorFlow 2.7.0 introduces the ExtensionType API, which can be used to create user-defined object-oriented types that work seamlessly with TensorFlow’s APIs. Extension types are a great way to track and organize the tensors used by complex models. Extension types can also be used to define new tensor-like types, which specialize or extend the basic concept of “Tensor.” To create an extension type, simply define a Python class with tf.experimental.ExtensionType as its base, and use type annotations to specify the type for each field:

class TensorGraph(tf.experimental.ExtensionType):
"""A collection of labeled nodes connected by weighted edges."""
edge_weights: tf.Tensor # shape=[num_nodes, num_nodes]
node_labels: typing.Mapping[str, tf.Tensor] # shape=[num_nodes]; dtype=any

class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for missing/invalid values.

class CSRSparseMatrix(tf.experimental.ExtensionType):
"""Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
values: tf.Tensor # shape=[num_nonzero]; dtype=any
col_index: tf.Tensor # shape=[num_nonzero]; dtype=int64
row_index: tf.Tensor # shape=[num_rows+1]; dtype=int64

The ExtensionType base class adds a constructor and special methods based on the field type annotations (similar to typing.NamedTuple and @dataclasses.dataclass from the standard Python library). You can optionally customize the type by overriding these defaults, or adding new methods, properties, or subclasses.

Extension types are supported by the following TensorFlow APIs:

  • Keras: Extension types can be used as inputs and outputs for Keras Models and Layers.
  • Dataset: Extension types can be included in Datasets, and returned by dataset Iterators.
  • TensorFlow hub: Extension types can be used as inputs and outputs for tf.hub modules.
  • SavedModel: Extension types can be used as inputs and outputs for SavedModel functions.
  • tf.function: Extension types can be used as arguments and return values for functions wrapped with the @tf.function decorator.
  • control flow: Extension types can be used by control flow operations, such as tf.cond and tf.while_loop. This includes control flow operations added by autograph.
  • tf.py_function: Extension types can be used as arguments and return values for the func argument to tf.py_function.
  • Tensor ops: Extension types can be extended to support most TensorFlow ops that accept Tensor inputs (e.g., tf.matmul, tf.gather, and tf.reduce_sum), using dispatch decorators.
  • distribution strategy: Extension types can be used as per-replica values.

For more information about extension types, see the Extension Type guide.

Note: The tf.experimental prefix indicates that this is a new API, and we would like to collect feedback from real-world usage; barring any unforeseen design issues, we plan to migrate ExtensionType out of the experimental package in accordance with the TF experimental policy.

TF2 Migration made easier!

To support users interested in migrating their workloads from TF1 to TF2, we have created a new Migrate to TF2 tab on the TensorFlow website, which includes updated guides and completely new documentation with concrete, runnable examples in Colab.

A new shim tool has been added which dramatically eases migration of variable_scope-based models to TF2. It is expected to enable most TF1 users to run existing model architectures as-is (or with only minor adjustments) in TF2 pipelines without having to rewrite your modeling code. You can learn more about it in the model mapping guide.

New community contributed models on TensorFlow Hub

Since the last TensorFlow release, the community really came together to make many new models available on TensorFlow Hub. Now you can find models like MLP-Mixer, Vision Transformers, Wav2Vec2, RoBERTa, ConvMixer, DistillBERT, YoloV5 and many more. All of these models are ready to use via TensorFlow Hub. You can learn more about publishing your models here.

Next steps

Check out the release notes for more information. To stay up to date, you can read the TensorFlow blog, follow twitter.com/tensorflow, or subscribe to youtube.com/tensorflow. If you’ve built something you’d like to share, please submit it for our Community Spotlight at goo.gle/TFCS. For feedback, please file an issue on GitHub or post to the TensorFlow Forum. Thank you!

Read More