diff --git a/src/python/library/tritonclient/utils/shared_memory/__init__.py b/src/python/library/tritonclient/utils/shared_memory/__init__.py index e6c21ad4a..719b96819 100755 --- a/src/python/library/tritonclient/utils/shared_memory/__init__.py +++ b/src/python/library/tritonclient/utils/shared_memory/__init__.py @@ -48,7 +48,10 @@ def from_param(cls, value): class ShmFile(Structure): if sys.platform == "win32": - _fields_ = [("shm_handle_", c_void_p)] + _fields_ = [ + ("backing_file_handle_", c_void_p), + ("shm_mapping_handle_", c_void_p), + ] else: _fields_ = [("shm_fd_", c_int)] @@ -334,7 +337,9 @@ def __init__(self, err): -4: "unable to read/mmap the shared memory region", -5: "unable to unlink the shared memory region", -6: "unable to munmap the shared memory region", - -7: "unable to create file mapping", + -7: "unable to create shm directory or backing file", + -8: "unable to create file mapping", + -9: "unable to delete backing file", } self._msg = None if type(err) == str: diff --git a/src/python/library/tritonclient/utils/shared_memory/shared_memory.cc b/src/python/library/tritonclient/utils/shared_memory/shared_memory.cc index da81a6d34..5242c007d 100644 --- a/src/python/library/tritonclient/utils/shared_memory/shared_memory.cc +++ b/src/python/library/tritonclient/utils/shared_memory/shared_memory.cc @@ -39,6 +39,8 @@ #include "shared_memory.h" #include "shared_memory_handle.h" +#define TRITON_SHM_FILE_ROOT "C:\\triton_shm\\" + //============================================================================== // SharedMemoryControlContext namespace { @@ -46,13 +48,13 @@ namespace { void* SharedMemoryHandleCreate( std::string triton_shm_name, void* shm_addr, std::string shm_key, - ShmFile* shm_file, size_t offset, size_t byte_size) + std::unique_ptr&& shm_file, size_t offset, size_t byte_size) { SharedMemoryHandle* handle = new SharedMemoryHandle(); handle->triton_shm_name_ = triton_shm_name; handle->base_addr_ = shm_addr; handle->shm_key_ = shm_key; - handle->platform_handle_.reset(shm_file); + handle->platform_handle_ = std::move(shm_file); handle->offset_ = offset; handle->byte_size_ = byte_size; return static_cast(handle); @@ -73,14 +75,14 @@ SharedMemoryRegionMap( DWORD low_order_offset = upperbound_offset & 0xFFFFFFFF; // map shared memory to process address space *shm_addr = MapViewOfFile( - shm_file->shm_handle_, // handle to map object - FILE_MAP_ALL_ACCESS, // read/write permission - high_order_offset, // offset (high-order DWORD) - low_order_offset, // offset (low-order DWORD) + shm_file->shm_mapping_handle_, // handle to map object + FILE_MAP_ALL_ACCESS, // read/write permission + high_order_offset, // offset (high-order DWORD) + low_order_offset, // offset (low-order DWORD) byte_size); if (*shm_addr == NULL) { - CloseHandle(shm_file->shm_handle_); + CloseHandle(shm_file->shm_mapping_handle_); return -1; } // For Windows, we cannot close the shared memory handle here. When all @@ -100,6 +102,38 @@ SharedMemoryRegionMap( #endif } +#ifdef _WIN32 +int +SharedMemoryCreateBackingFile(const char* shm_key, HANDLE* backing_file_handle) +{ + LPCSTR backing_file_directory(TRITON_SHM_FILE_ROOT); + bool success = CreateDirectory(backing_file_directory, NULL); + if (!success && GetLastError() != ERROR_ALREADY_EXISTS) { + return -1; + } + LPCSTR backing_file_path = + std::string(TRITON_SHM_FILE_ROOT + std::string(shm_key)).c_str(); + *backing_file_handle = CreateFile( + backing_file_path, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ, NULL, + OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL); + if (*backing_file_handle == INVALID_HANDLE_VALUE) { + return -1; + } + return 0; +} + +int +SharedMemoryDeleteBackingFile(const char* key, HANDLE backing_file_handle) +{ + CloseHandle(backing_file_handle); + LPCSTR backing_file_path = + std::string(TRITON_SHM_FILE_ROOT + std::string(key)).c_str(); + if (!DeleteFile(backing_file_path)) { + return -1; + } +} +#endif + } // namespace TRITONCLIENT_DECLSPEC int @@ -108,6 +142,11 @@ SharedMemoryRegionCreate( void** shm_handle) { #ifdef _WIN32 + HANDLE backing_file_handle; + int err = SharedMemoryCreateBackingFile(shm_key, &backing_file_handle); + if (err == -1) { + return -7; + } // The CreateFileMapping function takes a high-order and low-order DWORD (4 // bytes each) for size. 'size_t' can either be 4 or 8 bytes depending on the // operating system. To handle both cases agnostically, we cast 'byte_size' to @@ -118,22 +157,28 @@ SharedMemoryRegionCreate( DWORD low_order_size = upperbound_size & 0xFFFFFFFF; HANDLE win_handle = CreateFileMapping( - INVALID_HANDLE_VALUE, // use paging file - NULL, // default security - PAGE_READWRITE, // read/write access - high_order_size, // maximum object size (high-order DWORD) - low_order_size, // maximum object size (low-order DWORD) - shm_key); // name of mapping object + backing_file_handle, // use backing file + NULL, // default security + PAGE_READWRITE, // read/write access + high_order_size, // maximum object size (high-order DWORD) + low_order_size, // maximum object size (low-order DWORD) + shm_key); // name of mapping object if (win_handle == NULL) { - return -7; + LPCSTR backing_file_path = + std::string(TRITON_SHM_FILE_ROOT + std::string(shm_key)).c_str(); + // Cleanup backing file on failure + SharedMemoryDeleteBackingFile(shm_key, backing_file_handle); + return -8; } - ShmFile* shm_file = new ShmFile(win_handle); + std::unique_ptr shm_file = + std::make_unique(backing_file_handle, win_handle); // get base address of shared memory region void* shm_addr = nullptr; - int err = SharedMemoryRegionMap(shm_file, 0, byte_size, &shm_addr); + err = SharedMemoryRegionMap(shm_file.get(), 0, byte_size, &shm_addr); if (err == -1) { + SharedMemoryDeleteBackingFile(shm_key, backing_file_handle); return -4; } #else @@ -149,18 +194,18 @@ SharedMemoryRegionCreate( return -3; } - ShmFile* shm_file = new ShmFile(shm_fd); + std::unique_ptr shm_file = std::make_unique(shm_fd); // get base address of shared memory region void* shm_addr = nullptr; - int err = SharedMemoryRegionMap(shm_file, 0, byte_size, &shm_addr); + int err = SharedMemoryRegionMap(shm_file.get(), 0, byte_size, &shm_addr); if (err == -1) { return -4; } #endif // create a handle for the shared memory region *shm_handle = SharedMemoryHandleCreate( - std::string(triton_shm_name), shm_addr, std::string(shm_key), shm_file, 0, - byte_size); + std::string(triton_shm_name), shm_addr, std::string(shm_key), + std::move(shm_file), 0, byte_size); return 0; } @@ -186,7 +231,8 @@ GetSharedMemoryHandleInfo( *offset = handle->offset_; *byte_size = handle->byte_size_; #ifdef _WIN32 - file->shm_handle_ = handle->platform_handle_->shm_handle_; + file->backing_file_handle_ = handle->platform_handle_->shm_mapping_handle_; + file->shm_mapping_handle_ = handle->platform_handle_->shm_mapping_handle_; #else file->shm_fd_ = handle->platform_handle_->shm_fd_; #endif @@ -204,10 +250,12 @@ SharedMemoryRegionDestroy(void* shm_handle) if (!success) { return -6; } - // We keep Windows shared memory handles open until we are done - // using them. When all handles are closed, the system will free - // the section of the paging file that the object uses. - CloseHandle(handle->platform_handle_->shm_handle_); + CloseHandle(handle->platform_handle_->shm_mapping_handle_); + int err = SharedMemoryDeleteBackingFile( + handle->shm_key_.c_str(), handle->platform_handle_->backing_file_handle_); + if (err == -1) { + return -9; + } #else int status = munmap(shm_addr, handle->byte_size_); if (status == -1) { diff --git a/src/python/library/tritonclient/utils/shared_memory/shared_memory_handle.h b/src/python/library/tritonclient/utils/shared_memory/shared_memory_handle.h index 0bf7b71cb..bd264546a 100644 --- a/src/python/library/tritonclient/utils/shared_memory/shared_memory_handle.h +++ b/src/python/library/tritonclient/utils/shared_memory/shared_memory_handle.h @@ -37,8 +37,11 @@ struct ShmFile { #ifdef _WIN32 - HANDLE shm_handle_; - ShmFile(HANDLE shm_handle) : shm_handle_(shm_handle){}; + HANDLE backing_file_handle_; + HANDLE shm_mapping_handle_; + ShmFile(HANDLE backing_file_handle, HANDLE shm_mapping_handle) + : backing_file_handle_(backing_file_handle), + shm_mapping_handle_(shm_mapping_handle){}; #else int shm_fd_; ShmFile(int shm_fd) : shm_fd_(shm_fd){};