On Privacy and Personalization in Federated Learning: A Retrospective on the US/UK PETs Challenge

On Privacy and Personalization in Federated Learning: A Retrospective on the US/UK PETs Challenge

TL;DR: We study the use of differential privacy in personalized, cross-silo federated learning (NeurIPS’22), explain how these insights led us to develop a 1st place solution in the US/UK Privacy-Enhancing Technologies (PETs) Prize Challenge, and share challenges and lessons learned along the way. If you are feeling adventurous, checkout the extended version of this post with more technical details!


How can we be better prepared for the next pandemic?

Patient data collected by groups such as hospitals and health agencies is a critical tool for monitoring and preventing the spread of disease. Unfortunately, while this data contains a wealth of useful information for disease forecasting, the data itself may be highly sensitive and stored in disparate locations (e.g., across multiple hospitals, health agencies, and districts).

In this post we discuss our research on federated learning, which aims to tackle this challenge by performing decentralized learning across private data silos. We then explore an application of our research to the problem of privacy-preserving pandemic forecasting—a scenario where we recently won a 1st place, $100k prize in a competition hosted by the US & UK governments—and end by discussing several directions of future work based on our experiences.


Part 1: Privacy, Personalization, and Cross-Silo Federated Learning

Federated learning (FL) is a technique to train models using decentralized data without directly communicating such data. Typically:

  • a central server sends a model to participating clients;
  • the clients train that model using their own local data and send back updated models; and
  • the server aggregates the updates (e.g., via averaging, as in FedAvg)

and the cycle repeats. Companies like Apple and Google have deployed FL to train models for applications such as predictive keyboards, text selection, and speaker verification in networks of user devices.

However, while significant attention has been given to cross-device FL (e.g., learning across large networks of devices such as mobile phones), the area of cross-silo FL (e.g., learning across a handful of data silos such as hospitals or financial institutions) is relatively under-explored, and it presents interesting challenges in terms of how to best model federated data and mitigate privacy risks. In Part 1.1, we’ll examine a suitable privacy granularity for such settings, and in Part 1.2, we’ll see how this interfaces with model personalization, an important technique in handling data heterogeneity across clients.

1.1. How should we protect privacy in cross-silo federated learning?

Although the high-level federated learning workflow described above can help to mitigate systemic privacy risks, past work suggests that FL’s data minimization principle alone isn’t sufficient for data privacy, as the client models and updates can still reveal sensitive information.

This is where differential privacy (DP) can come in handy. DP provides both a formal guarantee and an effective empirical mitigation to attacks like membership inference and data poisoning. In a nutshell, DP is a statistical notion of privacy where we add randomness to a query on a “dataset” to create quantifiable uncertainty about whether any one “data point” has contributed to the query output. DP is typically measured by two scalars ((varepsilon, delta))—the smaller, the more private.

In the above, “dataset” and “data point” are in quotes because privacy granularity matters. In cross-device FL, it is common to apply “client-level DP” when training a model, where the federated clients (e.g., mobile phones) are thought of as “data points”. This effectively ensures that each participating client/mobile phone user remains private.

However, while client-level DP makes sense for cross-device FL as each client naturally corresponds to a person, this privacy granularity may not be suitable for cross-silo FL, where there are fewer (2-100) ‘clients’ but each holds many data subjects that require protection, e.g., each ‘client’ may be a hospital, bank, or school with many patient, customer, or student records.

Visualizing client-level DP vs. silo-specific example-level DP in federated learning.

In our recent work (NeurIPS’22), we instead consider the notion of “silo-specific example-level DP” in cross-silo FL (see figure above). In short, this says that the (k)-th data silo may set its own ((varepsilon_k, delta_k)) example-level DP target for any learning algorithm with respect to its local dataset.

This notion is better aligned with real-world use cases of cross-silo FL, where each data subject contributes a single “example”, e.g., each patient in a hospital contributes their individual medical record. It is also very easy to implement: each silo can just run DP-SGD for local gradient steps with calibrated per-step noise. As we discuss below, this alternate privacy granularity affects how we consider modeling federated data to improve privacy/utility trade-offs.

1.2. The interplay of privacy, heterogeneity, and model personalization

Let’s now look at how this privacy granularity may interface with model personalization in federated learning.

Model personalization is a common technique used to improve model performance in FL when data heterogeneity (i.e. non-identically distributed data) exists between data silos.1 Indeed, existing benchmarks suggest that realistic federated datasets may be highly heterogeneous and that fitting separate local models on the federated data are already competitive baselines.

When considering model personalization techniques under silo-specific example-level privacy, we find that a unique trade-off may emerge between the utility costs from privacy and data heterogeneity (see figure below):

  • As DP noises are added independently by each silo for its own privacy targets, these noises are reflected in the silos’ model updates and can thus be smoothed out when these updates are averaged (e.g. via FedAvg), leading to a smaller utility drop from DP for the federated model.
  • On the other hand, federation also means that the shared, federated model may suffer from data heterogeneity (“one size does not fit all”).
Consider two interesting phenomena illustrated by a simple experiment where all silos use (ε = 1, δ = 1e-7) example-level DP for their own dataset. Left: FedAvg can smooth out the independent, per-silo DP noise and lead to smaller average utility drop from DP; Mid/Right: Local finetuning (FedAvg followed by further local training) may not improve utility as expected, as the effect of noise reduction is removed when finetuning begins.

This “privacy-heterogeneity cost tradeoff” is interesting because it suggests that model personalization can play a key and distinct role in cross-silo FL. Intuitively, local training (no FL participation) and FedAvg (full FL participation) can be viewed as two ends of a personalization spectrum with identical privacy costs—silos’ participation in FL itself does not incur privacy costs due to DP’s robustness to post-processing—and various personalization algorithms (finetuning, clustering, …) are effectively navigating this spectrum in different ways.

If local training minimizes the effect of data heterogeneity but enjoys no DP noise reduction, and contrarily for FedAvg, it is natural to wonder whether there are personalization methods that lie in between and achieve better utility. If so, what methods would work best?

Privacy-utility tradeoffs for representative personalization methods under silo-specific example-level DP across four cross-silo FL datasets. Finetune: a common baseline for model personalization; IFCA/HypCluster: hard clustering of client models; Ditto: a recently proposed method for personalized FL. MR-MTL: mean-regularized multi-task learning, which consistently outperform other baselines.

Our analysis points to mean-regularized multi-task learning (MR-MTL) as a simple yet particularly suitable form of personalization. MR-MTL simply asks each client (k) to train its own local model (w_k), regularize it towards the mean of others’ models (bar w) via a penalty (fraclambda 2 | w_k – bar w |_2^2 ), and keep (w_k) across rounds (i.e. client is stateful). The mean model (bar w) is maintained by the FL server (as in FedAvg) and may be updated in every round. More concretely, each local update step takes the following form:

The hyperparameter (lambda) serves as a smooth knob between local training and FedAvg: (lambda = 0) recovers local training, and a larger (lambda) forces the personalized models to be closer to each other (intuitively, “federate more”).

MR-MTL has some nice properties in the context of private cross-silo FL:

  1. Noise reduction is attained throughout training via the soft proximity constraint towards an averaged model;
  2. The mean-regularization itself has no privacy overhead;2 and
  3. (lambda) provides a smooth interpolation along the personalization spectrum.

Why is the above interesting? Consider the following experiment where we try a range of (lambda) values roughly interpolating local training and FedAvg. Observe that we could find a “sweet spot” (lambda^ast) that outperforms both of the endpoints under the same privacy cost. Moreover, both the utility advantage of MR-MTL((lambda^ast)) over the endpoints, and (lambda^ast) itself, are larger under privacy; intuitively, this says that silos are encouraged to “federate more” for noise reduction.

Test acc ± std of MR-MTL on a simple cross-silo FL task with varying λ. A “sweet spot” λ* exists where it outperforms both ends of the personalization spectrum (local / FedAvg) under the same privacy budget. Results correspond to ε = 0.5 in the first subplot in the privacy-utility tradeoff curves. Ditto resembles MR-MTL in terms of the training procedure and exhibits similar interpolation behaviors, but it suffers from privacy overhead due to 2x local training iterations.

The above provides rough intuition on why MR-MTL may be a strong baseline for private cross-silo FL and motivates this approach for a practical pandemic forecasting problem, which we discuss in Part 2. Our full paper delves deeper into the analyses and provides additional results and discussions!


Part 2: Federated Pandemic Forecasting at the US/UK PETs Challenge

Illustration of the pandemic forecasting problem at the US/UK PETs challenge (image source).

Let’s now take a look at a federated pandemic forecasting problem at the US/UK Privacy-Enhancing Technologies (PETs) prize challenge, and how we may apply the ideas from Part 1.

2.1. Problem setup

The pandemic forecasting problem asks the following: Given a person’s demographic attributes (e.g. age, household size), locations, activities, infection history, and the contact network, what is the likelihood of infection in the next (t_text{pred}=7) days? Can we make predictions while protecting the privacy of individuals? Moreover, what if the data are siloed across administrative regions?

There’s a lot to unpack in the above. First, the pandemic outbreak problem follows a discrete-time SIR model (Susceptible → Infectious → Recovered) and we begin with a subset of the population infected. Subsequently,

  • Each person goes about their usual daily activities and gets into contact with others (e.g. at a shopping mall)—this forms a contact graph where individuals are nodes and direct contacts are edges;
  • Each person may get infected with different risk levels depending on a myriad of factors—their age, the nature and duration of their contact(s), their node centrality, etc.; and
  • Such infection can also be asymptomatic—the individual can appear in the S state while being secretly infectious.

The challenge dataset models a pandemic outbreak in Virginia and contains roughly 7.7 million nodes (persons) and 186 million edges (contacts) with health states over 63 days; so the actual contact graph is fairly large but also quite sparse.

There are a few extra factors that make this problem challenging:

  1. Data imbalance: less than 5% of people are ever in the I or R state and roughly 0.3% of people became infected in the final week.
  2. Data silos: the true contact graph is cut along administrative boundaries, e.g., by grouped FIPS codes/counties. Each silo only sees a local subgraph, but people may still travel and make contacts across multiple regions! in In the official evaluation, the population sizes can also vary by more than 10(times) across silos.
  3. Temporal modeling: we are given the first (t_text{train} = 56) days of each person’s health states (S/I/R) and asked to predict individual infections any time in the subsequent ( t_text{pred} = 7 ) days. What is a training example in this case? How should we perform temporal partitioning? How does this relate to privacy accounting?
  4. Graphs generally complicate DP: we are often used to ML settings where we can clearly define the privacy granularity and how it relates to an actual individual (e.g. medical images of patients). This is tricky with graphs: people can make different numbers of contacts each of different natures, and their influence can propagate throughout the graph. At a high level (and as specified by the scope of sensitive data of the competition), what we care about is known as node-level DP—the model output is “roughly the same” if we add/remove/replace a node, along with its edges.
2.2. Applying MR-MTL with silo-specific example-level privacy

One clean approach to the pandemic forecasting problem is to just operate on the individual level and view it as (federated) binary classification: if we could build a feature vector to summarize an individual, then risk scores are simply the sigmoid probabilities of near-term infection.

Of course, the problem lies in what that feature vector (and the corresponding label) is—we’ll get to this in the following section. But already, we can see that MR-MTL with silo-specific example-level privacy (from Part 1) is a nice framework for a number of reasons:

  • Model personalization is likely needed as the silos are large and heterogeneous by construction (geographic regions are unlike to all be similar).
  • Privacy definition: There are a small number of clients, but each holds many data subjects, and client-level DP isn’t suitable.
  • Usability, efficiency, and scalability: MR-MTL is remarkably easy to implement with minimal resource overhead (over FedAvg and local training). This is crucial for real-world applications.
  • Adaptability and explainability: The framework is highly adaptable to any learning algorithm that can take DP-SGD-style updates. It also preserves the explainability of the underlying ML algorithm as we don’t obfuscate the model weights, updates, or predictions.

It is also helpful to look at the threat model we might be dealing with and how our framework behaves under it; the interested reader may find more details in the extended post!

2.3. Building training examples
Illustration of iterative, ℓ-hop neighborhood aggregation. Here, green nodes are the sampled neighbors and the yellow node can’t be sampled.

We now describe how to convert individual information and the contact network into a tabular dataset for every silo ( k ) with ( n_k ) nodes.

Recall that our task is to predict the risk of infection of a person within ( t_text{pred} = 7) days, and that each silo only sees its local subgraph. We formulate this via a silo-specific set of examples ( ( X_k in mathbb R^{n_k times d}, Y_k in mathbb {0, 1}^{n_k} ) ), where the features ( {X_k^{(i)} in mathbb R^d} ) describe the neighborhood around a person ( i ) (see figure) and binary label ( {Y_k^{(i)}} ) denotes if the person become infected in the next ( t_text{pred} ) days.

