How IBM Research Uses PyTorch and TerraTorch to Make Geospatial Computer Vision Accessible for Everyone

How IBM Research Uses PyTorch and TerraTorch to Make Geospatial Computer Vision Accessible for Everyone

Earth Observation-based analytics are becoming essential for understanding our planet — from monitoring deforestation to tracking urban development and analyzing the impacts of climate change. However, the coding and deep learning skills for applying AI models to satellite imagery and earth observation data has traditionally been a major barrier for many practitioners.

By IBM Research’s launch of TerraTorch 1.0, a PyTorch domain library for fine-tuning of Geospatial Computer Vision Foundation Models, we make geospatial AI not only more accessible but also more practical for the wider PyTorch community. Our goal: simplify the process so that any data scientist, researcher, or enthusiast can build powerful geospatial models with ease and low GPU and data processing requirements.

 

The power of foundation models, even with 75-95% of the input data removed, the models do a fantastic job in reconstruction of the input data – therefore learning the underlying physics of our planet in a deep, latent space

The Business Challenge

Our goal was to remove the technical barriers that prevent people from working with satellite imagery, weather and climate data at scale. Together with NASA, we’ve developed the Prithvi family of foundation models. Integrating the latest innovations of AI research using the clean API PyTorch provides has facilitated the job.

We wanted to create a framework that anyone can use to go from raw data to inference ready models in just a few steps.

 

How a weather and climate foundation model created and fine-tuned on PyTorch is used for weather forecasts

How IBM Research Used PyTorch

We’ve built TerraTorch on top of PyTorch, leveraging its dynamic ecosystem to integrate:

  • PyTorch Lightning for clean, scalable training loops
  • TorchGeo for geospatial data handling and transformations (PyTorch transforms)
  • For foundation models like the leading generative multimodal foundation model ‘Terramind’, co-developed by IBM and ESA, and the ‘Prithvi’ family, co-developed by IBM and NASA, TerraTorch has been used to fine-tune all of the downstream geospatial models for satellite imagery, weather and climate data. It includes the family of fine-tuned models that IBM has released as part of Granite. In addition, other interesting foundation models and ecosystem components like Clay, SatMAE, Satlas, DeCur and DOFA are included in TerraTorch.
  • Powerful and state-of-the-art vision transformers to experiment with modern neural network architectures
  • TerraTorch-Iterate build on top of PyTorch, Optuna, MLFlow and Ray Tune for Hyperparameter Optimization (HPO), Neural Architecture Search (NAS) and Foundation Model Benchmarking (GeoBench), where TerraTorch became the reference implementation

The fine-tuning and inference process is completely described in a single YAML config file. There, the architectural building blocks of the model (backbone, neck, decoder, head) are defined. The Model Factory assembles the model using the build-in and custom registries. In addition, the Optimizer and Data Modules are created as defined in the config. Finally, everything is passed to the Lightning Trainer, who executes the task.

With PyTorch’s flexibility, we were able to prototype quickly, iterate on model architectures, and deploy pipelines for a range of geospatial applications — from flood and biomass detection to increasing resolution of climate data, where some of our our work became part of the IBM Granite Geospatial Model Family.

 

Architecture of the Prithvi-EO-2.0-600M foundation model which IBM Research developed together with NASA

Solving AI Challenges with PyTorch

PyTorch helped us to tackle three major challenges:

  • Ease of experimentation: Dynamic computation graphs, automatic differentiation, full abstraction of CUDA and rich visualization tools made it simple to test different models and training strategies.
  • Scalability: With DDP, FSDP, PyTorch Lightning and TorchGeo, we could train models on large-scale datasets without worrying about infrastructure.
  • Community support: PyTorch – the de-facto standard in AI research – with its active community and excellent documentation made it easy to overcome hurdles and stay up to date with the latest advancements in AI research.

A Word from IBM Research

“PyTorch gave me the power to turn complex linear algebra and optimization problems into accessible, shareable solutions for the community. It feels empowering that we’re building and fine-tuning models for anyone curious about understanding our planet through AI.”

— Romeo Kienzler, AI Research Engineer at IBM Research Zurich, Rueschlikon

The Benefits of Using PyTorch

Using PyTorch allowed us to:

  • Build a reproducible, open-source framework for fine-tuning geospatial foundation models
  • Share our work with the community through easy-to-follow notebooks, TerraTorch configuration files, tutorials and model checkpoints on HuggingFace
  • Rapidly iterate over foundation model architectures and deploy fine-tuned models for inference, from research to real-world client products

Learn More

For more information about this project and to explore the code, visit:

Read More

How IBM Research Uses PyTorch and TerraTorch to Make Geospatial Computer Vision Accessible for Everyone

How IBM Research Uses PyTorch and TerraTorch to Make Geospatial Computer Vision Accessible for Everyone

Earth Observation-based analytics are becoming essential for understanding our planet — from monitoring deforestation to tracking urban development and analyzing the impacts of climate change. However, the coding and deep learning skills for applying AI models to satellite imagery and earth observation data has traditionally been a major barrier for many practitioners.

By IBM Research’s launch of TerraTorch 1.0, a PyTorch domain library for fine-tuning of Geospatial Computer Vision Foundation Models, we make geospatial AI not only more accessible but also more practical for the wider PyTorch community. Our goal: simplify the process so that any data scientist, researcher, or enthusiast can build powerful geospatial models with ease and low GPU and data processing requirements.

 

The power of foundation models, even with 75-95% of the input data removed, the models do a fantastic job in reconstruction of the input data – therefore learning the underlying physics of our planet in a deep, latent space

The Business Challenge

Our goal was to remove the technical barriers that prevent people from working with satellite imagery, weather and climate data at scale. Together with NASA, we’ve developed the Prithvi family of foundation models. Integrating the latest innovations of AI research using the clean API PyTorch provides has facilitated the job.

We wanted to create a framework that anyone can use to go from raw data to inference ready models in just a few steps.

 

How a weather and climate foundation model created and fine-tuned on PyTorch is used for weather forecasts

How IBM Research Used PyTorch

We’ve built TerraTorch on top of PyTorch, leveraging its dynamic ecosystem to integrate:

  • PyTorch Lightning for clean, scalable training loops
  • TorchGeo for geospatial data handling and transformations (PyTorch transforms)
  • For foundation models like the leading generative multimodal foundation model ‘Terramind’, co-developed by IBM and ESA, and the ‘Prithvi’ family, co-developed by IBM and NASA, TerraTorch has been used to fine-tune all of the downstream geospatial models for satellite imagery, weather and climate data. It includes the family of fine-tuned models that IBM has released as part of Granite. In addition, other interesting foundation models and ecosystem components like Clay, SatMAE, Satlas, DeCur and DOFA are included in TerraTorch.
  • Powerful and state-of-the-art vision transformers to experiment with modern neural network architectures
  • TerraTorch-Iterate build on top of PyTorch, Optuna, MLFlow and Ray Tune for Hyperparameter Optimization (HPO), Neural Architecture Search (NAS) and Foundation Model Benchmarking (GeoBench), where TerraTorch became the reference implementation

The fine-tuning and inference process is completely described in a single YAML config file. There, the architectural building blocks of the model (backbone, neck, decoder, head) are defined. The Model Factory assembles the model using the build-in and custom registries. In addition, the Optimizer and Data Modules are created as defined in the config. Finally, everything is passed to the Lightning Trainer, who executes the task.

With PyTorch’s flexibility, we were able to prototype quickly, iterate on model architectures, and deploy pipelines for a range of geospatial applications — from flood and biomass detection to increasing resolution of climate data, where some of our our work became part of the IBM Granite Geospatial Model Family.

 

Architecture of the Prithvi-EO-2.0-600M foundation model which IBM Research developed together with NASA

Solving AI Challenges with PyTorch

PyTorch helped us to tackle three major challenges:

  • Ease of experimentation: Dynamic computation graphs, automatic differentiation, full abstraction of CUDA and rich visualization tools made it simple to test different models and training strategies.
  • Scalability: With DDP, FSDP, PyTorch Lightning and TorchGeo, we could train models on large-scale datasets without worrying about infrastructure.
  • Community support: PyTorch – the de-facto standard in AI research – with its active community and excellent documentation made it easy to overcome hurdles and stay up to date with the latest advancements in AI research.

A Word from IBM Research

“PyTorch gave me the power to turn complex linear algebra and optimization problems into accessible, shareable solutions for the community. It feels empowering that we’re building and fine-tuning models for anyone curious about understanding our planet through AI.”

— Romeo Kienzler, AI Research Engineer at IBM Research Zurich, Rueschlikon

The Benefits of Using PyTorch

Using PyTorch allowed us to:

  • Build a reproducible, open-source framework for fine-tuning geospatial foundation models
  • Share our work with the community through easy-to-follow notebooks, TerraTorch configuration files, tutorials and model checkpoints on HuggingFace
  • Rapidly iterate over foundation model architectures and deploy fine-tuned models for inference, from research to real-world client products

Learn More

For more information about this project and to explore the code, visit:

Read More

Announcing the PyTorch Docathon 2025

Announcing the PyTorch Docathon 2025

We’re thrilled to announce the 2025 PyTorch Docathon! This is a hackathon-style event aimed at enhancing PyTorch documentation with the support of the community. Documentation is a vital component of any technology, and by refining it, we can simplify the onboarding process for new users, help them effectively utilize PyTorch’s features, and ultimately speed up the transition from research to production in machine learning.

WHY PARTICIPATE

Low Barrier to Entry

Unlike many open-source projects that require deep knowledge of the codebase and previous contributions to join hackathon events, the Docathon is tailored for newcomers. While we expect participants to be familiar with Python, and have basic knowledge of PyTorch and machine learning, there are tasks related to website issues that don’t even require that level of expertise.

Tangible Results

A major advantage of the Docathon is witnessing the immediate impact of your contributions. Enhancing documentation significantly boosts a project’s usability and accessibility, and you’ll be able to observe these improvements directly. Seeing tangible outcomes can also be a strong motivator to continue contributing.

Collaborative Environment

The Docathon fosters a collaborative atmosphere, offering you the chance to work alongside other contributors and PyTorch maintainers to improve the documentation. This is a fantastic opportunity to learn from peers, exchange ideas, and build connections.

Learning Opportunities

Even if you’re not a PyTorch expert, the Docathon offers a valuable learning experience. You’ll have the chance to delve into PyTorch modules, test tutorials on your machine, and explore them in the CI environment.

WHO SHOULD PARTICIPATE

Whether you’re a seasoned documentation expert or just starting out, we invite everyone to join in the PyTorch docathon to contribute and develop your skills and knowledge to help improve the documentation for everyone! We will have issues labelled by skill level, and the PyTorch Discord will be available for collaboration and help.

EVENT DETAILS

  • June 3: Kick-off 10 AM PT
  • June 4 – June 15: Submissions and Feedback
  • June 16 – June 17: Final Reviews
  • June 18: Winner Announcements

Make sure to RSVP to the event so you receive all the notifications and instructions on how to participate.

Further details about the Docathon will be shared during the Kick-off call on June 3.

Don’t forget to register for this year’s event: RSVP now

Read More

Announcing the PyTorch Docathon 2025

Announcing the PyTorch Docathon 2025

We’re thrilled to announce the 2025 PyTorch Docathon! This is a hackathon-style event aimed at enhancing PyTorch documentation with the support of the community. Documentation is a vital component of any technology, and by refining it, we can simplify the onboarding process for new users, help them effectively utilize PyTorch’s features, and ultimately speed up the transition from research to production in machine learning.

WHY PARTICIPATE

Low Barrier to Entry

Unlike many open-source projects that require deep knowledge of the codebase and previous contributions to join hackathon events, the Docathon is tailored for newcomers. While we expect participants to be familiar with Python, and have basic knowledge of PyTorch and machine learning, there are tasks related to website issues that don’t even require that level of expertise.

Tangible Results

A major advantage of the Docathon is witnessing the immediate impact of your contributions. Enhancing documentation significantly boosts a project’s usability and accessibility, and you’ll be able to observe these improvements directly. Seeing tangible outcomes can also be a strong motivator to continue contributing.

Collaborative Environment

The Docathon fosters a collaborative atmosphere, offering you the chance to work alongside other contributors and PyTorch maintainers to improve the documentation. This is a fantastic opportunity to learn from peers, exchange ideas, and build connections.

Learning Opportunities

Even if you’re not a PyTorch expert, the Docathon offers a valuable learning experience. You’ll have the chance to delve into PyTorch modules, test tutorials on your machine, and explore them in the CI environment.

WHO SHOULD PARTICIPATE

Whether you’re a seasoned documentation expert or just starting out, we invite everyone to join in the PyTorch docathon to contribute and develop your skills and knowledge to help improve the documentation for everyone! We will have issues labelled by skill level, and the PyTorch Discord will be available for collaboration and help.

EVENT DETAILS

  • June 3: Kick-off 10 AM PT
  • June 4 – June 15: Submissions and Feedback
  • June 16 – June 17: Final Reviews
  • June 18: Winner Announcements

Make sure to RSVP to the event so you receive all the notifications and instructions on how to participate.

Further details about the Docathon will be shared during the Kick-off call on June 3.

Don’t forget to register for this year’s event: RSVP now

Read More

May the Cloud Be With You: GeForce NOW Unveils 21 New Games This Month

May the Cloud Be With You: GeForce NOW Unveils 21 New Games This Month

May brings more than just rainbows and sunshine — it’s also time for fresh adventures and epic battles. This GFN Thursday spotlights 20 can’t-miss games joining the cloud this month, with something for every kind of gamer.

Gear up with the nine games available this week, on top of the launch of Rust’s Jungle Biome update.

Welcome to the Jungle

In Rust, a multiplayer survival game by Facepunch Studios, everyone starts off with only a rock and torch and must gather resources, build bases and fend off environmental threats and other players in a harsh, open-world setting. The game features intense player vs. player combat, dynamic alliances and frequent updates.

Jungle Rust update on GeForce NOW
Crocodiles, snakes and tigers — oh my!

The latest update introduces a new jungle biome: a lush but dangerous environment filled with crocodiles, snakes, tigers and other wildlife. Added features include new early-game weapons, like a blowpipe with venom darts, as well as unique mechanics, like regrowing jungle vines.

GeForce NOW members can experience the intense survival gameplay of Rust from the cloud. Whether on an underpowered PC, Mac, smartphone or smart TV, dive into the game’s open-world chaos with smooth performance and stunning visuals.

May-day Games

Haunted House Renovator on GeForce NOW
Spoooooky.

Check out Haunted House Renovator, new in the cloud for members to stream. Step into the shoes of a paranormal renovator to explore and restore haunted locations, all while dealing with mischievous spirits that can help or hinder. Whether players are exorcising spirits or taking a more forceful approach, the game offers a unique blend of ghost hunting and home-makeover fun — perfect for fans of both genres.

