Skip to content

Commit

Permalink
post training CLI (#51)
Browse files Browse the repository at this point in the history
## What does this PR do?
Add post training related CLI to client SDK

## user experience
Since kick off supervised finetune job need to setup several configs and
hyper-parameters, to make user experience friendly, we provide an
example script under
examples/post_training/supervised_fine_tune_client.py to kick off the
post training job

## test 
**kick off training**
python supervised_fine_tune_client.py "devgpu018.nha2.facebook.com" 5000
"1236" "meta-llama/Llama-3.2-3B-Instruct"
<img width="880" alt="Screenshot 2024-12-18 at 11 39 35 AM"
src="https://github.com/user-attachments/assets/f5f011c1-005d-4887-8eeb-41c3aa1604d6"
/>




**get job list**
llama-stack-client --endpoint http://devgpu018.nha2.facebook.com:5000
post_training list

<img width="880" alt="Screenshot 2024-12-18 at 11 40 27 AM"
src="https://github.com/user-attachments/assets/433b7b82-b3fa-473b-8a8f-3a0d046f2959"
/>




**get job status**
llama-stack-client --endpoint http://devgpu018.nha2.facebook.com:5000
post_training status --job-uuid "1235"


<img width="894" alt="Screenshot 2024-12-18 at 11 41 34 AM"
src="https://github.com/user-attachments/assets/2275cd56-9367-4903-85cb-6000be35f0d9"
/>

**get job artifacts**
llama-stack-client --endpoint http://devgpu018.nha2.facebook.com:5000
post_training artifacts --job-uuid "1235"
<img width="888" alt="Screenshot 2024-12-18 at 11 42 22 AM"
src="https://github.com/user-attachments/assets/dad2d521-e9b1-409a-93a9-f040c4d6cd24"
/>
  • Loading branch information
SLR722 authored Dec 18, 2024
1 parent a462997 commit b982fec
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 0 deletions.
111 changes: 111 additions & 0 deletions examples/post_training/supervised_fine_tune_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.

import asyncio
from typing import Optional

import fire
from llama_stack_client import LlamaStackClient

from llama_stack_client.types.post_training_supervised_fine_tune_params import (
AlgorithmConfigLoraFinetuningConfig,
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigEfficiencyConfig,
TrainingConfigOptimizerConfig,
)


async def run_main(
host: str,
port: int,
job_uuid: str,
model: str,
use_https: bool = False,
checkpoint_dir: Optional[str] = None,
cert_path: Optional[str] = None,
):

# Construct the base URL with the appropriate protocol
protocol = "https" if use_https else "http"
base_url = f"{protocol}://{host}:{port}"

# Configure client with SSL certificate if provided
client_kwargs = {"base_url": base_url}
if use_https and cert_path:
client_kwargs["verify"] = cert_path

client = LlamaStackClient(**client_kwargs)

algorithm_config = AlgorithmConfigLoraFinetuningConfig(
type="LoRA",
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
apply_lora_to_mlp=True,
apply_lora_to_output=False,
rank=8,
alpha=16,
)

data_config = TrainingConfigDataConfig(
dataset_id="alpaca",
validation_dataset_id="alpaca",
batch_size=1,
shuffle=False,
)

optimizer_config = TrainingConfigOptimizerConfig(
optimizer_type="adamw",
lr=3e-4,
weight_decay=0.1,
num_warmup_steps=100,
)

effiency_config = TrainingConfigEfficiencyConfig(
enable_activation_checkpointing=True,
)

training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
efficiency_config=effiency_config,
optimizer_config=optimizer_config,
max_steps_per_epoch=30,
gradient_accumulation_steps=1,
)

training_job = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
model=model,
algorithm_config=algorithm_config,
training_config=training_config,
checkpoint_dir=checkpoint_dir,
# logger_config and hyperparam_search_config haven't been used yet
logger_config={},
hyperparam_search_config={},
)

print(f"finished the training job: {training_job.job_uuid}")


def main(
host: str,
port: int,
job_uuid: str,
model: str,
use_https: bool = False,
checkpoint_dir: Optional[str] = "null",
cert_path: Optional[str] = None,
):
job_uuid = str(job_uuid)
asyncio.run(
run_main(host, port, job_uuid, model, use_https, checkpoint_dir, cert_path)
)


if __name__ == "__main__":
fire.Fire(main)
2 changes: 2 additions & 0 deletions src/llama_stack_client/lib/cli/llama_stack_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .inference import inference
from .memory_banks import memory_banks
from .models import models
from .post_training import post_training
from .providers import providers
from .scoring_functions import scoring_functions
from .shields import shields
Expand Down Expand Up @@ -75,6 +76,7 @@ def cli(ctx, endpoint: str, config: str | None):
cli.add_command(scoring_functions, "scoring_functions")
cli.add_command(eval, "eval")
cli.add_command(inference, "inference")
cli.add_command(post_training, "post_training")


def main():
Expand Down
9 changes: 9 additions & 0 deletions src/llama_stack_client/lib/cli/post_training/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .post_training import post_training

__all__ = ["post_training"]
117 changes: 117 additions & 0 deletions src/llama_stack_client/lib/cli/post_training/post_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Optional

import click

from llama_stack_client.types.post_training_supervised_fine_tune_params import (
AlgorithmConfig,
TrainingConfig,
)
from rich.console import Console

from ..common.utils import handle_client_errors


@click.group()
def post_training():
"""Query details about available post_training endpoints on distribution."""
pass


@click.command("supervised_fine_tune")
@click.option("--job-uuid", required=True, help="Job UUID")
@click.option("--model", required=True, help="Model ID")
@click.option("--algorithm-config", required=True, help="Algorithm Config")
@click.option("--training-config", required=True, help="Training Config")
@click.option(
"--checkpoint-dir", required=False, help="Checkpoint Config", default=None
)
@click.pass_context
@handle_client_errors("post_training supervised_fine_tune")
def supervised_fine_tune(
ctx,
job_uuid: str,
model: str,
algorithm_config: AlgorithmConfig,
training_config: TrainingConfig,
checkpoint_dir: Optional[str],
):
"""Kick off a supervised fine tune job"""
client = ctx.obj["client"]
console = Console()

post_training_job = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
model=model,
algorithm_config=algorithm_config,
training_config=training_config,
checkpoint_dir=checkpoint_dir,
# logger_config and hyperparam_search_config haven't been used yet
logger_config={},
hyperparam_search_config={},
)
console.print(post_training_job.job_uuid)


@click.command("list")
@click.pass_context
@handle_client_errors("post_training get_training_jobs")
def get_training_jobs(ctx):
"""Show the list of available post training jobs"""
client = ctx.obj["client"]
console = Console()

post_training_jobs = client.post_training.job.list()
console.print(
[post_training_job.job_uuid for post_training_job in post_training_jobs]
)


@click.command("status")
@click.option("--job-uuid", required=True, help="Job UUID")
@click.pass_context
@handle_client_errors("post_training get_training_job_status")
def get_training_job_status(ctx, job_uuid: str):
"""Show the status of a specific post training job"""
client = ctx.obj["client"]
console = Console()

job_status_reponse = client.post_training.job.status(job_uuid=job_uuid)
console.print(job_status_reponse)


@click.command("artifacts")
@click.option("--job-uuid", required=True, help="Job UUID")
@click.pass_context
@handle_client_errors("post_training get_training_job_artifacts")
def get_training_job_artifacts(ctx, job_uuid: str):
"""Get the training artifacts of a specific post training job"""
client = ctx.obj["client"]
console = Console()

job_artifacts = client.post_training.job.artifacts(job_uuid=job_uuid)
console.print(job_artifacts)


@click.command("cancel")
@click.option("--job-uuid", required=True, help="Job UUID")
@click.pass_context
@handle_client_errors("post_training cancel_training_job")
def cancel_training_job(ctx, job_uuid: str):
"""Cancel the training job"""
client = ctx.obj["client"]

client.post_training.job.cancel(job_uuid=job_uuid)


# Register subcommands
post_training.add_command(supervised_fine_tune)
post_training.add_command(get_training_jobs)
post_training.add_command(get_training_job_status)
post_training.add_command(get_training_job_artifacts)
post_training.add_command(cancel_training_job)

0 comments on commit b982fec

Please sign in to comment.