Google at ICLR 2022

The 10th International Conference on Learning Representations (ICLR 2022) kicks off this week, bringing together researchers, entrepreneurs, engineers and students alike to discuss and explore the rapidly advancing field of deep learning. Entirely virtual this year, ICLR 2022 offers conference and workshop tracks that present some of the latest research in deep learning and its applications to areas ranging from computer vision, speech recognition and text understanding to robotics, computational biology, and more.

As a Platinum Sponsor of ICLR 2022 and Champion DEI Action Fund contributor, Google will have a robust presence with nearly 100 accepted publications and extensive participation on organizing committees and in workshops. If you have registered for ICLR 2022, we hope you’ll watch our talks and learn about the work done at Google to address complex problems that affect billions of people. Here you can learn more about the research we will be presenting as well as our general involvement at ICLR 2022 (those with Google affiliations in bold).

Senior Area Chairs:
Includes: Been Kim, Dale Schuurmans, Sergey Levine

Area Chairs:
Includes: Adam White, Aditya Menon, Aleksandra Faust, Amin Karbasi, Amir Globerson, Andrew Dai, Balaji Lakshminarayanan, Behnam Neyshabur, Ben Poole, Bhuwan Dhingra, Bo Dai, Boqing Gong, Cristian Sminchisescu, David Ha, David Woodruff, Denny Zhou, Dipanjan Das, Dumitru Erhan, Dustin Tran, Emma Strubell, Eunsol Choi, George Dahl, George Tucker, Hanie Sedghi, Heinrich Jiang, Hossein Mobahi, Hugo Larochelle, Izhak Shafran, Jasper Snoek, Jean-Philippe Vert, Jeffrey Pennington, Justin Gilmer, Karol Hausman, Kevin Swersky, Krzysztof Choromanski, Mathieu Blondel, Matt Kusner, Michael Ryoo, Ming-Hsuan Yang, Minmin Chen, Mirella Lapata, Mohammad Ghavamzadeh, Mohammad Norouzi, Naman Agarwal, Nicholas Carlini, Olivier Bachem, Piyush Rai, Prateek Jain, Quentin Berthet, Richard Nock, Rose Yu, Sewoong Oh, Silvio Lattanzi, Slav Petrov, Srinadh Bhojanapalli, Tim Salimans, Ting Chen, Tong Zhang, Vikas Sindhwani, Weiran Wang, William Cohen, Xiaoming Liu

Workflow Chairs:
Includes: Yaguang Li

Diversity Equity & Inclusion Chairs:
Includes: Rosanne Liu

Invited Talks
Beyond Interpretability: Developing a Language to Shape Our Relationships with AI
Google Speaker: Been Kim

Do You See What I See? Large-Scale Learning from Multimodal Videos
Google Speaker: Cordelia Schmid

Publications
Hyperparameter Tuning with Renyi Differential Privacy – 2022 Outstanding Paper Award
Nicolas Papernot, Thomas Steinke

MIDI-DDSP: Detailed Control of Musical Performance via Hierarchical Modeling
Yusong Wu, Ethan Manilow, Yi Deng, Rigel Swavely, Kyle Kastner, Tim Cooijmans, Aaron Courville, Cheng-Zhi Anna Huang, Jesse Engel

The Information Geometry of Unsupervised Reinforcement Learning
Benjamin Eysenbach, Ruslan Salakhutdinov, Sergey Levine

Learning Strides in Convolutional Neural Networks – 2022 Outstanding Paper Award
Rachid Riad*, Olivier Teboul, David Grangier, Neil Zeghidour

Poisoning and Backdooring Contrastive Learning
Nicholas Carlini, Andreas Terzis

Coordination Among Neural Modules Through a Shared Global Workspace
Anirudh Goyal, Aniket Didolkar, Alex Lamb, Kartikeya Badola, Nan Rosemary Ke, Nasim Rahaman, Jonathan Binas, Charles Blundell, Michael Mozer, Yoshua Bengio

Fine-Tuned Language Models Are Zero-Shot Learners (see the blog post)
Jason Wei, Maarten Bosma, Vincent Y. Zhao, Kelvin Guu, Adams Wei Yu, Brian Lester, Nan Du, Andrew M. Dai, Quoc V. Le

Large Language Models Can Be Strong Differentially Private Learners
Xuechen Li, Florian Tramèr, Percy Liang, Tatsunori Hashimoto

Progressive Distillation for Fast Sampling of Diffusion Models
Tim Salimans, Jonathan Ho

Exploring the Limits of Large Scale Pre-training
Samira Abnar, Mostafa Dehghani, Behnam Neyshabur, Hanie Sedghi

Scarf: Self-Supervised Contrastive Learning Using Random Feature Corruption
Dara Bahri, Heinrich Jiang, Yi Tay, Donald Metzler

Scalable Sampling for Nonsymmetric Determinantal Point Processes
Insu Han, Mike Gartrell, Jennifer Gillenwater, Elvis Dohmatob, Amin Karbasi

When Vision Transformers Outperform ResNets without Pre-training or Strong Data Augmentations
Xiangning Chen, Cho-Jui Hsieh, Boqing Gong

ViTGAN: Training GANs with Vision Transformers
Kwonjoon Lee, Huiwen Chang, Lu Jiang, Han Zhang, Zhuowen Tu, Ce Liu

Generalized Decision Transformer for Offline Hindsight Information Matching
Hiroki Furuta, Yutaka Matsuo, Shixiang Shane Gu

The MultiBERTs: BERT Reproductions for Robustness Analysis
Thibault Sellam, Steve Yadlowsky, Ian Tenney, Jason Wei, Naomi Saphra, Alexander D’Amour, Tal Linzen, Jasmijn Bastings, Iulia Turc, Jacob Eisenstein, Dipanjan Das, Ellie Pavlick

Scaling Laws for Neural Machine Translation
Behrooz Ghorbani, Orhan Firat, Markus Freitag, Ankur Bapna, Maxim Krikun, Xavier Garcia, Ciprian Chelba, Colin Cherry

Interpretable Unsupervised Diversity Denoising and Artefact Removal
Mangal Prakash, Mauricio Delbracio, Peyman Milanfar, Florian Jug

Understanding Latent Correlation-Based Multiview Learning and Self-Supervision: An Identifiability Perspective
Qi Lyu, Xiao Fu, Weiran Wang, Songtao Lu

Memorizing Transformers
Yuhuai Wu, Markus N. Rabe, DeLesley Hutchins, Christian Szegedy

Churn Reduction via Distillation
Heinrich Jiang, Harikrishna Narasimhan, Dara Bahri, Andrew Cotter, Afshin Rostamizadeh

DR3: Value-Based Deep Reinforcement Learning Requires Explicit Regularization
Aviral Kumar, Rishabh Agarwal, Tengyu Ma, Aaron Courville, George Tucker, Sergey Levine

Path Auxiliary Proposal for MCMC in Discrete Space
Haoran Sun, Hanjun Dai, Wei Xia, Arun Ramamurthy

On the Relation Between Statistical Learning and Perceptual Distances
Alexander Hepburn, Valero Laparra, Raul Santos-Rodriguez, Johannes Ballé, Jesús Malo

Possibility Before Utility: Learning And Using Hierarchical Affordances
Robby Costales, Shariq Iqbal, Fei Sha

MT3: Multi-Task Multitrack Music Transcription
Josh Gardner*, Ian Simon, Ethan Manilow*, Curtis Hawthorne, Jesse Engel

Bayesian Neural Network Priors Revisited
Vincent Fortuin, Adrià Garriga-Alonso, Sebastian W. Ober, Florian Wenzel, Gunnar Rätsch, Richard E. Turner, Mark van der Wilk, Laurence Aitchison

GradMax: Growing Neural Networks using Gradient Information
Utku Evci, Bart van Merrienboer, Thomas Unterthiner, Fabian Pedregosa, Max Vladymyrov

Scene Transformer: A Unified Architecture for Predicting Future Trajectories of Multiple Agents
Jiquan Ngiam, Benjamin Caine, Vijay Vasudevan, Zhengdong Zhang, Hao-Tien Lewis Chiang, Jeffrey Ling, Rebecca Roelofs, Alex Bewley, Chenxi Liu, Ashish Venugopal, David Weiss, Ben Sapp, Zhifeng Chen, Jonathon Shlens

The Role of Pretrained Representations for the OOD Generalization of RL Agents
Frederik Träuble, Andrea Dittadi, Manuel Wüthrich, Felix Widmaier, Peter Gehler, Ole Winther, Francesco Locatello, Olivier Bachem, Bernhard Schölkopf, Stefan Bauer

Autoregressive Diffusion Models
Emiel Hoogeboom, Alexey A. Gritsenko, Jasmijn Bastings, Ben Poole, Rianne van den Berg, Tim Salimans

The Role of Permutation Invariance in Linear Mode Connectivity of Neural Networks
Rahim Entezari, Hanie Seghi, Olga Saukh, Behnam Neyshabur

DISSECT: Disentangled Simultaneous Explanations via Concept Traversals
Asma Ghandeharioun, Been Kim, Chun-Liang Li, Brendan Jou, Brian Eoff, Rosalind W. Picard

Anisotropic Random Feature Regression in High Dimensions
Gabriel C. Mel, Jeffrey Pennington

Open-Vocabulary Object Detection via Vision and Language Knowledge Distillation
Xiuye Gu, Tsung-Yi Lin*, Weicheng Kuo, Yin Cui

MCMC Should Mix: Learning Energy-Based Model with Flow-Based Backbone
Erik Nijkamp*, Ruiqi Gao, Pavel Sountsov, Srinivas Vasudevan, Bo Pang, Song-Chun Zhu, Ying Nian Wu

Effect of Scale on Catastrophic Forgetting in Neural Networks
Vinay Ramasesh, Aitor Lewkowycz, Ethan Dyer

Incremental False Negative Detection for Contrastive Learning
Tsai-Shien Chen, Wei-Chih Hung, Hung-Yu Tseng, Shao-Yi Chien, Ming-Hsuan Yang

Towards Evaluating the Robustness of Neural Networks Learned by Transduction
Jiefeng Chen, Xi Wu, Yang Guo, Yingyu Liang, Somesh Jha

What Do We Mean by Generalization in Federated Learning?
Honglin Yuan*, Warren Morningstar, Lin Ning, Karan Singhal