Look for the following games available to stream in the cloud this week:

  • Deadzone: Rogue (New release on Steam, April 29)
  • Haunted House Renovator (New release on Steam, April 30)
  • Far Cry 4 (New release on Xbox, available on PC Game Pass, April 30)
  • Anno 1800 (New release on Xbox, available on PC Game Pass, May 1)
  • Call of Duty: Modern Warfare 2 Remastered (New release on Xbox, available on PC Game Pass, May 1. Find it on GeForce NOW in the Call of Duty experience)
  • Blood Strike (Steam)
  • DREDGE (Xbox, available on the Microsoft Store)
  • LONESTAR (Steam)
  • Soulstone Survivors (Steam)

Learn how to stream supported Ubisoft games from PC Game Pass on GeForce NOW, including this week’s additions of Far Cry 4 and Anno 1800.

Here’s what to expect for the rest of May:

  • Survival Machine (New release on Steam, May 7)
  • Revenge of the Savage Planet (New release on Steam and Xbox, available on PC Game Pass, May 8)
  • Spirit of the North 2 (New release on Steam, May 8)
  • The Precinct (New release on Steam, May 13)
  • DOOM: The Dark Ages (New release on Steam, Battle.net and Xbox, available on PC Game Pass, May 15)
  • Blacksmith Master (New release on Steam, May 15)
  • 9 Kings (New release on Steam, May 19)
  • RoadCraft (New release on Steam, May 20)
  • Monster Train 2 (New release on Steam, May 21)
  • Survive the Fall (New release Steam, May 21)
  • Blades of Fire (New release on Epic Games Store, May 22)
  • Tokyo Xtreme Racer (Steam)
  • The Last Spell (Steam)
  • War Robots: Frontiers (Steam)
  • Torque Drift 2 (Epic Games Store)

April Showers Bring More Games

In addition to the 21 games announced last month, 13 more joined the GeForce NOW library:

  • Forever Skies (New release on Steam, April 14)
  • Hunt: Showdown 1896 (New release on Xbox, available on PC Game Pass, April 15)
  • Crime Scene Cleaner (New release on Xbox, available on PC Game Pass, April 17)
  • The Elder Scrolls IV: Oblivion Remastered (New release on Steam and Xbox, available on PC Game Pass, April 22)
  • Ace Attorney Investigations Collection (Steam and Xbox, available on the Microsoft Store)
  • Ace Attorney Investigations Collection Demo (Steam and Xbox, available on the Microsoft Store)
  • Dead Rising Deluxe Remaster Demo (Steam)
  • Diablo III (Xbox, available on PC Game Pass)
  • Gedonia 2 (Steam) 
  • MARVEL vs. CAPCOM Fighting Collection: Arcade Classics (Steam)
  • Path of Exile 2 (Epic Games Store)
  • Sands of Aura (Epic Games Store)
  • Sultan’s Game (Steam)

What are you planning to play this weekend? Let us know on X or in the comments below.

Read More

Wandercraft Begins Clinical Trials for Physical AI-Powered Personal Exoskeleton

Wandercraft Begins Clinical Trials for Physical AI-Powered Personal Exoskeleton

For Nicolas Simon, advancing the field of robotics is a personal mission that could change his siblings’ lives.

Two-thirds of Simon’s family members use wheelchairs due to mobility challenges related to Charcot-Marie-Tooth disease, an inherited genetic condition. As an engineering student and robotics club chair at France’s École Polytechnique, he saw the opportunity to build a new device that could help his brother and other relatives walk again.

His robotics startup, Wandercraft, builds mobility solutions for individuals with spinal cord injuries, stroke and other neuromuscular disorders — with the potential to support the estimated 80 million individuals worldwide who require wheelchairs for mobility.

The company’s Personal Exoskeleton, currently in clinical trials, enables users to stand and walk with the support of AI-powered mechanisms for stability and movement. Users can control the robotic system with a joystick.

Simon founded the company in 2012 with Matthieu Masselin and Jean-Louis Constanza, whose son also has Charcot-Marie-Tooth disease. The team is accelerating its workflows with NVIDIA technologies — enabling Wandercraft to harness the latest simulation tools and AI infrastructure to bring new capabilities to its exoskeletons.

“It’s essential for the exoskeleton to be fast enough that it can be used in the real world,” said Simon. “By integrating NVIDIA AI into the device, we can someday enable users to walk at an average pace, cross the road and go up and down stairs.”

Advancing Development With Physical AI 

Wandercraft’s first exoskeleton, called Atalante X, is FDA-cleared and already in use as a neurological rehabilitation tool in over 100 clinical and research settings worldwide, with patients taking over one million steps per month. Approved for use in the European Union in 2019 and in the U.S. in 2022, the device has helped hundreds of patients regain mobility through physiotherapy.

The startup’s latest device, the Personal Exoskeleton, is aimed at everyday indoor and outdoor use. With clinical trials underway in New York and New Jersey, the Personal Exoskeleton integrates AI to continuously adapt to a user’s movements in real time, supporting smooth and stable walking across different surfaces, including concrete, carpet and tile.

“It’s very important for us to use physical AI to deliver a better device and a better experience for our users, so they can move through their daily lives smoothly and efficiently,” Simon said.

Woman walking through conference exhibit hall using AI exoskeleton
Laubach demonstrating the Personal Exoskeleton at NVIDIA GTC. Image credit: Wandercraft

In addition to helping users gain new levels of mobility at home and in their communities, the Personal Exoskeleton could also help reduce the health impacts of being seated all day — which include increased risk of cardiovascular, skin and digestive conditions.

Wandercraft is currently experimenting with NVIDIA Isaac Sim — a reference application built on NVIDIA Omniverse for simulating and testing AI-driven robotics solutions in physically based virtual environments — to accelerate its reinforcement learning pipeline. The company is also investigating the use of the NVIDIA Isaac for Healthcare developer framework for AI healthcare robotics and NVIDIA Jetson Thor, an on-robot edge computer built on the NVIDIA Blackwell architecture.

With these systems for physical AI training, simulation and runtime in place, Wandercraft will have a three-computer solution for its robotics development.

“The technology is there — you just have to build the device,” Simon said. “We take all the technology from the field of humanoid robotics, and we apply it to our exoskeleton. So now, the possibilities are endless.”

At the NVIDIA GTC global AI conference in March, Wandercraft demonstrated the prototype Personal Exoskeleton with the help of Caroline Laubach, a spinal stroke survivor and full-time wheelchair user. The exoskeleton system was also featured in last year’s Olympic and Paralympic Torch Relay.

Wandercraft aims to apply for FDA clearance for its Personal Exoskeleton immediately following completion of its clinical trial, with the goal of making it accessible to millions of wheelchair users in the U.S. with expected Medicare coverage. The company is currently recruiting additional participants for the clinical trial, which it aims to complete this year.

“With this technology, we can enable people to move around and access the environment of the city,” Simon said. “My hope is to see my device in the streets — of New York at first, but in every city in the U.S.”

Read More

Announcing the PyTorch Docathon 2025

We’re thrilled to announce the 2025 PyTorch Docathon! This is a hackathon-style event aimed at enhancing PyTorch documentation with the support of the community. Documentation is a vital component of any technology, and by refining it, we can simplify the onboarding process for new users, help them effectively utilize PyTorch’s features, and ultimately speed up the transition from research to production in machine learning.

WHY PARTICIPATE

Low Barrier to Entry

Unlike many open-source projects that require deep knowledge of the codebase and previous contributions to join hackathon events, the Docathon is tailored for newcomers. While we expect participants to be familiar with Python, and have basic knowledge of PyTorch and machine learning, there are tasks related to website issues that don’t even require that level of expertise.

Tangible Results

A major advantage of the Docathon is witnessing the immediate impact of your contributions. Enhancing documentation significantly boosts a project’s usability and accessibility, and you’ll be able to observe these improvements directly. Seeing tangible outcomes can also be a strong motivator to continue contributing.

Collaborative Environment

