From 7660e8a91cc192ff722e1efa04c2d88e1acf1292 Mon Sep 17 00:00:00 2001 From: Googler Date: Tue, 23 Jul 2024 13:40:56 -0700 Subject: [PATCH] chore(components): Add target_field_name as input parameters to llm_evaluation_preprocessor component to support gemini model's input and output schema Signed-off-by: Googler PiperOrigin-RevId: 655287601 --- .../llm_evaluation_preprocessor/component.py | 12 +++++++++++- .../evaluation_llm_text_generation_pipeline.py | 1 + 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation_preprocessor/component.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation_preprocessor/component.py index 5c2b6f2e2da..f102fe541ac 100644 --- a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation_preprocessor/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation_preprocessor/component.py @@ -39,6 +39,7 @@ def evaluation_dataset_preprocessor_internal( gcp_resources: dsl.OutputPath(str), input_field_name: str = 'input_text', role_field_name: str = 'role', + target_field_name: str = 'ground_truth', model_name: str = 'publishers/google/model/text-bison@002', display_name: str = 'llm_evaluation_dataset_preprocessor_component', machine_type: str = 'e2-highmem-16', @@ -60,6 +61,8 @@ def evaluation_dataset_preprocessor_internal( contains the input prompts to the LLM. role_field_name: The field name of the role for input eval dataset instances that contains the input prompts to the LLM. + target_field_name: The field name of the target for input eval dataset + instances. model_name: Name of the model being used to create model-specific schemas. machine_type: The machine type of this custom job. If not set, defaulted to `e2-highmem-16`. More details: @@ -98,7 +101,10 @@ def evaluation_dataset_preprocessor_internal( f'--gcs_source_uris={gcs_source_uris}', f'--input_field_name={input_field_name}', f'--role_field_name={role_field_name}', - f'--model_name={model_name}', + ( + f'--target_field_name={target_field_name}' + f'--model_name={model_name}' + ), f'--output_dirs={output_dirs}', '--executor_input={{$.json_escape[1]}}', ], @@ -117,6 +123,7 @@ def llm_evaluation_dataset_preprocessor_graph_component( gcs_source_uris: List[str], input_field_name: str = 'input_text', role_field_name: str = 'role', + target_field_name: str = 'ground_truth', model_name: str = 'publishers/google/model/text-bison@002', display_name: str = 'llm_evaluation_dataset_preprocessor_component', machine_type: str = 'e2-standard-4', @@ -137,6 +144,8 @@ def llm_evaluation_dataset_preprocessor_graph_component( contains the input prompts to the LLM. role_field_name: The field name of the role for input eval dataset instances that contains the input prompts to the LLM. + target_field_name: The field name of the target for input eval dataset + instances. model_name: Name of the model being used to create model-specific schemas. display_name: The name of the Evaluation job. machine_type: The machine type of this custom job. If not set, defaulted @@ -176,6 +185,7 @@ def llm_evaluation_dataset_preprocessor_graph_component( ).output, input_field_name=input_field_name, role_field_name=role_field_name, + target_field_name=target_field_name, model_name=model_name, display_name=display_name, machine_type=machine_type, diff --git a/components/google-cloud/google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py b/components/google-cloud/google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py index 534e3afde0a..ba8fabe757b 100644 --- a/components/google-cloud/google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py +++ b/components/google-cloud/google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py @@ -127,6 +127,7 @@ def evaluation_llm_text_generation_pipeline( # pylint: disable=dangerous-defaul gcs_source_uris=batch_predict_gcs_source_uris, input_field_name=input_field_name, role_field_name=role_field_name, + target_field_name=target_field_name, model_name=model_name, machine_type=machine_type, service_account=service_account,