A foolproof way to shrink deep learning models

As more artificial intelligence applications move to smartphones, deep learning models are getting smaller to allow apps to run faster and save battery power. Now, MIT researchers have a new and better way to compress models. 

It’s so simple that they unveiled it in a tweet last month: Train the model, prune its weakest connections, retrain the model at its fast, early training rate, and repeat, until the model is as tiny as you want. 

“That’s it,” says Alex Renda, a PhD student at MIT. “The standard things people do to prune their models are crazy complicated.” 

Renda discussed the technique when the International Conference of Learning Representations (ICLR) convened remotely this month. Renda is a co-author of the work with Jonathan Frankle, a fellow PhD student in MIT’s Department of Electrical Engineering and Computer Science (EECS), and Michael Carbin, an assistant professor of electrical engineering and computer science — all members of the Computer Science and Artificial Science Laboratory.  

The search for a better compression technique grew out of Frankle and Carbin’s award-winning Lottery Ticket Hypothesis paper at ICLR last year. They showed that a deep neural network could perform with only one-tenth the number of connections if the right subnetwork was found early in training. Their revelation came as demand for computing power and energy to train ever larger deep learning models was increasing exponentially, a trend that continues to this day. Costs of that growth include a rise in planet-warming carbon emissions and a potential drop in innovation as researchers not affiliated with big tech companies compete for scarce computing resources. Everyday users are affected, too. Big AI models eat up mobile-phone bandwidth and battery power.

But at a colleague’s suggestion, Frankle decided to see what lessons it might hold for pruning, a set of techniques for reducing the size of a neural network by removing unnecessary connections or neurons. Pruning algorithms had been around for decades, but the field saw a resurgence after the breakout success of neural networks at classifying images in the ImageNet competition. As models got bigger, with researchers adding on layers of artificial neurons to boost performance, others proposed techniques for whittling them down. 

Song Han, now an assistant professor at MIT, was one pioneer. Building on a series of influential papers, Han unveiled a pruning algorithm he called AMC, or AutoML for model compression, that’s still the industry standard. Under Han’s technique, redundant neurons and connections are automatically removed, and the model is retrained to restore its initial accuracy. 

In response to Han’s work, Frankle recently suggested in an unpublished paper that results could be further improved by rewinding the smaller, pruned model to its initial parameters, or weights, and retraining the smaller model at its faster, initial rate. 

In the current ICLR study, the researchers realized that the model could simply be rewound to its early training rate without fiddling with any parameters. In any pruning regimen, the tinier a model gets, the less accurate it becomes. But when the researchers compared this new method to Han’s AMC or Frankle’s weight-rewinding methods, it performed better no matter how much the model shrank. 

It’s unclear why the pruning technique works as well as it does. The researchers say they will leave that question for others to answer. As for those who wish to try it, the algorithm is as easy to implement as other pruning methods, without time-consuming tuning, the researchers say. 

“It’s the pruning algorithm from the ‘Book,’” says Frankle. “It’s clear, generic, and drop-dead simple.”

Han, for his part, has now partly shifted focus from compression AI models to channeling AI to design small, efficient models from the start. His newest method, Once for All, also debuts at ICLR. Of the new learning rate method, he says: “I’m happy to see new pruning and retraining techniques evolve, giving more people access to high-performing AI applications.” 

Support for the study came from the Defense Advanced Research Projects Agency, Google, MIT-IBM Watson AI Lab, MIT Quest for Intelligence, and the U.S. Office of Naval Research.

Read More



Curated samples

Provided with genre, artist, and lyrics as input, Jukebox outputs a new music sample produced from scratch. Below, we show some of our favorite samples.
To hear all uncurated samples, check out our sample explorer.

Explore All Samples

Motivation and prior work


Automatic music generation dates back to more than half a century. A prominent approach is to generate music symbolically in the form of a piano roll, which specifies the timing, pitch, velocity, and instrument of each note to be played. This has led to impressive results like producing Bach chorals, polyphonic music with multiple instruments, as well as minute long musical pieces.

But symbolic generators have limitations—they cannot capture human voices or many of the more subtle timbres, dynamics, and expressivity that are essential to music. A different approach[1] is to model music directly as raw audio. Generating music at the audio level is challenging since the sequences are very long. A typical 4-minute song at CD quality (44 kHz, 16-bit) has over 10 million timesteps. For comparison, GPT-2 had 1,000 timesteps and OpenAI Five took tens of thousands of timesteps per game. Thus, to learn the high level semantics of music, a model would have to deal with extremely long-range dependencies.

One way of addressing the long input problem is to use an autoencoder that compresses raw audio to a lower-dimensional space by discarding some of the perceptually irrelevant bits of information. We can then train a model to generate audio in this compressed space, and upsample back to the raw audio space.

We chose to work on music because we want to continue to push the boundaries of generative models. Our previous work on MuseNet explored synthesizing music based on large amounts of MIDI data. Now in raw audio, our models must learn to tackle high diversity as well as very long range structure, and the raw audio domain is particularly unforgiving of errors in short, medium, or long term timing.


Raw audio 44.1k samples per second, where each sample is a float that represents the amplitude of sound at that moment in time
Encode using CNNs (convolutional neural networks)


Compressed audio 344 samples per second, where each sample is 1 of 2048 possible vocab tokens
Generate novel patterns from trained transformer conditioned on lyrics


Novel compressed audio 344 samples per second
Upsample using transformers and decode using CNNs


Novel raw audio 44.1k samples per second


Compressing music to discrete codes

Jukebox’s autoencoder model compresses audio to a discrete space, using a quantization-based approach called VQ-VAE. Hierarchical VQ-VAEs can generate short instrumental pieces from a few sets of instruments, however they suffer from hierarchy collapse due to use of successive encoders coupled with autoregressive decoders. A simplified variant called VQ-VAE-2 avoids these issues by using feedforward encoders and decoders only, and they show impressive results at generating high-fidelity images.

We draw inspiration from VQ-VAE-2 and apply their approach to music. We modify their architecture as follows:

  • To alleviate codebook collapse common to VQ-VAE models, we use random restarts where we randomly reset a codebook vector to one of the encoded hidden states whenever its usage falls below a threshold.
  • To maximize the use of the upper levels, we use separate decoders and independently reconstruct the input from the codes of each level.
  • To allow the model to reconstruct higher frequencies easily, we add a spectral loss that penalizes the norm of the difference of input and reconstructed spectrograms.

We use three levels in our VQ-VAE, shown below, which compress the 44kHz raw audio by 8x, 32x, and 128x, respectively, with a codebook size of 2048 for each level. This downsampling loses much of the audio detail, and sounds noticeably noisy as we go further down the levels. However, it retains essential information about the pitch, timbre, and volume of the audio.

