🎉 PyTorch Docathon H2 2023 Wrap-up 🎉

We are thrilled to announce the successful completion of the Fall 2023 PyTorch Docathon! The event was a resounding success, and we want to extend our heartfelt gratitude to all the participants who made it possible. Dedication, expertise, and tireless efforts of our open-source contributors have once again helped us to improve PyTorch documentation.

This Docathon ran from Nov 1 through Nov 15 with more than 170 registrants. The energy and enthusiasm were palpable, and entrants were judged on the difficulty of submissions that resulted in over TBA merged pull requests. We have fixed the PyTorch docstrings and made them compatible with the PEP 257 Python Docstring Conventions guidelines. We also have fixed multiple bugs in the pytorch/tutorials repo.

We want to give a special shout-out to our top contributors, who went above and beyond during this event. Your dedication and expertise have been invaluable in enhancing the PyTorch documentation and empowering developers worldwide.

Meet the top contributors:

You can see the full docathon leaderboard published here.

As we bring this Docathon to a close, we encourage each and every one of you to stay inspired and keep contributing to PyTorch documentation and code, and pushing the boundaries of what’s possible with PyTorch. Your collective efforts are shaping the landscape of deep learning and fostering innovation in the PyTorch community.

Thank you again for your participation and support. We look forward to seeing what you will achieve next!

Team PyTorch

Read More

Accelerating Generative AI with PyTorch: Segment Anything, Fast

Accelerating Generative AI with PyTorch: Segment Anything, Fast

This post is the first part of a multi-series blog focused on how to accelerate generative AI models with pure, native PyTorch. We are excited to share a breadth of newly released PyTorch performance features alongside practical examples of how these features can be combined to see how far we can push PyTorch native performance.

As announced during the PyTorch Developer Conference 2023, the PyTorch team rewrote Meta’s Segment Anything (“SAM”) Model resulting in 8x faster code than the original implementation, with no loss of accuracy, all using native PyTorch optimizations. We leverage a breadth of new PyTorch features:

  • Torch.compile: A compiler for PyTorch models
  • GPU quantization: Accelerate models with reduced precision operations
  • Scaled Dot Product Attention (SDPA): Memory efficient attention implementations
  • Semi-Structured (2:4) Sparsity: A GPU optimized sparse memory format
  • Nested Tensor: Batch together non-uniformly sized data into a single Tensor, such as images of different sizes.
  • Custom operators with Triton: Write GPU operations using Triton Python DSL and easily integrate it into PyTorch’s various components with custom operator registration.

We encourage readers to copy-paste code from our implementation of SAM on Github and ask us questions on Github.

A quick glimpse of increasing throughput and decreasing memory overhead

A quick glimpse of increasing throughput and decreasing memory overhead with our newly released, PyTorch native, features. Benchmarks run on p4d.24xlarge instance (8x A100s).

SegmentAnything Model

SAM is a zero-shot vision model for generating promptable image masks.

sam image masks

The SAM architecture [described in its paper] includes multiple prompt and image encoders based on the Transformer architecture. Of this, we measured performance across the smallest and largest vision transformer backbones: ViT-B and ViT-H. And for simplicity, we only show traces for the ViT-B model.


Below we tell the story of optimizing SAM: profiling, identifying bottlenecks, and building new features into PyTorch that solve these problems. Throughout, we showcase our new PyTorch features: torch.compile, SDPA, Triton kernels, Nested Tensor and semi-structured sparsity. The following sections are progressively built upon each other, ending with our SAM-fast, now available on Github. We motivate each feature using real kernel and memory traces, using fully PyTorch native tooling, and visualize these traces with Perfetto UI.


Our SAM baseline is Facebook Research’s unmodified model, using float32 dtype and a batch size of 1. After some initial warmup, we can look at a kernel trace using the PyTorch Profiler:

kernel trace

We notice two areas ripe for optimization.

The first is long calls to aten::index, the underlying call resulting from a Tensor index operation (e.g., []). While the actual GPU time spent on aten::index is relatively low. aten::index is launching two kernels, and a blocking cudaStreamSynchronize is happening in between. This means the CPU is waiting for the GPU to finish processing until it launches the second kernel. To optimize SAM, we should aim to remove blocking GPU syncs causing idle time.

The second is significant time spent on GPU in matrix multiplication (dark green on stream 7 7 above). This is common in Transformers. We can significantly speed up SAM if we can reduce the amount of GPU time spent on matrix multiplication.

We can measure the throughput (img/s) and memory overhead (GiB) from out of the box SAM to establish a baseline:

throughput (img/s) and memory overhead (GiB) from out of the box SAM

Bfloat16 Half precision (+GPU syncs and batching)

To address the first issue of less time spent in matrix multiplication, we can turn to bfloat16. Bfloat16 is a commonly used half-precision type. Through less precision per parameter and activations, we can save significant time and memory in computation. With reducing precision of parameters, it’s critical to validate end to end model accuracy.

replacing padding dtypes with half precision, bfloat16

Shown here is an example of replacing padding dtypes with half precision, bfloat16. Code is here.

Next to simply setting model.to(torch.bfloat16) we have to change a few small places that assume the default dtype.

Now, in order to remove GPU syncs we need to audit operations that cause them. We can find these pieces of code by searching the GPU traces for calls to cudaStreamSynchronize. In fact, we found two locations that we were able to rewrite to be sync-free.

code sample 1

replacing padding dtypes with half precision, bfloat16

Specifically, we see that within SAM’s image encoder, there are variables acting as coordinate scalers, q_coords and k_coords. These are both allocated and processed on the CPU. However, once these variables are used to index in rel_pos_resized, the index operation automatically moves these variables to the GPU. This copy over causes the GPU sync we’ve observed above. We notice a second call to index in SAM’s prompt encoder: We can use torch.where to rewrite this as shown above.

Kernel trace

After applying these changes, we begin to see significant time between individual kernel calls. This is typically observed with small batch sizes (1 here) due to the GPU overhead of launching kernels. To get a closer look at practical areas for optimization, we can start to profile SAM inference with batch size 8:

profile SAM inference with batch size 8

Looking at the time spent per-kernel, we obverse most of SAM’s GPU time spent on elementwise kernels and softmax operation. With this we now see that matrix multiplications have become a much smaller relative overhead.

matrix multiplications have become a much smaller relative overhead

Taken the GPU sync and bfloat16 optimizations together, we have now pushed SAM performance by up to 3x

SAM performance by up to 3x

Torch.compile (+graph breaks and CUDA graphs)

When observing a large number of small operations, such as the elementwise kernels profiled above, turning to a compiler to fuse operations can have strong benefits. PyTorch’s recently released torch.compile does a great job optimizing by:

  1. Fusing together sequences of operations such as nn.LayerNorm or nn.GELU into a single GPU kernel that is called and
  2. Epilogues: fusing operations that immediately follow matrix multiplication kernels to reduce the number of GPU kernel calls.

Through these optimizations, we reduce the number of GPU global memory roundtrips, thus speeding up inference. We can now try torch.compile on SAM’s image encoder. To maximize performance we use a few advanced compile techniques such as:

  • using torch.compile’s max-autotune mode enables CUDA graphs and shape-specific kernels with custom epilogues
  • By setting TORCH_LOGS=”graph_breaks,recompiles” we can manually verify that we are not running into graph breaks or recompiles.
  • Padding the batch of images input to the encoder with zeros ensures compile accepts static shapes thus being able to always use shape-specific optimized kernels with custom epilogues without recompilations.
predictor.model.image_encoder = 
    torch.compile(predictor.model.image_encoder, mode=use_compile)

Kernel trace

Kernel trace

torch.compile is working beautifully. We launch a single CUDA graph, which makes up a significant portion of GPU time within the timed region. Let’s run our profile again and look at the percentage of GPU time spent in specific kernels:

the percentage of GPU time spent in specific kernels

We now see softmax makes up a significant portion of the time followed by various GEMM variants. In summary we observe the following measurements for batch size 8 and above changes.

measurements for batch size 8 and above

SDPA: scaled_dot_product_attention

Next up, we can tackle one of the most common areas for transformer performance overhead: the attention mechanism. Naive attention implementations scale quadratically in time and memory with sequence length. PyTorch’s scaled_dot_product_attention operation built upon the principles of Flash Attention, FlashAttentionV2 and xFormer’s memory efficient attention can significantly speed up GPU attention. Combined with torch.compile, this operation allows us to express and fuse a common pattern within variants of MultiheadAttention. After a small set of changes we can adapt the model to use scaled_dot_product_attention.

PyTorch native attention implementation

PyTorch native attention implementation, see code here.

Kernel trace

We can now see that in particular the memory efficient attention kernel is taking up a large amount of computational time on the GPU:

memory efficient attention kernel is taking up a large amount of computational time on the GPU

Using PyTorch’s native scaled_dot_product_attention, we can significantly increase the batch size. We now observe the following measurements for batch size 32 and above changes.

batch size 32 and above

Triton: Custom SDPA for fused relative positional encoding

Transitioning away from inference throughput for a moment, we started profiling overall SAM memory. Within the image encoder, we saw significant spikes in memory allocation:

spikes in memory allocation

Zooming in, we see this allocation happens within add_decomposed_rel_pos, on the following line:

we see this allocation happens within add_decomposed_rel_pos

