Extend Amazon SageMaker Pipelines to include custom steps using callback steps

Launched at AWS re:Invent 2020, Amazon SageMaker Pipelines is the first purpose-built, easy-to-use continuous integration and continuous delivery (CI/CD) service for machine learning (ML). With Pipelines, you can create, automate, and manage end-to-end ML workflows at scale.

You can extend your pipelines to include steps for tasks performed outside of Amazon SageMaker by taking advantage of custom callback steps. This feature lets you include tasks that are performed using other AWS services, third parties, or tasks run outside AWS. Before the launch of this feature, steps within a pipeline were limited to the supported native SageMaker steps. With the launch of this new feature, you can use the new CallbackStep to generate a token and add a message to an Amazon Simple Queue Service (Amazon SQS) queue. The message on the SQS queue triggers a task outside of the currently supported native steps. When that task is complete, you can call the new SendStepSuccess API with the generated token to signal that the callback step and corresponding tasks are finished and the pipeline run can continue.

In this post, we demonstrate how to use CallbackStep to perform data preprocessing using AWS Glue. We use an Apache Spark job to prepare NYC taxi data for ML training. The raw data has one row per taxi trip, and shows information like the trip duration, number of passengers, and trip cost. To train an anomaly detection model, we want to transform the raw data into a count of the number of passengers that took taxi rides over 30-minute intervals.

Although we could run this specific Spark job in SageMaker Processing, we use AWS Glue for this post. In some cases, we may need capabilities that Amazon EMR or AWS Glue offer, like support for Hive queries or integration with the AWS Glue metadata catalog, so we demonstrate how to invoke AWS Glue from the pipeline.

Solution overview

The pipeline step that launches the AWS Glue job sends a message to an SQS queue. The message contains the callback token we need to send success or failure information back to the pipeline. This callback token triggers the next step in the pipeline. When handling this message, we need a handler that can launch the AWS Glue job and reliably check for job status until the job completes. We have to keep in mind that a Spark job can easily take longer than 15 minutes (the maximum duration of a single AWS Lambda function invocation), and the Spark job itself could fail for a number of reasons. That last point is worth emphasizing: in most Apache Spark runtimes, the job code itself runs in transient containers under the control of a coordinator like Apache YARN. We can’t add custom code to YARN, so we need something outside the job to check for completion.

We can accomplish this task several ways:

  • Have a Lambda function launch a container task that creates the AWS Glue job and polls for job completion, then sends the callback back to the pipeline
  • Have a Lambda function send a work notification to another SQS queue, with a separate Lambda function that picks up the message, checks for job status, and requeues the message if the job isn’t complete
  • Use AWS Glue job event notifications to respond to job status events sent by AWS Glue

For this post, we use the first technique because it’s the simplest (but likely not the most efficient). For this, we build out the solution as shown in the following diagram.

The solution is one example of how to use the new CallbackStep to extend your pipeline to steps outside SageMaker (such as AWS Glue). You can apply the same general steps and architectural guidance to extend pipelines to other custom processes or tasks. In our solution, the pipeline runs the following tasks:

Data preprocessing

  • This step (Step 1 in the preceding diagram) uses CallbackStep to send a generated token and defined input payload to the configured SQS queue (2). In this example, the input sent to the SQS queue is the Amazon Simple Storage Service (Amazon S3) locations of the input data and the step output training data.
    • The new message in the SQS queue triggers a Lambda function (3) that is responsible for running an AWS Fargate task with Amazon Elastic Container Service (Amazon ECS) (4).
    • The Fargate task runs using a container image that is configured to run a task. The task in this case is an AWS Glue job (5) used to transform your raw data into training data stored in Amazon S3 (6). This task is also responsible for sending a callback message that signals either the job’s success or failure.
  • Model training – This step (7) runs when the previous callback step has completed successfully. It uses the generated training data to train a model using a SageMaker training job and the Random Cut Forest algorithm.
  • Package model – After the model is successfully trained, the model is packaged for deployment (8).
  • Deploy model – In this final step (9), the model is deployed using a batch transform job.

These pipeline steps are just examples; you can modify the pipeline to meet your use case, such as adding steps to register the model in the SageMaker Model Registry.

In the next sections, we discuss how to set up this solution.


For the preceding pipeline, you need the prerequisites outlined in this section. The detailed setup of each of these prerequisites is available in the supporting notebook.

Notebook dependencies

To run the provided notebook, you need the following:

Pipeline dependencies

Your pipeline uses the following services:

  • SQS message queue – The callback step requires an SQS queue to trigger a task. For this, you need to create an SQS queue and ensure that AWS Identity and Access Management (IAM) permissions are in place that allow SageMaker to put a message in the queue and allow Lambda to poll the queue for new messages. See the following code:
sqs_client = boto3.client('sqs')
queue_url = ''
queue_name = 'pipeline_callbacks_glue_prep'
    response = sqs_client.create_queue(QueueName=queue_name)
    print(f"Failed to create queue")
  • Lambda function: The function is triggered by new messages put to the SQS queue. The function consumes these new messages and starts the ECS Fargate task. In this case, the Lambda execution IAM role needs permissions to pull messages from Amazon SQS, notify SageMaker of potential failures, and run the Amazon ECS task. For this solution, the function starts a task on ECS Fargate using the following code:
%%writefile queue_handler.py
import json
import boto3
import os
import traceback

ecs = boto3.client('ecs')
sagemaker = boto3.client('sagemaker')

def handler(event, context):   
    print(f"Got event: {json.dumps(event)}")
    cluster_arn = os.environ["cluster_arn"]
    task_arn = os.environ["task_arn"]
    task_subnets = os.environ["task_subnets"]
    task_sgs = os.environ["task_sgs"]
    glue_job_name = os.environ["glue_job_name"]
    print(f"Cluster ARN: {cluster_arn}")
    print(f"Task ARN: {task_arn}")
    print(f"Task Subnets: {task_subnets}")
    print(f"Task SG: {task_sgs}")
    print(f"Glue job name: {glue_job_name}")
    for record in event['Records']:
        payload = json.loads(record["body"])
        print(f"Processing record {payload}")
        token = payload["token"]
        print(f"Got token {token}")
            input_data_s3_uri = payload["arguments"]["input_location"]
            output_data_s3_uri = payload["arguments"]["output_location"]
            print(f"Got input_data_s3_uri {input_data_s3_uri}")
            print(f"Got output_data_s3_uri {output_data_s3_uri}")

            response = ecs.run_task(
                cluster = cluster_arn,
                    'awsvpcConfiguration': {
                        'subnets': task_subnets.split(','),
                        'securityGroups': task_sgs.split(','),
                        'assignPublicIp': 'ENABLED'
                    'containerOverrides': [
                            'name': 'FargateTask',
                            'environment': [
                                    'name': 'inputLocation',
                                    'value': input_data_s3_uri
                                    'name': 'outputLocation',
                                    'value': output_data_s3_uri
                                    'name': 'token',
                                    'value': token
                                    'name': 'glue_job_name',
                                    'value': glue_job_name
            if 'failures' in response and len(response['failures']) > 0:
                f = response['failures'][0]
                print(f"Failed to launch task for token {token}: {f['reason']}")
                    FailureReason = f['reason']
                print(f"Launched task {response['tasks'][0]['taskArn']}")
        except Exception as e:
            trc = traceback.format_exc()
            print(f"Error handling record: {str(e)}:m {trc}")
                FailureReason = e
  • After we create the SQS queue and Lambda function, we need to set up the function as an SQS target so that when new messages are placed in the queue, the function is automatically triggered:
  • Fargate cluster – Because we use Amazon ECS to run and monitor the status of the AWS Glue job, we need to ensure we have an ECS Fargate cluster running:
import boto3
ecs = boto3.client('ecs')
response = ecs.create_cluster(clusterName='FargateTaskRunner')
  • Fargate task: We also need to create a container image with the code (task.py) that starts the data preprocessing job on AWS Glue and reports the status back to the pipeline upon the success or failure of that task. The IAM role attached to the task must include permissions that allow the task to pull images from Amazon ECR, create logs in Amazon CloudWatch, start and monitor an AWS Glue job, and send the callback token when the task is complete. When we issue send_pipeline_execution_step_success back to the pipeline, we also indicate the output file with the prepared training data. We use the output parameter in the model training step in the pipeline. The following is the code for task.py:
import boto3
import os
import sys
import traceback
import time

if 'inputLocation' in os.environ:
    input_uri = os.environ['inputLocation']
    print("inputLocation not found in environment")
if 'outputLocation' in os.environ:
    output_uri = os.environ['outputLocation']
    print("outputLocation not found in environment")
if 'token' in os.environ:
    token = os.environ['token']
    print("token not found in environment")
if 'glue_job_name' in os.environ:
    glue_job_name = os.environ['glue_job_name']
    print("glue_job_name not found in environment")

print(f"Processing from {input_uri} to {output_uri} using callback token {token}")
sagemaker = boto3.client('sagemaker')
glue = boto3.client('glue')

poll_interval = 60

    t1 = time.time()
    response = glue.start_job_run(
            '--output_uri': output_uri,
            '--input_uri': input_uri
    job_run_id = response['JobRunId']
    print(f"Starting job {job_run_id}")
    job_status = 'STARTING'
    job_error = ''
    while job_status in ['STARTING','RUNNING','STOPPING']:
        response = glue.get_job_run(
        job_status = response['JobRun']['JobRunState']
        if 'ErrorMessage' in response['JobRun']:
            job_error = response['JobRun']['ErrorMessage']
        print(f"Job is in state {job_status}")
    t2 = time.time()
    total_time = (t2 - t1) / 60.0
    if job_status == 'SUCCEEDED':
        print("Job succeeded")
                    'Name': 'minutes',
                    'Value': str(total_time)
                    'Name': 's3_data_out',
                    'Value': str(output_uri),
        print(f"Job failed: {job_error}")
            FailureReason = job_error
except Exception as e:
    trc = traceback.format_exc()
    print(f"Error running ETL job: {str(e)}:m {trc}")
        FailureReason = str(e)
  • Data preprocessing code – The pipeline callback step does the actual data preprocessing using a PySpark job running in AWS Glue, so we need to create the code that is used to transform the data:
import sys
from awsglue.transforms import *
from awsglue.utils import getResolvedOptions
from pyspark.context import SparkContext
from awsglue.context import GlueContext
from awsglue.job import Job
from pyspark.sql.types import IntegerType
from pyspark.sql import functions as F

## @params: [JOB_NAME]
args = getResolvedOptions(sys.argv, ['JOB_NAME', 'input_uri', 'output_uri'])

sc = SparkContext()
glueContext = GlueContext(sc)
spark = glueContext.spark_session
job = Job(glueContext)
job.init(args['JOB_NAME'], args)

df = spark.read.format("csv").option("header", "true").load("{0}*.csv".format(args['input_uri']))
df = df.withColumn("Passengers", df["passenger_count"].cast(IntegerType()))
df = df.withColumn(
  F.unix_timestamp('tpep_pickup_datetime', 'yyyy-MM-dd HH:mm:ss').cast('timestamp')))
dfW = df.groupBy(F.window("pickup_time", "30 minutes")).agg(F.sum("Passengers").alias("passenger"))
dfOut = dfW.drop('window')
dfOut.repartition(1).write.option("timestampFormat", "yyyy-MM-dd HH:mm:ss").csv(args['output_uri'])

  • Data preprocessing job – We need to also configure the AWS Glue job that runs the preceding code when triggered by your Fargate task. The IAM role used must have permissions to read and write from the S3 bucket. See the following code:
glue = boto3.client('glue')
response = glue.create_job(
    Description='Prepare data for SageMaker training',
        'MaxConcurrentRuns': 1
        'Name': 'glueetl',
        'ScriptLocation': glue_script_location,
glue_job_name = response['Name']

After these prerequisites are in place, including the necessary IAM permissions outlined in the example notebook, we’re ready to configure and run the pipeline.

Configure the pipeline

To build out the pipeline, we rely on the preceding prerequisites in the callback step that perform data processing. We also combine that with steps native to SageMaker for model training and deployment to create an end-to-end pipeline.

To configure the pipeline, complete the following steps:

  1. Initialize the pipeline parameters:
from sagemaker.workflow.parameters import (

input_data = ParameterString(
id_out = ParameterString(
    default_value="taxiout"+ str(timestamp)
output_data = ParameterString(
training_instance_count = ParameterInteger(
training_instance_type = ParameterString(
  1. Configure the first step in the pipeline, which is CallbackStep.

This step uses the SQS queue created in the prerequisites in combination with arguments that are used by tasks in this step. These arguments include the inputs of the Amazon S3 location of the input (raw taxi data) and output training data. The step also defines the outputs, which in this case includes the callback output and Amazon S3 location of the training data. The outputs become the inputs to the next step in the pipeline. See the following code:

from sagemaker.workflow.callback_step import CallbackStep,CallbackOutput,CallbackOutputTypeEnum

callback1_output=CallbackOutput(output_name="s3_data_out", output_type=CallbackOutputTypeEnum.String)

step_callback_data = CallbackStep(
                        "input_location": f"s3://{default_bucket}/{taxi_prefix}/",
                        "output_location": f"s3://{default_bucket}/{taxi_prefix}_{id_out}/"
  1. We use TrainingStep to train a model using the Random Cut Forest algorithm.

We first need to configure an estimator, then we configure the actual pipeline step. This step takes the output of the previous step and Amazon S3 location of the training data created by AWS Glue as input to train the model. See the following code:

containers = {
    'us-west-2': '174872318107.dkr.ecr.us-west-2.amazonaws.com/randomcutforest:latest',
    'us-east-1': '382416733822.dkr.ecr.us-east-1.amazonaws.com/randomcutforest:latest',
    'us-east-2': '404615174143.dkr.ecr.us-east-2.amazonaws.com/randomcutforest:latest',
    'eu-west-1': '438346466558.dkr.ecr.eu-west-1.amazonaws.com/randomcutforest:latest'}
region_name = boto3.Session().region_name
container = containers[region_name]
model_prefix = 'model'

session = sagemaker.Session()

rcf = sagemaker.estimator.Estimator(
    output_path='s3://{}/{}/output'.format(default_bucket, model_prefix),

from sagemaker.inputs import TrainingInput
from sagemaker.workflow.steps import TrainingStep

step_train = TrainingStep(
        "train": TrainingInput(
        #s3_data = Output of the previous call back 
  1. We use CreateModelStep to package the model for SageMaker deployment:
from sagemaker.model import Model
from sagemaker import get_execution_role
role = get_execution_role()

image_uri = sagemaker.image_uris.retrieve("randomcutforest", region)

model = Model(
from sagemaker.inputs import CreateModelInput
from sagemaker.workflow.steps import CreateModelStep

inputs = CreateModelInput(

create_model = CreateModelStep(
  1. We deploy the trained model using a SageMaker batch transform job using TransformStep.

This step loads the trained model and processes the prediction request data stored in Amazon S3, then outputs the results (anomaly scores in this case) to the specified Amazon S3 location. See the following code:

base_uri = step_callback_data.properties.Outputs['s3_data_out']
output_prefix = 'batch-out'

from sagemaker.transformer import Transformer

transformer = Transformer(
    assemble_with = "Line",
    accept = 'text/csv',
from sagemaker.inputs import TransformInput
from sagemaker.workflow.steps import TransformStep


step_transform = TransformStep(

Create and run the pipeline

You’re now ready to create and run the pipeline. To do this, complete the following steps:

  1. Define the pipeline including the parameters accepted and steps:
from sagemaker.workflow.pipeline import Pipeline

pipeline_name = f"GluePipeline-{id_out}"
pipeline = Pipeline(
    steps=[step_callback_data, step_train,create_model,step_transform],
  1. Submit the pipeline definition to create the pipeline using the role that is used to create all the jobs defined in each step:
from sagemaker import get_execution_role
pipeline.upsert(role_arn = get_execution_role())
  1. Run the pipeline:
execution = pipeline.start()

You can monitor your pipeline using the SageMaker SDK, execution.list_steps(), or via the Studio console, as shown in the following screenshot.

Use CallbackStep to integrate other tasks outside of SageMaker

You can follow the same pattern to integrate any long-running tasks or jobs with Pipelines. This may include running AWS Batch jobs, Amazon EMR job flows, or Amazon ECS or Fargate tasks.

You can also implement an email approval step for your models as part of your ML pipeline.
CallbackStep runs after the model EvaluationStep and sends an email containing approve or reject links with model metrics to a user. The workflow progresses to the next state after the user approves the task to proceed.

You can implement this pattern using a Lambda function and Amazon Simple Notification Service (Amazon SNS).


In this post, we showed you an example of how to use CallbackStep in Pipelines to extend your pipelines to integrate an AWS Glue job for data preprocessing. You can follow the same process to integrate any task or job outside of SageMaker. You can walk through the full solution explained in the example notebook.

About the Author

Shelbee Eigenbrode is a Principal AI and Machine Learning Specialist Solutions Architect at Amazon Web Services (AWS). She holds 6 AWS certifications and has been in technology for 23 years spanning multiple industries, technologies, and roles. She is currently focusing on combining her DevOps and ML background to deliver and manage ML workloads at scale. With over 35 patents granted across various technology domains, she has a passion for continuous innovation and using data to drive business outcomes. Shelbee co-founded the Denver chapter of Women in Big Data.


Sofian Hamiti is an AI/ML specialist Solutions Architect at AWS. He helps customers across industries accelerate their AI/ML journey by helping them build and operationalize end-to-end machine learning solutions.




Randy DeFauw is a principal solutions architect at Amazon Web Services. He works with the AWS customers to provide guidance and technical assistance on database projects, helping them improve the value of their solutions when using AWS.




Payton Staub is a senior engineer with Amazon SageMaker. His current focus includes model building pipelines, experiment management, image management and other tools to help customers productionize and automate machine learning at scale.

Read More