New library updates in PyTorch 1.12

New library updates in PyTorch 1.12

We are bringing a number of improvements to the current PyTorch libraries, alongside the PyTorch 1.12 release. These updates demonstrate our focus on developing common and extensible APIs across all domains to make it easier for our community to build ecosystem projects on PyTorch.


  • TorchVision – Added multi-weight support API, new architectures, model variants, and pretrained weight. See the release notes here.
  • TorchAudio – Introduced beta features including a streaming API, a CTC beam search decoder, and new beamforming modules and methods. See the release notes here.
  • TorchText – Extended support for scriptable BERT tokenizer and added datasets for GLUE benchmark. See the release notes here.
  • TorchRec – Added EmbeddingModule benchmarks, examples for TwoTower Retrieval, inference and sequential embeddings, metrics, improved planner and demonstrated integration with production components. See the release notes here.
  • TorchX – Launch PyTorch trainers developed on local workspaces onto five different types of schedulers. See the release notes here.
  • FBGemm – Added and improved kernels for Recommendation Systems inference workloads, including table batched embedding bag, jagged tensor operations, and other special-case optimizations.

TorchVision v0.13

Multi-weight support API

TorchVision v0.13 offers a new Multi-weight support API for loading different weights to the existing model builder methods:

from torchvision.models import *

# Old weights with accuracy 76.130%

# New weights with accuracy 80.858%

# Best available weights (currently alias for IMAGENET1K_V2)
# Note that these weights may change across versions

# Strings are also supported

# No weights - random initialization

The new API bundles along with the weights important details such as the preprocessing transforms and meta-data such as labels. Here is how to make the most out of it:

from import read_image
from torchvision.models import resnet50, ResNet50_Weights

img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")

You can read more about the new API in the docs. To provide your feedback, please use this dedicated Github issue.

New architectures and model variants


The Swin Transformer and EfficienetNetV2 are two popular classification models which are often used for downstream vision tasks. This release includes 6 pre-trained weights for their classification variants. Here is how to use the new models:

import torch
from torchvision.models import *

image = torch.rand(1, 3, 224, 224)
model = swin_t(weights="DEFAULT").eval()
prediction = model(image)

image = torch.rand(1, 3, 384, 384)
model = efficientnet_v2_s(weights="DEFAULT").eval()
prediction = model(image)

In addition to the above, we also provide new variants for existing architectures such as ShuffleNetV2, ResNeXt and MNASNet. The accuracies of all the new pre-trained models obtained on ImageNet-1K are seen below:

Model Acc@1 Acc@5
swin_t 81.474 95.776
swin_s 83.196 96.36
swin_b 83.582 96.64
efficientnet_v2_s 84.228 96.878
efficientnet_v2_m 85.112 97.156
efficientnet_v2_l 85.808 97.788
resnext101_64x4d 83.246 96.454
resnext101_64x4d (quantized) 82.898 96.326
shufflenet_v2_x1_5 72.996 91.086
shufflenet_v2_x1_5 (quantized) 72.052 0.700
shufflenet_v2_x2_0 76.230 93.006
shufflenet_v2_x2_0 (quantized) 75.354 92.488
mnasnet0_75 71.180 90.496
mnas1_3 76.506 93.522

We would like to thank Hu Ye for contributing to TorchVision the Swin Transformer implementation.

(BETA) Object Detection and Instance Segmentation

We have introduced 3 new model variants for RetinaNet, FasterRCNN and MaskRCNN that include several post-paper architectural optimizations and improved training recipes. All models can be used similarly:

import torch
from torchvision.models.detection import *

images = [torch.rand(3, 800, 600)]
model = retinanet_resnet50_fpn_v2(weights="DEFAULT")
# model = fasterrcnn_resnet50_fpn_v2(weights="DEFAULT")
# model = maskrcnn_resnet50_fpn_v2(weights="DEFAULT")
prediction = model(images)

Below we present the metrics of the new variants on COCO val2017. In parenthesis we denote the improvement over the old variants:

Model Box mAP Mask mAP
retinanet_resnet50_fpn_v2 41.5 (+5.1)
fasterrcnn_resnet50_fpn_v2 46.7 (+9.7)
maskrcnn_resnet50_fpn_v2 47.4 (+9.5) 41.8 (+7.2)

We would like to thank Ross Girshick, Piotr Dollar, Vaibhav Aggarwal, Francisco Massa and Hu Ye for their past research and contributions to this work.

New pre-trained weights

SWAG weights

The ViT and RegNet model variants offer new pre-trained SWAG (​​Supervised Weakly from hashtAGs) weights. One of the biggest of these models achieves a whopping 88.6% accuracy on ImageNet-1K. We currently offer two versions of the weights: 1) fine-tuned end-to-end weights on ImageNet-1K (highest accuracy) and 2) frozen trunk weights with a linear classifier fit on ImageNet-1K (great for transfer learning). Below we see the detailed accuracies of each model variant:

Model Weights Acc@1 Acc@5
RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V1 86.012 98.054
RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_LINEAR_V1 83.976 97.244
RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1 86.838 98.362
RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_LINEAR_V1 84.622 97.48
RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_E2E_V1 88.228 98.682
RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_LINEAR_V1 86.068 97.844
ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1 85.304 97.65
ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1 81.886 96.18
ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1 88.064 98.512
ViT_L_16_Weights.IMAGENET1K_SWAG_LINEAR_V1 85.146 97.422
ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1 88.552 98.694
ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1 85.708 97.73

The SWAG weights are released under the Attribution-NonCommercial 4.0 International license. We would like to thank Laura Gustafson, Mannat Singh and Aaron Adcock for their work and support in making the weights available to TorchVision.

Model Refresh

The release of the Multi-weight support API enabled us to refresh the most popular models and offer more accurate weights. We improved on average each model by ~3 points. The new recipe used was learned on top of ResNet50 and its details were covered on a previous blog post.

Model Old weights New weights
efficientnet_b1 78.642 79.838
mobilenet_v2 71.878 72.154
mobilenet_v3_large 74.042 75.274
regnet_y_400mf 74.046 75.804
regnet_y_800mf 76.42 78.828
regnet_y_1_6gf 77.95 80.876
regnet_y_3_2gf 78.948 81.982
regnet_y_8gf 80.032 82.828
regnet_y_16gf 80.424 82.886
regnet_y_32gf 80.878 83.368
regnet_x_400mf 72.834 74.864
regnet_x_800mf 75.212 77.522
regnet_x_1_6gf 77.04 79.668
regnet_x_3_2gf 78.364 81.196
regnet_x_8gf 79.344 81.682
regnet_x_16gf 80.058 82.716
regnet_x_32gf 80.622 83.014
resnet50 76.13 80.858
resnet50 (quantized) 75.92 80.282
resnet101 77.374 81.886
resnet152 78.312 82.284
resnext50_32x4d 77.618 81.198
resnext101_32x8d 79.312 82.834
resnext101_32x8d (quantized) 78.986 82.574
wide_resnet50_2 78.468 81.602
wide_resnet101_2 78.848 82.51

We would like to thank Piotr Dollar, Mannat Singh and Hugo Touvron for their past research and contributions to this work.

New Augmentations, Layers and Losses

This release brings a bunch of new primitives which can be used to produce SOTA models. Some highlights include the addition of AugMix data-augmentation method, the DropBlock layer, the cIoU/dIoU loss and many more. We would like to thank Aditya Oke, Abhijit Deo, Yassine Alouini and Hu Ye for contributing to the project and for helping us maintain TorchVision relevant and fresh.


We completely revamped our models documentation to make them easier to browse, and added various key information such as supported image sizes, or image pre-processing steps of pre-trained weights. We now have a main model page with various summary tables of available weights, and each model has a dedicated page. Each model builder is also documented in their own page, with more details about the available weights, including accuracy, minimal image size, link to training recipes, and other valuable info. For comparison, our previous models docs are here. To provide feedback on the new documentation, please use the dedicated Github issue.

TorchAudio v0.12

(BETA) Streaming API

StreamReader is TorchAudio’s new I/O API. It is backed by FFmpeg†, and allows users to:

  • Decode audio and video formats, including MP4 and AAC
  • Handle input forms, such as local files, network protocols, microphones, webcams, screen captures and file-like objects
  • Iterate over and decode chunk-by-chunk, while changing the sample rate or frame rate
  • Apply audio and video filters, such as low-pass filter and image scaling
  • Decode video with Nvidia’s hardware-based decoder (NVDEC)

For usage details, please check out the documentation and tutorials:

† To use StreamReader, FFmpeg libraries are required. Please install FFmpeg. The coverage of codecs depends on how these libraries are configured. TorchAudio official binaries are compiled to work with FFmpeg 4 libraries; FFmpeg 5 can be used if TorchAudio is built from source.

(BETA) CTC Beam Search Decoder

TorchAudio integrates the wav2letter CTC beam search decoder from Flashlight (GitHub). The addition of this inference time decoder enables running end-to-end CTC ASR evaluation using TorchAudio utils.

Customizable lexicon and lexicon-free decoders are supported, and both are compatible with KenLM n-gram language models or without using a language model. TorchAudio additionally supports downloading token, lexicon, and pretrained KenLM files for the LibriSpeech dataset.

For usage details, please check out the documentation and ASR inference tutorial.

(BETA) New Beamforming Modules and Methods

To improve flexibility in usage, the release adds two new beamforming modules under torchaudio.transforms: SoudenMVDR and RTFMVDR. The main differences from MVDR are:

  • Use power spectral density (PSD) and relative transfer function (RTF) matrices as inputs instead of time-frequency masks. The module can be integrated with neural networks that directly predict complex-valued STFT coefficients of speech and noise
  • Add ‘reference_channel’ as an input argument in the forward method, to allow users to select the reference channel in model training or dynamically change the reference channel in inference

Besides the two modules, new function-level beamforming methods are added under torchaudio.functional. These include:

For usage details, please check out the documentation at torchaudio.transforms and torchaudio.functional and the Speech Enhancement with MVDR Beamforming tutorial.

TorchText v0.13

Glue Datasets

We increased the number of datasets in TorchText from 22 to 30 by adding the remaining 8 datasets from the GLUE benchmark (SST-2 was already supported). The complete list of GLUE datasets is as follows:

  • CoLA (paper): Single sentence binary classification acceptability task
  • SST-2 (paper): Single sentence binary classification sentiment task
  • MRPC (paper): Dual sentence binary classification paraphrase task
  • QQP: Dual sentence binary classification paraphrase task
  • STS-B (paper): Single sentence to float regression sentence similarity task
  • MNLI (paper): Sentence ternary classification NLI task
  • QNLI (paper): Sentence binary classification QA and NLI tasks
  • RTE (paper): Dual sentence binary classification NLI task
  • WNLI (paper): Dual sentence binary classification coreference and NLI tasks

Scriptable BERT Tokenizer

TorchText has extended support for scriptable tokenizer by adding the WordPiece tokenizer used in BERT. It is one of the commonly used algorithms for splitting input text into sub-words units and was introduced in Japanese and Korean Voice Search (Schuster et al., 2012).

TorchScriptabilty support would allow users to embed the BERT text-preprocessing natively in C++ without needing the support of python runtime. As TorchText now supports the CMAKE build system to natively link torchtext binaries with application code, users can easily integrate BERT tokenizers for deployment needs.

For usage details, please refer to the corresponding documentation.

TorchRec v0.2.0

EmbeddingModule + DLRM benchmarks

A set of benchmarking tests, showing performance characteristics of TorchRec’s base modules and research models built out of TorchRec.

TwoTower Retrieval Example, with FAISS

We provide an example demonstrating training a distributed TwoTower (i.e. User-Item) Retrieval model that is sharded using TorchRec. The projected item embeddings are added to an IVFPQ FAISS index for candidate generation. The retrieval model and KNN lookup are bundled in a Pytorch model for efficient end-to-end retrieval.


We demonstrate that TorchRec works out of the box with many components commonly used alongside PyTorch models in production like systems, such as

  • Training a TorchRec model on Ray Clusters utilizing the Torchx Ray scheduler
  • Preprocessing and DataLoading with NVTabular on DLRM
  • Training a TorchRec model with on-the-fly preprocessing with TorchArrow showcasing RecSys domain UDFs

Sequential Embeddings Example: Bert4Rec

We provide an example, using TorchRec, that reimplements the BERT4REC paper, showcasing EmbeddingCollection for non-pooled embeddings. Using DistributedModelParallel we see a 35% QPS gain over conventional data parallelism.

(Beta) Planner

The TorchRec library includes a built-in planner that selects near optimal sharding plan for a given model. The planner attempts to identify the best sharding plan by evaluating a series of proposals which are statically analyzed and fed into an integer partitioner. The planner is able to automatically adjust plans for a wide range of hardware setups, allowing users to scale performance seamlessly from local development environment to large scale production hardware. See this notebook for a more detailed tutorial.

(Beta) Inference

TorchRec Inference is a C++ library that supports multi-gpu inference. The TorchRec library is used to shard models written and packaged in Python via torch.package (an alternative to TorchScript). The torch.deploy library is used to serve inference from C++ by launching multiple Python interpreters carrying the packaged model, thus subverting the GIL. Two models are provided as examples: DLRM multi-GPU (sharded via TorchRec) and DLRM single-GPU.

(Beta) RecMetrics

RecMetrics is a metrics library that collects common utilities and optimizations for Recommendation models. It extends torchmetrics.

  • A centralized metrics module that allows users to add new metrics
  • Commonly used metrics, including AUC, Calibration, CTR, MSE/RMSE, NE & Throughput
  • Optimization for metrics related operations to reduce the overhead of metric computation
  • Checkpointing

(Prototype) Single process Batched + Fused Embeddings

Previously TorchRec’s abstractions (EmbeddingBagCollection/EmbeddingCollection) over FBGEMM kernels, which provide benefits such as table batching, optimizer fusion, and UVM placement, could only be used in conjunction with DistributedModelParallel. We’ve decoupled these notions from sharding, and introduced the FusedEmbeddingBagCollection, which can be used as a standalone module, with all of the above features, and can also be sharded.

TorchX v0.2.0

TorchX is a job launcher that makes it easier to run PyTorch in distributed training clusters with many scheduler integrations including Kubernetes and Slurm. We’re excited to release TorchX 0.2.0 with a number of improvements. TorchX is currently being used in production in both on-premise and cloud environments.

Check out the quickstart to start launching local and remote jobs.


TorchX now supports workspaces which allows users to easily launch training jobs using their local workspace. TorchX can automatically build a patch with your local training code on top of a base image to minimize iteration time and time to training.


Specifying options in .torchxconfig saves you from having to type long CLI commands each time you launch a job. You can also define project level generic configs and drop a config file in your home directory for user-level overrides.

Expanded Scheduler Support

TorchX now supports AWS Batch and Ray (experimental) schedulers in addition to our existing integrations.

Distributed Training On All Schedulers

The TorchX dist.ddp component now works on all schedulers without any configuration. Distributed training workers will automatically discover each other when using torchelastic via the builtin dist.ddp component.

Hyper Parameter Optimization

TorchX integrates with Ax to let you scale hyper-parameter optimizations (HPO) by launching the search trials onto remote clusters.

File and Device Mounts

TorchX now supports remote filesystem mounts and custom devices. This enables your PyTorch jobs to efficiently access cloud storage such as NFS or Lustre. The device mounts enables usage of network accelerators like Infiniband and custom inference/training accelerators.

FBGemm v0.2.0

The FBGEMM library contains optimized kernels meant to improve the performance of PyTorch workloads. We’ve added a number of new features and optimizations over the last few months that we are excited to report.

Inference Table Batched Embedding (TBE)

The table batched embedding bag (TBE) operator is an important base operation for embedding lookup for recommendation system inference on GPU. We added the following enhancements for performance and flexibility:

Alignment restriction removed

  • Embedding dimension * data type size had to be multiple of 4B before and now, it is 1B.

Unified Virtual Memory (UVM) caching kernel optimizations

  • UVM caching kernels now scale linearly with # of tables using UVM caching. Previously, it was having similar overhead as all tables using UVM caching
  • UVM caching kernel overhead is much smaller than before

Inference FP8 Table Batched Embedding (TBE)

The table batched embedding bag (TBE) previously supported FP32, FP16, INT8, INT4, and INT2 embedding weight types. While these weight types work well in many models, we integrate FP8 weight types (in both GPU and CPU operations) to allow for numerical and performance evaluations of FP8 in our models. Compared to INT8, FP8 does not require the additional bias and scale storage and calculations. Additionally, the next generation of H100 GPUs has the FP8 support on Tensor Core (mainly matmul ops).

Jagged Tensor Kernels

We added optimized kernels to speed up TorchRec JaggedTensor. The purpose of JaggedTensor is to handle the case where one dimension of the input data is “jagged”, meaning that each consecutive row in a given dimension may be a different length, which is often the case with sparse feature inputs in recommendation systems. The internal representation is shown below:

We added ops for converting jagged tensors from sparse to dense formats and back, performing matrix multiplications with jagged tensors, and elementwise ops.

Optimized permute102-baddbmm-permute102

