A Unifying, Game-Theoretic Framework for Imitation Learning

A Unifying, Game-Theoretic Framework for Imitation Learning

Imitation learning (IL) is the problem of finding a policy, (pi), that is as close as possible to an expert’s policy, (pi_E). IL algorithms can be grouped broadly into (a) online, (b) offline, and (c) interactive methods. We provide, for each setting, performance bounds for learned policies that apply for all algorithms, provably efficient algorithmic templates for achieving said bounds, and practical realizations that out-perform recent work.

From beating the world champion at Go (Silver et al.) to getting cars to drive themselves (Bojarski et al.), we’ve seen unprecedented successes in learning to make sequential decisions over the last few years. When viewed from an algorithmic viewpoint, many of these accomplishments share a common paradigm: imitation learning (IL). In imitation learning, one is given access to samples of expert behavior (e.g. moves chosen by Monte-Carlo Tree Search or steering angles recorded from an expert driver) and tries to learn a policy that mimics this behavior. Unlike reinforcement learning, imitation learning does not require careful tuning of a reward function, making it easier to scale to real-world tasks where one is able to gather expert behavior (like Go or driving). As we continue to apply imitation learning algorithms to safety-critical problems, it becomes increasingly important for us to have strong guarantees on their performance: while wrong steps in Go lead to a lost game at worst, mistakes of self-driving cars could result in far worse. In our ICML’21 Paper Of Moments and Matching: A Game Theoretic Framework for Closing the Imitation Gap, we provide bounds on how well any imitation algorithm can do, as well as provably efficient algorithms for achieving these bounds.

A Taxonomy of Imitation Learning Algorithms

Let’s focus on the problem of trying to teach a car to drive around a track from expert demonstrations. We instrument the car with cameras and sensors that measure the angle of the wheel and how hard the pedals are being pushed. Then, in terms of increasing requirements, the approaches we could take are:

  • Offline: Have the expert drive laps, recording their states (camera images) and actions (pedals/wheel). Use your favorite supervised learning algorithm to regress from states to actions. This approach is called Behavioral Cloning.
  • Online: Record expert states and actions. Then, have the car try to drive around the track and measure the delta between learner and expert trajectories. Train the policy to minimize this delta. GAIL is an algorithm that uses a discriminator network to measure this delta.
  • Interactive: (0) Start with an empty dataset D. (1) Record the car driving a sample lap. (2) Ask the expert driver what they would have done for each recorded image. Append this data to D. (3) Regress over data in D. (4) Go back to 1. This approach is known as DAgger.

One of our key insights is that all three of these approaches can be seen as minimizing a sort of divergence from expert behavior. Concretely,

  • Offline: We measure a divergence between learner and expert actions on states from expert demonstrations.
  • Online: We measure a divergence between learner and expert trajectories.
  • Interactive: We measure a divergence between learner and expert actions but on states from learner rollouts.

Also notice that as we transition from Offline to Online IL, we add a requirement of access to the environment or an accurate simulator. As we move from Online to Interactive IL, we also need access to a queryable expert. Let (pi) denote the policy, (pi_E) denote the expert’s policy, and (f) denote the divergence. We can visualize our thoughts thus far as:

With this divergence-minimizing perspective in mind, we’re able to introduce a unifying, game-theoretic perspective.

A Game-Theoretic Perspective on IL

A natural question at this point might be: what divergence should one use to measure the difference between learner and expert behavior? Examples abound in the literature: Kullback-Liebler? Wasserstein? Jensen-Shannon? Total Variation? Maximum Mean Discrepancy? Without prior knowledge about the problem, it’s really hard to say. For example, KL Divergence has a mode-covering effect — this means that if half the data was the expert swerving left to avoid a tree and half the data was them swerving right, the learner would learn to pick a point in the middle and drive straight into the tree!

If we’re not sure what divergence is the right choice, we can just minimize all of them, which is equivalent to minimizing a worst-case or adversarially-chosen one. Using (pi) and (pi_E) to denote the learner and expert policies, we can write out the optimization problem for each setting:

  • Offline: $$ min_{pi} max_f mathbb{E}_{s, a sim pi_E}[f(s, pi(s)) – f(s, a)] $$
  • Online: $$ min_{pi} max_f mathbb{E}_{s, a sim pi}[f(s, a)] – mathbb{E}_{s, a sim pi_E}[f(s, a)]$$
  • Interactive: $$ min_{pi} max_f mathbb{E}_{s, a sim pi}[f(s, a) – f(s, pi_E(s))] $$

Each of these equations is in the form of a two-player zero-sum game between a learner (pi) and a discriminator (f). Two-player zero-sum games have been extensively studied in game theory, allowing us to use standard tools to analyze and solve them. Notice the similarity of the forms of these games — the only real difference is which state-action distributions the divergence is calculated between. Thus, we can view all three classes of imitation learning as solving a games with different classes of discriminators. This game-theoretic perspective is extremely powerful for a few reasons:

  1. As we have access to more information (e.g. a simulator or a queryable expert), we’re able to evaluate more powerful discriminators. Minimizing these more powerful discriminators leads to tighter performance bounds. Specifically, we show that the difference between learner and expert performance for offline IL scales quadratically with the horizon of the problem, and linearly for online / interactive IL. Quadratically compounding errors translate to poor real-world performance. Thus, one perspective on our bounds is that they show that access to a simulator or a queryable expert is both necessary and sufficient for learning performant policies. We recommend checking out the full paper for the precise upper and lower bounds.
  2. These performance bounds apply for all algorithms in each class — after all, you can’t do better by considering a more restrictive class of divergences. This means our bounds apply for a lot of prior work (e.g. Behavioral Cloning, GAIL, DAgger, MaxEnt IRL, …). Importantly, these bounds also apply for all non-adversarial algorithms: they’re just optimizing over a singleton discriminator class.
  3. Our game-theoretic perspective also tells us that finding a policy that minimizes the worst-case divergence is equivalent to finding a Nash Equilibrium of the corresponding game, a problem we know how to solve provably efficiently for two-player zero-sum games. By solving a particular game, we inherit the performance bounds that come with the class of divergences considered.

Together, these three points tell us that a game-theoretic perspective allows us to unify imitation learning as well as efficiently find strong policies!

A Practical Prescription for each IL Setting

Let’s dig into how we can compute Nash equilibria efficiently in theory and in practice for all three games. Intuitively, a Nash equilibrium is a strategy for each player such that no player wants to unilaterally deviate. This means that each player is playing a best-response to every other player. We can find such an equilibrium by competing two types of algorithms:

  • No-Regret: slow, stable, choosing best option over history.
  • Best-Response: fast, choosing best option to last iterate of other player.

Classic analysis shows that having one player follow a no-regret algorithm and the other player follow a best-response algorithm will, within a polynomial number of iterations, converge to an approximate Nash equilibrium of the game. The intuition of the proof is that if player 1 is steadily converging to a strategy that performs well even when player 2 choses their strategy adversarially, player 1 can’t have much of an incentive to deviate, meaning their strategy must be half of a Nash equilibrium.

We’d like to emphasize the generality of this approach to imitation learning: you can plug in any no-regret algorithm and both our policy performance and efficiency results still hold. There’s a plethora of algorithms that can be developed from this no-regret reduction perspective!

We instantiate this general template into an implementable procedure for each setting. We compare our approaches against similar recent work. We plot the performance of our methods in orange. (J(pi)) refers to learner’s expected cumulative reward while (pi_E) in green is the expert’s performance. As stated above, our goal is for the learner to match expert performance.

Offline: We adopt a model similar to a Wasserstein GAN where the learner acts as the generator and the discriminator tries to distinguish between learner and expert actions on expert states. We set the learner’s learning rate to be much lower than that of the discriminator, simulating no-regret on policy vs. best response on divergence. We term this approach Adversarial Value-moment IL, or AdVIL. We find it to be competitive with recent work:

Online: We repurpose the replay buffer of an off-policy RL algorithm as the discriminator by assigning negative rewards to actions that don’t directly match the expert. We impute a reward of +1 for expert behavior and -1/k for learner behavior from a past round, where k is the round number. The slow-moving append-only replay buffer implements a no-regret algorithm against a policy that best-responds via RL at each round. We term this approach Adversarial Reward-moment IL, or AdRIL, and find that it can significantly outperform other online IL algorithms at some tasks:

Interactive: We modify DAgger to use adversarially chosen losses at each round instead of a fixed function. At each round, a discriminator network is trained between the last policy and the expert. Then, for all samples for that round, this discriminator network is used as the loss function. Then, just like DAgger, the learner minimizes loss over the history of samples and loss functions for all rounds. Thus, the learner is following a no-regret algorithm against a best-response by the discriminator. We call this algorithm DAgger-esque Qu-moment IL, or DAeQuIL.

To demonstrate the potential advantages of DAeQuIL over DAgger, we test out both algorithms on a simulated UAV forest navigation task, where the expert demonstrates a wide variety of tree avoidance behaviors (left). DAgger attempts to match the mean of these interactively queried action labels, leading to it learning to crash directly into the first tree it sees (center). DAeQuIL, on the other hand, is able to learn to swerve out of the way of trees and navigate successfully through the forest (right).

Parting Thoughts

We provide, for all three settings of imitation learning, performance bounds for learned policies, a provably efficient reduction to no-regret online learning, and practical algorithms. If you’re interested in learning more, I recommend you check out:

There are lots of interesting areas left to explore in imitation learning, including imitation from observation alone that would allow one to leverage the large corpus of instructional videos online to train robots. Another direction that we’re particularly excited about is mimicking expert behavior, even in the presence of unobserved confounders. Stay tuned!

DISCLAIMER: All opinions expressed in this post are those of the author and do not represent the views of CMU.

Read More

Analyze customer churn probability using call transcription and customer profiles with Amazon SageMaker

Analyze customer churn probability using call transcription and customer profiles with Amazon SageMaker

Regardless of the industry or product, customers are the most important component in a business’s success and growth. Businesses go to great lengths to acquire and more importantly retain their existing customers. Customer satisfaction links directly to revenue growth, business credibility, and reputation. These are all key factors in a sustainable and long-term business growth strategy.

Given the marketing and operational costs of customer acquisition and satisfaction, and how costly losing a customer to a competitor can be, generally it’s less costly to retain new customers. Therefore, it’s crucial for businesses to understand why and when a customer might stop using their services or switch to a competitor, so they can take proactive measures by providing incentives or offering upgrades for new packages that could encourage the customer to stay with the business.

Customer service interactions provide invaluable insight into the customer’s opinion about the business and its services, and can be used, in addition to other quantitative factors, to enable the business to better understand the sentiment and trends of customer conversations and to identify crucial company and product feedback. Customer churn prediction using machine learning (ML) techniques can be a powerful tool for customer service and care.

In this post, we walk you through the process of training and deploying a churn prediction model on Amazon SageMaker that uses Hugging Face Transformers to find useful signals in customer-agent call transcriptions. In addition to textual inputs, we show you how to incorporate other types of data, such as numerical and categorical features in order to predict customer churn.

Interested in learning more about customer churn models? These posts might interest you:


To try out the solution in your own account, make sure that you have the following in place:

The JumpStart solution launch creates the resources properly set up and configured to successfully run the solution.

Architecture overview

In this solution, we focus on SageMaker components. We use SageMaker training jobs to train the churn prediction model and a SageMaker endpoint to deploy the model. We use Amazon Simple Storage Service (Amazon S3) to store the training data and model artifacts, and Amazon CloudWatch to log training and endpoint outputs. The following figure illustrates the architecture for the solution.

Exploring the data

In this post, we use a mobile operator’s historical records of which customers ended up churning and which continued using the service. The data also includes transcriptions of the latest phone call conversations between the customer and the agent (which could also be the streaming transcription as the call is happening). We can use this historical information to train an ML classifier model, which we can then use to predict the probability of customer churn based on the customer’s profile information and the content of the phone call transcription. We create a SageMaker endpoint to make real-time predictions using the model and provide more insight to customer service agents as they handle customer phone calls.

The dataset we use is synthetically generated and available under the CC BY 4.0 license. The data used to generate the numerical and categorical features is based on the public dataset KDD Cup 2009: Customer relationship prediction. We have generated over 50,000 samples and randomly split the data into 45,000 samples for training and 5,000 samples for testing. In addition, the phone conversation transcripts were synthetically generated using the GPT2 (Generative Pre-trained Transformer 2) algorithm. The data is hosted on Amazon S3.

More details on customer churn classification models using similar data, and also step-by-step instructions on how to build a binary classifier model using similar data, can be found in the blog post Predicting Customer Churn with Amazon Machine Learning. That post is focused more on binary classification using the tabular data. This blog post approaches this problem from a different perspective, and brings in natural language processing (NLP) by processing the context of agent-customer phone conversations.

The following are the attributes (features) of the customer profiles dataset:

  • CustServ Calls – The number of calls placed to customer service
  • State: The US state in which the customer resides, indicated by a two-letter abbreviation; for example, OH or NJ
  • VMail Message – The average number of voice mail messages per month
  • Account Length – The number of days that this account has been active
  • Day Mins, Day Calls, Day Charge – The billed cost for calls placed during the day
  • Eve Mins, Eve Calls, Eve Charge – The billed cost for calls placed during the evening
  • Night Mins, Night Calls, Night Charge – The billed cost for calls placed during nighttime
  • Intl Mins, Intl Calls, Intl Charge – The billed cost for international calls
  • Location – Whether the customer is located in urban, suburban, rural, or other areas
  • State – The state location of the customer
  • Plan – The plan category
  • Limit – Limited or unlimited plan type
  • Text – The synthetic GPT-2 generated transcription of the customer-agent phone conversation
  • Y: Whether the customer left the service (true/false)

The last attribute, Y, is known as the target feature, or the feature we want the ML model to predict. Because the target feature is binary (true/false), the type of modeling is a binary classification model. The model we train later in this post predicts the likelihood of churn as well.

We don’t go over exploratory data analysis in this post. For more details, see Predicting Customer Churn with Amazon Machine Learning and the Customer Churn Prediction with XGBoost sample notebook.

The training script is developed to allow the ML practitioner to pick and choose the features used in training. For example, we don’t use all the features in training. We focus more on the maturity of the customer’s account, number of times the customer has contacted customer service, type of plan they have, and transcription of the latest phone call. You can use additional features in training by including the list in the hyperparameters, as we show in the next section.

The transcription of customer-agent phone call in the text column is synthetic text generated by ML models using the GPT2 algorithm. Its purpose is to show how you can apply this solution to real-world customer service phone conversations. GPT2 is an unsupervised transformer language model developed by OpenAI. It’s a powerful generative NLP model that excels in processing long-range dependencies, and is pre-trained on a diverse corpus of text. For more details on how to generate text using GPT2, see Experimenting with GPT-2 XL machine learning model package on Amazon SageMaker and the Creative Writing using GPT2 Text Generation example notebook.

Train the model

For this post, we use the SageMaker PyTorch Estimator to build a SageMaker estimator using an Amazon-built Docker container that runs functions defined in the supplied entry_point Python script within a SageMaker training job. The training job is started by calling .fit() on this estimator. Later, we deploy the model by calling the .deploy() method on the estimator. Visit Amazon SageMaker Python SDK technical documentation for more details on preparing PyTorch scripts for SageMaker training and using the PyTorch Estimator.

Also, visit Available Deep Learning Containers Images on GitHub to get a list of supported PyTorch versions. At the time of this writing, the latest version available is PyTorch 1.8.1 with Python version 3.6. You can update the framework version to the latest supported version by changing the framework_version parameter in the PyTorch Estimator. You can also use SageMaker utility API image URIs to get the latest list of supported versions.

The hyperparameters dictionary defines which features we want to use for training and also the number of trees in the forest (n-estimators) for the model. You can add any other hyperparameters for the RandomForestClassifier; however, you also need revise your custom training script to receive these parameters in the form of arguments (using the argparse library) and add them to your model. See the following code:

hyperparameters = {
    "n-estimators": 100,
    "numerical-feature-names": "CustServ Calls,Account Length",
    "categorical-feature-names": "plan,limit",
    "textual-feature-names": "text",
    "label-name": "y"

estimator = PyTorch(

If you launched the SageMaker JumpStart solution in your account, the custom scripts are available in your Studio files. We use the entry_point.py script. This script receives a list of numerical features, categorical features, textual features, and the target label, and trains a SKLearn RandomForestClassifier on the data. However, the key here is processing the features before using them in the classifier, especially the call transcription. The following figure shows this process, which applies imputing to numerical features and replaces missing values with mean, one-hot encoding to categorical features, and embeds transformers to textual features.

The purpose of the script presented in this post is to provide an example of how you can develop your own custom feature transformation pipeline. You can apply other transformations to the data based on your specific use case and the nature of your dataset, and make it as complex or as simple as you want. For example, depending on the nature of your dataset and the results of the exploratory data analysis, you may want to consider normalization, log transformation, or dropping records with null values. For a more complete list of feature transformation techniques, visit SKLearn Dataset Transformations.

The following code snippet shows you how to instantiate these transformers for numerical and categorical features, and how to apply them to your dataset. More details on how these are done in the training script is available in the entry_point.py script that is launched in your files by the JumpStart solution.

from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OneHotEncoder

# Instantiate transformers
numerical_transformer = SimpleImputer(missing_values=np.nan, 
categorical_transformer = OneHotEncoder(handle_unknown="ignore")

# Train transformers on data, and store transformers for future use by predict function
joblib.dump(numerical_transformer, Path(args.model_dir, "numerical_transformer.joblib"))

joblib.dump(categorical_transformer, Path(args.model_dir, "categorical_transformer.joblib"))

# transform the data
numerical_features = numerical_transformer.transform(numerical_features)
categorical_features = categorical_transformer.transform(categorical_features)

Now let’s focus on the textual data. We use Hugging Face sentence transformers, which you can use for sentence embedding generation. They come with pre-trained models that you can use out of the box based on your use case. In this post, we use the bert-base-nli-cls-token model, which is described in Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks.

Recently, SageMaker introduced new Hugging Face Deep Learning Containers (DLCs) that enable you to train, fine-tune, and run inference using Hugging Face models for NLP on SageMaker. In this post, we use the PyTorch container and a custom training script. For this purpose, in our training script, we define a BertEncoder class based on Hugging Face SentenceTransformer and define the pre-trained model as bert-base-nli-cls-token, as shown in the following code. The reason for this is to be able to apply the transformer to the dataset in the same way as the other dataset transformers, with the applying .transform() method. The benefit of using Hugging Face pre-trained models is that you don’t need to do additional training to be able to use the model. However, you can still fine-tune the models with custom data, as described in Fine-tuning a pretrained model.

from sentence_transformers import SentenceTransformer

# Define a class for BertEncoder
class BertEncoder(BaseEstimator, TransformerMixin):
    def __init__(self, model_name='bert-base-nli-cls-token'):
        self.model = SentenceTransformer(model_name)
        self.model.parallel_tokenization = False

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        output = []
        for sample in X:
            encodings = self.model.encode(sample)
        return output

# Instantiate the class 
textual_transformer = BertEncoder()

# Apply the transformation to textual features
textual_features = textual_transformer.transform(textual_features)

Now that the dataset is processed and ready to be consumed by an ML model, we can train any classifier model to predict if a customer will churn or not. In addition to predicting the class (0/1 or true/false) for customer churn, these models also generate the probability of each class, meaning the probability of a customer churning. This is particularly useful for customer service teams for strategizing the incentives or upgrades they can offer to the customer based on how likely the customer is to cancel the service or subscription. In this post, we use the SKLearn RandomForestClassifier model. You can choose from many hyperparameters for this model and also optimize the hyperparameters for a more accurate model prediction by using strategies like grid search, random search, and Bayesian search. SageMaker automatic hyperparameter tuning can be a powerful tool for this purpose.

Training the model in entry_point.py is handled by the train_fn() function in the custom script. This function is called when the .fit() method is applied to the estimator. This function also stores the trained model and trained data transformers on Amazon S3. These files are used later by model_fn() to load the model for inference purposes.

train_fn() also includes evaluation of the trained model, and provides accuracy scores for the model for both train and test datasets. This helps you better evaluate model performance. Because this is a classification problem, we recommend including other metrics in your evaluation script, for example F1 score, ROC AUC score, and recall score, the same way we added accuracy scores. These are printed as the training progresses. Because we’re using synthetic data for training the model in this example notebook, especially for the agent-customer call transcription, we’re not expecting to see high-performing models with regards to classification metrics, and therefore we’re not focusing on these metrics in this example. However, when you use your own data, you should consider how each classification metric could impact the applicability of the model to your use case. Training this model on 45,000 samples on an ml.p3.2xlarge instance takes about 30 minutes.

    'train': 's3://path/to/your/train.jsonl')),
    'test': 's3://path/to/your/test.jsonl'))

When you’re comfortable with the performance of your model, you can move to the next step, which is deploying your model for real-time inference.

Deploy the model

When the training is complete, you can deploy the model as a SageMaker hosted endpoint for real-time inference, or use the model for offline batch inference, using SageMaker batch transform. The task of performing inference (either real time or batch) is handled by four main functions in the custom script:

  • input_fn() processes the input data
  • model_fn() loads the trained model artifacts from Amazon S3
  • predict_fn() makes predictions
  • output_fn() prepares the model output

The following diagram illustrates this process.

The following script is a snippet of the entry_point.py script, and shows how the four functions work together to perform inference:

# Model function to load the trained model and trained transformers from S3
def model_fn(model_dir):
    print('loading feature_names')
    numerical_feature_names, categorical_feature_names, textual_feature_names = load_feature_names(Path(model_dir, "feature_names.json"))
    print('loading numerical_transformer')
    numerical_transformer = joblib.load(Path(model_dir, "numerical_transformer.joblib"))
    print('loading categorical_transformer')
    categorical_transformer = joblib.load(Path(model_dir, "categorical_transformer.joblib"))
    print('loading textual_transformer')
    textual_transformer = BertEncoder()
    classifier = joblib.load(Path(model_dir, "classifier.joblib"))
    model_assets = {
        'numerical_feature_names': numerical_feature_names,
        'numerical_transformer': numerical_transformer,
        'categorical_feature_names': categorical_feature_names,
        'categorical_transformer': categorical_transformer,
        'textual_feature_names': textual_feature_names,
        'textual_transformer': textual_transformer,
        'classifier': classifier
    return model_assets

# Input Preparation Function to receive the request body and ensure proper format
def input_fn(request_body_str, request_content_type):
    assert (
        request_content_type == "application/json"
    ), "content_type must be 'application/json'"
    request_body = json.loads(request_body_str)
    return request_body

# Predict function to make inference
def predict_fn(request, model_assets):
    print('making batch')
    request = [request]
    print('extracting features')
    numerical_features, categorical_features, textual_features = extract_features(
    print('transforming numerical_features')
    numerical_features = model_assets['numerical_transformer'].transform(numerical_features)
    print('transforming categorical_features')
    categorical_features = model_assets['categorical_transformer'].transform(categorical_features)
    print('transforming textual_features')
    textual_features = model_assets['textual_transformer'].transform(textual_features)
    # Concatenate Features
    print('concatenating features')
    categorical_features = categorical_features.toarray()
    textual_features = np.array(textual_features)
    textual_features = textual_features.reshape(textual_features.shape[0], -1)
    features = np.concatenate([
    ], axis=1)
    print('predicting using model')
    prediction = model_assets['classifier'].predict_proba(features)
    probability = prediction[0][1].tolist()
    output = {
        'probability': probability
    return output

# Output function to prepare the output
def output_fn(prediction, response_content_type):
    assert (
        response_content_type == "application/json"
    ), "accept must be 'application/json'"
    response_body_str = json.dumps(prediction)
    return response_body_str

To deploy the model, when the training is complete, we use the .deploy() method on the estimator and define the number and type of instances we want to attach to the endpoint, and SageMaker manages the infrastructure on your behalf. When calling the endpoint from the notebook, we use a SageMaker SDK predictor. The predictor sends data to an endpoint (as part of a request), and interprets the response. See the following code:

# Deploy the predictor
predictor = estimator.deploy(

predictor.serializer = JSONSerializer()
predictor.deserializer = JSONDeserializer()

This deploys the model as an endpoint predictor. After deployment is complete, we can use that to make predictions on sample data. Let’s determine the probability of churn for a hypothetical customer:

data = {
    "CustServ Calls": 10.0,
    "Account Length": 66,
    "plan": "B",
    "limit": "limited",
    'text': "Well, I've been dealing with TelCom for three months now and I am quite happy with your service"}

response = predictor.predict(data=data)

print("{:.2%} probability of churn".format(response['probability']))

In this case, the probability of churn is about 31%. For the same customer, we change the transcript to “I have been using your service for 6 months and I am disappointed in your customer service.” The probability of churn increases to over 46%. This demonstrates that a change in the customer’s sentiment affects the probability of churn.

Clean up

To clean up the resources and stop incurring charges in your account, you can delete the endpoint:



As we explained earlier, you can use additional features in training and also incorporate more feature transformers in the feature engineering pipeline, which can help improve model performance.

In addition, now that you have a working endpoint that is performing real-time inference, you can use it for your applications or website. However, your SageMaker endpoint is still not public facing, so you need to build an API Gateway to allow external traffic to your SageMaker endpoint. Amazon API Gateway is a fully managed service that makes it easy for developers to create, publish, maintain, monitor, and secure APIs at any scale. You can use API Gateway to present an external-facing, single point of entry for SageMaker endpoints, and provide security, throttling, authentication, firewall as provided by AWS WAF, and more. With API Gateway mapping templates, you can invoke your SageMaker endpoint with a REST API request and receive an API response back without needing any intermediate AWS Lambda functions, thereby improving the performance and cost-effectiveness of your applications.

To create an API Gateway and use it to perform real-time inference with your SageMaker endpoint (see the following architecture), you can follow the instructions outlined in Creating a machine learning-powered REST API with Amazon API Gateway mapping templates and Amazon SageMaker.

In addition, you can use Amazon Transcribe to generate transcriptions of recorded customer-agent conversations and use them for training purposes, and also use Amazon Transcribe streaming to send the conversation audio stream and receive a stream of text in real time. You can use this text stream to add a real-time speech-to-text capability to your applications and also send that text to the endpoint and provide customer churn insights to your customer service agents in real time.


In this post, we explained an end-to-end solution for creating a customer churn prediction model based on customer profiles and customer-agent call transcriptions. The solution included training a PyTorch model with a custom script and creating an endpoint for real-time model hosting. We also explained how you can create a public-facing API Gateway that can be securely used in your mobile applications or website. In addition, we explained how you can use Amazon Transcribe for batch or real-time transcription of customer-agent conversations, which you can use for training of your model or real-time inference.

For more SageMaker examples, visit the Amazon SageMaker Examples GitHub repo. For more PyTorch BYO script examples, visit the following GitHub repository. For more SageMaker Python examples for MXNet, TensorFlow, and PyTorch, visit the Amazon SageMaker Pre-Built Framework Containers and the Python SDK GitHub repo. Additional information about SageMaker is available in the technical documentation.

About the Author

Nick Minaie is an Sr AI/ML Specialist Solutions Architect with AWS, helping customers on their journey to well-architected machine learning solutions at scale. In his spare time, Nick enjoys family time, abstract painting, and exploring nature.



Ehsan M. Kermani is a Machine Learning Engineer in the AWS ML Automation Services group. He helps customers through their MLOps journey by providing his expertise in Software Engineering best practices to solve customers’ end-to-end Machine Learning tasks from infrastructure to deployment.


Dr. Li Zhang is a Principal Product Manager-Technical for Amazon SageMaker JumpStart and Amazon SageMaker built-in algorithms, a service that helps data scientists and machine learning practitioners get started with training and deploying their models, and uses reinforcement learning with Amazon SageMaker. His past work as a principal research staff member and master inventor at IBM Research has won the test of time paper award at IEEE INFOCOM.

Read More

Get started with the Amazon Kendra Amazon WorkDocs connector

Get started with the Amazon Kendra Amazon WorkDocs connector

Amazon Kendra is an intelligent search service powered by machine learning (ML). Amazon Kendra reimagines enterprise search for your websites and applications so your employees and customers can easily find the content they’re looking for, even when it’s scattered across multiple locations and content repositories within your organization.

With Amazon Kendra, you can search through troves of unstructured data and discover the right answers to your questions, when you need them. Amazon Kendra is a fully managed service, so there are no servers to provision, and no ML models to build, train, or deploy.

Amazon WorkDocs is a fully managed and secure content creation, storage, and collaboration service. With Amazon WorkDocs, you can easily create, edit, and share content. Moreover, because it’s stored centrally on AWS, you can access it from anywhere on any device.

In this post, we show how Amazon Kendra allows your users to search documents stored in Amazon WorkDocs.

Use case

For this post, we created a specific folder in Amazon WorkDocs containing a set of PDFs and Microsoft Word documents that we want to search content on. The Amazon WorkDocs connector also allows you to ingest comments for those documents.

The following screenshot shows the contents of a fictional WorkDocs folder called WorkdocsBlogpostDataset.

Create an Amazon WorkDocs connector

To create an Amazon WorkDocs connector, complete the following steps:

  1. On the Amazon Kendra console, choose Data sources.
  2. Choose Add data source.
  3. Under WorkDocs, choose Add connector.
  4. For Data source name, enter a name for your data source.
  5. Enter an optional description.
  6. Choose Next.
  7. In the Source section, choose the organization ID for your Amazon WorkDocs site.
  8. Create a new AWS Identity and Access Management (IAM) role for the data source.
  9. For Sync scope, select Crawl document comments and Use change logs.

For this post, we want Amazon Kendra to ingest the documents in the WorkdocsBlogpostDataset folder.

  1. In the Additional configuration section, enter WorkdocsBlogpostDataset as a path on the Include patterns tab.
  2. Choose Add.
  3. For Sync run schedule¸ choose Run on demand.
  4. Choose Next.
  5. In the WorkDocs field mapping section, use the default field mapping.
  6. Choose Next.
  7. Review the settings and choose Create.
  8. When the creation process is complete, choose Sync.

When the sync process complete, you can see how many documents were ingested.

Now your documents are ready be searched by Amazon Kendra.

  1. In the navigation pane, choose Search console.

You can now submit some test queries, as shown in the following screenshots.

Also, with the Amazon WorkDocs connector, you can ingest feedback (comments) on your documents. For example, the following screenshot shows that this document has feedback.

The following screenshot shows what the feedback search experience looks like.


In this post, you created a data source and ingested your Amazon WorkDocs documents into your Amazon Kendra index. As a next step, you can try some more queries and see what kind of results you obtain. You can also dive deep into Amazon Kendra with the Amazon Kendra Essentials workshop or try the multilingual chatbot experience.

About the Author

Juan Bustos is an AI Services Specialist Solutions Architect at Amazon Web Services, based in Dallas, TX. Outside of work, he loves spending time writing and playing music as well as trying random restaurants with his family.




Vijai Gandikota is a Senior Product Manager at Amazon Web Services for Amazon Kendra.

Read More

Investing in academic research to improve our privacy technology: Our approach and recent RFP winners

One of our goals over the next decade is to build stronger privacy protections for everyone who uses our apps and services. Our latest research award opportunity in privacy-enhancing technology and the recently launched request for proposals on Building Tools to Enhance Transparency in Fairness and Privacy are the next of many steps toward that goal, and a continuation of several years of investments in the privacy research space.

Our approach to academic research and investments

Through a variety of programs, partnerships, and collaborations, Facebook researchers work with the global academic community on topics that align with our mission to give people the power to build community and bring the world closer together. “We are sponsoring labs and conferences, partnering with academics on short- and long-term projects, and supporting PhD students through our Fellowship program,” says Sharon Ayalde, Research Program Manager, Facebook Academic Engagements. “We also provide research award opportunities through open requests for proposals.”

Requests for proposals (RFPs) in particular help us strengthen our ties to academia and foster community. Through RFPs, we are able to discover activities and key players in academia that are aligned with our research challenges. Research funds are generally awarded as unrestricted gifts to accredited universities to help finance winning proposals. In general, there are 15 to 20 RFP opportunities each year across a variety of research topics, such as privacy, networking, data science, probability, machine learning, and UX.

Investing in these research projects helps accelerate the field for everyone and allows us to apply the most cutting-edge technologies to our apps and services. In the privacy research space, we’ve steadily increased opportunities for academic collaboration, and research project funding continues to be available. Last year, we granted research awards in key topics such as privacy-preserving technologies and cryptography, user experiences in privacy, and privacy in AR/VR and smart device products. These opportunities alone attracted more than 300 applications, with over $2 million in total funding.

The 2020 People’s Expectations and Experiences with Digital Privacy RFP, in particular, received 147 proposals from 34 countries and 120 universities. The five winning proposals represented 14 universities, including Cornell University, Carnegie Mellon University, the Hebrew University of Jerusalem, India Institute of Technology, Brigham Young University, Northwestern University, and Hamad Bin Khalifa University.

What’s next

In 2021 and beyond, we will continue our investment in research and innovation to help us develop new ways to build products and process data with privacy in mind. We’ll also continue to work with policymakers, privacy experts, global organizations and developers on building solutions to ensure that people feel safe and comfortable using our products.

“Our world and the role of technology in our lives and society is evolving faster than ever before,” says Scott Renfro, Facebook Software Engineer. “It’s critical that we work hard to put privacy, safety, and security first and work with people at the forefront of emerging technologies and scientific understanding to find better solutions. This is why we want to collaborate with academia and support the important work they do by launching another research award opportunity.”

As part of our continued investment, we are pleased to announce the winners and finalists of the 2021 Privacy-Enhancing Technologies RFP, which sought proposals from academics conducting research in applied cryptography, data policies and compliance, differential privacy, and privacy in AI. The research award opportunity attracted 159 proposals from 102 universities. Thank you to everyone who took the time to submit a proposal, and congratulations to the winners.

Research award recipients

Principal investigators are listed first unless otherwise noted.

Bridging secure computation and differential privacy
Jonathan Katz (University of Maryland College Park)

Cryptographic enforcement of end-to-end data privacy
Anwar Hithnawi (ETH Zurich)

Implementing a flexible framework for privacy accounting
Salil Vadhan (Harvard University)

InferViz: Weighted inference and visualization of insecure code paths
Musard Balliu (KTH Royal Institute of Technology), Marco Guarnieri (IMDEA Software Institute)

Practical differential privacy: Using past and present to inform future
Aleksandra Korolova, Brendan Avent (University of Southern California)

Privacy-preserving machine learning via ADMM
Yupeng Zhang (Texas A&M University)

Private authentication with complex assertions and abuse prevention
Ian Miers (University of Maryland College Park)

Safeguarding user data against cross-library data harvesting
Luyi Xing, Xiaojing Liao (Indiana University Bloomington)

SEBRA: SEcuring BRowser Extensions by Information Flow Analysis
Andrei Sabelfeld (Chalmers University of Technology)

Towards privacy-preserving and fair ad targeting with federated learning
Golnoosh Farnadi (HEC Montreal and MILA), Martine De Cock (University of Washington Tacoma)


A methodological approach to privacy-preserving data analysis pipelines
Patrick Thomas Eugster, Savvas Savvides (Università della Svizzera italiana)

A toolkit for locally private statistical inference
Clement Canonne, Vincent Gramoli (University of Sydney)

Advancing differential privacy accounting
Yu-Xiang Wang (University of California Santa Barbara)

An informed consent management engine to control the privacy of IoT devices
John Grundy, Mohan Chhetri, Zubir Baig, Chehara Pathmabandu (Monash University)

Beyond cookies: Private personalization for the tracker-free web
Henry Corrigan-Gibbs (Massachusetts Institute of Technology)

Challenges in E2E encryption
Yevgeniy Dodis (New York University)

Consent flows tracking for OAuth2.0 standard protocol
Alex Pentland, Thomas Hardjono (Massachusetts Institute of Technology)

Deletion compliance in data systems
Manos Athanassoulis (Boston University)

Differentially private analyses of textual data, such as Facebook posts
Gary King (Harvard University)

Differentially private collection of key-value pairs using multi-party computation
Florian Kerschbaum (University of Waterloo)

Differentially private analysis of streaming and graph data
Jerome Le Ny (Polytechnique Montreal)

Differentially private multi-task learning
Virginia Smith, Steven Wu (Carnegie Mellon University)

DragonFLy: Private, efficient, and accurate federated learning
Adam O’Neill, Amir Houmansadr (University of Massachusetts Amherst)

Efficient sparse vector aggregation for private federated learning
Giulia Fanti, Elaine Shi (Carnegie Mellon University)

End-to-end privacy compliance in distributed web services
Malte Schwarzkopf (Brown University)

Fast identity online with attributes and global revocation (sFIDO)
Lucjan Hanzlik (CISPA Helmholtz Center for Information Security)

InferViz: Weighted inference and visualization of insecure code paths
Musard Balliu (KTH Royal Institute of Technology), Marco Guarnieri (IMDEA Software Institute)

Practical private information retrieval with privacy-enhancing applications
Ling Ren (University of Illinois Urbana-Champaign)

Privacy-preserving machine learning through label differential privacy
Prateek Mittal, Amir Houmansadr (Princeton University)

Privacy in sketches for big data analytics
Pedro Reviriego-Vasallo (University Carlos III de Madrid)

Privacy of data set properties in machine learning
Olga Ohrimenko (University of Melbourne)

Searching for accurate and efficient private models
Reza Shokri (National University of Singapore)

Symmetric homomorphic encryption for fast privacy-preserving data analysis
Patrick Thomas Eugster, Savvas Savvides (Università della Svizzera italiana)

Scalable and secure protocols for data linking and analytics
Xiao Wang (Northwestern University)

The post Investing in academic research to improve our privacy technology: Our approach and recent RFP winners appeared first on Facebook Research.

Read More

Setting the Virtual Stage: ‘Deathtrap Dungeon’ Gets Interactive Thanks to NVIDIA RTX

Setting the Virtual Stage: ‘Deathtrap Dungeon’ Gets Interactive Thanks to NVIDIA RTX

Deathtrap Dungeon: The Golden Room is a gripping choose-your-own-adventure story, but it’s no page-turner.

Based on the best-selling book of the same name, it’s an interactive film in which viewers become the player on their quest to find The Golden Room while facing down dungeon masters and avoiding traps.

NVIDIA RTX technology powers the real-time graphics and virtual sets behind this latest adaptation, which showcases the future of interactive storytelling on a virtual production stage.

On-Set Facilities (OSF) provided the technology for the virtual production. Using its own low-latency computing platform, the GODBOX powered by NVIDIA RTX, OSF enhanced virtual production workflows and delivered real-time compositing and previsualization for the interactive experience.

Bringing Virtual Sets to Life with NVIDIA RTX

When it comes to bringing VFX on set, OSF faced a common challenge — finding computers that could be configured for their creative teams and production needs. So they created their own on-set computer platform, GODBOX Workstations and Servers. It’s a synchronized real-time virtual production platform for low-latency, frame-accurate, virtual production applications and workflows.

From LED and in-camera VFX to mixed reality and motion capture, the GODBOX provides all the tools, features and solutions needed to set up and run a virtual production from any set.

All images courtesy of On-Set Facilities.

For Deathtrap Dungeon, OSF laid the foundations during preproduction and previsualization. The team used virtual sets and real locations, and combined that with real-time visual effects to bring sets to life. Digitally creating the previsual assets allowed the team to specify the size of stages, amounts of props and how many physical sets were needed.

“The objective was to previsualize the final VFX on the set, so that the directors, actors and crew could all see the virtual world,” said Asa Bailey, director of virtual production at OSF. “The GODBOX delivers in-camera VFX and real-time compositing pipelines powered by NVIDIA RTX. The platform was specifically designed to work with all kinds of virtual productions.”

Throughout the preproduction and the film shoot, OSF used Unreal Engine to conduct virtual scouting sessions using the GODBOX cloud production VPN. Using the secure cloud platform, OSF tested the virtual sets and worked with the production to set lighting and camera movements, all before going on set.

With the RTX-powered GODBOX, OSF also delivered real-time compositing so the cast and crew can see their performance with the virtual set and characters.  The team combined green screen live action with virtual sets and VFX elements in Unreal Engine. Then they’d take the real-time composition and feed it through to large screens and projectors on set.

OSF’s GODBOX stays updated with the latest NVIDIA drivers, as well as recently released optimizations. This helps increase the stability of the machine, which becomes crucial when the cameras are rolling.

Learn more about On-Set Facilities, GODBOX low-latency computing and virtual production. And see other NVIDIA solutions in media and entertainment.

The post Setting the Virtual Stage: ‘Deathtrap Dungeon’ Gets Interactive Thanks to NVIDIA RTX appeared first on The Official NVIDIA Blog.

Read More

Partnering with the NSF on a research institute for AI to improve care for older adults

From the early days of the internet to the development of the Human Genome Project, U.S. government-funded R&D has yielded remarkable progress for society, and today it is an important engine for AI research. That’s why, last year, we were proud to announce our partnership with the U.S. National Science Foundation (NSF) to provide $5M to support the establishment of national research institutes working in the area of Human-AI Interaction and Collaboration (HAIC). This partnership—which is part of a more than $300M NSF investment in AI Research Institutes—will create vibrant research centers across the U.S. to advance how people and AI collaborate through speech, text, gestures, and more. It also builds on our partnership with the NSF on next generation networks, and our AI research collaborations with U.S. federal agencies on weather modeling, robust AI systems, whale population monitoring, and more. 

Today, we are delighted to share that NSF has selected the AI Institute for Collaborative Assistance and Responsive Interaction for Networked Groups (AI-CARING) led by Georgia Tech, along with Carnegie Mellon University, Oregon State University, and University of Massachusetts Lowell to receive the $20M AI Institute for HAIC grant. AI-CARING will improve collaboration and communication in caregiving environments for older adults by developing AI systems that adjust to the evolving personal needs and behaviors of those requiring care. With our growing research presence in Atlanta, we’re excited to build on our rich history of collaboration with Georgia Tech and its partners in this effort—most recently supporting some of these universities’ work to help vulnerable populations find important information on COVID-19 and monitoring and forecasting disease spread.

With a growing population of older adults in need of caregiving, AI systems can be useful in a variety of contexts, like conversational assistants, health sensing, and improving coordination across the care network. For example, AI can help existing voice assistants better understand people with speech impairments, and can be integrated in home bathrooms to make them more accessible. The AI-CARING Institute will develop assistive AI agents across these types of contexts to help those requiring caregiving to sustain their independence and  improve their quality of life. Additionally, this research will be the product of interdisciplinary teams—with expertise across AI, geriatrics, behavioral sciences, and design—working to ensure that AI is deployed responsibly in this context, with human-centered principles in mind.

Congratulations to the recipient universities of the AI Institute awards and the faculty, listed below. We look forward to learning from the team’s research, sharing our resources and expertise, and building a collaboration to help older adults lead more independent lives and improve the quality of their care.

Recipient university institutions:

  • Georgia Institute of Technology
  • Carnegie Mellon University
  • Oregon State University
  • University of Massachusetts Lowell


  • Sonia Chernova (Georgia Tech) – PI
  • Elizabeth Mynatt (Georgia Tech) – Co-PI
  • Reid Simmons (Carnegie Mellon University) – Co-PI
  • Kagan Tumer (Oregon State University) – Co-PI
  • Holly Yanco (University of Massachusetts Lowell) – Co-PI

Read More

GFN Thursday Brings ‘Evil Genius 2: World Domination,’ ‘Escape From Naraka’ with RTX, and More This Week on GeForce NOW

GFN Thursday Brings ‘Evil Genius 2: World Domination,’ ‘Escape From Naraka’ with RTX, and More This Week on GeForce NOW

This GFN Thursday shines a spotlight on the latest games joining the collection of over 1,000 titles in the GeForce NOW library from the many publishers that have opted in to stream their games on our open cloud-gaming service.

Members can look forward to 14 games — including Evil Genius 2: World Domination from Rebellion and Escape From Naraka, which features RTX for Founders and Priority members — joining the GeForce NOW library this week. Team17’s Hell Let Loose, already streaming on GeForce NOW, has left early access with the introduction of the Soviet forces on the eastern front.

Getting Rebellious

Ready to take over the world? This GFN Thursday, stream several exciting titles from Rebellion — the award-winning British independent studio. Games include Evil Genius 2: World Domination, Evil Genius, Battlezone: Combat Commander and Zombie Army 4: Dead War.

Evil Genius 2 on GeForce NOW
Become a criminal mastermind in this wicked awesome sequel.

Be the best bad guy you can be in Evil Genius 2: World Domination (Steam). Construct an evil lair, train your minions to carry out nefarious plans and fight against the Forces of Justice with an array of traps to achieve global domination in this satirical spy-fi lair-builder game.

Members can also look forward to the original Evil Genius coming to the cloud this week, as well as Battlezone: Combat Commander for space-shooting intergalactic battles. And gamers can sink their teeth into Zombie Army 4: Dead War for the fight against the undead Armageddon. With all of these new additions, more gamers now can play these awesome titles from Rebellion.

“GeForce NOW gives us a great opportunity to provide access to our catalog of games to even more gamers,” said Matt Jeffery, chief strategy officer at Rebellion. “NVIDIA makes onboarding games to the cloud a simple process, helping us bring our games to players on low-powered or incompatible devices, with the power of a real gaming rig.”

These newest additions join a host of other Rebellion titles available to stream on GeForce NOW, including Sniper Elite 3, Sniper Elite 4 and Sniper Elite V2 Remastered for players who aim to have a good time with a stealth story. Plus more undead action in the Zombie Army Trilogy and an adventure full of puzzles, traps and mummies in Strange Brigade that’s too good to keep under wraps.

Heating Things Up

The front line is calling. Hell Let Loose (Steam) from Team17 has left early access and is streaming in open access on GeForce NOW.

Charge into battle in this hardcore World War II first-person shooter with epic battles and players filling roles of infantry, tanks, artillery and the dynamically shifting front line. Battle across more than nine maps modeled on real-life locales. Fight in iconic battles from the Western Front, including Carentan, Omaha Beach, Foy and more. Members can experience the chaos of all-out war and stream Hell Let Loose with GeForce NOW this week.

Members can also check out other popular titles from Team17 in the GeForce NOW library, including multiplayer games like Overcooked! and Overcooked! 2 and Golf with Friends for gamers looking to play with friends. Or experience a variety of RPGs with the cozy and creative Hokko Life, the pirate adventure King of the Seas, and the alien world of Planet Alpha.

Up This Week

As always, GFN Thursday means more games. This week, members can journey to a nightmarish temple in Escape from Naraka and save their beloved from an evil demon of legend.

Escape from Naraka on GeForce NOW
Don’t miss out on a new adventure with ‘Escape from Naraka,’ one of 14 titles joining GeForce NOW this week.

Complete challenges, unlock the secrets of the temple and use unique abilities to confront terrifying enemies and escape the temple. Founders and Priority members can play Escape from Naraka on GeForce NOW with beautifully ray-traced reflections, ray-traced shadows and RTX Global Illumination.

Here’s the complete list of 14 titles coming to the cloud this week:

And while you decide which of these games to spend your weekend with, we have an important question for you.

who’s the most iconic villain in video game history? 🤔

🌩 NVIDIA GeForce NOW (@NVIDIAGFN) July 28, 2021

Shout it out on Twitter, and we’ll see you next week!

The post GFN Thursday Brings ‘Evil Genius 2: World Domination,’ ‘Escape From Naraka’ with RTX, and More This Week on GeForce NOW appeared first on The Official NVIDIA Blog.

Read More

Facebook Fellow Spotlight: Empowering women in rural communities through research in HCI

Each year, PhD students from around the world apply for the Facebook Fellowship, a program designed to encourage and support doctoral students engaged in innovative and relevant research in areas related to computer science and engineering at an accredited university.

As a continuation of our Fellowship spotlight series, we’re highlighting 2020 Facebook Fellow Sharifa Sultana.

Sharifa is a PhD candidate in Information Science at Cornell University. Her work focuses on human-computer interaction (HCI) and information and communication technologies for development (ICTD) from a critical computing and feminist HCI perspective.

Raised in Jessore, Bangladesh, Sharifa noticed that women were underrepresented in STEM education and other professions around the world, particularly in Bangladesh. Because of this underrepresentation, many women in rural communities have difficulties accessing, trusting, and using technology. This inspired Sharifa to work towards creating a more inclusive environment in which women would feel empowered to use technology, and where technology could, in turn, help fight the oppression of women in her home country.

“My research asks the questions, ‘Why is tech not working for rural Bangladeshi women? How can we fight against oppression using tech?’” she says. Sharifa’s approach explores how women interact with technology in rural communities in an effort to develop and implement solutions that address their critical needs.

One of these needs is combating gender harassment. “In Bangladesh, women are often harassed by colleagues, friends, family members — people who they want to trust,” she says. “Yet it is often difficult for them to seek legal help for many reasons.”

In order to empower women to counter harassment, Sharifa designed a digital tool – ‘Unmochon’ – to collect evidence of tech-based harassment through Facebook Messenger. Users can install and run it to collect image evidence of harassing messages and the harassers’ Facebook handles. This tool allows users to report the incident to the appropriate authorities and confirm the authenticity of the evidence.

Sharifa’s most recent research focuses on alternative rationalities in computing – namely, exploring how rural communities determine what information is true and how misinformation can prevent women from seeking healthcare. “The aim is to design tech that would actually help [women], that they would actually use,” Sharifa says.

Healthcare misinformation is a serious issue for rural communities in Bangladesh, especially during the COVID-19 pandemic. She hopes to develop technology that will give people access to reliable information and connect them with the healthcare they need.

Sharifa’s research has opened up a new discussion on how HCI design can be used to address online gender harassment and on how studying HCI can help bridge the gap between women accessing life-saving healthcare. Currently, Sharifa is in Bangladesh, collaborating on a local research project to determine what kind of technology and healthcare practices could benefit rural communities.

To learn more about Sharifa Sultana, visit her Fellowship profile.

The post Facebook Fellow Spotlight: Empowering women in rural communities through research in HCI appeared first on Facebook Research.

Read More

Orchestrate XGBoost ML Pipelines with Amazon Managed Workflows for Apache Airflow

Orchestrate XGBoost ML Pipelines with Amazon Managed Workflows for Apache Airflow

The ability to scale machine learning operations (MLOps) at an enterprise is quickly becoming a competitive advantage in the modern economy. When firms started dabbling in ML, only the highest priority use cases were the focus. Businesses are now demanding more from ML practitioners: more intelligent features, delivered faster, and continually maintained over time. An effective MLOps strategy requires a unified platform that can orchestrate and automate complex data processing and ML tasks, and integrates with the latest tooling to best complete those tasks.

This post demonstrates the value of using Amazon Managed Workflows for Apache Airflow (Amazon MWAA) to orchestrate an ML pipeline using the popular XGBoost (eXtreme Gradient Boosting) algorithm. For more advanced and comprehensive MLOps capabilities, including a purpose-built model orchestration framework and a continuous integration and continuous delivery (CI/CD) service for ML, readers are encouraged to check out Amazon SageMaker Pipelines.

Why Airflow for orchestration

Customers choose Apache Airflow and specifically Amazon MWAA for several reasons, but three stand out:

  • Airflow is Python-based – Airflow, as a Python-based tool, enjoys the benefits of an imperative programming paradigm. This enables developers to programmatically define how tasks are to be done. Tools that are declarative, such as AWS Step Functions, only allow you to define what is to be done. When orchestrating ML pipelines, the ability to directly define the control flow is often required to navigate complex workflows.
  • Directed Acyclic Graph (DAG) workflow management – Airflow provides a DAG interface as a simple mechanism for defining and running complex workflows with dependencies. These DAG workflows are visualized through a GUI for operations management.
  • Extensibility – Airflow operators provide a structured way to perform common tasks using reusable modules. This capability is extensible and providers are free to develop custom Airflow operators that integrate with their tools and services. Many cloud-based services are supported. These operators provide useful abstraction, repeatability, and an API. In the context of big data and ML, these operators are especially valuable because they provide a way to orchestrate sometimes very long-running data pipelines or asynchronous ML processes such as model training.

Set up an Amazon MWAA environment

To create your Amazon MWAA environment, complete the following steps:

  1. On the Amazon MWAA console, choose Create environment.
  2. For Name, enter a unique name.
  3. For Airflow version, choose the version to use. For this post, we use Airflow v2.0.2. We also include code for Airflow v1.10.12.

  1. In the Dag code in the Amazon S3 section, specify the Amazon Simple Storage Service (Amazon S3) bucket where Amazon MWAA can find the DAGs, plugins.zip file, and requirements.txt file.

Airflow configuration for XGBoost

An XGBoost model requires a specific configuration in the Managed Airflow environment. The core.enable_xcom_pickling parameter must be set to True. The reason for this is the trained XGBoost model needs to be serialized in order to save it as a file in Amazon S3. Certain Python objects (like datetime) can’t be serialized without converting the Python object hierarchy into a byte stream through a process called pickling.

Requirements.txt file

Upload a requirements.txt file to the Amazon S3 location you specified in the Amazon MWAA setup. To support this demonstration, the requirements.txt file should have the following entries:


Orchestrate an XGBoost ML pipeline

Our ML pipeline is a simplified three-step pipeline:

  1. Data preprocessing using AWS Glue. Real pipelines could require numerous processing steps for data cleaning and featuring engineering. Although Amazon SageMaker Pipelines provides a similar functionality, we use AWS Glue to illustrate how different AWS services or third-party tools and services are orchestrated in a single pipeline.
  2. Train an XGBoost model using a SageMaker training job.
  3. Deploy the trained model as a real-time inference endpoint.

Solution architecture

Our ML pipeline is pictured in the following diagram. We use AWS Lambda to invoke DAGs with a Lambda function. We also use Amazon EventBridge to trigger Lambda functions. For more information, see Tutorial: Schedule AWS Lambda functions using EventBridge.

Stage the AWS Glue script

In our demo, we create the AWS Glue job dynamically using a PySpark script saved in Amazon S3. Copy the glue_etl.py file provided in the source code repo to an Amazon S3 location.

Set DAG configuration values

To keep things simple, we use a config.py file to import any environment-specific configurations rather than define it in the main DAG script. You can view the config.py file in its entirety on GitHub. A best practice is to use AWS Secrets Manager to store configuration and secrets information (as of this writing, AWS Systems Manager Parameter Store isn’t a supported backend on Amazon MWAA). Detailed documentation on how to securely store secrets in AWS Secrets Manager for Amazon MWAA is available here.

Upload the updated config.py file to the DAG directory.

Stage the customer churn training data

The customer churn dataset is mentioned in the book Discovering Knowledge in Data by Daniel T. Larose. It’s attributed by the author to the University of California Irvine Repository of Machine Learning Datasets. The dataset is publicly available and provided in the GitHub repo.

Upload the customer-churn.csv file to the Amazon S3 location you specified in the config.py file.

Construct the DAG

For our demonstration, the DAG consists of four primary sections:

  • Import statements
  • DAG operator configuration
  • DAG task definitions
  • DAG task dependency definition

Import statements

Because Airflow is Python-based, the DAG file is a simple Python file and the modules for Airflow are imported just as they would be for any Python application.

Some services have native Airflow operators available that manage asynchronous API calls and polling to determine success or failure of orchestrated tasks. We recommend using native operators wherever possible. AWS services that don’t have native Airflow operators, like AWS Glue, can still be orchestrated in Airflow using AWS SDKs called from the general PythonOperator.

For nearly all AWS services, the AWS SDK for Python (Boto3) provides service-level access to the APIs. This SDK provides a high degree of control, but also a lower level of abstraction. For ML pipelines using SageMaker, you can use the SageMaker Python SDK. This is a streamlined SDK abstracted specifically for ML experimentation.

The following import statements include general Airflow modules and operators, native Airflow operators for SageMaker, and the Boto3 and SageMaker SDKs:

# Airflow Operators
import airflow
from airflow.models import DAG
from airflow.utils.dates import days_ago
from airflow.operators.python_operator import PythonOperator

# Airflow Sagemaker Operators
from airflow.providers.amazon.aws.operators.sagemaker_training import SageMakerTrainingOperator
from airflow.providers.amazon.aws.operators.sagemaker_endpoint import SageMakerEndpointOperator
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook

# AWS SDK for Python
import boto3

# Amazon SageMaker SDK
import sagemaker
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker.estimator import Estimator
from sagemaker.session import s3_input

# Airflow SageMaker Configuration
from sagemaker.workflow.airflow import training_config
from sagemaker.workflow.airflow import model_config_from_estimator
from sagemaker.workflow.airflow import deploy_config_from_estimator

# Configuration variables
import config

Other import statements are needed to support this demonstration; refer to the GitHub repo for the full code.

DAG operator configuration

The DAG and DAG tasks are defined based on the operators invoked to run each task.

For the AWS Glue task, we invoke the PythonOperator using the SDK for Python to create a client for AWS Glue. To keep the DAG code tidy, we abstract the AWS Glue client code in a helper function called preprocess_glue. We stage the glue_etl.py (referenced in the GitHub repo) in Amazon S3 so it can be loaded when the AWS Glue job is created. See the following code:

def preprocess_glue():
  """preprocess data using glue for etl"""

  # not best practice to hard code location 
  glue_script_location = 's3://{}/{}'.format(config.GLUE_JOB_SCRIPT_S3_BUCKET, config.GLUE_JOB_SCRIPT_S3_KEY)
  glue_client = boto3.client('glue')

  # instantiate the Glue ETL job
  response = glue_client.create_job(
    Description='PySpark job to extract the data and split in to training and validation data sets',
      'MaxConcurrentRuns': 2
      'Name': 'glueetl',
      'ScriptLocation': glue_script_location,
      'PythonVersion': '3'
      '--job-language': 'python'
  # execute the previously instantiated Glue ETL job
  response = glue_client.start_job_run(
      '--S3_SOURCE': config.DATA_S3_SOURCE,
      '--S3_DEST': config.DATA_S3_DEST,
      '--TRAIN_KEY': 'train/',
      '--VAL_KEY': 'validation/' 

We create a helper function that returns the ARN of the SageMaker role:

def get_sagemaker_role_arn(role_name, region_name):
    iam = boto3.client("iam", region_name=region_name)
    response = iam.get_role(RoleName=role_name)
    return response["Role"]["Arn"]

The XGBoost estimator requires the SageMaker role, container image, and hyperparameters, which we collect using a hook into SageMaker:

hook = AwsBaseHook(aws_conn_id="airflow-sagemaker", client_type="sagemaker")
sess = hook.get_session(region_name=config.REGION_NAME)
sagemaker_role = get_sagemaker_role_arn(config.SAGEMAKER_ROLE_NAME, config.REGION_NAME)
container = get_image_uri(sess.region_name, "xgboost")
hyperparameters = {

With the parameters defined, we can create the estimator object:

xgb_estimator = Estimator(

This estimator object is an input parameter into the training configuration. We need to define other training parameters:

# create unique name with guid

# define S3 locations for training & validation data processed using Glue
sagemaker_training_data = s3_input(config.SAGEMAKER_TRAINING_DATA_S3_SOURCE, content_type=config.SAGEMAKER_CONTENT_TYPE)
sagemaker_validation_data = s3_input(config.SAGEMAKER_VALIDATION_DATA_S3_SOURCE, content_type=config.SAGEMAKER_CONTENT_TYPE)

sagemaker_training_inputs = {
  'train': sagemaker_training_data,
  'validation': sagemaker_validation_data

Let’s take a closer look at the arguments for sagemaker_training_inputs. The XGBoost algorithm supports both LIBSVM and CSV text formats for training and validation datasets. However, LIBSVM is supported by default. This means that we must specify CSV explicitly so XGBoost interprets our data correctly. The content type is set as text/csv in our custom DAG configuration file. We use CSV because it’s the most common data file format familiar to all ML practitioners.

With these parameters defined, we can create the training config object:

training_config = training_config(

For native Airflow SageMaker operators, you can construct and reference well-defined configuration objects when invoking the operators.

The next configuration definition is for the SageMaker endpoint:

# create unique name using guid

For this simple pipeline, we use the deploy_config_from_estimator API option in the SageMaker SDK to export an Airflow deploy config directly from the SageMaker XGBoost estimator (the endpoint_name parameter must be 63 characters or less):

endpoint_config = deploy_config_from_estimator(

For more information about how we set up the model training and deployment configuration, including how we used the SageMaker SDK sagemaker.workflow.airflow APIs, see the GitHub repo.

With the operator configuration complete, we’re ready to put it all together to define our DAG.

DAG task definitions

For the XGBoost model training task, we invoke the SageMakerTrainingOperator. For the endpoint deployment task, we invoke the SageMakerEndpointOperator. It’s important to note the separation of concerns: we create a model using the SageMakerModelOperator but configure the SageMaker endpoint using the SageMakerEndpointConfigOperator. This provides added granular control over the creation and deployment of the model. See the following code:

args = {"owner": "airflow", "start_date": airflow.utils.dates.days_ago(2), 'depends_on_past': False}

with DAG(
) as dag:
    process_task = PythonOperator(

    train_task = SageMakerTrainingOperator(
      task_id = "train",
      config = training_config,
      aws_conn_id = "airflow-sagemaker",
      wait_for_completion = True,
      check_interval = 60, #check status of the job every minute
      max_ingestion_time = None, #allow training job to run as long as it needs, change for early stop

    endpoint_deploy_task = SageMakerEndpointOperator(
      task_id = "endpoint-deploy",
      config = endpoint_config,
      aws_conn_id = "sagemaker-airflow",
      wait_for_completion = True,
      check_interval = 60, #check status of endpoint deployment every minute
      max_ingestion_time = None,
      operation = 'create', #change to update if you are updating rather than creating an endpoint

DAG task dependency definition

After we define the tasks, we set the dependencies of the tasks. Airflow implements the right shift logical operator (>>) to define downstream dependencies and the left shift logical operator (<<) to define upstream dependencies. In our example, we only define downstream dependencies:

# set the dependencies between tasks
process_task >> train_task >> endpoint_deploy_task

When the completed DAG is uploaded to the designated Amazon S3 location, Amazon MWAA automatically ingests the DAG. The graph view visually shows the task dependencies. You can trigger the DAG manually from the console during iterative testing, or as we described earlier, from an external source such as EventBridge and a Lambda function. Each task is highlighted depending on the stage of completion, as shown in the following screenshot. Dark green indicates successful completion of the task.

Test the deployed endpoint

After the endpoint-deploy task is complete, we can view the endpoint on the SageMaker console. The SageMaker endpoint is a real-time inference endpoint. SageMaker takes care of deploying, hosting, and exposing the HTTPS endpoint.

We can test the deployed endpoint with a SageMaker notebook.

Follow these steps to set up a SageMaker notebook environment:

  1. Launch a SageMaker notebook instance.
  2. On the Notebook instances page, open your notebook instance by choosing either Open JupyterLab for the JupyterLab interface or Open Jupyter for the classic Jupyter view.
  3. Choose Upload to import the test notebook available in the GitHub repo.

Prepare a test sample

We use Pandas DataFrames to create a test dataset out of the customer churn dataset that was used for training. For the test dataset, we must drop the label column, which is the first column. We also take a random sample of the dataset using the Pandas DataFrame sample method.

Request inferences

Now that we have our sampled test data, we use the Boto3 library to create a SageMaker runtime client. We use the client when we invoke our endpoint, pass it test data, and receive an inference value.


You can use Amazon MWAA to orchestrate and automate complex ML pipelines from the data processing stage through model training and endpoint deployment. You can set special configuration options in the Amazon MWAA environment to support popular ML frameworks like XGBoost.

In this post, we demonstrated how to dynamically create and run an AWS Glue job to preprocess training and validation data. We showed how to construct the DAG to support this ML pipeline, including the import statements, the DAG operator configuration, the DAG task definitions, and the DAG dependency definition. We demonstrated the difference between using native Airflow operators vs. invoking AWS SDK API calls from a generic PythonOperator.

Amazon MWAA is a highly versatile orchestration tool that enterprises can use to operationalize and scale their ML capabilities.

About the authors

Justin Leto is a Sr. Solutions Architect at Amazon Web Services with specialization in big data analytics and machine learning. His passion is helping customers achieve better cloud adoption. In his spare time, he enjoys offshore sailing and playing jazz piano. He lives in Manhattan with his wife Veera.



David Ehrlich is a Machine Learning Specialist at Amazon Web Services. He is passionate about helping customers unlock the true potential of their data. In his spare time, he enjoys exploring the different neighborhoods in New York City, going to comedy clubs, and traveling.




Shreyas Subramanian is a AI/ML specialist Solutions Architect, and helps customers by using Machine Learning to solve their business challenges using AWS services.

Read More