Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamic batches handling #10

Open
haiderasad opened this issue Jul 28, 2024 · 4 comments
Open

dynamic batches handling #10

haiderasad opened this issue Jul 28, 2024 · 4 comments

Comments

@haiderasad
Copy link

haiderasad commented Jul 28, 2024

Hey, nice work,
the dynamic batching flow is a bit broken I think
it works fine if the engine is built on one shape, but when built with

min_shape = (1, 360, 640)   # Minimum shape with batch size 1
opt_shape = (6, 360, 640)   # Optimal shape with batch size 6
max_shape = (10, 360, 640)  # Maximum shape with batch size 10

when I give it an image of (6, 360, 640) it says
ValueError: could not broadcast input array from the shape (1382400,) into shape (230400,)

upon investigating I see that the shape and size of the inputs is set to (10, 360, 640) so
its expecting (10, 360, 640) , I don't know why , so are you aware of what the best practice in tensorrt to handle dynamic inputs?

below is my whole code

engine building


import tensorrt as trt

# Set up TensorRT logger, builder, and network
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

# Use the ONNX parser to load your ONNX model
parser = trt.OnnxParser(network, TRT_LOGGER)

# Path to the ONNX file generated
onnx_file_path = 'det_model.onnx'

# Parse the ONNX file
with open(onnx_file_path, 'rb') as model:
    if not parser.parse(model.read()):
        print('ERROR: Failed to parse the ONNX file.')
        for error in range(parser.num_errors()):
            print(parser.get_error(error))
        exit()

# Configure the builder and create an optimization profile
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)  # 4GB, adjust this as necessary

# Assuming your model's input shape can vary in batch size, you need to set the optimization profile accordingly
min_shape = (1, 360, 640)   # Minimum shape with batch size 1
opt_shape = (6, 360, 640)   # Optimal shape with batch size 6360, 640
max_shape = (10, 360, 640)  # Maximum shape with batch size 10

profile = builder.create_optimization_profile()
profile.set_shape(network.get_input(0).name, min=min_shape, opt=opt_shape, max=max_shape)
config.add_optimization_profile(profile)

# Build the TensorRT engine
engine = builder.build_serialized_network(network, config)

# Save the engine to a file
engine_file_path = 'det_model_dynamic.trt'
with open(engine_file_path, 'wb') as f:
    f.write(engine)

print("TensorRT model is successfully created and saved to", engine_file_path)

Main.py


import ctypes
import numpy as np
import tensorrt as trt
from cuda import cuda, cudart
import cv2 as cv
try:
    FileNotFoundError
except NameError:
    FileNotFoundError = IOError

EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

def check_cuda_err(err):
    if isinstance(err, cuda.CUresult):
        if err != cuda.CUresult.CUDA_SUCCESS:
            raise RuntimeError("Cuda Error: {}".format(err))
    if isinstance(err, cudart.cudaError_t):
        if err != cudart.cudaError_t.cudaSuccess:
            raise RuntimeError("Cuda Runtime Error: {}".format(err))
    else:
        raise RuntimeError("Unknown error type: {}".format(err))

def cuda_call(call):
    err, res = call[0], call[1:]
    check_cuda_err(err)
    if len(res) == 1:
        res = res[0]
    return res

def GiB(val):
    return val * 1 << 30

class HostDeviceMem:
    def __init__(self, size: int, dtype: np.dtype, name= None, shape = None, format= None):
        nbytes = size * dtype.itemsize
        host_mem = cuda_call(cudart.cudaMallocHost(nbytes))
        pointer_type = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))

        self._host = np.ctypeslib.as_array(ctypes.cast(host_mem, pointer_type), (size,))
        self._device = cuda_call(cudart.cudaMalloc(nbytes))
        self._nbytes = nbytes
        self._name = name
        self._shape = shape
        self._format = format
        self._dtype = dtype

    @property
    def host(self) -> np.ndarray:
        return self._host

    @host.setter
    def host(self, arr: np.ndarray):
        if arr.size > self.host.size:
            raise ValueError(f"Tried to fit an array of size {arr.size} into host memory of size {self.host.size}")
        np.copyto(self.host[:arr.size], arr.flat, casting='safe')

    @property
    def device(self) -> int:
        return self._device

    @property
    def nbytes(self) -> int:
        return self._nbytes

    @property
    def name(self):
        return self._name

    @property
    def shape(self):
        return self._shape

    @property
    def format(self):
        return self._format

    @property
    def dtype(self) -> np.dtype:
        return self._dtype

    def __str__(self):
        return f"Host:\n{self.host}\nDevice:\n{self.device}\nSize:\n{self.nbytes}\n"

    def __repr__(self):
        return self.__str__()

    def free(self):
        cuda_call(cudart.cudaFree(self.device))
        cuda_call(cudart.cudaFreeHost(self.host.ctypes.data))