Each VQ-VAE level independently encodes the input. The bottom level encoding produces the highest quality reconstruction, while the top level encoding retains only the essential musical information.
To generate novel songs, a cascade of transformers generates codes from top to bottom level, after which the bottom-level decoder can convert them to raw audio.



Generating codes using transformers

Next, we train the prior models whose goal is to learn the distribution of music codes encoded by VQ-VAE and to generate music in this compressed discrete space. Like the VQ-VAE, we have three levels of priors: a top-level prior that generates the most compressed codes, and two upsampling priors that generate less compressed codes conditioned on above.

The top-level prior models the long-range structure of music, and samples decoded from this level have lower audio quality but capture high-level semantics like singing and melodies. The middle and bottom upsampling priors add local musical structures like timbre, significantly improving the audio quality.

We train these as autoregressive models using a simplified variant of Sparse Transformers. Each of these models has 72 layers of factorized self-attention on a context of 8192 codes, which corresponds to approximately 24 seconds, 6 seconds, and 1.5 seconds of raw audio at the top, middle and bottom levels, respectively.

Once all of the priors are trained, we can generate codes from the top level, upsample them using the upsamplers, and decode them back to the raw audio space using the VQ-VAE decoder to sample novel songs.


To train this model, we crawled the web to curate a new dataset of 1.2 million songs (600,000 of which are in English), paired with the corresponding lyrics and metadata from LyricWiki. The metadata includes artist, album genre, and year of the songs, along with common moods or playlist keywords associated with each song. We train on 32-bit, 44.1 kHz raw audio, and perform data augmentation by randomly downmixing the right and left channels to produce mono audio.

Artist and genre conditioning

The top-level transformer is trained on the task of predicting compressed audio tokens. We can provide additional information, such as the artist and genre for each song. This has two advantages: first, it reduces the entropy of the audio prediction, so the model is able to achieve better quality in any particular style; second, at generation time, we are able to steer the model to generate in a style of our choosing.

This t-SNE below shows how the model learns, in an unsupervised way, to cluster similar artists and genres close together, and also makes some surprising associations like Jennifer Lopez being so close to Dolly Parton!

Lyrics conditioning

In addition to conditioning on artist and genre, we can provide more context at training time by conditioning the model on the lyrics for a song. A significant challenge is the lack of a well-aligned dataset: we only have lyrics at a song level without alignment to the music, and thus for a given chunk of audio we don’t know precisely which portion of the lyrics (if any) appear. We also may have song versions that don’t match the lyric versions, as might occur if a given song is performed by several different artists in slightly different ways. Additionally, singers frequently repeat phrases, or otherwise vary the lyrics, in ways that are not always captured in the written lyrics.

To match audio portions to their corresponding lyrics, we begin with a simple heuristic that aligns the characters of the lyrics to linearly span the duration of each song, and pass a fixed-size window of characters centered around the current segment during training. While this simple strategy of linear alignment worked surprisingly well, we found that it fails for certain genres with fast lyrics, such as hip hop. To address this, we use Spleeter to extract vocals from each song and run NUS AutoLyricsAlign on the extracted vocals to obtain precise word-level alignments of the lyrics. We chose a large enough window so that the actual lyrics have a high probability of being inside the window.

To attend to the lyrics, we add an encoder to produce a representation for the lyrics, and add attention layers that use queries from the music decoder to attend to keys and values from the lyrics encoder. After training, the model learns a more precise alignment.


Lyric–music alignment learned by encoder–decoder attention layer
Attention progresses from one lyric token to the next as the music progresses, with a few moments of uncertainty.


While Jukebox represents a step forward in musical quality, coherence, length of audio sample, and ability to condition on artist, genre, and lyrics, there is a significant gap between these generations and human-created music.

For example, while the generated songs show local musical coherence, follow traditional chord patterns, and can even feature impressive solos, we do not hear familiar larger musical structures such as choruses that repeat. Our downsampling and upsampling process introduces discernable noise. Improving the VQ-VAE so its codes capture more musical information would help reduce this. Our models are also slow to sample from, because of the autoregressive nature of sampling. It takes approximately 9 hours to fully render one minute of audio through our models, and thus they cannot yet be used in interactive applications. Using techniques that distill the model into a parallel sampler can significantly speed up the sampling speed. Finally, we currently train on English lyrics and mostly Western music, but in the future we hope to include songs from other languages and parts of the world.

Future directions

Our audio team is continuing to work on generating audio samples conditioned on different kinds of priming information. In particular, we’ve seen early success conditioning on MIDI files and stem files. Here’s an example of a raw audio sample conditioned on MIDI tokens. We hope this will improve the musicality of samples (in the way conditioning on lyrics improved the singing), and this would also be a way of giving musicians more control over the generations. We expect human and model collaborations to be an increasingly exciting creative space. If you’re excited to work on these problems with us, we’re hiring.

As generative modeling across various domains continues to advance, we are also conducting research into issues like bias and intellectual property rights, and are engaging with people who work in the domains where we develop tools. To better understand future implications for the music community, we shared Jukebox with an initial set of 10 musicians from various genres to discuss their feedback on this work. While Jukebox is an interesting research result, these musicians did not find it immediately applicable to their creative process given some of its current limitations. We are connecting with the wider creative community as we think generative work across text, images, and audio will continue to improve. If you’re interested in being a creative collaborator to help us build useful tools or new works of art in these domains, please let us know!

Creative Collaborator Sign-Up

To connect with the corresponding authors, please email jukebox@openai.com.


  • Our first raw audio model, which learns to recreate instruments like Piano and Violin. We try a dataset of rock and pop songs, and surprisingly it works.

  • We collect a larger and more diverse dataset of songs, with labels for genres and artists. Model picks up artist and genre styles more consistently with diversity, and at convergence can also produce full-length songs with long-range coherence.

  • We scale our VQ-VAE from 22 to 44kHz to achieve higher quality audio. We also scale top-level prior from 1B to 5B to capture the increased information. We see better musical quality, clear singing, and long-range coherence. We also make novel completions of real songs.

  • We start training models conditioned on lyrics to incorporate further conditioning information. We only have unaligned lyrics, so model has to learn alignment and pronunciation, as well as singing.


TFRT: A new TensorFlow runtime

TFRT: A new TensorFlow runtime

Posted by Eric Johnson, TFRT Product Manager and Mingsheng Hong, TFRT Tech Lead/Manager

