Posted by Ayush Jain, Carlos Araya, and Mani Varadarajan for the TensorFlow team
Welcome to TensorFlow and Keras at Google I/O!
The world of machine learning is changing, faster than ever. The rise of Large Language Models (LLMs) is sparking the imagination of developers worldwide, with new generative AI applications reaching hundreds of millions of people around the world. These models are trained on massive datasets, and used to solve a variety of tasks, from natural language processing to image generation.
Powering all these new capabilities requires new levels of model efficiency and performance, as well as support for seamless deployment across a growing number of devices – be it on a server, the web, mobile devices, or beyond. As stewards of one of the largest machine learning communities in the world, the TensorFlow team is continually asking how we can better serve you.
To that end, this post covers a few of the many improvements and additions coming this year to the TensorFlow ecosystem. Let’s dive in!
A Growing Ecosystem
New functionality we’re covering today:
KerasCV and KerasNLP allows you to access pre-trained, state-of-the-art models in just a few lines of code.
DTensor helps you scale up your models and train them efficiently by combining different parallelism techniques.
With JAX2TF, models written with the JAX numerical library can be used in the TensorFlow ecosystem.
We also preview the TF Quantization API, which enables you to make your models more cost and resource-efficient without compromising on accuracy.
Applied ML with KerasCV & KerasNLP
KerasCV and KerasNLP are powerful, modularized libraries that give you direct access to the state-of-the-art in computer vision and natural language processing.
|The KerasCV + KerasNLP suite, at a glance.|
Whether you want to classify images, auto-generate text from prompts like with Bard or anything in between, KerasCV and KerasNLP make it easy with just a few lines of code. And since it’s a part of Keras, it’s fully integrated with the TensorFlow Ecosystem.
Let’s look at some code for image generation. KerasCV is designed to support many models, and in this case we’ll use a diffusion model. Despite the complexity of the underlying architecture, you can get it up and running with just a few lines of code.
With one line to import and another to initialize the model, you can generate completely new images:
|KerasCV-generated images of an astronaut riding a horse!|
Machine Learning at Scale with DTensor
DTensor enables larger and more performant model training by giving developers the flexibility to combine and fine-tune multiple parallelism techniques.
Traditionally, ML developers have scaled up models through data parallelism, which splits up your data and feeds it to horizontally-scaled model instances. This scales up training but has an important limitation: it requires that the model fits within a single hardware device.
As models get bigger, fitting into a single device is no longer a guarantee — developers need to be able to scale their models across hardware devices. This is where model parallelism becomes important, allowing for the model to be split up into shards that can be trained in parallel.
With DTensor, data and model parallelism are not only supported, but also can be directly combined to scale models even more efficiently. And it’s completely accelerator agnostic — whether you use TPUs, GPUs, or something else.
|Mixed (data + model) parallelism, with DTensor.|
Let’s go through an example. Let’s say that you are building with a transformer model, like the Open Pre-trained Transformer (OPT) available through KerasNLP, and training it with some input dataset:
But here’s the thing about OPT — it’s big. With variations up to 175 billion parameters, if we tried traditional data parallelism, it would have errored outright — there’s just too many weights to reasonably replicate within a single hardware device. That’s where DTensor comes in.
To work with DTensor, we need to define two things:
First is a mesh, where you define (a) a set of hardware devices and (b) a topology, here the batch and model dimensions.
Second is a layout, which defines how to shard the Tensor dimension on your defined mesh. Through our Keras domain package integrations, you can do this in just one line.
layout_map = keras_nlp.models.OPTCausalLM.create_layout_map(mesh)
From there, you create the DTensor layout’s context and include your model creation code within it. Note that at no point did we have to make any changes to the model itself!
Performance for DTensor today is already on par with industry benchmarks, nearly matching the gold-standard implementation of model parallelism offered by NVIDIA’s Megatron for GPUs. Further improvements are in the works to raise the bar even further, across hardware devices.
In the future, DTensor will be fully integrated with key interfaces like
tf.distribute and Keras as a whole, with one entry point regardless of hardware and a number of other quality of life features. If you want to learn more, check out the DTensor overview or the Keras integration guide!
Bringing Research to Production with JAX2TF
Many of the ML advancements that are now household names had their beginnings in research. For example, the Transformer architecture, created and published by Google AI, underpins the fantastic advances in language models.
JAX has emerged as a trusted tool for much of this kind of discovery, but productionizing it is hard. To that end, we’ve been thinking about how to bring research more easily into TensorFlow, giving innovations built on JAX the full strength of TensorFlow’s uniquely robust and diverse production ecosystem.
That’s why we’ve built JAX2TF, a lightweight API that provides a pathway from the JAX ecosystem to the TensorFlow ecosystem. There are many examples of how this can be useful – here’s just a few:
- Inference: Taking a model written for JAX and deploying it either on a server using TF Serving or on-device using TFLite.
- Fine Tuning: Taking a model that was trained using JAX, we can bring its components to TF using JAX2TF, and continue training it in TensorFlow with your existing training data and setup.
- Fusion: Combining parts of models that were trained using JAX with those trained using TensorFlow for maximum flexibility.
The key to enabling this kind of interoperation between JAX and TensorFlow is baked into
jax2tf.convert, which takes in model components created on top of JAX (e.g. your loss function, prediction function, etc.) and creates equivalent representations of them as TensorFlow functions, which can then be exported as a TensorFlow SavedModel.
We’ve created a code walkthrough for one of the examples above: a quick fine-tuning setup, creating a simple model using modeling libraries in the JAX ecosystem (like Flax and Optax) and bringing it into TF to finish training. Check it out here.
JAX2TF is already baked into various tools in the TensorFlow ecosystem, under the hood. For example, here are code guides for simple conversion from JAX to TFLite for mobile devices and from JAX to TF.js for web deployment!
Coming Soon: The TensorFlow Quantization API
ML developers today face a wide variety of real-world constraints introduced by the settings they’re working in, like the size of a model or where it gets deployed.
With TensorFlow, we want developers to be able to quickly adjust and accommodate for these kinds of constraints, and to do so without sacrificing model quality. To do this, we’re building the TF Quantization API, a native quantization toolkit for TF2 which will be available publicly later in 2023.
Briefly, quantization is a group of techniques designed to make models faster, smaller, and generally less resource- and infrastructure-intensive to train and serve.
Quantization does this by reducing the precision of a model’s parameters, just like reducing pixel depth in an image like the one of Albert Einstein below. Note that even with reduced precision, we can still make out the key details:
|Renderings of a photograph of Albert Einstein with increasingly reduced bit precision.|
At a high level, this works by taking a range of values in your starting precision, and mapping that range to a single bucket in your ending precision. Let’s illustrate this with an example:
|Quantizing float representation to 4-bit integers.|
Take a look at the range [0.27, 0.49] on the x-axis: for float32, the blue line actually represents 7381976 unique numbers! The red line represents the int4 quantization of this range, condensing all of those numbers into a single bucket: 1001 (the number 9 in decimal).
By lowering precision through quantization, we can store model weights in a much more efficient, compressed form.
There’s a few different ways to quantize.
- Post-Training Quantization (PTQ): Convert to a quantized model after training. This is as simple as it gets and most readily accessible, but there can be a small quality drop.
- Quantization-Aware Training (QAT): Simulate quantization during just the forward pass, providing for maximal flexibility with a minimal quality tradeoff.
- Quantized Training: Quantize all computations while training. This is still nascent, and needs a lot more testing, but is a powerful tool we want to make sure TensorFlow users have access to.
TensorFlow previously has had a few tools for developers to quantize their models, like this guide for PTQ and this one for QAT. However, these have been limited – with PTQ depending on conversion to TFLite for mobile deployment and QAT requiring you to rewrite your model.
The TF Quantization API is different – it’s designed to work regardless of where you’re deploying, and without you having to rewrite a single line of existing modeling code. We’re building it with flexibility and fidelity in mind, so you get the benefits of a smaller quantized model with new levels of fine-grained control and without any concerns about how it’ll all fit into your stack.
Since you’ve made it this far into the blog, here’s a sneak peek at how it’ll look. We’ll start with a typical setup for a TensorFlow model, just a few layers in Keras. From there, we can load in a predefined quantization schema to apply as a config map to our model.
But if you need more flexibility, TF Quantization API will also let you fully customize how you quantize. There’s built-in support for you to curate your schema to apply different behaviors for every layer, operation, or tensor!
With that, we can directly apply quantization and train or save within a quantization context. Our model still has natural compatibility with the rest of the TF ecosystem, where quantization truly bears fruit.
We ran a bunch of tests using the MobileNetV2 model on the Pixel 7, and saw up to 16.7x gains in serving throughput versus the non-quantized baseline. This gain comes without any noticeable detriment to quality: both the float32 baseline and the int8 quantized model reported 73% accuracy.
The TF Quantization API isn’t public just yet, but will be available very soon and will continue to evolve to provide even more benefits.
That’s a wrap!
Today, we’ve shown you just a few of the key things we’ve been working on, and there’s a lot more to come.
We can’t wait to see what you’ll build, and we’re always inspired by our community’s enduring enthusiasm and continued partnership. Thanks for stopping by!
Special thanks to George Necula, Francois Chollet, Jonathan Bischof, Scott Zhu, Martin Gorner, Dong Li, Adam Koch, Bruce Fontaine, Laurence Moroney, Josh Gordon, Lauren Usui, and numerous others for their contributions to this post.