diff --git a/examples/post_training/supervised_fine_tune_client.py b/examples/post_training/supervised_fine_tune_client.py index 705dfad..a8b3cc9 100644 --- a/examples/post_training/supervised_fine_tune_client.py +++ b/examples/post_training/supervised_fine_tune_client.py @@ -26,8 +26,9 @@ async def run_main( port: int, job_uuid: str, model: str, - model_descriptor: str, - use_https: bool = False, + run_eval: bool = False, + model_descriptor: Optional[str] = "null", + use_https: Optional[bool] = False, checkpoint_dir: Optional[str] = None, cert_path: Optional[str] = None, ): @@ -92,6 +93,10 @@ async def run_main( print(f"finished the training job: {training_job.job_uuid}") + if not run_eval: + return + + # register the finetuned model before eval response = client.models.register( model_id=f"{model_descriptor}-sft-{training_config['n_epochs']-1}", provider_id="meta-reference-inference", @@ -103,6 +108,10 @@ async def run_main( f"registerd model {model_descriptor}-sft-{training_config['n_epochs']-1} successfully" ) + # register the eval dataset, please see https://llama-stack.readthedocs.io/en/latest/benchmark_evaluations/index.html + # for more details and examples + # this is just an simple example to showcase how to run eval on a finetuned model + # you can register your own eval task base on your need response = client.datasets.register( dataset_id="post_training_eval", provider_id="huggingface-0", @@ -159,8 +168,9 @@ def main( port: int, job_uuid: str, model: str, - model_descriptor: str, - use_https: bool = False, + run_eval: bool = False, + model_descriptor: Optional[str] = "null", + use_https: Optional[bool] = False, checkpoint_dir: Optional[str] = "null", cert_path: Optional[str] = None, ): @@ -171,6 +181,7 @@ def main( port, job_uuid, model, + run_eval, model_descriptor, use_https, checkpoint_dir,