Skip to content

Commit

Permalink
zte ragas
Browse files Browse the repository at this point in the history
  • Loading branch information
NickLennonLiu committed Apr 17, 2024
1 parent ec877a5 commit 04c8190
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 8 deletions.
2 changes: 1 addition & 1 deletion configs/datasets/opseval/qa_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_qa_gen_datasets(dataset_name, path, langs=['zh'], qtypes=None):
inferencer=get_gen_inferencer(sc=False),
),
# eval_cfg=dict(evaluator=dict(type=BleuRougeEvaluator))
eval_cfg=dict(evaluator=dict(type=OpsEvalGenQAEvaluator, language=lang), need_ragas=True)
eval_cfg=dict(evaluator=dict(type=OpsEvalGenQAEvaluator, language=lang), need_ragas=True, num_gpus=4)
)
for shot_abbr, shot_hint_id, retriever_dict in zip(
['Zero-shot', '3-shot'],
Expand Down
4 changes: 2 additions & 2 deletions configs/tests/test_ragas.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
*owl_qa_ppl,
*rzy_qa_gen,
*rzy_qa_ppl,
*zedx_qa_gen,
# *zedx_qa_gen,
]

datasets = [
Expand All @@ -46,7 +46,7 @@
dataset['infer_cfg']['inferencer']['sc_size'] = 2
dataset['infer_cfg']['inferencer']['max_token_len'] = 200
dataset['eval_cfg']['sc_size'] = 2
dataset['sample_setting'] = dict(sample_size=5) # !!!WARNING: Use for testing only!!!
dataset['sample_setting'] = dict(sample_size=100) # !!!WARNING: Use for testing only!!!


infer = dict(
Expand Down
3 changes: 2 additions & 1 deletion configs/xz/runconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
zeroshot_datasets = []
fewshot_datasets = []

for dataset in [*ceval_mc_ppl,*network_mc_ppl,*zte_mc_ppl,*owl_mc_ppl,*oracle_mc_ppl,*company_mc_ppl,*ceval_mc_gen,*network_mc_gen,*zte_mc_gen,*owl_mc_gen,*oracle_mc_gen,*company_mc_gen,*zedx_qa_gen,*zedx_qa_ppl]:
for dataset in [*ceval_mc_ppl,*ceval_mc_gen,*zedx_qa_gen,*zedx_qa_ppl]:
# dataset['path'] = dataset['path'].replace('/mnt/mfs/opsgpt/evaluation','/mnt/home/opseval/evaluation/')
dataset['sample_setting'] = dict()
dataset['infer_cfg']['inferencer']['save_every'] = 8
Expand Down Expand Up @@ -131,5 +131,6 @@
runner=dict(
type=LocalRunner,
max_num_workers=8,
max_workers_per_ragas=10,
task=dict(type=OpenICLEvalTask)),
)
1 change: 1 addition & 0 deletions configs/xz/runconfig_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,6 @@
runner=dict(
type=LocalRunner,
max_num_workers=8,
max_workers_per_ragas=10,
task=dict(type=OpenICLEvalTask)),
)
2 changes: 2 additions & 0 deletions opencompass/openicl/icl_evaluator/opseval_gen_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def clean_word(words):

def get_ragas_score(self, predictions, references, test_set) -> dict:
from opencompass.ragas.judge import calculate_score
for ref, q in zip(references, test_set):
assert ref == q['answer'], 'Reference and test set not match!'
reference = [{"id": idx, "question": question, "answer": ref}
for idx, (question, ref) in enumerate(zip(test_set['question'], references))]
answers = [{"id": idx, "question": question, "answer": ans}
Expand Down
13 changes: 13 additions & 0 deletions opencompass/ragas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ def load_llm(ragas_config: dict) -> BaseLanguageModel:

return ChatTongyi(model=models_config.get('llm_model', 'qwen1.5-72b-chat'))

elif llm_type == 'vllm':

from langchain_community.llms import VLLM

llm = VLLM(model="/mnt/tenant-home_speed/gaozhengwei/projects/LLM/models/Qwen/Qwen1.5-72B-Chat",
trust_remote_code=True,
vllm_kwargs={
# "tensor_parallel_size": 4,
"gpu_memory_utilization": 0.8,
"max_model_len": 2048,
}
)
return llm

logger.error(f'Unsupported LLM model: {llm_type}')
sys.exit(1)
Expand Down
9 changes: 8 additions & 1 deletion opencompass/ragas/judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,13 @@ def calculate_score(reference: list[dict], answers: list[dict], ragas_config: di
gt_df = pd.DataFrame(reference)
preds_df = pd.DataFrame(validate_and_format_answers(answers))
data = preprocess_data(gt_df, preds_df)
# data.to_csv("/mnt/home/lyh/ragas_data.csv")
res_df = compute_scores(data, ragas_config)
# res_df.to_csv("/mnt/home/lyh/ragas_res.csv")
detail = res_df.to_dict(orient='records')
# with open("/mnt/home/lyh/ragas_detail.json", 'w') as f:
# import json
# json.dump(detail, f, indent=4, ensure_ascii=False)

overall_score = sum([item['score'] for item in detail]) / len(detail)
accuracy = sum([item['correct'] for item in detail]) / len(detail)
Expand Down Expand Up @@ -163,7 +168,9 @@ def compute_scores(df: pd.DataFrame, ragas_config: dict) -> list[dict]:
],
llm=load_llm(ragas_config),
embeddings=load_embeddings(ragas_config),
run_config=RunConfig(max_workers=judge_config.get('max_workers', 16)),
run_config=RunConfig(max_workers=judge_config.get('max_workers', 16),
timeout=judge_config.get('timeout', 300),
max_wait=judge_config.get('max_wait', 300)),
callbacks=callbacks,
raise_exceptions=False
)
Expand Down
11 changes: 10 additions & 1 deletion opencompass/ragas/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def _compute_statement_presence(self, prediction: t.Any) -> float:
gt_keywords, overlapping_keywords = [
item if isinstance(item, list) else np.nan for item in prediction
]
if gt_keywords is None or (type(gt_keywords) == float and np.isnan(gt_keywords)):
logger.warning('[gt_keywords] gt_keywords is nan!')
gt_keywords = []
logger.warning(f"[gt_keywords] {gt_keywords}")
gt_keywords = [k.lower() for k in gt_keywords]
overlapping_keywords = [k.lower() for k in overlapping_keywords]
overlapping_keywords = [k for k in overlapping_keywords if self.match(gt_keywords, k)]
Expand Down Expand Up @@ -136,13 +140,18 @@ async def _ascore(self, row: dict, callbacks: t.Any, is_async: bool) -> float:

q, a, g = row["question"], row["answer"], row["ground_truth"]
p_value = self.correctness_prompt.format(question=q, ground_truth=g, answer=a)

# TODO: add chat_template
p_value.prompt_str = '<|im_start|>system\nYou are a helpful assistant.<|im_end|><|im_start|>user\n' + p_value.prompt_str + '<|im_end|><|im_start|>assistant\n'

is_statement_present = await self.llm.generate(
p_value, callbacks=callbacks, is_async=is_async
p_value, callbacks=callbacks, is_async=is_async, stop=['<|im_end|>', '<|endoftext|>']
)

prediction = await json_loader.safe_load(
is_statement_present.generations[0][0].text, self.llm, is_async=is_async
)
logger.warning(f"\n-------------------------------------\n[prompt] {p_value}\n[prediction] {prediction}")
f1_score = self._compute_statement_presence(prediction)

if self.weights[1] == 0:
Expand Down
4 changes: 2 additions & 2 deletions opencompass/runners/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self,
max_num_workers: int = 16,
debug: bool = False,
max_workers_per_gpu: int = 1,
max_workers_per_ragas: int = 1,
max_workers_per_ragas: int = 10,
lark_bot_url: str = None):
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
self.max_num_workers = max_num_workers
Expand Down Expand Up @@ -129,7 +129,7 @@ def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
print('DEBUG: ', gpus)

# ragas ports !AD HOC!
all_ragas_ids = [0,1]
all_ragas_ids = [0]
ragases = np.zeros(max(all_ragas_ids)+1, dtype=np.uint)
ragases[all_ragas_ids] = self.max_workers_per_ragas
ragas_lock = np.zeros(1)
Expand Down

0 comments on commit 04c8190

Please sign in to comment.