Posted by Yunlu Li and Artsiom Ablavatski
Pruning is one of the core optimization techniques provided in the TensorFlow Model Optimization Toolkit (TF MOT). Not only does it help to significantly reduce model size, but it can also be used to accelerate CPU inference on mobile and web. With modern compute intensive models, the area of pruning as a model optimization technique has drawn significant attention, demonstrating that dense networks can be easily pruned (i.e. a fraction of the weights set to zero) with negligible quality degradation. Today, we are excited to announce a set of updates to TF MOT Pruning API that simplify pruning and enable developers to build sparse models for fast on-device inference.
Updates to TF MOT
TensorFlow has long standing support for neural network pruning via TensorFlow Model Optimization Toolkit (TF MOT) Pruning API. The API, featured in 2019, introduced essential primitives for pruning, and enabled researchers throughout the world with new optimization techniques. Today we are happy to announce experimental updates to the API that further advance model pruning. We are releasing tools that simplify the control of pruning and enable latency reduction for on-device inference.
The TF MOT Pruning API has extensive functionality that provides the user with tools for model manipulation:
PruneLowMagnitudewrapper to every layer in the model
PruneLowMagnitudewrapper handles low-level pruning logic
PruningSchedulecontrols when pruning is applied
PruningSummariescallback logs the pruning progress
These abstractions allow to control almost any aspect of model pruning (i.e. how to prune (
PruneLowMagnitude), when to prune (
PruningSchedule) and how to track the progress of the pruning (
PruningSummaries) with the exception of what to prune, i.e. where
PruneLowMagnitude wrapper is applied. We are happy to release an extension of TF MOT
PruningPolicy, a class that controls which parts of the model the
PruneLowMagnitude wrapper is applied to. The instance of
PruningPolicy is used as an argument in the
prune_low_magnitude function and provides the following functionalities:
- Controls where the pruning wrapper should be applied on per-layer basis through the
- Checks that the whole model supports pruning via
PruningPolicy is an abstract interface, and it can have many implementations depending on the particular application. For latency improvements on CPU via XNNPACK, the concrete implementation
PruneForLatencyOnXNNPack applies the pruning wrapper only to the parts of the model that can be accelerated via sparse on-device inference while leaving the rest of the network untouched. Such selective pruning allows an application to maintain model quality while targeting parts of the model that can be accelerated by sparsity.
The below example showcases the
PruneForLatencyOnXNNPack policy in action on
import tensorflow as tf
import tensorflow_model_optimization as tfmot
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
# See the implementation of the function below.
model = load_model_for_pruning()
model_for_pruning = prune_low_magnitude(
input_tensor = tf.keras.layers.Input((224, 224, 3))
input_tensor = tf.keras.layers.ZeroPadding2D(1)(input_tensor)
model = tf.keras.applications.MobileNetV2(input_tensor=input_tensor)
if layer.name == 'Conv1':
# The default padding `SAME` for the first convolution is incompatible
# with XNNPACK sparse inference.
layer.padding = 'valid'
# We ask the model to rebuild since we've changed the padding parameter.
layer.built = False
return tf.keras.models.clone_model(model, clone_function=clone_fn)
PruneForLatencyOnXNNPack policy applies the pruning wrapper only to convolutions with 1×1 kernel size since only these layers can be accelerated on CPU by as much as 2x using XNNPACK. The rest of the layers are left untouched allowing the network to recover after quality degradation at the pruning step. Also, the policy verifies that the model is amenable to being pruned by using the
ensure_model_supports_pruning method. Once the sparse model has been trained and converted, we recommend using the TensorFlow Lite benchmark utility in debug mode to confirm that the final model is compatible with XNNPack’s sparse inference backend.
We hope that this newly introduced experimental API will be useful in practice and we will continue to improve its stability and flexibility in the future.
Compression and Latency Improvements
Model compression is another major benefit of applying pruning to a model. Using a smart compression format allows efficient storage of model weights which leads to a significant size reduction.
- It supports flexible traversal order to store a tensor in row-major or column-major formats easily.
- It supports multi-dimensional sparse tensors like the 4-D filter of a convolution op.
- It can represent block structure as the inner dimension of the tensor (example of a 4×4 tensor with 2×2 inner block structure).
We also adapted the format to use flexible data types for the metadata storing the indices of non-zero elements. This reduces the storage overhead for small tensors, or tensors with compact data types like
In order to realize size reductions in practice during the model conversion, the
tf.lite.Optimize.EXPERIMENTAL_SPARSITY optimization needs to be applied. This optimization handles examining the model for sparse tensors and converting them to an efficient storage format. It also works seamlessly with quantization and you can combine them to achieve more aggressive model compresion. The full example of such a conversion is shown below:
# Remove the pruning wrappers from the model.
model = tfmot.sparsity.keras.strip_pruning(model)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
# We apply float16 quantization together with sparsity optimization that
# compactly stores pruned weights.
converter.optimizations = [
tf.lite.Optimize.EXPERIMENTAL_SPARSITY, # Enables size reduction optimization.
tf.lite.Optimize.DEFAULT # Enables quantization at conversion.
converter.target_spec.supported_types = [tf.float16]
tflite_buffer = converter.convert()
After applying the
tf.lite.Optimize.EXPERIMENTAL_SPARSITY optimization together with
PruneForLatencyOnXNNPack pruning policy, a ~2x size reduction can be achieved as is demonstrated in Figure 1:
|Figure 1. Ablation study of MobileNetV2 model size (float32 and float16 types) with different sparsity levels using
PruneForLatencyOnXNNPack pruning policy. Only the 1×1 convolutional layers are pruned and the rest of the layers are left dense.
In addition to size reduction, pruning can provide inference acceleration on CPU via XNNPACK. Using the
PruneForLatencyOnXNNPack pruning policy, we’ve conducted an ablation study of CPU inference latency for a MobileNetV2 model on Pixel 4 using TensorFlow Lite benchmark with the
use_xnnpack option enabled:
|Figure 2. Ablation study of CPU inference speed of MobileNetV2 model with different sparsity levels on a Pixel 4 device.
This study in Figure 2 demonstrates 1.7x latency improvement when running on mobile devices using XNNPACK. The strategies for training the sparse MobileNetV2 model together with hyperparameters and pre-trained checkpoints are described in Elsen et al.
Pruning techniques & tips
Pruning aware training is a key step in model optimization. Many hyperparameters are involved in training and some of them like the pruning schedule and learning rate can have a dramatic impact on the final quality of the model. Though many strategies have been proposed, a simple yet effective 3-steps strategy (see Table 1) achieves strong performance for the majority of our use cases. The strategy builds on top of the well-proven approach from Zhu & Gupta and produces good results without extensive re-training:
1. Pre-training or
using pre-trained weights (optional)
The same as for the regular dense network: starting from high value (possibly with warm-up) and ending with low value
The same as for the regular dense network
Paired with weight decay regularization this step helps the model to push unimportant weights towards 0 for pruning in the next step
2. Iterative pruning
Constant, mean of the learning rate values for the regular training
30 to 100 epochs
Iterative pruning step during which weights become sparse
The same as at the first stage but without warm up stage
The same as at the first stage
Helps to mitigate quality degradation after the pruning step
3-step schedule for training the sparse model
The strategy inevitably leads to a substantial increase (~3x) in the training time. However, paired with the
PolynomialDecay pruning schedule, this 3-step strategy achieves limited or no quality degradation with significantly pruned (>70%) neural networks.
Pruned models in MediaPipe
Together with the updates to the TF MOT Pruning API, we are happy to release pruned models for some of the MediaPipe solutions. The released models include pose and face detectors as well as a pruned hand tracking model. All of these models have been trained with the newly introduced functionality using the 3-steps pruning strategy. Compared with dense baselines the released pruned models have demonstrated significant model size reduction as well as superior performance when running on CPU via XNNPACK. Quality-wise the pruned models achieve similar metrics including in the evaluation on our fairness datasets (see model cards for details). Side-by-side demos of the solutions are shown below:
|Figure 3. Comparison of dense (left) and sparse (right) models in the end-to-end examples of face (top) and pose (bottom) detection
Pruning for GPU
While exploiting sparsity on GPUs can be challenging, recent work has made progress in improving the performance of sparse operations on these platforms. There is momentum for adding first-class support for sparse matrices and sparse operations in popular frameworks, and state-of-the-art GPUs have recently added hardware acceleration for some forms of structured sparsity. Going forward, improvements in software and hardware support for sparsity in both training and inference will be a key contributor to progress in the field.
TF MOT offers a variety of model optimization methods, many of which have proven to be essential for efficient on-device model inference. We will continue to expand the TF MOT Pruning API with algorithms beyond low magnitude pruning, and also investigate the combination of pruning and quantization techniques to achieve even better results for on-device inference. Stay tuned!
Huge thanks to all who worked on this project: Karthik Raveendran, Ethan Kim, Marat Dukhan, Trevor Gale, Utku Evci, Erich Elsen, Frank Barchard, Yury Kartynnik, Valentin Bazarevsky, Matsvei Zhdanovich, Juhyun Lee, Chuo-Ling Chang, Ming Guang Yong, Jared Duke and Matthias Grundmann.