PyTorch compile to speed up inference on Llama 2

PyTorch compile to speed up inference on Llama 2

In this blog, we discuss how to improve the inference latencies of the Llama 2 family of models using PyTorch native optimizations such as native fast kernels, compile transformations from torch compile, and tensor parallel for distributed inference. Our approach results in 29ms/token latency for single user requests on the 70B LLaMa model (as measured on 8 A100 GPUs). We are excited to share our findings with the community and make our code available here.


We are amid a generative AI revolution with large language models of tens of billions of parameters becoming commoditized and available for use. However, it is well recognized in the community that deploying these large models in a cost-efficient manner remains a key challenge. Many different approaches have been attempted with varying degrees of success and offering different trade-offs. Hardware-specific optimizations (e.g., Faster Transformer from NVIDIA) are restricted to specific target hardware whereas approaches that rely on layers of abstraction (e.g., ONNX) enable arbitrary models but suffer from loss of efficiency. With the introduction of PyTorch compile last year, IBM and the PyTorch team started exploring the use of model compilation for inference optimizations with the goal of reducing the latency per token for generative models.

Model Choice

We choose to benchmark on the Llama 2 family of models, given their popularity. The models that we are interested in, and their hyper parameters relevant for this blog are given in the below table:

Model size Hidden dimension Num heads Num layers Attention type
7B 4096 32 32 MHA
13B 5120 40 40 MHA
70B 8192 64 80 GQA

These models are decoder only, which means that tokens get generated in a serialized manner, which is typically sped up using KV caching. We take a similar approach in our latency and throughput measurements.

Inference Approach

Our goal for inference is to provide a path for achieving the best possible latencies rapidly, to keep up with the velocity with which new model architectures are emerging in the community. A PyTorch native approach is appealing as it allows for the maximum flexibility in terms of “coverage” of models. We note that there are four orthogonal techniques that provide acceleration in inference: (a) Kernel fusion using compile, (b) Faster kernels, (c) Tensor parallel for larger models, and (d) Quantization. In our approach, we use the first three of these four levers – compile natively working with faster kernels from SDPA and a custom tensor parallel implementation that all work hand-in-glove to achieve inference latencies of 29ms/token on a 70B model as measured on 8 NVIDIA A100 GPUs with single user.

Compile all the way!

PyTorch Compile leverages tracing and graph capture to reduce the CPU overhead and in an ideal scenario results in a single graph execution/instruction from CPU to GPU. However, often compile introduces graph breaks due to model architecture and ops unsupported by compile. For example, complex operations such as einops are not supported by compile today. Similarly, tensor parallel inference can introduce graph breaks at each layer, since compile requires the tensor parallel implementation to use traceable communication collectives. If these graph breaks are not removed, the performance of the compiled artifacts will be hampered and could even be lower compared to eager mode execution. To get full benefit of the compiled artifacts, the graph breaks need to be removed.

Below, we describe how we went about doing this for the 70b Llama 2 model and the challenges we had to overcome to get compile to work all the way through.

Our first attempt was to try using torch.compile to compile the out-of-box Llama 2 model, but it failed because complex ops were not supported. Using TORCH_COMPILE_DEBUG = 1 we identified the RoPE positional encodings was using complex number functions resulting in graph breaks and significant slowdowns. We rewrote the RoPE function to bypass torch.einsum (Original implementation uses torch.polar that also conflicts with compile) and use torch.cos and torch.sin instead.

self.cached_freqs[dev_idx][alpha] = torch.stack(
        ).view(*freqs.shape, 2, 2)

Our implementation of the frequencies computation

t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb =, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

Hugging Face implementation of the frequencies computation

Once RoPE was fixed, we were able to get 7B and 13B models to compile without ANY graph breaks on a single A100 GPU.

We used SDPA, the PyTorch native implementation of efficient attention computation with tracing enabled (for compile). To avoid graph breaks related to forcing a single algorithm choice using a Python context, the recommended way, we had to use the torch.backends.cuda.enable_*_sdp functions.

attn = torch.nn.functional.scaled_dot_product_attention(
            dropout_p=self.p_dropout if else 0.0,

Attention computation using SDPA

Next we ran the same steps for the larger 70B model and found that even with half precision, the model does not fit in a single GPU and requires tensor parallel inference. Using torch.compile for the 70B model resulted in 162 graph breaks due to two all-reduces per layer, one all-gather for forward embedding, and one all-gather for reverse embedding. Due to this, we saw no significant improvement in inference latencies. We could not use the distributed tensor implementation from PyTorch at the time of writing this blog as it did not support compile. We rewrote the tensor parallel code from scratch so that it only depends on traceable collectives to make it work with compile. After this last change, PyTorch compiler did not introduce any graph breaks and we saw a significant speedup in inference latencies. Specifically, we measured latencies for the Llama 70B model at 29ms/token when using 8 A100 GPUs, a 2.4x improvement over unoptimized inference.

Serving aspects

Finally, a point to note here is that simply performing compile on a model is not sufficient to serve the model in a production setting. To realize the above performance with high throughput, we need to support dynamic batching, nested tensors, as well as have a warm up phase where we pre-compile for bucketized sequence lengths. We are working on these aspects to realize such performance in a production setting.

Experiments and Measurements

We use nodes with 8 A100 NVIDIA GPUs with 80G cards for all our measurements in two different environments (IBM Cloud and AWS, both running OpenShift). First, we compare the various techniques – eager mode, with SDPA Flash kernel, with Compile, and with Compile and SDPA. For the 70B model, we run it in Tensor Parallel mode with compile and SDPA. For this experiment, we use 512 tokens as input length with 50 token generation. For 7 and 13B models, we use single A100 for measurement of latencies, whereas we use 8 A100s for the 70B model. In addition, for the 70B model we use the reduce-overhead option in PyTorch compile that uses CudaGraphs to reduce CPU to GPU kernel launching overheads; the use of CudaGraphs in the 7B and 13B models did not show any benefits (and are thus not reported here). We observe from Figure 1 that compile and SDPA provide very low latencies, with 70B Llama 2 model at 29ms/token.

Figure 1. Median latency across different techniques with sequence length 512 (measured on IBM Cloud A100 servers)

Fig. 1: Median latency across different techniques with sequence length 512 (measured on IBM Cloud A100 servers)

Next, we examine the impact of sequence length, where we increase it from 1024 to 4096 and observe that the median latency per token increases sub-linearly, demonstrating that when we increase context to large documents, we do not sacrifice response times.

Figure 2. Median latency for compile+SDPA with different sequence lengths (Measured on A100s on AWS)

Fig. 2: Median latency for compile+SDPA with different sequence lengths (Measured on A100s on AWS)

Finally, with increased batch sizes, we observe that the response latencies increase sub-linearly. For the 13B model, at batch size 8, we encounter an OOM. For the 70B model, given that it is running on 8 GPUs with tensor parallel, we do not see any such OOM issues.

Figure 3. Median latency for compile+SDPA with different batch sizes and sequence length fixed at 4096 (Measured on A100s on AWS)

Fig. 3: Median latency for compile+SDPA with different batch sizes and sequence length fixed at 4096 (Measured on A100s on AWS)

Final Thoughts

We have demonstrated how a PyTorch compile pathway for inference demonstrates ultra low latencies for 70B model inference. The next steps are to enable dynamic batching and nested tensors with the above levers.

Special thanks to Edward Yang, Elias Ellison, Driss Guessous, Will Feng, Will Constable, Horace He, Less Wright, and Andrew Gu from Team PyTorch, whose PRs reviews and code contributions made it possible for us to realize the latencies using PyTorch native approach. We thank the broader Team PyTorch that have been tirelessly working to make PyTorch better, special shout outs to the SDPA team for enabling tracing and compile on fast kernels, the compile team that has been closely guiding us on how to work around as well as fix issues (including identifying and raising NVIDIA driver bugs in CUDA graphs).

Inference latency has been one of the roadblocks for LLM adoption in critical enterprise workflows, but another major one is the need for safety, trustworthiness and governance. IBM’s guide for AI safety and LLM risk can be found here and Meta’s responsible user guide for LLaMa can be found here.


Read More

Use generative AI to increase agent productivity through automated call summarization

Use generative AI to increase agent productivity through automated call summarization

Your contact center serves as the vital link between your business and your customers. Every call to your contact center is an opportunity to learn more about your customers’ needs and how well you are meeting those needs.

Most contact centers require their agents to summarize their conversation after every call. Call summarization is a valuable tool that helps contact centers understand and gain insights from customer calls. Additionally, accurate call summaries enhance the customer journey by eliminating the need for customers to repeat information when transferred to another agent.

In this post, we explain how to use the power of generative AI to reduce the effort and improve the accuracy of creating call summaries and call dispositions. We also show how to get started quickly using the latest version of our open source solution, Live Call Analytics with Agent Assist.

Challenges with call summaries

As contact centers collect more speech data, the need for efficient call summarization has grown significantly. However, most summaries are empty or inaccurate because manually creating them is time-consuming, impacting agents’ key metrics like average handle time (AHT). Agents report that summarizing can take up to a third of the total call, so they skip it or fill in incomplete information. This hurts the customer experience—long holds frustrate customers while the agent types, and incomplete summaries mean asking customers to repeat information when transferred between agents.

The good news is that automating and solving the summarization challenge is now possible through generative AI.

Generative AI is helping summarize customer calls accurately and efficiently

Generative AI is powered by very large machine learning (ML) models referred to as foundation models (FMs) that are pre-trained on vast amounts of data at scale. A subset of these FMs focused on natural language understanding are called large language models (LLMs) and are able to generate human-like, contextually relevant summaries. The best LLMs can process even complex, non-linear sentence structures with ease and determine various aspects, including topic, intent, next steps, outcomes, and more. Using LLMs to automate call summarization allows for customer conversations to be summarized accurately and in a fraction of the time needed for manual summarization. This in turn enables contact centers to deliver superior customer experience while reducing the documentation burden on their agents.

The following screenshot shows an example of the Live Call Analytics with Agent Assist call details page, which contains information about each call.

The following video shows an example of the Live Call Analytics with Agent Assist summarizing an in-progress call, summarizing after the call ends, and generating a follow-up email.

Solution overview

The following diagram illustrates the solution workflow.

The first step to generating abstractive call summaries is transcribing the customer call. Having accurate, ready-to-use transcripts is crucial to generate accurate and effective call summaries. Amazon Transcribe can help you create transcripts with high accuracy for your contact center calls. Amazon Transcribe is a feature-rich speech-to-text API with state-of-the-art speech recognition models that are fully managed and continuously trained. Customers such as New York Times, Slack, Zillow, Wix, and thousands of others use Amazon Transcribe to generate highly accurate transcripts to improve their business outcomes. A key differentiator for Amazon Transcribe is its ability to protect customer data by redacting sensitive information from the audio and text. Although protecting customer privacy and safety is important in general to contact centers, it’s even more important to mask sensitive information such as bank account information and Social Security numbers before generating automated call summaries, so they don’t get injected into the summaries.

For customers who are already using Amazon Connect, our omnichannel cloud contact center, Contact Lens for Amazon Connect provides real-time transcription and analytics features natively. However, if you want to use generative AI with your existing contact center, we have developed solutions that do most of the heavy lifting associated with transcribing conversations in real time or post-call from your existing contact center, and generating automated call summaries using generative AI. Additionally, the solution detailed in this section allows you to integrate with your Customer Relationship Management (CRM) system to automatically update your CRM of choice with generated call summaries. In this example, we use our Live Call Analytics with Agent Assist (LCA) solution to generate real-time call transcriptions and call summaries with LLMs hosted on Amazon Bedrock. You can also write an AWS Lambda function and provide LCA the function’s Amazon Resource Name (ARN) in the AWS CloudFormation parameters, and use the LLM of your choice.

The following simplified LCA architecture illustrates call summarization with Amazon Bedrock.

LCA is provided as a CloudFormation template that deploys the preceding architecture and allows you to transcribe calls in real time. The workflow steps are as follows:

  1. Call audio can be streamed via SIPREC from your telephony system to Amazon Chime SDK Voice Connector, which buffers the audio in Amazon Kinesis Video Streams. LCA also supports other audio ingestion mechanisms, such Genesys Cloud Audiohook.
  2. Amazon Chime SDK Call Analytics then streams the audio from Kinesis Video Streams to Amazon Transcribe, and writes the JSON output to Amazon Kinesis Data Streams.
  3. A Lambda function processes the transcription segments and persists them to an Amazon DynamoDB table.
  4. After the call ends, Amazon Chime SDK Voice Connector publishes an Amazon EventBridge notification that triggers a Lambda function that reads the persisted transcript from DynamoDB, generates an LLM prompt (more on this in the following section), and runs an LLM inference with Amazon Bedrock. The generated summary is persisted to DynamoDB and can be used by the agent in the LCA user interface. You can optionally provide a Lambda function ARN that will be run after the summary is generated to integrate with third-party CRM systems.

LCA also allows the option to call the summarization Lambda function during the call, because at any time the transcript can be fetched and a prompt created, even if the call is in progress. This can be useful for times when a call is transferred to another agent or escalated to a supervisor. Rather than putting the customer on hold and explaining the call, the new agent can quickly read an auto-generated summary, and it can include what the current issue is and what the previous agent tried to do to resolve it.

Example call summarization prompt

You can run LLM inferences with prompt engineering to generate and improve your call summaries. You can modify the prompt templates to see what works best for the LLM you select. The following is an example of the default prompt for summarizing a transcript with LCA. We replace the {transcript} placeholder with the actual transcript of the call.

Human: Answer the questions below, defined in <question></question> based on the transcript defined in <transcript></transcript>. If you cannot answer the question, reply with 'n/a'. Use gender neutral pronouns. When you reply, only respond with the answer.

What is a summary of the transcript?



LCA runs the prompt and stores the generated summary. Besides summarization, you can direct the LLM to generate almost any text that is important for agent productivity. For example, you can choose from a set of topics that were covered during the call (agent disposition), generate a list of required follow-up tasks, or even write an email to the caller thanking them for the call.

The following screenshot is an example of agent follow-up email generation in the LCA user interface.

With a well-engineered prompt, some LLMs have the ability to generate all of this information in a single inference as well, reducing inference cost and processing time. The agent can then use the generated response within a few seconds of ending the call for their after-contact work. You can also integrate the generated response automatically into your CRM system.

The following screenshot shows an example summary in the LCA user interface.

It’s also possible to generate a summary while the call is still ongoing (see the following screenshot), which can be especially helpful for long customer calls.

Prior to generative AI, agents would be required to pay attention while also taking notes and performing other tasks as required. By automatically transcribing the call and using LLMs to automatically create summaries, we can lower the mental burden on the agent, so they can focus on delivering a superior customer experience. This also leads to more accurate after-call work, because the transcription is an accurate representation of what occurred during the call—not just what the agent took notes on or remembered.


The sample LCA application is provided as open source—use it as a starting point for your own solution, and help us make it better by contributing back fixes and features via GitHub pull requests. For information about deploying LCA, refer to Live call analytics and agent assist for your contact center with Amazon language AI services. Browse to the LCA GitHub repository to explore the code, sign up to be notified of new releases, and check out the README for the latest documentation updates. For customers who are already on Amazon Connect, you can learn more about generative AI with Amazon Connect by referring to How contact center leaders can prepare for generative AI.

About the authors

Christopher Lott is a Senior Solutions Architect in the AWS AI Language Services team. He has 20 years of enterprise software development experience. Chris lives in Sacramento, California and enjoys gardening, aerospace, and traveling the world.

Smriti Ranjan is a Principal Product Manager in the AWS AI/ML team focusing on language and search services. Prior to joining AWS, she worked at Amazon Devices and other technology startups leading product and growth functions. Smriti lives in Boston, MA and enjoys hiking, attending concerts and traveling the world.

Read More

Customize Amazon Textract with business-specific documents using Custom Queries

Customize Amazon Textract with business-specific documents using Custom Queries

Amazon Textract is a machine learning (ML) service that automatically extracts text, handwriting, and data from scanned documents. Queries is a feature that enables you to extract specific pieces of information from varying, complex documents using natural language. Custom Queries provides a way for you to customize the Queries feature for your business-specific, non-standard documents such as auto lending contracts, checks, and pay statements, in a self-service way. By customizing the feature to recognize the unique terms, structures, and key information specific to these document types, you can meet your downstream processing needs with greater precision and minimal human intervention. Custom Queries is easy to integrate in your existing Textract pipeline and you continue to benefit from the fully managed intelligent document processing features of Amazon Textract without having to invest in ML expertise or infrastructure management.

In this post, we show how Custom Queries can accurately extract data from checks that are complex, non-standard documents. In addition, we discuss the benefits of Custom Queries and share best practices for effectively using this feature.

Solution overview

When starting with a new use case, you can evaluate how Textract Queries performs on your documents by navigating to the Textract console and using the Analyze Document Demo or Bulk Document Uploader. Refer to Best Practices for Queries to draft queries applicable to your use case. If you identify errors in the query responses due to the nature of your business documents, you can use Custom Queries to improve accuracy. Within hours, you can annotate your sample documents using the AWS Management Console and train an adapter. Adapters are components that plug in to the Amazon Textract pre-trained deep learning model, customizing its output based on your annotated documents. You can use the adapter for inference by passing the adapter identifier as an additional parameter to the Analyze Document Queries API request.

Let’s examine how Custom Queries can improve extraction accuracy in a challenging real-world scenario such as extraction of data from checks. The primary challenge when processing checks arises from their high degree of variation depending on the type (e.g., personal or cashier’s checks), financial institution and country (e.g., MICR line format). . These variations can include the placement of the payee’s name, the amount in numbers and words, the date, and the signature. Recognizing and adapting to these variations can be a complex task during data extraction. To improve data extraction, organizations often employ manual verification and validation processes, which increases the cost and time of the extraction process.

Custom Queries addresses these challenges by enabling you to customize the pre-trained Queries features on the different variations of checks. Customization of the pre-trained feature helps you achieve a high data extraction accuracy on the specific variety of layouts that you process.

In our use case, a financial institution wants to extract the following fields from a check: payee name, payer name, account number, routing number, payment amount (in numbers), payment amount (in words), check number, date, and memo.

Let’s explore the process of generating an adapter (component that customizes the output) for checks processing. Adapters can be created via the console or programmatically via the API. This post details the console experience; however, if you’d like to programmatically create the adapter, refer to the code samples in the custom-queries-checks-blog.ipynb Jupyter notebook (Option 2).

The adapter generation process involves five high-level steps: create an adapter, upload sample documents, annotate the documents, train the adapter, and evaluate performance metrics.

Create an adapter

On the Amazon Textract console, create a new adapter by providing a name, description, and optional tags that can help you identify the adapter. You have the option to enable automatic updates, which allows Amazon Textract to update your adapter when the underlying Queries feature is updated with new capabilities.

After the adapter is created, you will see an adapter details page with a list of steps in the How it works section. This section will activate your next steps as you complete them sequentially.

Upload sample documents

The initial phase in adapter generation involves the careful selection of an appropriate set of sample documents for annotation, training, and testing. We have an option to auto split the documents into test and train datasets; however, for this process, we manually split the dataset.

It’s important to note that you can construct an adapter with as few as five test and five training samples, but it’s essential to ensure that this sample set is diverse and representative of the workload encountered in a production environment.

For this tutorial, we have curated sample check datasets that you can download. Our dataset includes variations such as personal checks, cashier’s checks, stimulus checks and checks embedded within pay stubs. We also included handwritten and printed checks; along with variations in fields such as the memo line.

Annotate sample documents

As a next step, you annotate the sample documents by associating queries with their corresponding answers via the console. You can initiate annotation via auto labeling or manual labeling. Auto labeling uses Amazon Textract Queries to pre-label the dataset. We recommend using auto labeling to fast-track the annotation process.

For this checks processing use case, we use the following queries. If your use case involves other document types, refer to Best Practices for Queries to draft queries applicable to your use case.

  • Who is the payee?
  • What is the check#?
  • What is the payee address?
  • What is the date?
  • What is the account#?
  • What is the check amount in words?
  • What is the account name/payer/drawer name?
  • What is the dollar amount?
  • What is the bank name/drawee name?
  • What is the bank routing number?
  • What is the MICR line?
  • What is the memo?

When the auto labeling process is complete, you have the option to review and make edits to the answers provided for each document. Choose Start reviewing to review the annotations against each image.

If the response to a query is missing or wrong, you can add or edit the response either by drawing a bounding box or entering the response manually.

To accelerate your walkthrough, we have pre-annotated the checks samples for you to copy to your AWS account. Run the custom-queries-checks-blog.ipynb Jupyter notebook within the Amazon Textract code samples library to automatically update your annotations.

Train the adapter

After you’ve reviewed all the sample documents to ensure the accuracy of the annotations, you can begin the adapter training process. During this step, you need to designate a storage location where the adapter should be saved. The duration of the training process will vary depending on the size of the dataset utilized for training. The training API can also be invoked programmatically if you choose to use an annotation tool of your own choice and pass the relevant input files to the API. Refer to Custom Queries for more details.

Evaluate performance metrics

After the adapter has completed training, you can assess its performance by examining evaluation metrics such as F1 score, precision, and recall. You can analyze these metrics either collectively or on a per-document basis. Using our sample checks dataset, you will see the accuracy metric (F1 score) improve from 68% to 92% with the trained adapter.

Additionally, you can test the adapter’s output on new documents by choosing Try Adapter.

Following the evaluation, you can choose to enhance the adapter’s performance by either incorporating additional sample documents into the training dataset or by re-annotating documents with scores that are lower than your threshold. To re-annotate documents, choose Verify documents on the adapter details page, select the document, and choose Review annotations.

Programmatically test the adapter

With the training successfully completed, you can now use the adapter in your AnalyzeDocument API calls. The API request is similar to the Amazon Textract Queries API request, with the addition of the AdaptersConfig object.

You can run the following sample code or directly run it within the custom-queries-checks-blog.ipynb Jupyter notebook. The sample notebook also provides code to compare results between Amazon Textract Queries and Amazon Textract Custom Queries.

Create an AdaptersConfig object with the adapter ID and adapter version, and optionally include the pages you want the adapter to be applied to:

!python -m pip install amazon-textract-caller --upgrade
!python -m pip install amazon-textract-response-parser –upgrade

import boto3
from textractcaller.t_call import call_textract, Textract_Features, Query, QueriesConfig, Adapter, AdaptersConfig
import trp.trp2 as t2
from tabulate import tabulate

# Create AdaptersConfig
adapter1 = Adapter(adapter_id=”111111111”, version="1", pages=["*"])
adapters_config = AdaptersConfig(adapters=[adapter1])

Create a QueriesConfig object with the queries you trained the adapter with and call the Amazon Textract API. Note that you can also include additional queries that the adapter has not been trained on. Amazon Textract will automatically use the Queries feature for these questions and not Custom Queries, thereby providing you with the flexibility of using Custom Queries only where needed.

# Create QueriesConfig
queries = []
queries.append(Query(text="What is the check#?", alias="CHECK_NUMBER", pages=["*"]))
queries.append(Query(text="What is the date?", alias="DATE", pages=["*"]))
queries.append(Query(text="What is the check amount in words?", alias="CHECK_AMOUNT_WORDS", pages=["*"]))
queries.append(Query(text="What is the dollar amount?", alias="DOLLAR_AMOUNT", pages=["*"]))
queries.append(Query(text="Who is the payee?", alias="PAYEE_NAME", pages=["*"]))
queries.append(Query(text="What is the customer account#", alias="ACCOUNT_NUMBER", pages=["*"]))
queries.append(Query(text="what is the payee address?", alias="PAYEE_ADDRESS", pages=["*"]))
queries.append(Query(text="What is the bank routing number?", alias="BANK_ROUTING_NUMBER", pages=["*"]))
queries.append(Query(text="What is the memo", alias="MEMO", pages=["*"]))
queries.append(Query(text="What is the account name/payer/drawer name?", alias="ACCOUNT_NAME", pages=["*"]))
queries.append(Query(text="What is the bank name/drawee name?", alias="BANK_NAME", pages=["*"]))
queries_config = QueriesConfig(queries=queries)

document_name = "<image_name>"

textract_json_with_adapter = call_textract(input_document=document_name,

Finally, we tabulate our results for better readability:

def tabulate_query_answers(textract_json):
    d = t2.TDocumentSchema().load(textract_json)
    for page in d.pages:
        query_answers = d.get_query_answers(page=page)
        print(tabulate(query_answers, tablefmt="github"))


Clean up

To clean up your resources, complete the following steps:

  1. On the Amazon Textract console, choose Custom Queries in the navigation pane.
  2. Select the adaptor you want to delete.
  3. Choose Delete.

Adapter management

You can regularly improve your adapters by creating new versions of a previously generated adapter. To create a new version of an adapter, you add new sample documents to an existing adapter, label the documents, and perform training. You can simultaneously maintain multiple versions of an adapter for use in your development pipelines. To update your adapters seamlessly, do not make changes to or delete your Amazon Simple Storage Service (Amazon S3) bucket where the files needed for adapter generation are saved.

Best practices

When using Custom Queries on your documents, refer to Best practices for Amazon Textract Custom Queries for additional considerations and best practices.

Benefits of Custom Queries

Custom Queries offers the following benefits:

  • Enhanced document understanding – Through its ability to extract and normalize data with high accuracy, Custom Queries reduces reliance on manual reviews, and audits, and enables you to build more reliable automation for your intelligent document processing workflows.
  • Faster time to value – When you encounter new document types where you need higher accuracy, you can use Custom Queries to generate an adapter in a self-service manner within a few hours. You don’t have to wait for a pre-trained model update when you encounter new document types or variations of existing ones in your workflow. You have complete control over your pipeline and don’t need to depend on Amazon Textract to support your new document types.
  • Data privacy – Custom Queries does not retain or use the data employed in generating adapters to enhance our general pretrained models available to all customers. The adapter is limited to the customer’s account or other accounts explicitly designated by the customer, ensuring that only such accounts can access the improvements made using the customer’s data.
  • Convenience –Custom Queries provides a fully managed inference experience similar to Queries. The adapter training is free and you will only pay for inference. Custom Queries saves you the overhead and expenses of training and operating custom models.


In this post, we discussed the benefits of Custom Queries, showed how Custom Queries can accurately extract data from checks, and shared best practices for effectively utilizing this feature. In just a few hours, you can create an adapter using the console and use it in the AnalyzeDocument API for your data extraction needs. For more information, refer to Custom Queries.

About the authors

Shibin Michaelraj is a Sr. Product Manager with the Amazon Textract team. He is focused on building AI/ML-based products for AWS customers. He is excited helping customers solve their complex business challenges by leveraging AI and ML technologies. In his spare time, he enjoys running, tuning into podcasts, and refining his amateur tennis skills.

Keith Mascarenhas is a Sr. Solutions Architect with the Amazon Textract service team. He is passionate about solving business problems at scale using machine learning, and currently helps our worldwide customers automate their document processing to achieve faster time to market with reduced operational costs.

Read More

Stream large language model responses in Amazon SageMaker JumpStart

Stream large language model responses in Amazon SageMaker JumpStart

We are excited to announce that Amazon SageMaker JumpStart can now stream large language model (LLM) inference responses. Token streaming allows you to see the model response output as it is being generated instead of waiting for LLMs to finish the response generation before it is made available for you to use or display. The streaming capability in SageMaker JumpStart can help you build applications with better user experience by creating a perception of low latency to the end-user.

In this post, we walk through how to deploy and stream the response from a Falcon 7B Instruct model endpoint.

At the time of this writing, the following LLMs available in SageMaker JumpStart support streaming:

  • Mistral AI 7B, Mistral AI 7B Instruct
  • Falcon 180B, Falcon 180B Chat
  • Falcon 40B, Falcon 40B Instruct
  • Falcon 7B, Falcon 7B Instruct
  • Rinna Japanese GPT NeoX 4B Instruction PPO
  • Rinna Japanese GPT NeoX 3.6B Instruction PPO

To check for updates on the list of models supporting streaming in SageMaker JumpStart, search for “huggingface-llm” at Built-in Algorithms with pre-trained Model Table.

Note that you can use the streaming feature of Amazon SageMaker hosting out of the box for any model deployed using the SageMaker TGI Deep Learning Container (DLC) as described in Announcing the launch of new Hugging Face LLM Inference containers on Amazon SageMaker.

Foundation models in SageMaker

SageMaker JumpStart provides access to a range of models from popular model hubs, including Hugging Face, PyTorch Hub, and TensorFlow Hub, which you can use within your ML development workflow in SageMaker. Recent advances in ML have given rise to a new class of models known as foundation models, which are typically trained on billions of parameters and can be adapted to a wide category of use cases, such as text summarization, generating digital art, and language translation. Because these models are expensive to train, customers want to use existing pre-trained foundation models and fine-tune them as needed, rather than train these models themselves. SageMaker provides a curated list of models that you can choose from on the SageMaker console.

You can now find foundation models from different model providers within SageMaker JumpStart, enabling you to get started with foundation models quickly. SageMaker JumpStart offers foundation models based on different tasks or model providers, and you can easily review model characteristics and usage terms. You can also try these models using a test UI widget. When you want to use a foundation model at scale, you can do so without leaving SageMaker by using prebuilt notebooks from model providers. Because the models are hosted and deployed on AWS, you trust that your data, whether used for evaluating or using the model at scale, won’t be shared with third parties.

Token streaming

Token streaming allows the inference response to be returned as it’s being generated by the model. This way, you can see the response generated incrementally rather than wait for the model to finish before providing the complete response. Streaming can help enable a better user experience because it decreases the latency perception for the end-user. You can start seeing the output as it’s generated and therefore can stop generation early if the output isn’t looking useful for your purposes. Streaming can make a big difference, especially for long-running queries, because you can start seeing outputs as it’s generated, which can create a perception of lower latency even though the end-to-end latency stays the same.

As of this writing, you can use streaming in SageMaker JumpStart for models that utilize Hugging Face LLM Text Generation Inference DLC.

Response with No Steaming Response with Streaming

Solution overview

For this post, we use the Falcon 7B Instruct model to showcase the SageMaker JumpStart streaming capability.

You can use the following code to find other models in SageMaker JumpStart that support streaming:

from sagemaker.jumpstart.notebook_utils import list_jumpstart_models
from sagemaker.jumpstart.filters import And

filter_value = And("task == llm", "framework == huggingface")
model_ids = list_jumpstart_models(filter=filter_value)

We get the following model IDs that support streaming:

['huggingface-llm-bilingual-rinna-4b-instruction-ppo-bf16', 'huggingface-llm-falcon-180b-bf16', 'huggingface-llm-falcon-180b-chat-bf16', 'huggingface-llm-falcon-40b-bf16', 'huggingface-llm-falcon-40b-instruct-bf16', 'huggingface-llm-falcon-7b-bf16', 'huggingface-llm-falcon-7b-instruct-bf16', 'huggingface-llm-mistral-7b', 'huggingface-llm-mistral-7b-instruct', 'huggingface-llm-rinna-3-6b-instruction-ppo-bf16']


Before running the notebook, there are some initial steps required for setup. Run the following commands:

%pip install --upgrade sagemaker –quiet

Deploy the model

As a first step, use SageMaker JumpStart to deploy a Falcon 7B Instruct model. For full instructions, refer to Falcon 180B foundation model from TII is now available via Amazon SageMaker JumpStart. Use the following code:

from sagemaker.jumpstart.model import JumpStartModel

my_model = JumpStartModel(model_id="huggingface-llm-falcon-7b-instruct-bf16")
predictor = my_model.deploy()

Query the endpoint and stream response

Next, construct a payload to invoke your deployed endpoint with. Importantly, the payload should contain the key/value pair "stream": True. This indicates to the text generation inference server to generate a streaming response.

payload = {
    "inputs": "How do I build a website?",
    "parameters": {"max_new_tokens": 256},
    "stream": True

Before you query the endpoint, you need to create an iterator that can parse the bytes stream response from the endpoint. Data for each token is provided as a separate line in the response, so this iterator returns a token each time a new line is identified in the streaming buffer. This iterator is minimally designed, and you might want to adjust its behavior for your use case; for example, while this iterator returns token strings, the line data contains other information, such as token log probabilities, that could be of interest.

import io
import json

class TokenIterator:
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            line = self.buffer.readline()
            if line and line[-1] == ord("n"):
                self.read_pos += len(line) + 1
                full_line = line[:-1].decode("utf-8")
                line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
                return line_data["token"]["text"]
            chunk = next(self.byte_iterator)
  , io.SEEK_END)

Now you can use the Boto3 invoke_endpoint_with_response_stream API on the endpoint that you created and enable streaming by iterating over a TokenIterator instance:

import boto3

client = boto3.client("runtime.sagemaker")
response = client.invoke_endpoint_with_response_stream(

for token in TokenIterator(response["Body"]):
    print(token, end="")

Specifying an empty end parameter to the print function will enable a visual stream without new line characters inserted. This produces the following output:

Building a website can be a complex process, but it generally involves the following steps:

1. Determine the purpose and goals of your website
2. Choose a domain name and hosting provider
3. Design and develop your website using HTML, CSS, and JavaScript
4. Add content to your website and optimize it for search engines
5. Test and troubleshoot your website to ensure it is working properly
6. Maintain and update your website regularly to keep it running smoothly.

There are many resources available online to guide you through these steps, including tutorials and templates. It may also be helpful to seek the advice of a web developer or designer if you are unsure about any of these steps.<|endoftext|>

You can use this code in a notebook or other applications like Streamlit or Gradio to see the streaming in action and the experience it provides for your customers.

Clean up

Finally, remember to clean up your deployed model and endpoint to avoid incurring additional costs:



In this post, we showed you how to use newly launched feature of streaming in SageMaker JumpStart. We hope you will use the token streaming capability to build interactive applications requiring low latency for a better user experience.

About the authors

Rachna Chadha is a Principal Solution Architect AI/ML in Strategic Accounts at AWS. Rachna is an optimist who believes that ethical and responsible use of AI can improve society in future and bring economic and social prosperity. In her spare time, Rachna likes spending time with her family, hiking, and listening to music.

Dr. Kyle Ulrich is an Applied Scientist with the Amazon SageMaker built-in algorithms team. His research interests include scalable machine learning algorithms, computer vision, time series, Bayesian non-parametrics, and Gaussian processes. His PhD is from Duke University and he has published papers in NeurIPS, Cell, and Neuron.

Dr. Ashish Khetan is a Senior Applied Scientist with Amazon SageMaker built-in algorithms and helps develop machine learning algorithms. He got his PhD from University of Illinois Urbana-Champaign. He is an active researcher in machine learning and statistical inference, and has published many papers in NeurIPS, ICML, ICLR, JMLR, ACL, and EMNLP conferences.

Read More

High-Performance Llama 2 Training and Inference with PyTorch/XLA on Cloud TPUs

High-Performance Llama 2 Training and Inference with PyTorch/XLA on Cloud TPUs

In a landscape where AI innovation is accelerating at an unprecedented pace, Meta’s Llama family of open sourced large language models (LLMs) stands out as a notable breakthrough. Llama marked a significant step forward for LLMs, demonstrating the power of pre-trained architectures for a wide range of applications. Llama 2 further pushed the boundaries of scale and capabilities, inspiring advancements in language understanding, generation, and beyond.

Shortly after the announcement of Llama, we published a blog post showcasing ultra-low inference latency for Llama using PyTorch/XLA on Cloud TPU v4. Building on these results, today, we are proud to share Llama 2 training and inference performance using PyTorch/XLA on Cloud TPU v4 and our newest AI supercomputer, Cloud TPU v5e.

In this blog post, we use Llama 2 as an example model to demonstrate the power of PyTorch/XLA on Cloud TPUs for LLM training and inference. We discuss the computation techniques and optimizations used to improve inference throughput and training model FLOPs utilization (MFU). For Llama 2 70B parameters, we deliver 53% training MFU, 17 ms/token inference latency, 42 tokens/s/chip throughput powered by PyTorch/XLA on Google Cloud TPU. We offer a training user guide and an inference user guide for reproducing the results in this article. Additionally, you may find our Google Next 2023 presentation here.

Model Overview

Llama 2 comes in various sizes, ranging from 7B to 70B parameters, catering to different needs, computational resources, and training / inference budgets. Whether it’s small-scale projects or large-scale deployments, Llama models offer versatility and scalability to accommodate a wide range of applications.

Llama 2 is an auto-regressive language model that uses an optimized transformer architecture. The largest, 70B model, uses grouped-query attention, which speeds up inference without sacrificing quality. Llama 2 is trained on 2 trillion tokens (40% more data than Llama) and has the context length of 4,096 tokens for inference (double the context length of Llama), which enables more accuracy, fluency, and creativity for the model.

Llama 2 is a state-of-the-art LLM that outperforms many other open source language models on many benchmarks, including reasoning, coding, proficiency, and knowledge tests. The model’s scale and complexity place many demands on AI accelerators, making it an ideal benchmark for LLM training and inference performance of PyTorch/XLA on Cloud TPUs.

Performance Challenge of LLMs

Large-scale distributed training for LLMs such as Llama 2 introduces technical challenges that require practical solutions to make the most efficient use of TPUs. Llama’s size can strain both memory and processing resources of TPUs. To address this, we use model sharding, which involves breaking down the model into smaller segments, each fitting within the capacity of a single TPU core. This enables parallelism across multiple TPUs, improving training speed while reducing communication overhead.

Another challenge is managing the large datasets required for training Llama 2 efficiently, which requires effective data distribution and synchronization methods. Additionally, optimizing factors like learning rate schedules, gradient aggregation, and weight synchronization across distributed TPUs is crucial for achieving convergence.

After pretraining or fine-tuning Llama 2, running inference on the model checkpoint creates additional technical challenges. All of the challenges discussed in our previous blog post, such as autoregressive decoding, variable input prompt lengths, and the need for model sharding and quantization still apply for Llama 2. In addition, Llama 2 introduced two new capabilities: grouped-query attention and early stopping. We discuss how PyTorch/XLA handles these challenges to enable high-performance, cost-efficient training and inference of Llama 2 on Cloud TPU v4 and v5e.

Large-Scale Distributed Training

PyTorch/XLA offers two major ways of doing large-scale distributed training: SPMD, which utilizes the XLA compiler to transform and partition a single-device program into a multi-device distributed program; and FSDP, which implements the widely-adopted Fully Sharded Data Parallel algorithm.

In this blog post, we show how to use the SPMD API to annotate the HuggingFace (HF) Llama 2 implementation to maximize performance. For comparison, we also show our FSDP results with the same configurations; read about PyTorch/XLA FSDP API here.

SPMD Overview

Let’s briefly review the fundamentals of SPMD. For details, please refer to our blog post and user guide.


A multidimensional array that describes the logical topology of the TPU devices:

# Assuming you are running on a TPU host that has 8 devices attached
num_devices = xr.global_runtime_device_count()
# mesh shape will be (4,2) in this example
mesh_shape = (num_devices // 2, 2)
device_ids = np.array(range(num_devices))
# axis_names 'x' and 'y' are optional
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

Partition Spec

A tuple that describes how the corresponding tensor’s dimensions are sharded across the mesh:

partition_spec = ('x', 'y')

Mark Sharding

An API that takes a mesh and a partition_spec, and then generates a sharding annotation for the XLA compiler.

tensor = torch.randn(4, 4).to('xla')
# Let's resue the above mesh and partition_spec.
# It means the tensor's 0th dim is sharded 4 way and 1th dim is sharded 2 way.
xs.mark_sharding(tensor, mesh, partition_spec)

2D Sharding with SPMD

In our SPMD blog post, we demonstrated using 1D FSDP style sharding. Here, we introduce a more powerful sharding strategy, called 2D sharding, where both the parameters and activations are sharded. This new sharding strategy not only allows fitting a larger model but also boosts the MFU to up to 54.3%. For more details, read the Benchmarks section.

This section introduces a set of general rules that applies to most LLMs, and for convenience we directly reference the variable names and configuration names from HF Llama.

First, let’s create a 2D Mesh with corresponding axis names: data and model. The data axis is usually where we distribute the input data, and the model axis is where we further distribute the model.

mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))

The mesh_shape can be a hyper-parameter that is tuned for different model sizes and hardware configurations. The same mesh will be reused in all following sharding annotations. In the next few sections, we will cover how to use the mesh to shard parameters, activations and input data.

Parameter Sharding

Below is a table that summarizes all parameters of HF Llama 2 and corresponding partition specifications. Example HF code can be found here.

Parameter Name Explanation Parameter Shape Partition Spec
embed_tokens embedding layer (vocab_size, hidden_size) (model, data)
q_proj attention weights (num_heads x head_dim, hidden_size) (data, model)
k_proj / v_proj attention weights (num_key_value_heads x head_dim, hidden_size) (data, model)
o_proj attention weights (hidden_size, num_heads x head_dim) (model, data)
gate_proj / up_proj MLP weights (intermediate_size, hidden_size) (model, data)
down_proj MLP weights (hidden_size, intermediate_size) (data, model)
lm_head HF output embedding (vocab_size, hidden_size) (model, data)

Table 1: SPMD 2D Sharding Parameter Partition Spec

The rule is to shard the hidden_size dim of any weights except QKVO projections according to the data axis of the mesh, then shard the other dim with the remaining model axis. For QKVO, do the opposite. This model-data axis rotation methodology is similar to that of Megatron-LM to reduce communication overhead. For layernorm weights, we implicitly mark them as replicated across different devices given they are 1D tensors.

Activation Sharding

In order to better utilize the device memory, very often we need to annotate the output of some memory bound ops. That way the compiler is forced to only keep partial output on devices instead of the full output. In Llama 2, we explicitly annotate all torch.matmul and nn.Linear outputs. Table 2 summarizes the corresponding annotations; the example HF code can be found here.

Output Name Explanation Output Shape Partition Spec
inputs_embeds embedding layer output (batch_size, sequence_length, hidden_size) (data, None, model)
query_states attention nn.Linear output (batch_size, sequence_length, num_heads x head_dim) (data, None, model)
key_states / value_states attention nn.Linear output (batch_size, sequence_length, num_key_value_heads x head_dim) (data, None, model)
attn_weights attention weights (batch_size, num_attention_heads, sequence_length, sequence_length) (data, model, None, None)
attn_output attention layer output (batch_size, sequence_length, hidden_size) (data, None, model)
up_proj / gate_proj / down_proj MLP nn.Linear outputs (batch_size, sequence_length, intermediate_size) (data, None, model)
logits HF output embedding output (batch_size, sequence_length, hidden_size) (data, None, model)

Table 2: SPMD 2D Sharding Activation Partition Spec

The rule is to shard the batch_size dim of any outputs according to the data axis of the mesh, then replicate the length dims of any outputs, and finally shard the last dim along the model axis.

Input Sharding

For input sharding, the rule is to shard the batch dim along the data axis of the mesh, and replicate the sequence_length dim. Below is the example code, and the corresponding HF change may be found here.

partition_spec = ('data', None)
sharding_spec = xs.ShardingSpec(mesh, partition_spec)
# MpDeviceLoader will shard the input data before sending to the device.
pl.MpDeviceLoader(dataloader, self.args.device, input_sharding=sharding_spec, ...)

Now, all the data and model tensors that require sharding are covered!

Optimizer States & Gradients

You may be wondering whether it is necessary to shard the optimizer states and gradients as well. Great news: the sharding propagation feature of the XLA compiler automates the sharding annotation in these two scenarios, without needing more hints to improve performance.

It is important to note that optimizer states are typically initialized within the first iteration of the training loop. From the standpoint of the XLA compiler, the optimizer states are the outputs of the first graph, and therefore have the sharding annotation propagated. For subsequent iterations, the optimizer states become inputs to the second graph, with the sharding annotation propagated from the first one. This is also why PyTorch/XLA typically produces two graphs for the training loops. If the optimizer states are somehow initialized before the first iteration, users will have to manually annotate them, just like the model weights.

Again, all concrete examples of the above sharding annotation can be found in our fork of HF Transformers here. The repo also contains code for our experimental feature MultiSlice, including HybridMesh and dcn axis, which follows the same principles mentioned above.


While using SPMD for training, there are a few important things to pay attention to:

  • Use torch.einsum instead of torch.matmul; torch.matmul usually flattens tensors and does a at the end, and that’s bad for SPMD when the combined axes are sharded. The XLA compiler will have a hard time determining how to propagate the sharding.
  • PyTorch/XLA provides patched [nn.Linear]( to overcome the above constraint:
import torch_xla.experimental.xla_sharding as xs
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear

 model = apply_xla_patch_to_nn_linear(model, xs.xla_patched_nn_linear_forward)
  • Always reuse the same mesh across all shardings
  • Always specify --dataloader_drop_last yes. The last smaller data is hard to annotate.
  • Large models which are initialized on the host can induce host-side OOM. One way to avoid this issue is to initialize parameters on the meta device, then create and shard real tensors layer-by-layer.

Infrastructure Improvements

Besides the above modeling techniques, we have developed additional features and improvements to maximize performance, including:

  • We enable asynchronous collective communication. This requires enhancements on the XLA compiler’s latency hiding scheduler to better optimize for the Llama 2 PyTorch code.
  • We now allow sharding annotations in the middle of the IR graph, just like JAX’s jax.lax.with_sharding_constraint. Previously, only graph inputs were annotated.
  • We also propagate replicated sharding spec from the compiler to the graph outputs. This allows us to shard the optimizer states automatically.

Inference Optimizations

All the PyTorch/XLA optimizations implemented for Llama inference are applied to Llama 2 as well. That includes Tensor Parallelism + Dynamo (torch.compile) using torch-xla collective ops, autoregressive decoding logic improvement to avoid recompilation, bucketized prompt length, KV-cache with compilation friendly index ops. Llama 2 introduces two new changes: Grouped Query Attention, and Early Stopping when eos is reached for all prompts. We applied corresponding changes to promote better performance and flexibility with PyTorch/XLA.

Grouped Query Attention

Llama 2 enables Grouped Query Attention for the 70B models. It allows the number of Key and Value heads to be smaller than the number of Query heads, while still supporting KV-cache sharding up to the number of KV heads. For the 70B models, the n_kv_heads is 8, which limits the tensor parallelism to be less or equal to 8. In order to shard the model checkpoint to run on more devices, the K, V projection weights need to be replicated first, and then split into multiple pieces. For example, to shard the 70B model checkpoint from 8 pieces to 16 pieces, the K, V projection weights are duplicated and split into 2 pieces for each shard. We provide a script to handle that, and to make sure the sharded checkpoint performs mathematically identical to the original checkpoint.

EOS Early Stopping

The Llama 2 generation code added the early stopping logic. A eos_reached tensor is used to track the completion of all the prompt generations, and if the eos token is reached for all the prompts in the batch, the generation would stop early. The similar change is incorporated in the PyTorch/XLA optimized version as well, with some minor tweaks.

In PyTorch/XLA, checking the value of a tensor like eos_reached as part of the control flow condition would invoke a blocking device-to-host transfer. The tensor would be transferred from device memory to CPU memory to evaluate its value, while all other logics are waiting. This introduced a delay on the scale of ms after every new token generation. As a trade-off, we reduce the rate of checking the eos_reached value to be once every 10 new token generations. With this change, the impact of the blocking device-to-host transfer would be reduced by 10x, while the early stopping would still be effective, and at most 9 unnecessary tokens would be generated after each sequence reaches the eos token.

Model Serving

PyTorch/XLA is working on a serving strategy to enable the PyTorch community to serve their deep learning applications via Torch.Export, StableHLO, and SavedModel. PyTorch/XLA Serving is an experimental feature in PyTorch/XLA 2.1 release; for details visit our serving user guide. Users can take advantage of TorchServe to run their single-host workloads.



To measure training performance, we use the industry-standard metric: Model FLOPS Utilization (MFU). Model FLOPS are the floating point operations required to perform a single forward and backward pass. Model FLOPs are hardware and implementation independent and only depend on the underlying model. MFU measures how effectively the model is using the actual hardware during training. Achieving 100% MFU means that the model is using the hardware perfectly.

To measure inference performance, we use the industry-standard metric of throughput. First, we measure latency per token when the model has been compiled and loaded. Then, we calculate throughput by dividing batch size (BS) over latency per chip. As a result, throughput measures how the model is performing in production environments regardless of how many chips are used.


Training Evaluation

Figure 1 shows Llama 2 SPMD 2D sharding training results on a range of Google TPU v4 hardware with PyTorch/XLA FSDP as the baseline. We increased MFU by 28% across all sizes of Llama 2 compared to FSDP running on the same hardware configuration. This performance improvement is largely due to: 1) 2D Sharding has less communication overhead than FSDP, and 2) asynchronous collective communication is enabled in SPMD which allows communication and computation overlapping. Also note that as the model size scales, we maintain the high MFU. Table 3 shows all the hardware configurations plus some hyperparameters used in the training benchmarks.

Figure 1. Llama 2 Training MFU on TPU v4 Hardware

Fig. 1: Llama 2 Training MFU on TPU v4 Hardware

The results in Figure 1 are produced with sequence length 1,024. Figure 2 shows how the performance behaves with larger sequence lengths. It shows our performance also scales linearly with sequence lengths. The MFU is expected to decrease a little as a smaller per device batch size is needed to accommodate the additional memory pressure introduced by the larger sequence length since the sequence length axis is not sharded in 2D sharding. And TPU is very sensitive to batch size. For Llama 2, 70B parameters, the performance decrease is as low as 4%. At the time of preparing these results, Hugging Face Llama 2 tokenizer limits the max model input to 2,048, preventing us from evaluating larger sequence lengths.

Figure 2. Llama 2 SPMD Training MFU on TPU v4 with Different Sequence Lengths

Fig. 2: Llama 2 SPMD Training MFU on TPU v4 with Different Sequence Lengths

Model Size 7B 13B 70B
TPU NumCores V4-32 V4-64 V4-256
Mesh Shape (16, 1) (32, 1) (32, 4)
Seq Len 1,024 2,048 1,024 2,048 1,024 2,048
Global Batch 256 128 256 128 512 256
Per Device Batch 16 8 8 4 16 8

Table 3: Llama 2 SPMD Training Benchmark TPU Configurations and Hyperparameters

One last thing to call out is that we use adafactor as the optimizer for better memory utilization. And once again, here is the user guide to reproduce the benchmark results listed above.

Inference Evaluation

In this section, we extend our previous evaluation of Llama on Cloud v4 TPU. Here, we demonstrate the performance properties of TPU v5e for inference applications.

We define inference throughput as the number of tokens produced by a model per second per TPU chip. Figure 3 shows Llama 2 70B throughput on a v5e-16 TPU node. Given Llama is a memory bound application, we see that applying weight-only quantization unblocks extending the model batch size to 32. Higher throughput results would be possible on larger TPU v5e hardware up to the point where the ICI network bandwidth between chips throttle the TPU slice from delivering higher throughput. Exploring the upper bound limits of TPU v5e on Llama 2 was outside of the scope of this work. Notice, to make the Llama 2 70B model run on v5e-16, we replicated the attention heads to have one head per chip as discussed in the Inference section above. As discussed previously, with increasing model batch size, per-token latency grows proportionally; quantization improves overall latency by reducing memory I/O demand.

Figure 3. Llama 2 70B Inference Per-Chip Throughput on TPU v5e vs. Batch Size

Fig. 3: Llama 2 70B Inference Per-Chip Throughput on TPU v5e vs. Batch Size

Figure 4 shows inference throughput results across different model sizes. These results highlight the largest throughput given the hardware configuration when using bf16 precision. With weight only quantization, this throughput reaches 42 on the 70B model. As mentioned above, increasing hardware resources may lead to performance gains.

Figure 4. Llama 2 Inference Per-Chip Throughput on TPU v5e

Fig. 4: Llama 2 Inference Per-Chip Throughput on TPU v5e

Figure 5 shows the cost of serving Llama 2 models (from Figure 4) on Cloud TPU v5e. We report the TPU v5e per-chip cost based on the 3-year commitment (reserved) price in the us-west4 region. All model sizes use maximum sequence length of 2,048 and maximum generation length of 1,000 tokens. Note that with quantization, the cost for the 70B model drops to $0.0036 per 1,000 tokens.

Figure 5. Llama 2 Inference Per-Chip Cost on TPU v5e

Fig. 5: Llama 2 Inference Per-Chip Cost on TPU v5e

Figure 6 summarizes our best Llama 2 inference latency results on TPU v5e. Llama 2 7B results are obtained from our non-quantized configuration (BF16 Weight, BF16 Activation) while the 13B and 70B results are from the quantized (INT8 Weight, BF16 Activation) configuration. We attribute this observation to the inherent memory saving vs. compute overhead tradeoff of quantization; as a result, for smaller models, quantization may not lead to lower inference latency.

Additionally, prompt length has a strong effect on the memory requirements of LLMs. For instance, we observe a latency of 1.2ms / token (i.e. 201 tokens / second / chip) when max_seq_len=256 at batch size of 1 with no quantization on v5e-4 running Llama2 7B.

Figure 6. Llama 2 Inference Latency on TPU v5e

Fig. 6: Llama 2 Inference Latency on TPU v5e

Final Thoughts

The recent wave of AI innovation has been nothing short of transformative, with breakthroughs in LLMs at the forefront. Meta’s Llama and Llama 2 models stand as notable milestones in this wave of progress. PyTorch/XLA uniquely enables high-performance, cost-efficient training and inference for Llama 2 and other LLMs and generative AI models on Cloud TPUs, including the new Cloud TPU v5e. Looking forward, PyTorch/XLA will continue to push the performance limits on Cloud TPUs in both throughput and scalability and at the same time maintain the same PyTorch user experience.

We are ecstatic about what’s ahead for PyTorch/XLA and invite the community to join us. PyTorch/XLA is developed fully in open source. So, please file issues, submit pull requests, and send RFCs to GitHub so that we can openly collaborate. You can also try out PyTorch/XLA for yourself on various XLA devices including TPUs and GPUs.

We would like to extend our special thanks to Marcello Maggioni, Tongfei Guo, Andy Davis, Berkin Ilbeyi for their support and collaboration in this effort.

The PyTorch/XLA Team at Google

Read More

Introducing GPTs

You can now create custom versions of ChatGPT that combine instructions, extra knowledge, and any combination of skills.OpenAI Blog

Best of both worlds: Achieving scalability and quality in text clustering

Best of both worlds: Achieving scalability and quality in text clustering

Clustering is a fundamental, ubiquitous problem in data mining and unsupervised machine learning, where the goal is to group together similar items. The standard forms of clustering are metric clustering and graph clustering. In metric clustering, a given metric space defines distances between data points, which are grouped together based on their separation. In graph clustering, a given graph connects similar data points through edges, and the clustering process groups data points together based on the connections between them. Both clustering forms are particularly useful for large corpora where class labels can’t be defined. Examples of such corpora are the ever-growing digital text collections of various internet platforms, with applications including organizing and searching documents, identifying patterns in text, and recommending relevant documents to users (see more examples in the following posts: clustering related queries based on user intent and practical differentially private clustering).

The choice of text clustering method often presents a dilemma. One approach is to use embedding models, such as BERT or RoBERTa, to define a metric clustering problem. Another is to utilize cross-attention (CA) models, such as PaLM or GPT, to define a graph clustering problem. CA models can provide highly accurate similarity scores, but constructing the input graph may require a prohibitive quadratic number of inference calls to the model. On the other hand, a metric space can efficiently be defined by distances of embeddings produced by embedding models. However, these similarity distances are typically of substantial lower-quality compared to the similarity signals of CA models, and hence the produced clustering can be of much lower-quality.

An overview of the embedding-based and cross-attention–based similarity scoring functions and their scalability vs. quality dilemma.

Motivated by this, in “KwikBucks: Correlation Clustering with Cheap-Weak and Expensive-Strong Signals”, presented at ICLR 2023, we describe a novel clustering algorithm that effectively combines the scalability benefits from embedding models and the quality from CA models. This graph clustering algorithm has query access to both the CA model and the embedding model, however, we apply a budget on the number of queries made to the CA model. This algorithm uses the CA model to answer edge queries, and benefits from unlimited access to similarity scores from the embedding model. We describe how this proposed setting bridges algorithm design and practical considerations, and can be applied to other clustering problems with similar available scoring functions, such as clustering problems on images and media. We demonstrate how this algorithm yields high-quality clusters with almost a linear number of query calls to the CA model. We have also open-sourced the data used in our experiments.

The clustering algorithm

The KwikBucks algorithm is an extension of the well-known KwikCluster algorithm (Pivot algorithm). The high-level idea is to first select a set of documents (i.e., centers) with no similarity edge between them, and then form clusters around these centers. To obtain the quality from CA models and the runtime efficiency from embedding models, we introduce the novel combo similarity oracle mechanism. In this approach, we utilize the embedding model to guide the selection of queries to be sent to the CA model. When given a set of center documents and a target document, the combo similarity oracle mechanism outputs a center from the set that is similar to the target document, if present. The combo similarity oracle enables us to save on budget by limiting the number of query calls to the CA model when selecting centers and forming clusters. It does this by first ranking centers based on their embedding similarity to the target document, and then querying the CA model for the pair (i.e., target document and ranked center), as shown below.

A combo similarity oracle that for a set of documents and a target document, returns a similar document from the set, if present.

We then perform a post processing step to merge clusters if there is a strong connection between two of them, i.e., when the number of connecting edges is higher than the number of missing edges between two clusters. Additionally, we apply the following steps for further computational savings on queries made to the CA model, and to improve performance at runtime:

  1. We leverage query-efficient correlation clustering to form a set of centers from a set of randomly selected documents instead of selecting these centers from all the documents (in the illustration below, the center nodes are red).
  2. We apply the combo similarity oracle mechanism to perform the cluster assignment step in parallel for all non-center documents and leave documents with no similar center as singletons. In the illustration below, the assignments are depicted by blue arrows and initially two (non-center) nodes are left as singletons due to no assignment.
  3. In the post-processing step, to ensure scalability, we use the embedding similarity scores to filter down the potential mergers (in the illustration below, the green dashed boundaries show these merged clusters).

Illustration of progress of the clustering algorithm on a given graph instance.


We evaluate the novel clustering algorithm on various datasets with different properties using different embedding-based and cross-attention–based models. We compare the clustering algorithm’s performance with the two best performing baselines (see the paper for more details):

To evaluate the quality of clustering, we use precision and recall. Precision is used to calculate the percentage of similar pairs out of all co-clustered pairs and recall is the percentage of co-clustered similar pairs out of all similar pairs. To measure the quality of the obtained solutions from our experiments, we use the F1-score, which is the harmonic mean of the precision and recall, where 1.0 is the highest possible value that indicates perfect precision and recall, and 0 is the lowest possible value that indicates if either precision or recall are zero. The table below reports the F1-score for Kwikbucks and various baselines in the case that we allow only a linear number of queries to the CA model. We show that Kwikbucks offers a substantial boost in performance with a 45% relative improvement compared to the best baseline when averaging across all datasets.

Comparing the clustering algorithm to two baseline algorithms using various public datasets: (1) The query-efficient correlation clustering algorithm for budgeted clustering with access to CA only, and (2) spectral clustering on the k-nearest neighbor (kNN) graph formed by querying the CA model for the k-nearest neighbors of each vertex from embedding-based similarity. Pre-processed datasets can be downloaded here.

The figure below compares the clustering algorithm’s performance with baselines using different query budgets. We observe that KwikBucks consistently outperforms other baselines at various budgets.

A comparison of KwikBucks with top-2 baselines when allowed different budgets for querying the cross-attention model.


Text clustering often presents a dilemma in the choice of similarity function: embedding models are scalable but lack quality, while cross-attention models offer quality but substantially hurt scalability. We present a clustering algorithm that offers the best of both worlds: the scalability of embedding models and the quality of cross-attention models. KwikBucks can also be applied to other clustering problems with multiple similarity oracles of varying accuracy levels. This is validated with an exhaustive set of experiments on various datasets with diverse properties. See the paper for more details.


This project was initiated during Sandeep Silwal’s summer internship at Google in 2022. We would like to express our gratitude to our co-authors, Andrew McCallum, Andrew Nystrom, Deepak Ramachandran, and Sandeep Silwal, for their valuable contributions to this work. We also thank Ravi Kumar and John Guilyard for assistance with this blog post.

Read More

‘Starship for the Mind’: University of Florida Opens Malachowsky Hall, an Epicenter for AI and Data Science

‘Starship for the Mind’: University of Florida Opens Malachowsky Hall, an Epicenter for AI and Data Science

Embodying the convergence of AI and academia, the University of Florida Friday inaugurated the Malachowsky Hall for Data Science & Information Technology.

The sleek, seven-story building is poised to play a pivotal role in UF’s ongoing efforts to harness the transformative power of AI, reaffirming its stature as one of the nation’s leading public universities.

Evoking Apple co-founder Steve Jobs’ iconic description of a personal computer, NVIDIA’s founder and CEO Jensen Huang described Malachowsky Hall — named for NVIDIA co-founder Chris Malachowsky — and the HiPerGator AI supercomputer it hosts as a “starship for knowledge discovery.”

“Steve Jobs called (the PC) ‘the bicycle of the mind,’ a device that propels our thoughts further and faster,” Huang said.

“What Chris Malachowsky has gifted this institution is nothing short of the ‘starship of the mind’ — a vehicle that promises to take our intellect to uncharted territories,” Huang said.

The inauguration of the 260,000-square-foot structure marks a milestone in the partnership between UF alum Malachowsky, NVIDIA and the state of Florida — a collaboration that has propelled UF to the forefront of AI innovation.

Malachowsky and NVIDIA both made major contributions toward its construction, bolstered by a $110 million investment from the state of Florida.

University of Florida President Ben Sasse and NVIDIA CEO Jensen Huang speak at the opening of Malachowsky Hall.
University of Florida President Ben Sasse and NVIDIA CEO Jensen Huang speak at the opening of Malachowsky Hall.

Following the opening, Huang and UF’s new president, Ben Sasse, met to discuss the impact of AI and data science across UF and beyond for students just starting their careers.

Sasse underscored the importance of adaptability in a rapidly changing world, telling the audience: “work in lots and lots of different organizations … because lifelong work in any one, not just firm, but any one industry is going to end in our lives. You’re ultimately going to have to figure out how to reinvent yourselves at 30, 35, 40 and 45.”

Huang offered students very different advice, recalling how he met his wife, Lori, who was in the audience, as an undergraduate. “Have a good pickup line … do you want to know the pickup line?” Huang asked, pausing a beat. “You want to see my homework?”

The spirit of Sasse and Huang’s adaptable approach to career and personal development is embodied in Malachowsky Hall, designed to bring together people from academia and industry, research and government.

Packed with innovative collaboration spaces and labs, the hall features a spacious 400-seat auditorium, dedicated high-performance computing study spaces and a rooftop terrace that unveils panoramic campus vistas.

Sustainability is woven into its design, highlighted by energy-efficient systems and rainwater harvesting facilities.

Malachowsky Hall will serve as a conduit to bring the on-campus advances in AI to Florida’s thriving economy, which continues to outpace the nation in jobs and GDP growth.

Together, NVIDIA founder and UF alumnus Chris Malachowsky and NVIDIA donated $50 million toward the University of Florida’s HiPerGator AI supercomputer.
Together, NVIDIA founder and UF alumnus Chris Malachowsky and NVIDIA donated $50 million toward the University of Florida’s HiPerGator AI supercomputer.

UF’s efforts to bring AI and academia together, catalyzed by support from Malachowsky and NVIDIA, go far beyond Malachowsky Hall.

In 2020, UF announced that Malachowsky and NVIDIA together donated $50 million toward HiPerGator, one of the most powerful AI supercomputers in the country.

With additional state support, UF recently added more than 110 AI faculty members to the 300 already engaged in AI teaching and research.

As a result, UF extended AI-focused courses, workshops and projects across the university, enabling its 55,000 students to delve into AI and its interdisciplinary applications.

Friday’s ribbon-cutting will open exciting new opportunities for the university, its students and the state of Florida to realize the potential of AI innovations across sectors.

Huang likened pursuing knowledge through AI to embarking on a “starship.” “You’ve got to go as far as you can,” he urged students.

For a deeper exploration of Malachowsky Hall and UF’s groundbreaking AI initiatives, visit UF’s website.

Read More