PyTorch 1.7 released w/ CUDA 11, New APIs for FFTs, Windows support for Distributed training and more

Today, we’re announcing the availability of PyTorch 1.7, along with updated domain libraries. The PyTorch 1.7 release includes a number of new APIs including support for NumPy-Compatible FFT operations, profiling tools and major updates to both distributed data parallel (DDP) and remote procedure call (RPC) based distributed training. In addition, several features moved to stable including custom C++ Classes, the memory profiler, extensions via custom tensor-like objects, user async functions in RPC and a number of other features in torch.distributed such as Per-RPC timeout, DDP dynamic bucketing and RRef helper.

A few of the highlights include:

  • CUDA 11 is now officially supported with binaries available at PyTorch.org
  • Updates and additions to profiling and performance for RPC, TorchScript and Stack traces in the autograd profiler
  • (Beta) Support for NumPy compatible Fast Fourier transforms (FFT) via torch.fft
  • (Prototype) Support for Nvidia A100 generation GPUs and native TF32 format
  • (Prototype) Distributed training on Windows now supported
  • torchvision
    • (Stable) Transforms now support Tensor inputs, batch computation, GPU, and TorchScript
    • (Stable) Native image I/O for JPEG and PNG formats
    • (Beta) New Video Reader API
  • torchaudio
    • (Stable) Added support for speech rec (wav2letter), text to speech (WaveRNN) and source separation (ConvTasNet)

To reiterate, starting PyTorch 1.6, features are now classified as stable, beta and prototype. You can see the detailed announcement here. Note that the prototype features listed in this blog are available as part of this release.

Find the full release notes here.

Front End APIs

[Beta] NumPy Compatible torch.fft module

FFT-related functionality is commonly used in a variety of scientific fields like signal processing. While PyTorch has historically supported a few FFT-related functions, the 1.7 release adds a new torch.fft module that implements FFT-related functions with the same API as NumPy.

This new module must be imported to be used in the 1.7 release, since its name conflicts with the historic (and now deprecated) torch.fft function.

Example usage:

>>> import torch.fft
>>> t = torch.arange(4)
>>> t
tensor([0, 1, 2, 3])

>>> torch.fft.fft(t)
tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j])

>>> t = tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j])
>>> torch.fft.fft(t)
tensor([12.+16.j, -8.+0.j, -4.-4.j,  0.-8.j])

[Beta] C++ Support for Transformer NN Modules

Since PyTorch 1.5, we’ve continued to maintain parity between the python and C++ frontend APIs. This update allows developers to use the nn.transformer module abstraction from the C++ Frontend. And moreover, developers no longer need to save a module from python/JIT and load into C++ as it can now be used it in C++ directly.

[Beta] torch.set_deterministic

Reproducibility (bit-for-bit determinism) may help identify errors when debugging or testing a program. To facilitate reproducibility, PyTorch 1.7 adds the torch.set_deterministic(bool) function that can direct PyTorch operators to select deterministic algorithms when available, and to throw a runtime error if an operation may result in nondeterministic behavior. By default, the flag this function controls is false and there is no change in behavior, meaning PyTorch may implement its operations nondeterministically by default.

More precisely, when this flag is true:

  • Operations known to not have a deterministic implementation throw a runtime error;
  • Operations with deterministic variants use those variants (usually with a performance penalty versus the non-deterministic version); and
  • torch.backends.cudnn.deterministic = True is set.

Note that this is necessary, but not sufficient, for determinism within a single run of a PyTorch program. Other sources of randomness like random number generators, unknown operations, or asynchronous or distributed computation may still cause nondeterministic behavior.

See the documentation for torch.set_deterministic(bool) for the list of affected operations.

Performance & Profiling

[Beta] Stack traces added to profiler

Users can now see not only operator name/inputs in the profiler output table but also where the operator is in the code. The workflow requires very little change to take advantage of this capability. The user uses the autograd profiler as before but with optional new parameters: with_stack and group_by_stack_n. Caution: regular profiling runs should not use this feature as it adds significant overhead.