ViDT: An Efficient and Effective Fully Transformer-Based Object Detector
Hwanjun Song, Deqing Sun, Sanghyuk Chun, Varun Jampani, Dongyoon Han, Byeongho Heo, Wonjae Kim, Ming-Hsuan Yang

Measuring CLEVRness: Black-Box Testing of Visual Reasoning Models
Spyridon Mouselinos, Henryk Michalewski, Mateusz Malinowski

Wisdom of Committees: An Overlooked Approach To Faster and More Accurate Models (see the blog post)
Xiaofang Wang, Dan Kondratyuk, Eric Christiansen, Kris M. Kitani, Yair Alon (prev. Movshovitz-Attias), Elad Eban

Leveraging Unlabeled Data to Predict Out-of-Distribution Performance
Saurabh Garg*, Sivaraman Balakrishnan, Zachary C. Lipton, Behnam Neyshabur, Hanie Sedghi

Data-Driven Offline Optimization for Architecting Hardware Accelerators (see the blog post)
Aviral Kumar, Amir Yazdanbakhsh, Milad Hashemi, Kevin Swersky, Sergey Levine

Diurnal or Nocturnal? Federated Learning of Multi-branch Networks from Periodically Shifting Distributions
Chen Zhu*, Zheng Xu, Mingqing Chen, Jakub Konecny, Andrew Hard, Tom Goldstein

Policy Gradients Incorporating the Future
David Venuto, Elaine Lau, Doina Precup, Ofir Nachum

Discrete Representations Strengthen Vision Transformer Robustness
Chengzhi Mao*, Lu Jiang, Mostafa Dehghani, Carl Vondrick, Rahul Sukthankar, Irfan Essa

SimVLM: Simple Visual Language Model Pretraining with Weak Supervision (see the blog post)
Zirui Wang, Jiahui Yu, Adams Wei Yu, Zihang Dai, Yulia Tsvetkov, Yuan Cao

Neural Stochastic Dual Dynamic Programming
Hanjun Dai, Yuan Xue, Zia Syed, Dale Schuurmans, Bo Dai

PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions
Zhaoqi Leng, Mingxing Tan, Chenxi Liu, Ekin Dogus Cubuk, Xiaojie Shi, Shuyang Cheng, Dragomir Anguelov

Information Prioritization Through Empowerment in Visual Model-Based RL
Homanga Bharadhwaj*, Mohammad Babaeizadeh, Dumitru Erhan, Sergey Levine

Value Function Spaces: Skill-Centric State Abstractions for Long-Horizon Reasoning
Dhruv Shah, Peng Xu, Yao Lu, Ted Xiao, Alexander Toshev, Sergey Levine, Brian Ichter

Understanding and Leveraging Overparameterization in Recursive Value Estimation
Chenjun Xiao, Bo Dai, Jincheng Mei, Oscar Ramirez, Ramki Gummadi, Chris Harris, Dale Schuurmans

The Efficiency Misnomer
Mostafa Dehghani, Anurag Arnab, Lucas Beyer, Ashish Vaswani, Yi Tay

On the Role of Population Heterogeneity in Emergent Communication
Mathieu Rita, Florian Strub, Jean-Bastien Grill, Olivier Pietquin, Emmanuel Dupoux

No One Representation to Rule Them All: Overlapping Features of Training Methods
Raphael Gontijo-Lopes, Yann Dauphin, Ekin D. Cubuk

Data Poisoning Won’t Save You From Facial Recognition
Evani Radiya-Dixit, Sanghyun Hong, Nicholas Carlini, Florian Tramèr

AdaMatch: A Unified Approach to Semi-Supervised Learning and Domain Adaptation
David Berthelot, Rebecca Roelofs, Kihyuk Sohn, Nicholas Carlini, Alex Kurakin

Maximum Entropy RL (Provably) Solves Some Robust RL Problems
Benjamin Eysenbach, Sergey Levine

Auto-scaling Vision Transformers Without Training
Wuyang Chen, Wei Huang, Xianzhi Du, Xiaodan Song, Zhangyang Wang, Denny Zhou

Optimizing Few-Step Diffusion Samplers by Gradient Descent
Daniel Watson, William Chan, Jonathan Ho, Mohammad Norouzi

ExT5: Towards Extreme Multi-Task Scaling for Transfer Learning
Vamsi Aribandi, Yi Tay, Tal Schuster, Jinfeng Rao, Huaixiu Steven Zheng, Sanket Vaibhav Mehta, Honglei Zhuang, Vinh Q. Tran, Dara Bahri, Jianmo Ni, Jai Gupta, Kai Hui, Sebastian Ruder, Donald Metzler

Fortuitous Forgetting in Connectionist Networks
Hattie Zhou, Ankit Vani, Hugo Larochelle, Aaron Courville

Evading Adversarial Example Detection Defenses with Orthogonal Projected Gradient Descent
Oliver Bryniarski, Nabeel Hingun, Pedro Pachuca, Vincent Wang, Nicholas Carlini

Benchmarking the Spectrum of Agent Capabilities
Danijar Hafner

Charformer: Fast Character Transformers via Gradient-Based Subword Tokenization
Yi Tay, Vinh Q. Tran, Sebastian Ruder, Jai Gupta, Hyung Won Chung, Dara Bahri, Zhen Qin, Simon Baumgartner, Cong Yu, Donald Metzler

Mention Memory: Incorporating Textual Knowledge into Transformers Through Entity Mention Attention
Michiel de Jong, Yury Zemlyanskiy, Nicholas FitzGerald, Fei Sha, William Cohen

Eigencurve: Optimal Learning Rate Schedule for SGD on Quadratic Objectives with Skewed Hessian Spectrums
Rui Pan, Haishan Ye, Tong Zhang

Scale Efficiently: Insights from Pre-training and Fine-Tuning Transformers
Yi Tay, Mostafa Dehghani, Jinfeng Rao, William Fedus, Samira Abnar, Hyung Won Chung, Sharan Narang, Dani Yogatama, Ashish Vaswani, Donald Metzler

Omni-Scale CNNs: A Simple and Effective Kernel Size Configuration for Time Series Classification
Wensi Tang, Guodong Long, Lu Liu,Tianyi Zhou, Michael Blumenstein, Jing Jiang

Embedded-Model Flows: Combining the Inductive Biases of Model-Free Deep Learning and Explicit Probabilistic Modeling
Gianluigi Silvestri, Emily Fertig, Dave Moore, Luca Ambrogioni

Post Hoc Explanations May be Ineffective for Detecting Unknown Spurious Correlation
Julius Adebayo, Michael Muelly, Hal Abelson, Been Kim

Axiomatic Explanations for Visual Search, Retrieval, and Similarity Learning
Mark Hamilton, Scott Lundberg, Stephanie Fu, Lei Zhang, William T. Freeman

Pix2seq: A Language Modeling Framework for Object Detection (see the blog post)
Ting Chen, Saurabh Saxena, Lala Li, David J. Fleet, Geoffrey Hinton

Mirror Descent Policy Optimization
Manan Tomar, Lior Shani, Yonathan Efroni, Mohammad Ghavamzadeh

CodeTrek: Flexible Modeling of Code Using an Extensible Relational Representation
Pardis Pashakhanloo, Aaditya Naik, Yuepeng Wang, Hanjun Dai, Petros Maniatis, Mayur Naik

Conditional Object-Centric Learning From Video
Thomas Kipf, Gamaleldin F. Elsayed, Aravindh Mahendran, Austin Stone, Sara Sabour, Georg Heigold, Rico Jonschkowski, Alexey Dosovitskiy, Klaus Greff

A Loss Curvature Perspective on Training Instabilities of Deep Learning Models
Justin Gilmer, Behrooz Ghorbani, Ankush Garg, Sneha Kudugunta, Behnam Neyshabur, David Cardoze, George E. Dahl, Zack Nado, Orhan Firat

Autonomous Reinforcement Learning: Formalism and Benchmarking
Archit Sharma, Kelvin Xu, Nikhil Sardana, Abhishek Gupta, Karol Hausman, Sergey Levine, Chelsea Finn

TRAIL: Near-Optimal Imitation Learning with Suboptimal Data
Mengjiao Yang, Sergey Levine, Ofir Nachum

Minimax Optimization With Smooth Algorithmic Adversaries
Tanner Fiez, Lillian J. Ratliff, Chi Jin, Praneeth Netrapalli

Unsupervised Semantic Segmentation by Distilling Feature Correspondences
Mark Hamilton, Zhoutong Zhang, Bharath Hariharan, Noah Snavely, William T. Freeman

InfinityGAN: Towards Infinite-Pixel Image Synthesis
Chieh Hubert Lin, Hsin-Ying Lee, Yen-Chi Cheng, Sergey Tulyakov, Ming-Hsuan Yang

Shuffle Private Stochastic Convex Optimization
Albert Cheu, Matthew Joseph, Jieming Mao, Binghui Peng

Hybrid Random Features
Krzysztof Choromanski, Haoxian Chen, Han Lin, Yuanzhe Ma, Arijit Sehanobish, Deepali Jain, Michael S Ryoo, Jake Varley, Andy Zeng, Valerii Likhosherstov, Dmitry Kalashnikov, Vikas Sindhwani, Adrian Weller

Vector-Quantized Image Modeling With Improved VQGAN
Jiahui Yu, Xin Li, Jing Yu Koh, Han Zhang, Ruoming Pang, James Qin, Alexander Ku, Yuanzhong Xu, Jason Baldridge, Yonghui Wu

On the Benefits of Maximum Likelihood Estimation for Regression and Forecasting
Pranjal Awasthi, Abhimanyu Das, Rajat Sen, Ananda Theertha Suresh

Surrogate Gap Minimization Improves Sharpness-Aware Training
Juntang Zhuang*, Boqing Gong, Liangzhe Yuan, Yin Cui, Hartwig Adam, Nicha C. Dvornek, Sekhar Tatikonda, James S. Duncan, Ting Liu

Online Target Q-learning With Reverse Experience Replay: Efficiently Finding the Optimal Policy for Linear MDPs
Naman Agarwal, Prateek Jain, Dheeraj Nagaraj, Praneeth Netrapalli, Syomantak Chaudhuri

CrossBeam: Learning to Search in Bottom-Up Program Synthesis
Kensen Shi, Hanjun Dai, Kevin Ellis, Charles Sutton

Workshops
Workshop on the Elements of Reasoning: Objects, Structure, and Causality (OSC)
Organizers include: Klaus Greff, Thomas Kipf

Workshop on Agent Learning in Open-Endedness
Organizers include: Krishna Srinivasan
Speakers include: Natasha Jaques, Danijar Hafner

Wiki-M3L: Wikipedia and Multi-modal & Multi-lingual Research
Organizers include: Klaus Greff, Thomas Kipf
Speakers include: Jason Baldridge, Tom Duerig

Setting Up ML Evaluation Standards to Accelerate Progress
Organizers include: Rishabh Agarwal
Speakers and Panelists include: Katherine Heller, Sara Hooker, Corinna Cortes

From Cells to Societies: Collective Learning Across Scales
Organizers include: Mark Sandler, Max Vladymyrov
Speakers include: Blaise Aguera y Arcas, Alexander Mordvintsev, Michael Mozer

Emergent Communication: New Frontiers
Speakers include: Natasha Jaques

Deep Learning for Code
Organizers include: Jonathan Herzig

GroundedML: Anchoring Machine Learning in Classical Algorithmic Theory
Speakers include: Gintare Karolina Dziugaite

Generalizable Policy Learning in the Physical World
Speakers and Panelists include: Mrinal Kalakrishnan

CoSubmitting Summer (CSS) Workshop
Organizers include: Rosanne Liu



*Work done while at Google.  

Read More

Pix2Seq: A New Language Interface for Object Detection

Object detection is a long-standing computer vision task that attempts to recognize and localize all objects of interest in an image. The complexity arises when trying to identify or localize all object instances while also avoiding duplication. Existing approaches, like Faster R-CNN and DETR, are carefully designed and highly customized in the choice of architecture and loss function. This specialization of existing systems has created two major barriers: (1) it adds complexity in tuning and training the different parts of the system (e.g., region proposal network, graph matching with GIOU loss, etc.), and (2), it can reduce the ability of a model to generalize, necessitating a redesign of the model for application to other tasks.

In “Pix2Seq: A Language Modeling Framework for Object Detection”, published at ICLR 2022, we present a simple and generic method that tackles object detection from a completely different perspective. Unlike existing approaches that are task-specific, we cast object detection as a language modeling task conditioned on the observed pixel inputs. We demonstrate that Pix2Seq achieves competitive results on the large-scale object detection COCO dataset compared to existing highly-specialized and well-optimized detection algorithms, and its performance can be further improved by pre-training the model on a larger object detection dataset. To encourage further research in this direction, we are also excited to release to the broader research community Pix2Seq’s code and pre-trained models along with an interactive demo.

Pix2Seq Overview
Our approach is based on the intuition that if a neural network knows where and what the objects in an image are, one could simply teach it how to read them out. By learning to “describe” objects, the model can learn to ground the descriptions on pixel observations, leading to useful object representations. Given an image, the Pix2Seq model outputs a sequence of object descriptions, where each object is described using five discrete tokens: the coordinates of the bounding box’s corners [ymin, xmin, ymax, xmax] and a class label.

Pix2Seq framework for object detection. The neural network perceives an image, and generates a sequence of tokens for each object, which correspond to bounding boxes and class labels.

With Pix2Seq, we propose a quantization and serialization scheme that converts bounding boxes and class labels into sequences of discrete tokens (similar to captions), and leverage an encoder-decoder architecture to perceive pixel inputs and generate the sequence of object descriptions. The training objective function is simply the maximum likelihood of tokens conditioned on pixel inputs and the preceding tokens.

Sequence Construction from Object Descriptions
In commonly used object detection datasets, images have variable numbers of objects, represented as sets of bounding boxes and class labels. In Pix2Seq, a single object, defined by a bounding box and class label, is represented as [ymin, xmin, ymax, xmax, class]. However, typical language models are designed to process discrete tokens (or integers) and are unable to comprehend continuous numbers. So, instead of representing image coordinates as continuous numbers, we normalize the coordinates between 0 and 1 and quantize them into one of a few hundred or thousand discrete bins. The coordinates are then converted into discrete tokens as are the object descriptions, similar to image captions, which in turn can then be interpreted by the language model. The quantization process is achieved by multiplying the normalized coordinate (e.g., ymin) by the number of bins minus one, and rounding it to the nearest integer (the detailed process can be found in our paper).

Quantization of the coordinates of the bounding boxes with different numbers of bins on a 480 × 640 image. With a small number of bins/tokens, such as 500 bins (∼1 pixel/bin), it achieves high precision even for small objects.

After quantization, the object annotations provided with each training image are ordered into a sequence of discrete tokens (shown below). Since the order of the objects does not matter for the detection task per se, we randomize the order of objects each time an image is shown during training. We also append an End of Sequence (EOS) token at the end as​​ different images often have different numbers of objects, and hence sequence lengths.

The bounding boxes and class labels for objects detected in the image on the left are represented in the sequences shown on the right. A random object ordering strategy is used in our work but other approaches to ordering could also be used.

The Model Architecture, Objective Function, and Inference
We treat the sequences that we constructed from object descriptions as a “dialect” and address the problem via a powerful and general language model with an image encoder and autoregressive language encoder. Similar to language modeling, Pix2Seq is trained to predict tokens, given an image and preceding tokens, with a maximum likelihood loss. At inference time, we sample tokens from model likelihood. The sampled sequence ends when the EOS token is generated. Once the sequence is generated, we split it into chunks of 5 tokens for extracting and de-quantizing the object descriptions (i.e., obtaining the predicted bounding boxes and class labels). It is worth noting that both the architecture and loss function are task-agnostic in that they don’t assume prior knowledge about object detection (e.g., bounding boxes). We describe how we can incorporate task-specific prior knowledge with a sequence augmentation technique in our paper.

Results
Despite its simplicity, Pix2Seq achieves impressive empirical performance on benchmark datasets. Specifically, we compare our method with well established baselines, Faster R-CNN and DETR, on the widely used COCO dataset and demonstrate that it achieves competitive average precision (AP) results.

Pix2Seq achieves competitive AP results compared to existing systems that require specialization during model design, while being significantly simpler. The best performing Pix2Seq model achieved an AP score of 45.

Since our approach incorporates minimal inductive bias or prior knowledge of the object detection task into the model design, we further explore how pre-training the model using the large-scale object detection COCO dataset can impact its performance. Our results indicate that this training strategy (along with using bigger models) can further boost performance.

The average precision of the Pix2Seq model with pre-training followed by fine-tuning. The best performing Pix2Seq model without pre-training achieved an AP score of 45. When the model is pre-trained, we see an 11% improvement with an AP score of 50.

Pix2Seq can detect objects in densely populated and complex scenes, such as those shown below.

Example complex and densely populated scenes labeled by a trained Pix2Seq model. Try it out here.

Conclusion and Future Work
With Pix2Seq, we cast object detection as a language modeling task conditioned on pixel inputs for which the model architecture and loss function are generic, and have not been engineered specifically for the detection task. One can, therefore, readily extend this framework to different domains or applications, where the output of the system can be represented by a relatively concise sequence of discrete tokens (e.g., keypoint detection, image captioning, visual question answering), or incorporate it into a perceptual system supporting general intelligence, for which it provides a language interface to a wide range of vision and language tasks. We also hope that the release of our Pix2Seq’s code, pre-trained models and interactive demo will inspire further research in this direction.

Acknowledgements
This post reflects the combined work with our co-authors: Saurabh Saxena, Lala Li, Geoffrey Hinton. We would also like to thank Tom Small for the visualization of the Pix2Seq illustration figure.

Read More

Hidden Interfaces for Ambient Computing

As consumer electronics and internet-connected appliances are becoming more common, homes are beginning to embrace various types of connected devices that offer functionality like music control, voice assistance, and home automation. A graceful integration of devices requires adaptation to existing aesthetics and user styles rather than simply adding screens, which can easily disrupt a visual space, especially when they become monolithic surfaces or black screens when powered down or not actively used. Thus there is an increasing desire to create connected ambient computing devices and appliances that can preserve the aesthetics of everyday materials, while providing on-demand access to interaction and digital displays.

Illustration of how hidden interfaces can appear and disappear in everyday surfaces, such as a mirror or the wood paneling of a home appliance.

In “Hidden Interfaces for Ambient Computing: Enabling Interaction in Everyday Materials through High-Brightness Visuals on Low-Cost Matrix Displays”, presented at ACM CHI 2022, we describe an interface technology that is designed to be embedded underneath materials and our vision of how such technology can co-exist with everyday materials and aesthetics. This technology makes it possible to have high-brightness, low-cost displays appear from underneath materials such as textile, wood veneer, acrylic or one-way mirrors, for on-demand touch-based interaction.

Hidden interface prototypes demonstrate bright and expressive rendering underneath everyday materials. From left to right: thermostat under textile, a scalable clock under wood veneer, and a caller ID display and a zooming countdown under mirrored surfaces.

Parallel Rendering: Boosting PMOLED Brightness for Ambient Computing
While many of today’s consumer devices employ active-matrix organic light-emitting diode (AMOLED) displays, their cost and manufacturing complexity is prohibitive for ambient computing. Yet other display technologies, such as E-ink and LCD, do not have sufficient brightness to penetrate materials.

To address this gap, we explore the potential of passive-matrix OLEDs (PMOLEDs), which are based on a simple design that significantly reduces cost and complexity. However, PMOLEDs typically use scanline rendering, where active display driver circuitry sequentially activates one row at a time, a process that limits display brightness and introduces flicker.

Instead, we propose a system that uses parallel rendering, where as many rows as possible are activated simultaneously in each operation by grouping rectilinear shapes of horizontal and vertical lines. For example, a square can be shown with just two operations, in contrast to traditional scanline rendering that needs as many operations as there are rows. With fewer operations, parallel rendering can output significantly more light in each instant to boost brightness and eliminate flicker. The technique is not strictly limited to lines and rectangles even if that is where we see the most dramatic performance increase. For example, one could add additional rendering steps for antialiasing (i.e., smoothing of) non-rectilinear content.

