Optical character recognition with TensorFlow Lite: A new example app

Posted by Wei Wei, TensorFlow Developer Advocate

As the old adage goes, “a picture is worth a thousand words.” Images are rich in visual information, but sometimes the key is with the text within. While it is easy for literate human beings to read words embedded in images, how do we use computer vision and machine learning to teach computers to do so?

Today, we are going to show you how to use TensorFlow Lite to extract text from images on Android devices. We will walk you through the key steps of the Optical Character Recognition (OCR) Android app that we recently open sourced here, which you can refer to for the complete code. You can see how the app extracts the product names from three Google product logos in the animation below.

Optical Character Recognition demo

The process of recognizing text from images is called Optical Character Recognition and is widely used in many domains. For example, Google Maps uses OCR technology to automatically extract information from the geo-located imagery to improve Google Maps.

Generally speaking, OCR is a pipeline with multiple steps. Usually they consist of text detection and text recognition:

  • Use a text detection model to find out bounding boxes around text;
  • Do some post-processing to transform the bounding boxes;
  • Transform the images within those bounding boxes into grayscale, so that a text recognition model can map out the words and numbers.

In our case, we are going to leverage the text detection and text recognition models from TensorFlow Hub. There are several different model versions for speed / accuracy tradeoffs; we use the float16 quantized models here. For more information on model quantization, please refer to the TensorFlow Lite quantization section. We also use OpenCV, which is a widely used computer vision library for Non-Maximum Suppression (NMS) and perspective transformation (we’ll expand on this later) to post-process detection results. In addition, we use the TFLite Support Library to grayscale and normalize the images.

OCR pipeline from text detection, perspective transformation, to recognition
OCR pipeline from text detection, perspective transformation, to recognition.

For text detection, since the detection model accepts a fixed size of 320×320, we use the TFLite Support Library to resize and normalize the input image:

val imageProcessor =
.add(ResizeOp(height, width, ResizeOp.ResizeMethod.BILINEAR))
.add(NormalizeOp(means, stds))
var tensorImage = TensorImage(DataType.FLOAT32)

tensorImage = imageProcessor.process(tensorImage)

Then we use TFLite to run the detection model:

detectionInterpreter.runForMultipleInputsOutputs(detectionInputs, detectionOutputs)

The output of the detection model is a number of rotated bounding boxes which contain the text in the image. We run Non-Maximum Suppression to identify one bounding box for each text block with OpenCV:


Sometimes texts inside images are distorted (e.g., the ‘kubernetes’ sticker on my laptop) with a perspective angle:

Perspective transformation demo
Perspective transformation demo

If we just feed the raw rotated bounding box into the recognition model, the model is unlikely to correctly identify the characters. In this case, we need to use OpenCV to do perspective transformation:

val rotationMatrix = getPerspectiveTransform(srcPtsMat, targetPtsMat)

Size(recognitionImageWidth.toDouble(), recognitionImageHeight.toDouble())

After that, we use the TFLite Support Library again to resize, grayscale, and normalize the transformed images inside the bounding boxes:

val imageProcessor =
.add(ResizeOp(height, width, ResizeOp.ResizeMethod.BILINEAR))
.add(NormalizeOp(mean, std))

Finally, we run the text recognition model, map out the characters and numbers from the model output, and update the app UI:

recognitionInterpreter.run(recognitionTensorImage.buffer, recognitionResult)

