Information extraction with LLMs using Amazon SageMaker JumpStart

Information extraction with LLMs using Amazon SageMaker JumpStart

Large language models (LLMs) have unlocked new possibilities for extracting information from unstructured text data. Although much of the current excitement is around LLMs for generative AI tasks, many of the key use cases that you might want to solve have not fundamentally changed. Tasks such as routing support tickets, recognizing customers intents from a chatbot conversation session, extracting key entities from contracts, invoices, and other type of documents, as well as analyzing customer feedback are examples of long-standing needs.

What makes LLMs so transformative, however, is their ability to achieve state-of-the-art results on these common tasks with minimal data and simple prompting, and their ability to multitask. Rather than requiring extensive feature engineering and dataset labeling, LLMs can be fine-tuned on small amounts of domain-specific data to quickly adapt to new use cases. By handling most of the heavy lifting, services like Amazon SageMaker JumpStart remove the complexity of fine-tuning and deploying these models.

SageMaker JumpStart is a machine learning (ML) hub with foundation models (FMs), built-in algorithms, and prebuilt ML solutions that you can deploy with just a few clicks. With SageMaker JumpStart, you can evaluate, compare, and select FMs quickly based on predefined quality and responsibility metrics to perform tasks like article summarization and image generation.

This post walks through examples of building information extraction use cases by combining LLMs with prompt engineering and frameworks such as LangChain. We also examine the uplift from fine-tuning an LLM for a specific extractive task. Whether you’re looking to classify documents, extract keywords, detect and redact personally identifiable information (PIIs), or parse semantic relationships, you can start ideating your use case and use LLMs for your natural language processing (NLP).

Prompt engineering

Prompt engineering enables you to instruct LLMs to generate suggestions, explanations, or completions of text in an interactive way. Prompt engineering relies on large pretrained language models that have been trained on massive amounts of text data. At first glance, there might not be one best way to design a prompt, and different LLMs might work better or worse with different prompts. Therefore, prompts are often iteratively refined through trial and error to produce better results. As a starting point, you can refer to the model documentation which typically includes recommendations and best practices for prompting the model, and examples provided in SageMaker JumpStart.

In the following sections, we focus on the prompt engineering techniques required for extractive use cases. They help unlock the power of LLMs by providing helpful constraints and guide the model toward its intended behavior. We discuss the following use cases:

  • Sensitive information detection and redaction
  • Entity extraction; generic and specific entities with structured formats
  • Classification, using prompt engineering and fine-tuning

Before we explore these use cases, we need to set up our development environment.

Prerequisites

The source code accompanying this example is available in this GitHub repo. It consists of several Jupyter notebooks and a utils.py module. The utils.py module houses the shared code that is used throughout the notebooks.

The simplest way to run this example is by using Amazon SageMaker Studio with the Data Science 3.0 kernel or an Amazon SageMaker notebook instance with the conda_python3 kernel. For the instance type, you can choose the default settings.

In this example, we use ml.g5.2xlarge and ml.g5.48xlarge instances for endpoint usage, and ml.g5.24xlarge for training job usage. Use the Service Quotas console to make sure you have sufficient quotas for these instances in the Region where you’re running this example.

We use Jupyter notebooks throughout this post. Before we explore the examples, it’s crucial to confirm that you have the latest version of the SageMaker Python SDK. This SDK offers a user-friendly interface for training and deploying models on SageMaker. To install or upgrade to the latest version, run the following command in the first cell of your Jupyter notebook:

%pip install --quiet --upgrade sagemaker

Deploy Llama-2-70b-chat using SageMaker JumpStart

There are many LLMs available in SageMaker JumpStart to choose from. In this example, we use Llama-2-70b-chat, but you might use a different model depending on your use case. To explore the list of SageMaker JumpStart models, see JumpStart Available Model Table.

To deploy a model from SageMaker JumpStart, you can use either APIs, as demonstrated in this post, or use the SageMaker Studio UI. After the model is deployed, you can test it by asking a question from the model:

from sagemaker.jumpstart.model import JumpStartModel

model_id, model_version = "meta-textgeneration-llama-2-70b-f", "2.*"
endpoint_name = model_id
instance_type = "ml.g5.48xlarge"

model = JumpStartModel(
	model_id=model_id, model_version=model_version, role=role_arn
)
predictor = model.deploy(
	endpoint_name=endpoint_name, instance_type=instance_type
)

If no instance_type is provided, the SageMaker JumpStart SDK will select the default type. In this example, you explicitly set the instance type to ml.g5.48xlarge.

Sensitive data extraction and redaction

LLMs show promise for extracting sensitive information for redaction. This includes techniques such as prompt engineering, which includes priming the model to understand the redaction task, and by providing examples that can improve the performance. For example, priming the model by stating “redact sensitive information” and demonstrating a few examples of redacting names, dates, and locations can help the LLM infer the rules of the task.

More in-depth forms of priming the model include providing positive and negative examples, demonstrations of common errors, and in-context learning to teach the nuances of proper redaction. With careful prompt design, LLMs can learn to redact information while maintaining readability and utility of the document. In real-life applications, however, additional evaluation is often necessary to improve the reliability and safety of LLMs for handling confidential data. This is often achieved through the inclusion of human review, because no automated approach is entirely foolproof.

The following are a few examples of using prompt engineering for the extraction and redaction of PII. The prompt consists of multiple parts: the report_sample, which includes the text that you want to identify and mask the PII data within, and instructions (or guidance) passed on to the model as the system message.

report_sample = """
This month at AnyCompany, we have seen a significant surge in orders from a diverse clientele. On November 5th, 2023, customer Alice from US placed an order with total of $2190. Following her, on Nov 7th, Bob from UK ordered a bulk set of twenty-five ergonomic keyboards for his office setup with total of $1000. The trend continued with Jane from Australia, who on Nov 12th requested a shipment of ten high-definition monitors with total of $9000, emphasizing the need for environmentally friendly packaging. On the last day of that month, customer John, located in Singapore, finalized an order for fifteen USB-C docking stations, aiming to equip his design studio with the latest technology for total of $3600.
"""

system = """
Your task is to precisely identify Personally Identifiable Information (PII) and identifiable details, including name, address, and the person's country, in the provided text. Replace these details with exactly four asterisks (****) as the masking characters. Use '****' for masking text of any length. Only write the masked text in the response.
"""

In the following example, you define the llama2_chat function that encapsulates sending the prompt to the Llama-2 model. You reuse this function throughout the examples.

def llama2_chat(
    predictor,
    user,
    temperature=0.1,
    max_tokens=512,
    top_p=0.9,
    system=None,
):
    """Constructs the payload for the llama2 model, sends it to the endpoint,
    and returns the response."""

    inputs = []
    if system:
        inputs.append({"role": "system", "content": system})
    if user:
        inputs.append({"role": "user", "content": user})

    payload = {
        "inputs": [inputs],
        "parameters": {
            "max_new_tokens": max_tokens,
            "top_p": top_p,
            "temperature": temperature,
        },
    }
    response = predictor.predict(payload, custom_attributes="accept_eula=true")
    return response

Use the following code to call the function, passing your parameters:

response = utils.llama2_chat(
    predictor,
    system=system,
    user=report_sample,
)
print(utils.llama2_parse_output(response))