TensorFlow aims to make it easy for you to build and deploy ML models across many different devices. Yet, what it means to “build and deploy ML models” is not static and continues to change with increased investment in the ML ecosystem.

At the top-half of the TensorFlow stack, innovation is leading to more complex models and deployment scenarios. Researchers are inventing new algorithms that require more compute, while application developers are enhancing their products with these new techniques across edge and server.

At the bottom-half of the stack, the tension from increasing compute needs and rising compute costs due to the ending of Moore’s law has sparked a proliferation of new hardware aimed at specific ML use cases. Traditional chip makers, startups, and software companies alike (including Google) have invested in specialized silicon.

The result is that the needs of the ML ecosystem are vastly different than they were 4 or 5 years ago when TensorFlow was first created. Of course, we’ve continued to iterate with the release of 2.x, but the current TensorFlow stack is optimized for graph execution, and incurs non-trivial overhead when dispatching a single op. A high-performance low-level runtime is a key to enable the trends of today and empower the innovations of tomorrow.

Enter TFRT, a new TensorFlow RunTime. It aims to provide a unified, extensible infrastructure layer with best-in-class performance across a wide variety of domain specific hardware. It provides efficient use of multithreaded host CPUs, supports fully asynchronous programming models, and focuses on low-level efficiency.

TFRT will benefit a broad range of users, including:

  • Researchers looking for faster iteration time and better error reporting when developing complex new models in eager mode.
  • Application developers looking for improved performance when training and serving models in production.
  • Hardware makers looking to integrate edge and datacenter devices into TensorFlow in a modular way.

What is TFRT?

TFRT is a new runtime that will replace the existing TensorFlow runtime. It is responsible for efficient execution of kernels – low-level device-specific primitives – on targeted hardware. It plays a critical part in both eager and graph execution, which is illustrated by this simplified diagram of the TensorFlow training stack:

TFRT’s role in graph and eager execution within the TensorFlow training stack

Note that everything in grey is part of TFRT. In eager execution, TensorFlow APIs call directly into the new runtime. In graph execution, your program’s computational graph is lowered to an optimized target-specific program and dispatched to TFRT. In both execution paths, the new runtime invokes a set of kernels that call into the underlying hardware devices to complete the model execution, as shown by the black arrows.

Key design points

Whereas the existing TensorFlow runtime was initially built for graph execution and training workloads, the new runtime will make eager execution and inference first-class citizens, while putting special emphasis on architecture extensibility and modularity. More specifically, TFRT has the following selected design highlights:

  • To achieve higher performance, TFRT has a lock-free graph executor that supports concurrent op execution with low synchronization overhead, and a thin eager op dispatch stack so that eager API calls will be asynchronous and more efficient.
  • To make extending the TF stack easier, we decoupled device runtimes from the host runtime, the core TFRT component that drives host CPU and I/O work.
  • To get consistent behavior, TFRT leverages common abstractions, such as shape functions and kernels, across both eager and graph.

The power of MLIR

TFRT is also tightly-integrated with MLIR. For example:

  • TFRT utilizes MLIR’s compiler infrastructure to generate an optimized, target-specific representation of your computational graph that the runtime executes.
  • TFRT uses MLIR’s extensible type system to support arbitrary C++ types in the runtime, which removes tensor-specific limitations.

Together, TFRT and MLIR will improve TensorFlow’s unification, flexibility, and extensibility.

Initial Results

Early performance results from the inference and serving use case are encouraging. As part of a benchmarking study for TensorFlow Dev Summit 2020, we integrated TFRT with TensorFlow Serving and measured the latency of sending requests to the model and getting prediction results back. We picked a common MLPerf model, ResNet-50, and chose a batch size of 1 and a data precision of FP16 to focus our study on runtime related op dispatch overhead. In comparing performance of GPU inference over TFRT to the current runtime, we saw an improvement of 28% in average inference time. These early results are strong validation for TFRT, and we expect it to provide a big boost to performance. We hope you are as excited as we are!

What’s next

TFRT is being integrated with TensorFlow, and will be enabled initially through an opt-in flag, giving the team time to fix any bugs and fine-tune performance. Eventually, it will become TensorFlow’s default runtime. Although it is still an early stage project, we have made the GitHub repository available to the community. We are limiting contributions to begin with, but encourage participation in the form of requirements and design discussions.

To learn more, please check out our Dev Summit 2020 presentation, where we first introduced TFRT to the world, and our MLIR Open Design Deep Dive presentation, where we provided a detailed overview of TFRT’s core components, low-level abstractions, and general design principles. And finally, if you want to keep up with all things TFRT, please join our new mailing list. Thanks!

Read More

Automating the search for entirely new “curiosity” algorithms

Driven by an innate curiosity, children pick up new skills as they explore the world and learn from their experience. Computers, by contrast, often get stuck when thrown into new environments.

To get around this, engineers have tried encoding simple forms of curiosity into their algorithms with the hope that an agent pushed to explore will learn about its environment more effectively. An agent with a child’s curiosity might go from learning to pick up, manipulate, and throw objects to understanding the pull of gravity, a realization that could dramatically accelerate its ability to learn many other things. 

Engineers have discovered many ways of encoding curious exploration into machine learning algorithms. A research team at MIT wondered if a computer could do better, based on a long history of enlisting computers in the search for new algorithms. 

In recent years, the design of deep neural networks, algorithms that search for solutions by adjusting numeric parameters, has been automated with software like Google’s AutoML and auto-sklearn in Python. That’s made it easier for non-experts to develop AI applications. But while deep nets excel at specific tasks, they have trouble generalizing to new situations. Algorithms expressed in code, in a high-level programming language, by contrast, have the capacity to transfer knowledge across different tasks and environments. 

“Algorithms designed by humans are very general,” says study co-author Ferran Alet, a graduate student in MIT’s Department of Electrical Engineering and Computer Science and Computer Science and Artificial Intelligence Laboratory (CSAIL). “We were inspired to use AI to find algorithms with curiosity strategies that can adapt to a range of environments.”

The researchers created a “meta-learning” algorithm that generated 52,000 exploration algorithms. They found that the top two were entirely new — seemingly too obvious or counterintuitive for a human to have proposed. Both algorithms generated exploration behavior that substantially improved learning in a range of simulated tasks, from navigating a two-dimensional grid based on images to making a robotic ant walk. Because the meta-learning process generates high-level computer code as output, both algorithms can be dissected to peer inside their decision-making processes.

The paper’s senior authors are Leslie Kaelbling and Tomás Lozano-Pérez, both professors of computer science and electrical engineering at MIT. The work will be presented at the virtual International Conference on Learning Representations later this month. 

The paper received praise from researchers not involved in the work. “The use of program search to discover a better intrinsic reward is very creative,” says Quoc Le, a principal scientist at Google who has helped pioneer computer-aided design of deep learning models. “I like this idea a lot, especially since the programs are interpretable.”

The researchers compare their automated algorithm design process to writing sentences with a limited number of words. They started by choosing a set of basic building blocks to define their exploration algorithms. After studying other curiosity algorithms for inspiration, they picked nearly three dozen high-level operations, including basic programs and deep learning models, to guide the agent to do things like remember previous inputs, compare current and past inputs, and use learning methods to change its own modules. The computer then combined up to seven operations at a time to create computation graphs describing 52,000 algorithms. 

Even with a fast computer, testing them all would have taken decades. So, instead, the researchers limited their search by first ruling out algorithms predicted to perform poorly, based on their code structure alone. Then, they tested their most promising candidates on a basic grid-navigation task requiring substantial exploration but minimal computation. If the candidate did well, its performance became the new benchmark, eliminating even more candidates. 

Four machines searched over 10 hours to find the best algorithms. More than 99 percent were junk, but about a hundred were sensible, high-performing algorithms. Remarkably, the top 16 were both novel and useful, performing as well as, or better than, human-designed algorithms at a range of other virtual tasks, from landing a moon rover to raising a robotic arm and moving an ant-like robot in a physical simulation. 

All 16 algorithms shared two basic exploration functions. 

In the first, the agent is rewarded for visiting new places where it has a greater chance of making a new kind of move. In the second, the agent is also rewarded for visiting new places, but in a more nuanced way: One neural network learns to predict the future state while a second recalls the past, and then tries to predict the present by predicting the past from the future. If this prediction is erroneous it rewards itself, as it is a sign that it discovered something it didn’t know before. The second algorithm was so counterintuitive it took the researchers time to figure out. 

“Our biases often prevent us from trying very novel ideas,” says Alet. “But computers don’t care. They try, and see what works, and sometimes we get great unexpected results.”

More researchers are turning to machine learning to design better machine learning algorithms, a field known as AutoML. At Google, Le and his colleagues recently unveiled a new algorithm-discovery tool called Auto-ML Zero. (Its name is a play on Google’s AutoML software for customizing deep net architectures for a given application, and Google DeepMind’s Alpha Zero, the program that can learn to play different board games by playing millions of games against itself.) 

Their method searches through a space of algorithms made up of simpler primitive operations. But rather than look for an exploration strategy, their goal is to discover algorithms for classifying images. Both studies show the potential for humans to use machine-learning methods themselves to create novel, high-performing machine-learning algorithms.

“The algorithms we generated could be read and interpreted by humans, but to actually understand the code we had to reason through each variable and operation and how they evolve with time,” says study co-author Martin Schneider, a graduate student at MIT. “It’s an interesting open challenge to design algorithms and workflows that leverage the computer’s ability to evaluate lots of algorithms and our human ability to explain and improve on those ideas.” 

The research received support from the U.S. National Science Foundation, Air Force Office of Scientific Research, Office of Naval Research, Honda Research Institute, SUTD Temasek Laboratories, and MIT Quest for Intelligence.

Read More

3 Questions: Tom Leighton on the major surge in internet traffic triggered by physical distancing

With various physical distancing guidelines in place throughout the world as a means to curb the spread of Covid-19, the internet has experienced a dramatic spike in overall traffic. MIT Professor Tom Leighton is chief executive officer and co-founder of Akamai Technologies, a global content delivery network, cybersecurity, and cloud service company that provides web and internet security services. At MIT he specializes in applied mathematics in the Department of Mathematics and is a member of the Computer Science and Artificial Intelligence Laboratory (CSAIL). The Department of Mathematics Communications spoke to Leighton about his company’s response to the world’s increased reliance on the internet during the Covid-19 pandemic.

Q: How is the pandemic changing the way people use the internet?

A: The internet has become our lifeline as we face the challenges of working remotely, distance learning, and sheltering in place. Everything has moved online: religious services, movie premieres, commerce of all kinds, and even gatherings of friends for a cup of coffee. We’ve already been doing many of these things online for years — the big difference now is that we are suddenly only doing them online.

When we’ve emerged from the pandemic, it seems quite possible that our usage of the internet for nearly every facet of our lives will have increased permanently. Many more people may be working remotely even when offices reopen; the shift to virtual meetings may become the norm even when we can travel again; a much greater share of commerce may be conducted online even when we can return to shopping malls; and our usage of social media and video streaming could well be greater than ever before, even when it’s OK to meet others in person.

Q: How much more use is the internet seeing as a result of the pandemic?

A: Akamai operates a globally distributed intelligent edge platform with more than 270,000 servers in 4,000 locations across 137 countries. From our vantage point, we can see that global internet traffic increased by about 30 percent during the past month. That’s about 10 times normal, and it means we’ve seen an entire year’s worth of growth in internet traffic in just the past few weeks. And that’s without any live sports streaming, like the usual March Madness college basketball tournament in the United States.

Just a few weeks ago, we set a new peak record of traffic on the Akamai edge platform of 167 terabits per second. That’s more than double the peak we saw one year before. These are truly unprecedented times. The internet is being used at a scale that the world has never experienced.

Q: Can the internet keep up with the surge in traffic?

A: The answer is yes, but with many more caveats now.

Around the world, some regulators, major carriers, and content providers are taking steps to reduce load during peak traffic times in an effort to avert online gridlock. For example, European regulators have asked telecom providers and streaming platforms to switch to standard definition video during periods of peak demand. And Akamai is working with leading companies such as Microsoft and Sony to deliver software updates for e-gaming at off-peak traffic times. The typical software update uses as much traffic as about 30,000 web pages, so this makes a big difference when it comes to managing congestion.

In addition, Akamai’s intelligent edge network architecture is designed to mitigate and minimize network congestion. Because we’ve deployed our infrastructure deep into carrier networks, we can help those networks avoid overload by diverting traffic away from areas experiencing high levels of congestion.

Overall, we fully expect to maintain the integrity and reliability of website and mobile application delivery, as well as security services, for all of our customers during this time. In particular, Akamai customers across sectors such as government, health care, financial services, commerce, manufacturing, and business services should not experience any change in the performance of their services. We will continue working with governments, network operators, and our customers to minimize stress on the system. At the same time, we’ll do our best to make sure that everyone who is relying on the internet for their work, studies, news, and entertainment continues to have a high-quality, positive experience.

Read More

MIT conference reveals the power of using artificial intelligence to discover new drugs

Developing drugs to combat Covid-19 is a global priority, requiring communities to come together to fight the spread of infection. At MIT, researchers with backgrounds in machine learning and life sciences are collaborating, sharing datasets and tools to develop machine learning methods that can identify novel cures for Covid-19.

This research is an extension of a community effort launched earlier this year. In February, before the Institute de-densified as a result of the pandemic, the first-ever AI Powered Drug Discovery and Manufacturing Conference, conceived and hosted by the Abdul Latif Jameel Clinic for Machine Learning in Health, drew attendees including pharmaceutical industry researchers, government regulators, venture capitalists, and pioneering drug researchers. More than 180 health care companies and 29 universities developing new artificial intelligence methods used in pharmaceuticals got involved, making the conference a singular event designed to lift the mask and reveal what goes on in the process of drug discovery.

As secretive as Silicon Valley seems, computer science and engineering students typically know what a job looks like when aspiring to join companies like Facebook or Tesla. But the global head of research and development for Janssen — the innovative pharmaceutical company owned by Johnson & Johnson — said it’s often much harder for students to grasp how their work fits into drug discovery.

“That’s a problem at the moment,” Mathai Mammen says, after addressing attendees, including MIT graduate students and postdocs, who gathered in the Samberg Conference Center in part to get a glimpse behind the scenes of companies currently working on bold ideas blending artificial intelligence with health care. Mathai, who is a graduate of the Harvard-MIT Program in Health Sciences and Technology and whose work at Theravance has brought to market five new medicines and many more on their way, is here to be part of the answer to that problem. “What the industry needs to do, is talk to students and postdocs about the sorts of interesting scientific and medical problems whose solutions can directly and profoundly benefit the health of people everywhere” he says.

“The conference brought together research communities that rarely overlap at technical conferences,” says Regina Barzilay, the Delta Electronics Professor of Electrical Engineering and Computer Science, Jameel Clinic faculty co-lead, and one of the conference organizers. “This blend enables us to better understand open problems and opportunities in the intersection. The exciting piece for MIT students, especially for computer science and engineering students, is to see where the industry is moving and to understand how they can contribute to this changing industry, which will happen when they graduate.”

Over two days, conference attendees snapped photographs through a packed schedule of research presentations, technical sessions, and expert panels, covering everything from discovering new therapeutic molecules with machine learning to funding AI research. Carefully curated, the conference provided a roadmap of bold tech ideas at work in health care now and traced the path to show how those tech solutions get implemented.

At the conference, Barzilay and Jim Collins, the Termeer Professor of Medical Engineering and Science in MIT’s Institute for Medical Engineering and Science (IMES) and Department of Biological Engineering, and Jameel Clinic faculty co-lead, presented research from a study published in Cell where they used machine learning to help identify a new drug that can target antibiotic-resistant bacteria. Together with MIT researchers Tommi Jaakkola, Kevin Yang, Kyle Swanson, and the first author Jonathan Stokes, they demonstrated how blending their backgrounds can yield potential answers to combat the growing antibiotic resistance crisis.

Collins saw the conference as an opportunity to inspire interest in antibiotic research, hoping to get the top young minds involved in battling resistance to antibiotics built up over decades of overuse and misuse, an urgent predicament in medicine that computer science students might not understand their role in solving. “I think we should take advantage of the innovation ecosystem at MIT and the fact that there are many experts here at MIT who are willing to step outside their comfort zone and get engaged in a new problem,” Collins says. “Certainly in this case, the development and discovery of novel antibiotics, is critically needed around the globe.”

AIDM showed the power of collaboration, inviting experts from major health-care companies and relevant organizations like Merck, Bayer, Darpa, Google, Pfizer, Novartis, Amgen, the U.S. Food and Drug Administration, and Janssen. Reaching capacity for conference attendees, it also showed people are ready to pull together to get on the same page. “I think the time is right and I think the place is right,” Collins says. “I think MIT is well-positioned to be a national, if not an international leader in this space, given the excitement and engagement of our students and our position in Kendall Square.”

A biotech hub for decades, Kendall Square has come a long way since big data came to Cambridge, Massachusetts, forever changing life science companies based here. AIDM kicked off with Institute Professor and Professor of Biology Phillip Sharp walking attendees through a brief history of AI in health care in the area. He was perhaps the person at the conference most excited for others to see the potential, as through his long career, he’s watched firsthand the history of innovation that led to this conference.

“The bigger picture, which this conference is a major part of, is this bringing together of the life science — biologists and chemists with machine learning and artificial intelligence — it’s the future of life science,” Sharp says. “It’s clear. It will reshape how we talk about our science, how we think about solving problems, how we deal with the other parts of the process of taking insights to benefit society.”

Read More

Muscle signals can pilot a robot

Albert Einstein famously postulated that “the only real valuable thing is intuition,” arguably one of the most important keys to understanding intention and communication. 

But intuitiveness is hard to teach — especially to a machine. Looking to improve this, a team from MIT’s Computer Science and Artificial Intelligence Laboratory (CSAIL) came up with a method that dials us closer to more seamless human-robot collaboration. The system, called “Conduct-A-Bot,” uses human muscle signals from wearable sensors to pilot a robot’s movement. 

“We envision a world in which machines help people with cognitive and physical work, and to do so, they adapt to people rather than the other way around,” says Professor Daniela Rus, director of CSAIL, deputy dean of research for the MIT Stephen A. Schwarzman College of Computing, and co-author on a paper about the system. 

To enable seamless teamwork between people and machines, electromyography and motion sensors are worn on the biceps, triceps, and forearms to measure muscle signals and movement. Algorithms then process the signals to detect gestures in real time, without any offline calibration or per-user training data. The system uses just two or three wearable sensors, and nothing in the environment — largely reducing the barrier to casual users interacting with robots.

While Conduct-A-Bot could potentially be used for various scenarios, including navigating menus on electronic devices or supervising autonomous robots, for this research the team used a Parrot Bebop 2 drone, although any commercial drone could be used.

By detecting actions like rotational gestures, clenched fists, tensed arms, and activated forearms, Conduct-A-Bot can move the drone left, right, up, down, and forward, as well as allow it to rotate and stop. 

If you gestured toward the right to your friend, they could likely interpret that they should move in that direction. Similarly, if you waved your hand to the left, for example, the drone would follow suit and make a left turn. 

In tests, the drone correctly responded to 82 percent of over 1,500 human gestures when it was remotely controlled to fly through hoops. The system also correctly identified approximately 94 percent of cued gestures when the drone was not being controlled.

“Understanding our gestures could help robots interpret more of the nonverbal cues that we naturally use in everyday life,” says Joseph DelPreto, lead author on the new paper. “This type of system could help make interacting with a robot more similar to interacting with another person, and make it easier for someone to start using robots without prior experience or external sensors.” 

This type of system could eventually target a range of applications for human-robot collaboration, including remote exploration, assistive personal robots, or manufacturing tasks like delivering objects or lifting materials. 

These intelligent tools are also consistent with social distancing — and could potentially open up a realm of future contactless work. For example, you can imagine machines being controlled by humans to safely clean a hospital room, or drop off medications, while letting us humans stay a safe distance.

Muscle signals can often provide information about states that are hard to observe from vision, such as joint stiffness or fatigue.    

For example, if you watch a video of someone holding a large box, you might have difficulty guessing how much effort or force was needed — and a machine would also have difficulty gauging that from vision alone. Using muscle sensors opens up possibilities to estimate not only motion, but also the force and torque required to execute that physical trajectory.

For the gesture vocabulary currently used to control the robot, the movements were detected as follows: 

  • stiffening the upper arm to stop the robot (similar to briefly cringing when seeing something going wrong): biceps and triceps muscle signals;

  • waving the hand left/right and up/down to move the robot sideways or vertically: forearm muscle signals (with the forearm accelerometer indicating hand orientation);

  • fist clenching to move the robot forward: forearm muscle signals; and

  • rotating clockwise/counterclockwise to turn the robot: forearm gyroscope.

Machine learning classifiers detected the gestures using the wearable sensors. Unsupervised classifiers processed the muscle and motion data and clustered it in real time to learn how to separate gestures from other motions. A neural network also predicted wrist flexion or extension from forearm muscle signals.  

The system essentially calibrates itself to each person’s signals while they’re making gestures that control the robot, making it faster and easier for casual users to start interacting with robots.

In the future, the team hopes to expand the tests to include more subjects. And while the movements for Conduct-A-Bot cover common gestures for robot motion, the researchers want to extend the vocabulary to include more continuous or user-defined gestures. Eventually, the hope is to have the robots learn from these interactions to better understand the tasks and provide more predictive assistance or increase their autonomy. 

“This system moves one step closer to letting us work seamlessly with robots so they can become more effective and intelligent tools for everyday tasks,” says DelPreto. “As such collaborations continue to become more accessible and pervasive, the possibilities for synergistic benefit continue to deepen.” 

DelPreto and Rus presented the paper virtually earlier this month at the ACM/IEEE International Conference on Human Robot Interaction.

Read More

SAIL at ICLR 2020: Accepted Papers and Videos

SAIL at ICLR 2020: Accepted Papers and Videos

The International Conference on Learning Representations (ICLR) 2020 is being hosted virtually from April 26th – May 1st. We’re excited to share all the work from SAIL that’s being presented, and you’ll find links to papers, videos and blogs below. Feel free to reach out to the contact authors directly to learn more about the work that’s happening at Stanford!

List of Accepted Papers

Hierarchical Foresight: Self-Supervised Learning of Long-Horizon Tasks via Visual Subgoal Generation


Suraj Nair, Chelsea Finn | contact: surajn@stanford.edu
keywords: visual planning; reinforcement learning; robotics

Active World Model Learning with Progress Curiosity


Kuno Kim, Megumi Sano, Julian De Freitas, Nick Haber, Dan Yamins | contact: khkim@cs.stanford.edu
keywords: curiosity, reinforcement learning, cognitive science

Kaleidoscope: An Efficient, Learnable Representation For All Structured Linear Maps

paper | blog post

Tri Dao, Nimit Sohoni, Albert Gu, Matthew Eichhorn, Amit Blonder, Megan Leszczynski, Atri Rudra, Christopher Ré | contact: trid@stanford.edu
keywords: structured matrices, efficient ml, algorithms, butterfly matrices, arithmetic circuits

Weakly Supervised Disentanglement with Guarantees


Rui Shu, Yining Chen, Abhishek Kumar, Stefano Ermon, Ben Poole | contact: ruishu@stanford.edu
keywords: disentanglement, generative models, weak supervision, representation learning, theory

Depth width tradeoffs for Relu networks via Sharkovsky’s theorem


Vaggos Chatziafratis, Sai Ganesh Nagarajan, Ioannis Panageas, Xiao Wang | contact: vaggos@cs.stanford.edu
keywords: dynamical systems, benefits of depth, expressivity

Watch, Try, Learn: Meta-Learning from Demonstrations and Reward


Allan Zhou, Eric Jang, Daniel Kappler, Alex Herzog, Mohi Khansari, Paul Wohlhart, Yunfei Bai, Mrinal Kalakrishnan, Sergey Levine, Chelsea Finn | contact: ayz@stanford.edu
keywords: imitation learning, meta-learning, reinforcement learning

Assessing robustness to noise: low-cost head CT triage


Sarah Hooper, Jared Dunnmon, Matthew Lungren, Sanjiv Sam Gambhir, Christopher Ré, Adam Wang, Bhavik Patel | contact: smhooper@stanford.edu
keywords: ai for affordable healthcare workshop, medical imaging, sinogram, ct, image noise

Learning transport cost from subset correspondence


Ruishan Liu, Akshay Balsubramani, James Zou | contact: ruishan@stanford.edu
keywords: optimal transport, data alignment, metric learning

Generalization through Memorization: Nearest Neighbor Language Models


Urvashi Khandelwal, Omer Levy, Dan Jurafsky, Luke Zettlemoyer, Mike Lewis | contact: urvashik@stanford.edu
keywords: language models, k-nearest neighbors

Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization


Shiori Sagawa, Pang Wei Koh, Tatsunori B. Hashimoto, Percy Liang | contact: ssagawa@cs.stanford.edu
keywords: distributionally robust optimization, deep learning, robustness, generalization, regularization

Phase Transitions for the Information Bottleneck in Representation Learning


Tailin Wu, Ian Fischer | contact: tailin@cs.stanford.edu
keywords: information theory, representation learning, phase transition

Improving Neural Language Generation with Spectrum Control


Lingxiao Wang, Jing Huang, Kevin Huang, Ziniu Hu, Guangtao Wang, Quanquan Gu | contact: jhuang18@stanford.edu
keywords: neural language generation, pre-trained language model, spectrum control

Understanding and Improving Information Transfer in Multi-Task Learning

paper | blog post

Sen Wu, Hongyang Zhang, Christopher Ré | contact: senwu@cs.stanford.edu
keywords: multi-task learning

Strategies for Pre-training Graph Neural Networks

paper | blog post

Weihua Hu, Bowen Liu, Joseph Gomes, Marinka Zitnik, Percy Liang, Vijay Pande, Jure Leskovec | contact: weihuahu@cs.stanford.edu
keywords: pre-training, transfer learning, graph neural networks

Query2box: Reasoning over Knowledge Graphs in Vector Space using Box Embeddings


Hongyu Ren, Weihua Hu, Jure Leskovec | contact: hyren@cs.stanford.edu
keywords: knowledge graph embeddings

Learning Self-Correctable Policies and Value Functions from Demonstrations with Negative Sampling


Yuping Luo, Huazhe Xu, Tengyu Ma | contact: roosephu@gmail.com
keywords: imitation learning, model-based imitation learning, model-based rl, behavior cloning, covariate shift

Improved Sample Complexities for Deep Neural Networks and Robust Classification via an All-Layer Margin


Colin Wei, Tengyu Ma | contact: colinwei@stanford.edu
keywords: deep learning theory, generalization bounds, adversarially robust generalization, data-dependent generalization bounds

Selection via Proxy: Efficient Data Selection for Deep Learning

paper | blog post | code

Cody Coleman, Christopher Yeh, Stephen Mussmann, Baharan Mirzasoleiman, Peter Bailis, Percy Liang, Jure Leskovec, Matei Zaharia | contact: cody@cs.stanford.edu
keywords: active learning, data selection, deep learning

We look forward to seeing you at ICLR!

Read More

Automating Data Augmentation: Practice, Theory and New Direction

Automating Data Augmentation: Practice, Theory and New Direction

Data augmentation is a de facto technique used in nearly every state-of-the-art machine learning model in applications such as image and text classification. Heuristic data augmentation schemes are often tuned manually by human experts with extensive domain knowledge, and may result in suboptimal augmentation policies. In this blog post, we provide a broad overview of recent efforts in this exciting research area, which resulted in new algorithms for automating the search process of transformation functions, new theoretical insights that improve the understanding of various augmentation techniques commonly used in practice, and a new framework for exploiting data augmentation to patch a flawed model and improve performance on crucial subpopulation of data.

Why Data Augmentation?

Modern machine learning models, such as deep neural networks, may have billions of parameters and require massive labeled training datasets—which are often not available. The technique of artificially expanding labeled training datasets—known as data augmentation—has quickly become critical for combating this data scarcity problem. Today, data augmentation is used as a secret sauce in nearly every state-of-the-art model for image classification, and is becoming increasingly common in other modalities such as natural language understanding as well. The goal of this blog post is to provide an overview of recent efforts in this exciting research area.

Figure 1. Heuristic data augmentations apply a deterministic sequence of transformation functions tuned by human experts.The augmented data will be used for training downstream models.

Heuristic data augmentation schemes often rely on the composition of a set of simple transformation functions (TFs) such as rotations and flips (see Figure 1). When chosen carefully, data augmentation schemes tuned by human experts can improve model performance. However, such heuristic strategies in practice can cause large variances in end model performance, and may not produce augmentations needed for state-of-the-art models.

The Open Challenges in Data Augmentation

The limitations of conventional data augmentation approaches reveal huge opportunities for research advances. Below we summarize a few challenges that motivate some of the works in the area of data augmentation.

  • From manual to automated search algorithms: As opposed to performing suboptimal manual search, how can we design learnable algorithms to find augmentation strategies that can outperform human-designed heuristics?

  • From practical to theoretical understanding: Despite the rapid progress of creating various augmentation approaches pragmatically, understanding their benefits remains a mystery because of a lack of analytic tools. How can we theoretically understand various data augmentations used in practice?

  • From coarse-grained to fine-grained model quality assurance: While most existing data augmentation approaches focus on improving the overall performance of a model, it is often imperative to have a finer-grained perspective on critical subpopulations of data. When a model exhibits inconsistent predictions on important subgroups of data, how can we exploit data augmentations to mitigate the performance gap in a prescribed way?

In this blog, we will describe ideas and recent research works leading the way to overcome these challenges above.

Practical Methods of Learnable Data Augmentations

Learnable data augmentation is promising, in that it allows us to search for more powerful parameterizations and compositions of transformations. Perhaps the biggest difficulty with automating data augmentation is how to search over the space of transformations. This can be prohibitive due to the large number of transformation functions and associated parameters in the search space. How can we design learnable algorithms that explore the space of transformation functions efficiently and effectively, and find augmentation strategies that can outperform human-designed heuristics? In response to the challenge, we highlight a few recent methods below.

TANDA: Transformation Adversarial Networks for Data Augmentations

To address this problem, TANDA (Ratner et al. 2017) proposes a framework to learn augmentations, which models data augmentations as sequences of Transformation Functions (TFs) provided by users. For example, these might include “rotate 5 degrees” or “shift by 2 pixels”. At the core, this framework consists of two components (1) learning a TF sequence generator that results in useful augmented data points, and (2) using the sequence generator to augment training sets for a downstream model. In particular, the TF sequence generator is trained to produce realistic images by having to fool a discriminator network, following the GANs framework (Goodfellow et al. 2014). The underlying assumption here is that the transformations would either lead to realistic images, or indistinguishable garbage images that are off the manifold. As shown in Figure 1, the objective for the generator is to produce sequences of TFs such that the augmented data point can fool the discriminator; whereas the objective for the discriminator is to produce values close to 1 for data points in the original training set and values close to 0 for augmented data points.

Figure 2. Automating data augmentation with TANDA (Ratner et al. 2017). A TF sequence generator is trained adversarially to produce augmented images that are realistic compared to training data.

AutoAugment and Further Improvement
Using a similar framework, AutoAugment (Cubuk et al. 2018) developed by Google demonstrated state-of-the-art performance using learned augmentation policies. In this work, a TF sequence generator learns to directly optimize for validation accuracy on the end model. Several subsequent works including RandAugment (Cubuk et al. 2019) and Adversarial AutoAugment (Zhang et al. 2019) have been proposed to reduce the computational cost of AutoAugment, establishing new state-of-the-art performance on image classification benchmarks.

Theoretical Understanding of Data Augmentations

Despite the rapid progress of practical data augmentation techniques, precisely understanding their benefits remains a mystery. Even for simpler models, it is not well-understood how training on augmented data affects the learning process, the parameters, and the decision surface. This is exacerbated by the fact that data augmentation is performed in diverse ways in modern machine learning pipelines, for different tasks and domains, thus precluding a general model of transformation. How can we theoretically characterize and understand the effect of various data augmentations used in practice? To address this challenge, our lab has studied data augmentation from a kernel perspective, as well as under a simplified linear setting.

Data Augmentation As a Kernel

Dao et al. 2019 developed a theoretical framework by modeling data augmentation as a Markov Chain, in which augmentation is performed via a random sequence of transformations, akin to how data augmentation is performed in practice. We show that the effect of applying the Markov Chain on the training dataset (combined with a k-nearest neighbor classifier) is akin to using a kernel classifier, where the kernel is a function of the base transformations.

Built on the connection between kernel theory and data augmentation, Dao et al. 2019 show that a kernel classifier on augmented data approximately decomposes into two components: (i) an averaged version of the transformed features, and (ii) a data-dependent variance regularization term. This suggests a more nuanced explanation of data augmentation—namely, that it improves generalization both by inducing invariance and by reducing model complexity. Dao et al. 2019 validate the quality of our approximation empirically, and draw connections to other generalization-improving techniques, including recent work on invariant learning (van der Wilk et al. 2018) and robust optimization (Namkoong & Duchi, 2017).

Data Augmentation Under A Simplified Linear Setting

One limitation of the above works is that it is challenging to pin down the effect of applying a particular transformation on the resulting kernel. Furthermore, it is not yet clear how to apply data augmentation efficiently on kernel methods to get comparable performance to neural nets. In more recent work, we consider a simpler linear setting that is capable of modeling a wide range of linear transformations commonly used in image augmentation, as shown in Figure 3.

Theoretical Insights. We offer several theoretical insights by considering an over-parametrized linear model, where the training data lies in a low-dimensional subspace. We show that label-invariant transformations can add new information to the training data, and estimation error of the ridge estimator can be reduced by adding new points that are outside the span of the training data. In addition, we show that mixup (Zhang et al., 2017 can play an effect of regularization through shrinking the weight of the training data relative to the L2 regularization term on the training data.

Figure 3. Illustration of common linear transformations applied in data augmentation.

Theory-inspired New State-of-the-art. One insight from our theoretical investigation is that different (compositions of) transformations show very different end performance. Inspired by this observation, we’d like to make use of the fact that certain transformations are better performing than others. We propose an uncertainty-based random sampling scheme which, among the transformed data points, picks those with the highest losses, i.e. those “providing the most information” (see Figure 4). Our sampling scheme achieves higher accuracy by finding more useful transformations compared to RandAugment on three different CNN architectures, establishing new state-of-the-art performance on common benchmarks. For example, our method outperforms RandAugment by 0.59% on CIFAR-10 and 1.24% on CIFAR-100 using Wide-ResNet-28-10. Please check out our full paper here. Our code will be released soon for you to try out!

Figure 4. Uncertainty-based random sampling scheme for data augmentation. Each transformation function is randomly sampled from a set of pre-specified operations. We select among the transformed data points with highest loss for end model training.

New Direction: Data Augmentations for Model Patching

Most machine learning research carried out today is still solving fixed tasks. However, in the real world, machine learning models in deployment can fail due to unanticipated changes in data distribution. This raises the concerning question of how we can move from model building to model maintenance in an adaptive manner. In our latest work, we propose model patching—the first framework that exploits data augmentation to mitigate the performance issues of a flawed model in deployment.

A Medical Use Case of Model Patching

To provide a concrete example, in skin cancer detection, researchers have shown that standard classifiers have drastically different performance on two subgroups of the cancerous class, due to the classifier’s association between colorful bandages with benign images (see Figure 5, left). This subgroup performance gap has also been studied in parallel research from our group (Oakden-Rayner et al., 2019), and arises due to classifier’s reliance on subgroup-specific features, e.g. colorful bandages.

Figure 5: A standard model trained on a skin cancer dataset exhibits a subgroup performance gap between images of malignant cancers with and without colored bandages. GradCAM illustrates that the vanilla model spuriously associates the colored spot with benign skin lesions. With model patching, the malignancy is predicted correctly for both subgroups.

In order to fix such flaws in a deployed model, domain experts have to resort to manual data cleaning to erase the differences between subgroups, e.g. removing markings on skin cancer data with Photoshop (Winkler et al. 2019), and retrain the model with the modified data. This can be extremely laborious! Can we somehow learn transformations that allow augmenting examples to balance population among groups in a prescribed way? This is exactly what we are addressing through this new framework of model patching.

CLAMP: Class-conditional Learned Augmentations for Model Patching

The conceptual framework of model patching consists of two stages (as shown in Figure 6).

  • Learn inter-subgroup transformations between different subgroups. These transformations are class-preserving maps that allow semantically changing a datapoint’s subgroup identity (e.g. add or remove colorful bandages).
  • Retrain to patch the model with augmented data, encouraging the classifier to be robust to their variations.

Figure 6: Model Patching framework with data augmentation. The highlighted box contains samples from a class with differing performance between subgroups A and B. Conditional generative models are trained to transform examples from one subgroup to another (A->B and B->A) respectively.

We propose CLAMP, an instantiation of our first end-to-end model patching framework. We combine a novel consistency regularizer with a robust training objective that is inspired by recent work of Group Distributionally Robust Optimization (GDRO, Sagawa et al. 2019). We extend GDRO to a class-conditional training objective that jointly optimizes for the worst-subgroup performance in each class. CLAMP is able to balance the performance of subgroups within each class, reducing the performance gap by up to 24x. On a skin cancer detection dataset ISIC, CLAMP improves robust accuracy by 11.7% compared to the robust training baseline. Through visualization, we also show in Figure 5 that CLAMP successfully removes the model’s reliance on the spurious feature (colorful bandages), shifting its attention to the skin lesion—true feature of interest.

Our results suggest that the model patching framework is a promising direction for automating the process of model maintenance. In fact, model patching is becoming a late breaking area that would alleviate the major problem in safety-critical systems, including healthcare (e.g. improving models to produce MRI scans free of artifact) and autonomous driving (e.g. improving perception models that may have poor performance on irregular objects or road conditions). We envision that model patching can be widely useful for many other domain applications. If you are intrigued by the latest research on model patching, please follow our Hazy Research repository on Github where the code will be released soon. If you have any feedback for our drafts and latest work, we’d like to hear from you!

Further Reading


Thanks to members of Hazy Research who provided feedback on the blog post. Special thanks to Sidd Karamcheti and Andrey Kurenkov from the SAIL blog team for the editorial help.

About the Author

Sharon Y. Li is a postdoctoral fellow at Stanford, working with Chris Ré. She is an incoming Assistant Professor in the department of Computer Sciences at University of Wisconsin-Madison. Her research focuses on developing machine learning models and systems that can reduce human supervision during training, and enhance reliability during deployment in the wild.

Read More