Towards Behavior-Driven AI Development

Towards Behavior-Driven AI Development

Figure 1: Behavior-driven AI development centers model iteration on evaluating and improving specific real-world use cases.

It has never been easier to prototype AI-driven systems. With a bit of programming knowledge and a couple of hours, you can spin up a chatbot for your notes, a text-based image editor, or a tool for summarizing customer feedback. But play around with your prototype for a bit, and you might find that it doesn’t work as well as you first expected. Your system might make up facts or respond with racist suggestions. How would you evaluate your model and predict its performance in deployment?

The canonical process for benchmarking AI systems revolves around model-centric metrics. Calculate a metric (F1-score, precision, etc.), and if it increases, you are going in the right direction. But these metrics are oversimplified objectives that sand away the complexity of model behavior and cannot fully represent a model’s performance. A metric may tell you how well your model can predict the next word in a sentence, but it won’t tell you how factually accurate, logical, or fair your model is across diverse, real-world use cases. Generative AI systems such as ChatGPT or Stable Diffusion make evaluation even more challenging since there are no well-defined metrics that can summarize their performance.

When creating deployed AI products, practitioners instead focus on the specific use cases their customers have and whether or not their models are fulfilling them. In interviews with 18 AI practitioners, we found that they constantly collect user feedback and develop “golden test sets” of behaviors that they expect deployed models to have. We term this behavior-driven AI development, a development process focused on evaluating and updating models to improve performance on real-world use cases. While chatbot A might sound more human-like, a practitioner will deploy chatbot B if it produces concise and accurate answers that customers prefer.

The landscape of AI evaluation tools primarily revolves around model-centric metrics that do not capture important behaviors like these chatbot characteristics. While there are specific tools for behavior-driven development, such as fairness toolkits and robustness analysis libraries, practitioners end up cobbling together disparate tools into ad-hoc scripts or computational notebooks that are hard to maintain and reproduce.

I believe that there are a set of abstractions that can unify AI evaluation in line with model use cases in practice. This philosophy revolves around model behaviors: metrics summarizing patterns of output on subgroups of instances. This simple concept can encode any model evaluation or analysis, from fairness audits to language model hallucinations. We show what this can look like with Zeno, an interactive platform we built for behavior-driven development that supports interactive data exploration, slicing, and reporting. By investigating their own models using Zeno, practitioners have been able to pinpoint significant and actionable issues such as biases and systematic failures. 

What is model behavior?

The dictionary describes behavior as anything that an organism does involving action and response to stimulation. In the case of AI systems, model behavior is a specific pattern of output for a semantically meaningful subgroup of input data (stimulus). By semantically meaningful, I mean subgroups that can be described with human-interpretable concepts, such as “audio with noise in the background” or “people who identify as women.” Similarly, a pattern of output could be “high audio transcription error” or “low loan approval rate.” 

Behaviors can be quantified as metrics on subgroups of data, often using the same metrics as are used for model-centric evaluation. But unlike summary metrics across an entire dataset, metrics in behavior-centric development quantify specific patterns of behavior, like how often an image generation model produces unintelligible text. Tests of model behaviors are like exams for specific subjects, while summary metrics resemble IQ tests.

Figure 2. How model behaviors are defined from a dataset. Behaviors are subgroups of data (typically defined by combinations of metadata) quantified by a specific metric. For the example behavior of “blurry text” from a text-to-image model, a metadata column for “images with text” could be used to create a subgroup on which a metric measuring the clarity of text can be calculated.

Model behaviors are a relatively simple concept, but encoding behaviors can be challenging in practice. Practitioners may not have enough data to validate or fix important model behaviors and have to collect or generate more data. If they have extensive data, they need ways to subdivide it into meaningful groups of instances – how do I find all images that have text? Lastly, for each subgroup, practitioners have to derive the appropriate metrics to quantify the prevalence of behavior – how do I detect blurry text? Succinctly, behavior-driven development requires sufficient data that is representative of expected behaviors and metadata for defining and quantifying the behaviors.

A platform for behavior-driven AI development

The beauty of a behavior-based framing on AI development is that it is still data and model agnostic. While the specific behaviors for each ML task will be vastly different, subgroups of data and metrics are universal concepts.

To test this theory, we built a platform for behavior-driven AI development called Zeno. Zeno is a platform that empowers users to explore data and model outputs, interactively create subgroups of data, and calculate and quantify model behaviors. Zeno consists of a Python API for scaffolding the data needed for analysis and a user interface for interactively creating subgroups and evaluating behaviors.

Figure 3. The Zeno interface shown for the Imagenette dataset and image classification. The right side has the instance view showing the input images and model outputs. The left side shows distributions for the dataset’s metadata, which has been interactively filtered to show images of English Springer Spaniels.

The Python API is a set of decorator functions (wrappers on user-defined functions) that can be used to plug in ML models and derive metadata features and metrics from input data. Since the decorators are generic wrappers, Zeno supports any Python-based model, processing function, or metric. Zeno preprocesses the input data with these functions, which it passes into the UI for analysis.

Zeno’s UI is the primary interface for behavior-driven evaluation. It allows users to interactively explore and filter their data, create slices, calculate metrics, and create exportable visualizations. On the right side of the UI is Zeno’s instance view, where users can explore the raw data on which the model is being evaluated. In addition to the standard list view, users can also see the data in a table or a 2D scatterplot representation. The left side of the interface holds the metadata panel. All the metadata columns that either came with the dataset or were generated with the Python API have their distributions displayed in the panel. Users can interactively filter the distributions to update the instance view and create named subgroups.

The UI also has a report page for creating interactive summary visualizations of behaviors. For example, a user could create a bar chart comparing the performance of three models across ten different slices. Or they could create a line chart showing how a model performs on data slices from each day of data. These visualizations can be exported or shared directly with other stakeholders.

Figure 4: With Zeno, users can interactively filter their data to create slices and calculate subgroup metrics. They can also use the 2D projection to find new areas of data where their model is underperforming. In this example, a user is exploring the CIFAR-10 classification model. They first filter the dataset to compare low versus high brightness images, finding a significant difference in accuracy between the two groups. They then find a group of instances with high error in the projection view, which is mostly made up of birds in the sky being misclassified as airplanes.

Case Studies

We have worked with various ML practitioners to apply Zeno to the models and tasks on which they work. Using Zeno, practitioners found significant model issues and areas for improvement, including gender biases and regional model disparities.

Audio transcription. This first case study I ran myself after I heard that OpenAI released a new speech-to-text model, Whisper, with state-of-the-art performance. I was curious how the model compared to some existing off-the-shelf transcription models. Instead of looking at aggregate metrics, I ran the models on the Speech Accent Archive dataset, which has speakers worldwide saying the same phrase. By filtering the dataset’s extensive metadata, I found that the models perform worse for English speakers who learned the language later in life and speakers from countries where English is not the native language.

Figure 5. (left) The average word error rate (WER) for both models across different ages when participants started learning English. (right) The average WER of the Silero and Whisper transcription models across speakers from different continents.
Charts exported directly from the Zeno Report UI. 

Cancer classification. In another case study, we worked with a researcher who wanted to improve a breast cancer classifier for mammogram images. Since the data was anonymized and lacked meaningful metadata, the practitioner wrote dozens of functions using a Python library to extract meaningful metadata features. By exploring the distributions, they found that images with higher “entropy” correlating with denser breast tissue had a significantly higher error rate than images with lower entropy, or less dense, tissue. This finding matches performance differences in human radiologists, who also perform worse for images of denser breast tissue since it makes it harder to detect lesions.

Low density (4937)
entropy < 2.75 &&
gray level variance < 2.5
High density (656)
entropy > 2.75 &&
gray level variance > 2.5
AUC 0.86 0.76
Figure 6. The breast cancer classification model performed significantly worse for high-density images (described by high entropy and gray level variance metadata levels) compared to the low-density images. (left, low density, right, high density).

Image generation. Models with complex outputs often do not have clearly defined metrics, including text-to-image generation models such as DALL*E and Stable Diffusion. We can instead look at metrics that measure specific behaviors. In this example, a practitioner we worked with was exploring the DiffusionDB dataset, which has over two million prompt-image pairs from the Stable Diffusion model. The dataset also has metadata for how NSFW or inappropriate the prompts and images are. This data was used to derive an “average NSFW” metric, which can show us interesting potential biases in the model. For example, the participant compared the images generated using prompts with the word “boy” versus “girl” and found that prompts with “girl” generated images with a significantly higher NSFW level than prompts with “boy”, showing potential biases in the types of images created by the model.

Figure 7. Given similar or less inappropriate prompts, the images generated with stable diffusion are much more inappropriate (NSFW) for prompts with “girl” or “woman” than “boy” or “man”.
Charts exported directly from the Zeno Report UI. 

Discussion and Opportunities

Model iteration is still a primarily reactive process of finding and defining behaviors after a model has been deployed and the customer complaints start rolling in. There remains significant room for improving this process, from making it easier to ideate model behaviors to tracking model changes over time.

Discovering behaviors. While practitioners often need a model to discover the behaviors the model should have, methods for defining expected model behaviors before deployment can prevent serious real-world model issues.  For example, crowdsourcing techniques for eliciting potential edge cases could preemptively catch model errors. Algorithmic methods that find clusters of data with high error have also shown promise for surfacing problematic behaviors.

Data discovery and generation. Having high-quality, representative data remains a persistent obstacle for behavioral evaluation. In some domains with ample data, such as natural images, methods like Stable Diffusion have shown promise for generating new data for evaluation or training. In less data-rich domains, techniques for searching through large unlabeled datasets, such as text-based image search, can surface valuable data for evaluation and retraining. It is also challenging to derive metadata from instances for creating subgroups and calculating metrics. While it can be easy to generate metadata for simple concepts like “image brightness,” many behaviors are defined by complex metadata such as “images with a person wearing clear glasses” that cannot be encoded by a simple function. Foundation models have shown some promise in using text-based descriptions to generate complex metadata and metrics.

Model comparison. Models are almost never one-off jobs and can be updated daily or weekly. While it is easy to compare aggregate metrics, it can be challenging to compare model performance in behavior-driven development. To pick between models, users may have to compare dozens of behaviors and qualitative insights. Improved visual encodings or intelligent recommendations of model differences could help users make informed decisions and deploy the right models.

Fixing behaviors. Discovering and encoding behaviors is one thing, but fixing behaviors is another massive challenge. A common approach to fixing issues is to gather more data and retrain the model, but this process can lead to catastrophic forgetting and regressions. There are recent techniques that align well with behavior-driven development, such as slice-based learning, which can selectively fix model behaviors without new data.


There is significant excitement for this new era of AI systems. But along with their growing capability, the complexity of their behavior is also increasing. We need powerful tools to empower behavior-driven development and ensure we build intelligent systems that align with human values. Zeno provides a general-purpose platform that empowers users to do this deep evaluation across the diverse tasks of modern AI. Learn more about Zeno at, read the full paper, or reach out if you would like to use Zeno for your models!


I’d like to thank Will Epperson, Jason I. Hong, Yi-Cheng Huang, Misha Khodak, Adam Perer, Venkat Sivaraman, Ameet Talwalkar, and Kristen Vossler for their thoughtful feedback and advice.

Read More

RLPrompt: Optimizing Discrete Text Prompts with Reinforcement Learning

RLPrompt: Optimizing Discrete Text Prompts with Reinforcement Learning

Figure 1: Overview of RL Prompt for discrete prompt optimization. All language models (LMs) are frozen. We build our policy network by training a task-specific multi-layer perceptron (MLP) network inserted into a frozen pre-trained LM. The figure above illustrates generation of a prompt (left), example usages in a masked LM for classification (top right) and a left-to-right LM for generation (bottom right), and update of the MLP using RL reward signals (red arrows).

TL;DR: Prompting enables large language models (LLMs) to perform various NLP tasks without changing the model. Discrete prompts have many desirable properties, but are difficult to optimize. We propose an efficient approach using reinforcement learning, which shows superior performance and facilitates rich interpretations and analyses. You can easily adapt it for your own tasks using our code base here.

Prompting has emerged as a promising approach to solving a wide range of NLP problems using large pre-trained language models (LMs), including left-to-right models such as GPTs (Radford et al., 2019; Brown et al., 2020) and masked LMs such as BERT (Devlin et al., 2019), RoBERTa (Liu et al., 2019), etc.

Compared to conventional fine-tuning that expensively updates the massive LM parameters for each downstream task, prompting concatenates the inputs with an additional piece of text that steers the LM to produce the desired outputs. A key question with prompting is how to find the optimal prompts to improve the LM’s performance on various tasks, often with only a few training examples.

Most existing work resorts to tuning soft prompt (e.g., embeddings) which falls short of interpretability, reusability across LMs, and applicability when gradients are not accessible. Discrete prompt, on the other hand, is difficult to optimize, and is often created by “enumeration (e.g., paraphrasing)-then-selection” heuristics that do not explore the prompt space systematically.

Instead, we propose RLPrompt, an efficient discrete prompt optimization approach with reinforcement learning (RL). RLPrompt is flexibly applicable to different types of LMs (e.g., BERT and GPTs) for both classification and generation tasks. Experiments on few-shot classification and unsupervised text style transfer show superior performance over a wide range of existing finetuning or prompting methods. 

Interestingly, the resulting optimized prompts are often ungrammatical gibberish text; and surprisingly, those gibberish prompts are transferable between different LMs to retain significant performance, indicating LMs may have grasped shared structures for prompting, but do not follow human language patterns.

Discrete Prompt Optimization with RL

This paper presents RLPrompt, a new discrete prompt optimization approach based on reinforcement learning (RL). This approach brings together a wide range of desirable properties for efficient use on diverse tasks and LMs (see the table below). 

Table 1: RLPrompt unites the desirable properties of a wide range of previous prompt optimization approaches

Crucially, rather than directly editing the discrete tokens, which has been difficult and inefficient, RLPrompt trains a policy network that generates the desired prompts. Discrete prompt optimization thus amounts to learning a small number of policy parameters which we set as an MLP layer inserted into a frozen compact model such as distilGPT-2 (HuggingFace, 2019). We describe the specific formulations in Section §2.1-2.3 of our paper.

This formulation also allows us to employ off-the-shelf RL algorithms (e.g., Guo et al., 2021) that learn the policy with arbitrary reward functions—defined either with available data (e.g., in few-shot classification) or other weak signals when no supervised data is accessible (e.g., in controllable text generation).

Reward Stabilization 

On the other hand, RL for prompt optimization poses new challenges to learning efficiency: the large black-box LM presents a highly complex environment that, given the prompt (i.e., actions), goes through a long series of complex transitions (e.g., reading the input and inferring the output) before computing the rewards. This makes the reward signals extremely unstable and hard to learn from. 

To overcome this difficulty, we propose two simple yet surprisingly effective ways to stabilize the rewards and improve the optimization efficiency.

  1. Normalizing the training signal by computing the z-score of rewards for the same input.
  2. Designing piecewise reward functions that provide a sparse, qualitative bonus to desirable behaviors (e.g., certain accuracy on certain class).

We describe more details in Section §2.4 of our paper.


We evaluate our approach on both classification (in the few-shot setting) and generation (unsupervised text style transfer), and perform rich analyses for new insights on LM prompting. We describe implementation details such as reward function design in Section §3 our paper, and publish the code at our Github codebase.

Few-Shot Text Classification

For few-shot classification, we follow previous work and experiment on popular sentiment and topic classification tasks, using 16 examples per class for both training and validation (Perez et al., 2021). Results using RoBERTa-large (left table below) show our approach improving over a wide range of fine-tuning and prompting methods, and is as efficient to optimize as similar methods that tune soft prompts (e.g., right figure below). We report detailed dataset-level results in Section §3.1 of our paper.

Table 1: Average accuracy for few-shot text classification across all tested datasets. All methods use RoBERTa-large for fine-tuning or prompting.
Figure 2: Comparison of our method (orange) and BlackBox (BB) Tuning (Sun et al., 2022) (blue) in terms of training efficiency. The solid curves are the mean and the shaded regions are the max. and min. test accuracies over 5 trials.

Unsupervised Text Style Transfer

For text style transfer, we evaluate on the popular Yelp sentiment transfer dataset (Shen et al., 2017) using popular automatic metrics for content preservation, style accuracy, and fluency, and report their sentence-level joint product (J(cdot)) below. Our full paper also includes few-shot experiments on the Shakespeare (Xu et al., 2012) dataset and human evaluations.

Results using GPT-2 (left table below) show our method outperforms or competes with various fine-tuning and prompting baselines, including DiRR (Liu et al., 2021c) which expensively fine-tunes all parameters of a GPT-2 model. Ablation study (right figure below) shows that our proposed reward normalization technique is crucial to optimization success. We describe the full evaluation results in Section §3.2 of our paper. 

Table 2: Automatic evaluation of our method vs. baselines on the Yelp (Shen et al., 2017) sentiment transfer dataset. (J(cdot)) is our main metric which measures the average joint sentence-level scores of content preservation, style accuracy, and fluency. Numbers in (parentheses) are standard deviations across 3 sets of prompts.
Figure 3: Comparison of our method with (orange) and without (purple) z-score reward normalization. The format is the same as Figure 2.


Optimal Prompts Don’t Follow Human Language

The resulting discrete prompts also facilitate rich interpretations and analyses for new insights into LM prompting. In particular, the optimized prompts, though inducing strong task performance, tend to be gibberish text without clear human-understandable meaning (e.g., table below), echoing recent research (Webson and Pavlick, 2021; Zhao et al., 2021; Prasad et al., 2022) that LMs making use of prompts do not necessarily follow human language patterns. 

Table 3: Comparison of our method (RLPrompt) with manually-written (Manual) prompts for text style transfer performance on Yelp (Shen et al., 2017). For the manual prompts, we take one from Reif et al. (2021) and write two more for this experiment. (J(cdot)) is the main metric introduced in Table 2. All outputs are generated using GPT-2-xl and metrics are averaged over 5 runs.

Learned Prompts Transfer Trivially Across LMs

Perhaps surprisingly, those gibberish prompts learned with one LM can be used in other LMs for significant performance, indicating that those different pre-trained LMs have grasped shared structures for prompting (e.g., figures below).

Figure 4: Heatmap of sentiment analysis performance with transferred discrete prompts of 2 tokens. The columns represent the models used to learn the prompts, and the rows represent the models we perform classification with. Brighter color represents higher accuracy.
Figure 5: Heatmap of text style transfer performance with transferred discrete prompts. The columns represent the models used to learn the prompts, and the rows represent the models we perform text generation with. Manual and Random refer to manual prompts and random tokens, respectively. Brighter color represents better joint score (J(cdot)).


We have presented RLPrompt, an efficient and flexible approach for discrete prompt optimization using RL, which improves over a wide range of fine-tuning and prompting methods in experiments on few-shot classification and unsupervised text style transfer.

Analysis reveals that strong optimized prompts are incoherent but transferable between LMs for remarkable performance. The observation opens up many promising possibilities for prompting, such as learning prompts cheaply from smaller models and performing inference with larger models. We are excited to explore further.

Read More

Bottom-up Top-Down Detection Transformers For Open Vocabulary Object Detection

Bottom-up Top-Down Detection Transformers For Open Vocabulary Object Detection

We perform open vocabulary detection of the objects mentioned in the sentence using both bottom-up and top-down feedback.

Object detection is the fundamental computer vision task of finding all “objects” that are present in a visual scene. However, this raises the question, what is an object? Typically, this question is side-stepped by defining a vocabulary of categories and then training a model to detect instances of this vocabulary. This means that if “apple” is not in this vocabulary, the model does not consider it as an object. The problem gets even worse when we try to integrate these object detectors into real household agents. Imagine that we want a robot that can pick up “your favorite green mug from the table right in front of you”. We want the robot to specifically detect the “green mug” which is on the “table in front of you” and not any other mug or table. Obviously, treating descriptions such as “green mug from the table right in front of you” as separate classes in the detector’s vocabulary cannot scale; one can come up with countless variations of such descriptions.

In light of this, we introduce Bottom-up Top-Down DEtection TRansformer (BUTD-DETR pron. Beauty-DETER), a model that conditions directly on a language utterance and detects all objects that the utterance mentions. When the utterance is a list of object categories, BUTD-DETR operates as a standard object detector. It is trained from both fixed vocabulary object detection datasets and referential grounding datasets which provide image-language pairs annotated with the bounding boxes for all objects referred to in the language utterance. With minimal changes, BUTD-DETR grounds language phrases both in 3D point clouds and 2D images.

BUTD-DETR conditions on language and can detect objects that SOTA Object detectors frequently miss.

No box bottleneck: BUTD-DETR decodes object boxes directly by attending to language and visual input instead of selecting them from a pool. Language-directed attention helps us localize objects that our bottom-up, task-agnostic attention may miss. For example, in the above image, the hint of “clock on top of the shelf” suffices to guide our attention to the right place, though the clock is not a salient object in the scene. Previous approaches for language grounding are detection-bottlenecked: they select the referred object from a pool of box proposals obtained from a pre-trained object detector. This means that if the object detector fails, then the grounding model will fail as well.

How does it work?

BUTD-DETR Architecture: Conditioning on visual, language and object detection stream, our model decodes boxes and spans for all mentioned objects.

The input to our model is a scene and a language utterance. A pre-trained object detector is used to extract box proposals. Next, the scene, boxes, and utterance are encoded using per-modality-specific encoders into visual, box, and language tokens respectively. These tokens are contextualized by attending to one another. The refined visual tokens are used to initialize object queries that attend to the different streams and decode boxes and spans.

Augmenting supervision with Detection prompts

Object Detection as Referential Language Grounding using detection prompts: We can generate additional grounding annotations/examples by chaining multiple object category tokens.

Object detection is an instance of referential language grounding in which the utterance is simply the object category label. We cast object detection as the referential grounding of detection prompts: we randomly sample some object categories from the detector’s vocabulary and generate synthetic utterances by sequencing them, e.g., “Couch. Person. Chair.”, as shown in the figure above. We use these detection prompts as additional supervision data: the task is to localize all object instances of the category labels mentioned in the prompt if they appear in the scene. For the category labels with no instances present in the visual input (e.g. “person” in the above figure), the model is trained to not match them to any boxes. In this way, a single model can perform both language grounding and object detection simultaneously and share the supervision information.


BUTD-DETR achieves a large boost in performance over state-of-the-art approaches across all 3D language grounding benchmarks (SR3D, NR3D, ScanRefer). Moreover, it was the winning entry in the ReferIt3D challenge, held at the ECCV workshop on Language for 3D Scenes. On 2D language grounding benchmarks, BUTD-DETR performs on par with state-of-the-art methods when trained on large-scale data. Importantly, our model converges twice as fast compared to state-of-the-art MDETR, mainly because of the efficient deformable attention which we used with our 2D model.

Quantitative Results across 3D Benchmarks: Our model significantly outperforms all prior methods across all established 3D benchmarks.

We show the qualitative results of our model in the video at the beginning of the blog. For more visualizations, please refer to our project page and paper.

What’s next?

Our method detects all objects mentioned in the sentence — however, this assumes that the user needs to mention all relevant objects in the sentence. This is not desirable in general — for example, in response to “make breakfast” we would like our model to detect all the relevant ingredients like bread, eggs etc., even if they are not mentioned in the sentence. Additionally, while our architecture works for both 2D and 3D language grounding with minimal changes, we do not share parameters between the two modalities. This prevents transferring representations across modalities, which would be particularly helpful for the low-resource 3D modality. Our ongoing work is investigating these two directions.

We have released our code and model weights on GitHub, making it easy to reproduce our results and build upon our method. If you are interested in a language-conditioned open vocabulary detector for your project, then give BUTD-DETR a run! For more details, please check out our project page and paper.

Read More

Causal Confounds in Sequential Decision Making

Causal Confounds in Sequential Decision Making