Illustration of scanline rendering (top) and parallel rendering (bottom) operations of an unfilled rectangle. Parallel rendering achieves bright, flicker-free graphics by simultaneously activating multiple rows.

Rendering User Interfaces and Text
We show that hidden interfaces can be used to create dynamic and expressive interactions. With a set of fundamental UI elements such as buttons, switches, sliders, and cursors, each interface can provide different basic controls, such as light switches, volume controls and thermostats. We created a scalable font (i.e., a set of numbers and letters) that is designed for efficient rendering in just a few operations. While we currently exclude letters “k, z, x” with their diagonal lines, they could be supported with additional operations. The per-frame-control of font properties coupled with the high frame rate of the display enables very fluid animations — this capability greatly expands the expressivity of the rectilinear graphics far beyond what is possible on fixed 7-segment LED displays.

In this work, we demonstrate various examples, such as a scalable clock, a caller ID display, a zooming countdown timer, and a music visualizer.

Realizing Hidden Interfaces with Interactive Hardware
To implement proof-of-concept hidden interfaces, we use a PMOLED display with 128×96 resolution that has all row and column drivers routed to a connector for direct access. We use a custom printed circuit board (PCB) with fourteen 16-channel digital-to-analog converters (DACs) to directly interface those 224 lines from a Raspberry Pi 3 A+. The touch interaction is enabled by a ring-shaped PCB surrounding the display with 12 electrodes arranged in arc segments.

Comparison to Existing Technologies
We compared the brightness of our parallel rendering to both the scanline on the same PMOLED and a small and large state-of-the-art AMOLED. We tested brightness through six common materials, such as wood and plastic. The material thickness ranged from 0.2 mm for the one-way mirror film to 1.6 mm for basswood. We measured brightness in lux (lx = light intensity as perceived by the human eye) using a light meter near the display. The environmental light was kept dim, slightly above the light meter’s minimum sensitivity. For simple rectangular shapes, we observed 5–40x brightness increase for the PMOLED in comparison to the AMOLED. The exception was the thick basswood, which didn’t let much light through for any rendering technology.

Example showing performance difference between parallel rendering on the PMOLED (this work) and a similarly sized modern 1.4″ AMOLED.

To validate the findings from our technical characterization with more realistic and complex content, we evaluate the number “2”, a grid of checkboxes, three progress bars, and the text “Good Life”. For this more complex content, we observed a 3.6–9.3x brightness improvement. These results suggest that our approach of parallel rendering on PMOLED enables display through several materials, and outperforms common state-of-the-art AMOLED displays, which seem to not be usable for the tested scenarios.

Brightness experiments with additional shapes that require different numbers of operations (ops). Measurements are shown in comparison to large state-of-the-art AMOLED displays.

What’s Next?
In this work, we enabled hidden interfaces that can be embedded in traditional materials and appear on demand. Our lab evaluation suggests unmet opportunities to introduce hidden displays with simple, yet expressive, dynamic and interactive UI elements and text in traditional materials, especially wood and mirror, to blend into people’s homes.

In the future, we hope to investigate more advanced parallel rendering techniques, using algorithms that could also support images and complex vector graphics. Furthermore, we plan to explore efficient hardware designs. For example, application-specific integrated circuits (ASICs) could enable an inexpensive and small display controller with parallel rendering instead of a large array of DACs. Finally, longitudinal deployment would enable us to go deeper into understanding user adoption and behavior with hidden interfaces.

Hidden interfaces demonstrate how control and feedback surfaces of smart devices and appliances could visually disappear when not in use and then appear when in the user’s proximity or touch. We hope this direction will encourage the community to consider other approaches and scenarios where technology can fade into the background for a more harmonious coexistence with traditional materials and human environments.

Acknowledgements
First and foremost, we would like to thank Ali Rahimi and Roman Lewkow for the collaboration, including providing the enabling technology. We also thank Olivier Bau, Aaron Soloway, Mayur Panchal and Sukhraj Hothi for their prototyping and fabrication contributions. We thank Michelle Chang and Mark Zarich for visual designs, illustrations and presentation support. We thank Google ATAP and the Google Interaction Lab for their support of the project. Finally, we thank Sarah Sterman and Mathieu Le Goc for helpful discussions and suggestions.

Read More

FormNet: Beyond Sequential Modeling for Form-Based Document Understanding

Form-based document understanding is a growing research topic because of its practical potential for automatically converting unstructured text data into structured information to gain insight about a document’s contents. Recent sequence modeling, which is a self-attention mechanism that directly models relationships between all words in a selection of text, has demonstrated state-of-the-art performance on natural language tasks. A natural approach to handle form document understanding tasks is to first serialize the form documents (usually in a left-to-right, top-to-bottom fashion) and then apply state-of-the-art sequence models to them.

However, form documents often have more complex layouts that contain structured objects, such as tables, columns, and text blocks. Their variety of layout patterns makes serialization difficult, substantially limiting the performance of strict serialization approaches. These unique challenges in form document structural modeling have been largely underexplored in literature.

An illustration of the form document information extraction task using an example from the FUNSD dataset.

In “FormNet: Structural Encoding Beyond Sequential Modeling in Form Document Information Extraction”, presented at ACL 2022, we propose a structure-aware sequence model, called FormNet, to mitigate the sub-optimal serialization of forms for document information extraction. First, we design a Rich Attention (RichAtt) mechanism that leverages the 2D spatial relationship between word tokens for more accurate attention weight calculation. Then, we construct Super-Tokens (tokens that aggregate semantically meaningful information from neighboring tokens) for each word by embedding representations from their neighboring tokens through a graph convolutional network (GCN). Finally, we demonstrate that FormNet outperforms existing methods, while using less pre-training data, and achieves state-of-the-art performance on the CORD, FUNSD, and Payment benchmarks.

FormNet for Information Extraction
Given a form document, we first use the BERT-multilingual vocabulary and optical character recognition (OCR) engine to identify and tokenize words. We then feed the tokens and their corresponding 2D coordinates into a GCN for graph construction and message passing. Next, we use Extended Transformer Construction (ETC) layers with the proposed RichAtt mechanism to continue to process the GCN-encoded structure-aware tokens for schema learning (i.e., semantic entity extraction). Finally, we use the Viterbi algorithm, which finds a sequence that maximizes the posterior probability, to decode and obtain the final entities for output.

Extended Transformer Construction (ETC)
We adopt ETC as the FormNet model backbone. ETC scales to relatively long inputs by replacing standard attention, which has quadratic complexity, with a sparse global-local attention mechanism that distinguishes between global and long input tokens. The global tokens attend to and are attended by all tokens, but the long tokens attend only locally to other long tokens within a specified local radius, reducing the complexity so that it is more manageable for long sequences.

Rich Attention
Our novel architecture, RichAtt, avoids the deficiencies of absolute and relative embeddings by avoiding embeddings entirely. Instead, it computes the order of and log distance between pairs of tokens with respect to the x and y axes on the layout grid, and adjusts the pre-softmax attention scores of each pair as a direct function of these values.

In a traditional attention layer, each token representation is linearly transformed into a Query vector, a Key vector, and a Value vector. A token “looks” for other tokens from which it might want to absorb information (i.e., attend to) by finding the ones with Key vectors that create relatively high scores when matrix-multiplied (called Matmul) by its Query vector and then softmax-normalized. The token then sums together the Value vectors of all other tokens in the sentence, weighted by their score, and passes this up the network, where it will normally be added to the token’s original input vector.

However, other features beyond the Query and Key vectors are often relevant to the decision of how strongly a token should attend to another given token, such as the order they’re in, how many other tokens separate them, or how many pixels apart they are. In order to incorporate these features into the system, we use a trainable parametric function paired with an error network, which takes the observed feature and the output of the parametric function and returns a penalty that reduces the dot product attention score.

The network uses the Query and Key vectors to consider what value some low-level feature (e.g., distance) should take if the tokens are related, and penalizes the attention score based on the error.

At a high level, for each attention head at each layer, FormNet examines each pair of token representations, determines the ideal features the tokens should have if there is a meaningful relationship between them, and penalizes the attention score according to how different the actual features are from the ideal ones. This allows the model to learn constraints on attention using logical implication.

A visualization of how RichAtt might act on a sentence. There are three adjectives that the word “crow” might attend to. “Lazy” is to the right, so it probably does not modify “crow” and its attention edge is penalized. “Sly” is many tokens away, so its attention edge is also penalized. “Cunning” receives no significant penalties, so by process of elimination, it is the best candidate for attention.

Furthermore, if one assumes that the softmax-normalized attention scores represent a probability distribution, and the distributions for the observed features are known, then this algorithm — including the exact choice of parametric functions and error functions — falls out algebraically, meaning FormNet has a mathematical correctness to it that is lacking from many alternatives (including relative embeddings).

Super-Tokens by Graph Learning
The key to sparsifying attention mechanisms in ETC for long sequence modeling is to have every token only attend to tokens that are nearby in the serialized sequence. Although the RichAtt mechanism empowers the transformers by taking the spatial layout structures into account, poor serialization can still block significant attention weight calculation between related word tokens.

To further mitigate the issue, we construct a graph to connect nearby tokens in a form document. We design the edges of the graph based on strong inductive biases so that they have higher probabilities of belonging to the same entity type. For each token, we obtain its Super-Token embedding by applying graph convolutions along these edges to aggregate semantically relevant information from neighboring tokens. We then use these Super-Tokens as an input to the RichAtt ETC architecture. This means that even though an entity may get broken up into multiple segments due to poor serialization, the Super-Tokens learned by the GCN will have retained much of the context of the entity phrase.

An illustration of the word-level graph, with blue edges between tokens, of a FUNSD document.

Key Results
The Figure below shows model size vs. F1 score (the harmonic mean of the precision and recall) for recent approaches on the CORD benchmark. FormNet-A2 outperforms the most recent DocFormer while using a model that is 2.5x smaller. FormNet-A3 achieves state-of-the-art performance with a 97.28% F1 score. For more experimental results, please refer to the paper.

Model Size vs. Entity Extraction F1 Score on CORD benchmark. FormNet significantly outperforms other recent approaches in absolute F1 performance and parameter efficiency.

We study the importance of RichAtt and Super-Token by GCN on the large-scale masked language modeling (MLM) pre-training task across three FormNets. Both RichAtt and GCN components improve upon the ETC baseline on reconstructing the masked tokens by a large margin, showing the effectiveness of their structural encoding capability on form documents. The best performance is obtained when incorporating both RichAtt and GCN.

Performance of the Masked-Language Modeling (MLM) pre-training. Both the proposed RichAtt and Super-Token by GCN components improve upon ETC baseline by a large margin, showing the effectiveness of their structural encoding capability on large-scale form documents.

Using BertViz, we visualize the local-to-local attention scores for specific examples from the CORD dataset for the standard ETC and FormNet models. Qualitatively, we confirm that the tokens attend primarily to other tokens within the same visual block for FormNet. Moreover for that model, specific attention heads are attending to tokens aligned horizontally, which is a strong signal of meaning for form documents. No clear attention pattern emerges for the ETC model, suggesting the RichAtt and Super-Token by GCN enable the model to learn the structural cues and leverage layout information effectively.

The attention scores for ETC and FormNet (ETC+RichAtt+GCN) models. Unlike the ETC model, the FormNet model makes tokens attend to other tokens within the same visual blocks, along with tokens aligned horizontally, thus strongly leveraging structural cues.

Conclusion
We present FormNet, a novel model architecture for form-based document understanding. We determine that the novel RichAtt mechanism and Super-Token components help the ETC transformer excel at form understanding in spite of sub-optimal, noisy serialization. We demonstrate that FormNet recovers local syntactic information that may have been lost during text serialization and achieves state-of-the-art performance on three benchmarks.

Acknowledgements
This research was conducted by Chen-Yu Lee, Chun-Liang Li, Timothy Dozat, Vincent Perot, Guolong Su, Nan Hua, Joshua Ainslie, Renshen Wang, Yasuhisa Fujii, and Tomas Pfister. Thanks to Evan Huang, Shengyang Dai, and Salem Elie Haykal for their valuable feedback, and Tom Small for creating the animation in this post.

Read More

Learning to Prompt for Continual Learning

Supervised learning is a common approach to machine learning (ML) in which the model is trained using data that is labeled appropriately for the task at hand. Ordinary supervised learning trains on independent and identically distributed (IID) data, where all training examples are sampled from a fixed set of classes, and the model has access to these examples throughout the entire training phase. In contrast, continual learning tackles the problem of training a single model on changing data distributions where different classification tasks are presented sequentially. This is particularly important, for example, to enable autonomous agents to process and interpret continuous streams of information in real-world scenarios.

To illustrate the difference between supervised and continual learning, consider two tasks: (1) classify cats vs. dogs and (2) classify pandas vs. koalas. In supervised learning, which uses IID, the model is given training data from both tasks and treats it as a single 4-class classification problem. However, in continual learning, these two tasks arrive sequentially, and the model only has access to the training data of the current task. As a result, such models tend to suffer from performance degradation on the previous tasks, a phenomenon called catastrophic forgetting.

Mainstream solutions try to address catastrophic forgetting by buffering past data in a “rehearsal buffer” and mixing it with current data to train the model. However, the performance of these solutions depends heavily on the size of the buffer and, in some cases, may not be possible at all due to data privacy concerns. Another branch of work designs task-specific components to avoid interference between tasks. But these methods often assume that the task at test time is known, which is not always true, and they require a large number of parameters. The limitations of these approaches raise critical questions for continual learning: (1) Is it possible to have a more effective and compact memory system that goes beyond buffering past data? (2) Can one automatically select relevant knowledge components for an arbitrary sample without knowing its task identity?

In “Learning to Prompt for Continual Learning”, presented at CVPR2022, we attempt to answer these questions. Drawing inspiration from prompting techniques in natural language processing, we propose a novel continual learning framework called Learning to Prompt (L2P). Instead of continually re-learning all the model weights for each sequential task, we instead provide learnable task-relevant “instructions” (i.e., prompts) to guide pre-trained backbone models through sequential training via a pool of learnable prompt parameters. L2P is applicable to various challenging continual learning settings and outperforms previous state-of-the-art methods consistently on all benchmarks. It achieves competitive results against rehearsal-based methods while also being more memory efficient. Most importantly, L2P is the first to introduce the idea of prompting in the field of continual learning.

Compared with typical methods that adapt entire or partial model weights to tasks sequentially using a rehearsal buffer, L2P uses a single frozen backbone model and learns a prompt pool to conditionally instruct the model. “Model 0” indicates that the backbone model is fixed at the beginning.

<!–

Compared with typical methods that adapt entire or partial model weights to tasks sequentially using a rehearsal buffer, L2P uses a single frozen backbone model and learns a prompt pool to conditionally instruct the model. “Model 0” indicates that the backbone model is fixed at the beginning.

–>

Prompt Pool and Instance-Wise Query
Given a pre-trained Transformer model, “prompt-based learning” modifies the original input using a fixed template. Imagine a sentiment analysis task is given the input “I like this cat”. A prompt-based method will transform the input to “I like this cat. It looks X”, where the “X” is an empty slot to be predicted (e.g., “nice”, “cute”, etc.) and “It looks X” is the so-called prompt. By adding prompts to the input, one can condition the pre-trained models to solve many downstream tasks. While designing fixed prompts requires prior knowledge along with trial and error, prompt tuning prepends a set of learnable prompts to the input embedding to instruct the pre-trained backbone to learn a single downstream task, under the transfer learning setting.

In the continual learning scenario, L2P maintains a learnable prompt pool, where prompts can be flexibly grouped as subsets to work jointly. Specifically, each prompt is associated with a key that is learned by reducing the cosine similarity loss between matched input query features. These keys are then utilized by a query function to dynamically look up a subset of task-relevant prompts based on the input features. At test time, inputs are mapped by the query function to the top-N closest keys in the prompt pool, and the associated prompt embeddings are then fed to the rest of the model to generate the output prediction. At training, we optimize the prompt pool and the classification head via the cross-entropy loss.

Illustration of L2P at test time. First, L2P selects a subset of prompts from a key-value paired prompt pool based on our proposed instance-wise query mechanism. Then, L2P prepends the selected prompts to the input tokens. Finally, L2P feeds the extended tokens to the model for prediction.

Intuitively, similar input examples tend to choose similar sets of prompts and vice versa. Thus, prompts that are frequently shared encode more generic knowledge while other prompts encode more task-specific knowledge. Moreover, prompts store high-level instructions and keep lower-level pre-trained representations frozen, thus catastrophic forgetting is mitigated even without the necessity of a rehearsal buffer. The instance-wise query mechanism removes the necessity of knowing the task identity or boundaries, enabling this approach to address the under-investigated challenge of task-agnostic continual learning.

Effectiveness of L2P
We evaluate the effectiveness of L2P in different baseline methods using an ImageNet pre-trained Vision Transformer (ViT) on representative benchmarks. The naïve baseline, called Sequential in the graphs below, refers to training a single model sequentially on all tasks. The EWC model adds a regularization term to mitigate forgetting and the Rehearsal model saves past examples to a buffer for mixed training with current data. To measure the overall continual learning performance, we measure both the accuracy and the average difference between the best accuracy achieved during training and the final accuracy for all tasks (except the last task), which we call forgetting. We find that L2P outperforms the Sequential and EWC methods significantly in both metrics. Notably, L2P even surpasses the Rehearsal approach, which uses an additional buffer to save past data. Because the L2P approach is orthogonal to Rehearsal, its performance could be further improved if it, too, used a rehearsal buffer.

L2P outperforms baseline methods in both accuracy (top) and forgetting (bottom). Accuracy refers to the average accuracy for all tasks and forgetting is defined as the average difference between the best accuracy achieved during training and the final accuracy for all tasks (except the last task).

We also visualize the prompt selection result from our instance-wise query strategy on two different benchmarks, where one has similar tasks and the other has varied tasks. The results indicate that L2P promotes more knowledge sharing between similar tasks by having more shared prompts, and less knowledge sharing between varied tasks by having more task-specific prompts.

Prompt selection histograms for benchmarks of similar tasks (left) and varied tasks (right). The left benchmark has higher intra-task similarity, thus sharing prompts between tasks results in good performance, while the right benchmark favors more task-specific prompts.

Conclusion
In this work, we present L2P to address key challenges in continual learning from a new perspective. L2P does not require a rehearsal buffer or known task identity at test time to achieve high performance. Further, it can handle various complex continual learning scenarios, including the challenging task-agnostic setting. Because large-scale pre-trained models are widely used in the machine learning community for their robust performance on real-world problems, we believe that L2P opens a new learning paradigm towards practical continual learning applications.

Acknowledgements
We gratefully acknowledge the contributions of other co-authors, including Chen-Yu Lee, Han Zhang, Ruoxi Sun, Xiaoqi Ren, Guolong Su, Vincent Perot, Jennifer Dy, Tomas Pfister. We would also like to thank Chun-Liang Li, Jeremy Martin Kubica, Sayna Ebrahimi, Stratis Ioannidis, Nan Hua, and Emmanouil Koukoumidis, for their valuable discussions and feedback, and Tom Small for figure creation.

Read More

Locked-image Tuning: Adding Language Understanding to Image Models

The ability to classify images into categories has been transformed by deep learning. It has also been significantly accelerated by transfer learning, whereby models are first pre-trained on large datasets, like ImageNet, to learn visual representations that are then transferred via fine-tuning to a new task with less data (e.g., classifying animals). Previous works such as BiT and ViT employed these methods to achieve state-of-the-art performance on a wide range of classification tasks, such as the VTAB benchmark.

However, fine-tuning has some downsides: though pre-training is done only once, fine-tuning is necessary on every new dataset for which task-specific data is needed. Multimodal contrastive learning is an alternative, recently popularized paradigm (e.g., CLIP, ALIGN) that overcomes these issues by instead learning how to match free-form text with images. These models can then solve new tasks by reformulating them as image-text matching problems, without extra data (referred to as “zero-shot” learning). Contrastive learning is flexible and easy to adapt to new tasks, but has its own limitations, namely the need for a lot of paired image-text data and weaker performance than transfer learning approaches.

With those limitations in mind, we propose “LiT: Zero-Shot Transfer with Locked-image Text Tuning”, to appear at CVPR 2022. LiT models learn to match text to an already pre-trained image encoder. This simple yet effective setup provides the best of both worlds: strong image representations from pre-training, plus flexible zero-shot transfer to new tasks via contrastive learning. LiT achieves state-of-the-art zero-shot classification accuracy, significantly closing the gap between the two styles of learning. We think the best way to understand is to try it yourself, so we’ve included a demo of LiT models at the end of this post.

Fine-tuning (left) requires task-specific data and training to adapt a pre-trained model to a new task. An LiT model (right) can be used with any task, without further data or adaptation.

Contrastive Learning on Image-Text Data
Contrastive learning models learn representations from “positive” and “negative” examples, such that representations for “positive” examples are similar to each other but different from “negative” examples.

Multimodal contrastive learning applies this to pairs of images and associated texts. An image encoder computes representations from images, and a text encoder does the same for texts. Each image representation is encouraged to be close to the representation of its associated text (“positive”), but distinct from the representation of other texts (“negatives”) in the data, and vice versa. This has typically been done with randomly initialized models (“from scratch”), meaning the encoders have to simultaneously learn representations and how to match them.

Multimodal contrastive learning trains models to produce similar representations for closely matched images and texts.

This training can be done on noisy, loosely aligned pairs of image and text, which naturally occur on the web. This circumvents the need for manual labeling, and makes data scaling easy. Furthermore, the model learns much richer visual concepts — it’s not constrained to what’s defined in the classification label space. Instead of classifying an image as “coffee”, it can understand whether it’s “a small espresso in a white mug” or “a large latte in a red flask”.

Once trained, a model that aligns image and text can be used in many ways. For zero-shot classification, we compare image representations to text representations of the class names. For example, a “wombat vs jaguar” classifier can be built by computing the representations of the texts “jaguar” and “wombat”, and classifying an image as a jaguar if its representation better matches the former. This approach scales to thousands of classes and makes it very easy to solve classification tasks without the extra data necessary for fine-tuning. Another application of contrastive models is image search (a.k.a. image-text retrieval), by finding the image whose representation best matches that of a given text, or vice versa.

The Best of Both Worlds with Locked-image Tuning
As mentioned earlier, transfer learning achieves state-of-the-art accuracy, but requires per-task labels, datasets, and training. On the other hand, contrastive models are flexible, scalable, and easily adaptable to new tasks, but fall short in performance. To compare, at the time of writing, the state of the art on ImageNet classification using transfer learning is 90.94%, but the best contrastive zero-shot models achieve 76.4%.

LiT tuning bridges this gap: we contrastively train a text model to compute representations well aligned with the powerful ones available from a pre-trained image encoder. Importantly, for this to work well, the image encoder should be “locked“, that is: it should not be updated during training. This may be unintuitive since one usually expects the additional information from further training to increase performance, but we find that locking the image encoder consistently leads to better results.

LiT-tuning contrastively trains a text encoder to match a pre-trained image encoder. The text encoder learns to compute representations that align to those from the image encoder.

This can be considered an alternative to the classic fine-tuning stage, where the image encoder is separately adapted to every new classification task; instead we have one stage of LiT-tuning, after which the model can classify any data. LiT-tuned models achieve 84.5% zero-shot accuracy on ImageNet classification, showing significant improvements over previous methods that train models from scratch, and halving the performance gap between fine-tuning and contrastive learning.

Left: LiT-tuning significantly closes the gap between the best contrastive models and the best models fine-tuned with labels. Right: Using a pre-trained image encoder is always helpful, but locking it is surprisingly a key part of the recipe to success; unlocked image models (dashed) yield significantly worse performance.

An impressive benefit of contrastive models is increased robustness — they retain high accuracy on datasets that typically fool fine-tuned models, such as ObjectNet and ImageNet-C. Similarly, LiT-tuned models have high performance across various challenging versions of ImageNet, for example achieving a state-of-the-art 81.1% accuracy on ObjectNet.

LiT-tuning has other advantages. While prior contrastive works require large amounts of data and train for a very long time, the LiT approach is much less data hungry. LiT models trained on 24M publicly available image-text pairs rival the zero-shot classification performance of prior models trained on 400M image-text pairs of private data. The locked image encoder also leads to faster training with a smaller memory footprint. On larger datasets, image representations can be pre-computed; not running the image model during training further improves efficiency and also unlocks much larger batch sizes, which increases the number of “negatives” the model sees and is key to high-performance contrastive learning. The method works well with varied forms of image pre-training (e.g., including self-supervised learning), and with many publicly available image models. We hope that these benefits make LiT a great testbed for researchers.

Conclusion
We present Locked-image Tuning (LiT), which contrastively trains a text encoder to match image representations from a powerful pre-trained image encoder. This simple method is data and compute efficient, and substantially improves zero-shot classification performance compared to existing contrastive learning approaches.

Want to try it yourself?

A preview of the demo: use it to match free-form text descriptions to images and build your own zero-shot classifier!

We have prepared a small interactive demo to try some LiT-tuned models. We also provide a Colab with more advanced use cases and larger models, which are a great way to get started.

Acknowledgments
We would like to thank Xiaohua Zhai, Xiao Wang, Daniel Keysers, Alexander Kolesnikov, and Lucas Beyer who have co-authored the LiT paper and been involved in all aspects of its development, as well as the Brain team in Zürich. We also would like to thank Tom Small for creating the animations used in this blogpost.

Read More

Simple and Effective Zero-Shot Task-Oriented Dialogue

Modern conversational agents need to integrate with an ever-increasing number of services to perform a wide variety of tasks, from booking flights and finding restaurants, to playing music and telling jokes. Adding this functionality can be difficult — for each new task, one needs to collect new data and retrain the models that power the conversational agent. This is because most task-oriented dialogue (TOD) models are trained on a single task-specific ontology. An ontology is generally represented as a list of possible user intents (e.g., if the user wants to book a flight, if the user wants to play some music, etc.) and possible parameter slots to extract from the conversation (e.g., the date of the flight, the name of a song, and so on). A rigid ontology can be limiting, preventing the model from generalizing to new tasks or domains. For instance, a TOD model trained on a certain ontology only knows the intents in that ontology, and lacks the ability to generalize its knowledge to unseen intents. This is true even for new ontologies that overlap with ones already known to the agent — for example, if an agent already knows how to book train tickets, adding the ability to book airline tickets would require training on completely new data. Ideally, the agent should be able to leverage its existing knowledge from one ontology, and apply it to new ones.

New benchmarks, such as the the Schema Guided Dialogue (SGD) dataset, have been designed to evaluate the ability to generalize to unseen tasks, by distilling each ontology into a schema of slots and intents. In the SGD setting, TOD models are trained on multiple schemas, and evaluated on how well they generalize to unseen ones — instead of how well they overfit to a single ontology. However, recent work shows the top models still have room for improvement.

To address this problem, we introduce two different sequence-to-sequence approaches toward zero-shot transfer for dialogue modeling, presented in the papers “Description-Driven Task-Oriented Dialogue” and “Show, Don’t Tell: Demonstrations Outperform Descriptions for Schema-Guided Task-Oriented Dialogue”. Both models condition on additional contextual information, either slot and intent descriptions, or single demonstrative examples. Results obtained on multiple dialogue state tracking benchmarks show that by doing away with the fixed schemas and ontologies, these new approaches lead to state-of-the-art results on the dialogue state tracking task with more efficient models. The source code for the described approaches can be found here.

Background: Dialogue State Tracking
To address the challenge of zero-shot transfer for dialogue models, we focus on the problem of Dialogue State Tracking (DST). DST is a fundamental problem for conversational agents, in which a model predicts the belief state of a conversation, i.e., the agent’s understanding of the user’s indicated preferences. The belief state is typically modeled as an assignment of values to slots for which the user has indicated a preference in the conversation. An example is shown below.

An example conversation and its ground truth slots and intents for dialogue state tracking. Here, the active user intent is “Book a train”, and pertinent information for booking this train is recorded in the slot values.

Description-Driven Task-Oriented Dialogue
In our first paper, we introduce Description-Driven Dialogue State Tracking (D3ST), a DST model that leverages slot and intent descriptions when making predictions about the belief state. D3ST is built on top of the T5 sequence-to-sequence language model, which was shown in previous work to be pretrained effectively for DST problems.

D3ST prompts the input sequence with slot and intent descriptions, allowing the T5 model to attend to both this contextual information and the conversation. Its ability to generalize comes from the formulation of these descriptions. Instead of using a name for each slot, we assign a random index for every slot. For categorical slots (i.e., slots that only take values from a small, predefined set), possible values are also arbitrarily enumerated and then listed. The same is done with intents, and together these descriptions form the schema representation to be included in the input string. This is concatenated with the conversation text and fed into the T5 model. The target output is the belief state and user intent, again identified by their assigned indices. An example is shown below.

An example of the D3ST input and output format. The red text contains slot descriptions, while the blue text contains intent descriptions. The yellow text contains the conversation utterances.

This forces the model to predict conversation contexts using a slot’s index, and not that specific slot. By randomizing the index we assign to each slot between different examples, we prevent the model from learning specific schema information. The slot with index 0 could be the “Train Departure” slot in one example, and the “Train Destination” in another — as such, the model is encouraged to use the slot description given in index 0 to find the correct value, and discouraged from overfitting to a specific schema. With this setup, a model that sees enough different tasks or domains will learn to generalize the action of belief state tracking and intent prediction.

Show Don’t Tell
In our subsequent paper, “Show, Don’t Tell: Demonstrations Outperform Descriptions for Schema-Guided Task-Oriented Dialogue”, we employ a single annotated dialogue example that demonstrates the possible slots and values in a conversation, instead of relying on slot descriptions. In this sense, we “show” the semantics of the schema rather than “tell” the model through descriptions — hence the name “Show Don’t Tell” (SDT). SDT is also built on T5, and improves zero-shot performance beyond D3ST.

n example of the SDT input and output format. The text in red contains the demonstrative example, while the text in blue contains its ground truth belief state. The actual conversation for the model to predict is in yellow. While the D3ST prompt relies entirely on slot descriptions, the SDT prompt contains a concise example dialogue followed by the expected dialogue state annotations, resulting in more direct supervision.

The rationale for SDT’s single example demonstration is simple: there can still be ambiguities that are not fully captured in a slot or intent description, and require a concrete example to demonstrate. Moreover, from a developer’s standpoint, creating short dialogue examples to describe a schema can often be easier than writing descriptions that fully capture the meaning behind each slot and intent.

Benchmark Results
We evaluate both D3ST and SDT on a number of benchmarks, most notably the SGD dataset, which tests zero-shot generalization to unseen schemas in its test set. We evaluate our state tracking models on joint goal accuracy (JGA), the fraction of dialogue turns for which the model predicts an exactly correct belief state.

Both of our models either match or outperform existing state-of-the-art baselines (T5DST and paDST) at comparable model sizes, as shown below. In general, SDT performs slightly better than D3ST. Note that our models can be trained on different sizes of the underlying T5 language model. In addition, while the baseline models can only make predictions for one slot per forward pass, both our models can decode the entire dialogue state in a single forward pass — a much more efficient method in both training and inference.

Joint Goal Accuracy on the SGD dataset plotted against model size for existing baselines and our proposed models D3ST and SDT. Note that paDST* includes additional data augmentation.

Additional metrics are reported in both papers. D3ST exhibits state-of-the-art quality on the MultiWOZ dataset, with 75.9% JGA on MultiWOZ 2.4. Both D3ST and SDT show state-of-the-art performance in the MultiWOZ cross-domain leave-one-out setting. In addition, both D3ST and SDT were evaluated using the SGD-X dataset, and demonstrated strong robustness to linguistic variations in schema. These benchmarks all indicate that D3ST and SDT are state-of-the-art TOD models, with the ability to generalize to unseen tasks and domains.

Zero-Shot Capability
D3ST and SDT sometimes demonstrate a surprising ability to generalize to unseen tasks, and we saw many interesting examples when trying completely new dialogues with the model. We’ve included one such example below:

A D3ST model trained on the SGD dataset makes predictions (right) for an unseen meta conversation (left) about creating this blog post. The model predicts a completely correct belief state, even though it is not fine-tuned on anything related to blogs, authors or NLP.

Future Work
These papers demonstrate the feasibility of a zero-shot TOD system that can generalize to unseen tasks or domains. However, we’ve limited ourselves to the DST problem for now — we plan to extend this research to enable zero-shot dialogue policy modeling, allowing TOD systems to take actions following arbitrary instructions. In addition, the current input format can often lead to long input sequences, which can be slow for inference — we’re exploring new and more efficient methods to encode schema information.

Acknowledgements
This post reflects the combined work of Jeffrey Zhao, Raghav Gupta, Harrison Lee, Mingqiu Wang, Dian Yu, Yuan Cao, and Abhinav Rastogi. We’d like to thank Yonghui Wu and Izhak Shafran for their continued advice and guidance.

Read More

Lidar-Camera Deep Fusion for Multi-Modal 3D Detection

LiDAR and visual cameras are two types of complementary sensors used for 3D object detection in autonomous vehicles and robots. LiDAR, which is a remote sensing technique that uses light in the form of a pulsed laser to measure ranges, provides low-resolution shape and depth information, while cameras provide high-resolution shape and texture information. While the features captured by LiDAR and cameras should be merged together to provide optimal 3D object detection, it turns out that most state-of-the-art 3D object detectors use LiDAR as the only input. The main reason is that to develop robust 3D object detection models, most methods need to augment and transform the data from both modalities, making the accurate alignment of the features challenging.

Existing algorithms for fusing LiDAR and camera outputs, such as PointPainting, PointAugmenting, EPNet, 4D-Net and ContinuousFusion, generally follow two approaches — input-level fusion where the features are fused at an early stage, decorating points in the LiDAR point cloud with the corresponding camera features, or mid-level fusion where features are extracted from both sensors and then combined. Despite realizing the importance of effective alignment, these methods struggle to efficiently process the common scenario where features are enhanced and aggregated before fusion. This indicates that effectively fusing the signals from both sensors might not be straightforward and remains challenging.

In our CVPR 2022 paper, “DeepFusion: LiDAR-Camera Deep Fusion for Multi-Modal 3D Object Detection”, we introduce a fully end-to-end multi-modal 3D detection framework called DeepFusion that applies a simple yet effective deep-level feature fusion strategy to unify the signals from the two sensing modalities. Unlike conventional approaches that decorate raw LiDAR point clouds with manually selected camera features, our method fuses the deep camera and deep LiDAR features in an end-to-end framework. We begin by describing two novel techniques, InverseAug and LearnableAlign, that improve the quality of feature alignment and are applied to the development of DeepFusion. We then demonstrate state-of-the-art performance by DeepFusion on the Waymo Open Dataset, one of the largest datasets for automotive 3D object detection.

InverseAug: Accurate Alignment under Geometric Augmentation
To achieve good performance on existing 3D object detection benchmarks for autonomous cars, most methods require strong data augmentation during training to avoid overfitting. However, the necessity of data augmentation poses a non-trivial challenge in the DeepFusion pipeline. Specifically, the data from the two modalities use different augmentation strategies, e.g., rotating along the z-axis for 3D point clouds combined with random flipping for 2D camera images, often resulting in alignment that is inaccurate. Then the augmented LiDAR data has to go through a voxelization step that converts the point clouds into volume data stored in a three dimensional array of voxels. The voxelized features are quite different compared to the raw data, making the alignment even more difficult. To address the alignment issue caused by geometry-related data augmentation, we introduce Inverse Augmentation (InverseAug), a technique used to reverse the augmentation before fusion during the model’s training phase.

In the example below, we demonstrate the difficulties in aligning the augmented LiDAR data with the camera data. In this case, the LiDAR point cloud is augmented by rotation with the result that a given 3D key point, which could be any 3D coordinate, such as a LiDAR data point, cannot be easily aligned in 2D space simply through use of the original LiDAR and camera parameters. To make the localization feasible, InverseAug first stores the augmentation parameters before applying the geometry-related data augmentation. At the fusion stage, it reverses all data augmentation to get the original coordinate for the 3D key point, and then finds its corresponding 2D coordinates in the camera space.

During training, InverseAug resolves the inaccurate alignment from geometric augmentation.
Left: Alignment without InverseAug. Right: Alignment quality improvement with InverseAug.

LearnableAlign: A Cross-Modality-Attention Module to Learn Alignment
We also introduce Learnable Alignment (LearnableAlign), a cross-modality-attention–based feature-level alignment technique, to improve the alignment quality. For input-level fusion methods, such as PointPainting and PointAugmenting, given a 3D LiDAR point, only the corresponding camera pixel can be exactly located as there is a one-to-one mapping. In contrast, when fusing deep features in the DeepFusion pipeline, each LiDAR feature represents a voxel containing a subset of points, and hence, its corresponding camera pixels are in a polygon. So the alignment becomes the problem of learning the mapping between a voxel cell and a set of pixels.

A naïve approach is to average over all pixels corresponding to the given voxel. However, intuitively, and as supported by our visualized results, these pixels are not equally important because the information from the LiDAR deep feature unequally aligns with every camera pixel. For example, some pixels may contain critical information for detection (e.g., the target object), while others may be less informative (e.g., consisting of backgrounds such as roads, plants, occluders, etc.).

LearnableAlign leverages a cross-modality attention mechanism to dynamically capture the correlations between two modalities. Here, the input contains the LiDAR features in a voxel cell, and all its corresponding camera features. The output of the attention is essentially a weighted sum of the camera features, where the weights are collectively determined by a function of the LiDAR and camera features. More specifically, LearnableAlign uses three fully-connected layers to respectively transform the LiDAR features to a vector (ql), and camera features to vectors (kc) and (vc). For each vector (ql), we compute the dot products between (ql) and (kc) to obtain the attention affinity matrix that contains correlations between the LiDAR features and the corresponding camera features. Normalized by a softmax operator, the attention affinity matrix is then used to calculate weights and aggregate the vectors (vc) that contain camera information. The aggregated camera information is then processed by a fully-connected layer, and concatenated (Concat) with the original LiDAR feature. The output is then fed into any standard 3D detection framework, such as PointPillars or CenterPoint for model training.

LearnableAlign leverages the cross-attention mechanism to align LiDAR and camera features.

DeepFusion: A Better Way to Fuse Information from Different Modalities
Powered by our two novel feature alignment techniques, we develop DeepFusion, a fully end-to-end multi-modal 3D detection framework. In the DeepFusion pipeline, the LiDAR points are first fed into an existing feature extractor (e.g., pillar feature net from PointPillars) to obtain LiDAR features (e.g., pseudo-images). In the meantime, the camera images are fed into a 2D image feature extractor (e.g., ResNet) to obtain camera features. Then, InverseAug and LearnableAlign are applied in order to fuse the camera and LiDAR features together. Finally, the fused features are processed by the remaining components of the selected 3D detection model (e.g., the backbone and detection head from PointPillars) to obtain the detection results.

The pipeline of DeepFusion.

Benchmark Results
We evaluate DeepFusion on the Waymo Open Dataset, one of the largest 3D detection challenges for autonomous cars, using the Average Precision with Heading (APH) metric under difficulty level 2, the default metric to rank a model’s performance on the leaderboard. Among the 70 participating teams all over the world, the DeepFusion single and ensemble models achieve state-of-the-art performance in their corresponding categories.

The single DeepFusion model achieves new state-of-the-art performance on Waymo Open Dataset.
The Ensemble DeepFusion model outperforms all other methods on Waymo Open Dataset, ranking No. 1 on the leaderboard.

The Impact of InverseAug and LearnableAlign
We also conduct ablation studies on the effectiveness of the proposed InverseAug and LearnableAlign techniques. We demonstrate that both InverseAug and LearnableAlign individually contribute to a performance gain over the LiDAR-only model, and combining both can further yield an even more significant boost.

Ablation studies on InverseAug (IA) and LearnableAlign (LA) measured in average precision (AP) and APH. Combining both techniques contributes to the best performance gain.

Conclusion
We demonstrate that late-stage deep feature fusion can be more effective when features are aligned well, but aligning features from two different modalities can be challenging. To address this challenge, we propose two techniques, InverseAug and LearnableAlign, to improve the quality of alignment among multimodal features. By integrating these techniques into the fusion stage of our proposed DeepFusion method, we achieve state-of-the-art performance on the Waymo Open Dataset.

Acknowledgements:
Special thanks to co-authors Tianjian Meng, Ben Caine, Jiquan Ngiam, Daiyi Peng, Junyang Shen, Bo Wu, Yifeng Lu, Denny Zhou, Quoc Le, Alan Yuille, Mingxing Tan.

Read More

Investing in Eastern Europe’s AI future

It was an honor and a privilege to attend a special event in the Bulgarian capital, Sofia, today to launch INSAIT, the Institute for Computer Science, Artificial Intelligence and Technology. INSAIT is a new AI and computer science research institute that will provide truly world-class facilities.

It’s fantastic to see the country where I was born leading the charge in bridging Eastern Europe to the world-stage in computer science research.

The institute is modeled on the computer science departments of renowned institutions such as MIT, UC Berkeley and the Max-Planck Institute, and is backed by the Bulgarian government with an endowment fund of nearly $100 million. Its computer science and AI research will span topics such as machine learning, quantum computing, information security, robotics and many more. Within two years, INSAIT expects faculty and students to publish papers in top conferences.

Google is investing $3 million over the next three years to provide INSAIT with cloud computing resources and access to itsTensor Processing Unit Research Cloud, a specialized infrastructure for running high-performance machine learning models. Supported with additional investment from DeepMind and Amazon Web Services, INSAIT aims to attract and develop the best researchers, engineers and top PhD and MSc students.

I know there’s no shortage of talented researchers, computer scientists and engineers in Eastern Europe – indeed, Sofia is already ranked asone of Europe’s top tech cities – but historically, the lack of local facilities, funding and support has meant limited opportunities for basic research. INSAIT has been created in partnership with two of the world’s leading technology universities, ETH Zurich and EPFL Lausanne, and its supervisory and advisory boards consist of leading researchers who are committed to help the institute achieve its ambitious goals.

INSAIT opens in September, and I know the team is particularly keen to receive applications from women and other groups that are often underrepresented in the world of tech.

Google is delighted to support these efforts, and I cannot wait to see what new innovation emerges from this promising venture.

Read More

Large-Scale Matrix Factorization on TPUs

Matrix factorization is one of the oldest, yet still widely used, techniques for learning how to recommend items such as songs or movies from user ratings. In its basic form, it approximates a large, sparse (i.e., mostly empty) matrix of user-item interactions with a product of two smaller, denser matrices representing learned item and user features. These dense matrices, in turn, can be used to recommend items to a user with which they haven’t interacted before.

Despite its algorithmic simplicity, matrix factorization can still achieve competitive performance in recommender benchmarks. Alternating least squares (ALS), and especially its implicit variation, is a fundamental algorithm to learn the parameters of matrix factorization. ALS is known for its high efficiency because it scales linearly in the number of rows, columns and non-zeros. Hence, this algorithm is very well suited for large-scale challenges. But, for very large real-world matrix factorization datasets, a single machine implementation would not suffice, and so, it would require a large distributed system. Most of the distributed implementations of matrix factorization that employ ALS leverage off-the-shelf CPU devices, and rightfully so, due to the inherently sparse nature of the problem (the input matrix is mostly empty).

On the other hand, recent success of deep learning, which has exhibited growing computational capacity, has spurred a new wave of research and progress on hardware accelerators such as Tensor Processing Units (TPUs). TPUs afford domain specific hardware speedups, especially for use cases like deep learning, which involves a large number of dense matrix multiplications. In particular, they allow significant speedups for traditional data-parallel workloads, such as training models with Stochastic Gradient Descent (SGD) in SPMD (single program multiple data) fashion. The SPMD approach has gained popularity in computations like training neural networks with gradient descent algorithms, and can be used for both data-parallel and model-parallel computations, where we distribute parameters of the model across available devices. Nevertheless, while TPUs have been enormously attractive for methods based on SGD, it is not immediately clear if a high performance implementation of ALS, which requires a large number of distributed sparse matrix multiplies, can be developed for a large-scale cluster of TPU devices.

In “ALX: Large Scale Matrix Factorization on TPUs”, we explore a distributed ALS design that makes efficient use of the TPU architecture and can scale well to matrix factorization problems of the order of billions of rows and columns by scaling the number of available TPU cores. The approach we propose leverages a combination of model and data parallelism, where each TPU core both stores a portion of the embedding table and trains over a unique slice of data, grouped in mini-batches. In order to spur future research on large-scale matrix factorization methods and to illustrate the scalability properties of our own implementation, we also built and released a real world web link prediction dataset called WebGraph.

The figure shows the flow of data and computation through the ALX framework on TPU devices. Similar to SGD-based training procedures, each TPU core performs identical computation for its own batch of data in SPMD fashion, which allows for synchronous computation in parallel on multiple TPU cores. Each TPU starts with gathering all the relevant item embeddings in the Sharded Gather stage. These materialized embeddings are used to solve for user embeddings which are scattered to the relevant shard of the embedding table in the Sharded Scatter stage.

Dense Batching for Improved Efficiency
We designed ALX specifically for TPUs, exploiting unique properties of TPU architecture while overcoming a few interesting limitations. For instance, each TPU core has limited memory and restricts all tensors to have a static shape, but each example in a mini-batch can have a wildly varying number of items (i.e., inputs can be long and sparse). To resolve this, we break exceedingly long examples into multiple smaller examples of the same shape, a process called dense batching. More details about dense batching can be found in our paper.

Illustrating example of how sparse batches are densified to increase efficiency on TPUs.

Uniform Sharding of Embedding Tables
With the batching problem solved, we next want to factorize a sparse matrix into two dense embedding matrices (e.g., user and item embeddings) such that the resulting dot product of embeddings approximate the original sparse matrix — this helps us infer predictions for all the positions from the original matrix, including those that were empty, which can be used to recommend items with which users haven’t interacted. Both the resulting embedding tables (W and H in the figure below) can potentially be too large to fit in a single TPU core, thus requiring a distributed training setup for most large-scale use cases.

Most previous attempts of distributed matrix factorization use a parameter server architecture where the model parameters are stored on highly available servers, and the training data is processed in parallel by workers that are solely responsible for the learning task. In our case, since each TPU core has identical compute and memory, it’s wasteful to only use either memory for storing model parameters or compute for training. Thus, we designed our system such that each core is used to do both.

Illustrative example of factorizing a sparse matrix Y into two dense embedding matrices W and H.

In ALX, we uniformly divide both embedding tables, thus fully exploiting both the size of distributed memory available and the dedicated low-latency interconnects between TPUs. This is highly efficient for very large embedding tables and results in good performance for distributed gather and scatter operations.

Uniform sharding of both embedding tables (W and H) across TPU cores (in blue).

WebGraph
Since potential applications may involve very large data sets, scalability is potentially an important opportunity for advancement in matrix factorization. To that end, we also release a large real-world web link prediction dataset called WebGraph. This dataset can be easily modeled as a matrix factorization problem where rows and columns are source and destination links, respectively, and the task is to predict destination links from each source link. We use WebGraph to illustrate the scaling properties of ALX.

The WebGraph dataset was generated from a single crawl performed by CommonCrawl in 2021 where we strip everything and keep only the link->outlinks data. Since the performance of a factorization method depends on the properties of the underlying graph, we created six versions of WebGraph, each varying in the sparsity pattern and locale, to study how well ALS performs on each.

  • To study locale-specific graphs, we filter based on two top level domains: ‘de’ and ‘in’, each producing a graph with an order of magnitude fewer nodes.
  • These graphs can still have arbitrary sparsity patterns and dangling links. Thus we further filter the nodes in each graph to have a minimum of either 10 or 50 inlinks and outlinks.

For easy access, we have made these available as a Tensorflow Dataset package. For reference, the biggest version, WebGraph-sparse, has more than 365M nodes and 30B edges. We create and publish both training and testing splits for evaluation purposes.

Results
We carefully tune the system and quality parameters of ALX. Based on our observations related to precision and choice of linear solvers. ​​We observed that by carefully selecting the precision for storage of the embedding tables (bfloat16) and for the input to the linear solvers (float32), we were able to halve the memory required for the embeddings while still avoiding problems arising from lower precision values during the solve stage. For our linear solvers, we selected conjugate gradients, which we found to be the fastest across the board on TPUs. We use embeddings of dimension 128 and train the model for 16 epochs. In our experience, hyperparameter tuning over both norm penalty (λ) and unobserved weight (α) has been indispensable for good recall metrics as shown in the table below.

Results obtained by running ALX on all versions of WebGraph dataset. Recall values of 1.0 denote perfect recall.

Scaling Analysis
Since the input data are processed in parallel across TPU cores, increasing the number of cores decreases training time, ideally in a linear fashion. But at the same time, a larger number of cores requires more network communication (due to the sharded embedding tables). Thanks to high-speed interconnects, this overhead can be negligible for a small number of cores, but as the number of cores increases, the overhead eventually slows down the ideal linear scaling.

In order to confirm our hypothesis, we analyze scaling properties of the four biggest WebGraph variants in terms of training time as we increase the number of available TPU cores. As shown below, even empirically, we do observe the predicted linear decrease in training time up to a sweet spot, after which the network overhead slows the decline.

Scaling analysis of running time as the number of TPU cores are increased. Each figure plots the time taken to train for one epoch in seconds.

Conclusion
For easy access and reproducibility, the ALX code is open-sourced and can be easily run on Google Cloud. In fact, we illustrate that a sparse matrix like WebGraph-dense of size 135M x 135M (with 22B edges) can be factorized in a colab connected to 8 TPU cores in less than a day. We have designed the ALX framework with scalability in mind. With 256 TPU cores, one epoch of the largest WebGraph variant, WebGraph-sparse (365M x 365M sparse matrix) takes around 20 minutes to finish (5.5 hours for the whole training run). The final model has around 100B parameters. We hope that the ALX and WebGraph will be useful to both researchers and practitioners working in these fields. The code for ALX can be found here on github!

Acknowledgements
The core team includes Steffen Rendle, Walid Krichene and Li Zhang. We thank many Google colleagues for helping at various stages of this project. In particular, we are grateful to the JAX team for numerous discussions, especially James Bradbury and Skye Wanderman-Milne; Blake Hechtman for help with XLA and Rasmus Larsen for useful discussions about performance of linear solvers on TPUs. Finally, we’re also grateful to Nicolas Mayoraz and John Anderson for providing useful feedback.

Read More