Graph neural networks in TensorFlow

Graph neural networks in TensorFlow

Posted by Dustin Zelle – Software Engineer, Research and Arno Eigenwillig – Software Engineer, CoreML

This article is also shared on the Google Research Blog

Objects and their relationships are ubiquitous in the world around us, and relationships can be as important to understanding an object as its own attributes viewed in isolation — for example: transportation networks, production networks, knowledge graphs, or social networks. Discrete mathematics and computer science have a long history of formalizing such networks them as graphs, consisting of nodes arbitrarily connected by edges in various irregular ways. Yet most machine learning (ML) algorithms allow only for regular and uniform relations between input objects, such as a grid of pixels, a sequence of words, or no relation at all.

Graph neural networks, or GNNs for short, have emerged as a powerful technique to leverage both the graph’s connectivity (as in the older algorithms DeepWalk and Node2Vec) and the input features on the various nodes and edges. GNNs can make predictions for graphs as a whole (Does this molecule react in a certain way?), for individual nodes (What’s the topic of this document, given its citations?) or for potential edges (Is this product likely to be purchased together with that product?). Apart from making predictions about graphs, GNNs are a powerful tool used to bridge the chasm to more typical neural network use cases. They encode a graph’s discrete, relational information in a continuous way so that it can be included naturally in another deep learning system.

We are excited to announce the release of TensorFlow GNN 1.0 (TF-GNN), a production-tested library for building GNNs at large scale. It supports both modeling and training in TensorFlow as well as the extraction of input graphs from huge data stores. TF-GNN is built from the ground up for heterogeneous graphs where types and relations are represented by distinct sets of nodes and edges. Real-world objects and their relations occur in distinct types and TF-GNN’s heterogeneous focus makes it natural to represent them.

Inside TensorFlow, such graphs are represented by objects of type tfgnn.GraphTensor. This is a composite tensor type (a collection of tensors in one Python class) accepted as a first-class citizen in, etc. It stores both the graph structure and its features attached to nodes, edges and the graph as a whole. Trainable transformations of GraphTensors can be defined as Layers objects in the high-level Keras API, or directly using the tfgnn.GraphTensor primitive.

GNNs: Making predictions for an object in context

For illustration, let’s look at one typical application of TF-GNN: predicting a property of a certain type of node in a graph defined by cross-referencing tables of a huge database. For example, a citation database of Computer Science (CS) arXiv papers with one-to-many cites and many-to-one cited relationships where we would like to predict the subject area of each paper.

Like most neural networks, a GNN is trained on a dataset of many labeled examples (~millions), but each training step consists only of a much smaller batch of training examples (say, hundreds). To scale to millions, the GNN gets trained on a stream of reasonably small subgraphs from the underlying graph. Each subgraph contains enough of the original data to compute the GNN result for the labeled node at its center and train the model. This process — typically referred to as subgraph sampling — is extremely consequential for GNN training. Most existing tooling accomplishes sampling in a batch way, producing static subgraphs for training. TF-GNN provides tooling to improve on this by sampling dynamically and interactively.

moving image illustrating the process of subgraph sampling where small, tractable subgraphs are sampled from a larger graph to create input examples for GNN training.
Pictured, the process of subgraph sampling where small, tractable subgraphs are sampled from a larger graph to create input examples for GNN training.

TF-GNN 1.0 debuts a flexible Python API to configure dynamic or batch subgraph sampling at all relevant scales: interactively in a Colab notebook (like this one), for efficient sampling of a small dataset stored in the main memory of a single training host, or distributed by Apache Beam for huge datasets stored on a network filesystem (up to hundreds of millions of nodes and billions of edges). For details, please refer to our user guides for in-memory and beam-based sampling, respectively.

On those same sampled subgraphs, the GNN’s task is to compute a hidden (or latent) state at the root node; the hidden state aggregates and encodes the relevant information of the root node’s neighborhood. One classical approach is message-passing neural networks. In each round of message passing, nodes receive messages from their neighbors along incoming edges and update their own hidden state from them. After n rounds, the hidden state of the root node reflects the aggregate information from all nodes within n edges (pictured below for n = 2). The messages and the new hidden states are computed by hidden layers of the neural network. In a heterogeneous graph, it often makes sense to use separately trained hidden layers for the different types of nodes and edges.

moving image illustrating the process of subgraph sampling where small, tractable subgraphs are sampled from a larger graph to create input examples for GNN training.
Pictured, a simple message-passing neural network where, at each step, the node state is propagated from outer to inner nodes where it is pooled to compute new node states. Once the root node is reached, a final prediction can be made.

The training setup is completed by placing an output layer on top of the GNN’s hidden state for the labeled nodes, computing the loss (to measure the prediction error), and updating model weights by backpropagation, as usual in any neural network training.

Beyond supervised training (i.e., minimizing a loss defined by labels), GNNs can also be trained in an unsupervised way (i.e., without labels). This lets us compute a continuous representation (or embedding) of the discrete graph structure of nodes and their features. These representations are then typically utilized in other ML systems. In this way, the discrete, relational information encoded by a graph can be included in more typical neural network use cases. TF-GNN supports a fine-grained specification of unsupervised objectives for heterogeneous graphs.

Building GNN architectures

The TF-GNN library supports building and training GNNs at various levels of abstraction.

At the highest level, users can take any of the predefined models bundled with the library that are expressed in Keras layers. Besides a small collection of models from the research literature, TF-GNN comes with a highly configurable model template that provides a curated selection of modeling choices that we have found to provide strong baselines on many of our in-house problems. The templates implement GNN layers; users need only to initialize the Keras layers.

import tensorflow_gnn as tfgnn
from tensorflow_gnn.models import mt_albis