A standard assumption in sequential decision making is that we observe everything required to make good decisions. In practice however, this isn’t always the case. We discuss two specific examples (temporally correlated noise (a) and unobserved contexts (c)) that have stymied the use of IL/RL algorithms (in autonomous helicopters (b) and self-driving (d)). We derive provably correct algorithms for both of these problems that scale to continuous control problems.

Reinforcement Learning (RL) and Imitation Learning (IL) methods have achieved impressive results in recent years like beating the world champion at Go or controlling stratospheric balloons. Usually, these results are on problems where we either a) observe the full state or b) are able to faithfully execute our intended actions on the system. However, we frequently have to contend with situations where this isn’t the case: our self-driving car might miss a person’s hand gestures or persistent wind might make it difficult to fly our quadcopter perfectly straight. These sorts of situations can cause standard IL approaches to perform poorly ([1], [2]). In causal inference, we call a random variable that we don’t observe that influences a relationship we’d like to model a confounder. Using techniques from causal inference, we derive provably correct and scalable algorithms for sequential decision making in these sorts of confounded settings.

We’re going to be focused mostly on imitation learning (see our last blog post for more details). In IL, we observe trajectories (sequences of states and actions) generated by some expert policy (pi_E) and want to recover the policy that generated them. The expert could be a) an experienced driver if we’re trying to build a self-driving car, b) a user interacting with a recommender system we’re attempting to model, or c) scraped internet text used to train a large language model. We’re now going to discuss two issues of causation that are difficult for standard IL algorithms to handle. I’ll be drawing upon material from our ICML ’22 paper Causal Imitation under Temporally Correlated Noise and our NeurIPS ’22 paper Sequence Model Imitation Learning with Unobserved Contexts.

Issue 1: Temporally Correlated Noise

The first problem we’ll consider is when there’s temporally correlated noise (TCN) affecting pairs of expert actions. For example, if our expert was a human quadcopter pilot, the TCN could be persistent wind [1]. Let’s assume that the expert intended to fly completely straight but was unable to do so because of the persistent wind, producing swerving trajectories. If the learner ignores the fact that the wind was the root cause of these swerves, they might attempt to reproduce the swerving behavior of the expert, causing them to deviate even further at test time due to the continued influence of the wind.

In effect, adding correlated noise to both (X=s_t) and (Y=a_t) changes the apparent slope of the line we’re trying to fit, making standard regression-based approaches no longer consistent.

What’s happened is that the TCN has created a spurious correlation (e.g. being a little on the left is often followed by going more to the left) which the learner has latched onto. What we’d like is to instead learn a policy that can fly as straight as the expert did in the demonstrations.

Let’s try and formalize what’s happening in the above example via graphical models. In a Markov Decision Process (MDP), the learner (pi) responds to the observed state (s_t) via taking an action (a_t) and the MDP transitions via dynamics (mathcal{T}: s times a rightarrow s ) to the next state (s_{t+1}).

Adding in the TCN corresponds to adding in a common cause of a pair of actions that’s not part of the observed state ((u_t) below).

Now, let’s dig into what goes wrong with imitation learning under TCN. Let’s say we apply a standard algorithm like behavioral cloning: direct regression from states ((X = s_t)) to actions ((Y=a_t)). The TCN travels through the dynamics to influence both the inputs and outputs of our regression procedure, creating spurious correlations that our learner might latch onto, leading to poor test-time performance. We visualize this effect in red below.

In effect, the TCN makes some elements of state (s_2) look more correlated with action (a_2) than they would otherwise. Attempting to minimize training error, the learner will use this correlation to their advantage but will end up swerving unnecessarily at test time as a result.

Filtering out TCN with Instrumental Variable Regression

To learn well under TCN, we’re going to utilize the idea of a natural experiment from causal inference, in which we’re able to use variation in the observational data to simulate the effect of an intervention / randomized control trial (RCT). Interestingly enough, this idea was one of the central contributions of the winners of the Nobel Prize in Economics in 2021. Let’s first take a look at this idea in the single shot setting. Assume that we’re trying to predict from (X) to (Y), both of which are affected by an unobserved confounder (U):

Notice the random variable (Z) both (1) affects (X) and (2) is independent from (U). (Z) is therefore an independent source of variation in (X). We call such variables instruments. Under some assumptions (see our paper for details), one can condition on an instrument to de-noise the inputs to our regression procedure and recover a causal relationship. So, instead of regressing from $$X rightarrow Y,$$ one regresses from $$X|Z rightarrow Y|Z.$$

We’ll skip the derivation here but intuitively, conditioning on an independent source of variation, (Z), “washes out” the effect of the confounder, (U). The variation in (Z) therefore acts as a sort of “natural experiment,” simulating an RCT without requiring a true intervention.

Now, you might well be wondering what a valid instrument in the sequential setting is? Well, given the past is independent of future confounding, we can use past states as an instrument! Graphically,

Many econometrics papers will spend quite a lot of text arguing something is a valid instrument (i.e. satisfies (1) and (2)) but here, the arrow of time gives us an instrument for free. Algorithmically, instrumental variable regression for imitation learning under TCN corresponds to solving $$ s_t | s_{t-1} rightarrow a_t | s_{t-1}.$$

So, one first fits models of the left and right-hand sides of the above expression and then regresses between them. We give two algorithms for doing so efficiently in our paper. On simulated control tasks, we observe that our approaches (in teal and orange) are able to recover a policy that matches expert performance, while regression-based behavioral cloning struggles to do so.

So, putting it all together, temporally correlated noise between expert actions can create spurious correlations between recorded states and actions, causing standard regression-based imitation learning approaches to produce inaccurate predictors. One can use the “natural experiment” caused by the variation in past states to de-noise the inputs to their regression procedure and recover causally correct predictors that do not suffer from spurious correlations.

Issue 2: Unobserved Contexts

We’re now going to consider a different kind of confounder. Let’s assume we’re trying to imitate an expert who observes some context, (c), which the learner does not. For example, if we were trying to imitate an expert driver but our sensors don’t pick up a stop sign on the side of the road, (c) could refer to this side information. Now, this problem is in general impossible to solve: if we don’t see the sign, how would we know to stop? However, if we pay attention to our surroundings and accumulate information over time, we might be able to eventually filter out what we missed: if we observe all the other cars around us slowing down for a few timesteps, we might rationally conclude that we should as well, even though we never saw the stop sign.

More formally, we’re in a Partially Observed Markov Decision Process (a POMDP), which we can visualize as follows:

So, the effect of the stop sign ((c)) is reflected in the states ((s_1, s_2)). There’s two key differences between this graphical model and the TCN graphical model. First, the confounder directly affects the state. Second, the confounder is constant, rather than time-varying. These two characteristics make it plausible that, given enough state observations, we can figure out what (c) is and act as the expert does.

To enable us to accumulate evidence over time, we’re going to allow our policy to condition its actions on the entire history up to this point. Denoting ( h_t = (s_1, a_1, dots, s_{t}) ), our policy now takes the form $$pi: h_t rightarrow a_t .$$ This means our policy is now a sequence model (e.g. an RNN or a transformer).

Ok, we use sequence models as our policy class and apply behavioral cloning (prediction of the next token from the prefix) and call it a day right? Unfortunately, this often produces policies that naively repeat their own past actions. We call this the latching effect. For example, in the self-driving domain, this can lead to learning a driving policy that begins to turn and then just keeps turning, driving in circles [2]. This is clearly less than ideal.

Modeling Unobserved Contexts with Sequence Models + On-Policy Training

Let’s dig into why the latching effect stymies off-policy imitation learners. The model we learn during training looks like $$pi(a_t |h_t) approx p_(a_t^E| s_1^E, a_1^E, dots, s_t^E),$$ where the superscript (E) denotes that these are expert states and actions. Now at test time, the past actions we pass into our sequence model policy are our own rather than the expert’s. This means we’re trying to sample actions from $$p_(a_t^E| s_1, a_1, dots, s_t).$$ Unless we know that the elements to the right of the conditioning bar are exactly the same (i.e. we’ve perfectly matched the expert policy), we have no reason to believe that the model we’ve learned will generalize to our own induced state distribution. Put differently, we have no reason to believe we’ve learned a policy which will perform well at test-time.

In machine learning terms, we’re facing the problem of policy-induced covariate shift (PICS). Here, our covariates are the history of states and actions the policy uses to predict the next action. PICS is a well known problem in imitation learning and sequential prediction writ large (see last blog post for more details). What’s special about the unobserved context setting is the degree to which it causes problems. Early on in an episode, no learner can do well as they haven’t accumulated enough information to perform well. This means that we have an unavoidable source of covariate shift from preliminary mistakes.

Off-policy training (i.e. training only on the green expert states) might not ensure we generalize well to our own induced state distribution (the orange dashed line). The effect of unobserved contexts is to ensure we don’t do well at first because we don’t have enough information to act properly, making it quite likely that we go off the rails.

Now, let’s assume that the expert’s actions were relatively similar between adjacent timesteps (which is frequently true in practice). This means rather than learning a complex mapping from states to actions, our learner might have instead learned a policy that mostly copies the last action. This was fine when the previous actions were the expert’s. But when they are the learner’s extremely suboptimal initial actions, naive copying is a recipe for disaster. This is what leads to us learning a driving policy that is only capable of turning in circles. A different way of phrasing this point is because the learner was trained only on trajectories from the expert, it treats its own test-time past actions as though they were produced by the expert, leading to unfortunate downstream consequences.

On-policy training (i.e. performing rollouts in the environment and comparing them to expert rollouts) doesn’t suffer from these issues. This is because we’re instead trying to satisfy $$ p_(a_t^E| s_1^E, a_1^E, dots, s_t^E) approx p_(a_t| s_1, a_1, dots, s_t).$$ In words, rather than matching expert actions on expert histories, we’re instead trying to match expert and learner trajectory distributions. We’re making sure we predict well based on our own histories and therefore can’t be surprised at test-time by states we didn’t expect to end up in. We prove in our paper that under certain identifiability conditions, on-policy training of sequence model policies is both necessary and sufficient for the learner to generalize well to their own induced state distribution and successfully imitate the expert.

On-policy training involves rolling out the learner’s current policy (in orange) and minimizing some notion of divergence (in red) to expert trajectories (in green). Because we observe our own induced state distribution, we can’t fool ourselves into thinking our own past actions are like those of the expert.

We also conduct experiments on continuous control tasks with unobserved contexts and show that equipping an off-policy learner (in grey) with access to history can actually hurt its performance, while we see no such effect for an on-policy learner (in orange).


Confounding commonly occurs in real world situations, including those in which we want to make a sequence of decisions. When we don’t filter out or model the effects of what we don’t observe, we can end up making extremely suboptimal decisions. We describe two situations in which there’s clean algorithms for handling confounding in sequential decision making (temporally correlated noise can be filtered out via IVR, unobserved contexts can be modeled via on-policy training of a sequence model). Moving forward, we’re interested in other causal structures (e.g. negative controls or proxy shifts) and applying our techniques to problems where confounders abound (e.g. recommender systems).

If you’re interested in learning more, I recommend you look at the full papers or the attached code:

DISCLAIMER: All opinions expressed in this post are those of the author and do not represent the views of CMU.

Read More

How do Authors’ Perceptions about their Papers Compare with Co-authors’ Perceptions and Peer-review Decisions?

How do Authors’ Perceptions about their Papers Compare with Co-authors’ Perceptions and Peer-review Decisions?

NeurIPS 2021 Author Perception Experiment

Alina Beygelzimer, Yann N. Dauphin, Percy Liang, Jennifer Wortman Vaughan
(NeurIPS 2021 Program Chairs

Charvi Rastogi, Ivan Stelmakh, Zhenyu Xue, Hal Daumé III, Emma Pierson, and Nihar B. Shah

There is a considerable body of research on peer review. Within the machine learning community, there have been experiments establishing significant disagreement across reviewers and across reviewer panels—including at NeurIPS 2021—and active discussions about the state of peer review. But how do author perceptions about their submitted papers match up to the outcomes of the peer-review process and perceptions of other authors? We investigate this question by asking authors who submitted papers to NeurIPS 2021 three questions:

(Q1) [At the time of paper submission] What is your best estimate of the probability (as a percentage) that this submission will be accepted?

(Q2) [At the time of paper submission; to authors submitting two or more papers] Rank your submissions in terms of your own perception of their scientific contributions to the NeurIPS community, if published in their current form.

(Q3) [After preliminary reviews were available to authors] After you read the reviews of this paper, how did your perception of the value of its scientific contribution to the NeurIPS community change (assuming it was published in its initially submitted form)?  

Here are five key findings.

1. How well do authors estimate the probability of acceptance of their papers?

Authors significantly overestimate their papers’ chances of acceptance. When answering Q1, authors were informed that the acceptance rate at NeurIPS over the last 4 years had been about 21%. The acceptance rate at NeurIPS 2021 turned out to be 25.8%. The authors’ responses had a nearly three-fold overestimate, with a median prediction of 70%.

2. Are some sub-groups better calibrated than others?

We examined calibration error across sub-groups, measuring this error in terms of the Brier score (squared loss) and controlling for other confounders. We find that the calibration error of female authors is slightly (but statistically significantly) higher than that of male authors. We also see a trend of miscalibration decreasing with seniority, with authors who were invited to serve as (meta-)reviewers better calibrated than the rest. All sub-groups we examined over-predicted their papers’ chances of acceptance.


3. Among authors with multiple papers, how much do their predictions of acceptance probabilities agree with their own perceived scientific merit?

These two sets of responses are largely in agreement: The strict ranking provided by authors about their perceived scientific merit (Q2) and the strict ranking induced by their predicted acceptance probabilities (Q1) agree for 93% of responses. However, there is a noticeable 7% of responses where the authors think that the peer review is more likely to reject the better of their two papers.

4. How much do co-authors agree on the relative quality of their joint papers?

Strikingly, the amount of disagreement between co-authors in terms of the perceived relative scientific contribution of their papers (Q2) is similar to the amount of disagreement between authors and reviewers! In cases where one paper from an author was ultimately accepted and another rejected, authors rated the rejected paper higher about a third of the time. But looking at pairs of papers with overlapping authors in which both authors provided rankings, the co-authors also disagreed with each other about a third of the time. While there are discussions in the literature about inter-reviewer disagreements, this result suggests that there is similar disagreement in co-authors’ views of their papers as well.

5. Does peer review change authors’ perception of their own papers?

The question Q3 was a multiple-choice question with five choices: much more positive (“++”), slightly more positive (“+”), did not change (“0”), slightly more negative (“-”), much more negative (“- -”).

We find that among both accepted and rejected papers, about 50% of authors report that their perception about their own paper changed after seeing the initial reviews (Q3). Moreover, among both accepted and rejected papers, over 30% of authors report that their perception became more positive.

Accepted papers Rejected papers


The fact that authors vastly overestimated the probability that their papers will be accepted suggests it would be useful for conference organizers and research mentors to attempt to recalibrate expectations prior to each conference. The disagreements we document around paper quality — between co-authors as well as between authors and reviewers — taken together with the disagreement among committees of reviewers observed in the complementary NeurIPS 2021 consistency experiment, suggest that assessing paper quality is not only an extremely noisy process, but may be a fundamentally challenging task with no objective right answer. The outcomes of paper submissions should thus be taken with a grain of salt. More broadly, as a community, we may take these findings into account when deciding on our policies and perceptions pertaining to the peer-review process and its outcomes. We hope the results of our experiment encourage discussion and introspection in the community.

More details: Available here

Read More

Tackling Diverse Tasks with Neural Architecture Search

DASH searches for the optimal kernel size and dilation rate efficiently from a large set of options for each convolutional layer in a CNN backbone. The resulting model can achieve task-specific feature extraction and work as well as hand-designed expert architectures, making DASH an effective tool for tackling diverse tasks beyond well-researched domains like vision.

The past decade has witnessed the success of machine learning (ML) in solving diverse real-world problems, from facial recognition and machine translation to disease diagnosis and protein sequence prediction. However, progress in such areas has involved painstaking manual effort in designing and training task-specific neural networks, leveraging human and computational resources that most practitioners do not have access to.

In contrast to this task-specific approach, general-purpose models such as DeepMind’s Perceiver IO and Gato and Google’s Pathway have been developed to solve more than one task at once. However, as these proprietary pretrained models are not publicly available, practitioners cannot even assess whether fine-tuning one of these models would work on their task of interest. Independently developing a general-purpose model from scratch is also infeasible due to the massive amount of compute and training data it requires.

A more accessible alternative is the field of automated machine learning (AutoML), which aims to obtain high-quality models for diverse tasks with minimal human effort and computational resources, as noted in a recent blogpost. In particular, we can use Neural Architecture Search (NAS) to automate the design of neural networks for different learning problems. Indeed, compared with training large-scale transformer-based general-purpose models, many efficient NAS algorithms such as DARTS can be run on a single GPU and take a few hours to complete a simple task. However, while NAS has enabled fast and effective model development in well-studied areas such as computer vision, its application to domains beyond vision remains largely unexplored. In fact, a major difficulty in applying NAS to more diverse problems is the trade-off between considering a sufficiently expressive set of neural networks and being able to efficiently search over this set. In this blog post, we will introduce our approach to find a suitable balance between expressivity and efficiency in NAS.

In our upcoming NeurIPS 2022 paper, we developed a NAS method called DASH that generates and trains task-specific convolutional neural networks (CNNs) with high prediction accuracy. Our core hypothesis is that for a broad set of problems (especially those with non-vision inputs such as audio and protein sequences), simply searching for the right kernel sizes and dilation rates for the convolutional layers in a CNN can achieve high-quality feature extraction and yield models competitive to expert-designed ones. We explicitly focus on extending the generalization ability of CNNs due to the well known effectiveness of convolutions as feature extractors, coupled with recent work demonstrating the success of modern CNNs on a variety of tasks (e.g., the state-of-the-art performance of the ConvNeXt model that incorporates many techniques used by Transformers).

While a search space of diverse kernels is easy to define, searching it efficiently is challenging because we want to consider many kernels with different kernel sizes and dilation rates, which results in a combinatorial explosion of possible architectures. To address this issue, we introduce three techniques exploiting the mathematical properties of convolution and fast matrix multiplication on GPUs. We evaluate DASH on 10 different tasks spanning multiple domains (vision, audio, electrocardiogram, music, protein, genomics, cosmic-ray, and mathematics), input dimensions (1D and 2D), and prediction types (point and dense). While searching up to 10x faster than existing NAS techniques, DASH achieves the lowest error rates among all NAS baselines on 7/10 tasks and all hand-crafted expert models on 7/10 tasks.

In the following, we will first discuss how DASH is inspired by and differs from existing NAS work. Then, we will introduce three novel “tricks” that improve the efficiency of searching over a diverse kernel space. Finally, we will present the empirical evaluation to demonstrate DASH’s effectiveness.

The Expressivity-Efficiency Trade-Off in NAS

Most NAS methods have two components for generating task-specific models: a search space that defines all candidate networks and a search algorithm that explores the search space until a final model is found. Effective models for arbitrary new tasks can be developed if and only if the search space is sufficiently expressive, but this also means we need more time to explore the set of possible architectures in the space.

This tension between search space expressivity and search algorithm efficiency has been prominent in NAS research. On one hand, vision-centric approaches like DARTS are designed to explore multiple architectures quickly, but the search spaces are limited to models with (inverted) residual blocks, and are thus highly tailored to vision tasks. On the other hand, new approaches like AutoML-Zero and XD aim to solve arbitrary tasks by considering highly expressive search spaces, but the associated search algorithms are often practically intractable. For instance, XD tries to substitute layers in existing networks with matrix transformations. The matrix search space is continuous and expansive, but optimizing it is extremely time-consuming even for simple benchmarking tasks like CIFAR-100, rendering XD impractical for diverse tasks with more data points or larger input dimensions.

To bridge this gap, we present DASH, which fixes a CNN as the backbone and searches for the optimal kernel configurations. The intuition is that modern convolutional models like ConvNeXt and Conv-Mixer are powerful enough to compete with attention-based architectures, and varying kernel sizes and dilations can further strengthen the feature extraction process for different problems. For instance, small filters are generally used for visual tasks to detect low-level features such as edges and corners, whereas large kernels are typically more effective for sequence tasks to model long-range dependencies. Unlike conventional cell-based NAS which searches for a block of operations and stacks several copies of the same block together, DASH is more flexible as it decouples layer operations from the network structure: since the searched operators can vary from the beginning to the end of a network, features at different granularities can be processed differently. 

“Kernel Tricks” for Improving NAS Efficiency

As mentioned above, we seek a sufficiently expressive (i.e., large) kernel search space to ensure that there exist kernels that can effectively extract features for a diverse set of tasks. To achieve this, we replace each convolutional layer in the backbone network with the following aggregated convolution operator:

$$S_{bf AggConv_{K, D}} = {bf Conv_{k,d} | k in K, din D}, tag{1}label{1}$$

where (K) and (D) denote the set of kernel sizes and dilations that we consider, respectively. A naive approach to searching over this kernel space is to use the continuous relaxation scheme of DARTS to compute the output (we call this approach mixed-results):

$$bf AggConv_{K,D} (mathbf{x}) := sum_{kin K}sum_{din D} alpha_{k,d}cdot bf Conv(mathbf{w}_{k,d})(mathbf{x}), tag{2}label{2}$$

where (mathbf{w}_{k,d}) are the kernel weights and (alpha_{k, d}) are the architecture parameters. However, DARTS only considers a few kernels with small sizes and dilations (with (k_{max})=5, (d_{max})=2, and thus small (|K||D|)), whereas we aim to search over many and large kernels (e.g., with (k_{max})=11, (d_{max})=127, and large (|K||D|)). This increased expressivity leads to drastically higher search costs for naive search, whose runtime complexity is (O(n|K||D|)), where (n) is the input size.

In the following, we will describe three techniques—kernel-mixing, Fourier convolution, and Kronecker dilation—that DASH collectively employs to enable efficient search. Complexity-wise, DASH’s efficient search replaces the (O(n|K||D|)) complexity of naive search with an (O(nlog n)) complexity, where (log n) is small for any realistic (n), including long sequence inputs. Empirically, this latter complexity translates to significantly improved search speed, e.g., DASH searches about 10 times faster than DARTS for the large (|K||D|) regime (Figure 1).

Figure 1. Combined forward- and backward-pass time for one search epoch vs. the search space size for 1D input with (n = 1000). We vary the search space by letting (K = {2p+1 | 1leq p leq c}), (D = {2^q-1 | 1leq qleq c}) and increasing (c) from 1 to 7. As the aggregated kernel size (bar{D}) increases, the DASH curves grow much slower than the other methods.

Technique 1: Mixed-Weights. We first observe that all computations in Equation 2 are linear, so the distributive law applies. Hence, instead of computing (|K||D|) convolutions, we can combine the kernels and compute convolution once:

$$bf AggConv_{K, D}(mathbf{x})=bf Convleft(sum_{kin K}sum_{din D} alpha_{k, d}cdotmathbf{w}_{k,d}right)(mathbf{x}). tag{3}label{3}$$

Let’s call this approach mixed-weights. Mixed-weights allows the search complexity to depend on the aggregated kernel size (bar{D} := max (k − 1)d + 1) rather than (|K||D|), but the former still scales with search space. Can we do better than this?

Technique 2: Fourier Convolution. Imagine that (x) is an image input. Then, Equation 3 operates on the pixel values in the spatial domain. Alternatively, one can work with the rate of change of the pixel values in the frequency domain to compute the convolution output more efficiently, taking advantage of the celebrated convolution theorem:

$$bf AggConv_{K,D}(mathbf{x})= mathbf{F}^{-1}bf diagleft(mathbf{F}(sum_{kin K}sum_{din D} alpha_{k,d}cdot mathbf{w}_{k,d})right)mathbf{F}mathbf{x}. tag{4}label{4}$$

where (mathbf{F}) represents the discrete Fourier transform. Equation 4 allows us to remove the dependence on the combined kernel size (bar{D}), as (mathbf{F}) can be applied in time (O(n log n)) using the Fast Fourier Transform.

Technique 3: Kronecker Dilation. Lastly, we focus on accelerating search in a subset of the search space where the kernel size (k) is fixed and the dilation rate (d) varies. To dilate a kernel before applying it to the input, we need to insert (d-1) zeros between the adjacent elements in the weight matrix. An efficient implementation on GPUs exploits the Kronecker product (otimes). For example, in 2D, we can introduce a sparse pattern matrix (P in mathbb{R}^{dtimes d}) whose entries are all (0)’s except for the upper-left entry (P_{1,1} = 1). Then, (mathbf{w}_{k,d} = mathbf{w}_{k,1} otimes P). After dilating the kernels, we can proceed by following Equation 4.

Empirical Results on NAS-Bench-360

Figure 2. We use performance profiles to measure the aggregate performance of DASH and other NAS methods (higher is better). Each profile is a curve corresponding to a single method and plots the fraction of tasks on which that method is a multiplicative factor (tau) worse than the best method, where (tau) is the domain variable between 1 and infinity. DASH being far in the top left corner indicates it is rarely suboptimal and is often the best.

To verify that DASH finds a balance between expressivity and efficiency, we evaluate its performance with the Wide ResNet backbone on ten diverse tasks from NAS-Bench-360. We present the performance profile (a technique for comparing different methods while revealing both ranking and absolute performance characteristics) for DASH and the NAS baselines in Figure 2. The exact accuracy metrics can be found in our paper. We highlight the following observations:

  • DASH ranks first among all NAS baselines (and outperforms DARTS) on 7/10 tasks.  It also dominates traditional non-DL approaches such as Auto-Sklearn and general-purpose models such as Perceiver IO.
  • DASH outperforms hand-crafted expert models on 7/10 tasks. While the degree of sophistication of the expert networks varies task by task, the performance of DASH on tasks such as Darcy Flow suggests that it is capable of competing with highly specialized networks, e.g., Fourier Neural Operator for PDE solving. This implies that equipping backbone networks with task-specific kernels is a promising approach for model development in new domains.
  • Speedwise, DASH is consistently faster than DARTS, and its search process often takes only a fraction of the time needed to train the backbone. Figure 3 visualizes the trade-off between efficiency and effectiveness for each method-task combination. Evidently, DASH is both faster and more effective than most NAS methods on the tasks we considered.
Figure 3. Comparing -(logtau)-suboptimality of speed vs. accuracy on all tasks. DASH’s concentration in the top right corner indicates its strong efficacy-efficiency trade-offs relative to the other methods.

In addition to the Wide ResNet backbone and NAS-Bench-360 tasks, we have also verified the efficacy of DASH on other backbones including TCN and ConvNeXt, and on large-scale datasets including ImageNet. In particular, DASH is able to achieve a 1.5% increase in top-1 accuracy for ImageNet-1K on top of the ConvNeXt backbone (note that ConvNeXt itself was developed in part via manual tuning of the kernel size). These results provide further support that DASH is backbone-agnostic, and it can be used to augment modern architectures with task-specific kernels to solve diverse problems effectively and efficiently.

Discussion and Takeaways

In this blogpost, we argue that a crucial goal of AutoML is to discover effective models for solving diverse learning problems that we may encounter in reality. To this end, we propose DASH, which efficiently searches for task-specific kernel patterns and integrates them into existing convolutional backbones. Please see our paper and codebase to learn more about DASH or try it out on your own tasks and backbones.

While DASH makes progress in obtaining high-quality models by finding a suitable balance between expressivity and efficiency in NAS, we view it as an early step as part of the broader challenge of AutoML for diverse tasks. Thus, we would like to highlight and encourage participation in the ongoing AutoML Decathlon competition at NeurIPS 2022. At the moment, we are working on implementing DASH as one of the baselines for the competition. Meanwhile, we hope to see more automated and practical methods developed for tackling diverse tasks in the future.  

Read More

Tracking Any Pixel in a Video

We upgrade pixels into PIPs: “Persistent Independent Particles”. With this representation, we track any pixel over time, and overcome visibility issues with a learned temporal prior.

Motion estimation is a fundamental task of computer vision, with extremely broad applications. By tracking something, you can build models of its various properties: shape, texture, articulation, dynamics, affordances, and so on. More fine-grained tracking allows more fine-grained understanding. For robots, fine-grained tracking also enables fine-grained manipulation. Even setting aside downstream AI-related applications, motion tracks are directly useful for video editing applications — making realistic edits to a person or object in a video demands precise-as-possible tracking of the pixels, across an indefinite timespan.

There are a variety of methods for tracking objects (at the level of segmentation masks or bounding boxes), or for tracking certain points in certain categories (e.g., the joints of a person), but there are actually very few options for general-purpose fine-grained tracking. In this domain, the dominant approaches are feature matching and optical flow. The feature matching approach is: compute a feature for the target on the first frame, then compute features for pixels in other frames, and then compute “matches” using feature similarity (i.e., nearest neighbors). This often works well, but does not take into account temporal context, like smoothness of motion. The optical flow approach is: compute a dense “motion field” that relates each pair of frames, and then do some post-processing to link the fields together. Optical flow is very powerful, but since it only describes motion for a pair of frames at a time, it cannot produce useful outputs for targets that undergo multi-frame occlusion. “Occlusion” means our view is obstructed, and we need to guess the target’s location from context.

During an occlusion, appearance information does not suffice, because the target is not even present in the frames. Multi-frame temporal priors are key.

Around the year 2006, Peter Sand and Seth Teller proposed an alternative to flow-based and feature-based methods, called a “particle video.” This approach aims to represent a video with a set of particles that move across multiple frames. Their proposed method did not handle occlusions, but in our view, they laid the groundwork for treating pixels as persistent entities, with multi-frame trajectories and long-range temporal priors.

Inspired by their work, we propose Persistent Independent Particles (PIPs), a new particle video method. Our method takes a video as input, along with the ( (x, y) ) coordinate of a target to track, and produces the target’s trajectory as output. The model can be queried for any number of particles, at any positions.

Particle trajectories for arbitrary “target” pixels in the video.

You may notice in our visualizations that the PIP trajectories often leave the video bounds. We treat pixels flying out-of-bounds as just another “occlusion.” This type of robustness is exactly what is missing from feature-based and flow-based methods.

Let’s step through how we achieved this.

How does it work?

At a high level, our method makes an extreme trade-off between spatial awareness and temporal awareness: we estimate the trajectory of every target independently. This extreme choice allows us to devote the majority of parameters into a module that simultaneously learns (1) temporal priors, and (2) an iterative inference mechanism that searches for the target pixel’s location in all input frames. Related work on optical flow estimation typically uses the opposite approach: they estimate the motion of all pixels simultaneously (using maximal spatial context), for just 2 frames at a time (using minimal temporal context).

Given the ( (x_1,y_1) ) coordinate of the target on the first frame, our concrete goal is to estimate the target’s full trajectory over (T) frames: ( (x_1,y_1), ldots, (x_T,y_T) ).

We start by initializing a zero-velocity estimate. This means copying the initial coordinate to every timestep.

To track the target, we need to know what it looks like, so we compute appearance features for all the frames (using a CNN), and initialize the target’s appearance trajectory with a bilinear sample at the given coordinate on the first frame. (The “bilinear sample” step extracts a feature vector at a subpixel location in the spatial map of features.)

Our inference process will proceed by iteratively refining the sequence of positions, and sequence of appearance features, until they (hopefully) match the true trajectory of the target. This idea is illustrated in the video below: at initialization, only “timestep 1” has the correct location of the target (since this was given), and gradually, the model “locks on” to the target in all frames.

Our model’s job is to produce updates (i.e., deltas) for the positions and features, so that the trajectory tracks the target on every frame. A critical detail here is that we ask our model to produce these updates for multiple timesteps simultaneously. This allows us to “catch” a target after it re-emerges from an occluder, and “fill in” the missing part of the trajectory.

There are many ways to implement this, but for fast training and good generalization, we need to carefully select what information we provide to the model.

The main source of information we provide to the model is: measurements of local appearance similarity. We obtain these measurements using cross correlation (i.e., dot products), computed at multiple scales. When the target is visible, it should show up as a strong peak in at least one of the similarity maps. Also, when we are “locked on”, the peak should be in the middle of the map. This is illustrated in the animation below.

The second source of information we provide to the model is: the estimated trajectory itself. This allows the model to impose a temporal prior, and fix up parts of the trajectory where the local similarity information was ambiguous.

Finally, we allow the model to inspect the feature vector of the target, in case it might learn different strategies for different types of features. For example, depending on the scale or texture of the target, it may adjust the way it uses information from the multi-scale similarity maps.

For the model architecture, we elected to use an MLP-Mixer, which we found to have a good trade-off between model capacity, training time, and generalization. We also tried convolutional models and transformers, but the convolutional models could not fit the data as well as the MLP-Mixer, and the transformers took too long to train.

We trained the model in synthetic data that we made (based on an existing optical flow dataset), where we could provide multi-frame ground-truth for targets that undergo occlusions. The animation below shows the kind of data we trained on. You might say the data looks crazy — but that’s the point! If you can’t get real data, your best bet is synthetic data with extremely high diversity.

FlyingThings++: We train our model to track objects in this data, so that real videos are easy in comparison.

After training for a couple days on this data, the model starts to work on real videos.


In the paper, we provide some quantitative analysis, showing that it works better than existing methods. It’s certainly not perfect — on keypoint tracking, the model works about 6/10 times, so there is a lot of room for improvement. The baselines are at around 5/10 or less.

The idea here is: pick a point on the first frame, and try to locate that same point in other frames of the video. This is a hard task especially when the point gets occluded.

Output of our PIPs model.

Baseline methods tend to get stuck on occluders, since they do not use multi-frame temporal context. For example, here is the output of a state-of-the-art optical flow method, on the same video and same target.

Output of an optical flow model (RAFT).

We have had fun trying the model on various videos, and observing the estimated trajectories. Sometimes they are surprisingly complex, since the target’s actual motion is subtly entangled with camera motion.

Visualizing the trajectories more densely gives mesmerizing results. Notice that the model even tracks ripples and specularities in the water.

Despite the fact that each particle trajectory is estimated independently of the others, they show surprisingly accurate grouping. Notice that the background particles all move together.

What’s next?

Our method upgrades pixels into PIPs: “Persistent Independent Particles.” This independence assumption, however, is probably not what we want in general. In ongoing work, we are trying to incorporate context across particles, so that confident particles can help the unconfident ones, and so that we track at multiple levels of granularity simultaneously.

We have released our code and model weights on GitHub. We encourage you to try our! If you are interested in building on our method, the provided tests, visualizations, and training scripts should make that easy. Or, if you are working on a video-based method that currently relies on optical flow, you may want to try our PIP trajectories as a replacement, which should give a better signal under occlusions.

We hope that our work opens up long-range fine-grained tracking of “anything.” For more details, please check out our project page and paper.

Read More

Long-term Dynamics of Fairness Intervention in Connection Recommender Systems

Figure 1: Modeled recommendation cycle. We demonstrate how enforcing group-fairness in every recommendation slate separately does not necessarily promote equity in second order variables of interest like network size.

Connection recommendation is at the heart of user experience in many online social networks. Given a prompt such as ‘People you may know’, connection recommender systems suggest a list of users, and the recipient of the recommendation decides which of the users to connect with. In some instances, connection recommendations can account for more than 50% of the social network graph [1]. Depending on the platform, being connected to the right people is tied to important advantages such as job opportunities or increased visibility. While this makes it imperative to treat users fairly, it is far from obvious how fairness can be enforced or what it even means to have a ‘fair’ system in this scenario. In fact, when enforcing fairness in dynamic systems like this, we are often interested in second-order variables while interventions target equity in single steps [2]. In connection recommendation, this generally means enforcing some sort of parity condition in each recommendation slate, e.g. equal exposure of female and male users for every query, while simultaneously hoping that this will promote more equitable network sizes in the long run. In our work, we demonstrate how this approach can fail and common statistical notions of recommendation fairness do not ensure equity of network sizes in the long run. In fact, we see that fairness interventions are not even sufficient to mitigate the amplification of existing biases. We reach these conclusions based on an extensive simulation study supplemented by theoretical limit analyses. In the following, we focus on a discussion of empirical results and refer to the full paper for theoretical derivations.

Recommendation procedure

The assumed recommendation cycle is depicted in Figure 1. We briefly summarize the involved steps and give more detail on the most important components in the following sections. On the left in the flowchart, a user queries a connection recommendation, for example, by loading the ‘People You may Know’ page on the platform’s website. The system then computes relevance scores between the user who is seeking recommendations – which we will refer to as the source user – and other previously unconnected users who will be referred to as the destination users. Based on these relevance scores, we derive a ranking of destination users subject to fairness constraints and display a list of possible connections to the source user. After new connections have been made, the cycle repeats.

Fairness constrained probabilistic ranking framework

For each recommendation query, the system obtains a set of relevance scores and derives a probabilistic ranking. This probabilistic ranking takes the form of a matrix (P_{s,d}^q) where the entry in the (d)th row and (r)th column denotes the probability of displaying member (d) in slot (r) of the recommendation list. The probabilities are selected to maximize the expected utility for the source user which is common practice in the recommendation literature [3].

More formally, we denote the source member as (s) and the destination member as (d). (q) denotes a query for an ordered list of (m) destination members and (u_{s,d}^q) the relevance score of (s) and (d) under query (q). We model a form of position bias by discounting the expected attention a recommended user receives based on how far down the list the recommendation, i.e. we define the exposure of slot (r) as (v(r)=1/log(r+1)). Given these quantities, the optimization problem for query (q) can be written as


$$text{s.t.} sum_{i=1}^nP_s^q(d,i)leq 1 text{ for all } din[D_q],$$

$$sum_{i=1}^{D_q}P_s^q(i,r)=1text{ for all }rin[m],$$

$$0leq P_s^q(i,r)leq 1text{ for all }iin[D_q],jin[m].$$

The constraints in this problem ensure that each slot has a recommended user and each user can be recommended at most once in a given list. We explore two types of fairness constraints to add to this problem. First, we consider demographic parity of exposure which is a commonly suggested form of fairness in recommendations. It requires that groups receive expected exposure proportional to their population shares, i.e. for groups (G_0) and (G_1)

$$frac{1}{vert G_0vert}sum_{din G_0}sum_{r=1}^mP_s^q(d,r)v_r=frac{1}{vert G_1vert}sum_{din G_1}sum_{r=1}^mP_s^q(d,r)v_r.$$

Assume a setting in which there is no position bias (v), the population is split into 60% majority group and 40% minority group, and each recommendation list has 10 slots. In this setting, demographic parity of exposure requires (in expectation) that 6 of the recommended members belong to the majority group while 4 belong to the minority group in every recommendation list.

Second, we explore a dynamic parity of utility constraint. As opposed to the first constraint, this fairness measure does not only consider the exposure of groups but also the total utility that can be expected from the exposure, i.e. for two groups (G_0) and (G_1),

$$frac{1}{vert G_0vert}sum_{din G_0}u_{s,d}^qsum_{r=1}^mP_s^q(d,r)v_r=frac{1}{vert G_1vert}sum_{din G_1}u_{s,d}^qsum_{r=1}^mP_s^q(d,r)v_r.$$

On a high level, the constraint requires that the sum of relevance scores across groups is proportional to the population share of groups discounted by the exposure limitations posed by position bias. To understand what this means, let us consider the example setting from before with no position bias, a 60%/40% group split, and recommendation lists with 10 slots. Assume that all destination members of the majority group have a relevance score of 0.12 while all destination members of the minority group have a fixed relevance of 0.08. The probabilistic ranking fulfills dynamic parity of utility if the expected group-split in the recommendation list is 50%/50% because (0.5 * 10 * 0.12 = 0.6) and (0.5 * 10 * 0.08 = 0.4).

How do we model relevance scores?

Relevance scores usually model the probability of connection if recommended, a measure of downstream engagement or some mixture of the two. The exact models employed by social media platforms are usually proprietary, but we opt for a synthetic model in the form of logistic regression in three realistic features [4,5]. Assuming our prediction target is the probability of connection, we first use the source member’s network size. This is based on the assumption that users with larger networks are more likely to be proactive in connection forming as they are generally more active on the platform. Next, we assume that users with more common connections are more likely to connect. The common connections feature is rooted in social network literature where it is commonly referred to as triadic closure [e.g. 6,7]. Lastly, we make the assumption that users with similarities such as demographics, interests, education, workplace, etc. are more likely to connect. This follows the observation that individuals like to be connected to similar individuals commonly referred to as homophily in sociology and other social sciences [e.g. 8].

The main simulation procedure makes use of the relevance scoring model for user pairs as follows. We assume a fixed-size graph of evolving connections with two groups of members. 65% of members belong to an initially more connected majority group. First, we assume a ground truth model for matching scores by imposing ground-truth parameter values in the presented logistic regression function. We use this function to simulate a data set of recommendations and formed connections and use the synthetic data to train a logistic regression matching scores model. 

Simulation procedure

We sample covariate vectors from group-dependent distributions for each member in the graph. These vectors will be used to compute the similarity between members which we set as negative Euclidean distance. The connections in the graph are initialized with a stochastic block model in which user pairs in the majority group are slightly more likely to connect than user pairs in the minority group and user pairs who belong to the same group are slightly more likely to connect than user pairs who belong to opposite groups. This models the initial advantage for majority group members we would expect to see in practice and accounts for homophily preferences. Given the initial specifications, we go through the previously described recommendation cycle for 2,500 timesteps and keep track of key fairness and performance metrics. Following our reasoning from the scoring model, the frequency with which users seek out recommendations is based on exponential waiting times with a mean depending on the current network size of a user. The whole simulation procedure is repeated with each fairness intervention separately and without intervention for comparison. Results are averaged over 10 runs. 

Figure 2: Results of the simulation study averaged over 10 runs. Panel (a) shows the absolute difference in average network sizes between groups over the simulation time period. Panel (b) shows the share of network degrees that belong to the majority group. A connection essentially translates into two degrees, one for the source member of the connection and one for the destination member. In panel (c), we see the share of the majority group on the destination members of new connections, and in panel (d), we see the share of the majority group among all new degrees.

Rich-gets-richer in groups

Figure 2 depicts the performance of all 3 intervention types. With no intervention – which is displayed in red in the graphs – we observe that the gap in average network sizes between groups drastically increases over time. In fact, our results reveal a group-wise rich-get-richer effect in which network sizes tend to a power law distribution with a lower mode in the minority population. Members whose networks are in the tail of the distribution tend to be the members who had large networks, as compared to their peers, to begin with. These observations are in line with previous findings on rich-get-richer phenomena under homophily.

Demographic parity of exposure intervention

We now move to the results for the demographic parity of exposure intervention. First, we can see that the majority group share of exposure in recommendations is down to 66.3% which suggests that the intervention is working. However, as we can see in panels (a) and (b), the gap in average network sizes is still increasing over time. Why is this the case? First, majority group members seek out recommendations more frequently based on their already larger network sizes and activity levels. This leads to more connections formed with majority group members at the source side of the recommendation while our intervention can only target the destination side. Second, majority group members have higher ranking scores and are thus more likely to be invited for connection even if both groups are exposed equally. Consider the following example recommendation:

Here, both female and male users have the same exposure – we will ignore position bias for this example. However, the probability of connection and its proxy – the matching score – are much lower for women than for men essentially leading to fewer new connections for women than for men. While the intervention suggests recommendations are made fairly, this does not align with our intuitive goal of more equitable network sizes.

Dynamic parity of utility intervention

Some of the problems with the demographic parity of exposure intervention are addressed by the dynamic parity of utility intervention. We see that the majority group share of destination members of new connections is down to 65.4% as desired. Yet, even with this type of constraint, panels (a) and (b) show that the gap in average network sizes is still increasing over time. Like in the previous case, one of the reasons for this is that majority group members seek out recommendations more frequently. In addition, our analysis reveals that source members from the majority group generally form more connections per recommendation query leading to the increased share in panel (d). To understand why this is the case, consider the following example:

In the first row, a majority group member – here displayed as male – receives recommendations. In the second row, a minority group member – here female – receives recommendations. In both recommendation lists the dynamic parity of utility constraint is fulfilled, but the male source member receives more overall utility from the query because the constraint can only target fairness within a list and not in between different sets of recommendations.

Summary of findings

Let us summarize the key findings. Our study shows that unconstrained connection recommendation leads to a group-wise rich-get-richer effect. Enforcing demographic parity of exposure or dynamic parity of utility between groups, which are commonly suggested remedies against demographic bias in recommender systems, leads to less bias amplification but is not sufficient in order to mitigate an increase in the disparities in network sizes over time. As shown in the full paper, theoretical limit analysis shows that dynamic parity of utility would be the optimal intervention if there was no source-side bias. Yet, this is an unrealistic assumption in practice.

Overall, the common practice of measuring fairness in recommender systems in a one-shot or time-aggregate static manner can lead to an illusion of fairness and deployment of fairness-enhancing algorithms with unforeseen consequences. Connection recommendation operates on a dynamical system that needs to be taken into account to ensure equitable outcomes in the long run.

The full paper is published in the proceedings of the AAAI / ACM conference on Artificial Intelligence, Ethics, and Society (AIES 2022). A preprint is available here.


[1] LinkedIn PYMK: [Online; accessed 7/6/22] [2] Lydia T. Liu, Sarah Dean, Esther Rolf, Max Simchowitz, and Moritz Hardt. Delayed impact of fair machine learning. In Proceedings of the 35th International Conference on Machine Learning (ICLM 2018), 2018.

[3] Deepak K. Agarwal and Bee-Chung Chen. 2016. Statistical Methods for Recommender Systems. Cambridge University Press.

[4] LinkedIn PYMK: [Online; accessed 7/6/22] [5] Facebook PYMK: [Online; accessed 7/6/22] [6] Kossinets, G., & Watts, D. J. (2006). Empirical Analysis of an Evolving Social Network. In Science (Vol. 311, Issue 5757, pp. 88–90). American Association for the Advancement of Science (AAAS).

[7] David Liben-Nowell and Jon Kleinberg. 2007. The link-prediction problem for social networks. J. Am. Soc. Inf. Sci. Technol. 58, 7 (May 2007), 1019–1031.

[8] McPherson, M., Smith-Lovin, L., & Cook, J. M. (2001). Birds of a Feather: Homophily in Social Networks. In Annual Review of Sociology (Vol. 27, Issue 1, pp. 415–444). Annual Reviews.

Read More

Recurrent Model-Free RL Can Be a Strong Baseline for Many POMDPs

Figure 1. Our implementation of recurrent model-free RL outperforms the on-policy version (PPO/A2C-GRU), and a recent model-based POMDP algorithm (VRM) on most tasks of a POMDP benchmark where VRM was evaluated in their paper.

While algorithms for decision-making typically focus on relatively easy problems where everything is known, most realistic problems involve noise and incomplete information. Complex algorithms have been proposed to tackle these complex problems, but there’s a simple approach that (in theory) works on both the easy and the complex problems. We show how to make this simple approach work in practice.


Decision-making tasks in the real world are messy, with noise, occlusions, and uncertainty that are typically missing from their canonical problem formulation as a Markov decision process (MDP; Bellman, 1957). In contrast, Partially Observable MDPs (POMDPs; Åström, 1965) can capture the uncertainty in the states, rewards, and dynamics. Such uncertainty arises in applications such as robotics, healthcare, NLP and finance.

Apart from being realistic, POMDPs are a general framework that contains many subareas in RL, including:

What is recurrent model-free RL?

Solving POMDPs is hard because the agent needs to learn two tasks simultaneously: inference and control. Inference aims to infer the posterior over current states conditioned on history. Control aims to perform RL / planning algorithms on the inferred state space. While prior methods typically decouple the two jobs with separate models, deep learning, provides us with a general and simple baseline: combine an off-the-shelf RL algorithm with a recurrent neural network (RNN; e.g., LSTM (Hochreiter & Schmidhuber, 1997) and GRU (Chung et al., 2014)).

By backpropagating the gradients from policy loss, RNNs make it possible to process sequences (histories in POMDPs) and learn implicit inference on the state space for control. We refer to it as recurrent model-free RL. Fig. 2 presents our design, where actor and critic networks each have an RNN as the history encoder.

Figure 2. Our recurrent actor and critic architecture. Each network contains an RNN as the sequence encoder. The recurrent actor takes as input current observation, previous action and reward (optional), and outputs current action. The recurrent critic also takes current action as an input and outputs Q-values.

Recurrent model-free RL has many merits at first glance:

  • Conceptually simple. The policy purely learns from rewards, without extra objectives.
  • Easy to implement. Practitioners can just change several lines of code from model-free RL.
  • Expressive in theory. RNNs have been shown as universal function approximators (Siegelmann & Sontag, 1995, Schäfer & Zimmermann, 2006) and, thus recurrent model-free RL can (approximately) express any memory-based policies.

Due to its simplicity and expressivity, there is rich literature (Schmidhuber, 1991, Bakker, 2001, Wierstra et al., 2007, Heess et al., 2015, Hausknecht & Stone, 2015) on studying different RL algorithms and RNN architectures of recurrent model-free RL. However, prior work has shown that it often fails in practice with poor or unstable performance (Igl et al., 2018, Hung et al., 2018, Packer et al., 2018, Rakelly et al., 2019, Zintgraf et al., 2020, Han et al., 2020, Zhang et al., 2021, Raposo et al., 2021), with only a few exceptions (Yu et al., 2019, Fakoor et al., 2020).

Motivated by the poor performance of this simple baseline, prior work has proposed more sophisticated methods. Some introduce model-based objectives that explicitly learn inference, while others incorporate the assumptions used in the subarea of POMDPs as inductive bias. Both achieve good results on a range of respective tasks, although the model-based methods may have staleness issue in the belief states stored in the replay buffer, and the specialized methods require more assumptions than recurrent model-free RL (e.g., meta-RL methods normally assumes the hidden variable is constant within a single episode).

How to train recurrent model-free RL?

In this work, we found that recurrent model-Free RL is not fatally failed, but just needed to be implemented differently, with differences including:

  1. Separating the RNNs in actor and critic networks. Un-sharing the weights can prevent gradient explosion, and can be the difference between the algorithm learning nothing and solving the task almost perfectly.
  2. Using an off-policy RL algorithm to improve sample efficiency. Using, say, TD3 instead of PPO greatly improves sample efficiency.
  3. Tuning the RNN context length. We found that the RNN architectures (LSTM and GRU) do not matter much, but the RNN context length (the length of the sequence fed into the RL algorithm), is crucial and depends on the task. We suggest choosing a medium length as a start.

Properly tuned, the simple baseline outperforms alternatives on many POMDPs

With these changes, our implementation of recurrent model-free RL is at least on par with (if not much better than) prior methods, on the tasks those prior methods were designed to solve. While prior methods are typically designed to solve special cases of POMDPs, recurrent model-free RL applies to all types of POMDPs.

Our first comparison looks at meta-RL tasks, which are usually approached by methods that decouple the task inference and reward maximization steps (Rakelly et al., 2019, Zintgraf et al., 2020). When comparing these prototypical methods (on-policy variBAD (Zintgraf et al., 2020) and off-policy variBAD (Dorfman et al., 2020)), we find that our recurrent model-free approach can often perform at least on par with them. These results (Fig. 3, 4) suggest that disentangling inference and control may be not that necessary in many tasks.

Figure 3. Two meta-RL environments from off-policy variBAD (left: Semi-Circle, right: Wind), where ours outperforms off-policy variBAD.
Figure 4. Two meta-RL environments from on-policy variBAD (left: Cheetah-Dir, right: Ant-Dir), where we have mixed results.

Next, we move on to the robust RL setting, which is mostly solved by algorithms that explicitly maximize the worst returns (Rajeswaran et al., 2017, Mankowitz et al., 2020). By comparing one recent robust RL algorithm (Jiang et al., 2021), we find that our recurrent model-free approach performs better in all the tasks. The results (Fig. 5) indicate that with the power of implicit task inference with RNNs, we can improve both average and worst returns.

Figure 5. One robust RL environment from MRPO, Cheetah-Robust (left: average return; right: worst return), where ours outperforms MRPO.

Then we explore the generalization in RL, where people have different kinds of specialized methods, including policy regularization (Farebrother et al., 2018) and data augmentation (Lee et al., 2020). Here we choose a popular benchmark SunBlaze (Packer et al., 2018) which provides a specialized baseline EPOpt-PPO-FF (Rajeswaran et al., 2017). The results (Fig. 6) show that despite not explicitly enhancing generalization, our recurrent model-free approach can perform better in extrapolation than the baseline. This suggests that the task inference learned by RNN is generalizable.

Figure 6. One generalization in RL environment from SunBlaze, Hopper-Generalize (left: interpolated success rate; right: extrapolated success rate), where ours outperforms EPOpt-PPO-FF on extrapolation.

Finally, we study the temporal credit assignment domain where the methods for solving it usually involve reward decomposition/redistribution (Liu et al., 2019, Hung et al., 2018). Here, we choose a recent specialized method IMPALA+SR (Raposo et al., 2021), and evaluate our method on their benchmark with pixel-based discrete control and sparse rewards. Despite being unaware of the reward structure and not performing credit assignment explicitly, our recurrent model-free approach can perform better. The results (Fig. 7) indicate that recurrent policies can effectively cope with sparse rewards, perhaps better than previously expected.

Figure 7. Two temporal credit assignment environments from IMPALA+SR (left: Delayed-Catch, right: Key-to-Door), where ours outperforms IMPALA+SR.


In a sense, our finding can be interpreted as echoing the motivation for deep learning: we can achieve better results by reducing a method to a single differentiable architecture, optimized end-to-end with a single loss. This is exciting because RL systems often involve many interconnected parts (e.g., feature extracting, model learning, value estimation) trained with different objectives, but perhaps they might be replaced by end-to-end approaches if equipped with sufficiently expressive architectures.

We have open-sourced the code on GitHub to support reproducibility and to help future work develop better POMDP algorithms. Please see our project site for the paper and our ICML 2022 presentation.

Read More

auton-survival: An Open-Source Package for Regression, Counterfactual Estimation, Evaluation and Phenotyping Censored Time-to-Event Data


GitHub Repo stars

Real-world decision-making often requires reasoning about when an event will occur. The overarching goal of such reasoning is to help aid decision-making for optimal triage and subsequent intervention. Such problems involving estimation of Times-to-an-Event frequently arise across multiple application areas, including,

Healthcare and Bio-informatics: More commonly known as ‘Survival Analysis‘ involves prognostication of an adverse physiological event like a stroke, the onset of cancer, re-hospitalization, and mortality. Time-to-event or survival analysis can be used to proactively mitigate adverse outcomes and extend the longevity of patients.
Internet Marketing and e-commerce: Models employed for estimating customer churn and retention in large commercial organizations are essentially time-to-event regression models and help determine best practices to maximize customer retention.
Predictive Maintenance: Reliability engineering and systems safety research involves the use of remaining useful life prediction models to help extend the longevity of machinery and equipment by proactive part and component replacement.
Finance and Actuarial and Sciences: Time-to-Event models are ubiquitous in the estimation of optimal financial strategies for setting insurance premiums, as well as estimating credit defaulting behavior.

Figure 1: Patients A and C died 1 and 4 years following entry into the study, whereas mortality outcomes are missing for Patients B and D prior to their exit from the study at 3 and 2 years following entry. Time-to-Event or Survival Regression involves adjusting estimates for such individuals whose outcomes are censored.

Real-world decision-making often requires reasoning about when an event will occur. The overarching goal of such reasoning is to help aid decision-making for optimal triage and subsequent intervention. Such problems involving estimation of Times-to-an-Event frequently arise across multiple application areas, including,

Healthcare and Bio-informatics: More commonly known as ‘Survival Analysis‘ involves prognostication of an adverse physiological event like a stroke, the onset of cancer, re-hospitalization, and mortality. Time-to-event or survival analysis can be used to proactively mitigate adverse outcomes and extend the longevity of patients.
Internet Marketing and e-commerce: Models employed for estimating customer churn and retention in large commercial organizations are essentially time-to-event regression models and help determine best practices to maximize customer retention.
Predictive Maintenance: Reliability engineering and systems safety research involves the use of remaining useful life prediction models to help extend the longevity of machinery and equipment by proactive part and component replacement.
Finance and Actuarial and Sciences: Time-to-Event models are ubiquitous in the estimation of optimal financial strategies for setting insurance premiums, as well as estimating credit defaulting behavior.

Figure 1 illustrates a typical example of a Time-to-Event problem in healthcare. The challenge of working with time-to-event data is compounded by the fact that as evidenced in the figure, such data typically includes individuals whose outcomes are unobserved, or ‘censored,’ either due to a loss of follow-up or end of the study.

Discretizing time-to-event outcomes to predict if an event will occur is a common approach in standard machine learning. However, this neglects temporal context, which could result in models that misestimate and lead to poorer generalization.

The auton-survival Package

In our recent Machine Learning for Healthcare ’22 paper, we present auton-survival – a comprehensive Python code repository of user-friendly, machine learning tools for working with censored time-to-event data. This package includes an exclusive suite of workflows for a range of tasks from data pre-processing and regression modeling to model evaluation. auton-survival includes an API similar to the scikit-learn package (Pedregosa et al., 2011), making its adoption easy for users with machine learning experience in Python. Additionally, to promote the usability of the package and rapid prototyping of solutions for both machine learning and clinical researchers, we include detailed documentation as well as example notebooks.

Time-to-Event Regression

Time-to-Event or Survival regression can be used to estimate the conditional probability of an event occurring within a specified time period or event-horizon. A time-to-event estimation problem thus reduces to estimating the conditional distribution of survival:

( mathbb{E}[1{T > t}|X = x] = mathbb{P}(T > t|X = x) = 1 − mathbb{P}(T ≤ t|X = x) )

Note that ( X) is a set of covariates, and ( T ) refers to the distribution of the censored survival time ( T = text{min}(T^∗, C) ) where ( T^∗ ) is the distribution of the true time-to-event and ( C ) is the distribution of the censoring time. Assuming conditional independence between ( T ) and ( C ) (ie., ( T ⊥ C|X )) allows identification of the distribution of ( mathbb{P}(T |X) ).

Survival regression naturally allows accounting for censored data. In the case of survival regression, the likelihood ( ell ) under censoring is given as

( ell ({x, t, δ}) ∝ mathbb{P}(T = t|X = x)^δmathbb{P}(T > t|X = x)^{1−δ} ).

Here ( x in mathbb{R}^d ) are the covariates, ( t in mathbb{R}^{+} ) is the event or censoring time and ( delta in {0, 1} ) is a binary indicator denoting if the individual was censored. For the censored individuals, the likelihood corresponds to the probability that the event takes place beyond the time horizon, ( t, mathbb{P}(T > t|X = x) ) also known as the ‘survival function‘.

Broadly, the popular approaches for learning estimators of survival in the presence of censoring can be categorized into:

  • Parametric: Assume that time-to-event distribution ( mathbb{P}(T) ) adheres to a known parametric distribution, such as Weibull or Log-Normal.
  • Non-Parametric: Involve learning kernels or similarity functions of the input covariates followed by a non-parametric (Kaplan-Meier or Nelson-Aalen) estimation of the survival rate weighted with the learned kernel.
  • Semi-Parametric: As with Cox Proportional Hazards models, feature interactions are learned through a parametric model followed by a non-parametric estimation of the base survival (hazard) rate.

Estimators of Survival [Notebook] [Docs]

Figure 2: When proportional hazards assumptions are satisfied, survival curves and their corresponding hazard rates do not intersect. auton-survival includes flexible estimators of time-to-events in the presence of non-proportional hazards.

Complex multimodal data often observed in healthcare and other applications, bring a multitude of challenges to traditional machine learning. auton-survival allows a simple interface to use deep neural networks and representation learning to model such complex data.

auton-survival includes extensions to the standard Cox Proportional Hazards (CPH) (Cox, 1972) involving deep representation learning (Faraggi and Simon, 1995; Katzman et al., 2018) as well as latent variable survival regression models, Deep Cox Mixtures (DCM), and Deep Survival Machines (DSM) (Nagpal et al., 2021 a,b) that ease the strong assumptions of proportional hazards shown in Figure 2 by modeling the time-to-event distribution as a fixed size mixture.

The SurvivalModel Class

The package provides a convenient SurvivalModel class that enables rapid experimentation via a consistent API that wraps multiple alternative regression estimators. In addition to the models mentioned above, the SurvivalModel class includes Random Survival Forests (RSF) (Ishwaran et al., 2008), which is a popular non-parametric survival model.

Hyperparameter tuning for model selection can be streamlined with the SurvivalRegressionCV class to apply ( K-text{fold} ) cross-validation over a user-specified hyperparameter grid.

Time-Varying Survival Regression

Real-world data often consists of multiple time-dependent observations per individual or time-varying covariates. auton-survival is equipped to handle time-varying covariates for survival analysis with auto-regressive deep learning models that allow learning temporal dependencies when estimating time-to-event outcomes. Implementations of time-varying DSM and Deep Cox Proportional Hazards model involve the use of RNNs, LSTMs, or GRUs (Chung et al., 2014; Hochreiter and Schmidhuber, 1997) for time-varying survival regression as shown in Figure 3.

Figure 3: For an individual with time-to-event (mathcal{T}_i), we observe covariates (x_i^j) at multiple time points (t_i^j). At each time-step, (j) we estimate the distribution of the remaining time-to-event (T_i – t_i^j). The representations of the input covariates (widetilde{x}_i^{(j+1)}) at time-step ((j)+1) are functions of the covariates, (widetilde{x}_i^{(j+1)}) and the representation of the preceding time-step (widetilde{x}_i^j).

Counterfactual Estimators of Survival

Decision support often requires reasoning about ‘what if’ scenarios regarding the effect of different treatments on outcomes. In observational settings, outcomes and treatment assignments may share common causes. Adjusting for such confounding factors is crucial when performing causal inference. auton-survival includes counterfactual survival regression as a tool for causal inference that accounts for confounding factors when estimating the effect of treatment on survival. Counterfactual survival regression involves fitting separate regression models on the treated and control populations and computing survival rates across treatment arms. Under the standard causal inference assumption of strong ignorability, the time-to-event outcome under intervention ( text{do}(A = a) ) can then be estimated as

(hat{S}big(t|text{do}(A = a)big) = mathop{mathbb{E}}_{X} big[hat{mathbb{E}}[1{T > t}|X = x, A = a] big] )

where (hat{mathbb{E}}[1{T > t}|X = x, A = a]) is just an estimate of the conditional expectation of survival learnt on the population under intervention ( (A=a) ).

Consider the data from the large SEER Cancer Incidence registry (Ries et al., 1975). When stratified by region (Figure 4a), there is an apparent disparity in survival rates. We demonstrate the use of counterfactual regression to provide insight into whether these discrepancies can be attributed to the geographic region or other socio-economic or physiological confounding factors that affect both belonging to these regions and the outcomes. To adjust estimates of survival with counterfactual estimation, we train two separate Deep Cox models on data from Greater California and Louisiana as counterfactual regressors. The fitted regressors are then applied to estimate the survival curves for each instance, which are then averaged over treatment groups to compute the domain-specific survival rate. Figure 4b presents the counterfactual survival rates compared with the survival rates obtained from a Kaplan-Meier estimator. The Kaplan-Meier estimator does not adjust for confounding factors and overestimates treatment effect, as evidenced by the extent that survival rates differ between regions. Alternatively, counterfactual regression adjusts for confounding factors and predicts more similar survival rates between regions.

Phenotyping Censored Survival Data [Notebook] [Docs]

a) Unsupervised b) Supervised c) Counterfactual
Figure 5:
Phenotypers in auton-survival: ( X ) represents the covariates, ( T ) the time-to-event, and ( Z ) is the phenotype to be inferred.

Survival rates differ across groups of individuals with heterogeneous characteristics. Identifying groups of patients with similar survival rates can be used to derive insight into practices and interventions that can help improve longevity for such groups. While domain knowledge can help identify such subgroups, in practice there could be potentially complex, non-linear feature interactions that determine assignment to subgroups, making identification difficult. In auton-survival, we refer to this group identification and survival assessment as phenotyping.

Our package offers multiple approaches to phenotyping that involve either the use of specific domain knowledge, as in the case of the intersectional phenotyper, or a completely unsupervised approach that clusters subjects based on the observed covariates. Additionally, auton-survival also offers phenotypers that explicitly involve supervision in the form of the observed outcomes and counterfactuals inform the learned phenotypes to better stratify the data. Directed Acyclic Graphical representations of probabilistic phenotypers in auton-survival are shown in Figure 5.

  • Intersectional Phenotyping: Recovers groups, or phenotypes, of individuals over exhaustive combinations of user-specified categorical and numerical features.
  • Unsupervised Phenotyping: Identifies groups of individuals based on structured similarity in the feature space by first performing dimensionality reduction of the input covariates, followed by clustering. The estimated probability of an individual belonging to a latent group is computed as the distance to the cluster normalized by the sum of distances to other clusters.
  • Supervised Phenotyping: Identifies latent groups of individuals with similar survival outcomes conditioned on outcomes. This approach can be performed as a direct consequence of training the DSM and DCM latent variable survival estimators.
  • Counterfactual Phenotyping: Identifies groups of individuals that demonstrate enhanced or diminished treatment effects (Chirag et al., 2022).

Figure 6 presents the Kaplan-Meier survival curves of the phenogroups extracted from SUPPORT (Knaus et al., 1995) using the unsupervised and supervised phenotyping. The intersecting survival curves suggest the phenotypers’ ability to recover phenogroups that do not strictly adhere to assumptions of Proportional Hazards. From Figure 7, it can be inferred that supervised phenotyping extracts phenogroups with higher discriminative power as indicated by the contrasting Kaplan-Meier estimates of phenogroup level survival.

Treatment Effect Estimation

auton-survival offers additional tools to analyze the effect of an intervention on outcomes by computing propensity-adjusted treatment effects in terms of the following metrics through bootstrap resampling of the dataset with replacement:

  • Hazard Ratio: Assuming the proportional hazards assumptions holds, the treatment effect can be measured as the ratio of hazard rates between the treatment and control arms.
  • Time at Risk (TaR) (Figure 7a): The treatment effect can be measured as the difference in time-to-event at a specified level of risk.
  • Risk at Time (Figure 7b): The treatment effect can be measured as the difference in risk at a specified time horizon.
  • Restricted Mean Survival Time (RMST) (Figure 7c): The treatment effect can be measured as the difference in the expected (or mean) time-to-event conditioned on a specified time horizon.

Propensity-adjustment allows an alternative approach to estimate treatment effects of potential confounders that influence both treatment assignment and the outcome. Not adjusting for treatment propensity could result in misestimations of treatment effects.