You get the following output:

This month at AnyCompany, we have seen a significant surge in orders from a diverse clientele. On November 5th, 2023, customer ***** from ***** placed an order with total of $2190. Following her, on Nov 7th, ***** from ***** ordered a bulk set of twenty-five ergonomic keyboards for his office setup with total of $1000. The trend continued with ***** from *****, who on Nov 12th requested a shipment of ten high-definition monitors with total of $9000, emphasizing the need for environmentally friendly packaging. On the last day of that month, customer *****, located in *****, finalized an order for fifteen USB-C docking stations, aiming to equip his design studio with the latest technology for total of $3600.

Entity extraction

Entity extraction is the process of identifying and extracting key information entities from unstructured text. This technique helps create structured data from unstructured text and provides useful contextual information for many downstream NLP tasks. Common applications for entity extraction include building a knowledge base, extracting metadata to use for personalization or search, and improving user inputs and conversation understanding within chatbots.

You can effectively use LLMs for entity extraction tasks through careful prompt engineering. With a few examples of extracting entities from text, explanatory prompts, and the desired output format, the model can learn to identify and extract entities such as people, organizations, and locations from new input texts. In the following examples, we demonstrate a few different entity extraction tasks ranging from simpler to more complex using prompt engineering with the Llama-2-70b-chat model you deployed earlier.

Extract generic entities

Use the following code to extract specific entities:

email_sample = "Hello, My name is John. Your AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0008 has a minimum payment of $24.53 that is due by July 31st. Based on your autopay settings, we will withdraw your payment on the due date from your bank account number XXXXXX1111 with the routing number XXXXX0000. Customer feedback for Sunshine Spa, 123 Main St, Anywhere. Send comments to Alice at alice_aa@anycompany.com and Bob at bob_bb@anycompany.com. I enjoyed visiting the spa. It was very comfortable but it was also very expensive. The amenities were ok but the service made the spa a great experience."

system = """
Your task is to precisely identify any email addresses from the given text and then write them, one per line. Remember to ONLY write an email address if it's precisely spelled out in the input text. If there are no email addresses in the text, write "N/A". DO NOT write anything else.
"""

result = utils.llama2_chat(predictor, system=system, user=email_sample)
print(utils.llama2_parse_output(result))

You get the following output:

alice_aa@anycompany.com
bob_bb@anycompany.com

Extract specific entities in a structured format

Using the previous sample report, you can extract more complex information in a structured manner. This time, you provide a JSON template for the model to use and return the output in JSON format.

With LLMs generating JSON documents as output, you can effortlessly parse them into a range of other data structures. This enables simple conversions to dictionaries, YAML, or even Pydantic models using third-party libraries, such as LangChain’s PydanticOutputParser. You can see the implementation in the GitHub repo.

import json

system = """
Your task is to precisely extract information from the text provided, and format it according to the given JSON schema delimited with triple backticks. Only include the JSON output in your response. If a specific field has no available data, indicate this by writing `null` as the value for that field in the output JSON. In cases where there is no data available at all, return an empty JSON object. Avoid including any other statements in the response.

```
{json_schema}
```
"""

json_schema = """
{
    "orders":
        [
            {
                "name": "<customer_name>",
                "location": "<customer_location>",
                "order_date": "<order_date in format YYYY-MM-DD>",
                "order_total": "<order_total>",
                "order_items": [
                    {
                        "item_name": "<item_name>",
                        "item_quantity": "<item_quantity>"
                    }
                ]
            }
        ]
}
"""


response = utils.llama2_chat(
    predictor,
    system=system.format(json_schema=json_schema),
    user=report_sample,
)
json_str = utils.llama2_parse_output(response)
print(json_str)

You get the following output:

{
    "orders": [
        {
            "name": "Alice",
            "location": "US",
            "order_date": "2023-11-05",
            "order_total": 2190,
            "order_items": [
                {
                    "item_name": null,
                    "item_quantity": null
                }
            ]
        },
        {
            "name": "Bob",
            "location": "UK",
            "order_date": "2023-11-07",
            "order_total": 1000,
            "order_items": [
                {
                    "item_name": "ergonomic keyboards",
                    "item_quantity": 25
                }
            ]
        },
        {
            "name": "Jane",
            "location": "Australia",
            "order_date": "2023-11-12",
            "order_total": 9000,
            "order_items": [
                {
                    "item_name": "high-definition monitors",
                    "item_quantity": 10
                }
            ]
        },
        {
            "name": "John",
            "location": "Singapore",
            "order_date": "2023-11-30",
            "order_total": 3600,
            "order_items": [
                {
                    "item_name": "USB-C docking stations",
                    "item_quantity": 15
                }
            ]
        }
    ]
}

Classification using prompt engineering

LLMS can be a useful tool for information extraction tasks such as text classification. Common applications include classifying the intents of user interactions via channels such as email, chatbots, voice, and others, or categorizing documents to route their requests to downstream systems. The initial step involves identifying the intent or class of the user’s request or the document. These intents or classes could take many forms—from short single words to thousands of hierarchical classes and sub-classes.

In the following examples, we demonstrate prompt engineering on synthetic conversation data to extract intents. Additionally, we show how pre-trained models can be assessed to determine if fine-tuning is needed.

Let’s start with the following example. You have a list of customer interactions with an imaginary health and life insurance company. To start, use the Llama-2-70b-chat model you deployed in the previous section:

inference_instance_type = "ml.g5.48xlarge"

# Llama-2-70b chat
model_id, model_version = "meta-textgeneration-llama-2-70b-f", "2.*"
endpoint_name = model_id

predictor = utils.get_predictor(
    endpoint_name=endpoint_name,
    model_id=model_id,
    model_version=model_version,
    inference_instance_type=inference_instance_type,
)

The get_predictor function is a helper function that creates a predictor object from a model ID and version. If the specified endpoint doesn’t exist, it creates a new endpoint and deploy the model. If the endpoint already exists, it uses the existing endpoint.

customer_interactions = [
    """Hello, I've recently moved to a new state and I need to update my address for my health insurance policy.
Can you assist me with that?
""",
    """Good afternoon! I'm interested in adding dental coverage to my existing health plan.
Could you provide me the options and prices?
""",
    """I had a disappointing experience with the customer service yesterday regarding my claim.
I want to file a formal complaint and speak with a supervisor.
""",
]

system = """
Your task is to identify the customer intent from their interactions with support bot in the provided text. The intent output must not more than 4 words. If the intent is not clear, please provide a fallback intent of "unknown".
"""

def get_intent(system, customer_interactions):
    for customer_interaction in customer_interactions:
        response = utils.llama2_chat(
            predictor,
            system=system,
            user=customer_interaction,
        )
        content = utils.llama2_parse_output(response)
        print(content)
get_intent(system, customer_interactions)

You get the following output:

Update Address
Intent: Informational
Intent: Escalate issue

Looking at the output, these seem reasonable as the intents. However, the format and style of the intents can vary depending on the language model. Another limitation of this approach is that intents are not confined to a predefined list, which means the language model might generate and word the intents differently each time you run it.

To address this, you can use the in-context learning technique in prompt engineering to steer the model towards selecting from a predefined set of intents, or class labels, that you provide. In the following example, alongside the customer conversation, you include a list of potential intents and ask the model to choose from this list:

