Skip to content

Commit

Permalink
Fix ctypes, don't close fd, use smart pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
fpetrini15 committed Apr 6, 2024
1 parent 0a124ef commit 479dccc
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 29 deletions.
13 changes: 12 additions & 1 deletion src/python/library/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ else()

if (${TRITON_ENABLE_PERF_ANALYZER})
set(perf_analyzer_arg --perf-analyzer ${CMAKE_INSTALL_PREFIX}/bin/perf_analyzer)
endif()
endif() # TRITON_ENABLE_PERF_ANALYZER
if (${TRITON_ENABLE_GPU})
set(gpu_arg --include-gpu-libs)
endif() # TRITON_ENABLE_GPU
set(linux_wheel_stamp_file "linux_stamp.whl")
add_custom_command(
OUTPUT "${linux_wheel_stamp_file}"
Expand All @@ -138,6 +141,7 @@ else()
--dest-dir "${CMAKE_CURRENT_BINARY_DIR}/linux"
--linux
${perf_analyzer_arg}
${gpu_arg}
DEPENDS ${LINUX_WHEEL_DEPENDS}
)

Expand Down Expand Up @@ -178,7 +182,14 @@ if(${TRITON_ENABLE_PYTHON_GRPC})
)
endif() # TRITON_ENABLE_PYTHON_GRPC

# Generic Wheel
set(WHEEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/generic")
install(
CODE "file(GLOB _Wheel \"${WHEEL_DIR}/triton*.whl\")"
CODE "file(INSTALL \${_Wheel} DESTINATION \"${CMAKE_INSTALL_PREFIX}/python\")"
)

# Platform-specific wheels
if(WIN32)
set(WHEEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/windows")
else()
Expand Down
24 changes: 16 additions & 8 deletions src/python/library/build_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ def sed(pattern, replace, source, dest=None):
required=False,
help="Include windows specific artifacts.",
)
parser.add_argument(
"--include-gpu-libs",
action="store_true",
required=False,
help="Include gpu specific libraries",
)
parser.add_argument(
"--perf-analyzer",
type=str,
Expand Down Expand Up @@ -186,10 +192,11 @@ def sed(pattern, replace, source, dest=None):
"tritonclient/utils/libcshm.so",
os.path.join(FLAGS.whl_dir, "tritonclient/utils/shared_memory/libcshm.so"),
)
cpdir(
"tritonclient/utils/cuda_shared_memory",
os.path.join(FLAGS.whl_dir, "tritonclient/utils/cuda_shared_memory"),
)
if FLAGS.include_gpu_libs:
cpdir(
"tritonclient/utils/cuda_shared_memory",
os.path.join(FLAGS.whl_dir, "tritonclient/utils/cuda_shared_memory"),
)

# Copy the pre-compiled perf_analyzer binary
if FLAGS.perf_analyzer is not None:
Expand All @@ -212,10 +219,11 @@ def sed(pattern, replace, source, dest=None):
os.path.join(FLAGS.whl_dir, "tritonclient/utils/shared_memory/cshm.dll"),
)
# FIXME: Enable when Windows supports GPU tensors DLIS-4169
# cpdir(
# "tritonclient/utils/cuda_shared_memory",
# os.path.join(FLAGS.whl_dir, "tritonclient/utils/cuda_shared_memory"),
# )
# if FLAGS.include_gpu_libs:
# cpdir(
# "tritonclient/utils/cuda_shared_memory",
# os.path.join(FLAGS.whl_dir, "tritonclient/utils/cuda_shared_memory"),

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
# )

shutil.copyfile("LICENSE.txt", os.path.join(FLAGS.whl_dir, "LICENSE.txt"))
shutil.copyfile("setup.py", os.path.join(FLAGS.whl_dir, "setup.py"))
Expand Down
11 changes: 8 additions & 3 deletions src/python/library/tritonclient/utils/shared_memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,10 @@ def get_contents_as_numpy(shm_handle, datatype, shape, offset=0):
The numpy array generated using the contents of the specified shared
memory region.
"""
shm_file = c_void_p()
# Safe initializer for Unix case where shm_file must be dereferenced to
# base in order to store file descriptor.
safe_initializer = c_int(-1)
shm_file = cast(byref(safe_initializer), c_void_p)
region_offset = c_uint64()
byte_size = c_uint64()
shm_addr = c_char_p()
Expand Down Expand Up @@ -284,8 +287,10 @@ def destroy_shared_memory_region(shm_handle):
SharedMemoryException
If unable to unlink the shared memory region.
"""

shm_file = c_void_p()
# Safe initializer for Unix case where shm_file must be dereferenced to
# base in order to store file descriptor.
safe_initializer = c_int(-1)
shm_file = cast(byref(safe_initializer), c_void_p)
offset = c_uint64()
byte_size = c_uint64()
shm_addr = c_char_p()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ SharedMemoryHandleCreate(
handle->triton_shm_name_ = triton_shm_name;
handle->base_addr_ = shm_addr;
handle->shm_key_ = shm_key;
handle->platform_handle_ = new ShmFile(shm_file);
handle->platform_handle_ = std::make_unique<ShmFile>(shm_file);
handle->offset_ = offset;
handle->byte_size_ = byte_size;
return static_cast<void*>(handle);
Expand Down Expand Up @@ -97,8 +97,7 @@ SharedMemoryRegionMap(
return -1;
}

// close shared memory descriptor, return 0 if success else return -1
return close(fd);
return 0;
#endif
}

Expand All @@ -119,29 +118,29 @@ SharedMemoryRegionCreate(
DWORD high_order_size = (upperbound_size >> 32) & 0xFFFFFFFF;
DWORD low_order_size = upperbound_size & 0xFFFFFFFF;

HANDLE local_handle = CreateFileMapping(
HANDLE shm_file = CreateFileMapping(
INVALID_HANDLE_VALUE, // use paging file
NULL, // default security
PAGE_READWRITE, // read/write access
high_order_size, // maximum object size (high-order DWORD)
low_order_size, // maximum object size (low-order DWORD)
shm_key); // name of mapping object

if (local_handle == NULL) {
if (shm_file == NULL) {
return -7;
}

// get base address of shared memory region
void* shm_addr = nullptr;
int err = SharedMemoryRegionMap((void*)local_handle, 0, byte_size, &shm_addr);
int err = SharedMemoryRegionMap((void*)shm_file, 0, byte_size, &shm_addr);
if (err == -1) {
return -4;
}

// create a handle for the shared memory region
*shm_handle = SharedMemoryHandleCreate(
std::string(triton_shm_name), shm_addr, std::string(shm_key),
(void*)local_handle, 0, byte_size);
(void*)shm_file, 0, byte_size);
#else
// get shared memory region descriptor
int shm_fd = shm_open(shm_key, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
Expand Down Expand Up @@ -188,12 +187,12 @@ GetSharedMemoryHandleInfo(
#ifdef _WIN32
HANDLE* file = static_cast<HANDLE*>(shm_file);
#else
int* file = *static_cast<int**>(shm_file);
int* file = *reinterpret_cast<int**>(shm_file);
#endif // _WIN32
SharedMemoryHandle* handle = static_cast<SharedMemoryHandle*>(shm_handle);
*shm_addr = static_cast<char*>(handle->base_addr_);
*shm_key = handle->shm_key_.c_str();
*file = handle->platform_handle_->shm_file_;
*file = *(handle->platform_handle_->GetShmFile());
*offset = handle->offset_;
*byte_size = handle->byte_size_;
return 0;
Expand All @@ -213,7 +212,7 @@ SharedMemoryRegionDestroy(void* shm_handle)
// We keep Windows shared memory handles open until we are done
// using them. When all handles are closed, the system will free
// the section of the paging file that the object uses.
CloseHandle(handle->platform_handle_->shm_file_);
CloseHandle(*(handle->platform_handle_->GetShmFile()));
#else
int status = munmap(shm_addr, handle->byte_size_);
if (status == -1) {
Expand All @@ -224,11 +223,11 @@ SharedMemoryRegionDestroy(void* shm_handle)
if (shm_fd == -1) {
return -5;
}
close(*(handle->platform_handle_->GetShmFile()));
#endif // _WIN32

// FIXME: Investigate use of smart pointers for these
// allocations instead
delete handle->platform_handle_;
// FIXME: Investigate use of smart pointers for this
// allocation instead
delete handle;

return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ extern "C" {
#ifdef _WIN32
#define TRITONCLIENT_DECLSPEC __declspec(dllexport)
#else
define TRITONCLIENT_DECLSPEC
#define TRITONCLIENT_DECLSPEC
#endif

//==============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,28 @@
#ifdef _WIN32
#include <windows.h>
#endif // _WIN32
#include <memory>

struct ShmFile {
#ifdef _WIN32
HANDLE shm_file_;
ShmFile(void* shm_file) { shm_file_ = static_cast<HANDLE>(shm_file); };
HANDLE* GetShmFile() { return &shm_file_; };
#else
int shm_file_;
ShmFile(int shm_file) { shm_file_ = *static_cast<int*>(shm_file); };
std::unique_ptr<int> shm_file_;
ShmFile(void* shm_file)
{
shm_file_ = std::make_unique<int>(*static_cast<int*>(shm_file));
};
int* GetShmFile() { return shm_file_.get(); }
#endif // _WIN32
};

struct SharedMemoryHandle {
std::string triton_shm_name_;
std::string shm_key_;
void* base_addr_;
ShmFile* platform_handle_;
std::unique_ptr<ShmFile> platform_handle_;
size_t offset_;
size_t byte_size_;
};

0 comments on commit 479dccc

Please sign in to comment.