The Docathon fosters a collaborative atmosphere, offering you the chance to work alongside other contributors and PyTorch maintainers to improve the documentation. This is a fantastic opportunity to learn from peers, exchange ideas, and build connections.

Learning Opportunities

Even if you’re not a PyTorch expert, the Docathon offers a valuable learning experience. You’ll have the chance to delve into PyTorch modules, test tutorials on your machine, and explore them in the CI environment.

WHO SHOULD PARTICIPATE

Whether you’re a seasoned documentation expert or just starting out, we invite everyone to join in the PyTorch docathon to contribute and develop your skills and knowledge to help improve the documentation for everyone! We will have issues labelled by skill level, and the PyTorch Discord will be available for collaboration and help.

EVENT DETAILS

  • June 3: Kick-off 10 AM PT
  • June 4 – June 15: Submissions and Feedback
  • June 16 – June 17: Final Reviews
  • June 18: Winner Announcements

Make sure to RSVP to the event so you receive all the notifications and instructions on how to participate.

Further details about the Docathon will be shared during the Kick-off call on June 3.

Don’t forget to register for this year’s event: RSVP now

Read More

How IBM Research Uses PyTorch and TerraTorch to Make Geospatial Computer Vision Accessible for Everyone

How IBM Research Uses PyTorch and TerraTorch to Make Geospatial Computer Vision Accessible for Everyone

Earth Observation-based analytics are becoming essential for understanding our planet — from monitoring deforestation to tracking urban development and analyzing the impacts of climate change. However, the coding and deep learning skills for applying AI models to satellite imagery and earth observation data has traditionally been a major barrier for many practitioners.

By IBM Research’s launch of TerraTorch 1.0, a PyTorch domain library for fine-tuning of Geospatial Computer Vision Foundation Models, we make geospatial AI not only more accessible but also more practical for the wider PyTorch community. Our goal: simplify the process so that any data scientist, researcher, or enthusiast can build powerful geospatial models with ease and low GPU and data processing requirements.

globes

The power of foundation models, even with 75-95% of the input data removed, the models do a fantastic job in reconstruction of the input data – therefore learning the underlying physics of our planet in a deep, latent space

The Business Challenge

Our goal was to remove the technical barriers that prevent people from working with satellite imagery, weather and climate data at scale. Together with NASA, we’ve developed the Prithvi family of foundation models. Integrating the latest innovations of AI research using the clean API PyTorch provides has facilitated the job.

We wanted to create a framework that anyone can use to go from raw data to inference ready models in just a few steps.

globes

How a weather and climate foundation model created and fine-tuned on PyTorch is used for weather forecasts

How IBM Research Used PyTorch

We’ve built TerraTorch on top of PyTorch, leveraging its dynamic ecosystem to integrate:

  • PyTorch Lightning for clean, scalable training loops
  • TorchGeo for geospatial data handling and transformations (PyTorch transforms)
  • For foundation models like the leading generative multimodal foundation model ‘Terramind’, co-developed by IBM and ESA, and the ‘Prithvi’ family, co-developed by IBM and NASA, TerraTorch has been used to fine-tune all of the downstream geospatial models for satellite imagery, weather and climate data. It includes the family of fine-tuned models that IBM has released as part of Granite. In addition, other interesting foundation models and ecosystem components like Clay, SatMAE, Satlas, DeCur and DOFA are included in TerraTorch.
  • Powerful and state-of-the-art vision transformers to experiment with modern neural network architectures
  • TerraTorch-Iterate build on top of PyTorch, Optuna, MLFlow and Ray Tune for Hyperparameter Optimization (HPO), Neural Architecture Search (NAS) and Foundation Model Benchmarking (GeoBench), where TerraTorch became the reference implementation

flow diagram

The fine-tuning and inference process is completely described in a single YAML config file. There, the architectural building blocks of the model (backbone, neck, decoder, head) are defined. The Model Factory assembles the model using the build-in and custom registries. In addition, the Optimizer and Data Modules are created as defined in the config. Finally, everything is passed to the Lightning Trainer, who executes the task.

With PyTorch’s flexibility, we were able to prototype quickly, iterate on model architectures, and deploy pipelines for a range of geospatial applications — from flood and biomass detection to increasing resolution of climate data, where some of our our work became part of the IBM Granite Geospatial Model Family.

flow diagram

Architecture of the Prithvi-EO-2.0-600M foundation model which IBM Research developed together with NASA

Solving AI Challenges with PyTorch

PyTorch helped us to tackle three major challenges:

  • Ease of experimentation: Dynamic computation graphs, automatic differentiation, full abstraction of CUDA and rich visualization tools made it simple to test different models and training strategies.
  • Scalability: With DDP, FSDP, PyTorch Lightning and TorchGeo, we could train models on large-scale datasets without worrying about infrastructure.
  • Community support: PyTorch – the de-facto standard in AI research – with its active community and excellent documentation made it easy to overcome hurdles and stay up to date with the latest advancements in AI research.

A Word from IBM Research

“PyTorch gave me the power to turn complex linear algebra and optimization problems into accessible, shareable solutions for the community. It feels empowering that we’re building and fine-tuning models for anyone curious about understanding our planet through AI.”

— Romeo Kienzler, AI Research Engineer at IBM Research Zurich, Rueschlikon

quote

The Benefits of Using PyTorch

Using PyTorch allowed us to:

  • Build a reproducible, open-source framework for fine-tuning geospatial foundation models
  • Share our work with the community through easy-to-follow notebooks, TerraTorch configuration files, tutorials and model checkpoints on HuggingFace
  • Rapidly iterate over foundation model architectures and deploy fine-tuned models for inference, from research to real-world client products

Learn More

For more information about this project and to explore the code, visit:

Read More

FlexAttention Part II: FlexAttention for Inference

FlexAttention Part II: FlexAttention for Inference

Overview

In PyTorch 2.5.0 release, we introduced FlexAttention torch.nn.attention.flex_attention for ML researchers who’d like to customize their attention kernels without writing kernel code. This blog introduces our decoding backend optimized for inference, supporting GQA and PagedAttention, along with feature updates including nested jagged tensor support, performance tuning guides and trainable biases support.

If you’re looking for an easy way to play around with FlexAttention in your post-training / inference pipeline, PyTorch native post-training library torchtune and inference codebase gpt-fast already have FlexAttention integrated. Try it out!

We are excited to share that our paper on FlexAttention has been accepted for presentation at the MLSys2025 Conference held from May 12-15th in Santa Clara, California.

Title: FlexAttention: A Programming Model for Generating Optimized Attention Kernels. Poster

FlexAttention for Inference

TL;DR: torch.compile lowers flex_attention to a fused FlashDecoding kernel when it runs on a very short query.

One fused attention kernel does not suit all – especially in long-context LLM inference.

The decoding phase of LLM inference is an iterative process: tokens are generated one at a time, requiring N forward passes to generate an N-token sentence. Fortunately, each iteration doesn’t need to recompute self-attention over the full sentence — previously calculated tokens are cached, therefore we only need to attend the newly generated token to the cached context.

This results in a unique attention pattern where a short query sequence (1 token) attends to a long key-value cache (context length up to 128k). Traditional optimizations for square attention kernels (q_len ≈ kv_len) don’t directly apply here. This pattern poses new challenges for GPU memory utilization and occupancy. We build a dedicated FlexDecoding backend optimized for long-context LLM inference incorporating decoding-specific techniques from FlashDecoding.

FlexDecoding is implemented as an alternative backend for the torch.nn.attention.flex_attention operator. flex_attention automatically switches to the FlexDecoding backend for its JIT compilation when given a short query and a long KV cache. If the input shape changes significantly, for example transitioning from the prefill phase to decoding, JIT recompilation generates a separate kernel for each scenario.

flex_attention = torch.compile(flex_attention)

k_cache = torch.random(B, H, 16384, D) 
v_cache = torch.random(B, H, 16384, D)

...

# Prefill Phase: query shape = [B, H, 8000, D]
flex_attention(q_prefill, k_cache, v_cache, ...) # Uses FlexAttention backend optimized for prefill & training

# Decoding Phase: q_last_token shape = [B, H, 1, D]
flex_attention(q_last_token  , k_cache, v_cache, ...) # Recompiles with the FlexDecoding backend 

# decode 2 tokens at the same time: q_last_2_tokens shape = [B, H, 2, D]
flex_attention(q_last_2_tokens, k_cache, v_cache, ...) # No recompilation needed! Runs the decoding kernel again.

Working with KV Cache

One of the key optimizations for efficient inference is maintaining a preallocated KV cache that updates in place as new tokens are generated. Instead of enforcing a specific KV cache policy with a dedicated API, FlexDecoding allows users to define and manage the KV cache themselves.

Similar to FlexAttention, FlexDecoding takes user-defined mask_mod and score_mod functions. These functions modify attention scores before the softmax operation.

score_mod(score, b, h, q_idx, kv_idx) -> tensor # return updated score

Score is a scalar pytorch tensor that represents the dot product of a query token and a key token. The rest of the arguments specify which score is being computed:

  • b batch index
  • h attention head index
  • q_idx token position in query tensor
  • kv_idx token position in key/value tensor

In the decoding phase, previously calculated tokens are cached, and only the latest generated token (i-th) is used as the query. A naive causal mask on this one token query looks like this:

def causal(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, -float("inf"))

This is problematic: the new token “saw” should attend to all previously generated tokens i.e. “The cat sat on the mat and saw”, not just the first entry in the kv cache. To correct this, the score_mod needs to offset q_idx by for accurate decoding.

Creating a new score_mod for each token to accommodate the offset is slow since it means FlexAttention needs to be recompiled every iteration for a different score_mod. Instead,

We define this offset as a tensor and increment its value at each iteration:

offset = torch.tensor(i, "cuda")
def causal_w_offset(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx + offset >= kv_idx, score, -float("inf"))

# Attend the i-th token
flex_attention(..., score_mod=causal_w_offset  ) # Compiles the kernel here 
...
# Attend the i+1-th token
offset = offset + 1 # Increment offset
flex_attention(..., score_mod=causal_w_offset ) # Doesn't need to recompile! 

Notably, here offset becomes a captured tensor and it does not need to recompile if offset changes values.

Manually rewriting your score_mod and mask_mod for offset handling isn’t necessary. We can automate this process with a generic rewriter:

offset = torch.tensor(i, "cuda")

def get_score_mod_w_offset(score_mod: _score_mod_signature, _offset: tensor):
    def _score_mod(score, b, h, q, kv):
        return score_mod(score, b, h, q + _offset, kv)
    return _score_mod

def get_mask_mod_w_offset(mask_mod: _mask_mod_signature, _offset: tensor):
    def _mask_mod(b, h, q, kv):
        return mask_mod(b, h, q + _offset, kv)
    return _mask_mod

causal_w_offset = get_score_mod_w_offset(causal, offset)

BlockMask for Inference

We can also use BlockMask with inference to leverage mask sparsity. The idea is to precompute the BlockMask once during model setup and use slices of it during decoding

Precomputing BlockMask

During setup, we create a squared BlockMask for MAX_SEQ_LEN x MAX_SEQ_LEN:

from torch.nn.attention.flex_attention import create_block_mask

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

block_mask = create_block_mask(causal_mask, B=None, H=None, Q_LEN=MAX_SEQ_LEN,KV_LEN=MAX_SEQ_LEN)

Using BlockMask During Decoding

For the i-th token, we use a slice of the mask:

block_offset = i // block_mask.BLOCK_SIZE[0]
block_mask_slice = block_mask[:, :, block_offset]

# don't forget to use the mask_mod with offset! 
block_mask_slice.mask_mod = get_mask_mod_w_offset(causal_mask)

Performance

FlexDecoding kernel performs on par with FlashDecoding (FAKV) and significantly outperforms pytorch scaled_dot_product_attention (code).

FlexDecoding boosts LLaMa3.1-8B serving performance by 1.22x-2.04x, and LLaMa3.1-70B performance by 0.99x – 1.66x compared to SDPA in gpt-fast. (code)

Paged Attention

vLLM is one of the popular LLM serving engines, powered by the efficient memory management from PagedAttention. Existing PagedAttention implementation requires dedicated CUDA kernels and shows limited flexibility on supporting emerging attention variants. In this section, we present a PT2-native PagedAttention implementation that is enabled by flex attention and torch.compile.

PagedAttention scatters KV cache to reduce memory fragmentation and support higher batch sizes. Without PagedAttention, KV cache from the same request are stored in a contiguous memory, requiring 2 tensor of shape B x H x KV LEN x D. We call it a logical KV cache. Here, KV_LEN is the maximum sequence length over all requests in a batch. Considering the Figure 1(a), KV_LEN is 9 thus all requests must be padded to 9 tokens, leading to large memory waste. With PagedAttention, we can chunk each request into multiple pages of the same size page_size and scatter these pages into a physical KV cache of shape 1 x H x max seq len x D, where max_seq_len=n_pages x page_size. This avoids padding requests to the same length and saves memory. Specifically, we provide an assign API to update KV cache via index computations:

def assign(
    batch_idx: torch.Tensor,
    input_pos: torch.Tensor,
    k_val: torch.Tensor,
    v_val: torch.Tensor,
    k_cache: torch.Tensor,
    v_cache: torch.Tensor,
) -> None

Behind this assign API is a page table, a tensor mapping logical KV cache to physical KV cache:

[batch_idx, logical_page_idx] -> physical_page_idx

assign takes k_val and v_val and scatters to physical KV cache guided by the mapping from the page table.

Paged Attention with Page Table

A natural question is, how to integrate PagedAttention with flex attention to support diverse attention variants? A naive idea is to materialize the logical KV cache before computing with flex attention. But this leads to redundant memory copy and bad performance. Another idea is to build a dedicated CUDA or Triton kernel for paged attention, similar to existing PagedAttention implementation. However, this adds much manual effort and code complexity.

Instead, we design a fused indirect memory access by converting a logical block mask according to the page table. In FlexAttention, we exploit BlockMask to identify logical blocks and skip redundant computation. While Paged Attention adds an extra layer of indirect memory access, we can further convert the logical block mask to the physical block mask corresponding to the page table, as illustrated in Figure 2. Our PagedAttention implementation provides a convert_logical_block_mask via torch.gather calls:

def convert_logical_block_mask(
    block_mask: BlockMask,
    batch_idx: Optional[torch.Tensor] = None,
) -> BlockMask

Paged Attention via Block Mask Conversion

One remaining question is how to rewrite user-specified mask_mod and score_mod for PagedAttention. When users specify these modifications, they write with logical indices without the knowledge of the page table maintained at runtime. The following code shows an automated conversion at runtime which is necessary to rewrite user-specified modifications with physical kv indices. The new_mask_mod would take the physical_kv_idx and convert it back to the logical_kv_idx and apply user-specified mask_mod on the logical_kv_idx for the correct mask. For efficiency, we maintain physical_to_logical as a mapping from physical_kv_block to logical_kv_block to facilitate the conversion. For correctness, we mask out-of-boundary blocks as False with a torch.where call. After batching logical KV caches from multiple requests into the same physical KV cache, there are much more physical blocks than the number of logical blocks for each request. Thus, a physical block may not have a corresponding logical block for a specific request during block mask conversion. By masking as False with torch.where, we can ensure the correctness that data from different requests do not interfere with each other. Similarly, we can convert the score_mod automatically.

