Announcing the PyTorch Docathon 2025

We’re thrilled to announce the 2025 PyTorch Docathon! This is a hackathon-style event aimed at enhancing PyTorch documentation with the support of the community. Documentation is a vital component of any technology, and by refining it, we can simplify the onboarding process for new users, help them effectively utilize PyTorch’s features, and ultimately speed up the transition from research to production in machine learning.

WHY PARTICIPATE

Low Barrier to Entry

Unlike many open-source projects that require deep knowledge of the codebase and previous contributions to join hackathon events, the Docathon is tailored for newcomers. While we expect participants to be familiar with Python, and have basic knowledge of PyTorch and machine learning, there are tasks related to website issues that don’t even require that level of expertise.

Tangible Results

A major advantage of the Docathon is witnessing the immediate impact of your contributions. Enhancing documentation significantly boosts a project’s usability and accessibility, and you’ll be able to observe these improvements directly. Seeing tangible outcomes can also be a strong motivator to continue contributing.

Collaborative Environment

The Docathon fosters a collaborative atmosphere, offering you the chance to work alongside other contributors and PyTorch maintainers to improve the documentation. This is a fantastic opportunity to learn from peers, exchange ideas, and build connections.

Learning Opportunities

Even if you’re not a PyTorch expert, the Docathon offers a valuable learning experience. You’ll have the chance to delve into PyTorch modules, test tutorials on your machine, and explore them in the CI environment.

WHO SHOULD PARTICIPATE

Whether you’re a seasoned documentation expert or just starting out, we invite everyone to join in the PyTorch docathon to contribute and develop your skills and knowledge to help improve the documentation for everyone! We will have issues labelled by skill level, and the PyTorch Discord will be available for collaboration and help.

EVENT DETAILS

  • June 3: Kick-off 10 AM PT
  • June 4 – June 15: Submissions and Feedback
  • June 16 – June 17: Final Reviews
  • June 18: Winner Announcements

Make sure to RSVP to the event so you receive all the notifications and instructions on how to participate.

Further details about the Docathon will be shared during the Kick-off call on June 3.

Don’t forget to register for this year’s event: RSVP now

Read More

FlexAttention Part II: FlexAttention for Inference

FlexAttention Part II: FlexAttention for Inference

Overview

In PyTorch 2.5.0 release, we introduced FlexAttention torch.nn.attention.flex_attention for ML researchers who’d like to customize their attention kernels without writing kernel code. This blog introduces our decoding backend optimized for inference, supporting GQA and PagedAttention, along with feature updates including nested jagged tensor support, performance tuning guides and trainable biases support.

If you’re looking for an easy way to play around with FlexAttention in your post-training / inference pipeline, PyTorch native post-training library torchtune and inference codebase gpt-fast already have FlexAttention integrated. Try it out!

We are excited to share that our paper on FlexAttention has been accepted for presentation at the MLSys2025 Conference held from May 12-15th in Santa Clara, California.

Title: FlexAttention: A Programming Model for Generating Optimized Attention Kernels. Poster

FlexAttention for Inference

TL;DR: torch.compile lowers flex_attention to a fused FlashDecoding kernel when it runs on a very short query.

One fused attention kernel does not suit all – especially in long-context LLM inference.

The decoding phase of LLM inference is an iterative process: tokens are generated one at a time, requiring N forward passes to generate an N-token sentence. Fortunately, each iteration doesn’t need to recompute self-attention over the full sentence — previously calculated tokens are cached, therefore we only need to attend the newly generated token to the cached context.

This results in a unique attention pattern where a short query sequence (1 token) attends to a long key-value cache (context length up to 128k). Traditional optimizations for square attention kernels (q_len ≈ kv_len) don’t directly apply here. This pattern poses new challenges for GPU memory utilization and occupancy. We build a dedicated FlexDecoding backend optimized for long-context LLM inference incorporating decoding-specific techniques from FlashDecoding.

FlexDecoding is implemented as an alternative backend for the torch.nn.attention.flex_attention operator. flex_attention automatically switches to the FlexDecoding backend for its JIT compilation when given a short query and a long KV cache. If the input shape changes significantly, for example transitioning from the prefill phase to decoding, JIT recompilation generates a separate kernel for each scenario.

flex_attention = torch.compile(flex_attention)

k_cache = torch.random(B, H, 16384, D) 
v_cache = torch.random(B, H, 16384, D)

...

# Prefill Phase: query shape = [B, H, 8000, D]
flex_attention(q_prefill, k_cache, v_cache, ...) # Uses FlexAttention backend optimized for prefill & training

# Decoding Phase: q_last_token shape = [B, H, 1, D]
flex_attention(q_last_token  , k_cache, v_cache, ...) # Recompiles with the FlexDecoding backend 

# decode 2 tokens at the same time: q_last_2_tokens shape = [B, H, 2, D]
flex_attention(q_last_2_tokens, k_cache, v_cache, ...) # No recompilation needed! Runs the decoding kernel again.

Working with KV Cache

One of the key optimizations for efficient inference is maintaining a preallocated KV cache that updates in place as new tokens are generated. Instead of enforcing a specific KV cache policy with a dedicated API, FlexDecoding allows users to define and manage the KV cache themselves.

Similar to FlexAttention, FlexDecoding takes user-defined mask_mod and score_mod functions. These functions modify attention scores before the softmax operation.

score_mod(score, b, h, q_idx, kv_idx) -> tensor # return updated score

Score is a scalar pytorch tensor that represents the dot product of a query token and a key token. The rest of the arguments specify which score is being computed:

  • b batch index
  • h attention head index
  • q_idx token position in query tensor
  • kv_idx token position in key/value tensor

In the decoding phase, previously calculated tokens are cached, and only the latest generated token (i-th) is used as the query. A naive causal mask on this one token query looks like this:

def causal(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, -float("inf"))

This is problematic: the new token “saw” should attend to all previously generated tokens i.e. “The cat sat on the mat and saw”, not just the first entry in the kv cache. To correct this, the score_mod needs to offset q_idx by for accurate decoding.

Creating a new score_mod for each token to accommodate the offset is slow since it means FlexAttention needs to be recompiled every iteration for a different score_mod. Instead,

We define this offset as a tensor and increment its value at each iteration:

offset = torch.tensor(i, "cuda")
def causal_w_offset(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx + offset >= kv_idx, score, -float("inf"))

# Attend the i-th token
flex_attention(..., score_mod=causal_w_offset  ) # Compiles the kernel here 
...
# Attend the i+1-th token
offset = offset + 1 # Increment offset
flex_attention(..., score_mod=causal_w_offset ) # Doesn't need to recompile! 

Notably, here offset becomes a captured tensor and it does not need to recompile if offset changes values.

Manually rewriting your score_mod and mask_mod for offset handling isn’t necessary. We can automate this process with a generic rewriter:

offset = torch.tensor(i, "cuda")

def get_score_mod_w_offset(score_mod: _score_mod_signature, _offset: tensor):
    def _score_mod(score, b, h, q, kv):
        return score_mod(score, b, h, q + _offset, kv)
    return _score_mod

def get_mask_mod_w_offset(mask_mod: _mask_mod_signature, _offset: tensor):
    def _mask_mod(b, h, q, kv):
        return mask_mod(b, h, q + _offset, kv)
    return _mask_mod

causal_w_offset = get_score_mod_w_offset(causal, offset)

BlockMask for Inference

We can also use BlockMask with inference to leverage mask sparsity. The idea is to precompute the BlockMask once during model setup and use slices of it during decoding

Precomputing BlockMask

During setup, we create a squared BlockMask for MAX_SEQ_LEN x MAX_SEQ_LEN:

from torch.nn.attention.flex_attention import create_block_mask

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

block_mask = create_block_mask(causal_mask, B=None, H=None, Q_LEN=MAX_SEQ_LEN,KV_LEN=MAX_SEQ_LEN)

Using BlockMask During Decoding

For the i-th token, we use a slice of the mask:

block_offset = i // block_mask.BLOCK_SIZE[0]
block_mask_slice = block_mask[:, :, block_offset]

# don't forget to use the mask_mod with offset! 
block_mask_slice.mask_mod = get_mask_mod_w_offset(causal_mask)

Performance

FlexDecoding kernel performs on par with FlashDecoding (FAKV) and significantly outperforms pytorch scaled_dot_product_attention (code).

FlexDecoding boosts LLaMa3.1-8B serving performance by 1.22x-2.04x, and LLaMa3.1-70B performance by 0.99x – 1.66x compared to SDPA in gpt-fast. (code)

Paged Attention

vLLM is one of the popular LLM serving engines, powered by the efficient memory management from PagedAttention. Existing PagedAttention implementation requires dedicated CUDA kernels and shows limited flexibility on supporting emerging attention variants. In this section, we present a PT2-native PagedAttention implementation that is enabled by flex attention and torch.compile.

PagedAttention scatters KV cache to reduce memory fragmentation and support higher batch sizes. Without PagedAttention, KV cache from the same request are stored in a contiguous memory, requiring 2 tensor of shape B x H x KV LEN x D. We call it a logical KV cache. Here, KV_LEN is the maximum sequence length over all requests in a batch. Considering the Figure 1(a), KV_LEN is 9 thus all requests must be padded to 9 tokens, leading to large memory waste. With PagedAttention, we can chunk each request into multiple pages of the same size page_size and scatter these pages into a physical KV cache of shape 1 x H x max seq len x D, where max_seq_len=n_pages x page_size. This avoids padding requests to the same length and saves memory. Specifically, we provide an assign API to update KV cache via index computations:

def assign(
    batch_idx: torch.Tensor,
    input_pos: torch.Tensor,
    k_val: torch.Tensor,
    v_val: torch.Tensor,
    k_cache: torch.Tensor,
    v_cache: torch.Tensor,
) -> None

Behind this assign API is a page table, a tensor mapping logical KV cache to physical KV cache:

[batch_idx, logical_page_idx] -> physical_page_idx

assign takes k_val and v_val and scatters to physical KV cache guided by the mapping from the page table.

Paged Attention with Page Table

A natural question is, how to integrate PagedAttention with flex attention to support diverse attention variants? A naive idea is to materialize the logical KV cache before computing with flex attention. But this leads to redundant memory copy and bad performance. Another idea is to build a dedicated CUDA or Triton kernel for paged attention, similar to existing PagedAttention implementation. However, this adds much manual effort and code complexity.

Instead, we design a fused indirect memory access by converting a logical block mask according to the page table. In FlexAttention, we exploit BlockMask to identify logical blocks and skip redundant computation. While Paged Attention adds an extra layer of indirect memory access, we can further convert the logical block mask to the physical block mask corresponding to the page table, as illustrated in Figure 2. Our PagedAttention implementation provides a convert_logical_block_mask via torch.gather calls:

def convert_logical_block_mask(
    block_mask: BlockMask,
    batch_idx: Optional[torch.Tensor] = None,
) -> BlockMask

Paged Attention via Block Mask Conversion

One remaining question is how to rewrite user-specified mask_mod and score_mod for PagedAttention. When users specify these modifications, they write with logical indices without the knowledge of the page table maintained at runtime. The following code shows an automated conversion at runtime which is necessary to rewrite user-specified modifications with physical kv indices. The new_mask_mod would take the physical_kv_idx and convert it back to the logical_kv_idx and apply user-specified mask_mod on the logical_kv_idx for the correct mask. For efficiency, we maintain physical_to_logical as a mapping from physical_kv_block to logical_kv_block to facilitate the conversion. For correctness, we mask out-of-boundary blocks as False with a torch.where call. After batching logical KV caches from multiple requests into the same physical KV cache, there are much more physical blocks than the number of logical blocks for each request. Thus, a physical block may not have a corresponding logical block for a specific request during block mask conversion. By masking as False with torch.where, we can ensure the correctness that data from different requests do not interfere with each other. Similarly, we can convert the score_mod automatically.

def get_mask_mod(mask_mod: Optional[_mask_mod_signature]) -> _mask_mod_signature:
    if mask_mod is None:
        mask_mod = noop_mask

    def new_mask_mod(
        b: torch.Tensor,
        h: torch.Tensor,
        q_idx: torch.Tensor,
        physical_kv_idx: torch.Tensor,
    ):
        physical_kv_block = physical_kv_idx // page_size
        physical_kv_offset = physical_kv_idx % page_size
        logical_block_idx = physical_to_logical[b, physical_kv_block]
        logical_kv_idx = logical_block_idx * page_size + physical_kv_offset
        return torch.where(
            logical_block_idx >= 0, mask_mod(b, h, q_idx, logical_kv_idx), False
        )

    return new_mask_mod

Figure 3 demonstrates the latency from Paged Attention (code). Overall, there is less than 5% overhead from Flex Attention with Paged Attention, compared with Flex Attention only. We also observe an on-par performance with Flash Attention v2. A minimal serving example further shows that PagedAttention can support 76x higher batch size when evaluating on OpenOrca dataset which includes 1M GPT-4 completions and 3.2M GPT-3.5 completions.

Paged Attention: Latency under diverse sequence length

Ragged input sequences with Nested Jagged Tensors (NJTs)

FlexAttention now supports ragged-sized input sequences through the use of Nested Jagged Tensors (NJTs). NJTs represent ragged-sized sequences by packing sequences into a single “stacked sequence” and maintaining a set of offsets delimiting sequence boundaries for each batch item.

A block mask can be created for input NJTs through the new create_nested_block_mask() API. The returned block mask is compatible with the ragged structure of the given NJT, treating it as a single “stacked sequence” with inter-sequence attention automatically masked out. The mask_mod or score_mod function can be written as usual.

from torch.nn.attention.flex_attention import create_nested_block_mask, flex_attention

BATCH = 8
NUM_HEADS = 8
D = 16
device = "cuda"

# Input NJTs of shape (BATCH, SEQ_LEN*, D) with ragged SEQ_LEN
sequence_lengths = [torch.randint(5, 30, ()).item() for _ in range(BATCH)]
query = torch.nested.nested_tensor([
    torch.randn(seq_len, NUM_HEADS * D, device=device)
    for seq_len in sequence_lengths
], layout=torch.jagged)
key = torch.randn_like(query)
value = torch.randn_like(query)

# View as shape (BATCH, NUM_HEADS, SEQ_LEN*, HEAD_DIM)
query = query.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)
key = key.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)
value = value.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)

# Simple causal mask
def my_mask_mod(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

# Construct a block mask using the ragged structure of the
# specified query NJT. Ragged-sized sequences are treated as a single
# "stacked sequence" with inter-sequence attention masked out.
block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query)

# For cross attention, create_nested_block_mask() also supports a
# rectangular block mask using the ragged structures of both query / key.
#block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query, key)

output = flex_attention(query, key, value, block_mask=block_mask)

Trainable Biases

FlexAttention now supports trainable parameters in score_mod functions. This feature enables users to reference tensors that require gradients within their score_mod implementations, with gradients automatically backpropagating through these parameters during training.

Memory-Efficient Gradient Accumulation

Instead of materializing the full attention scores matrix, FlexAttention uses atomic additions (tl.atomic_add) to accumulate gradients. This approach significantly reduces memory usage at the cost of introducing some non-determinism in gradient calculations.

Handling Broadcasted Operations

Broadcasting operations in the forward pass (e.g., score + bias[h]) require special consideration in the backward pass. When broadcasting a tensor across multiple attention scores within a head or other dimensions, we need to reduce these gradients back to the original tensor shape. Rather than materializing the full attention score matrix to perform this reduction, we use atomic operations. While this incurs some runtime overhead, it allows us to maintain memory efficiency by avoiding the materialization of large intermediate tensors.

Current Limitations

The implementation currently allows only a single read from each input tensor in the score_mod function. For example, bias[q_idx] + bias[kv_idx] would not be supported as it reads from the same tensor twice. We hope to remove this restriction in the future.

Simple Example:

bias = torch.randn(num_heads, requires_grad=True)
def score_mod(score, b, h, q_idx, kv_idx):
    return score + bias[h]  

Performance Tuning for FlexAttention

TL;DR

For optimal performance, compile FlexAttention using max-autotune, especially when dealing with complex score_mods and mask_mods:

flex_attention = torch.compile(flex_attention, dynamic=True, mode=’max-autotune’)

What is max-autotune?

max-autotune is a torch.compile mode in which TorchInductor sweeps many kernel parameters (e.g., tile size, num_stages) and selects the best-performing configuration. This process allows kernels to test both successful and failing configurations without issues, and find the best viable configuration.

While compilation takes longer with max-autotune, the optimal configuration is cached for future kernel executions.

Here’s an example of FlexAttention compiled with max-autotune:

triton_flex_attention_backward_7 0.2528 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=32, BLOCK_N1=32, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=7, HAS_FULL_BLOCKS=False, IS_DIVISIBLE=False, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=128, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.08838834764831843, SPARSE_KV_BLOCK_SIZE=1073741824, SPARSE_Q_BLOCK_SIZE=1073741824, V_HEAD_DIM=128, num_stages=4, num_warps=4

Why Use max-autotune for FlexAttention?

The amount of shared memory utilized in FlexAttention depends on score_mod and mask_mod methods. This variability means that the preconfigured default kernel parameters may lead to performance cliffs or even out of shared memory** **errors on certain hardware for some masks/mods.

For instance, with document masks, default configurations can halve GPU occupancy, reducing performance to ~75% of its potential on some GPUs. To avoid such issues, we strongly recommend enabling max-autotune.

Updates and Enhancements

  • Now available as a prototype feature in PyTorch 2.5.0
  • Fixed critical correctness issues, including a bug affecting multiple calls to FlexAttention within the same call to torch.compile

Expanded Architecture Support

  • Arbitrary sequence length support – no longer requires multiples of 128
  • Added native grouped-query attention (GQA) support via is_gqa=True
  • Enhanced dimension flexibility:
    • Different QK and V head dimensions
    • Non-power-of-two head dimensions
  • Trainable attention biases (prototype)

Under the Hood

  • New fused CPU backend
  • Improved TF32 handling for float32 inputs
  • Resolved various dynamic shape issues
  • Output layout matching query strides

These updates make FlexAttention more robust and flexible while maintaining its core promise of combining PyTorch’s ease of use with FlashAttention’s performance benefits.

Read More

FlexAttention Part II: FlexAttention for Inference

FlexAttention Part II: FlexAttention for Inference

Overview

In PyTorch 2.5.0 release, we introduced FlexAttention torch.nn.attention.flex_attention for ML researchers who’d like to customize their attention kernels without writing kernel code. This blog introduces our decoding backend optimized for inference, supporting GQA and PagedAttention, along with feature updates including nested jagged tensor support, performance tuning guides and trainable biases support.

If you’re looking for an easy way to play around with FlexAttention in your post-training / inference pipeline, PyTorch native post-training library torchtune and inference codebase gpt-fast already have FlexAttention integrated. Try it out!

We are excited to share that our paper on FlexAttention has been accepted for presentation at the MLSys2025 Conference held from May 12-15th in Santa Clara, California.

Title: FlexAttention: A Programming Model for Generating Optimized Attention Kernels. Poster

FlexAttention for Inference

TL;DR: torch.compile lowers flex_attention to a fused FlashDecoding kernel when it runs on a very short query.

One fused attention kernel does not suit all – especially in long-context LLM inference.

The decoding phase of LLM inference is an iterative process: tokens are generated one at a time, requiring N forward passes to generate an N-token sentence. Fortunately, each iteration doesn’t need to recompute self-attention over the full sentence — previously calculated tokens are cached, therefore we only need to attend the newly generated token to the cached context.

This results in a unique attention pattern where a short query sequence (1 token) attends to a long key-value cache (context length up to 128k). Traditional optimizations for square attention kernels (q_len ≈ kv_len) don’t directly apply here. This pattern poses new challenges for GPU memory utilization and occupancy. We build a dedicated FlexDecoding backend optimized for long-context LLM inference incorporating decoding-specific techniques from FlashDecoding.

FlexDecoding is implemented as an alternative backend for the torch.nn.attention.flex_attention operator. flex_attention automatically switches to the FlexDecoding backend for its JIT compilation when given a short query and a long KV cache. If the input shape changes significantly, for example transitioning from the prefill phase to decoding, JIT recompilation generates a separate kernel for each scenario.

flex_attention = torch.compile(flex_attention)

k_cache = torch.random(B, H, 16384, D) 
v_cache = torch.random(B, H, 16384, D)

...

# Prefill Phase: query shape = [B, H, 8000, D]
flex_attention(q_prefill, k_cache, v_cache, ...) # Uses FlexAttention backend optimized for prefill & training

# Decoding Phase: q_last_token shape = [B, H, 1, D]
flex_attention(q_last_token  , k_cache, v_cache, ...) # Recompiles with the FlexDecoding backend 

# decode 2 tokens at the same time: q_last_2_tokens shape = [B, H, 2, D]
flex_attention(q_last_2_tokens, k_cache, v_cache, ...) # No recompilation needed! Runs the decoding kernel again.

Working with KV Cache

One of the key optimizations for efficient inference is maintaining a preallocated KV cache that updates in place as new tokens are generated. Instead of enforcing a specific KV cache policy with a dedicated API, FlexDecoding allows users to define and manage the KV cache themselves.

Similar to FlexAttention, FlexDecoding takes user-defined mask_mod and score_mod functions. These functions modify attention scores before the softmax operation.

score_mod(score, b, h, q_idx, kv_idx) -> tensor # return updated score

Score is a scalar pytorch tensor that represents the dot product of a query token and a key token. The rest of the arguments specify which score is being computed:

  • b batch index
  • h attention head index
  • q_idx token position in query tensor
  • kv_idx token position in key/value tensor

In the decoding phase, previously calculated tokens are cached, and only the latest generated token (i-th) is used as the query. A naive causal mask on this one token query looks like this:

def causal(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, -float("inf"))

This is problematic: the new token “saw” should attend to all previously generated tokens i.e. “The cat sat on the mat and saw”, not just the first entry in the kv cache. To correct this, the score_mod needs to offset q_idx by for accurate decoding.

Creating a new score_mod for each token to accommodate the offset is slow since it means FlexAttention needs to be recompiled every iteration for a different score_mod. Instead,

We define this offset as a tensor and increment its value at each iteration:

offset = torch.tensor(i, "cuda")
def causal_w_offset(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx + offset >= kv_idx, score, -float("inf"))

# Attend the i-th token
flex_attention(..., score_mod=causal_w_offset  ) # Compiles the kernel here 
...
# Attend the i+1-th token
offset = offset + 1 # Increment offset
flex_attention(..., score_mod=causal_w_offset ) # Doesn't need to recompile! 

Notably, here offset becomes a captured tensor and it does not need to recompile if offset changes values.

Manually rewriting your score_mod and mask_mod for offset handling isn’t necessary. We can automate this process with a generic rewriter:

offset = torch.tensor(i, "cuda")

def get_score_mod_w_offset(score_mod: _score_mod_signature, _offset: tensor):
    def _score_mod(score, b, h, q, kv):
        return score_mod(score, b, h, q + _offset, kv)
    return _score_mod

def get_mask_mod_w_offset(mask_mod: _mask_mod_signature, _offset: tensor):
    def _mask_mod(b, h, q, kv):
        return mask_mod(b, h, q + _offset, kv)
    return _mask_mod

causal_w_offset = get_score_mod_w_offset(causal, offset)

BlockMask for Inference

We can also use BlockMask with inference to leverage mask sparsity. The idea is to precompute the BlockMask once during model setup and use slices of it during decoding

Precomputing BlockMask

During setup, we create a squared BlockMask for MAX_SEQ_LEN x MAX_SEQ_LEN:

from torch.nn.attention.flex_attention import create_block_mask

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

block_mask = create_block_mask(causal_mask, B=None, H=None, Q_LEN=MAX_SEQ_LEN,KV_LEN=MAX_SEQ_LEN)

Using BlockMask During Decoding

For the i-th token, we use a slice of the mask:

block_offset = i // block_mask.BLOCK_SIZE[0]
block_mask_slice = block_mask[:, :, block_offset]

# don't forget to use the mask_mod with offset! 
block_mask_slice.mask_mod = get_mask_mod_w_offset(causal_mask)

Performance

FlexDecoding kernel performs on par with FlashDecoding (FAKV) and significantly outperforms pytorch scaled_dot_product_attention (code).

FlexDecoding boosts LLaMa3.1-8B serving performance by 1.22x-2.04x, and LLaMa3.1-70B performance by 0.99x – 1.66x compared to SDPA in gpt-fast. (code)

Paged Attention

vLLM is one of the popular LLM serving engines, powered by the efficient memory management from PagedAttention. Existing PagedAttention implementation requires dedicated CUDA kernels and shows limited flexibility on supporting emerging attention variants. In this section, we present a PT2-native PagedAttention implementation that is enabled by flex attention and torch.compile.

PagedAttention scatters KV cache to reduce memory fragmentation and support higher batch sizes. Without PagedAttention, KV cache from the same request are stored in a contiguous memory, requiring 2 tensor of shape B x H x KV LEN x D. We call it a logical KV cache. Here, KV_LEN is the maximum sequence length over all requests in a batch. Considering the Figure 1(a), KV_LEN is 9 thus all requests must be padded to 9 tokens, leading to large memory waste. With PagedAttention, we can chunk each request into multiple pages of the same size page_size and scatter these pages into a physical KV cache of shape 1 x H x max seq len x D, where max_seq_len=n_pages x page_size. This avoids padding requests to the same length and saves memory. Specifically, we provide an assign API to update KV cache via index computations:

def assign(
    batch_idx: torch.Tensor,
    input_pos: torch.Tensor,
    k_val: torch.Tensor,
    v_val: torch.Tensor,
    k_cache: torch.Tensor,
    v_cache: torch.Tensor,
) -> None

Behind this assign API is a page table, a tensor mapping logical KV cache to physical KV cache:

[batch_idx, logical_page_idx] -> physical_page_idx

assign takes k_val and v_val and scatters to physical KV cache guided by the mapping from the page table.

Paged Attention with Page Table

A natural question is, how to integrate PagedAttention with flex attention to support diverse attention variants? A naive idea is to materialize the logical KV cache before computing with flex attention. But this leads to redundant memory copy and bad performance. Another idea is to build a dedicated CUDA or Triton kernel for paged attention, similar to existing PagedAttention implementation. However, this adds much manual effort and code complexity.

Instead, we design a fused indirect memory access by converting a logical block mask according to the page table. In FlexAttention, we exploit BlockMask to identify logical blocks and skip redundant computation. While Paged Attention adds an extra layer of indirect memory access, we can further convert the logical block mask to the physical block mask corresponding to the page table, as illustrated in Figure 2. Our PagedAttention implementation provides a convert_logical_block_mask via torch.gather calls:

def convert_logical_block_mask(
    block_mask: BlockMask,
    batch_idx: Optional[torch.Tensor] = None,
) -> BlockMask

Paged Attention via Block Mask Conversion

One remaining question is how to rewrite user-specified mask_mod and score_mod for PagedAttention. When users specify these modifications, they write with logical indices without the knowledge of the page table maintained at runtime. The following code shows an automated conversion at runtime which is necessary to rewrite user-specified modifications with physical kv indices. The new_mask_mod would take the physical_kv_idx and convert it back to the logical_kv_idx and apply user-specified mask_mod on the logical_kv_idx for the correct mask. For efficiency, we maintain physical_to_logical as a mapping from physical_kv_block to logical_kv_block to facilitate the conversion. For correctness, we mask out-of-boundary blocks as False with a torch.where call. After batching logical KV caches from multiple requests into the same physical KV cache, there are much more physical blocks than the number of logical blocks for each request. Thus, a physical block may not have a corresponding logical block for a specific request during block mask conversion. By masking as False with torch.where, we can ensure the correctness that data from different requests do not interfere with each other. Similarly, we can convert the score_mod automatically.

def get_mask_mod(mask_mod: Optional[_mask_mod_signature]) -> _mask_mod_signature:
    if mask_mod is None:
        mask_mod = noop_mask

    def new_mask_mod(
        b: torch.Tensor,
        h: torch.Tensor,
        q_idx: torch.Tensor,
        physical_kv_idx: torch.Tensor,
    ):
        physical_kv_block = physical_kv_idx // page_size
        physical_kv_offset = physical_kv_idx % page_size
        logical_block_idx = physical_to_logical[b, physical_kv_block]
        logical_kv_idx = logical_block_idx * page_size + physical_kv_offset
        return torch.where(
            logical_block_idx >= 0, mask_mod(b, h, q_idx, logical_kv_idx), False
        )

    return new_mask_mod

Figure 3 demonstrates the latency from Paged Attention (code). Overall, there is less than 5% overhead from Flex Attention with Paged Attention, compared with Flex Attention only. We also observe an on-par performance with Flash Attention v2. A minimal serving example further shows that PagedAttention can support 76x higher batch size when evaluating on OpenOrca dataset which includes 1M GPT-4 completions and 3.2M GPT-3.5 completions.

Paged Attention: Latency under diverse sequence length

Ragged input sequences with Nested Jagged Tensors (NJTs)

FlexAttention now supports ragged-sized input sequences through the use of Nested Jagged Tensors (NJTs). NJTs represent ragged-sized sequences by packing sequences into a single “stacked sequence” and maintaining a set of offsets delimiting sequence boundaries for each batch item.

A block mask can be created for input NJTs through the new create_nested_block_mask() API. The returned block mask is compatible with the ragged structure of the given NJT, treating it as a single “stacked sequence” with inter-sequence attention automatically masked out. The mask_mod or score_mod function can be written as usual.

from torch.nn.attention.flex_attention import create_nested_block_mask, flex_attention

BATCH = 8
NUM_HEADS = 8
D = 16
device = "cuda"

# Input NJTs of shape (BATCH, SEQ_LEN*, D) with ragged SEQ_LEN
sequence_lengths = [torch.randint(5, 30, ()).item() for _ in range(BATCH)]
query = torch.nested.nested_tensor([
    torch.randn(seq_len, NUM_HEADS * D, device=device)
    for seq_len in sequence_lengths
], layout=torch.jagged)
key = torch.randn_like(query)
value = torch.randn_like(query)

# View as shape (BATCH, NUM_HEADS, SEQ_LEN*, HEAD_DIM)
query = query.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)
key = key.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)
value = value.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)

# Simple causal mask
def my_mask_mod(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

# Construct a block mask using the ragged structure of the
# specified query NJT. Ragged-sized sequences are treated as a single
# "stacked sequence" with inter-sequence attention masked out.
block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query)

# For cross attention, create_nested_block_mask() also supports a
# rectangular block mask using the ragged structures of both query / key.
#block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query, key)

output = flex_attention(query, key, value, block_mask=block_mask)

Trainable Biases

FlexAttention now supports trainable parameters in score_mod functions. This feature enables users to reference tensors that require gradients within their score_mod implementations, with gradients automatically backpropagating through these parameters during training.

Memory-Efficient Gradient Accumulation

Instead of materializing the full attention scores matrix, FlexAttention uses atomic additions (tl.atomic_add) to accumulate gradients. This approach significantly reduces memory usage at the cost of introducing some non-determinism in gradient calculations.

Handling Broadcasted Operations

Broadcasting operations in the forward pass (e.g., score + bias[h]) require special consideration in the backward pass. When broadcasting a tensor across multiple attention scores within a head or other dimensions, we need to reduce these gradients back to the original tensor shape. Rather than materializing the full attention score matrix to perform this reduction, we use atomic operations. While this incurs some runtime overhead, it allows us to maintain memory efficiency by avoiding the materialization of large intermediate tensors.

Current Limitations

The implementation currently allows only a single read from each input tensor in the score_mod function. For example, bias[q_idx] + bias[kv_idx] would not be supported as it reads from the same tensor twice. We hope to remove this restriction in the future.

Simple Example:

bias = torch.randn(num_heads, requires_grad=True)
def score_mod(score, b, h, q_idx, kv_idx):
    return score + bias[h]  

Performance Tuning for FlexAttention

TL;DR

For optimal performance, compile FlexAttention using max-autotune, especially when dealing with complex score_mods and mask_mods:

flex_attention = torch.compile(flex_attention, dynamic=True, mode=’max-autotune’)

What is max-autotune?

max-autotune is a torch.compile mode in which TorchInductor sweeps many kernel parameters (e.g., tile size, num_stages) and selects the best-performing configuration. This process allows kernels to test both successful and failing configurations without issues, and find the best viable configuration.

While compilation takes longer with max-autotune, the optimal configuration is cached for future kernel executions.

Here’s an example of FlexAttention compiled with max-autotune:

triton_flex_attention_backward_7 0.2528 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=32, BLOCK_N1=32, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=7, HAS_FULL_BLOCKS=False, IS_DIVISIBLE=False, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=128, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.08838834764831843, SPARSE_KV_BLOCK_SIZE=1073741824, SPARSE_Q_BLOCK_SIZE=1073741824, V_HEAD_DIM=128, num_stages=4, num_warps=4

Why Use max-autotune for FlexAttention?

The amount of shared memory utilized in FlexAttention depends on score_mod and mask_mod methods. This variability means that the preconfigured default kernel parameters may lead to performance cliffs or even out of shared memory** **errors on certain hardware for some masks/mods.

For instance, with document masks, default configurations can halve GPU occupancy, reducing performance to ~75% of its potential on some GPUs. To avoid such issues, we strongly recommend enabling max-autotune.

Updates and Enhancements

  • Now available as a prototype feature in PyTorch 2.5.0
  • Fixed critical correctness issues, including a bug affecting multiple calls to FlexAttention within the same call to torch.compile

Expanded Architecture Support

  • Arbitrary sequence length support – no longer requires multiples of 128
  • Added native grouped-query attention (GQA) support via is_gqa=True
  • Enhanced dimension flexibility:
    • Different QK and V head dimensions
    • Non-power-of-two head dimensions
  • Trainable attention biases (prototype)

Under the Hood

  • New fused CPU backend
  • Improved TF32 handling for float32 inputs
  • Resolved various dynamic shape issues
  • Output layout matching query strides

These updates make FlexAttention more robust and flexible while maintaining its core promise of combining PyTorch’s ease of use with FlashAttention’s performance benefits.

Read More

6x faster Async Checkpointing in PyTorch, using Cached Plans, no GIL contention

6x faster Async Checkpointing in PyTorch, using Cached Plans, no GIL contention

Meta: Less Wright, Meet Vadakkanchery, Saurabh Mishra, Ela Krepska, Hamid Shojanazeri, Pradeep Fernando
Crusoe: Ethan Petersen, Martin Cala, Chip Smith

PyTorch DCP (Distributed Checkpointing) has recently enabled new optimizations in asynchronous checkpointing to reduce GPU utilization drop by minimizing collective overhead and improving overall checkpointing efficiency.

Using Crusoe’s 2K H200 cluster, with TorchTitan and training a Llama3-70B, we were able to verify these new features deliver substantial speedups at 1856 GPU scale, reducing the background processing time for async DCP checkpoints from ~436 seconds to ~67 seconds.

This is roughly a 6.5x reduction in background checkpoint processing time, enabling even more total training time to proceed at full training throughput.

Fig 1: 1856 training run with high frequency checkpointing. The first checkpoint (drop down in tps) does not have a cached save plan, and the background processing takes far longer than the rest where the cached plan is used.

Background: What is Asynchronous Checkpointing?

In a standard checkpointing workflow, GPUs are blocked while the checkpointing data is offloaded from GPU to CPU and then written to storage. After the save to physical media is complete, training can resume.

Asynchronous checkpointing greatly reduces this downtime by enabling the actual saving to storage to be done via CPU threads, allowing GPU-based training to continue while the checkpoint data is being persisted in parallel. It is used primarily for intermediate/fault tolerant checkpoints as it unblocks the GPUs much faster compared to the synchronous checkpoints.
For example, in our large-scale experiment, GPU training was blocked for less than a second (.78 seconds at 1856 scale) while checkpoint data was moved from GPU to CPU (staging). At that point, GPU training immediately continues, which is a substantial training time improvement over traditional checkpointing. For reference, Async Checkpointing is covered in more detail here.

Challenges with Asynchronous Checkpointing

However, the background processing inherent in Asynchronous Checkpointing has additional challenges that result in a temporary reduction of training throughput while the storage phase is being completed. These are highlighted below.

GPU utilization drop from GIL contention:

The Global Interpreter Lock (GIL) in Python is a mechanism that prevents multiple native threads from executing Python bytecode at the same time. This lock is necessary mainly because CPython’s memory management is not thread-safe.

DCP currently uses background threads for metadata collectives and uploading to storage. Although these expensive steps are done asynchronously, it leads to contention for the GIL with the trainer threads. This causes the GPU utilization (QPS) to suffer significantly and also increases the e2e upload latency. For large-scale checkpoints, the overhead of the CPU parallel processing has a suppressive effect on net GPU training speed since CPUs also drive the training process via GPU kernel launches.

Please refer to the following figure from our experiments:

Fig 2: One can see a sustained drop in training QPS even after staging (i.e. blocking operation to trainer) is complete.

The first dip in Figure 2 (marked by the purple line) indicates that staging is complete, and training can continue. However, a second drop is evident (marked by the area between the purple and yellow lines) which is due to trainer thread and checkpointing threads contending for the Python GIL, leading to degraded training QPS until the checkpoint thread completes execution.

Collective communications cost:

DCP performs multiple collectives today for various reasons: dedupe, global metadata for the checkpoint, resharding, and distributed exception handling. Collectives are costly as these require network I/O and pickling/unpickling of the large metadata being sent across the GPU network. These collectives become extremely expensive as the job scale grows, leading to significantly higher e2e latency and potential for collective timeouts.

Solutions

Process based async checkpointing

DCP now supports async checkpoint save via a background process. This helps avoid the training QPS drop by eliminating the python GIL contention with the trainer threads. Please see Fig 2 for checkpointing via threads and Fig 3 for checkpointing via background process.

Caching of the save plans

DCP has a clear boundary between the planning and storage I/O steps. SavePlanner in DCP is a stateful component which acts as an access proxy to the state_dict. Planner manages save plans prepared by individual ranks, which carry metadata information necessary to do the write I/O. The planning step involves a collective operation to gather a comprehensive view of the checkpoint on the coordinator rank. The coordinator rank is responsible for de-duplicating parameters/weights to eliminate redundancies, validating the global plan to ensure accuracy and consistency, and creating the global metadata structs. This is followed by a scatter collective where the coordinator rank assigns I/O tasks to each rank. Any transformations done on the plans affect how the storage components finally write the data.

During the course of a training job, multiple checkpoints are saved. In the majority of these cases, only the checkpoint data changes between different save instances, and thus, the plan remains the same. This presented an opportunity for us to cache the plans, pay the planning cost only on the first save, and then amortize that cost across all the subsequent attempts. Only the updated plans (plans which changed in the next attempt) are sent via collective, thus reducing the collective overhead significantly.

Experiment Results

Set up: 1856 H200 GPUs, Llama3-70B, HSDP2 with TorchTitan

After deploying both the solutions above, the following are the key results:

  • TPS drop has significantly narrowed, with a peak dip to 372 vs 315 tps, and for a greatly reduced time window (~67 seconds vs ~437 seconds). This time window is now mostly attributed to the blocking for CPU processing.
  • Subsequent checkpoint save attempts also continue to be much faster due to very low overhead at the planning stage. E2E latency is thus improved by over 6.5x. This will allow our partners to increase the checkpointing frequency and reduce the lost training progress (i.e. wasted training time).

If you look at the very first downspike in Figure 1, this drawdown in GPU processing time takes training throughput from 700 down to 320 tps, and suppresses it for roughly 7 minutes (467 seconds). Once the CPUs have finished processing, training continues again at full speed.

Previously, this ~7 minute suppression would be repeated at every checkpoint. However, with the new process-based checkpointing feature, only the first checkpoint has the full drawdown time (mainly due to overhead from daemon process initialization), as all future checkpoints are executed via the background process, mitigating GIL contention with the trainer threads.

This is visually shown in all the subsequent checkpoints where the average MFU suppression time drops to just over a minute, reflected by the sharp spikes that almost immediately revert to full MFU throughput.

Fig 3: The red box shows the non-cached plan checkpoint, which also includes Checkpoint Background Init process overhead, while the purple box highlights the first checkpoint to run with the cached plan.

This means that even large-scale checkpointing, such as shown in Fig 2 at 1856 GPU scale, can be done with ~6x reduced training throughput impact. This enables Asynchronous DCP checkpointing to be run more frequently (thus better rollback protection) while enhancing total training throughput relative to previous Async Checkpointing overhead.

Using DCP’s cached checkpointing:

This feature is already available as part of the PyTorch nightly builds, and you can test out PyTorch’s Asynchronous DCP checkpointing directly in TorchTitan. Following are the instructions to enable these features:

  • Process-based asynchronous checkpointing:
    • Set the async_checkpointer_type to AsyncCheckpointerType.PROCESS in the async_save API. (file: pytorch/torch/distributed/checkpoint/state_dict_saver.py)
  • Save plan caching:
    • Set the enable_plan_caching flag to true in the DefaultSavePlanner. (file: pytorch/torch/distributed/checkpoint/default_planner.py)

Future work

DCP will be rolling out additional optimizations to further improve the checkpointing cost. Currently even though the save plans are cached, coordinator rank still prepares the metadata. For larger jobs and models with many tensors, this overhead is non-trivial. In the next iteration, DCP will eliminate the metadata overhead and improve the e2e latency further. DCP will also introduce additional optimizations, such as zero-overhead checkpointing, to enable efficient checkpointing in large-scale jobs.

Stay tuned!

Read More

6x faster Async Checkpointing in PyTorch, using Cached Plans, no GIL contention

6x faster Async Checkpointing in PyTorch, using Cached Plans, no GIL contention

Meta: Less Wright, Meet Vadakkanchery, Saurabh Mishra, Ela Krepska, Hamid Shojanazeri, Pradeep Fernando
Crusoe: Ethan Petersen, Martin Cala, Chip Smith

PyTorch DCP (Distributed Checkpointing) has recently enabled new optimizations in asynchronous checkpointing to reduce GPU utilization drop by minimizing collective overhead and improving overall checkpointing efficiency.

Using Crusoe’s 2K H200 cluster, with TorchTitan and training a Llama3-70B, we were able to verify these new features deliver substantial speedups at 1856 GPU scale, reducing the background processing time for async DCP checkpoints from ~436 seconds to ~67 seconds.

This is roughly a 6.5x reduction in background checkpoint processing time, enabling even more total training time to proceed at full training throughput.

Fig 1: 1856 training run with high frequency checkpointing. The first checkpoint (drop down in tps) does not have a cached save plan, and the background processing takes far longer than the rest where the cached plan is used.

Background: What is Asynchronous Checkpointing?

In a standard checkpointing workflow, GPUs are blocked while the checkpointing data is offloaded from GPU to CPU and then written to storage. After the save to physical media is complete, training can resume.

Asynchronous checkpointing greatly reduces this downtime by enabling the actual saving to storage to be done via CPU threads, allowing GPU-based training to continue while the checkpoint data is being persisted in parallel. It is used primarily for intermediate/fault tolerant checkpoints as it unblocks the GPUs much faster compared to the synchronous checkpoints.
For example, in our large-scale experiment, GPU training was blocked for less than a second (.78 seconds at 1856 scale) while checkpoint data was moved from GPU to CPU (staging). At that point, GPU training immediately continues, which is a substantial training time improvement over traditional checkpointing. For reference, Async Checkpointing is covered in more detail here.

Challenges with Asynchronous Checkpointing

However, the background processing inherent in Asynchronous Checkpointing has additional challenges that result in a temporary reduction of training throughput while the storage phase is being completed. These are highlighted below.

GPU utilization drop from GIL contention:

The Global Interpreter Lock (GIL) in Python is a mechanism that prevents multiple native threads from executing Python bytecode at the same time. This lock is necessary mainly because CPython’s memory management is not thread-safe.

DCP currently uses background threads for metadata collectives and uploading to storage. Although these expensive steps are done asynchronously, it leads to contention for the GIL with the trainer threads. This causes the GPU utilization (QPS) to suffer significantly and also increases the e2e upload latency. For large-scale checkpoints, the overhead of the CPU parallel processing has a suppressive effect on net GPU training speed since CPUs also drive the training process via GPU kernel launches.

Please refer to the following figure from our experiments:

Fig 2: One can see a sustained drop in training QPS even after staging (i.e. blocking operation to trainer) is complete.

The first dip in Figure 2 (marked by the purple line) indicates that staging is complete, and training can continue. However, a second drop is evident (marked by the area between the purple and yellow lines) which is due to trainer thread and checkpointing threads contending for the Python GIL, leading to degraded training QPS until the checkpoint thread completes execution.

Collective communications cost:

DCP performs multiple collectives today for various reasons: dedupe, global metadata for the checkpoint, resharding, and distributed exception handling. Collectives are costly as these require network I/O and pickling/unpickling of the large metadata being sent across the GPU network. These collectives become extremely expensive as the job scale grows, leading to significantly higher e2e latency and potential for collective timeouts.

Solutions

Process based async checkpointing

DCP now supports async checkpoint save via a background process. This helps avoid the training QPS drop by eliminating the python GIL contention with the trainer threads. Please see Fig 2 for checkpointing via threads and Fig 3 for checkpointing via background process.

Caching of the save plans

DCP has a clear boundary between the planning and storage I/O steps. SavePlanner in DCP is a stateful component which acts as an access proxy to the state_dict. Planner manages save plans prepared by individual ranks, which carry metadata information necessary to do the write I/O. The planning step involves a collective operation to gather a comprehensive view of the checkpoint on the coordinator rank. The coordinator rank is responsible for de-duplicating parameters/weights to eliminate redundancies, validating the global plan to ensure accuracy and consistency, and creating the global metadata structs. This is followed by a scatter collective where the coordinator rank assigns I/O tasks to each rank. Any transformations done on the plans affect how the storage components finally write the data.

During the course of a training job, multiple checkpoints are saved. In the majority of these cases, only the checkpoint data changes between different save instances, and thus, the plan remains the same. This presented an opportunity for us to cache the plans, pay the planning cost only on the first save, and then amortize that cost across all the subsequent attempts. Only the updated plans (plans which changed in the next attempt) are sent via collective, thus reducing the collective overhead significantly.

Experiment Results

Set up: 1856 H200 GPUs, Llama3-70B, HSDP2 with TorchTitan

After deploying both the solutions above, the following are the key results:

  • TPS drop has significantly narrowed, with a peak dip to 372 vs 315 tps, and for a greatly reduced time window (~67 seconds vs ~437 seconds). This time window is now mostly attributed to the blocking for CPU processing.
  • Subsequent checkpoint save attempts also continue to be much faster due to very low overhead at the planning stage. E2E latency is thus improved by over 6.5x. This will allow our partners to increase the checkpointing frequency and reduce the lost training progress (i.e. wasted training time).

If you look at the very first downspike in Figure 1, this drawdown in GPU processing time takes training throughput from 700 down to 320 tps, and suppresses it for roughly 7 minutes (467 seconds). Once the CPUs have finished processing, training continues again at full speed.

Previously, this ~7 minute suppression would be repeated at every checkpoint. However, with the new process-based checkpointing feature, only the first checkpoint has the full drawdown time (mainly due to overhead from daemon process initialization), as all future checkpoints are executed via the background process, mitigating GIL contention with the trainer threads.

This is visually shown in all the subsequent checkpoints where the average MFU suppression time drops to just over a minute, reflected by the sharp spikes that almost immediately revert to full MFU throughput.

Fig 3: The red box shows the non-cached plan checkpoint, which also includes Checkpoint Background Init process overhead, while the purple box highlights the first checkpoint to run with the cached plan.

This means that even large-scale checkpointing, such as shown in Fig 2 at 1856 GPU scale, can be done with ~6x reduced training throughput impact. This enables Asynchronous DCP checkpointing to be run more frequently (thus better rollback protection) while enhancing total training throughput relative to previous Async Checkpointing overhead.

Using DCP’s cached checkpointing:

This feature is already available as part of the PyTorch nightly builds, and you can test out PyTorch’s Asynchronous DCP checkpointing directly in TorchTitan. Following are the instructions to enable these features:

  • Process-based asynchronous checkpointing:
    • Set the async_checkpointer_type to AsyncCheckpointerType.PROCESS in the async_save API. (file: pytorch/torch/distributed/checkpoint/state_dict_saver.py)
  • Save plan caching:
    • Set the enable_plan_caching flag to true in the DefaultSavePlanner. (file: pytorch/torch/distributed/checkpoint/default_planner.py)

Future work

DCP will be rolling out additional optimizations to further improve the checkpointing cost. Currently even though the save plans are cached, coordinator rank still prepares the metadata. For larger jobs and models with many tensors, this overhead is non-trivial. In the next iteration, DCP will eliminate the metadata overhead and improve the e2e latency further. DCP will also introduce additional optimizations, such as zero-overhead checkpointing, to enable efficient checkpointing in large-scale jobs.

Stay tuned!

Read More

FlexAttention Part II: FlexAttention for Inference

FlexAttention Part II: FlexAttention for Inference

Overview

In PyTorch 2.5.0 release, we introduced FlexAttention torch.nn.attention.flex_attention for ML researchers who’d like to customize their attention kernels without writing kernel code. This blog introduces our decoding backend optimized for inference, supporting GQA and PagedAttention, along with feature updates including nested jagged tensor support, performance tuning guides and trainable biases support.

If you’re looking for an easy way to play around with FlexAttention in your post-training / inference pipeline, PyTorch native post-training library torchtune and inference codebase gpt-fast already have FlexAttention integrated. Try it out!

We are excited to share that our paper on FlexAttention has been accepted for presentation at the MLSys2025 Conference held from May 12-15th in Santa Clara, California.

Title: FlexAttention: A Programming Model for Generating Optimized Attention Kernels. Poster

FlexAttention for Inference

TL;DR: torch.compile lowers flex_attention to a fused FlashDecoding kernel when it runs on a very short query.

One fused attention kernel does not suit all – especially in long-context LLM inference.

The decoding phase of LLM inference is an iterative process: tokens are generated one at a time, requiring N forward passes to generate an N-token sentence. Fortunately, each iteration doesn’t need to recompute self-attention over the full sentence — previously calculated tokens are cached, therefore we only need to attend the newly generated token to the cached context.

chart

This results in a unique attention pattern where a short query sequence (1 token) attends to a long key-value cache (context length up to 128k). Traditional optimizations for square attention kernels (q_len ≈ kv_len) don’t directly apply here. This pattern poses new challenges for GPU memory utilization and occupancy. We build a dedicated FlexDecoding backend optimized for long-context LLM inference incorporating decoding-specific techniques from FlashDecoding.

FlexDecoding is implemented as an alternative backend for the torch.nn.attention.flex_attention operator. flex_attention automatically switches to the FlexDecoding backend for its JIT compilation when given a short query and a long KV cache. If the input shape changes significantly, for example transitioning from the prefill phase to decoding, JIT recompilation generates a separate kernel for each scenario.

flex_attention = torch.compile(flex_attention)

k_cache = torch.random(B, H, 16384, D) 
v_cache = torch.random(B, H, 16384, D)

...

# Prefill Phase: query shape = [B, H, 8000, D]
flex_attention(q_prefill, k_cache, v_cache, ...) # Uses FlexAttention backend optimized for prefill & training

# Decoding Phase: q_last_token shape = [B, H, 1, D]
flex_attention(q_last_token  , k_cache, v_cache, ...) # Recompiles with the FlexDecoding backend 

# decode 2 tokens at the same time: q_last_2_tokens shape = [B, H, 2, D]
flex_attention(q_last_2_tokens, k_cache, v_cache, ...) # No recompilation needed! Runs the decoding kernel again.

Working with KV Cache

One of the key optimizations for efficient inference is maintaining a preallocated KV cache that updates in place as new tokens are generated. Instead of enforcing a specific KV cache policy with a dedicated API, FlexDecoding allows users to define and manage the KV cache themselves.

Similar to FlexAttention, FlexDecoding takes user-defined mask_mod and score_mod functions. These functions modify attention scores before the softmax operation.

chart

score_mod(score, b, h, q_idx, kv_idx) -> tensor # return updated score

Score is a scalar pytorch tensor that represents the dot product of a query token and a key token. The rest of the arguments specify which score is being computed:

  • b batch index
  • h attention head index
  • q_idx token position in query tensor
  • kv_idx token position in key/value tensor

In the decoding phase, previously calculated tokens are cached, and only the latest generated token (i-th) is used as the query. A naive causal mask on this one token query looks like this:

def causal(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, -float("inf"))

chart

This is problematic: the new token “saw” should attend to all previously generated tokens i.e. “The cat sat on the mat and saw”, not just the first entry in the kv cache. To correct this, the score_mod needs to offset q_idx by i for accurate decoding.

chart

Creating a new score_mod for each token to accommodate the offset is slow since it means FlexAttention needs to be recompiled every iteration for a different score_mod. Instead,

We define this offset as a tensor and increment its value at each iteration:

offset = torch.tensor(i, "cuda")
def causal_w_offset(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx + offset >= kv_idx, score, -float("inf"))

# Attend the i-th token
flex_attention(..., score_mod=causal_w_offset  ) # Compiles the kernel here 
...
# Attend the i+1-th token
offset = offset + 1 # Increment offset
flex_attention(..., score_mod=causal_w_offset ) # Doesn't need to recompile! 

Notably, here offset becomes a captured tensor and it does not need to recompile if offset changes values.

Manually rewriting your score_mod and mask_mod for offset handling isn’t necessary. We can automate this process with a generic rewriter:

offset = torch.tensor(i, "cuda")

def get_score_mod_w_offset(score_mod: _score_mod_signature, _offset: tensor):
    def _score_mod(score, b, h, q, kv):
        return score_mod(score, b, h, q + _offset, kv)
    return _score_mod

def get_mask_mod_w_offset(mask_mod: _mask_mod_signature, _offset: tensor):
    def _mask_mod(b, h, q, kv):
        return mask_mod(b, h, q + _offset, kv)
    return _mask_mod

causal_w_offset = get_score_mod_w_offset(causal, offset)

BlockMask for Inference

We can also use BlockMask with inference to leverage mask sparsity. The idea is to precompute the BlockMask once during model setup and use slices of it during decoding

Precomputing BlockMask

During setup, we create a squared BlockMask for MAX_SEQ_LEN x MAX_SEQ_LEN:

from torch.nn.attention.flex_attention import create_block_mask

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

block_mask = create_block_mask(causal_mask, B=None, H=None, Q_LEN=MAX_SEQ_LEN,KV_LEN=MAX_SEQ_LEN)

