Decision forest models like random forests and gradient boosted trees are often the most effective tools available for working with tabular data. They provide many advantages over neural networks, including being easier to configure, and faster to train. Using trees greatly reduces the amount of code required to prepare your dataset, as they natively handle numeric, categorical, and missing features. And they often give good results out-of-the-box, with interpretable properties.
Although we usually think of TensorFlow as a library to train neural networks, a popular use case at Google is to use TensorFlow to create decision forests.
|An animation of a decision tree classifying data.|
This article provides a migration guide if you were previously creating tree-based models using tf.estimator.BoostedTrees, which was introduced in 2019. The Estimator API took care of much of the complexity of working with models in production, including distributed training and serialization. However, it is no longer recommended for new code.
If you are starting a new project, we recommend that you use TensorFlow Decision Forests (TF-DF). This library provides state-of-the-art algorithms for training, serving and interpreting decision forest models, with many benefits over the previous approach, notably regarding quality, speed, and ease of use.
To start, here are equivalent examples using the Estimator API and TF-DF to create a boosted tree model.
Previously, this is how you would train a gradient boosted tree models with tf.estimator.BoostedTrees (no longer recommended)
import tensorflow as tf
# Dataset generators
data = ... # read dataset
# List the possible values for the feature "f_2".
f_2_dictionary = ["NA", "red", "blue", "green"]
# The feature columns define the input features of the model.
feature_columns = [
# A special value "missing" is used to represent missing values.
# Configure the estimator
estimator = boosted_trees.BoostedTreesClassifier(
# Rule of thumb proposed in the BoostedTreesClassifier documentation.
n_batches_per_layer=max(2, int(len(train_df) / 2 / FLAGS.batch_size)),
# Stop the training is the validation loss stop decreasing.
early_stopping_hook = early_stopping.stop_if_no_decrease_hook(
# Early stopping needs a CheckpointSaverHook.
How to train the same model using TensorFlow Decision Forests
import tensorflow_decision_forests as tfdf
# Load the datasets
# This code is similar to the estimator.
data = ... # read dataset
train_dataset = make_dataset(train_path)
valid_dataset = make_dataset(valid_path)
# List the input features of the model.
features = [
model = tfdf.keras.GradientBoostedTreesModel(
task = tfdf.keras.Task.CLASSIFICATION,
# Export the model to a SavedModel.
- While not explicit in this example, early stopping is automatically enabled and configured.
- The dictionary of the “f_2” features is automatically built and optimized (e.g. rare values are merged into an out-of-vocabulary item).
- The number of classes (3 in this example) is automatically determined from the dataset.
- The batch size (64 in this example), has no impact on the model training. Larger values are often preferable as it makes reading the dataset more efficient.
TF-DF is all about ease of use, and the previous example can be further simplified and improved, as shown next.
How to train a TensorFlow Decision Forests (recommended solution)
import tensorflow_decision_forests as tfdf
import pandas as pd
# Pandas dataset can be used easily with pd_dataframe_to_tf_dataset.
train_df = pd.read_csv("project/train.csv")
# Convert the Pandas dataframe into a TensorFlow dataset.
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="my_label")
model = tfdf.keras.GradientBoostedTreeModel(num_trees=1000)
- We did not specify the semantics (e.g. numerical, or categorical) of the features. In this case, the semantics will be automatically inferred.
- We also didn’t list which input features to use. In this case, all the columns (except for the label) will be used. The list and semantics of the input feature is visible in the training logs, or with the model inspector API.
- We did not specify any validation dataset. Each algorithm will optionally extract a validation dataset from the training examples as best for the algorithm. For example, by default, GradientBoostedTreeModel uses 10% of the training data for validation if no validation dataset is provided.
Now, let’s look at a couple differences between the Estimator API and TF-DF.
Differences between the Estimator API and TF-DF
Type of algorithms
TF-DF is a collection of decision forest algorithms. This includes (but is not limited to) the Gradient Boosted Trees available with the Estimator API. Notably, TF-DF also supports Random Forest (great for nosy datasets) and a CART implementation (great for model interpretation).
Exact vs approximate splits
The TF1 GBT Estimator is an approximated tree learning algorithm. Informally, the Estimator builds trees by only considering a random subset of examples and a random subset of the conditions at each step.
By default, TF-DF is an exact tree training algorithm. Informally, TF-DF considers all the training examples and all the possible splits at each step. This is a more common and often better performing solution.
While sometimes faster on larger datasets (>10B examples x features), the estimator approximation are often less accurate (as more trees need to be grown to reach the same quality). In a small dataset (<100M examples x features), the form of approximated training implemented in the Estimator can even be slower than exact training.
TF-DF also supports various types of “approximated” tree training. The recommended approach is to use exact training, and optionally test approximated training on large datasets.
While both algorithms return the exact same results, the top-down algorithm is less efficient because of exceeding branching predictions and cache misses. TF-DF inference is generally 10x faster on the same model.
For latency critical applications TF-DF offers a C++ API. It provides often ~1µs/example/core inference time. This is often a 50x-1000x speed-up over TF SavedModel inference (especially on small batches).
The Estimator supports multi-head models (a model that outputs multiple predictions). TF-DF (currently) does not support multi-head models directly, however, using the Keras Functional API, multiple TF-DF models trained in parallel can be assembled into a multi-head model.
You can learn more about TensorFlow Decision Forests by visiting the website. If you’re new to this library, the beginner example is a good place to start. Experienced TensorFlow users can visit this guide for important details about the difference between using decision forests and neural networks in TensorFlow, including how to configure your training pipeline, and tips on Dataset I/O. You can also see Migrate from Estimator to Keras APIs for more info on migrating from Estimators to Keras in general.