Skip to content

Commit

Permalink
feat(components): Add role_field_name and model_name as input paramet…
Browse files Browse the repository at this point in the history
…ers to llm_evaluation_preprocessor component to support gemini model's input and output schema

Signed-off-by: Googler <[email protected]>
PiperOrigin-RevId: 641377116
  • Loading branch information
Googler authored and DharmitD committed Jun 10, 2024
1 parent 8e4dfec commit 1cee088
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 0 deletions.
55 changes: 55 additions & 0 deletions backend/src/v2/compiler/argocompiler/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
package argocompiler

import (
"fmt"
"os"
"strings"

wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
"github.com/golang/protobuf/jsonpb"
Expand Down Expand Up @@ -361,6 +363,59 @@ func (c *workflowCompiler) addContainerExecutorTemplate(refName string) string {
extendPodMetadata(&executor.Metadata, k8sExecCfg)
}
}
caBundleCfgMapName := os.Getenv("EXECUTOR_CABUNDLE_CONFIGMAP_NAME")
caBundleCfgMapKey := os.Getenv("EXECUTOR_CABUNDLE_CONFIGMAP_KEY")
caBundleMountPath := os.Getenv("EXECUTOR_CABUNDLE_MOUNTPATH")
if caBundleCfgMapName != "" && caBundleCfgMapKey != "" {
caFile := fmt.Sprintf("%s/%s", caBundleMountPath, caBundleCfgMapKey)
var certDirectories = []string{
caBundleMountPath,
"/etc/ssl/certs",
"/etc/pki/tls/certs",
}
// Add to REQUESTS_CA_BUNDLE for python request library.
executor.Container.Env = append(executor.Container.Env, k8score.EnvVar{
Name: "REQUESTS_CA_BUNDLE",
Value: caFile,
})
// For AWS utilities like cli, and packages.
executor.Container.Env = append(executor.Container.Env, k8score.EnvVar{
Name: "AWS_CA_BUNDLE",
Value: caFile,
})
// OpenSSL default cert file env variable.
// https://www.openssl.org/docs/man1.1.1/man3/SSL_CTX_set_default_verify_paths.html
executor.Container.Env = append(executor.Container.Env, k8score.EnvVar{
Name: "SSL_CERT_FILE",
Value: caFile,
})
sslCertDir := strings.Join(certDirectories, ":")
executor.Container.Env = append(executor.Container.Env, k8score.EnvVar{
Name: "SSL_CERT_DIR",
Value: sslCertDir,
})
volume := k8score.Volume{
Name: volumeNameCABUndle,

Check failure on line 398 in backend/src/v2/compiler/argocompiler/container.go

View workflow job for this annotation

GitHub Actions / backend-tests

undefined: volumeNameCABUndle

Check failure on line 398 in backend/src/v2/compiler/argocompiler/container.go

View workflow job for this annotation

GitHub Actions / run-go-unittests

undefined: volumeNameCABUndle
VolumeSource: k8score.VolumeSource{
ConfigMap: &k8score.ConfigMapVolumeSource{
LocalObjectReference: k8score.LocalObjectReference{
Name: caBundleCfgMapName,
},
},
},
}

executor.Volumes = append(executor.Volumes, volume)

volumeMount := k8score.VolumeMount{
Name: volumeNameCABUndle,

Check failure on line 411 in backend/src/v2/compiler/argocompiler/container.go

View workflow job for this annotation

GitHub Actions / backend-tests

undefined: volumeNameCABUndle

Check failure on line 411 in backend/src/v2/compiler/argocompiler/container.go

View workflow job for this annotation

GitHub Actions / run-go-unittests

undefined: volumeNameCABUndle
MountPath: caFile,
SubPath: caBundleCfgMapKey,
}

executor.Container.VolumeMounts = append(executor.Container.VolumeMounts, volumeMount)

}
c.templates[nameContainerImpl] = executor
c.wf.Spec.Templates = append(c.wf.Spec.Templates, *container, *executor)
return nameContainerExecutor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def evaluation_dataset_preprocessor_internal(
output_dirs: dsl.OutputPath(list),
gcp_resources: dsl.OutputPath(str),
input_field_name: str = 'input_text',
role_field_name: str = 'role',
model_name: str = 'publishers/google/model/text-bison@002',
display_name: str = 'llm_evaluation_dataset_preprocessor_component',
machine_type: str = 'e2-highmem-16',
service_account: str = '',
Expand All @@ -56,6 +58,9 @@ def evaluation_dataset_preprocessor_internal(
gcs_source_uris: A json escaped list of GCS URIs of the input eval dataset.
input_field_name: The field name of the input eval dataset instances that
contains the input prompts to the LLM.
role_field_name: The field name of the role for input eval dataset instances
that contains the input prompts to the LLM.
model_name: Name of the model being used to create model-specific schemas.
machine_type: The machine type of this custom job. If not set, defaulted
to `e2-highmem-16`. More details:
https://cloud.google.com/compute/docs/machine-resource
Expand Down Expand Up @@ -92,6 +97,8 @@ def evaluation_dataset_preprocessor_internal(
f'--eval_dataset_preprocessor={True}',
f'--gcs_source_uris={gcs_source_uris}',
f'--input_field_name={input_field_name}',
f'--role_field_name={role_field_name}',
f'--model_name={model_name}',
f'--output_dirs={output_dirs}',
'--executor_input={{$.json_escape[1]}}',
],
Expand All @@ -109,6 +116,8 @@ def llm_evaluation_dataset_preprocessor_graph_component(
location: str,
gcs_source_uris: List[str],
input_field_name: str = 'input_text',
role_field_name: str = 'role',
model_name: str = 'publishers/google/model/text-bison@002',
display_name: str = 'llm_evaluation_dataset_preprocessor_component',
machine_type: str = 'e2-standard-4',
service_account: str = '',
Expand All @@ -126,6 +135,9 @@ def llm_evaluation_dataset_preprocessor_graph_component(
gcs_source_uris: A list of GCS URIs of the input eval dataset.
input_field_name: The field name of the input eval dataset instances that
contains the input prompts to the LLM.
role_field_name: The field name of the role for input eval dataset
instances that contains the input prompts to the LLM.
model_name: Name of the model being used to create model-specific schemas.
display_name: The name of the Evaluation job.
machine_type: The machine type of this custom job. If not set, defaulted
to `e2-standard-4`. More details:
Expand Down Expand Up @@ -163,6 +175,8 @@ def llm_evaluation_dataset_preprocessor_graph_component(
input_list=gcs_source_uris
).output,
input_field_name=input_field_name,
role_field_name=role_field_name,
model_name=model_name,
display_name=display_name,
machine_type=machine_type,
service_account=service_account,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def evaluation_llm_text_generation_pipeline( # pylint: disable=dangerous-defaul
batch_predict_gcs_destination_output_uri: str,
model_name: str = 'publishers/google/models/text-bison@002',
evaluation_task: str = 'text-generation',
role_field_name: str = 'role',
input_field_name: str = 'input_text',
target_field_name: str = 'output_text',
batch_predict_instances_format: str = 'jsonl',
Expand Down Expand Up @@ -76,6 +77,7 @@ def evaluation_llm_text_generation_pipeline( # pylint: disable=dangerous-defaul
batch_predict_gcs_destination_output_uri: Required. The Google Cloud Storage location of the directory where the eval pipeline output is to be written to.
model_name: The Model name used to run evaluation. Must be a publisher Model or a managed Model sharing the same ancestor location. Starting this job has no impact on any existing deployments of the Model and their resources.
evaluation_task: The task that the large language model will be evaluated on. The evaluation component computes a set of metrics relevant to that specific task. Currently supported tasks are: `summarization`, `question-answering`, `text-generation`.
role_field_name: The field name of the role for input eval dataset instances that contains the input prompts to the LLM.
input_field_name: The field name of the input eval dataset instances that contains the input prompts to the LLM.
target_field_name: The field name of the eval dataset instance that contains an example reference text response. Alternatively referred to as the ground truth (or ground_truth_column) field. If not set, defaulted to `output_text`.
batch_predict_instances_format: The format in which instances are given, must be one of the Model's supportedInputStorageFormats. Only "jsonl" is currently supported. For more details about this input config, see https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#InputConfig.
Expand Down Expand Up @@ -124,6 +126,8 @@ def evaluation_llm_text_generation_pipeline( # pylint: disable=dangerous-defaul
location=location,
gcs_source_uris=batch_predict_gcs_source_uris,
input_field_name=input_field_name,
role_field_name=role_field_name,
model_name=model_name,
machine_type=machine_type,
service_account=service_account,
network=network,
Expand Down

0 comments on commit 1cee088

Please sign in to comment.