def allocate_buffers(engine: trt.ICudaEngine, profile_idx= None):
    inputs = []
    outputs = []
    bindings = []
    stream = cuda_call(cudart.cudaStreamCreate())
    tensor_names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]
    for binding in tensor_names:
        format = engine.get_tensor_format(binding)
        
        
        shape = engine.get_tensor_shape(binding) if profile_idx is None else engine.get_tensor_profile_shape(binding, profile_idx)[0]
        shape_valid = np.all([s >= 0 for s in shape])
        if not shape_valid and profile_idx is None:
            raise ValueError(f"Binding {binding} has dynamic shape, but no profile was specified.")
        size = trt.volume(shape)
        dtype = np.dtype(trt.nptype(engine.get_tensor_dtype(binding)))

        print(shape)
        binding_memory = HostDeviceMem(size, dtype, name=binding, shape=shape, format=format)

        bindings.append(int(binding_memory.device))

        if engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:
            inputs.append(binding_memory)
        else:
            outputs.append(binding_memory)
    return inputs, outputs, bindings, stream


def memcpy_host_to_device(device_ptr: int, host_arr: np.ndarray):
    nbytes = host_arr.size * host_arr.itemsize
    cuda_call(cudart.cudaMemcpy(device_ptr, host_arr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice))

def memcpy_device_to_host(host_arr: np.ndarray, device_ptr: int):
    nbytes = host_arr.size * host_arr.itemsize
    cuda_call(cudart.cudaMemcpy(host_arr, device_ptr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost))

def _do_inference_base(inputs, outputs, stream, execute_async_func):
    kind = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
    [cuda_call(cudart.cudaMemcpyAsync(inp.device, inp.host, inp.nbytes, kind, stream)) for inp in inputs]
    execute_async_func()
    kind = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
    [cuda_call(cudart.cudaMemcpyAsync(out.host, out.device, out.nbytes, kind, stream)) for out in outputs]
    cuda_call(cudart.cudaStreamSynchronize(stream))
    return [out.host for out in outputs]

def do_inference(context, engine, bindings, inputs, outputs, stream):
    def execute_async_func():
        context.execute_async_v3(stream_handle=stream)

    num_io = engine.num_io_tensors
    context.set_input_shape('input', (6, 360, 640))
    for i in range(num_io):
        context.set_tensor_address(engine.get_tensor_name(i), bindings[i])
        # if engine.get_tensor_name(i)=='input':
        #     context.set_input_shape('input', (6, 360, 640))
        
    #print(context.all_binding_shapes_specified)
    return _do_inference_base(inputs, outputs, stream, execute_async_func)

def load_engine(engine_file_path):
    with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
        return runtime.deserialize_cuda_engine(f.read())

def preprocess_images(images, width=1280 // 2, height=720 // 2):
    shapes = [img.shape for img in images]
    images = [cv.resize(img, (width, height)) for img in images]
    images = np.stack(images)
    images = images / 128.0 - 1
    return images

engine_file_path = 'det_model_dynamic.trt'
engine = load_engine(engine_file_path)
inputs, outputs, bindings, stream = allocate_buffers(engine=engine,profile_idx=0)
images = [np.random.rand(360, 640, 6).astype(np.float32) for _ in range(1)]  # Adjust the batch size as needed
preprocessed_images = preprocess_images(images)

#print(inputs[0].shape)
for host_device_buffer in inputs:
    np.copyto(host_device_buffer.host, preprocessed_images.flatten())
    
context = engine.create_execution_context()
masks = do_inference(context=context, engine=engine, inputs=inputs, outputs=outputs, bindings=bindings, stream=stream)
#print(len(masks))
for mask in masks:
    print(mask.shape)
@leimao
Copy link
Owner

leimao commented Jul 28, 2024

You will need to make the batch size to be -1 in your ONNX model to enable dynamic batching for your TensorRT engine.

@haiderasad
Copy link
Author

image
like this?

@leimao
Copy link
Owner

leimao commented Jul 28, 2024

Yes. But there are many other caveats of using dynamic batching. This example emphasizes on using the TensorRT custom plugin interface. Dynamic batching is considered as a little bit more "advanced" feature and is not covered in this example.

If you are interested in using dynamic batching for your TensorRT engine and custom plugins, please refer to TensorRT developer guide for guidance.

@haiderasad
Copy link
Author

Thanks for the info, i was able to get the batching to work but it's just the host_device_buffer in inputs that needs to be dynamically set based on runtime input batch size

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants