PLAS: Latent Action Space for Offline Reinforcement Learning

PLAS: Latent Action Space for Offline Reinforcement Learning

Figure 1: Overview: To avoid out-of-distribution actions in the Offline Reinforcement Learning problem, we propose to implicitly constrain the policy by modifying the action space instead of enforcing explicit constraints. The policy is trained to output a latent action which will be passed into a decoder pretrained with the dataset. We demonstrate that our method provides competitive performance in both simulation and real-robot experiments.

Offline RL: Learning a policy from a static dataset

The goal of Reinforcement Learning (RL) is to learn to perform a task by interacting with the environment. It has achieved significant success in a lot of applications such as games and robotics. One major challenge in RL is that it requires a huge amount of interactive data collected in the environment to learn a policy. However, data collection is expensive and potentially dangerous in many real-world applications, such as robotics in safety-critical situations (e.g., around humans) or healthcare problems. Worse, RL algorithms also usually assume that the dataset used to update the policy comes from the current policy or its own training process.

To use data more wisely, we may consider Offline Reinforcement Learning. The goal of offline RL is to learn a policy from a static dataset of transitions without further data collection. Although we may still need a large amount of data, the assumption of static datasets allows more flexibility in data collection. For example, in robotics, we can include human demonstrations, reuse rollouts from previous experiments, and share data within the community. In this way, the dataset is more likely to be scaled up in size even when data collection is expensive.

Figure 2: In contrast to a common reinforcement learning pipeline that collects data and updates the policy alternatively, Offline Reinforcement Learning aims to learn a policy from a static dataset without further data collection.

One important feature of offline RL is that it requires no assumption about the performance of the policy that is used to collect the dataset. This is in contrast to behavior cloning, where we assume that the dataset is collected by an expert, so that we can directly “copy” the actions given states without reasoning about the future reward. In offline RL, the dataset could be collected by a policy (or several policies) with arbitrary performance.

At first glance, off-policy algorithms seem to be able to meet the above requirements. Off-policy algorithms save the agent’s interactions during training in a replay buffer and train the policy by sampling transitions from the replay buffer (Lillicrap 2015, Haarnoja 2018). However, as shown in previous work (Fujimoto 2018b), when we apply off-policy algorithms to a static dataset, the performance can be very poor due to out-of-distribution actions. In the off-policy algorithm, the Q-function is updated by the Bellman operator:

$$ mathcal{T} hat{Q}^pi(s_t, a_t) = mathbb{E}_{r_t, s_{t+1}}[r_t + gamma hat{Q}^pi(s_{t+1}, pi(s_{t+1}))] $$

As explained in Fujimoto (2018b), if the policy selects an action (pi(s_{t+1})) that is not included in this static dataset, then the term (hat{Q}^pi(s_{t+1},pi(s_{t+1}))) may have a large extrapolation error. The extrapolation error will be accumulated by the Bellman operator and exacerbated by the policy updates. These errors eventually lead to significant overestimation bias that can hurt the performance. This has always been an issue for Q-learning-based methods (Thrun 1993, Fujimoto 2018a), but it is especially problematic when applied on a static dataset because the policy is not able to try out and correct the overly-optimistic actions. The problem is more significant when the action space is large, such as continuous action space with high dimensions.

Objectives for Offline RL algorithms to avoid out-of-distribution actions

To fix the issue discussed above and to fully utilize the dataset, we need two objectives in offline RL. First, the policy should be constrained to select actions within the support of the dataset. The policy that represents the probability distribution (p(a|s)) of the dataset is usually called the behavior policy, denoted as (pi_B). We aim to maximize the return of the policy (G_t) subject to the constraint that (pi_B(a|s)) is larger than a threshold (epsilon):

$$ max_{asim pi(cdot|s)} mathbb{E}[G_t]$$
$$ s.t. pi_B(a|s) > epsilon$$

An illustration of this constraint is shown in Figure 3 below. Given a behavior policy ( pi_B(a|s) ) on the top figure, the agent policy ( pi) should only choose actions within the green region where ( pi_B(a|s) > epsilon). On the other hand, the constraint cannot be overly restrictive.  Specifically, the policy should not be affected by the density of (pi_B). In the example in Figure 3, the policy should have the flexibility to choose any action within the green region even if it deviates from the most probable action of (pi_B) and the “shape” of the distribution (pi) is very different from the shape of (pi_B).

Figure 3: An illustration of the two objectives of offline RL: Given a behavior policy distribution at a state (s) (top), the policy (bottom) should (1) only choose actions within the green region where (pi_B > epsilon) (2) not be restricted by the density of the behavior policy (pi_B).

Figure 4 below shows a more intuitive example. Consider an agent in a grid world. Suppose that the agent has a dataset of transitions marked as blue dots and arrows. The agent aims to find a path to get to the goal without the information of the other parts of the map. As shown on the left figure, it cannot select out-of-distribution actions because it might be dangerous. As shown on the right figure, if action (a_1) appears 10 times in the dataset, and action (a_0) appears 5 times in the dataset, it should not choose action (a_1) just because it appears more often in the dataset; as shown, this might be a suboptimal action for the task.

Figure 4: An intuitive explanation of the two objectives in offline RL: (1) Left: Choosing out-of-distribution actions may lead to a dangerous state. (2) Right: The action selection given a state should not be biased by the probability of the actions of the dataset.

Previous methods struggled with achieving both of these objectives. For example, BCQ (Fujimoto 2018b) proposes to sample from the behavior policy, perturb around it and then take the action that maximizes the Q-value. This method will be restricted by the density of the behavior policy distribution if the sample size 𝑁 is not large enough. Another line of work uses explicit policy constraints in the optimization process (Jaques 2019, Kumar 2019, Wu 2019). They try to force the agent policy to be close to the behavior policy in terms of different measures of distance, such as KL or MMD (Figure 1). The explicit constraints create difficulties in the optimization and distance metrics such as KL will be affected by the density (see Appendix E in our paper). 

Proposed Method: Policy in Latent Action Space (PLAS)

In our paper, PLAS: Latent Action Space for Offline Reinforcement Learning (CoRL 2020), we propose a method that can satisfy both the objectives discussed above by simply modifying the action space of the policy – i.e., the policy will only select actions when ( pi_B(a|s) > epsilon), but will not be restricted by the density of the distribution ( pi_B(a|s)). In our method, we first model the behavior policy using a Conditional Variational Autoencoder (CVAE) as in previous work (Fujimoto 2018b, Kumar 2019). The CVAE is trained to reconstruct actions conditioned on the states. The decoder of the CVAE creates a mapping from the latent space to the action space. Instead of training a policy in the action space of the environment, we propose to learn a Policy in the Latent Action Space (PLAS) of the CVAE and then use the pretrained decoder to output an action in the original action space.

Figure 5: Proposed Method: Policy in Latent Action Space (PLAS). We propose to first train a CVAE using the dataset and freeze the decoder. Second, we train a policy in the latent action space. The latent action will be passed into the decoder and transformed into an action within the distribution of the dataset. In contrast to previous work, it forms an implicit policy constraint in the latent action space.

Using the above approach, we can naturally constrain the policy to select actions within the dataset because the action is chosen from the latent space. The prior of the latent variable of CVAE is set to be a normal distribution for simplicity, following the common practice. To constrain the latent policy from selecting actions that are too “far away” from this prior, we use a tanH activation at the output of the policy; this implicitly constrains the policy to select within a fixed number of standard deviations of the mean of the latent prior. It is important to note that the action output from the decoder should be conditioned on the state because we care about (pi_B(a|s) > epsilon) instead of (pi_B(a)>epsilon). This approach also satisfies the second objective because the policy can select any action within the latent space and will not be affected by the density of the behavior policy. 

Experiments: Cloth Sliding and D4RL benchmark

This modification over the action space can be built on top of any off-policy algorithm with either a stochastic or deterministic policy. In our experiment, we use TD3 (Fujimoto 2018a) with a deterministic latent policy. We evaluate our algorithm on a wide range of continuous control tasks, including a real robot experiment on cloth sliding and the D4RL benchmark

The task for the real-robot experiment is to slide along the edge of the cloth without dropping it. The dataset we use consists of a replay buffer from a previous experiment (around 7000 timesteps) and 5 episodes of expert trajectories (around 300 timesteps). Our method outperforms all the baselines and achieves similar performance as the expert.

Figure 6: Results from the real robot experiment. Left: Performance of PLAS on the cloth-sliding task. More videos can be found here. Right: Training curves of PLAS and the baselines. PLAS outperforms the other baselines and achieves similar performance as the expert. 

On the D4RL benchmark, our method also achieves consistent performance across a wide range of datasets with different environments and qualities despite its simplicity. We provide some of the qualitative and quantitative results below in Figures 7 and 8. Check out the full results on the D4RL benchmark in the paper. More videos can be found on our website.

Figure 7: We evaluate our method on different environments and datasets from the D4RL benchmark. Here are some examples of trained policies in Hopper-v2, Walker2d-v2, Adroit Hammer, and the FrankaKitchen environment. The policies are able to perform the tasks without further data collection. More videos can be found here.
Figure 8: Training curves on the medium expert datasets on the locomotion tasks. Our method achieves comparable performance as the expert on the medium expert datasets. More results can be found in the paper.

To further analyze the result, we plot the estimation error of the learned Q-functions in Figure 9. During the evaluation, we compare the estimated Q-value of the state-action pairs with their true return from the rollouts. Our method has the lowest mean squared error (MSE) while the baselines have either more significant overestimation or underestimation.

Figure 9: Analysis of Q-functions on the Walker2d-v2 Medium Expert dataset: (a) Mean squared error (b) The percentage of overestimated Q-values. Our method has the lowest MSE without significant overestimation or underestimation.

As mentioned earlier in the objectives, our method focuses on avoiding out-of-distribution actions. In our experiment, we analyze the effect of out-of-distribution actions by introducing an additional component: we add a perturbation layer that is trained together with the latent policy, inspired by Fujimoto 2018b. The perturbation layer outputs a residual over the action output of the decoder. This allows the final policy output to deviate from the support of the dataset in a controlled way. More precisely, restricting the range of the output of the perturbation layer is essentially constraining the action output to be close to the dataset in terms of the L-infinity norm. In Figure 10, we plot the performance of our method with different ranges of allowed perturbation. We found that out-of-distribution actions introduced by the perturbation layer are usually harmful to datasets with high-quality rollouts such as the medium-expert datasets. However, it could be helpful for some of the random or medium datasets depending on the environment. The full analysis of the perturbation layer can be found in Appendix D. The results shed light on the disentangled contributions of in-distribution generalization and out-of-distribution generalization in offline reinforcement learning.

Figure 10: Effect of the perturbation layer on different datasets. More allowed perturbation is usually harmful to the datasets with high-quality rollouts such as the medium expert datasets, but it could be helpful for the medium or random datasets for certain environments.

Conclusion and Discussion

We propose a simple and straightforward approach to offline RL: Policy in the Latent Action Space (PLAS). To summarize:

  • Our approach naturally avoids out-of-distribution actions while allowing the flexibility for improvement over the performance of the dataset through implicit constraints.
  • It achieves competitive performance in both simulated environments and a real robot experiment on cloth manipulation.
  • We provided the analyses on Q-function estimation error and the separation of in-distribution vs. out-of-distribution generalization in Offline RL.

Please visit our website for the paper, code, and more videos.

Our method can be extended in different ways. First, it will benefit from a better generative model. For example, using normalizing flow to replace VAE could potentially lead to theoretical guarantees and better evaluation performance. Second, it can also be extended to allow better “out-of-distribution” generalization. We hope that our method will pave the way for future possibilities of applying reinforcement learning algorithms to real-world applications by using the static datasets more efficiently.


This material is based upon work supported by the United States Air Force and DARPA under Contract No. FA8750-18-C-0092, LG Electronics and the National Science Foundation under Grant No. IIS-1849154. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of United States Air Force and DARPA and the National Science Foundation.

DISCLAIMER: All opinions expressed in this post are those of the author and do not represent the views of Carnegie Mellon University.

Read More

A Learning Theoretic Perspective on Local Explainability

A Learning Theoretic Perspective on Local Explainability

Fig 1: A formal relationship between interpretability and complexity. Going from left to right, we consider increasingly complex functions. As the complexity increases, local linear explanations can approximate the function only in smaller and smaller neighborhoods. These neighborhoods, in other words, need to become more and more disjoint as the function becomes more complex. Indeed, we quantify “disjointedness” of the neighborhoods via a term denoted by (rho_S) and relate it to the complexity of the function class, and subsequently, its generalization properties.


There has been a growing interest in interpretable machine learning (IML), towards helping users better understand how their ML models behave. IML has become a particularly relevant concern especially as practitioners aim to apply ML in important domains such as healthcare [Caruana et al., ’15], financial services [Chen et al., ’18], and scientific discovery [Karpatne et al., ’17].

While much of the work in IML has been qualitative and empirical, in our recent ICLR21 paper, we study how concepts in interpretability can be formally related to learning theory. At a high level, the connection between these two fields seems quite natural. Broadly, one can consider there to be a trade-off between notions of a model’s interpretability and its complexity. That is, as a model’s decision boundary gets more complicated (e.g., compare a sparse linear model vs. a neural network) it is harder in a general sense for a human to understand how it makes decisions. At the same time, learning theory commonly analyzes relationships between notions of complexity for a class of functions (e.g., the number of parameters required to represent those functions) and the functions’ generalization properties (i.e., their predictive accuracy on unseen test data). Therefore, it is natural to suspect that, through some appropriate notion of complexity, one can establish connections between interpretability and generalization.

How do we establish this connection? First, IML encompasses a wide range of definitions and problems, spanning both the design of inherently interpretable models as well as post-hoc explanations for black-boxes (e.g. including but not limited to approximation [Riberio et al., ’16], counterfactual [Dhurandhar et al., ’18], and feature importance-based explanations [Lundberg & Lee, ’17]). In our work, we focus on a notion of interpretability that is based on the quality of local approximation explanations. Such explanations are a common and flexible post-hoc approach for IML, used by popular methods such as LIME [Riberio et al., ’16] and MAPLE [Plumb et al., ‘18] which we’ll briefly outline later in the blog. We then answer two questions that relate this notion of local explainability to important statistical properties of the model:

  1. Performance Generalization: How can a model’s predictive accuracy on unseen test data be related to the interpretability of the learned model? 
  2. Explanation Generalization: We look at a novel statistical problem that arises in a growing subclass of local approximation algorithms (such as MAPLE and RL-LIM [Yoon et al., ‘19]). Since these algorithms learn explanations by fitting them on finite data, the explanations may not necessarily fit unseen data well. Hence, we ask, what is the quality of those explanations on unseen data?

In what follows, we’ll first provide a quick introduction to local explanations. Then, we’ll motivate and answer our two main questions by discussing a pair of corresponding generalization guarantees that are in terms of how “accurate” and “complex” the explanations are. Here, the “complexity” of the local explanations corresponds to how large of a local neighborhood the explanations span (the larger the neighborhood, the lower the complexity — see Fig 1 for a visualization). For question (1), this results in a bound that roughly captures the idea that an easier-to-locally-approximate (f) enjoys better performance generalization. For question (2), our bound tells us that, when the explanations can accurately fit all the training data that fall within large neighborhoods, the explanations are likely to fit unseen data better. Finally, we’ll examine our insights in practice by verifying that these guarantees capture non-trivial relationships in a few real-world datasets. 

Local Explainability

Local approximation explanations operate on a basic idea: use a model from a simple family (like a linear model) to locally mimic a model from a more complex family (like a neural network model). Then, one can directly inspect the approximation (e.g. by looking at the weights of the linear model). More formally, for a given black-box model (f : mathcal{X} rightarrow mathcal{Y}), the explanation system produces at any input (x in mathcal{X}) , a ”simple” function (g_x(cdot) : mathcal{X} rightarrow mathcal{Y}) that approximates (f) in a local neighborhood around (x) . Here, we assume (f in mathcal{F}) (the complex model class) and (g_x in mathcal{G}_{text{local}} ) (the simple model class). 

As an example, here’s what LIME (Local Interpretable Model-agnostic Explanations) does. At any point (x), in order to produce the corresponding explanation (g_x), LIME would sample a bunch of perturbed points in the neighborhood around (x) and label them using the complex function (f(x)). It would then learn a simple function that fits the resulting dataset. One can then use this simple function to better understand how (f) behaves in the locality around (x).

Performance Generalization

Our first result is a generalization bound on the squared error loss of the function (f). Now, a typical generalization bound would look something like

$$text{TestLoss}(f) ≤ text{TrainLoss}(f) + sqrt{ frac{text{Complexity}(mathcal{F})}{text{# of training examples}}  }$$

where the bound is in terms of how well (f) fits the training data, and also how “complex” the function class (mathcal{F}) is. In practice though, (mathcal{F}) can often be a very complex class, rendering these bounds too large to be meaningful.

Yet while (mathcal{F}) is complex, what if the function (f) might itself have been picked from a subset of (mathcal{F}) that is in some way much simpler? For example, this is the case in neural networks trained by gradient descent [Zhang et al., ‘17, Neyshabur et al., ‘15]). Capturing this sort of simplicity could lead to more interesting bounds that (a) aren’t as loose and/or (b) shed insight into what meaningful  properties of (f) can influence how well it generalizes. While there are many different notions of simplicity that different learning theory results have studied, here we are interested in quantifying simplicity in terms of how “interpretable” (f) is, and relate that to generalization. 

To state our  result, imagine that we have a training set (S = {(x_i,y_i)}_{i=1}^{m})sampled from the data distribution (D). Then, we show the following bound on the test-time squared error loss:

$$underbrace{mathbb{E}_{D}[(f(x)-y)^2]}_{text{Test Loss}} leq underbrace{hat{mathbb{E}}_{S}[(f(x)-y)^2]}_{text{Train Loss}} + underbrace{mathbb{E}_{x sim D} mathbb{E}_{x’ sim N_{x}}[(g_{x’}(x) – f(x))^2]}_{text{Explanation Quality (MNF)}} + underbrace{rho_S cdot mathcal{R}(mathcal{G}_{local})}_{substack{text{Complexity of} \ text{Explanation System}}}.$$

Let’s examine these terms one by one.

Train Loss: The first term, as is typical of many generalization bounds, is simply the training error of (f) on (S). 

Explanation Quality (MNF): The second term captures a notion of how interpretable (f) is, measuring how accurate the set of local explanations (g) is with a quantity that we call the “mirrored neighborhood fidelity” (MNF). This metric is actually a slight modification of a standard notion of local explanation quality used in IML literature, called neighborhood fidelity (NF) [Riberio et al., ’16; Plumb et al., ‘18]. More concretely, we explain how MNF and NF are calculated below in Fig 2.

Fig 2. How MNF is calculated: We use orange to denote “source points” (where explanations are generated) and teal to denote “target points” (where approximation error is computed). To compute the inner expectation for MNF, we sample a single target point (x sim D). Then, we sample a source point (x’) from a “neighborhood” distribution (N_x) (typically a distribution centered at (x)). We then measure how well (f) is approximated at (x) by the explanation generated at (x’). Averaging over (x) and (x’), we define (text{MNF} = mathbb{E}_{x sim D} mathbb{E}_{x’ sim N_{x}}[(g_{x’}(x) – f(x))^2]). To get NF, we simply need to swap (x) and (x’) in the innermost term: (text{NF} = mathbb{E}_{x sim D} mathbb{E}_{x’ sim N_{x}}[(g_{x}(x’) – f(x’))^2]).

While notationally the differences between MNF and the standard notion of NF are slight, there are some noteworthy differences and potential (dis)advantages of using MNF over NF from an interpretability point of view. At a high level, we argue that MNF offers more robustness (when compared to NF) to any potential irregular behavior of (f) off the manifold of the data distribution. We discuss this in greater detail in the full paper.

Complexity of Explanation System: Finally, the third and perhaps the most interesting term measures how complex the infinite set of explanation functions ({g_x}_{x in mathcal{X}}) is. As it turns out, this system of explanations ({g_x}_{x in mathcal{X}}), which we will call (g), has a complexity that can be nicely decomposed into two factors. One factor, namely (mathcal{R}(mathcal{G}_{text{local}})), corresponds to the (Rademacher) complexity of the simple local class (mathcal{G}_{text{local}}), which is going to be a very small quantity, much smaller than the complexity of (mathcal{F}). Think of this factor as typically scaling linearly with the number of parameters for (mathcal{G}_{text{local}}) and also with the dataset size (m) as (1/sqrt{m}). The second factor is (rho_S), and is what we call the “neighborhood disjointedness” factor. This factor lies between ([1, sqrt{m}]) and is defined by how little overlap there is between the different local neighborhoods specified for each of the training datapoints in (S). When there is absolutely no overlap, (rho_S) can be as large as (sqrt{m}), but when all these neighborhoods are exactly the same, (rho_S) equals (1). 

Implications of the overall bound: Having unpacked all the terms, let us take a step back and ask: assuming that (f) has fit the training data well (i.e., the first term is small), when are the other two terms large or small? We visualize this in Fig 1. Consider the case where MNF can be made small by approximating (f) by (g) on very large neighborhoods (Fig 1 left). In such a case, the neighborhoods would overlap heavily, thus keeping (rho_S) small as well. Intuitively, this suggests good generalization when (f) is “globally simple”. On the other hand, when (f) is too complex, then we need to either shrink the neighborhoods or increase the complexity of (mathcal{G}_{text{local}}) to keep MNF small. Thus, one would either suffer from MNF or (rho_S) exploding, suggesting bad generalization. In fact, when (rho_S) is as large as (sqrt{m}), the bound is “vacuous” as the complexity term no longer decreases with the dataset size (m), suggesting no generalization!

Explanation Generalization

We’ll now turn to a different, novel statistical question which arises when considering a number of recent IML algorithms. Here we are concerned with how well explanations learned from finite data generalize to unseen data. 

To motivate this question more clearly, we need to understand a key difference between canonical and finite-sample-based IML approaches. In canonical approaches (e.g. LIME), at different values of (x’), the explanations (g_{x’}) are learned by fitting on a fresh bunch of points (S_{x’}) from a (user-defined) neighborhood distribution (N_{x’}) (see Fig 3, top). But a growing number of approaches such as MAPLE and RL-LIM learn their explanations by fitting (g_{x’}) on a “realistic” dataset (S) drawn from (D) (rather than from an arbitrary distribution) and then re-weighting the datapoints in (S) depending on a notion of their closeness to (x’) (see Fig 3, bottom).

Now, while the canonical approaches effectively train (g) on an infinite dataset (cup_{x’ in mathcal{X}} S_{x’}), recent approaches train (g) on only that finite dataset (S) (reused for every (x’)). 

Using a realistic (S) has certain advantages (as motivated in this blog post), but on the flip side, since (S) is finite, it can potentially result in a severe chance of overfitting (we visualize this in Fig 3 right). This makes it valuable to seek a guarantee on the approximation error of (g) on test data (“Test MNF”) in terms of its fit on the training data (S) (“Train MNF”). In our paper, we derive such a result below: 

$$underbrace{mathbb{E}_{x sim D} mathbb{E}_{x’ sim N_{x}}[(g_{x’}(x) – f(x))^2]}_{text{Test MNF}} leq underbrace{hat{mathbb{E}}_{x sim S} mathbb{E}_{x’ sim N_{x}}[(g_{x’}(x) – f(x))^2]}_{text{Train MNF}} + rho_S cdot mathcal{R}(mathcal{G}_{local}). $$

As before, what this bound implies is that when the neighborhoods have very little overlap, there is poor generalization. This indeed makes sense: if the neighborhoods are too tiny, any explanation (g_{x’}) would have been trained on a very small subset of (S) that falls within its neighborhood. Thus the fit of (g_{x’}) won’t generalize to other neighboring points.

Fig 3. Difference between canonical local explanations (a) vs. finite-sample-based explanations (b and c): On the top panel (a), we visualize how one would go about generating explanations for different source points in a canonical method like LIME. In the bottom panels (b and c), we visualize the more recent approaches where one uses (and reuses) a single dataset for each explanation. Crucially, to learn an explanation at a particular source point, these procedures correspondingly re-weight this common dataset (visualized by the orange square boxes which are more opaque for points closer to each source point). In panel (b), the common dataset is large enough that it leads to good explanations; but in panel (c), the dataset is too small that the explanations do not generalize well to their neighborhoods.


While our theoretical results offer insights that make sense qualitatively, we also want to make a case empirically that they indeed capture meaningful relationships between the quantities involved. Particularly, we explore this for neural networks trained on various regression tasks from the UCI suite and in the context of the “explanation generalization” bound. That is we learn explanations to fit a finite dataset (S) by minimizing Train MNF, and then evaluate what Test MNF is like. Here, there are two important relationships we empirically establish:

  1. Dependence on (rho_S): Given that our bounds may be vacuous for large (rho_S=sqrt{m}), does this quantity actually scale well in practice (i.e. less than (O(sqrt{m})))? Indeed, we observe that we can find reasonably large choices of the neighborhood size ( sigma ) without causing Train MNF to become too high (somewhere around (sigma = 1) in Fig 4 bottom) and for which we can also achieve a reasonably small (rho_S approx O(m^{0.2})) (Fig 4 top).
  2. Dependence on neighborhood size: Do wider neighborhoods actually lead to improved generalization gaps? From Fig 4 bottom, we do observe that as the neighborhood width increases, TrainMNF and TestMNF overall get closer to each other, indicating that the generalization gap decreases (Fig 4 bottom).

Fig 4. Empirical study of our bounds For various neighborhood widths (sigma), in the top, we plot the approximate exponent of (rho_S)‘s polynomial growth rate i.e., the exponent (c) in (rho_S = O(|S|^c)). Below, we plot train/test MNF. We observe a tradeoff here: increasing (rho_S) results in better values of (rho_S) but hurts the MNF terms.


We have shown how a model’s local explainability can be formally connected to some of its various important statistical properties. One direction for future work is to consider extending these ideas to high-dimensional datasets, a challenging setting where our current bounds become prohibitively large. Another direction would be to more thoroughly explore these bounds in the context of neural networks, for which researchers are in search of novel types of bounds [Zhang et al., ‘17; Nagarajan and Kolter ‘19].

Separately, when it comes to the interpretability community, it would be interesting to explore the advantages/disadvantages of evaluating and learning explanations via MNF rather than NF. As discussed here, MNF appears to have reasonable connections to generalization, and as we show in the paper, it may also promise more robustness to off-manifold behavior.

To learn more about our work, check out our upcoming ICLR paper. Moreover, for a broader discussion about IML and some of the most pressing challenges in the field, here is a link to a recent white paper we wrote.


Jeffrey Li, Vaishnavh Nagarajan, Gregory Plumb, and Ameet Talwalkar, 2021, “A Learning Theoretic Perspective on Local Explainability“, ICLR 2021.

Rich Caruana, Yin Lou, Johannes Gehrke, Paul Koch, Marc Sturm, and Noemie Elhadad, 2015, “Intelligible models for healthcare: Predicting pneumonia risk and hospital 30-day readmission.” ACM SIGKDD, 2015.

Chaofan Chen, Kancheng Lin, Cynthia Rudin, Yaron Shaposhnik, Sijia Wang, and Tong Wang, 2018, “An interpretable model with globally consistent explanations for credit risk.” NeurIPS 2018 Workshop on Challenges and Opportunities for AI in Financial Services: the Impact of Fairness, Explainability, Accuracy, and Privacy, 2018.

Anuj Karpatne, Gowtham Atluri, James H. Faghmous, Michael Steinbach, Arindam Banerjee, Auroop Ganguly, Shashi Shekhar, Nagiza Samatova, and Vipin Kumar, 2017, “Theory-guided data science: A new paradigm for scientific discovery from data.” IEEE Transactions on Knowledge and Data Engineering, 2017.

Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin, 2016, “Why should I trust you?: Explaining the predictions of any classifier.” ACM SIGKDD, 2016.

Scott M. Lundberg, and Su-In Lee, 2017, “A unified approach to interpreting model predictions.” NeurIPS, 2017.

Amit Dhurandhar, Pin-Yu Chen, Ronny Luss, Chun-Chen Tu, Paishun Ting, Karthikeyan Shanmugam, and Payel Das, 2018, “Explanations based on the Missing: Towards Contrastive Explanations with Pertinent Negatives” NeurIPS, 2018.

Jinsung Yoon, Sercan O. Arik, and Tomas Pfister, 2019, “RL-LIM: Reinforcement learning-based locally interpretable modeling”, arXiv 2019 1909.12367.

Gregory Plumb, Denali Molitor and Ameet S. Talwalkar, 2018, “Model Agnostic Supervised Local Explanations“, NeurIPS 2018.

Vaishnavh Nagarajan and J. Zico Kolter, 2019, “Uniform convergence may be unable to explain generalization in deep learning”, NeurIPS 2019

Behnam Neyshabur, Ryota Tomioka, Nathan Srebro, 2015, “In Search of the Real Inductive Bias: On the Role of Implicit Regularization in Deep Learning”, ICLR 2015 Workshop.

Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin Recht, Oriol Vinyals, 2017, “Understanding deep learning requires rethinking generalization”, ICLR’ 17.

Valerie Chen, Jeffrey Li, Joon Sik Kim, Gregory Plumb, and Ameet Talwalkar, 2021, “Towards Connecting Use Cases and Methods in Interpretable Machine Learning“, arXiv 2021 2103.06254.

Read More

Counterfactual predictions under runtime confounding

Counterfactual predictions under runtime confounding

Figure 1. Due to feasibility or ethical requirements, a prediction model may only access a subset of the confounding factors that affect both the decision and outcome. We propose a procedure for learning valid counterfactual predictions in this setting.

In machine learning, we often want to predict the likelihood of an outcome if we take a proposed decision or action. A healthcare setting, for instance, may require predicting whether a patient will be re-admitted to the hospital if the patient receives a particular treatment. In the child welfare setting, a social worker needs to assess the likelihood of adverse outcomes if the agency offers family services. In such settings, algorithmic predictions can be used to help decision-makers. Since the prediction target depends on a particular decision (e.g., the particular medical treatment, or offering family services), we refer to these predictions as counterfactual.

In general, for valid counterfactual inference, we need to measure all factors that affect both the decision and the outcome of interest. However, we may not want to use all such factors in our prediction model. Some factors such as race or gender may be too sensitive to use for prediction. Some factors may be too complex to use when model interpretability is desired, or some factors may be difficult to measure at prediction time.

Child welfare example: The child welfare screening task requires a social worker to decide which calls to the child welfare hotline should be investigated. In jurisdictions such as Allegheny County, the social worker makes their decision based on allegations in the call and historical information about individuals associated with the call, such as their prior child welfare interaction and criminal justice history. Both the call allegations and historical information may contain factors that affect both the decision and future child outcomes, but the child welfare agency may be unable to parse and preprocess call information in real-time for use in a prediction system. The social worker would still benefit from a prediction that summarizes the risk based on historical information. Therefore, the goal is a prediction based on a subset of the confounding factors.

Figure 2. Algorithmic predictions can help child welfare hotline screeners decide which cases to investigate. However, these predictions cannot access allegations in the call because of limitations in real-time processing.

Healthcare example: Healthcare providers may make decisions based on the patient’s history as well as lab results and diagnostic tests, but the patient’s health record may not be in a form that can be easily input to a prediction algorithm.

Figure 3. Predictions used to inform medical treatment decisions may not have access to all confounding factors.

How can we make counterfactual predictions using only a subset of confounding factors?

We propose a method for using offline data to build a prediction model that only requires access to the available subset of confounders at prediction time. Offline data is an important part of the solution because if we know nothing about the unmeasured confounders, then in general we cannot make progress. Fortunately, in our settings of interest, it is often possible to obtain an offline dataset that contains measurements of the full set of confounders as well as the outcome of interest and historical decision.

What is “runtime confounding?”

Runtime confounding occurs when all confounding factors are recorded in the training data, but the prediction model cannot use all confounding factors as features due to sensitivity, interpretability, or feasibility requirements. As examples,

  • It may not be possible to measure factors efficiently enough for use in the prediction model but it is possible to measure factors offline with sufficient processing time. Child welfare agencies typically do record call allegations for offline processing.
  • It may be undesirable to use some factors that are too sensitive or too complex for use in a prediction model.

Formally, let (V in mathbb{R}^{d_v}) denote the vector of factors available for prediction and (Z in mathbb{R}^{d_z}) denote the vector of confounding factors unavailable for prediction (but available in the training data). Given (V), our goal is to predict an outcome under a proposed decision; we wish to predict the potential outcome (Y^{A=a}) that we would observe under decision (a).

Prediction target: $$nu(v) := mathbb{E}[Y^{A=a} mid V = v] .$$ In order to estimate this hypothetical counterfactual quantity, we need assumptions that enable us to identify this quantity with observable data. We require three assumptions that are standard in causal inference:

Assumption 1: The decision assigned to one unit does not affect the potential outcomes of another unit.
Assumption 2: All units have some non-zero probability of receiving decision (a) (the decision of interest for prediction).
Assumption 3: (V,Z) describe all factors that jointly affect the decision and outcome.

These assumptions enable us to identify our target estimand as $$nu(v) = mathbb{E}[ mathbb{E}[Y mid A = a, V = v, Z =z] mid V =v].$$

This suggests that we can estimate an outcome model (mu(v,z) := mathbb{E}[Y mid A = a, V = v, Z =z]) and then regress the outcome model estimates on (V).

The simple plug-in (PL) approach:

  1. Estimate the outcome model (mu(v,z)) by regressing (Y sim V, Zmid A = a). Use this model to construct pseudo-outcomes (hat{mu}(V,Z)) for each case in our training data.
  2. Regress (hat{mu}(V,Z) sim V) to yield a prediction model that only requires knowledge of (V).
Figure 4. The Plug-in (PL) learning procedure. The full set of confounders ((V, Z)) is used to build an outcome model. The output of the outcome model and the available predictors (V) are used to build a prediction model.

How does the PL approach perform?

  • Yields valid counterfactual predictions under our three causal assumptions.
  • Not optimal: Consider the setting in which (d_z >> d_v), for instance, in the child welfare setting where (Z) corresponds to the natural language in the hotline call. The PL approach requires us to efficiently estimate a more challenging high-dimensional target (mathbb{E}[Y mid A = a, V = v, Z =z]) when our target is a lower-dimensional quantity (nu(V)).

We can better take advantage of the lower-dimensional structure of our target estimand using doubly-robust techniques, which are popular in causal inference because they give us two chances to get our estimation right.

Our proposed doubly-robust (DR) approach

In addition to estimating the outcome model like the PL approach, a doubly-robust approach also estimates a decision model (pi(v,z) := mathbb{E}[mathbb{I}{A=a} mid V = v, Z =z]), which is known as the propensity model in causal inference. This is particularly helpful in settings where it is easier to estimate the decision model than the outcome model.

We propose a doubly-robust (DR) approach that also involves two stages:

  1. Regress (Y sim V, Zmid A = a) to yield outcome model (hat{mu}(v,z)). Regress (mathbb{I}{A=a} sim V, Z) to yield decision model (hat{pi}(v,z)).
  2. Regress $$frac{mathbb{I}{A=a}}{hat{pi}(V,Z)}(Y – hat{mu}(V,Z)) + hat{mu}(V,Z) sim V.$$
Figure 5. Our proposed doubly-robust (DR) learning procedure. The full set of confounders ((V, Z)) is used to build an outcome model and a decision model. The output of the outcome and decision models and the available predictors (V) are used to build a prediction model.

When does the DR approach perform well?

  • When we can build either a very good outcome model or a very good decision model
  • If both the decision model and outcome model are somewhat good

The DR approach can achieve oracle optimality–that is, it achieves the same regression error (up to constants) as an oracle with access to the true potential outcomes (Y^a).

We can see this by bounding the error of our method (hat{nu}) with the sum of the oracle error and a product of error terms on the outcome and decision models:

mathbb{E}[(hat{nu}(v) – nu(v))^2] ≲
& mathbb{E}[(tilde{nu}(v) – nu(v))^2] + \
& mathbb{E}[(hat{pi}(V,Z) -pi(V,Z))^2 mid V = v]mathbb{E}[(hat{mu}(V,Z) -mu(V,Z))^2 mid V = v].

where (tilde{nu}(v)) denotes the function we would get in our second-stage estimation if we had oracle access to (Y^a).

So as long as we can estimate the outcome and decision models such that their product of errors is smaller than the oracle error, then the DR approach is oracle-efficient. This result holds for any regression method, assuming that we have used sample-splitting to learn (hat{nu}), (hat{mu}), and (hat{pi}).

While the DR approach has this desirable theoretical guarantee, in practice is it possible that the PL approach may perform better depending on the dimensionality of the problem.

How do I know which method I should use?

To determine which method will work best in a given setting, we provide an evaluation procedure that can be applied to any prediction method to estimate its mean-squared error. Under our three causal assumptions, the prediction error of a model (hat{nu}) is identified as

$$mathbb{E}[(Y^a – hat{nu}(V))^2] = mathbb{E}[mathbb{E}[(Y-hat{nu}(V)^2 mid V, Z, A = a]].$$

Defining the error regression (eta(v,z) = mathbb{E}[(Y-hat{nu}(V))^2 mid V = v, Z =a, A = a] ), we propose the following doubly-robust estimator for the MSE on a validation sample of (n) cases:

$$frac{1}{n} sum_{i=1}^n left[ frac{mathbb{I}{A_i = a }}{hat{pi}(V_i, Z_i)} left( (Y_i -hat{nu}(V_i))^2 – hat{eta}(V_i, Z_i) right) + hat{eta}(V_i, Z_i) right] .$$

Under mild assumptions, this estimator is (sqrt{n}) consistent, enabling us to get error estimates with confidence intervals.

DR achieves lowest MSE in synthetic experiments

We perform simulations on synthetic data to show how the level of confounding and dimensionalities of (V) and (Z) determine which method performs best. Synthetic experiments enable us to evaluate the methods on the ground-truth counterfactual outcomes. We compare the PL and DR approaches to a biased single-stage approach that estimates (mathbb{E}[Y mid V, A =a]), which we refer to as the treatment-conditional regression (TCR) approach.

MSE of the plug-in (PL), doubly-robust (DR), and treatment conditional regression (TCR) approaches to counterfactual prediction under runtime confounding as we vary the level of confounding ((k_z)) in the left-hand panel and as we vary (d_v), the dimensionality of our predictors (V), in the right-hand panel.

In the left-hand panel above, we compare the method as we vary the amount of confounding. When there is no confounding ((k_z = 0)), the TCR approach performs best as expected. Under no confounding, the TCR approach is no longer biased and efficiently estimates the target of interest in one stage. However, as we increase the level of confounding, the TCR performance degrades faster than the PL and DR methods. The DR method performs best under any non-zero level of confounding.

The right-hand panel compares the methods as we vary the dimensionality of our predictors. We hold the total dimensionality of ((V, Z)) fixed at (500) (so (d_z = 500 – d_v)). The DR approach performs best across the board, and the TCR approach performs well when the dimensionality is low because TCR avoids the high-dimensional second stage regression. However, this advantage disappears as (d_v) increases. The gap between the PL and DR methods is largest for low (d_v) because the DR method is able to take advantage of the lower dimensional target. At high (d_v) the PL error approaches the DR error.

DR is comparable to PL in a real-world task

We compare the methods on a real-world child welfare screening task where the goal is to predict the likelihood that a case will require services under the decision “screened in for investigation” using historical information as predictors and controlling for confounders that are sensitive (race) and hard to process (the allegations in the call). Our dataset consists of over 30,000 calls to the child welfare hotline in Allegheny County, PA. We evaluate the methods using our proposed real-world evaluation procedure since we do not have access to the ground-truth outcomes for cases that were not screened in for investigation.

Child welfare screening task: estimated MSE. The PL and DR methods achieve lower MSE than the TCR approach. Parentheses denote 95% confidence intervals.

We find that the DR and PL approach perform comparably on this task, both outperforming the TCR method.


  • Runtime confounding arises when it is undesirable or impermissible to use some confounding factors in the prediction model.
  • We propose a generic procedure to build counterfactual predictions when the factors are available in offline training data.
  • In theory, our approach is provably efficient in the oracle sense
  • In practice, we recommend building the DR, PL, and TCR approaches and using our proposed evaluation scheme to choose the best performing model.
  • Our full paper is available in the Proceedings of NeurIPS 2020.

Read More

Tilted Empirical Risk Minimization

Tilted Empirical Risk Minimization

Figure 1. A toy linear regression example illustrating Tilted Empirical Risk Minimization (TERM) as a function of the tilt hyperparameter (t). Classical ERM ((t=0)) minimizes the average loss and is shown in pink. As (t to -infty ) (blue), TERM finds a line of best fit while ignoring outliers. In some applications, these ‘outliers’ may correspond to minority samples that should not be ignored. As (t to + infty) (red), TERM recovers the min-max solution, which minimizes the worst loss. This can ensure the model is a reasonable fit for all samples, reducing unfairness related to representation disparity.

In machine learning, models are commonly estimated via empirical risk minimization (ERM), a principle that considers minimizing the average empirical loss on observed data. Unfortunately, minimizing the average loss alone in ERM has known drawbacks—potentially resulting in models that are susceptible to outliers, unfair to subgroups in the data, or brittle to shifts in distribution. Previous works have thus proposed numerous bespoke solutions for these specific problems.

In contrast, in this post, we describe our work in tilted empirical risk minimization (TERM), which provides a unified view on the deficiencies of ERM (Figure 1). TERM considers a modification to ERM that can be used for diverse applications such as enforcing fairness between subgroups, mitigating the effect of outliers, and addressing class imbalance—all in one unified framework.

What is Tilted ERM (TERM)?

Empirical risk minimization is a popular technique for statistical estimation where the model, (theta in R^d), is estimated by minimizing the average empirical loss over data, ({x_1, dots, x_N}):

$$overline{R} (theta) := frac{1}{N} sum_{i in [N]} f(x_i; theta).$$

Despite its popularity, ERM is known to perform poorly in situations where average performance is not an appropriate surrogate for the problem of interest. In our work (ICLR 2021), we aim to address deficiencies of ERM through a simple, unified framework—tilted empirical risk minimization (TERM). TERM encompasses a family of objectives, parameterized by the hyperparameter (t):

$$widetilde{R} (t; theta) := frac{1}{t} logleft(frac{1}{N} sum_{i in [N]} e^{t f(x_i; theta)}right). $$

TERM recovers ERM when (t to 0). It also recovers other popular alternatives such as the max-loss ((t to +infty)) and min-loss ((t to -infty)). While the tilted objective used in TERM is not new and is commonly used in other domains,1 it has not seen widespread use in machine learning.

In our work, we investigate tilting by: (i) rigorously studying properties of the TERM objective, and (ii) exploring its utility for a number of ML applications. Surprisingly, we find that this simple and general extension to ERM is competitive with state-of-the-art, problem-specific solutions for a wide range of problems in ML. 

TERM: Properties and Interpretations

Given the modifications that TERM makes to ERM, the first question we ask is: What happens to the TERM objective when we vary (t)? Below we explore properties of TERM with varying (t) to better understand the potential benefits of (t)-tilted losses. In particular, we find:

Figure 2. We present a toy problem where there are three samples with individual losses: (f_1, f_2), and (f_3). ERM will minimize the average of the three losses, while TERM aggregates them via exponential tilting parameterized by a family of (t)’s. As (t) moves from (-infty) to (+infty), (t)-tilted losses smoothly move from min-loss to avg-loss to max-loss. The colored dots are optimal solutions of TERM for (t in (-infty, +infty)). TERM is smooth for all finite (t) and convex for positive (t). 
  • TERM with varying (t)’s reweights samples to magnify/suppress outliers (as in Figure 1).
  • TERM smoothly moves between traditional ERM (pink line), the max-loss (red line), and min-loss (blue line), and can be used to trade-off between these problems. 
  • TERM approximates a popular family of quantile losses (such as median loss, shown in the orange line) with different tilting parameters. Quantile losses have nice properties but can be hard to directly optimize.

Next, we discuss these properties and interpretations in more detail.

Varying (t) reweights the importance of outlier samples

First, we take a closer look at the gradient of the t-tilted loss, and observe that the gradients of the t-tilted objective (widetilde{R}(t; theta)) are of the form:

$$nabla_{theta} widetilde{R}(t; theta) = sum_{i in [N]} w_i(t; theta) nabla_{theta} f(x_i; theta), text{where } w_i propto e^{tf(x_i; theta)}.$$

This indicates that the tilted gradient is a weighted average of the gradients of the original individual losses, and the weights are exponentially proportional to the loss values. As illustrated in Figure 1, for positive values of (t), TERM will thus magnify outliers (samples with large losses), and for negative t’s, it will suppress outliers by downweighting them.

TERM offers a trade-off between the average loss and min-/max-loss

Another perspective on TERM is that it offers a continuum of solutions between the min and max losses. As (t) goes from 0 to (+infty), the average loss will increase, and the max-loss will decrease (going from the pink star to the red star in Figure 2), smoothly trading average-loss for max-loss. Similarly, for (t<0), the solutions achieve a smooth tradeoff between average-loss and min-loss. Additionally, as (t >0) increases, the empirical variance of the losses across all samples also decreases. In the applications below we will see how these properties can be used in practice.

TERM solutions approximate superquantile methods

Finally, we note a connection to another popular variant on ERM: superquantile methods. The (k)-th quantile loss is defined as the (k)-th largest individual loss among all samples, which may be useful for many applications. For example, optimizing for the median loss instead of mean may be desirable for applications in robustness, and the max-loss is an extreme of the quantile loss which can be used to enforce fairness. However, minimizing such objectives can be challenging especially in large-scale settings, as they are non-smooth (and generally non-convex). The TERM objective offers an upper bound on the given quantile of the losses, and the solutions of TERM can provide close approximations to the solutions of the quantile loss optimization problem.

[Note] All discussions above assume that the loss functions belong to generalized linear models. However, we empirically observe competitive performance when applying TERM to broader classes of objectives, including deep neural networks. Please see our paper for full statements and proofs.

TERM Applications

TERM is a general framework applicable to a variety of real-world machine learning problems. Using our understanding of TERM from the previous section, we can consider ‘tilting’ at different hierarchies of the data to adjust to the problem of interest. For instance, one can tilt at the sample level, as in the linear regression toy example. It is also natural to perform tilting at the group level to upweight underrepresented groups. Further, we can tilt at multiple levels to address practical applications requiring multiple objectives. In our work, we consider applying TERM to the following problems:

  • [negative (t)’s]: robust regression, robust classification, mitigating noisy annotators
  • [positive (t)’s]: handling class imbalance, fair PCA, variance reduction for generalization
  • [hierarchical tilting with (t_1<0, t_2>0)]: jointly addressing robustness and fairness

For all applications considered, we find that TERM is competitive with or outperforms state-of-the-art, problem-specific tailored baselines. In this post, we discuss three examples on robust classification with (t<0), fair PCA with (t>0), and hierarchical tilting. 

Robust classification 

Crowdsourcing is a popular technique for obtaining data labels from a large crowd of annotators. However, the quality of annotators varies significantly as annotators may be unskilled or even malicious. Thus, handling a large amount of noise is essential for the crowdsourcing setting. Here we consider applying TERM ((t<0)) to the application of mitigating noisy annotators.

Specifically, we explore a common benchmark—taking the CIFAR10 dataset and simulating 100 annotators where 20 of them are always correct and 80 of them assign labels uniformly at random. We use negative (t)’s for annotator-level tilting, which is equivalent to assigning annotator-level weights based on the aggregate value of their loss.

Figure 3 demonstrates that our approach performs on par with the oracle method that knows the qualities of annotators in advance. Additionally, we find that the accuracy of TERM alone is 5% higher than that reported by previous approaches which are specifically designed for this problem.

Figure 3. TERM removes the impact of noisy annotators.

Fair principal component analysis (PCA)

While the previous application explored TERM with (t<0), here we consider an application of TERM with positive (t)’s to fair PCA. PCA is commonly used for dimension reduction while preserving useful information of the original data for downstream tasks. The goal of fair PCA  is to learn a projection that achieves similar (or the same) reconstruction errors across subgroups. 

Applying standard PCA can be unfair to underrepresented groups. Figure 4 demonstrates that the classical PCA algorithm results in a large gap in the representation quality between two groups (G1 and G2).

To promote fairness, previous methods have proposed to solve a min-max problem via semidefinite programming, which scales poorly with the problem dimension. We apply TERM to this problem, reweighting the gradients based on the loss on each group. We see that TERM with a large (t) can recover the min-max results where the resulting losses on two groups are almost identical. In addition, with moderate values of (t), TERM offers more flexible tradeoffs between performance and fairness by reducing the performance gap less aggressively. 

Figure 4. TERM applied to PCA recovers the min-max fair solution with a large t, while offering more flexibility to trade performance on Group 1 (G1) for performance on Group 2 (G2).

Solving compound issues: multi-objective tilting

Finally, we note that in practice, multiple issues can co-exist in the data, e.g., we may have issues with both class imbalance and label noise. In these settings we can adopt hierarchical TERM as described previously to address compound issues. Depending on the application, one can choose whether to apply tilting at each level (e.g., possibly more than two levels of hierarchies exist in the data), and at either direction ((t>0) or (t<0)). For example, we can perform negative tilting at the sample level within each group to mitigate outlier samples, and perform positive tilting across all groups to promote fairness.

We test this protocol on the HIV-1 data with logistic regression. In Table 1 below, we find that TERM is superior to all baselines which perform well in their respective problem settings (only showing a subset here) when considering noisy samples and class imbalance simultaneously. 

Table 1. Hierarchical TERM can address both class imbalance and noisy samples

More broadly, the idea of tilting can be applied to other learning problems, like GAN training, meta-learning, and improving calibration and generalization for deep learning. We encourage interested readers to view our paper, which explores a more comprehensive set of applications.

[Solving TERM] Wondering how to solve TERM? In our work, we discuss the properties of the objective in terms of its smoothness and convexity behavior. Based on that, we develop both batch and (scalable) stochastic solvers for TERM, where the computation cost is within 2(times) of standard ERM solvers. We describe these algorithms as well as their convergence guarantees in our paper.


Our work explores tilted empirical risk minimization (TERM), a simple and general alternative to ERM, which is ubiquitous throughout machine learning. Our hope is that the TERM framework will allow machine learning practitioners to easily modify the ERM objective to handle practical concerns such as enforcing fairness amongst subgroups, mitigating the effect of outliers, and ensuring robust performance on new, unseen data. Critical to the success of such a framework is understanding the implications of the modified objective (i.e., the impact of varying (t)), both theoretically and empirically. Our work rigorously explores these effects—demonstrating the potential benefits of tilted objectives across a wide range of applications in machine learning.

Interested in learning more?


Thanks to Maruan Al-Shedivat, Ahmad Beirami, Virginia Smith, and Ivan Stelmakh for feedback on this blog post.


1    For instance, this type of exponential smoothing (when (t>0)) is commonly used to approximate the max. Variants of tilting have also appeared in other contexts, including importance sampling, decision making, and large deviation theory.

Read More

Carnegie Mellon University at the Conference on Fairness, Accountability, and Transparency (ACM FAccT 2021)

Carnegie Mellon University at the Conference on Fairness, Accountability, and Transparency (ACM FAccT 2021)

This week researchers from all across computer science, social sciences and humanities are gathering for the flagship conference of the emerging field of Fairness, Accountability and Transparency in algorithmic systems: FAccT. FAccT (previously FAT*) is dedicated to the inherent risks that come with the increasing adoption of data-driven algorithmic decision making systems in socially consequential domains such as policing, criminal justice, health care and education. The conference was formed as a venue for the increasing volume of work in this area in 2018 and has since become one of the top venues in the study of societal impacts of machine learning – submissions have more than quadrupled since the inaugural conference!

Number of submitted and accepted papers at FAccT since inaugural conference.

Now in its 4th year, the fully-virtual event spans 82 paper presentations from 15 different countries across 14 time zones as well as 13 tutorials, a doctoral consortium and 10 CRAFT sessions aimed at Critiquing and Rethinking Accountability, Fairness and Transparency. Complementing paper presentations and tutorials, the CRAFT sessions aim for interaction between participants with different backgrounds including academics, journalists, advocates, activists, educators and artists with the idea of reflection and discussion of the field from a more holistic perspective.

Many influential papers have been published at FAccT even within these first few years of the conference. Examples include Joy Buolamwini and Timnit Gebru’s 2018 study on Gender Shades in which the authors uncover significantly higher error rates in commercial gender classification for darker-skinned females which led companies to adjust their algorithms and sparked a wider discussion of similar problems in computer vision. Leading up to this year’s conference, the paper ‘On the Dangers of Stochastic Parrots: Can Language Models Be Too Big?’ coming out of Google and the University of Washington has gotten much attention in the wider field as it led to the firing of Timnit Gebru as co-lead of the Ethical AI team from Google leaving both room for speculations and sparking a discussion on the future of AI ethics research in private companies.

As one of the main contributing institutions, Carnegie Mellon University is proud to present 10 papers and one tutorial at this year’s conference. Contributions are made from all across campus with authors from the Machine Learning Department, the Department of Statistics and Data Science, the Institute for Software Research, the Computer Science Department, Heinz College of Public Policy, and the Philosophy Department. Several of the studies focus on auditing existing systems in the context of predictive policing [4], image representations learned in an unsupervised manner [5], or the use of mobility data for Covid-19 policy [6]. Others propose new algorithmic solutions to analyze the allocation of opportunities for intergenerational mobility [1], post-process predictions in risk assessment [2], examine the equity of cash bail decisions [3], or understand the fairness implications of leave-one-out training data [8]. The authors of [7] focus on disparity amplification avoidance under different world views and fairness notions, while [9] introduce Value Cards, an educational toolkit for teaching the societal impacts of machine learning. Finally, the authors of [10] provide counternarratives on data sharing in Africa using a storytelling approach based on a series of interviews. We give a short description of each of the papers along with the session times at the conference and links to the preprints below.


[1] Allocating Opportunities in a Dynamic Model of Intergenerational Mobility
Hoda Heidari (Carnegie Mellon University), Jon Kleinberg (Cornell University)
Session: March 8, 22:00 – 23:45 UTC 
Tags: Algorithm Development, Fairness
Summary: The authors develop a model for analyzing the allocation of opportunities for intergenerational mobility such as higher education and find that purely payoff-maximizing objectives can still lead to a form of socioeconomic affirmative action in the optimal allocation.

[2] Fairness in Risk Assessment Instruments: Post-Processing to Achieve Counterfactual Equalized Odds
Alan Mishler (Carnegie Mellon University), Edward Kennedy (Carnegie Mellon University), Alexandra Chouldechova (Carnegie Mellon University)
Session: March 10, 20:00 – 21:30 UTC
Tags: Algorithm Development, Causality, Evaluation, Fairness
Summary: The authors develop a method to post-process existing binary predictors used in risk assessment, e.g. for recidivism prediction, to satisfy approximate counterfactual equalized odds. They discuss the convergence rate to an optimal fair predictor and propose doubly robust estimation of the risk and fairness properties of a fixed post-processed predictor.

[3] A Bayesian Model of Cash Bail Decisions
Joshua Williams (Carnegie Mellon University), Zico Kolter (Carnegie Mellon University)
Session: March 8, 20 – 21:45 UTC
Tags: Algorithm Development, Data, Fairness, Law & Policy
Summary: The authors create a hierarchical Bayesian model of cash bail assignments to analyze fairness between racial groups while overcoming the problem of infra-marginality. Results on 50 judges uniformly show that they are more likely to assign cash bail to black defendants than to white defendants given the same likelihood of skipping a court appearance.

[4] The effect of differential victim crime reporting on predictive policing systems
Nil-Jana Akpinar (Carnegie Mellon University), Maria De-Arteaga (University of Texas at Austin), Alexandra Chouldechova (Carnegie Mellon University)
Session: March 8, 20:00 – 21:45 UTC
Tags: Auditing, Data, Evaluation
Summary: The authors audit place-based predictive policing algorithms trained on victim crime reporting data and find that geographical bias arises when victim crime reporting rates vary within a city. This result requires no use of arrest data or data from police initiated contact.

[5] Image Representations Learned With Unsupervised Pre-Training Contain Human-like Biases
Ryan Steed (Carnegie Mellon University), Aylin Caliskan (George Washington University)
Session: March 8, 12:00 – 13:45 UTC
Tags: Computer Vision, Data, Evaluation, Fairness, Humanistic Theory & Critique
Summary: The authors develop a method for quantifying biased associations between representations of social concepts and attributes in images using image representations learned in an unsupervised manner. The results closely match hypotheses about intersectional bias from social psychology and suggest that machine learning models can automatically learn bias from the way people are stereotypically portrayed on the web.

[6] Leveraging Administrative Data for Bias Audits: Assessing Disparate Coverage with Mobility Data for COVID-19 Policy
Amanda Coston (Carnegie Mellon University), Neel Guha (Stanford University), Derek Ouyang (Stanford University), Lisa Lu (Stanford University), Alexandra Chouldechova (Carnegie Mellon University), Daniel E. Ho (Stanford University)
Session: March 9, 14:00 – 15:45 UTC
Tags: Auditing, Data, Evaluation
Summary: The authors audit the use of smartphone-based mobility data for COVID-19 policy by leveraging administrative voter roll data in the absence of demographic information. Their results suggest that older and non-white voters are less liekely to be captured by mobility data which can disproportionally harm these groups if allocation of public health resources is based on such data sets.

[7] Avoiding Disparity Amplification under Different Worldviews
Samuel Yeom (Carnegie Mellon University), Michael Carl Tschantz (International Computer Science Institute)
Session: Match 10, 14:00 – 15:45 UTC
Tags: Data, Evaluation, Metrics
Summary: The authors mathematically compare competing definitions of group-level fairness and their properties under various worldviews which are assumptions about how, if at all, the observed data is biased. They discuss the criterion of disparity amplification and introduce a new world view with a corresponding notion of fairness as a more realistic perspective.

[8] Leave-one-out Unfairness
Emily Black (Carnegie Mellon University), Matt Fredrikson (Carnegie Mellon University)
Session: March 9, 22:00 – 23:45 UTC
Tags: Algorithm Development, Data, Evaluation, Fairness, Metrics 
Summary: The authors introduce leave-one-out unfairness which focuses on the change of prediction for an individual due to inclusion or exclusion of a single other individual from the training data. They discuss the relation of this concept to robustness, memorization and individual fairness in deep models.

[9] Value Cards: An Educational Toolkit for Teaching Social Impacts of Machine Learning through Deliberation
Hong Shen (Carnegie Mellon University), Wesley Deng (UC Berkeley), Aditi Chattopadhyay (Carnegie Mellon University), Steven Wu (Carnegie Mellon University), Xu Wang (University of Michigan), Haiyi Zhu (Carnegie Mellon University)
Session: March 8, 22:00 – 23:45 UTC
Tags: Accountability, Education, Human Factors
Summary: The authors introduce Value Cards, an educational toolkit with topics related to Fairness, Accountability, and Ethics, and present an early use of the approach in a college-level computer science course. Results suggest that the use of the toolkit can improve students’ understanding of both technical definitions and trade-offs of performance metrics and apply them in real-world contexts.

[10] Narratives and Counternarratives on Data Sharing in Africa
Rediet Abebe (UC Berkeley), Kehinde Aruleba (University of Witwatersrand), Abeba Birhane (University College Dublin), Sara Kingsley (Carnegie Mellon University), George Obaido (University of Witwatersrand), Sekou L. Remy (IBM Research Africa), Swathi Sadagopan (Deloitte)
Session: March 9, 12:00 – 13:50 UTC
Tags: Data, Ethics, Humanistic Theory & Critique 
Summary: The authors use storytelling via fictional personas built from a series of interviews with African data experts to complicate dominant narratives and provide counternarratives on data sharing in Africa. They discuss issues arising from power imbalances and Western-centric policies in the context of open data initiatives centered around data extracted from African communities and discuss avenues for addressing these issues.


Sociocultural diversity in machine learning: Lessons from philosophy, psychology, and organizational science
Sina Fazelpour (Carnegie Mellon University) and Maria De-Arteaga (University of Texas at Austin)
Session: March 4, 14:00 – 15:30 UTC
Summary: The current discussion of sociocultural diversity in machine learning research leaves a gap between the conversation about measures and benefits and the philosophical, psychological and organizational research on the underlying concepts. This tutorial addresses the concepts and consequences of sociocultural diversity and situates this understanding and its implications for the discussion of sociocultural diversity in machine learning.

Read More

An Inferential Perspective on Federated Learning

An Inferential Perspective on Federated Learning

TL;DR: motivated to better understand the fundamental tradeoffs in federated learning, we present a probabilistic perspective that generalizes and improves upon federated optimization and enables a new class of efficient federated learning algorithms.

Thanks to deep learning, today we can train better machine learning models when given access to massive data. However, the standard, centralized training is impossible in many interesting use-cases—due to the associated data transfer and maintenance costs (most notably in video analytics), privacy concerns (e.g., in healthcare settings), or sensitivity of the proprietary data (e.g., in drug discovery). And yet, different parties that own even a small amount of data want to benefit from access to accurate models. This is where federated learning comes to the rescue!

Broadly, federated learning (FL) allows multiple data owners (or clients) to train shared models collaboratively under the orchestration of a central server without having to share any data. Typically, FL proceeds in multiple rounds of communication between the server and the clients: the clients compute model updates on their local data and send them to the server which aggregates and applies these updates to the shared model. While gaining popularity very quickly, FL is a relatively new subfield with many open questions and unresolved challenges.

Here is one interesting conundrum driving our work:

Client-server communication is often too slow and expensive. To speed up training (often x10-100) we can make clients spend more time at each round on local training (e.g., do more local SGD steps), thereby reducing the total number of communication rounds. However, because of client data heterogeneity (natural in practice), it turns out that increasing the amount of local computation per round results in convergence to inferior models!

This phenomenon is illustrated below in Figure 1 on a toy convex problem, where we see that more local steps lead the classical federated averaging (FedAvg) algorithm to converge to points that are much further away from the global optimum. But why does this happen?

Figure 1: A toy 2D setting with two clients and quadratic objectives that illustrates the convergence issues of FedAvg. Left: convergence trajectories in the parameter space. Right: convergence in terms of distance from the global optimum. Each drawing of the plot corresponds to a run of federated optimization from a different starting point in the parameter space. More local SGD steps per round speed up training, but the progress eventually stagnates at an inferior point further away from the global optimum.

In this post, we will present a probabilistic perspective on federated learning that will help us better understand this phenomenon and design new FL algorithms that can utilize local computation much more efficiently, converging faster, to better optima.

The classical approach: FL as a distributed optimization problem

Federated learning was originally introduced as a new setting for distributed optimization with a few distinctive properties such as a massive number of distributed nodes (or clients), slow and expensive communication, and unbalanced and non-IID data scattered across the nodes. The main goal of FL is to approximate centralized training (the gold-standard) and converge to the same optimum as the centralized optimization would have, at the fastest rate possible.

Mathematically, FL is formulated as minimization of a linear combination of local objectives, (f_i): $$min_{theta in mathbb{R}^d} left{F(theta) := sum_{i=1}^N q_i f_i(theta) right}$$ where the weights (q_i) are usually set proportional to the sizes (n_i) of the local datasets to make (F(theta)) match the centralized training objective. So, how can we solve this optimization problem within a minimal number of communication rounds?

The trick is simple: at each round (t), instead of asking clients to estimate and send gradients of their local objective functions (as done in conventional distributed optimization), let them optimize their objectives for multiple steps (or even epochs!) to obtain (theta^t_{i}) and send differences (or “deltas”) between the initial (theta^t) and updated states (theta^t_{i}) to the server as pseudo-gradients, which the server then averages, scales by a learning rate (alpha_t), and uses to update the model state: $$theta^{t+1} = theta^t + alpha_t sum_{i=1}^N q_i Delta_i^t, quad text{where} Delta_i^t := theta^t – theta_i^t$$ This approach, known as FedAvg or local SGD, allows clients to make more progress at each round. And since taking additional SGD steps locally is orders of magnitude faster than communicating with the server, the method converges much faster both in the number of rounds and in wall-clock time.

The problem (a.k.a. “client drift”): as we mentioned in the beginning, allowing multiple local SGD steps between client-server synchronization makes the algorithm converge to an inferior optimum in the non-IID setting (i.e., when clients have different data distributions) since the resulting pseudo-gradients turn out to be somehow biased compared to centralized training.

There are ways to overcome client drift using local regularization, carefully setting learning rate schedules, or using different control variate methods, but most of these mitigation strategies intentionally have to limit the optimization progress clients can make at each round.

Fundamentally, viewing FL as a distributed optimization problem runs into a tradeoff between the amount of local progress allowed and the quality of the final solution.

So, is there a way around this fundamental limitation?

An alternative approach: FL via posterior inference

Typically, client objectives (f_i(theta)) correspond to log-likelihoods of their local data. Therefore, statistically speaking, FL is solving a maximum likelihood estimation (MLE) problem. Instead of solving it using distributed optimization techniques, however, we can take a Bayesian approach: first, infer the posterior distribution, (P(theta mid D)), then identify its mode which will be the solution.

Why is posterior inference better than optimization? Because any posterior can be exactly decomposed into a product of sub-posteriors: $$P(theta mid D) propto prod_{i=1}^N P(theta mid D_i)$$

Thus, we are guaranteed to find the correct solution in three simple steps:

  1. Infer local sub-posteriors on each client and send their sufficient statistics to the server.
  2. Multiplicatively aggregate sub-posteriors on the server into the global posterior.
  3. Find and return the mode of the global posterior.

Wait, isn’t posterior inference intractable!? 😱

Indeed, there is a reason why posterior inference is not as popular as optimization: it is either intractable or often significantly more complex and computationally expensive. Moreover, posterior distributions rarely have closed form expressions and require various approximations.

For example, consider federated least squares regression, with quadratic local objectives: (f_i(theta) = frac{1}{2} |X_i^toptheta – y_i|^2.) In this case, the global posterior mode has a closed form expression: $$theta^star = left( sum_{i=1}^N q_i Sigma_i^{-1} right)^{-1} left( sum_{i=1}^N q_i Sigma_i^{-1} mu_i right)$$ where (mu_i) and (Sigma_i) are the means and covariances of the local posteriors. Even though in this simple case the posterior is Gaussian and inference is technically tractable, computing (theta^star) requires inverting multiple matrices and communicating local means and covariances from the clients to the server. In comparison to FedAvg, which requires only (O(d)) computation and (O(d)) communication per round, posterior inference seems like a very bad idea…

Approximate inference FTW! 😎

Turns out that we can compute approximately using an elegant distributed inference algorithm which we call federated posterior averaging (or FedPA):

  1. On the server, we can compute iteratively over multiple rounds: $$theta^{t+1} = theta^t – alpha_t sum_{i=1}^N q_i underbrace{Sigma_i^{-1}left( theta^t – mu_i right)}_{:= Delta_i^t}$$ where (alpha_t) is the server learning rate. This procedure avoids the outer matrix inverse and requires clients to send to the server only some delta vectors instead of full covariance matrices. Also, the summation can be substituted with a stochastic approximation, i.e., only a subset of clients must participate in each round. Note how similar it is to FedAvg!
  2. On the clients, we can compute (Delta_i^t := Sigma_i^{-1}left( theta^t – mu_i right)) very efficiently in two steps:
    1. Use stochastic gradient Markov chain Monte Carlo (SG-MCMC) to produce multiple approximate samples from the local posterior.
    2. Use an efficient dynamic programming procedure to compute the inverse covariance matrix multiplied by a vector in (O(d)) time and memory.

Note: in the case of arbitrary non-Gaussian likelihoods (which is the case for deep neural nets), FedPA essentially approximates the local and global posteriors with the best fitting Gaussians (a.k.a. the Laplace approximation).

What is the difference between FedAvg and FedPA? 🤔

FedPA has the same computation and communication complexity as FedAvg. In fact, the algorithms differ only in how the client updates (Delta_i^t) are computed. Since FedAvg computes (Delta_i^t approx theta^t – mu_i), we can also view it as an approximate posterior inference algorithm that estimates local covariances (Sigma_i) with identity matrices, which results in biased updates!

Figure 2: Bias and variance of the deltas computed by FedAvg and FedPA for 10-dimensional federated least squares. More local steps increase the bias of FedAvg; FedPA is able to utilize additional computation to reduce that bias.

Figure 2 illustrates the difference between FedAvg and FedPA in terms of the bias and variance of updates they compute at each round as functions of the number of SGD steps:

  • More local SGD steps increase the bias of FedAvg updates, leading the algorithm to converge to a point further away from the optimum.
  • FedPA uses local SGD steps to produce more posterior samples, which improves the estimates of the local means and covariances and reduces the bias of model updates.

Does FedPA actually work in practice? 🧐

The bias-variance tradeoff argument seems great in theory, but does it actually work in practice? First, let’s revisit our toy 2D example with 2 clients and quadratic objectives:

Figure 3: FedPA vs. FedAvg in our toy 2D setting with two clients and quadratic objectives.

We see that not only is FedPA as fast as FedAvg initially but it also converges to a point that is significantly closer to the global optimum. At the end of convergence, FedPA exhibits some oscillations that could be further eliminated by increasing the number of local posterior samples.

Next, let’s compare FedPA with FedAvg head-to-head on realistic and challenging benchmarks, such as the federated CIFAR100 and StackOverflow datasets:

Figure 4: CIFAR-100: Evaluation loss (left) and accuracy (right) for FedAvg and FedPA. Each algorithm used 20 clients per round and ran local SGD with momentum for 10 epochs (hence “-ME” suffixes, which stand for “multi-epoch”).
Figure 5: StackOverlfow LR: Evaluation loss (left) and macro-F1 (right) for FedAvg and FedPA. Each algorithm used 10 clients per round and ran local SGD with momentum for 5 epochs (hence “-ME” suffixes, which stand for “multi-epoch”).

For clients to be able to sample from local posteriors using SG-MCMC, their models have to be close enough to local optima in the parameter space. Therefore, we first “burn-in” FedPA for a few rounds by running it in the FedAvg regime (i.e., compute the deltas the same way as FedAvg). At some point, we switch to local SG-MCMC sampling. Figures 4 and 5 show the evaluation metrics over the course of training. We clearly see a significant jump in performance right at the point when the algorithm was essentially switched from FedAvg to FedPA.

Concluding thoughts & what’s next?

Viewing federated learning through the lens of probabilistic inference turned out to be fruitful. Not only were we able to reinterpret FedAvg as a biased approximate inference algorithm and explain the strange effect of multiple local SGD steps on its convergence, but this new perspective allowed us to design a new FL algorithm that blends together optimization with local MCMC-based posterior sampling and utilizes local computation efficiently.

We believe that FedPA is just the beginning of a new class of approaches to federated learning. One of the biggest advantages of the distributed optimization over posterior inference so far is a strong theoretical understanding of FedAvg’s convergence and its variations in different IID and non-IID settings, which was developed over the past few years by the optimization community. Convergence analysis of posterior inference in different federated settings is an important research avenue to pursue next.

While FedPA relies on a number of specific design choices we had to make (the Laplace approximation, MCMC-based local inference, the shrinkage covariance estimation, etc.), our inferential perspective connects FL to a rich toolbox of techniques from Bayesian machine learning literature: variational inference, expectation propagation, ensembling and Bayesian deep learning, privacy guarantees for posterior sampling, among others. Exploring application of these techniques in different FL settings may lead us to even more interesting discoveries!

Want to learn more?

ACKNOWLEDGEMENTS: Thanks to Jenny Gillenwater, Misha Khodak, Peter Kairouz, and Afshin Rostamizadeh for feedback on this blog post.

DISCLAIMER: All opinions expressed in this post are those of the author and do not represent the views of CMU.

Read More

Representational Aspects of Depth and Conditioning in Normalizing Flows

Representational Aspects of Depth and Conditioning in Normalizing Flows

Top and Bottom Right: RealNVP [3] uses checkerboard and channel-wise partitioning schemes in order to factor out parameters and ensure that there aren’t redundant partitions from previous layers. GLOW [4] uses an invertible 1×1 convolution which allows the partitioned to be ‘learned’ by a linear layer. We show that arbitrary partitions can be simulated in a constant number of layers with a fixed partition, showing that these ideas increase representational power by at most a constant factor. Bottom Left: Random points are well-separated with high probability on a high-dimensional sphere, which allows us to construct a distribution that is challenging for flows.

The promise of unsupervised learning lies in its potential to take advantage of cheap and plentiful unlabeled data to learn useful representations or generate high-quality samples. For the latter task, neural network-based generative models have recently enjoyed a lot of success in producing realistic images and text. Two major paradigms in deep generative modeling are generative adversarial networks (GANs) and normalizing flows. When successfully scaled up and trained, both can generate high-quality and diverse samples from high-dimensional distributions. The training procedure for GANs involves min-max (saddle-point) optimization, which is considerably more difficult than standard loss minimization, leading to problems like mode dropping.

Samples from a GLOW [4] model trained on the CelebA Faces Dataset.

Normalizing flows [1] have been proposed as an alternative type of generative model which allows not only efficient sampling but also training via maximum likelihood through a closed-form computation of the likelihood function. They are written as pushforwards of a simple distribution (typically a Gaussian) through an invertible transformation (f), typically parametrized as a composition of simple invertible transformations. The main reason for this parametrization is the change-of-variables formula: if (z) is a random variable sampled from a known base distribution (P(z)) (typically a standard multivariate normal), (f: mathbb{R}^dto mathbb{R}^d) is invertible and differentiable, and (x = f^{-1}(z)) then $$p(x) = p(f(x))left|detleft(frac{partial f(x)}{partial x^T}right)right|.$$ Here, (frac{partial f(x)}{partial x^T}) is the Jacobian of (f).

Normalizing flows are trained by maximizing the likelihood using gradient descent. However, in practice, training normalizing flows runs into difficulties as well: models which produce good samples typically need to be extremely deep — which comes with accompanying vanishing/exploding gradient problems. A very related problem is that they are often poorly conditioned. Data like images are often inherently lower-dimensional than the ambient space, the map from low dimensional data to high dimensional latent variables can be difficult to invert and therefore train.

In our recent work [2], we tackle representational questions around depth and conditioning of normalizing flows—first for general invertible architectures, then for a particular common architecture—where the normalizing flow is a composition of so-called affine couplings.

Depth Bound on General Invertible Architectures

The most fundamental restriction of the normalizing flow paradigm is that each layer needs to be invertible. We ask whether this restriction has any ‘cost’ in terms of the size, and in particular the depth, of the model. Here we’re counting depth in terms of the number of the invertible transformations that make up the flow. A requirement for large depth would explain training difficulties due to exploding (or vanishing) gradients.

Since the Jacobian of a composition of functions is the product of the Jacobians of the functions being composed, the min (max) singular value of the Jacobian of the composition is the product of the min (max) singular value of the Jacobians of the functions. This implies that the smallest (largest) singular value of the Jacobian will get exponentially smaller (larger) with the number of compositions.

A natural way of formalizing this question is by exhibiting a distribution which is easy to model for an unconstrained generator network but hard for a shallow normalizing flow. Precisely, we ask: is there a probability distribution that can be represented by a shallow generator with a small number of parameters that could not be approximately represented by a shallow composition of invertible transformations?

We demonstrate that such a distribution exists. Specifically, we show that

Theorem: For every k, s.t. (k = o(exp(d))) and any parameterized family of compositions of Lipschitz invertible transformations with (p) parameters per transformation and at most (O(k / p)) transformations, there exists a generator (g:mathbb{R}^{d+1} to mathbb{R}) with depth (O(1)) and (O(k)) parameters s.t the pushforward of a Gaussian through (g) cannot be approximated in either KL or Wasserstein-1 distance by a network in this family.

The result above is extremely general: it only requires a bound on the number of parameters per transformation in the parametrization of the normalizing flow and Lipschitzness of these maps. As such it easily includes common choices used in practice like affine couplings with at most (p) parameters per layer or invertible feedforward networks, where each intermediate layer is of dimension (d) and the nonlinearity is invertible (e.g. leaky ReLU). On the flip side, for possible architectures with a large number of parameters per transformation, this theorem gives a (possibly loose) lower bound of a small number of transformations.

Proof Sketch: The generator for our construction approximates a mixture of (k) Gaussians with means placed uniformly randomly on a (d)-dimensional sphere in the ambient space. We will use the probabilistic method to show there is a family of such mixtures, s.t. each pair of members in this family are far apart (say, in Wasserstein distance). Furthermore, by an epsilon net discretization argument we can count how many “essentially” distinct invertible Lipschitz networks there are. If the number of mixtures in the family is much larger than the size of the epsilon net, at least one mixture must be far from all invertible networks.

We choose well-separated modes on a hypersphere in order to generate a large family of well separated mixtures of Gaussians.

The family of mixtures is constructed by choosing the (k) means for the components uniformly at random on a sphere. It’s well known that (exp(o(d))) randomly chosen points on a unit sphere will, with high probability, have constant pairwise distance. Similarly, coding-theoretic arguments (used to prove the so-called Gilbert-Varshamov bound) can be used to show that selecting (exp(o(d))) (k)-tuples of those means will, with high probability, ensure that each pair of (k)-tuples is such that the average pair of means is at constant distance. This suffices to ensure the Wasserstein distance between pairs of mixtures is large. ∎

Results for Affine Couplings

Affine Couplings [3] are one of the most common transformations in scalable architectures for normalizing flows. An affine coupling is a map (f: mathbb{R}^dto mathbb{R}^d) such that for some partition into a set containing approximately half of the coordinates (S) and it’s complement, $$f(x_S, x_{[d]setminus S}) = (x_S, x_{[d]setminus S} odot s(x_s) + t(x_s))$$ for some scaling and translation functions (typically parameterized by neural networks) (s) and (t). Clearly, an affine coupling block only transforms one partition of the coordinates at a time by an affine function while leaving the other partition intact. It’s easy to see that an affine coupling is invertible if each coordinate of (s) is invertible. Moreover, the Jacobian of this function is $$begin{bmatrix}I & 0\frac{partial t}{partial x_S^T} & text{diag}(s(x_S))end{bmatrix}$$ In particular it’s lower triangular, so we can calculate the determinant in linear time by multiplying the (d) diagonal elements (in general determinants take (O(d^3)) time to compute). This allows us to efficiently compute likelihoods and their gradients for SGD on large models via the change of variables formula.
These affine coupling blocks are stacked, often while changing the part of the partition that is updated or more generally, permuting the elements in between the application of the coupling.

Affine couplings consist of nonlinear affine transformations of half of the data dimensions at a time which end in a normal distribution. We show that the choice of which dimensions are in which half can be simulated in a constant number of couplings. Source

Effect of the Choice of Partition on Representational Power

Partitions in Real NVP. Source

The choice of partition is often somewhat ad-hoc and involves domain knowledge (e.g. for image datasets, a typical choice is a checkerboard pattern or a channel-wise partition). In fact, some recent approaches like GLOW [4] try to “learn” permutations to apply between each pair of affine couplings. (Technically, since a permutation is a discrete object, in [4] the authors learn a 1×1 convolutions instead.)

While ablation experiments provide definite evidence that including learned 1×1 convolutions is beneficial for modeling image data in practice, it’s unclear whether this effect is from increased modeling power or algorithmic effects — and even less so how to formally quantify it. In this section, we come to a clear understanding of the representational value of adding this flexibility in partitioning. We knew from GLOW that adding these partitions helped. Now we know why!

We formalize the representational question as follows: how many affine couplings with a fixed partition are needed to simulate an arbitrary linear map? Since a linear map is more general than a 1×1 convolution, if it’s possible to do so with a small (say constant) number of affine couplings, we can simulate any affine coupling-based normalizing flow including 1×1 convolutions by one that does not include them which is merely a constant factor larger.

Concretely, we consider linear functions of the form $$T = prod_{i=1}^Kbegin{bmatrix}I & 0\A_i & B_iend{bmatrix} begin{bmatrix}C_i & D_i\0 & Iend{bmatrix},$$ for matrices (A_i, D_i in mathbb{R}^{dtimes d}) and diagonal matrices (B_i, C_i in mathbb{R}^{dtimes d}). The right hand side is precisely a composition of affine coupling blocks with linear maps (s, t) with a fixed partition (the parts of the input that are being updated alternate). We show the following result:

Theorem: To represent an arbitrary invertible (T), (K) must be at most 24. Additionally, there exist invertible matrices (T) such that (K geq 3).

Proof sketch: The statement hopefully reminds the reader of the standard LU decomposition — the twist of course being that the matrices on the right-hand side have a more constrained structure than merely being triangular. Our proof starts with the existence of a (LUP) decomposition for every matrix.

We first show that we can construct an arbitrary permutation (up to sign) using at most 21 alternating matrices of the desired form. The argument is group theoretic: we use the fact that a permutation decomposes into a composition of two permutations of order 2, which must be disjoint products of swaps and show that swapping elements can be implemented “in parallel” using several partitioned matrices of the type we’re considering.

Next, we show that we can produce an arbitrary triangular matrix with our partitioned matrices. We use similar techniques as above to reduce the matrix to a regular system of block linear equations which we can then solve. Our upper bound comes from just counting the total number of matrices required for these operations: the 21 for the permutation and 13 for each triangular matrix (upper and lower), giving a total of 47 required matrices. ∎

To reiterate the takeaway: a GLOW-style linear layer in between affine couplings could in theory make your network between 5 and 47 times smaller while representing the same function. We now have a precise understanding of the value of that architectural choice!

We also verified empirically in the figure below how well these linear models would fit randomly chosen (i.e. with iid Gaussian entries) linear functions. It seems empirically that at least for this ensemble our upper bound is loose and we can fit the functions well without using the full 47 layers. Closing this gap is an interesting problem for future work.

We regress affine couplings with linear scaling and translation functions at a variety of depths on linear functions determined by random matrices. It seems like we can fit these functions arbitrarily well with 4-16 layers, suggesting that at least in random cases the true number of layers required is closer to our lower bound.

Universal Approximation with Poorly Conditioned Networks

In our earlier result on the depth of invertible networks, we assumed that our network was Lipschitz and therefore well-conditioned. A natural question is then, if we remove this requirement, how powerful is the resulting class of models? In particular, we ask: are poorly conditioned affine coupling-based normalizing flows universal approximators as they are used in practice?

Curiously, this question has in fact not been answered in prior work. In a very recent work [5], it was shown that if we allow for padding of our data with extra dimensions that take a constant value 0, affine couplings are universal approximators. (Note, this kind of padding clearly results in a singular Jacobian — as the value in the added dimensions is constant.) The idea for why padding helps is that these extra dimensions are used as a “scratch pad” for the computation the network is performing. Another recent work [6] gives a proof of universal approximation for affine couplings assuming arbitrary permutations in between the layers are allowed (ala Glow) and a partition separating (d -1) dimensions from the other. However, in practice, these models are trained using a roughly half-half split and often without linear layers in between couplings (which already works quite well). We prove that none of these architectural modifications to affine couplings are necessary for universal approximation and additionally suggest a trade-off between the conditioning of the model and the quality of its approximation. Concretely, we show:

Theorem: For any bounded and absolutely continuous distribution (Q) over (mathbb{R}^n) and any (epsilon > 0), there exists a 3-layer affine coupling (g) with maps (s, t) represented by feedforward ReLU networks such that (W_2(g_# P, Q) leq epsilon), where (g_# P) is the pushforward of a standard Gaussian through (g).

We note that the construction for the theorem trades off quality of approximation ((epsilon)) with conditioning: the smallest singular value of the Jacobian in the construction for our theorem above will scale like (1/epsilon) — thus suggesting that if we want to use affine couplings as universal approximators, conditioning may be an issue even if we don’t pad with a constant value for the added dimensions like prior works — which obviously results in a singular Jacobian.

Proof sketch: The proof is based on two main ideas.

The first is a deep result from optimal transport, Brenier’s theorem, which for sufficiently “regular” distributions (p) over (mathbb{R}^d) guarantees an invertible map (phi), s.t. the pushforward of the Gaussian through (phi) equals (p). This reduces our problem to approximating (phi) using a sequence of affine couplings.

The difficulty in approximating (phi) is the fact that affine couplings are only allowed to change one part of the input, and in a constrained way. The trick we use to do this without a “scratchpad” to store intermediate computation as in prior works is to instead hide information in the “low order bits” of the other partition. For details, refer to our paper. ∎

Finally, on the experimental front, we wanted to experiment with how padding affects the conditioning of a learned model. We considered synthetic 2d datasets (see figure below) and found that padding with zeros resulted in a very poorly conditioned model which produced poor samples, as might be expected. We also considered a type of padding which is reasonable but for which we have no theory — namely, to use iid Gaussian samples as values for the added dimensions (in this case, the resulting Jacobians are not prima facie singular, and the model can still use them as a “scratch pad”). While we have no result that this in any formal sense can result in better-conditioned networks, we found that in practice it frequently does and it also results in better samples. This seems like a very fruitful direction for future research. Finally, without padding, the model produces samples of middling quality and has a condition number in between that of zero and Gaussian padding.

In both examples, Gaussian padding of the data gives a sharper distribution and a better-conditioned model.


Normalizing flows are one of the most popular generative models across various domains, though we still have a relatively narrow understanding of their relative pros and cons compared to other models. We show in this work that there are fundamental tradeoffs between depth and conditioning and representational power of this type of function. Though we have cleared up considerably the representational aspects of these models, the algorithmic and statistical questions are still wide open. We hope that this work guides both users of flows and theoreticians as to the fine-grained properties of flows as compared to other generative models.


[1] Rezende and Mohamed, 2015, Variational Inference with Normalizing Flows, ICML 2015

[2] Koehler, Mehta, and Risteski, 2020, Representational aspects of depth and conditioning in normalizing flows, Under Submission.

[3] Dinh, Sohl-Dickstein, and S. Bengio, 2016, Density estimation using Real NVP, ICLR 2016

[4] Kingma and Dhariwal, 2018, GLOW: Generative flow with 1×1 convolutions, NeurIPS 2018

[5] Huang, Dinh, and Courville, 2020, Augmented Normalizing Flows: Bridging the Gap Between Generative Flows and Latent Variable Models

[6] Teshima, Ishikawa, Tojo, Oono, Ikeda, and Sugiyama, 2020, Coupling-based Invertible Neural Networks Are Universal Diffeomorphism Approximators, NeurIPS 2020

Read More

Carnegie Mellon University at NeurIPS 2020

Carnegie Mellon University at NeurIPS 2020

Carnegie Mellon University is proud to present 88 papers at the 34th Conference on Neural Information Processing Systems (NeurIPS 2020), which will be held virtually this week. Our faculty and researchers are also giving invited talks at 7 workshops and are involved in organizing 14 workshops at the conference.

Here is a quick overview of the areas our researchers are working on:

We are also proud to collaborate with many other researchers in academia and industry:


Reinforcement Learning

Breaking the Sample Size Barrier in Model-Based Reinforcement Learning with a Generative Model
Gen Li (Tsinghua University) · Yuting Wei (Carnegie Mellon University) · Yuejie Chi (CMU) · Yuantao Gu (Tsinghua University) · Yuxin Chen (Princeton University)
Mon Dec 07 09:00 PM — 11:00 PM (PST) @ Poster Session 0 #82

Reinforcement Learning with General Value Function Approximation: Provably Efficient Approach via Bounded Eluder Dimension
Ruosong Wang (Carnegie Mellon University) · Russ Salakhutdinov (Carnegie Mellon University) · Lin Yang (UCLA)
Mon Dec 07 09:00 PM — 11:00 PM (PST) @ Poster Session 0 #167

Provably Efficient Exploration for Reinforcement Learning Using Unsupervised Learning
Fei Feng (University of California, Los Angeles) · Ruosong Wang (Carnegie Mellon University) · Wotao Yin (Alibaba US, DAMO Academy) · Simon Du (Institute for Advanced Study) · Lin Yang (UCLA)
Mon Dec 07 09:00 PM — 11:00 PM (PST) @ Poster Session 0 #169

Object Goal Navigation using Goal-Oriented Semantic Exploration [code] [video]Devendra Singh Chaplot (Carnegie Mellon University) · Dhiraj Prakashchand Gandhi (Carnegie Mellon University) · Abhinav Gupta (Facebook AI Research/CMU) · Russ Salakhutdinov (Carnegie Mellon University)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #358

Sparse Graphical Memory for Robust Planning [code]Scott Emmons (UC Berkeley) · Ajay Jain (UC Berkeley) · Misha Laskin (UC Berkeley) · Thanard Kurutach (University of California Berkeley) · Pieter Abbeel (UC Berkeley & · Deepak Pathak (Carnegie Mellon University)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #419

Task-Agnostic Online Reinforcement Learning with an Infinite Mixture of Gaussian Processes
Mengdi Xu (Carnegie Mellon University) · Wenhao Ding (Carnegie Mellon University) · Jiacheng Zhu (Carnegie Mellon University) · ZUXIN LIU (Carnegie Mellon University) · Baiming Chen (Tsinghua University) · Ding Zhao (Carnegie Mellon University)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #420

On Reward-Free Reinforcement Learning with Linear Function Approximation
Ruosong Wang (Carnegie Mellon University) · Simon Du (Institute for Advanced Study) · Lin Yang (UCLA) · Russ Salakhutdinov (Carnegie Mellon University)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #499

Rewriting History with Inverse RL: Hindsight Inference for Policy Improvement [code]Ben Eysenbach (Carnegie Mellon University) · Xinyang Geng (UC Berkeley) · Sergey Levine (UC Berkeley) · Russ Salakhutdinov (Carnegie Mellon University)
Tue Dec 08 09:00 PM — 11:00 PM (PST) @ Poster Session 2 #594

Planning with General Objective Functions: Going Beyond Total Rewards
Ruosong Wang (Carnegie Mellon University) · Peilin Zhong (Columbia University) · Simon Du (Institute for Advanced Study) · Russ Salakhutdinov (Carnegie Mellon University) · Lin Yang (UCLA)
Tue Dec 08 09:00 PM — 11:00 PM (PST) @ Poster Session 2 #600

Preference-based Reinforcement Learning with Finite-Time Guarantees
Yichong Xu (Carnegie Mellon University) · Ruosong Wang (Carnegie Mellon University) · Lin Yang (UCLA) · Aarti Singh (CMU) · Artur Dubrawski (Carnegie Mellon University)
Tue Dec 08 09:00 PM — 11:00 PM (PST) @ Poster Session 2 #601

Is Long Horizon RL More Difficult Than Short Horizon RL?
Ruosong Wang (Carnegie Mellon University) · Simon Du (Institute for Advanced Study) · Lin Yang (UCLA) · Sham Kakade (University of Washington & Microsoft Research)
Tue Dec 08 09:00 PM — 11:00 PM (PST) @ Poster Session 2 #602

Neural Dynamic Policies for End-to-End Sensorimotor Learning
Shikhar Bahl (Carnegie Mellon University) · Mustafa Mukadam (Facebook AI Research) · Abhinav Gupta (Facebook AI Research/CMU) · Deepak Pathak (Carnegie Mellon University)
Thu Dec 10 09:00 AM — 11:00 AM (PST) @ Poster Session 5 #1371

Weakly-Supervised Reinforcement Learning for Controllable Behavior
Lisa Lee (CMU / Google Brain / Stanford) · Ben Eysenbach (Carnegie Mellon University) · Russ Salakhutdinov (Carnegie Mellon University) · Shixiang (Shane) Gu (Google Brain) · Chelsea Finn (Stanford)
Thu Dec 10 09:00 PM — 11:00 PM (PST) @ Poster Session 6 #1832

Estimation & Inference

Robust Density Estimation under Besov IPM Losses
Ananya Uppal (Carnegie Mellon University) · Shashank Singh (Google) · Barnabas Poczos (Carnegie Mellon University)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #429

Rewriting History with Inverse RL: Hindsight Inference for Policy Improvement [code]Ben Eysenbach (Carnegie Mellon University) · Xinyang Geng (UC Berkeley) · Sergey Levine (UC Berkeley) · Russ Salakhutdinov (Carnegie Mellon University)
Tue Dec 08 09:00 PM — 11:00 PM (PST) @ Poster Session 2 #594

Domain Adaptation as a Problem of Inference on Graphical Models
Kun Zhang (CMU) · Mingming Gong (University of Melbourne) · Petar Stojanov (Carnegie Mellon Univerisity) · Biwei Huang (Carnegie Mellon University) · Qingsong Liu (Unisound Intelligence Co., Ltd.) · Clark Glymour (Carnegie Mellon University)
Tue Dec 08 09:00 PM — 11:00 PM (PST) @ Poster Session 2 #698

Efficient semidefinite-programming-based inference for binary and multi-class MRFs [code]Chirag Pabbaraju (Carnegie Mellon University) · Po-Wei Wang (CMU) · J. Zico Kolter (Carnegie Mellon University / Bosch Center for AI)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #851

Randomized tests for high-dimensional regression: A more efficient and powerful solution
Yue Li (Carnegie Mellon University) · Ilmun Kim (CMU) · Yuting Wei (Carnegie Mellon University)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #999

Distribution-free binary classification: prediction sets, confidence intervals and calibration
Chirag Gupta (Carnegie Mellon University) · Aleksandr Podkopaev (Carnegie Mellon University) · Aaditya Ramdas (CMU)
Thu Dec 10 09:00 AM — 11:00 AM (PST) @ Poster Session 5 #1537

Deep Learning

Neural Methods for Point-wise Dependency Estimation [code]Yao-Hung Hubert Tsai (Carnegie Mellon University) · Han Zhao (Carnegie Mellon University) · Makoto Yamada (Kyoto University/RIKEN AIP) · Louis-Philippe Morency (Carnegie Mellon University) · Russ Salakhutdinov (Carnegie Mellon University)
Mon Dec 07 09:00 PM — 11:00 PM (PST) @ Poster Session 0 #15

Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing [code]Zihang Dai (Carnegie Mellon University) · Guokun Lai (Carnegie Mellon University) · Yiming Yang (CMU) · Quoc V Le (Google)
Mon Dec 07 09:00 PM — 11:00 PM (PST) @ Poster Session 0 #64

Big Bird: Transformers for Longer Sequences [unofficial video]Manzil Zaheer (Google) · Guru Guruganesh (Google Research) · Kumar Avinava Dubey (Carnegie Mellon University) · Joshua Ainslie (Google) · Chris Alberti (Google) · Santiago Ontanon (Google LLC) · Philip Pham (Google) · Anirudh Ravula (Google) · Qifan Wang (Google Research) · Li Yang (Google) · Amr Ahmed (Google Research)
Mon Dec 07 09:00 PM — 11:00 PM (PST) @ Poster Session 0 #65

Deep Transformers with Latent Depth [code]Xian Li (Facebook) · Asa Cooper Stickland (University of Edinburgh) · Yuqing Tang (Facebook AI) · Xiang Kong (Carnegie Mellon University)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #287

Multiscale Deep Equilibrium Models [code]Shaojie Bai (Carnegie Mellon University) · Vladlen Koltun (Intel Labs) · J. Zico Kolter (Carnegie Mellon University / Bosch Center for AI)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #320

Monotone operator equilibrium networks [code]Ezra Winston (Carnegie Mellon University) · J. Zico Kolter (Carnegie Mellon University / Bosch Center for AI)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #323

Beyond Homophily in Graph Neural Networks: Current Limitations and Effective Designs [code]Jiong Zhu (University of Michigan) · Yujun Yan (University of Michigan) · Lingxiao Zhao (Carnegie Mellon University) · Mark Heimann (University of Michigan) · Leman Akoglu (CMU) · Danai Koutra (U Michigan)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #374

On Completeness-aware Concept-Based Explanations in Deep Neural Networks
Chih-Kuan Yeh (Carnegie Mellon University) · Been Kim (Google) · Sercan Arik (Google) · Chun-Liang Li (Google) · Tomas Pfister (Google) · Pradeep Ravikumar (Carnegie Mellon University)
Tue Dec 08 09:00 PM — 11:00 PM (PST) @ Poster Session 2 #640

A Causal View on Robustness of Neural Networks
Cheng Zhang (Microsoft Research, Cambridge, UK) · Kun Zhang (CMU) · Yingzhen Li (Microsoft Research Cambridge)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #805

Improving GAN Training with Probability Ratio Clipping and Sample Reweighting
Yue Wu (Carnegie Mellon University) · Pan Zhou (National University of Singapore) · Andrew Wilson (New York University) · Eric Xing (Petuum Inc. / Carnegie Mellon University) · Zhiting Hu (Carnegie Mellon University)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #945

AutoSync: Learning to Synchronize for Data-Parallel Distributed Deep Learning
Hao Zhang (Carnegie Mellon University, Petuum Inc.) · Yuan Li (Duke University) · Zhijie Deng (Tsinghua University) · Xiaodan Liang (Sun Yat-sen University) · Lawrence Carin (Duke University) · Eric Xing (Petuum Inc. / Carnegie Mellon University)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #1037

Deep Archimedean Copulas
Chun Kai Ling (Carnegie Mellon University) · Fei Fang (Carnegie Mellon University) · J. Zico Kolter (Carnegie Mellon University / Bosch Center for AI)
Thu Dec 10 09:00 PM — 11:00 PM (PST) @ Poster Session 6 #1754

A Study on Encodings for Neural Architecture Search [code]Colin White (Abacus.AI) · Willie Neiswanger (Carnegie Mellon University) · Sam Nolen (RealityEngines.AI) · Yash Savani (RealityEngines.AI)
Thu Dec 10 09:00 PM — 11:00 PM (PST) @ Poster Session 6 #1777

Is normalization indispensable for training deep neural network?
Jie Shao (Fudan University) · Kai Hu (Carnegie Mellon University) · Changhu Wang (ByteDance.Inc) · Xiangyang Xue (Fudan University) · Bhiksha Raj (Carnegie Mellon University)
Thu Dec 10 09:00 PM — 11:00 PM (PST) @ Poster Session 6 #1887

Algorithms & Optimization

Latent Dynamic Factor Analysis of High-Dimensional Neural Recordings [code]Heejong Bong (Carnegie Mellon University) · Zongge Liu (Carnegie Mellon University) · Zhao Ren (University of Pittsburgh) · Matthew Smith (Carnegie Mellon University) · Valerie Ventura (Carnegie Mellon University) · Kass E Robert (CMU)
Mon Dec 07 09:00 PM — 11:00 PM (PST) @ Poster Session 0 #32

Neutralizing Self-Selection Bias in Sampling for Sortition
Bailey Flanigan (Carnegie Mellon University) · Paul Goelz (Carnegie Mellon University) · Anupam Gupta (Carnegie Mellon University) · Ariel Procaccia (Harvard University)
Tue Dec 08 09:00 PM — 11:00 PM (PST) @ Poster Session 2 #710

Efficient semidefinite-programming-based inference for binary and multi-class MRFs [code]Chirag Pabbaraju (Carnegie Mellon University) · Po-Wei Wang (CMU) · J. Zico Kolter (Carnegie Mellon University / Bosch Center for AI)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #851

Linear Dynamical Systems as a Core Computational Primitive
Shiva Kaul (Carnegie Mellon University)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #1083

Distributed Training with Heterogeneous Data: Bridging Median- and Mean-Based Algorithms
Xiangyi Chen (University of Minnesota) · Tiancong Chen (University of Minnesota) · Haoran Sun (University of Minnesota) · Steven Wu (Carnegie Mellon University) · Mingyi Hong (University of Minnesota)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #1144

WOR and p’s: Sketches for lp-Sampling Without Replacement
Edith Cohen (Google) · Rasmus Pagh (University of Copenhagen) · David Woodruff (Carnegie Mellon University)
Thu Dec 10 09:00 AM — 11:00 AM (PST) @ Poster Session 5 #1435

Confidence sequences for sampling without replacement
Ian Waudby-Smith (Carnegie Mellon University) · Aaditya Ramdas (CMU)
Thu Dec 10 09:00 AM — 11:00 AM (PST) @ Poster Session 5 #1445

PLLay: Efficient Topological Layer based on Persistent Landscapes
Kwangho Kim (Carnegie Mellon University) · Jisu Kim (Inria Saclay) · Manzil Zaheer (Google) · Joon Kim (Carnegie Mellon University) · Frederic Chazal (INRIA) · Larry Wasserman (Carnegie Mellon University)
Thu Dec 10 09:00 AM — 11:00 AM (PST) @ Poster Session 5 #1582

Tackling the Objective Inconsistency Problem in Heterogeneous Federated Optimization
Jianyu Wang (Carnegie Mellon University) · Qinghua Liu (Princeton University) · Hao Liang (Carnegie Mellon University) · Gauri Joshi (Carnegie Mellon University) · H. Vincent Poor (Princeton University)
Thu Dec 10 09:00 AM — 11:00 AM (PST) @ Poster Session 5 #1636

Transferable Graph Optimizers for ML Compilers
Yanqi Zhou (Google Brain) · Sudip Roy (Google) · Amirali Abdolrashidi (UC Riverside) · Daniel Wong (Carnegie Mellon University) · Peter Ma (Google) · Qiumin Xu (Google) · Hanxiao Liu (Google Brain) · Phitchaya Phothilimtha (Google Brain) · Shen Wang (Google Inc) · Anna Goldie (Google Brain / Stanford) · Azalia Mirhoseini (Google Brain) · James Laudon (Google)
Thu Dec 10 09:00 PM — 11:00 PM (PST) @ Poster Session 6 #1781

Community detection using fast low-cardinality semidefinite programming
Po-Wei Wang (CMU) · J. Zico Kolter (Carnegie Mellon University / Bosch Center for AI)
Thu Dec 10 09:00 PM — 11:00 PM (PST) @ Poster Session 6 #1803

Learning Theory

Sample Complexity of Asynchronous Q-Learning: Sharper Analysis and Variance Reduction
Gen Li (Tsinghua University) · Yuting Wei (Carnegie Mellon University) · Yuejie Chi (CMU) · Yuantao Gu (Tsinghua University) · Yuxin Chen (Princeton University)
Mon Dec 07 09:00 PM — 11:00 PM (PST) @ Poster Session 0 #160

Reinforcement Learning with General Value Function Approximation: Provably Efficient Approach via Bounded Eluder Dimension
Ruosong Wang (Carnegie Mellon University) · Russ Salakhutdinov (Carnegie Mellon University) · Lin Yang (UCLA)
Mon Dec 07 09:00 PM — 11:00 PM (PST) @ Poster Session 0 #167

Provably Efficient Exploration for Reinforcement Learning Using Unsupervised Learning
Fei Feng (University of California, Los Angeles) · Ruosong Wang (Carnegie Mellon University) · Wotao Yin (Alibaba US, DAMO Academy) · Simon Du (Institute for Advanced Study) · Lin Yang (UCLA)
Mon Dec 07 09:00 PM — 11:00 PM (PST) @ Poster Session 0 #169

Agnostic Q-learning with Function Approximation in Deterministic Systems: Near-Optimal Bounds on Approximation Error and Sample Complexity
Simon Du (Institute for Advanced Study) · Jason Lee (Princeton University) · Gaurav Mahajan (University of California, San Diego) · Ruosong Wang (Carnegie Mellon University)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #226

Generalized Boosting
Arun Suggala (Carnegie Mellon University) · Bingbin Liu (Carnegie Mellon University) · Pradeep Ravikumar (Carnegie Mellon University)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #364

PAC-Bayes Learning Bounds for Sample-Dependent Priors
Pranjal Awasthi (Google/Rutgers University) · Satyen Kale (Google) · Stefani Karp (Google/CMU) · Mehryar Mohri (Google Research & Courant Institute of Mathematical Sciences)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #436

Revisiting the Sample Complexity of Sparse Spectrum Approximation of Gaussian Processes
Minh Hoang (Carnegie Mellon University) · Nghia Hoang (Amazon) · Hai Pham (Carnegie Mellon University) · David Woodruff (Carnegie Mellon University)
Tue Dec 08 09:00 PM — 11:00 PM (PST) @ Poster Session 2 #666

Follow the Perturbed Leader: Optimism and Fast Parallel Algorithms for Smooth Minimax Games
Arun Suggala (Carnegie Mellon University) · Praneeth Netrapalli (Microsoft Research)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #1021

On Learning Ising Models under Huber’s Contamination Model
Adarsh Prasad (Carnegie Mellon University) · Vishwak Srinivasan (Carnegie Mellon University) · Sivaraman Balakrishnan (Carnegie Mellon University) · Pradeep Ravikumar (Carnegie Mellon University)
Wed Dec 09 09:00 PM — 11:00 PM (PST) @ Poster Session 4 #1186

Axioms for Learning from Pairwise Comparisons
Ritesh Noothigattu (Carnegie Mellon University) · Dominik Peters (Carnegie Mellon University) · Ariel Procaccia (Harvard University)
Thu Dec 10 09:00 AM — 11:00 AM (PST) @ Poster Session 5 #1447

A Unified View of Label Shift Estimation
Saurabh Garg (CMU) · Yifan Wu (Carnegie Mellon University) · Sivaraman Balakrishnan (CMU) · Zachary Lipton (Carnegie Mellon University)
Thu Dec 10 09:00 AM — 11:00 AM (PST) @ Poster Session 5 #1535

Weak Supervision

Unsupervised Data Augmentation for Consistency Training [code]Qizhe Xie (CMU, Google Brain) · Zihang Dai (Carnegie Mellon University) · Eduard Hovy (CMU) · Thang Luong (Google Brain) · Quoc V Le (Google)
Mon Dec 07 09:00 PM — 11:00 PM (PST) @ Poster Session 0 #21

Provably Efficient Exploration for Reinforcement Learning Using Unsupervised Learning
Fei Feng (University of California, Los Angeles) · Ruosong Wang (Carnegie Mellon University) · Wotao Yin (Alibaba US, DAMO Academy) · Simon Du (Institute for Advanced Study) · Lin Yang (UCLA)
Mon Dec 07 09:00 PM — 11:00 PM (PST) @ Poster Session 0 #169

Model-based Policy Optimization with Unsupervised Model Adaptation
Jian Shen (Shanghai Jiao Tong University) · Han Zhao (Carnegie Mellon University) · Weinan Zhang (Shanghai Jiao Tong University) · Yong Yu (Shanghai Jiao Tong Unviersity)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #547

Modeling Task Effects on Meaning Representation in the Brain via Zero-Shot MEG Prediction
Mariya Toneva (Carnegie Mellon University) · Otilia Stretcu (Carnegie Mellon University) · Barnabas Poczos (Carnegie Mellon University) · Leila Wehbe (Carnegie Mellon University) · Tom Mitchell (Carnegie Mellon University)
Thu Dec 10 09:00 AM — 11:00 AM (PST) @ Poster Session 5 #1687

Demystifying Contrastive Self-Supervised Learning: Invariances, Augmentations and Dataset Biases
Senthil Purushwalkam Shiva Prakash (Carnegie Mellon University) · Abhinav Gupta (Facebook AI Research/CMU)
Thu Dec 10 09:00 AM — 11:00 AM (PST) @ Poster Session 5 #1696

Comprehensive Attention Self-Distillation for Weakly-Supervised Object Detection
Zeyi Huang (carnegie mellon university) · Yang Zou (Carnegie Mellon University) · B. V. K. Vijaya Kumar (CMU, USA) · Dong Huang (Carnegie Mellon University)
Thu Dec 10 09:00 AM — 11:00 AM (PST) @ Poster Session 5 #1704

Weakly-Supervised Reinforcement Learning for Controllable Behavior
Lisa Lee (CMU / Google Brain / Stanford) · Ben Eysenbach (Carnegie Mellon University) · Russ Salakhutdinov (Carnegie Mellon University) · Shixiang (Shane) Gu (Google Brain) · Chelsea Finn (Stanford)
Thu Dec 10 09:00 PM — 11:00 PM (PST) @ Poster Session 6 #1832

Computational Linguistics

Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing [code]Zihang Dai (Carnegie Mellon University) · Guokun Lai (Carnegie Mellon University) · Yiming Yang (CMU) · Quoc V Le (Google)
Mon Dec 07 09:00 PM — 11:00 PM (PST) @ Poster Session 0 #64

Big Bird: Transformers for Longer Sequences [unofficial video]Manzil Zaheer (Google) · Guru Guruganesh (Google Research) · Kumar Avinava Dubey (Carnegie Mellon University) · Joshua Ainslie (Google) · Chris Alberti (Google) · Santiago Ontanon (Google LLC) · Philip Pham (Google) · Anirudh Ravula (Google) · Qifan Wang (Google Research) · Li Yang (Google) · Amr Ahmed (Google Research)
Mon Dec 07 09:00 PM — 11:00 PM (PST) @ Poster Session 0 #65

Learning Sparse Prototypes for Text Generation
Junxian He (Carnegie Mellon University) · Taylor Berg-Kirkpatrick (University of California San Diego) · Graham Neubig (Carnegie Mellon University)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #184

Deep Transformers with Latent Depth [code]Xian Li (Facebook) · Asa Cooper Stickland (University of Edinburgh) · Yuqing Tang (Facebook AI) · Xiang Kong (Carnegie Mellon University)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #287

Computer Vision

Swapping Autoencoder for Deep Image Manipulation [website] [unofficial code] [video]Taesung Park (UC Berkeley) · Jun-Yan Zhu (Adobe, CMU) · Oliver Wang (Adobe Research) · Jingwan Lu (Adobe Research) · Eli Shechtman (Adobe Research, US) · Alexei Efros (UC Berkeley) · Richard Zhang (Adobe)
Mon Dec 07 09:00 PM — 11:00 PM (PST) @ Poster Session 0 #105

Residual Force Control for Agile Human Behavior Imitation and Extended Motion Synthesis [website] [code] [video]Ye Yuan (Carnegie Mellon University) · Kris Kitani (Carnegie Mellon University)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #354

See, Hear, Explore: Curiosity via Audio-Visual Association [code] [video]Victoria Dean (Carnegie Mellon University) · Shubham Tulsiani (Facebook AI Research) · Abhinav Gupta (Facebook AI Research/CMU)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #355

SDF-SRN: Learning Signed Distance 3D Object Reconstruction from Static Images [code]Chen-Hsuan Lin (Carnegie Mellon University) · Chaoyang Wang (Carnegie Mellon University) · Simon Lucey (CMU)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #480

Measuring Robustness to Natural Distribution Shifts in Image Classification [code]Rohan Taori (Stanford University) · Achal Dave (Carnegie Mellon University) · Vaishaal Shankar (UC Berkeley) · Nicholas Carlini (Google) · Benjamin Recht (UC Berkeley) · Ludwig Schmidt (UC Berkeley)
Tue Dec 08 09:00 PM — 11:00 PM (PST) @ Poster Session 2 #679

Pixel-Level Cycle Association: A New Perspective for Domain Adaptive Semantic Segmentation [code]Guoliang Kang (Carnegie Mellon University) · Yunchao Wei (UTS) · Yi Yang (UTS) · Yueting Zhuang (Zhejiang University) · Alexander Hauptmann (Carnegie Mellon University)
Tue Dec 08 09:00 PM — 11:00 PM (PST) @ Poster Session 2 #693

Group Contextual Encoding for 3D Point Clouds [code]Xu Liu (The University of Tokyo) · Chengtao Li (MIT) · Jian Wang (Carnegie Mellon University) · Jingbo Wang (Peking University) · Boxin Shi (Peking University) · Xiaodong He (JD AI research)
Wed Dec 09 09:00 PM — 11:00 PM (PST) @ Poster Session 4 #1151

Comprehensive Attention Self-Distillation for Weakly-Supervised Object Detection
Zeyi Huang (carnegie mellon university) · Yang Zou (Carnegie Mellon University) · B. V. K. Vijaya Kumar (CMU, USA) · Dong Huang (Carnegie Mellon University)
Thu Dec 10 09:00 AM — 11:00 AM (PST) @ Poster Session 5 #1704

Graphical Models

Domain Adaptation as a Problem of Inference on Graphical Models
Kun Zhang (CMU) · Mingming Gong (University of Melbourne) · Petar Stojanov (Carnegie Mellon Univerisity) · Biwei Huang (Carnegie Mellon University) · Qingsong Liu (Unisound Intelligence Co., Ltd.) · Clark Glymour (Carnegie Mellon University)
Tue Dec 08 09:00 PM — 11:00 PM (PST) @ Poster Session 2 #698

Generalized Independent Noise Condition for Estimating Latent Variable Causal Graphs
Feng Xie (Peking University) · Ruichu Cai (Guangdong University of Technology) · Biwei Huang (Carnegie Mellon University) · Clark Glymour (Carnegie Mellon University) · Zhifeng Hao (Guangdong University of Technology) · Kun Zhang (CMU)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #887

On the Role of Sparsity and DAG Constraints for Learning Linear DAGs
Ignavier Ng (University of Toronto) · AmirEmad Ghassami (Johns Hopkins University) · Kun Zhang (CMU)
Thu Dec 10 09:00 AM — 11:00 AM (PST) @ Poster Session 5 #1665

Transfer Learning

Pixel-Level Cycle Association: A New Perspective for Domain Adaptive Semantic Segmentation [code]Guoliang Kang (Carnegie Mellon University) · Yunchao Wei (UTS) · Yi Yang (UTS) · Yueting Zhuang (Zhejiang University) · Alexander Hauptmann (Carnegie Mellon University)
Tue Dec 08 09:00 PM — 11:00 PM (PST) @ Poster Session 2 #693

Domain Adaptation as a Problem of Inference on Graphical Models [code]Kun Zhang (CMU) · Mingming Gong (University of Melbourne) · Petar Stojanov (Carnegie Mellon Univerisity) · Biwei Huang (Carnegie Mellon University) · Qingsong Liu (Unisound Intelligence Co., Ltd.) · Clark Glymour (Carnegie Mellon University)
Tue Dec 08 09:00 PM — 11:00 PM (PST) @ Poster Session 2 #698

Look-ahead Meta Learning for Continual Learning [code]Gunshi Gupta (University of montreal) · Karmesh Yadav (Carnegie Mellon) · Liam Paull (Université de Montréal)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #767

Mitigating Forgetting in Online Continual Learning via Instance-Aware Parameterization
Hung-Jen Chen (National Tsing Hua University) · An-Chieh Cheng (National Tsing Hua University) · Da-Cheng Juan (Google) · Wei Wei (CMU) · Min Sun (Appier, Inc.)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #770

Domain Adaptation with Conditional Distribution Matching and Generalized Label Shift
Remi Tachet des Combes (Microsoft Research Montreal) · Han Zhao (Carnegie Mellon University) · Yu-Xiang Wang (UC Santa Barbara) · Geoffrey Gordon (MSR Montréal & CMU)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #1008

Privacy & Robustness

Denoised Smoothing: A Provable Defense for Pretrained Classifiers [code]Hadi Salman (Microsoft Research AI) · Mingjie Sun (Carnegie Mellon University) · Greg Yang (Microsoft Research) · Ashish Kapoor (Microsoft) · J. Zico Kolter (Carnegie Mellon University / Bosch Center for AI)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #302

Multi-Robot Collision Avoidance under Uncertainty with Probabilistic Safety Barrier Certificates
Wenhao Luo (Carnegie Mellon University) · Wen Sun (Cornell University) · Ashish Kapoor (Microsoft)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #312

A Closer Look at Accuracy vs. Robustness [code]Yao-Yuan Yang (UCSD) · Cyrus Rashtchian (UCSD) · Hongyang Zhang (TTIC) · Russ Salakhutdinov (Carnegie Mellon University) · Kamalika Chaudhuri (UCSD)
Tue Dec 08 09:00 PM — 11:00 PM (PST) @ Poster Session 2 #667

Measuring Robustness to Natural Distribution Shifts in Image Classification [code]Rohan Taori (Stanford University) · Achal Dave (Carnegie Mellon University) · Vaishaal Shankar (UC Berkeley) · Nicholas Carlini (Google) · Benjamin Recht (UC Berkeley) · Ludwig Schmidt (UC Berkeley)
Tue Dec 08 09:00 PM — 11:00 PM (PST) @ Poster Session 2 #679

Smoothed Geometry for Robust Attribution
Zifan Wang (Carnegie Mellon University) · Haofan Wang (Carnegie Mellon University) · Shakul Ramkumar (Carnegie Mellon University) · Piotr Mardziel (Carnegie Mellon University) · Matt Fredrikson (CMU) · Anupam Datta (Carnegie Mellon University)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #936

Trade-offs and Guarantees of Adversarial Representation Learning for Information Obfuscation
Han Zhao (Carnegie Mellon University) · Jianfeng Chi (University of Virginia) · Yuan Tian (University of Virginia) · Geoffrey Gordon (MSR Montréal & CMU)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #1066

Understanding Gradient Clipping in Private SGD: A Geometric Perspective
Xiangyi Chen (University of Minnesota) · Steven Wu (Carnegie Mellon University) · Mingyi Hong (University of Minnesota)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #1081

Fairness & Interpretability

Fair Hierarchical Clustering
Sara Ahmadian (Google Research) · Alessandro Epasto (Google) · Marina Knittel (University of Maryland, College Park) · Ravi Kumar (Google) · Mohammad Mahdian (Google Research) · Benjamin Moseley (Carnegie Mellon University) · Philip Pham (Google) · Sergei Vassilvitskii (Google) · Yuyan Wang (Carnegie Mellon University)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #859

Metric-Free Individual Fairness in Online Learning
Yahav Bechavod (Hebrew University of Jerusalem) · Christopher Jung (University of Pennsylvania) · Steven Wu (Carnegie Mellon University)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #861

How do fair decisions fare in long-term qualification?
Xueru Zhang (University of Michigan) · Ruibo Tu (KTH Royal Institute of Technology) · Yang Liu (UC Santa Cruz) · mingyan liu (university of Michigan, Ann Arbor) · Hedvig Kjellstrom (KTH Royal Institute of Technology) · Kun Zhang (CMU) · Cheng Zhang (Microsoft Research, Cambridge, UK)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #869

Regularizing Black-box Models for Improved Interpretability
Gregory Plumb (Carnegie Mellon University) · Maruan Al-Shedivat (Carnegie Mellon University) · Ángel Alexander Cabrera (Carnegie Mellon University) · Adam Perer (Carnegie Mellon University) · Eric Xing (Petuum Inc. / Carnegie Mellon University) · Ameet Talwalkar (CMU)
Wed Dec 09 09:00 AM — 11:00 AM (PST) @ Poster Session 3 #1078

Explainable Voting
Dominik Peters (Carnegie Mellon University) · Ariel Procaccia (Harvard University) · Alexandros Psomas (Purdue University) · Zixin Zhou (Peking University)
Thu Dec 10 09:00 AM — 11:00 AM (PST) @ Poster Session 5 #1560

Counterfactual Predictions under Runtime Confounding
Amanda Coston (Carnegie Mellon University) · Edward Kennedy (Carnegie Mellon University) · Alexandra Chouldechova (CMU)
Thu Dec 10 09:00 AM — 11:00 AM (PST) @ Poster Session 5 #1622

Multi-agent Systems

Improving Policy-Constrained Kidney Exchange via Pre-Screening
Duncan McElfresh (University of Maryland) · Michael Curry (University of Maryland) · Tuomas Sandholm (CMU, Strategic Machine, Strategy Robot, Optimized Markets) · John Dickerson (University of Maryland)
Mon Dec 07 09:00 PM — 11:00 PM (PST) @ Poster Session 0 #126

Mitigating Manipulation in Peer Review via Randomized Reviewer Assignments
Steven Jecmen (Carnegie Mellon University) · Hanrui Zhang (Duke University) · Ryan Liu (Carnegie Mellon University) · Nihar Shah (CMU) · Vincent Conitzer (Duke University) · Fei Fang (Carnegie Mellon University)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #267

Polynomial-Time Computation of Optimal Correlated Equilibria in Two-Player Extensive-Form Games with Public Chance Moves and Beyond
Gabriele Farina (Carnegie Mellon University) · Tuomas Sandholm (CMU, Strategic Machine, Strategy Robot, Optimized Markets)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #341

No-Regret Learning Dynamics for Extensive-Form Correlated Equilibrium
Andrea Celli (Politecnico di Milano) · Alberto Marchesi (Politecnico di Milano) · Gabriele Farina (Carnegie Mellon University) · Nicola Gatti (Politecnico di Milano)
Tue Dec 08 09:00 AM — 11:00 AM (PST) @ Poster Session 1 #535

EvolveGraph: Multi-Agent Trajectory Prediction with Dynamic Relational Reasoning
Jiachen Li (University of California, Berkeley) · Fan Yang (Carnegie Mellon University) · Masayoshi Tomizuka (University of California, Berkeley) · Chiho Choi (Honda Research Institute US)
Wed Dec 09 09:00 PM — 11:00 PM (PST) @ Poster Session 4 #1236

Evaluating and Rewarding Teamwork Using Cooperative Game Abstractions
Tom Yan (Carnegie Mellon University) · Christian Kroer (Columbia University) · Alexander Peysakhovich (Facebook)
Wed Dec 09 09:00 PM — 11:00 PM (PST) @ Poster Session 4 #1272

Small Nash Equilibrium Certificates in Very Large Games
Brian Zhang (Carnegie Mellon University) · Tuomas Sandholm (CMU, Strategic Machine, Strategy Robot, Optimized Markets)
Thu Dec 10 09:00 AM — 11:00 AM (PST) @ Poster Session 5 #1465


Invited Speakers

Differentiable Computer Vision, Graphics, and Physics in Machine Learning
Abhinav Gupta
Fri Dec 11 05:00 AM — 12:30 PM (PST)

Advances and Opportunities: Machine Learning for Education
Carolyn Rose, Ken Koedinger
Fri Dec 11 05:30 AM — 02:10 PM (PST)

Human in the Loop Dialogue Systems
Maxine Eskenazi, Alexander Rudnicky
Fri Dec 11 06:10 AM — 05:20 PM (PST)

Causal Discovery and Causality-Inspired Machine Learning
Clark Glymour
Fri Dec 11 06:50 AM — 04:50 PM (PST)

Self-Supervised Learning — Theory and Practice
Katerina Fragkiadaki, Abhinav Gupta, Ruslan Salakhutdinov
Sat Dec 12 08:50 AM — 06:40 PM (PST)

Algorithmic Fairness through the Lens of Causality and Interpretability
Hoda Heidari
Sat Dec 12 01:00 AM — 12:00 PM (PST)

International Workshop on Scalability, Privacy, and Security in Federated Learning (SpicyFL 2020)
Ruslan Salakhutdinov, Virginia Smith
Sat Dec 12


Differentiable Computer Vision, Graphics, and Physics in Machine Learning
Krishna Murthy Jatavallabhula · Kelsey Allen · Victoria Dean · Johanna Hansen · Shuran Song · Florian Shkurti · Liam Paull · Derek Nowrouzezahrai · Josh Tenenbaum
Fri Dec 11 05:00 AM — 12:30 PM (PST)

Self-Supervised Learning for Speech and Audio Processing
Abdel-rahman Mohamed · Hung-yi Lee · Shinji Watanabe · Shang-Wen Li · Tara Sainath · Karen Livescu
Fri Dec 11 06:50 AM — 04:25 PM (PST)

Causal Discovery and Causality-Inspired Machine Learning
Biwei Huang · Sara Magliacane · Kun Zhang · Danielle Belgrave · Elias Bareinboim · Daniel Malinsky · Thomas Richardson · Christopher Meek · Peter Spirtes · Bernhard Schölkopf
Fri Dec 11 06:50 AM — 04:50 PM (PST)

Machine Learning and the Physical Sciences
Anima Anandkumar · Kyle Cranmer · Shirley Ho · Mr. Prabhat · Lenka Zdeborová · Atilim Gunes Baydin · Juan Carrasquilla · Adji Bousso Dieng · Karthik Kashinath · Gilles Louppe · Brian Nord · Michela Paganini · Savannah Thais
Fri Dec 11 07:00 AM — 03:15 PM (PST)

First Workshop on Quantum Tensor Networks in Machine Learning
Xiao-Yang Liu · Qibin Zhao · Jacob Biamonte · Cesar Caiafa · Paul Pu Liang · Nadav Cohen · Stefan Leichenauer
Fri Dec 11 08:00 AM — 07:00 PM (PST)

ML Retrospectives, Surveys & meta-Analyses (ML- RSA)
Chhavi Yadav · Prabhu Pradhan · Abhishek Gupta · Ryan Lowe · Peter Henderson · Jessica Forde · Mayoore Jaiswal · Jesse Dodge
Fri Dec 11 08:30 AM — 09:00 PM (PST)

BabyMind: How Babies Learn and How Machines Can Imitate
Byoung-Tak Zhang · Gary Marcus · Angelo Cangelosi · Pia Knoeferle · Klaus Obermayer · David Vernon · Chen Yu
Fri Dec 11 08:40 AM — 05:30 PM (PST)

Machine Learning for Autonomous Driving
Rowan McAllister · Xinshuo Weng · Xinshuo Weng · Daniel Omeiza · Nick Rhinehart · Fisher Yu · German Ros · Vladlen Koltun
Fri Dec 11 08:55 AM — 05:00 PM (PST)

Workshop on Dataset Curation and Security
Nathalie Baracaldo Angel · Yonatan Bisk · Avrim Blum · Michael Curry · John Dickerson · Micah Goldblum · Tom Goldstein · Bo Li · Avi Schwarzschild
Fri Dec 11

Tackling Climate Change with ML
David Dao · Evan Sherwin · Priya Donti · Yumna Yusuf · Lauren Kuntz · Lynn Kaack · David Rolnick · Catherine Nakalembe · Claire Monteleoni · Yoshua Bengio
Fri Dec 11

HAMLETS (Human And Machine in-the-Loop Evaluation and Learning Strategies)
Divyansh Kaushik · Bhargavi Paranjape · Bhargavi Paranjape · Forough Arabshahi · Yanai Elazar · Yixin Nie · Max Bartolo · Polina Kirichenko · Pontus Lars Erik Saito Stenetorp · Mohit Bansal · Zachary Lipton · Douwe Kiela
Sat Dec 12 08:15 AM — 08:00 PM (PST)

Self-Supervised Learning — Theory and Practice
Pengtao Xie · Shanghang Zhang · Pulkit Agrawal · Ishan Misra · Cynthia Rudin · Abdel-rahman Mohamed · Wenzhen Yuan · Barret Zoph · Laurens van der Maaten · Eric Xing
Sat Dec 12 08:50 AM — 06:40 PM (PST)

International Workshop on Scalability, Privacy, and Security in Federated Learning (SpicyFL 2020)
Xiaolin Andy Li · Dejing Dou · Ameet Talwalkar · Hongyu Li · Jianzong Wang · Yanzhi Wang
Sat Dec 12

Machine Learning for Engineering Modeling, Simulation and Design
Alex Beatson · Priya Donti · Amira Abdel-Rahman · Stephan Hoyer · Rose Yu · J. Zico Kolter · Ryan Adams
Sat Dec 12

Read More

Experiments with the ICML 2020 Peer-Review Process

Experiments with the ICML 2020 Peer-Review Process

This post is cross-listed on 

The International Conference on Machine Learning (ICML) is a flagship machine learning conference that in 2020 received 4,990 submissions and managed a pool of 3,931 reviewers and area chairs. Given that the stakes in the review process are high — the careers of researchers are often significantly affected by the publications in top venues — we decided to scrutinize several components of the peer-review process in a series of experiments. Specifically, in conjunction with the ICML 2020 conference, we performed three experiments that target: resubmission policies, management of reviewer discussions, and reviewer recruiting. In this post, we summarize the results of these studies.

Resubmission Bias

Motivation. Several leading ML and AI conferences have recently started requiring authors to declare previous submission history of their papers. In part, such measures are taken to reduce the load on reviewers by discouraging resubmissions without substantial changes. However, this requirement poses a risk of bias in reviewers’ evaluations.

Research question. Do reviewers get biased when they know that the paper they are reviewing was previously rejected from a similar venue?

Procedure. We organized an auxiliary conference review process with 134 junior reviewers from 5 top US schools and 19 papers from various areas of ML. We assigned participants 1 paper each and asked them to review the paper as if it was submitted to ICML. Unbeknown to participants, we allocated them to a test or control condition uniformly at random:

  • Control. Participants review the papers as usual.
  • Test. Before reading the paper, participants are told that the paper they review is a resubmission.

Hypothesis. We expect that if the bias is present, reviewers in the test condition should be harsher than in the control. 

Key findings. Reviewers give almost one point lower score (95% Confidence Interval: [0.24, 1.30]) on a 10-point Likert item for the overall evaluation of a paper when they are told that a paper is a resubmission. In terms of narrower review criteria, reviewers tend to underrate “Paper Quality” the most.

Implications. Conference organizers need to evaluate a trade-off between envisaged benefits such as the hypothetical reduction in the number of submissions and the potential unfairness introduced to the process by the resubmission bias. One option to reduce the bias is to postpone the moment in which the resubmission signal is revealed until after the initial reviews are submitted. This finding must also be accounted for when deciding whether the reviews of rejected papers should be publicly available on systems like and others. 


Herding Effects in Discussions

Motivation. Past research on human decision making shows that group discussion is susceptible to various biases related to social influence. For instance, it is documented that the decision of a group may be biased towards the opinion of the group member who proposes the solution first. We call this effect herding and note that, in peer review, herding (if present) may result in undesirable artifacts in decisions as different area chairs use different strategies to select the discussion initiator.

Research question. Conditioned on a set of reviewers who actively participate in a discussion of a paper, does the final decision of the paper depend on the order in which reviewers join the discussion?

Procedure. We performed a randomized controlled trial on herding in ICML 2020 discussions that involved about 1,500 papers and 2,000 reviewers. In peer review, the discussion takes place after the reviewers submit their initial reviews, so we know prior opinions of reviewers about the papers. With this information, we split a subset of ICML papers into two groups uniformly at random and applied different discussion-management strategies to them: 

  • Positive Group. First ask the most positive reviewer to start the discussion, then later ask the most negative reviewer to contribute to the discussion.
  • Negative Group. First ask the most negative reviewer to start the discussion, then later ask the most positive reviewer to contribute to the discussion.

Hypothesis. The only difference between the strategies is the order in which reviewers are supposed to join the discussion. Hence, if the herding is absent, the strategies will not impact submissions from the two groups disproportionately. However, if the herding is present, we expect that the difference in the order will introduce a difference in the acceptance rates across the two groups of papers.

Key findings. The analysis of outcomes of approximately 1,500 papers does not reveal a statistically significant difference in acceptance rates between the two groups of papers. Hence, we find no evidence of herding in the discussion phase of peer review.

Implications. Regarding the concern of herding which is found to occur in other applications involving people, discussion in peer review does not seem to be susceptible to this effect and hence no specific measures to counteract herding in peer-review discussions are needed.


Novice Reviewer Recruiting

Motivation.  A surge in the number of submissions received by leading ML and  AI conferences has challenged the sustainability of the review process by increasing the burden on the pool of qualified reviewers. Leading conferences have been addressing the issue by relaxing the seniority bar for reviewers and inviting very junior researchers with limited or no publication history, but there is mixed evidence regarding the impact of such interventions on the quality of reviews. 

Research question. Can very junior reviewers be recruited and guided such that they enlarge the reviewer pool of leading ML and AI conferences without compromising the quality of the process?

Procedure. We implemented a twofold approach towards managing novice reviewers:

  • Selection. We evaluated reviews written in the aforementioned auxiliary conference review process involving 134 junior reviewers, and invited 52 of these reviewers who produced the strongest reviews to join the reviewer pool of ICML 2020. Most of these 52 “experimental” reviewers come from the population not considered by the conventional way of reviewer recruiting used in ICML 2020.
  • Mentoring. In the actual conference, we provided these experimental reviewers with a senior researcher as a point of contact who offered additional mentoring.

Hypothesis. If our approach allows to bring strong reviewers to the pool, we expect experimental reviewers to perform at least as good as reviewers from the main pool on various metrics, including the quality of reviews as rated by area chairs.

Key findings. A combination of the selection and mentoring mechanisms results in reviews of at least comparable and on some metrics even higher-rated quality as compared to the conventional pool of reviews: 30% of reviews written by the experimental reviewers exceeded the expectations of area chairs (compared to only 14% for the main pool).

Implications. The experiment received positive feedback from participants who appreciated the opportunity to become a reviewer in ICML 2020 and from authors of papers used in the auxiliary review process who received a set of useful reviews without submitting to a real conference. Hence, we believe that a promising direction is to replicate the experiment at a larger scale and evaluate the benefits of each component of our approach.



All in all, the experiments we conducted in ICML 2020 reveal some useful and actionable insights about the peer-review process. We hope that some of these ideas will help to design a better peer-review pipeline in future conferences.

We thank ICML area chairs, reviewers, and authors for their tremendous efforts. We would also like to thank the Microsoft Conference Management Toolkit (CMT) team for their continuous support and implementation of features necessary to run these experiments, the authors of papers contributed to the auxiliary review process for their responsiveness, and participants of the resubmission bias experiment for their enthusiasm. Finally, we thank Ed Kennedy and Devendra Chaplot for their help with designing and executing the experiments.

The post is based on joint works with Nihar B. Shah, Aarti Singh, Hal Daumé III, and Charvi Rastogi.

Read More

On Learning Language-Invariant Representations for Universal Machine Translation

On Learning Language-Invariant Representations for Universal Machine Translation

Figure 1: An encoder-decoder generative model of translation pairs, which helps to circumvent the limitation discussed before. There is a global distribution (mathcal{D}) over the representation space (mathcal{Z}), from which sentences of language (L_i) are generated via decoder (D_i). Similarly, sentences could also be encoded via (E_i) to (mathcal{Z}).

Despite the recent improvements in neural machine translation (NMT), training a large NMT model with hundreds of millions of parameters usually requires a collection of parallel corpora at a large scale, on the order of millions or even billions of aligned sentences for supervised training (Arivazhagan et al.). While it might be possible to automatically crawl the web to collect parallel sentences for high-resource language pairs, such as German-English and French-English, it is often infeasible or expensive to manually translate large amounts of sentences for low-resource language pairs, such as Nepali-English, Sinhala-English, etc. To this end, the goal of the so-called multilingual universal machine translation, a.k.a., universal machine translation (UMT), is to learn to translate between any pair of languages using a single system, given pairs of translated documents for some of these languages. The hope is that by learning a shared “semantic space” between multiple source and target languages, the model can leverage language-invariant structure from high-resource translation pairs to transfer to the translation between low-resource language pairs, or even enable zero-shot translation.

Indeed, training such a single massively multilingual model has gained impressive empirical results, especially in the case of low-resource language pairs (see Fig. 2). However, such success also comes with a cost. From Fig. 2 we observe that the translation quality over high-resource language pairs by using such a single UMT system is worse than the corresponding bilingual baselines.

Figure 2: Translation quality by using a single massively multilingual model against bilingual baselines that are trained for each one of the 103 language pairs. While the translation performances over low resource languages increase, the performances over high resource languages decrease. Our work provides a theoretical explanation for this empirical phenomenon. Figure credit: Exploring Massively Multilingual, Massive Neural Machine Translation.

Is this empirical phenomenon by coincidence? If not, why does it happen? Furthermore, what kind of structural assumptions about languages could help us get over this detrimental effect? In this blog post, based on our recent ICML paper, we take the first step towards understanding universal machine translation by providing answers to the above questions. The key takeaways of this blog post could be summarized as follows:

  • In a completely assumption-free setup, based on a common shared representation, no matter what decoder is used to translate the target languages, it is impossible to avoid making a large translation error on at least one pair of the translation pairs.
  • Under a natural generative model assumption for the data, after seeing aligned sentences for a linear number of language pairs (instead of quadratic!), we can learn encoder/decoders that perform well on any unseen language pair, i.e., zero-shot translation is possible.

An Impossibility Theorem on UMT via Language-Invariant Representations

Suppose we have an unlimited amount of parallel sentences for each pair of languages, with unbounded computational resources. Could we train a single model that performs well on all pairs of translation tasks based on a common representation space? Put it in other words, is there any information-theoretic limit of such systems for the task of UMT? In this paragraph we will show that there is an inherent tradeoff between the translation quality and the degree of representation invariance w.r.t. languages: the better the language invariance, the higher the cost on at least one of the translation pairs. At a high-level, this result holds due to the general data-processing principle: if a representation is invariant to multiple source languages, then any decoder based on this representation will have to generate the same language model on the target language. But on the other hand, the parallel corpora we use to train such a system could have drastically different sentence distributions on the target language, thus leading to a discrepancy (error) between the generated sentence distribution and the ground-truth sentence distribution over the target language.

To keep our discussions simple and transparent, let’s start with a basic Two-to-One setup where there are only two source languages (L_0) and (L_1) and one target language (L). Furthermore, for each source language (L_i, iin{0, 1}), let’s assume that there is a perfect translator (f_{L_ito L}^*) that takes a sentence (or string, sequence) from (L_i) and outputs the corresponding translation in (L). Under this setup, it is easy to see that there exists a perfect translator (f_L^*) in this Two-to-One task: $$f_L^*(x) = sum_{iin{0, 1}}mathbb{I}(xin L_i)cdot f_{L_ito L}^*(x)$$ In words: upon receiving a sentence (x), (f_L^*) simply checks which source language (x) comes from and then call the corresponding ground-truth translator.

To make the idea of language-invariant representations formal, let (g: Sigma^*to mathcal{Z}) be an encoder that takes a sentence (string) from alphabet (Sigma) to a representation in a vector space (mathcal{Z}). We call (g) an (epsilon)-universal language mapping if the distributions of sentence representations from different languages (L_0) and (L_1) are (epsilon)-close to each other. In words, (d(g_sharpmathcal{D}_0, g_sharpmathcal{D}_1)leq epsilon) for some divergence measure (d), where (g_sharpmathcal{D}_i) is the induced distribution of sentence (from (L_i)) representations in the shared space (mathcal{Z}). Subsequently, a multilingual system will train a decoder (h) that takes a sentence representation (z) and outputs the corresponding target translation in language (L). The hope here is that (z) encodes the language-invariant semantic information about the input sentence (either from (L_0) or from (L_1)) based on which to translate to the target language (L).

So far so good, but could we recover the perfect translator (f_L^*) by learning a common, shared representation (Z), i.e., (epsilon) is small? Unfortunately, the answer here is negative if we don’t have any assumption on the parallel corpora we use to train our encoder (g) and decoder (h):

Theorem (informal): Let (g:Sigma^*tomathcal{Z}) be an (epsilon)-universal language mapping. Then for any decoder (h:mathcal{Z}to Sigma_L^*), the following lower bound holds: $$text{Err}_{mathcal{D}_0}^{L_0to L}(hcirc g) + text{Err}_{mathcal{D}_1}^{L_1to L}(hcirc g)geq d(mathcal{D}_0(L), mathcal{D}_1(L))- epsilon.$$

Here the error term (text{Err}_{mathcal{D}_i}^{L_ito L}(hcirc g)) measures the (0-1) translation performance given by the encoder-decoder pair (hcirc g) from (L_i) to (L) over distribution (mathcal{D}_i). The first term (d(mathcal{D}_0(L), mathcal{D}_1(L))) in the lower bound measures the difference of distributions over sentences from the target language in the two parallel corpora, i.e., (L_0-L) and (L_1 – L). For example, in many practical scenarios, it may happen that the parallel corpus of high-resource language pair, e.g., German-English, contains sentences over a diverse domain whereas as a comparison, the parallel corpus of low-resource language pair, e.g., Sinhala-English, only contains target translations from a specific domain, e.g., sports, news, product reviews, etc. In this case, despite the fact that the target is the same language (L), the corresponding sentence distributions from English are quite different between different corpora, leading to a large lower bound. As a result, our theorem, which could be interpreted as a kind of uncertainty principle in UMT, says that no matter what kind of decoder we are going to use, it has to incur a large error on at least one of the translation pairs. It is also worth pointing out that our lower bound is algorithm-independent and it holds even with unbounded computation and data. As a final note, realize that for fixed distributions (mathcal{D}_i, iin{0, 1}), the smaller the (epsilon) (hence the better the language-invariant representations), the larger the lower bound, demonstrating an inherent tradeoff between language-invariance and translation performance in general.

Proof Sketch: Here we provide a proof-by-picture (Fig. 3) in the special case of perfectly language-invariant representations, i.e., (epsilon = 0), to highlight the main idea in our proof of the above impossibility theorem. Please refer to our paper for more detailed proof as well as an extension of the above impossibility theorem in the more general many-to-many translation setting.

Figure 3: Proof by picture: Language-invariant representation (g) induces the same feature distribution over (mathcal{Z}), which leads to the same output distribution over the target language (Sigma_L^*). However, the parallel corpora of the two translation tasks in general have different marginal distributions over the target language, hence a triangle inequality over the output distributions gives the desired lower bound.

How can we Bypass this Limitation?

One way is to allow the decoder (h) to have access to the input sentences (besides the language-invariant representations) during the decoding process — e.g. via an attention mechanism on the input level. Technically, such information flow from input sentences during decoding would break the Markov structure of “input-representation-output” in Fig. 3, which is an essential ingredient in the proof of our theorem. Intuitively, in this case both language-invariant (hence language-independent) and language-dependent information would be used.

Another way would be to assume extra structure on the distributions of our corpora (mathcal{D}_{i}), i.e., by assuming some natural generative process capturing the distribution of the parallel corpora that are used for training. Since languages share a lot of semantic and syntactic characteristics, this would make a lot of sense — and intuitively, this is what universal translation approaches are banking on. In the next paragraph, we will do exactly this — we will show that under a suitable generative model, not only will there be a language-invariant representation, but it will be learnable using corpora from a very small (linear) number of pairs of language.

A Generative Model for UMT: A Linear Number of Translation Pairs Suffices!

In this section we will discuss a generative model, under which not only will there be a language-invariant representation, but it will be learnable using corpora from a very small (linear) number of pairs of language. Note that there are a quadratic number of translation pairs in our universe, hence our result shows that under this generative model zero-shot translation is actually possible.

To start with, what kind of generative model is suitable for the task of UMT? Ideally, we would like to have a feature space where vectors correspond to the semantic encoding of sentences from different languages. One could also understand it as a sort of “meaning” space. Then, language-dependent decoders would take these semantic vectors and decode them as the observable sentences. Figure 1 illustrates the generative process of our model, where we assume there is a common distribution (mathcal{D}) over the feature space (mathcal{Z}), from which parallel sentences are sampled and generated.

For ease of presentation, let’s first assume that each encoder-decoder pair ((E_i, D_i)) consists of deterministic mappings (see our paper on extensions with randomized encoders/decoders). The first question to ask is: how does this generative model assumption circumvent our previous lower bound in the last paragraph? We can easily observe that under the encoder-decoder generative assumption in Figure 1, the first term in our lower bound, (d(mathcal{D}_0(L), mathcal{D}_1(L))), gracefully reduces to 0, hence even if we try to learn perfectly language-invariant representations ((epsilon = 0)), there will be no loss of translation accuracy using universal language mapping. Perhaps what’s more interesting is that, under proper assumptions on the structure of (mathcal{F}), the class of encoders and decoders we learn from, by using the traditional empirical risk minimization (ERM) framework to learn the language-dependent encoders and decoders on a small number of language pairs, we could expect the learned encoders/decoders to well generalize on unseen language pairs as well! Informally,

Theorem (informal): Let (H) be a connected graph where each node (L_i) corresponds to a language and each edge ((L_i, L_j)) means that the learner has been trained on language pair (L_i) and (L_j), with empirical translation error (epsilon_{i,j}) and corpus of size (Omega(1 / epsilon_{i,j}^2 cdot log C(mathcal{F}))). Then with high probability, for any pair of language (L) and (L’) that are connected by a path (L = L_0, L_1, ldots, L_m = L’) in (H), its population level translation error is upper bounded by (O(sum_{k=0}^{m-1}epsilon_{k,k+1})).

In the theorem above, (C(mathcal{F})) is some complexity measure of the class (mathcal{F}). If we slightly simplify the theorem above by defining (epsilon := max_{(L_i, L_j)in H}epsilon_{i,j}) and realizing that the path length (m) is upper bounded by the diameter of the graph (H), (text{diam}(H)), we immediately obtain the following intuitive result:

For any pair of languages (L, L’) (the parallel corpus between (L) and (L’) may not necessarily appear in our training corpora), the translation error between (L) and (L’) is upper bounded by (O(text{diam}(H) cdot epsilon)).

The above corollary says that graphs (H) that do not have long paths are preferable. For example, (H) could be a star graph, where a central (high-resource) language acts as a pivot node. The proof of the theorem above essentially boils down to two steps: first, we use an epsilon-net argument to show that the learned encoders/decoders generalize on a pair of language that appears in our training corpora, and then by using the connectivity of the graph (H), we apply a chain of triangle-like inequalities to bound the error along the path connecting any pair of languages.

Some Concluding Thoughts

The prospect of building a single system for universal machine translation is appealing. Compared with building a quadratic number of bilingual translators, such a single system is easier to train, build, deploy, and maintain. More importantly, this could potentially allow the system to transfer some common knowledge in translation from high-resource languages to low-resource ones. However, such promise often comes with a price, which calls for proper assumptions on the generative process of the parallel corpora used for training. Our paper takes a first step towards better understanding the tradeoff in this regard and proposes a simple setup that allows for zero-shot translation. On the other hand, there are still some gaps between theory and practice. For example, it would be interesting to see whether the BLEU score, a metric used in the empirical evaluation of translation quality, bears a similar kind of lower bound. Also, could we further extend our generative modeling of sentences so that there are more hierarchical structures in the semantic space (mathcal{Z})? Empirically, it would be interesting to implement the above generative model on synthetic data to see the actual performance of zero-shot translation under the model assumption. These challenging problems (and more) will require collaborative efforts from a wide range of research communities and we hope our initial efforts could inspire more efforts in bridging the gap.


  1. Massively Multilingual Neural Machine Translation in the Wild: Findings and Challenges, Arivazhagan et al.,
  2. Investigating Multilingual NMT Representations at Scale, Kudugunta et al., EMNLP 2019,
  3. On Learning Language-Invariant Representations for Universal Machine Translation, Zhao et al., ICML 2020,
  4. The Source-Target Domain Mismatch Problem in Machine Translation, Shen et al.,
  5. How multilingual is Multilingual BERT? Pires et al., ACL 2019,

DISCLAIMER: All opinions expressed in this post are those of the author and do not represent the views of CMU.

Read More