Skip to content

Commit

Permalink
fix: Zero shot and aggregation on Leaderboard (#1810)
Browse files Browse the repository at this point in the history
* Made join_revision filter out no_revision_available when other revisions have been run on the task

* Fixed zero-shot filtering

* Fixed aggregation of task types

* Ran linting
  • Loading branch information
x-tabdeveloping authored Jan 15, 2025
1 parent bcb2cd9 commit 0acc166
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
14 changes: 7 additions & 7 deletions mteb/leaderboard/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def filter_models(
compatibility,
instructions,
model_size,
zero_shot,
zero_shot_setting,
):
lower, upper = model_size
# Setting to None, when the user doesn't specify anything
Expand All @@ -432,12 +432,12 @@ def filter_models(
tasks = mteb.get_tasks(tasks=task_select)
models_to_keep = set()
for model_meta in model_metas:
is_zero_shot = model_meta.is_zero_shot_on(tasks)
if is_zero_shot is None:
if zero_shot == "hard":
is_model_zero_shot = model_meta.is_zero_shot_on(tasks)
if is_model_zero_shot is None:
if zero_shot_setting == "hard":
continue
if not zero_shot:
if zero_shot != "off":
elif not is_model_zero_shot:
if zero_shot_setting != "off":
continue
models_to_keep.add(model_meta.name)
return list(models_to_keep)
Expand All @@ -460,7 +460,7 @@ def update_models(
compatibility,
instructions,
model_size,
zero_shot,
zero_shot_setting=zero_shot,
)
elapsed = time.time() - start_time
logger.info(f"update_models callback: {elapsed}s")
Expand Down
2 changes: 1 addition & 1 deletion mteb/leaderboard/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def get_means_per_types(per_task: pd.DataFrame):
dict(
model_name=model_name,
task_type=task_type,
score=scores[tasks].mean(),
score=scores[tasks].mean(skipna=False),
)
)
return pd.DataFrame.from_records(records)
Expand Down
12 changes: 10 additions & 2 deletions mteb/load_results/benchmark_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,16 @@ def parse_version(version_str: str) -> Version | None:

def keep_best(group: pd.DataFrame) -> pd.DataFrame:
is_main_revision = group["revision"] == group["main_revision"]
if is_main_revision.sum() == 1:
return group[is_main_revision]
# If the main revision is present we select that
if is_main_revision.sum() > 0:
return group[is_main_revision].head(n=1)
unique_revisions = group["revision"].unique()
# Filtering out no_revision_available if other revisions are present
if (len(unique_revisions) > 1) and (
"no_revision_available" in unique_revisions
):
group = group[group["revision"] != "no_revision_available"]
# If there are any not-NA mteb versions, we select the latest one
if group["mteb_version"].notna().any():
group = group.dropna(subset=["mteb_version"])
group = group.sort_values("mteb_version", ascending=False)
Expand Down

0 comments on commit 0acc166

Please sign in to comment.