Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] InternVL2-1B performance of lmdeploy is much worse compared to the original Hugging Face PyTorch model. #2705

Open
3 tasks done
henry16lin opened this issue Nov 4, 2024 · 7 comments
Assignees

Comments

@henry16lin
Copy link

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.

Describe the bug

InternVL2-1B performance of lmdeploy is much worse compared to the original Hugging Face PyTorch model.

Reproduction

Thanks for your great work.
My task is to do image caption and I tried OpenGVLab/InternVL2-1B from Here and used the sample code from quick start to run the model. I found the performance is clearly worse in lmdeploy (but faster).
For example:
test_img
(source image: https://content.api.news/v3/images/bin/7c169b05712f7657366268afaa47ae88 )

  • transformers version will response: The image shows a group of police officers at a McDonald's restaurant, with a white car parked in front of them.
  • lmdeploy version will response: The image shows a police officer and a uniformed officer in a parking lot in front of a McDonald's restaurant. The officer in the uniform is in a handcuff, and the officer in the police uniform is in a handcuff. The McDonald's restaurant is in the background with a McDonald's and a Starbucks Coffee.

One can see the responses keep repeating similar sentence.
I tried many cases and different parameter (include repetition_penalty) but didn't work and sometime lmdeploy also response Chinese like The CCTV is in the parking lot of a convenience store, with a Coke-Cola vending machine and a suspiciously parked car. The store's facade is a stone and brick wall, and the parking lot is a curb with a red and white警示标志。
(although I have replaced Chinese system prompt in pipe.chat_template.meta_instruction and asked it must answer in English)

I tried QwenVL2-2B, and its performance remained consistent, but the memory usage of InternVL2-1B is more suitable for the scenario I'm dealing with.

Do you have any thoughts or suggestions on this issue? I'm using lmdeploy 0.6.2

BTW, I try a few VLM models (QwenVL, InternVL series) and I found the memory usage is higher(5g->6G) but speed is faster(10s->3s). Is this the expected phenomenon?

Thank you very much!

Environment

I'm using lmdeploy 0.6.2, torch 2.2.0 (install by pip)

Error traceback

No response

@RunningLeon
Copy link
Collaborator

RunningLeon commented Nov 5, 2024

Thanks for your feedback. Could you provide a sample code to reproduce the results you mentioned using lmdeploy.

@henry16lin
Copy link
Author

henry16lin commented Nov 5, 2024

Thanks for your response. Here is my code:
The upper part is the transformers version, and the lower part is lmdeploy, separated by a comment in between

import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
import torchvision.transforms as T
from PIL import Image
import time
from torchvision.transforms.functional import InterpolationMode

from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig, PytorchEngineConfig, ChatTemplateConfig


model_path = 'OpenGVLab/InternVL2-1B'
prompt = "AI: "
system_prompt = """Describe what is happening in the image concisely. 
Focus on the key observable points or any unusual events occurring in the image. 

# Rules
- The description should be less than three sentences.
- The response must be in plain English without any formatting (e.g., no bullet points, Markdown, or JSON).
"""

test_img_list = ["test_img.jpeg"]


IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform


def resize_image(img, max_image_size):
  width, height = img.size
  
  if max(width, height) > max_image_size:
      scaling_factor = max_image_size / float(max(width, height))
      new_width = int(width * scaling_factor)
      new_height = int(height * scaling_factor)
      img = img.resize((new_width, new_height), Image.LANCZOS)
      print(f'Image is resized from {width} x {height} to {new_width}x{new_height}')
  
  return img


def load_image(image_file, transform=None):
    image = Image.open(image_file).convert('RGB')
    image = resize_image(image, 512)
    if transform:
        pixel_values = torch.unsqueeze(transform(image), dim=0)
        return pixel_values
    else:
        return image

model = AutoModel.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    # load_in_8bit=True,
    low_cpu_mem_usage=True,
    # use_flash_attn=True,
    attn_implementation="flash_attention_2",
    use_safetensors=True,
    trust_remote_code=True).eval().cuda()

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
transform = build_transform(input_size=448)
generation_config = dict(max_new_tokens=512,
                         do_sample=False,
                         temperature=0.8,
                         repetition_penalty=1,
                         )


for image_path in test_img_list:
    print(image_path)
    start_time = time.time()

    pixel_values = load_image(image_path, transform).to(torch.bfloat16).cuda()
    question = system_prompt
    response = model.chat(tokenizer, pixel_values, question, generation_config)

    print(f'elapsed time: {time.time() - start_time}')
    print(response)
    print('-' * 50)