def get_mask_mod(mask_mod: Optional[_mask_mod_signature]) -> _mask_mod_signature:
    if mask_mod is None:
        mask_mod = noop_mask

    def new_mask_mod(
        b: torch.Tensor,
        h: torch.Tensor,
        q_idx: torch.Tensor,
        physical_kv_idx: torch.Tensor,
    ):
        physical_kv_block = physical_kv_idx // page_size
        physical_kv_offset = physical_kv_idx % page_size
        logical_block_idx = physical_to_logical[b, physical_kv_block]
        logical_kv_idx = logical_block_idx * page_size + physical_kv_offset
        return torch.where(
            logical_block_idx >= 0, mask_mod(b, h, q_idx, logical_kv_idx), False
        )

    return new_mask_mod

Figure 3 demonstrates the latency from Paged Attention (code). Overall, there is less than 5% overhead from Flex Attention with Paged Attention, compared with Flex Attention only. We also observe an on-par performance with Flash Attention v2. A minimal serving example further shows that PagedAttention can support 76x higher batch size when evaluating on OpenOrca dataset which includes 1M GPT-4 completions and 3.2M GPT-3.5 completions.

Paged Attention: Latency under diverse sequence length

Ragged input sequences with Nested Jagged Tensors (NJTs)

FlexAttention now supports ragged-sized input sequences through the use of Nested Jagged Tensors (NJTs). NJTs represent ragged-sized sequences by packing sequences into a single “stacked sequence” and maintaining a set of offsets delimiting sequence boundaries for each batch item.

A block mask can be created for input NJTs through the new create_nested_block_mask() API. The returned block mask is compatible with the ragged structure of the given NJT, treating it as a single “stacked sequence” with inter-sequence attention automatically masked out. The mask_mod or score_mod function can be written as usual.

from torch.nn.attention.flex_attention import create_nested_block_mask, flex_attention

BATCH = 8
NUM_HEADS = 8
D = 16
device = "cuda"

# Input NJTs of shape (BATCH, SEQ_LEN*, D) with ragged SEQ_LEN
sequence_lengths = [torch.randint(5, 30, ()).item() for _ in range(BATCH)]
query = torch.nested.nested_tensor([
    torch.randn(seq_len, NUM_HEADS * D, device=device)
    for seq_len in sequence_lengths
], layout=torch.jagged)
key = torch.randn_like(query)
value = torch.randn_like(query)

# View as shape (BATCH, NUM_HEADS, SEQ_LEN*, HEAD_DIM)
query = query.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)
key = key.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)
value = value.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)

# Simple causal mask
def my_mask_mod(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

# Construct a block mask using the ragged structure of the
# specified query NJT. Ragged-sized sequences are treated as a single
# "stacked sequence" with inter-sequence attention masked out.
block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query)

# For cross attention, create_nested_block_mask() also supports a
# rectangular block mask using the ragged structures of both query / key.
#block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query, key)

output = flex_attention(query, key, value, block_mask=block_mask)

Trainable Biases

FlexAttention now supports trainable parameters in score_mod functions. This feature enables users to reference tensors that require gradients within their score_mod implementations, with gradients automatically backpropagating through these parameters during training.

Memory-Efficient Gradient Accumulation

Instead of materializing the full attention scores matrix, FlexAttention uses atomic additions (tl.atomic_add) to accumulate gradients. This approach significantly reduces memory usage at the cost of introducing some non-determinism in gradient calculations.

Handling Broadcasted Operations

Broadcasting operations in the forward pass (e.g., score + bias[h]) require special consideration in the backward pass. When broadcasting a tensor across multiple attention scores within a head or other dimensions, we need to reduce these gradients back to the original tensor shape. Rather than materializing the full attention score matrix to perform this reduction, we use atomic operations. While this incurs some runtime overhead, it allows us to maintain memory efficiency by avoiding the materialization of large intermediate tensors.

Current Limitations

The implementation currently allows only a single read from each input tensor in the score_mod function. For example, bias[q_idx] + bias[kv_idx] would not be supported as it reads from the same tensor twice. We hope to remove this restriction in the future.

Simple Example:

bias = torch.randn(num_heads, requires_grad=True)
def score_mod(score, b, h, q_idx, kv_idx):
    return score + bias[h]  

Performance Tuning for FlexAttention

TL;DR

For optimal performance, compile FlexAttention using max-autotune, especially when dealing with complex score_mods and mask_mods:

flex_attention = torch.compile(flex_attention, dynamic=True, mode=’max-autotune’)

What is max-autotune?

max-autotune is a torch.compile mode in which TorchInductor sweeps many kernel parameters (e.g., tile size, num_stages) and selects the best-performing configuration. This process allows kernels to test both successful and failing configurations without issues, and find the best viable configuration.

While compilation takes longer with max-autotune, the optimal configuration is cached for future kernel executions.

Here’s an example of FlexAttention compiled with max-autotune:

triton_flex_attention_backward_7 0.2528 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=32, BLOCK_N1=32, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=7, HAS_FULL_BLOCKS=False, IS_DIVISIBLE=False, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=128, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.08838834764831843, SPARSE_KV_BLOCK_SIZE=1073741824, SPARSE_Q_BLOCK_SIZE=1073741824, V_HEAD_DIM=128, num_stages=4, num_warps=4

Why Use max-autotune for FlexAttention?

The amount of shared memory utilized in FlexAttention depends on score_mod and mask_mod methods. This variability means that the preconfigured default kernel parameters may lead to performance cliffs or even out of shared memory** **errors on certain hardware for some masks/mods.

For instance, with document masks, default configurations can halve GPU occupancy, reducing performance to ~75% of its potential on some GPUs. To avoid such issues, we strongly recommend enabling max-autotune.

Updates and Enhancements

  • Now available as a prototype feature in PyTorch 2.5.0
  • Fixed critical correctness issues, including a bug affecting multiple calls to FlexAttention within the same call to torch.compile

Expanded Architecture Support

  • Arbitrary sequence length support – no longer requires multiples of 128
  • Added native grouped-query attention (GQA) support via is_gqa=True
  • Enhanced dimension flexibility:
    • Different QK and V head dimensions
    • Non-power-of-two head dimensions
  • Trainable attention biases (prototype)

Under the Hood

  • New fused CPU backend
  • Improved TF32 handling for float32 inputs
  • Resolved various dynamic shape issues
  • Output layout matching query strides

These updates make FlexAttention more robust and flexible while maintaining its core promise of combining PyTorch’s ease of use with FlashAttention’s performance benefits.

Read More

FlexAttention Part II: FlexAttention for Inference

FlexAttention Part II: FlexAttention for Inference

Overview

In PyTorch 2.5.0 release, we introduced FlexAttention torch.nn.attention.flex_attention for ML researchers who’d like to customize their attention kernels without writing kernel code. This blog introduces our decoding backend optimized for inference, supporting GQA and PagedAttention, along with feature updates including nested jagged tensor support, performance tuning guides and trainable biases support.

If you’re looking for an easy way to play around with FlexAttention in your post-training / inference pipeline, PyTorch native post-training library torchtune and inference codebase gpt-fast already have FlexAttention integrated. Try it out!

We are excited to share that our paper on FlexAttention has been accepted for presentation at the MLSys2025 Conference held from May 12-15th in Santa Clara, California.

Title: FlexAttention: A Programming Model for Generating Optimized Attention Kernels. Poster

FlexAttention for Inference

TL;DR: torch.compile lowers flex_attention to a fused FlashDecoding kernel when it runs on a very short query.

One fused attention kernel does not suit all – especially in long-context LLM inference.

The decoding phase of LLM inference is an iterative process: tokens are generated one at a time, requiring N forward passes to generate an N-token sentence. Fortunately, each iteration doesn’t need to recompute self-attention over the full sentence — previously calculated tokens are cached, therefore we only need to attend the newly generated token to the cached context.

This results in a unique attention pattern where a short query sequence (1 token) attends to a long key-value cache (context length up to 128k). Traditional optimizations for square attention kernels (q_len ≈ kv_len) don’t directly apply here. This pattern poses new challenges for GPU memory utilization and occupancy. We build a dedicated FlexDecoding backend optimized for long-context LLM inference incorporating decoding-specific techniques from FlashDecoding.

FlexDecoding is implemented as an alternative backend for the torch.nn.attention.flex_attention operator. flex_attention automatically switches to the FlexDecoding backend for its JIT compilation when given a short query and a long KV cache. If the input shape changes significantly, for example transitioning from the prefill phase to decoding, JIT recompilation generates a separate kernel for each scenario.

flex_attention = torch.compile(flex_attention)

k_cache = torch.random(B, H, 16384, D) 
v_cache = torch.random(B, H, 16384, D)

...

# Prefill Phase: query shape = [B, H, 8000, D]
flex_attention(q_prefill, k_cache, v_cache, ...) # Uses FlexAttention backend optimized for prefill & training

# Decoding Phase: q_last_token shape = [B, H, 1, D]
flex_attention(q_last_token  , k_cache, v_cache, ...) # Recompiles with the FlexDecoding backend 

# decode 2 tokens at the same time: q_last_2_tokens shape = [B, H, 2, D]
flex_attention(q_last_2_tokens, k_cache, v_cache, ...) # No recompilation needed! Runs the decoding kernel again.

Working with KV Cache

One of the key optimizations for efficient inference is maintaining a preallocated KV cache that updates in place as new tokens are generated. Instead of enforcing a specific KV cache policy with a dedicated API, FlexDecoding allows users to define and manage the KV cache themselves.

Similar to FlexAttention, FlexDecoding takes user-defined mask_mod and score_mod functions. These functions modify attention scores before the softmax operation.

score_mod(score, b, h, q_idx, kv_idx) -> tensor # return updated score

Score is a scalar pytorch tensor that represents the dot product of a query token and a key token. The rest of the arguments specify which score is being computed:

  • b batch index
  • h attention head index
  • q_idx token position in query tensor
  • kv_idx token position in key/value tensor

In the decoding phase, previously calculated tokens are cached, and only the latest generated token (i-th) is used as the query. A naive causal mask on this one token query looks like this:

def causal(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, -float("inf"))

This is problematic: the new token “saw” should attend to all previously generated tokens i.e. “The cat sat on the mat and saw”, not just the first entry in the kv cache. To correct this, the score_mod needs to offset q_idx by for accurate decoding.

Creating a new score_mod for each token to accommodate the offset is slow since it means FlexAttention needs to be recompiled every iteration for a different score_mod. Instead,

We define this offset as a tensor and increment its value at each iteration:

offset = torch.tensor(i, "cuda")
def causal_w_offset(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx + offset >= kv_idx, score, -float("inf"))

# Attend the i-th token
flex_attention(..., score_mod=causal_w_offset  ) # Compiles the kernel here 
...
# Attend the i+1-th token
offset = offset + 1 # Increment offset
flex_attention(..., score_mod=causal_w_offset ) # Doesn't need to recompile! 

Notably, here offset becomes a captured tensor and it does not need to recompile if offset changes values.

Manually rewriting your score_mod and mask_mod for offset handling isn’t necessary. We can automate this process with a generic rewriter:

offset = torch.tensor(i, "cuda")

def get_score_mod_w_offset(score_mod: _score_mod_signature, _offset: tensor):
    def _score_mod(score, b, h, q, kv):
        return score_mod(score, b, h, q + _offset, kv)
    return _score_mod

def get_mask_mod_w_offset(mask_mod: _mask_mod_signature, _offset: tensor):
    def _mask_mod(b, h, q, kv):
        return mask_mod(b, h, q + _offset, kv)
    return _mask_mod

causal_w_offset = get_score_mod_w_offset(causal, offset)

BlockMask for Inference

We can also use BlockMask with inference to leverage mask sparsity. The idea is to precompute the BlockMask once during model setup and use slices of it during decoding

Precomputing BlockMask

During setup, we create a squared BlockMask for MAX_SEQ_LEN x MAX_SEQ_LEN:

from torch.nn.attention.flex_attention import create_block_mask

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

block_mask = create_block_mask(causal_mask, B=None, H=None, Q_LEN=MAX_SEQ_LEN,KV_LEN=MAX_SEQ_LEN)

Using BlockMask During Decoding

For the i-th token, we use a slice of the mask:

block_offset = i // block_mask.BLOCK_SIZE[0]
block_mask_slice = block_mask[:, :, block_offset]

# don't forget to use the mask_mod with offset! 
block_mask_slice.mask_mod = get_mask_mod_w_offset(causal_mask)

Performance

FlexDecoding kernel performs on par with FlashDecoding (FAKV) and significantly outperforms pytorch scaled_dot_product_attention (code).

FlexDecoding boosts LLaMa3.1-8B serving performance by 1.22x-2.04x, and LLaMa3.1-70B performance by 0.99x – 1.66x compared to SDPA in gpt-fast. (code)

Paged Attention

vLLM is one of the popular LLM serving engines, powered by the efficient memory management from PagedAttention. Existing PagedAttention implementation requires dedicated CUDA kernels and shows limited flexibility on supporting emerging attention variants. In this section, we present a PT2-native PagedAttention implementation that is enabled by flex attention and torch.compile.

PagedAttention scatters KV cache to reduce memory fragmentation and support higher batch sizes. Without PagedAttention, KV cache from the same request are stored in a contiguous memory, requiring 2 tensor of shape B x H x KV LEN x D. We call it a logical KV cache. Here, KV_LEN is the maximum sequence length over all requests in a batch. Considering the Figure 1(a), KV_LEN is 9 thus all requests must be padded to 9 tokens, leading to large memory waste. With PagedAttention, we can chunk each request into multiple pages of the same size page_size and scatter these pages into a physical KV cache of shape 1 x H x max seq len x D, where max_seq_len=n_pages x page_size. This avoids padding requests to the same length and saves memory. Specifically, we provide an assign API to update KV cache via index computations:

def assign(
    batch_idx: torch.Tensor,
    input_pos: torch.Tensor,
    k_val: torch.Tensor,
    v_val: torch.Tensor,
    k_cache: torch.Tensor,
    v_cache: torch.Tensor,
) -> None

Behind this assign API is a page table, a tensor mapping logical KV cache to physical KV cache:

[batch_idx, logical_page_idx] -> physical_page_idx

assign takes k_val and v_val and scatters to physical KV cache guided by the mapping from the page table.

Paged Attention with Page Table

A natural question is, how to integrate PagedAttention with flex attention to support diverse attention variants? A naive idea is to materialize the logical KV cache before computing with flex attention. But this leads to redundant memory copy and bad performance. Another idea is to build a dedicated CUDA or Triton kernel for paged attention, similar to existing PagedAttention implementation. However, this adds much manual effort and code complexity.

Instead, we design a fused indirect memory access by converting a logical block mask according to the page table. In FlexAttention, we exploit BlockMask to identify logical blocks and skip redundant computation. While Paged Attention adds an extra layer of indirect memory access, we can further convert the logical block mask to the physical block mask corresponding to the page table, as illustrated in Figure 2. Our PagedAttention implementation provides a convert_logical_block_mask via torch.gather calls:

def convert_logical_block_mask(
    block_mask: BlockMask,
    batch_idx: Optional[torch.Tensor] = None,
) -> BlockMask

Paged Attention via Block Mask Conversion

One remaining question is how to rewrite user-specified mask_mod and score_mod for PagedAttention. When users specify these modifications, they write with logical indices without the knowledge of the page table maintained at runtime. The following code shows an automated conversion at runtime which is necessary to rewrite user-specified modifications with physical kv indices. The new_mask_mod would take the physical_kv_idx and convert it back to the logical_kv_idx and apply user-specified mask_mod on the logical_kv_idx for the correct mask. For efficiency, we maintain physical_to_logical as a mapping from physical_kv_block to logical_kv_block to facilitate the conversion. For correctness, we mask out-of-boundary blocks as False with a torch.where call. After batching logical KV caches from multiple requests into the same physical KV cache, there are much more physical blocks than the number of logical blocks for each request. Thus, a physical block may not have a corresponding logical block for a specific request during block mask conversion. By masking as False with torch.where, we can ensure the correctness that data from different requests do not interfere with each other. Similarly, we can convert the score_mod automatically.

def get_mask_mod(mask_mod: Optional[_mask_mod_signature]) -> _mask_mod_signature:
    if mask_mod is None:
        mask_mod = noop_mask

    def new_mask_mod(
        b: torch.Tensor,
        h: torch.Tensor,
        q_idx: torch.Tensor,
        physical_kv_idx: torch.Tensor,
    ):
        physical_kv_block = physical_kv_idx // page_size
        physical_kv_offset = physical_kv_idx % page_size
        logical_block_idx = physical_to_logical[b, physical_kv_block]
        logical_kv_idx = logical_block_idx * page_size + physical_kv_offset
        return torch.where(
            logical_block_idx >= 0, mask_mod(b, h, q_idx, logical_kv_idx), False
        )

    return new_mask_mod

Figure 3 demonstrates the latency from Paged Attention (code). Overall, there is less than 5% overhead from Flex Attention with Paged Attention, compared with Flex Attention only. We also observe an on-par performance with Flash Attention v2. A minimal serving example further shows that PagedAttention can support 76x higher batch size when evaluating on OpenOrca dataset which includes 1M GPT-4 completions and 3.2M GPT-3.5 completions.

Paged Attention: Latency under diverse sequence length

Ragged input sequences with Nested Jagged Tensors (NJTs)

FlexAttention now supports ragged-sized input sequences through the use of Nested Jagged Tensors (NJTs). NJTs represent ragged-sized sequences by packing sequences into a single “stacked sequence” and maintaining a set of offsets delimiting sequence boundaries for each batch item.

A block mask can be created for input NJTs through the new create_nested_block_mask() API. The returned block mask is compatible with the ragged structure of the given NJT, treating it as a single “stacked sequence” with inter-sequence attention automatically masked out. The mask_mod or score_mod function can be written as usual.

from torch.nn.attention.flex_attention import create_nested_block_mask, flex_attention

BATCH = 8
NUM_HEADS = 8
D = 16
device = "cuda"

# Input NJTs of shape (BATCH, SEQ_LEN*, D) with ragged SEQ_LEN
sequence_lengths = [torch.randint(5, 30, ()).item() for _ in range(BATCH)]
query = torch.nested.nested_tensor([
    torch.randn(seq_len, NUM_HEADS * D, device=device)
    for seq_len in sequence_lengths
], layout=torch.jagged)
key = torch.randn_like(query)
value = torch.randn_like(query)

# View as shape (BATCH, NUM_HEADS, SEQ_LEN*, HEAD_DIM)
query = query.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)
key = key.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)
value = value.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)

# Simple causal mask
def my_mask_mod(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

# Construct a block mask using the ragged structure of the
# specified query NJT. Ragged-sized sequences are treated as a single
# "stacked sequence" with inter-sequence attention masked out.
block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query)

# For cross attention, create_nested_block_mask() also supports a
# rectangular block mask using the ragged structures of both query / key.
#block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query, key)

output = flex_attention(query, key, value, block_mask=block_mask)

Trainable Biases

FlexAttention now supports trainable parameters in score_mod functions. This feature enables users to reference tensors that require gradients within their score_mod implementations, with gradients automatically backpropagating through these parameters during training.

Memory-Efficient Gradient Accumulation

Instead of materializing the full attention scores matrix, FlexAttention uses atomic additions (tl.atomic_add) to accumulate gradients. This approach significantly reduces memory usage at the cost of introducing some non-determinism in gradient calculations.

Handling Broadcasted Operations

Broadcasting operations in the forward pass (e.g., score + bias[h]) require special consideration in the backward pass. When broadcasting a tensor across multiple attention scores within a head or other dimensions, we need to reduce these gradients back to the original tensor shape. Rather than materializing the full attention score matrix to perform this reduction, we use atomic operations. While this incurs some runtime overhead, it allows us to maintain memory efficiency by avoiding the materialization of large intermediate tensors.

Current Limitations

The implementation currently allows only a single read from each input tensor in the score_mod function. For example, bias[q_idx] + bias[kv_idx] would not be supported as it reads from the same tensor twice. We hope to remove this restriction in the future.

Simple Example:

bias = torch.randn(num_heads, requires_grad=True)
def score_mod(score, b, h, q_idx, kv_idx):
    return score + bias[h]  

Performance Tuning for FlexAttention

TL;DR

For optimal performance, compile FlexAttention using max-autotune, especially when dealing with complex score_mods and mask_mods:

flex_attention = torch.compile(flex_attention, dynamic=True, mode=’max-autotune’)

What is max-autotune?

max-autotune is a torch.compile mode in which TorchInductor sweeps many kernel parameters (e.g., tile size, num_stages) and selects the best-performing configuration. This process allows kernels to test both successful and failing configurations without issues, and find the best viable configuration.

While compilation takes longer with max-autotune, the optimal configuration is cached for future kernel executions.

Here’s an example of FlexAttention compiled with max-autotune:

triton_flex_attention_backward_7 0.2528 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=32, BLOCK_N1=32, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=7, HAS_FULL_BLOCKS=False, IS_DIVISIBLE=False, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=128, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.08838834764831843, SPARSE_KV_BLOCK_SIZE=1073741824, SPARSE_Q_BLOCK_SIZE=1073741824, V_HEAD_DIM=128, num_stages=4, num_warps=4

Why Use max-autotune for FlexAttention?

The amount of shared memory utilized in FlexAttention depends on score_mod and mask_mod methods. This variability means that the preconfigured default kernel parameters may lead to performance cliffs or even out of shared memory** **errors on certain hardware for some masks/mods.

For instance, with document masks, default configurations can halve GPU occupancy, reducing performance to ~75% of its potential on some GPUs. To avoid such issues, we strongly recommend enabling max-autotune.

Updates and Enhancements

  • Now available as a prototype feature in PyTorch 2.5.0
  • Fixed critical correctness issues, including a bug affecting multiple calls to FlexAttention within the same call to torch.compile

Expanded Architecture Support

  • Arbitrary sequence length support – no longer requires multiples of 128
  • Added native grouped-query attention (GQA) support via is_gqa=True
  • Enhanced dimension flexibility:
    • Different QK and V head dimensions
    • Non-power-of-two head dimensions
  • Trainable attention biases (prototype)

Under the Hood

  • New fused CPU backend
  • Improved TF32 handling for float32 inputs
  • Resolved various dynamic shape issues
  • Output layout matching query strides

These updates make FlexAttention more robust and flexible while maintaining its core promise of combining PyTorch’s ease of use with FlashAttention’s performance benefits.

Read More