Triton Inference Server#
time expected: 10 minutes
NVIDIA Triton Inference Server is a high performance, open-source inference server for serving deep learning models. It is optimized to deploy models from multiple deep learning frameworks, including TensorRT, TensorFlow, ONNX, to various deployments target and cloud providers. Triton is also designed with optimizations to maximize hardware utilization through concurrent model execution and efficient batching strategies.
BentoML now supports running Triton Inference Server as a Runner. The following integration guide assumes that readers are familiar with BentoML architecture. Check out our tutorial should you wish to learn more about BentoML service definition.
For more information about Triton, please refer to the Triton Inference Server documentation.
The code examples in this guide can also be found in the example folder.
Why Integrating BentoML with Triton Inference Server?#
If you are an existing Triton user, the integration provides simpler ways to add custom logics in Python, deploy distributed multi-model inference graph, unify model management across different ML frameworks and workflows, and standardise model packaging format with versioning and collaboration features. If you are an existing BentoML user, the integration improves the runner efficiency and throughput under high load thanks to Tritonβs efficient C++ runtime.
Prerequisites#
Make sure to have at least BentoML 1.0.16:
$ pip install -U "bentoml[triton]"
Note
Triton Inference Server is currently only available in production mode (the default mode) and will not work during development mode (--development
flag).
Additonally, you will need to have Triton Inference Server installed in your system. Refer to Tritonβs building documentation to setup your environment. The recommended way to run Triton is through container (Docker/Podman). To pull the latest Triton container for testing, run:
$ docker pull nvcr.io/nvidia/tritonserver:<yy>.<mm>-py3
Note
<yy>.<mm>
: the version of Triton you wish to use. For example, at the time of writing, the latest version is 23.01
.
Finally, The example Bento built from the example project with the YOLOv5 model will be referenced throughout this guide.
Note
To develop your own Bento with Triton, you can refer to the example folder for more usage.
Get started with Triton Inference Server#
Triton Inference Server architecture evolves around the model repository and a inference server. The model repository is a filesystem based persistent volume that contains the models file and its respective configuration that defines how the model should be loaded and served. The inference server is implemented in either HTTP/REST or gRPC protocol to serve said models with various batching strategies.
BentoML provides a simple integration with Triton via Runner:
import bentoml
triton_runner = bentoml.triton.Runner("triton_runner", model_repository="/path/to/model_repository")
The argument model_repository
is the path to said model repository that Triton can use to serve the model. Note that model_repository
also
supports S3 path:
import bentoml
triton_runner = bentoml.triton.Runner("triton_runner",
model_repository="s3://bucket/path/to/model_repository",
cli_args=["--load-model=torchscrip_yolov5s", "--model-control-mode=explicit"]
)
Note
If models are saved on the file system, using the Triton runner requires setting up the model repository explicitly through the includes key in the bentofile.yaml.
Note
The cli_args
argument is a list of arguments that will be passed to the tritonserver
command. For example, the --load-model
argument is used to load a specific model from the model repository.
See tritonserver --help
for all available arguments.
From a developer perspective, remote invocation of Triton runners is similar to invoking any other BentoML runners.
Note
By default, bentoml.triton.Runner
will run the tritonserver
with gRPC protocol. To use HTTP/REST protocol, provide tritonserver_type=''http'
to the Runner
constructor.
import bentoml
triton_runner = bentoml.triton.Runner("triton_runner", model_repository="/path/to/model_repository", tritonserver_type="http")
Triton Runner Signatures#
Normally in a BentoML Runner, one can access the model signatures directly from the runners attributes. For example, the model signature predict
of a iris_classifier_runner
(see service definition) can be accessed as iris_classifier_runner.predict.run
.
However, Triton runnerβs attributes represent individual models defined under the model repository. For example, if the model repository has the following structure:
model_repository
βββ onnx_mnist
βΒ Β βββ 1
βΒ Β βΒ Β βββ model.onnx
βΒ Β βββ config.pbtxt
βββ tensorflow_mnist
βΒ Β βββ 1
βΒ Β βΒ Β βββ model.savedmodel/
βΒ Β βββ config.pbtxt
βββ torchscript_mnist
βββ 1
βΒ Β βββ model.pt
βββ config.pbtxt
Then each model inference can be accessed as triton_runner.onnx_mnist
, triton_runner.tensorflow_mnist
, or triton_runner.torchscript_mnist
and invoked using either run
or async_run
.
An example to demonstrate how to call the Triton runner:
import bentoml
import numpy as np
@svc.api(
input=bentoml.io.Image.from_sample("./data/0.png"), output=bentoml.io.NumpyNdarray()
)
async def bentoml_torchscript_mnist_infer(im: Image) -> NDArray[t.Any]:
arr = np.array(im) / 255.0
arr = np.expand_dims(arr, (0, 1)).astype("float32")
InferResult = await triton_runner.torchscript_mnist.async_run(arr)
return InferResult.as_numpy("OUTPUT__0")
There are a few things to note here:
Triton runners should only be called within an API function. In other words, if
triton_runner.torchscript_mnist.async_run
is invoked in the global scope, it will not work. This is because Triton is not implemented natively in Python, and henceinit_local
is not supported.triton_runner.init_local() # TritonRunner 'triton_runner' will not be available for development mode.
async_run
andrun
for any Triton runner call either takes all positional arguments or keyword arguments. The arguments should be in the same order as the inputs/outputs signatures defined inconfig.pbtxt
.For example, if the following
config.pbtxt
is used fortorchscript_mnist
:platform: "pytorch_libtorch" dynamic_batching {} input { name: "INPUT__0" data_type: TYPE_FP32 dims: -1 dims: 1 dims: 28 dims: 28 } input { name: "INPUT__1" data_type: TYPE_FP32 dims: -1 dims: 1 dims: 28 dims: 28 } output { name: "OUTPUT__0" data_type: TYPE_FP32 dims: -1 dims: 10 } output { name: "OUTPUT__1" data_type: TYPE_FP32 dims: -1 dims: 10 }
Then
run
orasync_run
takes either two positional arguments or two keyword arugmentsINPUT__0
andINPUT__1
:# Both are valid triton_runner.torchscript_mnist.run(np.zeros((1, 28, 28)), np.zeros((1, 28, 28))) await triton_runner.torchscript_mnist.async_run( INPUT__0=np.zeros((1, 28, 28)), INPUT__1=np.zeros((1, 28, 28)) )
Mixing positional and keyword arguments will result in an error:
triton_runner.torchscript_mnist.run( np.zeros((1, 28, 28)), INPUT__1=np.zeros((1, 28, 28)) ) # throws errors
run
andasync_run
return aInferResult
object. Regardless of the protocol used, theInferResult
object has the following methods:as_numpy(name: str) -> NDArray[T]
: returns the result as a numpy array. The argument is the name of the output defined inconfig.pbtxt
.get_output(name: str) -> InferOutputTensor | dict[str, T]
: Returns the results as aInferOutputTensor
(gRPC) or a dictionary (HTTP). The argument is the name of the output defined inconfig.pbtxt
.get_response(self) -> ModelInferResponse | dict[str, T]
: Returns the entire response as aModelInferResponse
(gRPC) or a dictionary (HTTP).
Using the above
config.pbtxt
as example, the model consists of two outputs,OUTPUT__0
andOUTPUT__1
.To get
OUTPUT__0
as a numpy array:InferResult = triton_runner.torchscript_mnist.run(np.zeros((1, 28, 28)), np.zeros((1, 28, 28))) return InferResult.as_numpy("OUTPUT__0")
InferResult = triton_runner.torchscript_mnist.run(np.zeros((1, 28, 28)), np.zeros((1, 28, 28))) return InferResult.as_numpy("OUTPUT__0")
To get
OUTPUT__1
as a JSON dictionary:InferResult = triton_runner.torchscript_mnist.run(np.zeros((1, 28, 28)), np.zeros((1, 28, 28))) return InferResult.get_output("OUTPUT__0", as_json=True)
InferResult = triton_runner.torchscript_mnist.run(np.zeros((1, 28, 28)), np.zeros((1, 28, 28))) return InferResult.get_output("OUTPUT__0")
Additonally, the Triton runner exposes all tritonclient functions.
The list below comprises all the model management APIs from tritonclient
that are supported by Triton runners:
get_model_config
get_model_metadata
get_model_repository_index
is_model_ready
is_server_live
is_server_ready
load_model
unload_model
infer
stream_infer
The following advanced client APIs are also supported:
get_cuda_shared_memory_status
get_inference_statistics
get_log_settings
get_server_metadata
get_system_shared_memory_status
get_trace_settings
register_cuda_shared_memory
register_system_shared_memory
unregister_cuda_shared_memory
unregister_system_shared_memory
update_log_settings
update_trace_settings
Important: All of the client APIs are asynchronous. To use them, make sure to use it under an async
@svc.api
. See Sync vs Async APIsservice.py#@svc.api(input=bentoml.io.Text.from_sample("onnx_mnist"), output=bentoml.io.JSON()) async def unload_model(input_model: str): await triton_runner.unload_model(input_model) return {"unloaded": input_model}
Packaging BentoService with Triton Inference Server#
To build your BentoService with Triton Inference Server, add the following to your bentofile.yaml
or use reference/core:bentoml.bentos.build:
service: service:svc
include:
- /model_repository
- /data/*.png
- /*.py
exclude:
- /__pycache__
- /venv
- /train.py
- /build_bento.py
- /containerize_bento.py
python:
packages:
- bentoml[triton]
docker:
base_image: nvcr.io/nvidia/tritonserver:22.12-py3
Building this Bento with bentoml build:
$ bentoml build
if __name__ == "__main__":
import bentoml
bentoml.bentos.build(
"service:svc",
include=["/model_repository", "/data/*.png", "service.py"],
exclude=["/__pycache__", "/venv"],
docker={"base_image": "nvcr.io/nvidia/tritonserver:22.12-py3"},
)
Notice that we are using nvcr.io/nvidia/tritonserver:22.12-py3
as our base image. This can be substituted with any other
custom base image that has tritonserver
binary available. See Tritonβs documentation here
to learn more about building/composing custom Triton image.
Important: The provided Triton image from NVIDIA includes Python 3.8. Therefore, if you are developing your Bento with any other Python version, make sure that your
service.py
is compatible with Python 3.8.
Tip
To see all available options for Triton run:
$ docker run --init --rm -p 3000:3000 triton-integration:gpu tritonserver --help
Current Caveats#
At the time of writing, there are a few caveats that you should be aware of when using TritonRunner:
Versioning Policy Limitations#
By default, model configuration version policy
is set to latest(n=1)
, meaning the latest version of the model will be loaded into Triton server.
Currently, TritonRunner only supports the latest
policy.
If you have multiple versions of the same model in your BentoService, then the runner only consider the latest version.
For example, if the model repository have the following structure:
model_repository
βββ onnx_mnist
βΒ Β βββ 1
βΒ Β βΒ Β βββ model.onnx
βΒ Β βββ 2
βΒ Β βΒ Β βββ model.onnx
βΒ Β βββ config.pbtxt
...
Then
triton_runner.onnx_mnist
will reference to the latest version of the model (in this case, version 2).
To use a specific version of said model, refer to the example below:
from __future__ import annotations
import typing as t
import numpy as np
from tritonclient.grpc.aio import InferInput
from tritonclient.grpc.aio import np_to_triton_dtype
from tritonclient.grpc.aio import InferRequestedOutput
import bentoml
if t.TYPE_CHECKING:
from PIL.Image import Image
from numpy.typing import NDArray
# triton runner
triton_runner = bentoml.triton.Runner(
"triton_runner",
"./model_repository",
cli_args=[
"--load-model=onnx_mnist",
"--load-model=torchscript_yolov5s",
"--model-control-mode=explicit",
],
)
svc = bentoml.Service("triton-integration", runners=[triton_runner])
@svc.api(
input=bentoml.io.Image.from_sample("./data/0.png"), output=bentoml.io.NumpyNdarray()
)
async def predict_v1(input_data: Image) -> NDArray[t.Any]:
arr = np.array(input_data) / 255.0
arr = np.expand_dims(arr, (0, 1)).astype("float32")
input_0 = InferInput("input_0", arr.shape, np_to_triton_dtype(arr.dtype))
input_0.set_data_from_numpy(arr)
output_0 = InferRequestedOutput("output_0")
InferResult = await triton_runner.infer(
"onnx_mnist", inputs=[input_0], model_version="1", outputs=[output_0]
)
return InferResult.as_numpy("output_0")
Inference Protocol and Metrics Server#
By default, TritonRunner uses the Inference protocol for both REST and gRPC.
HTTP/REST APIs is disabled by default, though it can be enabled when creating the runner by passing tritonserver_type
to the Runner:
triton_runner = TritonRunner(
"http_runner",
"/path/to/model_repository",
tritonserver_type="http"
)
Currently, TritonRunner does not support running Metrics server. If you are interested in supporting the metrics server, please open an issue on GitHub
Additionally, BentoML will allocate a random port for the gRPC/HTTP server, hence grpc-port
or http-port
options that is passed to Runner cli_args
will be omitted.
Adaptive Batching#
Adaptive batching is a feature supported by BentoML runners that allows for efficient batch size selection during inference. However, itβs important to note that this feature is not compatible with TritonRunner
.
TritonRunner
is designed as a standalone Triton server, which means that the adaptive batching logic in BentoML runners is not invoked when using TritonRunner
.
Fortunately, Triton supports its own solution for efficient batching called dynamic batching. Similar to adaptive batching, dynamic batching also allows for the selection of the optimal batch size during inference. To use dynamic batching in Triton, relevant settings can be specified in the model configuration file.
π§ Help us improve the integration!
This integration is still in its early stages and we are looking for feedbacks and contributions to make it even better!
If you have any feedback or want to contribute any improvements to the Triton Inference Server integration, we would love to see your feature requests and pull request!
Check out the BentoML development guide and documentation guide to get started.