var recognizedText = ""
for (k in 0 until recognitionModelOutputSize) {
var alphabetIndex = recognitionResult.getInt(k * 8)
if (alphabetIndex in 0..alphabets.length - 1)
recognizedText = recognizedText + alphabets[alphabetIndex]
Log.d("Recognition result:", recognizedText)
if (recognizedText != "") {
ocrResults.put(recognizedText, getRandomColor())

That’s it. We are now able to extract text from input images using TFLite within our app.

Finally, if you just want a ready-to-use OCR SDK, Google also offers on-device OCR functionality through ML Kit, which uses TFLite underneath and should be sufficient for most OCR use cases. There are some situations where you may want to build your own OCR solution with TFLite such as:

  • You have your own text detection / recognition TFLite models that you would like to use;
  • You have special business requirements (e.g. recognizing upside-down text) and need to customize the OCR pipeline;
  • You want to support languages not covered by ML Kit;
  • Your target user devices that don’t necessarily have Google Play services installed;
  • You want to have control over hardware backends (CPU / GPU / etc.) used to run your models.

In these cases, I hope that this tutorial and our example implementation can help you get started on building your own OCR functionality in your app.

The author would like to thank Tian Lin for the helpful feedback and community contributors @Tulasi123789 and @risingsayak for their prior work on OCR using TFLite (creating and uploading the models to TF Hub, providing accompanying notebooks, and etc.).

TensorFlow Hub’s Experience with Google Summer of Code 2021

Posted by Sayak Paul (MLE at Carted, and GDE) and Morgan Roff (Google)

header with GSOC and TFHub logos

We’re happy to share the work completed by Google Summer of Code students working with TensorFlow Hub this year. If you’re a student who is interested in writing open source code, then you’ll likely be interested in Google’s Summer of Code program.

Through this program, students propose project ideas to open source organizations, and if selected, receive a stipend to work with them to complete their projects over the summer. Students have the opportunity to learn directly from mentors within their selected organization, and organizations benefit from the students’ contributions. This year, 17 successful students completed their projects with the TensorFlow organization on many projects. In this article, we’ll focus on some of the work completed on TensorFlow Hub.

We’re Sayak and Morgan, two mentors for projects on TensorFlow Hub (TF Hub). Here we share what the students learned about building and publishing state-of-the-art models, training them on large-scale benchmark datasets, what we learned as mentors, and how rewarding summer of code was for each of us, and for the community.

We had the opportunity to mentor two students – Aditya Kane and Vasudev Gupta. Aditya successfully implemented several variants of RegNets including one based on this paper, and trained them on the ImageNet-1k dataset. Vasudev ported the pre-trained wav2vec2 weights from this paper to TensorFlow, which required him to implement the model architecture from scratch. He then demonstrated fine-tuning these pre-trained checkpoints on the LibriSpeech dataset, making his work more customizable and relevant for the community.

With model training happening at such a large scale, it becomes especially important to follow good engineering practices during the implementation. These include code modularization, unit tests, good design patterns, optimizations, and so on. Models were trained on Cloud TPUs to accelerate training time, and as such, substantial effort was put into the data input pipelines to ensure maximum accelerator utilization.

All of these factors collectively contributed to the complexity of the projects. Thanks to the Summer of Code program, students have the opportunity to tackle these challenges with the help of experienced mentors. This also enables students to gain insight into their organizations, and interact with people with many skillsets who cooperate to make large projects possible. A big thank you here to our students, who gracefully handled this engineering work and listened to our feedback.

Vasudev and Aditya contributed significant pre-trained models to TensorFlow Hub, along with tutorials (Wav2Vec, RegNetY) on their use, and TensorFlow implementations for folks who want to dig deeper. In their own words:

The last 2-3 months were full of lots of learning and coding. GSoC helped me get into the speech domain and motivated me to explore more about the TensorFlow ecosystem. I am thankful to my mentors for their continuous & timely feedback. I am looking forward to contributing more to the TensorFlow community and other awesome open source projects out there. – Vasudev Gupta

More about RegNets and Wav2Vec2

Almost 6 years after they were first published, ResNets are still widely used as benchmark architectures across image understanding tasks. Many recent self-supervised and semi-supervised learning frameworks still leverage ResNet50 as their backbone architectures. However, ResNets often do not scale well under larger data regimes and suffer from large training and inference time latencies as they grow. In contrast, RegNets were developed specifically to be a scalable architecture framework that maintains low latency while demonstrating high performance on standard image recognition tasks. Aditya’s models are published on TF Hub, with code and tutorials on GitHub.

Self-supervised learning is an important area of machine learning research. Many recent success stories have been focused on NLP and Computer Vision, and for Vasudev’s project, we wanted to explore speech. Last year, a group of researchers released the wav2vec2 framework for learning representations from audio in a self-supervised manner, benefiting downstream tasks like speech-to-text.

Using wav2vec2, you can now pre-train speech models without labeled data, and fine-tune those models on downstream tasks like speaker recognition. Vasudev’s models are available on TF Hub, along with a new tutorial on fine-tuning, and code on GitHub.

Wrapping up

We’d like to say a heartfelt thank you to all the students, mentors, and organizers who made Summer of Code a success despite this year’s many challenges. We encourage you to check out these models and share what you have built with us by tagging #TFHub on your social media posts, or share your work for the community spotlight program. If you have questions or want to learn more about these new models, you can ask them on discuss.tensorflow.org.

How Waze Uses TFX to Scale Production-Ready ML

Posted by Gal Moran, Iris Shmuel, and Daniel Marcous (Data Scientists at Waze)


Waze is the world’s largest community-based traffic and navigation app. It uses real-time data to help users circumvent literal and figurative bumps in the road. On top of mobile navigation, Waze offers a web platform, a carpool app, partnership services, an advertisement platform and more. Such a broad portfolio brings along diverse technological challenges and many different use cases.

GIF of Waze logo

ML @Waze

Waze relies on many ML solutions, including:

  • Predicting ETA
  • Matching Riders & Drivers (Carpool)
  • Serving The Right Ads

But it’s not that easy to get something like these right and “production grade”. It is very common for these kinds of projects to have requirements for complex surrounding infrastructure for getting them to production and hence require multiple engineers (data scientist, software engineer and software reliability engineers) and a lot of time. Even more so when you mix in the Waze-y requirements like large scale data, low (real-time, actually) latency inference, diverse use cases, and a whole lot of geospatial data.

The above is a good reason why opportunistically starting to do ML created a chaotic state at Waze. For us it manifested as:

  • Multiple ML frameworks – you name it (sklearn, xgboost, TensorFlow, fbprophet, Java PMML, hand made etc.)
  • ML & Ops disconnect – models & feature engineering embedded in (Java) backend servers by engineers with limited monitoring and validation capabilities
  • Semi-manual operations for training, validation and deployment
  • A hideously long development cycle from idea to production

Overall, data scientists ended up spending a lot of their time on ops and monitoring instead of focusing on the actual modelling and data processing. At a certain level of growth we’ve decided to organize the chaos and invest in automation and processes so we can scale faster. We’ve decided to heavily invest in a way to dramatically increase velocity and quality by adopting a full cycle data science philosophy. This means that in this new world we wanted to build, a single data scientist is able to close the product cycle from research to a production grade service.

Data scientists now directly contribute to production to maximize impact. They focus on modelling and data processing and get many infrastructures and ops work out-of-the-box. While we are not yet at the end of this journey fully realizing the above vision, we feel like the effort layed out here was crucial in putting us on the right track.

Waze’s ML Stack

Translating the above philosophy to a tech spec, we were set on creating an easy, stable, automated and uniform way of building ML pipelines at Waze.

Deep diving into tech requirements we came up with the below criteria:

  • Simple — to understand, use, operate
  • Managed — no servers, no hardware, just code
  • Customizable — get the simple stuff for free, yet flexible enough to go crazy for the 5% that would require going outside the lines
  • Scalable — auto scalable data processing, training, inference
  • Pythonic — we need something production-ready, that works with most tools and code today and fits the standard data scientist. There are practically no other options than Python these days.

For the above reasons we’ve landed on TFX and the power of its built-in components to deliver these capabilities mostly out of the box.

It’s worth saying – Waze runs its tech stack on Google Cloud Platform (GCP).

It happens to be that GCP offers a suite of tools called Vertex AI. It is the ML infrastructure platform Waze is building on top of. While we use many components of Vertex AI’s managed services, we will focus here on – Vertex Pipelines – a framework for ML pipelines that helps us encapsulate TFX (or any pipeline) complexity and setup.

Together with our data tech stack, the overall ML architecture at Waze (all managed, scaled, pythonic etc.) is as follows:

graph of ML architecture at Waze

Careful readers will notice the alleged caveat here – we go all in on TensorFlow.

TFX means TensorFlow (even though that’s not exactly true anymore, let’s assume it is).

It might be a little scary at first when you have many different use cases.

Fortunately, the TF ecosystem is rich and Waze has the merit of having large enough data that neural nets converge.

Since starting this we’ve yet to find a use case that TF magic does not solve better or adequately as other frameworks (and not talking about micro % points, not trying to do a Kaggle competition here but get something to production).

Waze TFX

You might think that landing on TFX and Vertex Pipelines solved all our problems, but that’s not exactly true.

In order to make things truly simple we’ve had to write some “glue code” (integrating the various products in the above architecture diagram) and abstracting enough details so the common data scientist could use this stuff effectively and fast.

That resulted in:

  • Eliminated boilerplate
  • Hiding all common TFX components so data scientists only focus on feature engineering and modelling and get the entire pipeline for free
  • Generating BigQuery based train / eval split
  • Providing pre-implemented optional common features transform (e.g. scaling, normalization, imputations)
  • Providing pre-implemented Keras models (e.g. DNN/RNN model. TF Estimator like but in Keras that speaks TFX)
  • Utility functions (e.g. TF columns preparation)
  • Unit testing framework for tf.transform feature engineering code
  • Orchestrated and scheduled pipeline runs from Airflow using a Cloud run instance with all TFX packages installed (without installing it on the Airflow composer)

We’ve put it all in an easy to use Python package called “waze-data-tfx”

Pyramid chart showing levels of Waze data tfx

On top, we provided a super detailed walkthrough, usage guides and code templates, to our data scientists, so the common DS workflow is: fork, change config, tweak the code a little, deploy.

For reference this is how a simple waze-data-tfx pipeline looks like:

  1. Configuration
    _DATASET_NAME = 'tfx_examples'
    _TABLE_NAME = 'simple_template_data'

    _LABEL_KEY = 'label'
    "categorical_calculated": 2,
    _DENSE_FLOAT_FEATURE_KEYS = ["numeric_feature1", "numeric_feature2"]
    "numeric_feature1": 5,
    "categorical_feature": {
    'top_k': 5,
    'num_oov_buckets': 3

    _EVAL_BATCH_SIZE = 128
    _NUM_EPOCHS = 250

    'dnn_hidden_units': [6, 3],
    'optimizer': tf.keras.optimizers.Adam,
    'optimizer_kwargs': {
    'learning_rate': 0.01
    'layer_activation': None,
    'metrics': ["Accuracy"]

    _EVAL_METRIC_SPEC = create_metric_spec([
    mse_metric(upper_bound=25, absolute_change=1),
  2. Feature Engineering
    def preprocessing_fn(inputs):
    """tf.transform's callback function for preprocessing inputs.

    inputs: map from feature keys to raw not-yet-transformedfeatures.

    Map from string feature key to transformed feature operations.
    outputs = features_transform(
    return outputs
  3. Modelling
    def _build_keras_model(**training_args):
    """Build a keras model.

    hidden_units: [int], the layer sizes of the DNN (input layer first).
    learning_rate: [float], learning rate of the Adam optimizer.

    A keras model
    feature_columns =

    return _dnn_regressor(deep_columns=list(feature_columns.values()),
  4. Orchestration
    pipeline_run = WazeTFXPipelineOperator(

Simple, right?

When you commit a configuration file to the code base it gets deployed and sets up continuous training, and a full blown pipeline including all TFX and Vertex AI magics like data validation, transforms deployed to Dataflow, monitoring etc.


We knew we were up to something good when one of our data scientists came back from a long leave and had to use this new framework for a use case. She said that she was able to spin up a full production-ready pipeline in hours, something that before her leave would have taken her weeks to do.

Going forward we have much planned that we want to bake into `waze-data-tfx`. A key advantage that we see in having this common infrastructure is that once a feature is added, then everyone can enjoy it “for free”. For example, we plan on adding additional components to the pipeline, such as Infra Validator and Fairness Indicators. Once these are supported, every new or existing ML pipeline will add these components out-of-the-box, no extra code needed.

Additional improvements we are planning are around deployment. We wish to provide deployment quality assurance while automating as much as possible.

One way we are currently exploring doing so is using canary deployments. A data scientist will simply need to configure an evaluation metric and the framework (using Vertex Prediction traffic splitting capabilities and other continuous evaluation magic) would test the new model in production and gradually deploy or rollback according to the evaluated metrics.

Introducing TensorFlow Similarity

Posted by Elie Bursztein and Owen Vallis, Google

Today we are releasing the first version of TensorFlow Similarity, a python package designed to make it easy and fast to train similarity models using TensorFlow.

Examples of nearest neighbor searches performed on the embeddings generated by a similarity model trained on the Oxford IIIT Pet Dataset
Examples of nearest neighbor searches performed on the embeddings generated by a similarity model trained on the Oxford IIIT Pet Dataset

The ability to search for related items has many real world applications, from finding similar looking clothes, to identifying the song that is currently playing, to helping rescue missing pets. More generally, being able to quickly retrieve related items is a vital part of many core information systems such as multimedia searches, recommender systems, and clustering pipelines.

Similarity models learn to output embeddings that project items in a metric space where similar items are close together and far from dissimilar ones
Similarity models learn to output embeddings that project items in a metric space where similar items are close together and far from dissimilar ones

Under the hood, many of these systems are powered by deep learning models that are trained using contrastive learning. Contrastive learning teaches the model to learn an embedding space in which similar examples are close while dissimilar ones are far apart, e.g., images belonging to the same class are pulled together, while distinct classes are pushed apart from each other. In our example, all the images from the same animal breed are pulled together while different breeds are pushed apart from each other.

Oxford-IIIT Pet dataset visualization using the Tensorflow Similarity projector
Oxford-IIIT Pet dataset visualization using the Tensorflow Similarity projector

When applied to an entire dataset, contrastive losses allow a model to learn how to project items into the embedding space such that the distances between embeddings are representative of how similar the input examples are. At the end of training you end up with a well clustered space where the distance between similar items is small and the distance between dissimilar items is large. For example, as visible above, training a similarity model on the Oxford-IIIT Pet dataset leads to meaningful clusters where similar looking breeds are close-by and cats and dogs are clearly separated.

Finding related items involve computing the query image embedding, performing an ANN search to find similar items and fetching similar items metadata including the images bytes.
Finding related items involve computing the query image embedding, performing an ANN search to find similar items and fetching similar items metadata including the images bytes.

Once the model is trained, we build an index that contains the embeddings of the various items we want to make searchable. Then at query time, TensorFlow Similarity leverages Fast Approximate Nearest Neighbor search (ANN) to retrieve the closest matching items from the index in sub-linear time. This fast look up leverages the fact that TensorFlow Similarity learns a metric embedding space where the distance between embedded points is a function of a valid distance metric. These distance metrics satisfy the triangle inequality, making the space amenable to Approximate Nearest Neighbor search and leading to high retrieval accuracy.

Other approaches, such as using model feature extraction, require the use of an exact nearest neighbor search to find related items and may not be as accurate as a trained similarity model. This prevents those methods scaling as performing an exact search requires a quadratic time in the size of the search index. In contrast, TensorFlow Similarity’s built-in Approximate Nearest Neighbor indexing system, which relies on the NMSLIB, makes it possible to search over millions of indexed items, retrieving the top-K similar matches within a fraction of second.

Beside accuracy and retrieval speed, the other major advantage of similarity models is that they allow you to add an unlimited new number of classes to the index without having to retrain. Instead you only need to compute the embeddings for representative items of the new classes and add them to the index. This ability to dynamically add new classes is particularly useful when tackling problems where the number of distinct items is unknown ahead of time, constantly changing, or is extremely large. An example of this would be enabling users to discover newly released music that is similar to songs they have liked in the past.

TensorFlow Similarity provides all the necessary components to make similarity training evaluation and querying intuitive and easy. In particular, as illustrated below, TensorFlow Similarity introduces the SimilarityModel(), a new Keras model that natively supports embedding indexing and querying. This allows you to perform end-to-end training and evaluation quickly and efficiently..

A minimal example that trains, indexes and searches on MNIST data can be written in less than 20 lines of code:

from tensorflow.keras import layers

# Embedding output layer with L2 norm
from tensorflow_similarity.layers import MetricEmbedding
# Specialized metric loss
from tensorflow_similarity.losses import MultiSimilarityLoss
# Sub classed keras Model with support for indexing
from tensorflow_similarity.models import SimilarityModel
# Data sampler that pulls datasets directly from tf dataset catalog
from tensorflow_similarity.samplers import TFDatasetMultiShotMemorySampler
# Nearest neighbor visualizer
from tensorflow_similarity.visualization import viz_neigbors_imgs

# Data sampler that generates balanced batches from MNIST dataset
sampler = TFDatasetMultiShotMemorySampler(dataset_name='mnist', classes_per_batch=10)

# Build a Similarity model using standard Keras layers
inputs = layers.Input(shape=(28, 28, 1))
x = layers.Rescaling(1/255)(inputs)
x = layers.Conv2D(64, 3, activation='relu')(x)
x = layers.Flatten()(x)
x = layers.Dense(64, activation='relu')(x)
outputs = MetricEmbedding(64)(x)

# Build a specialized Similarity model
model = SimilarityModel(inputs, outputs)

# Train Similarity model using contrastive loss
model.compile('adam', loss=MultiSimilarityLoss())
model.fit(sampler, epochs=5)

# Index 100 embedded MNIST examples to make them searchable
sx, sy = sampler.get_slice(0,100)
model.index(x=sx, y=sy, data=sx)

# Find the top 5 most similar indexed MNIST examples for a given example
qx, qy = sampler.get_slice(3713, 1)
nns = model.single_lookup(qx[0])

# Visualize the query example and its top 5 neighbors
viz_neigbors_imgs(qx[0], qy[0], nns)

Even though the code snippet above uses a sub-optimal model, it still yields good matching results where the nearest neighbors clearly looks like the queried digit as visible in the screenshot below:

Code example showing number 5

This initial release focuses on providing all the necessary components to help you build contrastive learning based similarity models, such as losses, indexing, batch samplers, metrics, and tutorials. TF Similarity also makes it easy to work with the Keras APIs and use the existing Keras Architectures. Moving forward, we plan to build on this solid foundation to support semi-supervised and self-supervised methods such as BYOL, SWAV, and SimCLR.

You can start experimenting with TF Similarity right away by heading to the Hello World tutorial. For more information you can check out the project Github.

Faster Quantized Inference with XNNPACK

Posted by Marat Dukhan and Frank Barchard, software engineers

Quantization is among the most popular methods to speedup neural network inference on CPUs. A year ago TensorFlow Lite increased performance for floating-point models with the integration of XNNPACK backend. Today, we are extending the XNNPACK backend to quantized models with, on average across computer vision models, 30% speedup on ARM64 mobile phones, 5X speedup on x86-64 laptop and desktop systems, and 20X speedup for in-browser inference with WebAssembly SIMD compared to the default TensorFlow Lite quantized kernels.

Quantized inference in XNNPACK is optimized for symmetric quantization schemas used by the TensorFlow Model Optimization Toolkit. XNNPACK supports both the traditional per-tensor quantization schema and the newer accuracy-optimized schema with per-channel quantization of weights and per-tensor quantization of activations. Additionally, XNNPACK supports the asymmetric quantization schema, albeit with reduced efficiency.

Performance improvements

We evaluated XNNPACK-acclerated quantized inference on a number of edge devices and neural network architectures. Below, we present benchmarks on four public and two internal quantized models covering common computer vision tasks:

  1. EfficientNet-Lite0 image classification [download]
  2. EfficientDet-Lite0 object detection [download]
  3. DeepLab v3 segmentation with MobileNet v2 feature extractor [download]
  4. CartoonGAN image style transfer [download]
  5. Quantized version of the Face Mesh landmarks
  6. Quantized version of the Video Segmentation
Speedup from XNNPACK on single-threaded inference of quantized computer vision models on Android/ARM64 mobile phones.
Speedup from XNNPACK on single-threaded inference of quantized computer vision models on Android/ARM64 mobile phones.

Across the six Android ARM64 mobile devices XNNPACK delivers, on average, 30% speedup over the default TensorFlow Lite quantized kernels.

Speedup from XNNPACK on single-threaded inference of quantized computer vision models on x86-64 laptop and desktop systems.
Speedup from XNNPACK on single-threaded inference of quantized computer vision models on x86-64 laptop and desktop systems.

XNNPACK offers even greater improvements on laptop and desktop systems with x86 processors. On the 5 x86 processors in our benchmarks XNNPACK accelerated inference on average by 5 times. Notably, low-end and older processors which don’t support AVX instructions see over 20X speedup from switching quantized inference to XNNPACK: while the previous TensorFlow Lite inference backend had optimized implementations only for AVX, AVX2, and AVX512 instruction sets, XNNPACK provides optimized implementations for all x86-64 processors.

Speedup from XNNPACK on single-threaded WebAssembly SIMD inference of quantized computer vision models on mobile phones, laptops, and desktops when running through V8.
Speedup from XNNPACK on single-threaded WebAssembly SIMD inference of quantized computer vision models on mobile phones, laptops, and desktops when running through V8.

Besides the traditional mobile and laptop/desktop platforms, XNNPACK brings accelerated quantized inference to the Web platform through the TensorFlow Lite Web API. The above plot demonstrates a geomean speedup of 20X over the default TensorFlow Lite implementation when running WebAssembly SIMD benchmarks through the V8 JavaScript engine on 3 x86-64 and 2 ARM64 systems.

Two years of optimizations

XNNPACK started its life as a fork of QNNPACK library, but as the first version of XNNPACK focused on floating-point inference and QNNPACK focused on quantized inference, it was not possible to compare the two. Now with XNNPACK introducing support for quantized inference, we can directly evaluate and attribute the two further years of performance optimizations.

Graph showing xxnpack speedup over qnnpack

To compare the two quantized inference backends, we ported randomized MobileNet v1 and MobileNet v2 models from XNNPACK API to QNNPACK API, and benchmarked their single-threaded performance on two ARM64 Android phones and two x86-64 systems. The results are presented in the plot above, and the progress made by XNNPACK in two years is striking. XNNPACK is 50% faster on the older Pixel 3a phone and 4-5X faster on the newer Pixel 4a phone, 2.5X faster on the x86-64 laptop, and over 3X faster on the x86-64 workstation. These improvements are the result of a multiple optimizations XNNPACK gained in the two years since it forked from QNNPACK:

  • XNNPACK retained the optimizations in QNNPACK, like the Indirect Convolution algorithm and microarchitecture-specific microkernel selection, and further augmented them with Indirect Deconvolution algorithm, and more flexible capabilities, like built-in numpy-like broadcasting in the quantized addition and quantized multiplication operators.
  • Convolution, Deconvolution, and Fully Connected operators accumulate products of 8-bit activations and weights into a 32-bit number, and in the end this number needs to be converted back, or requantized, to an 8-bit number. There are multiple ways how requantization can be implemented, but QNNPACK adapted schema from the GEMMLOWP library, which pioneered quantized computations for neural network inference. However, it has since been discovered that GEMMLOWP requantization schema is suboptimal in terms of both accuracy and performance, and XNNPACK replaced it with more performant and accurate alternatives
  • Whereas QNNPACK targeted asymmetric quantization schema, where both activations and weights are represented as unsigned integers with zero point and scale quantization parameters, XNNPACK’s optimizations focus on symmetric quantization, where both activations and weights are signed integers, and weights have additional restrictions: the zero point of the weights is always zero and the quantized weights elements are limited to the [-127, 127] range (-128 is excluded even though it can be represented as a signed 8-bit integer). Symmetric quantization offers two computational advantages exploited in XNNPACK. First, when the filter weights are static, the results of accumulating the product of input zero point by the filter weights can be completely fused into the bias term in the Convolution, Deconvolution, and Fully Connected operators. Thus, zero point parameters are completely absent from the inference computations. Secondly, the product of a signed 8-bit input element by the weight element restricted to [-127, 127] fits into 15 bits. This enables the microkernels for Convolution, Deconvolution, and Fully Connected operators to do half of the accumulations on 16-bit variables rather than always extending the products to 32 bits.
  • QNNPACK microkernels were optimized NEON SIMD instructions on ARM and SSE2 SIMD instructions on x86, but XNNPACK supports a much wider set of instruction set-specific optimizations. Most quantized microkernels in XNNPACK are optimized for SSE2, SSE4.1, AVX, XOP, AVX2, and AVX512 instructions on x86/x86-64, for NEON, NEON V8, and NEON dot product instructions on ARM/ARM64, and for WebAssembly SIMD instructions. Additionally, XNNPACK provides scalar implementations for WebAssembly 1.0 and pre-NEON ARM processors.
  • QNNPACK introduced the idea of specialized assembly microkernels for high-end ARM and low-end ARM cores, but XNNPACK takes this idea much further. XNNPACK not only includes specialized expert-tuned software pipelined assembly microkernels for Cortex-A53, Cortex-A55, and high-end cores with and without NEON dot product instructions, but even supports switching between them on the fly. When a thread doing inference migrates from a big to a little core, XNNPACK automatically adapts from using a microkernel optimized for the big core to the one optimized for the little core.
  • QNNPACK mainly focused on multi-threaded inference and organized computations as a large number of small tasks, each computing a tiny tile of the output tensor. XNNPACK reworked parallelisation and made the tasks flexible: they can be fine-grained or coarse-grained depending on the number of threads participating in the parallelization. Through dynamic adjustment of task granularity, XNNPACK archives low overhead in single-threaded execution and high parallelization efficiency for multi-threaded inference.

Taken together, these optimizations make XNNPACK the new state-of-art for quantized inference, and turn TensorFlow Lite into the most versatile quantized inference solution, covering systems from Raspberry Pi Zero to Chromebooks to workstations with server-class processors.

How can you use it?

Quantized XNNPACK inference is enabled by default in the CMake builds of TensorFlow Lite for all platforms, in the Bazel builds of TensorFlow Lite for the Web platform, and will be available in TensorFlow Lite Web API in the 2.7 release. In Bazel builds for other platforms, quantized XNNPACK inference is enabled via a build-time opt-in mechanism. When building TensorFlow Lite with Bazel, add --define tflite_with_xnnpack=true --define xnn_enable_qs8=true, and the TensorFlow Lite interpreter will use the XNNPACK backend by default for supported operators with symmetric quantization. Limited support for operators with asymmetric quantization is available via the --define xnn_enable_qu8=true Bazel option.

Which operations are accelerated?

The XNNPACK backend currently supports a subset of quantized TensorFlow Lite operators (see documentation for details and limitations). XNNPACK supports models produced by the Model Optimization Toolkit through post-training integer quantization and quantization-aware training, but not post-training dynamic range quantization.

Future work

This is the third version of the XNNPACK integration into TensorFlow Lite following the initial release of the floating-point implementation and the subsequent release that brought sparse inference support. In the following versions we plan to add the following improvements:

  • Half-precision inference on the recent ARM processors
  • Sparse quantized inference.
  • Even faster dense inference.

We encourage you to leave your thoughts and comments on our GitHub and StackOverflow pages, and you can ask questions on discuss.tensorflow.org

Easy Machine Learning for On-Device Audio

Posted by Luiz GUStavo Martins, Developer Advocate

At Google I/O, we shared a set of tutorials to help you use machine learning on audio. In this blog post you’ll find resources to help you develop and customize an audio classification model for your app, and a couple of real world examples for inspiration.

GIF of dog with audio waves picking up sound

Machine learning for audio

Sound and audio are sometimes used interchangeably, but they have a key difference. Sound is in essence what you can hear while audio is the sound’s electronic representation. That’s why we usually use the term audio when talking about machine learning.

Machine Learning for audio can be used to:

  • Understand speech
  • Understand musical instruments
  • Classify events (which bird is that?)
  • Detect pitch
  • Generate music

In this post we will focus on audio classification of events, a common scenario in practice with many real world applications like NOAA creating a humpback whale acoustic detector, and the Zoological Society of London using audio recognition to protect wildlife.

A number of classification models are available for you to try right now on TensorFlow Hub (YAMNet, Whale detection).

Audio recognition can also run completely on-device. For example, Android has a sound notifications feature that provides push notification for important sounds around you. It can also detect which music is playing, or even help with an ML-powered audio recorder app that can transcribe conversations on-device.

Having the models is only the beginning. Now you might ask:

  • How do I use them on my app?
  • How do I customize them for my audio use case?

Deploying machine learning models on-device

Imagine you have an audio classification model ready, such as a pretrained one from TF-Hub, how would you use this in a mobile app? To help you integrate audio classification into your app we created the TensorFlow Lite Task Library. The Audio Classifier component was released and you only need a couple of lines of code to add audio classification to your application:

// Initialization
val classifier = AudioClassifier.createFromFile(this, modelPath)

// Start recording
val record = classifier.createAudioRecord()

// Load latest audio samples
val tensor = classifier.createInputTensorAudio()

// Run inference
val output = classifier.classify(tensor)

The library takes care of loading the model to memory, to create the audio recorder with the proper model specifications (sample rate, bit rate) and the classification method to get the model’s inference results. Here you can find a full sample to get some inspiration.

Customizing the models

What if you need to recognize audio events that are not in the set provided by the pretrained models? Or if you need to specialize them to fewer classes? In these situations, you need to fine tune the model using a technique called Transfer Learning.

This is a very popular process and you don’t need to be an expert on machine learning to be able to do it. You can use Model Maker to help you with this.

spec = audio_classifier.YamNetSpec()
data = audio_classifier.DataLoader.from_folder(spec, DATA_DIR)

train_data, validation_data = data.split(0.8)
model = audio_classifier.create(train_data, spec, validation_data)


You can find complete code here. The output model can be directly loaded by the Task Library. And Model Maker can customize models not only for audio but also for image, text and recommendation system


Machine learning for audio is an exciting field and with many possibilities, enabling many new features. Doing ML on-device is getting easier and faster with tools like TensorFlow Lite Task Library and customization can be done without expertise in the field with Model Maker.

You can learn more about it on our new On-Device Machine Learning website (the audio path is here). You’ll find tutorials, codelabs and lots of resources on how to do not only audio related tasks but also for image (classification, object detection) and text (classification, entity extraction, question and answer)

You can share with us what you build by adding #TensorFlow on your social network post with your project, or submit it for the TensorFlow community spotlight program. And if you have any questions, you can ask them on discuss.tensorflow.org.

3D Pose Detection with MediaPipe BlazePose GHUM and TensorFlow.js

Posted by Ivan Grishchenko, Valentin Bazarevsky, Eduard Gabriel Bazavan, Na Li, Jason Mayes, Google

Pose detection is an important step in understanding more about the human body in videos and images. Our existing models have supported 2D pose estimation for some time, which many of you may have already tried.

Today, we are launching our first 3D model in TF.js pose-detection API. 3D pose estimation opens up new design opportunities for applications such as fitness, medical, motion capture and beyond – in many of these areas we’ve seen a growing interest from the TensorFlow.js community. A great example of this is 3D motion capture to drive a character animation in the browser.

3D motion capture with BlazePose GHUM

3D motion capture with BlazePose GHUM by Richard Yee

(used with permission, live demo available at 3d.kalidoface.com)

This community demo uses multiple models powered by MediaPipe and TensorFlow.js (namely FaceMesh, BlazePose and HandPose). Even better, no app install is needed as you just need to visit a webpage to enjoy the experience. So with that in mind, let’s learn more and see this new model in action!

Try out the live demo!
Try out the live demo!


The pose-detection API provides two runtimes for BlazePose GHUM, namely MediaPipe runtime and TensorFlow.js runtime.

To install the API and runtime library, you can either use the <script> tag in your html file or use NPM.

Through script tag:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/pose-detection"></script>
<!-- Include below scripts if you want to use TF.js runtime. -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-core"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-converter"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-backend-webgl"></script>

<!-- Optional: Include below scripts if you want to use MediaPipe runtime. -->
<script src="https://cdn.jsdelivr.net/npm/@mediapipe/pose"></script>

Through NPM:

yarn add @tensorflow-models/pose-detection

# Run below commands if you want to use TF.js runtime.
yarn add @tensorflow/tfjs-core @tensorflow/tfjs-converter
yarn add @tensorflow/tfjs-backend-webgl

# Run below commands if you want to use MediaPipe runtime.
yarn add @mediapipe/pose

To reference the API in your JS code, it depends on how you installed the library.

If installed through script tag, you can reference the library through the global namespace poseDetection.

If installed through NPM, you need to import the libraries first:

import * as poseDetection from '@tensorflow-models/pose-detection';
// Uncomment the line below if you want to use TF.js runtime.
// import '@tensorflow/tfjs-backend-webgl';
// Uncomment the line below if you want to use MediaPipe runtime.
// import '@mediapipe/pose';

Try it yourself!

First, you need to create a detector:

const model = poseDetection.SupportedModels.BlazePose;
const detectorConfig = {
runtime: 'mediapipe', // or 'tfjs'
modelType: 'full'
detector = await poseDetection.createDetector(model, detectorConfig);

Choose a modelType that fits your application needs, there are three options for you to choose from: lite, full, and heavy. From lite to heavy, the accuracy increases while the inference speed decreases. Please try our live demo to compare different configurations.

Once you have a detector, you can pass in a video stream to detect poses:

const video = document.getElementById('video');
const poses = await detector.estimatePoses(video);

How to use the output? poses represent an array of detected pose predictions in the image frame. For each pose, it contains keypoints and keypoints3D. The keypoints are the same as the 2D model we launched before, it is an array of 33 keypoint objects, each object has x, y in pixel units.

keypoints3D is an additional array with 33 keypoint objects, each object has x, y, z. The x, y, z are in meter units. The person is modeled as if they were in a 2m x 2m x 2m cubic space. The range for each axis goes from -1 to 1 (therefore 2m total delta). The origin of this 3D space is the hip center (0, 0, 0). From the origin, z is positive if moving closer to the camera, and negative if moving away from the camera. See below output snippet for example:

score: 0.8,
keypoints: [
{x: 230, y: 220, score: 0.9, name: "nose"},
{x: 212, y: 190, score: 0.8, name: "left_eye"},
keypoints3D: [
{x: 0.5, y: 0.9, z: 0.06 score: 0.9, name: "nose"},

You can refer to our ReadMe for more details about the API.

As you begin to play and develop with BlazePose GHUM, we would appreciate your feedback and contributions. If you make something using this model, tag it with #MadeWithTFJS on social media so we can find your work, as we would love to see what you create.

Model deep dive

The key challenge to build the 3D part of our pose model was obtaining realistic, in-the-wild 3D data. In contrast to 2D, which can be obtained via human annotation, accurate manual 3D annotation becomes a uniquely challenging task. It requires either a lab setup or specialised hardware with depth sensors for 3D scans – which introduce additional challenges to preserve a good level of human and environment diversity in the dataset. Another alternative, which many researchers choose – to build a completely synthetic dataset, which introduces yet another challenge of domain adaptation to real-world pictures.

Our approach is based on a statistical 3D human body model called GHUM, which is built using a large corpus of human shapes and motions. To obtain 3D human body pose ground truth, we fitted the GHUM model to our existing 2D pose dataset and extended it with a real world 3D keypoint coordinates in metric space. During the fitting process the shape and the pose variables of GHUM were optimized such that the reconstructed model aligns with the image evidence. This includes 2D keypoint and silhouette semantic segmentation alignment as well as shape and pose regularization terms. For more details see related work on 3D pose and shape inference (HUND, THUNDR).

Sample GHUM fitting for input image
Sample GHUM fitting for an input image. From left to right: original image, 3D GHUM reconstruction (different viewpoint) and blended result projected on top of the original image.

Due to the nature of 3D to 2D projection, multiple points in 3D can have the same projection in 2D (i.e. with the same X and Y but different Z). So the fitting can result in several realistic 3D body poses for the given 2D annotation. To minimize this ambiguity, in addition to a 2D body pose, we asked annotators to provide depth order between pose skeleton edges where they are certain (check the figure below). This task proved to be an easy one (compared to a real depth annotation) showing high consistency between annotators (98% on cross-validation) and helped to reduce the depth ordering errors for the fitted GHUM reconstructions from 25% to 3%.

Depth order annotation: the wider edge corner denotes the corner closer to the camera (e.g. the person’s right shoulder is closer to camera than left shoulder on both examples)
“Depth order” annotation: the wider edge corner denotes the corner closer to the camera (e.g. the person’s right shoulder is closer to camera than left shoulder on both examples)

BlazePose GHUM utilizes a two-step detector-tracker approach where the tracker operates on a cropped human image. Thus the model is trained to predict 3D body pose in relative coordinates of a metric space with origin in the subject’s hips center.

MediaPipe vs. TF.js runtime

There are some pros and cons of using each runtime. As shown in the performance table below, the MediaPipe runtime provides faster inference speed on desktop, laptop and android phones. The TF.js runtime provides faster inference speed on iPhones and iPads. The TF.js runtime is also about 1 MB smaller than the MediaPipe runtime.

MacBook Pro 15” 2019. 

Intel core i9. 

AMD Radeon Pro Vega 20 Graphics.


iPhone 11


Pixel 5



Intel i9-10900K. Nvidia GTX 1070 GPU.


MediaPipe Runtime

With WASM & GPU Accel.

75 | 67 | 34

9 | 6 | N/A                   

25 | 21 | 8

150 | 130 | 97

TFJS Runtime

With WebGL backend.

52 | 40 | 24

 43 | 32 | 22

14 | 10 | 4

42 | 35 | 29

Inference speed of BlazePose GHUM across different devices and runtimes. The first number in each cell is for the lite model, and the second number is for the full model, the third number is for the heavy model.


We would like to acknowledge our colleagues, who participated in creating BlazePose GHUM 3D: Andrei Zanfir, Cristian Sminchisescu, Tyler Zhu, the other contributors to MediaPipe: Chuo-Ling Chang, Michael Hays, Ming Guang Yong, Matthias Grundmann, along with those involved with the TensorFlow.js pose-detection API: Ahmed Sabie and Ping Yu, and of course the community who are making amazing work with these models: Richard Yee.

Pose estimation and classification on edge devices with MoveNet and TensorFlow Lite

Posted by Khanh LeViet, TensorFlow Developer Advocate and Yu-hui Chen, Software Engineer

Since MoveNet’s announcement at Google I/O earlier this year, we have received a lot of positive feedback and feature requests. Today, we are excited to share several updates with you:

  • The TensorFlow Lite version of MoveNet is now available on TensorFlow Hub. This includes a few updates to improve accuracy and make it compatible with hardware accelerators including GPUs and other accelerators available via the Android NN API.
  • We’ve released a new Android, Raspberry Pi pose estimation sample that lets you try out MoveNet on mobile and IoT devices. (iOS is coming soon)
  • We’ve also released a Colab notebook that teaches you how to do custom pose classification (e.g. recognize different yoga poses) with MoveNet. You can try pose classification on the Android, iOS and Raspberry Pi apps mentioned earlier.

What is pose estimation?

Gif of pose estimation using machine learning

Pose estimation is a machine learning task that estimates the pose of a person from an image or a video by estimating the spatial locations of specific body parts (keypoints). MoveNet is the state-of-the-art pose estimation model that can detect these 17 key-points:

  • Nose
  • Left and right eye
  • Left and right ear
  • Left and right shoulder
  • Left and right elbow
  • Left and right wrist
  • Left and right hip
  • Left and right knee
  • Left and right ankle

We have released two versions of MoveNet:

  • MoveNet.Lightning is smaller, faster but less accurate than the Thunder version. It can run in realtime on modern smartphones.
  • MoveNet.Thunder is the more accurate version but also larger and slower than Lightning.

The MoveNet models outperform Posenet (paper, blog post, model), our previous TensorFlow Lite pose estimation model, on a variety of benchmark datasets (see the evaluation/benchmark result in the table below).

These MoveNet models are available in both the TensorFlow Lite FP16 and INT8 quantized formats, allowing maximum compatibility with hardware accelerators.

This version of MoveNet can recognize a single pose from the input image. If there is more than one person in the image, the model along with the cropping algorithm will try its best to focus on the person who is closest to the image center. We have also implemented a smart cropping algorithm to improve the detection accuracy on videos. In short, the model will zoom into the region where there’s a pose detected in the previous frame, so that the model can see the finer details and make better predictions in the current frame.

If you are interested in a deep-dive into MoveNet’s implementation details, check out an earlier blog post including its model architecture and the dataset it was trained on.

Sample app for Android and Raspberry Pi

We have released new pose estimation sample apps for these platforms so that you can quickly try out different pose estimation models (MoveNet Lightning, MoveNet Thunder, Posenet) on the platform of your choice.

  • Android sample
  • iOS sample
  • Raspberry Pi sample

In the Android and iOS sample, you can also choose an accelerator (GPU, NNAPI, CoreML) to run the pose estimation models.

Screenshot of the Android sample app. The image is from Pixabay.

Screenshot of the Android sample app. The image is from Pixabay.

MoveNet performance

We have optimized MoveNet to run well on hardware accelerators supported by TensorFlow Lite, including GPU and accelerators available via the Android NN API. This performance benchmark result may help you choose the runtime configurations that are most suitable for your use cases.


Size (MB)


Latency (ms) **

Pixel 5 – 
CPU 4 threads

Pixel 5 – GPU

Raspberry Pi 4 – CPU 4 threads

MoveNet.Thunder (FP16 quantized)






MoveNet.Thunder (INT8 quantized)






MoveNet.Lightning (FP16 quantized)






MoveNet.Lightning (INT8 quantized)






(MobileNetV1 backbone, FP32)






* mAP was measured on a subset of the COCO keypoint dataset where we filter and crop each image to contain only one person.

** Latency was measured end-to-end using the Android and Raspberry Pi sample apps with TensorFlow 2.5 under sustained load.

Here are some tips when deciding which model and accelerator to use:

  • Choose Lightning or Thunder. Firstly, you should see whether the accuracy of the Lightning version is enough for your use case.
    • If the Lightning INT8 model’s accuracy is good enough, then go with it because it’s the smallest and fastest model in the lineup. A faster model also means less battery consumed.
    • If having good accuracy is critical for your use case, go with the Thunder FP16 model.
  • Choose the accelerator. Accelerator performance varies a lot between Android devices from different manufacturers.
    • CPU is the safest and simplest choice because you can know for sure that it will work on practically any Android device that can run TensorFlow Lite. However, it is usually slower and consumes more power than running the model on accelerators. All MoveNet models can run well on CPU so you should choose a model based on your accuracy needs.
    • GPU is the most widely available accelerator and provides a decent performance boost. Choose the FP16 quantized models if you want to leverage GPUs.
    • Android NNAPI is the convenient way to access additional ML accelerators on Android devices. If you are already using the CPU or GPU for other workloads and your user’s device runs Android 10 or a newer version, you can choose a model that suits your accuracy needs, and let NNAPI choose the path that it thinks works best for your model.
    • If you are an IoT developer, you may want to use Coral to increase inference speed. See the benchmark numbers for Coral here.
  • Deploy the model over-the-air rather than bundle it in the app binary. Due to the variety of the Android ecosystem, there’s no single model that is optimal for all of your users. For users with lower-end devices, the Lightning INT8 model might be optimal for them because it’s the fastest and consumes the least battery. However, for users with high-end devices, you may want to deliver better performance using the Thunder FP16 model. If you want to change models according to the user device, consider using the free Firebase ML to host your models instead of bundling all the models you intend to use into your app. You can write a logic to download an optimal model for each of your user’s device when the user starts using a feature in your app that requires the TFLite model.

Pose classification

While the pose estimation model tells you where the pose key points are, in many fitness applications, you may want to go further and classify the pose, for example whether it’s a yoga goddess pose or a plank pose, to deliver relevant information to your users.

To make pose classification easier to implement, we’ve also released a Colab notebook that teaches you how to use MoveNet and TensorFlow Lite to train a custom pose classification model from your custom pose dataset. It means that if you want to recognize yoga poses, all you need is to collect images of poses that you want to recognize, label them, and follow the tutorial to train and deploy a yoga pose classifier into your applications.

The pose classifier consists of two stages:

  1. Use MoveNet to detect keypoints from the input image.
  2. Use a small TensorFlow Lite model to classify the pose from the detected keypoints.
An example of pose classification using MoveNet. The input image is from Pixabay.

An example of pose classification using MoveNet. The input image is from Pixabay.

In order to train a custom pose classifier, you need to prepare the pose images and put them into a folder structure as below. Each subfolder name is the name of the class you want to recognize. Then you can run the notebook to train a custom pose classifier and convert it to the TensorFlow Lite format.

|__ downdog
|______ 00000128.jpg
|______ 00000181.bmp
|______ ...
|__ goddess
|______ 00000243.jpg
|______ 00000306.jpg
|______ ...

The pose classification TensorFlow Lite model is very small, only about 30KBs. It takes the landmarks output from MoveNet, normalizes the pose coordinates and feeds it through a few fully connected layers. The model output is a list of probabilities that the pose is each of the known pose types.

Overview of the pose classification TensorFlow Lite model.
Overview of the pose classification TensorFlow Lite model.

You can try your pose classification model in any of the pose estimation sample apps for Android or Raspberry Pi that we have just released.

What’s next

Our goal is to provide the core pose estimation and action recognition engine so that developers can build creative applications on top of it. Here are some of the directions that we are actively working on:

  • An improved version of MoveNet that can detect multiple poses in one forward path.
  • Action recognition based on the detected poses on multiple frames.

Please let us know via tflite@tensorflow.org or the TensorFlow Forum if you have any feedback or suggestions!


We would like to thank the other contributors to MoveNet: Ronny Votel, Ard Oerlemans, Francois Belletti along with those involved with the TensorFlow Lite: Tian Lin, Lu Wang.

How Digitec Galaxus trains and serves millions of personalized newsletters per week with TFX

Posted by Christian Sager (Product Owner, Digitec Galaxus) and Anant Nawalgaria (ML Specialist, Google)

gif from blog titled how Digitec Galaxus trains and serves millions of personalized newsletters per week with TFX showing an animation of a cloud and people

In the retail industry it is important to be able to engage and excite users by serving personalized content on newsletters at scale. It is important to do this in a manner which leverages existing trends, while exploring and unearthing potentially new trends with an even higher user engagement. This project was done as a collaboration between Digitec Galaxus and Google, by designing a system based on Contextual Bandits to personalize newsletters for more than 2 million users every week.

To accomplish this, we leveraged several products in the TensorFlow ecosystem and Google Cloud including TF Agents, TensorFlow Extended (TFX) running on Vertex AI , to build a system that personalizes newsletters in a scalable, modularized and cost effective manner with low latency. In this article, we’ll highlight a few of the pieces, and point you to resources you can use to learn more.

About Digitec Galaxus

Digitec Galaxus AG is the largest online retailer in Switzerland. It offers a wide range of products to its customers, from electronics to clothes. As an online retailer, we naturally make use of recommendation systems, not only on our home or product pages but also in our newsletters. We have multiple recommendation systems in place already for newsletters, and have been extensive early adopters of the Google Cloud recommendations AI. Because we have multiple recommendation systems and very large amounts of data, we are faced with the following complications.

1. Personalization

We have over 12 recommenders that it uses in the newsletters, however we would like to contextualize these by choosing different recommenders (which in turn select the items) for different users. Furthermore, we would like to exploit existing trends as well as experiment with new ones.

2. Latency

We would like to ensure that the ranked list of recommenders can be retrieved with sub 50 ms latency.

3. End-to-end easy to maintain and generalizable/modular architecture

We wanted the solution to be architected using an easy to maintain, platform invariant, complete with all MLops capabilities required to train and use contextual bandits models. It was also important to us that it is built in a modular fashion such that it can be adapted easily to other use cases which have in mind such as recommendations on the homepage, Smartags and more.

Before we get to the details of how we built a machine learning infrastructure capable of dealing with all requirements, we’ll dig a little deeper into how we got here and what problem we’re trying to solve.

Using contextual bandits

Digitec Galaxus has multiple recommendation systems in place already. Because we have multiple recommendation systems, it is sometimes difficult to choose between them in a personalized fashion. Hence we reached out to Google seeking assistance with implementing Contextual Bandit driven recommendations, which personalizes our homepage as well as our newsletter. Because we only send newsletters to registered users, we can incorporate features for every user.

We chose TFAgents to implement the contextual bandit model. Training and serving pipelines were orchestrated by Vertex AI pipelines running TFX, which in turn used TFAgents for the development of the contextual bandit models. Here’s an overview of our approach.

TFAgents to implement the contextual bandit model

Rewarding subscribes, and penalizing unsubscribes

Given some features (context) about the user, and each of the 12 available recommenders, we aim to suggest best recommender (action) which increases the chance (reward) of the user clicking (reward = 1) on at least one of the recommendations by the selected recommender, and minimizes the chance of incurring a click which leads to unsubscribe (reward = -1).

By formulating the problem and reward function in this manner, we hypothesized that the system would optimize for increasing clicks, while still showing relevant (and not click-baity) content to the user in order to sustain the potential increase in performance. This is because the reward functions penalizes an event when a user unsubscribes, which a click-baity content is likely to lead to. The problem was then tackled by using contextual bandits because of the fact that they excel at exploiting trends that work well, as well as exploring and uncovering potentially even better-performing trends.

Serving millions of users every week with low latency

A diagram showing the high-level architecture of the recommendation training and prediction systems on GCP.
A diagram showing the high-level architecture of the recommendation training and prediction systems on GCP.

There’s a lot of detail here, as the architecture shown in the diagram covers three phases of ML development, training, and serving. Here are some of the key pieces.

Model development

Vertex Notebooks are used as data science environments for experimentation and prototyping, in addition to implementing model training and scoring components and pipelines. The source code is version controlled in GitHub. A continuous integration (CI) pipeline is set up to run unit tests, build pipeline components, and store the container images to Cloud Container Registry.


The training pipeline is executed using TFX on Vertex Pipelines. In essence, the pipeline trains the model using new training data extracted from BigQuery, validates the produced model, and stores it in the model registry. In our system, the model registry is curated in Cloud Storage. The training pipeline uses Dataflow for large scale data extraction, validation, processing and model evaluation, as well as Vertex Training for large scale distributed training of the model. In addition, AI Platform Pipelines stores artifacts produced by the various pipeline steps to Cloud Storage, and information about these artifacts is stored in an ML metadata database in Cloud SQL.


Predictions are produced using a batch prediction pipeline, and stored in Cloud Datastore for consumption. The batch prediction pipeline is made using TFX and runs on Vertex Pipelines. The pipeline uses the most recent model in the model registry to score the serving queries from BigQuery. A Cloud Function is provided as a REST/HTTP endpoint to retrieve predictions from Datastore.

Continuous Training Pipeline

A diagram of the TFX pipeline for the training workflow.
A diagram of the TFX pipeline for the training workflow.

There are many components used in our TFX-based Continuous training workflow, training is currently done on an on-demand basis, but later on it is planned to be executed on a bi-weekly cadence. Here is a little bit of detail on the important ones.

Raw Data

Our data consists of multiple datasets stored in heterogeneous formats across BigQuery tables and other formats, that are then joined in denormalized fashion by the customer into a single BigQuery table for training. To help avoid bias and drift in our model we train the model on a rolling window of 4 weeks cadence with one overlapping week per training cycle. This was a simple design choice as it was very straightforward to implement, as BigQuery has good compatibility as a source with TFX, and also allows the user to do some basic data preprocessing and cleaning during fetching.


We first leverage BigQuery by leveraging built-in functions to preprocess the data. By embedding our own specific processes into the query calls made by the ExampleGen component, we were able to avoid building out a separate ETL that would exist outside the scope of a TFX pipeline. This ultimately proved to be a good way to get the model in production more quickly. This preprocessed data is then split into training and eval and converted to tf.Examples via the ExampleGen component.


This component does the necessary feature engineering and transformations necessary to handle strings, fill in missing values, log-normalize values, setup embeddings etc. The major benefit here is that the resulting transformation is ultimately prepended to the computational graph, so that the exact same code is used for training and serving. The Transform component runs on Cloud Dataflow in production.


The Trainer component trains the model using TF-Agents. We leverage parallel training on Vertex Training to speed things up. The model is designed such that the user id passes in from the input to the output unaltered, so that it can be used as part of the downstream serving pipeline. The Trainer component runs on Vertex Training in production.


The Evaluator compares the existing production model to the model received by the Trainer and prepares the metrics required by the validator component to bless the “better” one for use in production. The model gating criteria is based on the AUC scores as well as counterfactual policy evaluation and possibly other metrics in the future. It is easy to implement custom metrics which meet the business requirements owing to the extensibility of the evaluator component. The Evaluator runs on Vertex AI.


The Pusher’s primary function is to send the blessed model to our TFServing deployment for production. However, we added functionality to use the custom metrics produced in the Evaluator to determine decisioning criteria to be used in serving, and attach that to the computational graph. The level of abstraction available in TFX components made it easy to make this custom modification. Overall, the modification allows the pipeline to operate without a human in the loop so that we are able to make model updates frequently, while continuing to deliver consistent performance on metrics that are important to our business.


This is a custom TFX component which creates a dictionary with hyperparameters (e.g., batch size, learning rate) and stores the dictionary as an artifact. The hyperparameters are passed as input to the trainer.


This custom component takes a serving policy (which includes exploration) and a corresponding eval policy (without exploration), and resolves which policy will be used for serving.


This custom component copies the pushed/blessed model from the TFX artifacts directory to a curated destination.

The out-of-box components from TFX provided most of the functionality we require, and it is easy to create some new custom components to make the entire pipeline satisfy our requirements. There are also other components of the pipeline such as StatisticsGen (which also runs on Dataflow).

Batch Prediction Pipeline

A diagram showing the TFX pipeline for the batch prediction workflow.
A diagram showing the TFX pipeline for the batch prediction workflow.

Here are a few of the key pieces of our batch prediction system.

Inference Dataset

Our inference dataset has nearly identical format to the training dataset, except that it is emptied and repopulated with new data daily.


Just like for the Training pipeline, we use this component to read data from BigQuery and convert it into tf.Examples.

Model Importer

This component imports the computation graph exported by the Pusher component of the training pipeline. As mentioned above, since it contains the whole computation graph generated by the training pipeline, including feature transformation and the tf.Agents model (including the exploration/exploitation aspect), this is very portable and prevents train/test skew.


As the name implies, this component uses the imported computation graph to perform mass inference on the inference dataset. It runs on Cloud Dataflow in production and makes it very easy to scale.


This is a custom Python Component which takes the inference results from Bulkinfererrer, post-processes them to format/filter the fields as required, and persists it to Cloud Datastore. This runs on Cloud Dataflow in production as well.

Serving is done via cloud functions which take the user ids as input, and returns the precomputed results for each userId stored in DataStore with sub 50 ms latency.

Extending the work so far

In the few months since implementation of the first version we have been making dozens of improvements to the pipeline, everything from changing the architecture/approach of the original model, to changing the way the model’s results are used in the downstream application to generate newsletters. Moreover, each of these improvements brings new value to us more quickly than we’ve been able to in the past.

Since our initial implementation of this reference architecture, we have released a simple Vertex AI pipeline based github code samples to implementing recommender systems using TF Agents here. By using this template and guide, it will help them build recommender systems using contextual bandits on GCP in a scalable, modularized, low latency and cost effective manner. It’s quite remarkable how many of the existing TFX components that we have in place carry over to new projects, and even more so how drastically we’ve reduced the time it takes to get a model in production. As a result, even the software engineers on our team without much ML expertise feel confident in being able to reuse this architecture and adapt it to more use cases. The data scientists are able to spend more of their time optimizing the parameters and architectures of the models they produce, understanding their impact on the business, and ultimately delivering more value to the users and the business.


None of this would have been possible without the joint collaboration of the following Googlers: Khalid Salama, Efi Kokiopoulou, Gábor Bartók and Digitec Galaxus’s team of engineers.

A Google Cloud blog on this project can be found here.

Using the TensorFlow-Agents Bandits Library for Recommendations

Posted by Gábor Bartók and Efi Kokiopoulou, Google Research

This article assumes you have some prior experience with reinforcement learning and/or multi-armed bandits. If you’re new to the subject, a good starting point is the Bandits Wikipedia entry, or for a bit more technical and in-depth introduction, this book.

In this blog post we introduce the TensorFlow-Agents Bandits library. This library offers a comprehensive list of the most popular bandit algorithms along with a variety of test problems on which the algorithms can be run. The test problems (called bandit environments) include some synthetic environments as well as environments converted from real-life (classification or recommendation) datasets.

One of the latter is the MovieLens environment, which utilizes this dataset. In this blog post, we will guide you through the usage of the TF-Agents Bandits library with the help of the MovieLens Environment.

Multi-Armed Bandits

Multi-Armed Bandits is a machine learning framework in which an agent repeatedly selects actions from a set of actions and collects rewards by interacting with the environment. The goal of the agent is to accumulate as much reward as possible, within a given time horizon. The name “bandit” comes from the illustrative example of finding the best slot machine (one-armed bandit) from a set of machines with different payoffs. The actions are also known as “arms”.

Image from Wikipedia
Image from Wikipedia

There are two more important concepts to be aware of: “context”, and “regret”. In many real life scenarios, it’s not enough to find the best action that on average provides the highest reward: we want to find the best action depending on the situation/context. To extend the bandits framework in this direction, we introduce the notion of “context”. Before the agent has to select an action, it receives the context that provides information about the current round. Then the agent’s goal is to find the policy that selects the highest-rewarding action for the given context.

In bandits literature, the notion of “regret” is very important. The regret can be informally defined as the difference in performance between the optimal policy and the learned policy. Typically the performance is measured in terms of cumulative reward (i.e., sum of rewards across several rounds); otherwise, one may also refer to the “instantaneous regret” which is the regret the agent suffers at a certain round. Bandit algorithms typically come with performance guarantees in terms of upper bound on the regret given a family of bandit problems.

Example: Movie Recommendation

Consider the following scenario. You are tasked with recommending movies to users of a movie streaming service. In every round you receive information about the user. Your task is to choose from a handful of movies for the user with the goal of choosing one that the user will enjoy and give a high rating.

A Recommendation Dataset

For illustration purposes, we will turn the well-known MovieLens dataset into a bandit problem. The dataset consists of ~100K ratings from 943 users on 1682 movies. Our first step to turn this dataset into a contextual bandit problem is to construct the matrix `A` of user/movie ratings, where `A_ij` is the rating of user `i` of movie `j`. Since we have the ratings to a few movies only from each user, one issue with the ratings matrix `A` is that it is very sparse i.e., only a few entries `A_ij` are available; all the other entries are unknown. To address this sparsity issue, we construct a low-rank SVD decomposition `A ~= U*V’` (low-rank matrix decomposition in recommender systems is a popular approach for collaborative filtering, see e.g., Koren et al. 2009). This way, the rows of `U` are context features. Then, the movies to be recommended to the user are the set of actions, represented as rows of `V`. The reward for recommending movie `j` to user `i` can then be calculated as the inner product of the corresponding rows of `U_i` and `V_j`. Therefore, using the low-rank SVD decomposition to compute rewards gives us the ability to approximate the reward even for movies that were not recommended to the users; hence their rating was unknown.

TF-Agents Bandits

Now let’s see how the above problem is modeled and solved with the help of the TF-Agents Bandits library. TF-Agents is a modular library that has building blocks for every aspect of Reinforcement Learning and Bandits. A problem can be expressed in terms of an “environment”. An environment is a class that generates observations (aka contexts), and also outputs a reward after being presented with actions. In the case of the MovieLens environment, an observation is a random row of the matrix `U`, while the reward is given after an algorithm has chosen an action (i.e., row of the matrix `V`, a movie in our case). The implementation of the MovieLens environment can be found here. It’s worth noting here that it is rather simple to implement a bandit environment in TF-Agents. For a walkthrough, we refer the reader to our Bandits Tutorial.


Bandit algorithms in TF-Agents have two main building blocks: “policies” and “agents”. A policy is a function that, given an observation, chooses an action. The agent is responsible for learning a good policy: given examples of (observation, action, reward) tuples, it trains the policy so that it chooses better actions. The TF-Agents Bandits library offers a comprehensive list of the most popular algorithms, including linear methods as well as nonlinear ones (e.g., those with neural network-based value functions). Let’s see how LinUCB tackles the MovieLens problem!

The LinUCB algorithm

In short, the LinUCB algorithm keeps track of running average rewards for all actions, along with confidence intervals around the estimates. In every turn, the algorithm chooses the action that has the highest upper confidence bound on its reward estimate.

In the TF-Agents library, the LinUCB algorithm is built from a LinearBanditPolicy with an “Optimistic Exploration Strategy”, and a LinearBanditAgent responsible for updating the estimates. Note that the exploration strategy can be changed from “Optimistic” to “Sampling”, in which case the algorithm becomes Linear Thompson Sampling.

So let’s see how LinUCB performs on the MovieLens environment! We ran LinUCB on the MovieLens environment (with 100 actions and SVD decomposition rank 20) and we get results on TensorBoard:

(Note that all of the below plots are based on averaging five runs, the shadows show standard deviations. A rolling average smoothing is also applied on the curves.)

Linear Thompson Sampling

Linear Thompson Sampling

As mentioned above, with a slight modification of LinUCB, we get an implementation for Linear Thompson Sampling (LinTS). If we run LinTS on the same problem (implementation here), we get a very similar result to that of LinUCB (see joint graph further down).


Let’s compare these results with another agent, say, the NeuralEpsilonGreedy agent. As the name suggests, this agent uses a neural network to estimate the rewards, and adds uniform exploration with probability `epsilon`. This exploration strategy is known as “epsilon-greedy” since the method is greedy most of the time but with probability `epsilon` it explores by picking an action uniformly at random. If we run Neural Epsilon Greedy and put the results from the three algorithms, we get:

NeuralEpsilonGreedy graph

It’s interesting to also look at how often the methods pick suboptimal actions. This is shown below:


We see that LinUCB and LinTS have very similar performance, which is not very surprising, as they are very similar algorithms. On the other hand, Neural epsilon-Greedy is not doing very well on this problem. After fifty thousand iterations, the metrics are still far away from that of the linear methods. Note, nevertheless, that even the epsilon-Greedy algorithm manages to find the best movie about half the time, out of 100, still not bad!

To be fair, it’s expected that linear algorithms do better than non-linear ones on this problem, as the problem is linear (by the reward calculation construction).

As for the difference between the two linear algorithms, it seems that LinUCB struggles in the beginning a little bit, but in the long run it is slightly (not significantly) better than LinTS.

Recommendation with Arm Features

The MovieLens example above has some shortcomings: its actions are a selection of movies, algorithms have to learn a distinct model for every movie, and it’s also hard to introduce new movies in the system. To this end, we change the environment a little bit: instead of treating every movie as an independent action, we model the movies with features, similarly to users: the rows of `V` will be the movie features. Then the model only has to learn one reward function, whose input is both the user features `u` and the movie features `v`. This way we can have an unlimited number of movies in the system, and we can introduce new movies on the fly. This version of the environment can be found here.

Agents Running on Per Arm Feature Environments

Most of the agents implemented in our library have the functionality of running on environments that have features for its actions (we call these environments “per-arm environments”).

Now let’s see how the different algorithms behave on the per-arm version of the MovieLens environment. We ran the arm-feature versions of the three algorithms: LinUCB, LinTS, and eps-Greedy. The result is quite different from the previous section: Here the linear methods seem to fail to find the relationship between actions and rewards, while the neural approach gives similar results to that of the non-arm feature problem.


The neural algorithm still finds the best action ~45% of the time, while the linear algorithms only ~30% of the time.

Your New Bandit Algorithm

If you haven’t found what you are looking for in the list of agents within the library, it’s possible, and not too complicated, to implement your own algorithm. You need to:

  • subclass tf_agents.policies.TFPolicy and
  • subclass tf_agents.agents.TFAgent.


To define a policy, one needs to implement its private member function _distribution(…). In short, this function takes an observation and outputs a distribution of actions (or simply an action in case of a deterministic policy).


As stated above, an agent is responsible for training the policy. To this end, subclasses of TF-Agents’ TFAgent (sorry) have to implement the private member function _train() (among others, some details are omitted for clarity). This function takes batches of training data, and trains the policy.

Your New Bandit Environment

If you want to test your (new) algorithm and have an idea for an environment, it’s also simple to implement it in TF-Agents. A Bandit environment has two main roles: (i) to generate observations, and (ii) to return a reward after the agent chooses an action. One can easily create an environment class by defining these two functions.


In this blog post, we introduced the TF-Agents Bandit library and showed how to tackle a recommendation problem with it. If you want to play around with the environments and agents used in this post, you can go directly to this executable to run these agents and more. If you want to explore the library or just want to read more about it, we suggest starting with this tutorial. And if you’re interested in learning more about making recommendations on this MovieLens dataset, you can also check out another great library called TensorFlow Recommenders.


The TF-Agents Bandits library has been built in collaboration with Jesse Berent, Tzu-Kuo Huang, Kishavan Bhola, Sergio Guadarrama‎, Anoop Korattikara, Oscar Ramirez, Eugene Brevdo, and many others from the TF-Agents team.

Read More