# ================================= for lmdeploy part =================================

backend_config = TurbomindEngineConfig(# session_len=2048,
                                       # cache_chunk_size = 0,
                                       # max_batch_size = 1, # does not help 
                                       # quant_policy=8,  # does not help and have higher prob to get bad answer
                                       # cache_block_seq_len = 32,
                                       eager_mode=True, 
                                       # max_prefill_token_num=2048,
                                       enable_prefix_caching=False,  # it will output same sentence
                                       cache_max_entry_count=0.01,
                                       tp=1
                                       )

pipe = pipeline(model_path, backend_config=backend_config, log_level='WARN')
pipe.chat_template.meta_instruction = system_prompt

gen_config = GenerationConfig(temperature=0.5,
                              repetition_penalty=1.5,
                              do_sample=False,
                              # max_new_tokens=256  #default=512
                              )

for image_path in test_img_list:
    start_time = time.time()
    print(image_path)
    image = load_image(image_path)
    for i in range(3):  # it will response "internal error happened for some images, then retry (but still same error)"
        response = pipe((prompt, image), gen_config=gen_config)
        elapsed_time = time.time() - start_time

        print(f'elapsed time: {elapsed_time}')
        print(response.text)
        if response.text != "internal error happened":
            break
    print('-' * 50)

Note: my transformers version is 4.45.0 (4.46 will have got multiple values for keyword argument 'return_dict' error in first transformers part)
I appreciate your response, thank you!

@RunningLeon
Copy link
Collaborator

@henry16lin hi, thanks for your sample code. There are two differences with the inference between hf and lmdeploy in your script . Pls try to keep the inputs are same and then compare the inference outputs

  1. The pre-processings are different. Lmdeploy are aligned with example code in hf
  2. hf uses default system_prompt but lmdeploy uses the custom system_prompt in your script

@henry16lin
Copy link
Author

yes, that are few different, but I don't think that is the key factor result in poor response...
To clarify that, I just fully paste the same code (just modify the image source) from quick start in hf model card and the quality of the descriptions still varies significantly.
Here is the code: (almost same from quick start)

import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=448, max_num=12):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

# If you want to load a model using multiple GPUs, please refer to the `Multiple GPUs` section.
path = 'OpenGVLab/InternVL2-1B'
model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_flash_attn=True,
    trust_remote_code=True).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)


pixel_values = load_image('test_img.jpeg', max_num=12).to(torch.bfloat16).cuda()
generation_config = dict(max_new_tokens=1024, do_sample=True)

question = '<image>\ndescribe this image.'
response = model.chat(tokenizer, pixel_values, question, generation_config)
print(f'User: {question}\nAssistant: {response}')

print('=============================================================================')

from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image

model = 'OpenGVLab/InternVL2-1B'
image = load_image('test_img.jpeg')

pipe = pipeline(model, backend_config=TurbomindEngineConfig(session_len=8192))
response = pipe(('describe this image', image))
print(response.text)

and the response are as follow:

# hf:
User: <image>
describe this image.
Assistant: The image captures a scene outside a McDonald's restaurant. A group of police officers and other civilians are gathered around a white car in front of the restaurant. The officers appear to be engaged in a conversation or discussion. One officer is holding a folder. The car might be parked there temporarily, as there is a yellow line in the parking area. In the background, McDonald's restaurant sign can be seen alongside a greenery from McDonald's canteen. The weather appears to be rainy, as there are puddles on the road surface. A woman in pink is standing slightly away, and there is a clear view of a trash can labeled with an "PPS" mark in front of a pedestrian exit area.

# lmdeploy:
The image shows a police officer and a uniformed officer in a parking lot in front of a McDonald's restaurant. The officer in the uniform is in a handcuff, and the officer in the police uniform is in a conversation. The McDonald's restaurant is in the background with a McDonald's McDonald's and a Starbucks Coffee. The parking lot is wet, and there are a few parked cars.

Thank you for your response and please help me to use InternVL2-1B with lmdeploy 🙏

@RunningLeon
Copy link
Collaborator

@henry16lin hi, could you try with InternVL2-2B. Seems the smaller LLMs are not tolerant to the slightly implementation differences between transformers and lmdeploy.

@henry16lin
Copy link
Author

Yes, I had tried InternVL2-2B in lmdeploy but it need around 6G memory and it returns empty string to me...
My use case can only occupy less than 4GB of memory

@RunningLeon
Copy link
Collaborator

In that case, could you try with quantization to reduce runtime mem https://lmdeploy.readthedocs.io/en/latest/quantization/w4a16.html#

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants