Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unique ptr #865

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ODLA/include/ODLA/odla.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <ODLA/odla_common.h>
#include <ODLA/odla_compute.h>
#include <ODLA/odla_device.h>
#include <ODLA/odla_memory.h>
#include <ODLA/odla_ops.h>
#include <ODLA/odla_profiler.h>
#include <ODLA/odla_task.h>
Expand Down
93 changes: 93 additions & 0 deletions ODLA/include/ODLA/odla_memory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
//===- odla_memory.h ------------------------------------------------------===//
//
// Copyright (C) 2019-2021 Alibaba Group Holding Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#ifndef _ODLA_MEMORY_H_
#define _ODLA_MEMORY_H_

#include <ODLA/odla_common.h>

/*! \file
* \details This file defines the ODLA memory related APIs.
*/

#ifdef __cplusplus
extern "C" {
#endif

//! \brief memory copy types
typedef enum {
ODLA_MEMCPY_H2H,
ODLA_MEMCPY_H2D,
ODLA_MEMCPY_D2H,
ODLA_MEMCPY_D2D,
} odla_memcpy_type;

//! \brief Allocate device memory
/*!
\param devPtr the pointer to allocated device memory
\param size the requested allocation size in bytes

\return odla_status
*/
extern ODLA_API_EXPORT odla_status ODLA_API_CALL
odla_AllocateDeviceMemory(odla_void** dev_ptr, odla_size_t size);

//! \brief Free device memory
/*!
\param devPtr the device pointer to memory to free

\return odla_status
*/
extern ODLA_API_EXPORT odla_status ODLA_API_CALL
odla_FreeDeviceMemory(odla_void* ptr);

//! \brief Allocate host memory
/*!
\param ptr the pointer to allocated host memory
\param size the requested allocation size in bytes

\return odla_status
*/
extern ODLA_API_EXPORT odla_status ODLA_API_CALL
odla_AllocateHostMemory(odla_void** host_ptr, odla_size_t size);

//! \brief Free host memory
/*!
\param ptr the host pointer to memory to free

\return odla_status
*/
extern ODLA_API_EXPORT odla_status ODLA_API_CALL
odla_FreeHostMemory(odla_void* ptr);

//! \brief Copy data between host and device.
/*!
\param dst the destination memory address
\param src the source memory address
\param size the size in bytes to copy
\param type the memory copy type

\return odla_status
*/
extern ODLA_API_EXPORT odla_status ODLA_API_CALL odla_CopyMemory(
odla_void* dst, odla_void* src, odla_size_t size, odla_memcpy_type type);

#ifdef __cplusplus
} // C extern
#endif

#endif // _ODLA_MEMORY_H_
65 changes: 59 additions & 6 deletions ODLA/platforms/tensorrt/odla_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ static auto Dummy = cudaFree(0);
template <typename T>
struct TrtDestroyer {
void operator()(T* t) {
if (t) {
t->destroy();
}
#if NV_TENSORRT_MAJOR < 8
t->destroy();
#else
delete (t);
#endif
}
};

Expand Down Expand Up @@ -175,8 +177,8 @@ typedef struct {
} branch_info;

struct _odla_computation {
std::shared_ptr<nvinfer1::IBuilder> builder = nullptr;
std::shared_ptr<nvinfer1::INetworkDefinition> network = nullptr;
TrtUniquePtr<nvinfer1::IBuilder> builder;
TrtUniquePtr<nvinfer1::INetworkDefinition> network;
std::unordered_map<std::string, odla_value> inputs;
std::unordered_map<std::string, odla_value> outputs;
std::vector<std::vector<float>> buffers;
Expand Down Expand Up @@ -254,7 +256,7 @@ struct _odla_context {
odla_computation comp = nullptr;
TrtUniquePtr<nvinfer1::ICudaEngine> engine{nullptr};
TrtUniquePtr<nvinfer1::IExecutionContext> ctx{nullptr};
std::shared_ptr<nvinfer1::IBuilderConfig> builder_cfg = nullptr;
TrtUniquePtr<nvinfer1::IBuilderConfig> builder_cfg{nullptr};
nvinfer1::IOptimizationProfile* builder_profile = nullptr;
std::vector<void*> bindings;

Expand Down Expand Up @@ -2766,4 +2768,55 @@ odla_values odla_EndIf(odla_value_ids value_ids) {
return ret;
}

odla_status odla_AllocateDeviceMemory(odla_void** devPtr, odla_size_t size) {
cudaError_t code = cudaMalloc(devPtr, size);
CHECK(code);
return (code == cudaSuccess) ? ODLA_SUCCESS : ODLA_MEM_ERROR;
}

odla_status odla_FreeDeviceMemory(odla_void* devPtr) {
cudaError_t code = cudaFree(devPtr);
CHECK(code);
return (code == cudaSuccess) ? ODLA_SUCCESS : ODLA_MEM_ERROR;
}

odla_status odla_AllocateHostMemory(odla_void** ptr, odla_size_t size) {
cudaError_t code = cudaMallocHost(ptr, size);
CHECK(code);
return (code == cudaSuccess) ? ODLA_SUCCESS : ODLA_MEM_ERROR;
}

odla_status odla_FreeHostMemory(odla_void* ptr) {
cudaError_t code = cudaFreeHost(ptr);
CHECK(code);
return (code == cudaSuccess) ? ODLA_SUCCESS : ODLA_MEM_ERROR;
}

odla_status odla_CopyMemory(odla_void* dst, odla_void* src, odla_size_t size,
odla_memcpy_type type) {
cudaMemcpyKind kind = cudaMemcpyDefault;
switch (type) {
case ODLA_MEMCPY_H2H:
kind = cudaMemcpyHostToHost;
break;
case ODLA_MEMCPY_H2D:
kind = cudaMemcpyHostToDevice;
break;
case ODLA_MEMCPY_D2H:
kind = cudaMemcpyDeviceToHost;
break;
case ODLA_MEMCPY_D2D:
kind = cudaMemcpyDeviceToDevice;
break;

default:
kind = cudaMemcpyDefault;
break;
}

cudaError_t code = cudaMemcpy(dst, src, size, kind);
CHECK(code);
return (code == cudaSuccess) ? ODLA_SUCCESS : ODLA_MEM_ERROR;
}

} // C extern