Accelerate deep learning model training up to 35% with Amazon SageMaker smart sifting

In today’s rapidly evolving landscape of artificial intelligence, deep learning models have found themselves at the forefront of innovation, with applications spanning computer vision (CV), natural language processing (NLP), and recommendation systems. However, the increasing cost associated with training and fine-tuning these models poses a challenge for enterprises. This cost is primarily driven by the sheer volume of data used in training deep learning models. Today, large models are often trained on terabytes of data and can take weeks to train, even with powerful GPU or AWS Trainium-based hardware. Typically, customers rely on techniques and optimizations that improve the efficiency of a model’s training loop, such as optimized kernels or layers, mixed precision training, or features such as the Amazon SageMaker distributed training libraries. However, there is less focus today on the efficiency of the training data itself. Not all data contributes equally to the learning process during model training: a significant proportion of the computational resources may be spent on processing simple examples that don’t contribute substantially to the model’s overall accuracy.

Customers have traditionally relied on preprocessing techniques such as upsampling or downsampling and deduplication to refine and improve the information quality of their data. These techniques can help, but are often time consuming, require specialized data science experience, and can sometimes be more art than science. Customers often also rely on curated datasets, such as RefinedWeb, to improve the performance of their models; however, these datasets aren’t always fully open source and are often more general purpose and not related to your specific use case.

How else can you overcome this inefficiency related to low-information data samples during model training?

We’re excited to announce a public preview of smart sifting, a new capability of SageMaker that can reduce the cost of training deep learning models by up to 35%. Smart sifting is a new data efficiency technique that actively analyzes your data samples during training and filters out the samples that are less informative to the model. By training on a smaller subset of data with only the samples that contribute the most to model convergence, total training and cost decreases with minimal or no impact to accuracy. Additionally, because the feature operates online during model training, smart sifting doesn’t require changes to your upstream data or downstream training pipeline.

In this post, we discuss the following topics:

  • The new smart sifting capability in SageMaker and how it works
  • How to use smart sifting with PyTorch training workloads

You can also check out our documentation and sample notebooks for additional resources on how to get started with smart sifting.

How SageMaker smart sifting works

We begin this post with an overview of how the smart sifting capability can accelerate your model training on SageMaker.

Smart sifting’s task is to sift through your training data during the training process and only feed the more informative samples to the model. During a typical training with PyTorch, data is iteratively sent in batches to the training loop and to accelerator devices (for example, GPUs or Trainium chips) by the PyTorch DataLoader. Smart sifting is implemented at this data loading stage and therefore is independent of any upstream data preprocessing in your training pipeline.

Smart sifting uses your model and a user-specified loss function to do an evaluative forward pass of each data sample as it’s loaded. Samples that are high-loss will materially impact model training and therefore are used in training; data samples that are relatively low-loss are set aside and excluded from training.

A key input to smart sifting is the proportion of data to exclude: for example, by setting the proportion to 33% (beta_value=0.5), samples in approximately the bottom third of loss of each batch will be excluded from training. When enough high-loss samples have been identified to complete a batch, the data is sent through the full training loop and the model learns and trains normally. You don’t need to make any changes to your training loop when smart sifting is enabled.

The following diagram illustrates this workflow.

By including only a subset of your training data, smart sifting reduces the time and computation needed to train the model. In our tests, we achieved up to a nearly 40% reduction in total training time and cost. With smart sifting of data, there can be minimal or no impact to model accuracy because the excluded samples were relatively low-loss for the model. In the following table, we include a set of experimental results demonstrating the performance improvement possible with SageMaker smart sifting.

In the table, the % Accepted column indicates the proportion of data that is included and used in the training loop. Increasing this tunable parameter decreases the cost (as demonstrated in the IMR Savings % column), but it also can also affect the accuracy. The appropriate setting for % Accepted is a function of your dataset and model; you should experiment with and tune this parameter to achieve the best balance between reduced cost and impact to accuracy.

Solution overview

In the following sections, we walk through a practical example of enabling smart sifting with a PyTorch training job on SageMaker. If you want to get started quickly, you can jump to the PyTorch or PyTorch Lightning examples.

Prerequisites

We assume that you already know how to train a model using PyTorch or PyTorch Lightning using the SageMaker Python SDK and the Estimator class using SageMaker Deep Learning Containers for training. If not, refer to Using the SageMaker Python SDK before continuing.

Get started with SageMaker smart sifting

In a typical PyTorch training job, you initialize the PyTorch training DataLoader with your dataset and other required parameters, which provides input batches as the training progresses. To enable smart sifting of your training data, you’ll use a new DataLoader class: smart_sifting.dataloader.sift_dataloader.SiftingDataloader. This class is used as a wrapper on top of your existing PyTorch DataLoader and the training process will instead use SiftingDataloader to get input batches. The SiftingDataLoader gets the input batch from your original PyTorch DataLoader, evaluates the importance of samples in the batch, and constructs a sifted batch with high-loss samples, which are then passed to the training step. The wrapper looks like the following code:

from smart_sifting.dataloader.sift_dataloader import SiftingDataloader

train_dataloader =  SiftingDataloader(
    sift_config = sift_config,
    orig_dataloader=DataLoader(self.train, self.batch_size, shuffle=True),
    loss_impl=BertLoss(),
    model=self.model
)

The SiftingDataloader requires some additional parameters to analyze your training data, which you can specify via the sift_config parameter. First, create a smart_sifting.sift_config.sift_configs.RelativeProbabilisticSiftConfig object. This object holds the configurable and required beta_value and loss_history_length, which respectively define the proportion of samples to keep and the window of samples to include when evaluating relative loss. Note that, because smart sifting uses your model for defining the importance of the sample, there can be negative implications if we use a model with completely random weights. Instead, you can use loss_based_sift_config and a sift_delay to delay the sift process until the parameter weights in the model are updated beyond random values. (For more details, refer to Apply smart sifting to your training script.) In the following code, we define sift_config and specify beta_value and loss_history_length, as well as delay the start of sifting using loss_based_sift_config:

from smart_sifting.sift_config.sift_configs import RelativeProbabilisticSiftConfig, LossConfig, SiftingBaseConfig

sift_config = RelativeProbabilisticSiftConfig(
    beta_value=3,
    loss_history_length=500,
    loss_based_sift_config=LossConfig(
         sift_config=SiftingBaseConfig(sift_delay=10)
    )
)

Next, you must also include a loss_impl parameter in the SiftingDataloader object. Smart sifting works on an individual sample level, and it’s crucial to have access to a loss calculation method to determine the importance of the sample. You must implement a sifting loss method that returns a nx1 tensor, which holds loss values of n samples. Typically, you specify the same loss method used by your model during training. Finally, include a pointer to your model in the SiftingDataloader object, which is used to evaluate samples before they are included in training. See the following code:

from smart_sifting.sift_config.sift_configs import RelativeProbabilisticSiftConfig, LossConfig, SiftingBaseConfig

## Defining Sift loss
class SiftBertLoss(Loss):
    # You should add the following initializaztion function 
    # to calculate loss per sample, not per batch.
    def __init__(self):
        self.celoss = torch.nn.CrossEntropyLoss(reduction='none')

    def loss(
            self,
            model: torch.nn.Module,
            transformed_batch: SiftingBatch,
            original_batch: Any = None,
    ) -> torch.Tensor:
    
        device = next(model.parameters()).device
        batch = [t.to(device) for t in original_batch]

        # compute loss
        outputs = model(batch)
        return self.celoss(outputs.logits, batch[2])

....
....

train_dataloader =  SiftingDataloader(
    sift_config = sift_config,
    orig_dataloader=DataLoader(self.train, self.batch_size, shuffle=True),
    loss_impl=SiftBertLoss(),
    model=self.model
)

The following code shows a complete example of enabling smart sifting with an existing BERT training job:

from smart_sifting.dataloader.sift_dataloader import SiftingDataloader
from smart_sifting.loss.abstract_sift_loss_module import Loss
from smart_sifting.sift_config.sift_configs import RelativeProbabilisticSiftConfig, LossConfig, SiftingBaseConfig
...
...
...

## Defining Sift loss
class SiftBertLoss(Loss):
    # You should add the following initializaztion function 
    # to calculate loss per sample, not per batch.
    def __init__(self):
        self.celoss = torch.nn.CrossEntropyLoss(reduction='none')

    def loss(
            self,
            model: torch.nn.Module,
            transformed_batch: SiftingBatch,
            original_batch: Any = None,
    ) -> torch.Tensor:
    
        device = next(model.parameters()).device
        batch = [t.to(device) for t in original_batch]

        # compute loss
        outputs = model(batch)
        return self.celoss(outputs.logits, batch[2])
             
 ....
 ....
 ....
 
 sift_config = RelativeProbabilisticSiftConfig(
    beta_value=3,
    loss_history_length=500,
    loss_based_sift_config=LossConfig(
        sift_config=SiftingBaseConfig(sift_delay=10)
    )
)

train_dataloader =  SiftingDataloader(
    sift_config = sift_config,
    orig_dataloader=DataLoader(self.train, self.batch_size, shuffle=True),
    loss_impl=SiftBertLoss(),
    model=self.model
)

......

# use train_dataloader in the rest of the training logic.

Conclusion

In this post, we explored the public preview of smart sifting, a new capability of SageMaker that can reduce deep learning model training costs by up to 35%. This feature improves data efficiency during training that filters out less informative data samples. By including only the most impactful data for model convergence, you can significantly reduce training time and expense, all while maintaining accuracy. What’s more, it seamlessly integrates into your existing processes without requiring alterations to your data or training pipeline.

To dive deeper into SageMaker smart sifting, explore how it works, and implement it with PyTorch training workloads, check out our documentation and sample notebooks and get started with this new capability.


About the authors

Robert Van Dusen is a Senior Product Manager with Amazon SageMaker. He leads frameworks, compilers, and optimization techniques for deep learning training.

K Lokesh Kumar Reddy is a Senior engineer in the Amazon Applied AI team. He is focused on efficient ML training techniques and building tools to improve conversational AI systems. In his spare time he enjoys seeking out new cultures, new experiences, and staying up to date with the latest technology trends.

Abhishek Dan is a senior Dev Manager in the Amazon Applied AI team and works on machine learning and conversational AI systems. He is passionate about AI technologies and works in the intersection of Science and Engineering in advancing the capabilities of AI systems to create more intuitive and seamless human-computer interactions. He is currently building applications on large language models to drive efficiency and CX improvements for Amazon.

Read More