The attn variable here is the addition of two smaller tensors: rel_h of shape (B, q_h, q_w, k_h, 1) and rel_w of shape (B, q_h, q_w, 1, k_w).

It’s not surprising that the memory efficient attention kernel (used via SDPA) is taking a long time with an attention bias size over 3.0GiB. If instead of allocating this large attn tensor, we thread into SDPA the two smaller rel_h and rel_w tensors, and only construct attn as needed, we’d anticipate significant performance gain.

Unfortunately this is not a trivial modification; SDPA kernels are highly optimized and written in CUDA. We can turn to Triton, with their easy to understand and use tutorial on a FlashAttention implementation. After some significant digging and in close collaboration with xFormer’s Daniel Haziza we found one case of input shapes where it is relatively straightforward to implement a fused version of the kernel. The details have been added to the repository. Surprisingly this can be done in under 350 lines of code for the inference case.

This is a great example of extending PyTorch with a new kernel, straightforwardly built with Triton code.

Kernel trace

kernel trace

With our custom positional Triton kernel we observe the following measurements for batch size 32.

we observe the following measurements for batch size 32

NT: NestedTensor and batching predict_torch

We have spent a lot of time on the image encoder. This makes sense, since it takes up the most amount of computational time. At this point however it is fairly well optimized and the operator that takes the most time would require significant additional investment to be improved.

We discovered an interesting observation with the mask prediction pipeline: for each image we have there is an associated size, coords, and fg_labels Tensor. Each of these tensors are of different batch sizes. Each image itself is also of a different size. This representation of data looks like Jagged Data. With PyTorch’s recently released NestedTensor, we can modify our data pipeline batch coords and fg_labels Tensors into a single NestedTensor. This can have significant performance benefits for the prompt encoder and mask decoder that follow the image encoder. Invoking:

torch.nested.nested_tensor(data, dtype=dtype, layout=torch.jagged)

Kernel trace

Kernel trace

we can launch kernels much faster from the CPU than the GPU can process

We can see now that we can launch kernels much faster from the CPU than the GPU can process and that it spends a long time waiting at the end of our timed region for the GPU to finish (cudaDeviceSynchronize). We also don’t see any more idle time (white space) between kernels on the GPU.

With Nested Tensor, we observe the following measurements for batch size 32 and above changes.

batch size 32 and above changes

int8: quantization and approximating matmul

We notice in the above trace, that significant time is now spent in GEMM kernels. We’ve optimized enough that we now see matrix multiplication account for more time in inference than scaled dot product attention.

Building on earlier learnings going from fp32 to bfloat16, let’s go a step further, emulating even lower precision with int8 quantization. Looking at quantization methods, we focus on Dynamic quantization wherein our model observes the range of possible inputs and weights of a layer, and subdivides the expressible int8 range to uniformly “spread out” observed values. Ultimately each float input will be mapped to a single integer in the range [-128, 127]. For more information see PyTorch’s tutorial on quantization

Reducing precision can immediately lead to peak memory savings, but to realize inference speedups, we have to make full use of int8 through SAM’s operations. This requires building an efficient int8@int8 matrix multiplication kernel, as well as casting logic to translate from high to low precision (quantization) as well as reversing back from low to high (dequantization). Utilizing the power of torch.compile, we can compile and fuse together these quantization and dequantization routines into efficient single kernels and epilogues of our matrix multiplication. The resulting implementation is fairly short and less than 250 lines of code. For more information on the APIs and usage, see pytorch-labs/ao.

While it’s common to see some accuracy regression when quantizing models at inference time, SAM has been particularly robust to lower precision inference with minimal loss of accuracy. With quantization added, we now observe the following measurements for batch size 32 and above changes.

batch size 32 and above changes

sparse: Semi-structured (2:4) sparsity

Matrix multiplications are still our bottleneck. We can turn to the model acceleration playbook with another classic method to approximate matrix multiplication: sparsification. By sparsifying our matrices (i.e., zeroing out values), we could theoretically use fewer bits to store weight and activation tensors. The process by which we decide which weights in the tensor to set to zero is called pruning. The idea behind pruning is that small weights in a weight tensor contribute little to the net output of a layer, typically the product of weights with activations. Pruning away small weights can potentially reduce model size without significant loss of accuracy.

Methods for pruning are varied, from completely unstructured, wherein weights are greedily pruned to highly structured, wherein large sub-components of a tensor are pruned a time. Choice of method is not trivial. While unstructured pruning may have the theoretically least impact on accuracy, GPUs are also highly efficient with multiplying large, dense matrices and may suffer significant performance degradation in sparse regimes. One recent pruning method supported in PyTorch seeks to strike a balance, called semi-structured (or 2:4) sparsity. This sparse storage reduces the original tensor by a significant 50%, while simultaneously resulting in a dense tensor output that can leverage highly performant, 2:4 GPU kernels. See the following picture for an illustration.

dense tensor output that can leverage highly performant, 2:4 GPU kernels

From developer.nvidia.com/blog/exploiting-ampere-structured-sparsity-with-cusparselt

In order to use this sparse storage format and the associated fast kernels we need to prune our weights such that they adhere to the constraints for the format. We pick the two smallest weights to prune in a 1 by 4 region, measuring the performance vs accuracy tradeoff. It is easy to change a weight from its default PyTorch (“strided”) layout to this new, semi-structured sparse layout. To implement apply_sparse(model) we only require 32 lines of Python code:

import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor

# Sparsity helper functions
def apply_fake_sparsity(model):
    This function simulates 2:4 sparsity on all linear layers in a model.
    It uses the torch.ao.pruning flow.
    # torch.ao.pruning flow
    from torch.ao.pruning import WeightNormSparsifier
    sparse_config = []
    for name, mod in model.named_modules():
        if isinstance(mod, torch.nn.Linear):
            sparse_config.append({"tensor_fqn": f"{name}.weight"})

    sparsifier = WeightNormSparsifier(sparsity_level=1.0,
    sparsifier.prepare(model, sparse_config)


def apply_sparse(model):
    for name, mod in model.named_modules():
        if isinstance(mod, torch.nn.Linear):
            mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))

With 2:4 sparsity, we observe peak performance on SAM with vit_b and batch size 32:

With 2:4 sparsity, we observe peak performance on SAM with vit_b and batch size 32


Wrapping up, we are excited to have announced our fastest implementation of Segment Anything to date. We rewrote Meta’s original SAM in pure PyTorch with no loss of accuracy using a breadth of newly released features:

  • Torch.compile PyTorch’s native JIT compiler, providing fast, automated fusion of PyTorch operations [tutorial]
  • GPU quantization accelerate models with reduced precision operations [api]
  • Scaled Dot Product Attention (SDPA) a new, memory efficient implementation of Attention [tutorial]
  • Semi-Structured (2:4) Sparsity accelerate models with fewer bits to store weights and activations [tutorial]
  • Nested Tensor Highly optimized, ragged array handling for non-uniform batch and image sizes [tutorial]
  • Triton kernels. Custom GPU operations, easily built and optimized via Triton

For more details on how to reproduce the data presented in this blog post, check out the experiments folder of segment-anything-fast. Please don’t hesitate to contact us or open an issue if you run into any technical issues.

In our next post, we are excited to share similar performance gains with our PyTorch natively authored LLM!


We would like to thank Meta’s xFormers team including Daniel Haziza and Francisco Massa for authoring SDPA kernels and helping us design our custom one-off Triton kernel.

Read More

PyTorch compile to speed up inference on Llama 2

PyTorch compile to speed up inference on Llama 2

In this blog, we discuss how to improve the inference latencies of the Llama 2 family of models using PyTorch native optimizations such as native fast kernels, compile transformations from torch compile, and tensor parallel for distributed inference. Our approach results in 29ms/token latency for single user requests on the 70B LLaMa model (as measured on 8 A100 GPUs). We are excited to share our findings with the community and make our code available here.


We are amid a generative AI revolution with large language models of tens of billions of parameters becoming commoditized and available for use. However, it is well recognized in the community that deploying these large models in a cost-efficient manner remains a key challenge. Many different approaches have been attempted with varying degrees of success and offering different trade-offs. Hardware-specific optimizations (e.g., Faster Transformer from NVIDIA) are restricted to specific target hardware whereas approaches that rely on layers of abstraction (e.g., ONNX) enable arbitrary models but suffer from loss of efficiency. With the introduction of PyTorch compile last year, IBM and the PyTorch team started exploring the use of model compilation for inference optimizations with the goal of reducing the latency per token for generative models.

Model Choice

We choose to benchmark on the Llama 2 family of models, given their popularity. The models that we are interested in, and their hyper parameters relevant for this blog are given in the below table:

Model size Hidden dimension Num heads Num layers Attention type
7B 4096 32 32 MHA
13B 5120 40 40 MHA
70B 8192 64 80 GQA

These models are decoder only, which means that tokens get generated in a serialized manner, which is typically sped up using KV caching. We take a similar approach in our latency and throughput measurements.

Inference Approach