def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
  """Builds a GNN as a Keras model."""
  graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)

  # Encode input features (callback omitted for brevity).
  graph = tfgnn.keras.layers.MapFeatures(

  # For each round of message passing...
  for _ in range(2):
    # ... create and apply a Keras layer.
    graph = mt_albis.MtAlbisGraphUpdate(
        units=128, message_dim=64,
        attention_type="none", simple_conv_reduce_type="mean",
        normalization_type="layer", next_state_type="residual",
        state_dropout_rate=0.2, l2_regularization=1e-5,

  return tf.keras.Model(inputs, graph)

At the lowest level, users can write a GNN model from scratch in terms of primitives for passing data around the graph, such as broadcasting data from a node to all its outgoing edges or pooling data into a node from all its incoming edges (e.g., computing the sum of incoming messages). TF-GNN’s graph data model treats nodes, edges and whole input graphs equally when it comes to features or hidden states, making it straightforward to express not only node-centric models like the MPNN discussed above but also more general forms of GraphNets. This can, but need not, be done with Keras as a modeling framework on the top of core TensorFlow. For more details, and intermediate levels of modeling, see the TF-GNN user guide and model collection.

Training orchestration

While advanced users are free to do custom model training, the TF-GNN Runner also provides a succinct way to orchestrate the training of Keras models in the common cases. A simple invocation may look like this:

from tensorflow_gnn import runner
   task=runner.RootNodeBinaryClassification("papers", ...),
   trainer=runner.KerasTrainer(tf.distribute.MirroredStrategy(), model_dir="/tmp/model"),

The Runner provides ready-to-use solutions for ML pains like distributed training and tfgnn.GraphTensor padding for fixed shapes on Cloud TPUs. Beyond training on a single task (as shown above), it supports joint training on multiple (two or more) tasks in concert. For example, unsupervised tasks can be mixed with supervised ones to inform a final continuous representation (or embedding) with application specific inductive biases. Callers only need substitute the task argument with a mapping of tasks:

from tensorflow_gnn import runner
from tensorflow_gnn.models import contrastive_losses
        "classification": runner.RootNodeBinaryClassification("papers", ...),
        "dgi": contrastive_losses.DeepGraphInfomaxTask("papers"),

Additionally, the TF-GNN Runner also includes an implementation of integrated gradients for use in model attribution. Integrated gradients output is a GraphTensor with the same connectivity as the observed GraphTensor but its features replaced with gradient values where larger values contribute more than smaller values in the GNN prediction. Users can inspect gradient values to see which features their GNN uses the most.


In short, we hope TF-GNN will be useful to advance the application of GNNs in TensorFlow at scale and fuel further innovation in the field. If you’re curious to find out more, please try our Colab demo with the popular OGBN-MAG benchmark (in your browser, no installation required), browse the rest of our user guides and Colabs, or take a look at our paper.


The TF-GNN release 1.0 was developed by a collaboration between Google Research (Sami Abu-El-Haija, Neslihan Bulut, Bahar Fatemi, Johannes Gasteiger, Pedro Gonnet, Jonathan Halcrow, Liangze Jiang, Silvio Lattanzi, Brandon Mayer, Vahab Mirrokni, Bryan Perozzi, Anton Tsitsulin, Dustin Zelle), Google Core ML (Arno Eigenwillig, Oleksandr Ferludin, Parth Kothari, Mihir Paradkar, Jan Pfeifer, Rachael Tamakloe), and Google DeepMind (Alvaro Sanchez-Gonzalez and Lisa Wang).

Read More

TensorFlow 2.15 update: hot-fix for Linux installation issue

TensorFlow 2.15 update: hot-fix for Linux installation issue

Posted by the TensorFlow team

We are releasing a hot-fix for an installation issue affecting the TensorFlow installation process. The TensorFlow 2.15.0 Python package was released such that it requested tensorrt-related packages that cannot be found unless the user installs them beforehand or provides additional installation flags. This dependency affected anyone installing TensorFlow 2.15 alongside NVIDIA CUDA dependencies via pip install tensorflow[and-cuda]. Depending on the installation method, TensorFlow 2.14 would be installed instead of 2.15, or users could receive an installation error due to those missing dependencies.

To solve this issue as quickly as possible, we have released TensorFlow 2.15.0.post1 for the Linux x86_64 platform. This version removes the tensorrt Python package dependencies from the tensorflow[and-cuda] installation method. Support for TensorRT is otherwise unaffected as long as TensorRT is already installed on the system. Now, pip install tensorflow[and-cuda] works as originally intended for TensorFlow 2.15.

Using .post1 instead of a full minor release allowed us to push this release out quickly. However, please be aware of the following caveat: for users wishing to pin their Python dependency in a requirements file or other situation, under Python’s version specification rules, tensorflow[and-cuda]==2.15.0 will not install this fixed version. Please use ==2.15.0.post1 to specify this exact version on Linux platforms, or a fuzzy version specification, such as ==2.15.*, to specify the most recent compatible version of TensorFlow 2.15 on all platforms.

Read More

Half-precision Inference Doubles On-Device Inference Performance

Half-precision Inference Doubles On-Device Inference Performance

Posted by Marat Dukhan and Frank Barchard, Software Engineers

CPUs deliver the widest reach for ML inference and remain the default target for TensorFlow Lite. Consequently, improving CPU inference performance is a top priority, and we are excited to announce that we doubled floating-point inference performance in TensorFlow Lite’s XNNPack backend by enabling half-precision inference on ARM CPUs. This means that more AI powered features may be deployed to older and lower tier devices.

Traditionally, TensorFlow Lite supported two kinds of numerical computations in machine learning models: a) floating-point using IEEE 754 single-precision (32-bit) format and b) quantized using low-precision integers. While single-precision floating-point numbers provide maximum flexibility and ease of use, they come at the cost of 4X overhead in storage and memory and exhibit a performance overhead compared to 8-bit integer computations. In contrast, half-precision (FP16) floating-point numbers pose an interesting alternative balancing ease-of-use and performance: the processor needs to transfer twice fewer bytes and each vector operation produces twice more elements. By virtue of this property, FP16 inference paves the way for 2X speedup for floating-point models compared to the traditional FP32 way.

For a long time FP16 inference on CPUs primarily remained a research topic, as the lack of hardware support for FP16 computations limited production use-cases. However, around 2017 new mobile chipsets started to include support for native FP16 computations, and by now most mobile phones, both on the high-end and the low-end. Building upon this broad availability, we are pleased to announce the general availability for half-precision inference in TensorFlow Lite and XNNPack.

Performance Improvements

Half-precision inference has already been battle-tested in production across Google Assistant, Google Meet, YouTube, and ML Kit, and demonstrated close to 2X speedups across a wide range of neural network architectures and mobile devices. Below, we present benchmarks on nine public models covering common computer vision tasks:

  1. MobileNet v2 image classification [download]
  2. MobileNet v3-Small image classification [download]
  3. DeepLab v3 segmentation [download]
  4. BlazeFace face detection [download]
  5. SSDLite 2D object detection [download]
  6. Objectron 3D object detection [download]
  7. Face Mesh landmarks [download]
  8. MediaPipe Hands landmarks [download]
  9. KNIFT local feature descriptor [download]

These models were benchmarked on 5 popular mobile devices, including recent and older devices (Pixel 3a, Pixel 5a, Pixel 7, Galaxy M12 and Galaxy S22). The average speedup is shown below.

Graph of Average speedup for fp16 vs fp32
Single-threaded inference speedup with half-precision (FP16) inference compared to single-precision (FP32) across 5 mobile devices. Higher numbers are better.

The same models were also benchmarked on three laptop computers (MacBook Air M1, Surface Pro X and Surface Pro 9)

Single-threaded inference speedup with half-precision (FP16) inference compared to single-precision (FP32) across 3 laptop computers. Higher numbers are better.

Currently, the FP16-capable hardware supported in XNNPack is limited to ARM & ARM64 devices with ARMv8.2 FP16 arithmetics extension, which includes Android phones starting with Pixel 3, Galaxy S9 (Snapdragon SoC), Galaxy S10 (Exynos SoC), iOS devices with A11 or newer SoCs, all Apple Silicon Macs, and Windows ARM64 laptops based with Snapdragon 850 SoC or newer.

How Can I Use It?

To benefit from the half-precision inference in XNNPack, the user must provide a floating-point (FP32) model with FP16 weights and special “reduced_precision_support” metadata to indicate model compatibility with FP16 inference. The metadata can be added during model conversion using the _experimental_supported_accumulation_type attribute of the tf.lite.TargetSpec object:

converter.target_spec.supported_types = [tf.float16]
converter.target_spec._experimental_supported_accumulation_type = tf.dtypes.float16

When the compatible model is delegated to XNNPack on a hardware with native support for FP16 computations, XNNPack will transparently replace FP32 operators with their FP16 equivalents, and insert additional operators to convert model inputs from FP32 to FP16 and convert model outputs back from FP16 to FP32. If the hardware is not capable of FP16 arithmetics, XNNPack will perform model inference with FP32 calculations. Therefore, a single model can be transparently deployed on both recent and legacy devices.

Additionally, the XNNPack delegate provides an option to force FP16 inference regardless of the model metadata. This option is intended for development workflows, and in particular for testing end-to-end accuracy of the model when FP16 inference is used. In addition to devices with native FP16 arithmetics support, forced FP16 inference is supported on x86/x86-64 devices with AVX2 extension in emulation mode: all elementary floating-point operations are computed in FP32, then converted to FP16 and back to FP32. Note that such simulation is slow and not a bit-exact equivalent to native FP16 inference, but simulates the effects of restricted mantissa precision and exponent range in the native FP16 arithmetics. To force FP16 inference, either build TensorFlow Lite with --define xnnpack_force_float_precision=fp16 Bazel option, or apply XNNPack delegate explicitly and add TFLITE_XNNPACK_DELEGATE_FLAG_FORCE_FP16 flag to the TfLiteXNNPackDelegateOptions.flags bitmask passed into the TfLiteXNNPackDelegateCreate call:

TfLiteXNNPackDelegateOptions xnnpack_options =
xnnpack_options.flags |= TFLITE_XNNPACK_DELEGATE_FLAG_FORCE_FP16;
TfLiteDelegate* xnnpack_delegate =

XNNPack provides full feature parity between FP32 and FP16 operators: all operators that are supported for FP32 inference are also supported for FP16 inference, and vice versa. In particular, sparse inference operators are supported for FP16 inference on ARM processors. Therefore, users can combine the performance benefits of sparse and FP16 inference in the same model.

Future Work

In addition to most ARM and ARM64 processors, the most recent Intel processors, code-named Sapphire Rapids, support native FP16 arithmetics via the AVX512-FP16 instruction set, and the recently announced AVX10 instruction set promises to make this capability widely available on x86 platform. We plan to optimize XNNPack for these instruction sets in a future release.


We would like to thank Alan Kelly, Zhi An Ng, Artsiom Ablavatski, Sachin Joglekar, T.J. Alumbaugh, Andrei Kulik, Jared Duke, Matthias Grundmann for contributions towards half-precision inference in TensorFlow Lite and XNNPack.

Read More

What's new in TensorFlow 2.15

What’s new in TensorFlow 2.15

Posted by the TensorFlow team

TensorFlow 2.15 has been released! Highlights of this release (and 2.14) include a much simpler installation method for NVIDIA CUDA libraries for Linux, oneDNN CPU performance optimizations for Windows x64 and x86, full availability of tf.function types, an upgrade to Clang 17.0.1, and much more! For the full release note, please check here.

Note: Release updates on the new multi-backend Keras will be published on starting with Keras 3.0. For more information, please check here.

TensorFlow Core

NVIDIA CUDA libraries for Linux

The tensorflow pip package has a new, optional installation method for Linux that installs necessary NVIDIA CUDA libraries through pip. As long as the NVIDIA driver is already installed on the system, you may now run pip install tensorflow[and-cuda] to install TensorFlow’s NVIDIA CUDA library dependencies in the Python environment. Aside from the NVIDIA driver, no other pre-existing NVIDIA CUDA packages are necessary. In TensorFlow 2.15, CUDA has been upgraded to version 12.2.

oneDNN CPU performance optimizations

For Windows x64 & x86 packages, oneDNN optimizations are now enabled by default on X86 CPUs. These optimizations can be enabled or disabled by setting the environment variable TF_ENABLE_ONEDNN_OPTS to 1 (enable) or 0 (disable) before running TensorFlow. To fall back to default settings, simply unset the environment variable.


tf.function types are now fully available.

  • tf.types.experimental.TraceType now allows custom tf.function inputs to declare Tensor decomposition and type casting support. 
  • Introducing tf.types.experimental.FunctionType as the comprehensive representation of the signature of tf.function callables. It can be accessed through the function_type property of tf.function’s and ConcreteFunctions. See the tf.types.experimental.FunctionType documentation for more details. 
  • Introducing tf.types.experimental.AtomicFunction as the fastest way to perform TF computations in Python. This capability can be accessed through the inference_fn property of ConcreteFunctions. (Does not support gradients.) See the tf.types.experimental.AtomicFunction documentation for how to call and use it.

Upgrade to Clang 17.0.1 and CUDA 12.2

TensorFlow PIP packages are now being built with Clang 17 and CUDA 12.2 to improve performance for NVIDIA Hopper-based GPUs. Moving forward, Clang 17 will be the default C++ compiler for TensorFlow. We recommend upgrading your compiler to Clang 17 when building TensorFlow from source.

Read More

Join us at the third Women in ML Symposium!

Join us at the third Women in ML Symposium!

Posted by Sharbani Roy – Senior Director, Product Management, Google

We’re back with the third annual Women in Machine Learning Symposium on December 7, 2023! Join us virtually from 9:30 am to 1:00 pm PT for an immersive and insightful set of deep dives for every level of Machine Learning experience.

The Women in ML Symposium is an inclusive event for anyone passionate about the transformative fields of Machine Learning (ML) and Artificial Intelligence (AI). Dive into the latest advancements in generative AI, explore the intricacies of privacy-preserving AI, dig into the underlying accelerators and ML frameworks that power models, and uncover practical applications of ML across multiple industries.

Our event offers sessions for all expertise levels, from beginners to advanced practitioners. Hear about what’s new in ML and building with Google AI from our keynote speakers, gain insights from seasoned industry leaders across Google Health, Nvidia, Adobe, and more – and discover a wealth of knowledge on topics ranging from foundational AI concepts to open source tools, techniques, and beyond.

RSVP today to secure your spot and explore our exciting agenda. We can’t wait to see you there!

Read More

Simulated Spotify Listening Experiences for Reinforcement Learning with TensorFlow and TF-Agents

Simulated Spotify Listening Experiences for Reinforcement Learning with TensorFlow and TF-Agents

Posted by Surya Kanoria, Joseph Cauteruccio, Federico Tomasi, Kamil Ciosek, Matteo Rinaldi, and Zhenwen Dai – Spotify


Many of our music recommendation problems involve providing users with ordered sets of items that satisfy users’ listening preferences and intent at that point in time. We base current recommendations on previous interactions with our application and, in the abstract, are faced with a sequential decision making process as we continually recommend content to users.

Reinforcement Learning (RL) is an established tool for sequential decision making that can be leveraged to solve sequential recommendation problems. We decided to explore how RL could be used to craft listening experiences for users. Before we could start training Agents, we needed to pick a RL library that allowed us to easily prototype, test, and potentially deploy our solutions.

At Spotify we leverage TensorFlow and the extended TensorFlow Ecosystem (TFX, TensorFlow Serving, and so on) as part of our production Machine Learning Stack. We made the decision early on to leverage TensorFlow Agents as our RL Library of choice, knowing that integrating our experiments with our production systems would be vastly more efficient down the line.

One missing bit of technology we required was an offline Spotify environment we could use to prototype, analyze, explore, and train Agents offline prior to online testing. The flexibility of the TF-Agents library, coupled with the broader advantages of TensorFlow and its ecosystem, allowed us to cleanly design a robust and extendable offline Spotify simulator.

We based our simulator design on TF-Agents Environment primitives and using this simulator we developed, trained and evaluated sequential models for item recommendations, vanilla RL Agents (PPG, DQN) and a modified deep Q-Network, which we call the Action-Head DQN (AH-DQN), that addressed the specific challenges imposed by the large state and action space of our RL formulation.

Through live experiments we were able to show that our offline performance estimates were strongly correlated with online results. This then opened the door for large scale experimentation and application of Reinforcement Learning across Spotify, enabled by the technological foundations unlocked by TensorFlow and TF-Agents.

In this post we’ll provide more details about our RL problem and how we used TF-Agents to enable this work end to end.

The RL Loop and Simulated Users

Reinforcement Learning loop
In RL, Agents interact with the environment continuously. At a given time step the Agent consumes an observation from the environment and, using this observation, produces an action given its policy at time t. The environment then processes the action and emits both a reward and the next observation (note that although typically used interchangeably, State is the complete information required to summarize the environment post action, Observation is the portion of this information actually exposed to the Agent).

In our case the reward emitted from the environment is the response of a user to music recommendations driven by the Agent’s action. In the absence of a simulator we would need to expose real users to Agents to observe rewards. We utilize a model-based RL approach to avoid letting an untrained Agent interact with real users (with the potential of hurting user satisfaction in the training process).

In this model-based RL formulation the Agent is not trained online against real users. Instead, it makes use of a user model that predicts responses to a list of tracks derived via the Agent’s action. Using this model we optimize actions in such a way as to maximize a (simulated) user satisfaction metric. During the training phase the environment makes use of this user model to return a predicted user response to the action recommended by the Agent.

We use Keras to design and train our user model. The serialized user model is then unpacked by the simulator and used to calculate rewards during Agent training and evaluation.

Simulator Design

In the abstract, what we needed to build was clear. We needed a way to simulate user listening sessions for the Agent. Given a simulated user and some content, instantiate a listening session and let the Agent drive recommendations in that session. Allow the simulated user to “react” to these recommendations and let the Agent adjust its strategy based on this result to drive some expected cumulative reward.

The TensorFlow Agents environment design guided us in developing the modular components of our system, each of which was responsible for different parts of the overall simulation.

In our codebase we define an environment abstraction that requires the following be defined for every concrete instantiation:

class AbstractEnvironment(ABC):
_user_model: AbstractUserModel = None
_track_sampler: AbstractTrackSampler = None
_episode_tracker: EpisodeTracker = None
_episode_sampler: AbstractEpisodeSampler = None

    def reset(self) -> List[float]:

    def step(self, action: float) -> (List[float], float, bool):

    def observation_space(self) -> Dict:

    def action_space(self) -> Dict:


