PyTorch Foundation Technical Advisory Council Elects New Leadership

PyTorch Foundation Technical Advisory Council Elects New Leadership

We are pleased to announce the first-ever Chair and Vice Chair of the PyTorch Foundation’s Technical Advisory Council (TAC): Luca Antiga as the Chair and Jiong Gong as Vice Chair. Both leaders bring extensive experience and deep commitment to the PyTorch community, and they are set to guide the TAC in its mission to foster an open, diverse, and innovative PyTorch technical community.

Meet the New Leadership

Luca Antiga

Luca Antiga is the CTO at Lightning AI since 2022. He is an early contributor to PyTorch core and co-authored “Deep Learning with PyTorch” (published by Manning). He started his journey as a researcher in Bioengineering, and later co-founded Orobix, a company focused on building and deploying AI in production settings.

“I am looking forward to taking on the role of the chair of the PyTorch TAC,” says Luca. “As the TAC chair, I will ensure effective, timely topic selection and enhance visibility of technical needs from the board members and from the ecosystem at large. I will strive for directional, cohesive messaging throughout the transition of PyTorch from Meta to the Linux Foundation.”

Jiong Gong

Jiong Gong is a Principal Engineer and SW Architect for PyTorch Optimization from Intel. He serves as one of the PyTorch CPU module maintainers and is an active contributor to the TorchInductor CPU backend.

“I plan to further strengthen the collaboration between PyTorch developers and hardware vendors, promoting innovation and performance optimization across various hardware platforms, enhancing PyTorch ecosystem and streamlining the decision-making process,” says Jiong. “I am honored to serve as the vice chair of the TAC.”

What Does the TAC Do?

The PyTorch Foundation’s TAC provides a forum for technical communication, leadership, and collaboration for the PyTorch Foundation. The committee members are members of the PyTorch Foundation. The committee holds open meetings once a month that anyone in the community can attend. The committee provides thought leadership on technical topics, knowledge sharing, and a forum to discuss issues with other technical experts in the community.

New TAC Webpage

Stay connected with the PyTorch Foundation’s Technical Advisory Council (TAC) by visiting our new TAC webpage. Here you can find the TAC members, where to view upcoming meeting agendas, access presentations, attend public meetings, watch meeting recordings and participate in discussions on key technical topics.

Plus stay tuned on our blog for regular updates from the PyTorch Foundation TAC leadership.

Read More

PyTorch Conference 2024 Recap: On Fire 🔥

PyTorch Conference 2024 Recap: On Fire 🔥

women dancing with fire

The 2024 PyTorch Conference in San Francisco gathered nearly 1,500 AI researchers, developers, and enthusiasts. Over two days, the event featured engaging discussions, insightful keynotes, and hands-on sessions focused on artificial intelligence (AI) and advancements in PyTorch, the leading open-source machine learning framework. Attendees delved into the future of generative AI, Large Language Models (LLMs), and the crucial role open-source technology plays in driving AI innovation. Here’s a recap of the key themes, highlights, and major takeaways from this year’s conference.

Key Themes of the PyTorch Conference 2024

Three core themes emerged throughout the conference:

  1. Generative AI and LLMs: Many sessions focused on how PyTorch continues to evolve as a primary framework for Large Language Models and Generative AI applications. From scaling these models to optimizing their performance on various hardware platforms, the conference showcased the ongoing advancements and challenges in LLM architecture.
  2. Democratizing AI Through Open Source: One of the recurring themes was the importance of open source tools and communities in shaping the future of AI. PyTorch is committed to inclusivity, ease of use, and accessibility to developers of all levels, with a focus on bringing AI to an even larger global audience.
  3. Distributed and Edge Computing: Distributed computing and edge deployment appeared in many discussions, highlighting how PyTorch is being used to drive AI to the edge. The focus on edge accelerators, scalable training, and inference showcased how PyTorch enables the deployment of powerful models across diverse environments, from the cloud to on-device applications.

panel of people on a conference stage

Watch the Sessions from PyTorch Conference

The PyTorch Conference featured keynote sessions from top AI leaders and interesting lightning talks. You can view all of the conference sessions on our YouTube channel.

PyTorch Conference Startup Showcase

man speaking at a conference

New this year, the Startup Showcase was an exciting addition to the PyTorch Conference. Featuring early-stage founders pitching their AI startups to a panel of top venture capitalists, this event showcased the next generation of AI-driven innovation. The finalists for the inaugural PyTorch Conference Startup Showcase included Remix Inc., Cartesia, OpenBabylon, Remyx AI, A2 Labs, Inc., QuicSnap, Iso AI, CTGT, and Creao.ai, representing some of the most innovative AI/ML startups in the industry. Attendees got a front-row seat to see cutting-edge AI startups in action, while top VCs from the AI industry evaluated the pitches.

Congratulations to the PyTorch Conference Startup Showcase winner, CTGT! Deep learning can be opaque and biased, which limits its potential in crucial areas like healthcare and finance. CTGT is changing the game by enhancing data lineage in LLMs and cutting hallucinations. They’re empowering companies to create customized models using 500x less compute.

View the Startup Showcase

Mini-Summits

The DL Compiler Mini-Summit offered attendees a deep dive into the advances in deep learning (DL) compilers that are transforming AI workloads.

View the DL Compiler Mini-Summit

People watching an event

The Fine-Tuning Mini-Summit brought together a thriving community of researchers, developers, practitioners and hobbyists which focuses on topics ranging from memory efficiency, parameter-efficient fine-tuning and quantization to performance at scale and reproducible evaluations.

View the Fine-Tuning Mini-Summit

Major Takeaways from the PyTorch Conference 2024

Matt giving his keynote

  1. LLMs are Here to Stay: were a focal point of the event, reaffirming their pivotal role in the future of AI. As these models continue to scale, PyTorch remains the preferred framework for developing, training, and deploying them across various platforms and industries.
  2. Open Source Drives Innovation: A key takeaway from the conference was that open-source tools like PyTorch are vital for democratizing AI. This community-driven approach accelerates innovation, enabling researchers and developers globally to collaborate and contribute to faster advancements and more accessible AI technologies.
  3. Ethics and Sustainability Matter: The focus on ethical AI development was a significant takeaway. Talks on the inclusivity of computer vision models, the environmental impacts of AI infrastructure, and the need for transparent, unbiased AI models highlighted the growing importance of ethical considerations in the future of AI.
  4. PyTorch Expands Beyond the Cloud: With several sessions dedicated to edge AI and distributed computing, the conference showcased how PyTorch is expanding beyond cloud-based applications into edge devices and diverse computing environments. This shift is crucial as AI advances into areas like autonomous vehicles, mobile applications, and IoT devices.

Thank You to Our Sponsors

A crowd of people at a conference

Sponsor logos

We would like to thank each of the sponsors that made the PyTorch Conference 2024 possible. These include:

Diamond Sponsors:

  • AMD
  • Cloud Native Computing Foundation
  • IBM
  • Intel – PyTorch
  • Lightning.ai
  • Meta – PyTorch

Platinum Sponsors:

  • Arm
  • Google
  • Lambda Labs
  • Nvidia

Silver Sponsors:

  • Anyscale – PyTorch
  • Baseten
  • Chainguard
  • Databricks
  • Fal
  • FuriosaAi
  • HPE
  • Jane Street
  • Microsoft – PyTorch
  • MinIO
  • Outerbounds
  • Together.AI

Bronze Sponsors:

  • d-Matrix
  • MemVerge
  • Perforated AI
  • Quansight
  • Rotational Labs
  • ScaleGenAI

Special Event Sponsors:

  • PyTorch Flare Party: Hugging Face
  • Startup Showcase: Mayfield
  • Diversity Scholarship: AWS
  • Women and Non-Binary in PyTorch Lunch: Google
  • Happy Hour Reception: Lightning.AI

Thank you for your continued support in advancing the PyTorch ecosystem and helping to shape the future of AI!

Save the Date

See you next year for the PyTorch Conference in San Francisco at the Palace of Fine Arts from October 22-23, 2025.

Read More

CUDA-Free Inference for LLMs

PyTorch Native Architecture Optimization: torchao

By Team PyTorch

We’re happy to officially launch torchao, a PyTorch native library that makes models faster and smaller by leveraging low bit dtypes, quantization and sparsity. torchao is an accessible toolkit of techniques written (mostly) in easy to read PyTorch code spanning both inference and training. This blog will help you pick which techniques matter for your workloads.

We benchmarked our techniques on popular GenAI models like LLama 3 and Diffusion models and saw minimal drops in accuracy. Unless otherwise noted the baselines are bf16 run on A100 80GB GPU.

Our topline metrics for llama 3 are

For inference

  • 97% speedup for Llama 3 8B using autoquant with int4 weight only quantization and hqq
  • 73% peak VRAM reduction for Llama 3.1 8B at 128K context length with a quantized KV cache

For training

  • 50% speedup for Llama 3 70B pretraining using float8 training on H100
  • 30% peak VRAM reduction for Llama 3 8B using 4 bit quantized optimizers.

Our topline metrics for diffusion model inference

  • 53% speedup using float8 dynamic quantization inference with float8 row-wise scaling on flux1.dev onH100
  • 50% reduction in model VRAM for CogVideoX using int8 dynamic quantization

Below we’ll walk through some of the techniques available in torchao you can apply to your models for inference and training.

Inference

Our inference quantization algorithms work over arbitrary PyTorch models that contain nn.Linear layers. Weight only and dynamic activation quantization for various dtypes and sparse layouts can be chosen using our top level quantize_ api

from torchao.quantization import (
quantize_,
int4_weight_only,
)
quantize_(model, int4_weight_only())

Sometimes quantizing a layer can make it slower because of overhead so if you’d rather we just pick how to quantize each layer in a model for you then you can instead run

model = torchao.autoquant(torch.compile(model, mode=’max-autotune’))

quantize_ API has a few different options depending on whether your model is compute bound or memory bound.

from torchao.quantization import (
# Memory bound models
int4_weight_only,
int8_weight_only,

# Compute bound models  
int8_dynamic_activation_int8_semi_sparse_weight,  
int8_dynamic_activation_int8_weight,  
  
# Device capability 8.9+  
float8_weight_only,  
float8_dynamic_activation_float8_weight,   )

We also have extensive benchmarks on diffusion models in collaboration with the HuggingFace diffusers team in diffusers-torchao where we demonstrated 53.88% speedup on Flux.1-Dev and 27.33% speedup on CogVideoX-5b

Our APIs are composable so we’ve for example composed sparsity and quantization to bring 5% speedup for ViT-H inference

But also can do things like quantize weights to int4 and the kv cache to int8 to support Llama 3.1 8B at the full 128K context length running in under 18.9GB of VRAM.

QAT

Post training quantization, especially at less than 4 bit can suffer from serious accuracy degradations. Using Quantization Aware Training (QAT) we’ve managed to recover up to 96% of the accuracy degradation on hellaswag. We’ve integrated this as an end to end recipe in torchtune with a minimal tutorial

Training

Low precision compute and communications

torchao provides easy to use e2e workflows for reducing the precision of training compute and distributed communications, starting with float8 for `torch.nn.Linear` layers.Here is a one-liner to convert the compute gemms of your training run to float8:

from torchao.float8 import convert_to_float8_training
convert_to_float8_training(model)

For an e2e example of how to speed up LLaMa 3 70B pretraining by up to 1.5x with float8, see our README, and torchtitan’s blog and float8 recipe.

Performance and accuracy of float8 pretraining of LLaMa 3 70B, vs bfloat16


(source: https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359)

We are expanding our training workflows to more dtypes and layouts

  1. NF4 QLoRA in torchtune
  2. Prototype int8 training support
  3. Accelerated sparse 2:4 training

Low bit Optimizers

Inspired by Bits and Bytes we’ve also added prototype support for 8 and 4 bit optimizers as a drop in replacement for AdamW.

from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit
optim = AdamW8bit(model.parameters())

Integrations

We’ve been actively working on making sure torchao works well in some of the most important projects in open source.

  1. Huggingface transformers as an inference backend
  2. In diffusers-torchao as a reference implementation for accelerating diffusion models
  3. In HQQ for fast 4 bit inference
  4. In torchtune for PyTorch native QLoRA and QAT recipes
  5. In torchchat for post training quantization
  6. In SGLang for for int4 and int8 post training quantization

#

Conclusion

If you’re interested in making your models faster and smaller for training or inference, we hope you’ll find torchao useful and easy to integrate.

pip install torchao

There are a lot of things we’re excited about next ranging from going lower than 4 bit, performant kernels for high-throughput inference, expanding to more layers, scaling types or granularities, MX hardware support and supporting more hardware backends. If any of the above sounds exciting you can follow our progress at: https://github.com/pytorch/ao

If you’re interested in working on torchao, we’ve created a contributors guide, and if you have any questions we hang out on the #torchao channel on discord.gg/cudamode

Acknowledgements

We are fortunate to stand on the shoulders of giants and collaborate with some of the best people in open source. Thank you!

  1. Bits and Bytes for pioneering work in low bit optimizers and QLoRA
  2. Answer.ai for their engineering work to get FSDP and QLoRA composing
  3. Mobius Labs for the lovely back and forths on quantization algorithms and low bit kernels
  4. HuggingFace transformers for their help in battle testing and integrating our work
  5. HuggingFace diffusers for our collaboration on extensive benchmarks and best practices
  6. torch.compile so we could write our algorithms in pure PyTorch
  7. CUDA MODE for most of our early contributors

Read More

CUDA-Free Inference for LLMs

CUDA-Free Inference for LLMs

PyTorch Native Architecture Optimization: torchao

By Team PyTorch

We’re happy to officially launch torchao, a PyTorch native library that makes models faster and smaller by leveraging low bit dtypes, quantization and sparsity. torchao is an accessible toolkit of techniques written (mostly) in easy to read PyTorch code spanning both inference and training. This blog will help you pick which techniques matter for your workloads.

We benchmarked our techniques on popular GenAI models like LLama 3 and Diffusion models and saw minimal drops in accuracy. Unless otherwise noted the baselines are bf16 run on A100 80GB GPU.

Our topline metrics for llama 3 are

For inference

  • 97% speedup for Llama 3 8B using autoquant with int4 weight only quantization and hqq
  • 73% peak VRAM reduction for Llama 3.1 8B at 128K context length with a quantized KV cache

For training

  • 50% speedup for Llama 3 70B pretraining using float8 training on H100
  • 30% peak VRAM reduction for Llama 3 8B using 4 bit quantized optimizers.

Our topline metrics for diffusion model inference

  • 53% speedup using float8 dynamic quantization inference with float8 row-wise scaling on flux1.dev onH100
  • 50% reduction in model VRAM for CogVideoX using int8 dynamic quantization

Below we’ll walk through some of the techniques available in torchao you can apply to your models for inference and training.

Inference

Our inference quantization algorithms work over arbitrary PyTorch models that contain nn.Linear layers. Weight only and dynamic activation quantization for various dtypes and sparse layouts can be chosen using our top level quantize_ api

from torchao.quantization import (
quantize_,
int4_weight_only,
)
quantize_(model, int4_weight_only())

Sometimes quantizing a layer can make it slower because of overhead so if you’d rather we just pick how to quantize each layer in a model for you then you can instead run

model = torchao.autoquant(torch.compile(model, mode=’max-autotune’))

quantize_ API has a few different options depending on whether your model is compute bound or memory bound.

from torchao.quantization import (
# Memory bound models
int4_weight_only,
int8_weight_only,

# Compute bound models  
int8_dynamic_activation_int8_semi_sparse_weight,  
int8_dynamic_activation_int8_weight,  
  
# Device capability 8.9+  
float8_weight_only,  
float8_dynamic_activation_float8_weight,   )

«««< HEAD:_posts/2024-09-25-pytorch-native-architecture-optimization.md
We also have extensive benchmarks on diffusion models in collaboration with the HuggingFace diffusers team in diffusers-torchao where we demonstrated 53.88% speedup on Flux.1-Dev and 27.33% speedup on CogVideoX-5b
=======

We also have extensive benchmarks on diffusion models in collaboration with the HuggingFace diffusers team in diffusers-torchao where we demonstrated 53.88% speedup on Flux.1-Dev and 27.33% speedup on CogVideoX-5b

97898699f7101b847da377106274783ced03bb3d:_posts/2024-09-25-pytorch-native-architecture-optimizaion.md

Our APIs are composable so we’ve for example composed sparsity and quantization to bring 5% speedup for ViT-H inference

But also can do things like quantize weights to int4 and the kv cache to int8 to support Llama 3.1 8B at the full 128K context length running in under 18.9GB of VRAM.

QAT

Post training quantization, especially at less than 4 bit can suffer from serious accuracy degradations. Using Quantization Aware Training (QAT) we’ve managed to recover up to 96% of the accuracy degradation on hellaswag. We’ve integrated this as an end to end recipe in torchtune with a minimal tutorial

Training

Low precision compute and communications

torchao provides easy to use e2e workflows for reducing the precision of training compute and distributed communications, starting with float8 for `torch.nn.Linear` layers.Here is a one-liner to convert the compute gemms of your training run to float8:

from torchao.float8 import convert_to_float8_training
convert_to_float8_training(model)

For an e2e example of how to speed up LLaMa 3 70B pretraining by up to 1.5x with float8, see our README, and torchtitan’s blog and float8 recipe.

Performance and accuracy of float8 pretraining of LLaMa 3 70B, vs bfloat16


(source: https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359)

We are expanding our training workflows to more dtypes and layouts

  1. NF4 QLoRA in torchtune
  2. Prototype int8 training support
  3. Accelerated sparse 2:4 training

Low bit Optimizers

Inspired by Bits and Bytes we’ve also added prototype support for 8 and 4 bit optimizers as a drop in replacement for AdamW.

from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit
optim = AdamW8bit(model.parameters())

Integrations

We’ve been actively working on making sure torchao works well in some of the most important projects in open source.

  1. Huggingface transformers as an inference backend
  2. In diffusers-torchao as a reference implementation for accelerating diffusion models
  3. In HQQ for fast 4 bit inference
  4. In torchtune for PyTorch native QLoRA and QAT recipes
  5. In torchchat for post training quantization
  6. In SGLang for for int4 and int8 post training quantization

#

Conclusion

If you’re interested in making your models faster and smaller for training or inference, we hope you’ll find torchao useful and easy to integrate.

pip install torchao

There are a lot of things we’re excited about next ranging from going lower than 4 bit, performant kernels for high-throughput inference, expanding to more layers, scaling types or granularities, MX hardware support and supporting more hardware backends. If any of the above sounds exciting you can follow our progress at: https://github.com/pytorch/ao

If you’re interested in working on torchao, we’ve created a contributors guide, and if you have any questions we hang out on the #torchao channel on discord.gg/cudamode

Acknowledgements

We are fortunate to stand on the shoulders of giants and collaborate with some of the best people in open source. Thank you!

  1. Bits and Bytes for pioneering work in low bit optimizers and QLoRA
  2. Answer.ai for their engineering work to get FSDP and QLoRA composing
  3. Mobius Labs for the lovely back and forths on quantization algorithms and low bit kernels
  4. HuggingFace transformers for their help in battle testing and integrating our work
  5. HuggingFace diffusers for our collaboration on extensive benchmarks and best practices
  6. torch.compile so we could write our algorithms in pure PyTorch
  7. CUDA MODE for most of our early contributors

Read More

Challenges and Efforts in PyTorch Multi-Device Integration: Compatibility, Portability, and Integration Efficiencies

Challenges and Efforts in PyTorch Multi-Device Integration: Compatibility, Portability, and Integration Efficiencies

Introduction

As the demand for diverse hardware accelerators grows, the need for a robust and adaptable deep learning framework becomes increasingly critical. While working through this integration, several challenges have surfaced in the PyTorch ecosystem, potentially affecting various hardware vendors. This blog aims to highlight these issues and propose solutions to enhance PyTorch’s adaptability, portability, and resilience across different hardware platforms.

Improve Users’ Code Portability via Accelerator Autoloading

Currently, users face additional work when running their code on different accelerators. One such task is manually importing modules for out-of-tree devices. This requires users to not only understand the different usage patterns between accelerators but also make their code aware of these differences. If you have projects originally running on GPU/CPU and want to migrate to other accelerators, this can lead to significant work and potential frustration.

Examples of extra import:

# Case 1: Use HPU
import torch
import torchvision.models as models
import habana_frameworks.torch # <-- extra import
model = models.resnet50().eval().to("hpu")
input = torch.rand(128, 3, 224, 224).to("hpu")
output = model(input)

# Case 2: Use torch_npu
import torch
import torch_npu # <-- extra import
print(torch.ones(1, 2, device='npu'))

As a high-level machine learning framework, PyTorch’s ability to shield users from device differences is a competitive feature. Accelerator Autoloading allows users to continue using the familiar PyTorch device programming model without explicitly loading or importing device-specific extensions.

How does it works?

Utilize Python’s plugin architecture to enable automatic loading of device extensions via entry points in the PyTorch package.

Python entry points provide a standardized way for Python packages to expose and discover components or plugins within an application. Via definition in accelerator’s package setup.py , PyTorch can automatically initialize accelerator modules when calling import torch , which gives users consistent experience between different backend devices.

From device perspective, only need to claim following setup in setup.py (as example of torch_npu )

// setup.py 
entry_points={
 'torch.backends': ['torch_npu = torch_npu:_autoload', ],
}

When import torch is invoked, the accelerator module will be loaded automatically. This provides users with a consistent programming experience across out-of-tree devices, eliminating the need to be aware of differences between CUDA, HPU, and NPU.

# Case 1: Use HPU 
import torch 
import torchvision.models as models 
model = models.resnet50().eval().to("hpu") 
input = torch.rand(128, 3, 224, 224).to("hpu") 
output = model(input) 

# Case 2: Use torch_npu 
import torch 
print(torch.ones(1, 2, device='npu'))

Device Integration Optimization

What is PrivateUse1?

In PyTorch, the dispatcher is a crucial component of the framework’s backend that manages how operations are routed to the appropriate device-specific implementation. Dispatch keys are an integral part of this system, serving as identifiers that represent various execution contexts—such as the device (CPU, CUDA, XPU), layout (dense, sparse), and autograd functionality. These keys ensure that operations are directed to the correct implementation.

PrivateUse1 is a customizable device dispatch key, similar to CUDA/CPU/XPU, etc.), reserved for out-of-tree devices. It provides developers with a way to extend PyTorch’s functionality without modifying the core framework, allowing for the integration of new devices, hardware accelerators, or other specialized computing environments.

Why do we need PrivateUse1?

Internally, dispatch keys are represented as bit masks, each bit represents whether a certain key is active. This bit mask representation is efficient for quick lookup and combination of keys, but it inherently limits the number of distinct keys (typically to 64 or fewer).

The current implementation of BackendComponent dispatch keys in PyTorch has encountered a critical bottleneck, which restricts the addition of new backends and, as a result, limits the expansion of the PyTorch ecosystem.

bit diagram

In response to this challenge, a series of optimizations have been applied to the PrivateUse1 mechanism to enhance its capacity.

  • PrivateUse1 integration mechanism

    Initially reserved as fallback options, PrivateUse1, along with PrivateUse2 and PrivateUse3, were designed to be activated only when existing key resources became scarce.

    PrivateUse1 is now being developed to match the robustness and versatility of established keys like CUDA and CPU. Achieving this required a deep integration across critical PyTorch modules. This integration wasn’t just a simple switch—it involved significant updates to core components such as AMP (Automatic Mixed Precision), Autograd, Distributed Training, Checkpointing, DataLoader, Optimization, and Quantization, etc.

flow diagram

The activation of PrivateUse1 was a massive collaborative effort, culminating in over 100 pull requests aimed at making it from a placeholder to a fully operational dispatch key.

  • PrivateUse1 UT/CI Quality Assurance

    While unit tests are essential for ensuring quality during the development of the PrivateUse1 mechanism, they are not sufficient on their own to prevent new pull requests from inadvertently affecting existing functionality or compatibility of out-of-tree devices.

    To mitigate this risk, the community has added the pytorch_openreg module to the test suite. This module leverages a CPU backend to simulate interactions with accelerators, creating a controlled environment for rigorous testing. After implemented, this will enable automatic execution of device-generic test cases whenever relevant code is updated, allowing us to quickly detect and address any potential issues affecting the PrivateUse1 integration mechanism.

  • Comprehensive Documentation

    By providing comprehensive and easy-to-understand documentation, we aim to lower the barrier to entry for developers and encourage wider adoption of the PrivateUse1 mechanism in the PyTorch ecosystem. This documentation includes:

    • Step-by-step guides for integrating new backends using PrivateUse1
    • Clear explanations of PrivateUse1’s functionality and benefits
    • Code examples and best practices for efficient implementation

These enhancements aim to improve the robustness and reliability of the PrivateUse1 mechanism, facilitating better integration of new backends and expanding the capabilities of PyTorch.

Compatibility Between Upstream and Downstream

Device-Generic Unit Tests

Most unit tests in PyTorch focus on CPU and CUDA devices, which limits participation from users with other hardware. To address this, a plan to modify PyTorch’s unit testing framework, enabling better support for non-CUDA devices. This plan includes removing existing device restrictions, implementing dynamic data type loading, and generalizing decorators to accommodate a broader range of devices. Additionally, we aim to enforce the use of universal device code and expand distributed testing to support non-NCCL backends.

Through these improvements, we hope to significantly increase test coverage and pass rates for non-CUDA devices, integrating them into PyTorch’s continuous integration process. Initial changes have already been implemented, paving the way for new hardware support and creating a reference template for other devices.

Ensuring Robust Device Integration through Automated Testing

To uphold the high standards of quality assurance in PyTorch, an independent build repository and daily continuous integration (CI) workflows have been established, focusing on smoke and integration testing.

The pytorch-integration-tests repository automates the testing of PyTorch’s device-specific functionalities, ensuring that they operate correctly and efficiently across a variety of hardware platforms(NPUs and other specialized devices). In repository we are trying to make a fully automated system that continuously validates PyTorch’s compatibility with different hardware backends.

  • Automated Integration Tests: Run automated tests across different devices using GitHub Actions. This automation ensures that every change in the codebase is thoroughly tested against multiple hardware platforms, catching potential issues early in the development process.
  • Reusable Workflows: Workflows in this repository are modular and reusable, which streamlines the testing process. Developers can easily adapt these workflows to new devices or testing scenarios, making the system both flexible and scalable as PyTorch evolves.
  • Awareness of Out-of-Tree Devices: The repository displays the existence and behavior of all out-of-tree devices, keeping the community informed. This approach minimizes the risk of accidentally breaking downstream functionalities and provides fast feedback on changes.

Efforts to enhance multi-device integration are pivotal for its adaptability in the evolving deep learning landscape. These initiatives not only benefit current users but also lower entry barriers for new hardware vendors and developers, fostering innovation in AI and machine learning. As PyTorch continues to evolve, its commitment to flexibility, robustness, and inclusivity positions it as a leading framework capable of meeting the diverse needs of the deep learning community.

Read More

Arm Joins the PyTorch Foundation as a Premier Member

The PyTorch Foundation, a neutral home for the deep learning community to collaborate on the open source PyTorch framework and ecosystem, is announcing today that Arm has joined as a premier member.

Arm designs a high-performance, power-efficient compute platform with unmatched scalability, supporting a vast ecosystem of developers deploying AI at the edge and in the cloud, ranging from the Arm instances offered by all major cloud service providers to smartphones, laptops, software-defined vehicles and more.

“Our continued investments in software are accelerating development and AI performance for over 20 million software developers, ensuring they can develop for Arm, on Arm,” said Alex Spinelli, VP Developer Technology at Arm. “PyTorch is a pivotal framework in advancing AI research and development. This membership demonstrates our strong commitment to open source – ensuring PyTorch just works on Arm and can leverage seamless acceleration for the most demanding AI models, now and in the future.”

Last year at the PyTorch Conference, Arm partnered with Apple, Meta and Qualcomm to release ExecuTorch, an end-to-end solution for enabling on-device inference capabilities across mobile and edge devices including wearables, embedded devices and microcontrollers.

“We’re thrilled to welcome Arm to the PyTorch Foundation. As we look to the future of AI and machine learning, the role of specialized silicon and edge devices becomes increasingly crucial. Arm’s expertise in these areas will be invaluable as we work to make PyTorch more efficient and accessible across a wider range of hardware,” said PyTorch Foundation Executive Director Matt White. “This collaboration underscores our commitment to fostering innovation and expanding PyTorch’s capabilities to meet the evolving needs of developers and researchers worldwide.”

As a premier member, Arm is granted one seat to the PyTorch Foundation Governing Board. The Board sets policy through our bylaws, mission and vision statements, describing the overarching scope of foundation initiatives, technical vision, and direction.

We’re happy to welcome Alex Spinelli, VP Developer Technology at Arm, to our board. Prior to Arm, Alex was VP of Product for Core Machine Learning at Google, where he led Google’s technology and infrastructure for building, training, and serving machine learning, including the TensorFlow stack.

To learn more about how you can be a part of the PyTorch Foundation, visit our website.

About PyTorch Foundation

The PyTorch Foundation is a neutral home for the deep learning community to collaborate on the open source PyTorch framework and ecosystem. The PyTorch Foundation is supported by its members and leading contributors to the PyTorch open source project. The Foundation leverages resources provided by members and contributors to enable community discussions and collaboration.

About The Linux Foundation

The Linux Foundation is the world’s leading home for collaboration on open source software, hardware, standards, and data. Linux Foundation projects are critical to the world’s infrastructure including Linux, Kubernetes, Node.js, ONAP, PyTorch, RISC-V, SPDX, OpenChain, and more. The Linux Foundation focuses on leveraging best practices and addressing the needs of contributors, users, and solution providers to create sustainable models for open collaboration. For more information, please visit us at linuxfoundation.org. The Linux Foundation has registered trademarks and uses trademarks. For a list of trademarks of The Linux Foundation, please see its trademark usage page. Linux is a registered trademark of Linus Torvalds.

Read More

PyTorch Shanghai Meetup Notes

PyTorch Shanghai Meetup Notes

Summary

group photo

We are honored to successfully host the PyTorch Shanghai Meetup on August 15, 2024. This Meetup has received great attention from the industry. We invited senior PyTorch developers from Intel and Huawei as guest speakers, who shared their valuable experience and the latest technical trends. In addition, this event also attracted PyTorch enthusiasts from many technology companies and well-known universities. A total of more than 40 participants gathered together to discuss and exchange the latest applications and technological advances of PyTorch.

This Meetup not only strengthened the connection between PyTorch community members, but also provided a platform for local AI technology enthusiasts to learn, communicate and grow. We look forward to the next gathering to continue to promote the development of PyTorch technology in the local area.

1. PyTorch Foundation Updates

man instructing students

PyTorch Board member Fred Li shared the latest updates in the PyTorch community, He reviewed the development history of the PyTorch community, explained in detail the growth path of community developers, encouraged everyone to delve deeper into technology, and introduced the upcoming PyTorch Conference 2024 related matters.

2. Intel’s Journey with PyTorch Democratizing AI with ubiquitous hardware and open software

PyTorch CPU module maintainer Jiong Gong shared 6-year technical contributions from Intel to PyTorch and its ecosystem, explored the remarkable advancements that Intel has made in both software and hardware democratizing AI, ensuring accessibility, and optimizing performance across a diverse range of Intel hardware platforms.

man instructing students

3. Exploring Multi-Backend Support in PyTorch Ecosystem: A Case Study of Ascend

man instructing students

Fengchun Hua, a PyTorch contributor from Huawei, took Huawei Ascend NPU as an example to demonstrate the latest achievements in multi-backend support for PyTorch applications. He introduced the hardware features of Huawei Ascend NPU and the infrastructure of CANN (Compute Architecture for Neural Networks), and explained the key achievements and innovations in native support work. He also shared the current challenges and the next work plan.

Yuanhao Ji, another PyTorch contributor from Huawei, then introduced the Autoload Device Extension proposal, explained its implementation details and value in improving the scalability of PyTorch, and introduced the latest work progress of the PyTorch Chinese community.

4. Intel XPU Backend for Inductor

man instructing students

Eikan is a PyTorch contributor from Intel. He focuses on torch.compile stack for both Intel CPU and GPU. In this session, Eikan presented Intel’s efforts on torch.compile for Intel GPUs. He provided updates on the current status of Intel GPUs within PyTorch, covering both functionality and performance aspects. Additionally, Eikan used Intel GPU as a case study to demonstrate how to integrate a new backend into the Inductor using Triton.

5. PyTorch PrivateUse1 Evolution Approaches and Insights

man instructing students

Jiawei Li, a PyTorch collaborator from Huawei, introduced PyTorch’s Dispatch mechanism and emphasized the limitations of DIspatchKey. He took Huawei Ascend NPU as an example to share the best practices of the PyTorch PrivateUse1 mechanism. He mentioned that while using the PrivateUse1 mechanism, Huawei also submitted many improvements and bug fixes for the mechanism to the PyTorch community. He also mentioned that due to the lack of upstream CI support for out-of-tree devices, changes in upstream code may affect their stability and quality, and this insight was recognized by everyone.

Read More

CUDA-Free Inference for LLMs

CUDA-Free Inference for LLMs

In this blog, we discuss the methods we used to achieve FP16 inference with popular LLM models such as Meta’s Llama3-8B and IBM’s Granite-8B Code, where 100% of the computation is performed using OpenAI’s Triton Language.
For single token generation times using our Triton kernel based models, we were able to approach 0.76-0.78x performance relative to the CUDA kernel dominant workflows for both Llama and Granite on Nvidia H100 GPUs, and 0.62-0.82x on Nvidia A100 GPUs.

Why explore using 100% Triton? Triton provides a path for enabling LLMs to run on different types of GPUs – NVIDIA, AMD, and in the future Intel and other GPU based accelerators. It also provides a higher layer of abstraction in Python for programming GPUs and has allowed us to write performant kernels faster than authoring them using vendor specific APIs. In the rest of this blog, we will share how we achieve CUDA-free compute, micro-benchmark individual kernels for comparison, and discuss how we can further improve future Triton kernels to close the gaps.

Figure 1. Inference throughput benchmarks with Triton and CUDA variants of Llama3-8B and Granite-8B, on NVIDIA H100 and A100
Settings: batch size = 2, input sequence length = 512, output sequence length = 256

2.0 Composition of a Transformer Block

We start with a breakdown of the computations that happen in Transformer-based models. The figure below shows the “kernels” of a typical Transformer block.

Figure 2. Transformer Block by core kernels

The core operations for a Llama3 architecture are summarized in this list:

  1. RMSNorm
  2. Matrix multiplication: Fused QKV
  3. RoPE
  4. Attention
  5. Matrix multiplication: Output Projection
  6. RMSNorm
  7. Matrix multiplication: Fused Gate + Up Projection
  8. Activation function: SiLU
  9. Element Wise Multiplication
  10. Matrix multiplication: Down Projection

Each of these operations is computed on the GPU through the execution of one (or multiple) kernels. While the specifics of each of these kernels can vary across different transformer models, the core operations remain the same. For example, IBM’s Granite 8B Code model uses bias in the MLP layer, different from Llama3. Such changes do require modifications to the kernels. A typical model is a stack of these transformer blocks wired together with embedding layers.

3.0 Model Inference

Typical model architecture code is shared with a python model.py file that is launched by PyTorch. In the default PyTorch eager execution mode, these kernels are all executed with CUDA. To achieve 100% Triton for end-to-end Llama3-8B and Granite-8B inference we need to write and integrate handwritten Triton kernels as well as leverage torch.compile (to generate Triton ops). First, we replace smaller ops with compiler generated Triton kernels, and second, we replace more expensive and complex computations (e.g. matrix multiplication and flash attention) with handwritten Triton kernels.

Torch.compile generates Triton kernels automatically for RMSNorm, RoPE, SiLU and Element Wise Multiplication. Using tools like Nsight Systems we can observe these generated kernels; they appear as tiny dark green kernels in-between the matrix multiplications and attention.

Figure 3. Trace of Llama3-8B with torch.compile, showing CUDA kernels being used for matrix multiplications and flash attention

For the above trace, we note that the two major ops that make up 80% of the E2E latency in a Llama3-8B style model are matrix multiplication and attention kernels and both remain CUDA kernels. Thus to close the remaining gap, we replace both matmul and attention kernels with handwritten Triton kernels.

4.0 Triton SplitK GEMM Kernel

For the matrix multiplications in the linear layers, we wrote a custom FP16 Triton GEMM (General Matrix-Matrix Multiply) kernel that leverages a SplitK work decomposition. We have previously discussed this parallelization in other blogs as a way to accelerate the decoding portion of LLM inference.

5.0 GEMM Kernel Tuning

To achieve optimal performance we used the exhaustive search approach to tune our SplitK GEMM kernel. Granite-8B and Llama3-8B have linear layers with the following shapes:

Linear Layer Shape (in_features, out_features)
Fused QKV Projection (4096, 6144)
Output Projection (4096, 4096)
Fused Gate + Up Projection (4096, 28672)
Down Projection (14336, 4096)

Figure 4. Granite-8B and Llama3-8B Linear Layer Weight Matrix Shapes

Each of these linear layers have different weight matrix shapes. Thus, for optimal performance the Triton kernel must be tuned for each of these shape profiles. After tuning for each linear layer we were able to achieve 1.20x E2E speedup on Llama3-8B and Granite-8B over the untuned Triton kernel.

6.0 Flash Attention Kernel

We evaluated a suite of existing Triton flash attention kernels with different configurations, namely:

  1. AMD Flash
  2. OpenAI Flash
  3. Dao AI Lab Flash
  4. XFormers Flash
  5. PyTorch FlexAttention

We evaluated the text generation quality of each of these kernels, first, in eager mode and then (if we were able to torch.compile the kernel with standard methods) compile mode. For kernels 2-5, we noted the following:

Kernel Text Generation Quality Torch.compile Support for Arbitrary Sequence Length
AMD Flash Coherent Yes Yes
OpenAI Flash Incoherent Did not evaluate. WIP to debug precision in eager mode first No
Dao AI Lab Flash Incoherent Did not evaluate. WIP to debug precision in eager mode first Yes
Xformers FlashDecoding Hit a compilation error before we were able to evaluate text quality WIP No (This kernel is optimized for decoding)
PyTorch FlexAttention Coherent WIP WIP

Figure 5. Table of combinations we tried with different Flash Attention Kernels

The above table summarizes what we observed out-of-the box. With some effort we expect that kernels 2-5 can be modified to meet the above criteria. However, this also shows that having a kernel that works for benchmarking is often only the start of having it usable as an end to end production kernel.
We chose to use the AMD flash attention kernel in our subsequent tests as it can be compiled via torch.compile and produces legible output in both eager and compiled mode.

To satisfy torch.compile compatibility with the AMD flash attention kernel, we had to define it as a torch custom operator. This process is explained in detail here. The tutorial link discusses how to wrap a simple image crop operation. However, we note that wrapping a more complex flash attention kernel follows a similar process. The two step approach is as follows:

  1. Wrap the function into a PyTorch Custom Operator

  1. Add a FakeTensor Kernel to the operator, which given the shapes of the input tensors of flash (q, k and v) provides a way to compute the output shape of the flash kernel

After defining the Triton flash kernel as a custom op, we were able to successfully compile it for our E2E runs.

Figure 6. Trace of Llama3-8B with torch.compile, after swapping in Triton matmul and Triton flash attention kernels

From Figure 5, we note that now, after integrating both the SplitK matrix multiplication kernel, the torch op wrapped flash attention kernel, and then running torch.compile, we are able to achieve a forward pass that uses 100% Triton computation kernels.

7.0 End-to-End Benchmarks

We performed end-to-end measurements on NVIDIA H100s and A100s (single GPU) with Granite-8B and Llama3-8B models. We performed our benchmarks with two different configurations.

The Triton kernel configuration uses:

  1. Triton SplitK GEMM
  2. AMD Triton Flash Attention

The CUDA Kernel configuration uses:

  1. cuBLAS GEMM
  2. cuDNN Flash Attention – Scaled Dot-Product Attention (SDPA)

We found the following throughput and inter-token latencies for both eager and torch compiled modes, with typical inference settings:

GPU Model Kernel Config Median Latency (Eager) [ms/tok] Median Latency (Compiled) [ms/tok]
H100 Granite-8B Triton 27.42 11.59
    CUDA 18.84 9.50
  Llama3-8B Triton 20.36 10.61
    CUDA 16.59 8.59
A100 Granite-8B Triton 53.44 16.88
    CUDA 37.13 14.25
  Llama3-8B Triton 44.44 17.94
    CUDA 32.45 12.96

Figure 7. Granite-8B and Llama3-8B Single Token Generation Latency on H100 and A100,
(batch size = 2, input sequence length = 512, output sequence length = 256)

To summarize, the Triton models can get up to 78% of the performance of the CUDA models on the H100 and up to 82% on the A100.

The performance gap can be explained by the kernel latencies we observe for matmul and flash attention, which are discussed in the next section.

8.0 Microbenchmarks

Kernel Triton [us] CUDA [us]
QKV Projection Matmul 25 21
Flash Attention 13 8
Output Projection Matmul 21 17
Gate + Up Projection Matmul 84 83
Down Projection Matmul 58 42

Figure 8. Triton and CUDA Kernel Latency Comparison (Llama3-8B on NVIDIA H100)
Input was an arbitrary prompt (bs=1, prompt = 44 seq length), decoding latency time

From the above, we note the following:

  1. Triton matmul kernels are 1.2-1.4x slower than CUDA

  2. AMDs Triton Flash Attention kernel is 1.6x slower than CUDA SDPA

These results highlight the need to further improve the performance of kernels that are core primitives like GEMM and Flash Attention. We leave this as future research, as recent works (e.g. FlashAttention-3, FlexAttention) provide ways to leverage the underlying hardware better as well as Triton pathways that we hope to be able to build on to produce greater speedups. To illustrate this, we compared FlexAttention with SDPA and AMD’s Triton Flash kernel.

We are working to verify E2E performance with FlexAttention. For now, initial microbenchmarks with Flex show promise for longer context lengths and decoding problem shapes, where the query vector is small:

Figure 9. FlexAttention Kernel Benchmarks on NVIDIA H100 SXM5 80GB
(batch=1, num_heads=32, seq_len=seq_len, head_dim=128)

9.0 Future Work

For future work we plan to explore ways to further optimize our matmuls that leverage the hardware better, such as this blog we published on utilizing TMA for H100, as well as different work decompositions (persistent kernel techniques like StreamK etc.) to get greater speedups for our Triton-based approach. For flash attention, we plan to explore FlexAttention and FlashAttention-3 as the techniques used in these kernels can be leveraged to help further close the gap between Triton and CUDA.
We also note that our prior work has shown promising results for FP8 Triton GEMM kernel performance versus cuBLAS FP8 GEMM, thus in a future post we will explore E2E FP8 LLM inference.

Read More

Accelerate Your AI: PyTorch 2.4 Now Supports Intel GPUs for Faster Workloads

Accelerate Your AI: PyTorch 2.4 Now Supports Intel GPUs for Faster Workloads

We have exciting news! PyTorch 2.4 now supports Intel® Data Center GPU Max Series and the SYCL software stack, making it easier to speed up your AI workflows for both training and inference. This update allows for you to have a consistent programming experience with minimal coding effort and extends PyTorch’s device and runtime capabilities, including device, stream, event, generator, allocator, and guard, to seamlessly support streaming devices. This enhancement simplifies deploying PyTorch on ubiquitous hardware, making it easier for you to integrate different hardware back ends.

Intel GPU support upstreamed into PyTorch provides support for both eager and graph modes, fully running Dynamo Hugging Face benchmarks. Eager mode now includes common Aten operators implemented with SYCL. The most performance-critical graphs and operators are highly optimized by using oneAPI Deep Neural Network Library (oneDNN) and oneAPI Math Kernel Library (oneMKL). Graph mode (torch.compile) now has an enabled Intel GPU back end to implement the optimization for Intel GPUs and to integrate Triton. Furthermore, data types such as FP32, BF16, FP16, and automatic mixed precision (AMP) are supported. The PyTorch Profiler, based on Kineto and oneMKL, is being developed for the upcoming PyTorch 2.5 release.

Take a look at the current and planned front-end and back-end improvements for Intel GPU upstreamed into PyTorch.

the current and planned front-end and back-end improvements for Intel GPU upstreamed into PyTorch

PyTorch 2.4 on Linux supports Intel Data Center GPU Max Series for training and inference while maintaining the same user experience as other hardware. If you’re migrating code from CUDA, you can run your existing application on an Intel GPU with minimal changes—just update the device name from cuda to xpu. For example:

# CUDA Code 
tensor = torch.tensor([1.0, 2.0]).to("cuda") 
 
# Code for Intel GPU 
tensor = torch.tensor([1.0, 2.0]).to("xpu")

Get Started

Try PyTorch 2.4 on the Intel Data Center GPU Max Series through the Intel® Tiber™ Developer Cloud. Get a tour of the environment setup, source build, and examples. To learn how to create a free Standard account, see Get Started, then do the following:

  1. Sign in to the cloud console.

  2. From the Training section, open the PyTorch 2.4 on Intel GPUs notebook.

  3. Ensure that the PyTorch 2.4 kernel is selected for the notebook.

Summary

PyTorch 2.4 introduces initial support for Intel Data Center GPU Max Series to accelerate your AI workloads. With Intel GPU, you’ll get continuous software support, unified distribution, and synchronized release schedules for a smoother development experience. We’re enhancing this functionality to reach Beta quality in PyTorch 2.5. Planned features in 2.5 include:

  • More Aten operators and full Dynamo Torchbench and TIMM support in Eager Mode.

  • Full Dynamo Torchbench and TIMM benchmark support in torch.compile.

  • Intel GPU support in torch.profile.

  • PyPI wheels distribution.

  • Windows and Intel Client GPU Series support.

We welcome the community to evaluate these new contributions to Intel GPU support on PyTorch. 

Resources

Acknowledgments

We want thank PyTorch open source community for their technical discussions and insights: Nikita Shulga, Jason Ansel, Andrey Talman, Alban Desmaison, and Bin Bao.

We also thank collaborators from PyTorch for their professional support and guidance.

1 To enable GPU support and improve performance, we suggest installing the Intel® Extension for PyTorch

Read More

Enabling Fast Gradient Clipping and Ghost Clipping in Opacus

Enabling Fast Gradient Clipping and Ghost Clipping in Opacus

Introduction and Context

Differentially Private Stochastic Gradient Descent (DP-SGD) is the canonical method for training machine learning models with differential privacy. It involves the following two modifications to its non-private counterpart, Stochastic Gradient Descent.

  1. Per-sample gradient clipping: Clip gradients with respect to every sample in the mini-batch, ensuring that its norm is at most a pre-specified value, “Clipping Norm”, C, in every iteration.

  2. Noise addition: Add Gaussian noise of pre-specified variance, depending on the clipping norm and privacy parameters, to the average clipped gradient, in every iteration.

The first change, per-sample gradient clipping, introduces additional complexities since, in general, it requires instantiating per-sample gradients.

Opacus is a PyTorch implementation of DP-SGD. Opacus addresses the above task by employing hook functions, which allows intervening on specific events, such as forward and backward passes. For more details about Opacus, we encourage readers to review the previous blog posts: DP-SGD Algorithm Explained, Efficient Per-Sample Gradient Computation in Opacus and Efficient Per-Sample Gradient Computation for More Layers in Opacus.

While Opacus provides substantial efficiency gains compared to the naive approaches, the memory cost of instantiating per-sample gradients is significant. In particular, memory usage is proportional to the batch size times the number of trainable parameters. Consequently, memory limits Opacus to small batch sizes and/or small models, significantly restricting its range of applications.

We introduce Fast Gradient Clipping and Ghost Clipping to Opacus, which enable developers and researchers to perform gradient clipping without instantiating the per-sample gradients. As an example, this allows for fine-tuning 7M parameters of BERT, on a single 16GB GPU, with a batch size of 1024, with memory comparable to using PyTorch (without applying DP-SGD). In contrast, the previous version of Opacus, supported a maximum batch size of roughly 256 for the same setting. We provide a tutorial on how to use Fast Gradient Clipping in Opacus with the aforementioned task as an example.

Fast Gradient Clipping and Ghost Clipping

The key idea behind these techniques is based on the following observation: suppose per-sample gradient norms are known, then gradient clipping can be achieved by backpropagation on a re-weighted loss function $ bar{L} $. This loss function is defined as $ bar{L} = sum_{i} R_{i} L_{i} $, where $ R_i = minleft(frac{C}{C_i}, 1right) $ are the clipping coefficients computed from the per-sample gradient norms $ {C_i} $ and $ {L_i} $ are per-sample losses.

The above idea may seem circular at first glance, as it appears to require instantiating per-sample gradients in order to calculate per-sample gradient norms. However, for certain widely-used components of neural network architectures, such as fully connected/linear layers, it is indeed possible to obtain per-sample gradient norms in a single backpropagation pass without the need for per-sample gradients. This suggests a workflow that involves two backpropagation passes: the first to compute per-sample gradient norms, and the second to compute the aggregated (not per-sample) clipped gradient. The second backpropagation is simply the standard batched backpropagation.

backpropagation diagram

backpropagation diagram

Figure 1: Comparison between vanilla Opacus (top left), Fast Gradient Clipping (top right), and Ghost clipping (bottom). We marked in red gradient instantiations that become memory bottlenecks. For vanilla Opacus, it has to instantiate the per-sample gradients. Fast Gradient Clipping instantiates per-sample gradients for each layer to compute its norm, which is immediately released once the backward pass moves on to the next layer. Ghost Clipping works directly from per-sample activation gradients and per-sample activations, and avoids the need for gradient instantiation.

Fast Gradient Clipping
In Fast Gradient Clipping, the per-sample gradient norm is calculated in three steps:

  1. For each layer, the per-sample gradient is instantiated and its norm is calculated.
  2. The per-sample gradient is then immediately discarded.
  3. The (squared) per-sample gradient norms of each layer are summed up to obtain the overall (squared) per-sample gradient norm.

Ghost Clipping
Extending the approach of Fast Gradient Clipping, Ghost Clipping uses the fact that for linear layers1, per-sample gradient norms can be calculated just from activation gradients and activations. In particular, let backprops and activations be per-sample activation gradients and activations, of dimensions batch_size ✕ output_width and batch_size ✕ input_width, respectively. The per-sample gradient is the outer product of the two, which takes O(batch_size ✕ input_width ✕ output_width) time and space.

The ghost clipping trick instead calculates the (squared) norm of backprops and activations, sample-wise, and takes their product, which gives the (squared) norm of the gradient. This takes O(batch-size ✕ (input_width + output_width)) time and takes O(batch-size) space to store. Since per-sample activation and per-sample activation gradients are already stored, additional memory is needed only for storing the norms.

Relationship between Fast Gradient Clipping and Ghost Clipping

  1. Fast Gradient Clipping and Ghost Clipping are complementary techniques. Fast Gradient Clipping can be applied to any type of layer, while Ghost Clipping is a strictly better technique for supported layers.
  2. Our implementation automatically switches to Fast Gradient Clipping when the layer is not supported by Ghost Clipping.

How to use Fast Gradient Clipping in Opacus

The training loop is identical to that of the standard PyTorch loop. As in Opacus before, we use the PrivacyEngine(), which “sanitizes” the model and optimizer. To enable Ghost Clipping, the argument grad_sample_mode="ghost" is used. Additionally, make_private() takes the loss criterion as an extra input and sanitizes it. This allows us to hide the two backward passes and the loss rescaling in between in loss.backward().

from opacus import PrivacyEngine
criterion = nn.CrossEntropyLoss() # example loss function

privacy_engine = PrivacyEngine()
model_gc, optimizer_gc, criterion_gc, train_loader, = privacy_engine.make_private(
        module=model,
        optimizer=optimizer,
        data_loader=train_loader,
        noise_multiplier=noise_multiplier
        max_grad_norm=max_grad_norm,
	 criterion=criterion,
        grad_sample_mode="ghost",
)

# The training loop below is identical to that of PyTorch

for input_data, target_data in train_loader:
    output_gc = model_gc(input_data) # Forward pass
    optimizer_gc.zero_grad()
    loss = criterion_gc(output_gc, target_data)
    loss.backward()
    optimizer_gc.step()  # Add noise and update the model

Internally, before the first pass, we enable the hooks, which allows us to capture layer-wise values corresponding to forward and backward calls. They are used to compute the per-sample gradient norms. We then compute the clipping coefficients, rescale the loss function and disable hooks, which lets us use the standard PyTorch backward pass.

Memory Complexity Analysis

Consider a multi-layer neural network with the following properties:

L: Number of layers
d: Maximum layer width
B: Batch size
K: Number of non-supported/non-linear layers

The memory overhead of DP-SGD with Ghost Clipping compared to plain (PyTorch) SGD is an additive O(BL), required to store the per-sample gradient norms for all layers. Further, if there is a non-supported layer (if K≥1), then there is an additional O(Bd2) memory to instantiate the gradient of that layer.

Memory Benchmarking

We provide results on the memory usage for a variety of settings.

Fine-Tuning BERT

We consider the problem of privately fine-tuning the last three layers of BERT for a text classification task. The base model has over 100M parameters, of which we fine-tune the last three layers, BertEncoder, BertPooler, and Classifier, comprising roughly 7.6M parameters. The experiments are run on a P100 GPU with 16 GB of memory.

The following table reports the maximum memory and time taken per iteration for the various methods:

Batch size
B = 32 B = 128 B = 512 B = 1024 B = 2048
Mem Time Mem Time Mem Time Mem Time
PyTorch SGD 236 MB 0.15 s 1.04 GB 0.55 s 5.27 GB 2.1 s 12.7 GB 4.2 s OOM
DP-SGD 1,142 MB 0.21 s 4.55 GB 0.68 s OOM OOM OOM
FGC DP-SGD 908 MB 0.21 s 3.6 GB 0.75 s OOM OOM OOM
GC DP-SGD 362 MB 0.21 s 1.32 GB 0.67 s 5.27 GB 2.5 s 12.7 GB 5 s OOM

In terms of peak memory footprint, DP-SGD > FGC DP-SGD ≫ GC DP-SGD ≈ PyTorch SGD. Further, the runtimes are similar because most of the parameters are frozen and the forward pass takes up most of the time.

Synthetic Setup: Memory Profiling

We consider the following setup to profile the memory used by PyTorch SGD, Vanilla DP-SGD and Ghost Clipping, GC DP-SGD.

  • 2-layer fully connected neural network
    • Input: 5120
    • Hidden: 2560
    • Output: 1280
    • Total number of model parameters = 15.6M
    • Model size = 62.5 MB
  • Batch size, different values, as seen in the table below.

The table below summarizes the max memory increase (in MB) broken down by stages of the training loop for each of the methods.

Batch Size Method Model to GPU Forward First Backward Second Backward Optimizer Step
32 PyTorch SGD 62.5 0.5 62.5 N/A 0
Vanilla DP-SGD 62.5 0.47 3,663 N/A 162.5
GC DP-SGD 62.5 0.47 63.13 50 125
217 PyTorch SGD 62.5 1920 1932.5 N/A 0
Vanilla DP-SGD OOM
GC DP-SGD 62.5 1920 2625 1932.5 125

Industry use case

We tested Ghost Clipping DP-SGD on an internal Meta use case, consisting of a model of size roughly 100B with 40M trainable parameters. Our initial results show that Ghost Clipping SGD reduces 95% memory of vanilla DP-SGD, and achieves comparable memory usage to PyTorch SGD.

Conclusion

In this post, we describe implementations of Fast Gradient Clipping and Ghost Clipping in Opacus that enable memory-efficient training of machine learning models with differential privacy. Currently, the Ghost Clipping implementation only applies to linear layers, but, as outlined in part 3 of the series, it can be extended to “generalized” linear layers such as convolutions and multi-head attention. The current techniques require two explicit backpropagation steps, which increases runtime. We will explore developments on top of Ghost Clipping such as the Book-Keeping algorithm for mitigation.

To learn more about Opacus, visit opacus.ai and github.com/pytorch/opacus.

Acknowledgements

We thank Iden Kalemaj, Darren Liu, Karthik Prasad, Hao Shi, Igor Shilov, Davide Testuggine, Eli Uriegas, Haicheng Wang, and Richard Zou for valuable feedback and suggestions.

  1. There are ways to extend Ghost Clipping to non-linear layers. 

Read More