Skip to content

Commit

Permalink
Allow aggregated tasks within benchmarks
Browse files Browse the repository at this point in the history
Fixes #1231
  • Loading branch information
KennethEnevoldsen committed Jan 17, 2025
1 parent 8bb9026 commit 60a8f0f
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions mteb/load_results/task_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,10 @@ def get_score_fast(self, splits: str | None, languages: str | None) -> float:
raise ValueError("No splits had scores for the specified languages.")
return val_sum / n_val

@classmethod
def from_validated(cls, **data) -> TaskResult:
return cls.model_construct(**data)

def __repr__(self) -> str:
return f"TaskResult(task_name={self.task_name}, scores=...)"

Expand All @@ -497,7 +501,7 @@ def only_main_score(self) -> TaskResult:
}
)
new_res = {**self.to_dict(), "scores": new_scores}
new_res = TaskResult(**new_res)
new_res = TaskResult.from_validated(**new_res)
return new_res

def validate_and_filter_scores(self, task: AbsTask | None = None) -> TaskResult:
Expand Down Expand Up @@ -548,5 +552,5 @@ def validate_and_filter_scores(self, task: AbsTask | None = None) -> TaskResult:
f"{task.metadata.name}: Missing splits {set(splits) - seen_splits}"
)
new_res = {**self.to_dict(), "scores": new_scores}
new_res = TaskResult(**new_res)
new_res = TaskResult.from_validated(**new_res)
return new_res

0 comments on commit 60a8f0f

Please sign in to comment.