chart

Using BlockMask During Decoding

For the i-th token, we use a slice of the mask:

block_offset = i // block_mask.BLOCK_SIZE[0]
block_mask_slice = block_mask[:, :, block_offset]

# don't forget to use the mask_mod with offset! 
block_mask_slice.mask_mod = get_mask_mod_w_offset(causal_mask)

chart

Performance

chart

FlexDecoding kernel performs on par with FlashDecoding (FAKV) and significantly outperforms pytorch scaled_dot_product_attention (code).

chart

FlexDecoding boosts LLaMa3.1-8B serving performance by 1.22x-2.04x, and LLaMa3.1-70B performance by 0.99x – 1.66x compared to SDPA in gpt-fast. (code)

Paged Attention

vLLM is one of the popular LLM serving engines, powered by the efficient memory management from PagedAttention. Existing PagedAttention implementation requires dedicated CUDA kernels and shows limited flexibility on supporting emerging attention variants. In this section, we present a PT2-native PagedAttention implementation that is enabled by flex attention and torch.compile.

PagedAttention scatters KV cache to reduce memory fragmentation and support higher batch sizes. Without PagedAttention, KV cache from the same request are stored in a contiguous memory, requiring 2 tensor of shape B x H x KV LEN x D. We call it a logical KV cache. Here, KV_LEN is the maximum sequence length over all requests in a batch. Considering the Figure 1(a), KV_LEN is 9 thus all requests must be padded to 9 tokens, leading to large memory waste. With PagedAttention, we can chunk each request into multiple pages of the same size page_size and scatter these pages into a physical KV cache of shape 1 x H x max seq len x D, where max_seq_len=n_pages x page_size. This avoids padding requests to the same length and saves memory. Specifically, we provide an assign API to update KV cache via index computations:

def assign(
    batch_idx: torch.Tensor,
    input_pos: torch.Tensor,
    k_val: torch.Tensor,
    v_val: torch.Tensor,
    k_cache: torch.Tensor,
    v_cache: torch.Tensor,
) -> None

Behind this assign API is a page table, a tensor mapping logical KV cache to physical KV cache:

[batch_idx, logical_page_idx] -> physical_page_idx

assign takes k_val and v_val and scatters to physical KV cache guided by the mapping from the page table.

chart

Paged Attention with Page Table

A natural question is, how to integrate PagedAttention with flex attention to support diverse attention variants? A naive idea is to materialize the logical KV cache before computing with flex attention. But this leads to redundant memory copy and bad performance. Another idea is to build a dedicated CUDA or Triton kernel for paged attention, similar to existing PagedAttention implementation. However, this adds much manual effort and code complexity.

Instead, we design a fused indirect memory access by converting a logical block mask according to the page table. In FlexAttention, we exploit BlockMask to identify logical blocks and skip redundant computation. While Paged Attention adds an extra layer of indirect memory access, we can further convert the logical block mask to the physical block mask corresponding to the page table, as illustrated in Figure 2. Our PagedAttention implementation provides a convert_logical_block_mask via torch.gather calls:

def convert_logical_block_mask(
    block_mask: BlockMask,
    batch_idx: Optional[torch.Tensor] = None,
) -> BlockMask

chart

Paged Attention via Block Mask Conversion

One remaining question is how to rewrite user-specified mask_mod and score_mod for PagedAttention. When users specify these modifications, they write with logical indices without the knowledge of the page table maintained at runtime. The following code shows an automated conversion at runtime which is necessary to rewrite user-specified modifications with physical kv indices. The new_mask_mod would take the physical_kv_idx and convert it back to the logical_kv_idx and apply user-specified mask_mod on the logical_kv_idx for the correct mask. For efficiency, we maintain physical_to_logical as a mapping from physical_kv_block to logical_kv_block to facilitate the conversion. For correctness, we mask out-of-boundary blocks as False with a torch.where call. After batching logical KV caches from multiple requests into the same physical KV cache, there are much more physical blocks than the number of logical blocks for each request. Thus, a physical block may not have a corresponding logical block for a specific request during block mask conversion. By masking as False with torch.where, we can ensure the correctness that data from different requests do not interfere with each other. Similarly, we can convert the score_mod automatically.

def get_mask_mod(mask_mod: Optional[_mask_mod_signature]) -> _mask_mod_signature:
    if mask_mod is None:
        mask_mod = noop_mask

    def new_mask_mod(
        b: torch.Tensor,
        h: torch.Tensor,
        q_idx: torch.Tensor,
        physical_kv_idx: torch.Tensor,
    ):
        physical_kv_block = physical_kv_idx // page_size
        physical_kv_offset = physical_kv_idx % page_size
        logical_block_idx = physical_to_logical[b, physical_kv_block]
        logical_kv_idx = logical_block_idx * page_size + physical_kv_offset
        return torch.where(
            logical_block_idx >= 0, mask_mod(b, h, q_idx, logical_kv_idx), False
        )

    return new_mask_mod

Figure 3 demonstrates the latency from Paged Attention (code). Overall, there is less than 5% overhead from Flex Attention with Paged Attention, compared with Flex Attention only. We also observe an on-par performance with Flash Attention v2. A minimal serving example further shows that PagedAttention can support 76x higher batch size when evaluating on OpenOrca dataset which includes 1M GPT-4 completions and 3.2M GPT-3.5 completions.

chart

Paged Attention: Latency under diverse sequence length

Ragged input sequences with Nested Jagged Tensors (NJTs)

FlexAttention now supports ragged-sized input sequences through the use of Nested Jagged Tensors (NJTs). NJTs represent ragged-sized sequences by packing sequences into a single “stacked sequence” and maintaining a set of offsets delimiting sequence boundaries for each batch item.

A block mask can be created for input NJTs through the new create_nested_block_mask() API. The returned block mask is compatible with the ragged structure of the given NJT, treating it as a single “stacked sequence” with inter-sequence attention automatically masked out. The mask_mod or score_mod function can be written as usual.

from torch.nn.attention.flex_attention import create_nested_block_mask, flex_attention

BATCH = 8
NUM_HEADS = 8
D = 16
device = "cuda"

# Input NJTs of shape (BATCH, SEQ_LEN*, D) with ragged SEQ_LEN
sequence_lengths = [torch.randint(5, 30, ()).item() for _ in range(BATCH)]
query = torch.nested.nested_tensor([
    torch.randn(seq_len, NUM_HEADS * D, device=device)
    for seq_len in sequence_lengths
], layout=torch.jagged)
key = torch.randn_like(query)
value = torch.randn_like(query)

# View as shape (BATCH, NUM_HEADS, SEQ_LEN*, HEAD_DIM)
query = query.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)
key = key.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)
value = value.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)

# Simple causal mask
def my_mask_mod(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

# Construct a block mask using the ragged structure of the
# specified query NJT. Ragged-sized sequences are treated as a single
# "stacked sequence" with inter-sequence attention masked out.
block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query)

# For cross attention, create_nested_block_mask() also supports a
# rectangular block mask using the ragged structures of both query / key.
#block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query, key)

output = flex_attention(query, key, value, block_mask=block_mask)

Trainable Biases

FlexAttention now supports trainable parameters in score_mod functions. This feature enables users to reference tensors that require gradients within their score_mod implementations, with gradients automatically backpropagating through these parameters during training.

Memory-Efficient Gradient Accumulation

Instead of materializing the full attention scores matrix, FlexAttention uses atomic additions (tl.atomic_add) to accumulate gradients. This approach significantly reduces memory usage at the cost of introducing some non-determinism in gradient calculations.

Handling Broadcasted Operations

Broadcasting operations in the forward pass (e.g., score + bias[h]) require special consideration in the backward pass. When broadcasting a tensor across multiple attention scores within a head or other dimensions, we need to reduce these gradients back to the original tensor shape. Rather than materializing the full attention score matrix to perform this reduction, we use atomic operations. While this incurs some runtime overhead, it allows us to maintain memory efficiency by avoiding the materialization of large intermediate tensors.

Current Limitations

The implementation currently allows only a single read from each input tensor in the score_mod function. For example, bias[q_idx] + bias[kv_idx] would not be supported as it reads from the same tensor twice. We hope to remove this restriction in the future.

Simple Example:

bias = torch.randn(num_heads, requires_grad=True)
def score_mod(score, b, h, q_idx, kv_idx):
    return score + bias[h]  

Performance Tuning for FlexAttention

TL;DR

For optimal performance, compile FlexAttention using max-autotune, especially when dealing with complex score_mods and mask_mods:

flex_attention = torch.compile(flex_attention, dynamic=True, mode=’max-autotune’)

What is max-autotune?

max-autotune is a torch.compile mode in which TorchInductor sweeps many kernel parameters (e.g., tile size, num_stages) and selects the best-performing configuration. This process allows kernels to test both successful and failing configurations without issues, and find the best viable configuration.

While compilation takes longer with max-autotune, the optimal configuration is cached for future kernel executions.

Here’s an example of FlexAttention compiled with max-autotune:

triton_flex_attention_backward_7 0.2528 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=32, BLOCK_N1=32, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=7, HAS_FULL_BLOCKS=False, IS_DIVISIBLE=False, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=128, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.08838834764831843, SPARSE_KV_BLOCK_SIZE=1073741824, SPARSE_Q_BLOCK_SIZE=1073741824, V_HEAD_DIM=128, num_stages=4, num_warps=4

Why Use max-autotune for FlexAttention?

The amount of shared memory utilized in FlexAttention depends on score_mod and mask_mod methods. This variability means that the preconfigured default kernel parameters may lead to performance cliffs or even out of shared memory** **errors on certain hardware for some masks/mods.

For instance, with document masks, default configurations can halve GPU occupancy, reducing performance to ~75% of its potential on some GPUs. To avoid such issues, we strongly recommend enabling max-autotune.

Updates and Enhancements

  • Now available as a prototype feature in PyTorch 2.5.0
  • Fixed critical correctness issues, including a bug affecting multiple calls to FlexAttention within the same call to torch.compile

Expanded Architecture Support

  • Arbitrary sequence length support – no longer requires multiples of 128
  • Added native grouped-query attention (GQA) support via is_gqa=True
  • Enhanced dimension flexibility:
    • Different QK and V head dimensions
    • Non-power-of-two head dimensions
  • Trainable attention biases (prototype)

Under the Hood

  • New fused CPU backend
  • Improved TF32 handling for float32 inputs
  • Resolved various dynamic shape issues
  • Output layout matching query strides

These updates make FlexAttention more robust and flexible while maintaining its core promise of combining PyTorch’s ease of use with FlashAttention’s performance benefits.

Read More

6x faster Async Checkpointing in PyTorch, using Cached Plans, no GIL contention

6x faster Async Checkpointing in PyTorch, using Cached Plans, no GIL contention

Meta: Less Wright, Meet Vadakkanchery, Saurabh Mishra, Ela Krepska, Hamid Shojanazeri, Pradeep Fernando
Crusoe: Ethan Petersen, Martin Cala, Chip Smith

PyTorch DCP (Distributed Checkpointing) has recently enabled new optimizations in asynchronous checkpointing to reduce GPU utilization drop by minimizing collective overhead and improving overall checkpointing efficiency.

Using Crusoe’s 2K H200 cluster, with TorchTitan and training a Llama3-70B, we were able to verify these new features deliver substantial speedups at 1856 GPU scale, reducing the background processing time for async DCP checkpoints from ~436 seconds to ~67 seconds.

This is roughly a 6.5x reduction in background checkpoint processing time, enabling even more total training time to proceed at full training throughput.

chart

Fig 1: 1856 training run with high frequency checkpointing. The first checkpoint (drop down in tps) does not have a cached save plan, and the background processing takes far longer than the rest where the cached plan is used.

Background: What is Asynchronous Checkpointing?

In a standard checkpointing workflow, GPUs are blocked while the checkpointing data is offloaded from GPU to CPU and then written to storage. After the save to physical media is complete, training can resume.

Asynchronous checkpointing greatly reduces this downtime by enabling the actual saving to storage to be done via CPU threads, allowing GPU-based training to continue while the checkpoint data is being persisted in parallel. It is used primarily for intermediate/fault tolerant checkpoints as it unblocks the GPUs much faster compared to the synchronous checkpoints.
For example, in our large-scale experiment, GPU training was blocked for less than a second (.78 seconds at 1856 scale) while checkpoint data was moved from GPU to CPU (staging). At that point, GPU training immediately continues, which is a substantial training time improvement over traditional checkpointing. For reference, Async Checkpointing is covered in more detail here.

Challenges with Asynchronous Checkpointing

However, the background processing inherent in Asynchronous Checkpointing has additional challenges that result in a temporary reduction of training throughput while the storage phase is being completed. These are highlighted below.

GPU utilization drop from GIL contention:

The Global Interpreter Lock (GIL) in Python is a mechanism that prevents multiple native threads from executing Python bytecode at the same time. This lock is necessary mainly because CPython’s memory management is not thread-safe.

DCP currently uses background threads for metadata collectives and uploading to storage. Although these expensive steps are done asynchronously, it leads to contention for the GIL with the trainer threads. This causes the GPU utilization (QPS) to suffer significantly and also increases the e2e upload latency. For large-scale checkpoints, the overhead of the CPU parallel processing has a suppressive effect on net GPU training speed since CPUs also drive the training process via GPU kernel launches.

Please refer to the following figure from our experiments:

chart

Fig 2: One can see a sustained drop in training QPS even after staging (i.e. blocking operation to trainer) is complete.

The first dip in Figure 2 (marked by the purple line) indicates that staging is complete, and training can continue. However, a second drop is evident (marked by the area between the purple and yellow lines) which is due to trainer thread and checkpointing threads contending for the Python GIL, leading to degraded training QPS until the checkpoint thread completes execution.

Collective communications cost:

DCP performs multiple collectives today for various reasons: dedupe, global metadata for the checkpoint, resharding, and distributed exception handling. Collectives are costly as these require network I/O and pickling/unpickling of the large metadata being sent across the GPU network. These collectives become extremely expensive as the job scale grows, leading to significantly higher e2e latency and potential for collective timeouts.

Solutions

Process based async checkpointing

DCP now supports async checkpoint save via a background process. This helps avoid the training QPS drop by eliminating the python GIL contention with the trainer threads. Please see Fig 2 for checkpointing via threads and Fig 3 for checkpointing via background process.

Caching of the save plans

DCP has a clear boundary between the planning and storage I/O steps. SavePlanner in DCP is a stateful component which acts as an access proxy to the state_dict. Planner manages save plans prepared by individual ranks, which carry metadata information necessary to do the write I/O. The planning step involves a collective operation to gather a comprehensive view of the checkpoint on the coordinator rank. The coordinator rank is responsible for de-duplicating parameters/weights to eliminate redundancies, validating the global plan to ensure accuracy and consistency, and creating the global metadata structs. This is followed by a scatter collective where the coordinator rank assigns I/O tasks to each rank. Any transformations done on the plans affect how the storage components finally write the data.

During the course of a training job, multiple checkpoints are saved. In the majority of these cases, only the checkpoint data changes between different save instances, and thus, the plan remains the same. This presented an opportunity for us to cache the plans, pay the planning cost only on the first save, and then amortize that cost across all the subsequent attempts. Only the updated plans (plans which changed in the next attempt) are sent via collective, thus reducing the collective overhead significantly.

Experiment Results

Set up: 1856 H200 GPUs, Llama3-70B, HSDP2 with TorchTitan

After deploying both the solutions above, the following are the key results:

  • TPS drop has significantly narrowed, with a peak dip to 372 vs 315 tps, and for a greatly reduced time window (~67 seconds vs ~437 seconds). This time window is now mostly attributed to the blocking for CPU processing.
  • Subsequent checkpoint save attempts also continue to be much faster due to very low overhead at the planning stage. E2E latency is thus improved by over 6.5x. This will allow our partners to increase the checkpointing frequency and reduce the lost training progress (i.e. wasted training time).

If you look at the very first downspike in Figure 1, this drawdown in GPU processing time takes training throughput from 700 down to 320 tps, and suppresses it for roughly 7 minutes (467 seconds). Once the CPUs have finished processing, training continues again at full speed.

Previously, this ~7 minute suppression would be repeated at every checkpoint. However, with the new process-based checkpointing feature, only the first checkpoint has the full drawdown time (mainly due to overhead from daemon process initialization), as all future checkpoints are executed via the background process, mitigating GIL contention with the trainer threads.

This is visually shown in all the subsequent checkpoints where the average MFU suppression time drops to just over a minute, reflected by the sharp spikes that almost immediately revert to full MFU throughput.

chart

Fig 3: The red box shows the non-cached plan checkpoint, which also includes Checkpoint Background Init process overhead, while the purple box highlights the first checkpoint to run with the cached plan.

This means that even large-scale checkpointing, such as shown in Fig 2 at 1856 GPU scale, can be done with ~6x reduced training throughput impact. This enables Asynchronous DCP checkpointing to be run more frequently (thus better rollback protection) while enhancing total training throughput relative to previous Async Checkpointing overhead.

Using DCP’s cached checkpointing:

This feature is already available as part of the PyTorch nightly builds, and you can test out PyTorch’s Asynchronous DCP checkpointing directly in TorchTitan. Following are the instructions to enable these features:

  • Process-based asynchronous checkpointing:
    • Set the async_checkpointer_type to AsyncCheckpointerType.PROCESS in the async_save API. (file: pytorch/torch/distributed/checkpoint/state_dict_saver.py)
  • Save plan caching:
    • Set the enable_plan_caching flag to true in the DefaultSavePlanner. (file: pytorch/torch/distributed/checkpoint/default_planner.py)

Future work

DCP will be rolling out additional optimizations to further improve the checkpointing cost. Currently even though the save plans are cached, coordinator rank still prepares the metadata. For larger jobs and models with many tensors, this overhead is non-trivial. In the next iteration, DCP will eliminate the metadata overhead and improve the e2e latency further. DCP will also introduce additional optimizations, such as zero-overhead checkpointing, to enable efficient checkpointing in large-scale jobs.

Stay tuned!

Read More

PyTorch Foundation Expands to an Umbrella Foundation to Accelerate AI Innovation

Today, I am thrilled to announce a significant milestone for the PyTorch Foundation: we are expanding our scope to become an umbrella foundation, allowing us to host additional projects. This expansion positions the PyTorch Foundation to foster a broader ecosystem of high-value, trusted, and innovative AI projects that cater to all stages of the AI lifecycle—from training and inference to industry-specific applications.

Why Expand?

Since its inception at the Linux Foundation two and a half years ago, the PyTorch Foundation has rapidly grown, now encompassing over 30 member organizations and 120 vibrant ecosystem projects. PyTorch itself has become the framework of choice for AI researchers, practitioners, and industry leaders worldwide. Our flagship PyTorch Conference has seen attendance multiply sixfold over just two years, reflecting the community’s tremendous enthusiasm and engagement.

With new initiatives such as PyTorch Day events, global community meetups, the PyTorch Ambassador Program, Open Source Program Office (OSPO) outreach, the Speaker’s Bureau, and our upcoming training and certification programs, we have significantly deepened our community’s expertise and collaboration capabilities. To sustain and accelerate this momentum, the logical next step was to expand the PyTorch Foundation into an umbrella organization.

What Does an Umbrella Foundation Mean?

By transitioning into an umbrella foundation, PyTorch will now host a range of diverse, high-quality AI and ML projects beyond PyTorch Core. These include foundation-hosted projects in two categories:

  • Platform Projects: Domain-agnostic solutions essential across various stages of the AI lifecycle, such as training, inference, model optimization, and deployment as well as agentic systems.
  • Vertical Projects: Domain-specific projects tailored to particular industries or applications, such as biomedical imaging, protein folding, and geospatial analysis.

Projects under our umbrella gain immediate access to vendor-neutral governance, enhanced visibility, increased funding opportunities, and robust community engagement and support.

Foundation-Hosted vs. Ecosystem Projects

As we expand, it’s important to clarify the distinction between foundation-hosted and ecosystem projects:

  • Foundation-Hosted Projects are projects that fall under the umbrella, they are officially governed and administered under the PyTorch Foundation’s neutral and transparent governance model. Project maintainers continue to oversee their project, and they transfer assets to the Linux Foundation for independent stewardship and adopt an open governance model significantly reducing vendor bias and encouraging broader community contributions and adoption. These projects have greater stability and longevity and integrate with the larger PyTorch community.
  • Ecosystem Projects remain independently managed but receive recognition and increased visibility by aligning themselves closely with the PyTorch Foundation community standards. These projects meet specific quality and maturity criteria but retain full independence in governance and asset management.

How to Join the PyTorch Ecosystem or Become a Foundation-Hosted Project

We have clearly defined pathways for projects looking to become part of the PyTorch community:

  1. Ecosystem Project Status: Projects must meet defined criteria, such as active development, comprehensive documentation, CI/CD infrastructure, clear governance, and community engagement. Approved ecosystem projects benefit from increased exposure and official recognition on the PyTorch Landscape.
  2. Candidate Project Status: Ecosystem projects aspiring to foundation-hosted status can become candidates by securing sponsorship from a PyTorch Foundation Technical Advisory Council (TAC) voting member. Candidates receive guidance on meeting all necessary governance, technical, and strategic criteria.
  3. Foundation-Hosted Project Status: Candidate projects demonstrating high maturity, stability, multi-platform support, security best practices, and strategic value to the PyTorch community can be approved by the TAC. These projects gain extensive benefits, including neutral trademark hosting, foundation support, marketing and events resources, governance guidance, and strategic funding opportunities.

Ensuring Long-Term Success and Innovation

By expanding our scope to become an umbrella foundation, the PyTorch Foundation is uniquely positioned to enhance collaboration, innovation, and sustained growth across the entire AI community. Our mission is clear: create a vendor-neutral, open source environment where the best AI and ML tools can thrive, benefiting users, contributors, and industry stakeholders worldwide.

“PyTorch is absolutely the foundation of the innovation happening in AI today and with projects like Llama, ChatGPT, and hundreds of thousands of open projects built on PyTorch, it has cemented itself as a critical ingredient to the world of AI. This move to create an umbrella foundation enables PyTorch to significantly expand its ecosystem both horizontally and vertically in this new era of agentic systems. I am very excited about this opportunity to take the PyTorch community to the next level!” – Joe Spisak, Product Director for PyTorch at Meta.

“PyTorch sits at the very core of AI today. Meanwhile, the depth of the AI stack has grown dramatically—evolving from enabling accelerated compute to powering fully autonomous systems. Broadening the PyTorch Foundation is a key step in keeping the AI revolution open and accessible to all, across the stack and aligned with the principles PyTorch was built on.” – Luca Antiga, CTO at Lightning AI.

We are incredibly optimistic about the opportunities ahead and excited to welcome new projects into our growing family. The PyTorch Foundation remains deeply committed to driving AI innovation forward, and together, we will continue to build the future of open source artificial intelligence.

Stay tuned for more updates, announcements, and opportunities to participate!

Read More

PyTorch Foundation Expands to an Umbrella Foundation to Accelerate AI Innovation

Today, I am thrilled to announce a significant milestone for the PyTorch Foundation: we are expanding our scope to become an umbrella foundation, allowing us to host additional projects. This expansion positions the PyTorch Foundation to foster a broader ecosystem of high-value, trusted, and innovative AI projects that cater to all stages of the AI lifecycle—from training and inference to industry-specific applications.

Why Expand?

Since its inception at the Linux Foundation two and a half years ago, the PyTorch Foundation has rapidly grown, now encompassing over 30 member organizations and 120 vibrant ecosystem projects. PyTorch itself has become the framework of choice for AI researchers, practitioners, and industry leaders worldwide. Our flagship PyTorch Conference has seen attendance multiply sixfold over just two years, reflecting the community’s tremendous enthusiasm and engagement.

With new initiatives such as PyTorch Day events, global community meetups, the PyTorch Ambassador Program, Open Source Program Office (OSPO) outreach, the Speaker’s Bureau, and our upcoming training and certification programs, we have significantly deepened our community’s expertise and collaboration capabilities. To sustain and accelerate this momentum, the logical next step was to expand the PyTorch Foundation into an umbrella organization.

What Does an Umbrella Foundation Mean?

By transitioning into an umbrella foundation, PyTorch will now host a range of diverse, high-quality AI and ML projects beyond PyTorch Core. These include foundation-hosted projects in two categories:

  • Platform Projects: Domain-agnostic solutions essential across various stages of the AI lifecycle, such as training, inference, model optimization, and deployment as well as agentic systems.
  • Vertical Projects: Domain-specific projects tailored to particular industries or applications, such as biomedical imaging, protein folding, and geospatial analysis.

Projects under our umbrella gain immediate access to vendor-neutral governance, enhanced visibility, increased funding opportunities, and robust community engagement and support.

Foundation-Hosted vs. Ecosystem Projects

As we expand, it’s important to clarify the distinction between foundation-hosted and ecosystem projects:

  • Foundation-Hosted Projects are projects that fall under the umbrella, they are officially governed and administered under the PyTorch Foundation’s neutral and transparent governance model. Project maintainers continue to oversee their project, and they transfer assets to the Linux Foundation for independent stewardship and adopt an open governance model significantly reducing vendor bias and encouraging broader community contributions and adoption. These projects have greater stability and longevity and integrate with the larger PyTorch community.
  • Ecosystem Projects remain independently managed but receive recognition and increased visibility by aligning themselves closely with the PyTorch Foundation community standards. These projects meet specific quality and maturity criteria but retain full independence in governance and asset management.

How to Join the PyTorch Ecosystem or Become a Foundation-Hosted Project

We have clearly defined pathways for projects looking to become part of the PyTorch community:

  1. Ecosystem Project Status: Projects must meet defined criteria, such as active development, comprehensive documentation, CI/CD infrastructure, clear governance, and community engagement. Approved ecosystem projects benefit from increased exposure and official recognition on the PyTorch Landscape.
  2. Candidate Project Status: Ecosystem projects aspiring to foundation-hosted status can become candidates by securing sponsorship from a PyTorch Foundation Technical Advisory Council (TAC) voting member. Candidates receive guidance on meeting all necessary governance, technical, and strategic criteria.
  3. Foundation-Hosted Project Status: Candidate projects demonstrating high maturity, stability, multi-platform support, security best practices, and strategic value to the PyTorch community can be approved by the TAC. These projects gain extensive benefits, including neutral trademark hosting, foundation support, marketing and events resources, governance guidance, and strategic funding opportunities.

Ensuring Long-Term Success and Innovation

By expanding our scope to become an umbrella foundation, the PyTorch Foundation is uniquely positioned to enhance collaboration, innovation, and sustained growth across the entire AI community. Our mission is clear: create a vendor-neutral, open source environment where the best AI and ML tools can thrive, benefiting users, contributors, and industry stakeholders worldwide.

“PyTorch is absolutely the foundation of the innovation happening in AI today and with projects like Llama, ChatGPT, and hundreds of thousands of open projects built on PyTorch, it has cemented itself as a critical ingredient to the world of AI. This move to create an umbrella foundation enables PyTorch to significantly expand its ecosystem both horizontally and vertically in this new era of agentic systems. I am very excited about this opportunity to take the PyTorch community to the next level!” – Joe Spisak, Product Director for PyTorch at Meta.

“PyTorch sits at the very core of AI today. Meanwhile, the depth of the AI stack has grown dramatically—evolving from enabling accelerated compute to powering fully autonomous systems. Broadening the PyTorch Foundation is a key step in keeping the AI revolution open and accessible to all, across the stack and aligned with the principles PyTorch was built on.” – Luca Antiga, CTO at Lightning AI.

We are incredibly optimistic about the opportunities ahead and excited to welcome new projects into our growing family. The PyTorch Foundation remains deeply committed to driving AI innovation forward, and together, we will continue to build the future of open source artificial intelligence.

Stay tuned for more updates, announcements, and opportunities to participate!

Read More

PyTorch Foundation Expands to an Umbrella Foundation to Accelerate AI Innovation

Today, I am thrilled to announce a significant milestone for the PyTorch Foundation: we are expanding our scope to become an umbrella foundation, allowing us to host additional projects. This expansion positions the PyTorch Foundation to foster a broader ecosystem of high-value, trusted, and innovative AI projects that cater to all stages of the AI lifecycle—from training and inference to industry-specific applications.

Why Expand?

Since its inception at the Linux Foundation two and a half years ago, the PyTorch Foundation has rapidly grown, now encompassing over 30 member organizations and 120 vibrant ecosystem projects. PyTorch itself has become the framework of choice for AI researchers, practitioners, and industry leaders worldwide. Our flagship PyTorch Conference has seen attendance multiply sixfold over just two years, reflecting the community’s tremendous enthusiasm and engagement.

With new initiatives such as PyTorch Day events, global community meetups, the PyTorch Ambassador Program, Open Source Program Office (OSPO) outreach, the Speaker’s Bureau, and our upcoming training and certification programs, we have significantly deepened our community’s expertise and collaboration capabilities. To sustain and accelerate this momentum, the logical next step was to expand the PyTorch Foundation into an umbrella organization.

What Does an Umbrella Foundation Mean?

By transitioning into an umbrella foundation, PyTorch will now host a range of diverse, high-quality AI and ML projects beyond PyTorch Core. These include foundation-hosted projects in two categories:

  • Platform Projects: Domain-agnostic solutions essential across various stages of the AI lifecycle, such as training, inference, model optimization, and deployment as well as agentic systems.
  • Vertical Projects: Domain-specific projects tailored to particular industries or applications, such as biomedical imaging, protein folding, and geospatial analysis.

Projects under our umbrella gain immediate access to vendor-neutral governance, enhanced visibility, increased funding opportunities, and robust community engagement and support.

Foundation-Hosted vs. Ecosystem Projects

As we expand, it’s important to clarify the distinction between foundation-hosted and ecosystem projects:

  • Foundation-Hosted Projects are projects that fall under the umbrella, they are officially governed and administered under the PyTorch Foundation’s neutral and transparent governance model. Project maintainers continue to oversee their project, and they transfer assets to the Linux Foundation for independent stewardship and adopt an open governance model significantly reducing vendor bias and encouraging broader community contributions and adoption. These projects have greater stability and longevity and integrate with the larger PyTorch community.
  • Ecosystem Projects remain independently managed but receive recognition and increased visibility by aligning themselves closely with the PyTorch Foundation community standards. These projects meet specific quality and maturity criteria but retain full independence in governance and asset management.

How to Join the PyTorch Ecosystem or Become a Foundation-Hosted Project

We have clearly defined pathways for projects looking to become part of the PyTorch community:

  1. Ecosystem Project Status: Projects must meet defined criteria, such as active development, comprehensive documentation, CI/CD infrastructure, clear governance, and community engagement. Approved ecosystem projects benefit from increased exposure and official recognition on the PyTorch Landscape.
  2. Candidate Project Status: Ecosystem projects aspiring to foundation-hosted status can become candidates by securing sponsorship from a PyTorch Foundation Technical Advisory Council (TAC) voting member. Candidates receive guidance on meeting all necessary governance, technical, and strategic criteria.
  3. Foundation-Hosted Project Status: Candidate projects demonstrating high maturity, stability, multi-platform support, security best practices, and strategic value to the PyTorch community can be approved by the TAC. These projects gain extensive benefits, including neutral trademark hosting, foundation support, marketing and events resources, governance guidance, and strategic funding opportunities.

Ensuring Long-Term Success and Innovation

By expanding our scope to become an umbrella foundation, the PyTorch Foundation is uniquely positioned to enhance collaboration, innovation, and sustained growth across the entire AI community. Our mission is clear: create a vendor-neutral, open source environment where the best AI and ML tools can thrive, benefiting users, contributors, and industry stakeholders worldwide.

“PyTorch is absolutely the foundation of the innovation happening in AI today and with projects like Llama, ChatGPT, and hundreds of thousands of open projects built on PyTorch, it has cemented itself as a critical ingredient to the world of AI. This move to create an umbrella foundation enables PyTorch to significantly expand its ecosystem both horizontally and vertically in this new era of agentic systems. I am very excited about this opportunity to take the PyTorch community to the next level!” – Joe Spisak, Product Director for PyTorch at Meta.

“PyTorch sits at the very core of AI today. Meanwhile, the depth of the AI stack has grown dramatically—evolving from enabling accelerated compute to powering fully autonomous systems. Broadening the PyTorch Foundation is a key step in keeping the AI revolution open and accessible to all, across the stack and aligned with the principles PyTorch was built on.” – Luca Antiga, CTO at Lightning AI.

We are incredibly optimistic about the opportunities ahead and excited to welcome new projects into our growing family. The PyTorch Foundation remains deeply committed to driving AI innovation forward, and together, we will continue to build the future of open source artificial intelligence.

Stay tuned for more updates, announcements, and opportunities to participate!

Read More