Automating complex deep learning model training using Amazon SageMaker Debugger and AWS Step Functions

Amazon SageMaker Debugger can monitor ML model parameters, metrics, and computation resources as the model optimization is in progress. You can use it to identify issues during training, gain insights, and take actions like stopping the training or sending notifications through built-in or custom actions. Debugger is particularly useful in training challenging deep learning model architectures, which often require multiple rounds of manual tweaks to model architecture or training parameters to rectify training issues. You can also use AWS Step Functions, our powerful event-driven function orchestrator, to automate manual workflows with pre-planned steps in reaction to anticipated events.

In this post, we show how we can use Debugger with Step Functions to automate monitoring, training, and tweaking deep learning models with complex architecture and challenging training convergence characteristics. Designing deep neural networks often involves manual trials where the model is modified based on training convergence behavior to arrive at a baseline architecture. In these trials, new layers may get added or existing layers removed to stabilize unwanted behaviors like the gradients becoming too large (explode) or too small (vanish), or different learning methods or parameters may be tried to speed up training or improve performance. This manual monitoring and adjusting is a time-consuming part of model development workflow, exacerbated by the typically long deep learning training computation duration.

Instead of manually inspecting the training trajectory, you can configure Debugger to monitor convergence, and the new Debugger built-in actions can, for example, stop training if any of the specified set of rules are triggered. Furthermore, we can use Debugger as part of an iterative Step Functions workflow that modifies the model architecture and training strategy at a successfully-trained model. In such an architecture, we use Debugger to identify potential issues like misbehaving gradients or activation units, and Step Functions orchestrates modifying the model in response to events produced by Debugger.

Overview of the solution

A common challenge in training very deep convolutional neural networks is exploding or vanishing gradients, where gradients grow too large or too small during training, respectively. Debugger supports a number of useful built-in rules to monitor training issues like exploding gradients, dead activation units, or overfitting, and even take actions through built-in or custom actions. Debugger allows for custom rules also, although the built-in rules are quite comprehensive and insightful on what to look for when training doesn’t yield desired results.

We build this post’s example around the seminal 2016 paper “Deep Residual Networks with Exponential Linear Unit” by Shah et al. investigating exponential linear unit (ELU) activation, instead of the combination of ReLU activation with batch normalization layers, for the challenging ResNet family of very deep residual network models. Several architectures are explored in their paper, and in particular, the ELU-Conv-ELU-Conv architecture (Section 3.2.2 and Figure 3b in the paper) is reported to be among the more challenging constructs suffering from exploding gradients. To stabilize gradients, the paper modifies the architecture by adding batch normalization before the addition layers to stabilize training.

For this post, we use Debugger to monitor the training process for exploding gradients, and use SageMaker built-in stop training and notification actions to automatically stop the training and notify us if issues occur. As the next step, we devise a Step Functions workflow to address training issues on the fly with pre-planned strategies that we can try each time training fails through model development process. Our workflow attempts to stabilize the training first by trying different training warmup parameters to stabilize the starting training point, and if that fails, resorts to Shah et al.’s approach of adding batch normalization before addition layers. You can use the workflow and model code as a template to add in other strategies, for example, swapping the activation units, or try different flavors of the gradient-descent optimizers like Adam or RMSprop.

The workflow

The following diagram shows a schematic of the workflow.

The following diagram shows a schematic of the workflow.

The main components are state, model, train, and monitor, which we discuss in more detail in this section.

State component

The state component is a JSON collection that keeps track of the history of models or training parameters tried, current training status, and what to try next when an issue is observed. Each step of the workflow receives this state payload, possibly modifies it, and passes it to the next step. See the following code:

{
    "state": {
        "history": {
            "num_warmup_adjustments": int,
            "num_batch_layer_adjustments": int,
            "num_retraining": int,
            "latest_job_name": str,
            "num_learning_rate_adjustments": int,
            "num_monitor_transitions": int
        },
        "next_action": "<launch_new|monitor|end>",
        "job_status": str,
        "run_spec": {
            "warmup_learning_rate": float,
            "learning_rate": float,
            "add_batch_norm": int,
            "bucket": str,
            "base_job_name": str,
            "instance_type": str,
            "region": str,
            "sm_role": str,
            "num_epochs": int,
            "debugger_save_interval": int
        }
    }
}

Model component

Faithful to Shah et al.’s paper, the model is a residual network of (configurable) depth 20, with additional hooks to insert additional layers, change activation units, or change the learning behavior via input configuration parameters. See the following code:

def generate_model(input_shape=(32, 32, 3), activation='elu',
    add_batch_norm=False, depth=20, num_classes=10, num_filters_layer0=16):

Train component

The train step reads the model and training parameters that the monitor step specified to be tried next, and uses an AWS Lambda step to launch the training job using the SageMaker API. See the following code:

def lambda_handler(event, context):
    try:
        state = event['state']
        params = state['run_spec']
    except KeyError as e:
        ...
        ...
        ... 

    try:
        job_name = params['base_job_name'] + '-' + 
                      datetime.datetime.now().strftime('%Y-%b-%d-%Hh-%Mm-%S')
        sm_client.create_training_job(
            TrainingJobName=job_name,
            RoleArn=params['sm_role'],
            AlgorithmSpecification={
                'TrainingImage': sm_tensorflow_image,
                'TrainingInputMode': 'File',
                'EnableSageMakerMetricsTimeSeries': True,
                'MetricDefinitions': [{'Name': 'loss', 'Regex': 'loss: (.+?)'}]
            },
        ...

Monitor component

The monitor step uses another Lambda step that queries the status of the latest training job and plans the next steps of the workflow: Wait if there are no changes, or stop and relaunch with new parameters if training issues are found. See the following code:

if rule['RuleEvaluationStatus'] == "IssuesFound":
    logging.info(
        'Evaluation of rule configuration {} resulted in "IssuesFound". '
        'Attempting to stop training job {}'.format(
            rule.get("RuleConfigurationName"), job_name
        )
    )
    stop_job(job_name)
    logger.info('Planning a new launch')
    state = plan_launch_spec(state)
    logger.info(f'New training spec {json.dumps(state["run_spec"])}')
    state["rule_status"] = "ExplodingTensors"

The monitor step is also responsible for publishing updates about the status of the workflow to an Amazon Simple Notification Service (Amazon SNS) topic:

if state["next_action"] == "launch_new":
    sns.publish(TopicArn=topic_arn, Message=f'Retraining. n'
                                            f='State: {json.dumps(state)}')

Prerequisites

To launch this walkthrough, you only need to have an AWS account and basic familiarity with SageMaker notebooks.

Solution code

The entire code for this solution can be found in the following GitHub repository. This notebook serves as the entry point to the repository, and includes all necessary code to deploy and run the workflow. Use this AWS CloudFormation Stack to create a SageMaker notebook linked to the repository, together with the required AWS Identity and Access Management (IAM) roles to run the notebook. Besides the notebook and the IAM roles, the other resources like the Step Functions workflow are created inside the notebook itself.

In summary, to run the workflow, complete the following steps:

  1. Launch this CloudFormation stack. This stack creates a Sagemaker Notebook with necessary IAM roles, and clones the solution’s repository.
  2. Follow the steps in the notebook to create the resources and step through the workflow.

Creating the required resources manually without using the above CloudFormation stack

To manually create and run our workflow through a SageMaker notebook, we need to be able to create and run Step Functions, and create Lambda functions and SNS topics. The Step Functions workflow also needs an IAM policy to invoke Lambda functions. We also define a role for our Lambda functions to be able to access SageMaker. If you do not have permission to use the CloudFormation stack, you can create the roles on the IAM console.

The IAM policy for our notebook can be found in the solution’s repository here. Create an IAM role named sagemaker-debugger-notebook-execution and attach this policy to it.  Our Lambda functions need permissions to create or stop training jobs and check their status. Create an IAM role for Lambda, name it lambda-sagemaker-train, and attach the policy provided here to it. We also need to add sagemaker.amazonaws.com as a trusted principal in additional to lambda.amazonaws.com for this role.

Finally, the Step Functions workflow only requires access to invoke Lambda functions. Create an IAM role for workflow, name it step-function-basic-role, and attach the default AWS managed policy AWSLambdaRole. The following screenshot shows the policy on the IAM console.

The following screenshot shows the policy on the IAM console.

Next, launch a SageMaker notebook. Use the SageMaker console to create a SageMaker notebook. Use default settings except for what we specify in this post. For the IAM role, use the sagemaker-debugger-notebook-execution role we created earlier. This role allows our notebook to create the services we need, run our workflow, and clean up the resources at the end. You can link the project’s Github repository to the notebook, or alternatively, you can clone the repository using a terminal inside the notebook into the /home/ec2-user/SageMaker folder.

Final results

Step through the notebook. At the end, you will get a link to the Step Functions workflow. Follow the link to navigate to the AWS Step Function workflow dashboard.

Follow the link to navigate to the AWS Step Function workflow dashboard.

The following diagram shows the workflow’s state machine schematic diagram.

As the workflow runs through its steps, it sends SNS notifications with latest training parameters. When the workflow is complete, we receive a final notification that includes the final training parameters and the final status of the training job. The output of the workflow shows the final state of the state payload, where we can see the workflow completed seven retraining iterations, and settled at the end with lowering the warmup learning rate to 0.003125 and adding a batch normalization layer to the model (“add_batch_norm”: 1). See the following code:

{
  "state": {
    "history": {
      "num_warmup_adjustments": 5,
      "num_batch_layer_adjustments": 1,
      "num_retraining": 7,
      "latest_job_name": "complex-resnet-model-2021-Jan-27-06h-45m-19",
      "num_learning_rate_adjustments": 0,
      "num_monitor_transitions": 16
    },
    "next_action": "end",
    "job_status": "Completed",
    "run_spec": {
      "sm_role": "arn:aws:iam::xxxxxxx:role/lambda-sagemaker-train",
      "bucket": "xxxxxxx-sagemaker-debugger-model-automation",
      "add_batch_norm": 1,
      "warmup_learning_rate": 0.003125,
      "base_job_name": "complex-resnet-model",
      "region": "us-west-2",
      "learning_rate": 0.1,
      "instance_type": "ml.m5.xlarge",
      "num_epochs": 5,
      "debugger_save_interval": 100
    },
    "rule_status": "InProgress"
  }
}

Cleaning up

Follow the steps in the notebook under the Clean Up section to delete the resources created. The notebook’s final step deletes the notebook itself as a consequence of deleting the CloudFormation stack. Alternatively, you can delete the SageMaker notebook via the SageMaker console.

Conclusion

Debugger provides a comprehensive set of tools to develop and train challenging deep learning models. Debugger can monitor the training process for hardware resource usage and training problems like dead activation units, misbehaving gradients, or stalling performance, and through its built-in and custom actions, take automatic actions like stopping the training job or sending notifications. Furthermore, you can easily devise Step Functions workflows around Debugger events to change model architecture, try different training strategies, or tweak optimizer parameters and algorithms, while tracking the history of recipes tried, together with detailed notification messaging to keep data scientists in full control. The combination of Debugger and Step Functions toolchains significantly reduces experimentation turnaround and saves on development and infrastructure costs.


About the Authors

Peyman Razaghi is a data scientist at AWS. He holds a PhD in information theory from the University of Toronto and was a post-doctoral research scientist at the University of Southern California (USC), Los Angeles. Before joining AWS, Peyman was a staff systems engineer at Qualcomm contributing to a number of notable international telecommunication standards. He has authored several scientific research articles peer-reviewed in statistics and systems-engineering area, and enjoys parenting and road cycling outside work.

 

Ross Claytor is a Sr Data Scientist on the ProServe Intelligence team at AWS. He works on the application of machine learning and orchestration to real world problems across industries including media and entertainment, life sciences, and financial services.

Read More