Tracking Any Pixel in a Video

Tracking Any Pixel in a Video

We upgrade pixels into PIPs: “Persistent Independent Particles”. With this representation, we track any pixel over time, and overcome visibility issues with a learned temporal prior.

Motion estimation is a fundamental task of computer vision, with extremely broad applications. By tracking something, you can build models of its various properties: shape, texture, articulation, dynamics, affordances, and so on. More fine-grained tracking allows more fine-grained understanding. For robots, fine-grained tracking also enables fine-grained manipulation. Even setting aside downstream AI-related applications, motion tracks are directly useful for video editing applications — making realistic edits to a person or object in a video demands precise-as-possible tracking of the pixels, across an indefinite timespan.

There are a variety of methods for tracking objects (at the level of segmentation masks or bounding boxes), or for tracking certain points in certain categories (e.g., the joints of a person), but there are actually very few options for general-purpose fine-grained tracking. In this domain, the dominant approaches are feature matching and optical flow. The feature matching approach is: compute a feature for the target on the first frame, then compute features for pixels in other frames, and then compute “matches” using feature similarity (i.e., nearest neighbors). This often works well, but does not take into account temporal context, like smoothness of motion. The optical flow approach is: compute a dense “motion field” that relates each pair of frames, and then do some post-processing to link the fields together. Optical flow is very powerful, but since it only describes motion for a pair of frames at a time, it cannot produce useful outputs for targets that undergo multi-frame occlusion. “Occlusion” means our view is obstructed, and we need to guess the target’s location from context.

During an occlusion, appearance information does not suffice, because the target is not even present in the frames. Multi-frame temporal priors are key.

Around the year 2006, Peter Sand and Seth Teller proposed an alternative to flow-based and feature-based methods, called a “particle video.” This approach aims to represent a video with a set of particles that move across multiple frames. Their proposed method did not handle occlusions, but in our view, they laid the groundwork for treating pixels as persistent entities, with multi-frame trajectories and long-range temporal priors.

Inspired by their work, we propose Persistent Independent Particles (PIPs), a new particle video method. Our method takes a video as input, along with the ( (x, y) ) coordinate of a target to track, and produces the target’s trajectory as output. The model can be queried for any number of particles, at any positions.

Particle trajectories for arbitrary “target” pixels in the video.

You may notice in our visualizations that the PIP trajectories often leave the video bounds. We treat pixels flying out-of-bounds as just another “occlusion.” This type of robustness is exactly what is missing from feature-based and flow-based methods.

Let’s step through how we achieved this.

How does it work?

At a high level, our method makes an extreme trade-off between spatial awareness and temporal awareness: we estimate the trajectory of every target independently. This extreme choice allows us to devote the majority of parameters into a module that simultaneously learns (1) temporal priors, and (2) an iterative inference mechanism that searches for the target pixel’s location in all input frames. Related work on optical flow estimation typically uses the opposite approach: they estimate the motion of all pixels simultaneously (using maximal spatial context), for just 2 frames at a time (using minimal temporal context).

Given the ( (x_1,y_1) ) coordinate of the target on the first frame, our concrete goal is to estimate the target’s full trajectory over (T) frames: ( (x_1,y_1), ldots, (x_T,y_T) ).

We start by initializing a zero-velocity estimate. This means copying the initial coordinate to every timestep.

To track the target, we need to know what it looks like, so we compute appearance features for all the frames (using a CNN), and initialize the target’s appearance trajectory with a bilinear sample at the given coordinate on the first frame. (The “bilinear sample” step extracts a feature vector at a subpixel location in the spatial map of features.)

Our inference process will proceed by iteratively refining the sequence of positions, and sequence of appearance features, until they (hopefully) match the true trajectory of the target. This idea is illustrated in the video below: at initialization, only “timestep 1” has the correct location of the target (since this was given), and gradually, the model “locks on” to the target in all frames.

Our model’s job is to produce updates (i.e., deltas) for the positions and features, so that the trajectory tracks the target on every frame. A critical detail here is that we ask our model to produce these updates for multiple timesteps simultaneously. This allows us to “catch” a target after it re-emerges from an occluder, and “fill in” the missing part of the trajectory.

There are many ways to implement this, but for fast training and good generalization, we need to carefully select what information we provide to the model.

The main source of information we provide to the model is: measurements of local appearance similarity. We obtain these measurements using cross correlation (i.e., dot products), computed at multiple scales. When the target is visible, it should show up as a strong peak in at least one of the similarity maps. Also, when we are “locked on”, the peak should be in the middle of the map. This is illustrated in the animation below.

The second source of information we provide to the model is: the estimated trajectory itself. This allows the model to impose a temporal prior, and fix up parts of the trajectory where the local similarity information was ambiguous.

Finally, we allow the model to inspect the feature vector of the target, in case it might learn different strategies for different types of features. For example, depending on the scale or texture of the target, it may adjust the way it uses information from the multi-scale similarity maps.

For the model architecture, we elected to use an MLP-Mixer, which we found to have a good trade-off between model capacity, training time, and generalization. We also tried convolutional models and transformers, but the convolutional models could not fit the data as well as the MLP-Mixer, and the transformers took too long to train.

We trained the model in synthetic data that we made (based on an existing optical flow dataset), where we could provide multi-frame ground-truth for targets that undergo occlusions. The animation below shows the kind of data we trained on. You might say the data looks crazy — but that’s the point! If you can’t get real data, your best bet is synthetic data with extremely high diversity.

FlyingThings++: We train our model to track objects in this data, so that real videos are easy in comparison.

After training for a couple days on this data, the model starts to work on real videos.


In the paper, we provide some quantitative analysis, showing that it works better than existing methods. It’s certainly not perfect — on keypoint tracking, the model works about 6/10 times, so there is a lot of room for improvement. The baselines are at around 5/10 or less.

The idea here is: pick a point on the first frame, and try to locate that same point in other frames of the video. This is a hard task especially when the point gets occluded.

Output of our PIPs model.

Baseline methods tend to get stuck on occluders, since they do not use multi-frame temporal context. For example, here is the output of a state-of-the-art optical flow method, on the same video and same target.

Output of an optical flow model (RAFT).

We have had fun trying the model on various videos, and observing the estimated trajectories. Sometimes they are surprisingly complex, since the target’s actual motion is subtly entangled with camera motion.

Visualizing the trajectories more densely gives mesmerizing results. Notice that the model even tracks ripples and specularities in the water.

Despite the fact that each particle trajectory is estimated independently of the others, they show surprisingly accurate grouping. Notice that the background particles all move together.

What’s next?

Our method upgrades pixels into PIPs: “Persistent Independent Particles.” This independence assumption, however, is probably not what we want in general. In ongoing work, we are trying to incorporate context across particles, so that confident particles can help the unconfident ones, and so that we track at multiple levels of granularity simultaneously.

We have released our code and model weights on GitHub. We encourage you to try our! If you are interested in building on our method, the provided tests, visualizations, and training scripts should make that easy. Or, if you are working on a video-based method that currently relies on optical flow, you may want to try our PIP trajectories as a replacement, which should give a better signal under occlusions.

We hope that our work opens up long-range fine-grained tracking of “anything.” For more details, please check out our project page and paper.

Read More

Long-term Dynamics of Fairness Intervention in Connection Recommender Systems

Long-term Dynamics of Fairness Intervention in Connection Recommender Systems

Figure 1: Modeled recommendation cycle. We demonstrate how enforcing group-fairness in every recommendation slate separately does not necessarily promote equity in second order variables of interest like network size.

Connection recommendation is at the heart of user experience in many online social networks. Given a prompt such as ‘People you may know’, connection recommender systems suggest a list of users, and the recipient of the recommendation decides which of the users to connect with. In some instances, connection recommendations can account for more than 50% of the social network graph [1]. Depending on the platform, being connected to the right people is tied to important advantages such as job opportunities or increased visibility. While this makes it imperative to treat users fairly, it is far from obvious how fairness can be enforced or what it even means to have a ‘fair’ system in this scenario. In fact, when enforcing fairness in dynamic systems like this, we are often interested in second-order variables while interventions target equity in single steps [2]. In connection recommendation, this generally means enforcing some sort of parity condition in each recommendation slate, e.g. equal exposure of female and male users for every query, while simultaneously hoping that this will promote more equitable network sizes in the long run. In our work, we demonstrate how this approach can fail and common statistical notions of recommendation fairness do not ensure equity of network sizes in the long run. In fact, we see that fairness interventions are not even sufficient to mitigate the amplification of existing biases. We reach these conclusions based on an extensive simulation study supplemented by theoretical limit analyses. In the following, we focus on a discussion of empirical results and refer to the full paper for theoretical derivations.

Recommendation procedure

The assumed recommendation cycle is depicted in Figure 1. We briefly summarize the involved steps and give more detail on the most important components in the following sections. On the left in the flowchart, a user queries a connection recommendation, for example, by loading the ‘People You may Know’ page on the platform’s website. The system then computes relevance scores between the user who is seeking recommendations – which we will refer to as the source user – and other previously unconnected users who will be referred to as the destination users. Based on these relevance scores, we derive a ranking of destination users subject to fairness constraints and display a list of possible connections to the source user. After new connections have been made, the cycle repeats.

Fairness constrained probabilistic ranking framework

For each recommendation query, the system obtains a set of relevance scores and derives a probabilistic ranking. This probabilistic ranking takes the form of a matrix (P_{s,d}^q) where the entry in the (d)th row and (r)th column denotes the probability of displaying member (d) in slot (r) of the recommendation list. The probabilities are selected to maximize the expected utility for the source user which is common practice in the recommendation literature [3].

More formally, we denote the source member as (s) and the destination member as (d). (q) denotes a query for an ordered list of (m) destination members and (u_{s,d}^q) the relevance score of (s) and (d) under query (q). We model a form of position bias by discounting the expected attention a recommended user receives based on how far down the list the recommendation, i.e. we define the exposure of slot (r) as (v(r)=1/log(r+1)). Given these quantities, the optimization problem for query (q) can be written as


$$text{s.t.} sum_{i=1}^nP_s^q(d,i)leq 1 text{ for all } din[D_q],$$

$$sum_{i=1}^{D_q}P_s^q(i,r)=1text{ for all }rin[m],$$

$$0leq P_s^q(i,r)leq 1text{ for all }iin[D_q],jin[m].$$

The constraints in this problem ensure that each slot has a recommended user and each user can be recommended at most once in a given list. We explore two types of fairness constraints to add to this problem. First, we consider demographic parity of exposure which is a commonly suggested form of fairness in recommendations. It requires that groups receive expected exposure proportional to their population shares, i.e. for groups (G_0) and (G_1)

$$frac{1}{vert G_0vert}sum_{din G_0}sum_{r=1}^mP_s^q(d,r)v_r=frac{1}{vert G_1vert}sum_{din G_1}sum_{r=1}^mP_s^q(d,r)v_r.$$

Assume a setting in which there is no position bias (v), the population is split into 60% majority group and 40% minority group, and each recommendation list has 10 slots. In this setting, demographic parity of exposure requires (in expectation) that 6 of the recommended members belong to the majority group while 4 belong to the minority group in every recommendation list.

Second, we explore a dynamic parity of utility constraint. As opposed to the first constraint, this fairness measure does not only consider the exposure of groups but also the total utility that can be expected from the exposure, i.e. for two groups (G_0) and (G_1),

$$frac{1}{vert G_0vert}sum_{din G_0}u_{s,d}^qsum_{r=1}^mP_s^q(d,r)v_r=frac{1}{vert G_1vert}sum_{din G_1}u_{s,d}^qsum_{r=1}^mP_s^q(d,r)v_r.$$

On a high level, the constraint requires that the sum of relevance scores across groups is proportional to the population share of groups discounted by the exposure limitations posed by position bias. To understand what this means, let us consider the example setting from before with no position bias, a 60%/40% group split, and recommendation lists with 10 slots. Assume that all destination members of the majority group have a relevance score of 0.12 while all destination members of the minority group have a fixed relevance of 0.08. The probabilistic ranking fulfills dynamic parity of utility if the expected group-split in the recommendation list is 50%/50% because (0.5 * 10 * 0.12 = 0.6) and (0.5 * 10 * 0.08 = 0.4).

How do we model relevance scores?

Relevance scores usually model the probability of connection if recommended, a measure of downstream engagement or some mixture of the two. The exact models employed by social media platforms are usually proprietary, but we opt for a synthetic model in the form of logistic regression in three realistic features [4,5]. Assuming our prediction target is the probability of connection, we first use the source member’s network size. This is based on the assumption that users with larger networks are more likely to be proactive in connection forming as they are generally more active on the platform. Next, we assume that users with more common connections are more likely to connect. The common connections feature is rooted in social network literature where it is commonly referred to as triadic closure [e.g. 6,7]. Lastly, we make the assumption that users with similarities such as demographics, interests, education, workplace, etc. are more likely to connect. This follows the observation that individuals like to be connected to similar individuals commonly referred to as homophily in sociology and other social sciences [e.g. 8].

