Build fast, sparse on-device models with the new TF MOT Pruning API

Posted by Yunlu Li and Artsiom Ablavatski

Introduction

Pruning is one of the core optimization techniques provided in the TensorFlow Model Optimization Toolkit (TF MOT). Not only does it help to significantly reduce model size, but it can also be used to accelerate CPU inference on mobile and web. With modern compute intensive models, the area of pruning as a model optimization technique has drawn significant attention, demonstrating that dense networks can be easily pruned (i.e. a fraction of the weights set to zero) with negligible quality degradation. Today, we are excited to announce a set of updates to TF MOT Pruning API that simplify pruning and enable developers to build sparse models for fast on-device inference.

Updates to TF MOT

TensorFlow has long standing support for neural network pruning via TensorFlow Model Optimization Toolkit (TF MOT) Pruning API. The API, featured in 2019, introduced essential primitives for pruning, and enabled researchers throughout the world with new optimization techniques. Today we are happy to announce experimental updates to the API that further advance model pruning. We are releasing tools that simplify the control of pruning and enable latency reduction for on-device inference.

The TF MOT Pruning API has extensive functionality that provides the user with tools for model manipulation:

  • prune_low_magnitude function applies PruneLowMagnitude wrapper to every layer in the model
  • PruneLowMagnitude wrapper handles low-level pruning logic
  • PruningSchedule controls when pruning is applied
  • PruningSummaries callback logs the pruning progress

These abstractions allow to control almost any aspect of model pruning (i.e. how to prune (PruneLowMagnitude), when to prune (PruningSchedule) and how to track the progress of the pruning (PruningSummaries) with the exception of what to prune, i.e. where PruneLowMagnitude wrapper is applied. We are happy to release an extension of TF MOT PruningPolicy, a class that controls which parts of the model the PruneLowMagnitude wrapper is applied to. The instance of PruningPolicy is used as an argument in the prune_low_magnitude function and provides the following functionalities:

  • Controls where the pruning wrapper should be applied on per-layer basis through the allow_pruning function
  • Checks that the whole model supports pruning via ensure_model_supports_pruning function

PruningPolicy is an abstract interface, and it can have many implementations depending on the particular application. For latency improvements on CPU via XNNPACK, the concrete implementation PruneForLatencyOnXNNPack applies the pruning wrapper only to the parts of the model that can be accelerated via sparse on-device inference while leaving the rest of the network untouched. Such selective pruning allows an application to maintain model quality while targeting parts of the model that can be accelerated by sparsity.

The below example showcases the PruneForLatencyOnXNNPack policy in action on

MobileNetV2 (the full example is available in a recently introduced colab):

import tensorflow as tf
import tensorflow_model_optimization as tfmot
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# See the implementation of the function below.
model = load_model_for_pruning()

model_for_pruning = prune_low_magnitude(
model, pruning_policy=tfmot.sparsity.keras.PruneForLatencyOnXNNPack())

In order to comply with the constraints of XNNPACK sparse inference the Keras implementation of MobileNetV2 model requires slight modification of the padding for the first convolution operation:

def load_model_for_pruning():
input_tensor = tf.keras.layers.Input((224, 224, 3))
input_tensor = tf.keras.layers.ZeroPadding2D(1)(input_tensor)
model = tf.keras.applications.MobileNetV2(input_tensor=input_tensor)

def clone_fn(layer):
if layer.name == 'Conv1':
# The default padding `SAME` for the first convolution is incompatible
# with XNNPACK sparse inference.
layer.padding = 'valid'
# We ask the model to rebuild since we've changed the padding parameter.
layer.built = False
return layer

return tf.keras.models.clone_model(model, clone_function=clone_fn)

The PruneForLatencyOnXNNPack policy applies the pruning wrapper only to convolutions with 1×1 kernel size since only these layers can be accelerated on CPU by as much as 2x using XNNPACK. The rest of the layers are left untouched allowing the network to recover after quality degradation at the pruning step. Also, the policy verifies that the model is amenable to being pruned by using the ensure_model_supports_pruning method. Once the sparse model has been trained and converted, we recommend using the TensorFlow Lite benchmark utility in debug mode to confirm that the final model is compatible with XNNPack’s sparse inference backend.

We hope that this newly introduced experimental API will be useful in practice and we will continue to improve its stability and flexibility in the future.

Compression and Latency Improvements

Model compression is another major benefit of applying pruning to a model. Using a smart compression format allows efficient storage of model weights which leads to a significant size reduction.

TFLite adopted the TACO format to encode sparse tensors. Compared to widely used formats like CSR and CSC, the TACO format has several advantages:

  1. It supports flexible traversal order to store a tensor in row-major or column-major formats easily.
  2. It supports multi-dimensional sparse tensors like the 4-D filter of a convolution op.
  3. It can represent block structure as the inner dimension of the tensor (example of a 4×4 tensor with 2×2 inner block structure).

We also adapted the format to use flexible data types for the metadata storing the indices of non-zero elements. This reduces the storage overhead for small tensors, or tensors with compact data types like int8_t.

In order to realize size reductions in practice during the model conversion, the tf.lite.Optimize.EXPERIMENTAL_SPARSITY optimization needs to be applied. This optimization handles examining the model for sparse tensors and converting them to an efficient storage format. It also works seamlessly with quantization and you can combine them to achieve more aggressive model compresion. The full example of such a conversion is shown below:

# Remove the pruning wrappers from the model. 
model = tfmot.sparsity.keras.strip_pruning(model)

converter = tf.lite.TFLiteConverter.from_keras_model(model)
# We apply float16 quantization together with sparsity optimization that
# compactly stores pruned weights.
converter.optimizations = [
tf.lite.Optimize.EXPERIMENTAL_SPARSITY, # Enables size reduction optimization.
tf.lite.Optimize.DEFAULT # Enables quantization at conversion.
]
converter.target_spec.supported_types = [tf.float16]
tflite_buffer = converter.convert()

After applying the tf.lite.Optimize.EXPERIMENTAL_SPARSITY optimization together with PruneForLatencyOnXNNPack pruning policy, a ~2x size reduction can be achieved as is demonstrated in Figure 1:

Ablation study of MobileNetV2 model size (float32 and float16 types) with different sparsity levels using PruneForLatencyOnXNNPack pruning policy.
Figure 1. Ablation study of MobileNetV2 model size (float32 and float16 types) with different sparsity levels using PruneForLatencyOnXNNPack pruning policy. Only the 1×1 convolutional layers are pruned and the rest of the layers are left dense.

In addition to size reduction, pruning can provide inference acceleration on CPU via XNNPACK. Using the PruneForLatencyOnXNNPack pruning policy, we’ve conducted an ablation study of CPU inference latency for a MobileNetV2 model on Pixel 4 using TensorFlow Lite benchmark with the use_xnnpack option enabled:

Ablation study of CPU inference speed of MobileNetV2 model with different sparsity levels on a Pixel 4 device.
Figure 2. Ablation study of CPU inference speed of MobileNetV2 model with different sparsity levels on a Pixel 4 device.

This study in Figure 2 demonstrates 1.7x latency improvement when running on mobile devices using XNNPACK. The strategies for training the sparse MobileNetV2 model together with hyperparameters and pre-trained checkpoints are described in Elsen et al.

Pruning techniques & tips

Pruning aware training is a key step in model optimization. Many hyperparameters are involved in training and some of them like the pruning schedule and learning rate can have a dramatic impact on the final quality of the model. Though many strategies have been proposed, a simple yet effective 3-steps strategy (see Table 1) achieves strong performance for the majority of our use cases. The strategy builds on top of the well-proven approach from Zhu & Gupta and produces good results without extensive re-training:

Step

Learning rate

Duration

Notes

1. Pre-training or

using pre-trained weights (optional)

The same as for the regular dense network: starting from high value (possibly with warm-up) and ending with low value

The same as for the regular dense network

Paired with weight decay regularization this step helps the model to push unimportant weights towards 0 for pruning in the next step

2. Iterative pruning

Constant, mean of the learning rate values for the regular training

30 to 100 epochs

Iterative pruning step during which weights become sparse

3. Fine-tuning

The same as at the first stage but without warm up stage

The same as at the first stage

Helps to mitigate quality degradation after the pruning step

3-step schedule for training the sparse model

The strategy inevitably leads to a substantial increase (~3x) in the training time. However, paired with the PolynomialDecay pruning schedule, this 3-step strategy achieves limited or no quality degradation with significantly pruned (>70%) neural networks.

Pruned models in MediaPipe

Together with the updates to the TF MOT Pruning API, we are happy to release pruned models for some of the MediaPipe solutions. The released models include pose and face detectors as well as a pruned hand tracking model. All of these models have been trained with the newly introduced functionality using the 3-steps pruning strategy. Compared with dense baselines the released pruned models have demonstrated significant model size reduction as well as superior performance when running on CPU via XNNPACK. Quality-wise the pruned models achieve similar metrics including in the evaluation on our fairness datasets (see model cards for details). Side-by-side demos of the solutions are shown below:

MediaPipe example showing female waving at camera
MediaPipe example showing person jumping
Figure 3. Comparison of dense (left) and sparse (right) models in the end-to-end examples of face (top) and pose (bottom) detection

Pruning for GPU

While exploiting sparsity on GPUs can be challenging, recent work has made progress in improving the performance of sparse operations on these platforms. There is momentum for adding first-class support for sparse matrices and sparse operations in popular frameworks, and state-of-the-art GPUs have recently added hardware acceleration for some forms of structured sparsity. Going forward, improvements in software and hardware support for sparsity in both training and inference will be a key contributor to progress in the field.

Future directions

TF MOT offers a variety of model optimization methods, many of which have proven to be essential for efficient on-device model inference. We will continue to expand the TF MOT Pruning API with algorithms beyond low magnitude pruning, and also investigate the combination of pruning and quantization techniques to achieve even better results for on-device inference. Stay tuned!

Acknowledgments

Huge thanks to all who worked on this project: Karthik Raveendran, Ethan Kim, Marat Dukhan‎, Trevor Gale, Utku Evci, Erich Elsen, Frank Barchard, Yury Kartynnik‎, Valentin Bazarevsky, Matsvei Zhdanovich, Juhyun Lee, Chuo-Ling Chang, Ming Guang Yong, Jared Duke‎ and Matthias Grundmann.

Read More