auton-survival allows adjusting for treatment propensity with computation of treatment effects bootstrapped with sample weights. When the specified sample weights are propensity scores, such as obtained from a classification model, the bootstrapped distribution treatment effect converges to the Inverse Propensity of Treatment Weighting (IPTW) Thompson-Horvitz estimate of the population Average Treatment Effect.

(mathbb{ATE}(mathcal{D}^*, f) = mathbb{E}_{x sim mathcal{D}^*} big[mathbb{E} [f_1(x) – f_0(x) | X = x]big] ; quad mathcal{D}^* sim frac{1}{widehat{mathbb{P}}(A|X)} cdot mathbb{P}^*(mathcal{D}) )

In a second analysis of the effect of geographical region on breast cancer mortality using data from the SEER cancer registry (Ries et al., 1975), we compare treatment effects before and after adjusting for confounding factors by inverse propensity weighting. Similar to the previous analysis with counterfactual regression, we consider the regions of “Greater California” and “Louisiana” as the binary “treatment” in question. To adjust treatment effects for confounding factors, we first trained a logistic regression with an ( ell_2 ) penalty by regressing the geographical region on the set of confounding variables. The estimated propensity scores are then employed as sampling weights for the treatment effects in terms of hazard ratios, restricted mean survival time (RMST), and risk difference as in Figure 8. Adjusting for region propensity noticeably mitigates differences in treatment effects, indicating that mortality due to breast cancer is likely explained by confounding socio-economic and physiological factors rather than solely the geographic region.


We present auton-survival, an open-source Python package encapsulating multiple pipelines to work with censored time-to-event data. Such data is ubiquitous in many fields, including healthcare and the maintenance of equipment. Through continuous collaboration with the machine learning for healthcare community, we aim to better aid machine learning research in efforts to create a robust, comprehensive repository of rigorous tools for reproducible analysis of censored time-to-event data.


Chirag Nagpal
PhD Candidate, Auton Lab

Willa Potosnak
Research Intern and
Incoming PhD Student, Auton Lab


[1] Nagpal, C., Potosnak, W. and Dubrawski, A., 2022. auton-survival: an Open-Source Package for Regression, Counterfactual Estimation, Evaluation and Phenotyping with Censored Time-to-Event Data. arXiv preprint arXiv:2204.07276.

[2] Fabian Pedregosa, Ga¨el Varoquaux, Alexandre Gramfort, Vincent Michel, Bertrand Thirion, Olivier Grisel, Mathieu Blondel, Peter Prettenhofer, Ron Weiss, Vincent Dubourg, et al. Scikit-learn: Machine learning in python. the Journal of machine Learning research, 12: 2825–2830, 2011.

[3] D. R. Cox. Regression models and life-tables. Journal of the Royal Statistical Society. Series B (Methodological), 34(2):187–220, 1972.

[4] David Faraggi and Richard Simon. A neural network model for survival data. Statistics in medicine, 14(1):73–82, 1995.

[5] Jared L Katzman, Uri Shaham, Alexander Cloninger, Jonathan Bates, Tingting Jiang, and Yuval Kluger. Deepsurv: personalized treatment recommender system using a cox proportional hazards deep neural network. BMC medical research methodology, 18(1): 1–12, 2018.

[6] Nagpal, C., Li, X. and Dubrawski, A., 2021a. Deep survival machines: Fully parametric survival regression and representation learning for censored data with competing risks. IEEE Journal of Biomedical and Health Informatics25(8), pp.3163-3175.

[7] Nagpal, C., Yadlowsky, S., Rostamzadeh, N. and Heller, K., 2021b, October. Deep Cox mixtures for survival regression. In Machine Learning for Healthcare Conference (pp. 674-708). PMLR.

[8] H. Ishwaran, Udaya B. Kogalur, Eugene H. Blackstone, and Michael S. Lauer. Random survival forests. The Annals of Applied Statistics, 2(3), 2008.

[9] Junyoung Chung, Caglar Gulcehre, KyungHyun Cho, and Yoshua Bengio. Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555, 2014.

[10] Sepp Hochreiter and J¨urgen Schmidhuber. Long short-term memory. Neural computation, 9 (8):1735–1780, 1997.

[11] LAG Ries, D Melbert, M Krapcho, DG Stinchcomb, N Howlader, MJ Horner, A Mariotto, BA Miller, EJ Feuer, SF Altekruse, et al. Seer cancer statistics review, 1975–2005. Bethesda, MD: National Cancer Institute, 2999, 2008.

[12] Nagpal, C., Goswami, M., Dufendach, K. and Dubrawski, A., 2022. Counterfactual Phenotyping with Censored Time-to-Events. arXiv preprint arXiv:2202.11089.

[13] W. A. Knaus, Harrell F. E., Lynn J, and et al. The support prognostic model: Objective estimates of survival for seriously ill hospitalized adults. Annals of Internal Medicine, 122: 191–203, 1995.

Read More