PyTorch 2.4 Release Blog

We are excited to announce the release of PyTorch® 2.4 (release note)! PyTorch 2.4 adds support for the latest version of Python (3.12) for torch.compile. AOTInductor freezing gives developers running AOTInductor more performance-based optimizations by allowing the serialization of MKLDNN weights. As well, a new default TCPStore server backend utilizing libuv has been introduced which should significantly reduce initialization times for users running large-scale jobs. Finally, a new Python Custom Operator API makes it easier than before to integrate custom kernels into PyTorch, especially for torch.compile.

This release is composed of 3661 commits and 475 contributors since PyTorch 2.3. We want to sincerely thank our dedicated community for your contributions. As always, we encourage you to try these out and report any issues as we improve 2.4. More information about how to get started with the PyTorch 2-series can be found at our Getting Started page.

Beta Prototype Performance Improvements
Python 3.12 support for torch.compile FSDP2: DTensor-based per-parameter-sharding FSDP torch.compile optimizations for AWS Graviton (aarch64-linux) processors
AOTInductor Freezing for CPU torch.distributed.pipelining, simplified pipeline parallelism BF16 symbolic shape optimization in TorchInductor
New Higher-level Python Custom Operator API Intel GPU is available through source build Performance optimizations for GenAI projects utilizing CPU devices
Switching TCPStore’s default server backend to libuv

*To see a full list of public feature submissions click here.

Beta Features

[Beta] Python 3.12 support for torch.compile

torch.compile() previously only supported Python 3.8-3.11. Users can now optimize models with torch.compile() with Python 3.12.

[Beta] AOTInductor Freezing for CPU

This feature enables users to turn on the freezing flag when using AOTInductor on CPU. With this feature, AOTInductor can cover the same set of op scenarios and reach on-par performance as Inductor CPP backend. Before this support, when models contain MKLDNN operators (when computation-intensive operators are involved, such as Convolution, Linear, ConvTranspose, and so on) and freezing is on, those models will fail to run since AOTInductor didn’t support serializing the MKLDNN weights which have an opaque format.

The workflow is as explained in the AOTInductor tutorial, in addition to that users could now add the freezing flag to get better performance:

export TORCHINDUCTOR_FREEZING=1

[Beta] New Higher-level Python Custom Operator API

We’ve added a new higher-level Python Custom Operator API that makes it easier than before to extend PyTorch with custom operators that behave like PyTorch’s built-in operators. Operators registered using the new high-level torch.library APIs are guaranteed to be compatible with torch.compile and other PyTorch subsystems; authoring a custom operator in Python using the previous low-level torch.library APIs required deep understanding of PyTorch internals and has many footguns.

Please see the tutorial for more information.

[Beta] Switching TCPStore’s default server backend to libuv

Introduced a new default server backend for TCPStore built with libuv which should introduce significantly lower initialization times and better scalability. This should ideally benefit users with a much shorter startup time when accounting for large-scale jobs.

For more information on the motivation + fallback instructions please refer to this tutorial.

Prototype Features

[PROTOTYPE] FSDP2: DTensor-based per-parameter-sharding FSDP

FSDP2 is a new fully sharded data parallelism implementation that uses dim-0 per-parameter sharding to resolve fundamental composability challenges with FSDP1’s flat-parameter sharding.

For more information regarding the motivation / design for FSDP2 please refer to the RFC on Github.

[PROTOTYPE] torch.distributed.pipelining, simplified pipeline parallelism

Pipeline Parallelism is one of the primitive parallelism techniques for deep learning. It allows the execution of a model to be partitioned such that multiple micro-batches can execute different parts of the model code concurrently.

torch.distributed.pipelining provides a toolkit that allows for easy implementation of pipeline parallelism on general models while also offering composability with other common PyTorch distributed features like DDP, FSDP, or tensor parallel.

For more information on this please refer to our documentation and tutorial.

Performance Improvements

torch.compile optimizations for AWS Graviton (aarch64-linux) processors

AWS optimized the PyTorch torch.compile feature for AWS Graviton3 processors. This optimization results in up to 2x better performance for Hugging Face model inference (based on geomean of performance improvement for 33 models) and up to 1.35x better performance for TorchBench model inference (geomean of performance improvement for 45 models) compared to the default eager mode inference across several natural language processing (NLP), computer vision (CV), and recommendation models on AWS Graviton3-based Amazon EC2 instances.

For more information regarding specific technical details please refer to the blog post.

BF16 symbolic shape optimization in TorchInductor

Pytorch users can now experience improved quality and performance gains with the beta BF16 symbolic shape support. While static shape may afford additional optimization opportunities compared to symbolic shape, it is insufficient for scenarios such as inference services with varying batch size and sequence length, or detection models with data-dependent output shape.

Verification using TorchBench, Huggingface, and timms_model shows a similar pass rate and comparable speedup with the BF16 static shape scenario. Combining the benefits of symbolic shape with BF16 AMX instructions hardware acceleration provided by Intel CPUs and general Inductor CPU backend optimizations applicable to both static and symbolic shape in PyTorch 2.4, the performance for BF16 symbolic shape has significantly improved compared to PyTorch 2.3.

The API to use this feature:

model = .
model.eval()
with torch.autocast(device_type=cpu, dtype=torch.bfloat16), torch.no_grad():
   compiled_model = torch.compile(model, dynamic=True)

Performance optimizations for GenAI projects utilizing CPU devices

Highlighting the enhanced performance of PyTorch on CPU, as demonstrated through the optimizations made for the “Segment Anything Fast” and “Diffusion Fast” project. However, only CUDA devices are supported in the model. We have incorporated CPU support into the projects, enabling users to leverage the increased power of CPU for running the project’s experiments. Meanwhile, we have employed a block-wise attention mask for SDPA as well, which can significantly reduce peak memory usage and improve performance. We have also optimized a series of layout propagation rules in Inductor CPU to improve performance.

To facilitate this, we have updated the README file. The API to use this feature is given below, simply providing --device cpu in the command lines:

  • For Segment Anything Fast:

    export SEGMENT_ANYTHING_FAST_USE_FLASH_4=0
    python run_experiments.py 16 vit_b <pytorch_github> <segment-anything_github>
    <path_to_experiments_data> --run-experiments --num-workers 32 --device cpu
    
  • For Diffusion Fast:

    python run_benchmark.py --compile_unet --compile_vae --enable_fused_projections --device=cpu
    

Users can follow the guidelines to run the experiments and observe the performance improvements firsthand, as well as explore the performance improvement trends across FP32 and BF16 data types.

Additionally, users can achieve good performance using torch.compile and SDPA. By observing the performance trends across these different factors, users can gain a deeper understanding of how various optimizations enhance PyTorch’s performance on CPU.

Read More

Deep Dive on the Hopper TMA Unit for FP8 GEMMs

Deep Dive on the Hopper TMA Unit for FP8 GEMMs

Abstract

The Hopper (H100) GPU architecture, billed as the “first truly asynchronous GPU”, includes a new, fully asynchronous hardware copy engine for bulk data movement between global and shared memory called Tensor Memory Accelerator (TMA). While CUTLASS has built-in support for TMA via its asynchronous pipeline paradigm, Triton exposes TMA support via an experimental API.

In this post, we provide a deeper dive into the details of how TMA works, for developers to understand the new async copy engine. We also show the importance of leveraging TMA for H100 kernels by building a TMA enabled FP8 GEMM kernel in Triton, which delivers from 1.4-2.2x performance gains over cuBLAS FP16 for small-to-medium problem sizes. Finally, we showcase key implementation differences between Triton and CUTLASS that may account for reports of performance regressions with TMA in Triton. We open source our implementation for reproducibility and review at https://github.com/pytorch-labs/applied-ai/tree/main/kernels

The throughput in TFLOPs of various Triton and cuBLAS FP8 and FP16 kernels, for M=M, N=4096, K=4096. The red line is the Triton TMA, which showcases the advantages of leveraging TMA.

Figure 1. The throughput in TFLOPs of various Triton and cuBLAS FP8 and FP16 kernels, for M=M, N=4096, K=4096. The red line is the Triton TMA, which showcases the advantages of leveraging TMA.

TMA Background

TMA is an H100 hardware addition that allows applications to asynchronously and bi-directionally transfer 1D-5D tensors between GPU global and shared memory. In addition, TMA can also transfer the same data to not just the calling SM’s shared memory, but to other SM’s shared memory if they are part of the same Thread Block Cluster. This is termed ‘multicast’.

TMA is very lightweight as only a single thread is needed to kick off a TMA transfer. By moving data directly from GMEM (global) to SMEM (shared), this avoids earlier GPU requirements of using registers for moving data between different memory spaces.

A100-style data movement vs H100 with TMA.  TMA hardware eliminates the need for a large amount of threads and registers participating in bulk data transfers.

Figure 2. A100-style data movement vs H100 with TMA. TMA hardware eliminates the need for a large amount of threads and registers participating in bulk data transfers. (Image credit Nvidia)

A single thread can issue large data movement instructions, allowing the majority of a given thread block to continue working on other instructions while data is in-flight. Combined with asynchronous pipelining, this allows memory transfers to be easily hidden and ensure the majority of any given thread block cluster can focus on computational task.

This lightweight invocation for data movement enables the creation of warp-group specialized kernels, where warp-groups take on different roles, namely producers and consumers. Producers elect a leader thread that fires off TMA requests, which are then asynchronously coordinated with the consumer (MMA) warp-groups via an arrival barrier. Consumers then process the data using warp-group MMA, and signal back to the producers when they have finished reading from the SMEM buffer and the cycle repeats.

Further, within threadblock clusters, producers can lower their max register requirements since they are only issuing TMA calls, and effectively transfer additional registers to MMA consumers, which helps to alleviate register pressure for consumers.

In addition, TMA handles the address computation for the shared memory destination where the data requested should be placed. This is why calling threads (producers) can be so lightweight.

To ensure maximum read access speed, TMA can lay out the arriving data based on swizzling instructions, to ensure the arriving data can be read as fast as possible by consumers, as the swizzling pattern helps avoid shared memory bank conflicts.

Finally for TMA instructions that are outgoing, or moving data from SMEM to GMEM, TMA can also include reduction operations (add/min/max) and bitwise (and/or) operations.

TMA usage in Triton

Pre-Hopper Load:

offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)
offs_k = tl.arange(0, block_k)

a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k[None, :]*stride_ak)
b_ptrs = b_ptr + (offs_k[:, None]*stride_bk + offs_bn[None, :]*stride_bn)

a = tl.load(a_ptrs)
b = tl.load(b_ptrs)

Figure 3. Traditional style bulk load from global to shared memory in Triton

In the above Triton example showing a pre-Hopper load, we see how the data for tensors a and b are loaded by each thread block computing global offsets (a_ptrs, b_ptrs) from their relevant program_id (pid_m, pid_n, k) and then making a request to move blocks of memory into shared memory for a and b.

Now let’s examine how to perform a load using TMA in Triton.

