Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Medhelm #3038

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Medhelm #3038

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/helm/benchmark/adaptation/adapter_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL: str = "multiple_choice_separate_original"
ADAPT_MULTIPLE_CHOICE_SEPARATE_CALIBRATED: str = "multiple_choice_separate_calibrated"
ADAPT_RANKING_BINARY: str = "ranking_binary"

ADAPT_EHR_INSTRUCTION: str = "ehr_instruction"
ADAPT_MULTIPLE_CHOICE_SEPARATE_METHODS: List[str] = [
ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL,
ADAPT_MULTIPLE_CHOICE_SEPARATE_CALIBRATED,
Expand Down
6 changes: 5 additions & 1 deletion src/helm/benchmark/adaptation/adapters/adapter_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from helm.benchmark.adaptation.adapter_spec import (
ADAPT_EHR_INSTRUCTION,
ADAPT_GENERATION,
ADAPT_GENERATION_MULTIMODAL,
ADAPT_LANGUAGE_MODELING,
Expand All @@ -21,6 +22,7 @@
from helm.benchmark.adaptation.adapters.multiple_choice_joint_adapter import MultipleChoiceJointAdapter
from helm.benchmark.adaptation.adapters.multiple_choice_separate_adapter import MultipleChoiceSeparateAdapter
from helm.benchmark.window_services.tokenizer_service import TokenizerService
from helm.benchmark.adaptation.adapters.ehr_instruction_adapter import EHRInstructionAdapter


class AdapterFactory:
Expand All @@ -32,7 +34,9 @@ def get_adapter(adapter_spec: AdapterSpec, tokenizer_service: TokenizerService)
method: str = adapter_spec.method
adapter: Adapter

if method == ADAPT_GENERATION:
if method == ADAPT_EHR_INSTRUCTION:
adapter = EHRInstructionAdapter(adapter_spec, tokenizer_service)
elif method == ADAPT_GENERATION:
adapter = GenerationAdapter(adapter_spec, tokenizer_service)
elif method == ADAPT_LANGUAGE_MODELING:
adapter = LanguageModelingAdapter(adapter_spec, tokenizer_service)
Expand Down
106 changes: 106 additions & 0 deletions src/helm/benchmark/adaptation/adapters/ehr_instruction_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import List, Optional

from helm.benchmark.adaptation.prompt import Prompt
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.scenarios.scenario import TRAIN_SPLIT, EHRInstructionInput, Instance
from helm.benchmark.window_services.window_service import EncodeResult
from helm.common.tokenization_request import TokenizationToken

from .generation_adapter import GenerationAdapter

# in the prompt templates for EHR instructions, this is the placeholder for the EHR part
# which we use to compute accurate tokenized sequence lengths
PROMPT_TEMPLATE_EHR_PLACEHOLDER = "{ehr}"


class EHRInstructionAdapter(GenerationAdapter):
"""
Each instance consists of the following:

EHRInstructionInput:
question: the question to answer or instruction to follow
ehr: the XML-tagged EHR to use as context to answer the question
prompt_template: a string template for how to combine the question + ehr

Reference output:
text: the 'golden' clinician response to the question

This Adapter combines the above into RequestStates with logic to truncate the EHR specifically
to fit in the context window with enough room for the instruction/question and the specified
amount of generated tokens.
"""

def adapt(self, instances: List[Instance], parallelism: int) -> List[RequestState]:
"""
Main adaptation method which takes all instances and turns them into `RequestState` objects.
"""
# sanity check, since for now we assume that there are no training instances at all
if any(instance.split == TRAIN_SPLIT for instance in instances):
raise RuntimeError(f"Got train instances for {self.__class__.__name__} - expected only eval instances.")

# use superclass implementation here
return super().adapt(instances, parallelism)

def construct_prompt(
self,
train_instances: List[Instance], # unused
eval_instance: Instance,
include_output: bool, # unused
reference_index: Optional[int], # unused
) -> Prompt:
"""
Uses the instance to construct a prompt for a given eval instance.

Parameters
----------
eval_instance: Instance
the instance we wish to use to construct the prompt
"""
# start by simply getting the inputs
ehr_input: EHRInstructionInput = eval_instance.input # type: ignore
ehr_text: str = ehr_input.ehr
full_prompt_text: str = ehr_input.prompt_template.format(question=ehr_input.question, ehr=ehr_text)

# insert the question and see how many tokens we have so far
prompt_with_instr_no_ehr_placeholder = ehr_input.prompt_template.format(question=ehr_input.question, ehr="")
num_tokens_no_ehr = self.window_service.get_num_tokens(prompt_with_instr_no_ehr_placeholder)

# number of tokens we can allow the EHR part to be
target_ehr_num_tokens = (
self.window_service.max_request_length - self.adapter_spec.max_tokens - num_tokens_no_ehr
)

# round-trip tokenization to get the correct token length we need
# NOTE: we truncate from the left side so that the most recent pieces of the EHR are included in the context
# as opposed to the canonical way of truncating from the right. This is done to match the MedAlign method.
full_ehr_tokens: EncodeResult = self.window_service.encode(ehr_text, max_length=None, truncation=False)
truncated_ehr_tokens: List[TokenizationToken] = full_ehr_tokens.tokens[-target_ehr_num_tokens:]
ehr_truncated: str
ehr_truncated = self.window_service.decode(truncated_ehr_tokens)

# create the truncated prompt
truncated_prompt_text = ehr_input.prompt_template.format(question=ehr_input.question, ehr=ehr_truncated)
num_truncations = 1
while (
num_extra_tokens := self.adapter_spec.max_tokens
+ self.window_service.get_num_tokens(truncated_prompt_text)
- self.window_service.max_request_length
) > 0:
truncated_ehr_tokens = truncated_ehr_tokens[num_extra_tokens:]
ehr_truncated = self.window_service.decode(truncated_ehr_tokens)
truncated_prompt_text = ehr_input.prompt_template.format(question=ehr_input.question, ehr=ehr_truncated)
num_truncations += 1

# naively construct the full non-truncated prompt
prompt = Prompt(
global_prefix=self.adapter_spec.global_prefix,
global_suffix=self.adapter_spec.global_suffix,
instance_prefix=self.adapter_spec.instance_prefix,
substitutions=self.adapter_spec.substitutions,
instructions_block=self.adapter_spec.instructions,
train_instance_blocks=[],
eval_instance_block=full_prompt_text,
truncated_text=truncated_prompt_text,
)

return prompt
126 changes: 126 additions & 0 deletions src/helm/benchmark/metrics/comet_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import logging
from typing import List

import comet
from more_itertools import one
from torch import nn

from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.scenario_state import ScenarioState
from helm.common.hierarchical_logger import hlog
from helm.common.request import RequestResult

from .metric import Metric, MetricResult
from .metric_name import MetricName
from .metric_service import MetricService
from .statistic import Stat


class CometMetric(Metric):
"""COMET machine translation metric using a regression model.
The model takes a triplet of source sentence, translation, and reference
and computes a score in the range [0, 1] reflecting the quality of the predicted
translation.

Paper:
@inproceedings{rei-etal-2022-comet,
title = "{COMET}-22: Unbabel-{IST} 2022 Submission for the Metrics Shared Task",
author = "Rei, Ricardo and
C. de Souza, Jos{\'e} G. and
Alves, Duarte and
Zerva, Chrysoula and
Farinha, Ana C and
Glushkova, Taisiya and
Lavie, Alon and
Coheur, Luisa and
Martins, Andr{\'e} F. T.",
editor = {Koehn, Philipp and
Barrault, Lo{\"\i}c and
Bojar, Ond{\v{r}}ej and
Bougares, Fethi and
Chatterjee, Rajen and
Costa-juss{\`a}, Marta R. and
Federmann, Christian and
Fishel, Mark and
Fraser, Alexander and
Freitag, Markus and
Graham, Yvette and
Grundkiewicz, Roman and
Guzman, Paco and
Haddow, Barry and
Huck, Matthias and
Jimeno Yepes, Antonio and
Kocmi, Tom and
Martins, Andr{\'e} and
Morishita, Makoto and
Monz, Christof and
Nagata, Masaaki and
Nakazawa, Toshiaki and
Negri, Matteo and
N{\'e}v{\'e}ol, Aur{\'e}lie and
Neves, Mariana and
Popel, Martin and
Turchi, Marco and
Zampieri, Marcos},
booktitle = "Proceedings of the Seventh Conference on Machine Translation (WMT)",
month = dec,
year = "2022",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2022.wmt-1.52",
}
"""

METRIC_NAME = "comet"

def __init__(self, task: str, model_name: str = "Unbabel/wmt22-comet-da", device: str = "cpu"):
self.model_name = model_name
self.comet_scorer: nn.Module = self._load_model(model_name)
self.num_gpus = 0 if device == "cpu" else 1

# suppress warnings from PyTorch Lightning which spams terminal
logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING)
logging.getLogger("lightning.pytorch.accelerators.cuda").setLevel(logging.WARNING)

@staticmethod
def _load_model(model_name: str) -> nn.Module:
"""Load Comet model from the checkpoint.

Returns:
The loaded model.
"""
return comet.load_from_checkpoint(comet.download_model(model_name))

def evaluate(
self, scenario_state: ScenarioState, metric_service: MetricService, eval_cache_path: str, parallelism: int
) -> MetricResult:
hlog(
f"Setting parallelism from {parallelism} to 1, since "
f"evaluating {self.__class__.__name__} with parallelism > 1 seg faults."
)
return super().evaluate(scenario_state, metric_service, eval_cache_path, parallelism=1)

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
) -> List[Stat]:
"""Compute the COMET score for this instance"""
ref = one(r.output.text for r in request_state.instance.references)
src = request_state.instance.input.text

result = request_state.result
if not isinstance(result, RequestResult):
raise TypeError(f"Expected a valid result, but got {result}!")
mt = result.completions[0].text.strip()

# comet requires this exac5 format
data = [dict(ref=ref, src=src, mt=mt)]
output = self.comet_scorer.predict(data, gpus=self.num_gpus, progress_bar=False) # type: ignore
comet_score = output[0][0] # extract the actual score

metric_result = [Stat(MetricName(self.METRIC_NAME)).add(comet_score)]

return metric_result
36 changes: 36 additions & 0 deletions src/helm/benchmark/run_specs/classic_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL,
ADAPT_RANKING_BINARY,
AdapterSpec,
ADAPT_EHR_INSTRUCTION,
)
from helm.benchmark.adaptation.adapters.binary_ranking_adapter import BinaryRankingAdapter
from helm.benchmark.adaptation.common_adapter_specs import (
Expand Down Expand Up @@ -1218,6 +1219,41 @@ def get_medication_qa_spec() -> RunSpec:
)


def get_medalign_adapter_spec() -> AdapterSpec:
return AdapterSpec(
method=ADAPT_EHR_INSTRUCTION,
max_train_instances=0,
max_eval_instances=None,
num_outputs=1,
max_tokens=256, # MedAlign default number of generation tokens
)


def get_comet_metric_specs(args: Dict[str, Any]) -> List[MetricSpec]:
return [MetricSpec(class_name="helm.benchmark.metrics.comet_metric.CometMetric", args=args)]


@run_spec_function("medalign")
def get_medalign_spec(prompt_template: str = "generic.txt") -> RunSpec:
from helm.common.gpu_utils import get_torch_device_name

scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.medalign_scenario.MedAlignScenario",
args={"prompt_template": prompt_template},
)

adapter_spec = get_medalign_adapter_spec()

metric_args = {"task": "medalign", "device": get_torch_device_name()}
return RunSpec(
name="medalign",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=get_summarization_metric_specs(metric_args) + get_comet_metric_specs(metric_args),
groups=["medalign", "med_helm"],
)


@run_spec_function("lextreme")
def get_lextreme_spec(subset: str) -> RunSpec:
from helm.benchmark.scenarios.lextreme_scenario import (
Expand Down
Loading