You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
1. I have searched related issues but cannot get the expected help.
2. The bug has not been fixed in the latest version.
3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
Describe the bug
When I use lmdeploy to deploy minicpm-V-2_6 in triton, I find that the GPU memory keeps increasing until an exception occurs due to GPU memory issues, but this does not happen when I using a python script for loop inference. I experimented with the official model weights(https://huggingface.co/openbmb/MiniCPM-V-2_6) and still had the same problem. When I use torch.cuda.empty_cache() to free unused GPU memory in each request, I can solve this problem, but it will cause the inference time to increase.
import os
import triton_python_backend_utils as pb_utils
import json
import time
os.system(f"pip3 list")
from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
from lmdeploy.vl import load_image
import torch
import numpy as np
import io
from PIL import Image
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
class TritonPythonModel:
"""Your Python model must use the same class name. Every Python model
that is created must have "TritonPythonModel" as the class name.
"""
def initialize(self, args):
"""`initialize` is called only once when the model is being loaded.
Implementing `initialize` function is optional. This function allows
the model to intialize any state associated with this model.
Parameters
----------
args : dict
Both keys and values are strings. The dictionary keys and values are:
* model_config: A JSON string containing the model configuration
* model_instance_kind: A string containing model instance kind
* model_instance_device_id: A string containing model instance device ID
* model_repository: Model repository path
* model_version: Model version
* model_name: Model name
"""
# You must parse model_config. JSON string is not parsed here
pwd = os.path.abspath(os.path.dirname(__file__))
self.ckpt_path = f'{pwd}/MiniCPM-V-2_6/'
print('self.ckpt_path', self.ckpt_path)
engine_config = TurbomindEngineConfig(session_len=8192, max_batch_size=1)
self.pipeline = pipeline(self.ckpt_path, backend_config=engine_config)
self.gen_config = GenerationConfig(top_k=1)
print(os.environ['LD_LIBRARY_PATH'])
def execute(self, requests):
"""`execute` must be implemented in every Python model. `execute`
function receives a list of pb_utils.InferenceRequest as the only
argument. This function is called when an inference request is made
for this model. Depending on the batching configuration (e.g. Dynamic
Batching) used, `requests` may contain multiple requests. Every
Python model, must create one pb_utils.InferenceResponse for every
pb_utils.InferenceRequest in `requests`. If there is an error, you can
set the error argument when creating a pb_utils.InferenceResponse
Parameters
----------
requests : list
A list of pb_utils.InferenceRequest
Returns
-------
list
A list of pb_utils.InferenceResponse. The length of this list must
be the same as `requests`
"""
responses = []
# 后处理模型输出,并为每个请求创建响应
for request in requests:
#tic_pre = time.time()
query = pb_utils.get_input_tensor_by_name(request, "query")
query = query.as_numpy()
image = pb_utils.get_input_tensor_by_name(request, "image")
image = image.as_numpy()
#pil_image = load_image(image[0][0].decode('utf-8'))
pil_image = Image.open(io.BytesIO(image[0][0])).convert('RGB')
messages = []
message = dict()
message['role'] = 'user'
lm_content = []
lm_content.append(dict(type='text', text=query[0][0].decode('utf-8')))
lm_content.append(dict(type='image_data', image_data=dict(max_slice_nums=1, data=pil_image)))
message['content'] = lm_content
messages.append(message)
#toc_pre = time.time()
#print(f'preprocess cost {toc_pre-tic_pre}')
#tic_infer = time.time()
result = self.pipeline(messages, gen_config=self.gen_config)
output_bytes = result.text.encode("utf-8")
output_bytes = np.array(output_bytes, dtype=bytes).reshape([1, -1])
output_tensor = pb_utils.Tensor("response", output_bytes)
responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor]))
# torch.cuda.empty_cache()
return responses
def finalize(self):
"""`finalize` is called only once when the model is being unloaded.
Implementing `finalize` function is OPTIONAL. This function allows
the model to perform any necessary clean ups before exit.
"""
print('Cleaning up...')
sys.platform: linux
Python: 3.9.16 (main, Apr 2 2024, 20:40:25) [GCC 10.2.1 20210130 (Red Hat 10.2.1-11)]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0: NVIDIA L40
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.2, V12.2.91
GCC: gcc (GCC) 4.8.5 20150623 (Red Hat 4.8.5-44)
PyTorch: 2.2.1
PyTorch compiling details: PyTorch built with:
- GCC 10.2
- C++ Version: 201703
- Intel(R) oneAPI Math Kernel Library Version 2024.1-Product Build 20240215 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v3.3.2 (Git Hash 2dc95a2ad0841e29db8b22fbccaf3e5da7992b01)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: AVX512
- CUDA Runtime 12.2
- NVCC architecture flags: -gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_89,code=sm_89;-gencode;arch=compute_90,code=sm_90
- CuDNN 8.9.6
- Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.2, CUDNN_VERSION=8.9.6, CXX_COMPILER=/opt/rh/devtoolset-10/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.2.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,
TorchVision: 0.17.1
LMDeploy: 0.6.1+
transformers: 4.41.2
gradio: Not Found
fastapi: 0.115.0
pydantic: 2.9.2
triton: 2.3.1
NVIDIA Topology:
GPU0 NIC0 NIC1 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X SYS SYS N/A
NIC0 SYS X SYS
NIC1 SYS SYS X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NIC Legend:
NIC0: mlx5_bond_0
NIC1: mlx5_bond_1
Error traceback
terminate called after throwing an instance of 'std::runtime_error'what(): [TM][ERROR] CUDA runtime error: out of memory /lmdeploy/src/turbomind/utils/allocator.h:246
The text was updated successfully, but these errors were encountered:
hi seems you are using quantized model /xx/model_repos/minicpmv_2_6_awq_4bit, you should add model_format='awq' to TurbomindEngineConfig, see here. Could you try again? If the issue still exists, could you provide a dockerfile to reproduce it ? thanks.
hi seems you are using quantized model /xx/model_repos/minicpmv_2_6_awq_4bit, you should add model_format='awq' to TurbomindEngineConfig, see here. Could you try again? If the issue still exists, could you provide a dockerfile to reproduce it ? thanks.
Sorry for not replying in time. I still have the above problem when using the official weights without AWQ quantization. Since I don't know how to export dockerfile, so I can only provide the information of lmdeploy check_env. In addition, my triton version is 23.02.
Checklist
Describe the bug
When I use lmdeploy to deploy minicpm-V-2_6 in triton, I find that the GPU memory keeps increasing until an exception occurs due to GPU memory issues, but this does not happen when I using a python script for loop inference. I experimented with the official model weights(https://huggingface.co/openbmb/MiniCPM-V-2_6) and still had the same problem. When I use torch.cuda.empty_cache() to free unused GPU memory in each request, I can solve this problem, but it will cause the inference time to increase.
Reproduction
The triton service startup script is as follow:
The python backend inference script is as follow:
triton config.txt is as follow:
Environment
Error traceback
The text was updated successfully, but these errors were encountered: