Skip to content

Commit

Permalink
[STYLE] PR medium eval (ANNOTATOR_COLUMN) (#98)
Browse files Browse the repository at this point in the history
* ANNOTATOR_COLUMN all caps

* annotator_columns -> annotator_column

* is_randomize_output_order
  • Loading branch information
YannDubs authored Jul 25, 2023
1 parent af29371 commit f1a2f21
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/alpaca_eval/annotators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class BaseAnnotator(abc.ABC):
"""

DEFAULT_BASE_DIR = constants.EVALUATORS_CONFIG_DIR
ANNOTATOR_COLUMN = "annotator"
annotator_column = "annotator"
TMP_MISSING_ANNOTATION = -1
DEFAULT_ANNOTATION_TYPE = int

Expand All @@ -94,7 +94,7 @@ def __init__(
self.seed = seed
self.is_avoid_reannotations = is_avoid_reannotations
self.primary_keys = list(primary_keys)
self.all_keys = self.primary_keys + [self.ANNOTATOR_COLUMN]
self.all_keys = self.primary_keys + [self.annotator_column]
self.other_keys_to_keep = list(other_keys_to_keep)
self.is_store_missing_annotations = is_store_missing_annotations
self.is_raise_if_missing_primary_keys = is_raise_if_missing_primary_keys
Expand Down Expand Up @@ -216,7 +216,7 @@ def _preprocess(self, to_annotate: utils.AnyData) -> pd.DataFrame:
df_to_annotate = df_to_annotate.drop_duplicates(subset=self.primary_keys)

# set the annotater for each example
df_to_annotate[self.ANNOTATOR_COLUMN] = df_to_annotate.apply(
df_to_annotate[self.annotator_column] = df_to_annotate.apply(
lambda x: utils.random_seeded_choice(
# we add "annotator" at the beginning to not use the same seed for all tasks
seed="annotator" + "".join(x[self.random_seed_key]) + str(self.seed),
Expand All @@ -236,7 +236,7 @@ def _annotate(self, df_to_annotate: pd.DataFrame, **decoding_kwargs) -> pd.DataF
df_annotated = df_to_annotate
for annotator in self.annotators.keys():
# only annotate examples that have not been annotated yet
curr_idcs = df_annotated[self.ANNOTATOR_COLUMN] == annotator
curr_idcs = df_annotated[self.annotator_column] == annotator
if self.annotation_key in df_annotated.columns:
curr_idcs &= df_annotated[self.annotation_key].isna()

Expand Down Expand Up @@ -571,11 +571,11 @@ def __call__(self, df_to_annotate: pd.DataFrame, **decoding_kwargs) -> pd.DataFr
### Private methods ###
def _search_fn_completion_parser(self, name: str) -> Callable:
"""Search for a completion parser by name."""
return getattr(completion_parsers, name)
return utils.get_module_attribute(completion_parsers, name)

def _search_processor(self, name: str) -> Type["processors.BaseProcessor"]:
"""Search for a Processor class by name."""
return getattr(processors, name)
return utils.get_module_attribute(processors, name)

def _get_prompt_template(self, prompt_template: utils.AnyPath):
return utils.read_or_return(self.base_dir / prompt_template)
Expand Down
1 change: 1 addition & 0 deletions src/alpaca_eval/annotators/pairwise_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def __init__(
**kwargs,
):
processors_to_kwargs = processors_to_kwargs or {}
self.is_randomize_output_order = is_randomize_output_order
if is_randomize_output_order:
# swith output columns by default
processors_to_kwargs["RandomSwitchTwoColumnsProcessor"] = dict(
Expand Down
2 changes: 2 additions & 0 deletions src/alpaca_eval/completion_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from . import utils

__all__ = ["regex_parser", "lmsys_parser", "ranking_parser", "json_parser", "eval_parser"]


def regex_parser(completion: str, outputs_to_match: dict[str, Any]) -> list[Any]:
r"""Parse a single batch of completions, by returning a sequence of keys in the order in which outputs_to_match
Expand Down
1 change: 1 addition & 0 deletions src/alpaca_eval/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from . import utils

DUMMY_EXAMPLE = dict(instruction="1+1=", output_1="2", input="", output_2="3")
__all__ = ["RandomSwitchTwoColumnsProcessor", "PaddingForBatchesProcessor"]


class BaseProcessor(abc.ABC):
Expand Down
13 changes: 13 additions & 0 deletions src/alpaca_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,16 @@ def get_generator_name(name, model_outputs):
except:
name = "Current model"
return name


def get_module_attribute(module, func_name):
"""getattr but only if it's in __all__"""
if func_name in module.__all__:
return getattr(module, func_name)
elif hasattr(module, func_name):
raise AttributeError(
f"The function {func_name} is not allowed,add it to __all__ if needed."
f" Available functions: {module.__all__}"
)
else:
raise AttributeError(f"The function {func_name} does not exist. Available functions: {module.__all__}")

0 comments on commit f1a2f21

Please sign in to comment.