At the start of Agent training we need to instantiate a simulation environment that has representations of hypothetical users and the content we’re looking to recommend to them. We base these instantiations on both real and hypothetical Spotify listening experiences. The critical information that defines these instantiations is passed to the environment via _episode_sampler. As mentioned, we also need to provide the simulator with a trained user model, in this case via _user_model.
Flow chart of agent training set up

Actions and Observations

Just like any Agent environment, our simulator requires that we specify the action_spec and observation_spec. Actions in our case may be continuous or discrete depending both on our Agent selection and how we propose to translate an Agent’s action into actual recommendations. We typically recommend ordered lists of items drawn from a pool of potential items. Formulating this action space directly would lead to it being combinatorially complex. We also assume the user will interact with multiple items, and as such previous work in this area that relies on single choice assumptions doesn’t apply.

In the absence of a discrete action space consisting of item collections we need to provide the simulator with a method for turning the Agent’s action into actual recommendations. This logic is contained in the via _track_sampler. The “example play modes” proposed by the episode sampler contains information on items that can be presented to the simulated user. The track sampler consumes these and the agent’s action and returns actual item recommendations.
Flow chart of Agent actions_spec and observation_spec combining to create a recommendation

Termination and Reset

We also need to handle the episode termination dynamics. In our simulator, the reset rules are set by the model builder and based on empirical investigations of interaction data relevant to a specific music listening experience. As a hypothetical, we may determine that 92% of listening sessions terminate after 6 sequential track skips and we’d construct our simulation termination logic to match. It also requires that we design abstractions in our simulator that allow us to check if the episode should be terminated after each step.

When the episode is reset the simulator will sample a new hypothetical user listening session pair and begin the next episode.

Episode Steps

As with standard TF Agents Environments we need to define the step dynamics for our simulation. We have optional dynamics of the simulation that we need to make sure are enforced at each step. For example, we may desire that the same item cannot be recommended more than once. If the Agent’s action indicates a recommendation of an item that was previously recommended we need to build in the functionality to pick the next best item based on this action.

We also need to call the termination (and other supporting functions) mentioned above as needed at each step.

Episode Storage and Replay

The functionality mentioned up until this point collectively created a very complex simulation setup. While the TF Agents replay buffer provided us with the functionality required to store episodes for Agent training and evaluation, we quickly realized the need to be able to store more episode data for debugging purposes, and more detailed evaluations specific to our simulation distinct from standard Agent performance measures.

We thus allowed for the inclusion of an expanded _episode_tracker that would store additional information about the user model predictions, information noting the sampled users/content pairs, and more.

Creating TF-Agent Environments

Our environment abstraction gives us a template that matches that of a standard TF-Agents Environment class. Some inputs to our environment need to be resolved before we can actually create the concrete TF-Agents environment instance. This happens in three steps.

First we define a specific simulation environment that conforms to our abstraction. For example:

class PlaylistEnvironment(AbstractEnvironment):
def __init__(
user_model: AbstractUserModel,
track_sampler: AbstractTrackSampler,
episode_tracker: EpisodeTracker,
episode_sampler: AbstractEpisodeSampler,


Next we use an Environment Builder Class that takes as input a user model, track sampler, etc. and an environment class like PlaylistEnvironment. The builder creates a concrete instance of this environment:

self.playlist_env: PlaylistEnvironment = environment_ctor(

Lastly, we utilize a conversion class that constructs a TF-Agents Environment from a concrete instance of ours:

class TFAgtPyEnvironment(py_environment.PyEnvironment):
    def __init__(self, environment: AbstractEnvironment):
  self.env = environment

This is then executed internally to our Environment Builder:

class EnvironmentBuilder(AbstractEnvironmentBuilder):

      def __init__(self, ...):

      def get_tf_env(self):
      tf_env: TFAgtPyEnvironment = TFAgtPyEnvironment(
      return tf_env

The resulting TensorFlow Agents environment can then be used for Agent training.
Flow chart showing simulator design
This simulator design allows us to easily create and manage multiple environments with a variety of different configurations as needed.

We next discuss how we used our simulator to train RL Agents to generate Playlists.

A Customized Agent for Playlist Generation

As mentioned, Reinforcement Learning provides us with a method set that naturally accommodates the sequential nature of music listening; allowing us to adapt to users’ ever evolving preferences as sessions progress.

One specific problem we can attempt to use RL to solve is that of automatic music playlist generation. Given a (large) set of tracks, we want to learn how to create one optimal playlist to recommend to the user in order to maximize satisfaction metrics. Our use case is different from standard slate recommendation tasks, where usually the target is to select at most one item in the sequence. In our case, we assume we have a user-generated response for multiple items in the slate, making slate recommendation systems not directly applicable. Another complication is that the set of tracks from which recommendations are drawn is ever changing.

We designed a DQN variant capable of handling these constraints that we called an Action Head DQN (AHDQN).
Moving image of AH-DQN network creating recommendations based on changing variables
The AH-DQN network takes as input the current state and an available action to produce a single Q value for the input action. This process is repeated for every possible item in the input. Finally, the item with the highest Q value is selected and added to the slate, and the process continues until the slate is full.

Experiments In Brief

We tested our approach both offline and online at scale to assess the ability of the Agent to power our real-world recommender systems. In addition to testing the Agent itself we were also keen to assess the extent to which our offline performance estimates for various policies returned by our simulator matched (or at least directionally aligned) with our online results.
Graph measuring simulated performance assessment by scaled online reward for different policies

We observed this directional alignment for numerous naive, heuristic, model driven, and RL policies.

Please refer to our KDD paper for more information on the specifics of our model-based RL approach and Agent design.

Federico Tomasi, Joseph Cauteruccio, Surya Kanoria, Kamil Ciosek, Matteo Rinaldi, and Zhenwen Dai
KDD 2023


We’d like to thank all our Spotify teammates past and present who contributed to this work. Particularly, we’d like to thank Mehdi Ben Ayed for his early work in helping to develop our RL codebase. We’d also like to thank the TensorFlow Agents team for their support and encouragement throughout this project (and for the library that made it possible).

Read More

Building a board game with the TFLite plugin for Flutter

Building a board game with the TFLite plugin for Flutter

Posted by Wei Wei, Developer Advocate

In our previous blog posts Building a board game app with TensorFlow: a new TensorFlow Lite reference app and Building a reinforcement learning agent with JAX, and deploying it on Android with TensorFlow Lite, we demonstrated how to train a reinforcement learning (RL) agent with TensorFlow, TensorFlow Agents and JAX respectively, and then deploy the converted TFLite model in an Android app using TensorFlow Lite, to play a simple board game ‘Plane Strike’.

While these end-to-end tutorials are helpful for Android developers, we have heard from the Flutter developer community that it would be interesting to make the app cross-platform. Inspired by the officially released TensorFlow Lite Plugin for Flutter recently, we are going to write one last tutorial and port the app to Flutter.
Flow Chart illustrating training a Reinforncement Learning (RL) Agent with TensorFlow, TensorFlow Agents and JAX, deploying the converted model in an Android app and Flutter using the TensorFlow Lite plugin

Since we already have the model trained with TensorFlow and converted to TFLite, we can just load the model with TFLite interpreter:

void _loadModel() async {
  // Create the interpreter
  _interpreter = await Interpreter.fromAsset(_modelFile);

Then we pass in the user board state and help the game agent identify the most promising position to strike next (please refer to our previous blog posts if you need a refresher on the game rules) by running TFLite inference:

int predict(List<List<double>> boardState) {
  var input = [boardState];
  var output = List.filled(_boardSize * _boardSize, 0)
      .reshape([1, _boardSize * _boardSize]);

  // Run inference, output);

  // Argmax
  double max = output[0][0

maxIdx = 0;
  for (int i = 1; i < _boardSize * _boardSize; i++) {
    if (max < output[0][i]) {
      maxIdx = i;
      max = output[0][i];

  return maxIdx;

That’s it! With some additional Flutter frontend code to render the game boards and track game progress, we can immediately run the game on both Android and iOS (currently the plugin only supports these two mobile platforms). You can find the complete code on GitHub.

If you want to dig digger, there are a couple of things you can try:
  1. Convert the TFAgents-trained model to TFLite and run it with the plugin
  2. Leverage the RL technique we have used and build a new agent for the tic tac toe game in the Flutter Casual Games Toolkit. You will need to create a new RL environment and train the model from scratch before deployment, but the core concept and technique are pretty much the same.

This concludes this mini-series of blogs on leveraging TensorFlow/JAX to build games for Android and Flutter. And we very much look forward to all the exciting things you build with our tooling, so be sure to share them with @googledevs, @TensorFlow, and your developer communities!

Read More

People of AI: Season 2

People of AI: Season 2

Posted by Ashley Oldacre

If you are joining us for the first time, you can binge listen to our amazing 8 episodes from Season 1 wherever you get your podcasts.

We are back for another season of People of AI with a new lineup of incredible guests! I am so excited to introduce my new co-host Luiz Gustavo Martins as we meet inspiring people with interesting stories in the field of Artificial Intelligence.

Last season we focused on the incredible journeys that our guests took to get into the field of AI. Through our stories, we highlighted that no matter who you are, what your interests are, or what you work on, there is a place for anyone to get into this field. We also explored how much more accessible the technology has become over the years, as well as the importance of building AI-related products responsibly and ethically. It is easier than ever to use tools, platforms and services powered by machine learning to leverage the benefits of AI, and break down the barrier of entry.

For season 2, we will feature amazing conversations, focusing on Generative AI! Specifically, we will be discussing the explosive growth of Generative AI tools and the major technology shift that has happened in recent months. We will dive into various topics to explore areas where Generative AI can contribute tremendous value, as well as boost both productivity and economic growth. We will also continue to explore the personal paths and career development of this season’s guests as they share how their interest in technology was sparked, how they worked hard to get to where they are today, and explore what it is that they are currently working on.

Starting today, we will release one new episode of season 2 per week. Listen to the first episode on the People of AI site or wherever you get your podcasts. And stay tuned for later in the season when we premiere our first video podcasts as well!

  • Episode 1: meet your hosts, Ashley and Gus and learn about Generative AI, Bard and the big shift that has dramatically changed the industry. 
  • Episode 2: meet Sunita Verma, a long-time Googler, as she shares her personal journey from Engineering to CS, and into Google. As an early pioneer of AI and Google Ads, we will talk about the evolution of AI and how Generative AI will transform the way we work. 
  • Episode 3: meet Sayak Paul, a Google Developer Expert (GDE) as we explore what it means to be a GDE and how to leverage the power of your community through community contributions. 
  • Episode 4: meet Crispin Velez, the lead for Cloud’s Vertex AI as we dig into his experience in Cloud working with customers and partners on how to integrate and deploy AI. We also learn how he grew his AI developer community in LATAM from scratch. 
  • Episode 5: meet Joyce Shen, venture capital/private equity investor. She shares her fascinating career in AI and how she has worked with businesses to spot AI talent, incorporate AI technology into workflows and implement responsible AI into their products. 
  • Episode 6: meet Anne Simonds and Brian Gary, founders of Muse Join us as we talk about their recent journeys into AI and their new company which uses the power of Generative AI to spark creativity. 
  • Episode 7: meet Tulsee Doshi, product lead for Google’s Responsible AI efforts as we discuss the development of Google-wide resources and best practices for developing more inclusive, diverse, and ethical algorithm driven products. 
  • Episode 8: meet Jeanine Banks, Vice President and General Manager of Google Developer X and Head of Developer Relations. Join us as we debunk AI and get down to what Generative AI really is, how it has changed over the past few months and will continue to change the developer landscape. 
  • Episode 9: meet Simon Tokumine, Director of Product Management at Google. We will talk about how AI has brought us into the era of task-orientated products and is fueling a new community of makers.

Listen now to the first episode of Season 2. We can’t wait to share the stories of these exceptional People of AI with you!

This podcast is sponsored by Google. Any remarks made by the speakers are their own and are not endorsed by Google.

Read More

Pre-processing temporal data made easier with TensorFlow Decision Forests and Temporian

Pre-processing temporal data made easier with TensorFlow Decision Forests and Temporian

Posted by Google: Mathieu Guillame-Bert, Richard Stotz, Robert Crowe, Luiz GUStavo Martins (Gus), Ashley Oldacre, Kris Tonthat, Glenn Cameron, and Tryolabs: Ian Spektor, Braulio Rios, Guillermo Etchebarne, Diego Marvid, Lucas Micol, Gonzalo Marín, Alan Descoins, Agustina Pizarro, Lucía Aguilar, Martin Alcala Rubi

Temporal data is omnipresent in applied machine learning applications. Data often changes over time or is only available or valuable at a certain point in time. For example, market prices and weather conditions change constantly. Temporal data is also often highly discriminative in decision-making tasks. For example, the rate of change and interval between two consecutive heartbeats provides valuable insights into a person’s physical health, and temporal patterns of network logs are used to detect configuration issues and intrusions. Hence, it is essential to incorporate temporal data and temporal information in ML applications.

INFO:  Temporian is a new open-source Python library for preprocessing and feature engineering temporal data for machine learning applications. It is developed in collaboration between Google and Tryolabs. Check the sister blog post for more details.

This blog post demonstrates how to train a forecasting model on transactional data. Specifically, we will show how to forecast the total weekly sales from individual sales records. For the modeling part, we will use TensorFlow Decision Forests as they are well suited to handle temporal data. To feed the transaction data to our model, and to compute temporal specific features, we will use Temporian, a newly released library designed for ingesting and aggregating transactional data from multiple non-synchronized sources.


Time series are the most commonly used representation for temporal data. They consist of uniformly sampled values, which can be useful for representing aggregate signals. However, time series are sometimes not sufficient to represent the richness of available data. Instead, multivariate time series can represent multiple signals together, while time sequences or event sets can represent non-uniformly sampled measurements. Multi-index time sequences can be used to represent relations between different time sequences. In this blog post, we will use the multivariate multi-index time sequence, also known as event sets. Don’t worry, they’re not as complex as they sound.

Examples of temporal data include:

  • Weather and other environmental data for weather forecasting, soil profile forecasting and crop yield optimization, temperature tracking, and climate change characterization.

  • Sensory data for quality monitoring, and predictive maintenance.

  • Health data for early treatment, personalized medicine, and epidemic detection.

  • Retail customer data for sales forecasting, sales optimization, and targeted advertising.

  • Banking customer data for fraud detection and loan risk analysis.

  • Economic and financial data for risk analysis, budgetary analysis, stock market analysis, and yield projections.

A simple example

Let’s start with a simple example. We have collected sales records from a fictitious online shop. Each time a client makes a purchase, we record the following information: time of the purchase, client id, product purchased, and price of the product.

The dataset is stored in a single CSV file, with one transaction per line:

$ head -n 5 sales.csv
2010-10-05 11:09:56,c64,p35,405.35
2010-09-27 15:00:49,c87,p29,605.35
2010-09-09 12:58:33,c97,p10,108.99
2010-09-06 12:43:45,c60,p85,443.35

Looking at data is crucial to understand the data and spot potential issues. Our first task is to load the sales data into an EventSet and plot it.

INFO: A Temporian EventSet is a general-purpose container for temporal data. It can represent multivariate time series, time sequences, and indexed data.

# Import Temporian
import temporian as tp

# Load the csv dataset
sales = tp.from_csv("/tmp/sales.csv")

# Print details about the EventSet

This code snippet load and print the data:

We can also plot the data:
# Plot "price" feature of the EventSet


We have shown how to load and visualize temporal data in just a few lines of code. However, the resulting plot is very busy, as it shows all transactions for all clients in the same view.

A common operation on temporal data is to calculate the moving sum. Let’s calculate and plot the sum of sales for each transaction in the previous seven days. The moving sum can be computed using the moving_sum operator.

weekly_sales = sales["price"].moving_sum(tp.duration.days(7))



BONUS: To make the plots interactive, you can add the interactive=True argument to the plot function. 

Sales per products

In the previous step, we computed the overall moving sum of sales for the entire shop. However, what if we wanted to calculate the rolling sum of sales for each product or client separately?

For this task, we can use an index.

# Index the data by "product"
sales_per_product = sales.add_index("product")

# Compute the moving sum for each product
weekly_sales_per_product = sales_per_product["price"].moving_sum(

# Plot the results


NOTE: Many operators such as moving_sum applied independently on each index.

Aggregate transactions into time series

Our dataset contains individual client transactions. To use this data with a machine learning model, it is often useful to aggregate it into time series, where the data is sampled uniformly over time. For example, we could aggregate the sales weekly, or calculate the total sales in the last week for each day.

However, it is important to note that aggregating transaction data into time series can result in some data loss. For example, the individual transaction timestamps and values would be lost. This is because the aggregated time series would only represent the total sales for each time period.

Let’s compute the total sales in the last week for each day for each product individually.

# The data is sampled daily
daily_sampling = sales_per_product.tick(tp.duration.days(1))

weekly_sales_daily = sales_per_product["price"].moving_sum(
    sampling=daily_sampling, # The new bit



NOTE: The current plot is a continuous line, while the previous plots have markers. This is because Temporian uses continuous lines by default when the data is uniformly sampled, and markers otherwise.

After the data preparation stage is finished, the data can be exported to a Pandas DataFrame as a final step.


Train a forecasting model with TensorFlow model

A key application of Temporian is to clean data and perform feature engineering for machine learning models. It is well suited for forecasting, anomaly detection, fraud detection, and other tasks where data comes continuously.

In this example, we show how to train a TensorFlow model to predict the next day’s sales using past sales for each product individually. We will feed the model various levels of aggregations of sales as well as calendar information.

Let’s first augment our dataset and convert it to a dataset compatible with a tabular ML model.

sales_per_product = sales.add_index("product")

# Create one example per day
daily_sampling = sales_per_product.tick(tp.duration.days(1))

# Compute moving sums with various window length.
# Machine learning models are able to select the ones that matter.

features = [] for w in [3, 7, 14, 28]:
features.append(sales_per_product["price"] .moving_sum(

# Calendar information such as the day of the week are
# very informative of human activities.

# The label is the daly sales shifted / leaked one days in the future.
label = (sales_per_product["price"] .leak(tp.duration.days(1))

# Collect the features and labels together.
dataset = tp.glue(*features, label)



We can then convert the dataset from EventSet to TensorFlow Dataset format, and train a Random Forest.

import tensorflow_decision_forests as tfdf

def extract_label(example):
example.pop("timestamp") # Don't use use the timestamps as feature
label = example.pop("label")
return example, label

tf_dataset = tp.to_tensorflow_dataset(dataset).map(extract_label).batch(100)

model = tfdf.keras.RandomForestModel(task=tfdf.keras.Task.REGRESSION,verbose=2)

And that’s it, we have a model trained to forecast sales. We now can look at the variable importance of the model to understand what features matter the most.


In the summary, we can find the INV_MEAN_MIN_DEPTH variable importance:

Variable Importance: INV_MEAN_MIN_DEPTH:
1. "moving_sum_28" 0.342231 ################
2. "product" 0.294546 ############
3. "calendar_day_of_week" 0.254641 ##########
4. "moving_sum_14" 0.197038 ######
5. "moving_sum_7" 0.124693 #
6. "moving_sum_3" 0.098542

We see that moving_sum_28 is the feature with the highest importance (0.342231). This indicates that the sum of sales in the last 28 days is very important to the model. To further improve our model, we should probably add more temporal aggregation features. The product feature also matters a lot.

And to get an idea of the model itself, we can plot one of the trees of the Random Forest.

tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=2)

More on temporal data preprocessing

We demonstrated some simple data preprocessing. If you want to see other examples of temporal data preprocessing on different data domains, check the Temporian tutorials. Notably:

  • Heart rate analysis ❤️ detects individual heartbeats and derives heart rate related features on raw ECG signals from Physionet.
  • M5 Competition 🛒 predicts retail sales in the M5 Makridakis Forecasting competition.
  • Loan outcomes prediction 🏦 prepares relational SQL data to predict outcomes for finished loans.
  • Detecting payment card fraud 💳 detects fraudulent payment card transactions in real time.
  • Supervised and unsupervised anomaly detection 🔎 perform data analysis and feature engineering to detect anomalies in a group of server’s resource usage metrics.

Next Steps

We demonstrated how to handle temporal data such as transactions in TensorFlow using the Temporian library. Now you can try it too!

To learn more about model training with TensorFlow Decision Forests:

Read More

Distributed Fast Fourier Transform in TensorFlow

Distributed Fast Fourier Transform in TensorFlow

Posted by Ruijiao Sun, Google Intern – DTensor team

Fast Fourier Transform is an important method of signal processing, which is commonly used in a number of ways, including speeding up convolutions, extracting features, and regularizing models. Distributed Fast Fourier Transform (Distributed FFT) offers a way to compute Fourier Transforms in models that work with image-like datasets that are too large to fit into the memory of a single accelerator device. In a previous Google Research Paper, “Large-Scale Discrete Fourier Transform on TPUs” by Tianjian Lu, a Distributed FFT algorithm was implemented for TensorFlow v1 as a library. This work presents the newly added native support in TensorFlow v2 for Distributed FFT, through the new TensorFlow distribution API, DTensor.

About DTensor

DTensor is an extension to TensorFlow for synchronous distributed computing. It distributes the program and tensors through a procedure called Single program, multiple data (SPMD) extension. DTensor offers an uniform API for traditional data and model parallelism patterns used widely in Machine Learning.

Example Usage

The API interface for distributed FFT is the same as the original FFT in TensorFlow. Users just need to pass a sharded tensor as an input to the existing FFT ops in TensorFlow, such as tf.signal.fft2d. The output of a distributed FFT becomes sharded too.

import TensorFlow as tf
from TensorFlow.experimental import dtensor

# Set up devices
device_type = dtensor.preferred_device_type()
if device_type == 'CPU':
cpu = tf.config.list_physical_devices(device_type)
tf.config.set_logical_device_configuration(cpu[0], [tf.config.LogicalDeviceConfiguration()] * 8)
if device_type == 'GPU':
gpu = tf.config.list_physical_devices(device_type)
tf.config.set_logical_device_configuration(gpu[0], [tf.config.LogicalDeviceConfiguration(memory_limit=1000)] * 8)

# Create a mesh
mesh = dtensor.create_distributed_mesh(mesh_dims=[('x', 1), ('y', 2), ('z', 4)], device_type=device_type)

# Set up a distributed input Tensor
input = tf.complex(
tf.random.stateless_normal(shape=(2, 2, 4), seed=(1, 2), dtype=tf.float32),
tf.random.stateless_normal(shape=(2, 2, 4), seed=(2, 4), dtype=tf.float32))
init_layout = dtensor.Layout(['x', 'y', 'z'], mesh)
d_input = dtensor.relayout(input, layout=init_layout)

# Run distributed fft2d. DTensor determines the most efficient
layout of of d_output.
d_output = tf.signal.fft2d(d_input)

Performance Analysis

The following experiment demonstrates that the distributed FFT can process more data than the non-distributed one by utilizing memory across multiple devices. The tradeoff is spending additional time on communication and data transposes that slow down the calculation speed.

Graph of performance on different machines, measuri8ng wall clock time in seconds by size per dimension across single GPU, Distributed FFT and Undistributed FFT

This phenomenon is shown in detail from the profiling result of the 10K*10K distributed FFT experiment. The current implementation of distributed FFT in TensorFlow follows the simple shuffle+local FFT method, which is also used by other popular distributed FFT libraries such as FFTW and PFFT. Notably, the two local FFT ops only take 3.6% of the total time (15ms). This is around 1/3 of the time for non-distributed fft2d. Most of the computing time is spent on data shuffling, represented by the ncclAllToAll Operation. Note that these experiments were conducted on an 8xV100 GPU system.

Table of Top 10 TensorFlow operations on GPU highlighting two local FFT ops in the top 3

Next steps

The feature is new and we have adopted a simplest distributed FFT algorithm. A few ideas to fine tune or improve the performance are:

  • Switch to a different DFT/FFT algorithm.
  • Tweaks on the NCCL communication settings for the particular FFT sizes may improve utilization of the network bandwidth and increase the speed.
  • Reducing the number of collectives to minimize bandwidth requirements.
  • Use N-d local FFTs, rather than multiple 1-d local FFTs.

Try the new distributed FFT! We welcome your feedback on the TensorFlow Forum and look forward to working with you on improving the performance. Your input would be invaluable!

Read More