Distributed Training & RPC

[Stable] TorchElastic now bundled into PyTorch docker image

Torchelastic offers a strict superset of the current torch.distributed.launch CLI with the added features for fault-tolerance and elasticity. If the user is not be interested in fault-tolerance, they can get the exact functionality/behavior parity by setting max_restarts=0 with the added convenience of auto-assigned RANK and MASTER_ADDR|PORT (versus manually specified in torch.distributed.launch).

By bundling torchelastic in the same docker image as PyTorch, users can start experimenting with TorchElastic right-away without having to separately install torchelastic. In addition to convenience, this work is a nice-to-have when adding support for elastic parameters in the existing Kubeflow’s distributed PyTorch operators.

[Beta] Support for uneven dataset inputs in DDP

PyTorch 1.7 introduces a new context manager to be used in conjunction with models trained using torch.nn.parallel.DistributedDataParallel to enable training with uneven dataset size across different processes. This feature enables greater flexibility when using DDP and prevents the user from having to manually ensure dataset sizes are the same across different process. With this context manager, DDP will handle uneven dataset sizes automatically, which can prevent errors or hangs at the end of training.

[Beta] NCCL Reliability – Async Error/Timeout Handling

In the past, NCCL training runs would hang indefinitely due to stuck collectives, leading to a very unpleasant experience for users. This feature will abort stuck collectives and throw an exception/crash the process if a potential hang is detected. When used with something like torchelastic (which can recover the training process from the last checkpoint), users can have much greater reliability for distributed training. This feature is completely opt-in and sits behind an environment variable that needs to be explicitly set in order to enable this functionality (otherwise users will see the same behavior as before).

[Beta] TorchScript rpc_remote and rpc_sync

torch.distributed.rpc.rpc_async has been available in TorchScript in prior releases. For PyTorch 1.7, this functionality will be extended the remaining two core RPC APIs, torch.distributed.rpc.rpc_sync and torch.distributed.rpc.remote. This will complete the major RPC APIs targeted for support in TorchScript, it allows users to use the existing python RPC APIs within TorchScript (in a script function or script method, which releases the python Global Interpreter Lock) and could possibly improve application performance in multithreaded environment.

[Beta] Distributed optimizer with TorchScript support

PyTorch provides a broad set of optimizers for training algorithms, and these have been used repeatedly as part of the python API. However, users often want to use multithreaded training instead of multiprocess training as it provides better resource utilization and efficiency in the context of large scale distributed training (e.g. Distributed Model Parallel) or any RPC-based training application). Users couldn’t do this with with distributed optimizer before because we need to get rid of the python Global Interpreter Lock (GIL) limitation to achieve this.

In PyTorch 1.7, we are enabling the TorchScript support in distributed optimizer to remove the GIL, and make it possible to run optimizer in multithreaded applications. The new distributed optimizer has the exact same interface as before but it automatically converts optimizers within each worker into TorchScript to make each GIL free. This is done by leveraging a functional optimizer concept and allowing the distributed optimizer to convert the computational portion of the optimizer into TorchScript. This will help use cases like distributed model parallel training and improve performance using multithreading.

Currently, the only optimizer that supports automatic conversion with TorchScript is Adagrad and all other optimizers will still work as before without TorchScript support. We are working on expanding the coverage to all PyTorch optimizers and expect more to come in future releases. The usage to enable TorchScript support is automatic and exactly the same with existing python APIs, here is an example of how to use this:

import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer

with dist_autograd.context() as context_id:
  # Forward pass.
  rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
  rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
  loss = rref1.to_here() + rref2.to_here()

  # Backward pass.
  dist_autograd.backward(context_id, [loss.sum()])

  # Optimizer, pass in optim.Adagrad, DistributedOptimizer will
  # automatically convert/compile it to TorchScript (GIL-free)
  dist_optim = DistributedOptimizer(
     optim.Adagrad,
     [rref1, rref2],
     lr=0.05,
  )
  dist_optim.step(context_id)