The TMA instruction requires a special data structure called a tensor map, in contrast to the above where we directly pass pointers to global memory. To build the tensor map, we first create a TMA descriptor on the CPU. The descriptor handles the creation of the tensor map by using the cuTensorMapEncode API. The tensor map holds metadata such as the global and shared memory layout of the tensor and serves as a compressed representation of the structure of the multi-dimensional tensor stored in global memory.

TMA address generation via a copy descriptor

Figure 4. TMA address generation via a copy descriptor (Image credit: Nvidia)

The TMA descriptor holds the tensor’s key properties:

  1. Base Pointer
  2. Shape and Block Size
  3. Datatype

The TMA descriptor is created on the host before the kernel, and then moved to device by passing the descriptor to a torch tensor. Thus, in Triton, the GEMM kernel receives a global pointer to the tensor map.

Triton Host Code

   desc_a = np.empty(TMA_SIZE, dtype=np.int8)
   desc_b = np.empty(TMA_SIZE, dtype=np.int8)
   desc_c = np.empty(TMA_SIZE, dtype=np.int8)

   triton.runtime.driver.active.utils.fill_2d_tma_descriptor(a.data_ptr(), m, k, block_m, block_k, a.element_size(), desc_a)

   triton.runtime.driver.active.utils.fill_2d_tma_descriptor(b.data_ptr(), n, k, block_n, block_k, b.element_size(), desc_b)

   triton.runtime.driver.active.utils.fill_2d_tma_descriptor(c.data_ptr(), m, n, block_m, block_n, c.element_size(), desc_c)
  
   desc_a = torch.tensor(desc_a, device='cuda')
   desc_b = torch.tensor(desc_b, device='cuda')
   desc_c = torch.tensor(desc_c, device='cuda')

This is the code that is used to set up the descriptors in the kernel invoke function.

Triton Device Code

Offsets/Pointer Arithmetic:

   offs_am = pid_m * block_m
   offs_bn = pid_n * block_n
   offs_k = 0

Load:

  a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [block_m, block_k], tl.float8e4nv)
  b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [block_n, block_k], tl.float8e4nv)

Store:

 tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])

We no longer need to calculate a pointer array for both load and store functions in the kernel. Instead, we pass a single descriptor pointer, the offsets, block size and the input datatype. This simplifies address calculation and reduces register pressure, as we no longer have to do complex pointer arithmetic in software and dedicate CUDA cores for address computation.

TMA Performance Analysis

Below, we discuss the PTX instructions for different load mechanisms on Hopper.

PTX for Loading Tile (cp.async) – H100 no TMA

add.s32 	%r27, %r100, %r8;
add.s32 	%r29, %r100, %r9;
selp.b32 	%r30, %r102, 0, %p18;


@%p1 cp.async.cg.shared.global [ %r27 + 0 ], [ %rd20 + 0 ], 0x10, %r30;
@%p1 cp.async.cg.shared.global [ %r29 + 0 ], [ %rd21 + 0 ], 0x10, %r30;


cp.async.commit_group ;

Here, we observe the older cp.async instruction responsible for global memory copies. From the traces below we can see that both loads bypass the L1 cache. A major difference in the newer TMA load is that before tiles from A and B were ready to be consumed by the Tensor Core we would need to execute an ldmatrix instruction that operated on data contained in register files. On Hopper, the data can now be directly reused from shared memory.

H100 Memory Chart showing GMEM Throughput = 910.22 GB/s

Figure 5. H100 Memory Chart showing GMEM Throughput = 910.22 GB/s (Triton GEMM without TMA) for M=128, N=4096, K=4096

By leveraging TMA through the Triton API changes we mentioned above, we can investigate the PTX that Triton generates for a single 2D tile load with TMA.

PTX for Loading Tile (cp.async.bulk.tensor) – H100 using TMA

bar.sync 	0;
shr.u32 	%r5, %r4, 5;
shfl.sync.idx.b32	%r66, %r5, 0, 31, -1;

elect.sync _|%p7, 0xffffffff;


add.s32 	%r24, %r65, %r67;
shl.b32 	%r25, %r66, 7;

@%p8
cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [%r24], [%rd26, {%r25,%r152}], [%r19];

The cp.async.bulk.tensor.2d.shared TMA instruction is passed the destination address in shared memory, a pointer to the tensor map, the tensor map coordinates and a pointer to the mbarrier object, respectively.

H100 Memory Chart GMEM Throughput =1.45 TB/s

Figure 6. H100 Memory Chart GMEM Throughput =1.45 TB/s (Triton GEMM with TMA) for M=128, N=4096, K=4096

For optimal performance we tuned the TMA GEMM kernel extensively. Amongst other parameters such as tile sizes, number of warps and number of pipeline stages, the biggest increase in memory throughput was observed when we increased the TMA_SIZE (descriptor size) from 128 to 512. From the above NCU profiles, we can see that the final tuned kernel has increased global memory transfer throughput from 910 GB/s to 1.45 TB/s, a 59% increase in GMEM throughput, over the non-TMA Triton GEMM kernel.

Comparison of CUTLASS and Triton FP8 GEMM and TMA Implementation – Kernel Architecture

Triton vs CUTLASS Ping-Pong FP8 GEMM TFLOPs, M=M, N=4096, K=4096

Figure 7. Triton vs CUTLASS Ping-Pong FP8 GEMM TFLOPs, M=M, N=4096, K=4096

The above chart shows the performance of a CUTLASS Ping-Pong GEMM kernel against Triton. The Ping-Pong kernel leverages TMA differently than Triton. It makes use of all of its HW and SW software capabilities, while Triton currently does not. Specifically, CUTLASS supports the below TMA features that help explain the performance gaps in pure GEMM performance:.

  1. TMA Multicast

    • Enables copy of data from GMEM to multiple SMs
  2. Warp Specialization

    • Enables warp groups within a threadblock to take on different roles
  3. Tensor Map (TMA Descriptor) Prefetch

    • Enables prefetching the Tensor Map object from GMEM, which allows pipelining of TMA loads

To put the performance numbers in perspective, below we show a ‘speed-up’ chart highlighting the latency differences on a percentage basis:

% Speedup of CUTLASS Ping-Pong vs Triton FP8 with TMA.

Figure 8: % Speedup of CUTLASS Ping-Pong vs Triton FP8 with TMA.

This speedup is purely kernel throughput, not including E2E launch overhead which we will discuss below.

TMA Descriptor movement – a key difference between Triton and CUTLASS with E2E performance implications

As noted previously, creation of a 2D+ dimensional TMA descriptor takes place on the host and is then transferred to the device. However, this transfer process takes place very differently depending on the implementation.

Here we showcase the differences between how Triton transfers TMA descriptors compared with CUTLASS.

Recall, TMA transfers require a special data structure, a tensor map to be created on CPU through the cuTensorMap API, which for an FP8 GEMM Kernel means creating three descriptors, one for each A, B and C. We see below that for both the Triton and CUTLASS Kernels the same CPU procedures are invoked.

Calls to cuTensorMapEncodeTiled (Both Triton and CUTLASS use this path)

Figure 7. Calls to cuTensorMapEncodeTiled (Both Triton and CUTLASS use this path)

However, for Triton, each descriptor is transferred in its own distinct copy kernel, which adds a significant amount of overhead and serves as a barrier to use this kernel in an end-to-end use inference scenario.

Three H2D Copy Kernels are launched before the kernel execution, for A, B and C

Figure 8. Three H2D Copy Kernels are launched before the kernel execution, for A, B and C

These copies are not observed in the CUTLASS implementation, due to the way that TMA descriptors are passed to the kernel. We can see from the PTX below that with Cutlass, tensor maps are passed-by-value to the kernel.

.entry _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_6half_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEENS7_ILi128EEES9_EEENS6_IJNS7_ILi2EEENS7_ILi1EEESC_EEENS_4gemm32KernelTmaWarpSpecializedPingpongENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE(

.param .align 64 .b8 _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_6half_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEENS7_ILi128EEES9_EEENS6_IJNS7_ILi2EEENS7_ILi1EEESC_EEENS_4gemm32KernelTmaWarpSpecializedPingpongENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE_param_0[1024]


mov.b64 	%rd110, _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_10bfloat16_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEES8_NS7_ILi256EEEEEENS6_IJNS7_ILi1EEESB_SB_EEENS_4gemm24KernelTmaWarpSpecializedENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE_param_0;

add.s64 	%rd70, %rd110, 704;
cvta.param.u64 	%rd69, %rd70;

cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%rd69, {%r284, %r283}], [%r1880];

Figure 9. CUTLASS kernel PTX showing pass-by-value

By directly passing the TMA Descriptor as opposed to passing a global memory pointer, the CUTLASS kernel avoids the three extra H2D copy kernels and instead these copies are included in the single device kernel launch for the GEMM.

Because of the difference in how descriptors are moved to the device, the kernel latencies including the time to prepare the tensors to be consumed by the TMA is drastically different. For M=1-128, N=4096, K=4096 the CUTLASS pingpong kernel has an average latency of 10us Triton TMA kernels complete in an average of 4ms. This is a factor of ~3330x slower and appears to be directly linked to the 3 independent kernel launches for TMA descriptor transfer by Triton.

Cuda graphs may be one way to reduce this, but given the overhead created by the H2D copies the current Triton implementation when measured end to end is not competitive. A rework of how the Triton compiler manages TMA descriptors would likely resolve this gap. We thus focused on comparing the actual compute kernel throughput and not E2E in our data above.

Results Summary

Triton FP8 TMA GEMM TFLOPs Comparison

Figure 10. Triton FP8 TMA GEMM TFLOPs Comparison

M Triton TMA Triton Tutorial Triton SplitK cuBLAS FP8 cuBLAS FP16 CUTLASS Ping-Pong FP8
1 2.5 1 2.4 1.5 1.8 3.57
2 5.1 2.5 4.8 3.1 3.6 5.9
4 10.3 7.21 9.6 6.1 7.2 14.3
8 21.0 16.5 19.2 12.3 14.4 28.6
16 44.5 41.0 37.2 24.5 27.7 55.1
32 89.7 81.2 72.2 71.6 56.8 114.4
64 178.5 163.7 130.8 144.6 105.3 228.7
128 359.7 225.9 160.1 244.0 189.2 377.7

Figure 11. Triton FP8 TMA GEMM TFLOPs Comparison Table

The above chart and table summarize the gain we’ve been able to achieve on a single NVIDIA H100 for FP8 GEMM, by leveraging the TMA Hardware Unit, over non-TMA Triton kernels and high performance CUDA (cuBLAS) kernels. The key point to note is this kernel’s superior scaling (with the batch size) properties over the competition. The problem sizes we benchmarked on are representative of the matrix shapes found in small-to-medium batch size LLM inference. Thus, TMA GEMM kernel performance in the mid-M regime (M=32 to M=128) will be critical for those interested in leveraging this kernel for FP8 LLM deployment use cases, as the FP8 compressed data type can allow larger matrices to fit in GPUs memory.

To summarize our analysis, the TMA implementation in Triton and CUTLASS differ in terms of full featureset support (multicast, prefetch etc.) and how the TMA Descriptor is passed to the GPU kernel. If this descriptor is passed in a manner that more closely matches the CUTLASS kernel (pass-by-value), the extraneous H2D copies could be avoided and thus the E2E performance would be greatly improved.

Future Work

For future research, we plan to improve upon these results, by working with the community to incorporate the CUTLASS architecture of TMA loads into Triton as well as investigating the Cooperative Kernel for FP8 GEMM, a modified strategy to the Ping-Pong Kernel.

In addition, once features like thread block clusters and TMA atomic operations are enabled in Triton, we may be able to get further speedups by leveraging the SplitK strategy in the TMA GEMM Kernel, as atomic operations on Hopper can be performed in Distributed Shared Memory (DSMEM) as opposed to L2 Cache. We also note the similarities of NVIDIA Hopper GPUs with other AI hardware accelerators like Google’s TPU and IBM’s AIU which are dataflow architectures. On Hopper, data can now “flow” from GMEM to a network of connected SMs due to the additions of TMA, which we discussed extensively in this blog, and DSMEM, which we plan to cover in a future post.

Read More

FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision

FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision

Attention, as a core layer of the ubiquitous Transformer architecture, is a bottleneck for large language models and long-context applications. FlashAttention (and FlashAttention-2) pioneered an approach to speed up attention on GPUs by minimizing memory reads/writes, and is now used by most libraries to accelerate Transformer training and inference. This has contributed to a massive increase in LLM context length in the last two years, from 2-4K (GPT-3, OPT) to 128K (GPT-4), or even 1M (Llama 3). However, despite its success, FlashAttention has yet to take advantage of new capabilities in modern hardware, with FlashAttention-2 achieving only 35% utilization of theoretical max FLOPs on the H100 GPU. In this blogpost, we describe three main techniques to speed up attention on Hopper GPUs: exploiting asynchrony of the Tensor Cores and TMA to (1) overlap overall computation and data movement via warp-specialization and (2) interleave block-wise matmul and softmax operations, and (3) incoherent processing that leverages hardware support for FP8 low-precision.

We’re excited to release FlashAttention-3 that incorporates these techniques. It’s 1.5-2.0x faster than FlashAttention-2 with FP16, up to 740 TFLOPS, i.e., 75% utilization of H100 theoretical max FLOPS. With FP8, FlashAttention-3 reaches close to 1.2 PFLOPS, with 2.6x smaller error than baseline FP8 attention.

FlashAttention-3 is available at: https://github.com/Dao-AILab/flash-attention
Paper

FlashAttention Recap

FlashAttention is an algorithm that reorders the attention computation and leverages tiling and recomputation to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. We use tiling to load blocks of inputs from HBM (GPU memory) to SRAM (fast cache), perform attention with respect to that block, and update the output in HBM. By not writing the large intermediate attention matrices to HBM, we reduce the amount of memory reads/writes, which brings 2-4x wallclock time speedup.

Here we show a diagram of FlashAttention forward pass: with tiling and softmax rescaling, we operate by blocks and avoid having to read/write from HBM, while obtaining the correct output with no approximation.

math equations

New hardware features on Hopper GPUs – WGMMA, TMA, FP8

While FlashAttention-2 can achieve up to 70% theoretical max FLOPS on Ampere (A100) GPUs, it does not yet take advantage of new features on Hopper GPUs to maximize performance. We describe some of the new Hopper-specific features here, and why they are important.

1. WGMMA (Warpgroup Matrix Multiply-Accumulate). This new feature makes use of the new Tensor Cores on Hopper, with much higher throughput1 than the older mma.sync instruction in Ampere (image from the H100 white paper).

image from the H100 white paper

2. TMA (Tensor Memory Accelerator). This is a special hardware unit that accelerates the transfer of data between global memory and shared memory, taking care of all index calculation and out-of-bound predication. This frees up registers, which is a valuable resource to increase tile size and efficiency.

block diagram

3. Low-precision with FP8. This doubles the Tensor Core throughput (e.g. 989 TFLOPS with FP16 and 1978 TFLOPS with FP8), but trades off accuracy by using fewer bits to represent floating point numbers.

6x throughput

FlashAttention-3 makes use of all of these new features of Hopper, using powerful abstractions from NVIDIA’s CUTLASS library.

By rewriting FlashAttention to use these new features, we can already significantly speed it up (e.g., from 350 TFLOPS in FlashAttention-2 FP16 forward pass to around 540-570 TFLOPS). However, the asynchronous nature of the new instructions on Hopper (WGMMA and TMA) opens up additional algorithmic opportunities to overlap operations and thereby extract even greater performance. For this blogpost, we’ll explain two such techniques specific to attention. The generic technique of warp specialization, with separate producer and consumer warps doing TMA and WGMMA, is well-covered elsewhere in the context of GEMM and works the same here.

Asynchrony: Overlapping GEMM and Softmax

Why overlap?

Attention has GEMMs (those matmuls between Q and K and between attention probability P and V) and softmax as its two main operations. Why do we need to overlap them? Isn’t most of the FLOPS in the GEMMs anyway? As long as the GEMMs are fast (e.g., computed using WGMMA instructions), shouldn’t the GPU be going brrrr?

The problem is that non-matmul operations are much slower than matmul operations on modern accelerators. Special functions such as exponential (for the softmax) have even lower throughput than floating point multiply-add; they are evaluated by the multi-function unit, a unit separate from floating point multiply-add or matrix multiply-add. As an example, the H100 GPU SXM5 has 989 TFLOPS of FP16 matrix multiply, but only 3.9 TFLOPS (256x less throughput) for special functions2! For head dimension 128, there are 512x more matmul FLOPS than exponential, which means that exponential can take 50% of the time compared to matmul. The situation is even worse for FP8, where the matmul FLOPS are twice as fast yet exponential FLOPS stay the same speed. Ideally we want matmul and softmax to operate in parallel. While the Tensor Cores are busy with matmul, the multi-function units should be calculating exponential!

Inter-warpgroup overlapping with pingpong scheduling

The first and easiest way to overlap GEMM and softmax is to do nothing at all! The warp schedulers already try to schedule warps so that if some warps are blocked (e.g., waiting for GEMM results), other warps can run. That is, the warp schedulers do some of this overlapping for us, for free.

However, we can improve on this by doing some of the scheduling manually. As an example, if we have 2 warpgroups (labeled 1 and 2 – each warpgroup is a group of 4 warps), we can use synchronization barriers (bar.sync) so that warpgroup 1 first does its GEMMs (e.g., GEMM1 of one iteration and GEMM0 of the next iteration), and then warpgroup 2 does its GEMMs while warpgroup 1 does its softmax, and so on. This “pingpong” schedule is illustrated in the figure below, where the same color denotes the same iteration.

block chart

This would allow us to perform the softmax in the shadow of the GEMMs of the other warpgroup. Of course, this figure is just a caricature; in practice the scheduling is not really this clean. Nevertheless, pingpong scheduling can improve FP16 attention forward pass from around 570 TFLOPS to 620 TFLOPS (head dim 128, seqlen 8K).

Intra-warpgroup overlapping of GEMM and Softmax

Even within one warpgroup, we can have some part of softmax running while the GEMMs of that warpgroup is running. This is illustrated in this figure, where the same color denotes the same iteration.

block chart

This pipelining increases throughput from around 620 TFLOPS to around 640-660 TFLOPS for FP16 attention forward, at the cost of higher register pressure. We need more registers to hold both accumulators of the GEMMs, and the input/output of softmax. Overall, we find this technique to offer a favorable tradeoff.

Low-precision: reduce quantization error with incoherent processing

LLM activation can have outliers with much larger magnitude than the rest of the features. These outliers make it difficult to quantize, producing much larger quantization errors. We leverage incoherent processing, a technique used in the quantization literature (e.g. from QuIP) that multiplies the query and key with a random orthogonal matrix to “spread out” the outliers and reduce quantization error. In particular, we use the Hadamard transform (with random signs), which can be done per attention head in O(d log d) instead of O(d^2) time, where d is the head dimension. Since the Hadamard transform is memory-bandwidth bound, it can be fused with previous operations such as rotary embedding (also memory-bandwidth bound) “for free”.

In our experiment where Q, K, V are generated from a standard normal distribution but 0.1% of the entries have large magnitudes (to simulate outliers), we found that incoherent processing can reduce the quantization error by 2.6x. We show numerical error comparison in the table below. Please see the paper for details.

text diagram

Attention benchmark

We show some results with FlashAttention-3, and compare it to FlashAttention-2, as well as the implementation in Triton and cuDNN (both of which already use new hardware features of Hopper GPUs).

For FP16, we see about 1.6x-1.8x speedup over FlashAttention-2

speed charts

speed charts

For FP8, we can reach close to 1.2 PFLOPS!

speed charts

Discussion

This blogpost highlights some of the optimizations for FlashAttention available on Hopper GPUs. Other optimizations (e.g., variable length sequences, persistent kernel, and in-kernel transpose for FP8) are covered in the paper.

We have seen that designing algorithms that take advantage of the hardware they run on can bring significant efficiency gains and unlock new model capabilities such as long context. We look forward to future work on optimization for LLM inference, as well as generalizing our techniques to other hardware architectures.

We also look forward to FlashAttention-3 being integrated in a future release of PyTorch.

Notes

  1. Without the wgmma instruction, the older mma.sync instruction can only reach about ⅔ the peak throughput of Hopper Tensor Cores: https://arxiv.org/abs/2402.13499v1 

  2. The CUDA programming guide specifies that the throughput for special functions is 16 operations per streaming multiprocessor (SM) per clock cycle. We multiply 16 by 132 SMs and 1830 Mhz (clock speed used to calculate 989 TFLOPS of FP16 matmul) to get 3.9 TFLOPS 

Read More

Learn how to develop Android applications with ExecuTorch and Llama models

This blog is courtesy of the PyTorch team at Arm. More details can be found here.

Arm’s compute platform is delivering GenAI applications on phones, laptops, and servers. Cost, privacy, performance, security, and energy efficiency are just some of the reasons developers are investigating on-device AI.

A new Learning Path explaining how to leverage the capabilities of large language models (LLMs) on Android using ExecuTorch and XNNPACK is now available.

Here’s a summary of what you’ll learn:

  • Development Environment setup

    The Learning Path begins by guiding you through setting up your development environment, ensuring you have all the necessary tools installed, including Android Studio, the Android NDK, Java JDK, and Python.

  • ExecuTorch and XNNPACK

    You’ll learn about the core technologies: ExecuTorch, a framework for deploying PyTorch models to edge devices, and XNNPACK, a high-performance library for executing neural networks on Arm-based platforms.

  • Llama models

    The Learning Path explores Llama, a family of powerful LLMs, focusing specifically on the 8B Llama 3 model. You’ll learn about quantization techniques, which are essential for optimizing model size and performance on mobile devices.

  • Prepare Llama models for ExecuTorch

    You’ll be guided through the process of downloading, exporting, and evaluating Llama models, ensuring they are ready for deployment using ExecuTorch.

  • Check model performance on Android

    The Learning Path walks you through cross-compiling the Llama runner binary for Android, allowing you to test your model’s performance on your phone.

  • Build and run an Android Chat App

    Finally, you’ll learn how to build a native Android chat app using the LlamaDemo application from the ExecuTorch repository. This hands-on experience allows you to put your knowledge into practice and create a real-world application.

Explore this Learning Path if you want to learn how to leverage the power of LLMs on your Android phone, and gain expertise in tools for on-device machine learning.

Dig into the excitement of building Android chat apps and understand more about how they work on the Arm Developer Hub.

Read More

Accelerated PyTorch inference with torch.compile on AWS Graviton processors

Accelerated PyTorch inference with torch.compile on AWS Graviton processors

Summary

Originally PyTorch, used an eager mode where each PyTorch operation that forms the model is run independently as soon as it’s reached. PyTorch 2.0 introduced torch.compile to speed up PyTorch code over the default eager mode. In contrast to eager mode, the torch.compile pre-compiles the entire model into a single graph in a manner that’s optimal for running on a given hardware platform. AWS optimized the PyTorch torch.compile feature for AWS Graviton3 processors. This optimization results in up to 2x better performance for Hugging Face model inference (based on geomean of performance improvement for 33 models) and up to 1.35x better performance for TorchBench model inference (geomean of performance improvement for 45 models) compared to the default eager mode inference across several natural language processing (NLP), computer vision (CV), and recommendation models on AWS Graviton3-based Amazon EC2 instances. Starting with PyTorch 2.3.1, the optimizations are available in torch Python wheels and AWS Graviton PyTorch deep learning container (DLC).

In this blog post, we show how we optimized torch.compile performance on AWS Graviton3-based EC2 instances, how to use the optimizations to improve inference performance, and the resulting speedups.

Why torch.compile and what’s the goal?

In eager mode, operators in a model are run immediately as they are encountered. It’s easier to use, more suitable for machine learning (ML) researchers, and hence is the default mode. However, eager mode incurs runtime overhead because of redundant kernel launch and memory read overhead. Whereas in torch compile mode, operators are first synthesized into a graph, wherein one operator is merged with another to reduce and localize memory reads and total kernel launch overhead.

The goal for the AWS Graviton team was to optimize torch.compile backend for Graviton3 processors. PyTorch eager mode was already optimized for Graviton3 processors with Arm Compute Library (ACL) kernels using oneDNN (also known as MKLDNN). So, the question was, how to reuse those kernels in torch.compile mode to get the best of graph compilation and the optimized kernel performance together?

Results

The AWS Graviton team extended the torch inductor and oneDNN primitives that reused the ACL kernels and optimized compile mode performance on Graviton3 processors. Starting with PyTorch 2.3.1, the optimizations are available in the torch Python wheels and AWS Graviton DLC. Please see the Running an inference section that follows for the instructions on installation, runtime configuration, and how to run the tests.

To demonstrate the performance improvements, we used NLP, CV, and recommendation models from TorchBench and the most downloaded NLP models from Hugging Face across Question Answering, Text Classification, Token Classification, Translation, Zero-Shot Classification, Translation, Summarization, Feature Extraction, Text Generation, Text2Text Generation, Fill-Mask, and Sentence Similarity tasks to cover a wide variety of customer use cases.

We started with measuring TorchBench model inference latency, in milliseconds (msec), for the eager mode, which is marked 1.0 with a red dotted line in the following graph. Then we compared the improvements from torch.compile for the same model inference, the normalized results are plotted in the graph. You can see that for the 45 models we benchmarked, there is a 1.35x latency improvement (geomean for the 45 models).

PyTorch model inference performance improvement with torch.compile on AWS Graviton3-based c7g instance using TorchBench framework

Image 1: PyTorch model inference performance improvement with torch.compile on AWS Graviton3-based c7g instance using TorchBench framework. The reference eager mode performance is marked as 1.0. (higher is better)

Similar to the preceding TorchBench inference performance graph, we started with measuring the Hugging Face NLP model inference latency, in msec, for the eager mode, which is marked 1.0 with a red dotted line in the following graph. Then we compared the improvements from torch.compile for the same model inference, the normalized results are plotted in the graph. You can see that for the 33 models we benchmarked, there is around 2x performance improvement (geomean for the 33 models).

Hugging Face NLP model inference performance improvement with torch.compile on AWS Graviton3-based c7g instance using Hugging Face example scripts

Image 2: Hugging Face NLP model inference performance improvement with torch.compile on AWS Graviton3-based c7g instance using Hugging Face example scripts. The reference eager mode performance is marked as 1.0. (higher is better)

Running an inference

Starting with PyTorch 2.3.1, the optimizations are available in the torch Python wheel and in AWS Graviton PyTorch DLC. This section shows how to run inference in eager and torch.compile modes using torch Python wheels and benchmarking scripts from Hugging Face and TorchBench repos.

To successfully run the scripts and reproduce the speedup numbers mentioned in this post, you need an instance from the Graviton3 family (c7g/r7g/m7g/hpc7g) of hardware. For this post, we used the c7g.4xl (16 vcpu) instance. The instance, the AMI details, and the required torch library versions are mentioned in the following snippet.

Instance: c7g.4xl instance
Region: us-west-2
AMI: ami-05cc25bfa725a144a (Ubuntu 22.04/Jammy with 6.5.0-1017-aws kernel)

# Install Python
sudo apt-get update
sudo apt-get install -y python3 python3-pip

# Upgrade pip3 to the latest version
python3 -m pip install --upgrade pip

# Install PyTorch and extensions
python3 -m pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1

The generic runtime tunings implemented for eager mode inference are equally applicable for the torch.compile mode, so, we set the following environment variables to further improve the torch.compile performance on AWS Graviton3 processors.

# Enable the fast math GEMM kernels, to accelerate fp32 inference with bfloat16 gemm
export DNNL_DEFAULT_FPMATH_MODE=BF16

# Enable Linux Transparent Huge Page (THP) allocations,
# to reduce the tensor memory allocation latency
export THP_MEM_ALLOC_ENABLE=1

# Set LRU Cache capacity to cache the primitives and avoid redundant
# memory allocations
export LRU_CACHE_CAPACITY=1024

TORCHBENCH BENCHMARKING SCRIPTS

TorchBench is a collection of open source benchmarks used to evaluate PyTorch performance. We benchmarked 45 models using the scripts from the TorchBench repo. Following code shows how to run the scripts for the eager mode and the compile mode with inductor backend.

# Set OMP_NUM_THREADS to number of vcpus, 16 for c7g.4xl instance
export OMP_NUM_THREADS=16

# Install the dependencies
sudo apt-get install -y libgl1-mesa-glx
sudo apt-get install -y libpangocairo-1.0-0
python3 -m pip install psutil numpy transformers pynvml numba onnx onnxruntime scikit-learn timm effdet gym doctr opencv-python h5py==3.10.0 python-doctr 

# Clone pytorch benchmark repo
git clone https://github.com/pytorch/benchmark.git
cd benchmark
# PyTorch benchmark repo doesn't have any release tags. So,
# listing the commit we used for collecting the performance numbers
git checkout 9a5e4137299741e1b6fb7aa7f5a6a853e5dd2295

# Setup the models
python3 install.py 

# Colect eager mode performance using the following command. The results will be
# stored at .userbenchmark/cpu/metric-<timestamp>.json.
python3 run_benchmark.py cpu --model BERT_pytorch,hf_Bert,hf_Bert_large,hf_GPT2,hf_Albert,hf_Bart,hf_BigBird,hf_DistilBert,hf_GPT2_large,dlrm,hf_T5,mnasnet1_0,mobilenet_v2,mobilenet_v3_large,squeezenet1_1,timm_efficientnet,shufflenet_v2_x1_0,timm_regnet,resnet50,soft_actor_critic,phlippe_densenet,resnet152,resnet18,resnext50_32x4d,densenet121,phlippe_resnet,doctr_det_predictor,timm_vovnet,alexnet,doctr_reco_predictor,vgg16,dcgan,yolov3,pytorch_stargan,hf_Longformer,timm_nfnet,timm_vision_transformer,timm_vision_transformer_large,nvidia_deeprecommender,demucs,tts_angular,hf_Reformer,pytorch_CycleGAN_and_pix2pix,functorch_dp_cifar10,pytorch_unet --test eval --metrics="latencies,cpu_peak_mem"

# Collect torch.compile mode performance with inductor backend
# and weights pre-packing enabled. The results will be stored at
# .userbenchmark/cpu/metric-<timestamp>.json
python3 run_benchmark.py cpu --model BERT_pytorch,hf_Bert,hf_Bert_large,hf_GPT2,hf_Albert,hf_Bart,hf_BigBird,hf_DistilBert,hf_GPT2_large,dlrm,hf_T5,mnasnet1_0,mobilenet_v2,mobilenet_v3_large,squeezenet1_1,timm_efficientnet,shufflenet_v2_x1_0,timm_regnet,resnet50,soft_actor_critic,phlippe_densenet,resnet152,resnet18,resnext50_32x4d,densenet121,phlippe_resnet,doctr_det_predictor,timm_vovnet,alexnet,doctr_reco_predictor,vgg16,dcgan,yolov3,pytorch_stargan,hf_Longformer,timm_nfnet,timm_vision_transformer,timm_vision_transformer_large,nvidia_deeprecommender,demucs,tts_angular,hf_Reformer,pytorch_CycleGAN_and_pix2pix,functorch_dp_cifar10,pytorch_unet --test eval --torchdynamo inductor --freeze_prepack_weights --metrics="latencies,cpu_peak_mem"

On successful completion of the inference runs, the script stores the results in JSON format. The following is the sample output:

{
 "name": "cpu"
 "environ": {
     "pytorch_git_version": "d44533f9d073df13895333e70b66f81c513c1889"
  },
  
  "metrics": {
       "BERT_pytorch-eval_latency": 56.3769865,
       "BERT_pytorch-eval_cmem": 0.4169921875
  }
}

HUGGING FACE BENCHMARKING SCRIPTS

Google T5 Small Text Translation model is one of the around 30 Hugging Face models we benchmarked. We’re using it as a sample model to demonstrate how to run inference in eager and compile modes. The additional configurations and APIs required to run it in compile mode are highlighted in BOLD. Save the following script as google_t5_small_text_translation.py.

import argparse
from transformers import T5Tokenizer, T5Model
import torch
from torch.profiler import profile, record_function, ProfilerActivity
import torch._inductor.config as config
config.cpp.weight_prepack=True
config.freezing=True

