Accelerated Diffusers with PyTorch 2.0

Accelerated Diffusers with PyTorch 2.0

PyTorch 2.0 has just been released. Its flagship new feature is torch.compile(), a one-line code change that promises to automatically improve performance across codebases. We have previously checked on that promise in Hugging Face Transformers and TIMM models, and delved deep into its motivation, architecture and the road ahead.

As important as torch.compile() is, there’s much more to PyTorch 2.0. Notably, PyTorch 2.0 incorporates several strategies to accelerate transformer blocks, and these improvements are very relevant for diffusion models too. Techniques such as FlashAttention, for example, have become very popular in the diffusion community thanks to their ability to significantly speed up Stable Diffusion and achieve larger batch sizes, and they are now part of PyTorch 2.0.

In this post we discuss how attention layers are optimized in PyTorch 2.0 and how these optimization are applied to the popular 🧨 Diffusers library. We finish with a benchmark that shows how the use of PyTorch 2.0 and Diffusers immediately translates to significant performance improvements across different hardware.

Accelerating transformer blocks

PyTorch 2.0 includes a scaled dot-product attention function as part of torch.nn.functional. This function encompasses several implementations that can be applied depending on the inputs and the hardware in use. Before PyTorch 2.0, you had to search for third-party implementations and install separate packages in order to take advantage of memory optimized algorithms, such as FlashAttention. The available implementations are:

  • FlashAttention, from the official FlashAttention project.
  • Memory-Efficient Attention, from the xFormers project.
  • A native C++ implementation suitable for non-CUDA devices or when high-precision is required.

All these methods are available by default, and PyTorch will try to select the optimal one automatically through the use of the new scaled dot-product attention (SDPA) API. You can also individually toggle them for finer-grained control, see the documentation for details.

Using scaled dot-product attention in diffusers

The incorporation of Accelerated PyTorch 2.0 Transformer attention to the Diffusers library was achieved through the use of the set_attn_processor method, which allows for pluggable attention modules to be configured. In this case, a new attention processor was created, which is enabled by default when PyTorch 2.0 is available. For clarity, this is how you could enable it manually (but it’s usually not necessary since diffusers will automatically take care of it):

from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")"cuda")

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]

Stable Diffusion Benchmark

We ran a number of tests using accelerated dot-product attention from PyTorch 2.0 in Diffusers. We installed diffusers from pip and used nightly versions of PyTorch 2.0, since our tests were performed before the official release. We also used torch.set_float32_matmul_precision('high') to enable additional fast matrix multiplication algorithms.

We compared results with the traditional attention implementation in diffusers (referred to as vanilla below) as well as with the best-performing solution in pre-2.0 PyTorch: PyTorch 1.13.1 with the xFormers package (v0.0.16) installed.

Results were measured without compilation (i.e., no code changes at all), and also with a single call to torch.compile() to wrap the UNet module. We did not compile the image decoder because most of the time is spent in the 50 denoising iterations that run UNet evaluations.

Results in float32

Diffusers Speedup vs xFormers float32

The following figures explore performance improvement vs batch size for various representative GPUs belonging to different generations. We collected data for each combination until we reached maximum memory utilization. Vanilla attention runs out of memory earlier than xFormers or PyTorch 2.0, which explains the missing bars for larger batch sizes. Similarly, A100 (we used the 40 GB version) is capable of running batch sizes of 64, but the other GPUs could only reach 32 in our tests.

Diffusers Inference Speedup vs Vanilla and xFormers Attention (A100, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (3090, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (4090, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (V100, float32)

We found very significant performance improvements over vanilla attention across the board, without even using torch.compile(). An out of the box installation of PyTorch 2.0 and diffusers yields about 50% speedup on A100 and between 35% and 50% on 4090 GPUs, depending on batch size. Performance improvements are more pronounced for modern CUDA architectures such as Ada (4090) or Ampere (A100), but they are still very significant for older architectures still heavily in use in cloud services.

In addition to faster speeds, the accelerated transformers implementation in PyTorch 2.0 allows much larger batch sizes to be used. A single 40GB A100 GPU runs out of memory with a batch size of 10, and 24 GB high-end consumer cards such as 3090 and 4090 cannot generate 8 images at once. Using PyTorch 2.0 and diffusers we could achieve batch sizes of 48 for 3090 and 4090, and 64 for A100. This is of great significance for cloud services and applications, as they can efficiently process more images at a time.

When compared with PyTorch 1.13.1 + xFormers, the new accelerated transformers implementation is still faster and requires no additional packages or dependencies. In this case we found moderate speedups of up to 2% on datacenter cards such as A100 or T4, but performance was great on the two last generations of consumer cards: up to 20% speed improvement on 3090 and between 10% and 45% on 4090, depending on batch size.

When torch.compile() is used, we get an additional performance boost of (typically) 2% and 3% over the previous improvements. As compilation takes some time, this is better geared towards user-facing inference services or training.

Results in float16

Diffusers Speedup vs xFormers float16

Diffusers Inference Speedup vs Vanilla and xFormers Attention (A100, float16)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (4090, float16)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (3090, float16)

When we consider float16 inference, the performance improvements of the accelerated transformers implementation in PyTorch 2.0 are between 20% and 28% over standard attention, across all the GPUs we tested, except for the 4090, which belongs to the more modern Ada architecture. This GPU benefits from a dramatic performance improvement when using PyTorch 2.0 nightlies. With respect to optimized SDPA vs xFormers, results are usually on par for most GPUs, except again for the 4090. Adding torch.compile() to the mix boosts performance a few more percentage points across the board.


PyTorch 2.0 comes with multiple features to optimize the crucial components of the foundational transformer block, and they can be further improved with the use of torch.compile. These optimizations lead to significant memory and time improvements for diffusion models, and remove the need for third-party library installations.

To take advantage of these speed and memory improvements all you have to do is upgrade to PyTorch 2.0 and use diffusers >= 0.13.0.

For more examples and in-detail benchmark numbers, please also have a look at the Diffusers with PyTorch 2.0 docs.


The authors are grateful to the PyTorch team for their insights, assistance and suggestions during the elaboration of this post, and for creating such excellent software. We are particularly indebted to Hamid Shojanazeri, Grigory Sizov, Christian Puhrsch, Driss Guessous, Michael Gschwind and Geeta Chauhan.

Read More

PyTorch 2.0: Our next generation release that is faster, more Pythonic and Dynamic as ever

PyTorch 2.0: Our next generation release that is faster, more Pythonic and Dynamic as ever

We are excited to announce the release of PyTorch® 2.0 which we highlighted during the PyTorch Conference on 12/2/22! PyTorch 2.0 offers the same eager-mode development and user experience, while fundamentally changing and supercharging how PyTorch operates at compiler level under the hood with faster performance and support for Dynamic Shapes and Distributed.

This next-generation release includes a Stable version of Accelerated Transformers (formerly called Better Transformers); Beta includes torch.compile as the main API for PyTorch 2.0, the scaled_dot_product_attention function as part of torch.nn.functional, the MPS backend, functorch APIs in the torch.func module; and other Beta/Prototype improvements across various inferences, performance and training optimization features on GPUs and CPUs. For a comprehensive introduction and technical overview of torch.compile, please visit the 2.0 Get Started page.

Along with 2.0, we are also releasing a series of beta updates to the PyTorch domain libraries, including those that are in-tree, and separate libraries including TorchAudio, TorchVision, and TorchText. An update for TorchX is also being released as it moves to community supported mode. More details can be found in this library blog.

This release is composed of over 4,541 commits and 428 contributors since 1.13.1. We want to sincerely thank our dedicated community for your contributions. As always, we encourage you to try these out and report any issues as we improve 2.0 and the overall 2-series this year.


  • torch.compile is the main API for PyTorch 2.0, which wraps your model and returns a compiled model. It is a fully additive (and optional) feature and hence 2.0 is 100% backward compatible by definition.
  • As an underpinning technology of torch.compile, TorchInductor with Nvidia and AMD GPUs will rely on OpenAI Triton deep learning compiler to generate performant code and hide low level hardware details. OpenAI Triton-generated kernels achieve performance that’s on par with hand-written kernels and specialized cuda libraries such as cublas.
  • Accelerated Transformers introduce high-performance support for training and inference using a custom kernel architecture for scaled dot product attention (SPDA). The API is integrated with torch.compile() and model developers may also use the scaled dot product attention kernels directly by calling the new scaled_dot_product_attention() operator.
  • Metal Performance Shaders (MPS) backend provides GPU accelerated PyTorch training on Mac platforms with added support for Top 60 most used ops, bringing coverage to over 300 operators.
  • Amazon AWS optimizes the PyTorch CPU inference on AWS Graviton3 based C7g instances. PyTorch 2.0 improves inference performance on Graviton compared to the previous releases, including improvements for Resnet50 and Bert.
  • New prototype features and technologies across TensorParallel, DTensor, 2D parallel, TorchDynamo, AOTAutograd, PrimTorch and TorchInductor.
Stable Beta Prototype Performance Improvements

Accelerated PT 2 Transformers



CUDA support for 11.7 & 11.8 (deprecating CUDA 11.6)

PyTorch MPS Backend


Python 1.8 (deprecating Python 1.7)

Scaled dot product attention

2D Parallel

AWS Graviton3


Torch.compile (dynamic=True)

Dispatchable Collectives
Torch.set_default & torch.device

X86 quantization backend

GNN inference and training performance

*To see a full list of public 2.0, 1.13 and 1.12 feature submissions click here.

Stable Features

[Stable] Accelerated PyTorch 2 Transformers (previously known as “Better Transformer”)

The PyTorch 2.0 release includes a new high-performance implementation of the PyTorch Transformer API, formerly known as “Better Transformer API, “ now renamed Accelerated PyTorch 2 Transformers. In releasing accelerated PT2 Transformers, our goal is to make training and deployment of state-of-the-art Transformer models affordable across the industry. This release introduces high-performance support for training and inference using a custom kernel architecture for scaled dot product attention (SPDA).

Similar to the “fastpath” architecture, custom kernels are fully integrated into the PyTorch Transformer API – thus, using the native Transformer and MultiHeadAttention API will enable users to:

  • transparently see significant speed improvements;
  • support many more use cases including models using Cross-Attention, Transformer Decoders, and for training models; and
  • continue to use fastpath inference for fixed and variable sequence length Transformer Encoder and Self Attention use cases.

To take full advantage of different hardware models and Transformer use cases, multiple SDPA custom kernels are supported (see below), with custom kernel selection logic that will pick the highest-performance kernel for a given model and hardware type. In addition to the existing Transformer API, model developers may also use the

scaled dot product attention kernels directly by calling the new scaled_dot_product_attention() operator. Accelerated PyTorch 2 Transformers are integrated with torch.compile() . To use your model while benefiting from the additional acceleration of PT2-compilation (for inference or training), pre-process the model with model = torch.compile(model).

We have achieved major speedups for training transformer models and in particular large language models with Accelerated PyTorch 2 Transformers using a combination of custom kernels and torch.compile().

alt_textFigure: Using scaled dot product attention with custom kernels and torch.compile delivers significant speedups for training large language models, such as for nanoGPT shown here.

Beta Features

[Beta] torch.compile

torch.compile is the main API for PyTorch 2.0, which wraps your model and returns a compiled model. It is a fully additive (and optional) feature and hence 2.0 is 100% backward compatible by definition.

Underpinning torch.compile are new technologies – TorchDynamo, AOTAutograd, PrimTorch and TorchInductor:

  • TorchDynamo captures PyTorch programs safely using Python Frame Evaluation Hooks and is a significant innovation that was a result of 5 years of our R&D into safe graph capture.
  • AOTAutograd overloads PyTorch’s autograd engine as a tracing autodiff for generating ahead-of-time backward traces.
  • PrimTorch canonicalizes ~2000+ PyTorch operators down to a closed set of ~250 primitive operators that developers can target to build a complete PyTorch backend. This substantially lowers the barrier of writing a PyTorch feature or backend.
  • TorchInductor is a deep learning compiler that generates fast code for multiple accelerators and backends. For NVIDIA and AMD GPUs, it uses OpenAI Triton as a key building block. For intel CPUs, we generate C++ code using multithreading, vectorized instructions and offloading appropriate operations to mkldnn when possible.

With all the new technologies, torch.compile is able to work 93% of time across 165 open-source models and runs 20% faster on average at float32 precision and 36% faster on average at AMP precision.

For more information, please refer to and for TorchInductor CPU with Intel here.

[Beta] PyTorch MPS Backend

MPS backend provides GPU-accelerated PyTorch training on Mac platforms. This release brings improved correctness, stability, and operator coverage.

MPS backend now includes support for the Top 60 most used ops, along with the most frequently requested operations by the community, bringing coverage to over 300 operators. The major focus of the release was to enable full OpInfo-based forward and gradient mode testing to address silent correctness issues. These changes have resulted in wider adoption of MPS backend by 3rd party networks such as Stable Diffusion, YoloV5, WhisperAI, along with increased coverage for Torchbench networks and Basic tutorials. We encourage developers to update to the latest macOS release to see the best performance and stability on the MPS backend.


  1. MPS Backend
  2. Developer information
  3. Accelerated PyTorch training on Mac
  4. Metal, Metal Performance Shaders & Metal Performance Shaders Graph

[Beta] Scaled dot product attention 2.0

We are thrilled to announce the release of PyTorch 2.0, which introduces a powerful scaled dot product attention function as part of torch.nn.functional. This function includes multiple implementations that can be seamlessly applied depending on the input and hardware in use.

In previous versions of PyTorch, you had to rely on third-party implementations and install separate packages to take advantage of memory-optimized algorithms like FlashAttention. With PyTorch 2.0, all these implementations are readily available by default.

These implementations include FlashAttention from HazyResearch, Memory-Efficient Attention from the xFormers project, and a native C++ implementation that is ideal for non-CUDA devices or when high-precision is required.

PyTorch 2.0 will automatically select the optimal implementation for your use case, but you can also toggle them individually for finer-grained control. Additionally, the scaled dot product attention function can be used to build common transformer architecture components.

Learn more with the documentation and this tutorial.

[Beta] functorch -> torch.func

Inspired by Google JAX, functorch is a library that offers composable vmap (vectorization) and autodiff transforms. It enables advanced autodiff use cases that would otherwise be tricky to express in PyTorch. Examples include:

We’re excited to announce that, as the final step of upstreaming and integrating functorch into PyTorch, the functorch APIs are now available in the torch.func module. Our function transform APIs are identical to before, but we have changed how the interaction with NN modules work. Please see the docs and the migration guide for more details.

Furthermore, we have added support for torch.autograd.Function: one is now able to apply function transformations (e.g. vmap, grad, jvp) over torch.autograd.Function.

[Beta] Dispatchable Collectives

Dispatchable collectives is an improvement to the existing init_process_group() API which changes backend to an optional argument. For users, the main advantage of this feature is that it will allow them to write code that can run on both GPU and CPU machines without having to change the backend specification. The dispatchability feature will also make it easier for users to support both GPU and CPU collectives, as they will no longer need to specify the backend manually (e.g. “NCCL” or “GLOO”). Existing backend specifications by users will be honored and will not require change.

Usage example:

import torch.distributed.dist
# old
dist.init_process_group(backend=”nccl”, ...)
dist.all_reduce(...) # with CUDA tensors works
dist.all_reduce(...) # with CPU tensors does not work

# new
dist.init_process_group(...) # backend is optional
dist.all_reduce(...) # with CUDA tensors works
dist.all_reduce(...) # with CPU tensors works

Learn more here.

[Beta] torch.set_default_device and torch.device as context manager

torch.set_default_device allows users to change the default device that factory functions in PyTorch allocate on. For example, if you torch.set_default_device(‘cuda’), a call to torch.empty(2) will allocate on CUDA (rather than on CPU). You can also use torch.device as a context manager to change the default device on a local basis. This resolves a long standing feature request from PyTorch’s initial release for a way to do this.

Learn more here.

[Beta] “X86” as the new default quantization backend for x86 CPU

The new X86 quantization backend, which utilizes FBGEMM and oneDNN kernel libraries, replaces FBGEMM as the default quantization backend for x86 CPU platforms and offers improved int8 inference performance compared to the original FBGEMM backend, leveraging the strengths of both libraries, with 1.3X – 2X inference performance speedup measured on 40+ deep learning models. The new backend is functionally compatible with the original FBGEMM backend.

Table: Geomean Speedup of X86 Quantization Backend vs. FBGEMM Backend

1 core/instance 2 cores/instance 4 cores/instance 1 socket (32 cores)/instance
Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz 1.76X 1.80X 2.04X 1.34X

By default, users on x86 platforms will utilize the x86 quantization backend and their PyTorch programs will remain unchanged when using the default backend. Alternatively, users have the option to specify “X86” as the quantization backend explicitly. Example code is show below:

import torch
from import get_default_qconfig_mappingfrom torch.quantization.quantize_fx
import prepare_fx, convert_fx
# get default configuration
qconfig_mapping = get_default_qconfig_mapping()
# or explicitly specify the backend
# qengine = 'x86'
# torch.backends.quantized.engine = qengine
# qconfig_mapping = get_default_qconfig_mapping(qengine)
# construct fp32 model
model_fp32 = ...
# prepare
prepared_model = prepare_fx(model_fp32, qconfig_mapping, example_inputs=x)
# calibrate
# convert
quantized_model = convert_fx(prepared_model)

Find more information: and

[Beta] GNN inference and training optimization on CPU

PyTorch 2.0 includes several critical optimizations to improve GNN inference and training performance on CPU. Before 2.0, GNN models of PyG suffers from low efficiency on CPU due to lack of performance tuning for several critical kernels (scatter/gather, etc) and the lack of GNN-related sparse matrix multiplication ops. To be specific, optimizations include:

  • scatter_reduce: performance hotspot in Message Passing when the edge index is stored in Coordinate format (COO).
  • gather: backward of scatter_reduce, specially tuned for the GNN compute when the index is an expanded tensor.
  • with reduce flag: performance hotspot in Message Passing when the edge index is stored in Compressed Sparse Row (CSR). Supported reduce flag of: sum, mean, amax, amin.

On PyG benchmarks/examples, OGB benchmarks, a 1.12x – 4.07x performance speedup is measured (1.13.1 compared with 2.0) for single node inference and training.

Model-Dataset Option Speedup Ratio
GCN-Reddit (inference) 512-2-64-dense 1.22x
1024-3-128-dense 1.25x
512-2-64-sparse 1.31x
1024-3-128-sparse 1.68x
512-2-64-dense 1.22x
GraphSage-ogbn-products (inference) 1024-3-128-dense 1.15x
512-2-64-sparse 1.20x
1024-3-128-sparse 1.33x
full-batch-sparse 4.07x
GCN-PROTEINS (training) 3-32 1.67x
GCN-REDDIT-BINARY (training) 3-32 1.67x
GCN-Reddit (training) 512-2-64-dense 1.20x
1024-3-128-dense 1.12x

Learn more: PyG CPU Performance Optimization.

[Beta] Accelerating inference on CPU with PyTorch by leveraging oneDNN Graph

oneDNN Graph API extends oneDNN with a flexible graph API to maximize the optimization opportunity for generating efficient code on AI hardware.

  • It automatically identifies the graph partitions to be accelerated via fusion.
  • The fusion patterns focus on fusing compute-intensive operations such as convolution, matmul and their neighbor operations for both inference and training use cases.
  • Although work is ongoing to integrate oneDNN Graph with TorchDynamo as well, its integration with the PyTorch JIT Fuser attained beta status in PyTorch 2.0 for Float32 & BFloat16 inference (on machines that support AVX512_BF16 ISA).

From a developer’s/researcher’s perspective, the usage is quite simple & intuitive, with the only change in code being an API invocation:

  • Leverage oneDNN Graph, with JIT-tracing, a model is profiled with an example input.
  • The context manager with torch.jit.fuser(“fuser3”): can also be used instead of invoking torch.jit.enable_onednn_fusion(True).
  • For accelerating BFloat16 inference, we rely on eager-mode AMP (Automatic Mixed Precision) support in PyTorch & disable JIT mode’s AMP, as both of them are currently divergent:
# Assuming we have a model of the name 'model'
example_input = torch.rand(1, 3, 224, 224)
# enable oneDNN Graph
# Disable AMP for JIT
with torch.no_grad(), torch.cpu.amp.autocast():
	model = torch.jit.trace(model, (example_input))
	model = torch.jit.freeze(model)
 	# 2 warm-ups (2 for tracing/scripting with an example, 3 without an example)
	# speedup would be observed in subsequent runs.

Learn more here.

Prototype Features

Distributed API

[Prototype] DTensor

PyTorch DistributedTensor (DTensor) is a prototyping effort with distributed tensor primitives to allow easier distributed computation authoring in the SPMD (Single Program Multiple Devices) paradigm. The primitives are simple but powerful when used to express tensor distributions with both sharded and replicated parallelism strategies. PyTorch DTensor empowered PyTorch Tensor Parallelism along with other advanced parallelism explorations. In addition, it also offers a uniform way to save/load state_dict for distributed checkpointing purposes, even when there’re complex tensor distribution strategies such as combining tensor parallelism with parameter sharding in FSDP. More details can be found in this RFC and the DTensor examples notebook.

[Prototype] TensorParallel

We now support DTensor based Tensor Parallel which users can distribute their model parameters across different GPU devices. We also support Pairwise Parallel which shards two concatenated linear layers in a col-wise and row-wise style separately so that only one collective(all-reduce/reduce-scatter) is needed in the end. More details can be found in this example.

[Prototype] 2D Parallel

We implemented the integration of the aforementioned TP with FullyShardedDataParallel(FSDP) as 2D parallel to further scale large model training. More details can be found in this slide and code example.

[Prototype] torch.compile(dynamic=True)

Experimental support for PT2 compilation with dynamic shapes is available in this release. Inference compilation with inductor for simple models is supported, but there are a lot of limitations:

  • Training available in a future release (This is partially fixed in nightlies!)
  • Minifier available in a future release.
  • It is easy to end up in a situation where the dimension you wanted to be dynamic gets specialized anyway. Some of these issues are fixed in nightlies, others are not.
  • We do not appropriately propagate Inductor guards to the top-level, this is tracked at #96296.
  • Data-dependent operations like nonzero still require a graph break.
  • Dynamic does not work with non-standard modes like reduce-overhead or max-autotune.
  • There are many bugs in Inductor compilation. To track known bugs, check the dynamic shapes label on the PyTorch issue tracker.

For the latest and greatest news about dynamic shapes support on master, check out our status reports.

Highlights/Performance Improvements

Deprecation of Cuda 11.6 and Python 1.7 support for PyTorch 2.0

If you are still using or depending on CUDA 11.6 or Python 3.7 builds, we strongly recommend moving to at least CUDA 11.7 and Python 3.8, as it would be the minimum versions required for PyTorch 2.0. For more detail, please refer to the Release Compatibility Matrix for PyTorch releases.

Python 3.11 support on Anaconda Platform

Due to lack of Python 3.11 support for packages that PyTorch depends on, including NumPy, SciPy, SymPy, Pillow and others on the Anaconda platform. We will not be releasing Conda binaries compiled with Python 3.11 for PyTorch Release 2.0. The Pip packages with Python 3.11 support will be released, hence if you intend to use PyTorch 2.0 with Python 3.11 please use our Pip packages. Please note: Conda packages with Python 3.11 support will be made available on our nightly channel. Also we are planning on releasing Conda Python 3.11 binaries as part of future release once Anaconda provides these key dependencies. More information and instructions on how to download the Pip packages can be found here.

Optimized PyTorch Inference with AWS Graviton processors

The optimizations focused on three key areas: GEMM kernels, bfloat16 support, primitive caching and the memory allocator. For aarch64 platforms, PyTorch supports Arm Compute Library (ACL) GEMM kernels via Mkldnn(OneDNN) backend. The ACL library provides Neon/SVE GEMM kernels for fp32 and bfloat16 formats. The bfloat16 support on c7g allows efficient deployment of bfloat16 trained, AMP (Automatic Mixed Precision) trained, or even the standard fp32 trained models. The standard fp32 models leverage bfloat16 kernels via OneDNN fast math mode, without any model quantization. Next we implemented primitive caching for conv, matmul and inner product operators. More information on the updated PyTorch user guide with the upcoming 2.0 release improvements and TorchBench benchmark details can be found here.

Read More

New Library Updates in PyTorch 2.0


We are bringing a number of improvements to the current PyTorch libraries, alongside the PyTorch 2.0 release. These updates demonstrate our focus on developing common and extensible APIs across all domains to make it easier for our community to build ecosystem projects on PyTorch.

Along with 2.0, we are also releasing a series of beta updates to the PyTorch domain libraries, including those that are in-tree, and separate libraries including TorchAudio, TorchVision, and TorchText. An update for TorchX is also being released as it moves to community supported mode. Please find the list of the latest stable versions and updates below.

Latest Stable Library Versions (Full List)

TorchArrow 0.1.0 TorchRec 0.4.0 TorchVision 0.15
TorchAudio 2.0 TorchServe 0.7.1 TorchX 0.4.0
TorchData 0.6.0 TorchText 0.15.0 PyTorch on XLA Devices 1.14

*To see prior versions or (unstable) nightlies, click on versions in the top left menu above ‘Search Docs’.


[Beta] Data augmentation operators

The release adds several data augmentation operators under torchaudio.functional and torchaudio.transforms:

  • torchaudio.functional.add_noise
  • torchaudio.functional.convolve
  • torchaudio.functional.deemphasis
  • torchaudio.functional.fftconvolve
  • torchaudio.functional.preemphasis
  • torchaudio.functional.speed
  • torchaudio.transforms.AddNoise
  • torchaudio.transforms.Convolve
  • torchaudio.transforms.Deemphasis
  • torchaudio.transforms.FFTConvolve
  • torchaudio.transforms.Preemphasis
  • torchaudio.transforms.Speed
  • torchaudio.transforms.SpeedPerturbation

The operators can be used to synthetically diversify training data to improve the generalizability of downstream models.

For usage details, please refer to the functional and transform documentation and Audio Data Augmentation tutorial.

[Beta] WavLM and XLS-R models

The release adds two self-supervised learning models for speech and audio.

  • WavLM that is robust to noise and reverberation.
  • XLS-R that is trained on cross-lingual datasets.

Besides the model architectures, torchaudio also supports corresponding pre-trained pipelines:

  • torchaudio.pipelines.WAVLM_BASE
  • torchaudio.pipelines.WAVLM_BASE_PLUS
  • torchaudio.pipelines.WAVLM_LARGE
  • torchaudio.pipelines.WAV2VEC_XLSR_300M
  • torchaudio.pipelines.WAV2VEC_XLSR_1B
  • torchaudio.pipelines.WAV2VEC_XLSR_2B

For usage details, please refer to the factory function and pre-trained pipelines documentation.


The initial release of torchrl includes several features that span across the entire RL domain. TorchRL can already be used in online, offline, multi-agent, multi-task and distributed RL settings, among others. See below:

[Beta] Environment wrappers and transforms

torchrl.envs includes several wrappers around common environment libraries. This allows users to swap one library with another without effort. These wrappers build an interface between these simulators and torchrl:

  • dm_control:
  • Gym
  • Brax
  • EnvPool
  • Jumanji
  • Habitat

It also comes with many commonly used transforms and vectorized environment utilities that allow for a fast execution across simulation libraries. Please refer to the documentation for more detail.

[Beta] Datacollectors

Data collection in RL is made easy via the usage of single process or multiprocessed/distributed data collectors that execute the policy in the environment over a desired duration and deliver samples according to the user’s needs. These can be found in torchrl.collectors and are documented here.

[Beta] Objective modules

Several objective functions are included in torchrl.objectives, among which:

  • A generic PPOLoss class and derived ClipPPOLoss and KLPPOLoss
  • SACLoss and DiscreteSACLoss
  • DDPGLoss
  • DQNLoss
  • REDQLoss
  • A2CLoss
  • TD3Loss
  • ReinforceLoss
  • Dreamer

Vectorized value function operators also appear in the library. Check the documentation here.

[Beta] Models and exploration strategies

We provide multiple models, modules and exploration strategies. Get a detailed description in the doc.

[Beta] Composable replay buffer

A composable replay buffer class is provided that can be used to store data in multiple contexts including single and multi-agent, on and off-policy and many more.. Components include:

  • Storages (list, physical or memory-based contiguous storages)
  • Samplers (Prioritized, sampler without repetition)
  • Writers
  • Possibility to add transforms

Replay buffers and other data utilities are documented here.

[Beta] Logging tools and trainer

We support multiple logging tools including tensorboard, wandb and mlflow.

We provide a generic Trainer class that allows for easy code recycling and checkpointing.

These features are documented here.


TensorDict is a new data carrier for PyTorch.

[Beta] TensorDict: specialized dictionary for PyTorch

TensorDict allows you to execute many common operations across batches of tensors carried by a single container. TensorDict supports many shape and device or storage operations, and can readily be used in distributed settings. Check the documentation to know more.

[Beta] @tensorclass: a dataclass for PyTorch

Like TensorDict, tensorclass provides the opportunity to write dataclasses with built-in torch features such as shape or device operations.

[Beta] tensordict.nn: specialized modules for TensorDict

The tensordict.nn module provides specialized nn.Module subclasses that make it easy to build arbitrarily complex graphs that can be executed with TensorDict inputs. It is compatible with the latest PyTorch features such as functorch, torch.fx and torch.compile.


[Beta] KeyedJaggedTensor All-to-All Redesign and Input Dist Fusion

We observed performance regression due to a bottleneck in sparse data distribution for models that have multiple, large KJTs to redistribute.

To combat this we altered the comms pattern to transport the minimum data required in the initial collective to support the collective calls for the actual KJT tensor data. This data sent in the initial collective, ‘splits’ means more data is transmitted over the comms stream overall, but the CPU is blocked for significantly shorter amounts of time leading to better overall QPS.

Furthermore, we altered the TorchRec train pipeline to group the initial collective calls for the splits together before launching the more expensive KJT tensor collective calls. This fusion minimizes the CPU blocked time as launching each subsequent input distribution is no longer dependent on the previous input distribution.

With this feature, variable batch sizes are now natively supported across ranks. These features are documented here.


[Beta] Extending TorchVision’s Transforms to Object Detection, Segmentation & Video tasks

TorchVision is extending its Transforms API! Here is what’s new:

  • You can use them not only for Image Classification but also for Object Detection, Instance & Semantic Segmentation and Video Classification.
  • You can use new functional transforms for transforming Videos, Bounding Boxes and Segmentation Masks.

Learn more about these new transforms from our docs, and submit any feedback in our dedicated issue.


[Beta] Adding scriptable T5 and Flan-T5 to the TorchText library with incremental decoding support!

TorchText has added the T5 model architecture with pre-trained weights for both the original T5 paper and Flan-T5. The model is fully torchscriptable and features an optimized multiheaded attention implementation. We include several examples of how to utilize the model including summarization, classification, and translation.

For more details, please refer to our docs.


TorchX is moving to community supported mode. More details will be coming in at a later time.

Read More

Democratizing AI with PyTorch Foundation and ROCm™ support for PyTorch

Democratizing AI with PyTorch Foundation and ROCm™ support for PyTorch

AMD Founding Member

Last year, Meta announced that PyTorch joined the Linux Foundation as a neutral home for growing the machine learning project and community with AMD representation as a part of the founding membership and governing board.

PyTorch Foundation’s mission is to drive AI adoption by democratizing its software ecosystem through open source principles aligning with the AMD core principle of an Open software ecosystem. AMD strives to foster innovation through the support for latest generations of hardware, tools, libraries, and other components to simplify and accelerate adoption of AI across a broad range of scientific discoveries.

AMD, along with key PyTorch codebase developers (including those at Meta AI), delivered a set of updates to the ROCm™ open software ecosystem that brings stable support for AMD Instinct™ accelerators as well as many Radeon™ GPUs. This now gives PyTorch developers the ability to build their next great AI solutions leveraging AMD GPU accelerators & ROCm. The support from PyTorch community in identifying gaps, prioritizing key updates, providing feedback for performance optimizing and supporting our journey from “Beta” to “Stable” was immensely helpful and we deeply appreciate the strong collaboration between the two teams at AMD and PyTorch. The move for ROCm support from “Beta” to “Stable” came in the PyTorch 1.12 release (June 2022) brings the added support to easily run PyTorch on native environment without having to configure custom dockers. This is a sign of confidence about the quality of support and performance of PyTorch using AMD Instinct and ROCm. The results of these collaborative efforts are evident in the performance measured on key industry benchmarks like Microsoft’s SuperBench shown below in Graph 1.

“We are excited to see the significant impact of developers at AMD to contribute to and extend features within PyTorch to make AI models run in a more performant, efficient, and scalable way. A great example of this is the thought-leadership around unified memory approaches between the framework and future hardware systems, and we look forward to seeing that feature progress.”

– Soumith Chintala, PyTorch lead-maintainer and Director of Engineering, Meta AI

The progressive improvements on both the AMD CDNA™ architecture as well as ROCm and PyTorch shows single GPU model throughput increase from AMD Instinct MI100 to the latest generation AMD Instinct MI200 family GPUs going from ROCm 4.2 to ROCm 5.3 and from PyTorch 1.7 to PyTorch 1.12.

Graph 1: ML model performance over generation using Microsoft Superbench Suite

Graph 1: ML model performance over generation using Microsoft Superbench Suite 1, 2, 3

Below are a few of the key updates for ROCm support since the PyTorch 1.12 release

Full Continuous Integration (CI) for ROCm on PyTorch

With the ROCm support for PyTorch move from “Beta” to “Stable,” all the functions and features commits are now verified through a full Continuous Integration (CI) process. The CI process helps ensure the proper build and test process ahead of an expected Docker and PIP wheel release with stable commits forthcoming.

Support for Kineto Profiler

The addition of Kineto profiler support to ROCm now helps developers and users understand performance bottlenecks through effective diagnosis and profiling tools. The tool also provides recommendations to improve known issues and visualization through TensorBoard UI.

Key PyTorch Libraries support added

PyTorch ecosystem libraries like TorchText (Text classification), TorchRec (libraries for recommender systems – RecSys), TorchVision (Computer Vision), TorchAudio (audio and signal processing) are fully supported since ROCm 5.1 and upstreamed with PyTorch 1.12.

Key libraries provided with the ROCm software stack including MIOpen (Convolution models), RCCL (ROCm Collective Communications) and rocBLAS (BLAS for transformers) were further optimized to offer new potential efficiencies and higher performance.

MIOpen innovates on several fronts, such as implementing fusion to optimize for memory bandwidth and GPU launch overheads, providing an auto-tuning infrastructure to overcome the large design space of problem configurations, and implementing different algorithms to optimize convolutions for different filter and input sizes. MIOpen is one of the first libraries to publicly support the bfloat16 data-type for convolutions, allowing efficient training at lower precision maintaining expected accuracy.

RCCL (pronounced “Rickle”) is a stand-alone library of standard collective communication routines for GPUs, implementing all-reduce, all-gather, reduce, broadcast, reduce-scatter, gather, scatter, and all-to-all. There is support for direct GPU-to-GPU send and receive operations. It has been optimized to achieve high bandwidth on platforms using PCIe®, Infinity Fabric™ (GPU to GPU) as well as networking using InfiniBand Verbs or TCP/IP sockets. RCCL supports an arbitrary number of GPUs installed in single or multiple nodes and can be used in either single- or multi-process (e.g., MPI) applications.

Along with the above key highlights, over 50 features and functionality improvements were completed jointly between AMD and PyTorch to add stable support for ROCm. These include improvements to tools, compilers, runtime, graph optimizations through TorchScript, INT8 quant path usage, and ONNX runtime integration including support for Navi 21 based Radeon™ PRO datacenter graphics card to name a few.

AITemplate Inference Engine

MetaAI recently published a blog announcing the release of its open source AITemplate (link) for a unified inference system supporting AMD Instinct GPU accelerators using the AMD ROCm stack. This Python based framework can help significantly improve performance through increased utilization of AMD matrix cores for transformer blocks. This is achieved through the AMD Composable Kernel (CK) library which provides performance critical Kernels for ML AI workloads across multiple architectures including GPUs and CPUs through HIP & C++.

Moreover, the AITemplate also provides out-of-the-box support for widely used AI models like BERT, ResNET, Vision Transformer, Stable Diffusion etc. simplifying deployment process through these pretrained models.

What’s coming with future ROCm releases?

Unified memory models for CPU + GPU

As system architecture evolves to address the complexity of large problem sizes and data sets, memory management becomes a key performance bottle neck that needs a cohesive strategy to be addressed through innovations at both hardware and software levels. AMD is uniquely positioned to address this problem with its effective data center solutions integrating AMD EPYC™ CPU cores with its AMD Instinct GPU compute units in a truly unified datacenter APU (Accelerated Processing Unit) form factor set to be launched in 2H 2023.

The software work to leverage the unified CPU + GPU memory has already started in collaboration with the PyTorch team, to enable the usage of a fast, low latency, synchronized memory model that enables not only AMD but also other AI accelerators to address the complex memory management problem of today. We are looking forward to this joint effort and announcement soon.


The content in this blog highlights the joint work between AMD and key PyTorch contributors including Meta, working on many of the core features, as well as Microsoft enabling ONNX Runtime support. We are looking forward to working with the other founding members at the PyTorch Foundation on the next steps and improvements to democratize and grow adoption of PyTorch across the industry.


This blog contains forward-looking statements concerning Advanced Micro Devices, Inc. (AMD) such as the availability, timing and expected benefits of an AMD datacenter APU form factor, which are made pursuant to the Safe Harbor provisions of the Private Securities Litigation Reform Act of 1995. Forward-looking statements are commonly identified by words such as “would,” “may,” “expects,” “believes,” “plans,” “intends,” “projects” and other terms with similar meaning. Investors are cautioned that the forward-looking statements in this blog are based on current beliefs, assumptions and expectations, speak only as of the date of this blog and involve risks and uncertainties that could cause actual results to differ materially from current expectations. Such statements are subject to certain known and unknown risks and uncertainties, many of which are difficult to predict and generally beyond AMD’s control, that could cause actual results and other future events to differ materially from those expressed in, or implied or projected by, the forward-looking information and statements. Investors are urged to review in detail the risks and uncertainties in AMD’s Securities and Exchange Commission filings, including but not limited to AMD’s most recent reports on Forms 10-K and 10-Q. AMD does not assume, and hereby disclaims, any obligation to update forward-looking statements made in this blog, except as may be required by law.


  1. MI100D-01 SuperBench v0.5 model training results based on AMD internal testing as of 11/09/2022 measuring the total training throughput, at half precision, using a 2P AMD EPYC™ 7763 CPU server tested with 1x AMD Instinct™ MI100 (32GB HBM2e) 300W GPU, SBIOS 2.2, Ubuntu® 20.04.5 LTS, host ROCm™ 5.2.0, guest ROCm 4.2, PyTorch 1.7.0. Server manufacturers may vary configurations, yielding different results. Performance may vary based factors including use of latest drivers and optimizations.
  2. MI200D-01 SuperBench v0.6 model training results based on AMD internal testing as of 11/09/2022 measuring the total training throughput, at half precision, using a 2P AMD EPYC™ 7763 CPU server tested with 1x AMD Instinct™ MI210 (64GB HBM2e) 300W GPU, SBIOS 2.2, Ubuntu 20.04.5 LTS, host ROCm 5.3.0, guest ROCm 5.3, PyTorch 1.12. Server manufacturers may vary configurations, yielding different results. Performance may vary based factors including use of latest drivers and optimizations.
  3. MI200D-02: SuperBench v0.6 model training results based on AMD internal testing as of 11/09/2022 measuring the total training throughput, at half precision, using a 2P AMD EPYC™️ 7763 CPU server tested with 1x AMD Instinct™️ MI250 (128GB HBM2e) 560W GPU, SBIOS M12, Ubuntu 20.04 LTS, host ROCm 5.3.0, guest ROCm 5.3, PyTorch 1.12. Server manufacturers may vary configurations, yielding different results. Performance may vary based factors including use of latest drivers and optimizations.

Read More

Deprecation of CUDA 11.6 and Python 3.7 Support

For the upcoming PyTorch 2.0 feature release (target March 2022), we will target CUDA 11.7 as the stable version and CUDA 11.8 as the experimental version of CUDA and Python >=3.8, <=3.11.

If you are still using or depending on CUDA 11.6 or Python 3.7 builds, we strongly recommend moving to at least CUDA 11.7 and Python 3.8, as it would be the minimum versions required for PyTorch 2.0.

Please note that as of Feb 1, CUDA 11.6 and Python 3.7 are no longer included in the nightlies

Please refer to the Release Compatibility Matrix for PyTorch releases:

PyTorch Version Python Stable CUDA Experimental CUDA
2.0 >=3.8, <=3.11 CUDA 11.7, CUDNN CUDA 11.8, CUDNN
1.13 >=3.7, <=3.10 CUDA 11.6, CUDNN CUDA 11.7, CUDNN
1.12 >=3.7, <=3.10 CUDA 11.3, CUDNN CUDA 11.6, CUDNN

As of 2/1/2023

For more information on PyTorch releases, updated compatibility matrix and release policies, please see (and bookmark) Readme.

Read More

Performance experiments with Stable Diffusion

This is a companion to the main blog “Accelerated Stable Diffusion with PyTorch 2”, containing detailed information on benchmarking setup and results of individual experiments. It is mainly aimed at a hands-on reader who would want to reproduce or develop further the work we described in the main text. Please see the main text for all the context and the summary of results.

Appendix 1: benchmarked versions definition

Here we define precisely what we mean by “original code” and “optimized code” in the main text.

Original code

Lives in on original-benchmark branch, specifically in this commit. This is almost the same code as in, with minimal modifications necessary for benchmarking. In particular, the code is able to turn off xFormers attention when the environment variable USE_XFORMERS is set to False.

This code uses PyTorch 1.12 and the original custom implementation of attention.

Optimized code

The optimized version is the code living here. It has all the optimizations we mentioned in the main text:

  • nn.MultiheadAttention in CrossAttention instead of custom attention implementation
  • Compilation with torch.compile
  • Other minor optimizations in PyTorch-related code.

The first optimization (using nn.MultiheadAttention in CrossAttention) schematically boils down to the following pseudocode:

class CrossAttention(nn.Module):
    def __init__(self, ...):
        # Create matrices: Q, K, V, out_proj
    def forward(self, x, context=None, mask=None):
       # Compute out = SoftMax(Q*K/sqrt(d))V
       # Return out_proj(out)

gets replaced with

class CrossAttention(nn.Module):
    def __init__(self, ...):
        self.mha = nn.MultiheadAttention(...)
    def forward(self, x, context):
	return self.mha(x, context, context)

See the full diff here.

We have also introduced the following CLI flags:

  • --disable_math, --disable_mem_efficient, --disable_flash to allow turning specific attention backends off
  • --compile to turn on PyTorch compilation

The optimized version uses PyTorch 2.0.0.dev20230111+cu117

Flags added to both code versions

In both code versions we have added the following CLI options to

  • --skip_first to use a “warm-up” iteration before starting to measure time. See the end of section “Benchmarking setup and results summary” in the main text on why this was necessary
  • --time_file <FILENAME> to write runtime in seconds in text format to the specified file


Now it should already be clear how to run the 5 configurations mentioned in the main text. For completeness we provide the prompts which can be used to run each of them. This assumes you have

  • installed dependencies from the original version into conda environment ldm-original
  • installed dependencies from the optimized version into conda environment ldm
  • downloaded model weights into /tmp/model.ckpt
  • converted model weights to the new architecture and saved them into /tmp/model_native_mha.ckpt

(see Colab for a bash script which does that)

Prompts for 5 configurations:

# Run optimized with memory-efficient attention and compilation
conda activate ldm
git checkout optmize-w-compile
python scripts/ --prompt "A photo" --seed 1 --plms --config configs/stable-diffusion/v2-inference_native_mha.yaml --ckpt /tmp/model_native_mha.ckpt --n_iter 2 --n_samples 1 --compile --skip_first

# Run optimized with memory-efficient attention
conda activate ldm
git checkout optmize-w-compile
python stable-diffusion/scripts/ --prompt "A photo" --seed 1 --plms --config stable-diffusion/configs/stable-diffusion/v2-inference_native_mha.yaml --ckpt /tmp/model_native_mha.ckpt --n_iter 2 --n_samples 1 --skip_first

# Run optimized without memory-efficient or flash attention
conda activate ldm
git checkout optmize-w-compile
python stable-diffusion/scripts/ --prompt "A photo" --seed 1 --plms --config stable-diffusion/configs/stable-diffusion/v2-inference_native_mha.yaml --ckpt /tmp/model_native_mha.ckpt --n_iter 2 --n_samples 1 --disable_mem_efficient --disable_flash --skip_first 

# Run original code with xFormers
conda activate ldm-original
git checkout original-benchmark
python stable-diffusion-original/scripts/ --prompt "A photo" --seed 1 --plms --config stable-diffusion-original/configs/stable-diffusion/v2-inference.yaml --ckpt /tmp/model.ckpt --n_iter 2 --n_samples 1 --skip_first

# Run original code without xFormers
conda activate ldm-original
git checkout original-benchmark
USE_XFORMERS=False python stable-diffusion-original/scripts/ --prompt "A photo" --seed 1 --plms --config stable-diffusion-original/configs/stable-diffusion/v2-inference.yaml --ckpt /tmp/model.ckpt --n_iter 2 --n_samples 1 --skip_first

Appendix 2: per-run data

Plots with per-run benchmark data can be found here. Each plot shows all the runs for a particular GPU (P100, V100, T4, A10, A100) and batch size (1, 2, or 4). The bar charts in the main text are obtained from this data by averaging. The file names are self-explanatory, for example “original_vs_optimized_A10_n_samples_2_n_iter_2_sd2.png” contains runs for A10 GPU, batch size 2 and number of iterations 2.

Appendix 3: Accelerated Stable Diffusion 1

Before the work on Stable Diffusion 2 described in the main text, we also applied similar optimizations to Stable Diffusion 1 by CompVis prior to the release of Stable Diffusion 2. The original implementation of SD1 does not integrate with xFormers yet, and so the speedup from just using the PyTorch optimized attention instead of custom implementation is significant. It should be noted that the HuggingFace Diffusers port of SD1 allows integration with xFormers, so an interesting open question which we didn’t explore would be how the performance of SD1 with PyTorch optimized attention compares to HuggingFace SD1+xFormers.

We benchmarked two versions of SD1, original and optimized:

  • As the original version we took the first SD release, and placed it here with minimal modifications to simplify benchmarking. It uses PyTorch 1.11 and custom implementation of attention.
  • The optimized version is the code living here. It uses nn.MultiheadAttention in CrossAttention and PyTorch 2.0.0.dev20221220+cu117.

Here are the results for different GPU architectures and batch size 2:


T4 P100 V100 A100
Original SD1 (runtime in s)

70.9 71.5 20.3 14.4
Optimized SD1 (runtime in s)

52.7 (-25.6%) 57.5 (-19.5%) 14.3 (-29.3%) 10.4 (27.9%)

Same as for SD2, we used Meta hardware for P100, V100, A100 benchmarks. The T4 benchmark was done in Google Colab here.

We didn’t apply compilation to SD1, and so didn’t include a “warm-up” iteration in these benchmarks, as we did for SD2.

Both applying torch.compile to SD1 and benchmarking HuggingFace version of SD1 with PyTorch 2 optimisations would be a great exercise for the reader – try it and let us know if you get interesting results.

Read More

Accelerated Stable Diffusion with PyTorch 2

Accelerated Stable Diffusion with PyTorch 2

TL;DR: PyTorch 2.0 nightly offers out-of-the-box performance improvement for Stable Diffusion 2.1 by using the new torch.compile() compiler and optimized implementations of Multihead Attention integrated with PyTorch 2.


Stable Diffusion (SD) is a great example of Generative AI, producing high quality images from text prompts. However, as well as for other diffusion-based models, its generation is rather slow, due to the iterative nature of the sampling process by which the images are produced. This makes it important to optimize the code running inside the sampling loop.

We took SD 2.1 from Stability AI as a starting point and accelerated its text-to-image generation using two optimizations available in PyTorch 2: compilation and fast attention implementation. Together with a few minor memory processing improvements in the code these optimizations give up to 49% inference speedup relative to the original SD implementation without xFormers, and 39% inference speedup relative to using SD with xFormers (excluding the compilation time), depending on the GPU architecture and batch size. Importantly, the speedup comes without a need to install xFormers or any other extra dependencies.

The table below shows the improvement in runtime between the original implementation with xFormers installed and our optimized version with PyTorch-integrated memory efficient attention (originally developed for and released in the xFormers library) and PyTorch compilation. The compilation time is excluded.

Runtime improvement in % compared to original+xFormers

See the absolute runtime numbers in section “Benchmarking setup and results summary”

GPU Batch size 1 Batch size 2 Batch size 4
P100 (no compilation) -3.8 0.44 5.47
T4 2.12 10.51 14.2
A10 -2.34 8.99 10.57
V100 18.63 6.39 10.43
A100 38.5 20.33 12.17

One can notice the following:

  • The improvements are significant for powerful GPUs like A100 and V100. For those GPUs the improvement is most pronounced for batch size 1
  • For less powerful GPUs we observe smaller speedups (or in two cases slight regressions). The batch size trend is reversed here: improvement is larger for larger batches

In the following sections we describe the applied optimizations and provide detailed benchmarking data, comparing SD performance with various optimization features on/off.

Specifically, we benchmark 5 configurations and the plots below compare their absolute performance for different GPUs and batch sizes. For definitions of these configurations see section “Benchmarking setup and results”.

Benchmark of Stable Diffusion 2 versions across GPU architectures, batch size 1

Benchmark of Stable Diffusion 2 versions across GPU architectures, batch size 2

Benchmark of Stable Diffusion 2 versions across GPU architectures, batch size 4

If you prefer looking directly at the code, see the Google Colab which runs the benchmark on T4.


Here we’ll go into more detail about the optimizations introduced into the SD code. At the moment they rely on features only available in the nightlies, so we pinned the PyTorch version to a recent nightly (see here). Once the PyTorch 2.0 release comes out, these optimizations won’t have to rely on nightlies any more.

Optimized Attention

One part of the code which we optimized was the scaled dot-product attention. Attention is known to be a heavy operation: naive implementation materializes the attention matrix, leading to time and memory complexity quadratic in sequence length. In Stable Diffusion attention (CrossAttention) appears as part of Transformer blocks in multiple parts of the U-Net. Since the U-Net runs at every sampling step, this becomes a critical point to optimize. In PyTorch 2 optimized attention implementation is integrated into torch.nn.MultiheadAttention, and so we used it to replace the custom attention implementation in CrossAttention.

The optimized implementation of attention was available already in PyTorch 1.13 (see here) and widely adopted (see e.g. HuggingFace transformers library example). In particular, it integrates memory-efficient attention from the xFormers library and flash attention from PyTorch 2.0 expands this to additional attention functions such as cross attention and custom kernels for further acceleration, making it applicable to SD.

Flash attention is available on GPUs with compute capability SM 7.5 or SM 8.x – for example, on T4, A10, and A100, which are included in our benchmark (you can check compute capability of each NVIDIA GPU here). However, in our tests on A100 the memory efficient attention performed better than flash attention for the particular case of SD, due to the small number of attention heads and small batch size. PyTorch understands this and chooses memory efficient attention over flash attention for SD when both are available (see the logic here). For full control over the attention backends (memory-efficient attention, flash attention, “vanilla math”, or any future ones), power users can enable and disable them manually with the help of the context manager torch.backends.cuda.sdp_kernel.


Compilation is a new feature of PyTorch 2.0, enabling significant speedups with a very simple user experience. To invoke the default behavior, simply wrap a PyTorch module or a function into torch.compile:

model = torch.compile(model)

PyTorch compiler then turns Python code into a set of instructions which can be executed efficiently without Python overhead. The compilation happens dynamically the first time the code is executed. With the default behavior, under the hood PyTorch utilized TorchDynamo to compile the code and TorchInductor to further optimize it. See this tutorial for more details.

Although the one-liner above is enough for compilation, certain modifications in the code can squeeze a larger speedup. In particular, one should avoid so-called graph breaks – places in the code which PyTorch can’t compile. As opposed to previous PyTorch compilation approaches (like TorchScript), PyTorch 2 compiler doesn’t break in this case. Instead it falls back on eager execution – so the code runs, but with reduced performance. We introduced a few minor changes to the SD code to eliminate graph breaks (here and here). See this doc to learn more about graph breaks and how to eliminate them.

Note that compilation requires GPU compute capability >= SM 7.0 to run in non-eager mode. This covers all GPUs in our benchmarks – T4, V100, A10, A100 – except for P100 (see the full list).

Other optimizations

In addition, we have improved efficiency of some memory operations – e.g. creating a tensor on GPU directly rather than creating it on CPU and later moving to GPU (see here and here). The places where such optimizations were necessary were determined by line-profiling and looking at CPU/GPU traces and Flame Graphs.

Benchmarking setup and results summary

We have two versions of SD code to compare: original and optimized. On top of this, several optimization features (xFormers, PyTorch memory efficient attention, compilation) can be turned on/off. Overall, as mentioned in the introduction, we will be benchmarking 5 configurations:

  • Original code without xFormers
  • Original code with xFormers
  • Optimized code with vanilla math attention backend and no compilation
  • Optimized code with memory-efficient attention backend and no compilation
  • Optimized code with memory-efficient attention backend and compilation

As the original version we took the SD 2.1 release, and placed it here with minimal modifications necessary for benchmarking. It uses PyTorch 1.12 and a custom implementation of attention.

The optimized version is the code living here. It uses nn.MultiheadAttention in CrossAttention and PyTorch 2.0.0.dev20230111+cu117. It also has a few other minor optimizations in PyTorch-related code.

Please see the appendix “Benchmarked versions definition” in the companion page for the precise definition of the 5 configurations and prompts triggering each of them.

The table below shows runtime of each version of the code in seconds, and the percentage improvement compared to the original with xFormers. The compilation time is excluded.

Runtimes for batch size 1. In parenthesis – relative improvement with respect to the “Original with xFormers” row

Configuration P100 T4 A10 V100 A100
Original without xFormers 30.4s (-19.3%) 29.8s (-77.3%) 13.0s (-83.9%) 10.9s (-33.1%) 8.0s (-19.3%)
Original with xFormers 25.5s (0.0%) 16.8s (0.0%) 7.1s (0.0%) 8.2s (0.0%) 6.7s (0.0%)
Optimized with vanilla math attention, no compilation 27.3s (-7.0%) 19.9s (-18.7%) 13.2s (-87.2%) 7.5s (8.7%) 5.7s (15.1%)
Optimized with mem. efficient attention, no compilation 26.5s (-3.8%) 16.8s (0.2%) 7.1s (-0.8%) 6.9s (16.0%) 5.3s (20.6%)
Optimized with mem. efficient attention and compilation 16.4s (2.1%) 7.2s (-2.3%) 6.6s (18.6%) 4.1s (38.5%)

Runtimes for batch size 2

Configuration P100 T4 A10 V100 A100
Original without xFormers 58.0s (-21.6%) 57.6s (-84.0%) 24.4s (-95.2%) 18.6s (-63.0%) 12.0s (-50.6%)
Original with xFormers 47.7s (0.0%) 31.3s (0.0%) 12.5s (0.0%) 11.4s (0.0%) 8.0s (0.0%)
Optimized with vanilla math attention, no compilation 49.3s (-3.5%) 37.9s (-21.0%) 17.8s (-42.2%) 12.7s (-10.7%) 7.8s (1.8%)
Optimized with mem. efficient attention, no compilation 47.5s (0.4%) 31.2s (0.5%) 12.2s (2.6%) 11.5s (-0.7%) 7.0s (12.6%)
Optimized with mem. efficient attention and compilation 28.0s (10.5%) 11.4s (9.0%) 10.7s (6.4%) 6.4s (20.3%)

Runtimes for batch size 4

Configuration P100 T4 A10 V100 A100
Original without xFormers 117.9s (-20.0%) 112.4s (-81.8%) 47.2s (-101.7%) 35.8s (-71.9%) 22.8s (-78.9%)
Original with xFormers 98.3s (0.0%) 61.8s (0.0%) 23.4s (0.0%) 20.8s (0.0%) 12.7s (0.0%)
Optimized with vanilla math attention, no compilation 101.1s (-2.9%) 73.0s (-18.0%) 28.3s (-21.0%) 23.3s (-11.9%) 14.5s (-13.9%)
Optimized with mem. efficient attention, no compilation 92.9s (5.5%) 61.1s (1.2%) 23.9s (-1.9%) 20.8s (-0.1%) 12.8s (-0.9%)
Optimized with mem. efficient attention and compilation 53.1s (14.2%) 20.9s (10.6%) 18.6s (10.4%) 11.2s (12.2%)

To minimize fluctuations and external influence on the performance of the benchmarked code, we ran each version of the code one after another, and then repeated this sequence 10 times: A, B, C, D, E, A, B, … So the results of a typical run would look like the one in the picture below. For results of all runs please see appendix “Per-run data” in the companion page. Note that one shouldn’t rely on comparison of absolute run times between different graphs, but comparison of run times inside one graph is pretty reliable, thanks to our benchmarking setup.

Stable Diffusion 2.1 benchmarks

Each run of generates several batches, which is regulated by the CLI parameter --n_iter. In the benchmarks we used n_iter = 2, but introduced an additional “warm-up” iteration, which doesn’t contribute to the run time. This was necessary for the runs with compilation, because compilation happens the first time the code runs, and so the first iteration is much longer than all subsequent. To make comparison fair, we also introduced this additional “warm-up” iteration to all other runs, which is turned on by CLI option --skip_first provided to the modified

The numbers in the table above are for number of iterations 2 (plus a “warm-up one”), prompt ”A photo”, seed 1, PLMS sampler, and autocast turned on. See the companion page for precise CLI commands in appendix “Benchmarked versions definition” and detailed results of individual runs in appendix “Per-run data”.

The P100, V100, and A100 benchmarks were done on Meta internal infrastructure. The T4 benchmarks were done in Google Colab Pro (see the Google Colab notebook). The A10 benchmarks were done on g5.4xlarge AWS instances with 1 GPU.

Conclusions and next steps

We have shown that new features of PyTorch 2 – compiler and optimized attention implementation – give performance improvements exceeding or comparable with what previously required installation of an external dependency (xFormers). PyTorch achieved this, in particular, by integrating memory efficient attention from xFormers into its codebase. This is a significant improvement for user experience, given that xFormers, being a state-of-the-art library, in many scenarios requires custom installation process and long builds.

There are a few natural directions in which this work can be continued:

  • There are new implementations of SD, including a port to HuggingFace diffusers library. It would be interesting to benchmark against them. Note that diffusers also require installing xFormers in order to use memory efficient attention
  • The optimizations we implemented and described here are only benchmarked for text-to-image inference so far. It would be interesting to see how they affect training. PyTorch compilation can be directly applied to training; enabling training with PyTorch optimized attention is on the roadmap
  • We intentionally minimized changes to the original SD code. Further profiling and optimization can probably bring more improvements
  • At the moment compilation is applied only to the U-Net model inside the sampler. Since there is a lot happening outside of U-Net (e.g. operations directly in the sampling loop), it would be beneficial to compile the whole sampler. However, this would require analysis of the compilation process to avoid recompilation at every sampling step
  • Current code only applies compilation within the PLMS sampler, but it should be trivial to extend it to other samplers
  • Besides text-to-image generation, SD 2.1 has other pipelines – image-to-image and inpainting. It would be interesting to measure how their performance improves from PyTorch 2 optimizations

Try some of this in the Colab or on a GPU of your choice. See if you can further increase the performance of SD, and share the results! This is your chance to get a preview of PyTorch 2.0 and experience the features coming in the next release.

As a note, if you want access to new PyTorch features which come after this post is published, just tweak the PyTorch and TorchVision versions in environment.yaml.



We would like to thank Geeta Chauhan, Natalia Gimelshein, Patrick Labatut, Bert Maher, Mark Saroufim, Michael Voznesensky and Francisco Massa for their valuable advice and early feedback on the text.

Special thanks to Yudong Tao for creating the first version of Stable Diffusion with PyTorch native attention.

For more information, visit this page with additional resources.

Read More

PyTorch Trace Analysis for the Masses

PyTorch Trace Analysis for the Masses


We are excited to announce the public release of Holistic Trace Analysis (HTA), an open source performance analysis and visualization Python library for PyTorch users. HTA takes as input Kineto traces collected by the PyTorch profiler, which are complex and challenging to interpret, and up-levels the performance information contained in these traces. It was initially developed internally at Meta to understand and debug performance problems for large-scale distributed training jobs on GPUs. The multidisciplinary team has made a number of enhancements to HTA’s features and scaled them to support state-of-the-art ML workloads.

ML researchers and systems engineers often struggle to computationally scale up their models because they are not aware of the performance bottlenecks in their workloads. The resources requested for a job (e.g. GPUs, memory) are often misaligned with the resources actually required due to lack of visibility “under the hood”. To achieve the best performance from the hardware stack, it is imperative to understand the resource utilization and bottlenecks for distributed training workloads.

The initial HTA implementation was specifically targeted at Deep Learning Based Recommendation Models (DLRM). To make the features in HTA generic and applicable to use cases such as analyzing Vision and NLP models, we decided to refactor the HTA codebase and make the library available to the larger community. This new codebase has implemented several important ideas which lead to significant efficiency and performance improvements.

In this blog, we present several features implemented in the open source version of HTA, which can be used as a Python script as well as interactively in a Jupyter notebook. HTA provides the following features:

  1. Breakdown by Dimensions
    1. Temporal: Breakdown of GPU time in terms of time spent in computation, communication, memory events, and idle time on a single node and across all ranks.
    2. Idle Time: Breakdown of GPU idle time into waiting for the host, waiting for another kernel or attributed to an unknown cause.
    3. Kernel: Find kernels with the longest duration on each rank.
    4. Communication Computation Overlap: Calculate the percentage of time when communication overlaps computation.
  2. Statistical Analysis
    1. Kernel Duration Distribution: Distribution of average time taken by longest kernels across different ranks.
    2. CUDA Kernel Launch: Distributions of GPU kernels with very small duration, large duration, and excessive launch time.
    3. Augmented Counters (Memory bandwidth, Queue length): Augmented trace files which provide insights into memory copy bandwidth and number of outstanding operations on each CUDA stream.
  3. Patterns
    1. Frequent CUDA Kernels: Find the CUDA kernels most frequently launched by any given PyTorch or user defined operator.
  4. Trace Comparison
    1. Trace Diff: A trace comparison tool to identify and visualize the differences between traces.

HTA source code is available to users via Github. Users can request new features or build their own analysis using the core libraries and data structures provided in the codebase in addition to the features mentioned above.

GPU Training Performance Debugging 101

To understand the GPU performance in distributed training jobs, we consider how the model operators interact with the GPU devices and how such interactions are reflected in certain measurable metrics.

At a high level, we can break down the GPU operations in a model execution into three broad categories, henceforth referred to as kernel types:

  1. Computation (COMP) – Compute kernels execute compiled routines for matrix multiplication and similar numeric calculations. They are responsible for all of the number-crunching necessary for model execution.
  2. Communication (COMM) – Communication kernels are routines which are responsible for exchanging and synchronizing data between different GPU devices in a distributed training job. The NVIDIA Collective Communication Library (NCCL) is a widely used communication library and all its kernels have the prefix “nccl”. Example NCCL kernels include NCCL_AllGather, NCCL_ReduceScatter, NCCL_AllReduce, etc.
  3. Memory (MEM) – Memory kernels manage the memory allocations/deallocations on the GPU devices and data movement between the memory space on the host and the GPUs. The memory kernels include Memcpy_H2D, Memcpy_D2H, Memcpy_D2D, Memset, etc. Here, H represents the Host and D represents the GPU Device. Thus, H2D, D2H, D2D stands for Host to Device, Device to Host and Device to Device respectively.

Because a modern GPU device like the NVIDIA A100 GPU is a massively parallel device which is capable of running multiple kernels simultaneously, it is possible to overlap the computation, communication, and memory kernels to reduce the model execution time. One common technique to achieve the overlap is to utilize multiple CUDA streams. A CUDA stream is a sequence of operations that execute on a GPU device in the order in which they are issued by the host code. Different CUDA streams can be interleaved and even run concurrently, thus achieving the effect of kernel overlap.

To help understand the above concepts, Figure 1 provides a timeline of the GPU kernels in a sample distributed training job on 8 GPUs for one iteration. In the figure below, each rank represents one GPU and the kernels on each GPU run on 6 CUDA streams. In the right column of the figure, you can see names of the GPU kernels used. In the middle of the figure, you see the overlap between compute and communicate kernels. This figure is created using the plot_timeline example notebook available in HTA.

Figure 1. An example of the execution timeline of GPU Kernels across multiple ranks

Figure 1. An example of the execution timeline of GPU Kernels across multiple ranks

The performance of multiple GPU training jobs is affected by multiple factors. Among these factors, how does a model execution create and orchestrate the GPU kernels plays a critical role. HTA provides insights on how the model execution interacts with the GPU devices and highlights the opportunities for performance improvement.

With the features we built in HTA, we aim to provide users insights into “what is happening under the hood in a distributed GPU training?” We briefly describe these features in the next few paragraphs.

Features in Holistic Trace Analysis

For most users, understanding the performance of GPU training jobs is nontrivial. Thus, we built this library to simplify the task of trace analysis and provide the user useful insights by examining the model execution traces. As the first step, we developed features which are important and generic enough so that most users can benefit from this library.

Temporal Breakdown: We begin by asking whether the GPU is spending time on computation, communication, memory events, or is it idle? To answer this question, the temporal breakdown feature presents a breakdown in terms of these categories. To achieve high training efficiency the code should maximize time used by computation kernels and minimize idle time and non-compute time (time used by communication or memory kernels). This is accomplished by implementing concurrent execution of computation kernels with communication or memory kernels. Note that, during concurrent execution of computation kernels with communication/memory kernels the time spent by communication/memory kernels is accounted for under compute time.

Figure 2: Temporal Breakdown across 8 GPUs

Figure 2: Temporal Breakdown across 8 GPUs

Kernel Breakdown: It is natural to ask which kernels are taking the most amount of time. The next feature breaks down the time spent within each kernel type (COMM, COMP, MEM) and sorts them by duration. We present this information for each kernel type and for each rank as a pie chart. See figure 3 below.

Figure 3: Pie chart of top computation and communication kernels

Figure 3: Pie chart of top computation and communication kernels

Kernel Duration Distribution: Subsequently, one can also ask – for any given kernel, what is the distribution of the time spent across the ranks? To answer this, HTA generates bar graphs for the average duration of a given kernel across all ranks. Additionally, the error bars in the bar graphs show the minimum and maximum amount of time taken by a given kernel on a given rank. Figure 4 below shows a discrepancy between average duration on rank 0 as compared to other ranks. This anomalous behavior on rank 0 guides the user on where to look for possible bugs.

Figure 4: Average duration of NCCL AllReduce Kernel across 8 ranks

Figure 4: Average duration of NCCL AllReduce Kernel across 8 ranks

Communication Computation Overlap: In distributed training, a significant amount of time is spent in communication and synchronization events among multiple GPU devices. To achieve high GPU efficiency (i.e. TFLOPS/GPU) it is vital to keep the GPU doing actual computation work. In other words, a GPU should not be blocked because of waiting for data from other GPUs. One way to measure the extent to which computation is blocked by data dependencies is to calculate the computation-communication overlap. Higher GPU efficiency is observed if communication events overlap computation events. Lack of communication and computation overlap will lead to the GPU being idle, thus the efficiency would be low. Thus, the communication computation overlap feature calculates the percentage of time communication and computation overlap in a job for each rank and generates a bar graph representation. See figure below. More precisely, we measure the following ratio

(time spent in computation while communicating) / (time spent in communication)

Figure 5: Communication computation overlap

Figure 5: Communication computation overlap

Augmented Counters (Queue length, Memory bandwidth): To aid in debugging, HTA calculates the memory bandwidth statistics for D2H, H2D and D2D memory copy (memcpy) and memory set (memset) events. Additionally, HTA also computes the number of outstanding CUDA operations on each CUDA stream. We refer to this as queue length. When the queue length on a stream is 1024 or larger new events cannot be scheduled on that stream and the CPU will stall until the GPU events have processed. Additionally, HTA generates a new trace file containing tracks with the memory bandwidth and queue length time series. See Figure 6 below.

Figure 6: Memory Bandwidth and Queue Length

Figure 6: Memory Bandwidth and Queue Length

These primary features give us a peek into the system performance and help answer “what is happening in the system?”. As HTA evolves, we hope to address “why is X happening?” and also suggest possible solutions to overcome the bottlenecks.

Installation and Usage


For installing the HTA please refer to the README. In brief, the user is required to clone the repo and install the necessary Python packages via pip.


This version of Holistic Trace Analysis is currently in beta and we recommend using HTA in a Jupyter notebook. A demo notebook is provided for your convenience. To get started, import the hta package in a Jupyter notebook, create a TraceAnalysis object and off we go in exactly two lines of code.

from hta.trace_analysis import TraceAnalysis
analyzer = TraceAnalysis(trace_dir = /trace/folder/path)


  • All trace files for a training or inference job must be stored in a unique folder.
  • Trace files are in json or gzipped json format.


Q. How can I install HTA?

Please see the README in the root directory of the repository.

Q. Is there any documentation on the features and API in HTA?

The documentation and detailed API is available here.

Q. Can you implement feature X?

Depending on how widely the feature is needed and the level of effort required to implement it we would consider developing the feature. Please open a Github Issue and tag it with the feature-request label.

Q. Can I modify the code?

Please do and send a PR along the way, if you think it would be useful for others.

Q. How can I collect traces in PyTorch?

Please refer to this tutorial here.

Q. Can HTA be used at production scale?

Yes, please see a use case study here.

Read More

Compromised PyTorch-nightly dependency chain between December 25th and December 30th, 2022.

If you installed PyTorch-nightly on Linux via pip between December 25, 2022 and December 30, 2022, please uninstall it and torchtriton immediately, and use the latest nightly binaries (newer than Dec 30th 2022).

$ pip3 uninstall -y torch torchvision torchaudio torchtriton
$ pip3 cache purge

PyTorch-nightly Linux packages installed via pip during that time installed a dependency, torchtriton, which was compromised on the Python Package Index (PyPI) code repository and ran a malicious binary. This is what is known as a supply chain attack and directly affects dependencies for packages that are hosted on public package indices.

NOTE: Users of the PyTorch stable packages are not affected by this issue.**

How to check if your Python environment is affected

The following command searches for the malicious binary in the torchtriton package (PYTHON_SITE_PACKAGES/triton/runtime/triton) and prints out whether your current Python environment is affected or not.

python3 -c "import pathlib;import importlib.util;s=importlib.util.find_spec('triton'); affected=any( == 'triton' for x in (pathlib.Path(s.submodule_search_locations[0] if s is not None else '/' ) / 'runtime').glob('*'));print('You are {}affected'.format('' if affected else 'not '))"

The malicious binary is executed when the triton package is imported, which requires explicit code to do and is not PyTorch’s default behavior.

The Background

At around 4:40pm GMT on December 30 (Friday), we learned about a malicious dependency package (torchtriton) that was uploaded to the Python Package Index (PyPI) code repository with the same package name as the one we ship on the PyTorch nightly package index. Since the PyPI index takes precedence, this malicious package was being installed instead of the version from our official repository. This design enables somebody to register a package by the same name as one that exists in a third party index, and pip will install their version by default.

This malicious package has the same name torchtriton but added in code that uploads sensitive data from the machine.

What we know

torchtriton on PyPI contains a malicious triton binary which is installed at PYTHON_SITE_PACKAGES/triton/runtime/triton. Its SHA256 hash is listed below.

SHA256(triton)= 2385b29489cd9e35f92c072780f903ae2e517ed422eae67246ae50a5cc738a0e

The binary’s main function does the following:

  • Get system information
    • nameservers from /etc/resolv.conf
    • hostname from gethostname()
    • current username from getlogin()
    • current working directory name from getcwd()
    • environment variables
  • Read the following files
    • /etc/hosts
    • /etc/passwd
    • The first 1,000 files in $HOME/*
    • $HOME/.gitconfig
    • $HOME/.ssh/*
  • Upload all of this information, including file contents, via encrypted DNS queries to the domain *.h4ck[.]cfd, using the DNS server wheezy[.]io

The binary’s file upload functionality is limited to files less than 99,999 bytes in size. It also uploads only the first 1,000 files in $HOME (but all files < 99,999 bytes in the .ssh directory).

Steps taken towards mitigation

  • torchtriton has been removed as a dependency for our nightly packages and replaced with pytorch-triton (pytorch/pytorch#91539) and a dummy package registered on PyPI (so that this issue doesn’t repeat)
  • All nightly packages that depend on torchtriton have been removed from our package indices at until further notice
  • We have reached out to the PyPI security team to get proper ownership of the torchtriton package on PyPI and to delete the malicious version

Read More

Torchserve Performance Tuning, Animated Drawings Case-Study

Torchserve Performance Tuning, Animated Drawings Case-Study

Serving models in production

In this post we discuss performance tuning of Torchserve for serving your models in production. One of the biggest challenges in the life cycle of a ML project is deploying models in production. This requires a reliable serving solution along with solutions that address the MLOps needs. A robust serving solution needs to provide support for multi model serving, model versioning, metric logging, monitoring and scaling to serve the peak traffic. In this post, we will have an overview of Torchserve and how to tune its performance for production use-cases. We discuss the Animated Drawings app from Meta that can turn your human figure sketches to animations and how it could serve the peak traffic with Torchserve. The Animated Drawing’s workflow is below.

Many AI systems and tools are designed to handle realistic images of humans, children’s drawings add a level of complexity and unpredictability as they are often constructed in abstract, fanciful ways. These types of morphological and stylistic variations can confuse even state-of-the-art AI systems that excel at spotting objects in photorealistic images and drawings.
Meta AI researchers are working to overcome this challenge so that AI systems will be better able to recognize drawings of human figures in the wildly varied ways that children create them. This great blog post provides more details about the Animated Drawings and the approach taken.


Fig1. Overall flow of Torchserve performance tuning

Once you have trained your model, it needs to be integrated into a larger system to have a full-fledged application, we use the term “model serving” to refer to this integration. Basically model serving is making your trained model available to run inferences and subsequent use of the model.

Torchserve is the Pytorch preferred solution for serving models in production. It is a performant and scalable tool that wraps your model in a HTTP or HTTPS API. It has a frontend implemented in Java that handles multiple tasks from assigning workers for serving models to handling the connection between client and server. Torchserve has a Python backend that is responsible for handling the inference service.

Torchserve supports multi model serving and versioning for AB test, dynamic batching, logging and metrics. It exposes four APIs for inference, explanations, management and metrics.

Inference API is listening on port 8080 and accessible through localhost by default, this can be configured in Torchserve configuration and enable getting predictions from the model.

Explanation API uses Captum under the hood to provide explanations of the model that is being served and listens to the port 8080 as well.

Management API allows to register or unregister and describe a model. It also enables users to scale up or down the number of workers that serve the model.

Metric API by default listens to port 8082 and enables us to monitor the model that is being served.

Torchserve let you scale your model serving and handle the peak traffic by supporting batch inference and multiple workers that serve your model. Scaling can be done through management API and settings through a configuration file. Also, metric API helps you to monitor your model serving through default and customizable metrics.

Other advanced settings such as the length of the queue for the received requests, maximum wait time for a batch of inputs and many other properties are configurable through a config file that can be passed to Torchserve when it is started.

Steps to serve your model with Torchserve

  1. Install Torchserve, model archiver and its requirements.
  2. Choose a default handler that fits your task (e.g image classification, etc) or author a custom handler.
  3. Package your model artifacts (trained model checkpoint and all other necessary files for loading and running your model) and the handler into a “.mar” file using Torcharchive and place it in the model store.
  4. Start serving your model.
  5. Run inference.
    We will discuss model handlers and metrics in more detail here.

Model handlers

Torchserve uses a handler in the backend to load the models, preprocess the received data, run inference and post-process the response. Handler in torchserve is a python script that all the model initialization, preprocessing, inference and post processing logic goes into.

Torchserve provides an out of the box handler for a number of applications like image classification, segmentation, object detection and text classification. It also supports custom handlers, in case your use case is not supported in default handlers.

It provides a great flexibility in custom handlers, this potentially make Torchserve as multi-framework serving tool. Custom handlers let you define your custom logic to initialize a model that can be used also to load models from other frameworks such as ONNX.

Torchserve handler is made of four main functions, initialize, preprocess, inference and postprocess that each return a list. The code snippet below shows an example of a custom handler.Custom handlers inherit from BaseHandler in Torchserve and can overwrite any of the main functions. Here is an example of the handler used for loading the Detectron2 model for figure detection, this model has been exported to Torchscript and uses model.half() to run the inference with FP16, details are explained in another section in this post.

class MyModelHandler(BaseHandler):
    def initialize(self, context):
        self.manifest = ctx.manifest
        properties = ctx.system_properties
        model_dir = properties.get("model_dir")
        serialized_file = self.manifest["model"]["serializedFile"]
        model_pt_path = os.path.join(model_dir, serialized_file)

        self.device = torch.device(
        "cuda:" + str(properties.get("gpu_id"))
        if torch.cuda.is_available() and properties.get("gpu_id") is not None
        else "cpu"
        self.model = torch.jit.load(model_pt_path, map_location=self.device)

        self.model = self.model.half()

    def preprocess(self, data):

        inputs = []
        for request in batch:

            request_body = request.get("body")

            input_ = io.BytesIO(request_body)
            image = cv2.imdecode(np.fromstring(, np.uint8), 1)
            input = torch.Tensor(image).permute(2, 0, 1)
            input =
            input = input.half()
            inputs.append({"image": input})

        return inputs

    def inference(self,inputs):
        predictions = self.model(**inputs)
        return predictions

    def postprocess(self, output):
        responses = []
        for inference_output in inference_outputs:
            responses_json = {
            'classes': inference_output['pred_classes'].tolist(),
            'scores': inference_output['scores'].tolist(),
            "boxes": inference_output['pred_boxes'].tolist()

        return responses


An essential component in serving models in production is the ability to monitor them. Torchserve collects system level metrics regularly and allows adding custom metrics as well.

System level metrics consist of CPU utilization, available and used disk space and memory on the host machine along with number of requests with different response codes (e.g 200-300, 400-500 and above 500). Custom metrics can be added to the metrics as explained here. TorchServe logs these two sets of metrics to different log files. Metrics are collected by default at:

  • System metrics – log_directory/ts_metrics.log
  • Custom metrics – log directory/model_metrics.log

As mentioned before, Torchserve also exposes metric API, that by default listens to port 8082 and enables users to query and monitor the collected metrics. The default metrics endpoint returns Prometheus formatted metrics. You can query metrics using curl requests or point a Prometheus Server to the endpoint and use Grafana for dashboards.

While serving a model you can query metrics using curl request as follows:


In case you are looking into exporting the logged metrics, please refer to this example that uses mtail to export metrics to Prometheus. Tracking these metrics in a dashboard allows you to monitor performance regressions that may have been sporadic or hard to spot during an offline benchmark run.

What to consider for tuning performance of a model in production

The workflow suggested in Fig 1, is the general idea on how to approach model deployment in production with Torchserve.

In many cases serving models in production is optimized based on throughput or latency service level agreement (SLA)s. Usually real-time applications are more concerned about latency whereas off-line applications may care more about higher throughput.

There are a number of main factors contributing to the performance of a serving model in production. In particular, we are focusing on serving Pytorch models with Torchserve here, however most of these factors generalize to all models from other frameworks as well.

  • Model optimizations: this is a pre-step for deploying models into production. This is a very broad discussion that we will get into in a series of future blogs. This includes techniques like quantization, pruning to decrease the size of the model, using Intermediate representations (IR graphs) such as Torchscript in Pytorch, fusing kernels and many others. Currently torchprep provides many of these techniques as a CLI tool.
  • Batch inference: it refers to feeding multiple inputs into a model, while it is essential during training, it can be very helpful to manage the cost at inference time as well. Hardware accelerators are optimized for parallelism and batching helps to saturate the compute capacity and often leads to higher throughput. The main difference in inference is you can’t wait too long to get a batch filled from clients, something we call dynamic batching
  • Number of Workers : Torchserve uses workers to serve models. Torchserve workers are Python processes that hold a copy of the model weights for running inference. Too few workers means you’re not benefitting from enough parallelism but too many can cause worker contention and degrade end to end performance.

  • Hardware : choosing the appropriate hardware based on the model, application and latency, throughput budget. This could be one of the supported hardwares in Torchserve, CPU, GPU, AWS Inferentia. Some hardware configurations are intended for best in class performance and others are better suited for cost effective inference. From our experiments we’ve found that GPUs shine best at larger batch sizes whereas the right CPUs and AWS Inferentia can be far more cost effective for lower batch sizes and low latency.

Best Practices for Performance tuning on Torchserve

To get the best performance out of your model while serving it with Torchserve, we are sharing some of the best practices here. Torchserve provides a benchmark suite that provides helpful insight to make informed decisions on different choices as detailed below.

  • Optimize your model as the first step, Pytorch model optimization tutorials. Model optimization choices are also closely tied to the hardware of choice. We will discuss it in more detail in another blog post.
  • Deciding the hardware for model deployment can be closely related to the latency and throughput budget and cost per inference. Depending on the size of model and application it can vary, for some models like computer vision models it has been historically not affordable to run in production on CPU. However, by having optimizations such IPEX as recently added to Torchserve this has been much more affordable and cost beneficial and you can learn more in this investigative case study
  • Workers in Torchserve are Python processes that provide parallelism, setting the number of workers should be done carefully. By default Torchserve launch number of workers equal to VCPUs or available GPUs on the host, this can add a considerable amount of time to the Torchserve start.

    Torchserve exposes a config property to set the number of workers. To provide an efficient parallelism through multiple workers and avoiding them to compete over resources, as a baseline we recommend following setting on CPU and GPU:

    CPU : In the handler, torch.set_num_threads(1) then set the number of workers to num physical cores / 2. But the the best threading configurations can be achieved by leveraging the Intel CPU launcher script.

    GPU: number of available GPUs can be set through number_gpus in Torchserve uses round robin to assign workers to GPUs. We recommend setting the number of workers as follows. Number of worker = (Number of available GPUs) / (Number of Unique Models). Note that GPUs that are pre-Ampere do not provide any resource isolation with Multi Instance GPUs.

  • Batch size can directly affect the latency and the throughput. To better utilize the compute resources batch size needs to be increased. However, there is a tradeoff between latency and throughput. Larger batch sizes can increase the throughput but results in a higher latency as well. Batch size can be set in Torchserve in two ways, either through model config in or while registering the model using Management API.

In the next section, we are going to use Torchserve benchmark suite to decide the best combination of model optimization, hardware, workers, and batch size.

Animated Drawings Performance Tuning

To use the Torchserve benchmark suite, first we need to have an archived file, “.mar” file as discussed above, that contains the model, handler and all other artifacts to load and run inference. Animated Drawings uses Detectron2’s implementation of Mask-RCNN for an object detection model.

How to run benchmark suite

The Automated benchmark suite in Torchserve let you benchmark multiple models with different setting including batch size and number of worker and finally generate a report for you. To get started:

git clone

cd serve/benchmarks

pip install -r requirements-ab.txt

apt-get install apache2-utils

Model level settings can be configured in a yaml file similar to

        benchmark_engine: "ab"
        url: "Path to .mar file"
            - 1
            - 4
        batch_delay: 100
            - 1
            - 2
            - 4
            - 8
        requests: 10000
        concurrency: 10
        input: "Path to model input"
        backend_profiling: False
        exec_env: "local"
            - "cpu"
            - "gpus": "all"

This yaml file will be referenced in the benchmark_config_template.yaml file that includes other settings for generating reports, this can optionally work with AWS cloud watch for logs as well.

python benchmarks/ --input benchmark_config_template.yaml

Running the benchmarks, results will be written in “csv” file that can be found in “_ /tmp/benchmark/ab_report.csv_” and full report “/tmp/ts_benchmark/”. It will include items such as Torchserve average latency, model P99 latency, throughput, number of concurrency, number of requests, handler time, and some other metrics. Here we focus on some of the important ones that we track to tune the performance which are, concurrency, model P99 latency, throughput. We look at these numbers specifically in combination with batch size, the used device, number of workers and if any model optimization has been done.

The latency SLA for this model has been set to 100 ms, this is real-time application and as we discussed earlier, latency is more of a concern and throughput ideally should be as high as possible while it does not violate the latency SLA.

Through searching the space, over different batch sizes (1-32), number of workers (1-16) and devices (CPU,GPU), we have run a set of experiments that summarized the best ones in the table below.

Device Concurrency # Requests #workers Batch size Payload/image Optimization Throughput Latency P99
CPU 10 1000 1 1 small N/A 3.45 305.3 ms
CPU 1 1000 1 1 small N/A 3.45 291.8 ms
GPU 10 1000 1 1 small N/A 41.05 25.48 ms
GPU 1 1000 1 1 small N/A 42.21 23.6 ms
GPU 10 1000 1 4 small N/A 54.78 73.62 ms
GPU 10 1000 1 4 small model.half() 78.62 50.69 ms
GPU 10 1000 1 8 small model.half() 85.29 94.4 ms

The latency of this model on CPU with all of the tried settings in terms of batch size, concurrency and number of workers did not meet the SLA, in fact ~13x higher.

Moving the model serving to GPU, immediately could improve the latency ~**13x **from 305 ms down to 23.6 ms.

One of the simplest optimizations that we could do for the model was lowering its precision to fp16, it is one liner (model.half()) and could reduce the model P99 latency **by **32% and increase the throughput by almost the same amount.

There could be other optimization done by Torchscripting the model and using optimize_for_inference or other tricks including onnx or tensorrt runtime optimizations which leverage aggressive fusions are out of the scope of this post. We will discuss model optimizations in a separate post.

We found both on CPU and GPU , setting **number of workers=1 **worked the best in this case.

  • Moving the model to GPU, using number of workers = 1, and batch size = 1 increased the Throughput ~12x compared to CPU and latency ~13x.
  • Moving the model to GPU, using model.half(), number of workers = 1, and batch size = 8 yielded best results in terms of Throughput and tolerable latency. Throughput increased ~25x compared to CPU with latency still meeting the SLA (94.4ms).

Note: if you are running the benchmark suite, make sure you are setting a proper batch_delay and set the concurrency of the request to a number proportional to your batch size. Concurrency here means the number of concurrent requests being sent to the server.


In this post, we have discussed the considerations and knobs that Torchserve expose to tune the performance in production. We have discussed the Torchserve benchmark suite as a means to tune the performance and get insights on possible choices for model optimizations, hardware choice and cost in general. We used Animated Drawings app which uses Detectron2’s Mask-RCNN model as a case-study to showcase the performance tuning with benchmark suite.

For more details on Performance tuning in Torchserve please refer to our documentation here.
Also feel free to open a ticket on Torchserve repo for any further questions and feedback.


We would like to thank Somya Jain (Meta), Christopher Gustave (Meta) for their great support and guidance throughout many steps of this blog and providing insights to Sketch Animator workflow. Also, special thanks to Li Ning from AWS for the great efforts to make performance tuning much easier on Torchserve with automated benchmark suite.

Read More