Skip to content

Commit

Permalink
[fix] pass@k and self-consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
huyiwen committed Jul 26, 2024
1 parent 310b3bd commit c0a398f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 40 deletions.
76 changes: 55 additions & 21 deletions utilization/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ def __len__(self):
return len(self.evaluation_instances)

def __getitem__(self, idx):
return self.evaluation_instances[idx]
return deepcopy(self.evaluation_instances[idx])

def __iter__(self):
yield from self.evaluation_instances
yield from deepcopy(self.evaluation_instances)

def format_instance(self, instance: dict) -> dict:
r"""Format the dataset instance into task format. See [docs](https://github.com/RUCAIBox/LLMBox/blob/main/docs/utilization/how-to-customize-dataset.md#formating-the-instances) for more details.
Expand Down Expand Up @@ -832,6 +832,17 @@ def __repr__(self):


class DatasetCollection(torch.utils.data.Dataset):
r"""The dataset collection class that combines multiple datasets into one.
Args:
- datasets: A dictionary of dataset instances. The keys are the dataset names and the values are the dataset instances.
Examples:
Assume a DatasetCollection composed of two datasets: `sub1` and `sub2`. Each dataset has different number of evaluation instances.
- Two subets: `[sub1, sub2]`
- Two subsets with self-consistency = 3: `[sub1, sub1, sub1, sub2, sub2, sub2]`
- Two subsets with normalization: `[sub1, sub1-norm, sub2, sub2-norm]`
"""

def __init__(self, datasets: Dict[str, Dataset]):
super().__init__()
Expand Down Expand Up @@ -951,9 +962,6 @@ def set_subset(l: dict):
except Exception as e:
logger.warning(f"Failed to log predictions: {e}")

def post_processing(self, predictions: List[Union[str, float]]):
return sum((d.post_processing(p) for d, p in zip(self._datasets, self._split_by_subset(predictions))), [])

def __getitem__(self, idx):
if self.args.continue_from:
idx += self.args.continue_from
Expand All @@ -975,14 +983,39 @@ def __iter__(self):
def __getattr__(self, attr):
return getattr(self._datasets[self._cur_idx], attr)

def calculate_metric(self, predictions) -> Tuple[Dict[str, Dict[str, float]], List[Dict[str, List[float]]]]:
results = OrderedDict()
def calculate_metric(self, raw_predictions: List[Union[str, float]]) -> Dict[str, Dict[str, float]]:
r"""Post-process predictions and calculate the metric scores."""

metric_results = OrderedDict()
predictions = []
agg_predictions = []
score_lists = []
splitted = self._split_by_subset(predictions, option_num=False, normalization=False, sample_num=False)
grouped_display_names = defaultdict(list) # group by dataset
for n, d, p in zip(self.display_names, self._datasets, splitted):
subset_results, score_list = d.calculate_metric(p)
results.update(subset_results)
grouped_display_names = defaultdict(list)

for n, d, p in zip(self.display_names, self._datasets, self._split_by_subset(raw_predictions)):
# post process
preds = d.post_processing(p)

# aggregate self-consistency or pass@k
step = d.len(option_num=False, sample_num=False, normalization=False)
if self.args.pass_at_k:
# [inst1, inst2, inst1, inst2] -> [[inst1, inst1], [inst2, inst2]]
agg_preds = [preds[i::step] for i in range(step)]
elif len(preds) // step > 1:
from statistics import mode

# [inst1, inst2, inst1, inst2] -> [mode([inst1, inst1]), mode([inst2, inst2])]
agg_preds = [mode(preds[i::step]) for i in range(step)]
else:
# [inst1, inst2]
agg_preds = preds

predictions.extend(preds)
agg_predictions.extend(agg_preds)

# calculate metric
subset_results, score_list = d.calculate_metric(agg_preds)
metric_results.update(subset_results)
score_lists.append(score_list)
grouped_display_names[d.dataset_name].append(n)

Expand All @@ -995,19 +1028,20 @@ def calculate_metric(self, predictions) -> Tuple[Dict[str, Dict[str, float]], Li
# skip if not all subsets of a category are available
continue
fstr = f"{name}[{cat.title().replace('_', ' ')} Macro Average]"
results[fstr] = avg_metrics([results[n] for n in c])
metric_results[fstr] = avg_metrics([metric_results[n] for n in c])

if name == "gaokao":
r, f = zip(*[(results[name + ":" + n], f) for n, f in GAOKAO_CHINESE_TASKS_SCORE.items()])
results[name + "[Chinese Weighted Average]"] = avg_metrics(r, f, average_method="weighted")
r, f = zip(*[(results[name + ":" + n], f) for n, f in GAOKAO_ENGLISH_TASKS_SCORE.items()])
results[name + "[English Weighted Average]"] = avg_metrics(r, f, average_method="weighted")
r, f = zip(*[(results[name + ":" + n], f) for n, f in GAOKAO_TASKS_SCORE.items()])
results[name + "[Weighted Average]"] = avg_metrics(r, f, average_method="weighted")
r, f = zip(*[(metric_results[name + ":" + n], f) for n, f in GAOKAO_CHINESE_TASKS_SCORE.items()])
metric_results[name + "[Chinese Weighted Average]"] = avg_metrics(r, f, average_method="weighted")
r, f = zip(*[(metric_results[name + ":" + n], f) for n, f in GAOKAO_ENGLISH_TASKS_SCORE.items()])
metric_results[name + "[English Weighted Average]"] = avg_metrics(r, f, average_method="weighted")
r, f = zip(*[(metric_results[name + ":" + n], f) for n, f in GAOKAO_TASKS_SCORE.items()])
metric_results[name + "[Weighted Average]"] = avg_metrics(r, f, average_method="weighted")

results[name + "[Marco Average]"] = avg_metrics([r for k, r in results.items() if k.startswith(name + ":")])
metric_results[name + "[Marco Average]"] = avg_metrics([r for k, r in metric_results.items() if k.startswith(name + ":")])

return results, score_lists
self.log_final_results(raw_predictions, predictions, score_lists)
return metric_results

def get_batch_sampler(self, reload_tokenizer: bool = False):
if reload_tokenizer:
Expand Down
20 changes: 2 additions & 18 deletions utilization/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from logging import getLogger
from statistics import mode
from typing import Any, Callable, Dict, List, Optional

from .load_dataset import load_datasets
Expand Down Expand Up @@ -129,24 +128,9 @@ def evaluate(self) -> Dict[str, Dict[str, float]]:
f"The number of results {len(raw_predictions)} should be equal to the number of samples in the dataset {self.dataset.len()}."
)

# post processing and self-consistency
predictions = self.dataset.post_processing(raw_predictions)
if len(predictions) != self.dataset.len(option_num=False, normalization=False):
raise RuntimeError(
f"The number of results {len(predictions)} should be equal to the number of samples in the dataset {self.dataset.len(option_num=False, normalization=False)}."
)

step = self.dataset.len(option_num=False, sample_num=False, normalization=False)
if self.dataset_args.pass_at_k:
mode_predictions = [predictions[i::step] for i in range(step)]
elif len(predictions) // step > 1:
mode_predictions = [mode(predictions[i::step]) for i in range(step)]
else:
mode_predictions = predictions

# calculate metric
metric_results, last_score_lists = self.dataset.calculate_metric(mode_predictions)
self.dataset.log_final_results(raw_predictions, predictions, last_score_lists)
metric_results = self.dataset.calculate_metric(raw_predictions)

msg = f"Evaluation finished successfully:\nevaluation results: {self.dataset_args.evaluation_results_path}"
for display_name, result in metric_results.items():
if result is None:
Expand Down
2 changes: 1 addition & 1 deletion utilization/load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def get_subsets(
found_config = True
break
except Exception as e:
logger.info(f"Failed when trying to get_dataset_config_names({path}): {e}")
logger.info(f"Failed when trying to get_dataset_config_names({path}): {e}. Trying another method...")

logger.debug(f"get_dataset_config_names({path}): {s}")

Expand Down

0 comments on commit c0a398f

Please sign in to comment.