-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
## What does this PR do? Add post training related CLI to client SDK ## user experience Since kick off supervised finetune job need to setup several configs and hyper-parameters, to make user experience friendly, we provide an example script under examples/post_training/supervised_fine_tune_client.py to kick off the post training job ## test **kick off training** python supervised_fine_tune_client.py "devgpu018.nha2.facebook.com" 5000 "1236" "meta-llama/Llama-3.2-3B-Instruct" <img width="880" alt="Screenshot 2024-12-18 at 11 39 35 AM" src="https://github.com/user-attachments/assets/f5f011c1-005d-4887-8eeb-41c3aa1604d6" /> **get job list** llama-stack-client --endpoint http://devgpu018.nha2.facebook.com:5000 post_training list <img width="880" alt="Screenshot 2024-12-18 at 11 40 27 AM" src="https://github.com/user-attachments/assets/433b7b82-b3fa-473b-8a8f-3a0d046f2959" /> **get job status** llama-stack-client --endpoint http://devgpu018.nha2.facebook.com:5000 post_training status --job-uuid "1235" <img width="894" alt="Screenshot 2024-12-18 at 11 41 34 AM" src="https://github.com/user-attachments/assets/2275cd56-9367-4903-85cb-6000be35f0d9" /> **get job artifacts** llama-stack-client --endpoint http://devgpu018.nha2.facebook.com:5000 post_training artifacts --job-uuid "1235" <img width="888" alt="Screenshot 2024-12-18 at 11 42 22 AM" src="https://github.com/user-attachments/assets/dad2d521-e9b1-409a-93a9-f040c4d6cd24" />
- Loading branch information
Showing
4 changed files
with
239 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the terms described in the LICENSE file in | ||
# the root directory of this source tree. | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. | ||
|
||
import asyncio | ||
from typing import Optional | ||
|
||
import fire | ||
from llama_stack_client import LlamaStackClient | ||
|
||
from llama_stack_client.types.post_training_supervised_fine_tune_params import ( | ||
AlgorithmConfigLoraFinetuningConfig, | ||
TrainingConfig, | ||
TrainingConfigDataConfig, | ||
TrainingConfigEfficiencyConfig, | ||
TrainingConfigOptimizerConfig, | ||
) | ||
|
||
|
||
async def run_main( | ||
host: str, | ||
port: int, | ||
job_uuid: str, | ||
model: str, | ||
use_https: bool = False, | ||
checkpoint_dir: Optional[str] = None, | ||
cert_path: Optional[str] = None, | ||
): | ||
|
||
# Construct the base URL with the appropriate protocol | ||
protocol = "https" if use_https else "http" | ||
base_url = f"{protocol}://{host}:{port}" | ||
|
||
# Configure client with SSL certificate if provided | ||
client_kwargs = {"base_url": base_url} | ||
if use_https and cert_path: | ||
client_kwargs["verify"] = cert_path | ||
|
||
client = LlamaStackClient(**client_kwargs) | ||
|
||
algorithm_config = AlgorithmConfigLoraFinetuningConfig( | ||
type="LoRA", | ||
lora_attn_modules=["q_proj", "v_proj", "output_proj"], | ||
apply_lora_to_mlp=True, | ||
apply_lora_to_output=False, | ||
rank=8, | ||
alpha=16, | ||
) | ||
|
||
data_config = TrainingConfigDataConfig( | ||
dataset_id="alpaca", | ||
validation_dataset_id="alpaca", | ||
batch_size=1, | ||
shuffle=False, | ||
) | ||
|
||
optimizer_config = TrainingConfigOptimizerConfig( | ||
optimizer_type="adamw", | ||
lr=3e-4, | ||
weight_decay=0.1, | ||
num_warmup_steps=100, | ||
) | ||
|
||
effiency_config = TrainingConfigEfficiencyConfig( | ||
enable_activation_checkpointing=True, | ||
) | ||
|
||
training_config = TrainingConfig( | ||
n_epochs=1, | ||
data_config=data_config, | ||
efficiency_config=effiency_config, | ||
optimizer_config=optimizer_config, | ||
max_steps_per_epoch=30, | ||
gradient_accumulation_steps=1, | ||
) | ||
|
||
training_job = client.post_training.supervised_fine_tune( | ||
job_uuid=job_uuid, | ||
model=model, | ||
algorithm_config=algorithm_config, | ||
training_config=training_config, | ||
checkpoint_dir=checkpoint_dir, | ||
# logger_config and hyperparam_search_config haven't been used yet | ||
logger_config={}, | ||
hyperparam_search_config={}, | ||
) | ||
|
||
print(f"finished the training job: {training_job.job_uuid}") | ||
|
||
|
||
def main( | ||
host: str, | ||
port: int, | ||
job_uuid: str, | ||
model: str, | ||
use_https: bool = False, | ||
checkpoint_dir: Optional[str] = "null", | ||
cert_path: Optional[str] = None, | ||
): | ||
job_uuid = str(job_uuid) | ||
asyncio.run( | ||
run_main(host, port, job_uuid, model, use_https, checkpoint_dir, cert_path) | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the terms described in the LICENSE file in | ||
# the root directory of this source tree. | ||
|
||
from .post_training import post_training | ||
|
||
__all__ = ["post_training"] |
117 changes: 117 additions & 0 deletions
117
src/llama_stack_client/lib/cli/post_training/post_training.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the terms described in the LICENSE file in | ||
# the root directory of this source tree. | ||
|
||
from typing import Optional | ||
|
||
import click | ||
|
||
from llama_stack_client.types.post_training_supervised_fine_tune_params import ( | ||
AlgorithmConfig, | ||
TrainingConfig, | ||
) | ||
from rich.console import Console | ||
|
||
from ..common.utils import handle_client_errors | ||
|
||
|
||
@click.group() | ||
def post_training(): | ||
"""Query details about available post_training endpoints on distribution.""" | ||
pass | ||
|
||
|
||
@click.command("supervised_fine_tune") | ||
@click.option("--job-uuid", required=True, help="Job UUID") | ||
@click.option("--model", required=True, help="Model ID") | ||
@click.option("--algorithm-config", required=True, help="Algorithm Config") | ||
@click.option("--training-config", required=True, help="Training Config") | ||
@click.option( | ||
"--checkpoint-dir", required=False, help="Checkpoint Config", default=None | ||
) | ||
@click.pass_context | ||
@handle_client_errors("post_training supervised_fine_tune") | ||
def supervised_fine_tune( | ||
ctx, | ||
job_uuid: str, | ||
model: str, | ||
algorithm_config: AlgorithmConfig, | ||
training_config: TrainingConfig, | ||
checkpoint_dir: Optional[str], | ||
): | ||
"""Kick off a supervised fine tune job""" | ||
client = ctx.obj["client"] | ||
console = Console() | ||
|
||
post_training_job = client.post_training.supervised_fine_tune( | ||
job_uuid=job_uuid, | ||
model=model, | ||
algorithm_config=algorithm_config, | ||
training_config=training_config, | ||
checkpoint_dir=checkpoint_dir, | ||
# logger_config and hyperparam_search_config haven't been used yet | ||
logger_config={}, | ||
hyperparam_search_config={}, | ||
) | ||
console.print(post_training_job.job_uuid) | ||
|
||
|
||
@click.command("list") | ||
@click.pass_context | ||
@handle_client_errors("post_training get_training_jobs") | ||
def get_training_jobs(ctx): | ||
"""Show the list of available post training jobs""" | ||
client = ctx.obj["client"] | ||
console = Console() | ||
|
||
post_training_jobs = client.post_training.job.list() | ||
console.print( | ||
[post_training_job.job_uuid for post_training_job in post_training_jobs] | ||
) | ||
|
||
|
||
@click.command("status") | ||
@click.option("--job-uuid", required=True, help="Job UUID") | ||
@click.pass_context | ||
@handle_client_errors("post_training get_training_job_status") | ||
def get_training_job_status(ctx, job_uuid: str): | ||
"""Show the status of a specific post training job""" | ||
client = ctx.obj["client"] | ||
console = Console() | ||
|
||
job_status_reponse = client.post_training.job.status(job_uuid=job_uuid) | ||
console.print(job_status_reponse) | ||
|
||
|
||
@click.command("artifacts") | ||
@click.option("--job-uuid", required=True, help="Job UUID") | ||
@click.pass_context | ||
@handle_client_errors("post_training get_training_job_artifacts") | ||
def get_training_job_artifacts(ctx, job_uuid: str): | ||
"""Get the training artifacts of a specific post training job""" | ||
client = ctx.obj["client"] | ||
console = Console() | ||
|
||
job_artifacts = client.post_training.job.artifacts(job_uuid=job_uuid) | ||
console.print(job_artifacts) | ||
|
||
|
||
@click.command("cancel") | ||
@click.option("--job-uuid", required=True, help="Job UUID") | ||
@click.pass_context | ||
@handle_client_errors("post_training cancel_training_job") | ||
def cancel_training_job(ctx, job_uuid: str): | ||
"""Cancel the training job""" | ||
client = ctx.obj["client"] | ||
|
||
client.post_training.job.cancel(job_uuid=job_uuid) | ||
|
||
|
||
# Register subcommands | ||
post_training.add_command(supervised_fine_tune) | ||
post_training.add_command(get_training_jobs) | ||
post_training.add_command(get_training_job_status) | ||
post_training.add_command(get_training_job_artifacts) | ||
post_training.add_command(cancel_training_job) |