Each example’s features ( X_k^{(i)} ) consist of the following:

(1) Individual features: Basic (normalized) demographic features like age, gender, and household size; activity features like working, school, going to church, or shopping; and the individual’s infection history as concatenated one-hot vectors (which depends on how we create labels; see below).

(2) Contact features: One of our key simplifying heuristics is that each node’s (ell)-hop neighborhood should contain most of the information we need to predict infection. We build the contact features as follows:

  • Every sampled neighbor (v) of a node (u) is encoded using its individual features (as above) along with the edge features describing the contact—e.g. the location, the duration, and the activity type.
  • We use iterative neighborhood sampling (figure above), meaning that we first select a set of ( S_1 ) 1-hop neighbors, and then sample (S_2) 2-hop neighbors adjacent to those 1-hop neighbors, and so on. This allows reusing 1-hop edge features and keeps the feature dimension (d) low.
  • We also used deterministic neighborhood sampling—the same person always takes the same subset of neighbors. This drastically reduces computation as the graph/neighborhoods can now be cached. For the interested reader, this also has implications on privacy accounting.
Illustration of the tabularized features. Red/pink blocks are individual (node) features and green blocks are edge features describing the contact. Each blue block denotes the combined features of a single social contact (the neighboring node & the edge), and contacts of higher degrees are concatenated.

The figure above illustrates the neighborhood feature vector that describes a person and their contacts for the binary classifier! Intriguingly, this makes the per-silo models a simplified variant of a graph neural network (GNN) with a single-step, non-parameterized neighborhood aggregation and prediction (cf. SGC models).

For the labels ( Y_k^{(i)} ), we deployed a random infection window strategy:

  1. Pick a window size ( t_text{window} ) (say 21 days);
  2. Select a random day (t’) within the valid range ((t_text{window} le t’ le t_text{train} – t_text{pred}));
  3. Encode the S/I/R states in the past window from (t’) for every node in the neighborhood as individual features;
  4. The label is then whether person (i) is infected in any of the next (t_text{pred}) days from (t’).
During training, every time we sample a person (node) we take a random window of infection states to use as features (the “observation” window) and labels (1 iff the person transitions into infection during the “prediction” window) and their neighboring nodes will use the same window for building the neighborhood feature vector. During testing, we deterministically take the latest days of the infection history.

Our strategy implicitly assumes that a person’s infection risk is individual: whether Bob gets infected depends only on his own activities and contacts in the past window. This is certainly not perfect as it ignores population-level modeling (e.g. denser areas have higher risks of infection), but it makes the ML problem very simple: just plug-in existing tabular data modeling approaches!

2.4. Putting it all together

We can now see our solution coming together: each silo builds a tabular dataset using neighborhood vectors for features and infection windows for labels, and each silo trains a personalized binary classifier under MR-MTL with silo-specific example-level privacy. We complete our method with a few additional ingredients:

  1. Privacy accounting. We’ve so far glossed over what silo-specific “example-level” DP actually means for an individual. We’ve put more details in the extended blog post, and the main idea is that local DP-SGD can give “neighborhood-level” DP since each node’s enclosing neighborhood is fixed and unique, and we can then convert it to node-level DP (our privacy goal from Part 2.1) by carefully accounting for how a certain node may appear in other nodes’ neighborhoods.
  2. Noisy SGD as an empirical defense. While we have a complete framework for providing silo-specific node-level DP guarantees, for the PETs challenge specifically we decided to opt for weak DP ((varepsilon > 500)) as an empirical protection, rather than a rigorous theoretical guarantee. While some readers may find this mildly disturbing at first glance, we note that the strength of protection depends on the data, the models, the actual threats, the desired privacy-utility trade-off, and several crucial factors linking theory and practice which we outline in the extended blog. Our solution was in turn attacked by several red teams to test for vulnerabilities.
  3. Model architecture: simple is good. While the model design space is large, we are interested in methods amenable to gradient-based private optimization (e.g. DP-SGD) and weight-space averaging for federated learning. We compared simple logistic regression and a 3-layer MLP and found that the variance in data strongly favors linear models, which also have benefits in privacy (in terms of limited capacity for memorization) as well as explainability, efficiency, and robustness.
  4. Computation-utility tradeoff for neighborhood sampling. While larger neighborhood sizes (S) and more hops (ell) better capture the original contact graph, they also blow up the computation and our experiments found that larger (S) and (ell) tend to have diminishing returns.
  5. Data imbalance and weighted loss. Because the data are highly imbalanced, training naively will suffer from low recall and AUPRC. While there are established over-/under-sampling methods to deal with such imbalance, they, unfortunately, make privacy accounting a lot trickier in terms of the subsampling assumption or the increased data queries. We leveraged the focal loss from the computer vision literature designed to emphasize hard examples (infected cases) and found that it did improve both the AUPRC and the recall considerably.

The above captures the essence of our entry to the challenge. Despite the many subtleties in fully building out a working system, the main ideas were quite simple: train personalized models with DP and add some proximity constraints!


Takeaways and Open Challenges

In Part 1, we reviewed our NeurIPS’22 paper that studied the application of differential privacy in cross-silo federated learning scenarios, and in Part 2, we saw how the core ideas and methods from the paper helped us develop our submission to the PETs prize challenge and win a 1st place in the pandemic forecasting track. For readers interested in more details—such as theoretical analyses, hyperparameter tuning, further experiments, and failure modes—please check out our full paper. Our work also identified several important future directions in this context:

DP under data imbalance. DP is inherently a uniform guarantee, but data imbalance implies that examples are not created equal—minority examples (e.g., disease infection, credit card fraud) are more informative, and they tend to give off (much) larger gradients during model training. Should we instead do class-specific (group-wise) DP or refine “heterogeneous DP” or “outlier DP” notions to better cater to the discrepancy between data points?

Graphs and privacy. Another fundamental basis of DP is that we could delineate what is and isn’t an individual. But as we’ve seen, the information boundaries are often nebulous when an individual is a node in a graph (think social networks and gossip propagation), particularly when the node is arbitrarily well connected. Instead of having rigid constraints (e.g., imposing a max node degree and accounting for it), are there alternative privacy definitions that offer varying degrees of protection for varying node connectedness?

Scalable, private, and federated trees for tabular data. Decision trees/forests tend to work extremely well for tabular data such as ours, even with data imbalance, but despite recent progress, we argue that they are not yet mature under private and federated settings due to some underlying assumptions.

Novel training frameworks. While MR-MTL is a simple and strong baseline under our privacy granularity, it has clear limitations in terms of modeling capacity. Are there other methods that can also provide similar properties to balance the emerging privacy-heterogeneity cost tradeoff?

Honest privacy cost of hyperparameter search. When searching for better frameworks, the dependence on hyperparameters is particularly interesting: our full paper (section 7) made a surprising but somewhat depressing observation that the honest privacy cost of just tuning (on average) 10 configurations (values of (lambda) in this case) may already outweigh the utility advantage of the best tune MR-MTL((lambda^ast)). What does this mean if MR-MTL is already a strong baseline with just a single hyperparameter?


Check out the following related links:


DISCLAIMER: All opinions expressed in this post are those of the authors and do not represent the views of CMU.

Footnotes

1    Note that “personalization” refers to customizing models for each client (data silo) in federated learning rather than for a specific person.
2    As compared to local training or FedAvg for a fixed (lambda). However, tuning (lambda) as a hyperparameter can incur privacy cost.

Read More

TIDEE: An Embodied Agent that Tidies Up Novel Rooms using Commonsense Priors

TIDEE: An Embodied Agent that Tidies Up Novel Rooms using Commonsense Priors

Example of embodied commonsense reasoning. A robot proactively identifies a remote on the floor and knows it is out of place without instruction. Then, the robot figures out where to place it in the scene and manipulates it there.

For robots to operate effectively in the world, they should be more than explicit step-by-step instruction followers. Robots should take actions in situations when there is a clear violation of the normal circumstances and be able to infer relevant context from partial instruction. Consider a situation where a home robot identifies a remote control which has fallen to the kitchen floor. The robot should not need to wait until a human instructs the robot to “pick the remote control off the floor and place it on the coffee table”. Instead, the robot should understand that the remote on the floor is clearly out of place, and act to pick it up and place it in a reasonable location. Even if a human were to spot the remote control first and instruct the agent to “put away the remote that is on the living room floor”, the robot should not require a second instruction for where to put the remote, but instead infer from experience that a reasonable location would be, for example, on the coffee table. After all, it would become tiring for a home robot user to have to specify every desire in excruciating detail (think about for each item you want the robot to move, specifying an instruction such as “pick up the shoes beneath the coffee table and place them next to the door, aligned with the wall”).

The type of reasoning that would permit such partial or self-generated instruction following involves a deep sense of how things in the world (objects, physics, other agents, etc.) ought to behave. Reasoning and acting of this kind are all aspects of embodied commonsense reasoning and are vastly important for robots to act and interact seamlessly in the physical world.

There has been much work on embodied agents that follow detailed step-by-step instructions, but less on embodied commonsense reasoning, where the task involves learning how to perceive and act without explicit instruction. One task in which to study embodied commonsense reasoning is that of tidying up, where the agent must identify objects which are out of their natural locations and act in order bring the identified objects to plausible locations. This task combines many desirable capabilities of intelligent agents with commonsense reasoning of object placements. The agent must search in likely locations for objects to be displaced, identify when objects are out of their natural locations in the context of the current scene, and figure out where to reposition the objects so that they are in proper locations – all while intelligently navigating and manipulating.

In our recent work, we propose TIDEE, an embodied agent that can tidy up never-before-seen rooms without any explicit instruction. TIDEE is the first of its kind for its ability to search a scene for out of place objects, identify where in the scene to reposition the out of place objects, and effectively manipulate the objects to the identified locations. We’ll walk through how TIDEE is able to do this in a later section, but first let’s describe how we create a dataset to train and test our agent for the task of tidying up.

Creating messy homes

To create clean and messy scenes for our agent to learn from for what constitutes a tidy scene and what constitute a messy scene, we use a simulation environment called ai2thor. Ai2thor is an interactive 3D environment of indoor scenes that allows objects to be picked up and moved around. The simulator comes ready with 120 scenes of kitchens, bathrooms, living rooms, and bedrooms with over 116 object categories (and significantly more object instances) scattered throughout. Each of the scenes comes with a default initialization of object placements that are meticulously chosen by humans to be highly structured and “neat”. These default object locations make up our “tidy” scenes for providing our agent examples of objects in their natural locations. To create messy scenes, we apply forces to a subset of the objects with a random direction and magnitude (we “throw” the objects around) so they end up in uncommon locations and poses. You can see below some examples of objects which have been moved out of place.

Examples of “messy” object locations. These objects are moved out of place by applying forces to them in the simulator in order to generate untidy scenes for the robot to clean up.

Next, let’s see how TIDEE learns from this dataset to be able to tidy up rooms.

How does TIDEE work?

We give our agent a depth and RGB sensor to use for perceiving the scene. From this input, the agent must navigate around, detect objects, pick them up, and place them. The goal of the tidying task is to rearrange a messy room back to a tidy state.

TIDEE tidies up rooms in three phases. In the first phase, TIDEE explores around the room and runs an out of place object detector at each time step until one is identified. Then, TIDEE navigates over to the object, and picks it up. In the second phase, TIDEE uses graph inference in its joint external graph memory and scene graph to infer a plausible receptacle to place the object on within the scene. It then explores the scene guided by a visual search network that suggests where the receptacle may be found if TIDEE has not identified it in a previous time step. For navigation and keeping track objects, TIDEE maintains a obstacle map of the scene and stores in memory the estimated 3D centroids of previously detected objects. 

The three stages of TIDEE. TIDEE first searches for out of place objects. Then, once an out of place object is found, TIDEE infers where to put it in the scene. Finally, TIDEE searches for the correct placement location and places the object.

The out of place detector uses visual and relational language features to determine if an object is in or out of place in the context of the scene. The visual features for each object are obtained from an off-the-shelf object detector, and the relational language features are obtained by giving predicted 3D relations of the objects (e.g. next to, supported by, above, etc.) to a pretrained language model. We combine the visual and language features to classify whether each detected object is in or out of place. We find that combining the visual and relational modalities performs best for out of place classification over using a single modality.

Out of place object classification. The classifier makes use of visual and relational language features to infer if the object-under-consideration is in place or out of place.

To infer where to place an object once it has picked up, TIDEE includes a neural graph module which is trained to predict plausible object placement proposals of objects. The modules works by passing information between the object to be placed, a memory graph encoding plausible contextual relations from training scenes, and a scene graph encoding the object-relation configuration in the current scene. For our memory graph, we take inspiration from “Beyond Categories: The Visual Memex Model for Reasoning About Object Relationships” by Tomasz Malisiewicz and Alexei A. Efros (2009), which models instance-level object features and their relations to provide more complete appearance-based context. Our memory graph consists of the tidy object instances in the training to provide fine-grain contextualization of tidy object placements. We show in the paper that this fine-grain visual and relational information is important for TIDEE to place objects in human-preferred locations.

Neural graph module for determining where to place an object. The neural makes use of a graph made from training houses, which we call the memex graph. This gives the network priors about how objects are generally arranged in a “clean” state. We additionally give the network a current scene graph and the out of place object. The network outputs a plausible location to put the out of place object in the current scene.

To search for objects that have not been previously found, TIDEE uses a visual search network that takes as input the semantic obstacle map and a search category and predicts the likelihood of the object being present at each spatial location in the obstacle map. The agent then searches in those likely locations for the object of interest.

Object search network for finding objects-of-interest. The network conditions on a search category and outputs a heat-map for likely locations for the category to exist in the map. The robot searches in these likely locations to find the object.

Combining all the above modules provides us with a method to be able to detect out of place objects, infer where they should go, search intelligently, and navigate & manipulate effectively. In the next section, we’ll show you how well our agent performs at tidying up rooms.

How good is TIDEE at tidying up?

Using a set of messy test scenes that TIDEE has never seen before, we task our agent with reconfiguring the messy room to a tidy state. Since a single object may be tidy in multiple locations within a scene, we evaluate our method by asking humans whether they prefer the placements of TIDEE compared to baseline placements that do not make use of one or more of TIDEE’s commonsense priors. Below we show that TIDEE placements are significantly preferred to the baseline placements, and even competitive with human placements (last row).

TIDEE outperforms ablative versions of the model that do not use one or more of the commonsense priors, outperforms messy placements, and is competitive with human placements.

We additionally show that the placements of TIDEE can be customized based on user preferences. For example, based on user input such as “I never want my alarm on the desk”, we can use online learning techniques to change the output from the model that alarm clock being on the desk is out of place (and should be moved). Below we show some examples of locations and relations of alarm clocks that were predicted as being in the correct locations (and not out of place) within the scene after our initial training.  However, after doing the user-specified finetuning, our network predicts that the alarm clock on the desk is out of place and should be repositioned.

Alarm clock locations and their relations with other objects. Alarm clock is often found on desks in the training scenes.
Detection probabilities of the three alarm clocks before and after online learning of “alarm clock on the desk is out of place”. We show we are able to customize the priors of our out of place detector given user input.

We also show that a simplified version of TIDEE can generalize to task of rearrangement, where the agent sees the original state of the objects, then some of the objects get rearranged to new locations, and the agent must rearrange the objects back to their original state. We outperform the previous state of the art model that utilizes semantic mapping and reinforcement learning, even with noisy sensor measurements.

Rearrangement performance of TIDEE (blue) compared to the reinforcement learning baseline (orange). We are able to adapt our networks to perform object rearrangement, and beat a state-of-the-art baseline by a significant margin, even with noisy sensor measurements.

Summary

In this article, we discussed TIDEE, an embodied agent that uses commonsense reasoning to tidy up novel messy scenes. We introduce a new benchmark to test agents in their ability to clean up messy scenes without any human instruction. To check out our paper, code, and more, please visit our website at https://tidee-agent.github.io/.

Also, feel free to shoot me an email at gsarch@andrew.cmu.edu! I would love to chat!

Read More

Are Model Explanations Useful in Practice? Rethinking How to Support Human-ML Interactions.

Are Model Explanations Useful in Practice? Rethinking How to Support Human-ML Interactions.

Figure 1. This blog post discusses the effectiveness of black-box model explanations in aiding end users to make decisions. We observe that explanations do not in fact help with concrete applications such as fraud detection and paper matching for peer review. Our work further motivates novel directions for developing and evaluating tools to support human-ML interactions.

Model explanations have been touted as crucial information to facilitate human-ML interactions in many real-world applications where end users make decisions informed by ML predictions. For example, explanations are thought to assist model developers in identifying when models rely on spurious artifacts and to aid domain experts in determining whether to follow a model’s prediction. However, while numerous explainable AI (XAI) methods have been developed, XAI has yet to deliver on this promise. XAI methods are typically optimized for diverse but narrow technical objectives disconnected from their claimed use cases. To connect methods to concrete use cases, we argued in our Communications of ACM paper [1] for researchers to rigorously evaluate how well proposed methods can help real users in their real-world applications. 

Towards bridging this gap, our group has since completed two collaborative projects where we worked with domain experts in e-commerce fraud detection and paper matching for peer review. Through these efforts, we’ve gleaned the following two insights:

  1. Existing XAI methods are not useful for decision-making. Presenting humans with popular, general-purpose XAI methods does not improve their performance on real-world use cases that motivated the development of these methods. Our negative findings align with those of contemporaneous works.
  2. Rigorous, real-world evaluation is important but hard. These findings were obtained through user studies that were time-consuming to conduct. 

We believe that each of these insights motivates a corresponding research direction to support human-ML interactions better moving forward. First, beyond methods that attempt to explain the ML model itself, we should consider a wider range of approaches that present relevant task-specific information to human decision-makers; we refer to these approaches as human-centered ML (HCML) methods [10]. Second, we need to create new workflows to evaluate proposed HCML methods that are both low-cost and informative of real-world performance.

In this post, we first outline our workflow for evaluating XAI methods.  We then describe how we instantiated this workflow in two domains: fraud detection and peer review paper matching. Finally, we describe the two aforementioned insights from these efforts; we hope these takeaways will motivate the community to rethink how HCML methods are developed and evaluated.

How do you rigorously evaluate explanation methods?

In our CACM paper [1], we introduced a use-case-grounded workflow to evaluate explanation methods in practice—this means showing that they are ‘useful,’ i.e., that they can actually improve human-ML interactions in the real-world applications that they are motivated by. This workflow contrasts with evaluation workflows of XAI methods in prior work, which relied on researcher-defined proxy metrics that may or may not be relevant to any downstream task. Our proposed three-step workflow is based on the general scientific method:

Step 1: Define a concrete use case. To do this, researchers may need to work closely with domain experts to define a task that reflects the practical use case of interest.

Step 2: Select explanation methods for evaluation. While selected methods might be comprised of popular XAI methods, the appropriate set of methods is to a large extent application-specific and should also include relevant non-explanation baselines.

Step 3: Evaluate explanation methods against baselines. While researchers should ultimately evaluate selected methods through a user study with real-world users, researchers may want to first conduct cheaper, noisier forms of evaluation to narrow down the set of methods in consideration (Figure 2). 

Figure 2. Evaluation is a key component of our proposed use-case-grounded workflow and consists of four stages ranging from cheaper, lower-signal evaluations to more expensive, task-specific user studies. The stages of evaluation are adapted from Doshi-Velez and Kim (2017); we introduce an additional stage, use-case-grounded algorithmic evaluations, in a recent Neurips 2022 paper [2].

Instantiating the workflow in practice

We collaborated with experts from two domains (fraud detection and peer review paper matching) to instantiate this use-case-grounded workflow and evaluate existing XAI methods:

Figure 3. Example of the user interface used by fraud analysts in our experiment (populated with sample data for illustrative purposes). (a) Basic interface components, including the model score (shown in the top left), buttons to approve or decline the transactions, and transaction details. (b) A component of the interface that presents the explanations of the model score.

Domain 1: Fraud detection [3]. We partnered with researchers at Feedzai, a financial start-up, to assess whether providing model explanations improved the ability of fraud analysts to detect fraudulent e-commerce transactions. Given that we had access to real-world data (i.e., historical e-commerce transactions for which we had ground truth answers of whether the transaction was fraudulent) and real users (i.e., fraud analysts), we directly conducted a user study in this context. An example of the interface shown to analysts is in Figure 3. We compared analysts’ average performance when shown different explanations to a baseline setting where they were only provided the model prediction. We ultimately found that none of the popular XAI methods we evaluated (LIME, SHAP, and Tree Interpreter) resulted in any improvement in the analysts’ decisions compared to the baseline setting (Figure 5, left). Evaluating these methods with real users additionally posed many logistical challenges because fraud analysts took time from their regular day-to-day work to periodically participate in our study. 

Figure 4. Peer review paper matching is an example of a document matching application. For each submitted paper, the matching model pre-screens a list of candidate reviewers via affinity scores (solid arrows). Meta-reviewers, typically under a time constraint, then select the best match to the submitted paper among the pre-screened reviewer (box with a solid line). We study whether providing additional assistive information, namely highlighting potentially relevant information in the candidate documents, can help the meta-reviewers make better decisions (dotted arrows and boxes). 

Domain 2: Peer review paper matching [4]. We collaborated with Professor Nihar Shah (CMU), an expert in peer review, to investigate what information could help meta-reviewers of a conference better match submitted papers to suitable reviewers. Learning from our prior experience, we first conducted a user study using proxy tasks and users, which we worked with Professor Shah to design as shown in Figure 4. In this proxy setting, we found that providing explanations from popular XAI methods in fact led users to be more confident—-the majority of participants shown highlights from XAI methods believed the highlighted information was helpful—yet, they made statistically worse decisions (Figure 5 right)!

Figure 5. We evaluated popular XAI methods in two domains: e-commerce fraud (left), where we conducted a user study with a real use case and users, and peer review paper matching (right), where we conducted a user study with a proxy task and users that we designed with a domain expert. Although we find that explanations from popular XAI methods do not outperform baselines of only providing the model prediction (and often result in statistically worse performance), we are optimistic about the potential of task-specific methods. In particular, our proposed method in the peer review paper matching task outperformed both the model-score-only baseline and existing general-purpose methods.

How can we better support human-ML interactions?

Through these collaborations, we identified two important directions for future work, which we describe in more detail along with our initial efforts in each direction.

We need to develop methods for specific use cases. Our results suggest that explanations from popular, general-purpose XAI methods can both hurt decision-making while making users overconfident. These findings have also been observed in multiple contemporaneous works (e.g., [7,8,9]). Researchers, instead, need to consider developing human-centered ML (HCML) methods [10] tailored for each downstream use case. HCML methods are any approach that provides information about the particular use case and context that can inform human decisions.

Figure 6. Examples of highlighted information from different methods in our peer review matching proxy task. Highlights for “Key Parts” (second row) provide the “ground truth”, ie., it indicates the information relevant to the query summary (first row), all of which ideally should be visibly highlighted by the methods that follow. Existing methods like SHAP (third row) and BERTSum (fourth row) fail to fully highlight all key parts. Critically, they fail to visibly highlight the key part about “river levels rising” (yellow highlights in Key Parts), the unique information that distinguishes the ground truth from other candidate articles, which can directly impact the participant’s performance. On the other hand, our task-specific method (bottom row) visibly highlights all key parts.

Our contributions: In the peer review matching setting, we proposed an HCML method designed in tandem with a domain expert [4]. Notably, our method is not a model explanation approach, as it highlights information in the input data, specifically sentences and phrases that are similar in the submitted paper and the reviewer profile. Figure 6 compares the text highlighted using our method to the text highlighted using existing methods. Our method outperformed both a baseline where there was no explanation and the model explanation condition (Figure 5, right). Based on these positive results, we plan to move evaluations of our proposed method to more realistic peer review settings. Further, we performed an exploratory study to better understand how people interact with information provided by HCML methods as a first step towards coming up with a more systematic approach to devise task-specific HCML methods [5].

We need more efficient evaluation pipelines. While user studies conducted in a real-world use case and with real users are the ideal way to evaluate HCML methods, it is a time- and resource-consuming process. We highlight the need for more cost-effective evaluations that can be utilized to narrow down candidate HCML methods and still implicate the downstream use case. One option is to work with domain experts to design a proxy task as we did in the peer review setting, but even these studies require careful consideration of the generalizability to the real-world use case. 

Our contributions. We introduced an algorithmic-based evaluation called simulated user evaluation (SimEvals) [2]. Instead of conducting studies on proxy tasks, researchers can train SimEvals, which are ML models that serve as human proxies. SimEvals more faithfully reflects aspects of real-world evaluation because their training and evaluation data are instantiated on the same data and task considered in real-world studies. To train SimEvals, the researcher first needs to generate a dataset of observation-label pairs. The observation corresponds to the information that would be presented in a user study (and critically includes the HCML method), while the output is the ground truth label for the use case of interest. For example, in the fraud detection setting, the observation would consist of both the e-commerce transaction and ML model score shown in Figure 3(a) along with the explanation shown in Figure 3(b). The ground truth label is whether or not the transaction was fraudulent. SimEvals are trained to predict a label given an observation and their test set accuracies can be interpreted as a measure of whether the information contained in the observation is predictive for the use case. 

We not only evaluated SimEvals on a variety of proxy tasks but also tested SimEvals in practice by working with Feedzai, where we found results that corroborate the negative findings from the user study [6]. Although SimEvals should not replace user studies because SimEvals are not designed to mimic human decision-making, these results suggest that SimEvals could be initially used to identify more promising explanations (Figure 6). 

Figure 6. An overview of how simulated user studies (SimEvals) can help a researcher select which explanation methods to evaluate given their specific use case. (Left) When conducting user studies, researchers often only evaluate a small number of explanation methods due to resource constraints and select popular methods as candidate explanations to evaluate, with little justification about why each choice may be helpful for the downstream use case. (Right) We propose using SimEvals, which are use-case-grounded, algorithmic evaluations, to efficiently screen explanations before running a user study. In this example, the researcher runs a SimEval on each of the four candidate explanation methods and then uses the results of the SimEvals to select two promising explanation methods where the algorithmic agent has high accuracy for their human subject study.

Conclusion

In summary, our recent efforts motivate two ways the community should rethink how to support human-ML interactions: (1) we need to replace general-purpose XAI techniques with HCML methods tailored to specific use cases, and (2) creating intermediate evaluation procedures that can help narrow down the HCML methods to evaluate in more costly settings. 

For more information about the various papers mentioned in this blog post, see the links below:

[1] Chen, V., Li, J., Kim, J. S., Plumb, G., & Talwalkar, A. Interpretable Machine Learning. Communications of the ACM, 2022. (link)

[2] Chen, V., Johnson, N., Topin, N., Plumb, G., & Talwalkar, A. Use-case-grounded simulations for explanation evaluation. NeurIPS, 2022. (link)

[3] Amarasinghe, K., Rodolfa, K. T., Jesus, S., Chen, V., Balayan, V., Saleiro, P., Bizzaro, P., Talwalkar, A. & Ghani, R. (2022). On the Importance of Application-Grounded Experimental Design for Evaluating Explainable ML Methods. arXiv. (link)

[4] Kim, J. S., Chen, V., Pruthi, D., Shah, N., Talwalkar, A. Assisting Human Decisions in Document Matching. arXiv. (link)

[5] Chen, V., Liao, Q. V., Vaughan, J. W., & Bansal, G. (2023). Understanding the Role of Human Intuition on Reliance in Human-AI Decision-Making with Explanations. arXiv. (link)

[6] Martin, A., Chen, V., Jesus, S., Saleiro, P. A Case Study on Designing Evaluations of ML Explanations with Simulated User Studies. arXiv. (link)

[7] Bansal, G., Wu, T., Zhou, J., Fok, R., Nushi, B., Kamar, E., Ribeiro, M. T. & Weld, D. Does the whole exceed its parts? the effect of ai explanations on complementary team performance. CHI, 2021. (link)

[8] Adebayo, J., Muelly, M., Abelson, H., & Kim, B. Post hoc explanations may be ineffective for detecting unknown spurious correlation. ICLR, 2022. (link)

[9] Zhang, Y., Liao, Q. V., & Bellamy, R. K. Effect of confidence and explanation on accuracy and trust calibration in AI-assisted decision making. FAccT, 2020. (link)

[10] Chancellor, S. (2023). Toward Practices for Human-Centered Machine Learning. Communications of the ACM, 66(3), 78-85. (link)

Acknowledgments

We would like to thank Kasun Amarasinghe, Jeremy Cohen, Nari Johnson, Joon Sik Kim, Q. Vera Liao, and Junhong Shen for helpful feedback and suggestions on earlier versions of the blog post. Thank you also to Emma Kallina for her help with designing the main figure!

Read More

Towards Behavior-Driven AI Development

Towards Behavior-Driven AI Development

Figure 1: Behavior-driven AI development centers model iteration on evaluating and improving specific real-world use cases.

It has never been easier to prototype AI-driven systems. With a bit of programming knowledge and a couple of hours, you can spin up a chatbot for your notes, a text-based image editor, or a tool for summarizing customer feedback. But play around with your prototype for a bit, and you might find that it doesn’t work as well as you first expected. Your system might make up facts or respond with racist suggestions. How would you evaluate your model and predict its performance in deployment?

The canonical process for benchmarking AI systems revolves around model-centric metrics. Calculate a metric (F1-score, precision, etc.), and if it increases, you are going in the right direction. But these metrics are oversimplified objectives that sand away the complexity of model behavior and cannot fully represent a model’s performance. A metric may tell you how well your model can predict the next word in a sentence, but it won’t tell you how factually accurate, logical, or fair your model is across diverse, real-world use cases. Generative AI systems such as ChatGPT or Stable Diffusion make evaluation even more challenging since there are no well-defined metrics that can summarize their performance.

When creating deployed AI products, practitioners instead focus on the specific use cases their customers have and whether or not their models are fulfilling them. In interviews with 18 AI practitioners, we found that they constantly collect user feedback and develop “golden test sets” of behaviors that they expect deployed models to have. We term this behavior-driven AI development, a development process focused on evaluating and updating models to improve performance on real-world use cases. While chatbot A might sound more human-like, a practitioner will deploy chatbot B if it produces concise and accurate answers that customers prefer.

The landscape of AI evaluation tools primarily revolves around model-centric metrics that do not capture important behaviors like these chatbot characteristics. While there are specific tools for behavior-driven development, such as fairness toolkits and robustness analysis libraries, practitioners end up cobbling together disparate tools into ad-hoc scripts or computational notebooks that are hard to maintain and reproduce.

I believe that there are a set of abstractions that can unify AI evaluation in line with model use cases in practice. This philosophy revolves around model behaviors: metrics summarizing patterns of output on subgroups of instances. This simple concept can encode any model evaluation or analysis, from fairness audits to language model hallucinations. We show what this can look like with Zeno, an interactive platform we built for behavior-driven development that supports interactive data exploration, slicing, and reporting. By investigating their own models using Zeno, practitioners have been able to pinpoint significant and actionable issues such as biases and systematic failures. 

What is model behavior?

The dictionary describes behavior as anything that an organism does involving action and response to stimulation. In the case of AI systems, model behavior is a specific pattern of output for a semantically meaningful subgroup of input data (stimulus). By semantically meaningful, I mean subgroups that can be described with human-interpretable concepts, such as “audio with noise in the background” or “people who identify as women.” Similarly, a pattern of output could be “high audio transcription error” or “low loan approval rate.” 

Behaviors can be quantified as metrics on subgroups of data, often using the same metrics as are used for model-centric evaluation. But unlike summary metrics across an entire dataset, metrics in behavior-centric development quantify specific patterns of behavior, like how often an image generation model produces unintelligible text. Tests of model behaviors are like exams for specific subjects, while summary metrics resemble IQ tests.

Figure 2. How model behaviors are defined from a dataset. Behaviors are subgroups of data (typically defined by combinations of metadata) quantified by a specific metric. For the example behavior of “blurry text” from a text-to-image model, a metadata column for “images with text” could be used to create a subgroup on which a metric measuring the clarity of text can be calculated.

Model behaviors are a relatively simple concept, but encoding behaviors can be challenging in practice. Practitioners may not have enough data to validate or fix important model behaviors and have to collect or generate more data. If they have extensive data, they need ways to subdivide it into meaningful groups of instances – how do I find all images that have text? Lastly, for each subgroup, practitioners have to derive the appropriate metrics to quantify the prevalence of behavior – how do I detect blurry text? Succinctly, behavior-driven development requires sufficient data that is representative of expected behaviors and metadata for defining and quantifying the behaviors.

A platform for behavior-driven AI development

The beauty of a behavior-based framing on AI development is that it is still data and model agnostic. While the specific behaviors for each ML task will be vastly different, subgroups of data and metrics are universal concepts.

To test this theory, we built a platform for behavior-driven AI development called Zeno. Zeno is a platform that empowers users to explore data and model outputs, interactively create subgroups of data, and calculate and quantify model behaviors. Zeno consists of a Python API for scaffolding the data needed for analysis and a user interface for interactively creating subgroups and evaluating behaviors.

Figure 3. The Zeno interface shown for the Imagenette dataset and image classification. The right side has the instance view showing the input images and model outputs. The left side shows distributions for the dataset’s metadata, which has been interactively filtered to show images of English Springer Spaniels.

The Python API is a set of decorator functions (wrappers on user-defined functions) that can be used to plug in ML models and derive metadata features and metrics from input data. Since the decorators are generic wrappers, Zeno supports any Python-based model, processing function, or metric. Zeno preprocesses the input data with these functions, which it passes into the UI for analysis.

Zeno’s UI is the primary interface for behavior-driven evaluation. It allows users to interactively explore and filter their data, create slices, calculate metrics, and create exportable visualizations. On the right side of the UI is Zeno’s instance view, where users can explore the raw data on which the model is being evaluated. In addition to the standard list view, users can also see the data in a table or a 2D scatterplot representation. The left side of the interface holds the metadata panel. All the metadata columns that either came with the dataset or were generated with the Python API have their distributions displayed in the panel. Users can interactively filter the distributions to update the instance view and create named subgroups.

The UI also has a report page for creating interactive summary visualizations of behaviors. For example, a user could create a bar chart comparing the performance of three models across ten different slices. Or they could create a line chart showing how a model performs on data slices from each day of data. These visualizations can be exported or shared directly with other stakeholders.

Figure 4: With Zeno, users can interactively filter their data to create slices and calculate subgroup metrics. They can also use the 2D projection to find new areas of data where their model is underperforming. In this example, a user is exploring the CIFAR-10 classification model. They first filter the dataset to compare low versus high brightness images, finding a significant difference in accuracy between the two groups. They then find a group of instances with high error in the projection view, which is mostly made up of birds in the sky being misclassified as airplanes.

Case Studies

We have worked with various ML practitioners to apply Zeno to the models and tasks on which they work. Using Zeno, practitioners found significant model issues and areas for improvement, including gender biases and regional model disparities.

Audio transcription. This first case study I ran myself after I heard that OpenAI released a new speech-to-text model, Whisper, with state-of-the-art performance. I was curious how the model compared to some existing off-the-shelf transcription models. Instead of looking at aggregate metrics, I ran the models on the Speech Accent Archive dataset, which has speakers worldwide saying the same phrase. By filtering the dataset’s extensive metadata, I found that the models perform worse for English speakers who learned the language later in life and speakers from countries where English is not the native language.

Figure 5. (left) The average word error rate (WER) for both models across different ages when participants started learning English. (right) The average WER of the Silero and Whisper transcription models across speakers from different continents.
Charts exported directly from the Zeno Report UI. 

Cancer classification. In another case study, we worked with a researcher who wanted to improve a breast cancer classifier for mammogram images. Since the data was anonymized and lacked meaningful metadata, the practitioner wrote dozens of functions using a Python library to extract meaningful metadata features. By exploring the distributions, they found that images with higher “entropy” correlating with denser breast tissue had a significantly higher error rate than images with lower entropy, or less dense, tissue. This finding matches performance differences in human radiologists, who also perform worse for images of denser breast tissue since it makes it harder to detect lesions.

Low density (4937)
entropy < 2.75 &&
gray level variance < 2.5
High density (656)
entropy > 2.75 &&
gray level variance > 2.5
AUC 0.86 0.76
Figure 6. The breast cancer classification model performed significantly worse for high-density images (described by high entropy and gray level variance metadata levels) compared to the low-density images. (left, low density, right, high density).

Image generation. Models with complex outputs often do not have clearly defined metrics, including text-to-image generation models such as DALL*E and Stable Diffusion. We can instead look at metrics that measure specific behaviors. In this example, a practitioner we worked with was exploring the DiffusionDB dataset, which has over two million prompt-image pairs from the Stable Diffusion model. The dataset also has metadata for how NSFW or inappropriate the prompts and images are. This data was used to derive an “average NSFW” metric, which can show us interesting potential biases in the model. For example, the participant compared the images generated using prompts with the word “boy” versus “girl” and found that prompts with “girl” generated images with a significantly higher NSFW level than prompts with “boy”, showing potential biases in the types of images created by the model.

Figure 7. Given similar or less inappropriate prompts, the images generated with stable diffusion are much more inappropriate (NSFW) for prompts with “girl” or “woman” than “boy” or “man”.
Charts exported directly from the Zeno Report UI. 

Discussion and Opportunities

Model iteration is still a primarily reactive process of finding and defining behaviors after a model has been deployed and the customer complaints start rolling in. There remains significant room for improving this process, from making it easier to ideate model behaviors to tracking model changes over time.

Discovering behaviors. While practitioners often need a model to discover the behaviors the model should have, methods for defining expected model behaviors before deployment can prevent serious real-world model issues.  For example, crowdsourcing techniques for eliciting potential edge cases could preemptively catch model errors. Algorithmic methods that find clusters of data with high error have also shown promise for surfacing problematic behaviors.

Data discovery and generation. Having high-quality, representative data remains a persistent obstacle for behavioral evaluation. In some domains with ample data, such as natural images, methods like Stable Diffusion have shown promise for generating new data for evaluation or training. In less data-rich domains, techniques for searching through large unlabeled datasets, such as text-based image search, can surface valuable data for evaluation and retraining. It is also challenging to derive metadata from instances for creating subgroups and calculating metrics. While it can be easy to generate metadata for simple concepts like “image brightness,” many behaviors are defined by complex metadata such as “images with a person wearing clear glasses” that cannot be encoded by a simple function. Foundation models have shown some promise in using text-based descriptions to generate complex metadata and metrics.

Model comparison. Models are almost never one-off jobs and can be updated daily or weekly. While it is easy to compare aggregate metrics, it can be challenging to compare model performance in behavior-driven development. To pick between models, users may have to compare dozens of behaviors and qualitative insights. Improved visual encodings or intelligent recommendations of model differences could help users make informed decisions and deploy the right models.

Fixing behaviors. Discovering and encoding behaviors is one thing, but fixing behaviors is another massive challenge. A common approach to fixing issues is to gather more data and retrain the model, but this process can lead to catastrophic forgetting and regressions. There are recent techniques that align well with behavior-driven development, such as slice-based learning, which can selectively fix model behaviors without new data.

Conclusion

There is significant excitement for this new era of AI systems. But along with their growing capability, the complexity of their behavior is also increasing. We need powerful tools to empower behavior-driven development and ensure we build intelligent systems that align with human values. Zeno provides a general-purpose platform that empowers users to do this deep evaluation across the diverse tasks of modern AI. Learn more about Zeno at zenoml.com, read the full paper, or reach out if you would like to use Zeno for your models!

Acknowledgments

I’d like to thank Will Epperson, Jason I. Hong, Yi-Cheng Huang, Misha Khodak, Adam Perer, Venkat Sivaraman, Ameet Talwalkar, and Kristen Vossler for their thoughtful feedback and advice.

Read More

RLPrompt: Optimizing Discrete Text Prompts with Reinforcement Learning

RLPrompt: Optimizing Discrete Text Prompts with Reinforcement Learning

Figure 1: Overview of RL Prompt for discrete prompt optimization. All language models (LMs) are frozen. We build our policy network by training a task-specific multi-layer perceptron (MLP) network inserted into a frozen pre-trained LM. The figure above illustrates generation of a prompt (left), example usages in a masked LM for classification (top right) and a left-to-right LM for generation (bottom right), and update of the MLP using RL reward signals (red arrows).

TL;DR: Prompting enables large language models (LLMs) to perform various NLP tasks without changing the model. Discrete prompts have many desirable properties, but are difficult to optimize. We propose an efficient approach using reinforcement learning, which shows superior performance and facilitates rich interpretations and analyses. You can easily adapt it for your own tasks using our code base here.

Prompting has emerged as a promising approach to solving a wide range of NLP problems using large pre-trained language models (LMs), including left-to-right models such as GPTs (Radford et al., 2019; Brown et al., 2020) and masked LMs such as BERT (Devlin et al., 2019), RoBERTa (Liu et al., 2019), etc.

Compared to conventional fine-tuning that expensively updates the massive LM parameters for each downstream task, prompting concatenates the inputs with an additional piece of text that steers the LM to produce the desired outputs. A key question with prompting is how to find the optimal prompts to improve the LM’s performance on various tasks, often with only a few training examples.

Most existing work resorts to tuning soft prompt (e.g., embeddings) which falls short of interpretability, reusability across LMs, and applicability when gradients are not accessible. Discrete prompt, on the other hand, is difficult to optimize, and is often created by “enumeration (e.g., paraphrasing)-then-selection” heuristics that do not explore the prompt space systematically.

Instead, we propose RLPrompt, an efficient discrete prompt optimization approach with reinforcement learning (RL). RLPrompt is flexibly applicable to different types of LMs (e.g., BERT and GPTs) for both classification and generation tasks. Experiments on few-shot classification and unsupervised text style transfer show superior performance over a wide range of existing finetuning or prompting methods. 

Interestingly, the resulting optimized prompts are often ungrammatical gibberish text; and surprisingly, those gibberish prompts are transferable between different LMs to retain significant performance, indicating LMs may have grasped shared structures for prompting, but do not follow human language patterns.

Discrete Prompt Optimization with RL

This paper presents RLPrompt, a new discrete prompt optimization approach based on reinforcement learning (RL). This approach brings together a wide range of desirable properties for efficient use on diverse tasks and LMs (see the table below). 

Table 1: RLPrompt unites the desirable properties of a wide range of previous prompt optimization approaches

Crucially, rather than directly editing the discrete tokens, which has been difficult and inefficient, RLPrompt trains a policy network that generates the desired prompts. Discrete prompt optimization thus amounts to learning a small number of policy parameters which we set as an MLP layer inserted into a frozen compact model such as distilGPT-2 (HuggingFace, 2019). We describe the specific formulations in Section §2.1-2.3 of our paper.

This formulation also allows us to employ off-the-shelf RL algorithms (e.g., Guo et al., 2021) that learn the policy with arbitrary reward functions—defined either with available data (e.g., in few-shot classification) or other weak signals when no supervised data is accessible (e.g., in controllable text generation).

Reward Stabilization 

On the other hand, RL for prompt optimization poses new challenges to learning efficiency: the large black-box LM presents a highly complex environment that, given the prompt (i.e., actions), goes through a long series of complex transitions (e.g., reading the input and inferring the output) before computing the rewards. This makes the reward signals extremely unstable and hard to learn from. 

To overcome this difficulty, we propose two simple yet surprisingly effective ways to stabilize the rewards and improve the optimization efficiency.

  1. Normalizing the training signal by computing the z-score of rewards for the same input.
  2. Designing piecewise reward functions that provide a sparse, qualitative bonus to desirable behaviors (e.g., certain accuracy on certain class).

We describe more details in Section §2.4 of our paper.

Experiments

We evaluate our approach on both classification (in the few-shot setting) and generation (unsupervised text style transfer), and perform rich analyses for new insights on LM prompting. We describe implementation details such as reward function design in Section §3 our paper, and publish the code at our Github codebase.

Few-Shot Text Classification

For few-shot classification, we follow previous work and experiment on popular sentiment and topic classification tasks, using 16 examples per class for both training and validation (Perez et al., 2021). Results using RoBERTa-large (left table below) show our approach improving over a wide range of fine-tuning and prompting methods, and is as efficient to optimize as similar methods that tune soft prompts (e.g., right figure below). We report detailed dataset-level results in Section §3.1 of our paper.

Table 1: Average accuracy for few-shot text classification across all tested datasets. All methods use RoBERTa-large for fine-tuning or prompting.
Figure 2: Comparison of our method (orange) and BlackBox (BB) Tuning (Sun et al., 2022) (blue) in terms of training efficiency. The solid curves are the mean and the shaded regions are the max. and min. test accuracies over 5 trials.

Unsupervised Text Style Transfer

For text style transfer, we evaluate on the popular Yelp sentiment transfer dataset (Shen et al., 2017) using popular automatic metrics for content preservation, style accuracy, and fluency, and report their sentence-level joint product (J(cdot)) below. Our full paper also includes few-shot experiments on the Shakespeare (Xu et al., 2012) dataset and human evaluations.

Results using GPT-2 (left table below) show our method outperforms or competes with various fine-tuning and prompting baselines, including DiRR (Liu et al., 2021c) which expensively fine-tunes all parameters of a GPT-2 model. Ablation study (right figure below) shows that our proposed reward normalization technique is crucial to optimization success. We describe the full evaluation results in Section §3.2 of our paper. 

Table 2: Automatic evaluation of our method vs. baselines on the Yelp (Shen et al., 2017) sentiment transfer dataset. (J(cdot)) is our main metric which measures the average joint sentence-level scores of content preservation, style accuracy, and fluency. Numbers in (parentheses) are standard deviations across 3 sets of prompts.
Figure 3: Comparison of our method with (orange) and without (purple) z-score reward normalization. The format is the same as Figure 2.

Analysis

Optimal Prompts Don’t Follow Human Language

The resulting discrete prompts also facilitate rich interpretations and analyses for new insights into LM prompting. In particular, the optimized prompts, though inducing strong task performance, tend to be gibberish text without clear human-understandable meaning (e.g., table below), echoing recent research (Webson and Pavlick, 2021; Zhao et al., 2021; Prasad et al., 2022) that LMs making use of prompts do not necessarily follow human language patterns. 

Table 3: Comparison of our method (RLPrompt) with manually-written (Manual) prompts for text style transfer performance on Yelp (Shen et al., 2017). For the manual prompts, we take one from Reif et al. (2021) and write two more for this experiment. (J(cdot)) is the main metric introduced in Table 2. All outputs are generated using GPT-2-xl and metrics are averaged over 5 runs.

Learned Prompts Transfer Trivially Across LMs

Perhaps surprisingly, those gibberish prompts learned with one LM can be used in other LMs for significant performance, indicating that those different pre-trained LMs have grasped shared structures for prompting (e.g., figures below).

Figure 4: Heatmap of sentiment analysis performance with transferred discrete prompts of 2 tokens. The columns represent the models used to learn the prompts, and the rows represent the models we perform classification with. Brighter color represents higher accuracy.
Figure 5: Heatmap of text style transfer performance with transferred discrete prompts. The columns represent the models used to learn the prompts, and the rows represent the models we perform text generation with. Manual and Random refer to manual prompts and random tokens, respectively. Brighter color represents better joint score (J(cdot)).

Conclusion

We have presented RLPrompt, an efficient and flexible approach for discrete prompt optimization using RL, which improves over a wide range of fine-tuning and prompting methods in experiments on few-shot classification and unsupervised text style transfer.

Analysis reveals that strong optimized prompts are incoherent but transferable between LMs for remarkable performance. The observation opens up many promising possibilities for prompting, such as learning prompts cheaply from smaller models and performing inference with larger models. We are excited to explore further.

Read More

Bottom-up Top-Down Detection Transformers For Open Vocabulary Object Detection

Bottom-up Top-Down Detection Transformers For Open Vocabulary Object Detection

We perform open vocabulary detection of the objects mentioned in the sentence using both bottom-up and top-down feedback.

Object detection is the fundamental computer vision task of finding all “objects” that are present in a visual scene. However, this raises the question, what is an object? Typically, this question is side-stepped by defining a vocabulary of categories and then training a model to detect instances of this vocabulary. This means that if “apple” is not in this vocabulary, the model does not consider it as an object. The problem gets even worse when we try to integrate these object detectors into real household agents. Imagine that we want a robot that can pick up “your favorite green mug from the table right in front of you”. We want the robot to specifically detect the “green mug” which is on the “table in front of you” and not any other mug or table. Obviously, treating descriptions such as “green mug from the table right in front of you” as separate classes in the detector’s vocabulary cannot scale; one can come up with countless variations of such descriptions.

In light of this, we introduce Bottom-up Top-Down DEtection TRansformer (BUTD-DETR pron. Beauty-DETER), a model that conditions directly on a language utterance and detects all objects that the utterance mentions. When the utterance is a list of object categories, BUTD-DETR operates as a standard object detector. It is trained from both fixed vocabulary object detection datasets and referential grounding datasets which provide image-language pairs annotated with the bounding boxes for all objects referred to in the language utterance. With minimal changes, BUTD-DETR grounds language phrases both in 3D point clouds and 2D images.

BUTD-DETR conditions on language and can detect objects that SOTA Object detectors frequently miss.

No box bottleneck: BUTD-DETR decodes object boxes directly by attending to language and visual input instead of selecting them from a pool. Language-directed attention helps us localize objects that our bottom-up, task-agnostic attention may miss. For example, in the above image, the hint of “clock on top of the shelf” suffices to guide our attention to the right place, though the clock is not a salient object in the scene. Previous approaches for language grounding are detection-bottlenecked: they select the referred object from a pool of box proposals obtained from a pre-trained object detector. This means that if the object detector fails, then the grounding model will fail as well.

How does it work?

BUTD-DETR Architecture: Conditioning on visual, language and object detection stream, our model decodes boxes and spans for all mentioned objects.

The input to our model is a scene and a language utterance. A pre-trained object detector is used to extract box proposals. Next, the scene, boxes, and utterance are encoded using per-modality-specific encoders into visual, box, and language tokens respectively. These tokens are contextualized by attending to one another. The refined visual tokens are used to initialize object queries that attend to the different streams and decode boxes and spans.

Augmenting supervision with Detection prompts

Object Detection as Referential Language Grounding using detection prompts: We can generate additional grounding annotations/examples by chaining multiple object category tokens.

Object detection is an instance of referential language grounding in which the utterance is simply the object category label. We cast object detection as the referential grounding of detection prompts: we randomly sample some object categories from the detector’s vocabulary and generate synthetic utterances by sequencing them, e.g., “Couch. Person. Chair.”, as shown in the figure above. We use these detection prompts as additional supervision data: the task is to localize all object instances of the category labels mentioned in the prompt if they appear in the scene. For the category labels with no instances present in the visual input (e.g. “person” in the above figure), the model is trained to not match them to any boxes. In this way, a single model can perform both language grounding and object detection simultaneously and share the supervision information.

Results

BUTD-DETR achieves a large boost in performance over state-of-the-art approaches across all 3D language grounding benchmarks (SR3D, NR3D, ScanRefer). Moreover, it was the winning entry in the ReferIt3D challenge, held at the ECCV workshop on Language for 3D Scenes. On 2D language grounding benchmarks, BUTD-DETR performs on par with state-of-the-art methods when trained on large-scale data. Importantly, our model converges twice as fast compared to state-of-the-art MDETR, mainly because of the efficient deformable attention which we used with our 2D model.

Quantitative Results across 3D Benchmarks: Our model significantly outperforms all prior methods across all established 3D benchmarks.

We show the qualitative results of our model in the video at the beginning of the blog. For more visualizations, please refer to our project page and paper.

What’s next?

Our method detects all objects mentioned in the sentence — however, this assumes that the user needs to mention all relevant objects in the sentence. This is not desirable in general — for example, in response to “make breakfast” we would like our model to detect all the relevant ingredients like bread, eggs etc., even if they are not mentioned in the sentence. Additionally, while our architecture works for both 2D and 3D language grounding with minimal changes, we do not share parameters between the two modalities. This prevents transferring representations across modalities, which would be particularly helpful for the low-resource 3D modality. Our ongoing work is investigating these two directions.

We have released our code and model weights on GitHub, making it easy to reproduce our results and build upon our method. If you are interested in a language-conditioned open vocabulary detector for your project, then give BUTD-DETR a run! For more details, please check out our project page and paper.

Read More

Causal Confounds in Sequential Decision Making

Causal Confounds in Sequential Decision Making

A standard assumption in sequential decision making is that we observe everything required to make good decisions. In practice however, this isn’t always the case. We discuss two specific examples (temporally correlated noise (a) and unobserved contexts (c)) that have stymied the use of IL/RL algorithms (in autonomous helicopters (b) and self-driving (d)). We derive provably correct algorithms for both of these problems that scale to continuous control problems.

Reinforcement Learning (RL) and Imitation Learning (IL) methods have achieved impressive results in recent years like beating the world champion at Go or controlling stratospheric balloons. Usually, these results are on problems where we either a) observe the full state or b) are able to faithfully execute our intended actions on the system. However, we frequently have to contend with situations where this isn’t the case: our self-driving car might miss a person’s hand gestures or persistent wind might make it difficult to fly our quadcopter perfectly straight. These sorts of situations can cause standard IL approaches to perform poorly ([1], [2]). In causal inference, we call a random variable that we don’t observe that influences a relationship we’d like to model a confounder. Using techniques from causal inference, we derive provably correct and scalable algorithms for sequential decision making in these sorts of confounded settings.

We’re going to be focused mostly on imitation learning (see our last blog post for more details). In IL, we observe trajectories (sequences of states and actions) generated by some expert policy (pi_E) and want to recover the policy that generated them. The expert could be a) an experienced driver if we’re trying to build a self-driving car, b) a user interacting with a recommender system we’re attempting to model, or c) scraped internet text used to train a large language model. We’re now going to discuss two issues of causation that are difficult for standard IL algorithms to handle. I’ll be drawing upon material from our ICML ’22 paper Causal Imitation under Temporally Correlated Noise and our NeurIPS ’22 paper Sequence Model Imitation Learning with Unobserved Contexts.

Issue 1: Temporally Correlated Noise

The first problem we’ll consider is when there’s temporally correlated noise (TCN) affecting pairs of expert actions. For example, if our expert was a human quadcopter pilot, the TCN could be persistent wind [1]. Let’s assume that the expert intended to fly completely straight but was unable to do so because of the persistent wind, producing swerving trajectories. If the learner ignores the fact that the wind was the root cause of these swerves, they might attempt to reproduce the swerving behavior of the expert, causing them to deviate even further at test time due to the continued influence of the wind.

In effect, adding correlated noise to both (X=s_t) and (Y=a_t) changes the apparent slope of the line we’re trying to fit, making standard regression-based approaches no longer consistent.

What’s happened is that the TCN has created a spurious correlation (e.g. being a little on the left is often followed by going more to the left) which the learner has latched onto. What we’d like is to instead learn a policy that can fly as straight as the expert did in the demonstrations.

Let’s try and formalize what’s happening in the above example via graphical models. In a Markov Decision Process (MDP), the learner (pi) responds to the observed state (s_t) via taking an action (a_t) and the MDP transitions via dynamics (mathcal{T}: s times a rightarrow s ) to the next state (s_{t+1}).

Adding in the TCN corresponds to adding in a common cause of a pair of actions that’s not part of the observed state ((u_t) below).

Now, let’s dig into what goes wrong with imitation learning under TCN. Let’s say we apply a standard algorithm like behavioral cloning: direct regression from states ((X = s_t)) to actions ((Y=a_t)). The TCN travels through the dynamics to influence both the inputs and outputs of our regression procedure, creating spurious correlations that our learner might latch onto, leading to poor test-time performance. We visualize this effect in red below.

In effect, the TCN makes some elements of state (s_2) look more correlated with action (a_2) than they would otherwise. Attempting to minimize training error, the learner will use this correlation to their advantage but will end up swerving unnecessarily at test time as a result.

Filtering out TCN with Instrumental Variable Regression

To learn well under TCN, we’re going to utilize the idea of a natural experiment from causal inference, in which we’re able to use variation in the observational data to simulate the effect of an intervention / randomized control trial (RCT). Interestingly enough, this idea was one of the central contributions of the winners of the Nobel Prize in Economics in 2021. Let’s first take a look at this idea in the single shot setting. Assume that we’re trying to predict from (X) to (Y), both of which are affected by an unobserved confounder (U):

Notice the random variable (Z) both (1) affects (X) and (2) is independent from (U). (Z) is therefore an independent source of variation in (X). We call such variables instruments. Under some assumptions (see our paper for details), one can condition on an instrument to de-noise the inputs to our regression procedure and recover a causal relationship. So, instead of regressing from $$X rightarrow Y,$$ one regresses from $$X|Z rightarrow Y|Z.$$

We’ll skip the derivation here but intuitively, conditioning on an independent source of variation, (Z), “washes out” the effect of the confounder, (U). The variation in (Z) therefore acts as a sort of “natural experiment,” simulating an RCT without requiring a true intervention.

Now, you might well be wondering what a valid instrument in the sequential setting is? Well, given the past is independent of future confounding, we can use past states as an instrument! Graphically,

Many econometrics papers will spend quite a lot of text arguing something is a valid instrument (i.e. satisfies (1) and (2)) but here, the arrow of time gives us an instrument for free. Algorithmically, instrumental variable regression for imitation learning under TCN corresponds to solving $$ s_t | s_{t-1} rightarrow a_t | s_{t-1}.$$

So, one first fits models of the left and right-hand sides of the above expression and then regresses between them. We give two algorithms for doing so efficiently in our paper. On simulated control tasks, we observe that our approaches (in teal and orange) are able to recover a policy that matches expert performance, while regression-based behavioral cloning struggles to do so.

So, putting it all together, temporally correlated noise between expert actions can create spurious correlations between recorded states and actions, causing standard regression-based imitation learning approaches to produce inaccurate predictors. One can use the “natural experiment” caused by the variation in past states to de-noise the inputs to their regression procedure and recover causally correct predictors that do not suffer from spurious correlations.

Issue 2: Unobserved Contexts

We’re now going to consider a different kind of confounder. Let’s assume we’re trying to imitate an expert who observes some context, (c), which the learner does not. For example, if we were trying to imitate an expert driver but our sensors don’t pick up a stop sign on the side of the road, (c) could refer to this side information. Now, this problem is in general impossible to solve: if we don’t see the sign, how would we know to stop? However, if we pay attention to our surroundings and accumulate information over time, we might be able to eventually filter out what we missed: if we observe all the other cars around us slowing down for a few timesteps, we might rationally conclude that we should as well, even though we never saw the stop sign.

More formally, we’re in a Partially Observed Markov Decision Process (a POMDP), which we can visualize as follows:

So, the effect of the stop sign ((c)) is reflected in the states ((s_1, s_2)). There’s two key differences between this graphical model and the TCN graphical model. First, the confounder directly affects the state. Second, the confounder is constant, rather than time-varying. These two characteristics make it plausible that, given enough state observations, we can figure out what (c) is and act as the expert does.

To enable us to accumulate evidence over time, we’re going to allow our policy to condition its actions on the entire history up to this point. Denoting ( h_t = (s_1, a_1, dots, s_{t}) ), our policy now takes the form $$pi: h_t rightarrow a_t .$$ This means our policy is now a sequence model (e.g. an RNN or a transformer).

Ok, we use sequence models as our policy class and apply behavioral cloning (prediction of the next token from the prefix) and call it a day right? Unfortunately, this often produces policies that naively repeat their own past actions. We call this the latching effect. For example, in the self-driving domain, this can lead to learning a driving policy that begins to turn and then just keeps turning, driving in circles [2]. This is clearly less than ideal.

Modeling Unobserved Contexts with Sequence Models + On-Policy Training

Let’s dig into why the latching effect stymies off-policy imitation learners. The model we learn during training looks like $$pi(a_t |h_t) approx p_(a_t^E| s_1^E, a_1^E, dots, s_t^E),$$ where the superscript (E) denotes that these are expert states and actions. Now at test time, the past actions we pass into our sequence model policy are our own rather than the expert’s. This means we’re trying to sample actions from $$p_(a_t^E| s_1, a_1, dots, s_t).$$ Unless we know that the elements to the right of the conditioning bar are exactly the same (i.e. we’ve perfectly matched the expert policy), we have no reason to believe that the model we’ve learned will generalize to our own induced state distribution. Put differently, we have no reason to believe we’ve learned a policy which will perform well at test-time.

In machine learning terms, we’re facing the problem of policy-induced covariate shift (PICS). Here, our covariates are the history of states and actions the policy uses to predict the next action. PICS is a well known problem in imitation learning and sequential prediction writ large (see last blog post for more details). What’s special about the unobserved context setting is the degree to which it causes problems. Early on in an episode, no learner can do well as they haven’t accumulated enough information to perform well. This means that we have an unavoidable source of covariate shift from preliminary mistakes.

Off-policy training (i.e. training only on the green expert states) might not ensure we generalize well to our own induced state distribution (the orange dashed line). The effect of unobserved contexts is to ensure we don’t do well at first because we don’t have enough information to act properly, making it quite likely that we go off the rails.

Now, let’s assume that the expert’s actions were relatively similar between adjacent timesteps (which is frequently true in practice). This means rather than learning a complex mapping from states to actions, our learner might have instead learned a policy that mostly copies the last action. This was fine when the previous actions were the expert’s. But when they are the learner’s extremely suboptimal initial actions, naive copying is a recipe for disaster. This is what leads to us learning a driving policy that is only capable of turning in circles. A different way of phrasing this point is because the learner was trained only on trajectories from the expert, it treats its own test-time past actions as though they were produced by the expert, leading to unfortunate downstream consequences.

On-policy training (i.e. performing rollouts in the environment and comparing them to expert rollouts) doesn’t suffer from these issues. This is because we’re instead trying to satisfy $$ p_(a_t^E| s_1^E, a_1^E, dots, s_t^E) approx p_(a_t| s_1, a_1, dots, s_t).$$ In words, rather than matching expert actions on expert histories, we’re instead trying to match expert and learner trajectory distributions. We’re making sure we predict well based on our own histories and therefore can’t be surprised at test-time by states we didn’t expect to end up in. We prove in our paper that under certain identifiability conditions, on-policy training of sequence model policies is both necessary and sufficient for the learner to generalize well to their own induced state distribution and successfully imitate the expert.

On-policy training involves rolling out the learner’s current policy (in orange) and minimizing some notion of divergence (in red) to expert trajectories (in green). Because we observe our own induced state distribution, we can’t fool ourselves into thinking our own past actions are like those of the expert.

We also conduct experiments on continuous control tasks with unobserved contexts and show that equipping an off-policy learner (in grey) with access to history can actually hurt its performance, while we see no such effect for an on-policy learner (in orange).

Discussion

Confounding commonly occurs in real world situations, including those in which we want to make a sequence of decisions. When we don’t filter out or model the effects of what we don’t observe, we can end up making extremely suboptimal decisions. We describe two situations in which there’s clean algorithms for handling confounding in sequential decision making (temporally correlated noise can be filtered out via IVR, unobserved contexts can be modeled via on-policy training of a sequence model). Moving forward, we’re interested in other causal structures (e.g. negative controls or proxy shifts) and applying our techniques to problems where confounders abound (e.g. recommender systems).

If you’re interested in learning more, I recommend you look at the full papers or the attached code:

DISCLAIMER: All opinions expressed in this post are those of the author and do not represent the views of CMU.

Read More

How do Authors’ Perceptions about their Papers Compare with Co-authors’ Perceptions and Peer-review Decisions?

How do Authors’ Perceptions about their Papers Compare with Co-authors’ Perceptions and Peer-review Decisions?

NeurIPS 2021 Author Perception Experiment

Alina Beygelzimer, Yann N. Dauphin, Percy Liang, Jennifer Wortman Vaughan
(NeurIPS 2021 Program Chairs
)

Charvi Rastogi, Ivan Stelmakh, Zhenyu Xue, Hal Daumé III, Emma Pierson, and Nihar B. Shah

There is a considerable body of research on peer review. Within the machine learning community, there have been experiments establishing significant disagreement across reviewers and across reviewer panels—including at NeurIPS 2021—and active discussions about the state of peer review. But how do author perceptions about their submitted papers match up to the outcomes of the peer-review process and perceptions of other authors? We investigate this question by asking authors who submitted papers to NeurIPS 2021 three questions:

(Q1) [At the time of paper submission] What is your best estimate of the probability (as a percentage) that this submission will be accepted?

(Q2) [At the time of paper submission; to authors submitting two or more papers] Rank your submissions in terms of your own perception of their scientific contributions to the NeurIPS community, if published in their current form.

(Q3) [After preliminary reviews were available to authors] After you read the reviews of this paper, how did your perception of the value of its scientific contribution to the NeurIPS community change (assuming it was published in its initially submitted form)?  

Here are five key findings.

1. How well do authors estimate the probability of acceptance of their papers?

Authors significantly overestimate their papers’ chances of acceptance. When answering Q1, authors were informed that the acceptance rate at NeurIPS over the last 4 years had been about 21%. The acceptance rate at NeurIPS 2021 turned out to be 25.8%. The authors’ responses had a nearly three-fold overestimate, with a median prediction of 70%.

2. Are some sub-groups better calibrated than others?

We examined calibration error across sub-groups, measuring this error in terms of the Brier score (squared loss) and controlling for other confounders. We find that the calibration error of female authors is slightly (but statistically significantly) higher than that of male authors. We also see a trend of miscalibration decreasing with seniority, with authors who were invited to serve as (meta-)reviewers better calibrated than the rest. All sub-groups we examined over-predicted their papers’ chances of acceptance.

 

3. Among authors with multiple papers, how much do their predictions of acceptance probabilities agree with their own perceived scientific merit?

These two sets of responses are largely in agreement: The strict ranking provided by authors about their perceived scientific merit (Q2) and the strict ranking induced by their predicted acceptance probabilities (Q1) agree for 93% of responses. However, there is a noticeable 7% of responses where the authors think that the peer review is more likely to reject the better of their two papers.

4. How much do co-authors agree on the relative quality of their joint papers?

Strikingly, the amount of disagreement between co-authors in terms of the perceived relative scientific contribution of their papers (Q2) is similar to the amount of disagreement between authors and reviewers! In cases where one paper from an author was ultimately accepted and another rejected, authors rated the rejected paper higher about a third of the time. But looking at pairs of papers with overlapping authors in which both authors provided rankings, the co-authors also disagreed with each other about a third of the time. While there are discussions in the literature about inter-reviewer disagreements, this result suggests that there is similar disagreement in co-authors’ views of their papers as well.

5. Does peer review change authors’ perception of their own papers?

The question Q3 was a multiple-choice question with five choices: much more positive (“++”), slightly more positive (“+”), did not change (“0”), slightly more negative (“-”), much more negative (“- -”).

We find that among both accepted and rejected papers, about 50% of authors report that their perception about their own paper changed after seeing the initial reviews (Q3). Moreover, among both accepted and rejected papers, over 30% of authors report that their perception became more positive.

Accepted papers Rejected papers

Discussion

The fact that authors vastly overestimated the probability that their papers will be accepted suggests it would be useful for conference organizers and research mentors to attempt to recalibrate expectations prior to each conference. The disagreements we document around paper quality — between co-authors as well as between authors and reviewers — taken together with the disagreement among committees of reviewers observed in the complementary NeurIPS 2021 consistency experiment, suggest that assessing paper quality is not only an extremely noisy process, but may be a fundamentally challenging task with no objective right answer. The outcomes of paper submissions should thus be taken with a grain of salt. More broadly, as a community, we may take these findings into account when deciding on our policies and perceptions pertaining to the peer-review process and its outcomes. We hope the results of our experiment encourage discussion and introspection in the community.

More details: Available here

Read More

Tackling Diverse Tasks with Neural Architecture Search

DASH searches for the optimal kernel size and dilation rate efficiently from a large set of options for each convolutional layer in a CNN backbone. The resulting model can achieve task-specific feature extraction and work as well as hand-designed expert architectures, making DASH an effective tool for tackling diverse tasks beyond well-researched domains like vision.

The past decade has witnessed the success of machine learning (ML) in solving diverse real-world problems, from facial recognition and machine translation to disease diagnosis and protein sequence prediction. However, progress in such areas has involved painstaking manual effort in designing and training task-specific neural networks, leveraging human and computational resources that most practitioners do not have access to.

In contrast to this task-specific approach, general-purpose models such as DeepMind’s Perceiver IO and Gato and Google’s Pathway have been developed to solve more than one task at once. However, as these proprietary pretrained models are not publicly available, practitioners cannot even assess whether fine-tuning one of these models would work on their task of interest. Independently developing a general-purpose model from scratch is also infeasible due to the massive amount of compute and training data it requires.

A more accessible alternative is the field of automated machine learning (AutoML), which aims to obtain high-quality models for diverse tasks with minimal human effort and computational resources, as noted in a recent blogpost. In particular, we can use Neural Architecture Search (NAS) to automate the design of neural networks for different learning problems. Indeed, compared with training large-scale transformer-based general-purpose models, many efficient NAS algorithms such as DARTS can be run on a single GPU and take a few hours to complete a simple task. However, while NAS has enabled fast and effective model development in well-studied areas such as computer vision, its application to domains beyond vision remains largely unexplored. In fact, a major difficulty in applying NAS to more diverse problems is the trade-off between considering a sufficiently expressive set of neural networks and being able to efficiently search over this set. In this blog post, we will introduce our approach to find a suitable balance between expressivity and efficiency in NAS.

In our upcoming NeurIPS 2022 paper, we developed a NAS method called DASH that generates and trains task-specific convolutional neural networks (CNNs) with high prediction accuracy. Our core hypothesis is that for a broad set of problems (especially those with non-vision inputs such as audio and protein sequences), simply searching for the right kernel sizes and dilation rates for the convolutional layers in a CNN can achieve high-quality feature extraction and yield models competitive to expert-designed ones. We explicitly focus on extending the generalization ability of CNNs due to the well known effectiveness of convolutions as feature extractors, coupled with recent work demonstrating the success of modern CNNs on a variety of tasks (e.g., the state-of-the-art performance of the ConvNeXt model that incorporates many techniques used by Transformers).

While a search space of diverse kernels is easy to define, searching it efficiently is challenging because we want to consider many kernels with different kernel sizes and dilation rates, which results in a combinatorial explosion of possible architectures. To address this issue, we introduce three techniques exploiting the mathematical properties of convolution and fast matrix multiplication on GPUs. We evaluate DASH on 10 different tasks spanning multiple domains (vision, audio, electrocardiogram, music, protein, genomics, cosmic-ray, and mathematics), input dimensions (1D and 2D), and prediction types (point and dense). While searching up to 10x faster than existing NAS techniques, DASH achieves the lowest error rates among all NAS baselines on 7/10 tasks and all hand-crafted expert models on 7/10 tasks.

In the following, we will first discuss how DASH is inspired by and differs from existing NAS work. Then, we will introduce three novel “tricks” that improve the efficiency of searching over a diverse kernel space. Finally, we will present the empirical evaluation to demonstrate DASH’s effectiveness.

The Expressivity-Efficiency Trade-Off in NAS

Most NAS methods have two components for generating task-specific models: a search space that defines all candidate networks and a search algorithm that explores the search space until a final model is found. Effective models for arbitrary new tasks can be developed if and only if the search space is sufficiently expressive, but this also means we need more time to explore the set of possible architectures in the space.

This tension between search space expressivity and search algorithm efficiency has been prominent in NAS research. On one hand, vision-centric approaches like DARTS are designed to explore multiple architectures quickly, but the search spaces are limited to models with (inverted) residual blocks, and are thus highly tailored to vision tasks. On the other hand, new approaches like AutoML-Zero and XD aim to solve arbitrary tasks by considering highly expressive search spaces, but the associated search algorithms are often practically intractable. For instance, XD tries to substitute layers in existing networks with matrix transformations. The matrix search space is continuous and expansive, but optimizing it is extremely time-consuming even for simple benchmarking tasks like CIFAR-100, rendering XD impractical for diverse tasks with more data points or larger input dimensions.

To bridge this gap, we present DASH, which fixes a CNN as the backbone and searches for the optimal kernel configurations. The intuition is that modern convolutional models like ConvNeXt and Conv-Mixer are powerful enough to compete with attention-based architectures, and varying kernel sizes and dilations can further strengthen the feature extraction process for different problems. For instance, small filters are generally used for visual tasks to detect low-level features such as edges and corners, whereas large kernels are typically more effective for sequence tasks to model long-range dependencies. Unlike conventional cell-based NAS which searches for a block of operations and stacks several copies of the same block together, DASH is more flexible as it decouples layer operations from the network structure: since the searched operators can vary from the beginning to the end of a network, features at different granularities can be processed differently. 

“Kernel Tricks” for Improving NAS Efficiency

As mentioned above, we seek a sufficiently expressive (i.e., large) kernel search space to ensure that there exist kernels that can effectively extract features for a diverse set of tasks. To achieve this, we replace each convolutional layer in the backbone network with the following aggregated convolution operator:

$$S_{bf AggConv_{K, D}} = {bf Conv_{k,d} | k in K, din D}, tag{1}label{1}$$

where (K) and (D) denote the set of kernel sizes and dilations that we consider, respectively. A naive approach to searching over this kernel space is to use the continuous relaxation scheme of DARTS to compute the output (we call this approach mixed-results):

$$bf AggConv_{K,D} (mathbf{x}) := sum_{kin K}sum_{din D} alpha_{k,d}cdot bf Conv(mathbf{w}_{k,d})(mathbf{x}), tag{2}label{2}$$

where (mathbf{w}_{k,d}) are the kernel weights and (alpha_{k, d}) are the architecture parameters. However, DARTS only considers a few kernels with small sizes and dilations (with (k_{max})=5, (d_{max})=2, and thus small (|K||D|)), whereas we aim to search over many and large kernels (e.g., with (k_{max})=11, (d_{max})=127, and large (|K||D|)). This increased expressivity leads to drastically higher search costs for naive search, whose runtime complexity is (O(n|K||D|)), where (n) is the input size.

In the following, we will describe three techniques—kernel-mixing, Fourier convolution, and Kronecker dilation—that DASH collectively employs to enable efficient search. Complexity-wise, DASH’s efficient search replaces the (O(n|K||D|)) complexity of naive search with an (O(nlog n)) complexity, where (log n) is small for any realistic (n), including long sequence inputs. Empirically, this latter complexity translates to significantly improved search speed, e.g., DASH searches about 10 times faster than DARTS for the large (|K||D|) regime (Figure 1).

Figure 1. Combined forward- and backward-pass time for one search epoch vs. the search space size for 1D input with (n = 1000). We vary the search space by letting (K = {2p+1 | 1leq p leq c}), (D = {2^q-1 | 1leq qleq c}) and increasing (c) from 1 to 7. As the aggregated kernel size (bar{D}) increases, the DASH curves grow much slower than the other methods.

Technique 1: Mixed-Weights. We first observe that all computations in Equation 2 are linear, so the distributive law applies. Hence, instead of computing (|K||D|) convolutions, we can combine the kernels and compute convolution once:

$$bf AggConv_{K, D}(mathbf{x})=bf Convleft(sum_{kin K}sum_{din D} alpha_{k, d}cdotmathbf{w}_{k,d}right)(mathbf{x}). tag{3}label{3}$$

Let’s call this approach mixed-weights. Mixed-weights allows the search complexity to depend on the aggregated kernel size (bar{D} := max (k − 1)d + 1) rather than (|K||D|), but the former still scales with search space. Can we do better than this?

Technique 2: Fourier Convolution. Imagine that (x) is an image input. Then, Equation 3 operates on the pixel values in the spatial domain. Alternatively, one can work with the rate of change of the pixel values in the frequency domain to compute the convolution output more efficiently, taking advantage of the celebrated convolution theorem:

$$bf AggConv_{K,D}(mathbf{x})= mathbf{F}^{-1}bf diagleft(mathbf{F}(sum_{kin K}sum_{din D} alpha_{k,d}cdot mathbf{w}_{k,d})right)mathbf{F}mathbf{x}. tag{4}label{4}$$

where (mathbf{F}) represents the discrete Fourier transform. Equation 4 allows us to remove the dependence on the combined kernel size (bar{D}), as (mathbf{F}) can be applied in time (O(n log n)) using the Fast Fourier Transform.

Technique 3: Kronecker Dilation. Lastly, we focus on accelerating search in a subset of the search space where the kernel size (k) is fixed and the dilation rate (d) varies. To dilate a kernel before applying it to the input, we need to insert (d-1) zeros between the adjacent elements in the weight matrix. An efficient implementation on GPUs exploits the Kronecker product (otimes). For example, in 2D, we can introduce a sparse pattern matrix (P in mathbb{R}^{dtimes d}) whose entries are all (0)’s except for the upper-left entry (P_{1,1} = 1). Then, (mathbf{w}_{k,d} = mathbf{w}_{k,1} otimes P). After dilating the kernels, we can proceed by following Equation 4.

Empirical Results on NAS-Bench-360

Figure 2. We use performance profiles to measure the aggregate performance of DASH and other NAS methods (higher is better). Each profile is a curve corresponding to a single method and plots the fraction of tasks on which that method is a multiplicative factor (tau) worse than the best method, where (tau) is the domain variable between 1 and infinity. DASH being far in the top left corner indicates it is rarely suboptimal and is often the best.

To verify that DASH finds a balance between expressivity and efficiency, we evaluate its performance with the Wide ResNet backbone on ten diverse tasks from NAS-Bench-360. We present the performance profile (a technique for comparing different methods while revealing both ranking and absolute performance characteristics) for DASH and the NAS baselines in Figure 2. The exact accuracy metrics can be found in our paper. We highlight the following observations:

  • DASH ranks first among all NAS baselines (and outperforms DARTS) on 7/10 tasks.  It also dominates traditional non-DL approaches such as Auto-Sklearn and general-purpose models such as Perceiver IO.
  • DASH outperforms hand-crafted expert models on 7/10 tasks. While the degree of sophistication of the expert networks varies task by task, the performance of DASH on tasks such as Darcy Flow suggests that it is capable of competing with highly specialized networks, e.g., Fourier Neural Operator for PDE solving. This implies that equipping backbone networks with task-specific kernels is a promising approach for model development in new domains.
  • Speedwise, DASH is consistently faster than DARTS, and its search process often takes only a fraction of the time needed to train the backbone. Figure 3 visualizes the trade-off between efficiency and effectiveness for each method-task combination. Evidently, DASH is both faster and more effective than most NAS methods on the tasks we considered.
Figure 3. Comparing -(logtau)-suboptimality of speed vs. accuracy on all tasks. DASH’s concentration in the top right corner indicates its strong efficacy-efficiency trade-offs relative to the other methods.

In addition to the Wide ResNet backbone and NAS-Bench-360 tasks, we have also verified the efficacy of DASH on other backbones including TCN and ConvNeXt, and on large-scale datasets including ImageNet. In particular, DASH is able to achieve a 1.5% increase in top-1 accuracy for ImageNet-1K on top of the ConvNeXt backbone (note that ConvNeXt itself was developed in part via manual tuning of the kernel size). These results provide further support that DASH is backbone-agnostic, and it can be used to augment modern architectures with task-specific kernels to solve diverse problems effectively and efficiently.

Discussion and Takeaways

In this blogpost, we argue that a crucial goal of AutoML is to discover effective models for solving diverse learning problems that we may encounter in reality. To this end, we propose DASH, which efficiently searches for task-specific kernel patterns and integrates them into existing convolutional backbones. Please see our paper and codebase to learn more about DASH or try it out on your own tasks and backbones.

While DASH makes progress in obtaining high-quality models by finding a suitable balance between expressivity and efficiency in NAS, we view it as an early step as part of the broader challenge of AutoML for diverse tasks. Thus, we would like to highlight and encourage participation in the ongoing AutoML Decathlon competition at NeurIPS 2022. At the moment, we are working on implementing DASH as one of the baselines for the competition. Meanwhile, we hope to see more automated and practical methods developed for tackling diverse tasks in the future.  

Read More

Tracking Any Pixel in a Video

We upgrade pixels into PIPs: “Persistent Independent Particles”. With this representation, we track any pixel over time, and overcome visibility issues with a learned temporal prior.

Motion estimation is a fundamental task of computer vision, with extremely broad applications. By tracking something, you can build models of its various properties: shape, texture, articulation, dynamics, affordances, and so on. More fine-grained tracking allows more fine-grained understanding. For robots, fine-grained tracking also enables fine-grained manipulation. Even setting aside downstream AI-related applications, motion tracks are directly useful for video editing applications — making realistic edits to a person or object in a video demands precise-as-possible tracking of the pixels, across an indefinite timespan.

There are a variety of methods for tracking objects (at the level of segmentation masks or bounding boxes), or for tracking certain points in certain categories (e.g., the joints of a person), but there are actually very few options for general-purpose fine-grained tracking. In this domain, the dominant approaches are feature matching and optical flow. The feature matching approach is: compute a feature for the target on the first frame, then compute features for pixels in other frames, and then compute “matches” using feature similarity (i.e., nearest neighbors). This often works well, but does not take into account temporal context, like smoothness of motion. The optical flow approach is: compute a dense “motion field” that relates each pair of frames, and then do some post-processing to link the fields together. Optical flow is very powerful, but since it only describes motion for a pair of frames at a time, it cannot produce useful outputs for targets that undergo multi-frame occlusion. “Occlusion” means our view is obstructed, and we need to guess the target’s location from context.

During an occlusion, appearance information does not suffice, because the target is not even present in the frames. Multi-frame temporal priors are key.

Around the year 2006, Peter Sand and Seth Teller proposed an alternative to flow-based and feature-based methods, called a “particle video.” This approach aims to represent a video with a set of particles that move across multiple frames. Their proposed method did not handle occlusions, but in our view, they laid the groundwork for treating pixels as persistent entities, with multi-frame trajectories and long-range temporal priors.

Inspired by their work, we propose Persistent Independent Particles (PIPs), a new particle video method. Our method takes a video as input, along with the ( (x, y) ) coordinate of a target to track, and produces the target’s trajectory as output. The model can be queried for any number of particles, at any positions.

Particle trajectories for arbitrary “target” pixels in the video.

You may notice in our visualizations that the PIP trajectories often leave the video bounds. We treat pixels flying out-of-bounds as just another “occlusion.” This type of robustness is exactly what is missing from feature-based and flow-based methods.

Let’s step through how we achieved this.

How does it work?

At a high level, our method makes an extreme trade-off between spatial awareness and temporal awareness: we estimate the trajectory of every target independently. This extreme choice allows us to devote the majority of parameters into a module that simultaneously learns (1) temporal priors, and (2) an iterative inference mechanism that searches for the target pixel’s location in all input frames. Related work on optical flow estimation typically uses the opposite approach: they estimate the motion of all pixels simultaneously (using maximal spatial context), for just 2 frames at a time (using minimal temporal context).

Given the ( (x_1,y_1) ) coordinate of the target on the first frame, our concrete goal is to estimate the target’s full trajectory over (T) frames: ( (x_1,y_1), ldots, (x_T,y_T) ).

We start by initializing a zero-velocity estimate. This means copying the initial coordinate to every timestep.

To track the target, we need to know what it looks like, so we compute appearance features for all the frames (using a CNN), and initialize the target’s appearance trajectory with a bilinear sample at the given coordinate on the first frame. (The “bilinear sample” step extracts a feature vector at a subpixel location in the spatial map of features.)

Our inference process will proceed by iteratively refining the sequence of positions, and sequence of appearance features, until they (hopefully) match the true trajectory of the target. This idea is illustrated in the video below: at initialization, only “timestep 1” has the correct location of the target (since this was given), and gradually, the model “locks on” to the target in all frames.

Our model’s job is to produce updates (i.e., deltas) for the positions and features, so that the trajectory tracks the target on every frame. A critical detail here is that we ask our model to produce these updates for multiple timesteps simultaneously. This allows us to “catch” a target after it re-emerges from an occluder, and “fill in” the missing part of the trajectory.

There are many ways to implement this, but for fast training and good generalization, we need to carefully select what information we provide to the model.

The main source of information we provide to the model is: measurements of local appearance similarity. We obtain these measurements using cross correlation (i.e., dot products), computed at multiple scales. When the target is visible, it should show up as a strong peak in at least one of the similarity maps. Also, when we are “locked on”, the peak should be in the middle of the map. This is illustrated in the animation below.

The second source of information we provide to the model is: the estimated trajectory itself. This allows the model to impose a temporal prior, and fix up parts of the trajectory where the local similarity information was ambiguous.

Finally, we allow the model to inspect the feature vector of the target, in case it might learn different strategies for different types of features. For example, depending on the scale or texture of the target, it may adjust the way it uses information from the multi-scale similarity maps.

For the model architecture, we elected to use an MLP-Mixer, which we found to have a good trade-off between model capacity, training time, and generalization. We also tried convolutional models and transformers, but the convolutional models could not fit the data as well as the MLP-Mixer, and the transformers took too long to train.

We trained the model in synthetic data that we made (based on an existing optical flow dataset), where we could provide multi-frame ground-truth for targets that undergo occlusions. The animation below shows the kind of data we trained on. You might say the data looks crazy — but that’s the point! If you can’t get real data, your best bet is synthetic data with extremely high diversity.

FlyingThings++: We train our model to track objects in this data, so that real videos are easy in comparison.

After training for a couple days on this data, the model starts to work on real videos.

Results

In the paper, we provide some quantitative analysis, showing that it works better than existing methods. It’s certainly not perfect — on keypoint tracking, the model works about 6/10 times, so there is a lot of room for improvement. The baselines are at around 5/10 or less.

The idea here is: pick a point on the first frame, and try to locate that same point in other frames of the video. This is a hard task especially when the point gets occluded.

Output of our PIPs model.

Baseline methods tend to get stuck on occluders, since they do not use multi-frame temporal context. For example, here is the output of a state-of-the-art optical flow method, on the same video and same target.

Output of an optical flow model (RAFT).

We have had fun trying the model on various videos, and observing the estimated trajectories. Sometimes they are surprisingly complex, since the target’s actual motion is subtly entangled with camera motion.

Visualizing the trajectories more densely gives mesmerizing results. Notice that the model even tracks ripples and specularities in the water.

Despite the fact that each particle trajectory is estimated independently of the others, they show surprisingly accurate grouping. Notice that the background particles all move together.

What’s next?

Our method upgrades pixels into PIPs: “Persistent Independent Particles.” This independence assumption, however, is probably not what we want in general. In ongoing work, we are trying to incorporate context across particles, so that confident particles can help the unconfident ones, and so that we track at multiple levels of granularity simultaneously.

We have released our code and model weights on GitHub. We encourage you to try our demo.py! If you are interested in building on our method, the provided tests, visualizations, and training scripts should make that easy. Or, if you are working on a video-based method that currently relies on optical flow, you may want to try our PIP trajectories as a replacement, which should give a better signal under occlusions.

We hope that our work opens up long-range fine-grained tracking of “anything.” For more details, please check out our project page and paper.

Read More