[Beta] Enhancements to RPC-based Profiling

Support for using the PyTorch profiler in conjunction with the RPC framework was first introduced in PyTorch 1.6. In PyTorch 1.7, the following enhancements have been made:

  • Implemented better support for profiling TorchScript functions over RPC
  • Achieved parity in terms of profiler features that work with RPC
  • Added support for asynchronous RPC functions on the server-side (functions decorated with rpc.functions.async_execution).

Users are now able to use familiar profiling tools such as with torch.autograd.profiler.profile() and with torch.autograd.profiler.record_function, and this works transparently with the RPC framework with full feature support, profiles asynchronous functions, and TorchScript functions.

[Prototype] Windows support for Distributed Training

PyTorch 1.7 brings prototype support for DistributedDataParallel and collective communications on the Windows platform. In this release, the support only covers Gloo-based ProcessGroup and FileStore.

To use this feature across multiple machines, please provide a file from a shared file system in init_process_group.

# initialize the process group
dist.init_process_group(
    "gloo",
    # multi-machine example:
    # init_method = "file://////{machine}/{share_folder}/file"
    init_method="file:///{your local file path}",
    rank=rank,
    world_size=world_size
)

model = DistributedDataParallel(local_model, device_ids=[rank])

Mobile

PyTorch Mobile supports both iOS and Android with binary packages available in Cocoapods and JCenter respectively. You can learn more about PyTorch Mobile here.

[Beta] PyTorch Mobile Caching allocator for performance improvements

On some mobile platforms, such as Pixel, we observed that memory is returned to the system more aggressively. This results in frequent page faults as PyTorch being a functional framework does not maintain state for the operators. Thus outputs are allocated dynamically on each execution of the op, for the most ops. To ameliorate performance penalties due to this, PyTorch 1.7 provides a simple caching allocator for CPU. The allocator caches allocations by tensor sizes and, is currently, available only via the PyTorch C++ API. The caching allocator itself is owned by client and thus the lifetime of the allocator is also maintained by client code. Such a client owned caching allocator can then be used with scoped guard, c10::WithCPUCachingAllocatorGuard, to enable the use of cached allocation within that scope.
Example usage:

#include <c10/mobile/CPUCachingAllocator.h>
.....
c10::CPUCachingAllocator caching_allocator;
  // Owned by client code. Can be a member of some client class so as to tie the
  // the lifetime of caching allocator to that of the class.
.....
{
  c10::optional<c10::WithCPUCachingAllocatorGuard> caching_allocator_guard;
  if (FLAGS_use_caching_allocator) {
    caching_allocator_guard.emplace(&caching_allocator);
  }
  ....
  model.forward(..);
}
...

NOTE: Caching allocator is only available on mobile builds, thus the use of caching allocator outside of mobile builds won’t be effective.

torchvision

[Stable] Transforms now support Tensor inputs, batch computation, GPU, and TorchScript

torchvision transforms are now inherited from nn.Module and can be torchscripted and applied on torch Tensor inputs as well as on PIL images. They also support Tensors with batch dimensions and work seamlessly on CPU/GPU devices:

import torch
import torchvision.transforms as T

# to fix random seed, use torch.manual_seed
# instead of random.seed
torch.manual_seed(12)

transforms = torch.nn.Sequential(
    T.RandomCrop(224),
    T.RandomHorizontalFlip(p=0.3),
    T.ConvertImageDtype(torch.float),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
)
scripted_transforms = torch.jit.script(transforms)
# Note: we can similarly use T.Compose to define transforms
# transforms = T.Compose([...]) and 
# scripted_transforms = torch.jit.script(torch.nn.Sequential(*transforms.transforms))

