This backend allows forest models trained by several popular machine learning frameworks (including XGBoost, LightGBM, Scikit-Learn, and cuML) to be deployed in a Triton inference server using the RAPIDS Forest Inference LIbrary for fast GPU-based inference. Using this backend, forest models can be deployed seamlessly alongside deep learning models for fast, unified inference pipelines.
Pre-built Triton containers are available from NGC and may be pulled down via
docker pull nvcr.io/nvidia/tritonserver:21.06.1-py3
Note that the FIL backend cannot be used in the 21.06
version of this
container; the 21.06.1
patch release or later is required.
To build the Triton server container with the FIL backend, run the following from the repo root:
docker build -t triton_fil -f ops/Dockerfile .
For those with special build needs (including using a custom base container), you may find our documentation on customized end-to-end builds useful.
Before starting the server, you will need to set up a "model repository" directory containing the model you wish to serve as well as a configuration file. The FIL backend currently supports forest models serialized in XGBoost's binary format, XGBoost's JSON format, LightGBM's text format, and Treelite's binary checkpoint format. For those using cuML or Scikit-Learn random forest models, please see the documentation on how to prepare such models for use in Triton. Once you have a serialized model, you will need to prepare a directory structure similar to the following example, which uses an XGBoost binary file:
model_repository/
`-- fil
|-- 1
| `-- xgboost.model
`-- config.pbtxt
By default, the FIL backend assumes that XGBoost binary models will be named
xgboost.model
, XGBoost json models will be named xgboost.json
, LightGBM
models will be named model.txt
, and Treelite binary models will be named
checkpoint.tl
, but this can be tweaked through standard Triton configuration
options.
The FIL backend repository includes a Python script for generating example models and configuration files using XGBoost, LightGBM, Scikit-Learn, and cuML. These examples may serve as a useful template for setting up your own models on Triton. See the documentation on generating example models for more details.
Once you have chosen a model to deploy and placed it in the correct directory
structure, you will need to create a corresponding config.pbtxt
file. An
example of this configuration file is shown below:
name: "fil"
backend: "fil"
max_batch_size: 8192
input [
{
name: "input__0"
data_type: TYPE_FP32
dims: [ 500 ]
}
]
output [
{
name: "output__0"
data_type: TYPE_FP32
dims: [ 2 ]
}
]
instance_group [{ kind: KIND_GPU }]
parameters [
{
key: "model_type"
value: { string_value: "xgboost" }
},
{
key: "predict_proba"
value: { string_value: "true" }
},
{
key: "output_class"
value: { string_value: "true" }
},
{
key: "threshold"
value: { string_value: "0.5" }
},
{
key: "algo"
value: { string_value: "ALGO_AUTO" }
},
{
key: "storage_type"
value: { string_value: "AUTO" }
},
{
key: "blocks_per_sm"
value: { string_value: "0" }
}
]
dynamic_batching {
preferred_batch_size: [1, 2, 4, 8, 16, 32, 64, 128, 1024, 8192]
max_queue_delay_microseconds: 30000
}
For a full description of the configuration schema, see the Triton server docs. Here, we will simply summarize the most commonly-used options and those specific to FIL:
-
max_batch_size
: The maximum number of samples to process in a batch. In general, FIL's efficient handling of even large forest models means that this value can be quite high (2^13 in the example), but this may need to be reduced for your particular hardware configuration if you find that you are exhausting system resources (such as GPU or system RAM). -
input
: This configuration block specifies information about the input arrays that will be provided to the FIL model. Thedims
field should be set to[ NUMBER_OF_FEATURES ]
, but all other fields should be left as they are in the example. Note that thename
field should always be given a value ofinput__0
. Unlike some deep learning frameworks where models may have multiple input layers with different names, FIL-based tree models take a single input array with a consistent name ofinput__0
. -
output
: This configuration block specifies information about the arrays output by the FIL model. If thepredict_proba
option (described later) is set to "true" and you are using a classification model, thedims
field should be set to[ NUMBER_OF_CLASSES ]
. Otherwise, this can simply be[ 1 ]
, indicating that the model returns a single class ID for each sample. -
parameters
: This block contains FIL-specific configuration details. Note that all parameters are input as strings and should be formatted withkey
andvalue
fields as shown in the example.model_type
: One of"xgboost"
,"xgboost_json"
,"lightgbm"
, or"treelite_checkpoint"
, indicating whether the provided model is in XGBoost binary format, XGBoost JSON format, LightGBM text format, or Treelite binary format respectively.predict_proba
: Either"true"
or"false"
, depending on whether the desired output is a score for each class or merely the predicted class ID.output_class
: Either"true"
or"false"
, depending on whether the model is a classification or regression model.threshold
: The threshold score used for class prediction.algo
: One of"ALGO_AUTO"
,"NAIVE"
,"TREE_REORG"
or"BATCH_TREE_REORG"
indicating which FIL inference algorithm to use. More details are available in the cuML documentation. If you are uncertain of what algorithm to use, we recommend selecting"ALGO_AUTO"
, since it is a safe choice for all models.storage_type
: One of"AUTO"
,"DENSE"
,"SPARSE"
, and"SPARSE8"
, indicating the storage format that should be used to represent the imported model."AUTO"
indicates that the storage format should be automatically chosen."SPARSE8"
is currently experimental.blocks_per_sm
: If set to any nonzero value (generally between 2 and 7), this provides a limit to improve the cache hit rate for large forest models. In general, network latency will significantly overshadow any speedup from tweaking this setting, but it is provided for cases where maximizing throughput is essential. Please see the cuML documentation for a more thorough explanation of this parameter and how it may be used.
-
dynamic_batching
: This configuration block specifies how Triton should perform dynamic batching for your model. Full details about these options can be found in the main Triton documentation. You may find it useful to test your configuration using the Tritonperf_analyzer
tool in order to optimize performance.preferred_batch_size
: A list of preferred values for the number of samples to process in a single batch.max_queue_delay_microseconds
: How long of a window in which requests can be accumulated to form a batch of a preferred size.
Note that the configuration is in protobuf format. If invalid protobuf is
provided, the model will fail to load, and you will see an error line in the
server log containing Error parsing text-format inference.ModelConfig:
followed by the line and column number where the parsing error occurred.
To run the server with the configured model, execute the following command:
docker run \
--gpus=all \
--rm \
-p 8000:8000 \
-p 8001:8001 \
-p 8002:8002 \
-v $PATH_TO_MODEL_REPO_DIR:/models \
triton_fil \
tritonserver \
--model-repository=/models
General examples for submitting inference requests to a Triton server are available here. For convenience, we provide the following example code for using the Python client to submit inference requests to a FIL model deployed on a Triton server on the local machine:
import numpy
import tritonclient.http as triton_http
import tritonclient.grpc as triton_grpc
# Set up both HTTP and GRPC clients. Note that the GRPC client is generally
# somewhat faster.
http_client = triton_http.InferenceServerClient(
url='localhost:8000',
verbose=False,
concurrency=12
)
grpc_client = triton_grpc.InferenceServerClient(
url='localhost:8001',
verbose = False
)
# Generate dummy data to classify
features = 1_000
samples = 8_192
data = numpy.random.rand(samples, features).astype('float32')
# Set up Triton input and output objects for both HTTP and GRPC
triton_input_http = triton_http.InferInput(
'input__0',
(samples, features),
'FP32'
)
triton_input_http.set_data_from_numpy(data, binary_data=True)
triton_output_http = triton_http.InferRequestedOutput(
'output__0',
binary_data=True
)
triton_input_grpc = triton_grpc.InferInput(
'input__0',
(samples, features),
'FP32'
)
triton_input_grpc.set_data_from_numpy(data)
triton_output_grpc = triton_grpc.InferRequestedOutput('output__0')
# Submit inference requests (both HTTP and GRPC)
request_http = http_client.infer(
'fil',
model_version='1',
inputs=[triton_input_http],
outputs=[triton_output_http]
)
request_grpc = grpc_client.infer(
'fil',
model_version='1',
inputs=[triton_input_grpc],
outputs=[triton_output_grpc]
)
# Get results as numpy arrays
result_http = request_http.as_numpy('output__0')
result_grpc = request_grpc.as_numpy('output__0')
# Check that we got the same result with both GRPC and HTTP
numpy.testing.assert_almost_equal(result_http, result_grpc)
For full implementation details as well as information on modifying the FIL backend code or contributing code to the project, please see CONTRIBUTING.md.