def test_inference(mode, num_iter):
    tokenizer = T5Tokenizer.from_pretrained("t5-small")
    model = T5Model.from_pretrained("t5-small")

    input_ids = tokenizer(
        "Studies have been shown that owning a dog is good for you", return_tensors="pt"
    ).input_ids  # Batch size 1
    decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1

    if (mode == 'compile'):
        model = torch.compile(model)

    with torch.no_grad():
        for _ in range(50):
            outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)

        with profile(activities=[ProfilerActivity.CPU]) as prof:
            with record_function("model_inference"):
                for _ in range(num_iter):
                    outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)

    print(prof.key_averages().table(sort_by="self_cpu_time_total"))

def main() -> None:
    global m, args
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument(
        "-m",
        "--mode",
        choices=["eager", "compile"],
        default="eager",
        help="Which test to run.",
    )
    parser.add_argument(
        "-n",
        "--number",
        type=int,
        default=100,
        help="how many iterations to run.",
    )
    args = parser.parse_args()
    test_inference(args.mode, args.number)

if __name__ == "__main__":
    main()

Run the script with the following steps:

# Set OMP_NUM_THREADS to number of vcpus to 4 because
# the scripts are running inference in sequence, and
# they don't need large number of vcpus
export OMP_NUM_THREADS=4

# Install the dependencies
python3 -m pip install transformers

# Run the inference script in Eager mode
# using number of iterations as 1 just to show the torch profiler output
# but for the benchmarking, we used 1000 iterations.
python3 google_t5_small_text_translation.py -n 1 -m eager

# Run the inference script in torch compile mode
python3 google_t5_small_text_translation.py -n 1 -m compile

On successful completion of the inference runs, the script prints the torch profiler output with the latency breakdown for the torch operators. The following is the sample output from torch profiler:

# Torch profiler output for the eager mode run on c7g.xl (4vcpu)
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                aten::mm        40.71%      12.502ms        40.71%      12.502ms     130.229us            96  
         model_inference        26.44%       8.118ms       100.00%      30.708ms      30.708ms             1  
               aten::bmm         6.85%       2.102ms         9.47%       2.908ms      80.778us            36  
            aten::matmul         3.73%       1.146ms        57.26%      17.583ms     133.205us           132  
            aten::select         1.88%     576.000us         1.90%     583.000us       0.998us           584  
         aten::transpose         1.51%     464.000us         1.83%     563.000us       3.027us           186  
------------------------ ------------ ------------ ------------ ------------ ------------ -------------------
Self CPU time total: 30.708ms

# Torch profiler output for the compile mode run for the same model on the same instance
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
        mkldnn::_linear_pointwise        37.98%       5.461ms        45.91%       6.602ms      68.771us            96  
            Torch-Compiled Region        29.56%       4.251ms        98.53%      14.168ms      14.168ms             1  
                        aten::bmm        14.90%       2.143ms        21.73%       3.124ms      86.778us            36  
                     aten::select         4.51%     648.000us         4.62%     665.000us       1.155us           576  
                       aten::view         3.29%     473.000us         3.29%     473.000us       1.642us           288  
                      aten::empty         2.53%     364.000us         2.53%     364.000us       3.165us           115  
--------------------------------- ------------ ------------ ------------ ------------ ------------ --------------------
Self CPU time total: 14.379ms

Technical deep dive: What are the challenges and optimization details

Underpinning torch.compile are new technologies – TorchDynamo, AOTDispatcher, and TorchInductor.

TorchDynamo captures PyTorch programs safely using Python Frame Evaluation Hooks
AOTDispatcher overloads PyTorch’s autograd engine as a tracing autodiff for generating ahead-of-time backward traces.
TorchInductor is a deep learning compiler that generates fast code for multiple accelerators and backends.

The PyTorch compilation process source

Image 3: The PyTorch compilation process

When torch.compile is invoked, torch dynamo rewrites Python bytecode to extract sequences of PyTorch operations into an FX Graph, which is then compiled with inductor backend. For a typical inference scenario where the graph is frozen and gradient calculations are disabled, the inductor invokes platform specific optimizations like graph rewrite into more performant operators, operator fusion, and weights pre-packing.

However, on Graviton3, the inductor wasn’t able to perform any of those optimizations because there was no aarch64 backend defined. To fix this, we extended the inductor’s FX passes to pick oneDNN operators for linear layer compilation on Graviton3 processors with ACL backend. The code snippet for this follows:

packed_weight_op = (
    mkldnn._reorder_linear_weight
    if (is_bf16_weight or mkldnn._is_mkldnn_acl_supported())
                    
packed_linear_inputs: Tuple[Any, ...] = (input, packed_weight_node)
if is_bf16_weight or mkldnn._is_mkldnn_acl_supported():
    packed_linear_inputs += (bias, "none", [], "")
    packed_linear_op = mkldnn._linear_pointwise.default

After this was done, the FX pass was successful in compiling the matmul operators to linear_pointwise . The following snippet highlights the matmul operator in the original model:

 %attention_scores   : [num_users=1] = call_function[target=torch.matmul](args = (%query_layer, %transpose), kwargs = {})
 %attention_scores_1 : [num_users=1] = call_function[target=operator.truediv](args = (%attention_scores, 8.0), kwargs = {})
 %attention_scores_2 : [num_users=1] = call_function[target=operator.add](args = (%attention_scores_1, %extended_attention_mask_3), kwargs = {})

The following snippet highlights the linear_pointwise operator in the compiled graph:

%_linear_pointwise_default_140 : [num_users=2] = call_function[target=torch.ops.mkldnn._linear_pointwise.default](args = (%add_7, %_frozen_param278, %_frozen_param16, none, [], ), kwargs = {})
%mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%_linear_pointwise_default_140, 0.5), kwargs = {})
%mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%_linear_pointwise_default_140, 0.7071067811865476), kwargs = {})
%erf   : [num_users=1] = call_function[target=torch.ops.aten.erf.default](args = (%mul_6,), kwargs = {})
%add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%erf, 1), kwargs = {})

This completes the torch inductor changes required to compile the graph into optimized operators on AWS Graviton3 processors. Next comes the actual inference where the compiled graph is dispatched to be run. OneDNN with ACL was the backend we chose during the inductor compilation, so, the new operators were dispatched to oneDNN as expected, for example, mkldnn._linear_pointwise. However, due to gaps in oneDNN ACL primitives, the operators were run with C++ reference kernels instead of the optimized ACL kernels. Hence, the compile performance was still significantly behind the eager mode performance.

There were mainly three areas where oneDNN ACL primitives lack support for torch.compile mode. The following section talks about them in detail.

1 ACL primitives didn’t have support for weights in blocked layout

ACL primitives originally designed for eager mode supported weights only in the standard channels last (NHWC) format, without any pre-packing. Whereas weights pre-packing into blocked layout is one of the main optimizations in the inductor compilation passes where the weights are reordered into blocks specific to the runtime platform. This avoids the redundant and on-the-fly reorders when running the General Matrix Multiplication (GEMM), which otherwise would be the bottleneck for inference performance. But the ACL primitives didn’t have support for blocked layout and hence the operators were run with oneDNN C++ reference kernels instead.

2 Mixed precision primitives weren’t supported in oneDNN

AWS Graviton3 processors support bfloat16 MMLA instructions which can be used to accelerate fp32 inference with bfloat16 GEMM as a mixed precision compute. ACL supports bfloat16 mixed precision GEMM kernels, and are integrated into oneDNN as a fast math compute option for the existing fp32 operators. However, the fast math approach didn’t work for compile mode because of weights pre-packing optimization. The compile mode requires explicit mixed precision primitive implementation in oneDNN in order to use bfloat16 acceleration.

3 ACL primitives didn’t support fused kernels for some of the activation functions

In eager mode, operators are dispatched individually because the model is run independently as soon as it’s reached. Whereas in compile mode, operator fusion is another important optimization where the operators are fused for runtime efficiency. For example, Gaussian Error Linear Unit (GELU) is one of the most widely used activation functions in transformers-based neural network architectures. So, it’s typical to have a linear layer (with matrix multiplications) followed by GELU activation. As part of compiling the model into efficient operators, the torch inductor fuses matmul and GELU into a single linearpointwise+gelu operator. However, oneDNN ACL primitives didn’t have the support for fused kernels with GELU.

We addressed these gaps by extending oneDNN primitives to handle the additional layouts and new primitive definitions. The following sections talk about the optimizations in detail.

Optimization 1: Extended ACL primitives to accept weight tensors in blocked layout

We extended the ACL primitives to accept blocked layout in addition to the the standard NHWC format. The code snippet for this is as follows:

const bool is_weights_md_format_ok
                    = utils::one_of(weights_format_kind_received,
                      format_kind::any, format_kind::blocked);


const memory_desc_t weights_md_received = weights_md_;
acl_utils::reorder_to_weight_format(aip.wei_tensor_info,
             weights_md_, expected_weight_format, inner_dim, o_dim,
             remaining_dims, {});

ACL_CHECK_SUPPORT(
     (weights_format_kind_received == format_kind::blocked)
      && !(dnnl_memory_desc_equal(
      &weights_md_received, &weights_md_)),
      "specified blocked format not supported by ACL, use "
      "format_kind_t::any to find a supported blocked format for "
      "your platform");

Optimization 2: Defined new ACL primitives to handle mixed precision operators (weights in bfloat16 and activations in fp32)

We defined mixed precision primitive definitions and updated the existing oneDNN ACL fp32 primitives to handle bfloat16 tensors.

 /* With graph compilation, we are able to reorder and pre-pack the weights during the model load
  * and compilation phase itself so that redundant and on-the-fly reorders can be avoided.
  * This primitive definition is to support gemm fastmath mode for the compile scenario where src is
  * in fp32 and weights are in bf16
  */
 {{forward, f32, bf16, f32}, {
    CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t)
    nullptr,
 }},

Optimization 3: Disabled operator fusion pass in torch inductor

We bypassed the operator fusion pass in torch inductor so that the compiled graph doesn’t contain GELU fused operators. This is a temporary solution to enable ACL kernels in torch.compile. There is a work in progress to enable operator fusion pass for the future PyTorch releases. With this workaround, we were able to successfully dispatch the linear layer to ACL. As shown in the following torch.profiler output, the aten::addmm (one of the variants of the matmul operator) and aten::gelu in the original model (as highlighted in Image 4) was compiled to mkldnn::_linear_pointwise without gelu operator fusion (as highlighted in Image 5).

---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                aten::addmm        73.32%      46.543ms        74.49%      47.287ms     647.767us            73  
            model_inference         9.92%       6.296ms       100.00%      63.479ms      63.479ms             1  
                  aten::bmm         4.37%       2.776ms         5.46%       3.467ms     144.458us            24  
                aten::copy_         1.74%       1.102ms         1.74%       1.102ms       8.103us           136  
                 aten::gelu         1.50%     950.000us         1.50%     950.000us      79.167us            12  

Image 4: torch.profiler output for Hugging Face bert base model inference in Eager mode, showing addmm and gelu operators

 
-----------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                 Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                            mkldnn::_linear_pointwise        53.61%      15.529ms        57.53%      16.665ms     228.288us            73  
                                Torch-Compiled Region        36.95%      10.705ms        99.31%      28.769ms      28.769ms             1  
    aten::_scaled_dot_product_flash_attention_for_cpu         3.67%       1.064ms         4.43%       1.284ms     107.000us            12  
                                           aten::view         1.97%     572.000us         1.97%     572.000us       2.509us           228  
                                          aten::empty         1.38%     399.000us         1.38%     399.000us       3.270us           122 

Image 5: torch.profiler output for Hugging Face Bert base model inference in torch.compile mode, showing linear_pointwise operator without gelu fusion

Lastly, the gelu operator was compiled into erf (error function) and was dispatched to an inductor auto vectorization backend. The following snippets show the erf operator in the compiled graph and running it using libm.so.

%_linear_pointwise_default_140 : [num_users=2] = call_function[target=torch.ops.mkldnn._linear_pointwise.default](args = (%add_7, %_frozen_param278, %_frozen_param16, none, [], ), kwargs = {})
%mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%_linear_pointwise_default_140, 0.5), kwargs = {})
%mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%_linear_pointwise_default_140, 0.7071067811865476), kwargs = {})
%erf   : [num_users=1] = call_function[target=torch.ops.aten.erf.default](args = (%mul_6,), kwargs = {})
%add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%erf, 1), kwargs = {})

Image 6: snippet after post grad pass showing erf function in the compiled graph

 
     0.82%     0.40%  python3  libm.so.6            [.] erff32
     0.05%     0.00%  python3  libtorch_python.so   [.] torch::autograd::THPVariable_erf
     0.05%     0.00%  python3  libtorch_cpu.so      [.] at::_ops::erf::call

Image 7: Linux perf report showing erf dispatch to libm.so

With this work, we were able to optimize torch.compile performance on Graviton3 processors by using inductor graph compilation along with the oneDNN+ACL backend.

TorchBench enhancements

To demonstrate the torch.compile performance improvements on AWS Graviton3 processors, we extended TorchBench framework to add a new argument to enable graph freeze and weights pre-packing and disable torch auto grad for eval test mode. The code snippet for this is as follows:

parser.add_argument(
 "—freeze_prepack_weights",
 action='store_true',
 help="set to freeze the graph and prepack weights",
 )

if args.freeze_prepack_weights:
 torch._inductor.config.freezing=True
 torch._inductor.config.cpp.weight_prepack=True

Image 8: Added freeze_prepack_weights option for torchdynamo backend in TorchBench to demonstrate torch.compile performance improvements on AWS Graviton3 processors

We have upstreamed all the optimizations, and starting with PyTorch 2.3.1, these are supported in torch Python wheels and AWS Graviton PyTorch DLC.

What’s next

Next, we’re extending the torch inductor CPU backend support to compile Llama model, and adding support for fused GEMM kernels to enable torch inductor operator fusion optimization on AWS Graviton3 processors.

Conclusion

In this tutorial, we covered how we optimized torch.compile performance on AWS Graviton3-based EC2 instances, how to use the optimizations to improve PyTorch model inference performance, and demonstrated the resulting speedups. We hope that you will give it a try! If you need any support with ML software on Graviton, please open an issue on the AWS Graviton Technical Guide GitHub.

Acknowledgements

We would like to thank the PyTorch community for the baseline torch.compile framework and their continued efforts to optimize it further.

Author

Sunita Nadampalli is a Software Development Manager and AI/ML expert at AWS. She leads AWS Graviton software performance optimizations for AI/ML and HPC workloads. She is passionate about open source software development and delivering high-performance and sustainable software solutions for SoCs based on the Arm ISA.

Read More

Announcing Hacker Cup AI Track at NeurIPS 2024

The PyTorch team in partnership with Meta Hacker Cup, and Microsoft Research, are excited to announce the Hacker Cup AI Track at NeurIPS 2024. This will be the first AI track for the popular Meta Hacker Cup programming competition designed to assess the capabilities of Generative AI in performing autonomous code generation tasks. We aim to test the limits of AI in complex coding challenges and measure the performance gap between AI systems and human programmers. We will provide access to all Hacker Cup problems since 2011 alongside their respective solutions in a multimodal (image and text) format, and utilize the existing Hacker Cup infrastructure for competitor evaluation. Featuring both open evaluation, open model and open evaluation, closed model tracks, this competition invites diverse participation from research institutions of varied interests and resource constraints, including academic labs, AI startups, large technology companies, and AI enthusiasts. Our goal is to develop and democratize meaningful advancements in code automation with the very first open evaluation process for competitive AI programmers. Registration will begin in Early August, with our first qualification round on September 20th.

For more information please visit our website at https://www.facebook.com/codingcompetitions/hacker-cup/ and join our Discord at discord.gg/wWeN9hTH32

Read More

Powering the AI Revolution: The PyTorch Documentary

Powering the AI Revolution: The PyTorch Documentary

Now live: The official PyTorch Documentary! This film unveils the authentic narrative of PyTorch’s inception, attributing its existence to a dedicated group of unsung heroes driving technological innovation.

The documentary shares the strength of the PyTorch community, resonating with our communities across the globe. We hope this story of PyTorch inspires greater contributions, attracts more contributors to the project, and fosters widespread recognition of PyTorch’s significance in the open source community.

We couldn’t have produced this without the support of our PyTorch Foundation members and sponsors:

company logos

AMD

“PyTorch’s growth and adoption in the AI community is a testament to open collaboration. The collective efforts of all the contributors have helped propel PyTorch as one of the most widely adopted AI frameworks in the industry. AMD is proud to be a part of this movement – making sure that the future of AI is open – and we are excited to continue contributing to this vibrant ecosystem.”

– Niles Burbank, AMD

AWS

“The release of the PyTorch Documentary showcases the innovation and real-world impact of one of the most widely adopted open source machine learning frameworks. By supporting and contributing to the PyTorch community, AWS helps enable cutting-edge machine learning research that drives advancements in AI capabilities. We are excited about the documentary as it highlights the power of collaboration in propelling PyTorch to the forefront of machine learning and empowering developers and data scientists to create groundbreaking models. At AWS, we celebrate frameworks like PyTorch that foster environments where open source machine learning technologies can grow and benefit the community at-large, as well as our customers.”

– Brian Granger, AWS

Google Cloud

“Google recognizes the impact of PyTorch on the AI community, providing researchers and developers with powerful, flexible tools for innovation. This documentary not only celebrates the remarkable achievements of the PyTorch community but also highlights the collaborative spirit driving advancements in AI. We look forward to continuing our support for PyTorch and fostering an open ecosystem that accelerates machine learning research and application.”

– Dwarak Rajagopal, Google

Meta

“We have been so impressed with the growth and collaboration that PyTorch has created over the years. From very humble beginnings at Meta to a cornerstone in AI research and development, the documentary showcases the dedication of our contributors since the start. It’s an honor to be a part of something so impactful, and now it’s been documented for our community to take part in.”

– Soumith Chintala, Meta

Microsoft Azure

“We’re truly excited about the premiere of the PyTorch Documentary. At Microsoft, PyTorch has been our default deep learning framework for building AI solutions including Microsoft Copilot. Additionally, we have made significant investments to create an optimized environment for our customers to develop, train, fine-tune and deploy their PyTorch workloads on Azure and Windows, furthering our commitment to democratize AI.”

– Eric Boyd, Microsoft

PyTorch Foundation

“The release of the PyTorch documentary marks a significant milestone for our community, showcasing the incredible journey and rapid evolution of PyTorch. We are excited to share these stories and achievements with the world, and we look forward to continuing to foster innovation and growth of the PyTorch community and PyTorch’s evolving ecosystem.”

– Matt White, PyTorch Foundation

Read More

Training MoEs at Scale with PyTorch

Training MoEs at Scale with PyTorch

Over the past year, Mixture of Experts (MoE) models have surged in popularity, fueled by powerful open-source models like DBRX, Mixtral, DeepSeek, and many more. In this blog post, we’ll talk about how we scale to over three thousand GPUs using PyTorch Distributed and MegaBlocks, an efficient open-source MoE implementation in PyTorch.

What is a MoE?

A MoE model is a model architecture that uses multiple expert networks to make predictions. A gating network is used to route and combine the outputs of experts, ensuring each expert is trained on a different, specialized distribution of tokens. The architecture of a transformer-based large language model typically consists of an embedding layer that leads into multiple transformer blocks (Figure 1, Subfigure A). Each transformer block contains an attention block and a dense feed forward network (Figure 1, Subfigure B). These transformer blocks are stacked such that the output of one transformer block leads to the input of the next block. The final output goes through a fully connected layer and softmax to obtain probabilities for the next token to output.

When using a MoE in LLMs, the dense feed forward layer is replaced by a MoE layer which consists of a gating network and a number of experts (Figure 1, Subfigure D). The gating network, typically a linear feed forward network, takes in each token and produces a set of weights that determine which tokens are routed to which experts. The experts themselves are typically implemented as a feed forward network as well. During training, the gating network adapts to assign inputs to the experts, enabling the model to specialize and improve its performance. The router outputs are then used to weigh expert outputs to give the final output of the MoE layer.

Figure 1: Using Mixture of Experts in a transformer block

Figure 1: Using Mixture of Experts in a transformer block

Compared to dense models, MoEs provide more efficient training for a given compute budget. This is because the gating network only sends tokens to a subset of experts, reducing the computational load. As a result, the capacity of a model (its total number of parameters) can be increased without proportionally increasing the computational requirements. During inference, only some of the experts are used, so a MoE is able to perform faster inference than a dense model. However, the entire model needs to be loaded in memory, not just the experts being used.

The sparsity in MoEs that allows for greater computational efficiency comes from the fact that a particular token will only be routed to a subset of experts. The number of experts and how experts are chosen depends on the implementation of the gating network, but a common method is top k. The gating network first predicts a probability value for each expert, then routes the token to the top k experts to obtain the output. However, if all tokens always go to the same subset of experts, training becomes inefficient and the other experts end up undertrained. To alleviate this problem, a load balancing loss is introduced that encourages even routing to all experts.

The number of experts and choosing the top k experts is an important factor in designing MoEs. A higher number of experts allows scaling up to larger models without increasing computational cost. This means that the model has a higher capacity for learning, however, past a certain point the performance gains tend to diminish. The number of experts chosen needs to be balanced with the inference costs of serving the model since the entire model needs to be loaded in memory. Similarly, when choosing top k, a lower top k during training results in smaller matrix multiplications, leaving free computation on the table if communication costs are large enough. During inference, however, a higher top k generally leads to slower inference speed.

MegaBlocks

MegaBlocks is an efficient MoE implementation that uses sparse matrix multiplication to compute expert outputs in parallel despite uneven token assignment. MegaBlocks implements a dropless MoE that avoids dropping tokens while using GPU kernels that maintain efficient training. Prior to MegaBlocks, dynamic routing formulations forced a tradeoff between model quality and hardware efficiency. Previously, users had to either drop tokens from computation or waste computation and memory on padding. Experts can receive a variable number of tokens and the expert computation can be performed efficiently using block sparse matrix multiplication. We’ve integrated MegaBlocks into LLM Foundry to enable scaling MoE training to thousands of GPUs.

Figure 2: Matrix multiplication for expert computations

Figure 2: Matrix multiplication for expert computations

Expert Parallelism

As models scale to larger sizes and fail to fit on a single GPU, we require more advanced forms of parallelism. Expert parallelism is a form of model parallelism where we place different experts on different GPUs for better performance. Instead of expert weights being communicated across all GPUs, tokens are sent to the device that contains the expert. By moving data instead of weights, we can aggregate data across multiple machines for a single expert. The router determines which tokens from the input sequence should be sent to which experts. This is typically done by computing a gating score for each token-expert pair, and then routing each token to the top-scoring experts. Once the token-to-expert assignments are determined, an all-to-all communication step is performed to dispatch the tokens to the devices hosting the relevant experts. This involves each device sending the tokens assigned to experts on other devices, while receiving tokens assigned to its local experts.

The key advantage of expert parallelism is processing a few, larger matrix multiplications instead of several small matrix multiplications. As each GPU only has a subset of experts, it only has to do computation for those experts. Correspondly, as we aggregate tokens across multiple GPUs, the size of each matrix is proportionally larger. As GPUs are optimized for large-scale parallel computations, larger operations can better exploit their capabilities, leading to higher utilization and efficiency. A more in depth explanation of the benefits of larger matrix multiplications can be found here. Once the computation is complete, another all-to-all communication step is performed to send the expert outputs back to their original devices.

Figure 3: Token routing in expert parallelism

Figure 3: Token routing in expert parallelism

We leverage PyTorch’s DTensor, a low-level abstraction for describing how tensors are sharded and replicated, to effectively implement expert parallelism. We first manually place experts on different GPUs, typically sharding across a node to ensure we can leverage NVLink for fast GPU communication when we route tokens. We can then build a device mesh on top of this layout, which lets us succinctly describe the parallelism across the entire cluster. We can use this device mesh to easily checkpoint or rearrange experts when we need alternate forms of parallelism.

Scaling ZeRO-3 with PyTorch FSDP

In conjunction with expert parallelism, we use data parallelism for all other layers, where each GPU stores a copy of the model and optimizer and processes a different chunk of data. After each GPU has completed a forward and backward pass, gradients are accumulated across GPUs for a global model update.

ZeRO-3 is a form of data parallelism where weights and optimizers are sharded across each GPU instead of being replicated. Each GPU now only stores a subset of the full model, dramatically reducing memory pressure. When a part of the model is needed for computation, it is gathered across all the GPUs, and after the computation is complete, the gathered weights are discarded. We use PyTorch’s implementation of ZeRO-3, called Fully Sharded Data Parallel (FSDP).

As we scale to thousands of GPUs, the cost of communication across devices increases, slowing down training. Communication increases due to the need to synchronize and share model parameters, gradients, and optimizer states across all GPUs which involves all-gather and reduce-scatter operations. To mitigate this issue while keeping the benefits of FSDP, we utilize Hybrid Sharded Data Parallel (HSDP) to shard the model and optimizer across a set number of GPUs and replicate this multiple times to fully utilize the cluster. With HSDP, an additional all reduce operation is needed in the backward pass to sync gradients across replicas. This approach allows us to balance memory efficiency and communication cost during large scale distributed training. To use HSDP we can extend our previous device mesh from expert parallelism and let PyTorch do the heavy lifting of actually sharding and gathering when needed.

Figure 4: FSDP and HSDP

Figure 4: FSDP and HSDP

With PyTorch, we can effectively combine these two types of parallelism, leveraging FSDP’s higher level API while using the lower-level DTensor abstraction when we want to implement something custom like expert parallelism. We now have a 3D device mesh with expert parallel shard dimension, ZeRO-3 shard dimension, and a replicate dimension for pure data parallelism. Together, these techniques deliver near linear scaling across very large clusters, allowing us to achieve MFU numbers over 40%.

Elastic Checkpointing with Torch Distributed

Fault tolerance is crucial for ensuring that LLMs can be trained reliably over extended periods, especially in distributed environments where node failures are common. To avoid losing progress when jobs inevitably encounter failures, we checkpoint the state of the model, which includes parameters, optimizer states, and other necessary metadata. When a failure occurs, the system can resume from the last saved state rather than starting over. To ensure robustness to failures, we need to checkpoint often and save and load checkpoints in the most performant way possible to minimize downtime. Additionally, if too many GPUs fail, our cluster size may change. Accordingly, we need the ability to elastically resume on a different number of GPUs.

PyTorch supports elastic checkpointing through its distributed training framework, which includes utilities for both saving and loading checkpoints across different cluster configurations. PyTorch Distributed Checkpoint ensures the model’s state can be saved and restored accurately across all nodes in the training cluster in parallel, regardless of any changes in the cluster’s composition due to node failures or additions.

Additionally, when training very large models, the size of checkpoints may be very large, leading to very slow checkpoint upload and download times. PyTorch Distributed Checkpoint supports sharded checkpoints, which enables each GPU to save and load only its portion of the model. When combining sharded checkpointing with elastic training, each GPU reads the metadata file to determine which shards to download on resumption. The metadata file contains information on what parts of each tensor are stored in each shard. The GPU can then download the shards for its part of the model and load that part of the checkpoint.

Figure 5: Checkpointing saving and resumption resharded on additional GPUs

Figure 5: Checkpointing saving and resumption resharded on additional GPUs

By parallelizing checkpointing across GPUs, we can spread out network load, improving robustness and speed. When training a model with 3000+ GPUs, network bandwidth quickly becomes a bottleneck. We take advantage of the replication in HSDP to first download checkpoints on one replica and then send the necessary shards to other replicas. With our integration in Composer, we can reliably upload checkpoints to cloud storage as frequently as every 30 minutes and automatically resume from the latest checkpoint in the event of a node failure in less than 5 minutes.

Conclusion

We’re very excited to see how PyTorch is enabling training state-of-the-art LLMs with great performance. In our post, we’ve shown how we implemented efficient MoE training through Pytorch Distributed and MegaBlocks on Foundry. Furthermore, Pytorch elastic checkpointing allowed us to quickly resume training on a different number of GPUs when node failures occurred. Using Pytorch HSDP has allowed us to scale training efficiently as well as improve checkpointing resumption times. We look forward to continuing building on a strong and vibrant open-source community to help bring great AI models to everyone. Come join us in building great models at LLM Foundry and PyTorch.

Read More

🎉 PyTorch Docathon H2 2024 Wrap-up 🎉

We are thrilled to announce the successful completion of the H1 2024 PyTorch Docathon! The event was a resounding success, and we want to extend our heartfelt gratitude to all the participants who made it possible. Dedication, expertise, and tireless efforts of our open-source contributors have once again helped us to improve PyTorch documentation.

This Docathon ran from June 4 through June 20 with more than 176 registrants. The energy and enthusiasm were palpable, and entrants were judged on the difficulty of submissions that resulted in over 50 merged pull requests.

We want to give a special shout-out to our top contributors, who went above and beyond during this event. Your dedication and expertise have been invaluable in enhancing the PyTorch documentation and empowering developers worldwide.

Meet the top contributors

For the full list of participants, see here.

As we bring this Docathon to a close, we encourage each and every one of you to stay inspired and keep contributing to PyTorch documentation and code, and pushing the boundaries of what’s possible with PyTorch. Your collective efforts are shaping the landscape of deep learning and fostering innovation in the PyTorch community.

Thank you again for your participation and support. We look forward to seeing what you will achieve next!

Team PyTorch

Read More

Accelerating Neural Network Training with Semi-Structured (2:4) Sparsity

Accelerating Neural Network Training with Semi-Structured (2:4) Sparsity

Over the past year, we’ve added support for semi-structured (2:4) sparsity into PyTorch. With just a few lines of code, we were able to show a 10% end-to-end inference speedup on segment-anything by replacing dense matrix multiplications with sparse matrix multiplications.

However, matrix multiplications are not unique to neural network inference – they happen during training as well. By expanding on the core primitives we used earlier to accelerate inference, we were also able to accelerate model training. We wrote a replacement nn.Linear layer, SemiSparseLinear, that is able to achieve a 1.3x speedup across the forwards + backwards pass of the linear layers in the MLP block of ViT-L on a NVIDIA A100.

End-to-end, we see a wall time reduction of 6% for a DINOv2 ViT-L training, with virtually no accuracy degradation out of the box (82.8 vs 82.7 on ImageNet top-1 accuracy).

2 strategies for training a ViT model

We compare 2 strategies for training a ViT model for 125k iterations on 4x NVIDIA A100s: either fully dense (blue), or sparse for 70% of the training, then dense (orange). Both achieve similar results on the benchmarks, but the sparse variant trains 6% faster. For both experiments, we evaluate the intermediate checkpoints with and without sparsity.

As far as we are aware, this is the first OSS implementation of accelerated sparse training and we’re excited to provide a user API in torchao. You can try accelerating your own training runs with just a few lines of code:

# Requires torchao and pytorch nightlies and CUDA compute capability 8.0+
import torch
from torchao.sparsity.training import (
    SemiSparseLinear,
    swap_linear_with_semi_sparse_linear,
)

model = torch.nn.Sequential(torch.nn.Linear(1024, 4096)).cuda().half()

# Specify the fully-qualified-name of the nn.Linear modules you want to swap
sparse_config = {
    "seq.0": SemiSparseLinear
}

# Swap nn.Linear with SemiSparseLinear, you can run your normal training loop after this step
swap_linear_with_semi_sparse_linear(model, sparse_config)

How does this work?

The general idea behind sparsity is simple: skip calculations involving zero-valued tensor elements to speed up matrix multiplication. However, simply setting weights to zero isn’t enough, as the dense tensor still contains these pruned elements and dense matrix multiplication kernels will continue to process them, incurring the same latency and memory overhead. To achieve actual performance gains, we need to replace dense kernels with sparse kernels that intelligently bypass calculations involving pruned elements.

These kernels work on sparse matrices, which remove the pruned elements and store the specified elements in a compressed format. There are many different sparse formats, but we’re particularly interested in semi-structured sparsity, also known as 2:4 structured sparsity or fine-grained structured sparsity or more generally N:M structured sparsity.

2:4 sparse compressed representation

2:4 sparse compressed representation. Original Source

A 2:4-sparse matrix is a matrix where at most 2 elements are non-zero for every 4 elements, as illustrated in the image above. Semi-structured sparsity is attractive because it exists in a goldilocks spot of performance and accuracy:

  1. NVIDIA GPUs since Ampere offer hardware acceleration and library support (cuSPARSELt) for this format, with matrix multiplication being up to 1.6x faster
  2. Pruning models to fit this sparsity pattern does not degrade accuracy as much as other patterns. NVIDIA’s whitepaper shows pruning then retraining is able to recover accuracy for most vision models.

Illustration of 2:4 (sparse) matrix multiplication on NVIDIA GPUs

