From f6b71362f195dd37d3eac3e0533fe67d4c4c1f72 Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Tue, 22 Mar 2022 16:31:13 -0700 Subject: [PATCH 1/2] [ODLA/TRT] Add memory operation support --- ODLA/include/ODLA/odla.h | 1 + ODLA/include/ODLA/odla_memory.h | 93 ++++++++++++++++++++++++ ODLA/platforms/tensorrt/odla_tensorrt.cc | 51 +++++++++++++ 3 files changed, 145 insertions(+) create mode 100644 ODLA/include/ODLA/odla_memory.h diff --git a/ODLA/include/ODLA/odla.h b/ODLA/include/ODLA/odla.h index c13f4f23f..f01c46fb9 100644 --- a/ODLA/include/ODLA/odla.h +++ b/ODLA/include/ODLA/odla.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include diff --git a/ODLA/include/ODLA/odla_memory.h b/ODLA/include/ODLA/odla_memory.h new file mode 100644 index 000000000..ee17c92a6 --- /dev/null +++ b/ODLA/include/ODLA/odla_memory.h @@ -0,0 +1,93 @@ +//===- odla_memory.h ------------------------------------------------------===// +// +// Copyright (C) 2019-2021 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef _ODLA_MEMORY_H_ +#define _ODLA_MEMORY_H_ + +#include + +/*! \file + * \details This file defines the ODLA memory related APIs. + */ + +#ifdef __cplusplus +extern "C" { +#endif + +//! \brief memory copy types +typedef enum { + ODLA_MEMCPY_H2H, + ODLA_MEMCPY_H2D, + ODLA_MEMCPY_D2H, + ODLA_MEMCPY_D2D, +} odla_memcpy_type; + +//! \brief Allocate device memory +/*! + \param devPtr the pointer to allocated device memory + \param size the requested allocation size in bytes + + \return odla_status +*/ +extern ODLA_API_EXPORT odla_status ODLA_API_CALL +odla_AllocateDeviceMemory(odla_void** dev_ptr, odla_size_t size); + +//! \brief Free device memory +/*! + \param devPtr the device pointer to memory to free + + \return odla_status +*/ +extern ODLA_API_EXPORT odla_status ODLA_API_CALL +odla_FreeDeviceMemory(odla_void* ptr); + +//! \brief Allocate host memory +/*! + \param ptr the pointer to allocated host memory + \param size the requested allocation size in bytes + + \return odla_status +*/ +extern ODLA_API_EXPORT odla_status ODLA_API_CALL +odla_AllocateHostMemory(odla_void** host_ptr, odla_size_t size); + +//! \brief Free host memory +/*! + \param ptr the host pointer to memory to free + + \return odla_status +*/ +extern ODLA_API_EXPORT odla_status ODLA_API_CALL +odla_FreeHostMemory(odla_void* ptr); + +//! \brief Copy data between host and device. +/*! + \param dst the destination memory address + \param src the source memory address + \param size the size in bytes to copy + \param type the memory copy type + + \return odla_status +*/ +extern ODLA_API_EXPORT odla_status ODLA_API_CALL odla_CopyMemory( + odla_void* dst, odla_void* src, odla_size_t size, odla_memcpy_type type); + +#ifdef __cplusplus +} // C extern +#endif + +#endif // _ODLA_MEMORY_H_ diff --git a/ODLA/platforms/tensorrt/odla_tensorrt.cc b/ODLA/platforms/tensorrt/odla_tensorrt.cc index a14e2f113..ef0a1721e 100644 --- a/ODLA/platforms/tensorrt/odla_tensorrt.cc +++ b/ODLA/platforms/tensorrt/odla_tensorrt.cc @@ -2766,4 +2766,55 @@ odla_values odla_EndIf(odla_value_ids value_ids) { return ret; } +odla_status odla_AllocateDeviceMemory(odla_void** devPtr, odla_size_t size) { + cudaError_t code = cudaMalloc(devPtr, size); + CHECK(code); + return (code == cudaSuccess) ? ODLA_SUCCESS : ODLA_MEM_ERROR; +} + +odla_status odla_FreeDeviceMemory(odla_void* devPtr) { + cudaError_t code = cudaFree(devPtr); + CHECK(code); + return (code == cudaSuccess) ? ODLA_SUCCESS : ODLA_MEM_ERROR; +} + +odla_status odla_AllocateHostMemory(odla_void** ptr, odla_size_t size) { + cudaError_t code = cudaMallocHost(ptr, size); + CHECK(code); + return (code == cudaSuccess) ? ODLA_SUCCESS : ODLA_MEM_ERROR; +} + +odla_status odla_FreeHostMemory(odla_void* ptr) { + cudaError_t code = cudaFreeHost(ptr); + CHECK(code); + return (code == cudaSuccess) ? ODLA_SUCCESS : ODLA_MEM_ERROR; +} + +odla_status odla_CopyMemory(odla_void* dst, odla_void* src, odla_size_t size, + odla_memcpy_type type) { + cudaMemcpyKind kind = cudaMemcpyDefault; + switch (type) { + case ODLA_MEMCPY_H2H: + kind = cudaMemcpyHostToHost; + break; + case ODLA_MEMCPY_H2D: + kind = cudaMemcpyHostToDevice; + break; + case ODLA_MEMCPY_D2H: + kind = cudaMemcpyDeviceToHost; + break; + case ODLA_MEMCPY_D2D: + kind = cudaMemcpyDeviceToDevice; + break; + + default: + kind = cudaMemcpyDefault; + break; + } + + cudaError_t code = cudaMemcpy(dst, src, size, kind); + CHECK(code); + return (code == cudaSuccess) ? ODLA_SUCCESS : ODLA_MEM_ERROR; +} + } // C extern From c3c63969dfcb1fbc31596622a96e69cd40f6fcb2 Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Tue, 22 Mar 2022 17:14:12 -0700 Subject: [PATCH 2/2] [ODLA/TRT] Replace shared_ptr with unique_ptr --- ODLA/platforms/tensorrt/odla_tensorrt.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/ODLA/platforms/tensorrt/odla_tensorrt.cc b/ODLA/platforms/tensorrt/odla_tensorrt.cc index ef0a1721e..68e7ca605 100644 --- a/ODLA/platforms/tensorrt/odla_tensorrt.cc +++ b/ODLA/platforms/tensorrt/odla_tensorrt.cc @@ -54,9 +54,11 @@ static auto Dummy = cudaFree(0); template struct TrtDestroyer { void operator()(T* t) { - if (t) { - t->destroy(); - } +#if NV_TENSORRT_MAJOR < 8 + t->destroy(); +#else + delete (t); +#endif } }; @@ -175,8 +177,8 @@ typedef struct { } branch_info; struct _odla_computation { - std::shared_ptr builder = nullptr; - std::shared_ptr network = nullptr; + TrtUniquePtr builder; + TrtUniquePtr network; std::unordered_map inputs; std::unordered_map outputs; std::vector> buffers; @@ -254,7 +256,7 @@ struct _odla_context { odla_computation comp = nullptr; TrtUniquePtr engine{nullptr}; TrtUniquePtr ctx{nullptr}; - std::shared_ptr builder_cfg = nullptr; + TrtUniquePtr builder_cfg{nullptr}; nvinfer1::IOptimizationProfile* builder_profile = nullptr; std::vector bindings;