Optimizing TF, XLA and JAX for LLM Training on NVIDIA GPUs

Posted by Douglas Yarrington (Google TPgM), James Rubin (Google PM), Neal Vaidya (NVIDIA TME), Jay Rodge (NVIDIA PMM)

Together, NVIDIA and Google are delighted to announce new milestones and plans to optimize TensorFlow and JAX for the Ampere and recently announced Hopper GPU architectures by leveraging the power of XLA: a performant, flexible and extensible ML compiler built by Google. We will deepen our ongoing collaboration with dedicated engineering teams focused on delivering improved performance in currently available A100 GPUs. NVIDIA and Google will also jointly support unique features in the recently announced H100 GPU, including the Transformer Engine with support for hardware-accelerated 8-bit floating-point (FP8) data types and the transformer library.

We are announcing improved performance in TensorFlow, new NVIDIA GPU-specific features in XLA and the first release of JAX for multi-node, multi-GPU training, which will significantly improve large language model (LLM) training. We expect the Hopper architecture to be especially popular for LLMs.

NVIDIA H100 Tensor Core GPU


Google delivers high performance with LLMs on NVIDIA GPUs because of a notable technology, XLA, which supports all leading ML frameworks, such as TensorFlow, JAX, and PyTorch. Over 90% of Google’s ML compilations – across research and production, happen on XLA. These span the gamut of ML use cases, from ultra-large scale model training at DeepMind and Google Research, to optimized deployments across our products, to edge inferencing at Waymo.

XLA’s deep feature set accelerates large language model performance and is solving most large model challenges seen in the industry today. For example, a feature unique to XLA, SPMD, automates most of the work needed to partition models across multiple cores and devices, making large model training significantly more scalable and performant. XLA can also automatically recognize and select the most optimal hand-written library implementation for your target backend, like cuDNN for CUDA chipsets. Otherwise, XLA can natively generate optimized code for performant execution.

We’ve been collaborating with NVIDIA on several exciting features and integrations that will further optimize LLMs for GPUs. We recently enabled collectives such as all-reduce to run in parallel to compute. This has resulted in a significant reduction in end to end latency for customers. Furthermore, we enabled support for bfloat16, which has resulted in compute gains of 4.5x over 32 bit floating point while retaining the same dynamic range of values.

Our joint efforts mean that XLA integrates even more deeply with NVIDIA’s AI tools and can better leverage NVIDIA’s suite of AI hardware optimized libraries. In Q1 2023, we will release a XLA-cuDNN Graph API integration, which provides customers with optimized fusion of convolution/matmul operations and multi-headed attention in transformers for improved use of memory and faster GPU kernel execution. As a result, overheads drop significantly and performance improves notably.

TensorFlow for GPU

TensorFlow recently released distributed tensors (or DTensors) to enable Tensor storage across devices like NVIDIA GPUs while allowing programs to manipulate them seamlessly. The goal of DTensor is to make parallelizing large-scale TensorFlow models across multiple devices easy, understandable, and fast. DTensors are a drop-in replacement for local TensorFlow tensors and scale well to large clusters. In addition, the DTensor project improves the underlying TensorFlow execution and communication primitives, and they are available for use today!

We are also collaborating with NVIDIA on several exciting new features in TensorFlow that leverage GPUs, including supporting the new FP8 datatype which should yield a significant improvement in training times for transformer models, when using the Hopper H100 GPU.


Google seeks to empower every developer with purpose-built tools for every step of the ML workflow. That includes TensorFlow for robust, production-ready models and JAX with highly optimized capabilities for cutting-edge research. We are pleased to announce the unique collaboration between NVIDIA and Google engineering teams to enhance TensorFlow and JAX for large deep-learning models, like LLMs. Both frameworks fully embrace NVIDIA A100 GPUs, and will support the recently-announced H100 GPUs in the future.

One of the key advantages of JAX is the ease of achieving superior hardware utilization with industry-leading FLOPs across the accelerators. Through our collaboration with NVIDIA, we are translating these advantages to GPU using some XLA compiler magic. Specifically, we are leveraging XLA for operator fusion, improving GSPMD for GPU to support generalized data and model parallelism and optimizing for cross-host NVLink.

Future Plans

NVIDIA and Google are pleased with all the progress shared in this post, and are excited to hear from community members about their experience using TensorFlow and JAX, by leveraging the power of XLA for Ampere (A100) and Hopper (H100) GPUs.

Check out the release notes for more information. To stay up to date, you can read the TensorFlow blog, follow twitter.com/tensorflow, or subscribe to youtube.com/tensorflow. If you’ve built something you’d like to share, please submit it for our Community Spotlight at goo.gle/TFCS. For feedback, please file an issue on GitHub or post to the TensorFlow Forum.

TensorFlow is also available in the NVIDIA GPU Cloud (NGC) as a docker container that contains a validated set of libraries that enable and optimize GPU performance, with JAX NGC container coming soon later this year.

Thank you!

Contributors: Frederic Bastien (NVIDIA), Abhishek Ratna (Google), Sean Lee (NVIDIA), Nathan Luehr (NVIDIA), Ayan Moitra (NVIDIA), Yash Katariya (Google), Peter Hawkins (Google), Skye Wanderman-Milne (Google), David Majnemer (Google), Stephan Herhut (Google), George Karpanov (Google), Mahmoud Soliman (NVIDIA), Yuan Lin (NVIDIA), Vartika Singh (NVIDIA), Vinod Grover (NVIDIA), Pooya Jannaty (NVIDIA), Paresh Kharya (NVIDIA), Santosh Bhavani (NVIDIA)

Read More