PyTorch Grows as the Dominant Open Source Framework for AI and ML: 2024 Year in Review

PyTorch Grows as the Dominant Open Source Framework for AI and ML: 2024 Year in Review

This past year was a monumental year for PyTorch from major releases to the flagship PyTorch Conference. We’ve seen incredible growth in contributions from more than 3,500 individuals and 3,000 organizations. It’s safe to say PyTorch has now become the dominant deep learning framework for AI/ML. PyTorch leads the model training space with a 63% adoption rate according to the recent Shaping the Future of Generative AI Report from the Linux Foundation.

group at a conference

The PyTorch Foundation was formed in 2022 with the goal to drive the adoption of AI tooling by fostering and sustaining an ecosystem of open source, vendor-neutral projects centered around PyTorch and today remains a vibrant, collaborative hub created for and by the deep learning community. As we wrap up the year, let’s take a look back at a few highlights and how this year has been one of growth, collaboration, innovation, and community.

2024 Highlights: A Year of Growth and Impact

PyTorch accelerated its growth this year. Contributions are up 133%, from double the amount of organizations worldwide compared to last year.

The project has seen 20% year-over-year growth in new repositories using PyTorch, and a 30% increase in forks and users this past year.

Over 70% of AI research implementations are now using PyTorch.

Statistics based on the 2024 Linux Foundation Annual Report.

people at a conference

PyTorch Tools ecosystem grew by over 25%, enhancing both software and hardware capabilities. Working with all major cloud service providers, dozens of major software vendors, and industry partners, PyTorch is setting a new bar for the pace and breadth of AI innovation.

people at a conference

This year featured 4 milestone releases for PyTorch in the 2.2, 2.3, 2.4 and 2.5 releases. We observed the release of various hallmark features like AOTInductor, FlashAttention-2 support, Tensor Parallelism, a new Python Custom Operator API, and the introduction of FlexAttention. Engineers from across PyTorch Foundation member companies have also come together to introduce support and optimizations for platforms like Intel GPUs (XPU), AWS Graviton processors, Inductor performance, etc.

Throughout the year the PyTorch Team has been working hard to introduce a number of new PyTorch-native libraries! The ExecuTorch team released their alpha in collaboration with partners from Arm, Apple, and Qualcomm Technologies, Inc. then quickly followed with a beta focused on stability and adding MediaTek. TorchTune established a PyTorch-native library for easily fine-tuning large language models. TorchAO introduced a PyTorch native library that makes models faster and smaller by leveraging low bit dtypes, quantization and sparsity. TorchCodec was launched to give developers a simple, performant, and PyTorch native way to decode videos into tensors. TorchRec 1.0 was released, the first stable release of the PyTorch native recommendation systems library.

We’ve also had a number of strong technical showcases throughout the year to highlight how PyTorch can be used! TorchTitan exhibited what an open source, PyTorch-native distributed training system could look like for training large language models (LLMs). TorchChat showcased how to seamlessly and performantly run LLMs across laptop, desktop, and mobile devices.

As well we were very excited to include multiple new projects into the PyTorch ecosystem throughout 2024, including the introduction of vLLM into the PyTorch Ecosystem, a state-of-the-art inference engine, which gives machine learning engineers an easy, fast, and cheap way of serving LLMs. If you are interested in joining the PyTorch Ecosystem, please join!

people at a conference

In June in Paris, France we premiered the official PyTorch documentary on powering the AI Revolution that spotlights PyTorch’s vibrant ecosystem and its role in advancing AI innovation. The film unveiled the authentic narrative of PyTorch’s inception, attributing its existence to a dedicated group of unsung heroes driving technological innovation.

people at a conference

The PyTorch Conference 2024, brought in triple the registrations compared to 2023, reflecting the rapid growth of AI and machine learning communities around open source technologies. The two day event included insightful talks, hands-on sessions, and lively discussions about the future of AI, covering everything from generative AI to large language models.

A brand new Startup Showcase featured early-stage founders pitching their AI startups to a panel of top venture capitalists, a DL Compiler Mini-Summit took a deep dive into the advances in deep learning (DL) compilers that are transforming AI workloads, and a Fine-Tuning Mini-Summit brought together a thriving community of researchers, developers, practitioners and hobbyists to discuss topics like memory efficiency, parameter-efficient fine-tuning, and performance at scale.

speaking on stage at a conference

Outstanding contributors were honored with PyTorch Contributor Awards. Congratulations to this year’s nominees and recipients for the outstanding individuals and teams who have played a pivotal role in PyTorch’s journey this year.

people at a conference

PyTorch Foundation membership is growing with the addition of Arm and Rebellions this year. At the year-end mark, Premier Members include: AMD, Arm, AWS, Google Cloud, Huawei, Hugging Face, IBM, Intel, Lightning AI, Meta, Microsoft Azure, and NVIDIA. General Members include: Graphcore, Rebellions, and Snowflake. If your organization is interested in joining, find out how you can become a member of the PyTorch Foundation.

PyTorch hosted numerous in-person and virtual events, including The PyTorch Docathon where contributors worked to improve PyTorch documentation and foster collaboration, Local meetups around the world brought together interested parties in locations from Shanghai to Seoul, and more than a dozen webinars brought in attendees from everywhere during our Summer Webinar Series, live Q&As, and Expert Exchanges.

Matt speaking at a conference

PyTorch Foundation welcomed new leadership this year. Executive Director Matt White took the reins in April and immediately began raising the profile of PyTorch across the AI landscape. The Technical Advisory Council (TAC) also elected new leadership with Luca Antiga, Lightning AI as the Chair and Jiong Gong, Intel as Vice Chair.

The PyTorch Governing Board continued to set the direction and lead the Foundation in accomplishing its mission. The PyTorch Marketing and Outreach Committee developed programs to maximize the visibility of PyTorch and advance the interests of the community. The PyTorch CI Working Group assembled to successfully migrate the PyTorch CI pipeline to the Linux Foundation.

Our community joined us on social media with 775 thousand followers strong across X, LinkedIn, Facebook, and YouTube with more than 12 million impressions of PyTorch content throughout the year. The PyTorch Ecosystem also grew, adding many new projects to leverage PyTorch deep learning across many vertical domains.

people at a conference

PyTorch was mentioned in the media in top technology publications such as The New Stack’s article on Why PyTorch Gets All the Love and InfoWorld’s article on how the TorchAO PyTorch library makes models faster and smaller.

We published 74 technical and community blogs, and nearly ten million people visited the PyTorch website throughout the year.

fire dancers at a conference

Thanks to each of you who helped make this year an outstanding success! The evolution and growth we’ve seen PyTorch undergo over the past year is driven by the passion, dedication, and ingenuity of this amazing community. Looking ahead to next year, we’re excited to build on this momentum as we continue to push the boundaries of AI.

Save the date for the PyTorch Conference which will be held October 22-23, 2025 in San Francisco. 2025 promises even greater innovation and stronger community collaboration.

Read More

Improve RAG performance with torch.compile on AWS Graviton Processors

Improve RAG performance with torch.compile on AWS Graviton Processors

Large Language Models (LLMs) are trained on vast volumes of data and use billions of parameters to support tasks like answering questions, translating languages, and completing sentences. There are a few challenges when working with LLMs such as domain knowledge gaps, factuality issues, and hallucination, which affect their reliability especially for the fields that require high levels of accuracy, such as healthcare, law, or engineering. Retrieval Augmented Generation (RAG) provides a solution to mitigate some of these issues by augmenting LLMs with a specific domain or an organization’s internal knowledge base, without the need to retrain the model.

The RAG knowledge source is generally business specific databases which are typically deployed on general-purpose CPU infrastructure. So, deploying RAG on general-purpose CPU infrastructure alongside related business services is both efficient and cost-effective. With this motivation, we evaluated RAG deployment on AWS Graviton based Amazon EC2 instances which have been delivering up to 40% price-performance advantage compared to comparable instances for the majority of the workloads including databases, in-memory caches, big data analytics, media codecs, gaming servers, and machine learning inference.

In the past we published a few blog posts on how PyTorch was optimized for AWS Graviton processors to accelerate ML Inference performance for both eager mode (blog) and torch.compile mode (blog). In this blog we cover how to deploy a typical RAG workload using PyTorch and torch.compile, how we improved its performance up to 1.7x for embedding model and 1.3x for RAG query on AWS Graviton3-based m7g.xlarge instance compared to the default PyTorch “eager mode”, and finally a few recommendations that you can apply for your RAG use cases.

How to Optimize RAG?

Without RAG, the LLM takes the user input and creates a response based on information it was trained on (what it already knows). With RAG, an information retrieval component is introduced that utilizes the user input to first pull information from a new data source. The user query and the relevant information are both given to the LLM. The LLM uses the new knowledge and its training data to create better responses. The following diagram shows the conceptual flow of using RAG with LLMs.

Image 1: Conceptual flow of using RAG with LLMs

Image 1: Conceptual flow of using RAG with LLMs

Source: https://aws.amazon.com/what-is/retrieval-augmented-generation/

Embedding model

At the core of RAG is an embedding model that takes the text data and converts into a vector representation. These vectors are then stored in a vector db. When a user makes a query, the query is first converted to a vector and the RAG does a similarity search on the vector db. Hence, the first step in optimizing RAG performance is optimizing an embedding model’s inference performance. We used the AWS Graviton3-based m7g.xlarge instance and the HuggingFace sentence-transformer embedding model for the optimization work. Here is a sample script for profiling the HuggingFace sentence-transformer embedding model inference with PyTorch Eager mode.

import torch
from torch.profiler import profile, ProfilerActivity, record_function
from transformers import AutoModel, AutoTokenizer

model_name = "sentence-transformers/all-mpnet-base-v2"
input_text = ["This is an example sentence", "Each sentence is converted"]

model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

encoded_input = tokenizer(
    input_text, padding=True, truncation=True, return_tensors="pt"
)

warmup, actual = 100, 100
model.eval()

with torch.no_grad():
    # warmup
    for i in range(warmup):
        embeddings = model(**encoded_input)

    with profile(activities=[ProfilerActivity.CPU]) as prof:
        with record_function("model_inference"):
            for i in range(actual):
                embeddings = model(**encoded_input)
        print(prof.key_averages().table(sort_by="self_cpu_time_total"))

Eager mode

Since PyTorch eager mode was already optimized on AWS Graviton processors with the following runtime environment settings, we included them in the baseline and measured the following performance. Please refer to Optimized PyTorch 2.0 Inference with AWS Graviton processors for more details on how we optimized the PyTorch eager mode on AWS Graviton processors.

# Enable the fast math GEMM kernels, to accelerate fp32 inference with bfloat16 gemm
export DNNL_DEFAULT_FPMATH_MODE=BF16

# Enable Linux Transparent Huge Page (THP) allocations,
# to reduce the tensor memory allocation latency
export THP_MEM_ALLOC_ENABLE=1

# Set LRU Cache capacity to cache the primitives and avoid redundant
# memory allocations
export LRU_CACHE_CAPACITY=1024
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                aten::addmm        61.01%        2.638s        62.49%        2.702s     370.197us          7300  
            model_inference        12.01%     519.161ms       100.00%        4.324s        4.324s             1  
                  aten::bmm         6.25%     270.084ms        11.96%     517.089ms     215.454us          2400  
               aten::select         3.98%     172.165ms         5.34%     230.863ms       1.331us        173500  
                aten::copy_         2.11%      91.133ms         2.11%      91.133ms       6.200us         14700   
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 4.324s

Table 1: Profiler output for HuggingFace sentence-transformer embedding model inference on AWS Graviton3-based m7g.xlarge instance with PyTorch Eager mode

Next, we added torch.compile, weights pre-packing, and torch.inference_mode and observed around 1.7x performance improvement. The following section talks about each of these optimizations and the resulting speedup.

torch.compile

In contrast to eager mode, the torch.compile pre-compiles the entire model into a single graph in a manner that’s optimized for running on given hardware. Please refer to Accelerated PyTorch Inference with torch.compile on AWS Graviton processors for more details on torch.compile features and how we optimized them on AWS Graviton processors. Invoke torch.compile as shown in the following snippet to trigger PyTorch dynamo compilation for the model. This resulted in around 1.04x performance improvement from the baseline.

model = torch.compile(model)

----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                        Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                 aten::addmm        64.46%        2.675s        66.66%        2.766s     378.905us          7300  
       Torch-Compiled Region        19.76%     820.085ms        99.04%        4.109s      41.094ms           100  
                   aten::bmm         6.66%     276.216ms        12.52%     519.527ms     216.470us          2400  
                aten::select         3.98%     164.991ms         5.41%     224.488ms       1.299us        172800  
            aten::as_strided         1.66%      69.039ms         1.66%      69.039ms       0.383us        180100  
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 4.149s

Table 2: Profiler output for HuggingFace sentence-transformer embedding model inference on AWS Graviton3-based m7g.xlarge instance with torch.compile mode

Weights pre-packing

torch.compile opens up opportunities like pre-packing the model weights into a format that is more suitable for the given hardware during the model compilation, thus improving the performance. Set the following config to trigger weights pre-packing. This resulted in around 1.69x improvement from the baseline.

import torch._inductor.config as config
config.cpp.weight_prepack=True
config.freezing=True
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
    mkldnn::_linear_pointwise        39.10%     994.821ms        41.50%        1.056s     144.628us          7300  
        Torch-Compiled Region        35.12%     893.675ms        98.42%        2.504s      25.043ms           100  
                    aten::bmm        10.96%     278.859ms        21.66%     551.073ms     229.614us          2400  
                 aten::select         7.34%     186.838ms         9.98%     253.840ms       1.469us        172800  
             aten::as_strided         2.63%      67.002ms         2.63%      67.002ms       0.388us        172800   
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.544s

Table 3: Profiler output for HuggingFace sentence-transformer embedding model inference on AWS Graviton3-based m7g.xlarge instance with torch.compile and weights pre-packing

torch.inference_mode

Additionally, use torch.inference_mode() to get savings from turning off version control for tensors and view tracking of tensors. Please refer to the PyTorch documentation for more details.

with torch.inference_mode():
# instead of
with torch.no_grad():
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
    mkldnn::_linear_pointwise        38.92%     987.276ms        41.17%        1.044s     143.056us          7300  
        Torch-Compiled Region        34.92%     885.895ms        98.45%        2.498s      24.975ms           100  
                    aten::bmm        11.25%     285.292ms        22.22%     563.594ms     234.831us          2400  
                 aten::select         7.74%     196.223ms        10.22%     259.251ms       1.500us        172800  
             aten::as_strided         2.48%      63.027ms         2.48%      63.027ms       0.365us        172800  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.537s

Table 4: Profiler output for HuggingFace sentence-transformer embedding model inference on AWS Graviton3-based m7g.xlarge instance with torch.compile, weights pre-packing, and inference_mode

The following table shows the incremental performance improvements achieved for the standalone embedding model inference.

Optimization level Latency measured (in sec) Improvement over the baseline
PyTorch eager mode (Baseline) 0.04324 NA
torch.compile 0.04149 1.04x
weights pre-packing 0.02544 1.69x
torch.inference_mode 0.02537 1.70x

The following script is an updated example for the embedding model inference with the previously discussed optimizations included. The optimizations are highlighted in GREEN.

import torch
from torch.profiler import profile, record_function, ProfilerActivity
from transformers import AutoTokenizer, AutoModel
import torch._inductor.config as config
config.cpp.weight_prepack=True
config.freezing=True

model_name = "sentence-transformers/all-mpnet-base-v2"
input_text = ['This is an example sentence', 'Each sentence is converted']

model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

encoded_input = tokenizer(input_text, padding=True, truncation=True, return_tensors='pt')

warmup , actual = 100, 100
model.eval()
model = torch.compile(model)

with torch.inference_mode():
#instead of with torch.no_grad()
# warmup
  for i in range(warmup):
  	embeddings = model(**encoded_input)

  with profile(activities=[ProfilerActivity.CPU]) as prof:
	with record_function("model_inference"):
  	for i in range(actual):
     	embeddings = model(**encoded_input)
  print(prof.key_averages().table(sort_by="self_cpu_time_total"))

End-to-End RAG scenario on CPU

After optimizing the embedding model inference, we started with a PyTorch eager mode based RAG setup, mainly to validate the functionality on the CPU backend. We built the RAG solution with HuggingFaceEmbeddings from langchain_community.embeddings, as shown in the following code snippet.

from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
from langchain.prompts import PromptTemplate
from langchain_core.prompts import format_document
from bs4 import BeautifulSoup as Soup
import torch

url =  "https://pytorch.org/blog/pytorch2-5/"
chunk_size = 1000
chunk_overlap = 0
embedding_model = "sentence-transformers/all-mpnet-base-v2"
N = 5

question = "What's new in PyTorch 2.5?"

from transformers import AutoTokenizer, AutoModel
from typing import Any, List

loader = RecursiveUrlLoader(
            url=url, max_depth=3, extractor=lambda x: Soup(x, "html.parser").text
        )       
docs = loader.load()

# Split the document into chunks with a specified chunk size
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
all_splits = text_splitter.split_documents(docs)

# Store the document into a vector store with a specific embedding model
model = HuggingFaceEmbeddings(model_name=embedding_model)

warmup , actual = 100, 100

with torch.inference_mode():
    vectorstore = FAISS.from_documents(all_splits, model)

    for i in range(warmup):
        searchDocs = vectorstore.similarity_search(question, k=N)

    import time

    start = time.time()
    for i in range(actual):
        searchDocs = vectorstore.similarity_search(question, k=N)
    end = time.time()
    print(f"Time for 1 inference is {(end-start)/actual} seconds")

    doc_prompt = PromptTemplate.from_template("{page_content}")
    context = ""
    for i, doc in enumerate(searchDocs):
        context += f"n{format_document(doc, doc_prompt)}n"

Next, our goal was to optimize the end-to-end RAG use case with torch.compile and weights pre-packing that gave 1.7x improvement for the standalone embedding model inference. However, the optimizations didn’t work out of the box for the RAG scenario.

What are the challenges and solutions to achieve similar gains in an end-to-end RAG scenario?

Challenge 1: model handle

There was no way to get the model handle that was instantiated with HuggingFaceEmbeddings, and the wrapper class doesn’t provide compile APIs. So, there was no way for our application to invoke torch.compile to trigger the PyTorch dynamo compilation process.

Solution

We implemented our custom embedding class so that we can get a handle for the model. This instantiated the embedding model from sentence-transformers , and maintained the handle for immediate compilation or compilation at a later stage. With this, we were able to trigger torch.compile and hence the dynamo compilation.

class CustomEmbedding(HuggingFaceEmbeddings):
    
    def __init__(self, **kwargs: Any):
        """Initialize the sentence_transformer."""
        super().__init__(**kwargs)

        # Load model from HuggingFace Hub
        self.client = AutoModel.from_pretrained(self.model_name)
    class Config:
        arbitrary_types_allowed = True


    
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Compute doc embeddings using a HuggingFace transformer model.
        Args:
            texts: The list of texts to embed.
        Returns:
            List of embeddings, one for each text.
        """

        texts = list(map(lambda x: x.replace("n", " "), texts))

        # Tokenize sentences
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
        
        embeddings = self.client(
           **encoded_input, output_hidden_states=True
        )
        embeddings = embeddings.pooler_output.detach().numpy()

        return embeddings.tolist()

# instead of model = HuggingFaceEmbeddings(model_name=embedding_model)
model = CustomEmbedding(model_name=embedding_model)

# torch.compile the model
model.client = torch.compile(model.client)

Challenge 2: triggering the optimization

For a typical inference scenario where the graph is frozen and gradient calculations are disabled, Torch inductor (the compiler backend we used for CPUs) invokes hardware specific optimizations like graph rewrite into more performant operators, operator fusion, and weights pre-packing. Though Torch dynamo was able to see the model and trigger generic compilation, it failed to trigger these additional Fx passes in the Torch inductor.

There were two main reasons for Torch inductor not triggering the optimization passes: (1) The application didn’t set no_grad() or inference_mode() for torch inductor to detect that the graph was frozen; and (2) We hit a limitation with the torch.compile framework, where, if the no_grad is set just at the beginning of the compiled region, torch.compile wouldn’t be able to detect it while invoking the inductor Fx passes because it would not have hit the no_grad region by then. Please refer to this GitHub issue for more details.

Solution

We work around this limitation by moving the no_grad() context into the application code from within the model class. With this, the model compilation happened as expected and gave around 1.3x performance improvement when we profiled the stable inference pass for eager and compiled versions.

Challenge 3: extra compilation

With the previous fixes, the query lookup inference performance was improved, but not the total execution time of the benchmarking script. We root-caused it to redundant compilation for the model during the RAG inference. Further deep diving revealed that it was because of the batch size mismatch between the word embedding and the RAG query stages. For example, in our benchmarking script, when the database was vectorized and stored in vector db, we used the batch size of 16, hence the model was compiled with shapes of 16xNxK. Whereas, the RAG query lookup is usually a single request of shape 1xNxK. So, there was a batch size mismatch (dimension “0” of these tensors) that triggered the recompilation for the query lookup stage. We confirmed it with the following Torch logging: TORCH_LOGS="recompiles"

TORCH_LOGS="recompiles" python rag_compile.py 
V1103 02:48:08.805986 34281 site-packages/torch/_dynamo/guards.py:2813] [0/1] [__recompiles] Recompiling function forward in site-packages/transformers/models/mpnet/modeling_mpnet.py:502
V1103 02:48:08.805986 34281 site-packages/torch/_dynamo/guards.py:2813] [0/1] [__recompiles]     triggered by the following guard failure(s):
V1103 02:48:08.805986 34281 site-packages/torch/_dynamo/guards.py:2813] [0/1] [__recompiles]     - 0/0: tensor 'L['input_ids']' size mismatch at index 0. expected 16, actual 1

Solution

Torch dynamo provides a decorator to mark the dimension of a given tensor as dynamic and specify an expected value for the same, so that re-compilation is not triggered. For example, specifying dimension “0” of input_ids and attention_mask as dynamic, and specifying that value of “1” is allowed in that dimension (as shown in the following code snippet), should have avoided the redundant compilations.

torch._dynamo.decorators.mark_unbacked(encoded_input['input_ids'], 0)
torch._dynamo.mark_dynamic(encoded_input['input_ids'], 1)
        torch._dynamo.decorators.mark_unbacked(encoded_input['attention_mask'], 0)
torch._dynamo.mark_dynamic(encoded_input['attention_mask'], 1)

However, the Torch dynamo decorator and marking didn’t work in this particular case. Moreover, using the decorator created graph breaks. So, we added some warmup iterations to hide the compilation latency, and profiled the query lookup performance in the steady state. However, the good news is that, in practice, this re-compilation is triggered only for the first query, so it might not affect the production scenario if the database size is fixed. Moreover, PyTorch AOT Inductor (a new feature in PyTorch) addresses re-compilation and warm up challenges with torch.compile. In a follow-up blog we will address how in a production environment we can use AOT Inductor to address these challenges.

With these solutions we were able to apply torch.compile, weights pre-packing and the AWS Graviton specific optimizations for an end-end RAG scenario and improve the performance by 1.3x from the baseline eager mode.

Deployment

A detailed guide on how to deploy torch compiled RAG on AWS Graviton-based Amazon EC2 instances and how to deploy it in conjunction with Llama using TorchServe can be found on the PyTorch website.

Conclusion

In this blog, we covered how we optimized embedding model inference performance on AWS Graviton3-based EC2 instances. We also shared the challenges faced, the solutions we implemented to bring those optimizations for a RAG use case, and the resulting speedups. We hope that you will give it a try! If you need any support with ML software on Graviton, please open an issue on the AWS Graviton Technical Guide GitHub.

We would like to express our gratitude to Eli Uriegas for the support in making this blog post happen.

Authors

Sunita Nadampalli is a Principal Engineer and AI/ML expert at AWS. She leads AWS Graviton software performance optimizations for AI/ML and HPC workloads. She is passionate about open source software development and delivering high-performance and sustainable software solutions for SoCs based on the Arm ISA.

Ankith Gunapal is an AI Partner Engineer at Meta (PyTorch). He leads customer support, evangelizing & release engineering of TorchServe. He is passionate about solving production problems in model inference and model serving. He also enjoys distilling technically complex material in a user friendly format.

Hamid Shojanazeri leads the AI Frameworks Partner Engineering team at Meta. He is passionate about building scalable AI solutions and specializes in working with PyTorch to tackle the challenges of large-scale distributed training, inference, model serving, and optimization.

Read More

docTR joins PyTorch Ecosystem: From Pixels to Data, Building a Recognition Pipeline with PyTorch and docTR

docTR joins PyTorch Ecosystem: From Pixels to Data, Building a Recognition Pipeline with PyTorch and docTR

docTR logo

We’re thrilled to announce that the docTR project has been integrated into the PyTorch ecosystem! This integration ensures that docTR aligns with PyTorch’s standards and practices, giving developers a reliable, community-backed solution for powerful OCR workflows.

For more information on what it means to be a PyTorch ecosystem project, see the PyTorch Ecosystem Tools page.

About docTR

docTR is an Apache 2.0 project developed and distributed by Mindee to help developers integrate OCR capabilities into applications with no prior knowledge required.

To quickly and efficiently extract text information, docTR uses a two-stage approach:

  • First, it performs text detection to localize words.
  • Then, it conducts text recognition to identify all characters in a word.

Detection and recognition are performed by state-of-the-art models written in PyTorch. To learn more about this approach, you can refer to the docTR documentation.

docTR enhances the user experience in PyTorch projects by providing high-performance OCR capabilities right out of the box. Its specially designed models require minimal to no fine-tuning for common use cases, allowing developers to quickly integrate advanced document analysis features.

Local installation

docTR requires Python >= 3.10 and supports Windows, Mac and Linux. Please refer to our README for necessary dependencies for MacBook with the M1 chip.

pip3 install -U pip
pip3 install "python-doctr[torch,viz]"

This will install docTR along with the latest version of PyTorch.

Note: docTR also provides docker images for an easy deployment, such as a part of Kubernetes cluster.

Text recognition

Now, let’s try docTR’s OCR recognition on this sample:

OCR sample

The OCR recognition model expects an image with only one word on it and will output the predicted word with a confidence score. You can use the following snippet to test OCR capabilities from docTR:

python
from doctr.io import DocumentFile
from doctr.models import recognition_predictor

doc = DocumentFile.from_images("/path/to/image")

# Load the OCR model
# This will download pre-trained models hosted by Mindee
model = recognition_predictor(pretrained=True)

result = model(doc)
print(result)

Here, the most important line of code is model = recognition_predictor(pretrained=True). This will load a default text recognition model, crnn_vgg16_bn, but you can select other models through the arch parameter. You can check out the available architectures.

When run on the sample, the recognition predictor retrieves the following data: [('MAGAZINE', 0.9872216582298279)]

Note: using the DocumentFile object docTR provides an easy way to manipulate PDF or Images.

Text detection

The last example was a crop on a single word. Now, what about an image with several words on it, like this one?

photo of magazines

A text detection model is used before the text recognition to output a segmentation map representing the location of the text. Following that, the text recognition is applied on every detected patch.

Below is a snippet to run only the detection part:

from doctr.io import DocumentFile
from doctr.models import detection_predictor
from matplotlib import pyplot as plt
from doctr.utils.geometry import detach_scores
from doctr.utils.visualization import draw_boxes

doc = DocumentFile.from_images("path/to/my/file")
model = detection_predictor(pretrained=True)

result = model(doc)

draw_boxes(detach_scores([result[0]["words"]])[0][0], doc[0])
plt.axis('off')
plt.show()

Running it on the full sample yields the following:

photo of magazines

Similarly to the text recognition, detection_predictor will load a default model (fast_base here). You can also load another one by providing it through the arch parameter.

The full implementation

Now, let’s plug both components into the same pipeline.

Conveniently, docTR provides a wrapper that does exactly that for us:

from doctr.io import DocumentFile
from doctr.models import ocr_predictor

doc = DocumentFile.from_images("/path/to/image")

model = ocr_predictor(pretrained=True, assume_straight_pages=False)

result = model(doc)
result.show()

photo of magazines

The last line should display a matplotlib window which shows the detected patches. Hovering the mouse over them will display their contents.

You can also do more with this output, such as reconstituting a synthetic document like so:

import matplotlib.pyplot as plt

synthetic_pages = result.synthesize()
plt.imshow(synthetic_pages[0])
plt.axis('off')
plt.show()

black text on white

The pipeline is highly customizable, where you can modify the detection or recognition model behaviors by passing arguments to the ocr_predictor. Please refer to the documentation to learn more about it.

Conclusion

We’re excited to welcome docTR into the PyTorch Ecosystem, where it seamlessly integrates with PyTorch pipelines to deliver state-of-the-art OCR capabilities right out of the box.

By empowering developers to quickly extract text from images or PDFs using familiar tooling, docTR simplifies complex document analysis tasks and enhances the overall PyTorch experience.

We invite you to explore the docTR GitHub repository, join the docTR community on Slack, and reach out at contact@mindee.com for inquiries or collaboration opportunities.

Together, we can continue to push the boundaries of document understanding and develop even more powerful, accessible tools for everyone in the PyTorch community.

Read More

torchcodec: Easy and Efficient Video Decoding for PyTorch

torchcodec: Easy and Efficient Video Decoding for PyTorch

We are pleased to officially announce torchcodec, a library for decoding videos into PyTorch tensors. It is fast, accurate, and easy to use. When running PyTorch models on videos, torchcodec is our recommended way to turn those videos into data your model can use.

Highlights of torchcodec include:

  • An intuitive decoding API that treats a video file as a Python sequence of frames. We support both index-based and presentation-time-based frame retrieval.
  • An emphasis on accuracy: we ensure you get the frames you requested, even if your video has variable frame rates.
  • A rich sampling API that makes it easy and efficient to retrieve batches of frames.
  • Best-in-class CPU decoding performance.
  • CUDA accelerated decoding that enables high throughput when decoding many videos at once.
  • Support for all codecs available in your installed version of FFmpeg.
  • Simple binary installs for Linux and Mac.

Easy to Use

A simple, intuitive API was one of our main design principles. We start with simple decoding and extracting specific frames of a video:

from torchcodec.decoders import VideoDecoder
from torch import Tensor

decoder = VideoDecoder("my_video.mp4")

# Index based frame retrieval.
first_ten_frames: Tensor = decoder[10:]
last_ten_frames: Tensor = decoder[-10:]

# Multi-frame retrieval, index and time based.
frames = decoder.get_frames_at(indices=[10, 0, 15])
frames = decoder.get_frames_played_at(seconds=[0.2, 3, 4.5])

All decoded frames are already PyTorch tensors, ready to be fed into models for training.

Of course, more common in ML training pipelines is sampling multiple clips from videos. A clip is just a sequence of frames in presentation order—but the frames are often not consecutive. Our sampling API makes this easy:

from torchcodec.samplers import clips_at_regular_timestamps

clips = clips_at_regular_timestamps(
  decoder,
  seconds_between_clip_starts=10,
  num_frames_per_clip=5,
  seconds_between_frames=0.2,
)

The above call yields a batch of clips where each clip starts 10 seconds apart, each clip has 5 frames, and those frames are 0.2 seconds apart. See our tutorials on decoding and sampling for more!

Fast Performance

Performance was our other main design principle. Decoding videos for ML training has different performance requirements than decoding videos for playback. A typical ML video training pipeline will process many different videos (sometimes in the millions!), but only sample a small number of frames (dozens to hundreds) from each video.

For this reason, we’ve paid particular attention to our decoder’s performance when seeking multiple times in a video, decoding a small number of frames after each seek. We present experiments with the following four scenarios:

  1. Decoding and transforming frames from multiple videos at once, inspired by what we have seen in data loading for large-scale training pipelines:

    a. Ten threads decode batches of 50 videos in parallel.
    b. For each video, decode 10 frames at evenly spaced times.
    c. For each frame, resize it to a 256×256 resolution.

  2. Decoding 10 frames at random locations in a single video.
  3. Decoding 10 frames at evenly spaced times of a single video.
  4. Decoding the first 100 frames of a single video.

We compare the following video decoders:

  • Torchaudio, CPU decoding only.
  • Torchvision, using the video_reader backend which is CPU decoding only.
  • Torchcodec, GPU decoding with CUDA.
  • Torchcodec, CPU decoding only.

Using the following three videos:

  1. A synthetically generated video using FFmpeg’s mandelbrot generation pattern. The video is 10 seconds long, 60 frames per second and 1920×1080.
  2. Same as above, except the video is 120 seconds long.
  3. A promotional video from NASA that is 206 seconds long, 29.7 frames per second and 960×540.

The experimental script is in our repo. Our experiments run on a Linux system with an Intel processor that has 22 available cores and an NVIDIA GPU. For CPU decoding, all libraries were instructed to automatically determine the best number of threads to use.

Benchmark chart

From our experiments, we draw several conclusions:

  • Torchcodec is consistently the best-performing library for the primary use case we designed it for: decoding many videos at once as a part of a training data loading pipeline. In particular, high-resolution videos see great gains with CUDA where decoding and transforms both happen on the GPU.
  • Torchcodec is competitive on the CPU with seek-heavy use cases such as random and uniform sampling. Currently, torchcodec’s performance is better with shorter videos that have a smaller file size. This performance is due to torchcodec’s emphasis on seek-accuracy, which involves an initial linear scan.
  • Torchcodec is not as competitive when there is no seeking; that is, opening a video file and decoding from the beginning. This is again due to our emphasis on seek-accuracy and the initial linear scan.

Implementing an approximate seeking mode in torchcodec should resolve these performance gaps, and it’s our highest priority feature for video decoding.

What’s Next?

As the name implies, the long-term future for torchcodec is more than just video decoding. Our next big feature is audio support—both decoding audio streams from video, and from audio-only media. In the long term, we want torchcodec to be the media decoding library for PyTorch. That means as we implement functionality in torchcodec, we will deprecate and eventually remove complementary features from torchaudio and torchvision.

We also have video decoding improvements lined up, such as the previously mentioned approximate seeking mode for those who are willing to sacrifice accuracy for performance.

Most importantly, we’re looking for feedback from the community! We’re most interested in working on features that the community finds valuable. Come share your needs and influence our future direction!

Read More

vLLM Joins PyTorch Ecosystem: Easy, Fast, and Cheap LLM Serving for Everyone

vLLM Joins PyTorch Ecosystem: Easy, Fast, and Cheap LLM Serving for Everyone

vllm logo

We’re thrilled to announce that the vLLM project has become a PyTorch ecosystem project, and joined the PyTorch ecosystem family!

Running large language models (LLMs) is both resource-intensive and complex, especially as these models scale to hundreds of billions of parameters. That’s where vLLM comes in — a high-throughput, memory-efficient inference and serving engine designed for LLMs.

Originally built around the innovative PagedAttention algorithm, vLLM has grown into a comprehensive, state-of-the-art inference engine. A thriving community is also continuously adding new features and optimizations to vLLM, including pipeline parallelism, chunked prefill, speculative decoding, and disaggregated serving.

Since its release, vLLM has garnered significant attention, achieving over 31,000 GitHub stars—a testament to its popularity and thriving community. This milestone marks an exciting chapter for vLLM as we continue to empower developers and researchers with cutting-edge tools for efficient and scalable AI deployment. Welcome to the next era of LLM inference!

vLLM has always had a strong connection with the PyTorch project. It is deeply integrated into PyTorch, leveraging it as a unified interface to support a wide array of hardware backends. These include NVIDIA GPUs, AMD GPUs, Google Cloud TPUs, Intel GPUs, Intel CPUs, Intel Gaudi HPUs, and AWS Neuron, among others. This tight coupling with PyTorch ensures seamless compatibility and performance optimization across diverse hardware platforms.

Do you know you can experience the power of vLLM right from your phone? During this year’s Amazon Prime Day, vLLM played a crucial role in delivering lightning-fast responses to millions of users. Across three regions, over 80,000 Trainium and Inferentia chips powered an average of 3 million tokens per minute, all while maintaining a P99 latency of less than 1 second for the first response. That means when customers opened the Amazon app and chatted with Rufus, they were seamlessly interacting with vLLM in action!

vLLM also collaborates tightly with leading model vendors to ensure support for popular models. This includes tight integration with Meta LLAMA, Mistral, QWen, and DeepSeek models, plus many others. One particularly memorable milestone was the release of LLAMA 3.1 (405B). As the launching partner, vLLM was the first to enable running this very large model, showcasing vLLM’s capability to handle the most complex and resource-intensive language models.

To install vLLM, simply run:

pip install vllm

vLLM is designed for both researchers and production-grade serving.

To run vLLM as an OpenAI API compatible server, just use the Huggingface model ID:

vllm serve meta-llama/Llama-3.1-8B

To run vLLM as a simple function:

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
   "Hello, my name is",
   "The president of the United States is",
   "The capital of France is",
   "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="meta-llama/Llama-3.1-8B")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
   prompt = output.prompt
   generated_text = output.outputs[0].text
   print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Open-source innovation is part of the vLLM’s DNA. Born out of a Berkeley academic project, it follows the legacy of other pioneering open-source initiatives such as BSD, which revolutionized operating systems in the 1980s. Other innovations from the same organization include Apache Spark and Ray, now the standard for big data and AI systems. In the Gen AI era, vLLM serves as a platform dedicated to democratizing AI inference.

The vLLM team remains steadfast in its mission to keep the project “of the community, by the community, and for the community.” Collaboration and inclusivity lie at the heart of everything we do.

If you have collaboration requests or inquiries, feel free to reach out at vllm-questions@lists.berkeley.edu. To join the active and growing vLLM community, explore our GitHub repository or connect with us on the vLLM Slack. Together, we can push the boundaries of AI innovation and make it accessible to all.

Read More

Accelerating 2D Dynamic Block Quantized Float8 GEMMs in Triton

Accelerating 2D Dynamic Block Quantized Float8 GEMMs in Triton

2D block quantization for Float8 (FP8) holds the promise of improving the accuracy of Float8 quantization while also accelerating GEMM’s for both inference and training. In this blog, we showcase advances using Triton for the two main phases involved in doing block quantized Float8 GEMMs.

For the incoming quantization of A and B tensors from high precision (BFloat16) to Float8, we showcase GridQuant which leverages a mini-grid stride loop style of processing with nearly 2x speedups (99.31%) over a current 2D block quantization kernel.

For the Float8 GEMM, we showcase 3 new developments for Triton – Warp Specialization, TMA and a persistent kernel to effectively create a cooperative style kernel (an alternative to the Ping-Pong schedule). As a result, we achieve ~1.2x speedup over our best-performing SplitK kernel from last year.

Figure 1: A comparison of the 2D quantization speedup over a current baseline, across a range of sizes.

Figure 1: A comparison of the 2D quantization speedup over a current baseline, across a range of sizes. (lower-is-better)

Why 2D Blockwise Quantization for FP8?

Generally speaking, the accuracy of fp8 quantization improves as we move from tensor-wise scaling, to row-wise scaling, to 2D block-wise, and then finally to column-wise scaling. This is because features for a given token are stored in each column, and thus each column in that tensor is more similarly scaled.

To minimize the number of outliers of a given numerical set, we want to find commonality so that numbers are being scaled in a similar fashion. For transformers, this means column based quantization could be optimal…however, columnar memory access is massively inefficient due to the data being laid out in memory in a rowwise contiguous manner. Thus columnwise loading would require memory access involving large strides in memory to pull isolated values, contrary to the core tenets of efficient memory access.

However, 2D is the next best option as it includes some aspects of columnar while being more memory efficient to pull since we can vectorize these loads with 2D vectorization. Therefore, we want to find ways to improve the speed for 2D block quantization which is why we developed the GridQuant kernel.

For the quantization process, we need to 2D block quantize both the higher precision BF16 incoming tensors (A = input activations, B = weights) and then proceed to do the Float8 matmul using the quantized tensors and their 2D block scaling values, and return an output C tensor in BF16.

How does GridQuant improve 2D block quantization efficiency?

The GridQuant kernel has several improvements over the initial baseline quantization implementation which was a standard tile based implementation. The GridQuant kernel has two full passes through the entire input tensor and works as follows:

Phase 1 – Determine the max abs value for each 256×256 sub block from the incoming high precision tensor.

1 – We divide the BF16 tensor into 256 x 256 sub blocks. This quantization size is configurable, but 256×256 is the default as it provides a blend of quantization precision and processing efficiency.

2 – Each 256×256 sub-block is subdivided into 64 sub-blocks arranged in an 8×8 pattern, with each sub-block processing a 32×32 element block. A single warp (32 threads) handles the computation for all elements within its assigned 32×32 block.

3 – We declare a 32×32 max_vals array in shared memory. This will store the current max val for each position i,j as the 2d vector block moves across the entire 256×256 sub_block.

This is an important improvement because it means we can do vectorized, rather than scalar, updates to the max vals scoring system and allows for much more efficient updates.

Figure 2: The Fractionalized layout of an incoming tensor - a grid of 256x256 is created across the tensor, and within each 256x256 block, it is further refined into 32x32 sub blocks. A 32x32 max_vals is created for each 256x256 block.

Figure 2: The Fractionalized layout of an incoming tensor – a grid of 256×256 is created across the tensor, and within each 256×256 block, it is further refined into 32×32 sub blocks. A 32×32 max_vals is created for each 256×256 block.

4 – Each warp processes a 32×32 chunk and because we are using 4 warps, we ensure the Triton compiler can pipeline the memory loads for the next 32×32 chunk with the actual processing of absmax calculations for the current chunk. This ensures that the warp scheduler is able to toggle warps loading data with those processing and keep the SM continuously busy.

5 – The 32×32 2D vector block processing is moved across and through the entire 256×256 subblock in a grid stride looping fashion, with each warp updating the shared memory 32×32 max_vals against its current 32×32 sub-block. Thus max_vals[i,j] holds the latest max value as each sub block is processed.

After completing the 256×256 block grid stride loop, the maxvals matrix is then itself reduced to find the absolute single max value for that entire 256 block.

This gives us our final scaling factor value for this 2D 256 x 256 block.

Phase 2 – Quantize the 256×256 block values to Float8, by using the single max value scaling factor found during Phase 1.

Next, we make a second pass through the entire 256×256 block to rescale all the numbers using this max value found in phase 1 to convert them to the float 8 format.

Because we know we need to do 2 complete passes, for the loads during the phase 1 portion we instruct the triton compiler to keep these values in cache at higher priority (evict policy = last).

This means that during the second pass, we can get a high hit rate from the L2 cache which provides much faster memory access than going all the way to HBM.

With the 2D block quantization processing complete when all 256 x256 blocks are processed, we can return the new Float8 quantized tensor along with it’s scaling factor matrix, which we’ll use in the next phase of the GEMM processing. This input quantization is repeated for the second input tensor as well, meaning we end up with A_Float 8, A_scaling_matrix, and B_Float8 and B_scaling matrix.

GridQuant – GEMM Kernel

The GridQuant-GEMM kernel takes in the four outputs from the quantization above for processing. Our high-performance GEMM kernel features several new Triton developments to achieve SOTA performance for matrix shape profiles relevant in LLM inference during the decoding phase.

These new features are commonly found in Hopper optimized kernels like FlashAttention-3 and Machete, built using CUTLASS 3.x. Here, we discuss these methods and showcase the performance benefits that can be achieved leveraging them in Triton.

Tensor Memory Accelerator (TMA)

The TMA unit on NVIDIA Hopper GPUs, is a dedicated hardware unit for load/store operations that act on multidimensional tensors commonly found in AI workloads. This has several important benefits.

Transferring data from global and shared memory can occur without involving other resources on GPU SMs, freeing up registers and CUDA Cores. Further, when used in warp-specialized kernels, light-weight TMA operations can be assigned to a producer warp allowing for a high degree of overlap of memory transfers and computation.

For more details on how TMA is used in Triton see our previous blog.

Warp-Specialization (Cooperative Persistent Kernel Design)

Warp Specialization is a technique to leverage pipeline parallelism on GPUs. This experimental feature enables the expression of specialized threads through a tl.async_task API, allowing the user to specify how operations in a Triton program should be “split” amongst warps. The cooperative Triton kernel performs different types of computation and loads that each take place on their own dedicated hardware. Having dedicated hardware for each of these specialized tasks makes it possible to realize parallelism efficiently for operations that have no data dependency.

Figure 3. Logical view of dedicated HW units in NVIDIA H100 SM

Figure 3. Logical view of dedicated HW units in NVIDIA H100 SM

The operations in our kernel that create the pipeline are:

A – Load per-block scale from GMEM into SMEM (cp.async engine)

B – Load activation (A) and Weight (B) tiles from GMEM into SMEM (TMA)

C – Matrix-Multiplication of A tile and B tile = C tile (Tensor Core)

D – Scale C tile with per-block scale from A and per-block scale from B (CUDA core)

These steps can be assigned to “tasks” which are carried out by specialized warp groups in a threadblock. The cooperative strategy has three warp groups. A producer warp group that is responsible for feeding the compute units and 2 consumer warp groups that perform the computation. The two consumer warp groups each work on half of the same output tile.

Figure 4. Warp-Specialized Persistent Cooperative kernel

Figure 4. Warp-Specialized Persistent Cooperative kernel (source: NVIDIA)

This is different from the ping-pong schedule we discussed in our previous blog, where each consumer warp group works on different output tiles. We note that the Tensor Core ops are not overlapped with the epilogue computation. Decreased utilization of the Tensor Core pipeline during the epilogue phase of the computation will reduce register pressure for the consumer warp group compared to ping-pong which always keeps the Tensor Core busy, thus allowing for larger tile sizes.

Lastly, our kernel is designed to be persistent when the grid size exceeds the number of available compute units on H100 GPUs (132). Persistent kernels remain active on the GPU for an extended period and compute multiple output tiles during its lifetime. Our kernel leverages TMA async shared to global memory stores, while continuing to do work on the next output tile as opposed to incurring the cost of scheduling multiple threadblocks.

Microbenchmarks

Figure 5: Latency comparison (us) of Gridquant-GEMM vs our best performing SplitK kernel for small batch regime and Llama3 8192 N,K sizing.

Figure 5: Latency comparison (us) of Gridquant-GEMM vs our best performing SplitK kernel for small batch regime and Llama3 8192 N,K sizing. (lower-is-better)

The Warp-Specialized Triton kernel achieves SOTA performance at the above small-M and square matrix shapes, achieving a nearly 1.2x speedup over the SplitK Triton kernel, which was the previous best performing strategy for Triton GEMMs in this low arithmetic intensity regime. For future work, we plan to tune our kernel performance for the medium-to-large M regime and non-square matrices.

Conclusion and Future Work

Future work includes benchmarking gridquant on end to end workflows. In addition, we plan to run more extensive benchmarks on non-square (rectangular) matrices as well as medium-to-large M sizes. Finally, we plan to explore ping-pong style warp-specialization in Triton versus the current cooperative implementation.

Read More

HadaCore: Tensor Core Accelerated Hadamard Transform Kernel

HadaCore: Tensor Core Accelerated Hadamard Transform Kernel

IBM: Krish Agarwal, Rishi Astra, Adnan Hoque, Mudhakar Srivatsa, Raghu Ganti
Meta: Less Wright, Sijia Chen

Quantization is a method for improving model inference speeds by compressing model weights and performing (faster) computation in lower precision data types. However, quantization can result in accuracy loss due to the presence of outliers. Recent works like QuaRot, SpinQuant, and FlashAttention-3 introduce methods to increase the numerical accuracy of INT4, INT8 and FP8 quantization in LLMs. These methods rely on Hadamard Transforms. In this blog, we present HadaCore, a Hadamard Transform CUDA kernel that achieves state-of-the-art performance on NVIDIA A100 and H100 GPUs. Our kernel achieves speedups of 1.1–1.4x and 1.0–1.3x, with a peak gain of 3.5x and 3.6x respectively, over Dao AI Lab’s Fast Hadamard Transform Kernel. We leverage a hardware-aware work decomposition that benefits from Tensor Core acceleration while maintaining quantization error reduction.

Figure 1: Speedup of HadaCore vs Dao AI Hadamard CUDA kernel. A peak gain of 3.46x on the A100 is achieved using 128 rotation by 8.4M elements.

Figure 1: Speedup of HadaCore vs Dao AI Hadamard CUDA kernel. A peak gain of 3.46x on the A100 is achieved using 128 rotation by 8.4M elements.

The HadaCore Kernel is publicly available.

Background

QuaRot and SpinQuant both propose methods to increase the numerical accuracy of INT4 and INT8 quantization in LLMs. Both methods rotate model activations since rotations are statistically likely to reduce the magnitude of outliers, as it “distributes” extreme values among other (less extreme) dimensions, and rotation is also an easily invertible operation using the inverse of the rotation matrix. These methods can also improve FP8 inference accuracy, such as in FlashAttention-3.

Figure 2. Transformer block showing online (red) and offline rotations (blue) in QuaRot

Figure 2. Transformer block showing online (red) and offline rotations (blue) in QuaRot

Applying these rotation matrices introduces model runtime overhead due to the online operations shown in Figure 2. These rotations can be applied through matrix multiplication, but the added overhead would diminish the benefits from quantization. Therefore, QuaRot and SpinQuant opt to use Walsh-Hadamard matrices, a special type of rotation matrix that can be applied faster than matrix multiplication using the Fast Walsh-Hadamard Transform algorithm. HadaCore is an optimized implementation of this algorithm for NVIDIA GPUs that support Tensor Cores.

Tensor Core Accelerated Hadamard Transform

HadaCore leverages NVIDIA Tensor Cores, which are specialized compute units on NVIDIA GPUs optimized for matrix multiplication. To achieve this, our kernel performs a hardware-aware work decomposition of the Fast Walsh-Hadamard algorithm. This work decomposition ensures that we can utilize the MMA PTX instructions that execute on the Tensor Core chip. HadaCore applies a 16×16 Hadamard transform to chunks of the input data. The computation can then be offloaded to the FP16 Tensor Core with usage of the mma.m16n8k16 instruction. The warp-level parallelism for HadaCore is shown below.

Figure 3: HadaCore Parallelization, 1x256 vectors (rows) being rotated by a size 256 Hadamard.

Figure 3: HadaCore Parallelization, 1×256 vectors (rows) being rotated by a size 256 Hadamard.

We process fragments of 256 elements in parallel using warp-level Tensor Core operations to achieve up to a 256-size Hadamard transform. For further sizes, we shuffle data between warps and repeat.

Microbenchmarks

We benchmark HadaCore against the Dao AI Lab Hadamard Kernel on both NVIDIA H100 and A100 GPUs across varying Hadamard and input tensor sizes.

Figure 4:  HadaCore Kernel Speedup on NVIDIA A100 over Dao AI Lab Fast Hadamard Kernel

Figure 4: HadaCore Kernel Speedup on NVIDIA A100 over Dao AI Lab Fast Hadamard Kernel

Color coded Speedup Table for NVIDIA A100, Green = Speedup over Baseline

Color coded Speedup Table for NVIDIA A100, Green = Speedup over Baseline

Figure 5:  HadaCore Kernel Speedup on NVIDIA H100 over Dao AI Lab Fast Hadamard Kernel

Figure 5: HadaCore Kernel Speedup on NVIDIA H100 over Dao AI Lab Fast Hadamard Kernel

Color coded Speedup Table for NVIDIA H100, Green = Speedup over Baseline

Color coded Speedup Table for NVIDIA H100, Green = Speedup over Baseline

We showcase our speedup as the input tensor size (labeled element count) in our charts increase. Element count is the number of elements in the target matrix we are rotating. For example, in multi-head attention:

The queries (Q), keys (K) and values (V) tensors are 4D tensors of size:

(batch_size, seq_len, n_heads, head_dim)

A Hadamard matrix of size head_dim is applied to these activation tensors, so we refer to this as using a Hadamard size of head_dim with an element count of:

batch_size*seq_len*n_heads*head_dim.

Common element counts for query rotations in an attention block:

Model Tokens Prefill Decoding
Llama-2 70b 33,554,432 elements

128 Hadamard size

(1 batch * 64 heads * 4096 tokens * 128 dimensional embeddings per head per token)

8192 elements

128 Hadamard size

(1 batch * 64 heads * 1 token * 128 dimensional embeddings per head per token)
Llama-3 8b 33,554,432 elements

128 Hadamard size

(1 batch * 32 heads * 8192 tokens * 128 dimensional embeddings per head per token)
4,096 elements

128 Hadamard size

(1 batch * 32 heads * 1 token * 128 dimensional embeddings per head per token)

HadaCore achieves 1.1–1.4x speedup on A100 and 1.0–1.3x speedup on H100 over Dao AI Lab’s Fast Hadamard kernel, with a peak gain of 3.5x and 3.6x, respectively. For smaller sizes on H100, HadaCore’s gain decreases. For future work, we plan to incorporate usage of Hopper specific features like TMA and WGMMA for improved H100 performance.

MMLU Benchmarks

We evaluated MMLU scores on a Llama 3.1-8B inference workload where the FlashAttention computation was performed in FP8. Newer generation NVIDIA Hopper GPUs come equipped with FP8 Tensor Cores that deliver substantial compute gain over FP16.

Our results show the benefit of using HadaCore for accuracy preservation when combined with optimizations such as FP8 FlashAttention.

Format Method Llama3.1-8B

Avg. 5-Shot MMLU Accuracy
Q, K, V: FP16

FlashAttention: FP16
N/A 65.38
Q, K, V: FP16

FlashAttention: FP8
No Hadamard 64.40
Q, K, V: FP8

FlashAttention: FP8
HadaCore 65.09
Q, K, V: FP8

FlashAttention: FP8
Dao AI Fast Hadamard Kernel 65.45

Table 1: MMLU scores for Llama3.1 8B with FP16 baseline and FP8 attention using Hadamard transforms, comparing an implementation with explicit Hadamard matrix multiplications vs. HadaCore (higher is better)

From the above MMLU scores, we note that for Llama3.1-8B inference with FP8 attention, HadaCore improves the quantization error introduced from computing attention in a lower precision.

Conclusion

We showcased our speedups achieved by moving the Fast-Walsh Hadamard algorithm into a CUDA kernel that leverages Tensor Core acceleration and achieves a peak speedup of 3.5x and 3.6x over the Dao AI Fast-Hadamard kernel on NVIDIA A100 and H100, respectively.

Further, we showed on the MMLU benchmark that rotating with HadaCore maintains similar quantization error reduction to the Fast-Hadamard kernel, while providing computational acceleration.

Future Work

We plan to implement a Triton version of our kernel and experiment with more advanced techniques such as kernel fusion to support fused Hadamard transform and quantization. Further, we plan to extend our kernel to support BF16 Tensor Core compute.

Read More

Supercharging Training using float8 and FSDP2

Supercharging Training using float8 and FSDP2

IBM: Tuan Hoang Trong, Alexei Karve, Yan Koyfman, Linsong Chu, Divya Kumari, Shweta Salaria, Robert Walkup, Praneet Adusumilli, Nirmit Desai, Raghu Ganti, Seetharami Seelam
Meta: Less Wright, Wei Feng, Vasiliy Kuznetsov, Driss Guesseous

In this blog, we will demonstrate how we achieve up to 50% throughput speedup while achieving loss and evaluation benchmark parity in training over FSDP1 bf16 training. We achieve this speedup by leveraging FSDP2, DTensor, and torch.compile with torchao’s float8 via linear layer updates (compute), and float8 all_gathers for weight communication. We showcase these improvements across a spectrum of Meta LLaMa model architecture sizes, ranging from small 1.8B model size all the way to 405B model size, making training faster than ever.

We demonstrate these improvements using the Meta Llama3 architecture, and then perform model quality studies at two scales: 100B tokens at 8B model size, and 50B tokens at 70B model size, which provide an exact comparison of float8 and bf16 training loss curves. We demonstrate that the loss curves result in identical loss convergence across these model training runs compared to the bf16 counterpart. Further, we train a 3B model to 1T tokens using the FineWeb-edu dataset and run standard evaluation benchmarks to ensure that the model quality is intact and comparable to a bf16 run.

At IBM Research, we plan to adopt these capabilities for our data ablations to improve the number of experiments we can perform in a given GPU budget. Longer term, we will follow up with a larger scale model run to demonstrate the end-to-end feasibility of float8 training.

What is Float8?

The float8 format for training models was introduced by NVIDIA, ARM, and Intel in a 2022 paper which demonstrated the feasibility of training using lower precision float8, without sacrificing model quality. With the introduction of newer GPUs like the NVIDIA Hopper series, FP8 training became feasible with the potential of more than 2x improvement in training throughput due to native float8 tensor core support. There are a few challenges to realize this promise:
(i) Enable the core model operations like matmul and attention in float8,
(ii) Enable float8 training in a distributed framework, and
(iii) Enable weight communication between GPUs in float8.
While the float8 matmul was enabled by NVIDIA libraries, the latter two were provided in recent updates to FSDP2 and torchao.

In this blog, we are using torchtitan as the entry point for training, IBM’s deterministic data loader, the float8 linear layer implementation from torchao, and the float8 all gather from the latest PyTorch nightlies in conjunction with FSDP2. For this training, we are using the float8 per tensor (tensorwise) scaling granularity rather than rowwise. We leverage torch.compile to ensure that we get maximum performance gains. We are computing attention in bf16 using SDPA and are currently working on moving this to float8 as well.

Experiments

We perform various experiments to demonstrate the benefits of float8 training. The first is to ensure that model quality is not sacrificed. To verify this, we train an 8B model and 70B model for a few thousand steps and compare the loss curves between both the float8 and bf16 training run. Our experiments are performed on three different H100 clusters with 128, 256, and 512 H100 GPU configurations in very different environments to demonstrate reproducibility. The first cluster is customized on Grand Teton in Meta with 400Gbps custom interconnect, the second is an IBM research cluster with 3.2Tbps Infiniband interconnect, and the third is an IBM Cloud cluster with 3.2Tbps RoCE interconnect for GPU-to-GPU communication.

First, we plot the loss curve comparisons for both these models in the below figures to demonstrate loss parity for a few thousand steps.

Figure 1: (a) 8B model loss parity for 2k steps, (b) 70B loss parity for 1k steps

Figure 1: (a) 8B model loss parity for 2k steps, (b) 70B loss parity for 1k steps

Figure 1: (a) 8B model loss parity for 2k steps, (b) 70B loss parity for 1k steps

We observe that across these different models and in different environments, we obtain loss parity for the small scale of tokens. Next, we characterize the throughput gains for four different model sizes ranging from 1.8B to 405B. We explored the best batch size and activation checkpointing schemes for both the float8 and bf16 training runs to determine the tokens/sec/GPU (wps) metric and report the performance gain. For the 405B model, we leveraged DTensor for tensor parallel training with FSDP2. We use a sequence length of 8K for all our measurements.

Model size wps (bf16) wps (float8) Percent gain
1.8B 29K 35K 18%
8B 8K 10K 28%
70B 956 1430 50%
405B (TP4) 149 227 52%

Table 1: Performance gains over bf16 (both bf16 and float8 use torch.compile)

We observe from Table 1 that the gains for larger models (70B and 405B) reach up to 50%, the smaller models see gains between roughly 20 and 30%. In further experiments, we observed that the addition of float8 all_gather enables a boost of ~5% beyond the compute itself in float8, which is inline with the observations in this blog.

Second, to demonstrate the effectiveness of an FP8 model, we trained a 3B model following the Llama3 architecture for 1T tokens using the FineWeb-edu dataset from Hugging Face. We performed evaluations using the lm-eval-harness framework and present a small portion of these results in the below table. We observe that the bf16 performance is marginally better than the float8 scores (about one percent). While some scores are significantly better with bf16 (e.g., MMLU is 3 pts higher), we expect these gaps to vanish when the right hyper parameters are chosen and across larger scale training runs (e.g., the bf16 run had half the batch size and it is well known that smaller batch size runs can improve evaluation scores).

Benchmark Score (float8) Score (bf16)
MMLU (5-shot) 0.26 0.29
ARC-e 0.73 0.73
ARC-c 0.43 0.46
Hellaswag 0.65 0.67
sciq 0.89 0.88
OpenBook QA 0.43 0.43
PIQA 0.76 0.76
Winogrande 0.60 0.65
Average 0.59 0.60

Table 2: Benchmark scores for float8 trained model running in FP16 for eval (at 1T tokens of FineWeb pre-training).

Finally, we scale our experiments to 512 H100 GPUs on the IBM Cloud cluster. We were able to recreate the results and speedups that we observed even at 512 GPU scale. We summarize these results only for the large models in the below table (70B and 405B).

Model size wps (bf16) wps (float8) Percent gain
70B 960 1448 51%
405B (TP4) 152 217 43%

Table 3: Performance gains over bf16 (both bf16 and float8 use torch.compile) for 512 GPU scale

Future work

We are also working on evaluating other forms of parallelism such as Context Parallelism. We plan to evaluate all of these features to demonstrate the composability and ability to make choices for training large scale models.

Acknowledgements

We thank Davis Wertheimer from IBM Research for enabling the data loader for torchtitan runs enabling us to replay data in the same order across multiple runs. We also thank IBM Cloud for enabling us with early test access to the H100 cluster.

Read More

Rebellions logo

Rebellions Joins the PyTorch Foundation as a General Member

Rebellions logo

The PyTorch Foundation, a neutral home for the deep learning community to collaborate on the open source PyTorch framework and ecosystem, is announcing today that Rebellions has joined as a general member.

Rebellions is a South Korea-based semiconductor company specializing in the design and development of AI chips for data centers and edge devices. Their innovative hardware and software solutions aim to accelerate generative AI and machine learning workloads, focusing on high energy efficiency and performance. The company successfully launched and deployed its AI chip ‘ATOM’ targeting data centers in 2023 and is developing its next-generation AI accelerator ‘REBEL’.

“We’re thrilled to welcome Rebellions as a new general member of the PyTorch Foundation,” said Matt White, Executive Director of the PyTorch Foundation. “Rebellions brings a unique perspective to the PyTorch ecosystem with their focus on advancing the integration of NPU architectures for AI acceleration with PyTorch. Their expertise will play a vital role in ensuring PyTorch continues to evolve as a versatile framework, accommodating the diverse needs of modern AI workloads. We look forward to collaborating with Rebellions to drive innovation and strengthen the PyTorch ecosystem for developers worldwide.”

Rebellions has introduced native support for PyTorch 2.0 in their RBLN SDK. This integration includes compatibility with torch.compile, a pivotal feature of PyTorch 2.0 that enhances model performance. Through this development, Rebellions has empowered developers to seamlessly harness the full potential of their AI accelerator lineup within the environment.

Rebellions is also deeply committed to advancing the PyTorch ecosystem through collaborative innovation starting in Korea. The company has established a Special Interest Group (SIG) focusing on Pytorch Core within the PyTorch Korea community and is actively working with volunteers recruited through MODULABS, an open research institute, to integrate native support for the deep learning framework into their Neural Processing Unit (NPU).

In addition, Rebellions is collaborating with academic institutions, such as Yonsei University, Hanyang University, University of Science & Technology (UST) and national agencies, such as the Electronics and Telecommunications Research Institute (ETRI), to offer undergraduate and graduate courses on PyTorch and enable them to leverage Pytorch as their research platform.

These initiatives highlight Rebellions’ dedication to optimizing the PyTorch experience for developers and researchers alike, while also fostering education and innovation in the field.

“By integrating our hardware innovations with PyTorch, we’re building Native NPU support to accelerate diverse AI workloads.” said Hong-seok Kim, the Chief Software Architect at Rebellions. “We’re excited to contribute to the PyTorch community by community-driven initiatives and partnerships, advancing NPU architecture support for next-generation AI solutions. Together with the PyTorch community, we aim to pioneer new possibilities in AI acceleration and empower developers worldwide with efficient computing solutions.”

To learn more about how your organization can be a part of the PyTorch Foundation, visit our website.

About Rebellions

Rebellions is a South Korea-based semiconductor company specializing in the design and development of AI chips for data centers and edge devices. Their innovative hardware and software solutions aim to accelerate generative AI and machine learning workloads, focusing on high energy efficiency and performance. The company successfully launched and deployed its AI chip ‘ATOM’ targeting data centers in 2023 and is developing its next-generation AI accelerator ‘REBEL’ incorporating a scalable chiplet architecture and high-bandwidth memory.

About PyTorch Foundation

The PyTorch Foundation is a neutral home for the deep learning community to collaborate on the open source PyTorch framework and ecosystem. The PyTorch Foundation is supported by its members and leading contributors to the PyTorch open source project. The Foundation leverages resources provided by members and contributors to enable community discussions and collaboration.

About The Linux Foundation

The Linux Foundation is the world’s leading home for collaboration on open source software, hardware, standards, and data. Linux Foundation projects are critical to the world’s infrastructure including Linux, Kubernetes, Node.js, ONAP, PyTorch, RISC-V, SPDX, OpenChain, and more. The Linux Foundation focuses on leveraging best practices and addressing the needs of contributors, users, and solution providers to create sustainable models for open collaboration. For more information, please visit us at linuxfoundation.org.

Read More

Distilling Llama3.1 8B into 1B in torchtune

Distilling Llama3.1 8B into 1B in torchtune

In this blog, we present a case study on distilling a Llama 3.1 8B model into Llama 3.2 1B using torchtune’s knowledge distillation recipe. We demonstrate how knowledge distillation (KD) can be used in post-training to improve instruction-following task performance and showcase how users can leverage the recipe.

What is Knowledge Distillation?

Knowledge Distillation is a widely used compression technique that transfers knowledge from a larger (teacher) model to a smaller (student) model. Larger models have more parameters and capacity for knowledge, however, this larger capacity is also more computationally expensive to deploy. Knowledge distillation can be used to compress the knowledge of a larger model into a smaller model. The idea is that performance of smaller models can be improved by learning from larger model’s outputs.

How does Knowledge Distillation work?

Knowledge is transferred from the teacher to student model by training on a transfer set where the student is trained to imitate the token-level probability distributions of the teacher. The assumption is that the teacher model distribution is similar to the transfer dataset. The diagram below is a simplified representation of how KD works.

Figure 1: Simplified representation of knowledge transfer from teacher to student model

Figure 1: Simplified representation of knowledge transfer from teacher to student model

As knowledge distillation for LLMs is an active area of research, there are papers, such as MiniLLM, DistiLLM, AKL, and Generalized KD, investigating different loss approaches. In this case study, we focus on the standard cross-entropy (CE) loss with the forward Kullback-Leibler (KL) divergence loss as the baseline. Forward KL divergence aims to minimize the difference by forcing the student’s distribution to align with all of the teacher’s distributions.

Why is Knowledge Distillation useful?

The idea of knowledge distillation is that a smaller model can achieve better performance using a teacher model’s outputs as an additional signal than it could training from scratch or with supervised fine-tuning. For instance, Llama 3.2 lightweight 1B and 3B text models incorporated logits from Llama 3.1 8B and 70B to recover performance after pruning. In addition, for fine-tuning on instruction-following tasks, research in LLM distillation demonstrates that knowledge distillation methods can outperform supervised fine-tuning (SFT) alone.

Model Method DollyEval Self-Inst S-NI
GPT-4 Eval GPT-4 Eval Rouge-L
Llama 7B SFT 73.0 69.2 32.4
KD 73.7 70.5 33.7
MiniLLM 76.4 73.1 35.5
Llama 1.1B SFT 22.1 27.8
KD 22.2 28.1
AKL 24.4 31.4
OpenLlama 3B SFT 47.3 41.7 29.3
KD 44.9 42.1 27.9
SeqKD 48.1 46.0 29.1
DistiLLM 59.9 53.3 37.6

Table 1: Comparison of knowledge distillation approaches to supervised fine-tuning

Below is a simplified example of how knowledge distillation differs from supervised fine-tuning.

Supervised fine-tuning Knowledge distillation
   
model = llama3_2_1b()
ce_loss = CrossEntropyLoss()
kd_loss = ForwardKLLoss()

tokens, labels = batch["tokens"], batch["labels"]
logits = model(tokens, ...)

loss = ce_loss(logits, labels)
loss.backward()

   
   
   
model = llama3_2_1b()
teacher_model = llama3_1_8b()
ce_loss = CrossEntropyLoss()
kd_loss = ForwardKLLoss()

tokens, labels = batch["tokens"], batch["labels"]
logits = model(tokens, ...)
teacher_logits = teacher_model(tokens, ...)
loss = ce_loss(logits, labels) + kd_loss(logits, teacher_logits, labels)
loss.backward()
   
   

KD recipe in torchtune

With torchtune, we can easily apply knowledge distillation to Llama3, as well as other LLM model families, using torchtune’s KD recipe. The objective for this recipe is to fine-tune Llama3.2-1B on the Alpaca instruction-following dataset by distilling from Llama3.1-8B. This recipe focuses on post-training and assumes the teacher and student models have already been pre-trained.

First, we have to download the model weights. To be consistent with other torchtune fine-tuning configs, we will use the instruction tuned models of Llama3.1-8B as teacher and Llama3.2-1B as student.

tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

In order for the teacher model distribution to be similar to the Alpaca dataset, we will fine-tune the teacher model using LoRA. Based on our experiments, shown in the next section, we’ve found that KD performs better when the teacher model is already fine-tuned on the target dataset.

tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device

Finally, we can run the following command to distill the fine-tuned 8B model into the 1B model on a single GPU. For this case study, we used a single A100 80GB GPU. We also have a distributed recipe for running on multiple devices.

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device

Ablation studies

In this section, we demonstrate how changing configurations and hyperparameters can affect performance. By default, our configuration uses the LoRA fine-tuned 8B teacher model, downloaded 1B student model, learning rate of 3e-4 and KD loss ratio of 0.5. For this case study, we fine-tuned on the alpaca_cleaned_dataset and evaluated the models on truthfulqa_mc2, hellaswag and commonsense_qa tasks through the EleutherAI LM evaluation harness. Let’s take a look at the effects of:

  1. Using a fine-tuned teacher model
  2. Using a fine-tuned student model
  3. Hyperparameter tuning of KD loss ratio and learning rate

Using a fine-tuned teacher model

The default settings in the config uses the fine-tuned teacher model. Now, let’s take a look at the effects of not fine-tuning the teacher model first.

Taking a loss at the losses, using the baseline 8B as teacher results in a higher loss than using the fine-tuned teacher model. The KD loss also remains relatively constant, suggesting that the teacher model should have the same distributions as the transfer dataset.

Figure 2: (left to right) KD loss from forward KL divergence, class loss from cross entropy, total loss: even combination of KD and class loss.

Figure 2: (left to right) KD loss from forward KL divergence, class loss from cross entropy, total loss: even combination of KD and class loss.

In our benchmarks, we can see that supervised fine-tuning of the 1B model achieves better accuracy than the baseline 1B model. By using the fine-tuned 8B teacher model, we see comparable results for truthfulqa and improvement for hellaswag and commonsense. When using the baseline 8B as a teacher, we see improvement across all metrics, but lower than the other configurations.

Model TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
Baseline Llama 3.1 8B 0.5401 0.5911 0.7915 0.7707
Fine-tuned Llama 3.1 8B using LoRA 0.5475 0.6031 0.7951 0.7789
Baseline Llama 3.2 1B 0.4384 0.4517 0.6064 0.5536
Fine-tuned Llama 3.2 1B using LoRA 0.4492 0.4595 0.6132 0.5528
KD using baseline 8B as teacher 0.444 0.4576 0.6123 0.5561
KD using fine-tuned 8B as teacher 0.4481 0.4603 0.6157 0.5569

Table 2: Comparison between using baseline and fine-tuned 8B as teacher model

Using a fine-tuned student model

For these experiments, we look at the effects of KD when the student model is already fine-tuned. We analyze the effects using different combinations of baseline and fine-tuned 8B and 1B models.

Based on the loss graphs, using a fine-tuned teacher model results in a lower loss irrespective of whether the student model is fine-tuned or not. It’s also interesting to note that the class loss starts to increase when using a fine-tuned student model.

Figure 3: Comparing losses of different teacher and student model initializations

Figure 3: Comparing losses of different teacher and student model initializations

Using the fine-tuned student model boosts accuracy even further for truthfulqa, but the accuracy drops for hellaswag and commonsense. Using a fine-tuned teacher model and baseline student model achieved the best results on hellaswag and commonsense dataset. Based on these findings, the best configuration will change depending on which evaluation dataset and metric you are optimizing for.

Model TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
Baseline Llama 3.1 8B 0.5401 0.5911 0.7915 0.7707
Fine-tuned Llama 3.1 8B using LoRA 0.5475 0.6031 0.7951 0.7789
Baseline Llama 3.2 1B 0.4384 0.4517 0.6064 0.5536
Fine-tuned Llama 3.2 1B using LoRA 0.4492 0.4595 0.6132 0.5528
KD using baseline 8B and baseline 1B 0.444 0.4576 0.6123 0.5561
KD using baseline 8B and fine-tuned 1B 0.4508 0.448 0.6004 0.5274
KD using fine-tuned 8B and baseline 1B 0.4481 0.4603 0.6157 0.5569
KD using fine-tuned 8B and fine-tuned 1B 0.4713 0.4512 0.599 0.5233

Table 3: Comparison using baseline and fine-tuned teacher and student models

Hyperparameter tuning: learning rate

By default, the recipe has a learning rate of 3e-4. For these experiments, we changed the learning rate from as high as 1e-3 to as low as 1e-5.

Based on the loss graphs, all learning rates result in similar losses except for 1e-5, which has a higher KD and class loss.

Figure 4: Comparing losses of different learning rates

Figure 4: Comparing losses of different learning rates

Based on our benchmarks, the optimal learning rate changes depending on which metric and tasks you are optimizing for.

Model learning rate TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
Baseline Llama 3.1 8B 0.5401 0.5911 0.7915 0.7707
Fine-tuned Llama 3.1 8B using LoRA 0.5475 0.6031 0.7951 0.7789
Baseline Llama 3.2 1B 0.4384 0.4517 0.6064 0.5536
Fine-tuned Llama 3.2 1B using LoRA 0.4492 0.4595 0.6132 0.5528
KD using fine-tuned 8B and baseline 1B 3e-4 0.4481 0.4603 0.6157 0.5569
KD using fine-tuned 8B and baseline 1B 1e-3 0.4453 0.4535 0.6071 0.5258
KD using fine-tuned 8B and baseline 1B 1e-4 0.4489 0.4606 0.6156 0.5586
KD using fine-tuned 8B and baseline 1B 1e-5 0.4547 0.4548 0.6114 0.5487

Table 4: Effects of tuning learning rate

Hyperparameter tuning: KD ratio

By default, the KD ratio is set to 0.5, which gives even weighting to both the class and KD loss. In these experiments, we look at the effects of different KD ratios, where 0 only uses the class loss and 1 only uses the KD loss.

Overall, the benchmark results show that for these tasks and metrics, higher KD ratios perform slightly better.

Model kd_ratio (lr=3e-4) TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
Baseline Llama 3.1 8B 0.5401 0.5911 0.7915 0.7707
Fine-tuned Llama 3.1 8B using LoRA 0.5475 0.6031 0.7951 0.7789
Baseline Llama 3.2 1B 0.4384 0.4517 0.6064 0.5536
Fine-tuned Llama 3.2 1B using LoRA 0.4492 0.4595 0.6132 0.5528
KD using fine-tuned 8B and baseline 1B 0.25 0.4485 0.4595 0.6155 0.5602
KD using fine-tuned 8B and baseline 1B 0.5 0.4481 0.4603 0.6157 0.5569
KD using fine-tuned 8B and baseline 1B 0.75 0.4543 0.463 0.6189 0.5643
KD using fine-tuned 8B and baseline 1B 1.0 0.4537 0.4641 0.6177 0.5717

Table 5: Effects of tuning KD ratio

Looking Ahead

In this blog, we presented a study on how to distill LLMs through torchtune using the forward KL divergence loss on Llama 3.1 8B and Llama 3.2 1B logits. There are many directions for future exploration to further improve performance and offer more flexibility in distillation methods.

  • Expand KD loss offerings. The KD recipe uses the forward KL divergence loss. However, aligning the student distribution to the whole teacher distribution may not be effective, as mentioned above. There are multiple papers, such as MiniLLM, DistiLLM, and Generalized KD, that introduce new KD losses and policies to address the limitation and have shown to outperform the standard use of cross entropy with forward KL divergence loss. For instance, MiniLLM uses reverse KL divergence to prevent the student from over-estimating low-probability regions of the teacher. DistiLLM introduces a skewed KL loss and an adaptive training policy.
  • Enable cross-tokenizer distillation. The current recipe requires the teacher and student model to use the same tokenizer, which limits the ability to distill across different LLM families. There has been research on cross-tokenizer approaches (e.g. Universal Logit Distillation) that we could explore.
  • Expand distillation to multimodal LLMs and encoder models. A natural extension of the KD recipe is to expand to multimodal LLMs. Similar to deploying more efficient LLMs, there’s also a need to deploy smaller and more efficient multimodal LLMs. In addition, there has been work in demonstrating LLMs as encoder models (e.g. LLM2Vec). Distillation from LLMs as encoders to smaller encoder models may also be a promising direction to explore.

Read More