PyTorch 2.1: automatic dynamic shape compilation, distributed checkpointing

We are excited to announce the release of PyTorch® 2.1 (release note)! PyTorch 2.1 offers automatic dynamic shape support in torch.compile, torch.distributed.checkpoint for saving/loading distributed training jobs on multiple ranks in parallel, and torch.compile support for the NumPy API.

In addition, this release offers numerous performance improvements (e.g. CPU inductor improvements, AVX512 support, scaled-dot-product-attention support) as well as a prototype release of torch.export, a sound full-graph capture mechanism, and torch.export-based quantization.

Along with 2.1, we are also releasing a series of updates to the PyTorch domain libraries. More details can be found in the library updates blog. 

This release is composed of 6,682 commits and 784 contributors since 2.0. We want to sincerely thank our dedicated community for your contributions. As always, we encourage you to try these out and report any issues as we improve 2.1.  More information about how to get started with the PyTorch 2-series can be found at our Getting Started page.


  • torch.compile now includes automatic support for detecting and minimizing recompilations due to tensor shape changes using automatic dynamic shapes.
  • torch.distributed.checkpoint enables saving and loading models from multiple ranks in parallel, as well as resharding due to changes in cluster topology.
  • torch.compile can now compile NumPy operations via translating them into PyTorch-equivalent operations.
  • torch.compile now includes improved support for Python 3.11.
  • New CPU performance features include inductor improvements (e.g. bfloat16 support and dynamic shapes), AVX512 kernel support, and scaled-dot-product-attention kernels.
  • torch.export, a sound full-graph capture mechanism is introduced as a prototype feature, as well as torch.export-based quantization.
  • torch.sparse now includes prototype support for semi-structured (2:4) sparsity on NVIDIA® GPUs.
Stable Beta Prototype Performance Improvements
  Automatic Dynamic Shapes torch.export() AVX512 kernel support
  torch.distributed.checkpoint Torch.export-based Quantization CPU optimizations for scaled-dot-product-attention (SPDA)
  torch.compile + NumPy semi-structed (2:4) sparsity CPU optimizations for bfloat16
  torch.compile + Python 3.11 cpp_wrapper for torchinductor  
  torch.compile + autograd.Function    
  third-party device integration: PrivateUse1    

*To see a full list of public 2.1, 2.0, and 1.13 feature submissions click here.

Beta Features

(Beta) Automatic Dynamic Shapes

Dynamic shapes is functionality built into torch.compile that can minimize recompilations by tracking and generating code based on the symbolic shape of a tensor rather than the static shape (e.g. [B, 128, 4] rather than [64, 128, 4]). This allows torch.compile to generate a single kernel that can work for many sizes, at only a modest cost to efficiency. Dynamic shapes has been greatly stabilized in PyTorch 2.1, and is now automatically enabled if torch.compile notices recompilation due to varying input shapes. You can disable automatic dynamic by passing dynamic=False to torch.compile, or by setting torch._dynamo.config.automatic_dynamic_shapes = False.

In PyTorch 2.1, we have shown good performance with dynamic shapes enabled on a variety of model types, including large language models, on both CUDA and CPU.

For more information on dynamic shapes, see this documentation.

[Beta] torch.distributed.checkpoint

torch.distributed.checkpoint enables saving and loading models from multiple ranks in parallel. In addition, checkpointing automatically handles fully-qualified-name (FQN) mappings across models and optimizers, enabling load-time resharding across differing cluster topologies.

For more information, see torch.distributed.checkpoint documentation and tutorial.

[Beta] torch.compile + NumPy

torch.compile now understands how to compile NumPy operations via translating them into PyTorch-equivalent operations.  Because this integration operates in a device-agnostic manner, you can now GPU-accelerate NumPy programs – or even mixed NumPy/PyTorch programs – just by using torch.compile.

Please see this section in the torch.compile FAQ for more information about torch.compile + NumPy interaction, and follow the PyTorch Blog for a forthcoming blog about this feature.

[Beta] torch.compile + Python 3.11

torch.compile previously only supported Python versions 3.8-3.10. Users can now optimize models with torch.compile in Python 3.11.

[Beta] torch.compile + autograd.Function

torch.compile can now trace and optimize the backward function of user-defined autograd Functions, which unlocks training optimizations for models that make heavier use of extensions mechanisms.

[Beta] Improved third-party device support: PrivateUse1

Third-party device types can now be registered to PyTorch using the privateuse1 dispatch key.  This allows device extensions to register new kernels to PyTorch and to associate them with the new key, allowing user code to work equivalently to built-in device types.  For example, to register “my_hardware_device”, one can do the following:

x = torch.randn((2, 3), device='my_hardware_device')
y = x + x # run add kernel on 'my_hardware_device'

To validate this feature, the OSS team from Ascend NPU has successfully integrated torch_npu into pytorch as a plug-in through the PrivateUse1 functionality.

For more information, please see the PrivateUse1 tutorial here.

Prototype Features

[Prototype] torch.export()

torch.export() provides a sound tracing mechanism to capture a full graph from a PyTorch program based on new technologies provided by PT2.0.

Users can extract a clean representation (Export IR) of a PyTorch program in the form of a dataflow graph, consisting of mostly straight-line calls to PyTorch operators. Export IR can then be transformed, serialized, saved to file, transferred, loaded back for execution in an environment with or without Python.

For more information, please see the tutorial here.

[Prototype] torch.export-based Quantization now supports post-training static quantization on PyTorch2-based torch.export flows.  This includes support for built-in XNNPACK and X64Inductor Quantizer, as well as the ability to specify one’s own Quantizer.

For an explanation on post-training static quantization with torch.export, see this tutorial, for quantization-aware training for static quantization with torch.export, see this tutorial.

For an explanation on how to write one’s own Quantizer, see this tutorial.

[Prototype] semi-structured (2:4) sparsity for NVIDIA® GPUs

torch.sparse now supports creating and accelerating compute over semi-structured sparse (2:4) tensors.  For more information on the format, see this blog from NVIDIA.A minimal example introducing semi-structured sparsity is as follows:

from torch.sparse import to_sparse_semi_structured
x = torch.rand(64, 64).half().cuda()
mask = torch.tensor([0, 0, 1, 1]).tile((64, 16)).cuda().bool()
linear = nn.Linear(64, 64).half().cuda()

linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight.masked_fill(~mask, 0)))

To learn more, please see the documentation and accompanying tutorial.

[Prototype] cpp_wrapper for torchinductor

cpp_wrapper can reduce the Python overhead for invoking kernels in torchinductor by generating the kernel wrapper code in C++. This feature is still in the prototype phase; it does not support all programs that successfully compile in PT2 today. Please file issues if you discover limitations for your use case to help us prioritize.

The API to turn this feature on is:

import torch
import torch._inductor.config as config
config.cpp_wrapper = True

For more information, please see the tutorial.

Performance Improvements

AVX512 kernel support

In PyTorch 2.0, AVX2 kernels would be used even if the CPU supported AVX512 instructions.  Now, PyTorch defaults to using AVX512 CPU kernels if the CPU supports those instructions, equivalent to setting ATEN_CPU_CAPABILITY=avx512 in previous releases.  The previous behavior can be enabled by setting ATEN_CPU_CAPABILITY=avx2.

CPU optimizations for scaled-dot-product-attention (SDPA)

Previous versions of PyTorch provided optimized CUDA implementations for transformer primitives via torch.nn.functiona.scaled_dot_product_attention.  PyTorch 2.1 includes optimized FlashAttention-based CPU routines.

See the documentation here.

CPU optimizations for bfloat16

PyTorch 2.1 includes CPU optimizations for bfloat16, including improved vectorization support and torchinductor codegen.

Read More