Our goal for inference is to provide a path for achieving the best possible latencies rapidly, to keep up with the velocity with which new model architectures are emerging in the community. A PyTorch native approach is appealing as it allows for the maximum flexibility in terms of “coverage” of models. We note that there are four orthogonal techniques that provide acceleration in inference: (a) Kernel fusion using compile, (b) Faster kernels, (c) Tensor parallel for larger models, and (d) Quantization. In our approach, we use the first three of these four levers – compile natively working with faster kernels from SDPA and a custom tensor parallel implementation that all work hand-in-glove to achieve inference latencies of 29ms/token on a 70B model as measured on 8 NVIDIA A100 GPUs with single user.

Compile all the way!

PyTorch Compile leverages tracing and graph capture to reduce the CPU overhead and in an ideal scenario results in a single graph execution/instruction from CPU to GPU. However, often compile introduces graph breaks due to model architecture and ops unsupported by compile. For example, complex operations such as einops are not supported by compile today. Similarly, tensor parallel inference can introduce graph breaks at each layer, since compile requires the tensor parallel implementation to use traceable communication collectives. If these graph breaks are not removed, the performance of the compiled artifacts will be hampered and could even be lower compared to eager mode execution. To get full benefit of the compiled artifacts, the graph breaks need to be removed.

Below, we describe how we went about doing this for the 70b Llama 2 model and the challenges we had to overcome to get compile to work all the way through.

Our first attempt was to try using torch.compile to compile the out-of-box Llama 2 model, but it failed because complex ops were not supported. Using TORCH_COMPILE_DEBUG = 1 we identified the RoPE positional encodings was using complex number functions resulting in graph breaks and significant slowdowns. We rewrote the RoPE function to bypass torch.einsum (Original implementation uses torch.polar that also conflicts with compile) and use torch.cos and torch.sin instead.

self.cached_freqs[dev_idx][alpha] = torch.stack(
        ).view(*freqs.shape, 2, 2)

Our implementation of the frequencies computation

t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

Hugging Face implementation of the frequencies computation

Once RoPE was fixed, we were able to get 7B and 13B models to compile without ANY graph breaks on a single A100 GPU.

We used SDPA, the PyTorch native implementation of efficient attention computation with tracing enabled (for compile). To avoid graph breaks related to forcing a single algorithm choice using a Python context, the recommended way, we had to use the torch.backends.cuda.enable_*_sdp functions.

attn = torch.nn.functional.scaled_dot_product_attention(
            dropout_p=self.p_dropout if self.training else 0.0,

Attention computation using SDPA

Next we ran the same steps for the larger 70B model and found that even with half precision, the model does not fit in a single GPU and requires tensor parallel inference. Using torch.compile for the 70B model resulted in 162 graph breaks due to two all-reduces per layer, one all-gather for forward embedding, and one all-gather for reverse embedding. Due to this, we saw no significant improvement in inference latencies. We could not use the distributed tensor implementation from PyTorch at the time of writing this blog as it did not support compile. We rewrote the tensor parallel code from scratch so that it only depends on traceable collectives to make it work with compile. After this last change, PyTorch compiler did not introduce any graph breaks and we saw a significant speedup in inference latencies. Specifically, we measured latencies for the Llama 70B model at 29ms/token when using 8 A100 GPUs, a 2.4x improvement over unoptimized inference.

Serving aspects

Finally, a point to note here is that simply performing compile on a model is not sufficient to serve the model in a production setting. To realize the above performance with high throughput, we need to support dynamic batching, nested tensors, as well as have a warm up phase where we pre-compile for bucketized sequence lengths. We are working on these aspects to realize such performance in a production setting.

Experiments and Measurements

We use nodes with 8 A100 NVIDIA GPUs with 80G cards for all our measurements in two different environments (IBM Cloud and AWS, both running OpenShift). First, we compare the various techniques – eager mode, with SDPA Flash kernel, with Compile, and with Compile and SDPA. For the 70B model, we run it in Tensor Parallel mode with compile and SDPA. For this experiment, we use 512 tokens as input length with 50 token generation. For 7 and 13B models, we use single A100 for measurement of latencies, whereas we use 8 A100s for the 70B model. In addition, for the 70B model we use the reduce-overhead option in PyTorch compile that uses CudaGraphs to reduce CPU to GPU kernel launching overheads; the use of CudaGraphs in the 7B and 13B models did not show any benefits (and are thus not reported here). We observe from Figure 1 that compile and SDPA provide very low latencies, with 70B Llama 2 model at 29ms/token.

Figure 1. Median latency across different techniques with sequence length 512 (measured on IBM Cloud A100 servers)

Fig. 1: Median latency across different techniques with sequence length 512 (measured on IBM Cloud A100 servers)

Next, we examine the impact of sequence length, where we increase it from 1024 to 4096 and observe that the median latency per token increases sub-linearly, demonstrating that when we increase context to large documents, we do not sacrifice response times.

Figure 2. Median latency for compile+SDPA with different sequence lengths (Measured on A100s on AWS)

Fig. 2: Median latency for compile+SDPA with different sequence lengths (Measured on A100s on AWS)

Finally, with increased batch sizes, we observe that the response latencies increase sub-linearly. For the 13B model, at batch size 8, we encounter an OOM. For the 70B model, given that it is running on 8 GPUs with tensor parallel, we do not see any such OOM issues.

Figure 3. Median latency for compile+SDPA with different batch sizes and sequence length fixed at 4096 (Measured on A100s on AWS)

Fig. 3: Median latency for compile+SDPA with different batch sizes and sequence length fixed at 4096 (Measured on A100s on AWS)

Final Thoughts

We have demonstrated how a PyTorch compile pathway for inference demonstrates ultra low latencies for 70B model inference. The next steps are to enable dynamic batching and nested tensors with the above levers.

Special thanks to Edward Yang, Elias Ellison, Driss Guessous, Will Feng, Will Constable, Horace He, Less Wright, and Andrew Gu from Team PyTorch, whose PRs reviews and code contributions made it possible for us to realize the latencies using PyTorch native approach. We thank the broader Team PyTorch that have been tirelessly working to make PyTorch better, special shout outs to the SDPA team for enabling tracing and compile on fast kernels, the compile team that has been closely guiding us on how to work around as well as fix issues (including identifying and raising NVIDIA driver bugs in CUDA graphs).

Inference latency has been one of the roadblocks for LLM adoption in critical enterprise workflows, but another major one is the need for safety, trustworthiness and governance. IBM’s guide for AI safety and LLM risk can be found here and Meta’s responsible user guide for LLaMa can be found here.


Read More

High-Performance Llama 2 Training and Inference with PyTorch/XLA on Cloud TPUs

High-Performance Llama 2 Training and Inference with PyTorch/XLA on Cloud TPUs

In a landscape where AI innovation is accelerating at an unprecedented pace, Meta’s Llama family of open sourced large language models (LLMs) stands out as a notable breakthrough. Llama marked a significant step forward for LLMs, demonstrating the power of pre-trained architectures for a wide range of applications. Llama 2 further pushed the boundaries of scale and capabilities, inspiring advancements in language understanding, generation, and beyond.

Shortly after the announcement of Llama, we published a blog post showcasing ultra-low inference latency for Llama using PyTorch/XLA on Cloud TPU v4. Building on these results, today, we are proud to share Llama 2 training and inference performance using PyTorch/XLA on Cloud TPU v4 and our newest AI supercomputer, Cloud TPU v5e.

In this blog post, we use Llama 2 as an example model to demonstrate the power of PyTorch/XLA on Cloud TPUs for LLM training and inference. We discuss the computation techniques and optimizations used to improve inference throughput and training model FLOPs utilization (MFU). For Llama 2 70B parameters, we deliver 53% training MFU, 17 ms/token inference latency, 42 tokens/s/chip throughput powered by PyTorch/XLA on Google Cloud TPU. We offer a training user guide and an inference user guide for reproducing the results in this article. Additionally, you may find our Google Next 2023 presentation here.

Model Overview

Llama 2 comes in various sizes, ranging from 7B to 70B parameters, catering to different needs, computational resources, and training / inference budgets. Whether it’s small-scale projects or large-scale deployments, Llama models offer versatility and scalability to accommodate a wide range of applications.

Llama 2 is an auto-regressive language model that uses an optimized transformer architecture. The largest, 70B model, uses grouped-query attention, which speeds up inference without sacrificing quality. Llama 2 is trained on 2 trillion tokens (40% more data than Llama) and has the context length of 4,096 tokens for inference (double the context length of Llama), which enables more accuracy, fluency, and creativity for the model.

Llama 2 is a state-of-the-art LLM that outperforms many other open source language models on many benchmarks, including reasoning, coding, proficiency, and knowledge tests. The model’s scale and complexity place many demands on AI accelerators, making it an ideal benchmark for LLM training and inference performance of PyTorch/XLA on Cloud TPUs.

Performance Challenge of LLMs

Large-scale distributed training for LLMs such as Llama 2 introduces technical challenges that require practical solutions to make the most efficient use of TPUs. Llama’s size can strain both memory and processing resources of TPUs. To address this, we use model sharding, which involves breaking down the model into smaller segments, each fitting within the capacity of a single TPU core. This enables parallelism across multiple TPUs, improving training speed while reducing communication overhead.

Another challenge is managing the large datasets required for training Llama 2 efficiently, which requires effective data distribution and synchronization methods. Additionally, optimizing factors like learning rate schedules, gradient aggregation, and weight synchronization across distributed TPUs is crucial for achieving convergence.

After pretraining or fine-tuning Llama 2, running inference on the model checkpoint creates additional technical challenges. All of the challenges discussed in our previous blog post, such as autoregressive decoding, variable input prompt lengths, and the need for model sharding and quantization still apply for Llama 2. In addition, Llama 2 introduced two new capabilities: grouped-query attention and early stopping. We discuss how PyTorch/XLA handles these challenges to enable high-performance, cost-efficient training and inference of Llama 2 on Cloud TPU v4 and v5e.

Large-Scale Distributed Training

PyTorch/XLA offers two major ways of doing large-scale distributed training: SPMD, which utilizes the XLA compiler to transform and partition a single-device program into a multi-device distributed program; and FSDP, which implements the widely-adopted Fully Sharded Data Parallel algorithm.

In this blog post, we show how to use the SPMD API to annotate the HuggingFace (HF) Llama 2 implementation to maximize performance. For comparison, we also show our FSDP results with the same configurations; read about PyTorch/XLA FSDP API here.

SPMD Overview

Let’s briefly review the fundamentals of SPMD. For details, please refer to our blog post and user guide.


A multidimensional array that describes the logical topology of the TPU devices:

# Assuming you are running on a TPU host that has 8 devices attached
num_devices = xr.global_runtime_device_count()
# mesh shape will be (4,2) in this example
mesh_shape = (num_devices // 2, 2)
device_ids = np.array(range(num_devices))
# axis_names 'x' and 'y' are optional
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

Partition Spec

A tuple that describes how the corresponding tensor’s dimensions are sharded across the mesh:

partition_spec = ('x', 'y')

Mark Sharding

An API that takes a mesh and a partition_spec, and then generates a sharding annotation for the XLA compiler.

tensor = torch.randn(4, 4).to('xla')
# Let's resue the above mesh and partition_spec.
# It means the tensor's 0th dim is sharded 4 way and 1th dim is sharded 2 way.
xs.mark_sharding(tensor, mesh, partition_spec)

2D Sharding with SPMD

In our SPMD blog post, we demonstrated using 1D FSDP style sharding. Here, we introduce a more powerful sharding strategy, called 2D sharding, where both the parameters and activations are sharded. This new sharding strategy not only allows fitting a larger model but also boosts the MFU to up to 54.3%. For more details, read the Benchmarks section.

This section introduces a set of general rules that applies to most LLMs, and for convenience we directly reference the variable names and configuration names from HF Llama.

First, let’s create a 2D Mesh with corresponding axis names: data and model. The data axis is usually where we distribute the input data, and the model axis is where we further distribute the model.

mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))

The mesh_shape can be a hyper-parameter that is tuned for different model sizes and hardware configurations. The same mesh will be reused in all following sharding annotations. In the next few sections, we will cover how to use the mesh to shard parameters, activations and input data.

Parameter Sharding

Below is a table that summarizes all parameters of HF Llama 2 and corresponding partition specifications. Example HF code can be found here.

Parameter Name Explanation Parameter Shape Partition Spec
embed_tokens embedding layer (vocab_size, hidden_size) (model, data)
q_proj attention weights (num_heads x head_dim, hidden_size) (data, model)
k_proj / v_proj attention weights (num_key_value_heads x head_dim, hidden_size) (data, model)
o_proj attention weights (hidden_size, num_heads x head_dim) (model, data)
gate_proj / up_proj MLP weights (intermediate_size, hidden_size) (model, data)
down_proj MLP weights (hidden_size, intermediate_size) (data, model)
lm_head HF output embedding (vocab_size, hidden_size) (model, data)

Table 1: SPMD 2D Sharding Parameter Partition Spec

The rule is to shard the hidden_size dim of any weights except QKVO projections according to the data axis of the mesh, then shard the other dim with the remaining model axis. For QKVO, do the opposite. This model-data axis rotation methodology is similar to that of Megatron-LM to reduce communication overhead. For layernorm weights, we implicitly mark them as replicated across different devices given they are 1D tensors.

Activation Sharding

In order to better utilize the device memory, very often we need to annotate the output of some memory bound ops. That way the compiler is forced to only keep partial output on devices instead of the full output. In Llama 2, we explicitly annotate all torch.matmul and nn.Linear outputs. Table 2 summarizes the corresponding annotations; the example HF code can be found here.

Output Name Explanation Output Shape Partition Spec
inputs_embeds embedding layer output (batch_size, sequence_length, hidden_size) (data, None, model)
query_states attention nn.Linear output (batch_size, sequence_length, num_heads x head_dim) (data, None, model)
key_states / value_states attention nn.Linear output (batch_size, sequence_length, num_key_value_heads x head_dim) (data, None, model)
attn_weights attention weights (batch_size, num_attention_heads, sequence_length, sequence_length) (data, model, None, None)
attn_output attention layer output (batch_size, sequence_length, hidden_size) (data, None, model)
up_proj / gate_proj / down_proj MLP nn.Linear outputs (batch_size, sequence_length, intermediate_size) (data, None, model)
logits HF output embedding output (batch_size, sequence_length, hidden_size) (data, None, model)

Table 2: SPMD 2D Sharding Activation Partition Spec

The rule is to shard the batch_size dim of any outputs according to the data axis of the mesh, then replicate the length dims of any outputs, and finally shard the last dim along the model axis.

Input Sharding

For input sharding, the rule is to shard the batch dim along the data axis of the mesh, and replicate the sequence_length dim. Below is the example code, and the corresponding HF change may be found here.

partition_spec = ('data', None)
sharding_spec = xs.ShardingSpec(mesh, partition_spec)
# MpDeviceLoader will shard the input data before sending to the device.
pl.MpDeviceLoader(dataloader, self.args.device, input_sharding=sharding_spec, ...)

Now, all the data and model tensors that require sharding are covered!

Optimizer States & Gradients

You may be wondering whether it is necessary to shard the optimizer states and gradients as well. Great news: the sharding propagation feature of the XLA compiler automates the sharding annotation in these two scenarios, without needing more hints to improve performance.

It is important to note that optimizer states are typically initialized within the first iteration of the training loop. From the standpoint of the XLA compiler, the optimizer states are the outputs of the first graph, and therefore have the sharding annotation propagated. For subsequent iterations, the optimizer states become inputs to the second graph, with the sharding annotation propagated from the first one. This is also why PyTorch/XLA typically produces two graphs for the training loops. If the optimizer states are somehow initialized before the first iteration, users will have to manually annotate them, just like the model weights.

Again, all concrete examples of the above sharding annotation can be found in our fork of HF Transformers here. The repo also contains code for our experimental feature MultiSlice, including HybridMesh and dcn axis, which follows the same principles mentioned above.


While using SPMD for training, there are a few important things to pay attention to:

  • Use torch.einsum instead of torch.matmul; torch.matmul usually flattens tensors and does a torch.mm at the end, and that’s bad for SPMD when the combined axes are sharded. The XLA compiler will have a hard time determining how to propagate the sharding.
  • PyTorch/XLA provides patched [nn.Linear](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/xla_sharding.py#L570) to overcome the above constraint:
import torch_xla.experimental.xla_sharding as xs
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear

 model = apply_xla_patch_to_nn_linear(model, xs.xla_patched_nn_linear_forward)
  • Always reuse the same mesh across all shardings
  • Always specify --dataloader_drop_last yes. The last smaller data is hard to annotate.
  • Large models which are initialized on the host can induce host-side OOM. One way to avoid this issue is to initialize parameters on the meta device, then create and shard real tensors layer-by-layer.

Infrastructure Improvements

Besides the above modeling techniques, we have developed additional features and improvements to maximize performance, including:

  • We enable asynchronous collective communication. This requires enhancements on the XLA compiler’s latency hiding scheduler to better optimize for the Llama 2 PyTorch code.
  • We now allow sharding annotations in the middle of the IR graph, just like JAX’s jax.lax.with_sharding_constraint. Previously, only graph inputs were annotated.
  • We also propagate replicated sharding spec from the compiler to the graph outputs. This allows us to shard the optimizer states automatically.

Inference Optimizations

All the PyTorch/XLA optimizations implemented for Llama inference are applied to Llama 2 as well. That includes Tensor Parallelism + Dynamo (torch.compile) using torch-xla collective ops, autoregressive decoding logic improvement to avoid recompilation, bucketized prompt length, KV-cache with compilation friendly index ops. Llama 2 introduces two new changes: Grouped Query Attention, and Early Stopping when eos is reached for all prompts. We applied corresponding changes to promote better performance and flexibility with PyTorch/XLA.

Grouped Query Attention

Llama 2 enables Grouped Query Attention for the 70B models. It allows the number of Key and Value heads to be smaller than the number of Query heads, while still supporting KV-cache sharding up to the number of KV heads. For the 70B models, the n_kv_heads is 8, which limits the tensor parallelism to be less or equal to 8. In order to shard the model checkpoint to run on more devices, the K, V projection weights need to be replicated first, and then split into multiple pieces. For example, to shard the 70B model checkpoint from 8 pieces to 16 pieces, the K, V projection weights are duplicated and split into 2 pieces for each shard. We provide a reshard_checkpoints.py script to handle that, and to make sure the sharded checkpoint performs mathematically identical to the original checkpoint.

EOS Early Stopping

The Llama 2 generation code added the early stopping logic. A eos_reached tensor is used to track the completion of all the prompt generations, and if the eos token is reached for all the prompts in the batch, the generation would stop early. The similar change is incorporated in the PyTorch/XLA optimized version as well, with some minor tweaks.

In PyTorch/XLA, checking the value of a tensor like eos_reached as part of the control flow condition would invoke a blocking device-to-host transfer. The tensor would be transferred from device memory to CPU memory to evaluate its value, while all other logics are waiting. This introduced a delay on the scale of ms after every new token generation. As a trade-off, we reduce the rate of checking the eos_reached value to be once every 10 new token generations. With this change, the impact of the blocking device-to-host transfer would be reduced by 10x, while the early stopping would still be effective, and at most 9 unnecessary tokens would be generated after each sequence reaches the eos token.

Model Serving

PyTorch/XLA is working on a serving strategy to enable the PyTorch community to serve their deep learning applications via Torch.Export, StableHLO, and SavedModel. PyTorch/XLA Serving is an experimental feature in PyTorch/XLA 2.1 release; for details visit our serving user guide. Users can take advantage of TorchServe to run their single-host workloads.



To measure training performance, we use the industry-standard metric: Model FLOPS Utilization (MFU). Model FLOPS are the floating point operations required to perform a single forward and backward pass. Model FLOPs are hardware and implementation independent and only depend on the underlying model. MFU measures how effectively the model is using the actual hardware during training. Achieving 100% MFU means that the model is using the hardware perfectly.

To measure inference performance, we use the industry-standard metric of throughput. First, we measure latency per token when the model has been compiled and loaded. Then, we calculate throughput by dividing batch size (BS) over latency per chip. As a result, throughput measures how the model is performing in production environments regardless of how many chips are used.


Training Evaluation

Figure 1 shows Llama 2 SPMD 2D sharding training results on a range of Google TPU v4 hardware with PyTorch/XLA FSDP as the baseline. We increased MFU by 28% across all sizes of Llama 2 compared to FSDP running on the same hardware configuration. This performance improvement is largely due to: 1) 2D Sharding has less communication overhead than FSDP, and 2) asynchronous collective communication is enabled in SPMD which allows communication and computation overlapping. Also note that as the model size scales, we maintain the high MFU. Table 3 shows all the hardware configurations plus some hyperparameters used in the training benchmarks.

Figure 1. Llama 2 Training MFU on TPU v4 Hardware

Fig. 1: Llama 2 Training MFU on TPU v4 Hardware

The results in Figure 1 are produced with sequence length 1,024. Figure 2 shows how the performance behaves with larger sequence lengths. It shows our performance also scales linearly with sequence lengths. The MFU is expected to decrease a little as a smaller per device batch size is needed to accommodate the additional memory pressure introduced by the larger sequence length since the sequence length axis is not sharded in 2D sharding. And TPU is very sensitive to batch size. For Llama 2, 70B parameters, the performance decrease is as low as 4%. At the time of preparing these results, Hugging Face Llama 2 tokenizer limits the max model input to 2,048, preventing us from evaluating larger sequence lengths.

Figure 2. Llama 2 SPMD Training MFU on TPU v4 with Different Sequence Lengths

Fig. 2: Llama 2 SPMD Training MFU on TPU v4 with Different Sequence Lengths

Model Size 7B 13B 70B
TPU NumCores V4-32 V4-64 V4-256
Mesh Shape (16, 1) (32, 1) (32, 4)
Seq Len 1,024 2,048 1,024 2,048 1,024 2,048
Global Batch 256 128 256 128 512 256
Per Device Batch 16 8 8 4 16 8

Table 3: Llama 2 SPMD Training Benchmark TPU Configurations and Hyperparameters

One last thing to call out is that we use adafactor as the optimizer for better memory utilization. And once again, here is the user guide to reproduce the benchmark results listed above.

Inference Evaluation

In this section, we extend our previous evaluation of Llama on Cloud v4 TPU. Here, we demonstrate the performance properties of TPU v5e for inference applications.

We define inference throughput as the number of tokens produced by a model per second per TPU chip. Figure 3 shows Llama 2 70B throughput on a v5e-16 TPU node. Given Llama is a memory bound application, we see that applying weight-only quantization unblocks extending the model batch size to 32. Higher throughput results would be possible on larger TPU v5e hardware up to the point where the ICI network bandwidth between chips throttle the TPU slice from delivering higher throughput. Exploring the upper bound limits of TPU v5e on Llama 2 was outside of the scope of this work. Notice, to make the Llama 2 70B model run on v5e-16, we replicated the attention heads to have one head per chip as discussed in the Inference section above. As discussed previously, with increasing model batch size, per-token latency grows proportionally; quantization improves overall latency by reducing memory I/O demand.

Figure 3. Llama 2 70B Inference Per-Chip Throughput on TPU v5e vs. Batch Size

Fig. 3: Llama 2 70B Inference Per-Chip Throughput on TPU v5e vs. Batch Size

Figure 4 shows inference throughput results across different model sizes. These results highlight the largest throughput given the hardware configuration when using bf16 precision. With weight only quantization, this throughput reaches 42 on the 70B model. As mentioned above, increasing hardware resources may lead to performance gains.

Figure 4. Llama 2 Inference Per-Chip Throughput on TPU v5e

Fig. 4: Llama 2 Inference Per-Chip Throughput on TPU v5e

Figure 5 shows the cost of serving Llama 2 models (from Figure 4) on Cloud TPU v5e. We report the TPU v5e per-chip cost based on the 3-year commitment (reserved) price in the us-west4 region. All model sizes use maximum sequence length of 2,048 and maximum generation length of 1,000 tokens. Note that with quantization, the cost for the 70B model drops to $0.0036 per 1,000 tokens.

Figure 5. Llama 2 Inference Per-Chip Cost on TPU v5e

Fig. 5: Llama 2 Inference Per-Chip Cost on TPU v5e

Figure 6 summarizes our best Llama 2 inference latency results on TPU v5e. Llama 2 7B results are obtained from our non-quantized configuration (BF16 Weight, BF16 Activation) while the 13B and 70B results are from the quantized (INT8 Weight, BF16 Activation) configuration. We attribute this observation to the inherent memory saving vs. compute overhead tradeoff of quantization; as a result, for smaller models, quantization may not lead to lower inference latency.

Additionally, prompt length has a strong effect on the memory requirements of LLMs. For instance, we observe a latency of 1.2ms / token (i.e. 201 tokens / second / chip) when max_seq_len=256 at batch size of 1 with no quantization on v5e-4 running Llama2 7B.

Figure 6. Llama 2 Inference Latency on TPU v5e

Fig. 6: Llama 2 Inference Latency on TPU v5e

Final Thoughts

The recent wave of AI innovation has been nothing short of transformative, with breakthroughs in LLMs at the forefront. Meta’s Llama and Llama 2 models stand as notable milestones in this wave of progress. PyTorch/XLA uniquely enables high-performance, cost-efficient training and inference for Llama 2 and other LLMs and generative AI models on Cloud TPUs, including the new Cloud TPU v5e. Looking forward, PyTorch/XLA will continue to push the performance limits on Cloud TPUs in both throughput and scalability and at the same time maintain the same PyTorch user experience.

We are ecstatic about what’s ahead for PyTorch/XLA and invite the community to join us. PyTorch/XLA is developed fully in open source. So, please file issues, submit pull requests, and send RFCs to GitHub so that we can openly collaborate. You can also try out PyTorch/XLA for yourself on various XLA devices including TPUs and GPUs.

We would like to extend our special thanks to Marcello Maggioni, Tongfei Guo, Andy Davis, Berkin Ilbeyi for their support and collaboration in this effort.

The PyTorch/XLA Team at Google

Read More

Accelerating Inference on x86-64 Machines with oneDNN Graph

Accelerating Inference on x86-64 Machines with oneDNN Graph

Supported in PyTorch 2.0 as a beta feature, oneDNN Graph leverages aggressive fusion patterns to accelerate inference on x86-64 machines, especially Intel® Xeon® Scalable processors.

oneDNN Graph API extends oneDNN with a flexible graph API to maximize the optimization opportunity for generating efficient code on AI hardware. It automatically identifies the graph partitions to be accelerated via fusion. The fusion patterns focus on fusing compute-intensive operations such as convolution, matmul, and their neighbor operations for both inference and training use cases.

In PyTorch 2.0 and beyond, oneDNN Graph can help accelerate inference on x86-64 CPUs (primarily, Intel Xeon processor-based machines) with Float32 and BFloat16 (with PyTorch’s Automatic Mixed Precision support) datatypes. With BFloat16, speedup is limited to machines that support AVX512_BF16 ISA (Instruction Set Architecture), as well as machines that also support AMX_BF16 ISA.

oneDNN Graph Usage

From a user’s perspective, the usage is quite simple and intuitive, with the only change in code being an API invocation. To leverage oneDNN Graph with JIT-tracing, a model is profiled with an example input as shown below in Figure 1.

Figure 1. A code-snippet that demonstrates using oneDNN Graph

Fig. 1: A code-snippet that demonstrates using oneDNN Graph

oneDNN Graph receives the model’s graph and identifies candidates for operator-fusion with respect to the input shape of the example input. Currently, only static shapes are supported. This means that any other input shape would neither be supported nor receive any performance-benefit.


To ensure reproducibility of results, we used a fork of TorchBench to measure inference speed-up of some Vision models on an AWS m7i.16xlarge instance, which uses 4th Gen Intel® Xeon® Scalable processors.

The baseline for comparison was torch.jit.optimize_for_inference which only supports Float32 datatype. The batch-size for each model was based on the respective batch size being used for them in TorchBench.

In Figure 2, we depict the inference speedup of using oneDNN Graph over PyTorch alone. The geomean speedup with oneDNN Graph for Float32 datatype was 1.24x, and the geomean speedup for BFloat16 datatype was 3.31x1.

Figure 2. Inference speedup with oneDNN Graph over default CPU JIT Fuser (which only uses Float32 datatype)

Fig. 2: Inference speedup with oneDNN Graph over default CPU JIT Fuser (which only uses Float32 datatype)

Future work

oneDNN Graph is currently supported in PyTorch through TorchScript, but work is already underway by Intel to integrate it with the Inductor-CPU backend as a prototype feature in a future PyTorch release and Dynamo make supporting dynamic shapes easier with PyTorch, and we would like to introduce Dynamic shape support with Inductor-CPU. We also plan to add int8 quantization support.


The results presented in this blog are a joint effort between Meta and the Intel PyTorch team. Special thanks to Elias Ellison from Meta who spent precious time thoroughly reviewing the PRs and gave us helpful feedback.

Read More

AMD Extends Support for Pytorch Machine Learning Development nn Select RDNA™ 3 GPUs with ROCm™ 5.7

AMD Extends Support for Pytorch Machine Learning Development nn Select RDNA™ 3 GPUs with ROCm™ 5.7

Researchers and developers working with Machine Learning (ML) models and algorithms using PyTorch can now use AMD ROCm 5.7 on Ubuntu® Linux® to tap into the parallel computing power of the Radeon™ RX 7900 XTX and the Radeon™ PRO W7900 graphics cards which are based on the AMD RDNA™ 3 GPU architecture.

A client solution built on these two high-end GPUs enables a local, private, and cost-effective workflow for ML training and inference for those who previously relied on cloud-based solutions alone.

ML Development on Desktop

Accelerate Machine Learning With Pytorch On Your Desktop

  • A local PC or workstation system running PyTorch with a Radeon 7900 series GPU presents a capable, yet affordable solution to address these growing workflow challenges thanks to large GPU memory sizes of 24GB and even 48GB.

Unified Software Stack For The Desktop And The Datacenter

  • The latest AMD ROCm 5.7 software stack for GPU programming unlocks the massively parallel compute power of these RDNA™ 3 architecture-based GPUs for use with PyTorch, one of the leading ML frameworks. The same unified software stack also supports the CDNA™ GPU architecture of the AMD Instinct™ MI series accelerators.

Freedom To Customize

  • The AMD ROCm platform is primarily Open-Source Software (OSS). It allows developers the freedom to customize and tailor their GPU software for their own needs while collaborating with a community of other developers, and helping each other find solutions in an agile, flexible, and rapid manner. The AMD ROCm platform’s goal is to allow users to maximize their GPU hardware investment. The AMD ROCm platform is designed to help develop, test, and deploy GPU accelerated HPC, AI, scientific computing, CAD, and other applications in a free, open source, integrated and secure software ecosystem.

As the industry moves towards an ecosystem that supports a broad set of systems, frameworks and accelerators, AMD is determined to continue to make AI more accessible to PyTorch developers and researchers that benefit from a local client-based setup for ML development using RDNA™ 3 architecture-based desktop GPUs.

Learn More


Download Software


Visit the Documentation Portal to get started training ML models on your local desktop




How to Guide


© 2023 Advanced Micro Devices, Inc. All rights reserved. AMD, the AMD Arrow logo, CDNA, Radeon, ROCm, and combinations thereof are trademarks of Advanced Micro Devices, Inc. Linux® is the registered trademark of Linus Torvalds in the U.S. and other countries. Microsoft and Windows are registered trademarks of Microsoft Corporation in the US and/or other countries. PyTorch, the PyTorch logo and any related marks are trademarks of The Linux Foundation. TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc. Ubuntu and the Ubuntu logo are registered trademarks of Canonical Ltd. Other product names used in this publication are for identification purposes only and may be trademarks of their respective owners.

Radeon™ AI technology is compatible with all AMD Radeon 7000 Series graphics cards and newer. Please check with your system manufacturer for feature availability prior to purchase. GD-232.

  1. Based on AMD internal measurements, November 2022, comparing the Radeon RX 7900 XTX at 2.5GHz boost clock with 96 CUs issuing 2X the Bfloat16 math operations per clocks vs. the RX 6900 XT GPU at 2.25 GHz boost clock and 80 CUs issue 1X the Bfloat16 math operations per clock. RX-821

Read More

Compiling NumPy code into C++ or CUDA via torch.compile

Quansight engineers have implemented support for tracing through NumPy code via
torch.compile in PyTorch 2.1. This feature leverages PyTorch’s compiler to
generate efficient fused vectorized code without having to modify your original
NumPy code. Even more, it also allows for executing NumPy code on CUDA
just by running it through torch.compile under torch.device("cuda")!

In this post, we go over how to use this feature and give a few tips and tricks
to make the most out of it.

Compiling NumPy code into Parallel C++

We take as our running example one step in a K-Means algorithm.
This piece of code is borrowed from this NumPy book

import numpy as np

def kmeans(X, means):
    return np.argmin(np.linalg.norm(X - means[:, None], axis=2), axis=0)

We create a synthetic dataset with 20M random 2-D points. We can see that,
given that the means are chosen appropriately, the function returns the correct
cluster for all of them

npts = 10_000_000
X = np.repeat([[5, 5], [10, 10]], [npts, npts], axis=0)
X = X + np.random.randn(*X.shape)  # 2 distinct "blobs"
means = np.array([[5, 5], [10, 10]])
np_pred = kmeans(X, means)

Benchmarking this function gives us a baseline of 1.26s on an AMD 3970X CPU.

Compiling this function is now as easy as wrapping it with torch.compile and
executing it with the example inputs

import torch

compiled_fn = torch.compile(kmeans)
compiled_pred = compiled_fn(X, means)
assert np.allclose(np_pred, compiled_pred)

The compiled function yields a 9x speed-up when running it on 1 core. Even
better, as opposed to NumPy, our generated code does take advantage of all the
cores in a processor. As such, when we run it on 32 cores, we get a 57x
. Note that PyTorch always uses all the available cores unless
explicitly restricted, so this is the default behavior you get when using

We may inspect the generated C++ code by running the script with the
environment variable TORCH_LOGS=output_code. When doing so, we can see that
torch.compile was able to compile the broadcasting and the two reductions
into just one for-loop, and parallelize it using OpenMP

extern "C" void kernel(const double* in_ptr0, const long* in_ptr1, long* out_ptr0) {
    #pragma omp parallel num_threads(32)
    #pragma omp for
    for(long i0=0L; i0<20000000L; i0+=1L) {
        auto tmp0 = in_ptr0[2L*i0];
        auto tmp1 = in_ptr1[0L];
        auto tmp5 = in_ptr0[1L + (2L*i0)];
        auto tmp6 = in_ptr1[1L];
        // Rest of the kernel omitted for brevity

Compiling NumPy code into CUDA

Compiling our code so that it runs on CUDA is as simple as setting the
default device to be CUDA

with torch.device("cuda"):
    cuda_pred = compiled_fn(X, means)
assert np.allclose(np_pred, cuda_pred)

By inspecting the generated code via TORCH_LOGS=output_code, we see that,
rather than generating CUDA code directly, torch.compile generates rather
readable triton code

def triton_(in_ptr0, in_ptr1, out_ptr0, XBLOCK : tl.constexpr):
    xnumel = 20000000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (2*x0), xmask)
    tmp1 = tl.load(in_ptr1 + (0))
    // Rest of the kernel omitted for brevity

Running this small snippet on an RTX 2060 gives an 8x speed-up over the
original NumPy code. This is something, but it is not particularly impressive,
given the speed-ups we have seen on CPU. Let’s have a look into how to squeeze
the most out of our GPU via a couple minor changes.

float64 vs float32. Many GPUs, in particular consumer-grade ones, are
rather sluggish when running operations on float64. For this reason, changing
the data generation to float32, the original NumPy code just gets a bit
faster, about a 9%, but our CUDA code gets 40% faster, yielding a 11x
over the plain NumPy code.

torch.compile, by default, respects the NumPy semantics, and as such, it uses
np.float64 as its default dtype for all its creation ops. As discussed, this
can hinder performance, so it is possible to change this default by setting

from torch._dynamo import config
config.numpy_default_float = "float32"

CPU <> CUDA copies. An 11x speed-up is good, but it is not even close to
the CPU numbers. This is caused by a small transformation that torch.compile
does behind the scenes. The code above takes NumPy arrays and returns NumPy
arrays. All of these arrays are on CPU, but the computations are performed on
the GPU. This means that every time the function is called, torch.compile has
to copy all these arrays from CPU to the GPU, and then copy the result back to
CPU to preserve the original semantics. There is no native solution to this
issue in NumPy, as NumPy does not have the notion of a device. That being
said, we can work around it by creating a wrapper to this function so that it
accepts PyTorch tensors and returns PyTorch tensors.

def tensor_fn(X, means):
    X, means = X.numpy(), means.numpy()
    ret = kmeans(X, means)
    return torch.from_numpy(ret)

def cuda_fn(X, means):
    with torch.device("cuda"):
        return tensor_fn(X, means)

This function now takes tensors in CUDA memory and returns tensors in CUDA
memory, but the function itself is written in NumPy! torch.compile uses the
numpy() and the from_numpy() calls as hints, and optimizes them away, and
internally it simply works with PyTorch tensors without moving the memory at
all. When we keep the tensors in CUDA and perform the computations in
float32, we see a 200x speed-up over the initial NumPy implementation on
float32 arrays.

Mixing NumPy and PyTorch. In this example, we had to write a small adaptor
to convert tensors to ndarrays and then back to tensors. In programs that mix
PyTorch and NumPy converting a tensor into an ndarray is often implemented as
x.detach().cpu().numpy(), or simply x.numpy(force=True). Since when running
under torch.compile we can run NumPy code in CUDA, we can implement this
conversion pattern as call to x.numpy(), as we did above. Doing so and
running the resulting code under device("cuda") will generate efficient CUDA
code from original NumPy calls without copying the data from CUDA to CPU at
all. Note that the resulting code does not run without torch.compile. For it
to run in eager mode one would need to rollback to x.numpy(force=True).

Further Speed-up tricks

General advice. The CUDA code we have shown is already quite efficient, but
it is true that the running example is rather short. When dealing with larger
programs, we may need to tweak parts of it to make it more efficient. A good
place to start is the multiple tutorials and FAQs for torch.compile.
This showcases a number of ways to inspect the tracing process, and how to
identify problematic code that may cause slowdowns.

Advice when compiling NumPy code. NumPy, even if rather similar to PyTorch,
is often used very differently. It is rather common to perform computations in
NumPy and then do an if/else depending on values within the array, or perform
operations in-place, perhaps via boolean masks. These constructions, while
supported by torch.compile, hamper its performance. Changes like writing the
code in a branchless way to avoid graph breaks, or avoiding in-place ops can go
a long way.

To write fast NumPy code, it is best to avoid loops, but sometimes they are
unavoidable. When tracing through a loop, torch.compile will try to fully
unroll it. This is sometimes desirable, but sometimes it may not even be
possible, like when we have a dynamic stopping condition, like in a while loop.
In these cases, it may be best to just compile the body of the loop, perhaps a
few iterations at a time (loop unrolling).

Debugging NumPy code. Debugging is rather tricky when a compiler is
involved. To figure out whether an error you are hitting is a torch.compile
error, or an error from the program, you can execute your NumPy program without
torch.compile by replacing the NumPy import by import torch._numpy as np.
This is should just be used for debugging purposes and is in no way a
replacement for the PyTorch API, as it is much slower and, as a private API,
may change without notice. See also this FAQ for other tricks.

Differences between NumPy and torch.compile NumPy

NumPy scalars. NumPy returns NumPy scalars in almost any case where PyTorch
would return a 0-D tensor (e.g. from np.sum). Under torch.compile, NumPy
scalars are treated as 0-D arrays. This is just fine in most cases. The only
case when their behavior diverges is when NumPy scalars are implicitly used as
Python scalars. For example,

>>> np.asarray(2) * [1, 2, 3]  # 0-D array is an array-like
array([2, 4, 6])
>>> u = np.int32(2)
>>> u * [1, 2, 3]              # scalar decays into a Python int
[1, 2, 3, 1, 2, 3]
>>> torch.compile(lambda: u * [1, 2, 3])()
array([2, 4, 6])               # acts as a 0-D array, not as a scalar ?!?!

If we compile the first two lines, we see that torch.compile treats u as a
0-D array. To recover the eager semantics, we just need to make the casting

>>> torch.compile(lambda: int(u) * [1, 2, 3])()
[1, 2, 3, 1, 2, 3]

Type promotion and versioning. NumPy’s type promotion rules may be, at
times, a bit surprising

>>> np.zeros(1, dtype=np.int8) + 127
array([127], dtype=int8)
>>> np.zeros(1, dtype=np.int8) + 128
array([128], dtype=int16)

NumPy 2.0 is changing these rules to follow others that are closer to those
PyTorch. The relevant technical document is NEP 50.
torch.compile went ahead and implemented NEP 50 rather than the about-to-be-deprecated rules.

In general, NumPy within torch.compile follows NumPy 2.0 pre-release.

Beyond NumPy: SciPy and scikit-learn

In parallel to this effort of making torch.compile understand NumPy code,
other Quansight engineers have designed and proposed a way to support PyTorch
tensors within scikit-learn and SciPy. This was received enthusiastically by
other maintainers from these libraries, as it was shown that using PyTorch as a
backend would often yield considerable speed-ups. Both projects have now merged
initial support for PyTorch tensors across a number of APIs and submodules.

This sets the stepping stone to move towards a future where PyTorch tensors can
be used within other libraries in the Python data ecosystem. Even more, this
will enable running these other libraries on GPUs and even compiling code
mixing these libraries and PyTorch, similar to what we have been discussed in
this post.

If you want to learn more about this effort, how to use it, or how to help
moving it forward, see this other blogpost.


PyTorch has committed since its inception to be a framework compatible with the
rest of the Python ecosystem. Enabling compiling NumPy programs, and
establishing the tools necessary to do the same for other prominent libraries
are two more steps in this direction. Quansight and Meta continue working hand
on hand, improving the compatibility between PyTorch and the rest of the

From Quansight, we would like to thank Mengwei, Voz, and Ed for their
invaluable help in integrating our work with torch.compile. We would also
like to thank Meta for funding this project as well as previous work on
improving NumPy compatibility within PyTorch, and the project that led to
supporting PyTorch within scikit-learn and SciPy. These are giant leaps towards
consolidating PyTorch as the framework of choice within the open source Python
data ecosystem.

Read More

PyTorch Edge: Enabling On-Device Inference Across Mobile and Edge Devices with ExecuTorch

Other contributors: Dave Bort, Kimish Patel, Mergen Nachin, Orion Reblitz-Richardson, Andrew Caples

We are excited to announce ExecuTorch, our all-new solution for enabling on-device inference capabilities across mobile and edge devices with the backing of industry leaders like Arm, Apple, and Qualcomm Innovation Center.

As part of PyTorch Edge’s vision for the future of the on-device AI stack and ecosystem, ExecuTorch addresses the fragmentation in the on-device AI ecosystem. It offers a design that provides extension points for seamless third-party integration to accelerate ML models on specialized hardware. Our partners have contributed custom delegate implementations to optimize model inference execution on their respective hardware platforms.

We have created extensive documentation that provides more details about ExecuTorch’s architecture, its high-level components, example ML models running on ExecuTorch, and end-to-end tutorials for exporting and running a model on various hardware devices. We are excited to see all of the innovative use cases of ExecuTorch built by the community.

Key Components of ExecuTorch

ExecuTorch offers a compact runtime with a lightweight operator registry to cover the PyTorch ecosystem of models, and a streamlined path to execute PyTorch programs on edge devices. These devices range from mobile phones to embedded hardware powered by specific delegates built by our partners. In addition, ExecuTorch ships with a Software Developer Kit (SDK) and toolchain that provide an ergonomic UX for ML Developers to go from model authoring to training and device delegation in a single PyTorch workflow. This suite of tools enables ML developers to perform on-device model profiling and better ways of debugging the original PyTorch model.

ExecuTorch is architected from the ground up in a composable manner to allow ML developers to make decisions on what components to leverage as well as entry points to extend them if needed. This design provides the following benefits to the ML community:

  • Portability: Compatibility with a wide variety of computing platforms, from high-end mobile phones to highly constrained embedded systems and microcontrollers.
  • Productivity: Enabling developers to use the same toolchains and SDK from PyTorch model authoring and conversion, to debugging and deployment to a wide variety of platforms, resulting in productivity gains.
  • Performance: Providing end users with a seamless and high-performance experience due to a lightweight runtime as well as its ability to utilize full hardware capabilities, including general purpose CPUs and specialized purpose microprocessors such as NPUs and DSPs.

PyTorch Edge: from PyTorch Mobile to ExecuTorch

Bringing research and production environments closer together is a fundamental goal of PyTorch. ML engineers increasingly use PyTorch to author and deploy machine learning models in highly dynamic and ever-evolving environments, from servers to edge devices such as mobile phones and embedded hardware.

With the increasing adoption of AI in Augmented Reality (AR), Virtual Reality (VR), Mixed Reality (MR), Mobile, IoT and other domains, there is a growing need for an end-to-end on-device solution that is extensible, modular, and aligned with the PyTorch stack.

PyTorch Edge builds on the same fundamental principle of improving research to production by enabling the deployment of various ML models (spanning vision, speech, NLP, translation, ranking, integrity and content creation tasks) to edge devices via a low-friction development and deployment process. It provides a framework stack that spans the universe of on-device use-cases that the PyTorch community cares about.

PyTorch Edge provides portability of core components that is required to reach a wide spectrum of devices which are characterized by differing hardware configurations, performance and efficiency. Such portability is achieved by allowing optimization that are custom developed for the target use-cases, and developer productivity via well defined entry-points, representations, and tools to tie all this together into a thriving ecosystem.

PyTorch Edge is the future of the on-device AI stack and ecosystem for PyTorch. We are excited to see what the community builds with ExecuTorch’s on-device inference capabilities across mobile and edge devices backed by our industry partner delegates.

Read More

Lightning AI Joins the PyTorch Foundation as a Premier Member

The PyTorch Foundation, a neutral home for the deep learning community to collaborate on the open source PyTorch framework and ecosystem, is announcing today that Lightning AI has joined as a premier member.

Lightning AI is the company behind PyTorch Lightning, the platform and open-source framework for companies to build and deploy AI products leveraging the latest generative AI models.

“This is a very important milestone for Lightning AI and the PyTorch Lightning community,” remarks Luca Antiga, Chief Technology Officer of Lightning AI. “By joining the PyTorch Foundation, we are strengthening our commitment to boost the adoption of PyTorch across industries. We look forward to partnering with the Foundation to push the vision of PyTorch forward.”

PyTorch Lightning is one of the leading projects in the PyTorch ecosystem, allowing developers to build, train, fine-tune and deploy AI models at scale. PyTorch Lightning is helping drive the rapid adoption of PyTorch by both the research community and the enterprise.

“Lightning AI has been a great steward of the AI community, and notably a key contributor to PyTorch over the years,” said PyTorch Foundation Executive Director Ibrahim Haddad. “Their goal of making AI research scalable directly aligns with our mission at the foundation.”

As a premier member, Lightning AI is granted one seat to the PyTorch Foundation Governing Board. The Board sets policy through our bylaws, mission and vision statements, describing the overarching scope of foundation initiatives, technical vision, and direction.

We’re happy to welcome Luca Antiga, Chief Technology Officer at Lightning AI, to our board. Luca joined the Lightning AI team in April 2021 when the Tensorwerk team joined Grid AI. Prior to joining Lightning AI, Luca co-founded Orobix, an applied AI company, and Tensorwerk. He was an early core contributor to PyTorch and co-authored Deep Learning with PyTorch (Manning).

To learn more about how you can be a part of the PyTorch Foundation, visit our website.

About Lightning AI

Lightning AI is the creator of PyTorch Lightning, the deep learning platform and open-source framework of choice for developers and companies seeking to build and deploy AI products.

About PyTorch Foundation

The PyTorch Foundation is a neutral home for the deep learning community to collaborate on the open source PyTorch framework and ecosystem. The PyTorch Foundation is supported by its members and leading contributors to the PyTorch open source project. The Foundation leverages resources provided by members and contributors to enable community discussions and collaboration.

About The Linux Foundation

The Linux Foundation is the world’s leading home for collaboration on open source software, hardware, standards, and data. Linux Foundation projects are critical to the world’s infrastructure including Linux, Kubernetes, Node.js, ONAP, PyTorch, RISC-V, SPDX, OpenChain, and more. The Linux Foundation focuses on leveraging best practices and addressing the needs of contributors, users, and solution providers to create sustainable models for open collaboration. For more information, please visit us at linuxfoundation.org. The Linux Foundation has registered trademarks and uses trademarks. For a list of trademarks of The Linux Foundation, please see its trademark usage page. Linux is a registered trademark of Linus Torvalds.

Read More

Huawei Joins the PyTorch Foundation as a Premier Member

Today, the PyTorch Foundation, a neutral home for the deep learning community to collaborate on the open source PyTorch framework and ecosystem, announced that Huawei has joined as a premier member.

Huawei has been a long-standing supporter and contributor to the PyTorch Ecosystem, and, through the release of progressive diverse computing, provides easier access to the PyTorch ecosystem for more hardware vendors. By joining as a premier member, Huawei will continue to optimize PyTorch to fully unleash Ascend computing capabilities.

“We are delighted to join the PyTorch Foundation, and hope to further collaborate with other member companies and expand the community to a wider audience,” said by Zhang Dixuan, President of Huawei Ascend Computing Business, “This move benefits both Huawei, PyTorch, and the wider AI ecosystem. It also aligns with our long-held beliefs in openness, innovation, collaboration, and shared success, and we are confident that it will spur new innovations in the global AI community.”

Huawei unveiled the All Intelligence strategy to accelerate intelligence across all industries. To cater the demand for AI computing needs, Huawei invests in the system-level technologies, and that belief is centered on open hardware and software that enables partners and fosters talent. This strategy aligns with the PyTorch Foundation’s mission to develop AI as part of a sustainable open source ecosystem and produce inclusive technological feats.

PyTorch Foundation Executive Director Ibrahim Haddad said, “We are delighted to welcome Huawei to the PyTorch Foundation. Huawei is a leading body in researching computer vision, natural language processing, speech recognition, and other emerging areas, and has proven experience in the field of foundation models. We have no doubt that we will benefit from their support and guidance.”

As a premier member, Huawei is granted one seat to the PyTorch Foundation Governing Board, and will help set policies, bylaws, and mission and vision statements that define the overarching scope of the PyTorch Foundation’s initiatives, technical vision, and direction.

The Board welcomes Huawei representative Fred Li, Head of Computing Open Source Development Team at Huawei. Fred leads an active and creative team in R&D and operations projects under the principle of “upstream first”, which aims to make diverse computing power ubiquitous.

To learn more about how you can be a part of the PyTorch Foundation, visit our website.

About Huawei

Founded in 1987, Huawei is a leading global provider of information and communications technology (ICT) infrastructure and smart devices. We have 207,000 employees and operate in over 170 countries and regions, serving more than three billion people around the world. We are committed to bringing digital to every person, home and organization for a fully connected, intelligent world.

About PyTorch Foundation

The PyTorch Foundation is a neutral home for the deep learning community to collaborate on the open source PyTorch framework and ecosystem. The PyTorch Foundation is supported by its members and leading contributors to the PyTorch open source project. The Foundation leverages resources provided by members and contributors to enable community discussions and collaboration.

About The Linux Foundation

The Linux Foundation is the world’s leading home for collaboration on open source software, hardware, standards, and data. Linux Foundation projects are critical to the world’s infrastructure including Linux, Kubernetes, Node.js, ONAP, PyTorch, RISC-V, SPDX, OpenChain, and more. The Linux Foundation focuses on leveraging best practices and addressing the needs of contributors, users, and solution providers to create sustainable models for open collaboration. For more information, please visit us at linuxfoundation.org. The Linux Foundation has registered trademarks and uses trademarks. For a list of trademarks of The Linux Foundation, please see its trademark usage page. Linux is a registered trademark of Linus Torvalds.


PyTorch 基金会是深度学习社区在开源 PyTorch 框架和生态系统上进行协作的中立家园,今天宣布华为已作为Primer会员加入。


“通过加入PyTorch基金会,我们可以进一步与其他成员公司共同协作,加速PyTorch社区的发展。”华为昇腾计算业务总裁张迪煊表示,“我们相信这对华为和 PyTorch 生态系统是互惠互利的,也符合我们长期以来开放创新,协作共赢的开源理念,为全球人工智能社区带来更多的兴奋和创新。”

华为发布全面智能化战略,加速千行万业智能化的转型,持续通过系统级持续创新,坚持硬件开放、软件开源、使能伙伴、发展人才,以满足各行各业多样性的AI算力需求。这与 PyTorch 基金会的使命完美契合且相互补充,即通过培育和维持开源生态系统来推动人工智能的发展,并使每个人都能使用这些技术创新。

“华为在计算机视觉、自然语言处理、语音识别等领域进行了广泛的研究,并且在大模型领域也积累了成熟的研究经验。我们相信 PyTorch 基金会将从他们对我们的成员和生态系统的支持中受益匪浅。”PyTorch 基金会执行董事 Ibrahim Haddad 说道。

作为 Primer 会员,华为获得了 PyTorch 基金会董事会的一个席位。董事会通过我们的章程、使命和愿景声明制定政策,描述基金会计划、技术愿景和方向的总体范围。

我们很高兴欢迎华为计算开源业务总经理李永乐加入我们的董事会。李永乐目前负责华为计算产品线开源业务,他领导着一支极具创新又充满活力的技术和运营团队,他们秉持着“Upstream first”的原则,让多样性算力无处不在。

要了解有关如何成为 PyTorch 基金会一部分的更多信息,请访问我们的网站




PyTorch 基金会是深度学习社区在开源 PyTorch 框架和生态系统上进行协作的中立家园。 PyTorch 基金会得到其成员和 PyTorch 开源项目主要贡献者的支持。基金会利用成员和贡献者提供的资源来促进社区讨论和协作。


Linux 基金会是世界领先的开源软件、硬件、标准和数据协作中心。 Linux 基金会项目对世界基础设施至关重要,包括 Linux、Kubernetes、Node.js、ONAP、PyTorch、RISC-V、SPDX、OpenChain 等。 Linux 基金会专注于利用最佳实践并满足贡献者、用户和解决方案提供商的需求,以创建可持续的开放协作模型。欲了解更多信息,请访问我们的 linuxfoundation.org。 Linux 基金会已注册商标并使用商标。有关 Linux 基金会的商标列表,请参阅其商标使用页面:www.linuxfoundation.org/trademark-usage。 Linux 是 Linus Torvalds 的注册商标。

Read More