BanditPAM: Almost Linear-Time k-medoids Clustering via Multi-Armed Bandits

BanditPAM: Almost Linear-Time k-medoids Clustering via Multi-Armed Bandits


Want something better than (k)-means? Our state-of-the-art (k)-medoids algorithm from NeurIPS, BanditPAM, is now publicly available! (texttt{pip install banditpam}) and you’re good to go!

Like the (k)-means problem, the (k)-medoids problem is a clustering problem in which our objective is to partition a dataset into disjoint subsets. In (k)-medoids, however, we require that the cluster centers must be actual datapoints, which permits greater interpretability of the cluster centers. (k)-medoids also works better with arbitrary distance metrics, so your clustering can be more robust to outliers if you’re using metrics like (L_1).

Despite these advantages, most people don’t use (k)-medoids because prior algorithms were too slow. In our NeurIPS paper, BanditPAM, we sped up the best known algorithm from (O(n^2)) to (O(ntext{log}n)).

We’ve released our implementation, which is pip-installable. It’s written in C++ for speed and supports parallelization and intelligent caching, at no extra complexity to end users. Its interface also matches the (texttt{sklearn.cluster.KMeans}) interface, so minimal changes are necessary to existing code.

Useful Links:

(k)-means vs. (k)-medoids

If you’re an ML practitioner, you’re probably familiar with the (k)-means problem. In fact, you may know some of the common algorithms for the (k)-means problem. You’re much less likely, however, familiar with the (k)-medoids problem.

The (k)-medoids problem is a clustering problem similar to (k)-means. Given a dataset, we want to partition our dataset into subsets where the points in each cluster are closer to a single cluster center than all other (k-1) cluster centers. Unlike in (k)-means, however, the (k)-medoids problem requires cluster centers to be actual datapoints.

Figure 1: The (k)-medoids solution (left) forces the cluster centers to be actual datapoints. This solution is often different from the (k)-means solution (right).

The (k)-medoids problem has several advantages over (k)-means. By forcing the cluster centers — dubbed the medoids — to be actual datapoints, solutions tend to be more interpretable since you can determine exactly which datapoint is the cluster center for each cluster. When clustering images from the ImageNet dataset, for example, the mean of a solution to the (k)-means problem with (k = 1) are usually nondescript blobs (Figure 2, left), whereas the medoids of a corresponding solution to the (k)-medoids problem are actual images (Figure 2, right).

Figure 2: The cluster centers in (k)-means are often not easily interpretable, whereas they are actual datapoints in (k)-medoids. Shown are cluster centers for a subset of ImageNet with (k = 1) with (k)-means (top) and (k)-medoids (bottom). The mean of the dataset is the average per-pixel color, whereas the medoid is an image of a bee.

The (k)-medoids problem also supports arbitrary distance metrics, in contrast with (k)-means which usually requires (L_2) distance for efficiency. In fact, you’re allowed to use any pairwise dissimilarity function with (k)-medoids — your dissimilarity function need not even satisfy the properties of a metric. It can be asymmetric, negative, and violate the triangle inequality. In practice, allowing for arbitrary dissimilarity metrics enables the clustering of “exotic” objects like strings, natural language, trees, graphs, and more — without needing to embed these objects in a vector space first.

The advantages of (k)-medoids don’t stop there. Because the (k)-medoids problem supports arbitrary distance functions, the clustering can often be more robust to outliers if you’re using robust distance metrics. The (L_1) distance metric, for example, is more robust to outliers than the (L_2) distance metric; in one dimension, the (L_1) minimizer is the median of your datapoints whereas the (L_2) minimizer is the mean.

Despite all of these advantages, (k)-means is much more widely used than (k)-medoids, largely due to its much more favorable runtime. The best (k)-means algorithms scale linearly in dataset size, i.e., have (O(n)) complexity, whereas, until now, the best (k)-medoids algorithms scaled quadratically in dataset size, i.e., had (O(n^2)) complexity.

In our NeurIPS paper, BanditPAM, we reduced the complexity of the best known (k)-medoids algorithm from (O(n^2)) to (O(ntext{log}n)). This complexity almost matches the complexity of standard (k)-means algorithms — and now, you get all the benefits of (k)-medoids on top. We’ve also released a high-performance implementation of our algorithm written in C++ for speed but callable from Python via python bindings; (texttt{pip install banditpam}) and you’re ready to go! Our algorithm’s interface matches that of (texttt{sklearn.cluster.KMeans}) and can be used with a simple 2-line change. You can also implement your own distance metrics, interpret cluster centers, and cluster structured data!

BanditPAM: Almost Linear Time (k)-medoids Clustering via Multi-Armed Bandits

How does our algorithm, BanditPAM, work? Our claim is that we match the prior state-of-the-art solutions in clustering quality by recovering the exact same solution and reduce the complexity from (O(n^2)) to (O(ntext{log}n)). But is this reduction in complexity just “for free”?

To discuss how BanditPAM works, we first need to discuss its predecessor, the Partitioning Around Medoids (PAM) algorithm. The PAM algorithm, first proposed in 19901, is a greedy solution to the (k)-medoids problem. PAM is broken into two steps: the BUILD step and the SWAP step.

In the BUILD step, each of the (k) medoids is greedily initialized one by one. More concrete, PAM considers all possible datapoints as “candidate” medoids. For every candidate medoid, we compute the change in the overall loss if we were to add that candidate to the set of medoids conditioned on the previously assigned medoids being fixed. This results in an (O(n^2)) computational complexity since we need to compute every pairwise distance.

[In the SWAP step, we consider all (kn) (medoid, non-medoid) pairs and the change in loss that would be induced if we were to swap the first element of the pair out of the medoid set in favor of the second. Again, this procedure incurs an (O(n^2)) time complexity (really (O(kn^2))).

Figure 3: The (k)-medoids algorithm in action. In the BUILD step, each medoid is assigned greedily, one-by-one. In the SWAP step, we consider swapping medoid assignments to see if we can lower the overall loss.

Our fundamental insight was that for each step of the PAM algorithm, we don’t actually need to compute the distance from each point to all other (n) points. Instead, we can just sample these distances!

Consider, for example, the problem of assigning the first medoid at the beginning of the BUILD step. PAM would go through all (n) points and, for each point, compute its distance to every other point. We realized that, for each candidate, we only needed to compute the distance to (O(text{log}n)) other points. By intelligently choosing which distances to compute, we can save a lot of unnecessary computation. Formally, we reduce the problem of assigning the first medoid to a multi-armed bandit problem, as demonstrated in Figure 4. In multi-armed bandit problems, our objective is to identify the best action to take — also referred to as the best arm to pull — when actions are independent and have stochastic returns.

Figure 4: PAM (top) computes every pairwise distance for each candidate medoid. BanditPAM (bottom) only samples the pairwise distances. With just a few samples, we see that the purple point is a better candidate than the green point since the purple arrows are, on average, shorter than the green ones.

It turns out that all steps of the PAM algorithm can also be reduced to multi-armed bandit problems. In each part of the BUILD step, we still view each candidate datapoint as an arm. Now, however, pulling an arm corresponds to computing the induced change in loss for a random datapoint if we were to add the candidate to the set of medoids, conditioned on the previous medoids already being assigned. In each SWAP step, we view each (medoid, non-medoid) pair as an arm and pulling an arm corresponds to computing the induced change in loss on a random datapoint if we were to perform the swap. With these modifications, the original PAM algorithm is now reformulated as a sequence of best-arm identification problems. This reformulation reduces every step of the PAM algorithm from (O(n^2)) to (O(nlogn)).

Now, if you’re familiar with multi-armed bandits, you might protest. Our algorithm is a randomized algorithm and can sometimes return an incorrect result. In the full paper, we show that the probability of getting a “wrong” answer is very small. In practice, this means that users of our algorithm don’t have to worry and will almost always get the same answer as the original PAM algorithm.

The BanditPAM algorithm is an (O(ntext{log}n)) algorithm that matches prior state-of-the-art algorithms in clustering quality and almost matches the complexity of popular (k)-means algorithms. Want to try out BanditPAM? Run (texttt{pip3 install banditpam}) and jump to our examples.

Figure 5: A formal proof that (k)-medoids is superior to (k)-means in every way.


This blog post is based on the paper: BanditPAM: Almost Linear Time (k)-medoids Clustering via Multi-Armed Bandits. NeurIPS 2021.

A special thanks to my collaborators on this project, Martin Jinye Zhang, James Mayclin, Sebastian Thrun, Chris Piech, and Ilan Shomorony, as well as the reviewers of this blog post, Drew A. Hudson and Sidd Karamcheti.

  1. Kaufman, Leonard; Rousseeuw, Peter J. (1990-03-08), “Partitioning Around Medoids (Program PAM)”, Wiley Series in Probability and Statistics, Hoboken, NJ, USA: John Wiley & Sons, Inc., pp. 68–125] 

Read More

Stanford AI Lab Papers and Talks at NeurIPS 2021

Stanford AI Lab Papers and Talks at NeurIPS 2021

The Thirty-fifth Conference on Neural Information Processing Systems (NeurIPS) 2021 is being hosted virtually from Dec 6th – 14th. We’re excited to share all the work from SAIL that’s being presented at the main conference, at the Datasets and Benchmarks track and the various workshops, and you’ll find links to papers, videos and blogs below.

Some of the members in our SAIL community also serve as co-organizers of several exciting workshops that will take place on Dec 13-14, so we hope you will check them out!

Feel free to reach out to the contact authors and the workshop organizers directly to learn more about the work that’s happening at Stanford!

Main Conference

Improving Compositionality of Neural Networks by Decoding Representations to Inputs

Authors: Mike Wu, Noah Goodman, Stefano Ermon


Links: Paper

Keywords: generative models, compositionality, decoder

Reverse engineering recurrent neural networks with Jacobian switching linear dynamical systems

Authors: Jimmy T.H. Smith, Scott W. Linderman, David Sussillo


Links: Paper | Website

Keywords: recurrent neural networks, switching linear dynamical systems, interpretability, fixed points

Compositional Transformers for Scene Generation

Authors: Drew A. Hudson, C. Lawrence Zitnick


Links: Paper | Github

Keywords: GANs, transformers, compositionality, scene synthesis

Combining Recurrent, Convolutional, and Continuous-time Models with Linear State Space Layers

Authors: Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, Chris Ré


Links: Paper

Keywords: recurrent neural networks, rnn, continuous models, state space, long range dependencies, sequence modeling

Emergent Communication of Generalizations

Authors: Jesse Mu, Noah Goodman


Links: Paper | Video

Keywords: emergent communication, multi-agent communication, language grounding, compositionality

ELLA: Exploration through Learned Language Abstraction

Authors: Suvir Mirchandani, Siddharth Karamcheti, Dorsa Sadigh


Links: Paper | Video

Keywords: instruction following, reward shaping, reinforcement learning

CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation

Authors: Yusuke Tashiro, Jiaming Song, Yang Song, Stefano Ermon


Links: Paper | Website

Keywords: score-based generative modeling, time series imputation

Confidence-Aware Imitation Learning from Demonstrations with Varying Optimality

Authors: Songyuan Zhang, Zhangjie Cao, Dorsa Sadigh, Yanan Sui


Links: Paper | Video | Website

Keywords: imitation learning, learning from demonstration, learning from suboptimal demonstrations

Explaining heterogeneity in medial entorhinal cortex with task-driven neural networks

Authors: Aran Nayebi, Alexander Attinger, Malcolm G. Campbell, Kiah Hardcastle, Isabel I.C. Low, Caitlin S. Mallory, Gabriel C. Mel, Ben Sorscher, Alex H. Williams, Surya Ganguli, Lisa M. Giocomo, Daniel L.K. Yamins


Award nominations: Spotlight Presentation

Links: Paper | Website

Keywords: neural coding, medial entorhinal cortex, grid cells, biologically-inspired navigation, path integration, recurrent neural networks

On the theory of reinforcement learning with once-per-episode feedback

Authors: Niladri Chatterji, Aldo Pacchiano, Peter Bartlett, Michael Jordan


Keywords: theoretical reinforcement learning, binary rewards, non-markovian rewards

HyperSPNs: Compact and Expressive Probabilistic Circuits

Authors: Andy Shih, Dorsa Sadigh, Stefano Ermon


Links: Paper | Video | Website

Keywords: generative models, tractable probabilistic models, sum product networks, probabilistic circuits

COMBO: Conservative Offline Model-Based Policy Optimization

Authors: Tianhe Yu*, Aviral Kumar*, Rafael Rafailov, Aravind Rajeswaran, Sergey Levine, Chelsea Finn


Links: Paper

Keywords: offline reinforcement learning, model-based reinforcement learning, deep reinforcement learning

Conservative Data Sharing for Multi-Task Offline Reinforcement Learning

Authors: Tianhe Yu*, Aviral Kumar*, Yevgen Chebotar, Karol Hausman, Sergey Levine, Chelsea Finn


Links: Paper

Keywords: offline reinforcement learning, multi-task reinforcement learning, deep reinforcement learning

Autonomous Reinforcement Learning via Subgoal Curricula

Authors: Archit Sharma, Abhishek Gupta, Sergey Levine, Karol Hausman, Chelsea Finn


Links: Paper | Website

Keywords: reinforcement learning, curriculum, autonomous learning, reset-free reinforcement learning

Lossy Compression for Lossless Prediction

Authors: Yann Dubois, Benjamin Bloem-Reddy, Karen Ullrich Chris J. Maddison


Award nominations: Spotlight Presentation

Links: Paper | Video | Website

Keywords: compression, invariances, information theory, machine learning, self-supervised learning

Capturing implicit hierarchical structure in 3D biomedical images with self-supervised hyperbolic representations

Authors: Joy Hsu, Jeffrey Gu, Gong-Her Wu, Wah Chiu, Serena Yeung


Links: Paper

Keywords: hyperbolic representations, hierarchical structure, biomedical

Estimating High Order Gradients of the Data Distribution by Denoising

Authors: Chenlin Meng, Yang Song, Wenzhe Li, Stefano Ermon


Keywords: score matching, langevin dynamics, denoising, generative modeling

Universal Off-Policy Evaluation

Authors: Yash Chandak, Scott Niekum, Bruno Castro da Silva, Erik Learned-Miller, Emma Brunskill, Philip Thomas


Links: Paper | Website

Keywords: metrics, risk, distribution, cdf, off-policy evaluation, ope, reinforcement learning, counterfactuals, high-confidence bounds, confidence intervals

Evidential Softmax for Sparse Multimodal Distributions in Deep Generative Models

Authors: Phil Chen, Masha Itkina, Ransalu Senanayake, Mykel J. Kochenderfer


Links: Paper

Keywords: deep learning or neural networks, sparsity and feature selection, variational inference, (application) natural language and text processing

Provable Guarantees for Self-Supervised Deep Learning with Spectral Contrastive Loss

Authors: Jeff Z. HaoChen, Colin Wei, Adrien Gaidon, Tengyu Ma


Links: Paper

Keywords: deep learning theory, unsupervised learning theory, representation learning theory

Provable Model-based Nonlinear Bandit and Reinforcement Learning: Shelve Optimism, Embrace Virtual Curvature

Authors: Kefan Dong, Jiaqi Yang, Tengyu Ma


Links: Paper | Video

Keywords: nonlinear bandits, online learning, deep reinforcement learning theory, sequential rademacher complexity

Decrypting Cryptic Crosswords: Semantically Complex Wordplay Puzzles as a Target for NLP

Authors: Joshua Rozner, Christopher Potts, Kyle Mahowald


Links: Paper | Website

Keywords: compositionality in language, curriculum learning, meta-linguistics, systematicity, generalization

Design of Experiments for Stochastic Contextual Linear Bandits

Authors: Andrea Zanette*, Kefan Dong*, Jonathan Lee*, Emma Brunskill


Links: Paper

Keywords: linear bandits, design of experiments

Provable Benefits of Actor-Critic Methods for Offline Reinforcement Learning

Authors: Andrea Zanette, Martin J. Wainwright, Emma Brunskill


Links: Paper

Keywords: offline rl, mirror descent, bellman closure

A Topological Perspective on Causal Inference

Authors: Duligur Ibeling, Thomas Icard


Links: Paper

Keywords: causal inference, topological learning theory

Adversarial Training Helps Transfer Learning via Better Representations

Authors: Zhun Deng, Linjun Zhang, Kailas Vodrahalli, Kenji Kawaguchi, James Zou


Links: Paper

Keywords: transfer learning, adversarial training

Widening the Pipeline in Human-Guided Reinforcement Learning with Explanation and Context-Aware Data Augmentation

Authors: Lin Guan,Mudit Verma,Sihang Guo,Ruohan Zhang,Subbarao Kambhampati


Award nominations: Spotlight

Links: Paper | Website

Keywords: human-in-the-loop reinforcement learning, evaluative feedback, saliency map, visual explanation

Machine versus Human Attention in Deep Reinforcement Learning Tasks

Authors: Sihang Guo, Ruohan Zhang, Bo Liu, Yifeng Zhu, Dana Ballard, Mary Hayhoe, Peter Stone


Links: Paper

Keywords: deep reinforcement learning, interpretability, attention, eye tracking

Play to Grade: Testing Coding Games as Classifying Markov Decision Process

Authors: Allen Nie, Emma Brunskill, Chris Piech


Links: Paper | Website

Keywords: reinforcement learning, computational education, collaborative training, markov decision process

The Value of Information When Deciding What to Learn

Authors: Dilip Arumugam, Benjamin Van Roy


Links: Paper

Keywords: exploration, information theory, multi-armed bandits, reinforcement learning

[Diversity Matters When Learning From Ensembles](

Authors: Giung Nam*, Jongmin Yoon*, Yoonho Lee, Juho Lee


Links: [Paper]( | Website

Keywords: deep ensembles, knowledge distillation, calibration, output diversified sampling, batchensemble

Reinforcement Learning with State Observation Costs in Action-Contingent Noiselessly Observable Markov Decision Processes

Authors: HyunJi Nam, Scott Fleming, Emma Brunskill


Links: Paper | Website

Keywords: reinforcement learning, observation cost, markov decision process, mdp, partially observable markov decision process, pomdp, probably approximately correct, pac, healthcare, health care

Meta-learning with an Adaptive Task Scheduler

Authors: Huaxiu Yao, Yu Wang, Ying Wei, Peilin Zhao, Mehrdad Mahdavi, Defu Lian, Chelsea Finn


Links: Paper

Keywords: adaptive task scheduler, meta-learning, sampling

Spatial-Temporal Super-Resolution of Satellite Imagery via Conditional Pixel Synthesis

Authors: Yutong He, Dingjie Wang, Nicholas Lai, William Zhang, Chenlin Meng, Marshall Burke, David B. Lobell, Stefano Ermon


Links: Paper | Video | Website

Keywords: remote sensing, super-resolution, generative models

Scatterbrain: Unifying Sparse and Low-rank Attention

Authors: Beidi Chen*, Tri Dao*, Eric Winsor, Zhao Song, Atri Rudra, Christopher Ré.


Links: Paper

Keywords: efficient attention, sparse, low-rank

BCD Nets: Scalable Variational Approaches for Bayesian Causal Discovery

Authors: Chris Cundy, Aditya Grover, Stefano Ermon


Keywords: causal inference, variational inference

Calibrating Predictions to Decisions: A Novel Approach to Multi-Class Calibration

Authors: Shengjia Zhao, Michael P Kim, Roshni Sahoo, Tengyu Ma, Stefano Ermon


Links: Paper

Keywords: calibration, decision making under uncertainty

Beyond Pinball Loss: Quantile Methods for Calibrated Uncertainty Quantification

Authors: Youngseog Chung, Willie Neiswanger, Ian Char, Jeff Schneider


Links: Paper | Website

Keywords: uncertainty quantification, uq, quantile regression, pinball loss

Causal Abstractions of Neural Networks

Authors: Atticus Geiger*, Hanson Lu*, Thomas Icard, Christopher Potts


Links: Paper

Keywords: interpretability, analysis, nlp, causality

Generalized Shape Metrics on Neural Representations

Authors: Alex H Williams, Erin Kunz, Simon Kornblith, Scott Linderman


Keywords: representational similarity analysis, neural representations, shape analysis, metric space

D2C: Diffusion-Denoising Models for Few-shot Conditional Generation

Authors: Abhishek Sinha*, Jiaming Song*, Chenlin Meng, Stefano Ermon


Links: Paper | Website

Keywords: generative modeling, contrastive learning, conditional generation

Combiner: Full Attention Transformer with Sparse COmputation Cost

Authors: Hongyu Ren, Hanjun Dai, Zihang Dai, Mengjiao Yang, Jure Leskovec, Dale Schuurmans, Bo Dai


Links: Paper

Keywords: efficient transformer

Maximum Likelihood Training of Score-Based Diffusion Models

Authors: Yang Song, Conor Durkan, Iain Murray, Stefano Ermon


Award nominations: Spotlight presentation

Links: Paper

Keywords: score-based generative models, denoising score matching, diffusion models, maximum likelihood training

Contrastive Reinforcement Learning of Symbolic Reasoning Domains

Authors: Gabriel Poesia, WenXin Dong, Noah Goodman


Keywords: reinforcement learning, education, contrastive learning, symbolic reasoning

Equivariant Manifold Flows

Authors: Isay Katsman, Aaron Lou, Derek Lim, Qingxuan Jiang, Ser Nam Lim, Christopher M. De Sa


Links: Paper | Website

Keywords: manifold, normalizing flow, equivariant, invariant

Lower Bounds on Metropolized Sampling Methods for Well-Conditioned Distributions

Authors: Yin Tat Lee, Ruoqi Shen, Kevin Tian


Award nominations: Oral presentation

Links: Paper | Video

Keywords: sampling, lower bounds, langevin dynamics, hamiltonian monte carlo

List-Decodable Mean Estimation in Nearly-PCA Time

Authors: Ilias Diakonikolas, Daniel M. Kane, Daniel Kongsgaard, Jerry Li, Kevin Tian


Award nominations: Spotlight presentation

Links: Paper

Keywords: robust statistics, semidefinite programming, mixture models

Robust Regression Revisited: Acceleration and Improved Estimation Rates

Authors: Arun Jambulapati, Jerry Li, Tselil Schramm, Kevin Tian


Links: Paper

Keywords: robust statistics, regression, generalized linear models, acceleration, sum of squares methods

Learning with User-Level Privacy

Authors: Daniel Levy*, Ziteng Sun*, Kareem Amin, Satyen Kale, Alex Kulesza, Mehryar Mohri, Ananda Theertha Suresh


Links: Paper

Keywords: differential privacy user-level

Adapting to Function Difficulty and Growth Conditions in Private Optimization

Authors: Hilal Asi*, Daniel Levy*, John C. Duchi


Links: Paper

Keywords: differential privacy adaptivity optimization

Imitation with Neural Density Models

Authors: Kuno Kim, Akshat Jindal, Yang Song, Jiaming Song, Yanan Sui, Stefano Ermon


Links: Paper

Keywords: rl; imitation learning; density estimation

Why Do Pretrained Language Models Help in Downstream Tasks? An Analysis of Head and Prompt Tuning

Authors: Colin Wei, Sang Michael Xie, Tengyu Ma


Links: Paper

Keywords: nlp pretraining, theoretical analysis

Safe Reinforcement Learning by Imagining the Near Future

Authors: Garrett Thomas, Yuping Luo, Tengyu Ma


Links: Paper

Keywords: safe exploration, model-based rl

Pseudo-Spherical Contrastive Divergence

Authors: Lantao Yu, Jiaming Song, Yang Song, Stefano Ermon


Links: Paper

Keywords: deep generative models, energy-based models, proper scoring rules

IQ-Learn: Inverse soft-Q Learning for Imitation

Authors: Divyansh Garg, Shuvam Chakraborty, Chris Cundy, Jiaming Song, Stefano Ermon


Award nominations: Spotlight

Links: Paper | Website

Keywords: reinforcement learning, imitation learning, inverse reinforcement learning, statistical learning, energy-based models

Intrinsic Dimension, Persistent Homology and Generalization in Neural Networks

Authors: Tolga Birdal ~Tolga_Birdal3 , Aaron Lou, Leonidas Guibas, Umut Simsekli


Links: Paper | Website

Keywords: generalization, persistent homology, intrinsic dimension, deep networks

Baleen: Robust Multi-Hop Reasoning at Scale via Condensed Retrieval

Authors: Omar Khattab, Christopher Potts, Matei Zaharia


Award nominations: Spotlight paper

Links: Paper | Blog Post

Keywords: neural retrieval, multi-hop question answering, claim verification, reasoning, colbert

Datasets and Benchmarks Track


This year, multiple members of the SAIL community are also involved in great workshops that will take place on Dec 13-14. We hope you’ll check them out!

Machine Learning for Structural Biology Workshop (Dec 13)

Organizers: Namrata Anand, Bonnie Berger, Wouter Boomsma, Erika DeBenedictis, Stephan Eismann, John Ingraham, Sergey Ovchinnikov, Roshan Rao, Raphael Townshend and Ellen Zhong

Controllable Generative Modeling in Language and Vision (CtrlGen Workshop) (Dec 13)

Organizers: Steven Y. Feng, Drew A. Hudson, Anusha Balakrishnan, Varun Gangal, Dongyeop Kang, Tatsunori Hashimoto and Joel Tetreault

DistShift Workshop (Dec 13)

Organizers: Shiori Sagawa, Pang Wei Koh, Fanny Yang, Hongseok Namkoong, Jiashi Feng, Kate Saenko, Percy Liang, Sarah Bird and Sergey Levine

Data-centric AI Workshop (Dec 14)

Organizers: Andrew Ng, Lora Aroyo, Cody Coleman, Greg Diamos, Vijay Janapa Reddi, Joaquin Vanschoren,Carole-Jean Wu and Sharon Zhou

Physical Reasoning and Inductive Biases for the Real World Workshop (Dec 14)

Organizers: Krishna Murthy Jatavallabhula, Rika Antonova, Kevin Smith, Hsiao-Yu (Fish) Tung, Florian Shkurti, Jeannette Bohg and Josh Tenenbaum

Workshop Papers

  • How Does Contrastive Pre-training Connect Disparate Domains? by Kendrick Shen*, Robbie Jones*, Ananya Kumar*, Sang Michael Xie*, Percy Liang (DistShift Workshop)
  • Optimal Representations for Covariate Shifts by Yann Dubois, Yangjun Ruan, Chris J. Maddison (DistShift Workshop)
  • [Correct-N-Contrast: a Contrastive Approach for Improving Robustness to Spurious Correlations] by Michael Zhang, Nimit S. Sohoni, Hongyang R. Zhang, Chelsea Finn, Christopher Ré (DistShift Workshop)
  • Calibrated Ensembles: A Simple Way to Mitigate ID-OOD Accuracy Tradeoffs by Ananya Kumar, Aditi Raghunathan, Tengyu Ma, Percy Liang (DistShift Workshop)
  • Sharp Bounds for Federated Averaging (Local SGD) and Continuous Perspective by Margalit Glasgow*, Honglin Yuan*, Tengyu Ma (New Frontiers in Federated Learning)
  • What Matters in Learning from Offline Human Demonstrations for Robot Manipulation | Blog Post | Video | Website by Ajay Mandlekar, Danfei Xu, Josiah Wong, Soroush Nasiriany, Chen Wang, Rohun Kulkarni, Li Fei-Fei, Silvio Savarese, Yuke Zhu, Roberto Martín-Martín (Offline Reinforcement Learning Workshop)
  • An Algorithmic Theory of Metacognition in Minds and Machines | Blog Post by Rylan Schaeffer (Metacognition in the Age of AI: Challenges and Opportunities)
  • Beyond Ads: Sequential Decision-Making Algorithms in Public Policy by Peter Henderson, Ben Chugg, Brandon Anderson, Daniel E. Ho (Workshop on Causal Inference Challenges in Sequential Decision Making)
  • Tracking Urbanization in Developing Regions withRemote Sensing Spatial-Temporal Super-Resolution by Yutong He*, William Zhang*, Chenlin Meng, Marshall Burke, David B. Lobell, Stefano Ermon (Workshop on Machine Learning for the Developing World (ML4D))
  • Likelihood-free Density Ratio Acquisition Functions are not Equivalent to Expected Improvements by Jiaming Song, Stefano Ermon (Bayesian Deep Learning Workshop)
  • Exploiting Proximity Search and Easy Examples to Select Rare Events by Daniel Kang, Alex Derhacobian, Kaoru Tsuji, Trevor Hebert, Peter Bailis, Tadashi Fukami, Tatsunori Hashimoto, Yi Sun, Matei Zaharia (Data Centric AI workshop)

We look forward to seeing you at NeurIPS 2021!

Read More

Stanford AI Lab Papers at EMNLP/CoNLL 2021

Stanford AI Lab Papers at EMNLP/CoNLL 2021

The 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP 2021)
will take place next week, colocated with CoNLL 2021. We’re excited to share all the work from SAIL that will be presented, and you’ll find links to papers, videos and blogs below. Feel free to reach out to the contact authors directly to learn more about the work that’s happening at Stanford!

List of Accepted Papers

Calibrate your listeners! Robust communication-based training for pragmatic speakers

Authors: Rose E. Wang, Julia White, Jesse Mu, Noah D. Goodman


Links: Paper | Video

Keywords: language generation, pragmatics, communication-based training, calibration, uncertainty

Cross-Domain Data Integration for Named Entity Disambiguation in Biomedical Text

Authors: Maya Varma, Laurel Orr, Sen Wu, Megan Leszczynski, Xiao Ling, Christopher Ré


Links: Paper | Video

Keywords: named entity disambiguation, biomedical text, rare entities, data integration

ContractNLI: A Dataset for Document-level Natural Language Inference for Contracts

Authors: Yuta Koreeda, Christopher D. Manning


Links: Paper | Website

Keywords: natural language inference, contract, law, legal, dataset

Venue: The Findings of EMNLP 2021

The Emergence of the Shape Bias Results from Communicative Efficiency

Authors: Eva Portelance, Michael C. Frank, Dan Jurafsky, Alessandro Sordoni, Romain Laroche


Links: Paper | Website

Keywords: emergent communication, shape bias, multi-agent reinforcement learning, language learning, language acquisition

Conference: CoNLL

LM-Critic: Language Models for Unsupervised Grammatical Error Correction

Authors: Michihiro Yasunaga, Jure Leskovec, Percy Liang.


Links: Paper | Blog Post | Website

Keywords: language model, grammatical error correction, unsupervised translation

Sensitivity as a complexity measure for sequence classification tasks

Authors: Michael Hahn, Dan Jurafsky, Richard Futrell


Links: Paper

Keywords: decision boundaries, computational complexity

Distributionally Robust Multilingual Machine Translation

Authors: Chunting Zhou*, Daniel Levy*, Marjan Ghazvininejad, Xian Li, Graham Neubig


Keywords: machine translation, robustness, distribution shift, dro, cross-lingual transfer

Learning from Limited Labels for Long Legal Dialogue

Authors: Jenny Hong, Derek Chong, Christopher D. Manning


Keywords: legal nlp, information extraction, weak supervision

Capturing Logical Structure of Visually Structured Documents with Multimodal Transition Parser

Authors: Yuta Koreeda, Christopher D. Manning


Links: Paper | Website

Keywords: legal, preprocessing

Workshop: Natural Legal Language Processing Workshop

We look forward to seeing you at EMNLP/CoNLL 2021!

Read More

Stanford AI Lab Papers at CoRL 2021

Stanford AI Lab Papers at CoRL 2021

The Conference on Robot Learning (CoRL 2021)
will take place next week. We’re excited to share all the work from SAIL that will be presented, and you’ll find links to papers, videos and blogs below. Feel free to reach out to the contact authors directly to learn more about the work that’s happening at Stanford!

List of Accepted Papers

LILA: Language-Informed Latent Actions

Authors: Siddharth Karamcheti*, Megha Srivastava*, Percy Liang, Dorsa Sadigh


Keywords: natural language, shared autonomy, human-robot interaction

BEHAVIOR: Benchmark for Everyday Household Activities in Virtual, Interactive, and Ecological Environments

Authors: Sanjana Srivastava*, Chengshu Li*, Michael Lingelbach*, Roberto Martín-Martín*, Fei Xia, Kent Vainio, Zheng Lian, Cem Gokmen, Shyamal Buch, C. Karen Liu, Silvio Savarese, Hyowon Gweon, Jiajun Wu, Li Fei-Fei


Links: Paper | Website

Keywords: embodied ai, benchmarking, household activities

Co-GAIL: Learning Diverse Strategies for Human-Robot Collaboration

Authors: Chen Wang, Claudia Pérez-D’Arpino, Danfei Xu, Li Fei-Fei, C. Karen Liu, Silvio Savarese


Links: Paper | Website

Keywords: learning for human-robot collaboration, imitation learning

DiffImpact: Differentiable Rendering and Identification of Impact Sounds

Authors: Samuel Clarke, Negin Heravi, Mark Rau, Ruohan Gao, Jiajun Wu, Doug James, Jeannette Bohg


Links: Paper | Website

Keywords: differentiable sound rendering, auditory scene analysis

Example-Driven Model-Based Reinforcement Learning for Solving Long-Horizon Visuomotor Tasks

Authors: Bohan Wu, Suraj Nair, Li Fei-Fei*, Chelsea Finn*


Links: Paper

Keywords: model-based reinforcement learning, long-horizon tasks

GRAC: Self-Guided and Self-Regularized Actor-Critic

Authors: Lin Shao, Yifan You, Mengyuan Yan, Shenli Yuan, Qingyun Sun, Jeannette Bohg


Links: Paper | Website

Keywords: deep reinforcement learning, q-learning

Influencing Towards Stable Multi-Agent Interactions

Authors: Woodrow Z. Wang, Andy Shih, Annie Xie, Dorsa Sadigh


Award nominations: Oral presentation

Links: Paper | Website

Keywords: multi-agent interactions, human-robot interaction, non-stationarity

Learning Language-Conditioned Robot Behavior from Offline Data and Crowd-Sourced Annotation

Authors: Suraj Nair, Eric Mitchell, Kevin Chen, Brian Ichter, Silvio Savarese, Chelsea Finn


Links: Paper | Website

Keywords: natural language, offline rl, visuomotor manipulation

Learning Multimodal Rewards from Rankings

Authors: Vivek Myers, Erdem Bıyık, Nima Anari, Dorsa Sadigh


Links: Paper | Video | Website

Keywords: reward learning, active learning, learning from rankings, multimodality

Learning Reward Functions from Scale Feedback

Authors: Nils Wilde*, Erdem Bıyık*, Dorsa Sadigh, Stephen L. Smith


Links: Paper | Video | Website

Keywords: preference-based learning, reward learning, active learning, scale feedback

Learning to Regrasp by Learning to Place

Authors: Shuo Cheng, Kaichun Mo, Lin Shao


Links: Paper | Website

Keywords: regrasping, object placement, robotic manipulation

Learning to be Multimodal : Co-evolving Sensory Modalities and Sensor Properties

Authors: Rika Antonova, Jeannette Bohg


Links: Paper

Keywords: co-design, multimodal sensing, corl blue sky track

O2O-Afford: Annotation-Free Large-Scale Object-Object Affordance Learning

Authors: Kaichun Mo, Yuzhe Qin, Fanbo Xiang, Hao Su, Leonidas J. Guibas


Links: Paper | Video | Website

Keywords: robotic vision, object-object interaction, visual affordance

ObjectFolder: A Dataset of Objects with Implicit Visual, Auditory, and Tactile Representations

Authors: Ruohan Gao, Yen-Yu Chang, Shivani Mall, Li Fei-Fei, Jiajun Wu


Links: Paper | Video | Website

Keywords: object dataset, multisensory learning, implicit representations

Taskography: Evaluating robot task planning over large 3D scene graphs

Authors: Christopher Agia, Krishna Murthy Jatavallabhula, Mohamed Khodeir, Ondrej Miksik, Vibhav Vineet, Mustafa Mukadam, Liam Paull, Florian Shkurti


Links: Paper | Website

Keywords: robot task planning, 3d scene graphs, learning to plan, benchmarks

What Matters in Learning from Offline Human Demonstrations for Robot Manipulation

Authors: Ajay Mandlekar, Danfei Xu, Josiah Wong, Soroush Nasiriany, Chen Wang, Rohun Kulkarni, Li Fei-Fei, Silvio Savarese, Yuke Zhu, Roberto Martín-Martín


Award nominations: Oral

Links: Paper | Blog Post | Video | Website

Keywords: imitation learning, offline reinforcement learning, robot manipulation

XIRL: Cross-embodiment Inverse Reinforcement Learning

Authors: Kevin Zakka, Andy Zeng, Pete Florence, Jonathan Tompson, Jeannette Bohg, Debidatta Dwibedi


Links: Paper | Website

Keywords: inverse reinforcement learning, imitation learning, self-supervised learning

iGibson 2.0: Object-Centric Simulation for Robot Learning of Everyday Household Tasks

Authors: Chengshu Li*, Fei Xia*, Roberto Martín-Martín*, Michael Lingelbach, Sanjana Srivastava, Bokui Shen, Kent Vainio, Cem Gokmen, Gokul Dharan, Tanish Jain, Andrey Kurenkov, C. Karen Liu, Hyowon Gweon, Jiajun Wu, Li Fei-Fei, Silvio Savarese


Links: Paper | Website

Keywords: simulation environment, embodied ai, virtual reality interface

Learning Feasibility to Imitate Demonstrators with Different Dynamics

Authors: Zhangjie Cao, Yilun Hao, Mengxi Li, Dorsa Sadigh


Keywords: imitation learning, learning from agents with different dynamics

We look forward to seeing you at CoRL 2021!

Read More

Selective Classification Can Magnify Disparities Across Groups

Selective Classification Can Magnify Disparities Across Groups

Selective classification, where models are allowed to “abstain” when they are uncertain about a prediction, is a useful approach for deploying models in settings where errors are costly. For example, in medicine, model errors can have life-or-death ramifications, but abstentions can be easily handled by backing off to a doctor, who then makes a diagnosis. Across a range of applications from vision 123 and NLP 45, even simple selective classifiers, relying only on model logits, routinely and often dramatically improve accuracy by abstaining. This makes selective classification a compelling tool for ML practitioners 67.

However, in our recent ICLR paper, we find that despite reliably improving average accuracy, selective classification can fail to improve and even hurt the accuracy over certain subpopulations of the data. As a motivating example, consider the task of diagnosing pleural effusion, or fluid in the lungs, from chest X-rays. Pleural effusion is often treated with a chest drain, so many pleural effusion cases also have chest drains, while most cases without pleural effusion do not have chest drains 8. While selective classification improves average accuracy for this task, we find that it does not appreciably improve accuracy on the most clinically relevant subgroup, or subpopulation, of the data: those that have pleural effusion but don’t yet have a chest drain, i.e. those that have pleural effusion but have not yet been treated for it. Practitioners, thus, should be wary of these potential failure modes of using selective classification in the wild.

Example of the spurious correlation setup. This patient has a pleural effusion (excess fluid in the lung), but does not yet have a chest drain. The model, relying on the presence of a chest drain to make a prediction, incorrectly predicts negative.

To further outline this critical failure mode of selective classification, we’ll first provide an overview of selective classification. We then demonstrate empirically that selective classification can hurt or fail to significantly improve accuracy on certain subgroups of the data. We next outline our theoretical results, which suggest that selective classification is rarely a good tool to resolve differences in accuracy between subgroups. And finally, suggest methods for building more equitable selective classifiers.

Selective classification basics

Imagine you are trying to build a model that classifies X-rays as either pleural effusion positive or negative. With standard classification, the model is required to either output positive or negative on each input. In contrast, a selective classifier can additionally abstain from making a prediction when it is not sufficiently confident in any class 91011. By abstaining, selective classifiers aim to avoid making predictions on examples they are likely to classify incorrectly, say a corrupted or difficult-to-classify X-ray, which increases their average accuracy.

Selective classification pipeline. The model makes the incorrect prediction of negative. However, the outputted confidence of 0.7 is less than the confidence threshold of 0.8, so the selective classifier abstains. Selective classifiers increase accuracy by abstaining on examples they would get wrong.

One key question in selective classification is how to choose which examples to abstain on. Selective classifiers can be viewed as two models: one that outputs a prediction (say, negative), and another that outputs a confidence in that prediction (say, 0.7 out of 1.) Whenever the confidence is above a certain (confidence) threshold, the selective classifier outputs the original prediction; for example, if the threshold were 0.6, the selective classifier would predict negative. Otherwise, the selective classifier abstains. In our work, we primarily use softmax response 11 to extract confidences: the confidence in a prediction is simply the maximum softmax probability over the possible classes.

Selective classifiers are typically measured in terms of the accuracy (also called selective accuracy) on predicted examples, and the coverage, or fraction of examples the selective classifier makes predictions on 12. We can tweak both coverage and accuracy by adjusting the confidence threshold: a lower threshold for making predictions increases the coverage, since the model’s confidence for more examples is sufficiently high. However, this tends to lower average accuracy, as the model is less confident on average in its predictions. In contrast, higher thresholds increase confidence required to make a prediction, reducing the coverage but generally increasing average accuracy.

Typically, researchers measure the performance of selective classifiers by plotting accuracy as a function of coverage. In particular, for each possible coverage (ranging from 0: abstain on everything to 1: predict on everything) they compute the maximum threshold that achieves that coverage, and then plot the accuracy at that threshold. One particularly useful reference point is the full-coverage accuracy: the accuracy of the selective classifier at coverage 1, which is the accuracy of the regular classifier.

For five datasets, we plot the average accuracy as a function of the coverage. Reading from high coverages to low coverages (right to left), as the confidence threshold increases, accuracy reliably increases. This is expected, since the model is more confident on average in its predictions at lower coverage, so more of them tend to be correct.

Selective classification can magnify accuracy disparities between subgroups

While prior work mostly focuses on average accuracy for selective classifiers, we instead focus on the accuracy of different subgroups of the data. In particular, we focus on datasets where models often latch onto spurious correlations. For example, in the above pleural effusion task, the model might learn to predict whether or not there is a chest drain, instead of directly diagnosing pleural effusion, because chest drains are highly correlated with pleural effusion; this correlation is spurious because not all pleural effusions have a chest drain. We consider subgroups that highlight this spurious correlation: two groups for when the spurious correlation gives the correct result (positive pleural effusion with chest drain, negative pleural effusion without a chest drain), and two groups when it does not (positive pleural effusion with no chest drain, negative pleural effusion with a chest drain). As a result, a model that learns this spurious correlation obtains high accuracy for the first two subgroups, but low accuracy for the latter two.

In principle, selective classification seems like a reasonable approach towards resolving these accuracy discrepancies between different subgroups of the data. Since we empirically see that selective classification reliably improves average accuracy, it must be more likely to cause a model to abstain when an example would be classified incorrectly. Incorrect examples disproportionately come from the lowest-accuracy subgroups of the data, suggesting that without bias in the confidence function, worst-group accuracy should increase faster than average accuracy.

To test this, we plot the accuracy-coverage curves over a range of tasks, including hair color classification (CelebA), bird type classification (Waterbirds), pleural effusion classification (CheXpert-device), toxicity classification (CivilComments) and natural language inference (MultiNLI). CelebA, Waterbirds, and MultiNLI use the same spurious correlation setup presented in 2. CivilComments exhibits the same spurious correlations as described in the WILDS benchmark 13. Finally, we created the CheXpert-device dataset by subsampling the original CheXpert dataset 3 such that the presence of a chest drain even more strongly correlates with pleural effusion.

Reading from right to left, while we see that as the coverage decreases the average accuracy reliably increases, the worst-group accuracies do not always increase, and exhibit a range of undesirable behaviors. On CelebA, worst-group accuracy actually decreases: this means the more confident predictions are more likely to be incorrect. For Waterbirds, CheXpert-device, and CivilComments, worst-group accuracy sometimes increases, but never by more than 10 points until the noisy low-coverage regime, and sometimes decreases. For MultiNLI, worst-group accuracy does slowly improve, but can’t even reach 80% until very low coverages.

These results highlight that practitioners should be wary: even if selective classification reliably increases average accuracy, it will not necessarily improve the accuracy of different subgroups.

Selective classification rarely overcomes accuracy disparities

To better understand why selective classification can sometimes hurt worst-group accuracy and does not reduce full-coverage accuracy disparities, we theoretically characterize for a broad class of distributions: (1) when does selective classification improve accuracy as the confidence threshold decreases and (2) when does selective classification disproportionately help the worst group.

At a high level, our analysis focuses on the margin, or the model’s confidence for a given prediction multiplied by -1 if that prediction was incorrect. Intuitively, the more negative the margin, the “worse” the prediction. Using only the margin distribution, we can recreate the accuracy-coverage curve by abstaining on density between the negative and positive threshold, and computing the fraction of remaining density that is correct.

The key result of our theoretical analysis is that the full-coverage accuracy of a subgroup dramatically impacts how well selective classification performs on that subgroup, which amplifies disparities. For a wide range of margin distributions, full-coverage accuracy and a property of the margin distribution we call left-log-concavity completely determine whether or not the accuracy of a selective classifier monotonically increases or decreases. When a margin distribution is left-log-concave, which many standard distributions (e.g. gaussians) are, accuracy monotonically increases when full-coverage accuracy is at least 50% and decreases otherwise.

Next steps

So far, we have painted a fairly bleak picture of selective classification: even though it reliably improves average accuracy, it can, both theoretically and empirically, exacerbate accuracy disparities between subgroups. There are still, however, mechanisms to improve selective classification, which we outline below.

One natural step towards improving selective classification is to develop confidence functions that allow selective classifiers to overcome accuracy disparities between groups. In our paper, we test the two most widely used methods: softmax response and Monte Carlo dropout 10. We consistently find that both are disproportionately overconfident on incorrect examples from the worst-groups. However, new confidence functions that are better calibrated across groups would likely resolve disparities 14, and is an important direction for future work.

In the short term, however, we find that the most promising method to improve worst-group accuracy with selective classification is to build selective classifiers on top of already-equitable models, or models that achieve similar full-coverage accuracies across the relevant subgroups. One method to train such models is group DRO, which minimizes the maximum loss over subgroups 2. We find empirically that selective classifiers trained with group DRO improve the accuracy of subgroups at roughly the same rate when they have the same accuracy at full coverage. However, group DRO is far from a perfect fix – it requires a priori knowledge of the relevant subgroups, and subgroup labels for each training example which may be costly to obtain. Nevertheless, it is a promising start, and developing more broadly applicable methods for training already-equitable models is a critical area for future work.

To conclude, despite the intuition that selective classification should improve worst-group accuracy, and selective classification’s ability to consistently improve average accuracy, common selective classifiers can severely exacerbate accuracy discrepancies between subgroups. We hope our work encourages practitioners to apply selective classification with caution, and in general focus on how different methods affect different subgroups of the data.


Thanks to the SAIL blog editors, Pang Wei Koh, and Shiori Sagawa for their helpful feedback on this blog post. This post is based off our ICLR 2021 paper:

Selective Classification Can Magnify Disparities Across Groups. Erik Jones*, Shiori Sagawa* Pang Wei Koh*, Ananya Kumar, and Percy Liang. ICLR 2021.

  1. Ziwei Liu, Ping Luo, Xiaogang Wang, and Xiaoou Tang. Deep learning face attributes in the wild. In Proceedings of the IEEE International Conference on Computer Vision, pp. 3730–3738, 2015. 

  2. Shiori Sagawa, Pang Wei Koh, Tatsunori B. Hashimoto, and Percy Liang. Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization. In International Conference on Learning Representations (ICLR), 2020.  2 3

  3. Jeremy Irvin, Pranav Rajpurkar, Michael Ko, Yifan Yu, Silviana Ciurea-Ilcus, Chris Chute, Henrik Marklund, Behzad Haghgoo, Robyn Ball, Katie Shpanskaya, et al. Chexpert: A large chest radiograph dataset with uncertainty labels and expert comparison. In Association for the Advancement of Artificial Intelligence (AAAI), volume 33, pp. 590–597, 2019.  2

  4. Daniel Borkan, Lucas Dixon, Jeffrey Sorensen, Nithum Thain, and Lucy Vasserman. Nuanced metrics for measuring unintended bias with real data for text classification. In World Wide Web (WWW), pp. 491–500, 2019. 

  5. Adina Williams, Nikita Nangia, and Samuel Bowman. A broad-coverage challenge corpus for sentence understanding through inference. In Association for Computational Linguistics (ACL), pp. 1112–1122, 2018. 

  6. Yonatan Giefman and Ran El-Yaniv. SelectiveNet: A deep neural network with an integrated reject option. In International Conference on Machine Learning (ICML), 2019. 

  7. Hussein Mozannar and David Sontag. Consistent estimators for learning to defer to an expert. In International Conference on Machine Learning (ICML), 2020. 

  8. Luke Oakden-Rayner, Jared Dunnmon, Gustavo Carneiro, and Christopher Ré. Hidden stratification causes clinically meaningful failures in machine learning for medical imaging. In Proceedings of the ACM Conference on Health, Inference, and Learning, pp. 151–159, 2020. 

  9. C. K. Chow. An optimum character recognition system using decision functions. In IRE Transactions on Electronic Computers, 1957. 

  10. Yarin Gal and Zoubin Ghahramani. Dropout as a Bayesian approximation: Representing model uncertainty in deep learning. In International Conference on Machine Learning (ICML), 2016.  2

  11. Yonatan Geifman and Ran El-Yaniv. Selective classification for deep neural networks. In Advances in Neural Information Processing Systems (NeurIPS), 2017.  2

  12. Ran El-Yaniv and Yair Wiener. On the foundations of noise-free selective classification. Journal of Machine Learning Research (JMLR), 11, 2010. 

  13. Pang Wei Koh, Shiori Sagawa, Henrik Marklund, Sang Michael Xie, Marvin Zhang, Akshay Balsubramani, Weihua Hu, Michihiro Yasunaga, Richard Lanas Phillips, Irena Gao, Tony Lee, Etienne David, Ian Stavness, Wei Guo, Berton A. Earnshaw, Imran S. Haque, Sara Beery, Jure Leskovec, Anshul Kundaje, Emma Pierson, Sergey Levine, Chelsea Finn, and Percy Liang. WILDS: A benchmark of in-the-wild distribution shifts. arXiv, 2020. 

  14. Yoav Wald, Amir Feder, Daniel Greenfeld, and Uri Shalit. On Calibration and Out-of-domain Generalization. arXiv preprint arXiv:2102.10395, 2021. 

Read More

Stanford AI Lab Papers at ICCV 2021

Stanford AI Lab Papers at ICCV 2021

The International Conference on Computer Vision (ICCV 2021)
will be hosted virtually next week. We’re excited to share all the work from SAIL that will be presented, and you’ll find links to papers, videos and blogs below. Feel free to reach out to the contact authors directly to learn more about the work that’s happening at Stanford!

List of Accepted Papers

GLoRIA: A Multimodal Global-Local Representation Learning Framework for Label-efficient Medical Image Recognition

Authors: Mars Huang


Keywords: medical image, self-supervised learning, multimodal fusion

3D Shape Generation and Completion Through Point-Voxel Diffusion

Authors: Linqi Zhou, Yilun Du, Jiajun Wu


Links: Paper | Video | Website

Keywords: diffusion, shape generation

CAPTRA: CAtegory-level Pose Tracking for Rigid and Articulated Objects from Point Clouds

Authors: Yijia Weng*, He Wang*, Qiang Zhou, Yuzhe Qin, Yueqi Duan, Qingnan Fan, Baoquan Chen, Hao Su, Leonidas J. Guibas


Award nominations: Oral Presentation

Links: Paper | Video | Website

Keywords: category-level object pose tracking, articulated objects

Detecting Human-Object Relationships in Videos

Authors: Jingwei Ji, Rishi Desai, Juan Carlos Niebles


Links: Paper

Keywords: human-object relationships, video, detection, transformer, spatio-temporal reasoning

Geography-Aware Self-Supervised Learning

Authors: Kumar Ayush, Burak Uzkent, Chenlin Meng, Kumar Tanmay, Marshall Burke, David Lobell, Stefano Ermon


Links: Paper | Website

Keywords: self-supervised learning, contrastive learning, remote sensing, spatio-temporal, classification, object detection, segmentation

HuMoR: 3D Human Motion Model for Robust Pose Estimation

Authors: Davis Rempe, Tolga Birdal, Aaron Hertzmann, Jimei Yang, Srinath Sridhar, Leonidas Guibas


Award nominations: Oral Presentation

Links: Paper | Website

Keywords: 3d human pose estimation; 3d human motion; generative modeling

Learning Privacy-preserving Optics for Human Pose Estimation

Authors: Carlos Hinojosa, Juan Carlos Niebles, Henry Arguello


Links: Paper | Website

Keywords: computational photography; fairness, accountability, transparency, and ethics in vision; gestures and body pose

Learning Temporal Dynamics from Cycles in Narrated Video

Authors: Dave Epstein, Jiajun Wu, Cordelia Schmid, Chen Sun


Links: Paper | Website

Keywords: multi-modal learning, cycle consistency, video

Vector Neurons: A General Framework for SO(3)-Equivariant Networks

Authors: Congyue Deng, Or Litany, Yueqi Duan, Adrien Poulenard, Andrea Tagliasacchi, Leonidas Guibas


Links: Paper | Video | Website

Keywords: pointcloud network, rotation equivariance, rotation invariance

Neural Radiance for 4D View Synthesis and Video Processing

Authors: Yilun Du, Yinan Zhang, Hong-Xing Yu, Joshua B. Tenenbaum, Jiajun Wu


Links: Paper | Website

Keywords: 4d representation, neural rendering, video processing

Where2Act: From Pixels to Actions for Articulated 3D Objects

Authors: Kaichun Mo, Leonidas J. Guibas, Mustafa Mukadam, Abhinav Gupta, Shubham Tulsiani


Links: Paper | Website

Keywords: 3d computer vision, robotic vision, affordance learning, robot learning

Low-Shot Validation: Active Importance Sampling for Estimating Classifier Performance on Rare Categories

Authors: Fait Poms*, Vishnu Sarukkai*, Ravi Teja Mullapudi, Nimit S. Sohoni, William R. Mark, Deva Ramanan, Kayvon Fatahalian


Links: Paper | Blog | Video

Keywords: model evaluation, active learning

We look forward to seeing you at ICCV 2021!

Read More

Building Scalable, Explainable, and Adaptive NLP Models with Retrieval

Building Scalable, Explainable, and Adaptive NLP Models with Retrieval

Natural language processing (NLP) has witnessed impressive developments
in answering questions, summarizing or translating reports, and
analyzing sentiment or offensiveness. Much of this progress is owed to
training ever-larger language models, such
as T5 or GPT-3,
that use deep monolithic architectures to internalize how language is
used within text from massive Web crawls. During training, these models
distill the facts they read into implicit knowledge, storing in their
parameters not only the capacity to “understand” language tasks, but
also highly abstract knowledge representations of entities, events, and
facts the model needs for solving tasks.

Despite the well-publicized success of large language models, their
black-box nature hinders key goals of NLP. In particular, existing large
language models are generally:

  • Inefficient. Researchers continue to enlarge these models, leading
    to striking inefficiencies as the field already pushes past 1
    trillion parameters. This imposes a considerable environmental impact
    and its costs exclude all but a few large organizations from the
    ability to train—or in many cases even deploy—such models.

  • Opaque. They encode “knowledge” into model weights, synthesizing
    what they manage to memorize from training examples. This makes it
    difficult to discern what sources—if any—the model uses to make a
    prediction, a concerning problem in practice as these models
    frequently generate fluent yet untrue statements.

  • Static. They are expensive to update. We cannot efficiently adapt a
    GPT model trained on, say, Wikipedia text from 2019 so it reflects
    the knowledge encoded in the 2021 Wikipedia—or the latest snapshot
    of the medical preprint server medRXiv. In practice, adaptation often
    necessitates expensive retraining or fine-tuning on the new corpus.

This post explores an emerging alternative, Retrieval-based NLP, in
which models directly “search” for information in a text corpus to
exhibit knowledge, leveraging the representational strengths of language models
while addressing the challenges above. Such
models—including REALM, RAG, ColBERT-QA,
and Baleen—are
already advancing the state of the art for tasks like answering
open-domain questions and verifying complex claims, all with
architectures that back their predictions with checkable sources while
being 100–1000× smaller, and thus far cheaper to execute, than GPT-3. At
Stanford, we have shown that improving the expressivity and
supervision of scalable neural retrievers can lead to much stronger NLP
systems: for instance, ColBERT-QA improves answer correctness on open-QA
benchmarks by up to 16 EM points and Baleen improves the ability to
check complex claims on
correctly and with provenance, by up to 42 percentage points against existing work.

Retrieval-based NLP

Figure 1: An illustration comparing (a) black-box language models and (b) retrieval-oriented NLP models, the paradigm this post advocates for.

As Figure 1 illustrates, retrieval-based NLP methods view tasks as
exams: knowledge is encoded explicitly in the form of a text corpus like
Wikipedia, the medical literature, or a software’s API documentation. When
solving a language task, the model learns to search for pertinent passages
and to then use the retrieved information for crafting knowledgeable responses.
In doing so, retrieval helps decouple the capacity that language models have for
understanding text from how they store knowledge, leading to three key advantages.

Tackling Inefficiency. Retrieval-based models can be much smaller and
, and thus more environmentally friendly. Unlike black-box language models,
the parameters no longer need to store an ever-growing list of facts, as
such facts can be retrieved. Instead, we can dedicate those parameters
for processing language and solving tasks, leaving us with smaller
models that are highly effective. For instance, ColBERT-QA achieves
47.8% EM on the open-domain Natural Questions task, whereas a fine-tuned
T5-11B model (with 24x more parameters) and a few-shot GPT-3 model (with
400x more parameters) achieve only 34.8% and 29.9%, respectively.

Tackling Opaqueness. Retrieval-based NLP offers a transparent contract
with users: when the model produces an answer, we can read the sources
it retrieved and judge their relevance and credibility for ourselves.
This is essential whether the model is factually correct or not: by
inspecting the sources surfaced by a system like Baleen, we can trust
its outputs only if we find that reliable sources do support them.

Tackling Static Knowledge. Retrieval-based models emphasize learning
general techniques for finding and connecting information from the
available resources. With facts stored as text, the retrieval knowledge
store can be efficiently updated or expanded by modifying the text
corpus, all while the model’s capacity for finding and using information
remains constant. Besides computational cost reductions, this expedites generality:
developers, even in niche domains, can “plug in” a domain-specific text
collection and rely on retrieval to facilitate domain-aware responses.

ColBERT: Scalable yet expressive neural retrieval

As the name suggests, retrieval-based NLP relies on semantically rich search to extract
information. For search be practical and effective, it must scale to massive text corpora.
To draw on the open-book exam analogy, it’s hopeless to linearly look
through the pages of a hefty textbook during the exam—we need scalable
strategies for organizing the content in advance, and efficient
techniques for locating relevant information at inference time.

Figure 2: Schematic diagrams comparing two popular paradigms in neural IR in sub-figures (a) and (b) against the late interaction paradigm of ColBERT in sub-figure (c).

Traditionally in IR, search tasks were conducted using bag-of-words
models like BM25, which seek documents that contain the same tokens as
the query. In
2019, search was revolutionized with BERT for
ranking and its deployment
in Google and Bing for
Web search. The standard approach is illustrated in Figure 2(a). Each
document is concatenated with the query, and both are fed jointly into a BERT
model, fine-tuned to estimate relevance. BERT doubled the MRR@10 quality
metric over BM25 on the popular MS MARCO Passage Ranking leaderboard,
but it simultaneously posed a fundamental limitation: scoring
each query–document pair requires billions of computational operations
(FLOPs). As a result, BERT can only be used to re-rank the top-k (e.g.,
top-1000) documents already extracted by simpler methods like BM25,
having no capacity to recover useful documents that bag-of-word search

The key limitation of this approach is that it encodes queries and
documents jointly. Many representation-similarity systems have been
proposed to tackle this, some of which re-purpose BERT within the
paradigm depicted in Figure 2(b). In these systems
(like SBERT and ORQA,
and more
recently DPR and ANCE,
every document in the corpus is fed into a BERT encoder that produces a
dense vector meant to capture the semantics of the document. At search
time, the query is encoded, separately, through another BERT encoder, and the
top-k related documents are found using a dot product between the query
and document vectors. By removing the expensive interactions between the
query and the document, these models are able to scale far more
efficiently than the approach in Figure 2(a).

Nonetheless, representation-similarity models suffer from an
architectural bottleneck: they encode the query and document into
coarse-grained representations and model relevance as a single dot
product. This greatly diminishes quality compared with expensive
re-rankers that model token-level interactions between the contents of
queries and documents. Can we efficiently scale fine-grained, contextual
interactions to a massive corpus, without compromising speed or quality?
It turns out that the answer is “yes”, using a paradigm called late
interaction, first devised in
our ColBERT1 [code] model, which appeared at SIGIR 2020.

As depicted in Figure 2(c), ColBERT independently encodes queries and
documents into fine-grained multi-vector representations. It then
attempts to softly and contextually locate each query token inside the
document: for each query embedding, it finds the most similar embedding
in the document with a “MaxSim” operator and then sums up all of the
MaxSims to score the document. “MaxSim” is a careful choice that allows
us to index the document embeddings for Approximate Nearest Neighbor
(ANN) search, enabling us to scale this rich interaction to millions of passages with latency
on the order of tens of milliseconds. For instance, ColBERT can search over all
passages in English Wikipedia in approximately 70 milliseconds per query.
On MS MARCO Passage Ranking, ColBERT preserved the MRR@10 quality of BERT re-rankers while boosting recall@1k to nearly 97%
against the official BM25 ranking’s recall@1k of just 81%.

Making neural retrievers more lightweight remains an active area of
development, with models like DeepImpact
that trade away some quality for extreme forms of efficiency and
developments like BPR
and quantized ColBERT
that reduce the storage footprint by an order of magnitude while
preserving the quality of DPR and ColBERT, respectively.

ColBERT-QA and Baleen: Specializing neural retrieval to complex tasks, with tracked provenance

While scaling expressive search mechanisms is critical, NLP models need
more than just finding the right documents. In particular, we want NLP models
to use retrieval to answer questions, fact-check claims, respond
informatively in a conversation, or identify the sentiment of a piece of
text. Many tasks of this kind—dubbed knowledge-intensive language
tasks—are collected in
the KILT benchmark.
The most popular task is open-domain question answering (or Open-QA).
Systems are given a question from any domain and must produce an answer,
often by reference to the passages in a large corpus, as depicted in
Figure 1(b).

Benchmark System Metric Gains Baselines
Open-Domain Question Answering
Open-NaturalQuestions ColBERT-QA Answer Match +3 RAG, DPR, REALM, BM25+BERT
Open-TriviaQA +12
Open-SQuAD +17
Multi-Hop Reasoning
HotPotQA Baleen Retrieval Success@20 +10 / NA MDR / IRRR
Passage-Pair Match +5 / +3
HoVer Retrieval Success@100 +48 / +17 TF-IDF / ColBERT-Hop
“HoVer Score” for
Claim Verification
with Provenance
+42 Official “TF-IDF + BERT” Baseline
Cross-Lingual Open-Domain Question Answering
from IBM Research
Recall@5000-tokens +10 Official “DPR + Vanilla Transformer” Baseline
Zero-Shot Information Retrieval
BEIR ColBERT Recall@100 Outperforms other off-the-shelf
dense retrievers on 13/17 tasks
Table 1: Results of models using ColBERT, ColBERT-QA, and Baleen across a wide range of language tasks.

Two popular models in this space are REALM and RAG, which rely on the
ORQA and DPR retrievers discussed earlier. REALM and RAG jointly tune a
retriever as well as a reader, a modeling component that consumes the
retrieved documents and produces answers or responses. Take RAG as an
example: its reader is a generative BART model, which attends to the
passages while generating the target outputs. While they constitute
important steps toward retrieval-based NLP, REALM and RAG suffer from
two major limitations. First, they use the restrictive paradigm of
Figure 2(b) for retrieval, thereby sacrificing recall: they are often
unable to find relevant passages for conducting their tasks. Second,
when training the retriever, REALM and RAG collect documents by
searching for them inside the training loop and, to make this practical, they
freeze the document encoder when fine-tuning, restricting the model’s adaptation to the task.

ColBERT-QA2 is an Open-QA system (published at TACL’21) that we built on
top of ColBERT to tackle both problems. By adapting ColBERT’s expressive search to the task,
ColBERT-QA finds useful passages for a larger fraction of the questions and thus
enables the reader component to answer more questions correctly and with provenance.
In addition, ColBERT-QA introduces relevance-guided supervision (RGS),
a training strategy whose goal is to adapt a
retriever like ColBERT to the specifics of an NLP task like Open-QA. RGS
proceeds in discrete rounds, using the retriever trained in the previous
round to collect “positive” passages that are likely useful for the
reader—specifically, passages ranked highly by the latest version of the
retriever and that also overlap with the gold answer of the question—and
challenging “negative” passages. By converging to a high coverage of
positive passages and by effectively sampling hard negatives, ColBERT-QA
improves retrieval Success@20 by more than 5-, 5-, and 12-point gains on
the open-domain QA settings of NaturalQuestions, TriviaQA, and SQuAD, and thus greatly
improves downstream answer match.

A more sophisticated version of the Open-QA task is multi-hop reasoning,
where systems must answer questions or verify claims by gathering
information from multiple sources. Systems in this space,
like GoldEn, MDR,
and IRRR,
find relevant documents and “hop” between them—often by running
additional searches—to find all pertinent sources. While these models
have demonstrated strong performance for two-hop tasks, scaling robustly
to more hops is challenging as the search space grows exponentially.

To tackle this, our Baleen3 system
(accepted as a Spotlight paper at NeurIPS’21) introduces a richer pipeline for
multi-hop retrieval: after each retrieval “hop”, Baleen summarizes the
pertinent information from the passages into a short context that is used
to inform future hops. In doing so, Baleen controls the search space
architecturally—obviating the need to explore each potential passage
at every hop—without sacrificing recall. Baleen also extends ColBERT’s
late interaction: it allows the representations of different documents
to “focus” on distinct parts of the same query, as each of those documents
in the corpus might satisfy a distinct aspect of the same complex query.
As a result of its more deliberate architecture and its stronger
retrieval modeling, Baleen saturates retrieval on the popular two-hop
HotPotQA benchmark (raising answer-recall@20 from 89% by MDR to 96%) and
dramatically improves performance on the harder four-hop claim
benchmark HoVer,
finding all required passages in 92% of the examples—up from just 45%
for the official baseline and 75% for a many-hop flavor of ColBERT.

In these tasks, when our retrieval-based models make predictions, we can
inspect their underlying sources and decide whether we can trust the
answer. And when model errors stem from specific sources, those can be
removed or edited, and making sure models are faithful to such edits
is an active area of work.

Generalizing models to new domains with robust neural retrieval

In addition to helping with efficiency and transparency, retrieval
approaches promise to make domain generalization and knowledge updates
much easier in NLP. Exhibiting up-to-date, domain-specific knowledge is
essential for many applications: you might want to answer questions over
recent publications on COVID-19 or to develop a chatbot that guides
customers to suitable products among those currently available in a
fast-evolving inventory. For such applications, NLP models should be
able to leverage any corpus provided to them, without having to train a
new version of the model for each emerging scenario or domain.

While large language models are trained using plenty of data from the
Web, this snapshot is:

  • Static. The Web evolves as the world does: Wikipedia articles
    reflect new elected officials, news articles describe current events, and
    scientific papers communicate new research. Despite this, a language
    model trained in 2020 has no way to learn about 2021 events, short
    of training and releasing a new version of the model.

  • Incomplete. Many topics are under-represented in Web crawls like C4
    and The Pile. Suppose we seek to answer questions over the ACL
    papers published 2010–2021; there is no guarantee that The Pile
    contains all papers from the ACL Anthology a priori and there is no
    way to plug that in ad-hoc without additional training. Even when
    some ACL papers are present (e.g., through arXiv, which is included
    in The Pile), they form only a tiny sliver of the data, and it is
    difficult to reliably restrict the model to specifically those
    papers for answering NLP questions.

  • Public-only. Many applications hinge on private text, like internal
    company policies, in-house software documentation, copyrighted
    textbooks and novels, or personal email. Because models like GPT-3
    never see such data in their training, they are fundamentally
    incapable of exhibiting knowledge pertaining to those topics without
    special re-training or fine-tuning.

With retrieval-based NLP, models learn effective ways to encode and
extract information, allowing them to generalize to updated text,
specialized domains, or private data without resorting to additional
training. This suggests a vision where developers “plug in” their text
corpus, like in-house software documentation, which is indexed by a
powerful retrieval-based NLP model that can then answer questions, solve
classification tasks, or generate summaries using the knowledge from the
corpus, while always supporting its predictions with provenance from the

An exciting benchmark connected to this space
is BEIR,
which evaluates retrievers on their capacity for search “out-of-the-box”
on unseen IR tasks, like Argument Retrieval, and in new domains, like
the COVID-19 research literature. While retrieval offers a concrete
mechanism for generalizing NLP models to new domains, not every IR model
generalizes equally: the BEIR evaluations highlight the impact of
modeling and supervision choices on generalization. For instance, due to
its late interaction modeling, a vanilla off-the-shelf ColBERT retriever
achieved the strongest recall of all competing IR models in the initial
BEIR evaluations, outperforming the other off-the-shelf dense
retrievers—namely, DPR, ANCE, SBERT, and USE-QA—on 13 out of 17
datasets. The BEIR benchmark continues to develop quickly, a recent
addition being the
TAS-B model,
which advances a sophisticated supervision approach to distill ColBERT
and BERT models into single-vector representations, inheriting much of
their robustness in doing so. While retrieval allows rapid deployment in new
domains, explicitly adapting retrieval to new scenarios is also
possible. This is an active area of research, with work
like QGen and AugDPR that
generate synthetic questions and use those to explicitly fine-tune
retrievers for targeting a new corpus.

Summary: Is retrieval “all you need”?

The black-box nature of large language models like T5 and GPT-3 makes
them inefficient to train and deploy, opaque in their knowledge representations and in backing
their claims with provenance, and static in facing a constantly evolving world and diverse downstream contexts.
This post explores retrieval-based NLP, where models retrieve information
pertinent to solving their tasks from a plugged-in text corpus. This
paradigm allows NLP models to leverage the representational strengths
of language models, while needing much smaller architectures, offering
transparent provenance for claims, and enabling efficient updates and adaptation.

We surveyed much of the existing and emerging work in this space and
highlighted some of our work at Stanford, including
for scaling up expressive retrieval to massive corpora via late
ColBERT-QA for
accurately answering open-domain questions by adapting high-recall
retrieval to the task, and
Baleen for
solving tasks that demand information from several independent sources
using a condensed retrieval architecture.
We continue to actively maintain
our code as open source.

Acknowledgments. We would like to thank Megha Srivastava and Drew A. Hudson for helpful comments and feedback on this blog post. We also thank Ashwin Paranjape, Xiang Lisa Li, and Sidd Karamcheti for valuable and insightful discussions.

  1. Omar Khattab and Matei Zaharia. “ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT.” Proceedings of the 43rd International ACM SIGIR conference on research and development in Information Retrieval. 2020. 

  2. Omar Khattab, Christopher Potts, Matei Zaharia; “Relevance-guided Supervision for OpenQA with ColBERT.” Transactions of the Association for Computational Linguistics 2021; 9 929–944. doi: 

  3. Omar Khattab, Christopher Potts, and Matei Zaharia. “Baleen: Robust Multi-Hop Reasoning at Scale via Condensed Retrieval.” (To appear at NeurIPS 2021.) arXiv preprint arXiv:2101.00436 (2021). 

Read More

Break-It-Fix-It: Unsupervised Learning for Fixing Source Code Errors

Machine Learning for Code Repair

Across the board, programming has increased in popularity, ranging from developing with general-purpose programming languages like Python, C, Java to using simpler languages like HTML, SQL, LaTeX, and Excel formulas. When writing code we often make syntax errors such as typos, unbalanced parentheses, invalid indentations, etc., and need to fix them. In fact, several studies 1 show that both beginner and professional programmers spend 50% of time fixing code errors during programming. Automating code repair can dramatically enhance the programming productivity 2.

Recent works 3 use machine learning models to fix code errors by training the models on human-labeled (broken code, fixed code) pairs. However, collecting this data for even a single programming language is costly, much less the dozens of languages commonly used in practice.

On the other hand, unlabeled (unaligned) data—not aligned as (broken, fixed) pairs—is readily available: for example, raw code snippets on the web like GitHub. An unsupervised approach for training code repair models would make them much more scalable and widely deployable. In our recent work 4 published at ICML 2021, we study how to leverage unlabeled data to learn code fixers effectively.

Problem Setup

In code repair, we are given a critic that assesses the quality of an input: for instance, a compiler or code analyzer that tells us if input code has any syntax errors. The code is bad if there is at least one error and it is good if there are no errors. What we want is a fixer that repairs bad code into good code that satisfies the critic, e.g. repairing missing parenthesis as in the figure below. Our goal is to use unlabeled data and critic to learn a fixer.

While unlabeled data can be split into a set of good code and a set of bad code using the critic, they are unaligned; in other words, they do not form (broken, fixed) pairs ready to be used for training a fixer.

A straightforward technique 5 is to apply random or heuristic perturbations to good code, such as dropping tokens, and prepare synthetic paired data (perturbed code, good code) to train a fixer. However, such synthetically-generated bad code does not match the distribution of real bad code written by humans. For instance, as the figure below shows, synthetic perturbations (purple box) may drop parentheses arbitrarily from code, generating errors that are rare in real code. In contrast, human-written code (red box) rarely misses parentheses when only a single pair appears, but misses parentheses often in a nested context (e.g., 10x more than non-nested in our Python code dataset collected from GitHub). This distributional mismatch between synthetic data and real data can result in low code repair performance when used in practice. To tackle this challenge, we introduce a new training approach, Break-It-Fix-It (BIFI), that adapts the fixer towards real distributions of bad code.

Approach: Break-It-Fix-It

The basic idea of BIFI is to introduce a machine learning-based breaker that learns to corrupt good code into realistic bad code, and iteratively train both the fixer and the breaker while using them in conjunction to generate more realistic paired data. Concretely, BIFI takes as inputs:

  • Critic
  • Unaligned set of good and bad code
  • Initial fixer, which potentially is trained on synthetic data

BIFI then improves the fixer by performing the following cycle of data generation and training procedure:

  1. Apply the fixer to the set of bad code, which consists of real code errors made by humans, and use the critic to assess if the fixer’s output is good. If good, keep the pair
  2. Train the breaker on the resulting paired data from Step 1. Consequently, the breaker can generate more realistic errors than the initial synthetic data
  3. Apply the breaker to the set of good code, and keep outputs that the critic judges as bad
  4. Train the fixer on the newly-generated paired data in Step 1 and Step 3

These steps are also illustrated in the left panel of the figure below. We iterate over this cycle to improve the fixer and the breaker simultaneously until they have both converged. The intuition is that a better fixer and breaker will be able to generate more realistic paired data, which in turn helps to train a better fixer and breaker.

BIFI is related to the backtranslation (cycle-consistency) method in unsupervised translation 6. If we apply backtranslation directly to the code repair task, we would do the following:

  1. Apply the fixer to the set of bad code and generate (noisy) good code
  2. Train the breaker to reconstruct the bad code
  3. Apply the breaker to the set of good code and generate (noisy) bad code
  4. Train the fixer to reconstruct the good code

as illustrated in the right panel of the figure. BIFI improves on backtranslation in two aspects. First, while backtranslation may include non-fixed code as good or non-broken code as bad in Step 1 or 3, BIFI uses the critic to verify if the generated code is actually fixed or broken in Step 1 and 3, as highlighted with pink in the left panel of the figure. This ensures the correctness of training data generated by the breaker and fixer. Second, while backtranslation only uses paired data generated in Step 3 to train the fixer in Step 4, BIFI uses paired data generated in both Step 3 and Step 1, as paired data from Step 1 contains real code errors made by humans. This improves the distributional match of generated training data.

Let’s use our code repair model!

We apply and evaluate our method, BIFI, on two code repair benchmarks:

  • GitHub-Python 7: Fix syntax errors in Python code. Critic is Python AST parser.
  • DeepFix 8: Fix compiler errors in C code. Critic is C compiler.

BIFI improves on existing unsupervised methods for code repair
Using the GitHub-Python dataset, we first compare BIFI with existing unsupervised methods for code repair: a synthetic baseline that uses synthetic paired data generated by randomly dropping, inserting or replacing tokens from good code, and a backtranslation baseline that directly applies backtracklation to code repair. The synthetic baseline serves as the initial fixer for our BIFI algorithm. We find that BIFI improves the repair accuracy by 28% (62%→90%) over the synthetic baseline and by 10% (80%→90%) over the backtranslation baseline, as shown in the left panel of the figure. This result suggests that while we started from a simple initial fixer trained with random perturbations, BIFI can automatically turn it into a usable fixer with high repair accuracy.

For the other dataset, DeepFix, there are several prior works that use heuristic ways to generate synthetic paired data for the task: Gupta+17 9, Hajipour+19 10, DrRepair 11. We take the existing best model, DrRepair, as our initial fixer and apply BIFI. We find that it improves the repair accuracy by 5% (66%→71%), as shown in the right panel of the figure. This result suggests that while the initial fixer DrRepair was already trained with manually designed heuristics, there is still room for improving the adaptation to a more realistic distribution of code errors. BIFI helps to achieve this without additional manual effort.

Examples of breaker outputs
Let’s look at several examples of code generated by the trained breaker. Given the good Python code shown on the left below, we show on the right outputs that the breaker places high probability on. In output 1, the breaker converts raise ValueError(...) into raise ValueError, ..., which is an obsolete usage of raise in Python. In output 2, the breaker drops a closing parenthesis in a nested context. These are both errors commonly seen in human written bad code.

Examples of fixer outputs
Let’s look at how our fixer performs through examples too. The left side of the figure shows human-written Python code with an indentation error—one needs to add indent to the err = 0 line and remove indent in the next line. The initial fixer, shown in the center, only inserts one indent token and fails to fix the error. This is most likely due to the mismatch between real errors and synthetic errors used in training: synthetic errors generated by random perturbations do not frequently contain this kind of indentation error where multiple tokens need to be inserted/removed accordingly. The fixer trained by BIFI, shown on the right, fixes the indentation error by inserting and removing the correct pair of indent tokens. We find that this is one of the representative examples of when BIFI successfully fixes code errors but the initial fixer fails.

Finally, one limitation of this work is that we focus on fixing syntactic errors (we use critics such as AST parser and compiler), and we are not evaluating the semantic correctness of our outputs. Extending BIFI to fixing semantic errors is an exciting future research avenue.


Machine learning of source code repair is an important direction to enhance programming productivity, but collecting human-labeled data is costly. In this work, we studied how to learn source code repair in an unsupervised way, and developed a new training method, BIFI. The key innovation of BIFI is that it creates realistic paired data for training fixers from a critic (e.g. compiler) and unlabeled data (e.g. code snippets on the web) only, which are cheaply available.

More broadly, the idea of learning fixers from critics + unlabeled data is applicable to various repair tasks beyond code repair, such as grammatical error correction 12 and molecule design, using domain-specific critics. Additionally, the idea of using a critic to improve the quality of paired data is applicable to various translation tasks by introducing a learned critic. We hope that BIFI can be an effective solution to unsupervised repair tasks and translation tasks.

You can check out our full paper here and our source code/data on GitHub.


This blog post is based on the paper:

Many thanks to Percy Liang, as well as members of the Stanford P-Lambda group, SNAP group and NLP group for their valuable feedback. Many thanks to Jacob Schreiber and Sidd Karamcheti for edits on this blog post.

  1. Reversible Debugging Software. Tom Britton, Lisa Jeng, Graham Carver, Paul Cheak, Tomer Katzenellenbogen. 2013. Programmers’ Build Errors: A Case Study (at Google). Hyunmin Seo, Caitlin Sadowski, Sebastian Elbaum, Edward Aftandilian, Robert Bowdidge. 2014. 

  2. Improving programming productivity with machine learning is an extremely active area of research. A prominent example is the Copilot / Codex service recently released by OpenAI and GitHub, which translates natural language (e.g. English) descriptions into code. Automated code repair is another complementary technology to improve programming productivity. 

  3. SEQUENCER: Sequence-to-Sequence Learning for End-to-End Program Repair. Zimin Chen, Steve Kommrusch, Michele Tufano, Louis-Noël Pouchet, Denys Poshyvanyk, Martin Monperrus. 2019. DeepDelta: Learning to Repair Compilation Errors. Ali Mesbah Andrew Rice Emily Johnston Nick Glorioso Eddie Aftandilian. 2019. Patching as Translation: the Data and the Metaphor. Yangruibo Ding, Baishakhi Ray, Premkumar Devanbu, Vincent J. Hellendoorn. 2020 

  4. Break-It-Fix-It: Unsupervised Learning for Program Repair. Michihiro Yasunaga, Percy Liang. 2021. 

  5. DeepFix: Fixing common C language errors by deep learning. Rahul Gupta, Soham Pal, Aditya Kanade, Shirish Shevade. 2017. DeepBugs: A Learning Approach to Name-based Bug Detection. Michael Pradel, Koushik Sen. 2018. Neural program repair by jointly learning to localize and repair. Marko Vasic, Aditya Kanade, Petros Maniatis, David Bieber, Rishabh Singh. 2019. Global relational models of source code. Vincent J. Hellendoorn, Charles Sutton, Rishabh Singh, Petros Maniatis, David Bieber. 2020. 

  6. Improving Neural Machine Translation Models with Monolingual Data. Rico Sennrich, Barry Haddow, Alexandra Birch. 2016. Phrase-Based & Neural Unsupervised Machine Translation. Guillaume Lample, Myle Ott, Alexis Conneau, Ludovic Denoyer, Marc’Aurelio Ranzato. 2018. 


  8. DeepFix: Fixing common C language errors by deep learning. Rahul Gupta, Soham Pal, Aditya Kanade, Shirish Shevade. 2017. 

  9. DeepFix: Fixing common C language errors by deep learning. Rahul Gupta, Soham Pal, Aditya Kanade, Shirish Shevade. 2017. 

  10. SampleFix: Learning to Correct Programs by Sampling Diverse Fixes. Hossein Hajipour, Apratim Bhattacharya, Mario Fritz. 2019. 

  11. Graph-based, Self-Supervised Program Repair from Diagnostic Feedback. Michihiro Yasunaga, Percy Liang. 2020. 

  12. LM-Critic: Language Models for Unsupervised Grammatical Error Correction. Michihiro Yasunaga, Jure Leskovec, Percy Liang. 2021. 

Read More

Our Journey towards Data-Centric AI: A Retrospective

Our Journey towards Data-Centric AI: A Retrospective

This article provides a brief, biased retrospective of our road to data-centric AI. Our hope is to provide an entry point for people interested in this area, which has been scattered to the nooks and crannies of AI—even as it drives some of our favorite products, advancements, and benchmark improvements.

We’re collecting pointers to these resources on GitHub, and plan to write a few more articles about exciting new directions. We hope to engage with folks who are excited about data-centric AI in an upcoming HAI workshop in November — folks like you!

Generic badge
Generic badge

Starting in about 2016, researchers from our lab — the Hazy Research lab — circled through academia and industry giving talks about an intentionally provocative idea: machine learning (ML) models—long the darlings of researchers and practitioners—were no longer the center of AI. In fact, models were becoming commodities. Instead, we claimed that it was the training data that would drive progress towards more performant ML models and systems.

To underscore this, we had taglines like “AI is driven by data—not code” or worse ”Training data is the new new oil”. We started building systems championed by little octopuses wearing snorkels. Eventually, we turned to others and called this “Software 2.0” (inspired by Karpathy’s post. Others have since termed it data-centric AI, and recently Andrew Ng gave a great talk about his perspective on this direction.

Our view that models were becoming a commodity was heretical for a few reasons.

First, people often think of data as a static thing. After all, data literally means “that which is given”. For most ML people, they download an off-the-shelf dataset, drop it into a PyTorch dataloader, and plug-and-play: losses go down, accuracy goes up, and the data is a mere accessory.

But to an engineer in the wild, the training data is never “that which is given”. It is the result of a process — usually a dirty, messy process that is critical and underappreciated.

An engineer and their training data in the wild. Credit: Vickie Shelton.

Still, we had hope. In applications, we took time to clean and merge data. We engineered it. We began to talk about how AI and ML systems were driven by this data, how they were programmed by this data. This led to understandably (obtuse) names like “data programming”.

Unfortunately, we were telling people to put on galoshes, jump into the sewer that is your data, and splash around. Not an easy sales pitch for researchers used to life in beautiful PyTorch land.

We started to recognize that model-itis is a real problem. With some friends at Apple, we realized that teams would often spend time writing new models instead of understanding their problem—and its expression in data—more deeply. We weren’t the only ones thinking this way, lots of no-code AI folks like Ludwig, H2O, DataRobot were too. We began to argue that this aversion to data didn’t really lead to a great use of time. To make matters worse, 2016-2017 was a thrilling time to be in ML. Each week a new model came out, and each week, it felt like we were producing demos that we couldn’t dream of a decade earlier.

Despite this excitement, it was clear to us that success or failure to a level usable in applications we cared about—in medicine, at large technology companies or even pushing the limits on benchmarks—wasn’t really tied to models per se. That is, the advances were impressive, but they were hitting diminishing returns. You can see this in benchmarks, where most of the progress after 2017 is fueled by new advances in augmentations, weak supervision, and other issues of how you feed machines data. In round numbers, ten points of accuracy were due to those—while (by and large) model improvements were squeaking out a few tenths in accuracy points.

At the time, many of the folks who are now converts have shared with us that they were skeptical of our view of the future. We get it, our stupid jokes and general demeanor didn’t inspire confidence. But we weren’t totally insane. This idea has become mainstream and widespread. Our friends at Google in Ads, Gmail, YouTube and Apple extended to us a level of technical trust that we hope we’ve repaid. You’ve probably used some of the products that have incorporated these crazy ideas in the last few minutes. The Octopus is now widely used in the enterprise, and we’re just at the beginning!

This blog post is an incomplete, biased retrospective of this road. We’ll close with two thoughts:

  1. There is a data-centric research agenda inside AI. It’s intellectually deep, and it has been lurking at the core of AI progress for a while. Perhaps by calling it out we can make even more progress on an important viewpoint.
  2. We’d love to provide entry points for folks interested in this area. Our results are scattered in a number of different research papers, and we’d enjoy writing a survey (if anyone is interested – we have a form!). We’ve opted to be biased about what influenced us the most to try to present a coherent story here. Necessarily, this means we’re leaving out amazing work. Apologies, please send us notes and corrections.

On our end, we’ll do our best to build this data-centric community up on GitHub, with a collage of exciting related papers and lines of work. If you’re new to the area, use it as a pedagogical resource, and if you’re a veteran, please go ahead and send us PRs and contributions so we can expand the discussion! We’re gathering real-world case studies, so if you work on real applications that have benefited from a data-centric viewpoint (in academia, industry or anywhere), please don’t hesitate to reach out at or create an Issue on the Github so we can bring your experiences into the fold.

A more informal version of this blog can be found here.

Read More

Supporting COVID-19 policy response with large-scale mobility-based modeling

Supporting COVID-19 policy response with large-scale mobility-based modeling

Mobility restrictions, from stay-at-home orders to indoor occupancy caps, have been utilized extensively by policymakers during the COVID-19 pandemic. These reductions in mobility help to control the spread of the virus 12, but they come at a heavy cost to businesses and employees.

To balance these competing demands, policymakers need analytical tools that can evaluate the tradeoffs between mobility and COVID-19 infections. Furthermore, such tools should be fine-grained, able to test out heterogeneous plans—for example, allowing one level of mobility at essential retail, another level at gyms, and yet another at restaurants—so that policymakers can tailor restrictions to the specific risks and needs of each sector. At the same time, the tool also needs to be scalable, supporting analyses for a massive number of potential policies so that policymakers can find the best option for their jurisdiction.

Our tool

To fulfill these needs, we developed a novel computational tool, which we built in collaboration with the Biocomplexity Institute & Initiative at UVA to support the Virginia Department of Health (VDH). Described in our award-winning KDD 2021 paper, our tool enables policymakers to assess the costs and benefits of thousands of different mobility measures, based on millions of simulations from our underlying epidemiological model. We designed our tool to fulfill VDH’s desire to have a quantitative and comprehensive analysis of a range of reopening policies. With their guidance, we developed an interactive dashboard, where policymakers can select various proposed changes in mobility and observe their predicted impacts on COVID-19 infections over time and across regions.

Our dashboard focuses on mobility to five key categories of places: Restaurants, Gyms, Religious Organizations, Essential Retail (grocery stores, pharmacies, convenience stores), and Retail (clothing stores, book stores, hardware stores, etc.). For each category, the user can use sliders to choose a target level of mobility (e.g., 50% of normal levels, based on pre-pandemic mobility), or they can choose to continue current levels of mobility at these places. The other panels on the dashboard then visualize predicted COVID-19 infections under the selected mobility plan, and compare these outcomes to what would happen if all categories remained at their current levels of mobility.

Our tool enables policymakers to comprehensively analyze pandemic tradeoffs, by quantifying visits lost under each mobility plan as well as predicted infections. The sliders for each category allow them to test fine-grained, heterogeneous policies. Furthermore, the flexibility of our approach (i.e., allowing any combination of mobility levels) results in an exponential number of scenarios to test. To scale our modeling efforts, our tool features a robust computational infrastructure that compresses 2 years of compute time into the span of a few days.

Our approach

At the heart of our tool is our state-of-the-art epidemiological model which utilizes large-scale mobility networks to accurately capture the spread of COVID-19 in cities across the US.

Our mobility networks encode the hourly movements of people from census block groups (CBGs) to points of interest (POIs), which are non-residential locations such as restaurants, grocery stores, and churches. Using iterative proportional fitting, we infer these networks from aggregated, anonymized location data provided by SafeGraph. In this work, we infer hourly networks for the Washington DC, Virginia Beach, and Richmond metropolitan areas, three of the largest metropolitan areas in Virginia. From November 1 to December 31, 2020, their resulting networks contain 3.4 billion hourly edges between CBGs and POIs.

We integrate the mobility networks, along with other data sources such as daily mask use, into our model. The key to our model is that it maintains the number of people in each CBG who are susceptible (S), exposed (E), infectious (I), or removed (R).

These CBG states are updated in each hour of the simulation, based on transmission dynamics that capture both household transmission and transmission occurring at POIs. That is, if there are susceptible and infectious individuals visiting a POI at the same time, then we model some probability of new infection occurring. That probability depends on the POI’s area in square feet, its median dwell time, the percentage of people wearing masks, and the number of susceptible and infectious visitors. Based on all of these factors, our model realistically captures who was infected where and when, down to the individual POI and hour.

To validate our models, we compare its predictions against actual daily COVID-19 cases and deaths, as reported by The New York Times. In our initial work 3, published in Nature 2020, we showed that our dynamic mobility networks enable even these relatively simple SEIR models with minimal free parameters to accurately fit real case trajectories and predict case counts in held-out time periods, despite substantial changes in population behavior during the pandemic. Integrating these networks furthermore allows us to capture the fine-grained spread of the virus, enabling analyses of the riskiest venues to reopen and the most at-risk populations.

Illustration of our approach. We integrate many data sources to run, evaluate, and analyze our model. We pair our model output with an interactive dashboard, whose engineering architecture is described in the box on the right.

In this work, we sought to translate our model into a tool that can directly support COVID-19 decision-makers, motivated by our interactions with the Virginia Department of Health. This goal required many extensions to our computational pipeline, including fitting the model to new regions and time periods, and improving our computational infrastructure to deploy the model at scale. Furthermore, to keep pace with developments in the pandemic, we introduced new real-world features to the model such as daily mask use, time-varying case and death detection rates, and model initialization based on historical reported cases/deaths. These additions allowed us to accurately fit real COVID-19 trajectories in Virginia, and we showed that the inclusion of our new features contributed substantially toward reducing model loss. Most importantly, we worked with VDH to design use cases of our model that were most relevant to their needs, and developed a new dashboard to effectively communicate thousands of results from our model. Our full pipeline—the extended model, the computational infrastructure, and the new dashboard—constitutes advancements in this work that allowed us to truly transform our scientific model into a tool for real-world impact.

Using our model

Our fitted model can be applied to a wide variety of use cases. First, we can use it for retrospective analyses, by leveraging the model’s ability to capture who got infected where and when.

For example, we can use the model to compare the learned infection rates of lower-income and higher-income CBGs. What’s striking is that our model correctly predicts disparities from mobility data alone, even though we did not give our model any CBG demographics during runtime (only during analysis). In our prior work, we showed that two mechanisms in the mobility data explained these predicted disparities: lower-income CBGs were not able to reduce their mobility as much during the pandemic, and the POIs that they go to (even in the same category) tend to be more crowded with longer visits, and thus riskier. In this work, we show that this trend extends to both waves of the pandemic and to new metropolitan areas.

We can also use the model for forward-facing experiments. Essentially, the model has many different interpretable inputs, so we can simply modify one of those inputs, run the model, and observe what happens to the model’s predicted infections. For example, to generate data for our dashboard, we modify the mobility networks to reflect the user’s selected levels of mobility for each category, and run the model forward to produce predicted infections. We can also use our model to analyze vaccination strategies; for example, by reducing transmission rates per CBG based on the percentage of the CBG that is vaccinated.

Discussion & next steps

Our approach is not without its limitations, which we have discussed with policymakers. For instance, the mobility data from SafeGraph does not cover all POIs (e.g., limited coverage of nursing homes) or populations (e.g., children), and our model makes necessary but simplifying assumptions about the dynamics of disease transmission. Furthermore, in this work, we focused on how changes in mobility impact transmission, but where do these changes in mobility come from and how can we effect them? In future work, we plan to develop new models to answer these questions, to analyze and predict how complex mobility networks change in response to policy interventions and other pandemic events.

That said, in this work we’ve addressed a significant part of the puzzle, by introducing a tool that provides a quantitative and comprehensive near real-time assessment of the effects of mobility on transmission. Our underlying model is furthermore capable of many more types of analyses, from informing inequities to evaluating future vaccination strategies. In fact, we are now supporting the Virginia Department of Health on their vaccination efforts and extending our model to evaluate different vaccination policies. As the pandemic evolves, we will continue building decision-support tools and advancing the capabilities of our model, so that we can best support the needs of policymakers.


Special thanks to the SAIL blog editors, Emma Pierson, and Pang Wei Koh for their helpful feedback on this post. This blog post is based on our paper in KDD 2021:

Supporting COVID-19 policy response with large-scale mobility-based modeling. Serina Chang, Mandy L. Wilson, Bryan Lewis, Zakaria Mehrab, Komal K. Dudakiya, Emma Pierson, Pang Wei Koh, Jaline Gerardin, Beth Redbird, David Grusky, Madhav Marathe, and Jure Leskovec. KDD 2021 (Applied Data Science Track, Best Paper Award).

  1. S. Gao, J. Rao, Y. Kang, et al. Association of mobile phone location data indications of travel and stay-at-home mandates with COVID-19 infection rates in the US. JAMA Netw Open (2020). 

  2. J. Oh, HY. Lee, Q. Khuong, et al. Mobility restrictions were associated with reductions in COVID-19 incidence early in the pandemic: evidence from a real-time evaluation in 34 countries. Sci Rep 11, 13717 (2021). 

  3. S. Chang, E. Pierson, P.W. Koh, et al. Mobility network models of COVID-19 explain inequities and inform reopening. Nature 589, 82–87 (2020). 

Read More