Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove AdapterSpec from metrics #2244

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 25 additions & 49 deletions src/helm/benchmark/metrics/basic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,13 @@
import numpy as np
import scipy
import calibration as cal
from helm.benchmark.adaptation.scenario_state import ScenarioState
from helm.benchmark.metrics.evaluate_reference_metrics import compute_reference_metrics
from helm.benchmark.metrics.efficiency_metrics import EfficiencyMetric

from helm.common.hierarchical_logger import hlog
from helm.common.request import Token, Sequence
from helm.benchmark.adaptation.adapters.adapter_factory import (
ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL,
ADAPT_MULTIPLE_CHOICE_SEPARATE_CALIBRATED,
ADAPT_RANKING_BINARY,
)
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.metrics.metric import group_request_states_by_train_trial
from helm.benchmark.window_services.window_service import WindowService
from helm.benchmark.window_services.window_service_factory import WindowServiceFactory
from helm.benchmark.window_services.tokenizer_service import TokenizerService
Expand Down Expand Up @@ -107,20 +101,18 @@ class InstancesPerSplitMetric(MetricInterface):
"""Report the average num_instances in each MetricContext across train_trials."""

def evaluate(
self, scenario_state: ScenarioState, metric_service: MetricService, eval_cache_path: str, parallelism: int
self, request_states: List[RequestState], metric_service: MetricService, eval_cache_path: str, parallelism: int
) -> MetricResult:
adapter_spec = scenario_state.adapter_spec
global_stats: Dict[MetricName, Stat] = {}

for train_trial_index in range(adapter_spec.num_train_trials):
for trial_request_states in group_request_states_by_train_trial(request_states):
trial_stats: Dict[MetricName, Stat] = {} # Statistics just for this trial
# Group instances in this train_trial by context.
instances_per_metric_context: Dict[MetricContext, Set[Instance]] = defaultdict(set)
for request_state in scenario_state.request_states:
if request_state.train_trial_index == train_trial_index:
instances_per_metric_context[MetricContext.from_instance(request_state.instance)].add(
request_state.instance
)
for request_state in trial_request_states:
instances_per_metric_context[MetricContext.from_instance(request_state.instance)].add(
request_state.instance
)
for context, instance_set in instances_per_metric_context.items():
stat = Stat(MetricName("num_instances")).add(len(instance_set))
merge_stat(trial_stats, add_context(stat, context))
Expand Down Expand Up @@ -151,25 +143,23 @@ def __repr__(self):

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
) -> List[Stat]:
"""Compute all metrics."""
stats: List[Stat] = []
stats.extend(compute_request_state_metrics(self.efficiency_metric, adapter_spec, request_state, metric_service))
stats.extend(compute_request_state_metrics(self.efficiency_metric, request_state, metric_service))

if len(request_state.instance.references) > 0:
stats.extend(compute_reference_metrics(self.names, adapter_spec, request_state, metric_service))
stats.extend(compute_reference_metrics(self.names, request_state, metric_service))

stats.extend(compute_language_modeling_metrics(adapter_spec, request_state, metric_service))
stats.extend(compute_language_modeling_metrics(request_state, metric_service))

return stats

def evaluate_references(
self,
adapter_spec: AdapterSpec,
reference_request_states: List[RequestState],
metric_service: MetricService,
eval_cache_path: str,
Expand Down Expand Up @@ -218,37 +208,34 @@ def compute_logprob_and_length(request_state: RequestState, window_service: Wind
num_choices = len(references)

tokenizer_service: TokenizerService = metric_service
window_service: WindowService = WindowServiceFactory.get_window_service(
adapter_spec.model_deployment, tokenizer_service
)
model_deployment: str = reference_request_states[0].request.model_deployment
window_service: WindowService = WindowServiceFactory.get_window_service(model_deployment, tokenizer_service)
reference_stats: Dict[ReferenceKey, ReferenceStat] = {}
for request_state in reference_request_states:
assert request_state.reference_index is not None and request_state.request_mode is not None
reference_key = ReferenceKey(request_state.reference_index, request_state.request_mode)
reference_stats[reference_key] = compute_logprob_and_length(request_state, window_service)

if adapter_spec.method in [ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL, ADAPT_RANKING_BINARY]:
is_calibrated = any([request_state.request_mode == "calibration" for request_state in reference_request_states])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "any" here but using reference_request_states[0] to decide model_deployment?

If we are asserting in both cases that they are universal values, maybe we should write a helper to do that assertion?


if is_calibrated:
reference_scores = [
reference_stats[ReferenceKey(i, "original")].logprob
/ reference_stats[ReferenceKey(i, "original")].num_tokens
- reference_stats[ReferenceKey(i, "calibration")].logprob
for i in range(num_choices)
]
elif adapter_spec.method == ADAPT_MULTIPLE_CHOICE_SEPARATE_CALIBRATED:
else:
reference_scores = [
reference_stats[ReferenceKey(i, "original")].logprob
- reference_stats[ReferenceKey(i, "calibration")].logprob
/ reference_stats[ReferenceKey(i, "original")].num_tokens
for i in range(num_choices)
]
else:
raise ValueError(f"Unknown adapter method: {adapter_spec.method}")

stats: List[Stat] = []

general_metrics: Dict[MetricName, Stat] = {}
for request_state in reference_request_states:
for stat in compute_request_state_metrics(
self.efficiency_metric, adapter_spec, request_state, metric_service
):
for stat in compute_request_state_metrics(self.efficiency_metric, request_state, metric_service):
merge_stat(general_metrics, stat)
stats.extend(general_metrics.values())
max_prob = np.max(scipy.special.softmax(reference_scores))
Expand Down Expand Up @@ -284,7 +271,6 @@ def derive_per_instance_stats(self, per_instance_stats: Dict[Instance, List[Stat

def compute_request_state_metrics(
efficiency_metric: EfficiencyMetric,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
) -> List[Stat]:
Expand All @@ -294,20 +280,14 @@ def compute_request_state_metrics(
stats: List[Stat] = []

stats.append(Stat(MetricName("num_references")).add(len(request_state.instance.references)))

# Copy from adapter spec
stats.append(Stat(MetricName("num_train_trials")).add(adapter_spec.num_train_trials))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this Stat not needed?


stats.extend(efficiency_metric.compute_efficiency_metrics(adapter_spec, request_state, metric_service))
stats.extend(_compute_finish_reason_metrics(adapter_spec, request_state, metric_service))
stats.extend(_compute_truncation_metrics(adapter_spec, request_state, metric_service))
stats.extend(efficiency_metric.compute_efficiency_metrics(request_state, metric_service))
stats.extend(_compute_finish_reason_metrics(request_state, metric_service))
stats.extend(_compute_truncation_metrics(request_state, metric_service))

return stats


def _compute_finish_reason_metrics(
adapter_spec: AdapterSpec, request_state: RequestState, metric_service: MetricService
) -> List[Stat]:
def _compute_finish_reason_metrics(request_state: RequestState, metric_service: MetricService) -> List[Stat]:
"""Record how often generation finished due to reaching token limit, stop token(s), or end of text"""
assert request_state.result is not None
sequence = request_state.result.completions[0]
Expand All @@ -327,9 +307,7 @@ def _compute_finish_reason_metrics(
]


def _compute_truncation_metrics(
adapter_spec: AdapterSpec, request_state: RequestState, metric_service: MetricService
) -> List[Stat]:
def _compute_truncation_metrics(request_state: RequestState, metric_service: MetricService) -> List[Stat]:
"""
Record the number of training instances used in the prompt and whether
even the prompt needed to be truncated (once we hit zero training instances).
Expand All @@ -340,9 +318,7 @@ def _compute_truncation_metrics(
]


def compute_language_modeling_metrics(
adapter_spec: AdapterSpec, request_state: RequestState, metric_service: MetricService
) -> List[Stat]:
def compute_language_modeling_metrics(request_state: RequestState, metric_service: MetricService) -> List[Stat]:
"""Compute the logprob and normalization factors for the first completion"""
assert request_state.result is not None
sequence = request_state.result.completions[0]
Expand Down
2 changes: 0 additions & 2 deletions src/helm/benchmark/metrics/cleva_harms_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from helm.common.request import RequestResult
from helm.common.hierarchical_logger import hlog
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.metrics.cleva_metrics_helper import ChineseTokenizer
from helm.proxy.clients.perspective_api_client import PerspectiveAPIClientCredentialsError
from helm.common.general import ensure_file_downloaded, ensure_directory_exists
Expand Down Expand Up @@ -167,7 +166,6 @@ def __repr__(self):

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
Expand Down
7 changes: 2 additions & 5 deletions src/helm/benchmark/metrics/code_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

from helm.common.hierarchical_logger import hlog
from helm.common.request import RequestResult
from helm.benchmark.adaptation.scenario_state import ScenarioState
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.scenarios.code_scenario import CodeReference
from . import code_metrics_helper
from .metric import Metric, MetricResult
Expand Down Expand Up @@ -60,17 +58,16 @@ def __init__(self, names, timeout):
# resource.setrlimit(resource.RLIMIT_AS, (MAXIMUM_MEMORY_BYTES, MAXIMUM_MEMORY_BYTES))

def evaluate(
self, scenario_state: ScenarioState, metric_service: MetricService, eval_cache_path: str, parallelism: int
self, request_states: List[RequestState], metric_service: MetricService, eval_cache_path: str, parallelism: int
) -> MetricResult:
# Running with parallelism > 1 causes the run to get stuck.
hlog(
f"Setting parallelism from {parallelism} to 1, since evaluating code with parallelism > 1 isn't supported."
)
return super().evaluate(scenario_state, metric_service, eval_cache_path, parallelism=1)
return super().evaluate(request_states, metric_service, eval_cache_path, parallelism=1)

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
Expand Down
2 changes: 0 additions & 2 deletions src/helm/benchmark/metrics/copyright_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from nltk.tokenize.treebank import TreebankWordTokenizer

from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.scenarios.scenario import Reference
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.common.request import RequestResult
Expand Down Expand Up @@ -119,7 +118,6 @@ def __init__(self, name: str, normalize_by_prefix_length=False, normalize_newlin

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
Expand Down
16 changes: 5 additions & 11 deletions src/helm/benchmark/metrics/disinformation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.common.request import RequestResult, Sequence
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from .metric import Metric
from .metric_name import MetricName
from .metric_service import MetricService
Expand Down Expand Up @@ -76,17 +75,15 @@ def _fetch_human_evaluation_results(eval_cache_path: str, file_name: str) -> Dic
return json.load(f)


def _compute_wedging_human_eval(
adapter_spec: AdapterSpec, request_state: RequestState, eval_cache_path: str
) -> List[Stat]:
def _compute_wedging_human_eval(request_state: RequestState, eval_cache_path: str) -> List[Stat]:
"""
Reads the file with the human evaluation results for the narrative wedging scenario, finds the annotations
for the instance currently being evaluated, and outputs the human evaluation metrics for that instance.
"""
results: List[Stat] = []
instance_first_line = request_state.instance.input.text.splitlines()[0]
human_evaluations = _fetch_human_evaluation_results(eval_cache_path, WEDGING_HUMAN_EVAL_FILE)
model_results = human_evaluations.get(adapter_spec.model_deployment)
model_results = human_evaluations.get(request_state.request.model_deployment)

if not model_results:
# Trying to evaluate a model we don't have annotations for
Expand Down Expand Up @@ -115,7 +112,6 @@ def _compute_wedging_human_eval(


def _compute_reiteration_human_eval(
adapter_spec: AdapterSpec,
request_state: RequestState,
eval_cache_path: str,
) -> List[Stat]:
Expand All @@ -125,7 +121,7 @@ def _compute_reiteration_human_eval(
"""
results: List[Stat] = []
human_evaluations = _fetch_human_evaluation_results(eval_cache_path, REITERATION_HUMAN_EVAL_FILE)
model_results = human_evaluations.get(adapter_spec.model_deployment)
model_results = human_evaluations.get(request_state.request.model_deployment)
if not model_results:
# Trying to evaluate a model we don't have annotations for
return results
Expand All @@ -152,7 +148,7 @@ def _compute_reiteration_human_eval(
"monte_carlo_entropy": _monte_carlo_entropy,
}

human_metric_fns: Dict[str, Callable[[AdapterSpec, RequestState, str], List[Stat]]] = {
human_metric_fns: Dict[str, Callable[[RequestState, str], List[Stat]]] = {
"wedging": _compute_wedging_human_eval,
"reiteration": _compute_reiteration_human_eval,
}
Expand All @@ -167,7 +163,6 @@ def __init__(self, name):

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
Expand All @@ -190,10 +185,9 @@ def __init__(self, name):

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
) -> List[Stat]:
metrics = self._metric_fn(adapter_spec, request_state, eval_cache_path)
metrics = self._metric_fn(request_state, eval_cache_path)
return metrics
9 changes: 4 additions & 5 deletions src/helm/benchmark/metrics/dry_run_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from helm.common.general import parallel_map
from helm.common.request import Request
from helm.benchmark.adaptation.scenario_state import ScenarioState
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.metrics.statistic import Stat, merge_stat
from helm.benchmark.window_services.window_service import WindowService
Expand Down Expand Up @@ -58,7 +57,7 @@ def __repr__(self):

def evaluate(
self,
scenario_state: ScenarioState,
request_states: List[RequestState],
metric_service: MetricService,
eval_cache_path: str,
parallelism: int,
Expand All @@ -69,7 +68,7 @@ def evaluate(
processor = Processor(token_cost_estimator=self.token_cost_estimator, metric_service=metric_service)
results: List[List[Stat]] = parallel_map(
processor.process,
scenario_state.request_states,
request_states,
parallelism=parallelism,
)

Expand All @@ -81,7 +80,7 @@ def evaluate(
request_state.train_trial_index,
stats,
)
for request_state, stats in zip(scenario_state.request_states, results)
for request_state, stats in zip(request_states, results)
]

# Aggregate
Expand All @@ -90,6 +89,6 @@ def evaluate(
for stat in instance_stats:
merge_stat(stats, stat)

merge_stat(stats, Stat(MetricName("num_requests")).add(len(scenario_state.request_states)))
merge_stat(stats, Stat(MetricName("num_requests")).add(len(request_states)))

return MetricResult(list(stats.values()), per_instance_stats)
7 changes: 2 additions & 5 deletions src/helm/benchmark/metrics/efficiency_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from helm.common.hierarchical_logger import hlog
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.window_services.window_service import WindowService
from helm.benchmark.window_services.window_service_factory import WindowServiceFactory
from helm.benchmark.window_services.tokenizer_service import TokenizerService
Expand Down Expand Up @@ -59,9 +58,7 @@ def __init__(self):
with data_package.joinpath(TRAINING_EFFICIENCY_JSON_FILENAME).open("r") as f:
self.training_efficiency_dict = json.load(f)

def compute_efficiency_metrics(
self, adapter_spec: AdapterSpec, request_state: RequestState, metric_service: MetricService
) -> List[Stat]:
def compute_efficiency_metrics(self, request_state: RequestState, metric_service: MetricService) -> List[Stat]:
"""Compute efficiency metrics for both inference and training.
For inference, we record both the actual runtime and an estimated idealized runtime
for the given request with an optimized software implementation run on A100 GPU(s),
Expand Down Expand Up @@ -89,7 +86,7 @@ def compute_efficiency_metrics(
# and calculate the number of tokens in the prompt.
tokenizer_service: TokenizerService = metric_service
window_service: WindowService = WindowServiceFactory.get_window_service(
adapter_spec.model_deployment, tokenizer_service
request_state.request.model_deployment, tokenizer_service
)
prompt: str = request_state.request.prompt
num_prompt_tokens: int = window_service.get_num_tokens(prompt)
Expand Down
Loading
Loading