system = """
Your task is to identify the intent from the customer interaction with the support bot. Select from the intents provided in the following list delimited with ####. If the intent is not clear, please provide a fallback intent of "unknown". ONLY write the intent.

####
- information change
- add coverage
- complaint
- portal navigation
- free product upgrade
####
"""

get_intent(system, customer_interactions)

You get the following output:

information change
add coverage
complaint

Reviewing the results, it’s evident that the language model performs well in selecting the appropriate intent in the desired format.

Sub-intents and intent trees

If you make the preceding scenario more complex, as in many real-life use cases, intents can be designed in a large number of categories and also in a hierarchical fashion, which will make the classification tasks more challenging for the model. Therefore, you can further improve and modify your prompt to provide an example to the model, also known as n-shot learning, k-shot learning, or few-shot learning.

The following is the intent tree to use in this example. You can find its source code in the utils.py file in the code repository.

INTENTS = [
    {
        "main_intent": "profile_update",
        "sub_intents": [
            "contact_info",
            "payment_info",
            "members",
        ],
    },
    {
        "main_intent": "health_cover",
        "sub_intents": [
            "add_extras",
            "add_hospital",
            "remove_extras",
            "remove_hospital",
            "new_policy",
            "cancel_policy",
        ],
    },
    {
        "main_intent": "life_cover",
        "sub_intents": [
            "new_policy",
            "cancel_policy",
            "beneficiary_info",
        ],
    },
    {
        "main_intent": "customer_retention",
        "sub_intents": [
            "complaint",
            "escalation",
            "free_product_upgrade",
        ],
    },
    {
        "main_intent": "technical_support",
        "sub_intents": [
            "portal_navigation",
            "login_issues",
        ],
    },
]

Using the following prompt (which includes the intents), you can ask the model to pick from the provided list of intents:

system = """
Your task is to identify the intent from the customer interaction with the support bot. Identify the intent of the provided text using the list of provided intent tree delimited with ####. The intents are defined in classes and sub-classes. Write the intention with this format: <main-intent>:<sub-intent>. ONLY write the intent.

OUTPUT EXAMPLE:
profile_update:contact_info

OUTPUT EXAMPLE:
customer_retention:complaint

####
{intents}
####
"""

intents_json = json.dumps(utils.INTENTS, indent=4)
system = system.format(intents=intents_json)
get_intent(system, customer_interactions)

You get the following output:

profile_update:contact_info
health_cover:add_extras
customer_retention:complaint

Although LLMs can often correctly identify intent from a list of possible intents, they may sometimes produce additional outputs or fail to adhere to the exact intent structure and output schema. There are also scenarios where intents are not as straightforward as they initially seem or are highly specific to a business domain context that the model doesn’t fully comprehend.

As an example, in the following sample interaction, the customer ultimately wants to change their coverage, but their immediate question and interaction intent is to get help with portal navigation. Similarly, in the second interaction, the more appropriate intent is “free product upgrade” which the customer is requesting. However, the model is unable to detect these nuanced intents as accurately as desired.

customer_interactions = [
    "I want to change my coverage plan. But I'm not seeing where to do this on the online website. Could you please point me to it?",
    "I'm unhappy with the current benefits of my plan and I'm considering canceling unless there are better alternatives. What can you offer?",
]

get_intent(system, customer_interactions)

You get the following output:

profile_update:contact_info
customer_retention:complaint

Prompt engineering can often successfully extract specific intents from text. However, for some use cases, relying solely on prompt engineering has limitations. Scenarios where additional techniques beyond prompt engineering may be needed include:

  • Conversations with a large number of intent classes or long contexts that exceed the language model’s context window size, or making queries more computationally expensive
  • Desired outputs in specific formats that the model struggles to adopt
  • Enhancing model understanding of the domain or task to boost performance

In the following section, we demonstrate how fine-tuning can boost the accuracy of the LLM for the intent classification task attempted earlier.

Fine-tuning an LLM for classification

The following sections detail the fine-tuning process of the FlanT5-XL and Mistral 7B model using SageMaker JumpStart. We use the FlanT5-XL and Mistral 7B models to compare their accuracy. Both models are significantly smaller compared to the Llama-2-70b-Chat. The goal is to determine whether smaller models can achieve state-of-the-art performance on specific tasks after they’re fine-tuned.

We have fine-tuned both Mitral 7B and FlanT5-XL models. You can see the details of the Mistral 7b fine-tuning in the code repository. In the following, we outline the steps for fine-tuning and evaluating of FlanT5-XL.

Initially, you deploy (or reuse) the FlanT5 endpoint as the base_predictor, which represents the base model prior to any fine-tuning. Subsequently, you assess the performance of the models by comparing them after the fine-tuning process.

inference_instance_type = "ml.g5.2xlarge"

model_id , model_version= "huggingface-text2text-flan-t5-xl", "2.0.0"
base_endpoint_name = model_id

base_predictor = utils.get_predictor(
    endpoint_name=base_endpoint_name,
    model_id=model_id,
    model_version=model_version,
    inference_instance_type=inference_instance_type,
)

Prepare training data for fine-tuning

Preparing for fine-tuning requires organizing several files, including the dataset and template files. The dataset is structured to align with the required input format for fine-tuning. For example, each record in our training dataset adheres to the following structure:

{"query": "customer query", "response": "main-intent:sub-intent"}

In this example, you use a synthesized dataset comprising customer interactions with a fictional insurance company. To learn more about the data and gain access to it, refer to the source code.

intent_dataset_file = "data/intent_dataset.jsonl"
intent_dataset_train_file = "data/intent_dataset_train.jsonl"
intent_dataset_test_file = "data/intent_dataset_test.jsonl"
ft_template_file = "data/template.json"

The following is the prompt for fine-tuning. The prompt has the query parameter, which is set during the fine-tuning using the SageMaker JumpStart SDK.

FT_PROMPT = """Identify the intent classes from the given user query, delimited with ####. Intents are categorized into two levels: main intent and sub intent. In your response, provide only ONE set of main and sub intents that is most relevant to the query. Write your response ONLY in this format <main-intent>:<sub-intent>. ONLY Write the intention.

OUTPUT EXAMPLE:
profile_update:contact_info

OUTPUT EXAMPLE:
technical_support:portal_navigation

#### QUERY:
{query}
####
"""

The following creates a template file that will be used by the SageMaker JumpStart framework to fine-tune the model. The template has two fields, prompt and completion. These fields are used to pass labeled data to the model for the fine-tuning process.

template = {
    "prompt": utils.FT_PROMPT,
    "completion": "{response}",
}

with open(ft_template_file, "w") as f:
    json.dump(template, f)

The training data is uploaded to an Amazon Simple Storage Service (Amazon S3) bucket, setting the stage for the actual fine-tuning process.

train_data_location = utils.upload_train_and_template_to_s3(
    bucket_prefix="intent_dataset_flant5",
    train_path=intent_dataset_train_file,
    template_path=ft_template_file,
)

Fine-tune the model

Configure the JumpStartEstimator, specifying your chosen model and other parameters like instance type and hyperparameters (in this example, you use five epochs for the training). This estimator drives the fine-tuning process.

from sagemaker.jumpstart.estimator import JumpStartEstimator

estimator = JumpStartEstimator(
    model_id=model_id,
    disable_output_compression=True,
    instance_type="ml.g5.24xlarge",
    role=utils.get_role_arn(),
)

estimator.set_hyperparameters(
    instruction_tuned="True", epochs="5", max_input_length="1024"
)

estimator.fit({"training": train_data_location})

Deploy the fine-tuned model

After fine-tuning, deploy the fine-tuned model:

finetuned_endpoint_name = "flan-t5-xl-ft-infoext"
finetuned_model_name = finetuned_endpoint_name
# Deploying the finetuned model to an endpoint
finetuned_predictor = estimator.deploy(
    endpoint_name=finetuned_endpoint_name,
    model_name=finetuned_model_name,
)

Use the following code to test the fine-tuned model against its base model with ambiguous queries, which you saw in the previous section:

ambiguous_queries = [
    {
        "query": "I want to change my coverage plan. But I'm not seeing where to do this on the online site. Could you please show me how?",
        "main_intent": "techincal_support",
        "sub_intent": "portal_navigation",
    },
    {
        "query": "I'm unhappy with the current benefits of my plan and I'm considering canceling unless there are better alternatives. What can you offer?",
        "main_intent": "customer_retention",
        "sub_intent": "free_product_upgrade",
    },
]
for query in ambiguous_queries:
    question = query["query"]
    print("query:", question, "n")
    print(
        "expected intent:  ", f"{query['main_intent']}:{query['sub_intent']}"
    )

    prompt = utils.FT_PROMPT.format(query=question)
    response = utils.flant5(base_predictor, user=prompt, max_tokens=13)
    print("base model:  ", utils.parse_output(response))

    response = utils.flant5(finetuned_predictor, user=prompt, max_tokens=13)
    print("finetuned model:  ", utils.parse_output(response))
    print("-" * 80)

You get the following output:

query: I want to change my coverage plan. But I'm not seeing where to do this on the online site. Could you please show me how?
expected intent:   techincal_support:portal_navigation
base model:   main_intent>:sub_intent> change
finetuned model:   technical_support:portal_navigation
--------------------------------------------------------------------------------
query: I'm unhappy with the current benefits of my plan and I'm considering canceling unless there are better alternatives. What can you offer?

expected intent:   customer_retention:free_product_upgrade
base model:   main_intent>:sub_intent> cancel
finetuned model:   customer_retention:free_product_upgrade
--------------------------------------------------------------------------------

As shown in this example, the fine-tuned model is able to classify the ambiguous queries correctly.

In evaluations, fine-tuned models performed better in identifying the correct class for both clear and ambiguous intents. The following section details the benchmark’s performance overall, and against each intent.

Performance comparisons and considerations

In this section, we have gathered the evaluation results and performance benchmarks for each model, before and after fine-tuning, as well as a comparison between the prompt engineering and fine-tuning the LLM. The dataset consists of 7,824 examples, with a 90% split for training (including validation) and 10% for testing.

Model Overall Accuracy Fine-tuning Duration (minutes) Notes
Mistral-7b (fine-tuned five epochs, without classes in the prompt) 98.97% 720

Given Mistral-7b’s nature as a text generation model, parsing its output to extract intent can be challenging due to tendencies for character repetition and generation of additional characters.

Improved performance with more epochs: 98% accuracy for five epochs compared to 92% for one epoch.

Flan-T5-XL (fine-tuned one epochs, without classes in the prompt) 98.46% 150 Marginal improvement in accuracy with increased epochs: from 97.5% (one epoch) to 98.46% (five epochs).
Llama-2-70b-chat (With classes in the prompt) 78.42% N/A Low accuracy in ambiguous scenarios.
Llama-2-70b-chat (Without classes in the prompt) 10.85% N/A .
Flan-T5-XL (base model, without classes in the prompt) 0.0% N/A Unable to identify any of the intent classes with the expected format.
Mistral-7b (base model, without classes in the prompt) 0.0% N/A Unable to identify any of the intent classes with the expected format.

The following table contains a breakdown of models’ accuracy for each intent class.

Main Intent Sub-intent Example Count Llama2-70b (without classes in prompt) Llama2-70b (with classes in prompt) Flant5-XL
Fine-tuned
Mistral-7b Fine-tuned
Customer Retention Complaint 63 7.94% 44.44% 98.41% 98.41%
Customer Retention Escalation 49 91.84% 100% 100% 100%
Customer Retention Free Product Upgrade 50 0.00% 64.00% 100% 100%
Health Cover Add Extras 38 0.00% 100% 97.37% 100%
Health Cover Add Hospital 44 0.00% 81.82% 100% 97.73%
Health Cover Cancel Policy 43 0.00% 100% 100% 97.67%
Health Cover New Policy 41 0.00% 82.93% 100% 100%
Health Cover Remove Extras 47 0.00% 85.11% 100% 100%
Health Cover Remove Hospital 53 0.00% 84.90% 100% 100%
Life Cover Beneficiary Info 45 0.00% 100% 97.78% 97.78%
Life Cover Cancel Policy 47 0.00% 55.32% 100% 100%
Life Cover New Policy 40 0.00% 90.00% 92.50% 100%
Profile Update Contact Info 45 35.56% 95.56% 95.56% 95.56%
Profile Update Members 52 0.00% 36.54% 98.08% 98.08%
Profile Update Payment Info 47 40.43% 97.87% 100% 100%
Technical Support Login Issues 39 0.00% 92.31% 97.44% 100%
Technical Support Portal Navigation 40 0.00% 45.00% 95.00% 97.50%

This comparative analysis illustrates the trade-offs between fine-tuning time and model accuracy. It highlights the ability of models like Mistral-7b and FlanT5-XL to achieve higher classification accuracy through fine-tuning. Additionally, it shows how smaller models can match or surpass the performance of larger models on specific tasks when fine-tuned, contrasted with using prompt engineering alone on the larger models.

Clean up

Complete the following steps to clean up your resources:

  1. Delete the SageMaker endpoints, configuration, and models.
  2. Delete the S3 bucket created for this example.
  3. Delete the SageMaker notebook instance (if you used one to run this example).

Summary

Large language models have revolutionized information extraction from unstructured text data. These models excel in tasks such as classifying information and extracting key entities from various documents, achieving state-of-the-art results with minimal data.

This post demonstrated the use of large language models for information extraction through prompt engineering and fine-tuning. While effective, relying solely on prompt engineering can have limitations for complex tasks that require rigid output formats or a large number of classes. In these scenarios, fine-tuning even smaller models on domain-specific data can significantly improve performance beyond what prompt engineering alone can achieve.

The post included practical examples highlighting how fine-tuned smaller models can surpass prompt engineering with larger models for such complex use cases. Although prompt engineering is a good starting point for simpler use cases, fine-tuning offers a more robust solution for complex information extraction tasks, ensuring higher accuracy and adaptability to specific use cases. SageMaker JumpStart tools and services facilitate this process, making it accessible for individuals and teams across all levels of ML expertise.

Additional reading

You can read more on using SageMaker JumpStart for intelligent document processing, fine-tuning, and evaluation of LLMs in the following resources:


About the Authors

Pooya Vahidi  is a Senior Solutions Architect at AWS, passionate about computer science, artificial intelligence, and cloud computing. As an AI professional, he is an active member of the AWS AI/ML Area-of-Depth team. With a background spanning over two decades of expertise in leading the architecture and engineering of large-scale solutions, he helps customers on their transformative journeys through cloud and AI/ML technologies.

Dr. Romina Sharifpour is a Senior Machine Learning and Artificial Intelligence Solutions Architect at Amazon Web Services (AWS). She has spent over 10 years leading the design and implementation of innovative end-to-end solutions enabled by advancements in ML and AI. Romina’s areas of interest are natural language processing, large language models, and MLOps.

Read More

NVIDIA DGX SuperPOD to Power U.S. Government Generative AI

NVIDIA DGX SuperPOD to Power U.S. Government Generative AI

In support of President Biden’s executive order on AI, the U.S. government will use an NVIDIA DGX SuperPOD to produce generative AI advances in climate science, healthcare and cybersecurity.

The executive order, issued in October, is aimed at ensuring U.S. leadership in AI and managing its risks. MITRE, a nonprofit organization that operates federally funded research and development centers, is implementing a new NVIDIA DGX SuperPOD system that will provide researchers and developers access to massive computing leaps.

The DGX SuperPOD will support MITRE’s Federal AI Sandbox, a platform to improve experimentation with next-generation, AI-enabled applications across federal government agencies.

“The recent executive order on AI encourages federal agencies to reduce barriers for AI adoptions, but agencies often lack the computing environment necessary for experimentation and prototyping,” said Charles Clancy, senior vice president and chief technology officer at MITRE. “Our new Federal AI Sandbox will help level the playing field, making the high-quality compute power needed to train and test custom AI solutions available to any agency.”

The Federal AI Sandbox will deliver federal agencies the computing gains needed to train large language models and other generative AI tools to develop cutting-edge applications.

The NVIDIA DGX SuperPOD system powering the sandbox is capable of an exaflop of 8-bit AI compute, meaning it performs a quintillion math operation each second to train and deploy custom LLMs and other AI solutions at scale.

The supercomputing initiative comes as the White House recently unveiled plans, which include NVIDIA, for a $110 million partnership to help universities teach AI skills.

Read More

LoftQ: Reimagining LLM fine-tuning with smarter initialization

LoftQ: Reimagining LLM fine-tuning with smarter initialization

This research paper was presented at the 12th International Conference on Learning Representations (opens in new tab) (ICLR 2024), the premier conference dedicated to the advancement of deep learning.

Teal background with ICLR logo on the right (head and face) with LoftQ paper on the right.

Large language models (LLMs) use extensive datasets and advanced algorithms to generate nuanced, context-sensitive content. However, their development requires substantial computational resources. To address this, we developed LoftQ, an innovative technique that streamlines the fine-tuning process—which is used to adapt pre-trained language models to perform well in specialized applications, such as analyzing medical documents. During fine-tuning, the model undergoes additional training on a smaller, task-specific dataset. This results in improved performance, such as more accurate predictions, better understanding of domain-specific language, and more relevant responses in the context of the specialized area.

LoftQ’s strength lies in its ability to combine quantization and adaptive initialization during fine-tuning. Quantization reduces the precision of model parameters, lowering memory and computation needs. This not only accelerates processing but also reduces power consumption. Adaptive initialization closely aligns the model’s parameters to its optimal pre-trained state, preserving its capabilities while minimizing resource use. Our paper, “LoftQ: LoRA-Fine-Tuning-Aware Quantization for Large Language Models,” presented at ICLR 2024, details how this method can help make AI technologies more efficient and sustainable. 

How LoftQ works 

LoftQ builds on the principles of LoRA (opens in new tab) and QLoRA (opens in new tab). LoRA is a method that greatly reduces the number of parameters needed for training, decreasing the memory requirements for fine-tuning. QLoRA is a fine-tuning approach that uses 4-bit quantized, frozen weights and low rank adapters, significantly reducing memory requirements while maintaining high performance. This is illustrated in Table 1, which shows the amount of memory needed for fine-tuning an LLM with 7 billion parameters as well as the memory requirements for LoRA and QLoRA. LoRA achieves a fourfold reduction in memory usage, and QLoRA further reduces it by twofold.

LoftQ - Table 1: This table shows the GPU memory usage for a 7-billion parameter LLM, with the following configurations: full fine-tuning on the left, LoRA in the middle, and QLoRA on the right.
Table 1: This table shows the GPU memory usage for a 7-billion parameter LLM with the following configurations: full fine-tuning on the left, LoRA in the middle, and QLoRA on the right.

Unlike LoRA, QLoRA comes with a tradeoff, where some quality of the pretrained model is sacrificed due to the quantization of weights. LoftQ recognizes this and optimizes the initialization of quantization and low-rank adaptation matrices. That is, LoftQ seeks to identify a combination of a quantized matrix and a low rank matrix such that their sum closely approximates the original pretrained weight. This is done for every matrix that would be adapted in the model.

The LoftQ algorithm alternates between two primary steps. First it quantizes (simplifies) the weights, and then it finds the best low-rank factors that approximate the quantization between the pretrained weight and the low-rank weight. The process repeats for a few steps. This method enables the fine-tuning process to start from a more effective initial state, which preserves accuracy while using less computational power and much more simplified weights.

LoftQ requires a one-time setup to simplify and prepare these weights, allowing a fixed portion of the model’s parameters (e.g., 5 percent) to be adjusted. Once established, this configuration can be repeatedly applied as the model transitions between various tasks and settings.

Evaluating LoftQ 

Tests using various types of LLMs, including those with different combinations of encoding and decoding capabilities like the Llama-2, show that models initialized with LoftQ consistently achieve strong performance, often matching or surpassing those configured with QLoRA.

In practical terms, comparing the performance of LoftQ and QLoRA on different tasks using the Llama-2 model family yields distinct results, which are highlighted in Table 2. For the WikiText-2 dataset, which measures the model’s perplexity (lower is better), and the GSM8K dataset, which tests the model’s ability to solve basic math problems (higher is better), we demonstrate the effectiveness of varying degrees of weight simplification—averaging 3, 2.5, and 2.25 bits per weight. Our paper discusses the results in more detail. 

LoftQ - Table 2. This table compares LoftQ and QLoRA during the fine-tuning of two Llama-2 models on the Wikitext-2 and GSM8K datasets.
Table 2. This table compares LoftQ and QLoRA during the fine-tuning of two Llama-2 models on the Wikitext-2 and GSM8K datasets.

Microsoft Research Podcast

Collaborators: Holoportation™ communication technology with Spencer Fowers and Kwame Darko

Spencer Fowers and Kwame Darko break down how the technology behind Holoportation and the telecommunication device being built around it brings patients and doctors together when being in the same room isn’t an easy option and discuss the potential impact of the work.


Implications and looking forward 

LoftQ promises to advance the field of AI by accelerating research and facilitating the creation of cutting-edge tools while supporting sustainable development. While initially focused on LLMs, LoftQ’s flexible design also supports fine-tuning in other types of models, such those for vision and speech technologies. As our research progresses, we expect to make further enhancements that will boost performance on downstream tasks. We hope these improvements will lead to broader adoption across various AI applications. We’re excited about the breadth of this technology’s applicability and encourage the AI community to explore its benefits. LoftQ is available as open source through the Hugging Face PEFT library (opens in new tab).

The post LoftQ: Reimagining LLM fine-tuning with smarter initialization appeared first on Microsoft Research.

Read More

Rephrasing the Web: A Recipe for Compute and Data-Efficient Language Modeling

This paper has been accepted at the Data Problems for Foundation Models workshop at ICLR 2024.
Large language models are trained on massive scrapes of the web, which are often unstructured, noisy, and poorly phrased. Current scaling laws show that learning from such data requires an abundance of both compute and data, which grows with the size of the model being trained. This is infeasible both because of the large compute costs and duration associated with pre-training, and the impending scarcity of high-quality data on the web. In this work, we proposeWebRephrase Augmented Pre-training…Apple Machine Learning Research

A Mighty Meeting: Generative AI, Cybersecurity Connect at RSA

A Mighty Meeting: Generative AI, Cybersecurity Connect at RSA

Cybersecurity experts at the RSA Conference this week will be on the hunt for ways to secure their operations in the era of generative AI.

They’ll find many of the latest tools use AI and accelerated computing. This intersection of security and AI is coming into focus with collaborations that companies like NVIDIA and its partners will describe at the event.

Data Science for a Data Problem

Machine learning is a great tool for cybersecurity because data is exploding.

“With more devices and users expanding the landscape to defend, cybersecurity has become a data problem; and AI is a data solution,” said David Reber, NVIDIA’s chief security officer.

Today, security analysts can be overwhelmed by a data tsunami. Generative AI can provide security copilots that act like filters, extracting from the firehose flow of information the context and anomalies that need a human expert’s attention.

Generative AI also lets security managers interrogate their data directly instead of setting rules, chasing alerts and tracking dashboards. In the age of AI, security experts will move from a command line to a conversational interface.

AI Puts Security Context on Steroids

This shift takes context-based security to a new level, according to a talk Reber delivered at GTC.

The potential is huge, but it requires work to unearth it. At GTC, Reber encouraged cybersecurity experts to begin working with AI, starting with low-risk use cases to identify and secure gaps.

He also provided suggestions for how to go about securing machine learning processes, saying security experts need to:

  • secure data supply chains,
  • develop tests for securing AI models and datasets across the development lifecycle,
  • use model cards, data cards and software bills of materials to provide AI transparency and reliability,
  • participate in community testing events such as security hackathons, and
  • update policies on how to respond to AI security events.

Foundations for AI Cybersecurity

To give users a leg up, NVIDIA provides NVIDIA Morpheus, a cybersecurity AI framework that filters and classifies large volumes of real-time data. Morpheus, part of the NVIDIA AI Enterprise software suite, lets developers build applications that can detect spear phishing, insider threats and more.

Users can employ Morpheus with NVIDIA NIM and NeMo Retriever, microservices from the NVIDIA API Catalog for rapidly deploying AI. The combination can unlock new use cases, such as reducing from days to seconds the time to find and resolve common software vulnerabilities and exposures, one of many NVIDIA AI workflows.

A new release of NVIDIA DOCA — the software framework for programming NVIDIA BlueField DPUs and NVIDIA ConnectX NICs — provides another foundation for AI security. It now sports updated encryption features for network and storage data.

An Expanding AI Security Ecosystem

At RSA, many companies will show products built on NVIDIA technologies that extend security for the generative AI era, including:

  • AIC will demonstrate Qrypt’s key generation for quantum secure encryption running on a BlueField DPU in an AIC-designed server.
  • Anjuna will discuss how the U.S. Navy is evaluating confidential computing on the Anjuna Seaglass platform with proprietary LLMs running on NVIDIA H100 Tensor Core GPUs.
  • Bloombase will show an updated version of its StoreSafe Storage Firewall powered by Morpheus and BlueField DPUs and new use cases for threat detection and fast, quantum-safe encryption of AI models and data.
  • Check Point Software will show its AI Cloud Protect security solution on BlueField DPUs, Quantum Force security gateways on ConnectX NICs, and Quantum Maestro software on NVIDIA Spectrum switches.
  • Cisco will feature Cisco Hypershield, an AI-native security architecture, to protect against both known and unknown attacks. It will also discuss its expanding partnership with NVIDIA to help customers harness the power of AI.
  • CrowdStrike will show its CrowdStrike Falcon Foundry and CrowdStrike Falcon platform that employs NVIDIA’s GPU-optimized AI software, including NIM microservices.
  • Deloitte will showcase CyberSphere, a cyber operations platform that uses Morpheus to speed detection of cyber threats.
  • Palo Alto Networks will describe its collaboration with NVIDIA on two use cases, a next-generation reference architecture for securing generative AI deployments with NIM and its VM-Series Virtual Next-Generation Firewall, with expanded intelligent traffic offload (ITO), supporting BlueField-3 DPUs.
  • Qrypt will demonstrate its quantum-secure key generation for securing AI workloads running on BlueField-3 DPUs using DOCA.
  • Sygnia will announce the use of BlueField DPUs and Morpheus in Velocity Edge, its new hardware-accelerated MXDR service for the energy and industrial sectors.

They are part of the NVIDIA ecosystem building a new layer of security for generative AI. That community includes more than 20 companies at RSA this week from NVIDIA Inception, a program for cutting-edge startups.

At RSA, Daniel Rohrer, vice president of software product security at NVIDIA, will be part of the keynote panel on AI safety.

In addition, Kevin Deierling, senior vice president of networking at NVIDIA, will share insights on security at the Cloudflare Executive Supper Club. And NVIDIA will participate in an event about women in cybersecurity.

To get started with AI-powered cybersecurity, try a workflow in NVIDIA LaunchPad.

Read More

NVIDIA and Alphabet’s Intrinsic Put Next-Gen Robotics Within Grasp

NVIDIA and Alphabet’s Intrinsic Put Next-Gen Robotics Within Grasp

Intrinsic, a software and AI robotics company at Alphabet, has integrated NVIDIA AI and Isaac platform technologies to advance the complex field of autonomous robotic manipulation.

This week at the Automate trade show, in Chicago, Intrinsic is spotlighting leaps in robotic grasping and industrial scalability assisted by foundation models enabled by NVIDIA Isaac Manipulator, unlocking new value in industrial automation with AI.

NVIDIA unveiled Isaac Manipulator at GTC in March. Isaac Manipulator is a collection of foundation models and modular GPU-accelerated libraries that help industrial automation companies build scalable and repeatable workflows for dynamic manipulation tasks by accelerating AI model training and task reprogramming.

Foundation models are based on a transformer deep learning architecture that allows a neural network to learn by tracking relationships in data. They’re generally trained on huge datasets and can be used to process and understand sensor and robot information as magically as ChatGPT for text. This enables robot perception and decision-making like never before and provides zero-shot learning — the ability to perform tasks without prior examples.

NVIDIA’s collaboration with Intrinsic, a leading robotics software and AI company,  demonstrates the potential for a universally applicable robotic-grasping skill to work across grippers, environments and objects.

“For the broader industry, our work with NVIDIA shows how foundation models can have a profound impact, including making today’s processing challenges easier to manage at scale, creating previously infeasible applications, reducing development costs, and increasing flexibility for end users,” said Wendy Tan White, CEO at Intrinsic, in a blog post announcing the collaboration with NVIDIA.  (White will deliver a keynote address at Automate about what the rise of AI means for innovation and growth, on Thursday, May 9, at 7 a.m. PT.)

Developing Better Robot Grip With Isaac Manipulator

Grasping has been a long sought after robotics skill. So far it’s been time-consuming, expensive to program and difficult to scale. As a result, many repetitive pick-and-place conditions haven’t been seamlessly handled to date by robots.

Simulation is changing that. Enlisting NVIDIA Isaac Sim on the NVIDIA Omniverse platform, Intrinsic generated synthetic data for vacuum grasping using computer-aided design models of sheet metal and suction grippers. This allowed Intrinsic to create a prototype for its customer Trumpf Machine Tools, a leading maker of industrial machine tools.

The prototype uses Intrinsic Flowstate, a developer environment for AI-based robotics solutions, for visualizing processes, associated perception and motion planning. With a workflow that includes Isaac Manipulator, one can generate grasp poses and CUDA-accelerated robot motions, which can first be evaluated in simulation with Isaac Sim — a cost-saving step — before deployment in the real world with the Intrinsic platform.

Under the collaboration, NVIDIA and Intrinsic plan to bring state-of-the-art dexterity and modular AI capabilities for robotic arms, with a robust collection of foundation models and GPU-accelerated libraries to accelerate a greater number of new robotics and automation tasks.

On Tuesday, May 7, at 11 a.m. CT, NVIDIA Senior Research Scientist Adithya Murali and Intrinsic Chief Science Officer Torsten Kroeger will demonstrate the companies’ work in the session “Automating Smart Pick-and-Place With Intrinsic Flowstate and NVIDIA Isaac Manipulator ” in the Intrinsic booth 2808 at Automate. Join  our speaking sessions at Automate.

Read More

Abstracts: May 6, 2024

Abstracts: May 6, 2024

Stylized microphone and sound waves illustration.

Members of the research community at Microsoft work continuously to advance their respective fields. Abstracts brings its audience to the cutting edge with them through short, compelling conversations about new and noteworthy achievements.

In this episode, Senior Principal Researcher Michel Galley joins host Gretchen Huizinga to discuss “MathVista: Evaluating Mathematical Reasoning of Foundation Models in Visual Contexts,” which was accepted at the 2024 International Conference on Learning Representations (ICLR). MathVista, an open-source benchmark, combines new and existing data to measure how good models are at solving a variety of math problems that involve processing images as well as text, helping to gain insight into their reasoning capabilities.

Transcript

[MUSIC]

GRETCHEN HUIZINGA: Welcome to Abstracts, a Microsoft Research Podcast that puts the spotlight on world-class research in brief. I’m Dr. Gretchen Huizinga. In this series, members of the research community at Microsoft give us a quick snapshot—or a podcast abstract—of their new and noteworthy papers.

[MUSIC FADES]

My guest today is Dr. Michel Galley, a senior principal researcher at Microsoft Research. Dr. Galley is the coauthor of a paper called “MathVista: Evaluating Mathematical Reasoning of Foundation Models in Visual Contexts.” Michel, thanks for joining us on Abstracts today!


MICHEL GALLEY: Thank you for having me.

HUIZINGA: So I like to start with a distillation or sort of an elevator pitch of your research. Tell us in just a couple sentences what problem or issue your paper addresses and why we should care about it.

GALLEY: So this paper is about evaluating large foundation models. So it’s a very important part of researching large language models because it’s a good way to evaluate, kind of, the capabilities—what these models are good at and not good at. And a part of the focus of MathVista is to evaluate these large foundation models in a multimodal setup, so when the input to the model is actually not just text but also text and images. And then, an example of a task that such a model would perform is, like, the input is maybe a mathematical question, and then there’s some visual support to that question, let’s say, of an image of a graph, and then the model has to respond to something related to that. And why this is important … there has been a lot of work, of course, on large foundation model. Especially when it comes to reasoning tasks, like mathematical reasoning, a lot has focused more on written form.

HUIZINGA: Yeah …

GALLEY: So MathVista is one of the very first datasets that has input that is both images and text.

HUIZINGA: Yeah, yeah. Well, reading your paper, it seems like this is an area that hasn’t been studied systematically. In fact, you actually say that! And say that the field is largely unexplored. But quickly tell us what has been done in this field, and then tell us how your research addresses the proverbial gap in the literature.

GALLEY: Well, there has been a lot of work on vision and language in other problems, like not just about reasoning. Maybe let me just mention why reasoning is important. So one reason I think it’s very interesting to evaluate these large language models in terms of reasoning skill is that we evaluate their capabilities beyond just memorization. So as many of your listeners probably know, these large foundation models are trained on large amounts of text that is public data from various sources. So when you ask a question to a large foundation model, it could be the case, in many cases, that it just memorizes things it has seen in the data.

HUIZINGA: Sure.

GALLEY: So what makes it interesting in terms of reasoning, the answer oftentimes is not there in the data. So it needs to develop this ability to connect the dots between various pieces of information to come up with a new answer. So the focus of our paper is really on mathematical reasoning, but it goes also a bit beyond that because what is also represented in the data is also science question and so on.

HUIZINGA: Yeah …

GALLEY: So this reasoning part has largely focused, until MathVista, on text-only modalities.

HUIZINGA: Yeah …

GALLEY: So it’s one of our very first ones that combines text and images in terms of evaluating these large foundation models. So you ask about what was done before. So, yes, there has been a lot of work, text only, on reasoning, for example, the mathematical question that’s just based on text. And there has been a different stream of work that was much more focused on vision. A lot of work has been on tasks such as visual question answering …

HUIZINGA: Yeah …

GALLEY: … where basically, you have an image and the question is about answer a question about this image. So, yes, we’re trying to fuse the two lines of research here.

HUIZINGA: Right …

GALLEY: And that’s one of the first works that does that.

HUIZINGA: Yeah. Well, let’s talk about your methodology for a minute. Tell us how you went about conducting this research, and what methods did you use?

GALLEY: Yes, sure. So that’s a bit different from a typical, kind of, machine learning paper because the focus on this work is really on benchmarking on the dataset. So the methodology is more about how we collect the data, process it. So they have two components to doing that. One was to look at existing data that already combines vision and text. And there are existing datasets that are actually already fairly big but that were not focused on reasoning. So we use those existing datasets and look for instances in the data that actually include some mathematical or science reasoning. And so that part is leveraging existing datasets, but the important part is, like, we really want to carve out what was interesting piece in terms of reasoning. And we had different stages of processing the data to identify the subset that was reasoning-based. So one first step was basically to apply some automatic filter to determine whether or not a given example, let’s say something that is visual and text, is actually … involves some mathematical reasoning. So we have different strategy. For example, if the answer is numerical, it’s likely that it might be something mathematically related. But that’s just the first stage. And the second stage, we actually had humans, annotators, just certify that the selected data is actually of high quality. So we do have an example of, “Oh, this is mathematical, and that’s either mathematical or scientific,” and so on. And that’s one part of the effort. The other part is that we realized while we collected the data, there are certain types of mathematical reasoning or related to mathematical reasoning that were not represented in the data. So we created three new datasets as part of MathVista. So when I said dataset, it’s more like, think of MathVista as like an aggregate of different types of data, and we added three of them, three new types of data. One is what you call PaperQA, which is basically data that is collected from scientific papers on arXiv, and that had questions asking about that paper and that included some visual components from the paper, typically a plot or a figure.

HUIZINGA: Yeah …

GALLEY: And then we had IQTest, which is basically, I mean, it’s vaguely related mathematically, but basically it also, kind of, tried to see maybe more abstractive thinking about maybe some input that is both text and visual. And the final is about FunctionQA, that is basically algebraic reasoning and function plots and so on.

HUIZINGA: OK …

GALLEY: The important part was actually to identify among vast amounts of data what is actually very interesting in terms of mathematical reasoning.

HUIZINGA: Yeah …

GALLEY: So that part, I think, was quite a big part of doing that work—finding existing data but also creating new data.

HUIZINGA: Yeah, yeah. Well, my favorite part of a research paper is where it says, “and what we found was … ,” so talk a little bit about your results. What did you find?

GALLEY: So we evaluated a wide variety of models, including GPT-4, Claude 2, GPT-4V, multimodal Bard, and LLaVA, and we categorized them into three categories. So one is text only. So, basically, you take a model that is by default just text, and we give it the text part of the question and ask it to answer the question. Of course, that’s, kind of, a bit of a, it’s a difficult task because oftentimes [LAUGHTER] we crucially build these questions so that you have to rely on the vision part. But that’s for, you know, scientific investigation to know how well they can do, and so that’s one category of model. A different category is still text only but that is given the detection from the image. So on the image, we do OCR. So we convert those words from images to text. It’s kind of an extension of the text-based model, except that what was images is translated into text, and then the input to the model is word only, and that’s a different category of model. And the third one is basically truly multimodal model. And what we found, I mean, not surprisingly, it’s, kind of, the one that was doing most poorly is the one that is text only. The second is text plus OCR. And then finally, the one that does best is the multimodal like GPT-4V. But while the ordering between these three categories makes sense, it was a bit surprising that maybe the gap between multimodal and text plus OCR was not bigger. Well, it’s big, but maybe not as big as we were expecting. So, for example, the best detection from the images model achieved like 35 percent accuracy while GPT-4V was 50 percent. So it’s a substantial gap but not huge.

HUIZINGA: Right. Just to clarify, you’re saying OCR. What does that stand for?

GALLEY: [Optical] character recognition.

HUIZINGA: Gotcha.

GALLEY: So, basically, it’s the task of taking text, sometimes typed, but sometimes written, and convert this into the actual text like you would have in a text file.

HUIZINGA: Right. Michel, does any of this have to do with the difficulty of the math problems that you present these models with? I mean, it seems to me, similar to humans, that the easier the problem, the easier it would be for the machine. So at what level of math are we talking for these tests?

GALLEY: What’s nice about MathVista is there’s continuum [of] different difficulties. So the spectrum is quite broad, going from elementary school to more advanced concepts such as calculus. So it’s quite broad. So in the paper, we do have this, kind of, broken down by level. So the number I gave you, like 50 percent, is an aggregate over all the difficulties. But …

HUIZINGA: Gotcha.

GALLEY: But the goal there was really, kind of, to compare different models, but we do have a fair amount of analysis in the appendix. Actually, we have 100 pages of appendices of plenty of analysis and so on. So if people, I mean …

HUIZINGA: I saw that. I saw the length of the paper, and I’m going, what? [LAUGHS] That’s a LONG paper! Well, research in the lab is one thing, I always like to say, but understanding real-world impact is important, too. So where’s this work going to make the most difference, and who does it help most at this point?

GALLEY: Well, I think perhaps that’s the main point of this kind of line of work in terms of reasoning is that when looking at this difficult problem that are mathematical, actually it’s a way to, kind of, abstract away maybe more complex capabilities, and I think while thinking just about mathematics might seem a bit narrow, I don’t think that really is. It’s more about seeing whether this model has the ability to do, kind of, multistep kind of processing of your input and think maybe somewhat intelligently about a given problem. So we focus mostly on math. There is some science, but we would be very interested, especially in future work, to, kind of, go beyond that.

HUIZINGA: OK, well, let me press in a little bit there because … just say I’m a regular person using a GPT model. Is your work more addressed upstream from that to the research community to say, how do we get these models to be better so that downstream people like me can be more confident of the models?

GALLEY: Yes, I would say at the moment, I mean, this line of work is perhaps more geared towards somewhat more research community, but I think it could be some seed for researchers to think about some applications perhaps that also requires some kind of step-by-step reasoning but perhaps not going beyond math.

HUIZINGA: Yeah. Michel, if there was one thing you wanted our listeners to take away from this research, kind of golden nugget, what would it be?

GALLEY: Well, I would say it’s the challenging part of these datasets. I think that’s what makes MathVista stand out compared to other datasets. By now, there are a few other vision and language datasets, and of course, many that are more text-based. And we’ve seen, for example, some recent papers showing that actually MathVista remains one of the most challenging ones. So I think it’s probably going to stay around for a while because of the difficulty it represents. So it’s open source of available datasets that everybody can use, and I very much encourage people to use it.

HUIZINGA: Is it on GitHub?

GALLEY: Yes, it’s on GitHub.

HUIZINGA: So what’s next on the research agenda for helping LLMs get better at math, Michel? What are the big challenges in the field yet? I mean, you’ve alluded to many of them already, sort of, but what’s next on your research agenda?

GALLEY: Well, I would say what we found so far is these models are very good at processing the textual part of problems it’s given, to the model, but you have the equivalent in images actually harder somehow. So I think a lot more work needs to be done in terms of vision capabilities, in terms of reasoning over images, because the capabilities you will see in text are actually quite advanced, whereas the equivalent in images doesn’t seem that good. I mean, a fair disclaimer: my background is more on the text side, [LAUGHTER] so some of my colleagues on the paper are more on the vision side, so maybe if a listener maybe run into some of our coauthors at the conference, they might want to talk to these vision people because that’s less of my background. [LAUGHS]

HUIZINGA: Well, and if you think about Venn diagrams, you know, you’ve got people that are doing text, people that are doing vision, and then the people that are trying to do both to see how the worlds collide.

[MUSIC]

Well, Michel Galley, thanks for joining us today. And to our listeners, thanks for tuning in. If you want to read this paper, you can find a link at aka.ms/abstracts (opens in new tab), or you can find it on arXiv. You can also read it on the website for the International Conference on Learning Representations, or ICLR. And if you happen to be at the ICLR conference this week, you can hear more about it there. See you next time on Abstracts!

[MUSIC FADES]

The post Abstracts: May 6, 2024 appeared first on Microsoft Research.

Read More

Knowledge Transfer from Vision Foundation Models for Efficient Training of Small Task-specific Models

Vision Foundation Models (VFMs) pretrained on massive datasets exhibit impressive performance on various downstream tasks, especially with limited labeled target data. However, due to their high inference compute cost, these models cannot be deployed for many real-world applications. Motivated by this, we ask the following important question, “How can we leverage the knowledge from a large VFM to train a small task-specific model for a new target task with limited labeled training data?”, and propose a simple task-oriented knowledge transfer approach as a highly effective solution to this…Apple Machine Learning Research