From 0d3e79adc7bade905c112160781fed0feef3f595 Mon Sep 17 00:00:00 2001 From: Googler Date: Wed, 21 Aug 2024 12:06:44 -0700 Subject: [PATCH] fix(components): Pass moddel name to eval_runner to process batch prediction's output as per the output schema of model used Signed-off-by: Googler PiperOrigin-RevId: 665977093 --- .../model_evaluation/llm_evaluation/component.py | 3 +++ .../evaluation_llm_text_generation_pipeline.py | 1 + 2 files changed, 4 insertions(+) diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation/component.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation/component.py index e0d118bcb26..fe362b230e9 100644 --- a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation/component.py @@ -32,6 +32,7 @@ def model_evaluation_text_generation( row_based_metrics: Output[Metrics], project: str, location: str, + model_name: str, evaluation_task: str = 'text-generation', target_field_name: str = 'instance.output_text', prediction_field_name: str = 'predictions.content', @@ -55,6 +56,7 @@ def model_evaluation_text_generation( Args: project: The GCP project that runs the pipeline component. location: The GCP region that runs the pipeline component. + model_name: The name of the model to be evaluated. evaluation_task: The task that the large language model will be evaluated on. The evaluation component computes a set of metrics relevant to that specific task. Currently supported tasks are: `summarization`, @@ -124,6 +126,7 @@ def model_evaluation_text_generation( machine_type=machine_type, image_uri=version.LLM_EVAL_IMAGE_TAG, args=[ + f'--model_name={model_name}', f'--evaluation_task={evaluation_task}', f'--target_field_name={target_field_name}', f'--prediction_field_name={prediction_field_name}', diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py b/components/google-cloud/google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py index e9022932463..5a2f75e11b5 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py @@ -151,6 +151,7 @@ def evaluation_llm_text_generation_pipeline( # pylint: disable=dangerous-defaul eval_task = LLMEvaluationTextGenerationOp( project=project, location=location, + model_name=model_name, evaluation_task=evaluation_task, target_field_name=target_field_name, predictions_format=batch_predict_predictions_format,