diff --git a/doc/deployment/tenssorrt.md b/doc/deployment/tenssorrt.md new file mode 100644 index 00000000..f8ca113c --- /dev/null +++ b/doc/deployment/tenssorrt.md @@ -0,0 +1,47 @@ + +# TensorRT Backend + +For deployment on NVIDIA targets we support TensorRT backends. +Currently the TensorRT backend always compiles for the first GPU of the local system. + +## Installation +The tensorrt module in the poetry shell needs to be installed seperately via pip: +``` +poetry shell +pip install tensorrt +``` + +## Configuration + +The backend supports the following configuration options. + +val_batches +: 1 (number of batches used for validation) + +test_batches +: 1 (number of batches used for test) + +val_frequency +: 10 (run backend every n validation epochs) + +## TODO: + +- [ ] remote execution support +- [ ] profiling and feedback support diff --git a/hannah/backends/__init__.py b/hannah/backends/__init__.py index 2eca7df0..d9e728a5 100644 --- a/hannah/backends/__init__.py +++ b/hannah/backends/__init__.py @@ -19,9 +19,11 @@ from .onnxrt import OnnxruntimeBackend +from .tensorrt import TensorRTBackend from .torch_mobile import TorchMobileBackend __all__ = [ "OnnxruntimeBackend", "TorchMobileBackend", + "TensorRTBackend", ] diff --git a/hannah/backends/tensorrt.py b/hannah/backends/tensorrt.py new file mode 100644 index 00000000..4b370bab --- /dev/null +++ b/hannah/backends/tensorrt.py @@ -0,0 +1,222 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import os +import time +from pathlib import Path +from tempfile import TemporaryDirectory + +import numpy as np +import torch + +try: + import tensorrt as trt + from cuda import cuda, cudart +except ModuleNotFoundError: + trt = None + cuda = None + cudart = None + +from .base import InferenceBackendBase, ProfilingResult + + +# Wrapper for cudaMemcpy which infers copy size and does error checking +def memcpy_host_to_device(device_ptr: int, host_arr: np.ndarray): + nbytes = host_arr.size * host_arr.itemsize + cuda_call( + cudart.cudaMemcpy( + device_ptr, host_arr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice + ) + ) + + +# Wrapper for cudaMemcpy which infers copy size and does error checking +def memcpy_device_to_host(host_arr: np.ndarray, device_ptr: int): + nbytes = host_arr.size * host_arr.itemsize + cuda_call( + cudart.cudaMemcpy( + host_arr, device_ptr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost + ) + ) + + +def check_cuda_err(err): + if isinstance(err, cuda.CUresult): + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("Cuda Error: {}".format(err)) + if isinstance(err, cudart.cudaError_t): + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError("Cuda Runtime Error: {}".format(err)) + else: + raise RuntimeError("Unknown error type: {}".format(err)) + + +def cuda_call(call): + err, res = call[0], call[1:] + check_cuda_err(err) + if len(res) == 1: + res = res[0] + return res + + +class TensorRTBackend(InferenceBackendBase): + def __init__( + self, val_batches=1, test_batches=1, val_frequency=10, warmup=10, repeat=30 + ): + super().__init__( + val_batches=val_batches, + test_batches=test_batches, + val_frequency=val_frequency, + ) + + if trt is None or cuda is None or cudart is None: + raise RuntimeError( + "TensorRT is not available, please install with tensorrt extra activated." + ) + + self.trt_logger = trt.Logger(trt.Logger.INFO) + + self.builder = trt.Builder(self.trt_logger) + + self.config = self.builder.create_builder_config() + self.config.max_workspace_size = 8 * (2**30) # 8 GB + + self.batch_size = None + self.network = None + self.parser = None + + self.engine = None + self.context = None + + def output_spec(self): + """ + Get the specs for the output tensor of the network. Useful to prepare memory allocations. + :return: Two items, the shape of the output tensor and its (numpy) datatype. + """ + return self.outputs[0]["shape"], self.outputs[0]["dtype"] + + def prepare(self, module): + with TemporaryDirectory() as tmp_dir: + tmp_dir = Path(tmp_dir) + onnx_path = tmp_dir / "model.onnx" + + logging.info("transfering model to onnx") + dummy_input = module.example_input_array + dummy_input = dummy_input.to(module.device) + torch.onnx.export(module, dummy_input, onnx_path, verbose=False) + + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + + self.network = self.builder.create_network(network_flags) + self.parser = trt.OnnxParser(self.network, self.trt_logger) + onnx_path = os.path.realpath(onnx_path) + with open(onnx_path, "rb") as f: + if not self.parser.parse(f.read()): + logging.error("Failed to load ONNX file: {}".format(onnx_path)) + for error in range(self.parser.num_errors): + logging.error(self.parser.get_error(error)) + + self.engine = self.builder.build_engine(self.network, self.config) + self.context = self.engine.create_execution_context() + + # Setup I/O bindings + self.inputs = [] + self.outputs = [] + self.allocations = [] + for i in range(self.engine.num_bindings): + is_input = False + if self.engine.binding_is_input(i): + is_input = True + name = self.engine.get_binding_name(i) + dtype = self.engine.get_binding_dtype(i) + shape = self.engine.get_binding_shape(i) + if is_input: + self.batch_size = shape[0] + size = np.dtype(trt.nptype(dtype)).itemsize + for s in shape: + size *= s + + if size <= 0: + continue + + print("Allocation", name, "size: ", size) + allocation = cuda_call(cudart.cudaMalloc(size)) + binding = { + "index": i, + "name": name, + "dtype": np.dtype(trt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + } + self.allocations.append(allocation) + if self.engine.binding_is_input(i): + self.inputs.append(binding) + else: + self.outputs.append(binding) + + assert self.batch_size > 0 + assert len(self.inputs) > 0 + assert len(self.outputs) > 0 + assert len(self.allocations) > 0 + + def run(self, *inputs): + output = np.zeros(*self.output_spec()) + + memcpy_host_to_device( + self.inputs[0]["allocation"], np.ascontiguousarray(inputs[0].cpu().numpy()) + ) + self.context.execute_v2(self.allocations) + memcpy_device_to_host(output, self.outputs[0]["allocation"]) + + result = torch.from_numpy(output) + + return result + + def profile(self, *inputs): + output = np.zeros(*self.output_spec()) + + memcpy_host_to_device( + self.inputs[0]["allocation"], np.ascontiguousarray(inputs[0].cpu().numpy()) + ) + + for _ in range(self.warmup): + self.context.execute_v2(self.allocations) + + start = time.perf_counter() + for _ in range(self.repeat): + self.context.execute_v2(self.allocations) + end = time.perf_counter() + + duration = (end - start) / self.repeat + + memcpy_device_to_host(output, self.outputs[0]["allocation"]) + + result = torch.from_numpy(output) + + return ProfilingResult( + outputs=result, metrics={"duration": duration}, profile=None + ) + + @classmethod + def available(cls): + if trt is not None and cuda is not None and cudart is not None: + return cuda.cuDeviceGetCount()[1] > 0 + + return False diff --git a/hannah/callbacks/summaries.py b/hannah/callbacks/summaries.py index e8737fe0..c183e731 100644 --- a/hannah/callbacks/summaries.py +++ b/hannah/callbacks/summaries.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2023 Hannah contributors. +# Copyright (c) 2024 Hannah contributors. # # This file is part of hannah. # See https://github.com/ekut-es/hannah for further info. @@ -21,6 +21,7 @@ import traceback from collections import OrderedDict +import numpy as np import pandas as pd import torch import torch.fx as fx @@ -29,12 +30,17 @@ from tabulate import tabulate from torch.fx.graph_module import GraphModule -from hannah.nas.functional_operators.operators import add, conv2d, linear, conv1d, self_attention2d +from hannah.nas.functional_operators.operators import ( + add, + conv1d, + conv2d, + linear, + self_attention2d, +) from hannah.nas.graph_conversion import GraphConversionTracer from ..models.factory import qat from ..models.sinc import SincNet -import numpy as np msglogger = logging.getLogger(__name__) diff --git a/hannah/conf/backend/tensorrt.yaml b/hannah/conf/backend/tensorrt.yaml new file mode 100644 index 00000000..8b57b4ad --- /dev/null +++ b/hannah/conf/backend/tensorrt.yaml @@ -0,0 +1,24 @@ +## +## Copyright (c) 2022 University of Tübingen. +## +## This file is part of hannah. +## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. +## +## Licensed under the Apache License, Version 2.0 (the "License"); +## you may not use this file except in compliance with the License. +## You may obtain a copy of the License at +## +## http://www.apache.org/licenses/LICENSE-2.0 +## +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +## See the License for the specific language governing permissions and +## limitations under the License. +## + + +_target_: hannah.callbacks.backends.TensorRTBackend +val_batches: 10 +test_batches: 10 +val_frequency: 10 diff --git a/hannah/conf/model/tc-res8-bs.yaml b/hannah/conf/model/tc-res8-bs.yaml index 239084b5..5f7cf141 100644 --- a/hannah/conf/model/tc-res8-bs.yaml +++ b/hannah/conf/model/tc-res8-bs.yaml @@ -43,9 +43,3 @@ width_multiplier: 1.0 dilation: 3 clipping_value: 100000.0 small: true - - -# Set by LigthningModule -width: 101 -height: 40 -n_labels: 12 diff --git a/hannah/models/timm.py b/hannah/models/timm.py index c3f224ba..76d78bce 100644 --- a/hannah/models/timm.py +++ b/hannah/models/timm.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2023 Hannah contributors. +# Copyright (c) 2024 Hannah contributors. # # This file is part of hannah. # See https://github.com/ekut-es/hannah for further info. @@ -220,12 +220,11 @@ def __init__( super().__init__() self.name = name - dummy_input = torch.randn(input_shape) + dummy_input = torch.randn(tuple(input_shape)) self.encoder = timm.create_model( name, num_classes=0, global_pool="", pretrained=pretrained, **kwargs ) - _, input_channels, input_x, input_y = dummy_input.shape if stem == "auto": logger.info("""Using default logger for automatic stem creation""") diff --git a/hannah/modules/vision/base.py b/hannah/modules/vision/base.py index 2ca0e8f3..abadde14 100644 --- a/hannah/modules/vision/base.py +++ b/hannah/modules/vision/base.py @@ -93,23 +93,25 @@ def setup(self, stage): msglogger.info(" Dev Set: %d", len(self.dev_set)) msglogger.info(" Test Set: %d", len(self.test_set)) - example_data = self._decode_batch(self.test_set[0])["data"] + example_data = self._decode_batch(self.test_set[0])["data"].unsqueeze(0) if not isinstance(example_data, torch.Tensor): example_data = torch.tensor(example_data, device=self.device) - self.example_input_array = example_data.clone().detach().unsqueeze(0) - self.example_feature_array = example_data.clone().detach().unsqueeze(0) + self.example_input_array = example_data.clone().detach() + self.example_feature_array = example_data.clone().detach() self.num_classes = 0 if self.train_set.class_names: self.num_classes = len(self.train_set.class_names) + input_shape = self.example_input_array.shape + if hasattr(self.hparams, "model"): msglogger.info("Setting up model %s", self.hparams.model.name) self.model = instantiate( self.hparams.model, - input_shape=self.example_input_array.shape, + input_shape=input_shape, labels=self.num_classes, _recursive_=False, ) @@ -288,9 +290,8 @@ def setup_augmentations(self, pipeline_configs): augmentations = {k: torch.nn.Sequential(*v) for k, v in augmentations.items()} self.augmentations = torch.nn.ModuleDict(augmentations) - + return augmentations - def _get_dataloader(self, dataset, unlabeled_data=None, shuffle=False): batch_size = self.hparams["batch_size"] @@ -314,7 +315,7 @@ def calc_workers(dataset): else dataset.max_workers ) return result - + num_workers = calc_workers(dataset) loader = data.DataLoader( @@ -324,9 +325,9 @@ def calc_workers(dataset): num_workers=num_workers, sampler=sampler if not dataset.sequential else None, collate_fn=vision_collate_fn, - multiprocessing_context="fork" if num_workers > 0 else None, - persistent_workers = True if num_workers > 0 else False, - prefetch_factor = 2 if num_workers > 0 else None, + multiprocessing_context="fork" if num_workers > 0 else None, + persistent_workers=True if num_workers > 0 else False, + prefetch_factor=2 if num_workers > 0 else None, pin_memory=True, ) self.batches_per_epoch = len(loader) @@ -341,9 +342,7 @@ def calc_workers(dataset): sampler=data.RandomSampler(unlabeled_data) if not unlabeled_data.sequential else None, - multiprocessing_context="fork" - if unlabeled_workers > 0 - else None, + multiprocessing_context="fork" if unlabeled_workers > 0 else None, ) return CombinedLoader( diff --git a/poetry.lock b/poetry.lock index e234cf17..fe67547b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -930,6 +930,27 @@ ssh = ["bcrypt (>=3.1.5)"] test = ["certifi", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] test-randomorder = ["pytest-randomly"] +[[package]] +name = "cuda-python" +version = "12.4.0" +description = "Python bindings for CUDA" +optional = true +python-versions = "*" +files = [ + {file = "cuda_python-12.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a9ddcbfe14a9b38eedf7e98323f680acc8c61ed6529638da2c7d4a5a8787e12"}, + {file = "cuda_python-12.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9e1a2cf38d71be0c85f1063967409b48badc06fb4294dd72fdf05de852f42fb"}, + {file = "cuda_python-12.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:8d505667cae5f793435dbe5fcaab4bff4a2a3029024148ce16e4003bae43cf45"}, + {file = "cuda_python-12.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e845338ad9634be2778b787f12f3a05b0116d49610a9a517f017fd91e79b1b97"}, + {file = "cuda_python-12.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dde31841f204f3de78e5e121252f7c43f1cfcf207ff3a0cf1d1017acb08dabf"}, + {file = "cuda_python-12.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:43d92a8f7531c5fa8d9d68e24b36dd352d1653353791583c4a3b10a629176fd4"}, + {file = "cuda_python-12.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84cd4f3a174fc116a6c1496efb3d60a7d4dce791194456461582908653ac67b2"}, + {file = "cuda_python-12.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcf44462196af79e34d2516f4416c5e46691a4d6b9faa13e75664115bdc40d4c"}, + {file = "cuda_python-12.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:07db8b89ac34ba4520b070f8de971f0266229fabe2443608ca91a8b93d572c09"}, + {file = "cuda_python-12.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2202ed9793c5948282e4e575e8227c127adb29e21fd31cb6d80098f93a2c14e"}, + {file = "cuda_python-12.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7146ed28810c28704649204f83ca60a9609263b3e8f08f77a33e154501cd84dd"}, + {file = "cuda_python-12.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:de0a2cb4ce6ea8e302010a580576ab78a2e95d015963e0d6933bb2508f65bd6a"}, +] + [[package]] name = "cycler" version = "0.12.1" @@ -5143,6 +5164,19 @@ files = [ {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"}, ] +[[package]] +name = "tensorrt" +version = "8.6.1.post1" +description = "A high performance deep learning inference library" +optional = true +python-versions = ">=3.6" +files = [ + {file = "tensorrt-8.6.1.post1.tar.gz", hash = "sha256:0ca64da500480a2d204c18d7c6791ec462c163ae4fa1db574b8c211da1116ea2"}, +] + +[package.extras] +numpy = ["numpy"] + [[package]] name = "testslide" version = "2.7.1" @@ -6021,10 +6055,11 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] [extras] onnxrt = ["onnx", "onnxruntime"] +tensorrt = ["cuda-python", "onnx", "tensorrt"] tvm = ["hannah-tvm"] vision = ["albumentations", "gdown", "imagecorruptions", "kornia", "pycocotools", "timm"] [metadata] lock-version = "2.0" python-versions = ">=3.9 <3.12" -content-hash = "5d0e2e4512f788caf6632ee89a1236229ca05222d905548ee76bc062b444ed0f" +content-hash = "3e8977d3fab1c67a043508ad06bae926079618b56f99f55eaefbe7f1d3da2e91" diff --git a/pydoc-markdown.yml b/pydoc-markdown.yml index 46e7ce87..c0bb4de7 100644 --- a/pydoc-markdown.yml +++ b/pydoc-markdown.yml @@ -1,5 +1,5 @@ ## -## Copyright (c) 2023 Hannah contributors. +## Copyright (c) 2024 Hannah contributors. ## ## This file is part of hannah. ## See https://github.com/ekut-es/hannah for further info. @@ -107,6 +107,9 @@ renderer: - title: TVM name: deployment/tvm source: doc/deployment/tvm.md + - title: TensorRT + name: deployment/tensorrt + source: doc/deployment/tensorrt.md - title: "Evaluation" children: - title: "Evaluation" diff --git a/pyproject.toml b/pyproject.toml index c48a941a..e5e7fd5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,9 @@ lightning = "^2.1.2" dgl = "1.1.3" pytorch-lightning = "^2.1.3" onnx = "^1.16.0" +cuda-python = {version = "^12.1.0", optional = true} +tensorrt = {version = "^8.6.1.post1", optional = true} + [tool.poetry.dev-dependencies] @@ -88,6 +91,7 @@ flaky = ">=3.7.0" tvm = ["hannah-tvm"] onnxrt = ["onnxruntime", "onnx"] vision = ["pycocotools", "albumentations", "imagecorruptions", "timm", "gdown", "kornia"] +tensorrt = ["cuda-python", "onnx", "tensorrt"] [tool.poetry.scripts] hannah-train = 'hannah.tools.train:main' diff --git a/test/test_backend.py b/test/test_backend.py new file mode 100644 index 00000000..b934de88 --- /dev/null +++ b/test/test_backend.py @@ -0,0 +1,94 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Test the prepare and run methods for each of the hannah backends""" + +import inspect + +import pytest +import torch +import torch.nn as nn + +import hannah.backends +from hannah.modules.base import ClassifierModule + + +def backends(): + """Iterates over the backends""" + for item in inspect.getmembers(hannah.backends): + if inspect.isclass(item[1]) and item[0] != "InferenceBackendBase": + if item[1].available(): + yield item[1] + else: + print(f"Skipping {item[0]} because it is not available") + + +class SimpleModule(ClassifierModule): + """Simple test module for the backends""" + + def __init__(self): + super().__init__(None, nn.Linear(1, 1), None, None) + + def forward(self, x): + return self.model(x) + + def setup(self, stage): + self.example_input_array = torch.tensor([1.0]).unsqueeze(0) + + def prepare_data(self): + pass + + def get_class_names(self): + return ["test"] + + +@pytest.mark.parametrize("backend", backends()) +def test_backend(backend): + test_module = SimpleModule() + test_module.prepare_data() + test_module.setup("fit") + + x = torch.tensor([1.0]).unsqueeze(0) + + backend = backend() + backend.prepare(test_module) + results = backend.run(x) + + ref = test_module(x) + + assert torch.allclose(results[0], ref) + + +@pytest.mark.parametrize("backend", backends()) +def test_profile(backend): + test_module = SimpleModule() + test_module.prepare_data() + test_module.setup("fit") + + x = torch.tensor([1.0]).unsqueeze(0) + + backend = backend() + backend.prepare(test_module) + result = backend.profile(x) + + ref = test_module(x) + + assert torch.allclose(result.outputs[0], ref) + + assert result.metrics["duration"] is not None