diff --git a/configs/local_models/bert/bert.py b/configs/local_models/bert/bert.py new file mode 100644 index 0000000..f6fcec5 --- /dev/null +++ b/configs/local_models/bert/bert.py @@ -0,0 +1,18 @@ +from opencompass.models import Bert + +from mmengine.config import read_base +with read_base(): + from ...paths import ROOT_DIR + +bert_large_cased = dict( + type=Bert, + abbr='bert_large_cased', + path=ROOT_DIR+"models/google-bert/bert-large-cased", + tokenizer_path=ROOT_DIR+"models/google-bert/bert-large-cased", + tokenizer_kwargs=dict(trust_remote_code=True), + max_out_len=400, + max_seq_len=2048, + batch_size=8, + model_kwargs=dict(trust_remote_code=True), + run_cfg=dict(num_gpus=1, num_procs=1), + ) \ No newline at end of file diff --git a/configs/local_models/qwen/qwen_vllm.py b/configs/local_models/qwen/qwen_vllm.py new file mode 100644 index 0000000..950be76 --- /dev/null +++ b/configs/local_models/qwen/qwen_vllm.py @@ -0,0 +1,15 @@ +from opencompass.models import VLLM + +qwen_1_5b_14b_chat_vllm = dict( + type=VLLM, + abbr='qwen-1.5b-14b-chat', + path="/home/junetheriver/models/qwen/Qwen1.5-14B-Chat", + max_seq_len=2048, + model_kwargs=dict(trust_remote_code=True, max_model_len=2048), + generation_kwargs=dict(), + meta_template=None, + mode='none', + batch_size=1, + use_fastchat_template=False, + end_str=None, +) \ No newline at end of file diff --git a/configs/lyh/t5_all.py b/configs/lyh/t5_all.py new file mode 100644 index 0000000..e69de29 diff --git a/configs/tests/test_model.py b/configs/tests/test_model.py index 89be770..87e650b 100644 --- a/configs/tests/test_model.py +++ b/configs/tests/test_model.py @@ -8,24 +8,22 @@ from ..datasets.opseval.datasets import owl_mc, owl_qa # Models from ..local_models.google.t5 import t5_base - from ..local_models.lmsys.vicuna import vicuna_bases - from ..local_models.internlm.internlm import internlm2_bases - from ..local_models.yi.yi import yi_bases - from ..local_models.mistral.mistral import mistral_7b + from ..local_models.bert.bert import bert_large_cased + from ..paths import ROOT_DIR -yi_bases = [model for model in yi_bases if '34' not in model['abbr']] datasets = [ *owl_mc, *owl_qa, ] datasets = [ - dataset for dataset in datasets if 'Zero-shot' in dataset['abbr'] + dataset for dataset in datasets if 'Zero-shot' in dataset['abbr'] and 'zh' in dataset['abbr'] ] models = [ t5_base, + # bert_large_cased, # *vicuna_bases, # *internlm2_bases, # *yi_bases, diff --git a/configs/tests/test_vllm.py b/configs/tests/test_vllm.py new file mode 100644 index 0000000..168a820 --- /dev/null +++ b/configs/tests/test_vllm.py @@ -0,0 +1,69 @@ +from mmengine.config import read_base +from opencompass.partitioners import SizePartitioner, NaivePartitioner +from opencompass.runners import LocalRunner +from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask + +with read_base(): + # Datasets + from ..datasets.opseval.datasets import owl_mc, owl_qa + # Models + from ..local_models.google.t5 import t5_base + from ..local_models.bert.bert import bert_large_cased + from ..local_models.qwen.qwen_vllm import qwen_1_5b_14b_chat_vllm + + from ..paths import ROOT_DIR + + +datasets = [ + *owl_mc, *owl_qa, +] + +datasets = [ + dataset for dataset in datasets if 'Zero-shot' in dataset['abbr'] and 'zh' in dataset['abbr'] +] + +models = [ + t5_base, + # bert_large_cased, + qwen_1_5b_14b_chat_vllm, + # *vicuna_bases, + # *internlm2_bases, + # *yi_bases, + # mistral_7b +] + +for model in models: + model['run_cfg'] = dict(num_gpus=1, num_procs=1) + pass + +for dataset in datasets: + dataset['sample_setting'] = dict() + dataset['infer_cfg']['inferencer']['save_every'] = 8 + dataset['infer_cfg']['inferencer']['sc_size'] = 2 + dataset['infer_cfg']['inferencer']['max_token_len'] = 20 + dataset['eval_cfg']['sc_size'] = 2 + dataset['sample_setting'] = dict(sample_size=2) # !!!WARNING: Use for testing only!!! + + +infer = dict( + partitioner=dict( + # type=SizePartitioner, + # max_task_size=100, + # gen_task_coef=1, + type=NaivePartitioner + ), + runner=dict( + type=LocalRunner, + max_num_workers=16, + max_workers_per_gpu=1, + task=dict(type=OpenICLInferTask), + ), +) + +eval = dict( + partitioner=dict(type=NaivePartitioner), + runner=dict( + type=LocalRunner, + max_num_workers=32, + task=dict(type=OpenICLEvalTask)), +) diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 58f1e87..d35e64c 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -33,6 +33,7 @@ # from .custom import CustomModel # noqa from .wenxin_api import WenXinAI # noqa from .t5 import T5 # noqa +from .bert import Bert # noqa from .pangu_api import PanGu # noqa: F401 from .qwen_api import Qwen # noqa: F401 from .sensetime_api import SenseTime # noqa: F401 diff --git a/opencompass/models/bert.py b/opencompass/models/bert.py new file mode 100644 index 0000000..e535dd6 --- /dev/null +++ b/opencompass/models/bert.py @@ -0,0 +1,392 @@ +import os +from typing import Dict, List, Optional, Union + +import numpy as np +import torch + +from opencompass.models.base import BaseModel +from opencompass.registry import MODELS +from opencompass.utils.logging import get_logger +from opencompass.utils.prompt import PromptList +from transformers import GenerationConfig + +PromptType = Union[PromptList, str] + + +@MODELS.register_module() +class Bert(BaseModel): + """Model wrapper around Bert general models. + + Args: + path (str): The name or path to Bert's model. + hf_cache_dir: Set the cache dir to HF model cache dir. If None, it will + use the env variable HF_MODEL_HUB. Defaults to None. + max_seq_len (int): The maximum length of the input sequence. Defaults + to 2048. + tokenizer_path (str): The path to the tokenizer. Defaults to None. + tokenizer_kwargs (dict): Keyword arguments for the tokenizer. + Defaults to {}. + peft_path (str, optional): The name or path to the Bert's PEFT + model. If None, the original model will not be converted to PEFT. + Defaults to None. + tokenizer_only (bool): If True, only the tokenizer will be initialized. + Defaults to False. + model_kwargs (dict): Keyword arguments for the model, used in loader. + Defaults to dict(device_map='auto'). + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + extract_pred_after_decode (bool): Whether to extract the prediction + string from the decoded output string, instead of extract the + prediction tokens before decoding. Defaults to False. + batch_padding (bool): If False, inference with be performed in for-loop + without batch padding. + + Note: + About ``extract_pred_after_decode``: Commonly, we should extract the + the prediction tokens before decoding. But for some tokenizers using + ``sentencepiece``, like LLaMA, this behavior may change the number of + whitespaces, which is harmful for Python programming tasks. + """ + + def __init__( + self, + path: str, + hf_cache_dir: Optional[str] = None, + max_seq_len: int = 2048, + tokenizer_path: Optional[str] = None, + tokenizer_kwargs: dict = dict(), + peft_path: Optional[str] = None, + tokenizer_only: bool = False, + model_kwargs: dict = dict(device_map='auto'), + meta_template: Optional[Dict] = None, + extract_pred_after_decode: bool = False, + batch_padding: bool = False, + generate_kwargs: dict = None, + ): + super().__init__(path=path, + max_seq_len=max_seq_len, + tokenizer_only=tokenizer_only, + meta_template=meta_template) + from opencompass.utils.fileio import patch_hf_auto_model + if hf_cache_dir is None: + hf_cache_dir = os.getenv('HF_MODEL_HUB', None) + patch_hf_auto_model(hf_cache_dir) + self.logger = get_logger() + self._load_tokenizer(path=path, + tokenizer_path=tokenizer_path, + tokenizer_kwargs=tokenizer_kwargs) + self.batch_padding = batch_padding + self.extract_pred_after_decode = extract_pred_after_decode + self.generate_kwargs = generate_kwargs if generate_kwargs else dict() + if not tokenizer_only: + self._load_model(path=path, + model_kwargs=model_kwargs, + peft_path=peft_path) + + def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], + tokenizer_kwargs: dict): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path if tokenizer_path else path, **tokenizer_kwargs) + + # A patch for QwenTokenizer + if self.tokenizer.__class__.__name__ == 'QWenTokenizer': + self.tokenizer.pad_token_id = self.tokenizer.eod_id + self.tokenizer.bos_token_id = self.tokenizer.eod_id + self.tokenizer.eos_token_id = self.tokenizer.eod_id + + if self.tokenizer.pad_token_id is None: + self.logger.warning('pad_token_id is not set for the tokenizer. ' + 'Using eos_token_id as pad_token_id.' + f'Which is {self.tokenizer.eos_token_id}') + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + + # A patch for llama when batch_padding = True + if 'decapoda-research/llama' in path or \ + (tokenizer_path and + 'decapoda-research/llama' in tokenizer_path): + self.logger.warning('We set new pad_token_id for LLaMA model') + # keep consistent with official LLaMA repo + # https://github.com/google/sentencepiece/blob/master/python/sentencepiece_python_module_example.ipynb # noqa + self.tokenizer.bos_token = '' + self.tokenizer.eos_token = '' + self.tokenizer.pad_token_id = 0 + + def _load_model(self, + path: str, + model_kwargs: dict, + peft_path: Optional[str] = None): + from transformers import BertLMHeadModel + + model_kwargs.setdefault('torch_dtype', torch.float16) + self.model = BertLMHeadModel.from_pretrained(path, **model_kwargs) + if peft_path is not None: + from peft import PeftModel + self.model = PeftModel.from_pretrained(self.model, + peft_path, + is_trainable=False) + self.model.eval() + + # A patch for llama when batch_padding = True + if 'decapoda-research/llama' in path: + self.model.config.bos_token_id = 1 + self.model.config.eos_token_id = 2 + self.model.config.pad_token_id = self.tokenizer.pad_token_id + + def generate(self, inputs: List[str], max_out_len: int, + **kwargs) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str]): A list of strings. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + kwargs = {**kwargs, **self.generate_kwargs} + + if self.batch_padding and len(inputs) > 1: + return self._batch_generate(inputs=inputs, + max_out_len=max_out_len, + **kwargs) + else: + return sum((self._single_generate( + inputs=[input_], max_out_len=max_out_len, **kwargs) + for input_ in inputs), []) + + def _batch_generate(self, inputs: List[str], max_out_len: int, + **kwargs) -> List[str]: + """Support for batch prompts inference. + + Args: + inputs (List[str]): A list of strings. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + if self.extract_pred_after_decode: + prompt_lens = [len(input_) for input_ in inputs] + + # step-1: tokenize the input with batch_encode_plus + tokens = self.tokenizer.batch_encode_plus(inputs, + padding=True, + truncation=True, + max_length=self.max_seq_len - + max_out_len) + tokens = { + k: torch.tensor(np.array(tokens[k]), device=self.model.device) + for k in tokens if k in ['input_ids', 'attention_mask'] + } + + # step-2: conduct model forward to generate output + outputs = self.model.generate(**tokens, + max_new_tokens=max_out_len, + **kwargs) + + if not self.extract_pred_after_decode: + outputs = outputs[:, tokens['input_ids'].shape[1]:] + + decodeds = self.tokenizer.batch_decode(outputs, + skip_special_tokens=True) + + if self.extract_pred_after_decode: + decodeds = [ + token[len_:] for token, len_ in zip(decodeds, prompt_lens) + ] + + return decodeds + + def _single_generate(self, inputs: List[str], max_out_len: int, + **kwargs) -> List[str]: + """Support for single prompt inference. + + Args: + inputs (List[str]): A list of strings. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + if self.extract_pred_after_decode: + prompt_lens = [len(input_) for input_ in inputs] + + input_ids = self.tokenizer(inputs, + truncation=True, + max_length=self.max_seq_len - + max_out_len)['input_ids'] + input_ids = torch.tensor(input_ids, device=self.model.device) + # To accommodate the PeftModel, parameters should be passed in + # key-value format for generate. + outputs = self.model.generate(input_ids=input_ids, + max_new_tokens=max_out_len, + **kwargs) + + if not self.extract_pred_after_decode: + outputs = outputs[:, input_ids.shape[1]:] + + decodeds = self.tokenizer.batch_decode(outputs, + skip_special_tokens=True) + + if self.extract_pred_after_decode: + decodeds = [ + token[len_:] for token, len_ in zip(decodeds, prompt_lens) + ] + + return decodeds + + def get_logits(self, inputs: List[str]): + + if self.batch_padding and len(inputs) > 1: + # batch inference + tokens = self.tokenizer(inputs, + padding=True, + truncation=True, + max_length=self.max_seq_len) + + tokens = { + k: torch.tensor(np.array(tokens[k]), device=self.model.device) + for k in tokens if k in ['input_ids', 'attention_mask'] + } + outputs = self.model(**tokens) + + else: + inputs = self.tokenizer( + inputs, + padding=False, + truncation=True, + max_length=self.max_seq_len) + input_ids = inputs['input_ids'] + # input_ids = torch.tensor(input_ids, device=self.model.device) + tokens = {'input_ids': input_ids} + + # outputs = self.model(input_ids, decoder_input_ids=input_ids) + outputs = self.model(**inputs, labels=input_ids) + return outputs[0], {'tokens': tokens} + + def get_ppl(self, + inputs: List[str], + mask_length: Optional[List[int]] = None) -> List[float]: + """Get perplexity scores given a list of inputs. + + Args: + inputs (List[str]): A list of strings. + mask_length (Optional[List[int]]): A list of mask lengths. If + provided, the perplexity scores will be calculated with the + first mask_length[i] tokens masked out. It's okay to skip + its implementation if advanced features in PPLInfernecer is + not needed. + + Returns: + List[float]: A list of perplexity scores. + """ + + if self.batch_padding and len(inputs) > 1: + assert self.tokenizer.pad_token + return self._get_ppl(inputs, mask_length=mask_length) + else: + return np.concatenate([ + self._get_ppl(inputs=[text], mask_length=mask_length) + for text in inputs + ]) + + def _get_ppl(self, + inputs: List[str], + mask_length: Optional[List[int]] = None) -> List[float]: + """Get perplexity scores given a list of inputs. + + Args: + inputs (List[str]): A list of strings. + mask_length (Optional[List[int]]): A list of mask lengths. If + provided, the perplexity scores will be calculated with the + first mask_length[i] tokens masked out. It's okay to skip + its implementation if advanced features in PPLInfernecer is + not needed. + + Returns: + List[float]: A list of perplexity scores. + """ + + outputs, inputs = self.get_logits(inputs) + shift_logits = outputs[..., :-1, :].contiguous() + + shift_labels = inputs['tokens']['input_ids'][..., 1:].contiguous() + + # if not self.tokenizer.pad_token_id: + # self.tokenizer.pad_token_id = 151643 # TODO: temporally measure!!! PLEASE FIX LATER!! + loss_fct = torch.nn.CrossEntropyLoss( + reduction='none', ignore_index=self.tokenizer.pad_token_id) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1)).view(shift_labels.size()) + + if mask_length is not None: + mask = torch.zeros_like(shift_labels) # [batch,seqlen] + for i in range(len(mask)): + for j in range(mask_length[i] - 1, len(mask[i])): + mask[i][j] = 1 + loss = loss * mask + + lens = (inputs['tokens']['input_ids'] != + self.tokenizer.pad_token_id).sum(-1).cpu().numpy() + if mask_length is not None: + lens -= np.array(mask_length) + ce_loss = loss.sum(-1).cpu().detach().to(torch.float).cpu().numpy() / lens # FIXING ERROR BFloat16 unsupported + return ce_loss + + def get_token_len(self, prompt: str) -> int: + """Get lengths of the tokenized strings. + + Args: + prompt (str): Input string. + + Returns: + int: Length of the input tokens + """ + return len(self.tokenizer.encode(prompt)) + + +@MODELS.register_module() +class BertCausalLM(Bert): + """Model wrapper around Bert CausalLM. + + Args: + path (str): The name or path to Bert's model. + hf_cache_dir: Set the cache dir to HF model cache dir. If None, it will + use the env variable HF_MODEL_HUB. Defaults to None. + max_seq_len (int): The maximum length of the input sequence. Defaults + to 2048. + tokenizer_path (str): The path to the tokenizer. Defaults to None. + tokenizer_kwargs (dict): Keyword arguments for the tokenizer. + Defaults to {}. + peft_path (str, optional): The name or path to the Bert's PEFT + model. If None, the original model will not be converted to PEFT. + Defaults to None. + tokenizer_only (bool): If True, only the tokenizer will be initialized. + Defaults to False. + model_kwargs (dict): Keyword arguments for the model, used in loader. + Defaults to dict(device_map='auto'). + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + batch_padding (bool): If False, inference with be performed in for-loop + without batch padding. + """ + + def _load_model(self, + path: str, + model_kwargs: dict, + peft_path: Optional[str] = None): + from transformers import AutoModelForCausalLM + + model_kwargs.setdefault('torch_dtype', torch.float16) + self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs) + if peft_path is not None: + from peft import PeftModel + self.model = PeftModel.from_pretrained(self.model, + peft_path, + is_trainable=False) + self.model.eval()