The main simulation procedure makes use of the relevance scoring model for user pairs as follows. We assume a fixed-size graph of evolving connections with two groups of members. 65% of members belong to an initially more connected majority group. First, we assume a ground truth model for matching scores by imposing ground-truth parameter values in the presented logistic regression function. We use this function to simulate a data set of recommendations and formed connections and use the synthetic data to train a logistic regression matching scores model. 

Simulation procedure

We sample covariate vectors from group-dependent distributions for each member in the graph. These vectors will be used to compute the similarity between members which we set as negative Euclidean distance. The connections in the graph are initialized with a stochastic block model in which user pairs in the majority group are slightly more likely to connect than user pairs in the minority group and user pairs who belong to the same group are slightly more likely to connect than user pairs who belong to opposite groups. This models the initial advantage for majority group members we would expect to see in practice and accounts for homophily preferences. Given the initial specifications, we go through the previously described recommendation cycle for 2,500 timesteps and keep track of key fairness and performance metrics. Following our reasoning from the scoring model, the frequency with which users seek out recommendations is based on exponential waiting times with a mean depending on the current network size of a user. The whole simulation procedure is repeated with each fairness intervention separately and without intervention for comparison. Results are averaged over 10 runs. 

Figure 2: Results of the simulation study averaged over 10 runs. Panel (a) shows the absolute difference in average network sizes between groups over the simulation time period. Panel (b) shows the share of network degrees that belong to the majority group. A connection essentially translates into two degrees, one for the source member of the connection and one for the destination member. In panel (c), we see the share of the majority group on the destination members of new connections, and in panel (d), we see the share of the majority group among all new degrees.

Rich-gets-richer in groups

Figure 2 depicts the performance of all 3 intervention types. With no intervention – which is displayed in red in the graphs – we observe that the gap in average network sizes between groups drastically increases over time. In fact, our results reveal a group-wise rich-get-richer effect in which network sizes tend to a power law distribution with a lower mode in the minority population. Members whose networks are in the tail of the distribution tend to be the members who had large networks, as compared to their peers, to begin with. These observations are in line with previous findings on rich-get-richer phenomena under homophily.

Demographic parity of exposure intervention

We now move to the results for the demographic parity of exposure intervention. First, we can see that the majority group share of exposure in recommendations is down to 66.3% which suggests that the intervention is working. However, as we can see in panels (a) and (b), the gap in average network sizes is still increasing over time. Why is this the case? First, majority group members seek out recommendations more frequently based on their already larger network sizes and activity levels. This leads to more connections formed with majority group members at the source side of the recommendation while our intervention can only target the destination side. Second, majority group members have higher ranking scores and are thus more likely to be invited for connection even if both groups are exposed equally. Consider the following example recommendation:

Here, both female and male users have the same exposure – we will ignore position bias for this example. However, the probability of connection and its proxy – the matching score – are much lower for women than for men essentially leading to fewer new connections for women than for men. While the intervention suggests recommendations are made fairly, this does not align with our intuitive goal of more equitable network sizes.

Dynamic parity of utility intervention

Some of the problems with the demographic parity of exposure intervention are addressed by the dynamic parity of utility intervention. We see that the majority group share of destination members of new connections is down to 65.4% as desired. Yet, even with this type of constraint, panels (a) and (b) show that the gap in average network sizes is still increasing over time. Like in the previous case, one of the reasons for this is that majority group members seek out recommendations more frequently. In addition, our analysis reveals that source members from the majority group generally form more connections per recommendation query leading to the increased share in panel (d). To understand why this is the case, consider the following example:

In the first row, a majority group member – here displayed as male – receives recommendations. In the second row, a minority group member – here female – receives recommendations. In both recommendation lists the dynamic parity of utility constraint is fulfilled, but the male source member receives more overall utility from the query because the constraint can only target fairness within a list and not in between different sets of recommendations.

Summary of findings

Let us summarize the key findings. Our study shows that unconstrained connection recommendation leads to a group-wise rich-get-richer effect. Enforcing demographic parity of exposure or dynamic parity of utility between groups, which are commonly suggested remedies against demographic bias in recommender systems, leads to less bias amplification but is not sufficient in order to mitigate an increase in the disparities in network sizes over time. As shown in the full paper, theoretical limit analysis shows that dynamic parity of utility would be the optimal intervention if there was no source-side bias. Yet, this is an unrealistic assumption in practice.

Overall, the common practice of measuring fairness in recommender systems in a one-shot or time-aggregate static manner can lead to an illusion of fairness and deployment of fairness-enhancing algorithms with unforeseen consequences. Connection recommendation operates on a dynamical system that needs to be taken into account to ensure equitable outcomes in the long run.

The full paper is published in the proceedings of the AAAI / ACM conference on Artificial Intelligence, Ethics, and Society (AIES 2022). A preprint is available here.


[1] LinkedIn PYMK: [Online; accessed 7/6/22] [2] Lydia T. Liu, Sarah Dean, Esther Rolf, Max Simchowitz, and Moritz Hardt. Delayed impact of fair machine learning. In Proceedings of the 35th International Conference on Machine Learning (ICLM 2018), 2018.

[3] Deepak K. Agarwal and Bee-Chung Chen. 2016. Statistical Methods for Recommender Systems. Cambridge University Press.

[4] LinkedIn PYMK: [Online; accessed 7/6/22] [5] Facebook PYMK: [Online; accessed 7/6/22] [6] Kossinets, G., & Watts, D. J. (2006). Empirical Analysis of an Evolving Social Network. In Science (Vol. 311, Issue 5757, pp. 88–90). American Association for the Advancement of Science (AAAS).

[7] David Liben-Nowell and Jon Kleinberg. 2007. The link-prediction problem for social networks. J. Am. Soc. Inf. Sci. Technol. 58, 7 (May 2007), 1019–1031.

[8] McPherson, M., Smith-Lovin, L., & Cook, J. M. (2001). Birds of a Feather: Homophily in Social Networks. In Annual Review of Sociology (Vol. 27, Issue 1, pp. 415–444). Annual Reviews.

Read More

Recurrent Model-Free RL Can Be a Strong Baseline for Many POMDPs

Recurrent Model-Free RL Can Be a Strong Baseline for Many POMDPs

Figure 1. Our implementation of recurrent model-free RL outperforms the on-policy version (PPO/A2C-GRU), and a recent model-based POMDP algorithm (VRM) on most tasks of a POMDP benchmark where VRM was evaluated in their paper.

While algorithms for decision-making typically focus on relatively easy problems where everything is known, most realistic problems involve noise and incomplete information. Complex algorithms have been proposed to tackle these complex problems, but there’s a simple approach that (in theory) works on both the easy and the complex problems. We show how to make this simple approach work in practice.


Decision-making tasks in the real world are messy, with noise, occlusions, and uncertainty that are typically missing from their canonical problem formulation as a Markov decision process (MDP; Bellman, 1957). In contrast, Partially Observable MDPs (POMDPs; Åström, 1965) can capture the uncertainty in the states, rewards, and dynamics. Such uncertainty arises in applications such as robotics, healthcare, NLP and finance.

Apart from being realistic, POMDPs are a general framework that contains many subareas in RL, including:

What is recurrent model-free RL?

Solving POMDPs is hard because the agent needs to learn two tasks simultaneously: inference and control. Inference aims to infer the posterior over current states conditioned on history. Control aims to perform RL / planning algorithms on the inferred state space. While prior methods typically decouple the two jobs with separate models, deep learning, provides us with a general and simple baseline: combine an off-the-shelf RL algorithm with a recurrent neural network (RNN; e.g., LSTM (Hochreiter & Schmidhuber, 1997) and GRU (Chung et al., 2014)).

By backpropagating the gradients from policy loss, RNNs make it possible to process sequences (histories in POMDPs) and learn implicit inference on the state space for control. We refer to it as recurrent model-free RL. Fig. 2 presents our design, where actor and critic networks each have an RNN as the history encoder.

Figure 2. Our recurrent actor and critic architecture. Each network contains an RNN as the sequence encoder. The recurrent actor takes as input current observation, previous action and reward (optional), and outputs current action. The recurrent critic also takes current action as an input and outputs Q-values.

Recurrent model-free RL has many merits at first glance:

  • Conceptually simple. The policy purely learns from rewards, without extra objectives.
  • Easy to implement. Practitioners can just change several lines of code from model-free RL.
  • Expressive in theory. RNNs have been shown as universal function approximators (Siegelmann & Sontag, 1995, Schäfer & Zimmermann, 2006) and, thus recurrent model-free RL can (approximately) express any memory-based policies.

Due to its simplicity and expressivity, there is rich literature (Schmidhuber, 1991, Bakker, 2001, Wierstra et al., 2007, Heess et al., 2015, Hausknecht & Stone, 2015) on studying different RL algorithms and RNN architectures of recurrent model-free RL. However, prior work has shown that it often fails in practice with poor or unstable performance (Igl et al., 2018, Hung et al., 2018, Packer et al., 2018, Rakelly et al., 2019, Zintgraf et al., 2020, Han et al., 2020, Zhang et al., 2021, Raposo et al., 2021), with only a few exceptions (Yu et al., 2019, Fakoor et al., 2020).

Motivated by the poor performance of this simple baseline, prior work has proposed more sophisticated methods. Some introduce model-based objectives that explicitly learn inference, while others incorporate the assumptions used in the subarea of POMDPs as inductive bias. Both achieve good results on a range of respective tasks, although the model-based methods may have staleness issue in the belief states stored in the replay buffer, and the specialized methods require more assumptions than recurrent model-free RL (e.g., meta-RL methods normally assumes the hidden variable is constant within a single episode).

How to train recurrent model-free RL?

In this work, we found that recurrent model-Free RL is not fatally failed, but just needed to be implemented differently, with differences including:

  1. Separating the RNNs in actor and critic networks. Un-sharing the weights can prevent gradient explosion, and can be the difference between the algorithm learning nothing and solving the task almost perfectly.
  2. Using an off-policy RL algorithm to improve sample efficiency. Using, say, TD3 instead of PPO greatly improves sample efficiency.
  3. Tuning the RNN context length. We found that the RNN architectures (LSTM and GRU) do not matter much, but the RNN context length (the length of the sequence fed into the RL algorithm), is crucial and depends on the task. We suggest choosing a medium length as a start.

Properly tuned, the simple baseline outperforms alternatives on many POMDPs

With these changes, our implementation of recurrent model-free RL is at least on par with (if not much better than) prior methods, on the tasks those prior methods were designed to solve. While prior methods are typically designed to solve special cases of POMDPs, recurrent model-free RL applies to all types of POMDPs.

Our first comparison looks at meta-RL tasks, which are usually approached by methods that decouple the task inference and reward maximization steps (Rakelly et al., 2019, Zintgraf et al., 2020). When comparing these prototypical methods (on-policy variBAD (Zintgraf et al., 2020) and off-policy variBAD (Dorfman et al., 2020)), we find that our recurrent model-free approach can often perform at least on par with them. These results (Fig. 3, 4) suggest that disentangling inference and control may be not that necessary in many tasks.

Figure 3. Two meta-RL environments from off-policy variBAD (left: Semi-Circle, right: Wind), where ours outperforms off-policy variBAD.
Figure 4. Two meta-RL environments from on-policy variBAD (left: Cheetah-Dir, right: Ant-Dir), where we have mixed results.

Next, we move on to the robust RL setting, which is mostly solved by algorithms that explicitly maximize the worst returns (Rajeswaran et al., 2017, Mankowitz et al., 2020). By comparing one recent robust RL algorithm (Jiang et al., 2021), we find that our recurrent model-free approach performs better in all the tasks. The results (Fig. 5) indicate that with the power of implicit task inference with RNNs, we can improve both average and worst returns.

Figure 5. One robust RL environment from MRPO, Cheetah-Robust (left: average return; right: worst return), where ours outperforms MRPO.

Then we explore the generalization in RL, where people have different kinds of specialized methods, including policy regularization (Farebrother et al., 2018) and data augmentation (Lee et al., 2020). Here we choose a popular benchmark SunBlaze (Packer et al., 2018) which provides a specialized baseline EPOpt-PPO-FF (Rajeswaran et al., 2017). The results (Fig. 6) show that despite not explicitly enhancing generalization, our recurrent model-free approach can perform better in extrapolation than the baseline. This suggests that the task inference learned by RNN is generalizable.

Figure 6. One generalization in RL environment from SunBlaze, Hopper-Generalize (left: interpolated success rate; right: extrapolated success rate), where ours outperforms EPOpt-PPO-FF on extrapolation.

Finally, we study the temporal credit assignment domain where the methods for solving it usually involve reward decomposition/redistribution (Liu et al., 2019, Hung et al., 2018). Here, we choose a recent specialized method IMPALA+SR (Raposo et al., 2021), and evaluate our method on their benchmark with pixel-based discrete control and sparse rewards. Despite being unaware of the reward structure and not performing credit assignment explicitly, our recurrent model-free approach can perform better. The results (Fig. 7) indicate that recurrent policies can effectively cope with sparse rewards, perhaps better than previously expected.

Figure 7. Two temporal credit assignment environments from IMPALA+SR (left: Delayed-Catch, right: Key-to-Door), where ours outperforms IMPALA+SR.


In a sense, our finding can be interpreted as echoing the motivation for deep learning: we can achieve better results by reducing a method to a single differentiable architecture, optimized end-to-end with a single loss. This is exciting because RL systems often involve many interconnected parts (e.g., feature extracting, model learning, value estimation) trained with different objectives, but perhaps they might be replaced by end-to-end approaches if equipped with sufficiently expressive architectures.

We have open-sourced the code on GitHub to support reproducibility and to help future work develop better POMDP algorithms. Please see our project site for the paper and our ICML 2022 presentation.

Read More

auton-survival: An Open-Source Package for Regression, Counterfactual Estimation, Evaluation and Phenotyping Censored Time-to-Event Data


GitHub Repo stars

Real-world decision-making often requires reasoning about when an event will occur. The overarching goal of such reasoning is to help aid decision-making for optimal triage and subsequent intervention. Such problems involving estimation of Times-to-an-Event frequently arise across multiple application areas, including,

Healthcare and Bio-informatics: More commonly known as ‘Survival Analysis‘ involves prognostication of an adverse physiological event like a stroke, the onset of cancer, re-hospitalization, and mortality. Time-to-event or survival analysis can be used to proactively mitigate adverse outcomes and extend the longevity of patients.
Internet Marketing and e-commerce: Models employed for estimating customer churn and retention in large commercial organizations are essentially time-to-event regression models and help determine best practices to maximize customer retention.
Predictive Maintenance: Reliability engineering and systems safety research involves the use of remaining useful life prediction models to help extend the longevity of machinery and equipment by proactive part and component replacement.
Finance and Actuarial and Sciences: Time-to-Event models are ubiquitous in the estimation of optimal financial strategies for setting insurance premiums, as well as estimating credit defaulting behavior.

Figure 1: Patients A and C died 1 and 4 years following entry into the study, whereas mortality outcomes are missing for Patients B and D prior to their exit from the study at 3 and 2 years following entry. Time-to-Event or Survival Regression involves adjusting estimates for such individuals whose outcomes are censored.

Real-world decision-making often requires reasoning about when an event will occur. The overarching goal of such reasoning is to help aid decision-making for optimal triage and subsequent intervention. Such problems involving estimation of Times-to-an-Event frequently arise across multiple application areas, including,

Healthcare and Bio-informatics: More commonly known as ‘Survival Analysis‘ involves prognostication of an adverse physiological event like a stroke, the onset of cancer, re-hospitalization, and mortality. Time-to-event or survival analysis can be used to proactively mitigate adverse outcomes and extend the longevity of patients.
Internet Marketing and e-commerce: Models employed for estimating customer churn and retention in large commercial organizations are essentially time-to-event regression models and help determine best practices to maximize customer retention.
Predictive Maintenance: Reliability engineering and systems safety research involves the use of remaining useful life prediction models to help extend the longevity of machinery and equipment by proactive part and component replacement.
Finance and Actuarial and Sciences: Time-to-Event models are ubiquitous in the estimation of optimal financial strategies for setting insurance premiums, as well as estimating credit defaulting behavior.

Figure 1 illustrates a typical example of a Time-to-Event problem in healthcare. The challenge of working with time-to-event data is compounded by the fact that as evidenced in the figure, such data typically includes individuals whose outcomes are unobserved, or ‘censored,’ either due to a loss of follow-up or end of the study.

Discretizing time-to-event outcomes to predict if an event will occur is a common approach in standard machine learning. However, this neglects temporal context, which could result in models that misestimate and lead to poorer generalization.

The auton-survival Package

In our recent Machine Learning for Healthcare ’22 paper, we present auton-survival – a comprehensive Python code repository of user-friendly, machine learning tools for working with censored time-to-event data. This package includes an exclusive suite of workflows for a range of tasks from data pre-processing and regression modeling to model evaluation. auton-survival includes an API similar to the scikit-learn package (Pedregosa et al., 2011), making its adoption easy for users with machine learning experience in Python. Additionally, to promote the usability of the package and rapid prototyping of solutions for both machine learning and clinical researchers, we include detailed documentation as well as example notebooks.

Time-to-Event Regression

Time-to-Event or Survival regression can be used to estimate the conditional probability of an event occurring within a specified time period or event-horizon. A time-to-event estimation problem thus reduces to estimating the conditional distribution of survival:

( mathbb{E}[1{T > t}|X = x] = mathbb{P}(T > t|X = x) = 1 − mathbb{P}(T ≤ t|X = x) )

Note that ( X) is a set of covariates, and ( T ) refers to the distribution of the censored survival time ( T = text{min}(T^∗, C) ) where ( T^∗ ) is the distribution of the true time-to-event and ( C ) is the distribution of the censoring time. Assuming conditional independence between ( T ) and ( C ) (ie., ( T ⊥ C|X )) allows identification of the distribution of ( mathbb{P}(T |X) ).

Survival regression naturally allows accounting for censored data. In the case of survival regression, the likelihood ( ell ) under censoring is given as

( ell ({x, t, δ}) ∝ mathbb{P}(T = t|X = x)^δmathbb{P}(T > t|X = x)^{1−δ} ).

Here ( x in mathbb{R}^d ) are the covariates, ( t in mathbb{R}^{+} ) is the event or censoring time and ( delta in {0, 1} ) is a binary indicator denoting if the individual was censored. For the censored individuals, the likelihood corresponds to the probability that the event takes place beyond the time horizon, ( t, mathbb{P}(T > t|X = x) ) also known as the ‘survival function‘.

Broadly, the popular approaches for learning estimators of survival in the presence of censoring can be categorized into:

  • Parametric: Assume that time-to-event distribution ( mathbb{P}(T) ) adheres to a known parametric distribution, such as Weibull or Log-Normal.
  • Non-Parametric: Involve learning kernels or similarity functions of the input covariates followed by a non-parametric (Kaplan-Meier or Nelson-Aalen) estimation of the survival rate weighted with the learned kernel.
  • Semi-Parametric: As with Cox Proportional Hazards models, feature interactions are learned through a parametric model followed by a non-parametric estimation of the base survival (hazard) rate.

Estimators of Survival [Notebook] [Docs]

Figure 2: When proportional hazards assumptions are satisfied, survival curves and their corresponding hazard rates do not intersect. auton-survival includes flexible estimators of time-to-events in the presence of non-proportional hazards.

Complex multimodal data often observed in healthcare and other applications, bring a multitude of challenges to traditional machine learning. auton-survival allows a simple interface to use deep neural networks and representation learning to model such complex data.

auton-survival includes extensions to the standard Cox Proportional Hazards (CPH) (Cox, 1972) involving deep representation learning (Faraggi and Simon, 1995; Katzman et al., 2018) as well as latent variable survival regression models, Deep Cox Mixtures (DCM), and Deep Survival Machines (DSM) (Nagpal et al., 2021 a,b) that ease the strong assumptions of proportional hazards shown in Figure 2 by modeling the time-to-event distribution as a fixed size mixture.

The SurvivalModel Class

The package provides a convenient SurvivalModel class that enables rapid experimentation via a consistent API that wraps multiple alternative regression estimators. In addition to the models mentioned above, the SurvivalModel class includes Random Survival Forests (RSF) (Ishwaran et al., 2008), which is a popular non-parametric survival model.

Hyperparameter tuning for model selection can be streamlined with the SurvivalRegressionCV class to apply ( K-text{fold} ) cross-validation over a user-specified hyperparameter grid.

Time-Varying Survival Regression

Real-world data often consists of multiple time-dependent observations per individual or time-varying covariates. auton-survival is equipped to handle time-varying covariates for survival analysis with auto-regressive deep learning models that allow learning temporal dependencies when estimating time-to-event outcomes. Implementations of time-varying DSM and Deep Cox Proportional Hazards model involve the use of RNNs, LSTMs, or GRUs (Chung et al., 2014; Hochreiter and Schmidhuber, 1997) for time-varying survival regression as shown in Figure 3.

Figure 3: For an individual with time-to-event (mathcal{T}_i), we observe covariates (x_i^j) at multiple time points (t_i^j). At each time-step, (j) we estimate the distribution of the remaining time-to-event (T_i – t_i^j). The representations of the input covariates (widetilde{x}_i^{(j+1)}) at time-step ((j)+1) are functions of the covariates, (widetilde{x}_i^{(j+1)}) and the representation of the preceding time-step (widetilde{x}_i^j).

Counterfactual Estimators of Survival

Decision support often requires reasoning about ‘what if’ scenarios regarding the effect of different treatments on outcomes. In observational settings, outcomes and treatment assignments may share common causes. Adjusting for such confounding factors is crucial when performing causal inference. auton-survival includes counterfactual survival regression as a tool for causal inference that accounts for confounding factors when estimating the effect of treatment on survival. Counterfactual survival regression involves fitting separate regression models on the treated and control populations and computing survival rates across treatment arms. Under the standard causal inference assumption of strong ignorability, the time-to-event outcome under intervention ( text{do}(A = a) ) can then be estimated as

(hat{S}big(t|text{do}(A = a)big) = mathop{mathbb{E}}_{X} big[hat{mathbb{E}}[1{T > t}|X = x, A = a] big] )

where (hat{mathbb{E}}[1{T > t}|X = x, A = a]) is just an estimate of the conditional expectation of survival learnt on the population under intervention ( (A=a) ).

Consider the data from the large SEER Cancer Incidence registry (Ries et al., 1975). When stratified by region (Figure 4a), there is an apparent disparity in survival rates. We demonstrate the use of counterfactual regression to provide insight into whether these discrepancies can be attributed to the geographic region or other socio-economic or physiological confounding factors that affect both belonging to these regions and the outcomes. To adjust estimates of survival with counterfactual estimation, we train two separate Deep Cox models on data from Greater California and Louisiana as counterfactual regressors. The fitted regressors are then applied to estimate the survival curves for each instance, which are then averaged over treatment groups to compute the domain-specific survival rate. Figure 4b presents the counterfactual survival rates compared with the survival rates obtained from a Kaplan-Meier estimator. The Kaplan-Meier estimator does not adjust for confounding factors and overestimates treatment effect, as evidenced by the extent that survival rates differ between regions. Alternatively, counterfactual regression adjusts for confounding factors and predicts more similar survival rates between regions.

Phenotyping Censored Survival Data [Notebook] [Docs]

a) Unsupervised b) Supervised c) Counterfactual
Figure 5:
Phenotypers in auton-survival: ( X ) represents the covariates, ( T ) the time-to-event, and ( Z ) is the phenotype to be inferred.

Survival rates differ across groups of individuals with heterogeneous characteristics. Identifying groups of patients with similar survival rates can be used to derive insight into practices and interventions that can help improve longevity for such groups. While domain knowledge can help identify such subgroups, in practice there could be potentially complex, non-linear feature interactions that determine assignment to subgroups, making identification difficult. In auton-survival, we refer to this group identification and survival assessment as phenotyping.

Our package offers multiple approaches to phenotyping that involve either the use of specific domain knowledge, as in the case of the intersectional phenotyper, or a completely unsupervised approach that clusters subjects based on the observed covariates. Additionally, auton-survival also offers phenotypers that explicitly involve supervision in the form of the observed outcomes and counterfactuals inform the learned phenotypes to better stratify the data. Directed Acyclic Graphical representations of probabilistic phenotypers in auton-survival are shown in Figure 5.

  • Intersectional Phenotyping: Recovers groups, or phenotypes, of individuals over exhaustive combinations of user-specified categorical and numerical features.
  • Unsupervised Phenotyping: Identifies groups of individuals based on structured similarity in the feature space by first performing dimensionality reduction of the input covariates, followed by clustering. The estimated probability of an individual belonging to a latent group is computed as the distance to the cluster normalized by the sum of distances to other clusters.
  • Supervised Phenotyping: Identifies latent groups of individuals with similar survival outcomes conditioned on outcomes. This approach can be performed as a direct consequence of training the DSM and DCM latent variable survival estimators.
  • Counterfactual Phenotyping: Identifies groups of individuals that demonstrate enhanced or diminished treatment effects (Chirag et al., 2022).

Figure 6 presents the Kaplan-Meier survival curves of the phenogroups extracted from SUPPORT (Knaus et al., 1995) using the unsupervised and supervised phenotyping. The intersecting survival curves suggest the phenotypers’ ability to recover phenogroups that do not strictly adhere to assumptions of Proportional Hazards. From Figure 7, it can be inferred that supervised phenotyping extracts phenogroups with higher discriminative power as indicated by the contrasting Kaplan-Meier estimates of phenogroup level survival.

Treatment Effect Estimation

auton-survival offers additional tools to analyze the effect of an intervention on outcomes by computing propensity-adjusted treatment effects in terms of the following metrics through bootstrap resampling of the dataset with replacement:

  • Hazard Ratio: Assuming the proportional hazards assumptions holds, the treatment effect can be measured as the ratio of hazard rates between the treatment and control arms.
  • Time at Risk (TaR) (Figure 7a): The treatment effect can be measured as the difference in time-to-event at a specified level of risk.
  • Risk at Time (Figure 7b): The treatment effect can be measured as the difference in risk at a specified time horizon.
  • Restricted Mean Survival Time (RMST) (Figure 7c): The treatment effect can be measured as the difference in the expected (or mean) time-to-event conditioned on a specified time horizon.

Propensity-adjustment allows an alternative approach to estimate treatment effects of potential confounders that influence both treatment assignment and the outcome. Not adjusting for treatment propensity could result in misestimations of treatment effects.

auton-survival allows adjusting for treatment propensity with computation of treatment effects bootstrapped with sample weights. When the specified sample weights are propensity scores, such as obtained from a classification model, the bootstrapped distribution treatment effect converges to the Inverse Propensity of Treatment Weighting (IPTW) Thompson-Horvitz estimate of the population Average Treatment Effect.

(mathbb{ATE}(mathcal{D}^*, f) = mathbb{E}_{x sim mathcal{D}^*} big[mathbb{E} [f_1(x) – f_0(x) | X = x]big] ; quad mathcal{D}^* sim frac{1}{widehat{mathbb{P}}(A|X)} cdot mathbb{P}^*(mathcal{D}) )

In a second analysis of the effect of geographical region on breast cancer mortality using data from the SEER cancer registry (Ries et al., 1975), we compare treatment effects before and after adjusting for confounding factors by inverse propensity weighting. Similar to the previous analysis with counterfactual regression, we consider the regions of “Greater California” and “Louisiana” as the binary “treatment” in question. To adjust treatment effects for confounding factors, we first trained a logistic regression with an ( ell_2 ) penalty by regressing the geographical region on the set of confounding variables. The estimated propensity scores are then employed as sampling weights for the treatment effects in terms of hazard ratios, restricted mean survival time (RMST), and risk difference as in Figure 8. Adjusting for region propensity noticeably mitigates differences in treatment effects, indicating that mortality due to breast cancer is likely explained by confounding socio-economic and physiological factors rather than solely the geographic region.


We present auton-survival, an open-source Python package encapsulating multiple pipelines to work with censored time-to-event data. Such data is ubiquitous in many fields, including healthcare and the maintenance of equipment. Through continuous collaboration with the machine learning for healthcare community, we aim to better aid machine learning research in efforts to create a robust, comprehensive repository of rigorous tools for reproducible analysis of censored time-to-event data.


Chirag Nagpal
PhD Candidate, Auton Lab

Willa Potosnak
Research Intern and
Incoming PhD Student, Auton Lab


[1] Nagpal, C., Potosnak, W. and Dubrawski, A., 2022. auton-survival: an Open-Source Package for Regression, Counterfactual Estimation, Evaluation and Phenotyping with Censored Time-to-Event Data. arXiv preprint arXiv:2204.07276.

[2] Fabian Pedregosa, Ga¨el Varoquaux, Alexandre Gramfort, Vincent Michel, Bertrand Thirion, Olivier Grisel, Mathieu Blondel, Peter Prettenhofer, Ron Weiss, Vincent Dubourg, et al. Scikit-learn: Machine learning in python. the Journal of machine Learning research, 12: 2825–2830, 2011.

[3] D. R. Cox. Regression models and life-tables. Journal of the Royal Statistical Society. Series B (Methodological), 34(2):187–220, 1972.

[4] David Faraggi and Richard Simon. A neural network model for survival data. Statistics in medicine, 14(1):73–82, 1995.

[5] Jared L Katzman, Uri Shaham, Alexander Cloninger, Jonathan Bates, Tingting Jiang, and Yuval Kluger. Deepsurv: personalized treatment recommender system using a cox proportional hazards deep neural network. BMC medical research methodology, 18(1): 1–12, 2018.

[6] Nagpal, C., Li, X. and Dubrawski, A., 2021a. Deep survival machines: Fully parametric survival regression and representation learning for censored data with competing risks. IEEE Journal of Biomedical and Health Informatics25(8), pp.3163-3175.

[7] Nagpal, C., Yadlowsky, S., Rostamzadeh, N. and Heller, K., 2021b, October. Deep Cox mixtures for survival regression. In Machine Learning for Healthcare Conference (pp. 674-708). PMLR.

[8] H. Ishwaran, Udaya B. Kogalur, Eugene H. Blackstone, and Michael S. Lauer. Random survival forests. The Annals of Applied Statistics, 2(3), 2008.

[9] Junyoung Chung, Caglar Gulcehre, KyungHyun Cho, and Yoshua Bengio. Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555, 2014.

[10] Sepp Hochreiter and J¨urgen Schmidhuber. Long short-term memory. Neural computation, 9 (8):1735–1780, 1997.

[11] LAG Ries, D Melbert, M Krapcho, DG Stinchcomb, N Howlader, MJ Horner, A Mariotto, BA Miller, EJ Feuer, SF Altekruse, et al. Seer cancer statistics review, 1975–2005. Bethesda, MD: National Cancer Institute, 2999, 2008.

[12] Nagpal, C., Goswami, M., Dufendach, K. and Dubrawski, A., 2022. Counterfactual Phenotyping with Censored Time-to-Events. arXiv preprint arXiv:2202.11089.

[13] W. A. Knaus, Harrell F. E., Lynn J, and et al. The support prognostic model: Objective estimates of survival for seriously ill hospitalized adults. Annals of Internal Medicine, 122: 191–203, 1995.

Read More

Does AutoML work for diverse tasks?

Does AutoML work for diverse tasks?

Over the past decade, machine learning (ML) has grown rapidly in both popularity and complexity. Driven by advances in deep neural networks, ML is now being applied far beyond its traditional domains like computer vision and text processing, with applications in areas as diverse as solving partial differential equations (PDEs), tracking credit card fraud, and predicting medical conditions from gene sequences. However, progress in such areas has often required expert-driven development of complex neural network architectures, expensive hyperparameter tuning, or both. Given that such resource intensive iteration is expensive and inaccessible to most practitioners, AutoML has emerged with an overarching goal of enabling any team of ML developers to deploy ML on arbitrary new tasks. Here we ask about the current status of AutoML, namely: can available AutoML tools quickly and painlessly attain near-expert performance on diverse learning tasks?

This blog post is dedicated to two recent but related efforts that measure the field’s current effectiveness at achieving this goal: NAS-Bench-360 and the AutoML Decathlon. The first is a benchmark suite focusing on the burgeoning field of neural architecture search (NAS), which seeks to automate the development of neural network models. With evaluations on ten diverse tasks—including a precomputed tabular benchmark on three of them—NAS-Bench-360 is the first NAS testbed that goes beyond traditional AI domains such as vision, text, and audio signals. Specifically, the 10 tasks vary in their domain (including image, finance time series, audio, and natural sciences), problem type (including regression, single-label, and multi-label classification), and scale (ranging from several thousands to hundreds of thousands of observations).

The second is a NeurIPS 2022 competition (which we are soft-launching today!) that builds on our NAS-Bench-360 work yet has a broader vision of understanding what is truly the best approach for a practitioner to take when faced with a modern ML problem.  During the public development phase of the competition we will release a set of diverse tasks that will be representative of (but distinct from) the final set of test tasks on which evaluation will be performed. Unlike most past competitions in the AutoML community, competitors in the AutoML Decathlon are free (and in fact encouraged) to consider a wide range of approaches from traditional hyperparameter optimization and ensembling methods to modern techniques such as NAS and large-scale transfer learning.

You can learn more about getting involved with either of these efforts at the bottom of this post.

NAS-Bench-360: A NAS Benchmark for diverse tasks

NAS-Bench-360 is a benchmark suite consisting of ten ML tasks that we developed jointly with Renbo Tu, Nick Roberts, Junhong Shen, and Fred Sala. These tasks represent a diverse set of signals, including various kinds of imaging sources, simulation data, genomic data, and more. At the same time, we constrain all tasks to be amenable to modern NAS search spaces, i.e. we do not include tabular or graph-based data, thus allowing for the application of most NAS methods. Our evaluation on NAS-Bench-360 is thus a robustness test that checks whether the massive amount of largely computer vision-driven progress in the field of NAS is actually indicative of wider success of AutoML across a variety of applications, data types, and tasks. More importantly, the benchmark will serve as a useful tool to develop and evaluate new, better methods for NAS.

So can AutoML tools—specifically NAS methods—quickly and painlessly attain near-expert performance on NAS-Bench-360? In positive news, searching over a large search space such as DARTS using a state-of-the-art algorithm such as GAEA does yield models that outperform available expert architectures on half of the tasks, in addition to consistently beating perennial Kaggle favorite XGBoost and a recent attempt at a general-purpose architecture, Perceiver IO. On the other hand it fails catastrophically on several tasks, doing little better than a simple baseline, namely a tuned Wide ResNet (Figure 1, left panel). Indeed, despite being developed on CIFAR-10 it does surprisingly poorly on 2D classification tasks from the medical and audio domains. Furthermore, in a resource-constrained setting where AutoML methods are not given much more time than running a single architecture, the leading NAS method DenseNAS does worse than an untuned Wide ResNet (Figure 1, right panel).

Figure 2:  Whereas high-performance architectures on vision datasets often perform well on other vision datasets (left), we use NAS-Bench-360 to show that this does not translate to diverse tasks (right).

Our evaluation of modern NAS methods on NAS-Bench-360 demonstrates the need for such a benchmark and a lack of robustness in the field. NAS-Bench-360 is also useful for understanding past and future search spaces and algorithms, specifically whether current beliefs about NAS extend to diverse tasks. For example, Figure 2 shows that high-performing architectures transfer well between vision tasks—a quality used extensively in NAS research—but not between diverse tasks. Other examples of scientific uses of NAS-Bench-360—such as one investigating a recent paper on operation redundancy—are provided in our paper and in a recent ICLR 2022 blog post on zero-cost proxies. We also expect NAS-Bench-360 to be used for the development of new NAS methods; to further this, for two of the datasets we provide precomputed models for all architectures in the NAS-Bench-201 search space; together with existing CIFAR-100 precompute results this means three NAS-Bench-360 datasets have precomputed tabular benchmarks to accelerate search algorithm development. 

The AutoML Decathlon: A competition focused on diverse tasks and methods

Our goal in releasing NAS-Bench-360 is to spur the development of NAS methods that work well on diverse tasks. However, given the mixed performance of NAS on this benchmark, there remains a question of whether automatic architecture design should even be the focus of AutoML research more broadly. Building on our efforts from NAS-Bench-360, a group of researchers at CMU, Hewlett Packard Enterprise (HPE), Wisconsin-Madison, and Morgan Stanley are organizing the AutoML Decathlon competition at NeurIPS 2022 precisely to ask the following broader question: what automated technique(s) are best for diverse tasks?

This competition is designed to address two gaps between research and practice:

  1. Lack of task diversity. The field of NAS is no exception here, as the vast majority of recent AutoML benchmarking and competition efforts have focused on computer vision or other well studied tasks in speech and language processing. Evaluating AutoML methods on such well-studied tasks does not give a good indication of their utility on more far-afield applications.
  2. Siloed methodological development. Many developments in AutoML narrowly focus on particular techniques rather than the downstream benefits to the end user. A practitioner with a specific ML task ultimately cares about the quality of the resulting model (in terms of accuracy and other non-accuracy metrics), as opposed to the underlying technical details of the procedure yielding this model, e.g., whether the model is the result of a weight-sharing NAS method, a fine-tuned large model,  a more classical non-deep learning AutoML technique, or some other automated procedure.

By designing our competition in a practitioner-centric fashion and accounting for the two aforementioned gaps, our competition aims to spur innovation in AutoML with results that are directly transferable to ML practitioners. We envision that the results of our competition will provide novel empirical insights into several open practical and scientific questions, including:

  • Given the growing methodological diversity of (Auto)ML approaches, what methods should I consider as a practitioner in 2022?  
  • How do leading NAS methods compare to the increasingly popular pre-training/fine-tuning paradigm?
  • How do either of these more modern approaches compare to classical AutoML approaches or to standard baselines such as XGBoost or a tuned ResNet?
  • Should I consider using any AutoML procedure given that I’m working on a specific scientific, technological, or industrial problem that seemingly differs drastically from well-studied tasks in computer vision and NLP?  
  • Given a reasonable computational budget, can any AutoML approach (whether classical or more modern) consistently outperform bespoke models that were hand-crafted by either domain experts and/or ML experts?
Figure 3: Summary of the AutoML Decathlon competition timeline. To ensure efficiency, the evaluation will be conducted under a fixed computational budget. To ensure robustness, the performance profile methodology described above will be used for determining the winners.

We note that while AutoML is not a new research area, we view our competition as being particularly timely given (1) rapid growth of ML task diversity, (2) progress in ML model development, and (3) acceleration in the scale of both datasets and available compute resources. Indeed, recent progress along these three dimensions has led us to make remarkably different design choices from those of past competitions like the AutoDL competition, which was launched just three years ago. For instance, we work with bigger datasets, allow larger computational budgets, consider an expanding set of applications, and perform more robust evaluations based on performance profiles. Relatedly, while over the past three years we’ve witnessed significant progress in NAS and the emergence of the pretrain/fine-tuning paradigm in various settings, neither of these types of approaches featured prominently in the AutoDL competition (or other past competitions). In contrast, we hypothesize that these approaches will be more prominently featured in the AutoML Decathlon.

The AutoML Decathlon is built around a set of 20 datasets that we have curated which represent a broad spectrum of practical applications in scientific, technological, and industrial domains. As explained in Figure 3, ten of the tasks will be used for development and an additional ten tasks will be used for final evaluation and revealed only after the competition. We will provide computational resources to participants as needed, with funding provided by Morgan Stanley. The results of our performance-profile based evaluation will determine monetary prizes, including a $15K first prize, with sponsorship provided by HPE.

Getting Involved: Using NAS-Bench-360 and competing in the AutoML Decathlon

Our goal with both NAS-Bench360 and the AutoML Decathlon is to encourage community participation in evaluating what AutoML is already good at, what areas need improving, and what directions seem most promising for future work. We hope that these rigorous benchmarking activities will help the field more rapidly move towards a truly democratized ML toolkit that can be used by researchers and practitioners alike.

To learn more, check out the following links:

  • NAS-Bench360: You can download the ten datasets on the website, and learn more about the benchmark and our various insights from our paper.
  • AutoML Decathlon: The competition officially starts next week and runs through mid October, but we are soft-launching today to spread the word. You can learn more about the details at the competition website and the associated CodaLab website. 

Also, stay tuned for a follow up blog post where our collaborator Junhong Shen describes our recent algorithmic NAS work targeting diverse tasks.

Read More

Deep Attentive Variational Inference

Figure 1: Overview of a local variational layer (left) and an attentive variational layer (right) proposed in this post.

Generative models are a class of machine learning models that are able to generate novel data samples such as fictional celebrity faces, digital artwork, or scenery images. Currently, the most powerful generative models are deep probabilistic models. This class of models uses deep neural networks to express statistical hypotheses about the way in which the data have been generated. Latent variable models augment the set of the observed data with latent (unobserved) information in order to better characterize the procedure that generates the data of interest.

In spite of the successful results, deep generative modeling remains one of the most complex and expensive tasks in AI. Recent models rely on increased architectural depth to improve performance. However, as we show in our paper [1], the predictive gains diminish as depth increases. Keeping a Green-AI perspective in mind when designing such models could lead to their wider adoption in describing large-scale, complex phenomena.

A quick review of Deep Variational AutoEncoders

Latent variable models augment the set of the observed variables with auxiliary latent variables. They are characterized by a posterior distribution over the latent variables, one which is generally intractable and typically approximated by closed-form alternatives. Moreover, they provide an explicit parametric characterization of the joint distribution over the expanded random variable space. The generative and the inference portions of such a model are jointly trained. The Variational AutoEncoder (VAE) belongs to this model category. Figure 2 provides an overview of a VAE.

Figure 2: A Variational AutoEncoder consists of a generative model and an inference model. The generative model, or decoder, is defined by a joint distribution of latent and observed variables. The inference model, or encoder, approximates the true posterior of the latent variables given the observations. The two parts are jointly trained.

VAEs are trained by maximizing the Evidence Lower BOund (ELBO) which is a tractable, lower bound of the marginal log-likelihood:

[text{log } p(x) ge mathbb{E}_{q(zmid x)}large[text{log } p(xmid z)large] – D_{KL} large(q(zmid x) mid mid p(z)large). ]
Figure 3: Overview of a hierarchical VAE.

The most powerful VAEs introduce large latent spaces (z) that are organized in blocks such that (z = {z_1, z_2, dots, z_L}), with each block being generated by a layer in a hierarchy. Figure 3 illustrates a typical architecture of a hierarchical VAE. Most state-of-the-art VAEs correspond to a fully connected probabilistic graphical model. More formally, the prior distribution follows the factorization:

[ p(z) = p(z_1) prod_{l=2}^L p(z_l mid z_{<l}). text{ (1)}]

In words, (z_l) depends on all previous latent factors (z_{<l}). Similarly, the posterior distribution is given by:

[q(zmid x) = q(z_1 mid x) prod_{l=2}^L q(z_l mid x, z_{<l}). text{ (2)}]

The long-range conditional dependencies are implicitly enforced via deterministic features that are mixed with the latent variables and are propagated through the hierarchy. Concretely, each layer (l) is responsible for providing the next layer with a latent sample (z_l) along with context information (c_l):

[c_l leftarrow T_l left (z_{l-1} oplus c_{l-1} right). text{ (3)}]

In a convolutional VAE, (T_l) is a non-linear transformation implemented by ResNet blocks as shown in Figure 1. The operator (oplus) combines two branches in the network. Due to its recursive definition, (c_l) is a function of (z_{<l}).

Deep Variational AutoEncoders are “overthinking”

Recent models such as NVAE [2], rely on the increased depth to improve performance and deliver results comparable to that of purely generative, autoregressive models while permitting fast sampling that requires a single network evaluation. However, as we show in our paper and Table 1, the predictive gains diminish as depth increases. After some point, even if we double the number of layers, we can only realize a slight increase in the marginal likelihood.

Depth (L) bits/ dim (
(Delta(cdot) % )
2 3.5
4 3.26 -6.8
8 3.06 -6.1
16 2.96 -3.2
30 2.91 -1.7
Table 1: Deep VAEs suffer from diminishing returns. ( -text{log } p(x) ) in bits per dimension and relative decrease for varying number of variational layers (L).

We argue that this may be because the effect of the latent variables of earlier layers diminishes as the context feature (c_l) traverses the hierarchy and is updated with latent information from subsequent layers. In turn, this means that in practice the network may no longer respect the factorization of the variational distributions of Equations (1) and (2), leading to sub-optimal performance. Formally, large portions of early blocks (z_l) collapse to their prior counterparts, and therefore, they no longer contribute to inference.

This phenomenon can be attributed to the local connectivity of the layers in the hierarchy, as shown in Figure 4.a. In fact, a layer is directly connected only with the adjacent layers in a deep VAE, limiting long-range conditional dependencies between (z_l) and (z_{<<l}) as depth increases.

The flexibility of the prior (p(z)) and the posterior (q(z mid x)) can be improved by designing more informative representations for the conditioning factors of the conditional distributions (p(z_l mid z_{<l})) and (q(z_l mid x, z_{<l})). This can be accomplished by designing a hierarchy of densely connected stochastic layers that dynamically learn to attend to latent and observed information most critical to inference. A high-level description of this idea is illustrated in Figure 4.b.

Figure 4: (a) Locally Connected Variational Layer.
(b) Strongly Connected Variational Layer.

In the following sections, we describe the technical tool that allows our model to realize the strong couplings presented in Figure 4.b.

Problem: Handling long sequences of large 3D tensors

In deep convolutional architectures, we usually need to handle long sequences of large 3D context tensors. A typical sequence is shown in Figure 5. Constructing effectively strong couplings between current and previous layers in a deep architecture can be formulated as:

Figure 5: Sequence of 3D tensors in a convolutional architecture.

Problem definition: Given a sequence (c_{<l}={c_m}_{m=1}^{l-1}) of (l-1) contexts (c_m) with (c_min mathbb{R}^{H times W times C}), we need to construct a single context (hat{c}_linmathbb{R}^{H times W times C}) that summarizes information in (c_{<l}) that is most critical to the task.

In our framework, the task of interest is the construction of posterior and prior beliefs. Equivalently, contexts ( hat{c}^q_l) and ( hat{c}^p_l) represent the conditioning factor of the posterior and prior distribution of layer (l).

There are two ways to view a long sequence of (l-1) large (H times W times C)-dimensional contexts:

  • Inter-Layer couplings: As (H times W) independent pixel sequences of (C-)dimensional features of length (l-1). One such sequence is highlighted in Figure 5.
  • Intra-Layer couplings: As (l-1) independent pixel sequences of (C-)dimensional features of length (H times W).

This observation leads to a factorized attention scheme that identifies important long-range, inter-layer, and intra-layer dependencies separately. Such decomposition of large and long pixel sequences leads to significantly less compute.

Inter-Layer couplings: Depth-wise Attention

The network relies on a depth-wise attention scheme to discover inter-layer dependencies. The task is characterized by a query feature (s). During this phase, the pixel sequences correspond to instances of a pixel at the previous layers in the architecture. They are processed concurrently and independently from the rest. The contexts are represented by key features (k) of a lower dimension. The final context is computed as a weighted sum of the contexts according to an attention distribution. The mechanism is explained in Figure 6.

Figure 6: Explanation of depth-wise attention in convolutional architectures.

The layers in the variational hierarchy are augmented with two depth-wise attention blocks for constructing the context of the prior and posterior distribution. Figure 1 displays the computational block of an attentive variational layer. As shown in Figure 6, each layer also needs to emit attention-relevant features: the keys (k_l) and queries (s_l), along with the contexts (c_l). Equation (3) is revised for the attention-driven path in the decoder such that the context, its key, and the query are jointly learned:

[ [c_l, s_l, k_l] leftarrow T_l left (z_{l-1} oplus c_{l-1} right). text{ (4)}]

A formal description along with normalization schemes are provided in our paper.

Intra-Layer couplings: Non-local blocks

Intra-layer dependencies can be leveraged by interleaving non-local blocks [3] with the convolutions in the ResNet blocks of the architecture, also shown in Figure 1.


We evaluate Attentive VAEs on several public benchmark datasets of both binary and natural images. In Table 2, we show performance and training time of state-of-the-art, deep VAEs on CIFAR-10. CIFAR-10 is a 32×32 natural images dataset. Attentive VAEs achieve state-of-the-art likelihoods compared to other deep VAEs. More importantly, they do so with significantly fewer layers. Fewer layers mean decreased training and sampling time.

Model Layers Training Time
(GPU hours)
( – log p(x) )
Attentive VAE, 400 epochs [1] 16 272 2.82
Attentive VAE, 500 epochs [1] 16 336 2.81
Attentive VAE, 900 epochs [1] 16 608 2.79
NVAE [2] 30 440 2.91
Very Deep VAE [4] 45 288 2.87
Table 2: Comparison of performance and computational requirements of deep state-of-the art VAE models.

In Figures 8 and 9, we show reconstructed and novel images generated by attentive VAE.

Figure 8: Original & Reconstructed CIFAR-10 images.
Figure 9: Uncurated fantasy CIFAR-10 images.

The reason behind this improvement is that the attention-driven, long-range connections between layers lead to better utilization of the latent space. In Figure 7, we visualize the KL divergence per layer during training. As we see in (b), the KL penalty is evenly distributed among layers. In contrast, as shown in (a), the upper layers in a local, deep VAE are significantly less active. This confirms our hypothesis that the fully-connected factorizations of Equations (1) and (2) may not be supported by local models. In contrast, an attentive VAE dynamically prioritizes statistical dependencies between latent variables most critical to inference.

Figure 7: KL visualization in (a) a local
(b) and an attentive VAE.

Finally, attention-guided VAEs close the gap in the performance between variational models and expensive, autoregressive models. Comprehensive comparisons, quantitative and qualitative results are provided in our paper.


The expressivity of current deep probabilistic models can be improved by selectively prioritizing statistical dependencies between latent variables that are potentially distant from each other. Attention mechanisms can be leveraged to build more expressive variational distributions in deep probabilistic models by explicitly modeling both nearby and distant interactions in the latent space. Attentive inference reduces computational footprint by alleviating the need for deep hierarchies.


A special word of thanks is due to Christos Louizos for helpful pointers to prior works on VAEs, Katerina Fragkiadaki for helpful discussions on generative models and attention mechanisms for computer vision tasks, Andrej Risteski for insightful conversations on approximate inference, and Jeremy Cohen for his remarks on a late draft of this work. Moreover, we are very grateful to Radium Cloud for granting us access to computing infrastructure that enabled us to scale up our experiments. We also thank the International Society for Bayesian Analysis (ISBA) for the travel grant and the invitation to present our work as a contributed talk at the 2022 ISBA World Meeting. This material is based upon work supported by the Defense Advanced Research Projects Agency under award number FA8750-17-2-0130, and by the National Science Foundation under grant number 2038612. Moreover, the first author acknowledges support from the Alexander Onassis Foundation and from A. G. Leventis Foundation. The second author is supported by the National Science Foundation Graduate Research Fellowship Program under Grant No. DGE1745016 and DGE2140739.

DISCLAIMER: All opinions expressed in this post are those of the author and do not represent the views of CMU.


[1] Apostolopoulou I, Char I, Rosenfeld E, Dubrawski A. Deep Attentive Variational Inference. InInternational Conference on Learning Representations 2021 Sep 29.

[2] Vahdat A, Kautz J. Nvae: A deep hierarchical variational autoencoder. Advances in Neural Information Processing Systems. 2020;33:19667-79.

[3] Wang X, Girshick R, Gupta A, He K. Non-local neural networks. InProceedings of the IEEE conference on computer vision and pattern recognition 2018 (pp. 7794-7803).

[4] Child R. Very deep vaes generalize autoregressive models and can outperform them on images. arXiv preprint arXiv:2011.10650. 2020 Nov 20.

Want to learn more?

Check out:

Read More

An Experimental Design Perspective on Model-Based Reinforcement Learning

An Experimental Design Perspective on Model-Based Reinforcement Learning

$$ newcommand{statespace}{mathcal{S}} newcommand{actionspace}{mathcal{A}} newcommand{Rbb}{mathbb{R}} newcommand{Ebb}{mathbb{E}} newcommand{Hbb}{mathbb{H}} DeclareMathOperator{EIG}{EIG}$$

Reinforcement learning (RL) has achieved astonishing successes in domains where the environment is easy to simulate. For example, in games like Go or those in the Atari library, agents can play millions of games in the course of days to explore the environment and find superhuman policies [1]. However, transfer of these advances to broader real-world applications is challenging because the cost of exploration in many important domains is high. For example, while RL-based solutions for controlling plasmas in the nuclear fusion power generation are promising, there is only one operating tokamak in the United States and its resources are in excessive demand. Even the most data-efficient RL algorithms typically take thousands of samples to solve even moderately complex problems [2, 3], which is infeasible in plasma control and many other applications.

In contrast to conventional machine learning settings where the data is given to the decision maker, an RL agent can choose data to learn from. A natural idea for reducing data requirements is to choose data wisely such that a smaller amount of data is sufficient to perform well on a task. In this post, we describe a practical implementation of this idea. Specifically, we offer an answer to the following question: “If we were to collect one additional datapoint from anywhere in the state-action space to best improve our solution to the task, which one would it be?”. This question is related to a more fundamental idea in the design of intelligent agents with limited resources: such agents should be able to understand what information about the world is the most useful to help them accomplish their task. We see this work as a small step towards this bigger goal.

In our recent ICLR paper, An Experimental Design Perspective on Model-Based Reinforcement Learning, we derive an acquisition function that guides an agent in choosing data for the most successful learning. In doing this, we draw a connection between model-based reinforcement learning and Bayesian optimal experimental design (BOED) and evaluate data prospectively in the context of the task reward function and the current uncertainty about the dynamics. Our approach can be efficiently implemented under a conventional assumption of a Gaussian Process (GP) prior on the dynamics function. Typically in BOED, acquisition functions are used to sequentially design experiments that are maximally informative about some quantity of interest by repeatedly choosing the maximizer, running the experiment, and recomputing the acquisition function with the new data. Generalizing this procedure, we propose a simple algorithm that is able to solve a wide variety of control tasks, often using orders of magnitude less data than competitor methods to reach similar asymptotic performance.


In this work, we consider a RL agent that operates in an environment with unknown dynamics. This is a general RL model for decision-making problems in which the agent starts without any knowledge on how their actions impact the world. The agent then can query different state-action pairs to explore the environment and find a behavior policy that results in the best reward. For example, in the plasma control task, the states are various physical configurations of the plasma and possible actions include injecting power and changing the current. The agent does not have prior knowledge on how its actions impact the conditions of the plasma. Thus, it needs to quickly explore the space to ensure efficient and safe operation of the physical system—a requirement captured in the corresponding reward function.

Given that each observation of a state-action pair is costly, an agent needs to query as few state-action pairs as possible and in this work we develop an algorithm that informs the agent about which queries to make.

We operate under a setting we call transition query reinforcement learning (TQRL). In this setting, an agent can can sequentially query the dynamics at arbitrary states and actions to learn a good policy, essentially teleporting between states as it wishes. Traditionally, in the rollout setting, agents must simply choose actions and execute entire episodes to collect data. TQRL therefore is a slightly more informative form of access to the real environment.

Precise Definition of the Setup

More precisely: we address finite-horizon discrete time Markov Decision Processes (MDPs), which consist of a tuple (langle statespace, actionspace, T, r, p_0, Hrangle) where:

  • (statespace) is the state space
  • (actionspace) is the action space
  • (T) (dynamics) is the stochastic transition function that maps (state, action) pairs (statespace times actionspace) to a probability distribution over states (statespace)
  • (r: statespacetimesactionspace to Rbb) is a reward function
  • (p_0) is a start distribution over states (statespace)
  • (H) is an integer-valued horizon, that is, the number of steps the agent will perform in the environment

We assume that all of these parameters are known besides dynamics (T). The key quantity that defines the behaviour of the agent is its policy (pi: statespace to actionspace) that tells the agent what action to take in a given state. Thus, the overall goal of the agent is to find a policy that maximizes the cumulative reward over the agent’s trajectory (tau sim p(taumid pi, T) ) followed by the agent. Formally, a trajectory is simply a sequence of (state, action) pairs (tau = [s_0, a_0, dots, a_{H -1}, s_H]), where (a_i = pi(s_i)) is an action taken by the agent at step (i) and (s_i sim T(s_{i-1}, a_{i-1})) is a state in which agent was at time (i). Denoting the cumulative reward over trajectory (tau) as (R(tau)), the agent needs to solve the following optimization problem:

$$max_pi Ebb_{tau sim p(taumid pi, T)}left[R(tau)right].$$

We call an optimal policy for a given dynamics (T) as (pi^*_T). As we know the other parts of the MDP, to solve the optimization problem, we need to learn a model for the transition function (hat{T}).

Main Idea

Inspired by BOED and Bayesian algorithm execution (BAX) [11], we use ideas from information theory to motivate our method to effectively choose data points. Our goal is to sequentially choose queries ((s, a)) such that our agent quickly finds a good policy. We observe that to perform the task successfully, we do not need to approximate the optimal policy (pi^*) everywhere in the state space. Indeed, there could be regions of the state space that are unlikely to be visited by the optimal policy. Thus, we only need to approximate the optimal policy in the regions of the state space that are visited by the optimal policy.

Therefore, we choose to learn about (tau^*)—the optimal trajectory governed by the optimal policy (pi^*). This objective only requires data about the areas we believe (pi^*) will visit as it solves the task, so intuitively we should not “waste” samples on irrelevant regions in the state-action space. In plasma control, this idea might look like designing experiments in certain areas of the state and action space that will teach us the most about controlling plasma in the target regimes we need to maintain fusion.

We thus define our acquisition function to be the expected information gain about (tau^*) from sampling a point (T(s, a)) given a dataset (D): $$EIG_{tau^*}(s, a) = Hbb[tau^* mid D] – Ebb_{s’sim T(s, a)}left[Hbb[tau^*mid Dcup {(s, a, s’)}]right].$$ Intuitively, this quantity measures how much the additional data is expected to reduce the uncertainty (here given by Shannon entropy, denoted (Hbb)) about the optimal trajectory.

At a high level, following methods related to the InfoBAX algorithm [11], we can approximate this acquisition function by a three-step procedure:

  • First, we sample many possible dynamics functions from our posterior (e.g., functions that describe plasma evolution).
  • Second, we find optimal trajectories on each of the sampled dynamics functions without taking new data from the environment, as we can simulate controls on these dynamics.
  • Third, we compute predictive entropies of our model at ((s, a)) and of our model with additional data taken from each optimal trajectory. We can then subtract the trajectory-conditioned entropy from the original entropy.

This final step allows us to estimate the mutual information between (T(s, amid D)) and (tau^*), which is precisely the quantity we want. We give a more precise description of this below.

Given a task of regulating plasma to the goal conditions (green dot), we compute posterior samples of the optimal trajectory (paths in color). We can then estimate the point (red circle) with maximal mutual information with these optimal trajectories and query the dynamics at that point.

Computing (EIG_{tau^*}) via posterior function sampling

More formally: as (tau^*) is a high-dimensional object that implicitly assumes access to an optimal decision making policy, it is not obvious that the entropies involved in computing it will be easy to estimate. However, by making two additional assumptions and leveraging properties of mutual information, we can derive a practical method for estimating (EIG_{tau^*}). In particular, we need to assume that:

  • The dynamics (T) are drawn from a GP prior (P(T)), a fairly mild assumption since GPs are universal approximators [10].
  • (pi_T approx pi^*) for an MDP with known rewards and transition function (T), i.e., that a model-predictive control (MPC) policy using known dynamics will be close to optimal on those dynamics. This is not true in all settings and we investigate how crucial this assumption is to the performance of the algorithm in our experiments section.

We know from information theory that $$EIG_{tau^*}(s,a) = I(tau^*; T(s, a)) = Hbb[T(s, a)mid D] – Ebb_{tau^*sim P(tau^* mid D)}left[Hbb[T(s, a)mid Dcup tau^*]right],$$ where (I) refers to the mutual information. This expression is much easier to deal with, given a GP. We can use the fact that given a GP prior the marginal posterior at any point in the domain is a Gaussian in closed form, to exactly compute the left term (Hbb[T(s, a)mid D]). We compute the right term (Ebb_{tau^sim P(tau^ mid D)}left[Hbb[T(s, a)mid Dcup tau^*]right]) via a Monte Carlo approximation, sampling (T’ sim P(T’mid D)) (doable efficiently due to [6]) and then sampling trajectories (tausim P(tau mid T’, pi_{T’})) by executing MPC using (T’) as both the dynamics model used for planning and the dynamics function of the MDP used to sample transitions in (tau). As (tau) is a sequence of state-action transitions, it is essentially made of more data for our estimate of the transition model. So it is straightforward to compute the model posterior (P(T(s, a)mid Dcuptau^*)) (which again must be Gaussian) and read off the entropy of the prediction. The full Monte Carlo estimator is $$EIG_{tau^*}(s, a) approx Hbb[T(s, a)mid D] – frac{1}{n}sum_{i in [n]}Hbb[T(s, a)mid Dcup tau_i]$$ for (tau_i) sampled as described above.

In summary, we can estimate our acquisition function via the following procedure, which is subject to the two assumptions listed above:

  1. Sample many functions (T_isim P(Tmid D))
  2. Sample trajectories (tau_i sim P(taumid T_i, pi_{T_i})) by executing the MPC policy (pi_{T_i}) on the dynamics (T_i).
  3. Compute the entropies (Hbb[T(s, a)mid D]) and (Hbb[T(s, a)mid Dcup tau_i]), for all (i), using standard GP techniques.
  4. Compute the acquisition function using our Monte Carlo estimator.

Inspired by the main ideas of BOED and active learning, we give a simple greedy procedure which we call BARL (Bayesian Active Reinforcement Learning) for using our acquisition function to acquire data given some initial dataset:

  1. Compute (EIG_{tau^*}(s, a)) given the dataset for a large random set of state-action pairs. Samples of (tau^*) can be reused between these points.
  2. Sample (s’ sim T(s, a)) for the (s, a) that was found to maximize the acquisition function and add (s, a, s’) to the dataset.
  3. Repeat steps 1-2 until the query budget is exhausted. The evaluation policy is simply MPC on the GP posterior mean.

Does BARL reduce the data requirements of RL?

We evaluate BARL on the TQRL setting in 5 environments which span a variety of reward function types, dimensionalities, and amounts of required data. In this evaluation, we estimate the minimum amount of data an algorithm needs to learn a controller. The evaluation environments include the standard underactuated pendulum swing-up task, a cartpole swing-up task, the standard 2-DOF reacher task, a navigation problem where the agent must find a path across pools of lava, and a simulated nuclear fusion control problem where the agent is tasked with modulating the power injected into the plasma to achieve a target pressure.

To assess the performance of BARL in solving MDPs quickly, we assembled a group of reinforcement learning algorithms that represent the state of the art in solving continuous MDPs. We compare against model-based algorithms PILCO [7], PETS [2], model-predictive control with a GP (MPC), and uncertainty sampling with a GP ((EIG_T)), as well as model-free algorithms SAC [3], TD3 [8], and PPO [9]. Besides the uncertainty sampling (which operates in the TQRL setting and is directly comparable to BARL), these methods rely on the rollout setting for RL and are somewhat disadvantaged relative to BARL.

BARL clearly outperforms each of the comparison methods in nearly every problem in data efficiency. We see that simpler methods like (EIG_T) and MPC perform well on lower-dimensional problems like Lava Path, Pendulum, and Beta Tracking, but struggle with the higher-dimensional Reacher Problem. Model-free methods like SAC, TD3, and PPO are notably sample-hungry.

Sample Complexity: Median number of samples across 5 seeds required to reach the performance of MPC on the ground truth dynamics, averaged across 5 trials on our control environments. We record N/A when the median run is unable to solve the problem by the end of training.

After further investigation, we also find that models that used data chosen by BARL are more accurate on the datapoints required to solve the problem and less accurate on a randomly chosen test set of points than models using data chosen via (EIG_T). This implies that BARL is choosing the ‘right data’. Since the same model is used in both of these methods, there will inevitably be areas of the input space where each of the methods performs better than the other. BARL performs better on the areas that are needed to solve the problem.

We compare BARL and an uncertainty sampling baseline (EIG_T) on three criteria. In the left chart, we plot control performance as queries are made. (pi_T) is the performance of MPC with a perfect model. In the middle, we plot modeling errors for BARL vs (EIG_T) on the points where the model is queried in order to plan actions. On the right, we plot modeling errors on a uniform test set. BARL models the dynamics well on the points required to plan the optimal actions (middle) while not learning the dynamics well in general (right). This focus on choosing relevant datapoints allows BARL to solve the task quickly (left).


We believe (EIG_{tau^*}) is an important first step towards agents that think proactively about the data that they will acquire in the future. Though we are encouraged by the strong performance we have seen so far, there is substantial future work to be done. In particular, we are currently working to extend BARL to the rollout setting by planning actions that will lead to maximum information in the future. We also aim to solve problems of scaling these ideas to the high-dimensional state and action spaces that are necessary for many real-world problems.


[1] Mastering the game of Go with deep neural networks and tree search, Silver et al, Nature 2016

[2] Deep Reinforcement Learning in a Handful of Trials using Probabilistic Dynamics Models, Chua et al, Neurips 2018

[3] Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor, Haarnoja et al, ICML 2018

[4] Cross-Entropy Randomized Motion Planning, Kobilarov et al, RSS 2008

[5] Sample-efficient Cross-Entropy Method for Real-time Planning, Pinneri et al, CoRL 2020

[6] Efficiently Sampling Functions from Gaussian Process Posteriors, Wilson et al, ICML 2020

[7] PILCO: A Model-Based and Data-Efficient Approach to Policy Search, Deisenroth & Rasmussen, ICML 2011

[8] Addressing Function Approximation Error in Actor-Critic Methods, Fujimoto et al, ICML 2018

[9] Proximal Policy Optimization Algorithms, Schulman et al, 2017

[10] Universal Kernels, Micchelli et al, JMLR 2006

[11] Bayesian Algorithm Execution: Estimating Computable Properties of Black-box Functions Using Mutual Information, Neiswanger et al, ICML 2021

Read More

Two Experiments in Peer Review: Posting Preprints and Citation Bias

Two Experiments in Peer Review: Posting Preprints and Citation Bias

There is increasing interest in computer science and elsewhere to understand and improve peer review (see here for an overview). With this motivation, we conducted two experiments regarding peer review which we summarize in this blog post.

Pros and Cons of Posting Preprints Online

Motivation.  Authors posting preprints online before review in double-blind peer-review is a widely debated issue of policy as well as authors’ personal choice. This choice is especially challenging for authors who perceive that they may be at a disadvantage in the review process if their identity is revealed. By posting their preprints online they stand to gain publicity but may lose benefits of double-blind review. We substantiate this debate quantitatively by addressing two research questions. 

Research questions. We study the following research questions:

  • (Q1) What fraction of reviewers deliberately search for their assigned paper on the Internet? 
  • (Q2) For preprints posted online, what is the relation between the papers’ visibility and the ranking of the authors’ affiliations?

Methods. We conduct survey-based experiments in the ICML 2021 and EC 2021 conferences. 

  • (Q1) We conduct an anonymized survey, where we ask each reviewer whether they deliberately searched online for any of their assigned papers.
  • (Q2) We consider the papers submitted to ICML or EC which were also available online before review. We survey relevant reviewers to assess the visibility of these papers: we ask them whether they have seen these papers online outside of the reviewing context. Finally, we compute a correlation between papers’ visibility and associated authors’ affiliations’ ranking.

Key findings. We report two main insights:

  • (Q1) More than a third of the respondents self-report searching online for their assigned paper in both ICML and EC.
  • (Q2) We find a weak positive correlation: preprints from better-ranked affiliations enjoy a higher visibility. In ICML, the correlation coefficient is 0.06 and is statistically significant; in EC, the correlation coefficient is 0.05 and is not statistically significant. In particular, papers associated with the top-10-ranked affiliations had a visibility of about 11% in ICML and 22% in EC, whereas the remaining papers had a visibility of 7% and 18% respectively.

Implications. Conference organizers looking to design blinding policies and authors looking to post preprints online can use our findings to gauge the tradeoffs involved.


Citation Bias

Motivation.  Many anecdotes suggest that including citations to the works of potential reviewers is a good (albeit unethical) way to increase the acceptance chances of a manuscript. 

Research question. Does the citation of a reviewer’s work in a submission cause the reviewer to be positively biased towards the submission, that is, cause a shift in reviewer’s evaluation that goes beyond the genuine change in the submission’s scientific merit?

Methods. We pair cited and uncited reviewers for each submitted paper and then carefully analyze the differences in their scores. Our analysis accounts for the many confounding factors that may exist. By pairing reviewers, we alleviate the confounding factor of “paper quality” as both cited and uncited reviewers review the same paper. We also control for confounders related to reviewer identities by accommodating various associated aspects such as the reviewers’ expertise and preferences in reviewing papers. Finally, we analyze reviews of uncited reviewers to exclude cases in which a reviewer genuinely decreases their evaluation of a paper because it fails to cite their own relevant past work. 

Key findings. Our findings suggest that citation bias exists, and papers enjoy higher scores from a cited reviewer. Due to this bias, the expected increase in a cited reviewer’s score is 0.16 (on a 6 point scale) in ICML and 0.23 (on a 5 point scale) in EC. For reference, a one-point increase of a score by a single reviewer improves the position of a submission by 11% on average.

Implications. We detect and quantify the strength of citation bias in peer review, informing stakeholders of the presence of the bias. Our work also raises an important open problem of mitigating citation bias.



The post is based on joint works with Ivan Stelmakh, Ryan Liu,  Xinwei Shen, Marina Meila, Shuchi Chawla, Federico Echenique, and Nihar B. Shah.

Read More

Assessing Generalization of SGD via Disagreement

Assessing Generalization of SGD via Disagreement

Imagine training a deep network twice with two different random seeds on the same data, and then measuring the rate at which they disagree on unlabeled test points. Naively, they can disagree with one another with probability anywhere between zero and twice the error rate. But surprisingly, in practice, we observe that the disagreement and test error of deep neural network are remarkably close to each other. The variable (y) refers to the average generalization error of the two models and the variable (x) refers to the disagreement of the two models.

Estimating the generalization error of a model — how well the model performs on unseen data — is a fundamental component in any machine learning system. Generalization performance is traditionally estimated in a supervised manner, by dividing the labeled data into a training set and test set. However, high-quality labels are usually costly and, ideally, we would like to use all of them to train the model. On the other hand, in many real-world settings, a large amount of unlabeled data is readily available. How can we tap into the rich information in these unlabeled data and leverage them to assess a model’s performance without labels?

In this work (full paper), we demonstrate that a simple procedure can accurately estimate the generalization error with only unlabeled data. This result reveals a surprising fact about how neural networks make mistakes and their connection to calibration through an identity we call Generalization Disagreement Equality (GDE).

A surprising observation

Stochastic gradient descent (SGD) is perhaps the most popular optimization algorithm for deep neural networks. Due to the non-convex nature of the deep neural network’s optimization landscape, different runs of SGD will find different solutions. As a result, if the solutions are not perfect, they will disagree with each other on some of the unseen data. This disagreement can be harnessed to estimate generalization error without labels:

  1. Given a model, run SGD with the same hyperparameters but different random seeds on the training data to get two different solutions.
  2. Measure how often the networks’ predictions disagree on a new unlabeled test dataset.

We find that the disagreement rate is approximately equal to the average test error over the two models. Our observation builds on the phenomenon reported by Nakkiran and Bansal (2020) [1]: given two networks of the same architecture trained to zero training error on two independently drawn datasets of the same size, the disagreement rate of the pair on the test dataset is nearly equal to the average test error. Our observation generalizes prior work by showing that the same phenomenon holds for small changes in hyperparameters and, more importantly, the same dataset, which makes the procedure relevant for practical applications.

This procedure estimates the test error with unlabeled data, but the mechanisms behind it are not immediately obvious. Let (h(x)) be the prediction of classifier (h(x)) and (y) be the true label. By the triangle inequality, the disagreement rate can be anywhere between 0 and 2 times the test error: ( 0 leq mathbb{E}[h(x) ne h’(x)] leq mathbb{E}[h(x) ne y] + mathbb{E}[h’(x) ne y]). Given that the models observe the same data, we may expect the models to extract similar knowledge from the data and consequently make similar errors, which would make the disagreement rate lower than the test error. Yet, we find that in practice, the disagreement rate approximates the test error without any proportionality constant. Why is this the case?

The final parameters found by SGD depend on many sources of randomness: 1) random initialization, 2) random ordering of a fixed training dataset, and/or 3) random sampling of training data. To understand this phenomenon, we analyze what kind of randomness is responsible for this peculiar property. It is possible to have other sources of randomness, such as dropout, which are not studied in this work.

We study the phenomenon by observing how the behavior of disagreement changes when the different sources of randomness are isolated. For example, when examining the effect of different initialization,s we fix the dataset and the order in which the dataset is presented to the model and only change the random initialization. We observe that across different types of randomness, disagreement remains consistently close to the test error. This approach is particularly useful because we can use the same training data to train the two copies of the model. In addition, we empirically observe the phenomenon across a wide range of model architectures (ResNet, VGG, and fully-connected networks) and several popular image recognition benchmarks (MNIST, SVHN, Cifar10, and Cifar100).

GDE on CIFAR-10: The scatter plots of pair-wise model disagreement (x-axis) vs the test error (y-axis) of the different ResNet18 trained on CIFAR10. The dashed line is the diagonal line where disagreement equals the test error. Orange dots represent models that use data augmentation. The first two plots correspond to pairs of networks trained on independent datasets, and in the last two plots, on the same dataset.

Furthermore, estimating generalization error is especially important when the test distribution is different from the training distribution. On that note, we see some promising observations in the PACS data [2]. This dataset consists of 4 different environments (Photo, Art, Cartoon, Sketch) of different objects. For each environment, we trained two ResNet18 models on that environment with different initialization and order of the data and measured their disagreement rate on the remaining environments. We observe that similar phenomena hold across many (but not all) pairs of environments,  suggesting that the technique might be adaptable for estimating generalization error under distribution shift.

GDE under distribution shift: The scatter plots of pair-wise model disagreement (x-axis) vs the test error (y-axis) of the different ResNet50 trained on PACS. Each plot corresponds to models evaluated on the domain specified in the title. The source/training domain is indicated by different marker shapes.

Generalization-Disagreement Equality

We will now theoretically investigate the sufficient condition under which the disagreement rate equals generalization error.

In the deep learning literature, it is well-known that ensembles of SGD-trained deep networks are well-calibrated over the training distribution [3]. An ensemble is well-calibrated if the confidence ( tilde{h}_k(x) = mathbb{E}_{h_k sim H_A}[h(x)] ) it outputs for class ( k ) matches the expected accuracy i.e. ( P(y mid tilde{h}_k(x) = p) = p). There exist various formalisms for calibration. In this paper, we show that if an ensemble satisfies class-wise calibration for a distribution D (a more general form of calibration is discussed in the paper), then (mathbb{E}_{h, h’}[mathbb{E}_D[h(x) neq h'(x)]] = mathbb{E}_h[mathbb{E}_D[h(x) neq y]]) with strict equality. We refer to this equality as the Generalization Disagreement Equality (GDE).

Proof Sketch

The proof sketch we show here focuses on a special case, where class-wise calibration holds. Consider binary classification. Say that we have a distribution over the hypotheses ( H_A) given SGD algorithm (A) and we sample two hypotheses (h, h’) from this distribution. Now given some test data (D), we partition it by the confidence output by the ensemble i.e. (D_q = P(X | tilde{h}_0(x) = q)). For any fixed (x) in the support of (D_q), the disagreement rate in expectation over (h, h’) is

$$ E_{h,h’}[h(x) ne h'(x)] \ = P(tilde{h}(x) = 1)P(tilde{h}(x) = 0) + P(tilde{h}(x) = 0)P(tilde{h}(x) = 1) \ = q(1-q) + (1-q)q = 2q(1-q).$$

Additionally, note that for any (x in D), the expected error equals (tilde{h}_{1 – y}(x)). Since the ensemble is also class-wise calibrated, for any (x in D_q) the expected error over (h in H_A) is

$$E_h[h(x) ne y] \= tilde{h}_0(x)P(y = 1 | tilde{h}(x) = 1) + tilde{h}_1(x)P(y = 0 | tilde{h}(x) = 0) \ = q(1-q) + (1-q)q = 2q(1-q)$$

We proved that for each calibration level set, GDE holds. Since the disagreement rate 2q(1-q) equals the expected error 2q(1-q), we conclude that the expected error equals the disagreement rate. Since the disagreement rate 2q(1-q) equals the expected error 2q(1-q), we conclude that the expected error equals the disagreement rate. So in expectation over D, GDE holds.

Note the theorem only shows that the test error and expected disagreement are equal in expectation over all the models learned by SGD, but does not necessarily explain why the equality seems to hold for as few as a single pair of training runs. It captures the observation’s essence, but explaining why a single pair of models is sufficient is still an open problem. Intuitively, this can be true if the distribution of disagreement has an unusually small variance. In our experiments, we do observe very small variance for disagreement, but a rigorous theoretical discussion of why the variance is small remains unsettled.

Further, in practice, no models are perfectly calibrated, but we could still characterize how the deviation from calibration affects the GDE. We show that calibration error upper bounds the absolute difference between the expected disagreement and the expected calibration error. This inequality generalizes the GDE and implies that as long as the ensemble is reasonably calibrated, we can expect a good estimation of generalization error from disagreement. Finally, we empirically validate that this assumption often holds in practice by training 100 copies of the same model to produce an ensemble that is faithful to the true ensemble and measure their calibration error. The results show that these ensembles are indeed well-calibrated.


We have presented a method for estimating the generalization error of black-box deep neural networks with only unlabeled data, as well as some theoretical motivation for why the method works. Specifically, we showed that if the ensembles of the models trained by SGD are well-calibrated, the expected disagreement is equal to the expected test error. However, in practice, merely a single pair of models is sufficient for accurately estimating the generalization error. In the broader picture, this result marks a departure from the more traditional approaches of studying generalization and points to the tantalizing possibility of leveraging unlabeled data to estimate the generalization error. We are excited about the results we presented in the paper, but we are even more excited about the questions that we did not answer in the paper. We hope this work will encourage future work to investigate more unconventional ways of understanding generalization in deep learning.


[1] Nakkiran and Bansal. Distributional Generalization: A New Kind of Generalization. 2020.
[2] Li et al. Deeper, Broader and Artier Domain Generalization. IEEE Intl. Conf. on Computer Vision (ICCV), 2017.
[3] Lakshminarayanan et al. Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles. Conference on Neural Information Processing Systems (NeurIPS), 2017.

Read More

Why Spectral Normalization Stabilizes GANs: Analysis and Improvements

Why Spectral Normalization Stabilizes GANs: Analysis and Improvements

Figure 1: Training instability is one of the biggest challenges in training GANs. Despite the existence of successful heuristics like Spectral Normalization (SN) for improving stability, it is poorly-understood why they work. In our research, we theoretically explain why SN stabilizes GAN training. Using these insights, we further propose a better normalization technique for improving GANs’ stability called Bidirectional Scaled Spectral Normalization.

Generative adversarial networks (GANs) are a class of popular generative models enabling many cutting-edge applications such as photorealistic image synthesis. Despite their tremendous success, GANs are notoriously unstable to train—small hyper-parameter changes and even randomness in optimization can cause training to fail altogether, which leads to poor generated samples. One empirical heuristic that is widely used to stabilize GAN training is spectral normalization (SN) (Figure 2). Although it is very widely adopted, little is understood about why it works, and therefore there is little analytical basis for using it, configuring it, and more importantly, improving it.

Figure 2: Spectral normalization divides the weights (W_i) by their spectral norms (sigma(W_i)) (i.e., the largest singular value of (W_i)).

In this post, we discuss our recent work at NeurIPS 2021. We prove that spectral normalization controls two well-known failure modes of training stability: exploding and vanishing gradients. More interestingly, we uncover a surprising connection between spectral normalization and neural network initialization techniques, which not only help explain how spectral normalization stabilizes GANs, but also motivate us to design Bidirectional Scaled Spectral Normalization (BSSN), a simple change to spectral normalization that yields better stability than SN (Figure 3). 

Figure 3: The interesting connections we find between spectral normalization and prior initialization techniques: (1) LeCun initialization can help explain why spectral normalization avoids vanishing gradients; (2) Motivated by newer initialization techniques (Xavier and Kaiming), we propose BSSN to further improve spectral normalization.

Exploding and vanishing gradients cause training instability

Exploding and vanishing gradients describe a problem in which gradients either grow or shrink rapidly during training. It is known in the community that these phenomena are closely related to the instability of GANs. Figure 4 shows an illustrating example: when exploding and vanishing gradients happen, the sample quality measured by inception score (higher is better) deteriorates rapidly.

Figure 4: The close connection between gradient scales and training instability. Left: the gradient norm during the training of three GANs on CIFAR-10, either with exploding, vanishing, or stable gradients. Right: the inception score (measuring sample quality; the higher, the better) of these three GANs. We see that the GANs with bad gradient scales (exploding or vanishing) have worse sample quality as measured by inception score.

In the next section, we will show how spectral normalization alleviates exploding and vanishing gradients, which may explain its success.

How spectral normalization mitigates exploding gradients

The fact that spectral normalization prevents gradient explosion is not too surprising. Intuitively, it achieves this by limiting the ability of weight tensors to amplify inputs in any direction. More precisely, when the spectral norm of weights = 1 (as ensured by spectral normalization), and the activation functions are 1-Lipschitz (e.g., (Leaky)ReLU), we show that

$$ | text{gradient} |_{text{Frobenius}} leq sqrt{text{number of layers}} cdot | text{input} |.$$

(Please refer to the paper for more general results.) In other words, the gradient norm of spectrally normalized GANs cannot exceed a strict bound. This explains why spectral normalization can mitigate exploding gradients. 

Note that this good property is not unique to spectral normalization—our analysis can also be used to show the same result for other normalization and regularization techniques that control the spectral norm of weights, such as weight normalization and orthogonal regularization. The more surprising and important fact is that spectral normalization can also control vanishing gradients at the same time, as discussed below.

How spectral normalization mitigates vanishing gradients

To understand why spectral normalization prevents gradient vanishing, let’s take a brief detour to the world of neural network initialization. In 1996, LeCun, Bottou, Orr, and Müller introduced a new initialization technique (commonly called LeCun initialization) that aimed to prevent vanishing gradients. It achieved this by carefully setting the variance of the weight initialization distribution as $$text{Var}(W)=left(text{fan-in of the layer}right)^{-1},$$ where fan-in of the layer means the number of input connections from the previous layer (e.g., in fully-connected networks, fan-in of the layer is the number of neurons in the previous layer). LeCun et al. showed that

  • If the weight variance is larger than ( left(text{fan-in of the layer}right)^{-1} ), the internal outputs of the neural networks could be saturated by bounded activation or loss functions (e.g., sigmoid), which causes vanishing gradients.
  • If the weight variance is too small, gradients will also vanish because gradient norms are bounded by the scale of the weights.

We show theoretically that spectral normalization controls the variance of weights in a way similar to LeCun initialization. More specifically, for a weight matrix (Win mathbb{R}^{mtimes n}) with i.i.d. entries from a zero-mean Gaussian distribution (common for weight initialization), we show that

$$ text{Var}left( text{spectrally-normalized } W right) ~~~text{ is on the order of }~~~left( maxleft{ m,n right} right)^{-1} $$

(Please refer to the paper for more general results.) This result has separate implications on the fully-connected layers and convolutional layers:

  • For fully-connected layers with a fixed width across hidden layers, (maxleft{m,nright} =m =n =text{fan-in of the layer} ). Therefore, spectrally-normalized weights have exactly the desired variances as LeCun initialization!
  • For convolutional layers, the weight (i.e., convolution kernel) is actually a 4-dimensional tensor: ( W in mathbb{R}^{c_{out} c_{in} k_w k_h} ), where (c_{out},c_{in},k_w,k_h) denote the number of output channels, the number of input channels, kernel width, and kernel hight respectively. The popular implementation of spectral normalization normalizes the weights by ( frac{W}{sigmaleft( W_{c_{out} times left(c_{in} k_w k_hright)} right)} ) where (sigmaleft( W_{c_{out} times left(c_{in} k_w k_hright)} right)) is the spectral norm on the reshaped weight, i.e., ( m= c_{out}, n=c_{in} k_w k_h). In hidden layers, usually (maxleft{m,nright} =maxleft{c_{out}, c_{in} k_w k_hright} =c_{in} k_w k_h=text{fan-in of the layer} ). Therefore, spectrally-normalized convolutional layers also maintain the same desired variances as LeCun initialization!

Whereas LeCun initialization only controls the gradient vanishing problem at the beginning of training, we observe empirically that spectral normalization preserves this nice property throughout training (Figure 5). These results may help explain why spectral normalization controls vanishing gradients during GAN training.

Figure 5: Parameter variances throughout training. The blue lines show the parameter variances of different layers when SN is applied, and the orange line shows our theoretical bound at initialization: (left( maxleft{ m,n right} right)^{-1}). The parameter variances of SN are close to the bound throughout training.

How to improve spectral normalization

The next question we ask is: can we use the above theoretical insights to improve spectral normalization? Many advanced initialization techniques have been proposed in recent years to improve LeCun initialization, including Xavier initialization and Kaiming initialization. They derived better parameter variances by incorporating more realistic assumptions into the analysis. We propose Bidirectional Scaled Spectral Normalization (BSSN) so that the parameter variances parallel the ones in these newer initialization techniques:

  • Xavier initialization. The idea of Xavier initialization is to set the variance of parameter initialization distribution to be (text{Var}(W)=left(frac{text{fan-in of the layer} + text{fan-out of the layer}}{2}right)^{-1},) which they show to not only control the variances of outputs (as in LeCun initialization), but also the variances of backpropagated gradients, giving better gradient values. We propose Bidirectional Spectral Normalization that normalizes convolutional kernels by (frac{W}{left( sigmaleft(W_{c_{out} times left(c_{in} k_w k_hright)}right) +sigmaleft(W_{c_{in} times left(c_{out} k_w k_hright)}right) right)/2}~~~~~~~). We show that by doing this, the parameter variances mimic the ones in Xavier initialization.
  • Kaiming initialization. The analysis in LeCun and Xaiver initialization did not cover activation functions like (Leaky)ReLU which decrease the scales of the network outputs. To cancel out the effect of (Leaky)ReLU, Kaiming initialization scales up the variances in LeCun or Xavier initilization by a constant. Motivated from it, we propose to scale the above normalization formula with a tunable constant (c): (ccdotfrac{W}{left( sigmaleft(W_{c_{out} times left(c_{in} k_w k_hright)}right) +sigmaleft(W_{c_{in} times left(c_{out} k_w k_hright)}right) right)/2}~~~~~~~).

BSSN can be easily plugged into GAN training with minimal code changes and little computational overhead. We compare spectral normalization and BSSN on several image datasets, using standard metrics for image quality like inception score (higher is better) and FID (lower is better). We show that simply replacing spectral normalization with BSSN not only makes GAN training more stable (Figure 6), but also improves sample quality (Table 1). Generated samples from BSSN are in Figure 7.

Table 1: Inception score (IS) and FID. Our proposed BSSN method outperforms spectral normalization in sample quality metrics across different datasets by a large margin.
Figure 6: Inception score training curve in CIFAR10. Spectral normalization (in blue) exhibits (one type of) training instability: the sample quality drops as training proceeds. Our proposed BSSN (in orange) does not have the problem.
Figure 7: Generated samples from BSSN in CIFAR10 dataset.


This post only covers a portion of our theoretical and empirical results. Please refer to our NeurIPS 2021 paper and code if you are interested in learning more.

Read More