Skip to content

Commit

Permalink
fixing bubugs
Browse files Browse the repository at this point in the history
  • Loading branch information
NickLennonLiu committed Mar 14, 2024
1 parent a21df98 commit f10811f
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 52 deletions.
14 changes: 7 additions & 7 deletions configs/commons/inferencers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from opencompass.openicl.icl_inferencer import PPLInferencer, SCInferencer, CoTInferencer, GenInferencer


def get_ppl_inferencer(save_every=20, fixidlist=dict(fix_id_list=None)):
def get_ppl_inferencer(save_every=20):
ppl_inferencer = dict(
type=PPLInferencer,
save_every=save_every,
infer_type='PPL',
**fixidlist
# **fixidlist
)
return ppl_inferencer

def get_gen_inferencer(save_every=20,
max_out_len=400,
sc_size=1,
fixidlist=dict(fix_id_list=None),
generation_kwargs=dict(temperature=0.7),
sc=True,
):
Expand All @@ -25,7 +24,6 @@ def get_gen_inferencer(save_every=20,
infer_type='SC',
sc_size=sc_size,
max_out_len=max_out_len,
**fixidlist
)
else:
inferencer = dict(
Expand All @@ -34,11 +32,14 @@ def get_gen_inferencer(save_every=20,
generation_kwargs=generation_kwargs,
infer_type='Gen',
max_out_len=max_out_len,
**fixidlist
)
return inferencer

def get_cot_inferencer(save_every=20, max_out_len=400, sc_size=1, fixidlist=dict(fix_id_list=None), generation_kwargs=dict(temperature=0.7), cot_prompts=None):
def get_cot_inferencer(save_every=20,
max_out_len=400,
sc_size=1,
generation_kwargs=dict(temperature=0.7),
cot_prompts=None):
inferencer = dict(
type=CoTInferencer,
save_every=save_every,
Expand All @@ -47,6 +48,5 @@ def get_cot_inferencer(save_every=20, max_out_len=400, sc_size=1, fixidlist=dict
infer_type='SC',
sc_size=sc_size,
max_out_len=max_out_len,
**fixidlist
)
return inferencer
6 changes: 3 additions & 3 deletions configs/commons/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def mc_abcd_gen_prompt_template(prompt_hint, answer_hint):
role="HUMAN",
prompt=f'{prompt_hint}{{question}}\nA: {{A}}\nB: {{B}}\nC: {{C}}\nD: {{D}}\n{answer_hint}'
),
dict(role="BOT", prompt="{answer}")
# dict(role="BOT", prompt="{answer}")
],
),
ice_token="</E>",
Expand Down Expand Up @@ -85,7 +85,7 @@ def mc_abcd_cot_prompt_template(prompt_hint, cot_think_hint):
role="HUMAN",
prompt=f'{prompt_hint}{{question}}\nA: {{A}}\nB: {{B}}\nC: {{C}}\nD: {{D}}\n{cot_think_hint}'
),
dict(role="BOT", prompt="{answer}")
# dict(role="BOT", prompt="{answer}")
]
),
ice_token="</E>",
Expand Down Expand Up @@ -116,7 +116,7 @@ def qa_gen_prompt_template(prompt_hint, answer_hint):
role="HUMAN",
prompt=f'{prompt_hint}{{question}}\n{answer_hint}'
),
dict(role="BOT", prompt="{answer}")
# dict(role="BOT", prompt="{answer}")
],
),
ice_token="</E>",
Expand Down
18 changes: 8 additions & 10 deletions configs/datasets/opseval/mc_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,15 @@ def get_mc_gen_datasets(dataset_name, path, langs=['zh'], qtypes=['single']):
infer_cfg=dict(
ice_template=mc_abcd_gen_ice_template(prompt_hint, answer_hint),
prompt_template=mc_abcd_gen_prompt_template(prompt_hint, answer_hint),
retriever=dict(type=retriever, fix_id_list=fixidlist),
inferencer=get_gen_inferencer(sc_size=SAMPLE_SIZE, fixidlist=fixidlist),
retriever=retriever_dict,
inferencer=get_gen_inferencer(sc_size=SAMPLE_SIZE),
),
eval_cfg=dict(evaluator=dict(type=OpsEvalGenMCEvaluator))
)
for shot_abbr, fixidlist, shot_hint_id, retriever in zip(
for shot_abbr, shot_hint_id, retriever_dict in zip(
['Zero-shot', '3-shot'],
[dict(fix_id_list=None), dict(fix_id_list=[0, 1, 2])],
[0, 1],
[ZeroRetriever, FixKRetriever]
[dict(type=ZeroRetriever), dict(type=FixKRetriever, fix_id_list=[0,1,2])]
)
for qtype, qtype_hint_id in zip(
['single'],
Expand Down Expand Up @@ -68,15 +67,14 @@ def get_mc_gen_datasets(dataset_name, path, langs=['zh'], qtypes=['single']):
infer_cfg=dict(
ice_template=mc_abcd_cot_ice_template(prompt_hint, cot_think_hint, cot_conclude_hint),
prompt_template=mc_abcd_cot_prompt_template(prompt_hint, cot_think_hint),
retriever=dict(type=retriever, fix_id_list=fixidlist),
inferencer=get_cot_inferencer(sc_size=SAMPLE_SIZE, fixidlist=fixidlist, cot_prompts=cot_conclude_hint),
retriever=retriever_dict,
inferencer=get_cot_inferencer(sc_size=SAMPLE_SIZE, cot_prompts=cot_conclude_hint),
),
eval_cfg=dict(evaluator=dict(type=OpsEvalGenMCEvaluator)))
for shot_abbr, fixidlist, shot_hint_id, retriever in zip(
for shot_abbr, shot_hint_id, retriever_dict in zip(
['Zero-shot', '3-shot'],
[dict(fix_id_list=None), dict(fix_id_list=[0,1,2])],
[0, 1],
[ZeroRetriever, FixKRetriever]
[dict(type=ZeroRetriever), dict(type=FixKRetriever, fix_id_list=[0,1,2])]
)
for qtype, qtype_hint_id in zip(
['single', 'multiple'],
Expand Down
9 changes: 4 additions & 5 deletions configs/datasets/opseval/mc_ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,14 @@ def get_mc_ppl_datasets(dataset_name, path, langs=['zh'], qtypes=['single']):
infer_cfg=dict(
ice_template=mc_abcd_ppl_ice_template(prompt_hint, answer_hint),
prompt_template=mc_abcd_ppl_prompt_template(prompt_hint, answer_hint),
retriever=dict(type=retriever, fix_id_list=fixidlist),
inferencer=get_ppl_inferencer(fixidlist=fixidlist),
retriever=retriever_dict,
inferencer=get_ppl_inferencer(),
),
eval_cfg=dict(evaluator=dict(type=AccEvaluator)))
for shot_abbr, fixidlist, shot_hint_id, retriever in zip(
for shot_abbr, shot_hint_id, retriever_dict in zip(
['Zero-shot', '3-shot'],
[dict(fix_id_list=None), dict(fix_id_list=[0, 1, 2])],
[0, 1],
[ZeroRetriever, FixKRetriever]
[dict(type=ZeroRetriever), dict(type=FixKRetriever, fix_id_list=[0,1,2])]
)
for qtype, qtype_hint_id in zip(
['single'],
Expand Down
9 changes: 4 additions & 5 deletions configs/datasets/opseval/qa_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,15 @@ def get_qa_gen_datasets(dataset_name, path, langs=['zh'], qtypes=None):
infer_cfg=dict(
ice_template=qa_gen_ice_template(prompt_hint, answer_hint),
prompt_template=qa_gen_prompt_template(prompt_hint, answer_hint),
retriever=dict(type=retriever, fix_id_list=fixidlist),
inferencer=get_gen_inferencer(sc=False, fixidlist=fixidlist),
retriever=retriever_dict,
inferencer=get_gen_inferencer(sc=False),
),
eval_cfg=dict(evaluator=dict(type=BleuRougeEvaluator))
)
for shot_abbr, fixidlist, shot_hint_id, retriever in zip(
for shot_abbr, shot_hint_id, retriever_dict in zip(
['Zero-shot', '3-shot'],
[dict(fix_id_list=None), dict(fix_id_list=[0, 1, 2])],
[0, 1],
[ZeroRetriever, FixKRetriever]
[dict(type=ZeroRetriever), dict(type=FixKRetriever, fix_id_list=[0,1,2])]
)
for lang, prompt_hint, answer_hint in zip(
['zh', 'en'],
Expand Down
9 changes: 4 additions & 5 deletions configs/datasets/opseval/qa_ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,15 @@ def get_qa_ppl_datasets(dataset_name, path, langs=['zh'], qtypes=None):
infer_cfg=dict(
ice_template=qa_ppl_ice_template(prompt_hint, answer_hint),
prompt_template=qa_ppl_prompt_template(prompt_hint, answer_hint),
retriever=dict(type=retriever, fix_id_list=fixidlist),
inferencer=get_ppl_inferencer(fixidlist=fixidlist),
retriever=retriever_dict,
inferencer=get_ppl_inferencer(),
),
eval_cfg=dict(evaluator=dict(type=AccEvaluator))
)
for shot_abbr, fixidlist, shot_hint_id, retriever in zip(
for shot_abbr, shot_hint_id, retriever_dict in zip(
['Zero-shot', '3-shot'],
[dict(fix_id_list=None), dict(fix_id_list=[0, 1, 2])],
[0, 1],
[ZeroRetriever, FixKRetriever]
[dict(type=ZeroRetriever), dict(type=FixKRetriever, fix_id_list=[0,1,2])]
)
for lang, prompt_hint, answer_hint in zip(
['zh', 'en'],
Expand Down
10 changes: 5 additions & 5 deletions configs/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Datasets
from ..datasets.opseval.datasets import owl_mc, owl_qa
# Models
from ..local_models.qwen.qwen import qwen1_5_0_5b_base
from ..local_models.qwen.qwen import qwen1_5_4b_base
from ..paths import ROOT_DIR

datasets = [
Expand All @@ -16,16 +16,16 @@


models = [
qwen1_5_0_5b_base
qwen1_5_4b_base
]

for dataset in datasets:
dataset['sample_setting'] = dict()
dataset['infer_cfg']['inferencer']['save_every'] = 8
dataset['infer_cfg']['inferencer']['sc_size'] = 2
dataset['infer_cfg']['inferencer']['sc_size'] = 3
dataset['infer_cfg']['inferencer']['max_token_len'] = 20
dataset['eval_cfg']['sc_size'] = 2
dataset['sample_setting'] = dict(sample_size=2) # !!!WARNING: Use for testing only!!!
dataset['eval_cfg']['sc_size'] = 3
dataset['sample_setting'] = dict(sample_size=10) # !!!WARNING: Use for testing only!!!


infer = dict(
Expand Down
7 changes: 1 addition & 6 deletions opencompass/openicl/icl_inferencer/icl_cot_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
sc_size: Optional[int] = 1,
infer_type: Optional[str] = '',
cot_prompts: Optional[List[str]] = [''],
Expand All @@ -73,7 +72,6 @@ def __init__(
self.generation_kwargs = generation_kwargs
self.max_out_len = max_out_len
print(f'self.max_out_len: {self.max_out_len}')
self.fix_id_list = fix_id_list
self.sc_size = sc_size
self.cot_prompts = cot_prompts

Expand All @@ -96,10 +94,7 @@ def inference(self,
output_json_filename = self.output_json_filename

# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()

# 3. Generate prompts for testing input
prompt_list, reference_list = self.get_generation_prompt_list_from_retriever_indices( # noqa
Expand Down
7 changes: 1 addition & 6 deletions opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ def __init__(
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
labels: Optional[List] = None,
fix_id_list: Optional[List[int]] = None,
**kwargs) -> None:
super().__init__(
model=model,
Expand All @@ -263,7 +262,6 @@ def __init__(
)

self.labels = labels
self.fix_id_list = fix_id_list

def inference(self,
retriever: BaseRetriever,
Expand All @@ -285,10 +283,7 @@ def inference(self,
output_json_filename = self.output_json_filename

# 2. Get results of retrieval process
if self.fix_id_list:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()

# 3. Generate in-context examples for testing inputs
# (OpsEval) and get reference
Expand Down
1 change: 1 addition & 0 deletions opencompass/openicl/icl_retriever/icl_fix_k_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self,
def retrieve(self):
"""Retrieve the in-context example index for each test example."""
num_idx = len(self.index_ds)
print(num_idx, self.fix_id_list)
for idx in self.fix_id_list:
assert idx < num_idx, f'Index {idx} is out of range of {num_idx}'
rtr_idx_list = []
Expand Down

0 comments on commit f10811f

Please sign in to comment.