OpenReg: A Self-Contained PyTorch Out-of-Tree Backend Implementation Using “PrivateUse1” Mechanism

OpenReg: A Self-Contained PyTorch Out-of-Tree Backend Implementation Using “PrivateUse1” Mechanism

OpenReg is a self-contained demonstration of a PyTorch out-of-tree backend implementation utilizing the core framework’s “PrivateUse1” mechanism. This implementation serves two primary purposes:

  1. Reference Implementation: Provides a practical template for third-party device vendors integrating with PyTorch through PrivateUse1.
  2. CI Testing Infrastructure: Enables device-agnostic testing capabilities for continuous integration pipelines.

Usage

Module Installation

cd {project}/test/cpp_extensions/open_registration_extension
python setup.py install

Use Case

import torch
import pytorch_openreg

if __name__ == "__main__":
   print(torch.ones(1, 2, device='openreg'))

Architectural Overview

Process Management

OpenReg implements virtual device isolation by spawning N independent subprocesses, each maintaining dedicated request/response queues for inter-process communication. The parent process driver encapsulates device operations into command packets that are:

  1. Dispatched to target devices via request queues
  2. Processed asynchronously with results returned through response queues

Parent-Subprocess Communication Flow

Figure: Parent-Subprocess Communication Flow

Memory Management

Device memory allocations occur within individual subprocesses to ensure:

  1. Strict memory isolation between devices
  2. Realistic simulation of physical device constraints

Component Breakdown

_aten_impl.py

This module handles dual responsibilities:

  1. Hook Registration:
    • Utilizes _IMPL_REGISTRY to bind C++ backend hooks (e.g., getDevice, getStream) to device driver implementations
  2. Fallback Mechanism:
    • Define a new torch.Library that registers a fallback that will be called whenever a backend kernel for PrivateUse1 is called. It contains the logic to handle all kind of native functions, computing the output metadata, allocating it and only calling into the device daemon to perform computation

_device_daemon.py

Core Subsystems

  1. Allocators:
    • HostAllocator: Manages pinned memory in parent process
    • DeviceAllocator: Handles device memory with tensor reconstruction capabilities
  2. Driver (Parent Process):
    • Maintains device context (active device/streams)
    • Implements device control operations:
      • setDevice/getDevice
      • deviceCount
      • exchangeStream
    • Orchestrates command execution through queue-based IPC
  3. Executor (Subprocess):
    • Processes command types:
      • Memory operations (malloc/free)
      • Tensor computations (run_op)
      • Data transfers (send_data/recv_data)
      • Stream/event management (primarily no-op due to CPU sync nature)

_meta_parser.py

Key Features:

  • Implements serialization utilities for cross-process object transfer
  • OpenRegTensorMeta class encapsulates complete tensor metadata for:
    • Output tensor reconstruction
    • Device-side computation preparation

Design Considerations

Execution Characteristics

  • Synchronous Computation: CPU operator execution necessitates synchronous processing
  • Stream/Event Semantics: Implemented as no-ops due to synchronous execution model
  • Memory Isolation: Strict per-device memory boundaries enforced through subprocess allocation

This architecture enables realistic simulation of device integration while maintaining PyTorch compatibility through standard backend interfaces.

Read More

SGLang Joins PyTorch Ecosystem: Efficient LLM Serving Engine

SGLang Joins PyTorch Ecosystem: Efficient LLM Serving Engine

sglang logo

We’re thrilled to announce that the SGLang project has been integrated into the PyTorch ecosystem! This integration ensures that SGLang aligns with PyTorch’s standards and practices, providing developers with a reliable and community-supported framework for fast and flexible serving of LLMs.

To view the PyTorch Ecosystem, see the PyTorch Landscape and learn more about how projects can join the PyTorch Ecosystem.

About SGLang

SGLang is a fast-serving engine for large language models and vision language models. It makes the interaction with models faster and more controllable by co-designing the backend runtime and frontend language.

The core features include:

  • Fast Backend Runtime: Provides efficient serving with RadixAttention for prefix caching, zero-overhead CPU scheduler, continuous batching, token attention (paged attention), speculative decoding, tensor parallelism, chunked prefill, structured outputs, and quantization (FP8/INT4/AWQ/GPTQ).
  • Flexible Frontend Language: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions.
  • Extensive Model Support: Supports a wide range of generative models (Llama, Gemma, Mistral, Qwen, DeepSeek, LLaVA, etc.), embedding models (e5-mistral, gte, mcdse) and reward models (Skywork), with easy extensibility for integrating new models.
  • Active Community: SGLang is open source and backed by an active community with industry adoption.

SGLang is famous for its fast speed. It can often significantly outperform other state-of-the-art frameworks in terms of serving throughput and latency. You can learn more about the underlying techniques from the past release blog posts: v0.2 blog, v0.3 blog, v0.4 blog.

SGLang has been widely adopted by leading industry companies and frontier research labs. For example, xAI uses SGLang to serve its flagship model, Grok 3, which is currently the best model according to the Chatbot Arena leaderboard. Microsoft Azure uses SGLang to serve DeepSeek R1 on AMD GPUs, which is currently the best open source model.

Serving DeepSeek Models

You can easily launch a Docker container to serve a DeepSeek model with the following command:

# Pull the latest image
docker pull lmsysorg/sglang:latest

# Launch a server
docker run --gpus all --shm-size 32g -p 30000:30000 -v ~/.cache/huggingface:/root/.cache/huggingface --ipc=host --network=host --privileged lmsysorg/sglang:latest 
    python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code --port 30000

Then you can query the server with the OpenAI-compatible API

import openai
client = openai.Client(base_url=f"http://127.0.0.1:30000/v1", api_key="None")

response = client.chat.completions.create(
    model="deepseek-ai/DeepSeek-V3",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

The server launch command above works for 8xH200. You can find detailed instructions for other hardware (MI300X, H100, A100, H20, L40S) at https://docs.sglang.ai/references/deepseek.html.

SGLang integrates DeepSeek-specific optimizations, such as MLA throughput optimizations, MLA-optimized kernels, data-parallel attention, multi-token prediction, and DeepGemm, making it the top choice for serving DeepSeek models by dozens of companies, including AMD, NVIDIA, and many cloud providers. The team is actively working on integrating more optimizations following the 2025 H1 roadmap below.

Serving Llama Models

Similarly, you can launch the server for a Llama 3.1 text model with:

python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct

Or a Llama 3.2 multimodal model with:

python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-11B-Vision-Instruct  --chat-template=llama_3_vision

Roadmap

This year, the SGLang team will continue to push the boundaries of system efficiency. You can find the roadmap of 2025H1 here. The focus is

  • Throughput-oriented large-scale deployment similar to the DeepSeek inference system
  • Long context optimizations
  • Low latency speculative decoding
  • Reinforcement learning training framework integration
  • Kernel optimizations

Community

SGLang has been deployed to large-scale production, generating trillions of tokens every day. It has an active community with over three hundred contributors on GitHub. It is supported by the following institutions: AMD, Atlas Cloud, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, iFlytek, Jam & Tea Studios, LinkedIn, LMSYS, Meituan, Nebius, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, and 01.AI.

logos

Conclusion

We’re excited to welcome SGLang to the PyTorch ecosystem. SGLang accelerates the serving of large language and vision language models. It’s widely adopted by industry, powering the large-scale online serving of frontier models like Grok and DeepSeek.

We invite you to explore the SGLang GitHub repo, join the community on Slack, and reach out to contact@sglang.ai for inquiries or collaboration opportunities. Together, we can make powerful AI models accessible to everyone.

Read More

PyTorch Day China 2025 Call for Proposals Open

PyTorch Day China 2025 Call for Proposals Open

We’re excited to announce the first-ever PyTorch Day China! This new event, hosted by the PyTorch Foundation, will take place on June 7 in Beijing, China, bringing together AI practitioners, researchers, and industry professionals to explore the latest advancements in open source AI and machine learning. Co-located with the BAAI Conference, PyTorch Day China is a chance to connect with the community, share knowledge, and help shape the future of deep learning.

PyTorch Day China 2025 Call for Proposals Open

Why Submit a Proposal?

PyTorch Day China offers a platform for AI practitioners and researchers to showcase their work, exchange ideas, and connect with others in the community. If you’re working on innovative applications, tools, or research in the PyTorch ecosystem, we encourage you to share your expertise.

Topics for Submission:

  • AI Applications and Use Cases
  • Core PyTorch Framework
  • DL Compilers and Kernel Authoring
  • Edge AI and On-Device
  • Ethical AI, Governance, and Regulation
  • Generative AI and Large Language Models (LLMs) with PyTorch
  • Open Source Collaboration, Education, and Community Building
  • Optimization for Training and Inference
  • PyTorch on Accelerator Hardware
  • PyTorch Ecosystem and Tools
  • PyTorch in Research and Academia
  • Performance Measurement and Benchmarking
  • Scaling Training and Inference

The submission deadline is April 13. Submit and learn more here: https://www.lfasiallc.com/pytorch-day-china/call-for-proposals-cfp/

Why Attend?

PyTorch Day China will feature technical talks, discussions, and poster sessions that highlight real-world applications and developments in AI and machine learning. Attendees will have the opportunity to learn from experts, contribute to the open source community, and engage with fellow PyTorch users. Registration information will be available in April.

Event Details

  • Date: June 7, 2025
  • Location: Zhongguancun Exhibition Center, Beijing, China
  • Address: 索家坟, Hai Dian Qu, Bei Jing Shi, China, 100080
  • Co-located with: BAAI Conference

Travel Information

The venue, Zhongguancun Exhibition Center, is approximately 39 km from Beijing International Airport. More details on travel and accommodation will be available on the BAAI Conference website and updated here as they become available.

Have Questions?

For inquiries, please contact pytorchevents@linuxfoundation.org.

Submit your proposal by April 13 and join the conversation shaping the future of PyTorch.

Read More

PyTorch at GTC 2025

PyTorch at GTC 2025

GTC is coming back to San Jose on March 17–21, 2025. Join PyTorch Foundation members Arm, AWS, Google Cloud, IBM, Lightning AI, Meta, Microsoft Azure, Snowflake, and thousands of developers as we celebrate PyTorch. Together learn how AI & accelerated computing are helping humanity solve our most complex challenges.

Join in person with discounted GTC registration for PyTorch Foundation or watch online with free registration.

book cover

Scaling Open Source AI: From Foundation Models to Ecosystem Success

Hear from PyTorch Foundation’s Executive Director Matt White & panelists from UC Berkeley, Meta, NVIDIA, & Sequoia Capital how open source is transforming AI development, bringing together experts from industry, academia, and venture capital to discuss the technical and business aspects of collaborative open source AI development They’ll examine how open source projects like PyTorch, vLLM, Ray, and NVIDIA’s NeMo are accelerating AI innovation while creating new opportunities for businesses and researchers. They’ll share real-world experiences from PyTorch’s development, Berkeley’s research initiatives, and successful AI startups. Take away valuable insights into the technical and business aspects of open source AI. – Monday, Mar 17 10:00 AM – 11:00 AM PDT

PyTorch @ GTC

The Performance of CUDA with the Flexibility of PyTorch
Mark Saroufim, Software Engineer, Meta Platforms

This talk explores how PyTorch users are also becoming CUDA developers. We’ll start with motivating examples from eager, the launch of torch.compile and the more recent trend of kernel zoos. We will share details on how we went about integrating low bit matmuls in torchao and the torch.compile CUTLASS backend. We’ll also discuss details on how you can define, build and package your own custom ops in PyTorch so you get the raw performance of CUDA while maintaining the flexibility of PyTorch.

Make My PyTorch Model Fast, and Show Me How You Did It
Thomas Viehmann, Principal Research Engineer, Lightning AI
Luca Antiga, CTO, Lightning AI

PyTorch is popular in deep learning and LLMs for richness and ease of expressions. To make the most of compute resources, PyTorch models benefit from nontrivial optimizations, but this means losing some of their ease and understandability. Learn how with Thunder, a PyTorch-to-Python compiler focused on usability, understandability, and extensibility, you can optimize and transform (i.e., distribute across many machines) models while • leaving the PyTorch code unchanged • targeting a variety of models without needing to adapt to each of them • understanding each transformation step because the results are presented as simple Python code • accessing powerful extension code for your own optimizations with just one or a few lines of code We’ll show how the combination of Thunder transforms and the NVIDIA stack (NVFuser, cuDNN, Apex) delivers optimized performance in training and inference on a variety of models.

FlexAttention: The Flexibility of PyTorch With the Performance of FlashAttention
Driss Guessous, Machine Learning Engineer, Meta Platforms

Introducing FlexAttention: a novel PyTorch API that enables custom, user-defined attention mechanisms with performance comparable to state-of-the-art solutions. By leveraging the PyTorch compiler stack, FlexAttention supports dynamic modifications to attention scores within SDPA, achieving both runtime and memory efficiency through kernel fusion with the FlashAttention algorithm. Our benchmarks on A100 GPUs show FlexAttention achieves 90% of FlashAttention2’s performance in forward passes and 85% in backward passes. On H100 GPUs, FlexAttention’s forward performance averages 85% of FlashAttention3 and is ~25% faster than FlashAttention2, while backward performance averages 76% of FlashAttention3 and is ~3% faster than FlashAttention2. Explore how FlexAttention balances near-state-of-the-art performance with unparalleled flexibility, empowering researchers to rapidly iterate on attention mechanisms without sacrificing efficiency.

Keep Your GPUs Going Brrr : Crushing Whitespace in Model Training
Syed Ahmed, Senior Software Engineer, NVIDIA
Alban Desmaison, Research Engineer, Meta
Aidyn Aitzhan, Senior Software Engineer, NVIDIA

Substantial progress has recently been made on the compute-intensive portions of model training, such as high-performing attention variants. While invaluable, this progress exposes previously hidden bottlenecks in model training, such as redundant copies during collectives and data loading time. We’ll present recent improvements in PyTorch achieved through Meta/NVIDIA collaboration to tackle these newly exposed bottlenecks and how practitioners can leverage them.

Accelerated Python: The Community and Ecosystem
Andy Terrel, CUDA Python Product Lead, NVIDIA
Jeremy Tanner, Open Source Programs, NVIDIA
Anshuman Bhat, CUDA Product Management, NVIDIA

Python is everywhere. Simulation, data science, and Gen AI all depend on it. Unfortunately, the dizzying array of tools leaves a newcomer baffled at where to start. We’ll take you on a guided tour of the vibrant community and ecosystem surrounding accelerated Python programming. Explore a variety of tools, libraries, and frameworks that enable efficient computation and performance optimization in Python, including CUDA Python, RAPIDS, Warp, and Legate. We’ll also discuss integration points with PyData, PyTorch, and JAX communities. Learn about collaborative efforts within the community, including open source projects and contributions that drive innovation in accelerated computing. We’ll discuss best practices for leveraging these frameworks to enhance productivity in developing AI-driven applications and conducting large-scale data analyses.

Supercharge large scale AI with Google Cloud AI hypercomputer (Presented by Google Cloud)
Deepak Patil, Product Manager, Google Cloud
Rajesh Anantharaman, Product Management Lead, ML Software, Google Cloud

Unlock the potential of your large-scale AI workloads with Google Cloud AI Hypercomputer – a supercomputing architecture designed for maximum performance and efficiency. In this session, we will deep dive into PyTorch and JAX stacks on Google Cloud on NVIDIA GPUs, and showcase capabilities for high performance foundation model building on Google Cloud.

Peering Into the Future: What AI and Graph Networks Can Mean for the Future of Financial Analysis
Siddharth Samsi, Sr. Solutions Architect, NVIDIA
Sudeep Kesh, Chief Innovation Officer, S&P Global

Artificial Intelligence, agentic systems, and graph neural networks (GNNs) are providing the new frontier to assess, monitor, and estimate opportunities and risks across work portfolios within financial services. Although many of these technologies are still developing, organizations are eager to understand their potential. See how S&P Global and NVIDIA are working together to find practical ways to learn and integrate such capabilities, ranging from forecasting corporate debt issuance to understanding capital markets at a deeper level. We’ll show a graph representation of market data using the PyTorch-Geometric library and a dataset of issuances spanning three decades and across financial and non-financial industries. Technical developments include generation of a bipartite graph and link-prediction GNN forecasting. We’ll address data preprocessing, pipelines, model training, and how these technologies can broaden capabilities in an increasingly complex world.

Unlock Deep Learning Performance on Blackwell With cuDNN
Yang Xu (Enterprise Products), DL Software Engineering Manager, NVIDIA

Since its launch, cuDNN, a library for GPU-accelerating deep learning (DL) primitives, has been powering many AI applications in domains such as conversational AI, recommender systems, and speech recognition, among others. CuDNN remains a core library for DL primitives in popular frameworks such as PyTorch, JAX, Tensorflow, and many more while covering training, fine-tuning, and inference use cases. Even in the rapidly evolving space of Gen AI — be it Llama, Gemma, or mixture-of-experts variants requiring complex DL primitives such as flash attention variants — cuDNN is powering them all. Learn about new/updated APIs of cuDNN pertaining to Blackwell’s microscaling format, and how to program against those APIs. We’ll deep dive into leveraging its graph APIs to build some fusion patterns, such as matmul fusion patterns and fused flash attention from state-of-the-art models. Understand how new CUDA graph support in cuDNN, not to be mistaken with the cuDNN graph API, could be exploited to avoid rebuilding CUDA graphs, offering an alternative to CUDA graph capture with real-world framework usage.

Train and Serve AI Systems Fast With the Lightning AI Open-Source Stack (Presented by Lightning AI)
Luca Antiga, CTO, Lightning AI

See how the Lightning stack can cover the full life cycle, from data preparation to deployment, with practical examples and particular focus on distributed training and high-performance inference. We’ll show examples that focus on new features like support for multi-dimensional parallelism through DTensors, as well as quantization through torchao.

Connect With Experts (Interactive Sessions)

Meet the Experts From Deep Learning Framework Teams
Eddie Yan, Technical Lead of PyTorch, NVIDIA
Masaki Kozuki, Senior Software Engineer in PyTorch, NVIDIA
Patrick Wang (Enterprise Products), Software Engineer in PyTorch, NVIDIA
Mike Ruberry, Distinguished Engineer in Deep Learning Frameworks, NVIDIA
Rishi Puri, Sr. Deep Learning Engineer and Lead for PyTorch Geometric, NVIDIA

Training Labs

Kernel Optimization for AI and Beyond: Unlocking the Power of Nsight Compute
Felix Schmitt, Sr. System Software Engineer, NVIDIA
Peter Labus, Senior System Software Engineer, NVIDIA

Learn how to unlock the full potential of NVIDIA GPUs with the powerful profiling and analysis capabilities of Nsight Compute. AI workloads are rapidly increasing the demand for GPU computing, and ensuring that they efficiently utilize all available GPU resources is essential. Nsight Compute is the most powerful tool for understanding kernel execution behavior and performance. Learn how to configure and launch profiles customized for your needs, including advice on profiling accelerated Python applications, AI frameworks like PyTorch, and optimizing Tensor Core utilization essential to modern AI performance. Learn how to debug your kernel and use the expert system built into Nsight Compute, known as “Guided Analysis,” that automatically detects common issues and directs you to the most relevant performance data all the way down to the source code level.

Make Retrieval Better: Fine-Tuning an Embedding Model for Domain-Specific RAG
Gabriel Moreira, Sr. Research Scientist, NVIDIA
Ronay Ak, Sr. Data Scientist, NVIDIA

LLMs power AI applications like conversational chatbots and content generators, but are constrained by their training data. This might lead to hallucinations in content generation, which requires up-to-date or domain-specific information. Retrieval augmented generation (RAG) addresses this issue by enabling LLMs to access external context without modifying model parameters. Embedding or dense retrieval models are a key component of a RAG pipeline for retrieving relevant context to the LLM. However, an embedding model’s effectiveness to capture the unique characteristics of the custom data hinges on the quality and domain relevance of its training data. Fine-tuning embedding models is gaining interest to provide more accurate and relevant responses tailored to users’ specific domain.

In this lab, you’ll learn to generate a synthetic dataset with question-context pairs from a domain-specific corpus, and process the data for fine-tuning. Then, fine-tune a text embedding model using synthetic data and evaluate it.

Poster Presentations

Single-View X-Ray 3D Reconstruction Using Neural Back Projection and Frustum Resampling
Tran Minh Quan, Developer Technologist, NVIDIA

Enable Novel Applications in the New AI Area in Medicine: Accelerated Feature Computation for Pathology Slides
Nils Bruenggel, Principal Software Engineer, Roche Diagnostics Int. AG

Read More

Introducing the New PyTorch Landscape: Your Guide to the PyTorch Ecosystem

Introducing the New PyTorch Landscape: Your Guide to the PyTorch Ecosystem

We’re excited to reveal our brand new PyTorch Landscape. The PyTorch Landscape helps researchers, developers, and organizations easily locate useful, curated, community-built tools that augment the PyTorch core framework.

landscape banner

What the Landscape Offers

The Landscape visually organizes projects into three categories—Modeling, Training, and Optimizations—making finding relevant frameworks, libraries, and projects easy. Users can quickly locate curated, valuable tools for a variety of use cases that complement the PyTorch framework. Each tool that is part of the Landscape has been reviewed and vetted by PyTorch project experts. The projects in the Landscape are considered to be mature and healthy and provide valuable capabilities that complement the PyTorch framework in their respective domains.

Explore the AI Landscape

The Explore page presents platforms, tools, and libraries, each with a logo, description, and links to GitHub and further details. This categorized, visual approach simplifies discovery and provides quick access to essential technologies.

Guide Page: A Closer Look

For deeper insights, the Guide page expands on each project, highlighting methodologies and trends shaping AI development, from adversarial robustness to self-supervised learning. There are also project statistics provided for each project, including metrics such as number of stars, contributors, commit history, languages used, license, and other valuable metrics that provide an in-depth understanding of the project and how it may be used.

Tracking AI’s Growth: The Stats Page

The Stats page provides insights into AI development trends, tracking repository activity, programming languages, and industry funding data.

  • Repositories: 117 repositories, 20.5k contributors, and 797.2k stars across 815MB of source code.
  • Development Trends: Weekly commit activity over the last year.
  • Licensing Breakdown: Repositories are categorized by license type.
  • Funding & Acquisitions: Insights into investment trends, including funding rounds and acquisitions.

Why Use the PyTorch Landscape?

Finding useful and high quality open source projects that complement the PyTorch core system can be overwhelming. The PyTorch Landscape offers a clear, accessible way to explore the ecosystem of community-built tools, whether you’re researching, building models, or making strategic decisions.

Stay ahead with the PyTorch Landscape — your guide to the PyTorch Ecosystem.

Want to Contribute a Project to the PyTorch Landscape?

Have you built a useful open source tool that you would like to share with the PyTorch community? Then help us grow the Ecosystem by contributing your tool! You can find the instructions to apply here. We welcome all contributions from the community!

Read More

Scaling Recommendation Systems Training to Thousands of GPUs with 2D Sparse Parallelism

Scaling Recommendation Systems Training to Thousands of GPUs with 2D Sparse Parallelism

At Meta, recommendation systems are the cornerstone of delivering relevant and personalized ads to billions of users globally. Through technologies like PyTorch’s TorchRec, we’ve successfully developed solutions that enable model training across hundreds of GPUs. While these systems have served us well, recent research on scaling laws has revealed a compelling opportunity: we can achieve significantly better model performance by training dramatically larger neural networks.

However, this insight presents us with a new challenge. Our current training infrastructure, though highly optimized for hundreds of GPUs, cannot efficiently scale to the thousands of GPUs needed to train these larger models. The leap from hundreds to thousands of GPUs introduces complex technical challenges, particularly around handling sparse operations in recommendation models. These challenges require fundamentally new approaches to distributed training, which we address with a novel parallelization strategy.

To address these issues, we introduced 2D embedding parallel, a novel parallelism strategy that overcomes the sparse scaling challenges inherent in training large recommendation models across thousands of GPUs. This is available today in TorchRec through the DMPCollection API. This approach combines two complementary parallelization techniques: data parallelism for the sparse components of the model, and model parallelism for the embedding tables, leveraging TorchRec’s robust sharding capabilities. By strategically integrating these techniques, we’ve created a solution that scales to thousands of GPUs and now powers Meta’s largest recommendation model training runs.

What are the sparse scaling challenges?

We identified three key challenges that prevented us from naively scaling our model to thousands of GPUs:

  • Imbalancing and straggler issue: with more GPUs it’s harder to achieve balanced sharding, some ranks can have much heavier workload for embedding computations, which can slow down the entire training.
  • Communication across nodes: As training jobs utilize an increased number of GPUs, the all-to-all communication bandwidth can drop under certain network topologies which can increase communication latency significantly.
  • Memory overhead: The memory used by input features is often negligible, however, as we use thousands of GPUs, we can introduce larger input features and the memory requirements can become significant.

With 2D embedding parallel, we can describe our new parallelism scheme like this, in this example we have 2 model replicas (Replica 1: GPU1/GPU3, Replica 2: GPU2/GPU4)

Flow diagram

Figure 1: Layout illustration of 2D Sparse Parallelism

With 2D sparse parallelism we address these challenges, instead of sharding tables across all ranks, we first evenly divide all ranks into several parallel groups:

  1. Within each group, we use model parallel for the embedding tables, such as column-wise/row-wise sharding. At scale, for our largest tables, we have also developed a grid sharding, which shards embedding tables on the row and column dimension.
  2. Across groups, we do data parallel, such that each rank in a group has its corresponding replica rank in the other groups (replica rank means storing the same embedding table shards).
    1. After each group has completed its own backward pass, we all reduce the embedding table weights across the replicas to keep them synchronized.

Our production solution

TorchRec is our library to build the sparse part of the recommendation models in native PyTorch. With the traditional API being DistributedModelParallel which applies model parallel to the embedding tables. We introduce a new API alongside it, known as DMPCollection, which serves as the main entry point for enabling 2D parallel on TorchRec models. We designed it to be as easy of a change as applying FSDP/DDP is.

To understand what DMPCollection does, we have to understand what DistributedModelParallel (DMP) does first:

  1. Create embedding tables, known as EmbeddingBagCollection and EmbeddingCollections.
  2. Generate a sharding plan with respect to GPU topology, embedding tables, memory available, input data, and more.
  3. Wrap model with DMP and the associated sharding plan passed in.
  4. DMP initializes and shards the embedding tables in accordance with the sharding plan.
  5. On a train step, DMP takes an input batch, communicates it to the appropriate GPUs containing the embedding table shard of interest, looks up the value, and returns it back to the GPU that requested it. This is all done on the global process group, with some exceptions for special sharding (such as table row wise sharding)

DistributedModelParallel was built for model parallel with many parts working under the assumption of sharding and working around the global world size. We need to change these parts in a way where we can introduce additional dimensions of parallelism without losing the optimizations and feature set of TorchRec.

DMPCollection changes a few key parts to enable 2D parallel in an extensible way,

  • Generate sharding plans for the smaller sharding group once, once passed in we communicate to the appropriate ranks across the global group and remap the ranks to fit the new sharding group ranks.
  • Create two new NCCL process groups, known as sharding and replica process groups. The sharding process group is passed into sharding and train step components of TorchRec. The replica process group is used for the weight and optimizer state synchronization, the all reduce call happens over this process group.
    • The sub NCCL process groups allow us to efficiently communicate only between the ranks that are relevant for a particular comm. Each rank will have two associated process groups.

To the user, the change is very simple, while taking away all the complexity around applying the parallelism strategies to the model.

How do we create these sharding and replication groups?

These process groups are one of the keys to DMPCollection’s performant implementation. From our earlier diagram, we showed a simple 2×2 GPU setup, however, at scale, how do we assign which ranks are part of a given sharding group and what are their replica ranks across the sharding groups?

Consider the following setup with 2 nodes, each with 4 GPUs. The sharding and replication groups under 2D parallel will be,

Sharding Group Sharding Ranks
0 0, 2, 4, 6
1 1, 3, 5, 7
Replication Group Replication Ranks
0 0, 1
1 2, 3
2 4, 5
3 6, 7

We use the following formulation,

  1. Divide all trainers into G sharding groups, each with L trainers
    1. Groups, G, is determined by G = T / L, where T is total number of trainers
  2. For each group, G, we assigned non-contiguous trainer ranks based on the group it’s in, following,
    1. [i, G+i, 2G+i, …, (L – 1) G+i], where* i = 0 to G-1*
  3. From the groups, G, we can create the replication group, which is every G continuous ranks
    1. (0 to G-1, G to 2* G – 1) each continuous set stores the duplicate embedding table shards.

This means our sharding groups, G, are of size L, which can be known as the number of ranks to apply model parallel across. This, in turn, gives us replica groups, each of size G, which are the ranks we data parallel across.

In DMPCollection, we’re able to create these process groups efficiently with the use of DeviceMesh, we create the entire GPU topology in a 2×2 matrix, with each row representing the group of sharding ranks and each column representing the corresponding replica ranks,

create peer matrix
num_groups = global_world_size // sharding_group_size
for each group_rank in num_groups:
	peers = [num_groups * rank + group_rank for rank in range(sharding_group_size)]
	add peer to peer matrix

initalize DeviceMesh with two dimensions (shard, replicate)
slice DeviceMesh on shard for sharding process group
slide DeviceMesh on replicate for replica process group

With our DeviceMesh approach, should we want to change the topology or provide further flexibility in the future, we can easily extend our creation logic to any form of topologies and even extend for further dimensions of parallelism if needed.

Performance of 2D parallel

Our rank partitioning strategy optimizes communication patterns by strategically placing model replica ranks for each shard within the same compute node. This architecture provides significant performance benefits for the weight synchronization operation. After the backward pass, we perform all-reduce operations to synchronize model weights—which is an expensive process given the large parameter counts we have to communicate and sync—with our setup of placing replicas on the same node we leverage intra node’s high-bandwidth over-relying on slower inter-node bandwidth.

The effect of this design choice on the other communication collectives generally improves the latencies. The improvement stems from two factors.

  1. By sharding the embedding tables over a reduced number of ranks and conducting communications for the model within the smaller group, we achieve a lower all-to-all latency.
  2. With the replication in 2D parallel, our embedding lookup latency on a rank reduces, we can reduce the local batch size to 1/Nth of the equivalent global batch size, where N is the number of model replicas.

A production model trace exemplifies these two factors, here we run the 2D parallel job on 1024 GPUs, with a sharding group size of 256 GPUs.

State diagram

Figure 2: Comparing latencies between non 2D parallel and 2D parallel workloads

There are two key levers users have to tune to maximize performance for their workloads:

  1. The size of the model sharding group relative to the global world size. The global world size divided by the sharding group size represents the number of model replicas we will have.
    1. To maximize performance, users can look to scale up their model up to 8x, this scaling factor maintains the intra-host all reduce.
      1. For further scaling, the all reduce would have to happen over inter host. From our experiments, we did not see an obvious performance regression and in fact note advantages of an inter host all reduce. We can change our sharding and replica topology to inter host all reduce, which can help us introduce fault tolerance strategies should a particular host go down.
  2. Frequency of all reduce synchronization, DMPCollection comes with a sync() call, which can be tuned to be called every N training steps, performing a sort of local SGD training. With scale, reducing the frequency of synchronization can bring significant gains to performance.

Future Work

Readers should note that 2D sparse parallel training differs from non-parallelized training because we synchronize the embedding table weights rather than the gradients. This approach is made possible by TorchRec’s use of FBGEMM, which provides optimized kernels under the hood. One of FBGEMM’s key optimizations is the fusion of the optimizer in the backward pass. Instead of fully materializing the embedding table gradients—which would consume significant memory—they are passed directly to the optimizer update. Attempting to materialize and synchronize these gradients would create substantial overhead, making that approach impractical.

Our exploration revealed that to achieve training results comparable to the baseline, we synchronize optimizer states on a delayed schedule, with the timing dependent on the number of sharding/replica groups (ie: for Adagrad we update the momentum behind by one sync step). This approach also enables users to implement local SGD or semi-synchronized training strategies, which can achieve convergence and potentially produce better loss curves than the baseline.

We thank you for reading our post! This is an exciting direction we have come across that we hope to develop further to maximize performance of recommendation systems and push the state of the art.

Read More

Powering AI with PyTorch, Fedora, and Open Source Communities

Powering AI with PyTorch, Fedora, and Open Source Communities

man speaking at a conference

At DevConf.IN 2025 in Pune, I had the opportunity to host a PyTorch Meetup on February 28th. The session, titled “Powering AI with PyTorch, Fedora, and Open Source Communities” was aimed at introducing PyTorch to students and professionals, explaining why PyTorch+Fedora form an ideal AI development platform. The other key aspect I covered was collaboration between open source communities.

Introduction to PyTorch

The Power of Deep Learning made simple

With the explosion of GPTs, there is a renowned interest in the field of AI and ML. The myth of developing AI/ML technologies and its applications is rocket science and far-fetched, needs correction. Only open source has the power to demystify this myth and further evolve the technology to make it versatile and developer friendly. Since its inception, PyTorch has evolved and has been a driving force to make AI/ML development extremely simple. I covered the aspects of PyTorch key components, its features and why PyTorch is the best choice as a deep learning framework.

man speaking at a conference

The codewalk through was designed to showcase how easy and simple it is to utilise the power of GPUs, creating a simple neural network and training the model. The code walkthrough was very well received and it was great to hear back from the attendees that they never knew how powerful PyTorch is for deep learning. The real world examples sighted how this powerful framework can be used beyond the common GPTs and has the power to influence across a broad spectrum of applications.

Fedora+PyTorch the Ideal AI/ML Development Platform

man speaking at a conference

man speaking at a conference

One of the highlights of the event was the discussion on Fedora’s role as an AI platform. Fedora’s reliability, flexibility, and strong community support make it an ideal partner for PyTorch, allowing developers to focus on model-building without worrying about infrastructure. The students were intrigued by the idea of contributing to Fedora’s AI/ML ecosystem while building their own projects. Sumantro Mukherjee spoke about the AI policy in Fedora and how one can start contributing to the AI/ML using Fedora as a platform. He highlighted how Fedora is evolving to meet the needs of AI practitioners. The idea that an open-source operating system could provide the perfect foundation for AI research sparked an engaging conversation.

Innovation in Open Source When Communities Come Together

charts

It is important that we learn from history and repeat the good things! When open source communities come together they can create seismic shifts in the industry. To drive this home, I took the audience on a journey through history, revisiting a pivotal moment when Apache and Linux came together, solving common problems and fundamentally reshaping enterprise computing. That moment was not just about technology; it was about collaboration. It was about two powerful communities recognizing that they were stronger together. Today, we stand at the cusp of another such moment – PyTorch and Linux, particularly Fedora, are coming together to shape the future of AI/ML. This is not just an opportunity but a responsibility for contributors, developers, and AI/ML enthusiasts to be part of this movement.

Looking Ahead

man speaking at a conference

One of the best parts of the event was the enthusiasm it generated. Diverse audience, including students, AI enthusiasts, and industry professionals. Notably, Vincent Caldeira (CTO, APAC, Red Hat) and Chris Butler (Senior Principal Chief Architect, Red Hat) were present, reinforcing the growing interest in open-source AI/ML. Many students were eager to explore PyTorch and Fedora, contribute to open-source AI projects, and start their own AI experiments. Industry experts saw the potential for scalable, community-driven AI innovation. The session sparked curiosity and conversations that continued long after the event ended.

Read More

Peak Performance, Minimized Memory: Optimizing torchtune’s performance with torch.compile & Liger Kernel

Peak Performance, Minimized Memory: Optimizing torchtune’s performance with torch.compile & Liger Kernel

LinkedIn: Shivam Sahni, Byron Hsu, Yanning Chen
Meta: Ankith Gunapal, Evan Smothers

This blog explores the integration of a custom triton kernel, Liger Kernel with torch.compile to enhance the performance of fine-tuning large language models (LLMs) using torchtune. torchtune, a PyTorch-native library, offers modular building blocks and customizable finetuning recipes which include torch.compile support for various LLMs, while Liger Kernel provides optimized Triton kernels to improve training efficiency and reduce memory usage. The integration involves modifying the TransformerDecoder module in torchtune to bypass the linear layer computation, allowing the Liger Fused Linear Cross Entropy Loss to handle the forward projection weights. Experiments conducted on an NVIDIA A100 instance demonstrate that torch.compile outperforms PyTorch Eager in throughput and memory efficiency, with Liger Kernel further reducing peak memory allocation and enabling larger batch sizes. The results show a 47% reduction in peak memory at batch size 256 and a marginal increase in throughput with meta-llama/Llama-3.2-1B , confirming the effectiveness of the integration without affecting the loss curves.

Introduction to torchtune

torchtune is a PyTorch-native library which has been designed for finetuning LLMs. torchtune provides composable and modular building blocks along with finetuning recipes that can be easily customized for your use case, as will be shown in this blog.
torchtune provides:

  • PyTorch implementations of popular LLM model architectures from Llama, Gemma, Mistral, Phi, and Qwen model families
  • Hackable training recipes for full finetuning, LoRA, QLoRA, DPO, PPO, QAT, knowledge distillation, and more
  • Out-of-the-box memory efficiency, performance improvements, and scaling with the latest PyTorch APIs, including torch.compile
  • YAML configs for easily configuring training, evaluation, quantization or inference recipes
  • Built-in support for many popular dataset formats and prompt templates

Introduction to Liger Kernel

Liger Kernel is an open source library of optimized Triton kernels designed to enhance the efficiency and scalability of training Large Language Models (LLMs). It focuses on kernel-level optimizations such as operation fusing and input chunking, achieving significant improvements in training throughput and GPU memory usage compared to existing implementations like those from HuggingFace. By using a single line of code, Liger Kernel can improve training throughput by 20% and reduce memory usage by 60%.

Fused Linear Cross Entropy

The bulk of LIger Kernel’s performance improvement comes from the Fused Linear Cross Entropy (FLCE) Loss, whose core idea is as follows:

In LLMs, the vocabulary size has increased significantly, leading to a large logit tensor during cross-entropy (CE) loss computation. This logit tensor consumes excessive memory, causing a bottleneck in training. For example, when training with a batch size of 8 and sequence length of 4096, the 256k vocabulary size results in a 16.8 GB logit tensor. The FLCE kernel breaks down the computation into smaller chunks, reducing memory consumption.

Here’s how it works:

  1. Flattens the 3D hidden states into a 2D matrix by collapsing the batch size and sequence length dimensions.
  2. Applies the linear projection head sequentially on the chunked hidden states.
  3. Computes the partial loss and returns the chunked logits gradient using the Liger CE kernel.
  4. Derives the chunked hidden states gradients and accumulates the projection head gradients.

Torchtune’s recipes provide torch.compile support out of the box. It has been shown that utilizing torch.compile with FLCE makes FLCE 2x faster.

Integrating Liger Kernel with torch.compile & torchtune

We demonstrate integration of Liger Kernel with torch.compile & torchtune by running a full fine-tuning recipe with meta-llama/Llama-3.2-1B. To make this integration happen, we have defined a custom full finetuning recipe, the details of the changes are mentioned below.

CUDA_VISIBLE_DEVICES=0,1,2,3 tune run --nproc_per_node 4 recipes/full_finetune_distributed.py --config llama3_2/1B_full optimizer=torch.optim.AdamW optimizer.fused=True optimizer_in_bwd=False gradient_accumulation_steps=1  dataset.packed=True compile=True enable_activation_checkpointing=True tokenizer.max_seq_len=512  batch_size=128

One of the inputs to the LCE Kernel is the forward projection weights. torchtune is designed as a modular library with composable blocks. There is a TransformerDecoder block where at the end of the block, we pass the final hidden state through a linear layer to get the final output. Since the linear layer is combined with the CE loss in LCE Kernel, we write a custom forward function for TransformerDecoder where we skip the computation through the linear layer.

In the full finetuning recipe, we override the model’s forward method with this custom method

import types
from liger_kernel.torchtune.modules.transformers import decoder_forward
self._model.forward = types.MethodType(decoder_forward, self._model)

We then pass the model’s forward projection weights to calculate the loss with LCE Kernel

from liger_kernel.transformers.fused_linear_cross_entropy import (
    LigerFusedLinearCrossEntropyLoss,
)

# Use LCE loss instead of CE loss
self._loss_fn = LigerFusedLinearCrossEntropyLoss()

# call torch.compile on the loss function
if self._compile:
    training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)

# pass the model's forward projection weights for loss computation
current_loss = (
     self._loss_fn(
         self._model.output.tied_module.weight,
         logits,
         labels,
     )
     * current_num_tokens
 )

The complete code and instructions can be found in the GitHub repo.

Experiments & Benchmarking Results

We conduct 3 types of experiments to demonstrate how Liger Kernel integration with torch.compile enhances the performance of torchtune. We set up our experiments on an instance running NVIDIA A100. We fine-tune a small LLM meta-llama/Llama-3.2-1B with differing batch sizes. We record the throughput in terms of tokens/second and measure the peak memory allocated during finetuning. Since it’s a small model, we only use 4 A100 GPUs for the benchmarking. The following are the experiments we conducted:

  1. Increase batch_size in powers of 2 with PyTorch eager
  2. Increase batch_size in powers of 2 with torch.compile
  3. Increase batch_size in powers of 2 with torch.compile & Liger integration

We notice that with PyTorch Eager, throughput increases with increasing batch_size till we hit OOM at batch_size 256. With torch.compile, the throughput is higher than PyTorch Eager for each batch_size. We see that the peak memory allocation reduces drastically with increasing batch_size and more than 50% reduction in peak memory at batch_size 128. This results in torch.compile being able to support batch_size 256 and hence, the overall throughput with torch.compile being 36% greater than PyTorch Eager. Integrating Liger Kernel with torch.compile doesn’t drop the throughput at lower batch_size but with increasing batch_size, we notice that torchtune is consuming less memory compared to torch.compile. At batch_size 256, we see a 47% reduction in peak memory allocation with the Liger kernel. This allows us to use batch_size 512 with torch.compile & Liger. We notice that there is a marginal 1-2% increase in throughput compared to torch.compile without custom triton kernels.

Plot of tokens/sec per rank vs batch_size

Figure 2: Plot of tokens/sec per rank vs batch_size

Peak memory allocated vs batch_size

Figure 3: Peak memory allocated vs batch_size

To rule out any potential functional issues with our integration of Liger Kernel with torchtune, we plot the loss curve against training steps with & without Liger. We see that there is no visible difference in the loss curves.

Plot of loss vs training steps for batch_size=128

Figure 4: Plot of loss vs training steps for batch_size=128

Next Steps

Acknowledgments

We thank Hamid Shojanazeri (Meta), Less Wright (Meta), Horace He (Meta) & Gregory Chanan (Meta) for their feedback and support in making this blog post happen.

Read More

Current and New Activation Checkpointing Techniques in PyTorch

Current and New Activation Checkpointing Techniques in PyTorch

As models scale in depth, batch size, and sequence length, etc, activation memory becomes an increasingly significant contributor to the overall memory usage. To help address this, PyTorch provides utilities for activation checkpointing, which reduce the number of saved tensors by recomputing them when needed, trading off memory usage for additional compute.

In this post, we’ll walk through the basics of what activation memory is, the high-level ideas behind existing activation checkpointing techniques, and also introduce some newer techniques that aim to improve flexibility and provide more optimization/automation out of the box.

As we look at these techniques, we’ll compare how these methods fit into a speed vs. memory trade-off diagram and hopefully provide some insight on how to choose the right strategy for your use case.

(If you prefer to jump straight to the new APIs, please skip ahead to the “Selective Activation Checkpoint” and “Memory Budget API” sections below.)

flow diagram


Activation Memory Basics

By default, in eager mode (rather than using torch.compile), PyTorch’s autograd preserves intermediate activations for backward computation. For example, if you call sin on a tensor x during the forward pass, autograd must remember x to compute cos(x) during backward.

flow diagram

If this tensor x is saved at the beginning of the forward pass, it remains in memory throughout both the forward and backward phases. It can only be cleared after it is used to compute the gradient, which happens at the end of the backward pass (due to the reverse order of execution).

Thus, as you proceed through the forward pass and perform more and more operations, you accumulate more and more activations, resulting in more and more activation memory until it (typically) reaches its peak at the start of backward (at which point activations can start to get cleared).

flow diagram

In the diagram above, the orange boxes represent operations, black arrows represent their tensor inputs and outputs. The black arrows that cross over the right represent tensors that autograd saves for backward.

A useful way to visually organize this default saving behavior in eager as well as the techniques we’re about to introduce is based on how they trade off speed versus memory.

flow diagram

The ideal place to be on this diagram is the top-left, where you have “high” speed but also low memory usage.

We begin by putting the default saving behavior on the top-right (for reasons we’ll explain in more detail as we introduce more points for other techniques).


Activation Checkpointing (AC)

Activation checkpointing (AC) is a popular technique to reduce memory usage in PyTorch.

During forward, any operations performed inside the AC’d region do not save tensors for backward. (Only the inputs to the function are saved.) During backward, the intermediate activations needed for gradient computation are rematerialized by running the function a second time.

flow diagram

In the diagram (right), the black box shows where activation checkpointing is applied. Compared to the default eager approach (left), this setup results in fewer tensors being saved (1 versus 3).

Applying AC on the right parts of the model has the effect of reducing peak memory, because the intermediate activations are no longer materialized in memory when the memory usage typically peaks (at the beginning of backward).

On the speed-versus-memory tradeoff diagram, AC is plotted on the bottom-left. Relative to eager mode, it reduces the amount of memory saved for backward but comes with an added cost in compute due to recomputation.

flow diagram

Note that AC’s speed–memory tradeoff /can/ be adjusted by selecting which parts of the forward pass to checkpoint and by defining how many checkpoint regions to use. However, implementing these changes may require modifying your model’s structure and can be cumbersome depending on how your code is organized. For the purposes of this diagram, we assume only one region is checkpointed; under this assumption, AC appears as a single point on the tradeoff diagram.

Also note that “memory” here does not refer to peak memory usage; rather, it indicates the how much memory is saved for backward for a fixed region.


torch.compile and min-cut partitioner

Another notable approach to keep in mind is torch.compile (introduced in PyTorch 2.0). Like activation checkpointing, torch.compile can also perform some level of recomputation under the hood. Specifically, it traces the forward and backward computations into a single joint graph, which is then processed by a “min-cut” partitioner. This partitioner uses a min-cut/max-flow algorithm to split the graph such that it minimizes the number of tensors that need to be saved for backward.

At first glance, this might sound a lot like what we want for activation memory reduction. However, the reality is more nuanced. By default, the partitioner’s primary goal is to reduce runtime. As a result, it only recomputes certain types of operations—primarily simpler, fusible, and non-compute-intensive ops (like pointwise ops).

Placing “compile” on the speed-versus-memory tradeoff diagram…

flow diagram

It is to the top-left of the eager non-AC point, as we expect torch.compile to improve on both speed and memory.

On the other hand, relative to activation checkpointing, torch.compile is more conservative about what it recomputes, placing it closer to the top-left on the speed-versus-memory diagram.


Selective Activation Checkpoint [NEW!]

While normal checkpointing recomputes every op in a chosen region, selective activation checkpointing (SAC) is an additional setting on top of activation checkpointing that you can apply to have a more granular control over which operations to recompute.

This can be useful if you have certain more expensive operations like matmuls which you prefer to avoid recomputing, but still generally want to recompute cheaper operations like pointwise.

flow diagram

Where plain AC (left) would save a single tensor and then recompute the entire AC’d region, with SAC (right) you can selectively save specific operations (marked red) in the region, so you can avoid recomputing them.

To specify what to selectively save, you can specify a policy_fn. To illustrate the additional trade offs you can make with this, we present two simple policy functions.

Policy 1: Not recomputing matmuls:

aten = torch.ops.aten
compute_intensive_ops = [  
        aten.mm,
        aten.bmm,
        aten.addmm,
] 
def policy_fn(ctx, op, *args, **kwargs):
    if op in compute_intensive_ops:
        return CheckpointPolicy.MUST_SAVE
    else:
        return CheckpointPolicy.PREFER_RECOMPUTE

flow diagram

Policy 2: More aggressively save anything compute intensive

# torch/_functorch/partitioners.py
aten = torch.ops.aten
compute_intensive_ops = [  
   aten.mm,
   aten.convolution,
   aten.convolution_backward,
   aten.bmm,
   aten.addmm,
   aten._scaled_dot_product_flash_attention,
   aten._scaled_dot_product_efficient_attention,
   aten._flash_attention_forward,
   aten._efficient_attention_forward,
   aten.upsample_bilinear2d,
   aten._scaled_mm
] 
def policy_fn(ctx, op, *args, **kwargs):
    if op in compute_intensive_ops:
        return CheckpointPolicy.MUST_SAVE
    else:
        return CheckpointPolicy.PREFER_RECOMPUTE

flow diagram

On the speed-versus-memory diagram, SAC is plotted as a range of points from closer to AC to closer to Eager, depending on your chosen policy.

flow diagram

Try it out! (Available in 2.5 as a prototype feature; see docs for more info + copy-pastable example)

from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts

# Create a policy function that returns a CheckpointPolicy
def policy_fn(ctx, op, *args, **kwargs):
    if op in ops_to_save:
        return CheckpointPolicy.MUST_SAVE
    else:
        return CheckpointPolicy.PREFER_RECOMPUTE

# Use the context_fn= arg of the existing checkpoint API
out = checkpoint(
    fn, *args,
    use_reentrant=False,
    # Fill in SAC context_fn's policy_fn with functools.partial
    context_fn=partial(create_selective_checkpoint_contexts, policy_fn),
)


(compile-only) Memory Budget API [NEW!]

As mentioned previously, any given SAC policy can be represented as a point on a speed-memory tradeoff diagram. Not all policies are created equal, however. The “optimal” policies are the ones that fall on a pareto curve, e.g. for all policies that incur the same memory overhead, this policy is the one that minimizes the amount of required compute.

For users who are using torch.compile, we offer a memory budget API that automatically applies SAC over your compiled region with a pareto-optimal policy given a user-specified “memory budget” between 0 and 1, where a budget of 0 behaves like plain-AC and a budget of 1 behaves like default torch.compile.

flow diagram

Below are some real results on a transformer model:

flow diagram

We observe a 50% memory reduction by recomputing only pointwise ops, with a steady drop-off as you recompute more and more of your matmuls. Attention is the most expensive, so you tend to want to recompute those last.

Try it out! (Available in 2.4 as an experimental feature; see this comment block for more info)

torch._dynamo.config.activation_memory_budget = 0.5

out = torch.compile(fn)(inp)

Conclusion

flow diagram

In summary, activation checkpointing techniques in PyTorch offer a variety of ways to balance memory and compute demands, from simple region-based checkpointing to more selective and automated methods. By choosing the option that best matches your model’s structure and resource constraints, you can achieve significant memory savings with an acceptable trade-off in compute.

Acknowledgements

We would like to thank Meta’s xformers team including Francisco Massa for working on the original version of Selective Activation Checkpoint.

Read More

📣 Submit to Speak at PyTorch Conference + Save on Registration

📣 Submit to Speak at PyTorch Conference + Save on Registration

Step into the Future of AI at PyTorch Conference 2025.

banner ad for conference

The Call for Proposals for PyTorch Conference 2025 is officially open!

Join us in San Francisco from October 22–23, 2025, to showcase your expertise and innovations with PyTorch—the industry-leading, open-source machine learning framework powering innovations from bare-metal infrastructure to sophisticated application and agent layers. This is your opportunity to share insights, breakthroughs, and case studies with a global audience of AI and Generative AI practitioners, researchers, and developers.

people watching presentation at conference

Submit your proposals and prepare to engage, learn, and network alongside some of the brightest minds in the AI/ML community. We’re seeking sessions, Birds of a Feather discussions, lightning talks, and poster sessions on the following topics:

  • Core PyTorch Framework
  • PyTorch on Accelerator Hardware
  • PyTorch Ecosystem and Tools
  • AI Applications and Use Cases
  • AI in Research and Academia
  • AI in Industry and Enterprise Applications
  • AI Infrastructure and Scalability
  • Ethical AI, Governance, and Regulation
  • Training, Fine-Tuning, and Alignment
  • Inference, Deployment, and Serving
  • Performance Measurement and Benchmarking
  • Data Engineering and Management for AI
  • Generative AI and Large Language Models (LLMs)
  • Model Optimization and Efficiency
  • Open Source Collaboration, Education and Community Building
  • Edge AI and On-Device
  • DL Compilers and Kernel Authoring

Learn more and submit your talk by Sunday, June 1, at 11:59 PDT!


SUBMIT TO SPEAK


people arriving at conference

Save up to USD$500 with Super Early Bird Pricing!

  • Reserve your pass by 11:59 PM PDT on March 21 and score Super Early Bird pricing for just USD$499. That’s a savings of up to USD$500!
  • Student or faculty? Learn more about our discounted academic rate.
  • Need help covering travel costs? We offer discretionary travel funding for those community members who would otherwise not be able to attend. Learn more.

Become a Sponsor at PyTorch Conference 2025!

Seize your opportunity to influence the future of Generative AI and Machine Learning by sponsoring PyTorch Conference 2025. PyTorch is at the forefront of innovation—empowering rapid experimentation, flexible model development, and efficient deployment into production environments with its powerful, versatile ecosystem of tools and thriving community of dedicated users.

As a sponsor, you’ll gain more than visibility; you’ll strategically position your organization at the heart of a vibrant, global AI/ML ecosystem. Connect directly with 3,000+ expert attendees, researchers, engineers, and decision-makers, and actively shape the conversations driving the next generation of AI advancements.

For more details on CFP submissions, registration, and sponsorship, visit the PyTorch Conference Website.

Read More