It is difficult to fuse various matrix multiplications where the batch size is not the batch size of the model, switching the batch dimension is a quick solution. We created the permute102_baddbmm_permute102 operation that switches the first and the second dimension, performs the batched matrix multiplication and then switches back. Currently we only support forward pass with FP16 data type and will support FP32 type and backward pass in the future.

Optimized index_select for dim 0 index selection

index_select is normally used as part of a sparse operation. While PyTorch supports a generic index_select for an arbitrary-dimension index selection, its performance for a special case like the dim 0 index selection is suboptimal. For this reason, we implement a specialized index_select for dim 0. In some cases, we have observed 1.4x performance gain from FBGEMM’s index_select compared to the one from PyTorch (using uniform index distribution).

More about the implementation of influential instances can be found on our GitHub page and tutorials.

Thanks for reading, If you’re interested in these updates and want to join the PyTorch community, we encourage you to join the discussion forums and open GitHub issues. To get the latest news from PyTorch, follow us on Twitter, Medium, YouTube, and LinkedIn.


Team PyTorch

Read More

PyTorch 1.12: TorchArrow, Functional API for Modules and nvFuser, are now available

PyTorch 1.12: TorchArrow, Functional API for Modules and nvFuser, are now available

We are excited to announce the release of PyTorch 1.12 (release note)! This release is composed of over 3124 commits, 433 contributors. Along with 1.12, we are releasing beta versions of AWS S3 Integration, PyTorch Vision Models on Channels Last on CPU, Empowering PyTorch on Intel® Xeon® Scalable processors with Bfloat16 and FSDP API. We want to sincerely thank our dedicated community for your contributions.


  • Functional APIs to functionally apply module computation with a given set of parameters
  • Complex32 and Complex Convolutions in PyTorch
  • DataPipes from TorchData fully backward compatible with DataLoader
  • functorch with improved coverage for APIs
  • nvFuser a deep learning compiler for PyTorch
  • Changes to float32 matrix multiplication precision on Ampere and later CUDA hardware
  • TorchArrow, a new beta library for machine learning preprocessing over batch data

Frontend APIs

Introducing TorchArrow

We’ve got a new Beta release ready for you to try and use: TorchArrow. This is a library for machine learning preprocessing over batch data. It features a performant and Pandas-style, easy-to-use API in order to speed up your preprocessing workflows and development.

Currently, it provides a Python DataFrame interface with the following features:

  • High-performance CPU backend, vectorized and extensible User-Defined Functions (UDFs) with Velox
  • Seamless handoff with PyTorch or other model authoring, such as Tensor collation and easily plugging into PyTorch DataLoader and DataPipes
  • Zero copy for external readers via Arrow in-memory columnar format

For more details, please find our 10-min tutorial, installation instructions, API documentation, and a prototype for data preprocessing in TorchRec.

(Beta) Functional API for Modules

PyTorch 1.12 introduces a new beta feature to functionally apply Module computation with a given set of parameters. Sometimes, the traditional PyTorch Module usage pattern that maintains a static set of parameters internally is too restrictive. This is often the case when implementing algorithms for meta-learning, where multiple sets of parameters may need to be maintained across optimizer steps.

The new torch.nn.utils.stateless.functional_call() API allows for:

  • Module computation with full flexibility over the set of parameters used
  • No need to reimplement your module in a functional way
  • Any parameter or buffer present in the module can be swapped with an externally-defined value for use in the call. Naming for referencing parameters / buffers follows the fully-qualified form in the module’s state_dict()


import torch
from torch import nn
from torch.nn.utils.stateless import functional_call

class MyModule(nn.Module):
    def __init__(self):
        self.fc1 = nn.Linear(3, 3) = nn.BatchNorm1d(3)
        self.fc2 = nn.Linear(3, 3)

    def forward(self, x):
        return self.fc2(

m = MyModule()

# Define parameter / buffer values to use during module computation.
my_weight = torch.randn(3, 3, requires_grad=True)
my_bias = torch.tensor([1., 2., 3.], requires_grad=True)
params_and_buffers = {
    'fc1.weight': my_weight,
    'fc1.bias': my_bias,
    # Custom buffer values can be used too.
    'bn.running_mean': torch.randn(3),

# Apply module computation to the input with the specified parameters / buffers.
inp = torch.randn(5, 3)
output = functional_call(m, params_and_buffers, inp)

(Beta) Complex32 and Complex Convolutions in PyTorch

PyTorch today natively supports complex numbers, complex autograd, complex modules, and numerous complex operations, including linear algebra and Fast Fourier Transform (FFT) operators. Many libraries, including torchaudio and ESPNet, already make use of complex numbers in PyTorch, and PyTorch 1.12 further extends complex functionality with complex convolutions and the experimental complex32 (“complex half”) data type that enables half precision FFT operations. Due to the bugs in CUDA 11.3 package, we recommend using CUDA 11.6 package from wheels if you are using complex numbers.

(Beta) Forward-mode Automatic Differentiation

Forward-mode AD allows the computation of directional derivatives (or equivalently, Jacobian-vector products) eagerly in the forward pass. PyTorch 1.12 significantly improves the operator coverage for forward-mode AD. See our tutorial for more information.


BC DataLoader + DataPipe

`DataPipe` from TorchData becomes fully backward compatible with the existing `DataLoader` regarding shuffle determinism and dynamic sharding in both multiprocessing and distributed environments. For more details, please check out the tutorial.

(Beta) AWS S3 Integration

DataPipes based on AWSSDK have been integrated into TorchData. It provides the following features backed by native AWSSDK:

  • Retrieve list of urls from each S3 bucket based on prefix
    • Support timeout to prevent hanging indefinitely
    • Support to specify S3 bucket region
  • Load data from S3 urls
    • Support buffered and multi-part download
    • Support to specify S3 bucket region

AWS native DataPipes are still in the beta phase. And, we will keep tuning them to improve their performance.

(Prototype) DataLoader2

DataLoader2 became available in prototype mode. We are introducing new ways to interact between DataPipes, DataLoading API, and backends (aka ReadingServices). Feature is stable in terms of API, but functionally not complete yet. We welcome early adopters and feedback, as well as potential contributors.

For more details, please checkout the link.


Inspired by Google JAX, functorch is a library that offers composable vmap (vectorization) and autodiff transforms. It enables advanced autodiff use cases that would otherwise be tricky to express in PyTorch. Examples of these include:

We’re excited to announce functorch 0.2.0 with a number of improvements and new experimental features.

Significantly improved coverage

We significantly improved coverage for functorch.jvp (our forward-mode autodiff API) and other APIs that rely on it (functorch.{jacfwd, hessian}).

(Prototype) functorch.experimental.functionalize

Given a function f, functionalize(f) returns a new function without mutations (with caveats). This is useful for constructing traces of PyTorch functions without in-place operations. For example, you can use make_fx(functionalize(f)) to construct a mutation-free trace of a pytorch function. To learn more, please see the documentation.

For more details, please see our installation instructions, documentation, tutorials, and release notes.

Performance Improvements

Introducing nvFuser, a deep learning compiler for PyTorch

In PyTorch 1.12, Torchscript is updating its default fuser (for Volta and later CUDA accelerators) to nvFuser, which supports a wider range of operations and is faster than NNC, the previous fuser for CUDA devices. A soon to be published blog post will elaborate on nvFuser and show how it speeds up training on a variety of networks.

See the nvFuser documentation for more details on usage and debugging.

Changes to float32 matrix multiplication precision on Ampere and later CUDA hardware

PyTorch supports a variety of “mixed precision” techniques, like the torch.amp (Automated Mixed Precision) module and performing float32 matrix multiplications using the TensorFloat32 datatype on Ampere and later CUDA hardware for faster internal computations. In PyTorch 1.12 we’re changing the default behavior of float32 matrix multiplications to always use full IEEE fp32 precision, which is more precise but slower than using the TensorFloat32 datatype for internal computation. For devices with a particularly high ratio of TensorFloat32 to float32 throughput such as A100, this change in defaults can result in a large slowdown.

If you’ve been using TensorFloat32 matrix multiplications then you can continue to do so by setting torch.backends.cuda.matmul.allow_tf32 = True

which is supported since PyTorch 1.7. Starting in PyTorch 1.12 the new matmul precision API can be used, too: torch.set_float32_matmul_precision(“highest”|”high”|”medium”)

To reiterate, PyTorch’s new default is “highest” precision for all device types. We think this provides better consistency across device types for matrix multiplications. Documentation for the new precision API can be found here. Setting the “high” or “medium” precision types will enable TensorFloat32 on Ampere and later CUDA devices. If you’re updating to PyTorch 1.12 then to preserve the current behavior and faster performance of matrix multiplications on Ampere devices, set precision to “high”.

Using mixed precision techniques is essential for training many modern deep learning networks efficiently, and if you’re already using torch.amp this change is unlikely to affect you. If you’re not familiar with mixed precision training then see our soon to be published “What Every User Should Know About Mixed Precision Training in PyTorch” blogpost.

(Beta) Accelerating PyTorch Vision Models with Channels Last on CPU

Memory formats have a significant impact on performance when running vision models, generally Channels Last is more favorable from a performance perspective due to better data locality. 1.12 includes fundamental concepts of memory formats and demonstrates performance benefits using Channels Last on popular PyTorch vision models on Intel® Xeon® Scalable processors.

  • Enables Channels Last memory format support for the commonly used operators in CV domain on CPU, applicable for both inference and training
  • Provides native level optimization on Channels Last kernels from ATen, applicable for both AVX2 and AVX512
  • Delivers 1.3x to 1.8x inference performance gain over Channels First for TorchVision models on Intel® Xeon® Ice Lake (or newer) CPUs

(Beta) Empowering PyTorch on Intel® Xeon® Scalable processors with Bfloat16

Reduced precision numeric formats like bfloat16 improves PyTorch performance across multiple deep learning training workloads. PyTorch 1.12 includes the latest software enhancements on bfloat16 which applies to a broader scope of user scenarios and showcases even higher performance gains. The main improvements include:

  • 2x hardware compute throughput vs. float32 with the new bfloat16 native instruction VDPBF16PS, introduced on Intel® Xeon® Cooper Lake CPUs
  • 1/2 memory footprint of float32, faster speed for memory bandwidth intensive operators
  • 1.4x to 2.2x inference performance gain over float32 for TorchVision models on Intel® Xeon® Cooper Lake (or newer) CPUs

(Prototype) Introducing Accelerated PyTorch Training on Mac

With the PyTorch 1.12 release, developers and researchers can now take advantage of Apple silicon GPUs for significantly faster model training. This unlocks the ability to perform machine learning workflows like prototyping and fine-tuning locally, right on Mac. Accelerated GPU training is enabled using Apple’s Metal Performance Shaders (MPS) as a backend. The benefits include performance speedup from accelerated GPU training and the ability to train larger networks or batch sizes locally. Learn more here.

Accelerated GPU training and evaluation speedups over CPU-only (times faster)

Alongside the new MPS device support, the M1 binaries for Core and Domain libraries that have been available for the last few releases are now an official prototype feature. These binaries can be used to run PyTorch natively on Apple Silicon.

(Prototype) BetterTransformer: Fastpath execution for Transformer Encoder Inference

PyTorch now supports CPU and GPU fastpath implementations (“BetterTransformer”) for several Transformer Encoder modules including TransformerEncoder, TransformerEncoderLayer, and MultiHeadAttention (MHA). The BetterTransformer fastpath architecture Better Transformer is consistently faster – 2x for many common execution scenarios, depending on model and input characteristics. The new BetterTransformer-enabled modules are API compatible with previous releases of the PyTorch Transformer API and will accelerate existing models if they meet fastpath execution requirements, as well as read models trained with previous versions of PyTorch. PyTorch 1.12 includes:

  • BetterTransformer integration for Torchtext’s pretrained RoBERTa and XLM-R models
  • Torchtext which builds on the PyTorch Transformer API
  • Fastpath execution for improved performance by reducing execution overheads with fused kernels which combines multiple operators into a single kernel
  • Option to achieve additional speedups by taking advantage of data sparsity during the processing of padding tokens in natural-language processing (by setting enable_nested_tensor=True when creating a TransformerEncoder)
  • Diagnostics to help users understand why fastpath execution did not occur


(Beta) Fully Sharded Data Parallel (FSDP) API

FSDP API helps easily scale large model training by sharding a model’s parameters, gradients and optimizer states across data parallel workers while maintaining the simplicity of data parallelism. The prototype version was released in PyTorch 1.11 with a minimum set of features that helped scaling tests of models with up to 1T parameters.

In this beta release, FSDP API added the following features to support various production workloads. Highlights of the the newly added features in this beta release include:

  1. Universal sharding strategy API – Users can easily change between sharding strategies with a single line change, and thus compare and use DDP (only data sharding), FSDP (full model and data sharding), or Zero2 (only sharding of optimizer and gradients) to optimize memory and performance for their specific training needs
  2. Fine grained mixed precision policies – Users can specify a mix of half and full data types (bfloat16, fp16 or fp32) for model parameters, gradient communication, and buffers via mixed precision policies. Models are automatically saved in fp32 to allow for maximum portability
  3. Transformer auto wrapping policy – allows for optimal wrapping of Transformer based models by registering the models layer class, and thus accelerated training performance
  4. Faster model initialization using device_id init – initialization is performed in a streaming fashion to avoid OOM issues and optimize init performance vs CPU init
  5. Rank0 streaming for full model saving of larger models – Fully sharded models can be saved by all GPU’s streaming their shards to the rank 0 GPU, and the model is built in full state on the rank 0 CPU for saving

For more details and example code, please checkout the documentation and the tutorial.

Thanks for reading, If you’re interested in these updates and want to join the PyTorch community, we encourage you to join the discussion forums and open GitHub issues. To get the latest news from PyTorch, follow us on Twitter, Medium, YouTube, and LinkedIn.


Team PyTorch

Read More

How Computational Graphs are Executed in PyTorch

How Computational Graphs are Executed in PyTorch

Welcome to the last entry into understanding the autograd engine of PyTorch series!
If you haven’t read parts 1 & 2 check them now to understand how PyTorch creates the computational graph for the backward pass!

This post is based on PyTorch v1.11, so some highlighted parts may differ across versions.

PyTorch autograd graph execution

The last post showed how PyTorch constructs the graph to calculate the outputs’ derivatives w.r.t. the inputs when executing the forward pass. Now we will see how the execution of the backward pass is coordinated and done by looking at the whole process, starting from Python down to the lower C++ level internals.

What Happens when Calling backward()/grad() from Python

Using variable.backward()

After doing all our calculations with an input set to require the gradient, we call .backward() on the result to initiate the backward pass execution.

>>> x = torch.tensor([0.5, 0.75], requires_grad=True)
>>> y = torch.exp(x).sum()
>>> y.backward()

Calling .backward() on a tensor results in a call to torch.autograd.backward().

# torch/

def backward(self, gradient=None, retain_graph=None, create_graph=False, inputs=None):
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

torch.autograd.backward() checks the arguments and calls the autograd engine in the C++ layer.

def backward(
    tensors: _TensorOrTensors,
    grad_tensors: Optional[_TensorOrTensors] = None,
    retain_graph: Optional[bool] = None,
    create_graph: bool = False,
    grad_variables: Optional[_TensorOrTensors] = None,
    inputs: Optional[_TensorOrTensors] = None,
) -> None:

    if inputs is not None and len(inputs) == 0:
        raise RuntimeError("'inputs' argument to backward() cannot be empty.")

    tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
    inputs = (inputs,) if isinstance(inputs, torch.Tensor) else 
        tuple(inputs) if inputs is not None else tuple()

    grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
    grad_tensors_ = _make_grads(tensors, grad_tensors_)
    if retain_graph is None:
        retain_graph = create_graph

        tensors, grad_tensors_, retain_graph, create_graph, inputs,
        allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag

First, whether the grad_tensors argument was specified or not, there is a call to the _make_grads function. This is used to check the provided grad_tensors or to specify the default value for them by looking at the tensors argument values’ shapes. Check the first blog post for details on the default value for the grad_tensors of the backward pass. This function just provides the vector of the vector jacobian product if it was not initially specified.

In the above code, Variable has an _execution_engine attribute that is defined in torch.autograd.variable to be of type ImperativeEngine; the C++ engine exported to python and declared in torch/csrc/autograd/python_engine.cpp. In the following sections, we explain in detail how this object executes the backward pass.

Note that the torch.autograd.backward function has an inputs optional argument. This argument is used when we want to calculate the .grad field of only a subset of input tensors in the forward pass.

>>> x = torch.tensor([0.5, 0.75], requires_grad=True)
>>> y = torch.tensor([0.1, 0.90], requires_grad=True)
>>> z = torch.exp(x * y).sum()
>>> torch.autograd.backward([z], inputs=[x])
>>> x.grad
tensor([0.1051, 1.7676])
>>> y.grad  # None

Using torch.autograd.grad

An alternative to backward() is to use torch.autograd.grad(). The main difference to backward() is that grad() returns a tuple of tensors with the gradients of the outputs w.r.t. the inputs kwargs instead of storing them in the .grad field of the tensors. As you can see, the grad() code shown below is very similar to backward.

def grad(
    outputs: _TensorOrTensors,
    inputs: _TensorOrTensors,
    grad_outputs: Optional[_TensorOrTensors] = None,
    retain_graph: Optional[bool] = None,
    create_graph: bool = False,
    only_inputs: bool = True,
    allow_unused: bool = False,
   is_grads_batched: bool = False
) -> Tuple[torch.Tensor, ...]:
    outputs = (outputs,) if isinstance(outputs, torch.Tensor) else tuple(outputs)
    inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
    overridable_args = outputs + inputs
    if has_torch_function(overridable_args):
        return handle_torch_function(

    grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs))
    grad_outputs_ = _make_grads(outputs, grad_outputs_)

    if retain_graph is None:
        retain_graph = create_graph

    if is_grads_batched:
        # …. It will not be covered here
        return Variable._execution_engine.run_backward(
            outputs, grad_outputs_, retain_graph, create_graph, inputs,
            allow_unused, accumulate_grad=False)  # Calls into the C++ engine to run the backward pass

Figure 1 shows the computational graph with the backward() and grad() arguments highlighted in red and blue, respectively:

Fgiure 1: Correspondence of `backward`/`grad` arguments in the graphs.

Going Inside the Autograd Engine

Refreshing Concepts: Nodes and Edges

As we saw in 2
The computational graph comprises Node and Edge objects. Please read that post if you haven’t done it yet.


Node objects are defined in torch/csrc/autograd/function.h, and they provide an overload of operator() for the associated function and a list of edges to do the graph traversal. Note that Node is a base class that autograd functions inherit from and override the apply method to execute the backward function.

struct TORCH_API Node : std::enable_shared_from_this<Node> {
 /// Evaluates the function on the given inputs and returns the result of the
  /// function call.
  variable_list operator()(variable_list&& inputs) {

  /// Performs the `Node`'s actual operation.
  virtual variable_list apply(variable_list&& inputs) = 0;
  edge_list next_edges_;
  uint64_t topological_nr_ = 0;

There is an attribute called topological_nr_ in every node object. This number is used to optimize the graph execution as it allows to discard of graph branches under certain conditions. The topological number is the longest distance between this node and any leaf node and it is shown in Figure 2. Its main property is that for any pair of nodes x, y in a directed graph topo_nr(x) < topo_nr(y) means that there is no path from x to y. So this allows for reducing the number of paths in the graph in need of traversal. Check the topological_nr
) method comment for further details.

Figure 2: Example of the Topological Number calculation


The Edge object links Nodes together, and its implementation is straightforward.

struct Edge {
  /// The function this `Edge` points to.
  std::shared_ptr<Node> function;
  /// The identifier of a particular input to the function.
  uint32_t input_nr;

It only requires a function pointer to the Node and an input number that is the index of the output from the forward function this edge points to. When preparing the set of gradients before calling “function”, we know that what is flowing from this edge should be accumulated in the “input_nr”th argument. Note that the input/output name is flipped here and this is the input to the backward function.
Edge objects are constructed using the gradient_edge function method.

 Edge gradient_edge(const Variable& self) {
    if (const auto& gradient = self.grad_fn()) {
      return Edge(gradient, self.output_nr());
    } else {
      return Edge(grad_accumulator(self), 0);

Entering the C++ Realm

Once that torch.autograd.backward() has been invoked, the
THPEngine_run_backward routine starts the graph traversal. Following is a schema of the function body:

PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwargs)
  PyObject *tensors = nullptr;
  PyObject *grad_tensors = nullptr;
  unsigned char keep_graph = 0;
  unsigned char create_graph = 0;
  PyObject *inputs = nullptr;
  // Convert the python arguments to C++ objects
  const char *accepted_kwargs[] = { // NOLINT
      "tensors", "grad_tensors", "keep_graph", "create_graph", "inputs",
      "allow_unreachable", "accumulate_grad", nullptr
  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Obb", (char**)accepted_kwargs,
        &tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable, &accumulate_grad))

 // Prepare arguments
 for(const auto i : c10::irange(num_tensors)) {
   // Check that the tensors require gradients

  std::vector<Edge> output_edges;
  if (inputs != nullptr) {
     // Prepare outputs

      // Calls the actual autograd engine
    pybind11::gil_scoped_release no_gil;
    outputs = engine.execute(roots, grads, keep_graph, create_graph, accumulate_grad, output_edges);
    // Clean up and finish

First, we prepare the input arguments after converting the PyObject arguments to actual C++ objects. The tensors list contains the tensors from which we start the backward pass. These tensors are converted to edges using torch::autograd::impl::gradient_edge and added to a list called roots where the graph traversal starts.

 edge_list roots;
  variable_list grads;
  for(const auto i : c10::irange(num_tensors)) {
    PyObject *_tensor = PyTuple_GET_ITEM(tensors, i);
       const auto& variable = THPVariable_Unpack(_tensor);
       auto gradient_edge = torch::autograd::impl::gradient_edge(variable);

    PyObject *grad = PyTuple_GET_ITEM(grad_tensors, i);
    if (THPVariable_Check(grad)) {
      const Variable& grad_var = THPVariable_Unpack(grad);

Now, if the inputs argument was specified in backward or we used the torch.autograd.grad api, the following code creates a list of edges to accumulate the gradients in the specified tensors at the end of the computation. The engine uses this later to optimize the execution as it doesn’t add the gradients in all the leaf nodes, just the specified ones.

  std::vector<Edge> output_edges;
  if (inputs != nullptr) {
    int num_inputs = PyTuple_GET_SIZE(inputs);
    for (const auto i : c10::irange(num_inputs)) {
      PyObject *input = PyTuple_GET_ITEM(inputs, i);
      const auto& tensor = THPVariable_Unpack(input);
      const auto output_nr = tensor.output_nr();
      auto grad_fn = tensor.grad_fn();
      if (!grad_fn) {
        grad_fn = torch::autograd::impl::try_get_grad_accumulator(tensor);
      if (accumulate_grad) {
      if (!grad_fn) {
        output_edges.emplace_back(std::make_shared<Identity>(), 0);
      } else {
        output_edges.emplace_back(grad_fn, output_nr);

The next step is the actual graph traversal and node function execution, and finally, the cleanup and return.

    // Calls the actual autograd engine
    pybind11::gil_scoped_release no_gil;
    auto& engine = python::PythonEngine::get_python_engine();
    outputs = engine.execute(roots, grads, keep_graph, create_graph, accumulate_grad, output_edges);
  // Clean up and finish

Starting the Real Execution

engine.executeis present in torch/csrc/autograd/engine.cpp

There are two differentiated steps here:

Analyze the graph to find dependencies between functions
Create worker threads that traverse the graph

Data Structures Used for the Execution


All the execution metadata is managed by the GraphTask class in torch/csrc/autograd/engine.h

struct GraphTask: std::enable_shared_from_this<GraphTask> {
  std::atomic<uint64_t> outstanding_tasks_{0};
  //  … 
  std::unordered_map<Node*, InputBuffer> not_ready_;
  std::unordered_map<Node*, int> dependencies_;

  struct ExecInfo {
     // …
  std::unordered_map<Node*, ExecInfo> exec_info_;
  std::vector<Variable> captured_vars_;
  // …
  std::shared_ptr<ReadyQueue> cpu_ready_queue_;

Here we see a series of variables dedicated to maintaining the execution state.
outstanding_tasks_ tracks the number of tasks left to be executed for the backward pass to complete. not_ready_ holds the input arguments for the Nodes that are not ready to be executed. dependencies_ track the number of predecessors that a Node has. As the count reaches 0, the Node is ready for execution; it is placed in a ready queue to be retrieved and executed later.

exec_info_ and the associated ExecInfo struct are used only when the inputs argument is specified or it is a call to autograd.grad(). They allow filter paths on the graph that are not needeed since only the gradients are calculated only for the variables in the inputs list.

captured_vars_ is where the results of the graph execution are temporarily stored if we used the torch.autograd.grad() api instead of torch.autograd.backward() since grad() returns the gradients as tensors instead of just filling the .grad field of the inputs.


The NodeTask struct is a basic class that holds an fn_ pointer to the node to execute, and an inputs_ buffer to store the input arguments to this function. Note that the functions executed by the backward pass are the derivatives specified in the derivatives.yaml file. or the user provided backward function when using custom functions as described in the second blog post.

The inputs_ buffer is also where the output gradients of the previously executed functions are aggregated, and it is defined as a std::vector<Variable> container with facilities to accumulate values at a given position.

struct NodeTask {
  std::weak_ptr<GraphTask> base_;
  std::shared_ptr<Node> fn_;
  // This buffer serves as an implicit "addition" node for all of the
  // gradients flowing here.  Once all the dependencies are finished, we
  // use the contents of this buffer to run the function.
  InputBuffer inputs_;


The GraphRoot is a special function used to hold multiple input variables in a single place. The code is pretty simple as it only acts as a container of variables.

struct TORCH_API GraphRoot : public Node {
  GraphRoot(edge_list functions, variable_list inputs)
      : Node(std::move(functions)),
      outputs(std::move(inputs)) {
    for (const auto& t : outputs) {

  variable_list apply(variable_list&& inputs) override {
    return outputs;


This function is set during the graph creation in gradient_edge when the Variable object doesn’t have a grad_fn. This is, it is a leaf node.

    if (const auto& gradient = self.grad_fn()) {
      // …
    } else {
      return Edge(grad_accumulator(self), 0);

The function body is defined in torch/csrc/autograd/functions/accumulate_grad.cpp and it essentially accumulates the input grads in the object’s .grad attribute.

auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
  check_input_variables("AccumulateGrad", grads, 1, 0);

  at::Tensor new_grad = callHooks(variable, std::move(grads[0]));
  std::lock_guard<std::mutex> lock(mutex_);

  at::Tensor& grad = variable.mutable_grad();
      1 + !post_hooks().empty() /* num_expected_refs */,
      [&grad](at::Tensor&& grad_update) { grad = std::move(grad_update); });
  return variable_list();
}} // namespace torch::autograd

does several checks on the tensors format and eventually performs the variable_grad += new_grad; accumulation.

Preparing the graph for execution

Now, let’s walk through Engine::execute. The first thing to do besides arguments consistency checks is to create the actual GraphTask object we described above. This object keeps all the metadata of the graph execution.

auto Engine::execute(const edge_list& roots,
                     const variable_list& inputs,
                     bool keep_graph,
                     bool create_graph,
                     bool accumulate_grad,
                     const edge_list& outputs) -> variable_list {

  validate_outputs(roots, const_cast<variable_list&>(inputs), [](const std::string& msg) {
    return msg;

  // Checks

  auto graph_task = std::make_shared<GraphTask>(
      /* keep_graph */ keep_graph,
      /* create_graph */ create_graph,
      /* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
      /* cpu_ready_queue */ local_ready_queue);

  // If we receive a single root, skip creating extra root node
  // …
  // Prepare graph by computing dependencies
  // …
  // Queue the root 
  // …
  // launch execution
  // …

After creating the GraphTask, we use its associated function if we only have one root node. If we have multiple root nodes, we create a special GraphRoot object as described before.

  bool skip_dummy_node = roots.size() == 1;
  auto graph_root = skip_dummy_node ? :
    std::make_shared<GraphRoot>(roots, inputs);

The next step is to fill the dependencies_ map in the GraphTask object since the engine must know when it can execute a task. The outputs here is the inputs argument passed to the torch.autograd.backward() call in Python. But here, we have reversed the names since the gradients w.r.t. the inputs of the forward pass are now the outputs of the backward pass. And from now on, there is no concept of forward/backward, but only graph traversal and execution.

  auto min_topo_nr = compute_min_topological_nr(outputs);
  // Now compute the dependencies for all executable functions
  compute_dependencies(graph_root.get(), *graph_task, min_topo_nr);

  if (!outputs.empty()) {
    graph_task->init_to_execute(*graph_root, outputs, accumulate_grad, min_topo_nr);

Here we preprocess the graph for the execution of the nodes. First, compute_min_topological_nr is called to to obtain the minimum topological number of the tensors specified in outputs (0 if no inputs kwarg was supplied to .backward or input for .grad). This computation prunes paths in the graph that lead to input variables of which we don’t want/need to calculate the grads.

Second, is the compute_dependencies call. This function is a very simple graph traversal that starts with the root Node, and for each of the edges in node.next_edges() it increments the counter in dependencies_. Figure 3 shows the result of the dependencies calculation for the example graph. Note that the number of dependencies of any node is just the number of edges arriving at it.

Figure 3: Number of dependencies for each node

Finally, the init_to_execute call, this is the one that populates the GraphTask::exec_info_ map in case that inputs were specified in the python backward call. It iterates the graph again, starting from the root, and records in the exec_info_ map the intermediate nodes needed to calculate only the given inputs gradients.

  // Queue the root
  if (skip_dummy_node) {
    InputBuffer input_buffer(>num_inputs());
    auto input =;


    execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
  } else {
    execute_with_graph_task(graph_task, graph_root, InputBuffer(variable_list()));
  // Avoid a refcount bump for the Future, since we check for refcount in
  // DistEngine (see TORCH_INTERNAL_ASSERT(futureGrads.use_count() == 1)
  // in dist_engine.cpp).
  auto& fut = graph_task->future_result_;
  return fut->value().toTensorVector();

And now, we are ready to start the actual execution by creating the InputBuffer. In case we only have one root variable, we begin by copying the value of the inputs tensor (this is the gradients passed to python backward) in position 0 of the input_buffer. This is a small optimization that avoids running the RootNode for no reason. Also, if the rest of the graph is not on the cpu, we directly start on that worker while the RootNode is always placed on the cpu ready queue. Details of the workers and ready queues are explained in the section below.

On the other hand, if we have multiple roots, the GraphRoot object also holds the inputs, so it is enough to pass it an empty InputBuffer.

Graph Traversal and Node Execution

Devices, Threads and Queues

Before diving into the actual execution, we need to see how the engine is structured.

First of all, the engine is multithreaded with one thread per device. For example, the caller thread is associated with the CPU while additional threads are created and associated with each GPU or other devices available in the system. Each thread tracks its device using thread-local storage in the worker_device variable. In addition, the threads have a queue of tasks to be executed also located in thread-local storage, the local_ready_queue. This is where work is queued for this thread to execute in the thread_main function that is explained later.
You will wonder how the device where a task should be executed is decided. The InputBuffer class has a device() function that returns the first non-cpu device of all its tensors.
This function is used together with Engine::ready_queue to select the queue to queue a task.

auto Engine::ready_queue(std::shared_ptr<ReadyQueue> cpu_ready_queue, at::Device device) -> std::shared_ptr<ReadyQueue>{
  if (device.type() == at::kCPU || device.type() == at::DeviceType::Meta) {
    return cpu_ready_queue;
  } else {
    // See Note [Allocating GPUs to autograd threads]

The ReadyQueue object is defined in torch/csrc/autograd/engine.h and it is a simple wrapper over std::priority_queue that allows a thread to wait for a task if it’s empty. One interesting property of the ReadyQueue is that it increases the GraphTask::outstanding_tasks_ value used to determine if the execution has completed or not.

auto ReadyQueue::push(NodeTask item, bool incrementOutstandingTasks) -> void {
    std::lock_guard<std::mutex> lock(mutex_);
    if (incrementOutstandingTasks) {
      std::shared_ptr<GraphTask> graph_task = item.base_.lock();

auto ReadyQueue::pop() -> NodeTask {
  std::unique_lock<std::mutex> lock(mutex_);
  not_empty_.wait(lock, [this]{ return !heap_.empty(); });
  auto task = std::move(const_cast<NodeTask&>(; heap_.pop();
  return task;

Reentrant Backward

A reentrant backward happens when one of the tasks in a backward pass calls again backward. It is not a very common case, but it can be used to reduce memory utilization as it could potentially avoid saving intermediate results. For more information, check this PyTorch forum post.

class ReentrantBackward(torch.autograd.Function):
    def forward(ctx, input):
        return input.sum()

    def backward(ctx, input):
        # Let's compute the backward by using autograd
        input = input.detach().requires_grad_()
        with torch.enable_grad():
            out = input.sum()
        out.backward()  # REENTRANT CALL!!
        return out.detach()

Here, we call backward() inside backward() for a user custom-defined autograd function.
This situation can lead to deadlocks because the first backward needs to wait for the second one to complete. But some internal implementation details can prevent the second backward from completing as it is explained in the dedicated subsection.

Thread Initialization

execute_with_graph_task is in charge of initializing the threads taking care of the computation and placing the root node in the queue of the device that produced it.

c10::intrusive_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
    const std::shared_ptr<GraphTask>& graph_task,
    std::shared_ptr<Node> graph_root,
    InputBuffer&& input_buffer) {

  // Lock mutex for GraphTask.
  std::unique_lock<std::mutex> lock(graph_task->mutex_);

  auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device());

  if (worker_device == NO_DEVICE) {
    graph_task->owner_ = worker_device;
    queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));
    worker_device = NO_DEVICE;
  } else {
     // This deals with reentrant backwards, we will see it later.
  return graph_task->future_result_;

First, this function initializes several threads (one per device) calling initialize_device_threads_pool() where several things happen:
One ReadyQueue per device is created.
One thread per non-cpu device is created.
A thread local worker_device variable is set to track the current device associated with the thread.
thread_main function is called, and threads wait for tasks to be put in their queues.

Then it retrieves the queue to place the root node based on the device that holds the tensors present in the input_buffer using the ready_queue function. Now, the main thread (the one also executing the Python interpreter) has its worker_device set to NO_DEVICE, and it is in charge of executing functions with all its tensors living in the cpu. If worker_device is set to any other value, the graph execution is already started, and .backward() was called inside a running Node, creating a reentrant backward call. This is explained later. For now,
the main thread places the task in the queue and call thread_main.

Where the Magic Happens

It’s been a long way, but finally, we are ready to traverse the graph and execute the nodes. Each of the spawned threads, and the main thread call thread_main.

auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {

  while (graph_task == nullptr || !graph_task->future_result_->completed()) {
    std::shared_ptr<GraphTask> local_graph_task;
      NodeTask task = local_ready_queue->pop();

      if (task.isShutdownTask_) {

      if (!(local_graph_task = task.base_.lock())) {
        // GraphTask for function is no longer valid, skipping further
        // execution.

      if (task.fn_ && !local_graph_task->has_error_.load()) {
        at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_);

        try {
          GraphTaskGuard guard(local_graph_task);
          NodeGuard ndguard(task.fn_);
        } catch (std::exception& e) {
          thread_on_exception(local_graph_task, task.fn_, e);

    // Decrement the outstanding tasks.

    // Check if we've completed execution.
    if (local_graph_task->completed()) {
      auto base_owner = local_graph_task->owner_;
      if (worker_device != base_owner) {
        ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
            ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));

The code here is simple, given the local_ready_queue assigned to each thread in thread-local storage. The threads loop until there are no tasks left to execute in the graph. Note that for device-associated threads, the passed graph_task argument is nullptr, and they block in local_ready_queue->pop() until a task is pushed in their queue. After some consistency checks (the task type is shutdown, or the graph is still valid). We get to the actual function invocation in evaluate_function.

        try {
          GraphTaskGuard guard(local_graph_task);
          NodeGuard ndguard(task.fn_);
        } catch (std::exception& e) {
          thread_on_exception(local_graph_task, task.fn_, e);

After calling evaluate_function, we check if the graph_task execution is complete by looking the outstanding_tasks_ number. This number increases when a task is pushed to a queue and is decreased in local_graph_task->completed() when a task is executed. When the execution is done, we return the results that are be in the captured_vars_ in case we called torch.autograd.grad() instead of torch.autograd.backward() as this function returns tensors instead of storing them in the .grad attribute of the inputs. Finally we wake up the main thread if it’s waiting by sending a dummy task.

   // Decrement the outstanding tasks.

    // Check if we've completed execution.
    if (local_graph_task->completed()) {
      auto base_owner = local_graph_task->owner_;
      if (worker_device != base_owner) {
        ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
            ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));

Calling the Function and Unlocking New Tasks

evaluate_function serves three purposes:

Run the function.
Accumulate its results in the next node InputBuffers.
Decrease the dependencies counter of the next nodes and enqueues the tasks reaching 0 to be executed.

void Engine::evaluate_function(
    std::shared_ptr<GraphTask>& graph_task,
    Node* func,
    InputBuffer& inputs,
    const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {

  // If exec_info_ is not empty, we have to instrument the execution
  auto& exec_info_ = graph_task->exec_info_;
  if (!exec_info_.empty()) {
    // Checks if the function needs to be executed 
    if (!fn_info.needed_) {
      // Skip execution if we don't need to execute the function.

  auto outputs = call_function(graph_task, func, inputs);

  auto& fn = *func;
  if (!graph_task->keep_graph_) {

Initially, we check the exec_info_ map of the GraphTask structure to determine if the current node needs to be executed. Remember that if this map is empty, all the nodes are executed because we are calculating the grads for all the inputs of the forward pass.

After this check, the function is executed by running call_function. Its implementation is very straightforward and calls the actual derivative function and registered hooks if any.

  int num_outputs = outputs.size();
  if (num_outputs == 0) {
    // Records leaf stream (if applicable)

  if (AnomalyMode::is_enabled()) {
    // check for nan values in result

Next, we check the outputs of the function after call_function is done. If the number of outputs is 0, there are no following nodes to be executed so we can safely return. This is the case of the AccumulateGrad node associated with the leaf nodes.

Also, the check for NaN values in the gradients is done here if requested.

  std::lock_guard<std::mutex> lock(graph_task->mutex_);
  for (const auto i : c10::irange(num_outputs)) {
    auto& output = outputs[i];
    const auto& next = fn.next_edge(i);

    if (!next.is_valid()) continue;


We have now executed a grad_fn that has returned one gradient per each of the associated forward pass function inputs. As we saw in the previous blog post, we have an Edge object per each of these input tensors, and the grad_fn of the function producing them in the forward pass. Essentially, Output[0] of the node in the backward pass, corresponds to the first argument of the forward pass associated function. Figure 4 shows how the outputs of a backward function are related to the inputs of the forward function. See that the outputs of grad_fn C are the gradients of z w.r.t. the inputs of Function C

Figure 4: Correspondence between forward and backward functions inputs and outputs

We now iterate through these edges and check if the associated functions are ready to be executed.

 // Check if the next function is ready to be computed
    bool is_ready = false;
    auto& dependencies = graph_task->dependencies_;
    auto it = dependencies.find(next.function.get());

    if (it == dependencies.end()) {
      auto name = next.function->name();
      throw std::runtime_error(std::string("dependency not found for ") + name);
    } else if (--it->second == 0) {
      is_ready = true;

    auto& not_ready = graph_task->not_ready_;
    auto not_ready_it = not_ready.find(next.function.get());

For this, we check the graph_task->dependencies_ map. We decrement the counter, and if it reaches 0, we mark the function pointed by the edge ready to be executed. Following, we prepare the input buffers of the tasks indicated by the next edges.

    if (not_ready_it == not_ready.end()) {
      if (!exec_info_.empty()) {
        // Skip functions that aren't supposed to be executed

      // Creates an InputBuffer and moves the output to the corresponding input position
      InputBuffer input_buffer(next.function->num_inputs());

      if (is_ready) {
        auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
            NodeTask(graph_task, next.function, std::move(input_buffer)));
      } else {
        not_ready.emplace(next.function.get(), std::move(input_buffer));

Here, we look for the task in the graph_task->not_ready_ map. If it is not present, we create a new InputBuffer object and set the current output in the input_nr position of the buffer associated with the edge. If the task is ready to be executed, we enqueue it in the appropriate device ready_queue and complete the execution. However, if the task is not ready and we have seen it before, it is present in the not_ready_map_.

    } else {
      // The function already has a buffer
      auto &input_buffer = not_ready_it->second;
      // Accumulates into buffer
      if (is_ready) {
        auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
        queue->push(NodeTask(graph_task, next.function, std::move(input_buffer)));

In this case, we accumulate the output in the existing input_buffer instead of creating a new one. Once all the tasks are processed, the worker thread exits the loop and complete.
All this process is summarized in the animation in Figure 5. We see how a thread peeks at the tasks in the ready queue and decrements the next nodes’ dependencies, unlocking them for execution.

Figure 5: Animation of the execution of the computational graph

Flow with Reentrant Backward

As we saw above, the reentrant backward problem is when the currently executed function does a nested call to backward. When this happens, the thread running this function goes all the way down to execute_with_graph_task as in the non-reentrant case, but here is when things are different.

c10::intrusive_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
    const std::shared_ptr<GraphTask>& graph_task,
    std::shared_ptr<Node> graph_root,
    InputBuffer&& input_buffer) {

  // Lock mutex for GraphTask.
  std::unique_lock<std::mutex> lock(graph_task->mutex_);

  auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device());

  if (worker_device == NO_DEVICE) {
    //Regular case
  } else {
    // If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant
    //    backward call from that device.
    graph_task->owner_ = worker_device;

    // Now that all the non-thread safe fields of the graph_task have been populated,
    // we can enqueue it.
    queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));

    if (current_depth >= max_recursion_depth_) {
      // If reached the max depth, switch to a different thread
    } else {
  return graph_task->future_result_;

Here, execute_with_graph_task detects this as a reentrant call and then looks for the current number of nested calls. If it exceeds the limit, we create a new thread to take care of the execution of this graph, and if not, we execute this reentrant call regularly.
The limit of nested calls was originally set to avoid stack overflow due to reentrant calls creating very large call stacks. However, the number was further reduced when sanitizer tests were added because of the maximum amount of locks a thread can hold at a given moment. This can be seen in torch/csrc/autograd/engine.h.

When this maximum depth is exceeded, a new thread is created with the add_thread_pool_task function.

void Engine::add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task) {
  std::unique_lock<std::mutex> lck(thread_pool_shared_->mutex_);
  // if we have pending graph_task objects to be processed, create a worker.
   bool create_thread = (thread_pool_shared_->num_workers_ <= thread_pool_shared_->graphtasks_queue_.size());

  if (create_thread) {
    std::thread t(&Engine::reentrant_thread_init, this);


Before going in-depth, let’s look at the thread_pool_shared_ object in the Engine which manages all the information related to the threads associated to the reentrant backward calls.

  struct ThreadPoolShared {
    unsigned int num_workers_;
    std::condition_variable work_;
    std::mutex mutex_;
    std::queue<std::weak_ptr<GraphTask>> graphtasks_queue_;

    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
    ThreadPoolShared() : num_workers_(0) {}

ThreadPoolShared is a simple container holding a queue of GraphTask objects with synchronization mechanisms and the number of current workers.

Now it is easy to understand how add_thread_pool_task creates a thread when there are graph_task objects enqueued and insufficient workers to process them.

add_thread_pool_task initializes a thread by executing reentrant_thread_init

void Engine::reentrant_thread_init() {
  auto tp_shared = thread_pool_shared_;
  while(true) {
    std::unique_lock<std::mutex> lk(tp_shared->mutex_);
    tp_shared->work_.wait(lk, [&tp_shared]{ return !tp_shared->graphtasks_queue_.empty();});
    auto task = tp_shared->graphtasks_queue_.front();
    std::shared_ptr<GraphTask> graph_task;
    if (!(graph_task = task.lock())) {
    // set the local_ready_queue to the ready queue on the graph_task->owner_ device
    local_ready_queue = ready_queue_by_index(graph_task->cpu_ready_queue_, graph_task->owner_);
    total_depth = graph_task->reentrant_depth_;

The code is straightforward. The newly created thread waits on the thread_pool_shared->graphtasks_queue_ for reentrant backward graphs to be available and executes them. Notice that this thread uses the task-ready queue associated with the device of the thread that started this call by accessing the graph_task->owner_ field set in the execute_with_graph_task function.

Error Handling

Whenever an error happens in one of the worker threads. It will be propagated to the backward calling thread.

To achieve this, there is a try/catch block in the thread_main that catches any exception in the Node function call and sets it to the associated GraphTask object.

       try {
          GraphTaskGuard guard(local_graph_task);
          NodeGuard ndguard(task.fn_);
        } catch (std::exception& e) {
          thread_on_exception(local_graph_task, task.fn_, e);

thread_on_exception and the functions it calls end up setting the exception in the local_graph_task object.

void Engine::thread_on_exception(
    std::shared_ptr<GraphTask> graph_task,
    const std::shared_ptr<Node>& fn,
    std::exception& e) {
  graph_task->set_exception(std::current_exception(), fn);

void GraphTask::set_exception_without_signal(const std::shared_ptr<Node>& fn) {
  if (! {
    if (AnomalyMode::is_enabled() && fn) {

void GraphTask::set_exception(
    std::exception_ptr eptr,
    const std::shared_ptr<Node>& fn) {
  if (! {
    // NOLINTNEXTLINE(performance-move-const-arg)

In set_exception it sets the has_error_ flag to true and it calls the setError
function of the future_result_ object. This will make the error to be re-thrown at the caller thread when future_result_->value() is accessed.

 IValue value() {
    std::unique_lock<std::mutex> lock(mutex_);
    if (eptr_) {
    return value_;

Closing Remarks

This has been the last post of this series covering how PyTorch does the auto differentiation. We hope you enjoyed reading it and that now you are familiar enough with PyTorch internals to start contributing in PyTorch development!

Read More

Geospatial deep learning with TorchGeo

Geospatial deep learning with TorchGeo

TorchGeo is a PyTorch domain library providing datasets, samplers, transforms, and pre-trained models specific to geospatial data.

For decades, Earth observation satellites, aircraft, and more recently UAV platforms have been collecting increasing amounts of imagery of the Earth’s surface. With information about seasonal and long-term trends, remotely sensed imagery can be invaluable for solving some of the greatest challenges to humanity, including climate change adaptation, natural disaster monitoring, water resource management, and food security for a growing global population. From a computer vision perspective, this includes applications like land cover mapping (semantic segmentation), deforestation and flood monitoring (change detection), glacial flow (pixel tracking), hurricane tracking and intensity estimation (regression), and building and road detection (object detection, instance segmentation). By leveraging recent advancements in deep learning architectures, cheaper and more powerful GPUs, and petabytes of freely available satellite imagery datasets, we can come closer to solving these important problems.

National Oceanic and Atmospheric Administration satellite image of Hurricane Katrina, taken on August 28, 2005 (source). Geospatial machine learning libraries like TorchGeo can be used to detect, track, and predict future trajectories of hurricanes and other natural disasters.

The challenges

In traditional computer vision datasets, such as ImageNet, the image files themselves tend to be rather simple and easy to work with. Most images have 3 spectral bands (RGB), are stored in common file formats like PNG or JPEG, and can be easily loaded with popular software libraries like PIL or OpenCV. Each image in these datasets is usually small enough to pass directly into a neural network. Furthermore, most of these datasets contain a finite number of well-curated images that are assumed to be independent and identically distributed, making train-val-test splits straightforward. As a result of this relative homogeneity, the same pre-trained models (e.g., CNNs pretrained on ImageNet) have shown to be effective across a wide range of vision tasks using transfer learning methods. Existing libraries, such as torchvision, handle these simple cases well, and have been used to make large advances in vision tasks over the past decade.

Remote sensing imagery is not so uniform. Instead of simple RGB images, satellites tend to capture images that are multispectral (Landsat 8 has 11 spectral bands) or even hyperspectral (Hyperion has 242 spectral bands). These images capture information at a wider range of wavelengths (400 nm–15 µm), far outside of the visible spectrum. Different satellites also have very different spatial resolutions—GOES has a resolution of 4 km/px, Maxar imagery is 30 cm/px, and drone imagery resolution can be as high as 7 mm/px. These datasets almost always have a temporal component, with satellite revisists that are daily, weekly, or biweekly. Images often have overlap with other images in the dataset, and need to be stitched together based on geographic metadata. These images tend to be very large (e.g., 10K x 10K pixels), so it isn’t possible to pass an entire image through a neural network. This data is distributed in hundreds of different raster and vector file formats like GeoTIFF and ESRI Shapefile, requiring specialty libraries like GDAL to load.

From left to right: Mercator, Albers Equal Area, and Interrupted Goode Homolosine projections (source). Geospatial data is associated with one of many different types of reference systems that project the 3D Earth onto a 2D representation. Combining data from different sources often involves re-projecting to a common reference system in order to ensure that all layers are aligned.

Although each image is 2D, the Earth itself is 3D. In order to stitch together images, they first need to be projected onto a 2D representation of the Earth, called a coordinate reference system (CRS). Most people are familiar with equal angle representations like Mercator that distort the size of regions (Greenland looks larger than Africa even though Africa is 15x larger), but there are many other CRSs that are commonly used. Each dataset may use a different CRS, and each image within a single dataset may also be in a unique CRS. In order to use data from multiple layers, they must all share a common CRS, otherwise the data won’t be properly aligned. For those who aren’t familiar with remote sensing data, this can be a daunting task.

Even if you correctly georeference images during indexing, if you don’t project them to a common CRS, you’ll end up with rotated images with nodata values around them, and the images won’t be pixel-aligned.

The solution

At the moment, it can be quite challenging to work with both deep learning models and geospatial data without having expertise in both of these very different fields. To address these challenges, we’ve built TorchGeo, a PyTorch domain library for working with geospatial data. TorchGeo is designed to make it simple:

  1. for machine learning experts to work with geospatial data, and
  2. for remote sensing experts to explore machine learning solutions.

TorchGeo is not just a research project, but a production-quality library that uses continuous integration to test every commit with a range of Python versions on a range of platforms (Linux, macOS, Windows). It can be easily installed with any of your favorite package managers, including pip, conda, and spack:

$ pip install torchgeo

TorchGeo is designed to have the same API as other PyTorch domain libraries like torchvision, torchtext, and torchaudio. If you already use torchvision in your workflow for computer vision datasets, you can switch to TorchGeo by changing only a few lines of code. All TorchGeo datasets and samplers are compatible with the PyTorch DataLoader class, meaning that you can take advantage of wrapper libraries like PyTorch Lightning for distributed training. In the following sections, we’ll explore possible use cases for TorchGeo to show how simple it is to use.

Geospatial datasets and samplers

Example application in which we combine A) a scene from Landsat 8 and B) Cropland Data Layer labels, even though these files are in different EPSG projections. We want to sample patches C) and D) from these datasets using a geospatial bounding box as an index.

Many remote sensing applications involve working with geospatial datasets —datasets with geographic metadata. In TorchGeo, we define a GeoDataset class to represent these kinds of datasets. Instead of being indexed by an integer, each GeoDataset is indexed by a spatiotemporal bounding box, meaning that two or more datasets covering a different geographic extent can be intelligently combined.

In this example, we show how easy it is to work with geospatial data and to sample small image patches from a combination of Landsat and Cropland Data Layer (CDL) data using TorchGeo. First, we assume that the user has Landsat 7 and 8 imagery downloaded. Since Landsat 8 has more spectral bands than Landsat 7, we’ll only use the bands that both satellites have in common. We’ll create a single dataset including all images from both Landsat 7 and 8 data by taking the union between these two datasets.

from import DataLoader
from torchgeo.datasets import CDL, Landsat7, Landsat8, stack_samples
from torchgeo.samplers import RandomGeoSampler

landsat7 = Landsat7(root="...")
landsat8 = Landsat8(root="...", bands=Landsat8.all_bands[1:-2])
landsat = landsat7 | landsat8

Next, we take the intersection between this dataset and the CDL dataset. We want to take the intersection instead of the union to ensure that we only sample from regions where we have both Landsat and CDL data. Note that we can automatically download and checksum CDL data. Also note that each of these datasets may contain files in different CRSs or resolutions, but TorchGeo automatically ensures that a matching CRS and resolution is used.

cdl = CDL(root="...", download=True, checksum=True)
dataset = landsat & cdl

This dataset can now be used with a PyTorch data loader. Unlike benchmark datasets, geospatial datasets often include very large images. For example, the CDL dataset consists of a single image covering the entire contiguous United States. In order to sample from these datasets using geospatial coordinates, TorchGeo defines a number of samplers. In this example, we’ll use a random sampler that returns 256 x 256 pixel images and 10,000 samples per epoch. We’ll also use a custom collation function to combine each sample dictionary into a mini-batch of samples.

sampler = RandomGeoSampler(dataset, size=256, length=10000)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler, collate_fn=stack_samples)

This data loader can now be used in your normal training/evaluation pipeline.

for batch in dataloader:
    image = batch["image"]
    mask = batch["mask"]

    # train a model, or make predictions using a pre-trained model

Many applications involve intelligently composing datasets based on geospatial metadata like this. For example, users may want to:

  • Combine datasets for multiple image sources and treat them as equivalent (e.g., Landsat 7 and 8)
  • Combine datasets for disparate geospatial locations (e.g., Chesapeake NY and PA)

These combinations require that all queries are present in at least one dataset, and can be created using a UnionDataset. Similarly, users may want to:

  • Combine image and target labels and sample from both simultaneously (e.g., Landsat and CDL)
  • Combine datasets for multiple image sources for multimodal learning or data fusion (e.g., Landsat and Sentinel)

These combinations require that all queries are present in both datasets, and can be created using an IntersectionDataset. TorchGeo automatically composes these datasets for you when you use the intersection (&) and union (|) operators.

Multispectral and geospatial transforms

In deep learning, it’s common to augment and transform the data so that models are robust to variations in the input space. Geospatial data can have variations such as seasonal changes and warping effects, as well as image processing and capture issues like cloud cover and atmospheric distortion. TorchGeo utilizes augmentations and transforms from the Kornia library, which supports GPU acceleration and supports multispectral imagery with more than 3 channels.

Traditional geospatial analyses compute and visualize spectral indices which are combinations of multispectral bands. Spectral indices are designed to highlight areas of interest in a multispectral image relevant to some application, such as vegetation health, areas of man-made change or increasing urbanization, or snow cover. TorchGeo supports numerous transforms, which can compute common spectral indices and append them as additional bands to a multispectral image tensor.

Below, we show a simple example where we compute the Normalized Difference Vegetation Index (NDVI) on a Sentinel-2 image. NDVI measures the presence of vegetation and vegetation health and is computed as the normalized difference between the red and near-infrared (NIR) spectral bands. Spectral index transforms operate on sample dictionaries returned from TorchGeo datasets and append the resulting spectral index to the image channel dimension.

First, we instantiate a Sentinel-2 dataset and load a sample image. Then, we plot the true color (RGB) representation of this data to see the region we are looking at.

import matplotlib.pyplot as plt
from torchgeo.datasets import Sentinel2
from torchgeo.transforms import AppendNDVI

dataset = Sentinel2(root="...")
sample = dataset[...]
fig = dataset.plot(sample)

Next, we instantiate and compute an NDVI transform, appending this new channel to the end of the image. Sentinel-2 imagery uses index 0 for its red band and index 3 for its NIR band. In order to visualize the data, we also normalize the image. NDVI values can range from -1 to 1, but we want to use the range 0 to 1 for plotting.

transform = AppendNDVI(index_red=0, index_nir=3)
sample = transform(sample)
sample["image"][-1] = (sample["image"][-1] + 1) / 2
plt.imshow(sample["image"][-1], cmap="RdYlGn_r")

True color (left) and NDVI (right) of the Texas Hill Region, taken on November 16, 2018 by the Sentinel-2 satellite. In the NDVI image, red indicates water bodies, yellow indicates barren soil, light green indicates unhealthy vegetation, and dark green indicates healthy vegetation.

Benchmark datasets

One of the driving factors behind progress in computer vision is the existence of standardized benchmark datasets like ImageNet and MNIST. Using these datasets, researchers can directly compare the performance of different models and training procedures to determine which perform the best. In the remote sensing domain, there are many such datasets, but due to the aforementioned difficulties of working with this data and the lack of existing libraries for loading these datasets, many researchers opt to use their own custom datasets.

One of the goals of TorchGeo is to provide easy-to-use data loaders for these existing datasets. TorchGeo includes a number of benchmark datasets —datasets that include both input images and target labels. This includes datasets for tasks like image classification, regression, semantic segmentation, object detection, instance segmentation, change detection, and more.

If you’ve used torchvision before, these types of datasets should be familiar. In this example, we’ll create a dataset for the Northwestern Polytechnical University (NWPU) very-high-resolution ten-class (VHR-10) geospatial object detection dataset. This dataset can be automatically downloaded, checksummed, and extracted, just like with torchvision.

from import DataLoader
from torchgeo.datasets import VHR10

dataset = VHR10(root="...", download=True, checksum=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

for batch in dataloader:
    image = batch["image"]
    label = batch["label"]

    # train a model, or make predictions using a pre-trained model

All TorchGeo datasets are compatible with PyTorch data loaders, making them easy to integrate into existing training workflows. The only difference between a benchmark dataset in TorchGeo and a similar dataset in torchvision is that each dataset returns a dictionary with keys for each PyTorch Tensor.

Example predictions from a Mask R-CNN model trained on the NWPU VHR-10 dataset. The model predicts sharp bounding boxes and masks for all objects with high confidence scores.

Reproducibility with PyTorch Lightning

Another key goal of TorchGeo is reproducibility. For many of these benchmark datasets, there is no predefined train-val-test split, or the predefined split has issues with class imbalance or geographic distribution. As a result, the performance metrics reported in the literature either can’t be reproduced, or aren’t indicative of how well a pre-trained model would work in a different geographic location.

In order to facilitate direct comparisons between results published in the literature and further reduce the boilerplate code needed to run experiments with datasets in TorchGeo, we have created PyTorch Lightning datamodules with well-defined train-val-test splits and trainers for various tasks like classification, regression, and semantic segmentation. These datamodules show how to incorporate augmentations from the kornia library, include preprocessing transforms (with pre-calculated channel statistics), and let users easily experiment with hyperparameters related to the data itself (as opposed to the modeling process). Training a regression model on the Inria Aerial Image Labeling dataset is as easy as a few imports and four lines of code.

from pytorch_lightning import Trainer
from torchgeo.datamodules import InriaAerialImageLabelingDataModule
from torchgeo.trainers import SemanticSegmentationTask

datamodule = InriaAerialImageLabelingDataModule(root_dir="...", batch_size=64, num_workers=6)
task = SemanticSegmentationTask(model="resnet18", pretrained=True, learning_rate=0.1)
trainer = Trainer(gpus=1, default_root_dir="..."), datamodule=datamodule)

Building segmentations produced by a U-Net model trained on the Inria Aerial Image Labeling dataset. Reproducing these results is as simple as a few imports and four lines of code, making comparison of different models and training techniques simple and easy.

In our preprint we show a set of results that use the aforementioned datamodules and trainers to benchmark simple modeling approaches for several of the datasets in TorchGeo. For example, we find that a simple ResNet-50 can achieve state-of-the-art performance on the So2Sat dataset. These types of baseline results are important for evaluating the contribution of different modeling choices when tackling problems with remotely sensed data.

Future work and contributing

There is still a lot of remaining work to be done in order to make TorchGeo as easy to use as possible, especially for users without prior deep learning experience. One of the ways in which we plan to achieve this is by expanding our tutorials to include subjects like “writing a custom dataset” and “transfer learning”, or tasks like “land cover mapping” and “object detection”.

Another important project we are working on is pre-training models. Most remote sensing researchers work with very small labeled datasets, and could benefit from pre-trained models and transfer learning approaches. TorchGeo is the first deep learning library to provide models pre-trained on multispectral imagery. Our goal is to provide models for different image modalities (optical, SAR, multispectral) and specific platforms (Landsat, Sentinel, MODIS) as well as benchmark results showing their performance with different amounts of training data. Self-supervised learning is a promising method for training such models. Satellite imagery datasets often contain petabytes of imagery, but accurately labeled datasets are much harder to come by. Self-supervised learning methods will allow us to train directly on the raw imagery without needing large labeled datasets.

Aside from these larger projects, we’re always looking to add new datasets, data augmentation transforms, and sampling strategies. If you’re Python savvy and interested in contributing to TorchGeo, we would love to see contributions! TorchGeo is open source under an MIT license, so you can use it in almost any project.

External links:

If you like TorchGeo, give us a star on GitHub! And if you use TorchGeo in your work, please cite our paper.


We would like to thank all TorchGeo contributors for their efforts in creating the library, the Microsoft AI for Good program for support, and the PyTorch Team for their guidance. This research is part of the Blue Waters sustained-petascale computing project, which is supported by the National Science Foundation (awards OCI-0725070 and ACI-1238993), the State of Illinois, and as of December, 2019, the National Geospatial-Intelligence Agency. Blue Waters is a joint effort of the University of Illinois at Urbana-Champaign and its National Center for Supercomputing Applications. The research was supported in part by NSF grants IIS-1908104, OAC-1934634, and DBI-2021898.

Read More

How Disney Improved Activity Recognition Through Multimodal Approaches with PyTorch

How Disney Improved Activity Recognition Through Multimodal Approaches with PyTorch


Among the many things Disney Media & Entertainment Distribution (DMED) is responsible for, is the management and distribution of a huge array of media assets including news, sports, entertainment and features, episodic programs, marketing and advertising and more.

Our team focuses on media annotation as part of DMED Technology’s content platforms group. In our day-to-day work, we automatically analyze a variety of content that constantly challenges the efficiency of our machine learning workflow and the accuracy of our models.

Several of our colleagues recently discussed the workflow efficiencies that we achieved by switching to an end-to-end video analysis pipeline using PyTorch, as well as how we approach animated character recognition. We invite you to read more about both in this previous post.

While the conversion to an end-to-end PyTorch pipeline is a solution that any company might benefit from, animated character recognition was a uniquely-Disney concept and solution.

In this article we will focus on activity recognition, which is a general challenge across industries — but with some specific opportunities when leveraged in the media production field, because we can combine audio, video, and subtitles to provide a solution.

Experimenting with Multimodality

Working on a multimodal problem adds more complexity to the usual training pipelines. Having multiple information modes for each example means that the multimodal pipeline has to have specific implementations to process each mode in the dataset. Usually after this processing step, the pipeline has to merge or fuse the outputs.

Our initial experiments in multimodality were completed using the MMF framework. MMF is a modular framework for vision and language multimodal research. MMF contains reference implementations of state-of-the-art vision and language models and has also powered multiple research projects at Meta AI Research (as seen in this poster presented in PyTorch Ecosystem Day 2020). Along with the recent release of TorchMultimodal, a PyTorch library for training state-of-the-art multimodal models at scale, MMF highlights the growing interest in Multimodal understanding.

MMF tackles this complexity with modular management of all the elements of the pipeline through a wide set of different implementations for specific modules, ranging from the processing of the modalities to the fusion of the processed information.

In our scenario, MMF was a great entry point to experiment with multimodality. It allowed us to iterate quickly by combining audio, video and closed captioning and experiment at different levels of scale with certain multimodal models, shifting from a single GPU to TPU Pods.

Multimodal Transformers

With a workbench based on MMF, our initial model was based on a concatenation of features from each modality evolving to a pipeline that included a Transformer-based fusion module to combine the different input modes.

Specifically, we made use of the fusion module called MMFTransformer, developed in collaboration with the Meta AI Research team. This is an implementation based on VisualBERT for which the necessary modifications were added to be able to work with text, audio and video.

Despite having decent results with the out-of-box implementation MMFTransformer, we were still far from our goal, and the Transformers-based models required more data than we had available.

Searching for less data-hungry solutions

Searching for less data-hungry solutions, our team started studying MLP-Mixer. This new architecture has been proposed by the Google Brain team and it provides an alternative to well established de facto architectures like convolutions or self-attention for computer vision tasks.


The core idea behind mixed variations consists of replacing the convolutions or self-attention mechanisms used in transformers with Multilayer Perceptrons. This change in architecture favors the performance of the model in high data regimes (especially with respect to the Transformers), while also opening some questions regarding the inductive biases hidden in the convolutions and the self-attention layers.

Those proposals perform great in solving image classification tasks by splitting the image in chunks, flattening those chunks into 1D vectors and passing them through a sequence of Mixer Layers.

Inspired by the advantages of Mixer based architectures, our team searched for parallelisms with the type of problems we try to solve in video classification: specifically, instead of a single image, we have a set of frames that need to be classified, along with audio and closed captioning in the shape of new modalities.

Activity Recognition reinterpreting the MLP-Mixer

Our proposal takes the core idea of the MLP-Mixer — using multiple multi-layer perceptrons on a sequence and transposed sequence and extends it into a Multi Modal framework that allows us to process video, audio & text with the same architecture.

For each of the modalities, we use different extractors that will provide embeddings describing the content. Given the embeddings of each modality, the MLP-Mixer architecture solves the problem of deciding which of the modalities might be the most important, while also weighing how much each modality contributes to the final labeling.

For example, when it comes to detecting laughs, sometimes the key information is in audio or in the frames, and in some of the cases we have a strong signal in the closed caption.

We tried processing each frame separately with a ResNet34 and getting a sequence of embeddings and by using a video-specific model called R3D, both pre-trained on ImageNet and Kinetics400 respectively.

To process the audio, we use the pretrained ResNet34, and we remove the final layers to be able to extract 2D embeddings from the audio spectrograms (for 224×224 images we end up with 7×7 embeddings).

For closed captioning, we are using a pre-trained BERT-large, with all layers frozen, except for the Embeddings & LayerNorms.

Once we have extracted the embedding from each modality, we concatenate them into a single sequence and pass it through a set of MLP-Mixer blocks; next we use average pooling & a classification head to get predictions.

Our experiments have been performed on a custom, manually labeled dataset for activity recognition with 15 classes, which we know from experiments are hard and cannot all be predicted accurately using a single modality.

These experiments have shown a significant increase in performance using our approach, especially in a low/mid-data regime (75K training samples).

When it comes to using only Text and Audio, our experiments showed a 15 percent improvement in accuracy over using a classifier on top of the features extracted by state-of-the-art backbones.

Using Text, Audio and Video we have seen a 17 percent improvement in accuracy over using Meta AIFacebook’s MMF Framework, which uses a VisualBERT-like model to combine modalities using more powerful state of the art backbones.

Currently, we extended the initial model to cover up to 55 activity classes and 45 event classes. One of the challenges we expect to improve upon in the future is to include all activities and events, even those that are less frequent.

Interpreting the MLP-Mixer mode combinations

An MLP-Mixer is a concatenation of MultiLayer Perceptrons. This can be, very roughly, approximated to a linear operation, in the sense that, once trained, the weights are fixed and the input will directly affect the output.

Once we assume that approximation, we also assume that for an input consisting of NxM numbers, we could find a NxM matrix that (when multiplied elementwise) could approximate the predictions of the MLP-Mixer for a class.

We will call this matrix a stencil, and if we have access to it, we can find what parts of the input embeddings are responsible for a specific prediction.

You can think of it as a punch card with holes in specific positions. Only information in those positions will pass and contribute to a specific prediction. So we can measure the intensity of the input at those positions.

Of course, this is an oversimplification, and there won’t exist a unique stencil that perfectly represents all of the contributions of the input to a class (otherwise that would mean that the problem could be solved linearly). So this should be used for visualization purposes only, not as an accurate predictor.

Once we have a set of stencils for each class, we can effortlessly measure input contribution without relying on any external visualization techniques.

To find a stencil, we can start from a “random noise” stencil and optimize it to maximize the activations for a specific class by just back-propagating through the MLP-Mixer.

By doing this we can end up with many valid stencils, and we can reduce them to a few by using K-means to cluster them into similar stencils and averaging each cluster.

Using the Mixer to get the best of each world

MLP-Mixer, used as an image classification model without convolutional layers, requires a lot of data, since the lack of inductive bias – one of the model’s good points overall – is a weakness when it comes to working in low data domains.

When used as a way to combine information previously extracted by large pretrained backbones (as opposed to being used as a full end-to-end solution), they shine. The Mixer’s strength lies in finding temporal or structural coherence between different inputs. For example, in video-related tasks we could extract embeddings from the frames using a powerful, pretrained model that understands what is going on at frame level and use the mixer to make sense of it in a sequential manner.

This way of using the Mixer allows us to work with limited amounts of data and still get better results than what was achieved with Transformers. This is because Mixers seem to be more stable during training and seem to pay attention to all the inputs, while Transformers tend to collapse and pay attention only to some modalities/parts of the sequence.

Acknowledgements: We would like to thank the Meta AI Research and Partner Engineering teams for this collaboration.

Read More

Introducing Accelerated PyTorch Training on Mac

Introducing Accelerated PyTorch Training on Mac

In collaboration with the Metal engineering team at Apple, we are excited to announce support for GPU-accelerated PyTorch training on Mac. Until now, PyTorch training on Mac only leveraged the CPU, but with the upcoming PyTorch v1.12 release, developers and researchers can take advantage of Apple silicon GPUs for significantly faster model training. This unlocks the ability to perform machine learning workflows like prototyping and fine-tuning locally, right on Mac.

Metal Acceleration

Accelerated GPU training is enabled using Apple’s Metal Performance Shaders (MPS) as a backend for PyTorch. The MPS backend extends the PyTorch framework, providing scripts and capabilities to set up and run operations on Mac. MPS optimizes compute performance with kernels that are fine-tuned for the unique characteristics of each Metal GPU family. The new device maps machine learning computational graphs and primitives on the MPS Graph framework and tuned kernels provided by MPS.

Training Benefits on Apple Silicon

Every Apple silicon Mac has a unified memory architecture, providing the GPU with direct access to the full memory store. This makes Mac a great platform for machine learning, enabling users to train larger networks or batch sizes locally. This reduces costs associated with cloud-based development or the need for additional local GPUs. The Unified Memory architecture also reduces data retrieval latency, improving end-to-end performance.

In the graphs below, you can see the performance speedup from accelerated GPU training and evaluation compared to the CPU baseline:

Testing conducted by Apple in April 2022 using production Mac Studio systems with Apple M1 Ultra, 20-core CPU, 64-core GPU 128GB of RAM, and 2TB SSD. Tested with macOS Monterey 12.3, prerelease PyTorch 1.12, ResNet50 (batch size=128), HuggingFace BERT (batch size=64), and VGG16 (batch size=64). Performance tests are conducted using specific computer systems and reflect the approximate performance of Mac Studio.

Getting Started

To get started, just install the latest Preview (Nightly) build on your Apple silicon Mac running macOS 12.3 or later with a native version (arm64) of Python.

You can also learn more about Metal and MPS on Apple’s Metal page.

Read More

Ambient Clinical Intelligence: Generating Medical Reports with PyTorch

Ambient Clinical Intelligence: Generating Medical Reports with PyTorch


Complete and accurate clinical documentation is an essential tool for tracking patient care. It allows for treatment plans to be shared among care teams to aid in continuity of care and ensures a transparent and effective process for reimbursement.

Physicians are responsible for documenting patient care. Traditional clinical documentation methods have resulted in a sub-par patient-provider experience, less time interacting with patients, and decreased work-life balance. A significant amount of physicians’ time is spent in front of the computer doing administrative tasks. As a result, patients are less satisfied with the overall experience, and physicians, who prepare for years studying medicine, cannot practice at the top of their license and are burned out. Every hour physicians provide direct clinical face time to patients results in nearly two additional hours spent on EHR and desk work within the clinic day. Outside office hours, physicians spend another 1 to 2 hours of personal time each night doing additional computer and other clerical work.

Physician burnout is one of the primary causes for increased medical errors, malpractice suits, turnover, and decreased access to care. Burnout leads to an increase in healthcare costs and a decrease in overall patient satisfaction. Burnout costs the United States $4.6 billion a year.

What can we do to bring back trust, joy, and humanity to the delivery of healthcare? A significant portion of the administrative work consists of entering patient data into Electronic Health Records (EHRs) and creating clinical documentation. Clinical documentation is created from information already in the EHR as well as from the patient-provider encounter conversation.

This article will showcase how the Nuance Dragon Ambient eXperience (DAX), an AI-powered, voice-enabled, ambient clinical intelligence solution, automatically documents patient encounters accurately and efficiently at the point of care and the technologies that enable it.

Nuance DAX enhances the quality of care and patient experience, increases provider efficiency and satisfaction, and improves financial outcomes. It can be used in office and telehealth settings in all ambulatory specialties, including primary and urgent care.

Natural Language Processing

Natural Language Processing (NLP) is one of the most challenging fields in Artificial Intelligence (AI). It comprehends a set of algorithms that allow computers to understand or generate the language used by humans. These algorithms can process and analyze vast amounts of natural language data from different sources (either sound or text) to build models that can understand, classify, or even generate natural language as humans would. Like other fields in AI, NLP has significantly progressed thanks to the advent of Deep Learning (DL), which has resulted in models that can obtain results on par with humans in some tasks.

These advanced NLP techniques are being applied in healthcare. During a typical patient-provider encounter, a conversation ensues where the doctor constructs, through questions and answers, a chronological description of the development of the patient’s presenting illness or symptoms. A physician examines the patient and makes clinical decisions to establish a diagnosis and determine a treatment plan. This conversation, and data in the EHR, provide the required information for physicians to generate the clinical documentation, referred to as medical reports.

Two main NLP components play a role in automating the creation of clinical documentation. The first component, Automatic Speech Recognition (ASR), is used to translate speech into text. It takes the audio recording of the encounter and generates a conversation transcription (cf. Figure 2). The second component, Automatic Text Summarization, helps generate summaries from large text documents. This component is responsible for understanding and capturing the nuances and most essential aspects from the transcribed conversation into a final report in narrative form (cf. Figure 3), structured form, or a combination of both.

We will focus on this second component, Automatic Text Summarization, which is a difficult task with many challenges:

  • Its performance is tied to the ASR quality from multiple speakers (noisy input).
  • The input is conversational in nature and contains layman’s terms.
  • Protected Health Information (PHI) regulations limit medical data access.
  • The information for one output sentence is potentially spread across multiple conversation turns.
  • There is no explicit sentence alignment between input and output.
  • Various medical specialties, encounter types, and EHR systems constitute a broad and complex output space.
  • Physicians have different styles of conducting encounters and have their preferences for medical reports; there is no standard.
  • Standard summarization metrics might differ from human judgment of quality.

Figure 2: Transcript of a patient-doctor conversation

Figure 3: Excerpt of an AI-generated medical report. HPI stands for History of present illness.

Text Summarization with PyTorch and Fairseq

PyTorch is an open-source machine learning framework developed by Facebook that helps researchers prototype Deep Learning models. The Fairseq toolkit is built on top of PyTorch and focuses on sequence generation tasks, such as Neural Machine Translation (NMT) or Text Summarization. Fairseq features an active community that is continuously providing reference implementations of state-of-the-art models. It contains many built-in components (model architectures, modules, loss functions, and optimizers) and is easily extendable with plugins.

Text summarization constitutes a significant challenge in NLP. We need models capable of generating a short version of a document while retaining the key points and avoiding uninformative content. These challenges can be addressed with different approaches. 1). Abstractive text summarization aimed at training models that can generate a summary in narrative form. 2). Extractive methods where the models are trained to select the most important parts from the input text. 3). A combination of the two, where the essential parts from the input are selected and then summarized in an abstractive fashion. Hence, summarization can be accomplished via a single end-to-end network or as a pipeline of extractive and abstractive components. To that end, Fairseq provides all the necessary tools to be successful in our endeavor. It features either end-to-end models such as the classical Transformer, different types of Language Models and pre-trained versions that enable researchers to focus on what matters most—to build state-of-the-art models that generate valuable reports.

However, we are not just summarizing the transcribed conversation; we generate high-quality medical reports, which have many considerations.

  • Every section of a medical report is different in terms of content, structure, fluency, etc.
  • All medical facts mentioned in the conversation should be present in the report, for example, a particular treatment or dosage.
  • In the healthcare domain, the vocabulary is extensive, and models need to deal with medical terminology.
  • Patient-doctor conversations are usually much longer than the final report.

All these challenges require our researchers to run a battery of extensive experiments. Thanks to the flexibility of PyTorch and Fairseq, their productivity has greatly increased. Further, the ecosystem offers an easy path from ideation, implementation, experimentation, and final roll-out to production. Using multiple GPUs or CPUs is as simple as providing an additional argument to the tools, and because of the tight Python integration, PyTorch code can be easily debugged.

In our continuous effort to contribute to the open-source community, features have been developed at Nuance and pushed to the Fairseq GitHub repository. These try to overcome some of the challenges mentioned such as, facilitating copying of, especially rare or unseen, words from the input to summary, training speedups by improving Tensor Core utilization, and ensuring TorchScript compatibility of different Transformer configurations. Following, we will show an example of how to train a Transformer model with a Pointer Generator mechanism (Transformer-PG), which can copy words from the input.

How to build a Transformer model with a Pointer Generator mechanism

In this step-by-step guide, it is assumed the user has already installed PyTorch and Fairseq.

1. Create a vocabulary and extend it with source position markers:

These markers will allow the model to point to any word in the input sequence.

export LC_ALL=C
cat train.src train.tgt |
  tr -s '[:space:]' 'n' |
  sort |
  uniq -c |
  sort -k1,1bnr -k2 |
  head -n "$((vocab_size - 4))" |
  awk '{ print $2 " " $1 }' >
python3 -c "[print('<unk-{}> 0'.format(n)) for n in range($position_markers)]" >>

This will create a file “” that contains the <vocab_size> most frequent words followed by 512 position markers named from “<unk-0>” to “<unk-511>”.

In case we have an input like

src = "Hello, I'm The Dogtor"

it could happen that our model has been trained without the word “Dogtor” in its vocabulary. Therefore, when we feed this sequence into the model, it should be converted to:

src = "Hello, I'm The <unk-3>"

Now, “<unk-3>” is part of our vocabulary and could be predicted by the model (this is where the pointer-generator comes in). In such a case, we will only need to post-process the output to replace “<unk-3>” by the word at input position 3.

2. Preprocess the text data to replace unknown words by its positional markers:

We can use the scripts from

# Considering we have our data in:
# train_src = /path/to/train.src
# train_tgt = /path/to/train.tgt
# valid_src = /path/to/valid.src
# valid_tgt = /path/to/valid.tgt
./ --source /path/to/train.src 
                --target /path/to/train.tgt 
                --vocab <(cut -d' ' -f1 
                --source-out /path/to/ 
                --target-out /path/to/

./ --source /path/to/valid.src 
                --target /path/to/valid.tgt 
                --vocab <(cut -d' ' -f1 
                --source-out /path/to/ 
                --target-out /path/to/

./ --source /path/to/test.src 
                --vocab <(cut -d' ' -f1 
                --source-out /path/to/

3. Now let’s binarize the data, so that it can be processed faster:

fairseq-preprocess --task "translation" 
                   --source-lang "pg.src" 
                   --target-lang "pg.tgt" 
                   --trainpref /path/to/train 
                   --validpref /path/to/valid 
                   --destdir <data_dir>

You might notice the type of task is “translation”. This is because there is no “summarization” task available; we could understand it as a kind of NMT task where the input and output languages are shared and the output (summary) is shorter than the input.

4. Now we can train the model:

fairseq-train <data_dir> 
              --save-dir <model_dir> 
              --task "translation" 
              --source-lang "src" 
              --target-lang "tgt" 
              --arch "transformer_pointer_generator" 
              --max-source-positions 512 
              --max-target-positions 128 
              --max-tokens 2048 
              --required-batch-size-multiple 1 
              --required-seq-len-multiple 8 
              --dropout 0.1 
              --criterion "cross_entropy" 
              --optimizer adam 
              --adam-betas '(0.9, 0.98)' 
              --adam-eps 1e-9 
              --update-freq 4 
              --lr 0.004 
              # Pointer Generator
              --alignment-layer -1 
              --alignment-heads 1 
              --source-position-markers 512

This configuration makes use of features Nuance has contributed back to Fairseq:

  • Transformer with a Pointer Generator mechanism to facilitate copying of words from the input.
  • Sequence length padded to a multiple of 8 to better use tensor cores and reduce training time.

5. Now let’s take a look at how to generate a summary with our new medical report generation system:

import torch
from examples.pointer_generator.pointer_generator_src.transformer_pg import TransformerPointerGeneratorModel

# Patient-Doctor conversation
input = "[doctor] Lisa Simpson, thirty six year old female, presents to the clinic today because " 
        "she has severe right wrist pain"

# Load the model
model = TransformerPointerGeneratorModel.from_pretrained(data_name_or_path=<data_dir>,

result = model.translate([input], beam=2)

Ms. <unk-2> is a 36-year-old female who presents to the clinic today for evaluation of her right wrist.

6. Alternatively, we can use fairseq-interactive and a postprocessing tool to substitute positional unknown tokens by its words from the input:

fairseq-interactive <data_dir> 
              --batch-size <batch_size> 
              --task translation 
              --source-lang src 
              --target-lang tgt 
              --path <model_dir>/ 
              --input /path/to/ 
              --buffer-size 20 
              --max-len-a 0 
              --max-len-b 128 
              --beam 2 
              --skip-invalid-size-inputs-valid-test | tee generate.out

grep "^H-" generate.out | cut -f 3- > generate.hyp

	--source <(awk 'NF<512' /path/to/ 
	--target generate.hyp 
	--target-out generate.hyp.processed

Now we have the final set of reports in “generate.hyp.processed”, with “<unk-N>” replaced by the original word from the input sequence.

Model Deployment

PyTorch offers great flexibility in modeling and a rich surrounding ecosystem. However, while several recent articles have suggested that the use of PyTorch in research and academia may be close to surpassing TensorFlow, there seems to be an overall sense of TensorFlow being the preferred platform for deployment to production. Is this still the case in 2021? Teams looking to serve their PyTorch models in production have a few options.

Before describing our journey, let’s take a brief detour and define the term model.

Models as computation graphs

A few years back, it was still common for machine learning toolkits to support only particular classes of models of a rather fixed and rigid structure, with only a few degrees of freedom (like the kernel of a support vector machine or the number of hidden layers of a neural network). Inspired by foundational work in Theano, toolkits like Microsoft’s CNTK or Google’s TensorFlow were among the first to popularize a more flexible view on models, as computation graphs with associated parameters that can be estimated from data. This view blurred the boundaries between popular types of models (such as DNNs or SVMs), as it became easy to blend the characteristics of each into your type of graph. Still, such a graph had to be defined upfront before estimating its parameters, and it was pretty static. This made it easy to save models to a self-contained bundle, like a TensorFlow SavedModel (such a bundle simply contains the structure of the graph, as well as the concrete values of the estimated parameters). However, debugging such models can be difficult because the statements in the Python code that build the graph are logically separate from the lines that execute it. Researchers also long for easier ways of expressing dynamic behavior, such as the computation steps of the forward pass of a model being conditionally dependent on its input data (or its previous output).

Most recently, the above limitations have led to a second revolution spearheaded by PyTorch and TensorFlow 2. The computation graph is no longer defined explicitly. Instead, it will be populated implicitly as the Python code executes operations on tensor arguments. An essential technique that powers this development is automatic differentiation. As the computation graph is being built implicitly while executing the steps of the forward pass, all the necessary data will be tracked for later computation of the gradient concerning the model parameters. This allows for great flexibility in training a model, but it raises an important question. If the computation happening inside a model is only implicitly defined through our Python code’s steps as it executes concrete data, what is it that we want to save as a model? The answer – at least initially – was the Python code with all its dependencies, along with the estimated parameters. This is undesirable for practical reasons. For instance, there is a danger that the team working on model deployment does not exactly reproduce the Python code dependencies used during training, leading to subtly divergent behavior. The solution typically consists of combining two techniques, scripting and tracing, that is, extra annotations in your Python code and execution of your code on exemplary input data, allowing PyTorch to define and save the graph that should be executed during later inference on new, unseen data. This requires some discipline by whoever creates the model code (arguably voiding some of the original flexibility of eager execution), but it results in a self-contained model bundle in TorchScript format. The solution in TensorFlow 2 is remarkably similar.

Serving our report generation models

Our journey in deploying the report generation models reflects the above discussion. We started out serving our models by deploying the model code and its dependencies along with the parameter checkpoints in a custom Docker image exposing a gRPC service interface. However, we soon noticed that it became error-prone to replicate the exact code and environment used by the modeling team while estimating the parameters. Moreover, this approach prevented us from leveraging high-performance model serving frameworks like NVIDIA’s Triton, which is written in C++ and requires self-contained models that can be used without a Python interpreter. At this stage, we were facing a choice between attempting to export our PyTorch models to ONNX or TorchScript format. ONNX is an open specification for representing machine learning models that increasingly finds adoption. It is powered by a high-performance runtime developed by Microsoft (ONNX Runtime). Working closely with the ONNX team at Microsoft, we discovered that some operations that our models require were not yet supported in ONNX. Consequently, we turned our attention to TorchScript, the mechanism more native to PyTorch. Through a combination of tracing and scripting, annotating our code where needed, we succeeded and obtained self-contained TorchScript models that Triton could serve. This improved our deployment path considerably. We no longer had to worry about the code dependencies and now had the option of using Triton for high-performance model serving on NVIDIA GPUs.

A maturing ecosystem

Is it all roses? No, it has been a rockier journey than we expected. We encountered what seems to be a memory leak in the MKL libraries used by PyTorch while serving the PyTorch code directly. We encountered deadlocks in trying to load multiple models from multiple threads. We had difficulties exporting our models to ONNX and TorchScript formats. Models would not work out-of-the-box on hardware with multiple GPUs, they always accessed the particular GPU device on which they were exported. We encountered excessive memory usage in the Triton inference server while serving TorchScript models, which we found out was due to automatic differentiation accidentally being enabled during the forward pass. However, the ecosystem keeps improving, and there is a helpful and vibrant open-source community eager to work with us to mitigate such issues. Finally, for those of us that require enterprise-level support, Microsoft now offers Premier Support for use of PyTorch on Azure.

Where to go from here? For those that require the flexibility of serving PyTorch code directly, without going through the extra step of exporting self-contained models, it is worth pointing out that the TorchServe project now provides a way of bundling the code together with parameter checkpoints into a single servable archive, greatly reducing the risk of code and parameters running apart. To us, however, exporting models to TorchScript has proven beneficial. It provides a clear interface between modeling and deployment teams, and TorchScript further reduces the latency when serving models on GPU via its just-in-time compilation engine.

Scaling at large and the future

Finally, efficient deployment to the cloud is about more than just computing the response of a single model instance efficiently. Flexibility is needed in managing, versioning and updating models. High-level scalability must be achieved via techniques such as load-balancing, horizontal scaling and vertical scaling. If many models are involved, scale-to-zero quickly becomes a topic as it is unacceptable to pay for serving models that do not answer any requests. Providing such extra functionality on top of a low-level inference server like Triton is the job of an orchestration framework. After gaining some first experience with KubeFlow, to that end, we decided to turn our attention to Azure ML, which provides similar functionality but integrates more deeply with the Azure platform, on which we crucially rely for large parts of our technology stack already. This part of our journey has just begun.


Academia has long recognized that we are “standing on the shoulders of giants.” As Artificial Intelligence is maturing from a scientific discipline into technology, the same spirit of collaboration that originally fueled its scientific foundation has carried over into the world of software engineering. Open-source enthusiasts join technology companies worldwide to build open software ecosystems that allow for new angles at solving some of the most pressing challenges of modern society. In this article, we’ve taken a look at Nuance’s Dragon Ambient eXperience, an AI-powered, voice-enabled solution that automatically documents patient care, reducing healthcare providers’ administrative burdens. Nuance DAX improves the patient-provider experience, reduces physician burnout, and improves financial outcomes. It brings back trust, joy, and humanity to the delivery of healthcare. Fairseq and PyTorch have proven to be an incredible platform for powering this AI technology, and in turn, Nuance has contributed back some of its innovations in this space. For further reading, we invite you to take a look at our recent ACL publication and the Nuance “What’s Next” blog.

Read More

Running PyTorch Models on Jetson Nano

Running PyTorch Models on Jetson Nano


Nvidia Jetson Nano, part of the Jetson family of products or Jetson modules, is a small yet powerful Linux (Ubuntu) based embedded computer with 2/4GB GPU. With it, you can run many PyTorch models efficiently. This document summarizes our experience of running different deep learning models using 3 different mechanisms on Jetson Nano:

  1. Jetson Inference the higher-level Nvidia API that has built-in support for running most common computer vision models which can be transfer-learned with PyTorch on the Jetson platform.

  2. TensorRT a high-performance inference framework from Nvidia that requires the conversion of a PyTorch model to ONNX, and then to the TensorRT engine file that the TensorRT runtime can run.

  3. PyTorch with the direct PyTorch API torch.nn for inference.

Setting up Jetson Nano

After purchasing a Jetson Nano here, simply follow the clear step-by-step instructions to download and write the Jetson Nano Developer Kit SD Card Image to a microSD card, and complete the setup. After the setup is done and the Nano is booted, you’ll see the standard Linux prompt along with the username and the Nano name used in the setup.

To check the GPU status on Nano, run the following commands:

sudo pip3 install jetson-stats
sudo jtop

You’ll see information, including:

You can also see the installed CUDA version:

$ ls -lt /usr/local
lrwxrwxrwx  1 root root   22 Aug  2 01:47 cuda -> /etc/alternatives/cuda
lrwxrwxrwx  1 root root   25 Aug  2 01:47 cuda-10 -> /etc/alternatives/cuda-10
drwxr-xr-x 12 root root 4096 Aug  2 01:47 cuda-10.2

To use a camera on Jetson Nano, for example, Arducam 8MP IMX219, follow the instructions here or run the commands below after installing a camera module:

cd ~
chmod +x
./ -m arducam

Another way to do this is to use the original Jetson Nano camera driver:

sudo dpkg -r arducam-nvidia-l4t-kernel
sudo shutdown -r now

Then, use ls /dev/video0 to confirm the camera is found:

$ ls /dev/video0

And finally, the following command to see the camera in action:

nvgstcapture-1.0 --orientation=2

Using Jetson Inference

Nvidia Jetson Inference API offers the easiest way to run image recognition, object detection, semantic segmentation, and pose estimation models on Jetson Nano. Jetson Inference has TensorRT built-in, so it’s very fast.

To test run Jetson Inference, first clone the repo and download the models:

git clone --recursive
cd jetson-inference

Then use the pre-built Docker Container that already has PyTorch installed to test run the models:

docker/ --volume ~/jetson_inference:/jetson_inference

To run image recognition, object detection, semantic segmentation, and pose estimation models on test images, use the following:

cd build/aarch64/bin
./ images/jellyfish.jpg /jetson_inference/jellyfish.jpg
./ images/dog.jpg /jetson_inference/dog.jpeg
./ images/peds_0.jpg /jetson_inference/peds_0.jpg
./ images/humans_0.jpg /jetson_inference/pose_humans_0.jpg

Four result images from running the four different models will be generated. Exit the docker image to see them:

$ ls -lt ~/jetson_inference/
-rw-r--r-- 1 root root  68834 Oct 15 21:30 pose_humans_0.jpg
-rw-r--r-- 1 root root 914058 Oct 15 21:30 peds_0.jpg
-rw-r--r-- 1 root root 666239 Oct 15 21:30 dog.jpeg
-rw-r--r-- 1 root root 179760 Oct 15 21:29 jellyfish.jpg
Using jest interface example 1
Using jest interface example 2
Using jest interface example 3
Using jest interface example 4

You can also use the docker image to run PyTorch models because the image has PyTorch, torchvision and torchaudio installed:

# pip list|grep torch
torch (1.9.0)
torchaudio (0.9.0a0+33b2469)
torchvision (0.10.0a0+300a8a4)

Although Jetson Inference includes models already converted to the TensorRT engine file format, you can fine-tune the models by following the steps in Transfer Learning with PyTorch (for Jetson Inference) here.

Using TensorRT

TensorRT is a high-performance inference framework from Nvidia. Jetson Nano supports TensorRT via the Jetpack SDK, included in the SD Card image used to set up Jetson Nano. To confirm that TensorRT is already installed in Nano, run dpkg -l|grep -i tensorrt:

Theoretically, TensorRT can be used to “take a trained PyTorch model and optimize it to run more efficiently during inference on an NVIDIA GPU.” Follow the instructions and code in the notebook to see how to use PyTorch with TensorRT through ONNX on a torchvision Resnet50 model:

  1. How to convert the model from PyTorch to ONNX;

  2. How to convert the ONNX model to a TensorRT engine file;

  3. How to run the engine file with the TensorRT runtime for performance improvement: inference time improved from the original 31.5ms/19.4ms (FP32/FP16 precision) to 6.28ms (TensorRT).

You can replace the Resnet50 model in the notebook code with another PyTorch model, go through the conversion process above, and run the finally converted model TensorRT engine file with the TensorRT runtime to see the optimized performance. But be aware that due to the Nano GPU memory size, models larger than 100MB are likely to fail to run, with the following error information:

Error Code 1: Cuda Runtime (all CUDA-capable devices are busy or unavailable)

You may also see an error when converting a PyTorch model to ONNX model, which may be fixed by replacing:

torch.onnx.export(resnet50, dummy_input, "resnet50_pytorch.onnx", verbose=False)


torch.onnx.export(model, dummy_input, "deeplabv3_pytorch.onnx", opset_version=11, verbose=False)

Using PyTorch

First, to download and install PyTorch 1.9 on Nano, run the following commands (see here for more information):

wget -O torch-1.8.0-cp36-cp36m-linux_aarch64.whl -O torch-1.9.0-cp36-cp36m-linux_aarch64.whl
sudo apt-get install python3-pip libopenblas-base libopenmpi-dev 
pip3 install Cython
pip3 install numpy torch-1.9.0-cp36-cp36m-linux_aarch64.whl

To download and install torchvision 0.10 on Nano, run the commands below:
pip3 install torchvision-0.10.0a0+300a8a4-cp36-cp36m-linux_aarch64.whl

After the steps above, run this to confirm:

$ pip3 list|grep torch
torch (1.9.0)
torchvision (0.10.0)

You can also use the docker image described in the section Using Jetson Inference (which also has PyTorch and torchvision installed), to skip the manual steps above.

The official YOLOv5 repo is used to run the PyTorch YOLOv5 model on Jetson Nano. After logging in to Jetson Nano, follow the steps below:

  • Get the repo and install what’s required:
git clone
cd yolov5
pip install -r requirements.txt
  • Run python3, which by default uses the PyTorch model. You should see something like:
detect:, source=data/images, imgsz=[640, 640], conf_thres=0.25, iou_thres=0.45, max_det=1000, device=, view_img=False, save_txt=False, save_conf=False, save_crop=False, nosave=False, classes=None, agnostic_nms=False, augment=False, visualize=False, update=False, project=runs/detect, name=exp, exist_ok=False, line_thickness=3, hide_labels=False, hide_conf=False, half=False
YOLOv5 🚀 v5.0-499-g48b00db torch 1.9.0 CUDA:0 (NVIDIA Tegra X1, 3956.1015625MB)

Fusing layers... 
Model Summary: 224 layers, 7266973 parameters, 0 gradients
image 1/5 /home/jeff/repos/yolov5-new/yolov5/data/images/bus.jpg: 640x480 4 persons, 1 bus, 1 fire hydrant, Done. (0.142s)

The inference time on Jetson Nano GPU is about 140ms, more than twice as fast as the inference time on iOS or Android (about 330ms).

If you get an error “ImportError: The _imagingft C module is not installed.” then you need to reinstall pillow:

sudo apt-get install libpng-dev
sudo apt-get install libfreetype6-dev
pip3 uninstall pillow
pip3 install --no-cache-dir pillow

After successfully completing the python3 run, the object detection results of the test images located in data/images will be in the runs/detect/exp directory. To test the detection with a live webcam instead of local images, use the --source 0 parameter when running python3

~/repos/yolov5$ ls -lt runs/detect/exp10
total 1456
-rw-rw-r-- 1 jeff jeff 254895 Oct 15 16:12 zidane.jpg
-rw-rw-r-- 1 jeff jeff 202674 Oct 15 16:12 test3.png
-rw-rw-r-- 1 jeff jeff 217117 Oct 15 16:12 test2.jpg
-rw-rw-r-- 1 jeff jeff 305826 Oct 15 16:12 test1.png
-rw-rw-r-- 1 jeff jeff 495760 Oct 15 16:12 bus.jpg

Using the same test files used in the PyTorch iOS YOLOv5 demo app or Android YOLOv5 demo app, you can compare the results generated with running the YOLOv5 PyTorch model on mobile devices and Jetson Nano:

PyTorch YOLOv5 on Jetson Nano, example with a dog
PyTorch YOLOv5 on Jetson Nano, example with a horse and a rider

Figure 1. PyTorch YOLOv5 on Jetson Nano.

PyTorch YOLOv5 on iOS, example with a dog
PyTorch YOLOv5 on iOS, example with a horse and a rider

Figure 2. PyTorch YOLOv5 on iOS.

PyTorch YOLOv5 on Android, example with a dog
PyTorch YOLOv5 on Android, example with a horse and a rider

Figure 3. PyTorch YOLOv5 on Android.


Based on our experience of running different PyTorch models for potential demo apps on Jetson Nano, we see that even Jetson Nano, a lower-end of the Jetson family of products, provides a powerful GPU and embedded system that can directly run some of the latest PyTorch models, pre-trained or transfer learned, efficiently.

Building PyTorch demo apps on Jetson Nano can be similar to building PyTorch apps on Linux, but you can also choose to use TensorRT after converting the PyTorch models to the TensorRT engine file format.

But if you just need to run some common computer vision models on Jetson Nano using Nvidia’s Jetson Inference which supports image recognition, object detection, semantic segmentation, and pose estimation models, then this is the easiest way.


Torch-TensorRT, a compiler for PyTorch via TensorRT:

Jetson Inference docker image details:

A guide to using TensorRT on the Nvidia Jetson Nano:

  1. Use Jetson as a portable GPU device to run an NN chess engine model:

  2. A MaskEraser app using PyTorch and torchvision, installed directly with pip:

A PyTorch to TensorRT converter:

Read More

Introducing PyTorch Fully Sharded Data Parallel (FSDP) API

Introducing PyTorch Fully Sharded Data Parallel (FSDP) API

Recent studies have shown that large model training will be beneficial for improving model quality. During the last 3 years, model size grew 10,000 times from BERT with 110M parameters to Megatron-2 with one trillion. However, training large AI models is not easy—aside from the need for large amounts of computing resources, software engineering complexity is also challenging. PyTorch has been working on building tools and infrastructure to make it easier.

PyTorch Distributed data parallelism is a staple of scalable deep learning because of its robustness and simplicity. It however requires the model to fit on one GPU. Recent approaches like DeepSpeed ZeRO and FairScale’s Fully Sharded Data Parallel allow us to break this barrier by sharding a model’s parameters, gradients and optimizer states across data parallel workers while still maintaining the simplicity of data parallelism.

With PyTorch 1.11 we’re adding native support for Fully Sharded Data Parallel (FSDP), currently available as a prototype feature. Its implementation heavily borrows from FairScale’s version while bringing more streamlined APIs and additional performance improvements.

Scaling tests of PyTorch FSDP on AWS show it can scale up to train dense models with 1T parameters. Realized performance in our experiments reached 84 TFLOPS per A100 GPU for GPT 1T model and 159 TFLOPS per A100 GPU for GPT 175B model on AWS cluster. Native FSDP implementation also dramatically improved model initialization time compared to FairScale’s original when CPU offloading was enabled.

In future PyTorch versions, we’re going to enable users to seamlessly switch between DDP, ZeRO-1, ZeRO-2 and FSDP flavors of data parallelism, so that users can train different scales of models with simple configurations in the unified API.

How FSDP Works

FSDP is a type of data-parallel training, but unlike traditional data-parallel, which maintains a per-GPU copy of a model’s parameters, gradients and optimizer states, it shards all of these states across data-parallel workers and can optionally offload the sharded model parameters to CPUs.

The figure below shows how FSDP works for 2 data-parallel processes:

Figure 1. FSDP workflow

Usually, model layers are wrapped with FSDP in a nested way, so that only layers in a single FSDP instance need to gather the full parameters to a single device during forward or backward computations. The gathered full parameters will be freed immediately after computation, and the freed memory can be used for the next layer’s computation. In this way, peak GPU memory could be saved and thus training can be scaled to use a larger model size or larger batch size. To further maximize memory efficiency, FSDP can offload the parameters, gradients and optimizer states to CPUs when the instance is not active in the computation.

Using FSDP in PyTorch

There are two ways to wrap a model with PyTorch FSDP. Auto wrapping is a drop-in replacement for DDP; manual wrapping needs minimal changes of model definition code with the ability to explore complex sharding strategies.

Auto Wrapping

Model layers should be wrapped in FSDP in a nested way to save peak memory and enable communication and computation overlapping. The simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code.

fsdp_auto_wrap_policy argument allows specifying a callable function to recursively wrap layers with FSDP. default_auto_wrap_policy function provided by the PyTorch FSDP recursively wraps layers with the number of parameters larger than 100M. You can supply your own wrapping policy as needed. The example of writing a customized wrapping policy is shown in the FSDP API doc.

In addition, cpu_offload could be configured optionally to offload wrapped parameters to CPUs when these parameters are not used in computation. This can further improve memory efficiency at the cost of data transfer overhead between host and device.

The example below shows how FSDP is wrapped using auto wrapping.

from torch.distributed.fsdp import (
from torch.distributed.fsdp.wrap import (
import torch.nn as nn
class model(nn.Module):
   def __init__(self):
       self.layer1 = nn.Linear(8, 4)
       self.layer2 = nn.Linear(4, 16)
       self.layer3 = nn.Linear(16, 4)
model = DistributedDataParallel(model())
fsdp_model = FullyShardedDataParallel(

Manual Wrapping

Manual wrapping can be useful to explore complex sharding strategies by applying wrap selectively to some parts of the model. Overall settings can be passed to the enable_wrap() context manager.

from torch.distributed.fsdp import (
from torch.distributed.fsdp.wrap import (
import torch.nn as nn
from typing import Dict
class model(nn.Module):
   def __init__(self):
       self.layer1 = wrap(nn.Linear(8, 4))
       self.layer2 = nn.Linear(4, 16)
       self.layer3 = wrap(nn.Linear(16, 4))
wrapper_kwargs = Dict(cpu_offload=CPUOffload(offload_params=True))
with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs):
   fsdp_model = wrap(model())

After wrapping the model with FSDP using one of the two above approaches, the model can be trained in a similar way as local training, like this:

optim = torch.optim.Adam(fsdp_model.parameters(), lr=0.0001)
for sample, label in next_batch():
  out = fsdp_model(input)
  loss = criterion(out, label)

Benchmark Results

We ran extensive scaling tests for 175B and 1T GPT models on AWS clusters using PyTorch FSDP. Each cluster node is an instance with 8 NVIDIA A100-SXM4-40GB GPUs, and inter-nodes are connected via AWS Elastic Fabric Adapter (EFA) with 400 Gbps network bandwidth.

GPT models are implemented using minGPT. A randomly generated input dataset is used for benchmarking purposes. All experiments ran with 50K vocabulary size, fp16 precision and SGD optimizer.

Model Number of layers Hidden size Attention heads Model size, billions of parameters
GPT 175B 96 12288 96 175
GPT 1T 128 25600 160 1008

In addition to using FSDP with parameters CPU offloading in the experiments, the activation checkpointing feature in PyTorch is also applied in the tests.

The maximum per-GPU throughput of 159 teraFLOP/s (51% of NVIDIA A100 peak theoretical performance 312 teraFLOP/s/GPU) is achieved with batch size 20 and sequence length 512 on 128 GPUs for the GPT 175B model; further increase of the number of GPUs leads to per-GPU throughput degradation because of growing communication between the nodes.

For the GPT 1T model, the maximum per-GPU throughput of 84 teraFLOP/s (27% of the peak teraFLOP/s) is achieved with batch size 4 and sequence length 2048 on 128 GPUs. However, further increase of the number of GPUs doesn’t affect the per-GPU throughput too much because we observed that the largest bottleneck in the 1T model training is not from communication but from the slow CUDA cache allocator when peak GPU memory is reaching the limit. The use of A100 80G GPUs with larger memory capacity will mostly resolve this issue and also help scale the batch size to achieve much larger throughput.

Future Work

In the next beta release, we are planning to add efficient distributed model/states checkpointing APIs, meta device support for large model materialization, and mixed-precision support inside FSDP computation and communication. We’re also going to make it easier to switch between DDP, ZeRO1, ZeRO2 and FSDP flavors of data parallelism in the new API. To further improve FSDP performance, memory fragmentation reduction and communication efficiency improvements are also planned.

A Bit of History of 2 Versions of FSDP

FairScale FSDP was released in early 2021 as part of the FairScale library. And then we started the effort to upstream FairScale FSDP to PyTorch in PT 1.11, making it production-ready. We have selectively upstreamed and refactored key features from FairScale FSDP, redesigned user interfaces and made performance improvements.

In the near future, FairScale FSDP will stay in the FairScale repository for research projects, while generic and widely adopted features will be upstreamed to PyTorch incrementally and hardened accordingly.

Meanwhile, PyTorch FSDP will focus more on production readiness and long-term support. This includes better integration with ecosystems and improvements on performance, usability, reliability, debuggability and composability.


We would like to thank the authors of FairScale FSDP: Myle Ott, Sam Shleifer, Min Xu, Priya Goyal, Quentin Duval, Vittorio Caggiano, Tingting Markstrum, Anjali Sridhar. Thanks to the Microsoft DeepSpeed ZeRO team for developing and popularizing sharded data parallel techniques. Thanks to Pavel Belevich, Jessica Choi, Sisil Mehta for running experiments using PyTorch FSDP on different clusters. Thanks to Geeta Chauhan, Mahesh Yadav, Pritam Damania, Dmytro Dzhulgakov for supporting this effort and insightful discussions.

Read More

Introducing TorchRec, and other domain library updates in PyTorch 1.11

Introducing TorchRec, and other domain library updates in PyTorch 1.11

We are introducing the beta release of TorchRec and a number of improvements to the current PyTorch domain libraries, alongside the PyTorch 1.11 release. These updates demonstrate our focus on developing common and extensible APIs across all domains to make it easier for our community to build ecosystem projects on PyTorch. Highlights include:

  • TorchRec, a PyTorch domain library for Recommendation Systems, is available in beta. View it on GitHub.
  • TorchAudio – Added Enformer- and RNN-T-based models and recipes to support the full development lifecycle of a streaming ASR model. See the release notes here.
  • TorchText – Added beta support for RoBERTa and XLM-R models, byte-level BPE tokenizer, and text datasets backed by TorchData. See the release notes here.
  • TorchVision – Added 4 new model families and 14 new classification datasets such as CLEVR, GTSRB, FER2013. See the release notes here.

TorchRec 0.1

We announced TorchRec a few weeks ago and we are excited to release the beta version today. To recap, TorchRec is a PyTorch domain library for Recommendation Systems. This new library provides common sparsity and parallelism primitives, enabling researchers to build state-of-the-art personalization models and deploy them in production. TorchRec was used to train a 1.25 trillion parameter model, pushed to production in January 2022.

In particular, the library includes:

  • Modeling primitives, such as embedding bags and jagged tensors, that enable easy authoring of large, performant multi-device/multi-node models using hybrid data-parallelism and model-parallelism.
  • Optimized RecSys kernels powered by FBGEMM, including support for sparse and quantized operations.
  • A sharder which can partition embedding tables with a variety of different strategies including data-parallel, table-wise, row-wise, table-wise-row-wise, and column-wise sharding.
  • A planner which can automatically generate optimized sharding plans for models.
  • Pipelining to overlap dataloading device transfer (copy to GPU), inter-device communications (input_dist), and computation (forward, backward) for increased performance.
  • GPU inference support.
  • Common modules for RecSys, such as models and public datasets (Criteo & Movielens).

Please check the TorchRec announcement post here, video tutorial, install instructions here, test drive the feature through this tutorial here, and refer to the reference document here.

TorchAudio 0.11

TorchAudio: Building Blocks for Audio and Speech Processing

We published a paper, TorchAudio: Building Blocks for Audio and Speech Processing, describing the overview of the TorchAudio library. If you find TorchAudio useful for your research, please help us share with the community by citing our paper.

(Beta) RNN-T & (Prototype) Emformer Models and Recipes

Emformer is an efficient memory-transformer-based streaming acoustic model that has demonstrated state-of-the-art streaming automatic speech recognition (ASR) performance in low-latency, resource-constrained scenarios, such as on-device applications (citation:

The TorchAudio v0.11 release includes the following beta features:

  • Implementation of Emformer (docs)
  • Recurrent neural network transducer (RNN-T) streaming ASR model that uses Emformer for its transcription network (docs)
  • RNN-T beam search decoder with TorchScript support (docs)
  • LibriSpeech Emformer RNN-T training recipe (GitHub) and corresponding pre-trained streaming ASR inference pipeline (docs)

Also there are prototype features that are available from nightly builds or the main branch.

  • Training recipes trained on MuST-C and TED-LIUM3 datasets. (GitHub)
  • Pre-trained pipelines corresponding to the recipes. (docs)
  • Tutorial that steps through performing online speech recognition with RNN-T Emformer model. (docs)

Collectively, these features cover the full development lifecycle of a streaming ASR model, from definition through training and inference, and enable users to easily develop their own Emformer- and RNN-T-based models.

Special thanks to Yangyang Shi, Jay Mahadeokar, and Gil Keren for their code contributions and guidance.

(Beta) HuBERT Pretrain Model

The masked prediction training of HuBERT model requires the masked logits, unmasked logits, and feature norm as the outputs. The logits are for cross-entropy losses and the feature norm is for penalty loss. The release adds HuBERTPretrainModel and corresponding factory functions (hubert_pretrain_base, hubert_pretrain_large, and hubert_pretrain_xlarge) to enable training from scratch.

(Prototype) CTC Beam Search Decoder

In recent releases, TorchAudio has added support for ASR models fine-tuned on CTC loss. The addition of an inference time CTC beam search decoder enables running end-to-end ASR evaluation using TorchAudio utils.

The CTC decoder in TorchAudio supports customizable beam search decoding with lexicon constraint. It also has optional KenLM language model support.

For more details, please check out the API tutorial and documentation. This prototype feature is available through nightly builds.

(Prototype) Streaming API

TorchAudio started as simple audio I/O APIs that supplement PyTorch. With the recent addition of ASR models and training recipes, the project has received requests to support high-level application development.

Streaming API makes it easy to develop and test the model in online inference. It utilizes ffmpeg under the hood, and enables reading media from online services and hardware devices, decoding media in an incremental manner, and applying filters and preprocessing.

Please checkout the API tutorial and the documentation. There are also the streaming ASR tutorial and the device streaming ASR tutorial. This feature is available from nightly releases. Please refer to for how to install nightly builds.

TorchText 0.12

(Beta) RoBERTa and XLM-R Models

TorchText has added support for pre-trained RoBERTa and XLM-R models. It would allow users to train end-2-end Transformer Encoder based models on standard NLP tasks using TorchText.

More specifically:

  • The models are torchscriptable and hence can be employed for production use-cases.
  • The model APIs let users to easily attach custom task-specific heads with pre-trained encoders.
  • The API also comes equipped with data pre-processing transforms to match the pre-trained weights and model configuration.

We have added a tutorial to demonstrate SST-2 binary text classification task with pre-trained XLM-R base architecture.

For additional details on model APIs and usage examples, please refer to the documentation.

(Beta) byte-level BPE tokenizer

TorchText has added support for a Byte-Level BPE tokenizer, as used in GPT-2. This tokenizer is also used for tokenizing inputs to the pre-trained RoBERTa models described previously. In addition to the RoBERTa vocab, users can also load their own custom BPE vocab to use the tokenizer. Furthermore, the tokenizer is fully torchscriptable and hence can be employed for production use-cases. For additional details on model APIs and usage examples, please refer to the documentation.

(Beta) Text datasets backed by TorchData

TorchText has modernized its datasets by migrating from older-style Iterable Datasets to TorchData’s DataPipes. TorchData is a library that provides modular/composable primitives, allowing users to load and transform data in performant data pipelines.

These DataPipes work out-of-the-box with PyTorch DataLoader and would enable new functionalities like auto-sharding. Users can now easily do data manipulation and pre-processing using user-defined functions and transformations in a functional style programming. Datasets backed by DataPipes also enable standard flow-control like batching, collation, shuffling and bucketizing.

Collectively, DataPipes provides a comprehensive experience for data preprocessing and tensorization needs in a pythonic and flexible way for model training. We have added a tutorial to demonstrate data-processing pipelining using the modernized dataset for binary text-classification.

You can learn more about TorchData DataPipe APIs in its official documentation.

TorchVision 0.12

New Models

Four new model families have been released in the latest version along with pre-trained weights for their variants.

#1 Object Detection

FCOS is a popular, fully convolutional, anchor-free model for object detection. In this release we include a community-contributed model implementation as well as pre-trained weights. The model was trained on COCO train2017 and can be used as follows:

import torch
from torchvision import models

x = [torch.rand(3, 224, 224)]
fcos = models.detection.fcos_resnet50_fpn(pretrained=True).eval()
predictions =  fcos(x)

The box AP of the pre-trained model on COCO val2017 is 39.2 (see #4961 for more details).

We would like to thank Hu Ye and Zhiqiang Wang for contributing to the model implementation and initial training. This was the first community-contributed model in a long while, and given its success, we decided to use the learnings from this process and create a new model contribution guidelines.

#2 Optical Flow support and RAFT model

TorchVision now supports optical flow! Optical Flow models try to predict movement in a video: given two consecutive frames, the model predicts where each pixel of the first frame ends up in the second frame. Check out our new tutorial on Optical Flow!

We implemented a torchscript-compatible RAFT model with pre-trained weights (both normal and “small” versions), and added support for training and evaluating optical flow models. Our training scripts support distributed training across processes and nodes, leading to much faster training time than the original implementation. We also added 5 new optical flow datasets: Flying Chairs, Flying Things, Sintel, Kitti, and HD1K.

#3. Image Classification

Vision Transformer (ViT) and ConvNeXt are two popular architectures which can be used as image classifiers or as backbones for downstream vision tasks. In this release we include 8 pre-trained weights for their classification variants. The models were trained on ImageNet and can be used as follows:

import torch
from torchvision import models

x = torch.rand(1, 3, 224, 224)
vit = models.detection.vit_b_16(pretrained=True).eval()
convnext = models.detection.convnext_tiny(pretrained=True).eval()
predictions1 = vit(x)
predictions2 = convnext(x)

The accuracies of the pre-trained models obtained on ImageNet val are seen below:

Model Acc@1 Acc@5
vit_b_16 81.072 95.318
vit_b_32 75.912 92.466
vit_l_16 79.662 94.638
vit_l_32 76.972 93.07
convnext_tiny 82.52 96.146
convnext_small 83.616 96.65
convnext_base 84.062 96.87
convnext_large 84.414 96.976

The above models have been trained using an adjusted version of our new training recipe and this allows us to offer models with accuracies significantly higher than the ones on the original papers.

#4. GPU Video Decoding

In this release, we add support for GPU video decoding in the video reading API. To use hardware-accelerated decoding, we just need to pass a cuda device to the video reading API as shown below:

import torchvision

reader =, device="cuda:0")
for frame in reader:

We also support seeking to anyframe or a keyframe in the video before reading, as shown below:

New Datasets

We have implemented 14 new classification datasets: CLEVR, GTSRB, FER2013, SUN397, Country211, Flowers102, fvgc_aircraft, OxfordIIITPet, DTD, Food 101, Rendered SST2, Stanford cars, PCAM, and EuroSAT.

As part of our work on Optical Flow support (see above for more details), we also added 5 new optical flow datasets: Flying Chairs, Flying Things, Sintel, Kitti, and HD1K.

Other Updates

  • New documentation layout: Each function / class is now documented in a separate page, clearing up some space in the per-module pages, and easing the discovery of the proposed APIs. Compare e.g. our previous docs vs the new ones. Please let us know if you have any feedback!
  • New model contribution guidelines have been published following the success of the FCOS model which was contributed by the community. These guidelines aim to be an overview of the model contribution process for anyone who would like to suggest, implement and train a new model.
  • Upcoming Prototype API – We are currently working on a prototype API which adds Multi-weight support on all of our model builder methods. This will enable us to offer multiple pre-trained weights, associated with their meta-data and inference transforms. The API is still under review and thus was not included in the release but you can read more about it on our blogpost and provide your feedback on the dedicated Github issue.
  • Changes in our deprecation policy – Up until now, torchvision would almost never remove deprecated APIs. In order to be more aligned and consistent with pytorch core, we are updating our deprecation policy. We are now following a 2-release deprecation cycle: deprecated APIs will raise a warning for 2 versions, and will be removed after that. To reflect these changes and to smooth the transition, we have decided to:

    • Remove all APIs that had been deprecated before or on v0.8, released 1.5 years ago.
    • Update the removal timeline of all other deprecated APIs to v0.14, to reflect the new 2-cycle policy starting now in v0.12.

Captum 0.5

Captum is a PyTorch library for model interpretability. For this release, we expanded Captum with influential instances and added support for both similarity based influences and novel algorithms, TracIn and its variants. TracIn variants offer faster approximation of influence scores based on random projections for fully connected layers.

More specifically the new, influence, subsection of Captum includes:

  • SimilarityInfluence computes similarity scores between test and training examples using default (cosine or euclidean) or custom user definite metrics w.r.t. given input model layers.
  • TracInCP approximates the influential score of each training example on a given test example based on the dot-product similarity between loss gradients w.r.t. model parameters for test and training examples. Note that if we use training examples as test examples then we compute self influence. This method and its variants described below also return top-k proponents and opponents which are the top-k largest positive and negative influential examples respectively.
  • TracInCPFast is an approximation of TracInCP that avoids computing the gradients w.r.t. large parameter matrices. It approximates influence score based on the dot products between last fully connected layer activations and loss gradients w.r.t. that layer for training and test examples.
  • TracInCPFastRandProj uses a nearest neighbor approximation library such as annoy to compute the dot product between the training and test quantities. In order to reduce the dimensionality of layer activations and corresponding gradients this method, in addition, allows to project those vectors into a lower dimensional space using random projection matrices.

More about the implementation of influential instances can be found on our GitHub page and tutorials.

Thanks for reading, If you’re interested in these updates and want to join the PyTorch community, we encourage you to join the discussion forums and open GitHub issues. To get the latest news from PyTorch, follow us on Twitter, Medium, YouTube, and LinkedIn.


Team PyTorch

Read More