Illustration of 2:4 (sparse) matrix multiplication on NVIDIA GPUs. Original source

Accelerating inference with semi-structured sparsity is straightforward. Since our weights are fixed during inference, we can prune and compress the weight ahead of time (offline) and store the compressed sparse representation instead of our dense tensor.

flow chart

Then, instead of dispatching to dense matrix multiplication we dispatch to sparse matrix multiplication, passing in the compressed sparse weight instead of the normal dense one. For more information about accelerating models for inference using 2:4 sparsity, please refer to our tutorial.

Extending sparse inference acceleration to training

In order to use sparsity to reduce the training time of our models, we need to consider when the mask is calculated, as once we store the compressed representation the mask is fixed.

Training with a fixed mask applied to an existing trained dense model (also known as pruning) does not degrade accuracy, but this requires two training runs – one to obtain the dense model and another to make it sparse, offering no speedups.

Instead we’d like to train a sparse model from scratch (dynamic sparse training), but training from scratch with a fixed mask will lead to a significant drop in evaluations, as the sparsity mask would be selected at initialization, when the model weights are essentially random.

To maintain the accuracy of the model when training from scratch, we prune and compress the weights at runtime, so that we can calculate the optimal mask at each step of the training process.

Conceptually you can think of our approach as an approximate matrix multiplication technique, where we `prune_and_compress` and dispatch to `sparse_GEMM` in less time than a `dense_GEMM` call would take. This is difficult because the native pruning and compression functions are too slow to show speedups.

Given the shapes of our ViT-L training matrix multiplications (13008x4096x1024), we measured the runtime of a dense and sparse GEMM respectively at 538us and 387us. In other words, the pruning and compression step of the weight matrix must run in less than 538-387=151us to have any efficiency gain. Unfortunately, the compression kernel provided in cuSPARSELt already takes 380us (without even considering the pruning step!).

Given the max NVIDIA A100 memory IO (2TB/s), and considering that a prune and compress kernel would be memory bound, we could theoretically prune and compress our weight (4096x1024x2 bytes=8MB) in 4us (8MB / 2TB/s)! And in fact, we were able to write a kernel that prunes and compresses a matrix into 2:4-sparse format, and runs in 36 us (10x faster than the compression kernel in cuSPARSELt), making the entire GEMM (including the sparsification) faster. Our kernel is available for use in PyTorch.

Our custom sparsification kernel

Our custom sparsification kernel, which includes pruning + compression, is ~30% faster across a linear layer forward+backward. Benchmarks run on a NVIDIA A100-80GB GPU.

Writing a performant runtime sparsification kernel

There were multiple challenges we faced in order to implement a performant runtime sparsification kernel, which we will explore below.

1) Handling the backwards pass

For the backwards pass, we need to calculate dL/dX and dL/dW for the gradient update and the subsequent layer, which means we need to calculate xWT and xTW respectively.

Overview of runtime sparsification for training acceleration (FW + BW pass)

Overview of runtime sparsification for training acceleration (FW + BW pass)

However this is problematic, because the compressed representation cannot be transposed, since there’s no guarantee that the tensor is 2:4 sparse in both directions.

Both matrices are valid 2:4 matrices. However, the right one is no longer a valid 2:4 matrix once transposed because one column contains more than 2 elements

Both matrices are valid 2:4 matrices. However, the right one is no longer a valid 2:4 matrix once transposed because one column contains more than 2 elements

Therefore, we prune a 4×4 tile, instead of a 1×4 strip. We greedily preserve the largest values, ensuring that we take at most 2 values for each row / column. While this approach is not guaranteed to be optimal, as we sometimes only preserve 7 values instead of 8, it efficiently calculates a tensor that is 2:4 sparse both row-wise and column-wise.

We then compress both the packed tensor and the packed transpose tensor, storing the transpose tensor for the backwards pass. By calculating both the packed and packed transpose tensor at the same time, we avoid a secondary kernel call in the backwards pass.

Our kernel prunes the weight matrix in registers

Our kernel prunes the weight matrix in registers, and writes the compressed values in global memory. It also prunes at the same time W.t, which is needed for the backward pass, minimizing the memory IO

There’s some additional transpose trickery needed to handle the backwards pass – the underlying hardware only supports operations where the first matrix is sparse. For weight sparsification during inference, when we need to calculate xWT we rely on transpose properties to swap the order of the operands.

Math formula

During inference, we use torch.compile to fuse the outer transpose into subsequent pointwise ops in order to avoid paying a performance penalty.

However in the case of the backwards pass of training, we have no subsequent pointwise op to fuse with. Instead, we fuse the transposition into our matrix multiplication by taking advantage of cuSPARSELt’s ability to specify the row / column layout of the result matrix.

2) Kernel tiling for efficient memory-IO

In order for our kernel to be as efficient as possible, we want to coalesce our reads / writes, as we found that memory IO to be the main bottleneck. This means that within a CUDA thread, we want to read/write chunks of 128 bytes at a time, so that multiple parallel reads/writes can be coalesced into a single request by the GPU memory controller.

Therefore, instead of a thread handling a single 4×4 tile, which is only 4x4x2 = 32 bytes, we decided that each thread will handle 4 4×4 tiles (aka an 8×8 tile), which allows us to operate 8x8x2 =128 byte chunks.

Kernel tiling for efficient memory-IO

3) Sorting elements in a 4×4 tile without warp-divergence

For each individual 4×4 tile within our thread we calculate a bitmask that specifies which elements to prune and which elements to keep. To do this we sort all 16 elements and greedily preserve elements, so long as they do not break our 2:4 row / col constraint. This preserves only the weights with the largest values.

Crucially we observe that we are only ever sorting a fixed number of elements, so by using a branchless sorting network, we can avoid warp divergence.

Sorting network diagram

For clarity, the transposed packed tensor and metadata are omitted. Sorting network diagram taken from Wikipedia.

Warp divergence occurs when we have conditional execution inside across a thread block. In CUDA, work items in the same work group (thread block) are dispatched at the hardware level in batches (warps). If we have conditional execution, such that some work-items in the same batch run different instructions, then they are masked when the warp is dispatched, or dispatched sequentially.

For example, if we have some code like if (condition) do(A) else do(B), where condition is satisfied by all the odd-numbered work items, then the total runtime of this conditional statement is do(A) + do(B), since we would dispatch do(A) for all odd-numbered work-items, masking out even-numbered work-items, and do(B) for all even numbered work-items, masking out odd-numbered work-items. This answer provides more information about warp divergence.

4) Writing the compressed matrices and metadata

Once the bitmask has been computed, the weight data has to be written back in a compressed format in global memory. This is not trivial, because the data needs to stay in registers, and it’s not possible to index registers (eg C[i++] = a prevents us from storing C in registers). Furthermore, we found that nvcc was using many more registers than we expected, which caused register spilling and impacted global performance. We write this compressed matrix to global memory in Column-Major format to make the writes more efficient.

compressed matrix to global memory in Column-Major format

We also need to write the cuSPARSELt metadata as well. This metadata layout is quite similar to the one from the open-source CUTLASS library and is optimized for being loaded efficiently through shared-memory in the GEMM kernel with the PTX ldmatrix instruction.

However, this layout is not optimized to be written efficiently: the first 128 bits of the metadata tensor contains metadata about the first 32 columns of the rows 0, 8, 16 and 24. Recall that each thread handles an 8×8 tile, which means that this information is scattered across 16 threads.

We rely on a series of warp-shuffle operations, once for the original and transposed representation respectively to write the metadata. Fortunately, this data represents less than 10% of the total IO, so we can afford to not fully coalesce the writes.

DINOv2 Sparse Training: Experimental Setup and Results

For our experiments, the ViT-L model is trained on ImageNet for 125k steps using the DINOv2 method. All our experiments were run on 4x AMD EPYC 7742 64-core CPUs and 4x NVIDIA A100-80GB GPUs. During sparse training, the model is trained with 2:4 sparsity enabled for the first part of the training, where only half of the weights are enabled. This sparsity mask on the weights is dynamically recomputed at every step, as weights are continuously updated during the optimization. For the remaining steps, the model is trained densely, producing a final model without 2:4 sparsity (except the 100% sparse training setup), which is then evaluated.

Training setup ImageNet 1k log-regression
0% sparse (125k dense steps, baseline) 82.8
40% sparse (40k sparse -> 85k dense steps) 82.9
60% sparse (75k sparse -> 50k dense steps) 82.8
70% sparse (87.5k sparse -> 37.5k dense steps) 82.7
80% sparse (100k sparse -> 25k dense steps) 82.7
90% sparse (112.5k sparse -> 12.5k dense steps) 82.0
100% sparse (125k sparse steps) 82.3 (2:4-sparse model)

sparsity training diagrams

During the sparse training steps, in the backward pass we obtain a dense gradient for the sparse weights. For the gradient descent to be sound, we should also sparsify this gradient before using it in the optimizer to update the weights. Instead of doing that, we use the full dense gradient to update the weights – we found this to work better in practice: this is the STE (Straight Through Estimator) strategy. In other words, we update all the parameters at every step, even the ones we don’t use.

Conclusion and Future Work

In this blog post, we’ve shown how to accelerate neural network training with semi-structured sparsity and explained some of the challenges we faced. We were able to achieve a 6% end to end speedup on DINOv2 training with a small 0.1 pp accuracy drop.

There are several areas of expansion for this work:

  • Expansion to new sparsity patterns: Researchers have created new sparsity patterns like V:N:M sparsity that use the underlying semi-structured sparse kernels to allow for more flexibility. This is especially interesting for applying sparsity to LLMs, as 2:4 sparsity degrades accuracy too much, but we have seen some positive results for more general N:M pattern.
  • Performance optimizations for sparse fine-tuning: This post covers sparse training from scratch, but oftentimes we want to fine-tune a foundational model. In this case, a static mask may be sufficient to preserve accuracy which would enable us to make additional performance optimizations.
  • More experiments on pruning strategy: We calculate the mask at each step of the network, but calculating the mask every n steps may yield better training accuracy. Overall, figuring out the best strategy to use semi-structured sparsity during training is an open area of research.
  • Compatibility with fp8: The hardware also supports fp8 semi-structured sparsity (in the 4:8 format instead of 2:4), and this approach should work similarly with fp8 in principle. In practice, we would need to write similar sparsification kernels, and could possibly fuse them with the scaling of the tensors.
  • Activation Sparsity: Efficient sparsification kernels also enable to sparsify the activations during training. Because the sparsification overhead grows linearly with the sparsified matrix size, setups with large activation tensors compared to the weight tensors could benefit more from activation sparsity than weight sparsity. Furthermore, activations are naturally sparse because of the usage of ReLU or GELU activation functions, reducing accuracy degradation.

If you are interested in these problems, please feel free to open an issue / PR in torchao, a community we’re building for architecture optimization techniques like quantization and sparsity. Additionally, if you have general interest in sparsity please reach out in CUDA-MODE (#sparsity)

Read More