Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
SLR722 committed Dec 20, 2024
1 parent 945d0a6 commit fbdeead
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions examples/post_training/supervised_fine_tune_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ async def run_main(
port: int,
job_uuid: str,
model: str,
model_descriptor: str,
use_https: bool = False,
checkpoint_dir: Optional[str] = None,
cert_path: Optional[str] = None,
Expand Down Expand Up @@ -92,13 +93,15 @@ async def run_main(
print(f"finished the training job: {training_job.job_uuid}")

response = client.models.register(
model_id=f"{model}-sft-{training_config['n_epochs']-1}",
model_id=f"{model_descriptor}-sft-{training_config['n_epochs']-1}",
provider_id="meta-reference-inference",
provider_model_id="null",
metadata={"llama_model": "meta-llama/Llama-3.2-3B-Instruct"},
metadata={"llama_model": f"{model}"},
)

print(f"registerd model {model}-sft-{training_config['n_epochs']-1} successfully")
print(
f"registerd model {model_descriptor}-sft-{training_config['n_epochs']-1} successfully"
)

response = client.datasets.register(
dataset_id="post_training_eval",
Expand All @@ -116,8 +119,7 @@ async def run_main(
},
)

if response:
print("registered dataset post_training_eval successfully")
print("registered dataset post_training_eval successfully")

eval_rows = client.datasetio.get_rows_paginated(
dataset_id="post_training_eval",
Expand All @@ -138,7 +140,7 @@ async def run_main(
"type": "benchmark",
"eval_candidate": {
"type": "model",
"model": f"{model}-sft-{training_config['n_epochs']-1}",
"model": f"{model_descriptor}-sft-{training_config['n_epochs']-1}",
"sampling_params": {
"temperature": 0.0,
"max_tokens": 4096,
Expand All @@ -157,13 +159,23 @@ def main(
port: int,
job_uuid: str,
model: str,
model_descriptor: str,
use_https: bool = False,
checkpoint_dir: Optional[str] = "null",
cert_path: Optional[str] = None,
):
job_uuid = str(job_uuid)
asyncio.run(
run_main(host, port, job_uuid, model, use_https, checkpoint_dir, cert_path)
run_main(
host,
port,
job_uuid,
model,
model_descriptor,
use_https,
checkpoint_dir,
cert_path,
)
)


Expand Down

0 comments on commit fbdeead

Please sign in to comment.