tensor_image = torch.randint(0, 256, size=(3, 256, 256), dtype=torch.uint8)
# works directly on Tensors
out_image1 = transforms(tensor_image)
# on the GPU
out_image1_cuda = transforms(tensor_image.cuda())
# with batches
batched_image = torch.randint(0, 256, size=(4, 3, 256, 256), dtype=torch.uint8)
out_image_batched = transforms(batched_image)
# and has torchscript support
out_image2 = scripted_transforms(tensor_image)

These improvements enable the following new features:

  • support for GPU acceleration
  • batched transformations e.g. as needed for videos
  • transform multi-band torch tensor images (with more than 3-4 channels)
  • torchscript transforms together with your model for deployment
    Note: Exceptions for TorchScript support includes Compose, RandomChoice, RandomOrder, Lambda and those applied on PIL images, such as ToPILImage.

[Stable] Native image IO for JPEG and PNG formats

torchvision 0.8.0 introduces native image reading and writing operations for JPEG and PNG formats. Those operators support TorchScript and return CxHxW tensors in uint8 format, and can thus be now part of your model for deployment in C++ environments.

from torchvision.io import read_image

# tensor_image is a CxHxW uint8 Tensor
tensor_image = read_image('path_to_image.jpeg')

# or equivalently
from torchvision.io import read_file, decode_image
# raw_data is a 1d uint8 Tensor with the raw bytes
raw_data = read_file('path_to_image.jpeg')
tensor_image = decode_image(raw_data)

# all operators are torchscriptable and can be
# serialized together with your model torchscript code
scripted_read_image = torch.jit.script(read_image)

[Stable] RetinaNet detection model

This release adds pretrained models for RetinaNet with a ResNet50 backbone from Focal Loss for Dense Object Detection.

[Beta] New Video Reader API

This release introduces a new video reading abstraction, which gives more fine-grained control of iteration over videos. It supports image and audio, and implements an iterator interface so that it is interoperable with other the python libraries such as itertools.

from torchvision.io import VideoReader

# stream indicates if reading from audio or video
reader = VideoReader('path_to_video.mp4', stream='video')
# can change the stream after construction
# via reader.set_current_stream

# to read all frames in a video starting at 2 seconds
for frame in reader.seek(2):
    # frame is a dict with "data" and "pts" metadata
    print(frame["data"], frame["pts"])

# because reader is an iterator you can combine it with
# itertools
from itertools import takewhile, islice
# read 10 frames starting from 2 seconds
for frame in islice(reader.seek(2), 10):
    pass
    
# or to return all frames between 2 and 5 seconds
for frame in takewhile(lambda x: x["pts"] < 5, reader):
    pass

Notes:

  • In order to use the Video Reader API beta, you must compile torchvision from source and have ffmpeg installed in your system.
  • The VideoReader API is currently released as beta and its API may change following user feedback.

torchaudio

With this release, torchaudio is expanding its support for models and end-to-end applications, adding a wav2letter training pipeline and end-to-end text-to-speech and source separation pipelines. Please file an issue on github to provide feedback on them.

[Stable] Speech Recognition

Building on the addition of the wav2letter model for speech recognition in the last release, we’ve now added an example wav2letter training pipeline with the LibriSpeech dataset.

[Stable] Text-to-speech

With the goal of supporting text-to-speech applications, we added a vocoder based on the WaveRNN model, based on the implementation from this repository. The original implementation was introduced in “Efficient Neural Audio Synthesis”. We also provide an example WaveRNN training pipeline that uses the LibriTTS dataset added to torchaudio in this release.

[Stable] Source Separation

With the addition of the ConvTasNet model, based on the paper “Conv-TasNet: Surpassing Ideal Time-Frequency Magnitude Masking for Speech Separation,” torchaudio now also supports source separation. An example ConvTasNet training pipeline is provided with the wsj-mix dataset.

Cheers!

Team PyTorch

Read More