Skip to content

Commit

Permalink
temp commit
Browse files Browse the repository at this point in the history
  • Loading branch information
SLR722 committed Dec 20, 2024
1 parent 443bec6 commit 945d0a6
Showing 1 changed file with 59 additions and 50 deletions.
109 changes: 59 additions & 50 deletions examples/post_training/supervised_fine_tune_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,56 +91,65 @@ async def run_main(

print(f"finished the training job: {training_job.job_uuid}")

# response = client.datasets.register(
# dataset_id="post_training_eval",
# provider_id="huggingface",
# url={"uri": "https://huggingface.co/datasets/llamastack/evals"},
# metadata={
# "path": "llamastack/evals",
# "name": "evals__simpleqa",
# "split": "train",
# },
# dataset_schema={
# "input_query": {"type": "string"},
# "expected_answer": {"type": "string"},
# "chat_completion_input": {"type": "chat_completion_input"},
# },
# )

# if response:
# print("registered dataset post_training_eval successfully")

# eval_rows = client.datasetio.get_rows_paginated(
# dataset_id="post_training_eval",
# rows_in_page=5,
# )

# client.eval_tasks.register(
# eval_task_id="torchtune::evals",
# dataset_id=f"post_training_eval",
# scoring_functions=["basic::regex_parser_multiple_choice_answer"],
# )

# response = client.eval.evaluate_rows(
# task_id="torchtune::evals",
# input_rows=eval_rows.rows,
# scoring_functions=["basic::regex_parser_multiple_choice_answer"],
# task_config={
# "type": "benchmark",
# "eval_candidate": {
# "type": "model",
# "model": "meta-llama/Llama-3.2-3B-Instruct",
# "sampling_params": {
# "temperature": 0.0,
# "max_tokens": 4096,
# "top_p": 0.9,
# "repeat_penalty": 1.0,
# },
# },
# },
# )

# print(response)
response = client.models.register(
model_id=f"{model}-sft-{training_config['n_epochs']-1}",
provider_id="meta-reference-inference",
provider_model_id="null",
metadata={"llama_model": "meta-llama/Llama-3.2-3B-Instruct"},
)

print(f"registerd model {model}-sft-{training_config['n_epochs']-1} successfully")

response = client.datasets.register(
dataset_id="post_training_eval",
provider_id="huggingface-0",
url={"uri": "https://huggingface.co/datasets/llamastack/evals"},
metadata={
"path": "llamastack/evals",
"name": "evals__simpleqa",
"split": "train",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "chat_completion_input"},
},
)

if response:
print("registered dataset post_training_eval successfully")

eval_rows = client.datasetio.get_rows_paginated(
dataset_id="post_training_eval",
rows_in_page=5,
)

client.eval_tasks.register(
eval_task_id="torchtune::evals",
dataset_id=f"post_training_eval",
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
)

response = client.eval.evaluate_rows(
task_id="torchtune::evals",
input_rows=eval_rows.rows,
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
task_config={
"type": "benchmark",
"eval_candidate": {
"type": "model",
"model": f"{model}-sft-{training_config['n_epochs']-1}",
"sampling_params": {
"temperature": 0.0,
"max_tokens": 4096,
"top_p": 0.9,
"repeat_penalty": 1.0,
},
},
},
)

print(response)


def main(
Expand Down

0 comments on commit 945d0a6

Please sign in to comment.