Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize GPU tensor support for Python backend #293

Merged
merged 16 commits into from
Oct 25, 2023
Merged

Conversation

krishung5
Copy link
Contributor

@krishung5 krishung5 commented Aug 31, 2023

Currently, the CUDA IPC calls dominate the time for transferring GPU tensors between processes. Specifically, the functions cudaIpcOpenMemHandle and cudaIpcCloseMemHandle are heavily used. These functions are necessary because the allocated buffers are not in the same pool. So, we need to call these functions to open an interprocess memory handle exported from another process and get a device pointer that can be used in the local process.

This PR makes use of Triton's CUDA shared memory pool for GPU tensor transfers. The parent process will get the base address of the CUDA pool and share it with the stub. Here's how the data transfer process works:

  • Data transfer from parent to stub
    When the parent process wants to send a tensor to the stub using the pool, it stores the data, calculates the offset, and shares this offset with the stub. The stub process then uses this offset to retrieve the data.
  • Data transfer from stub to parent
    Because only the parent process can interact with the CUDA pool memory allocation, the stub first notifies the parent about the byte size of the data using an IPC message. Then, the parent pre-allocates a buffer from the memory pool and communicates the calculated offset details back to the stub. Afterward, the stub fills the buffer with the tensor data and notifies the parent once the task is done.

Testing: triton-inference-server/server#6276

@krishung5 krishung5 marked this pull request as ready for review September 1, 2023 17:56
@Tabrizian
Copy link
Member

Discussed with @krishung5 offline. Looks like there is one additional data copy that is introduced as a part of this change that is affecting the single model latency and throughput. @krishung5 is working on removing that extra copy and gather profiling numbers again.

src/python_be.cc Outdated Show resolved Hide resolved
src/python_be.cc Outdated Show resolved Hide resolved
src/shm_manager.h Outdated Show resolved Hide resolved
src/stub_launcher.h Outdated Show resolved Hide resolved
src/python_be.cc Outdated Show resolved Hide resolved
src/python_be.cc Outdated Show resolved Hide resolved
src/pb_memory.h Outdated Show resolved Hide resolved
@krishung5 krishung5 requested a review from Tabrizian September 12, 2023 08:55
src/pb_memory.cc Outdated Show resolved Hide resolved
src/pb_memory.cc Outdated Show resolved Hide resolved
src/pb_stub.cc Outdated Show resolved Hide resolved
@krishung5
Copy link
Contributor Author

krishung5 commented Oct 5, 2023

Updated the functionality of PbMemory - when creating a PbMemory object, it only sets the cuda_pool_offset if the data is allocated from the cuda pool; when loading PbMemory from memory, it'll return a pointer based on the offset. No extra data copy/logic happens inside PbMemory.

Summarize the logic for GPU tensor transfer below in case it could be helpful for reviewing.

Different cases for the GPU tensor transfer:

Inference

  • input
    • non-decoupled - Get the input buffer from input_collector. If the buffer is not using cuda pool, create a new BackendMemory from the pool and copy the input data from the original buffer to the new one. The input_collector will allocate the buffer using GPU pool first using BackendMemory. Won't need any extra handling here.
    • decoupled - Allocate the GPU memory using BackendMemory, and call backend::ReadInputTensor to read input tensor to the buffer.
  • output
    • non-decoupled - In InferResponse::Send function, get the Triton-provided output buffer. If the buffer is not using cuda pool, try to allocate a new buffer from the pool, and add the new buffer to gpu_buffer_helper. Will need to copy the output tensor back to the Triton-provided buffer once the stub fills in the buffer.
    • decoupled - Same as the non-decoupled case. The final copy happens here.

BLS

  • input - No differences between decoupled and non-decoupled cases. Allocate memory using BackendMemory and add the buffer to gpu_buffer_helper.
  • output - The request_executor now uses BackendMemory to allocate buffer for BLS output. Both non-decoupled and decoupled cases call ModelInstanceState::PrepareResponseHandle to prepare the response. It's possible that the cuda memory pool hasn't been shared with the stub process at the time the BLS output is allocated during the callback, and there is no way to share the cuda pool with the stub since we are not passing the StubLauncher object to the ResponseAlloc callback (thought it would be more complicated if we do so, but open to any ideas!). Hence, update the cuda pool offset here after the associated PbMemory is created.

@krishung5 krishung5 requested a review from Tabrizian October 5, 2023 10:09
src/infer_response.cc Outdated Show resolved Hide resolved
src/python_be.cc Outdated Show resolved Hide resolved
src/python_be.cc Show resolved Hide resolved
Copy link
Member

@Tabrizian Tabrizian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, Kris!

src/memory_manager.cc Outdated Show resolved Hide resolved
src/python_be.cc Show resolved Hide resolved
@krishung5 krishung5 merged commit 4c0a977 into main Oct 25, 2023
3 checks passed
@krishung5 krishung5 deleted the krish-python-gpu branch October 25, 2023 22:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

2 participants