Skip to content

Commit

Permalink
[2/n][torchtune integration] implement job management and return trai…
Browse files Browse the repository at this point in the history
…ning artifacts (#593)

### Context 
In this PR, we 
- Implement the post training job management and get training artifacts
apis
  - get_training_jobs
  - get_training_job_status
  - get_training_job_artifacts
- get_training_job_logstream is deleted since the trace can be directly
accessed by UI with Jaeger
https://llama-stack.readthedocs.io/en/latest/building_applications/telemetry.html#jaeger-to-visualize-traces
- Refactor the post training and training types definition to make them
more intuitive.
- Rewrite the checkpointer to make it compatible with llama-stack file
system and can be recognized during inference


### Test
Unit test
`pytest llama_stack/providers/tests/post_training/test_post_training.py
-m "torchtune_post_training_huggingface_datasetio" -v -s --tb=short
--disable-warnings`

<img width="1506" alt="Screenshot 2024-12-10 at 4 06 17 PM"
src="https://github.com/user-attachments/assets/16225029-bdb7-48c4-9d13-e580cc769c0a">


e2e test with client side call

<img width="888" alt="Screenshot 2024-12-10 at 4 09 44 PM"
src="https://github.com/user-attachments/assets/de375e4c-ef67-4dcc-a045-4037d9489191">
  • Loading branch information
SLR722 authored Dec 13, 2024
1 parent 5764a95 commit c294a01
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 67 deletions.
2 changes: 2 additions & 0 deletions llama_stack/apis/common/job_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ class Job(BaseModel):
class JobStatus(Enum):
completed = "completed"
in_progress = "in_progress"
failed = "failed"
scheduled = "scheduled"
19 changes: 16 additions & 3 deletions llama_stack/apis/common/training_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_models.llama3.api.datatypes import URL
from datetime import datetime
from typing import Optional

from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel


@json_schema_type
class PostTrainingMetric(BaseModel):
epoch: int
train_loss: float
validation_loss: float
perplexity: float


@json_schema_type(schema={"description": "Checkpoint created during training runs"})
class Checkpoint(BaseModel):
iters: int
path: URL
identifier: str
created_at: datetime
epoch: int
post_training_job_id: str
path: str
training_metrics: Optional[PostTrainingMetric] = None
38 changes: 14 additions & 24 deletions llama_stack/apis/post_training/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from datetime import datetime
from enum import Enum

from typing import Any, Dict, List, Optional, Protocol, Union

from llama_models.schema_utils import json_schema_type, webmethod
Expand All @@ -14,6 +15,7 @@
from typing_extensions import Annotated

from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403

Expand Down Expand Up @@ -64,6 +66,7 @@ class TrainingConfig(BaseModel):

@json_schema_type
class LoraFinetuningConfig(BaseModel):
type: Literal["LoRA"] = "LoRA"
lora_attn_modules: List[str]
apply_lora_to_mlp: bool
apply_lora_to_output: bool
Expand All @@ -75,12 +78,13 @@ class LoraFinetuningConfig(BaseModel):

@json_schema_type
class QATFinetuningConfig(BaseModel):
type: Literal["QAT"] = "QAT"
quantizer_name: str
group_size: int


AlgorithmConfig = Annotated[
Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type")
Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")
]


Expand All @@ -92,14 +96,6 @@ class PostTrainingJobLogStream(BaseModel):
log_lines: List[str]


@json_schema_type
class PostTrainingJobStatus(Enum):
running = "running"
completed = "completed"
failed = "failed"
scheduled = "scheduled"


@json_schema_type
class RLHFAlgorithm(Enum):
dpo = "dpo"
Expand Down Expand Up @@ -144,7 +140,7 @@ class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job."""

job_uuid: str
status: PostTrainingJobStatus
status: JobStatus

scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None
Expand All @@ -166,7 +162,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):


class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune")
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
async def supervised_fine_tune(
self,
job_uuid: str,
Expand All @@ -181,7 +177,7 @@ async def supervised_fine_tune(
algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ...

@webmethod(route="/post-training/preference-optimize")
@webmethod(route="/post-training/preference-optimize", method="POST")
async def preference_optimize(
self,
job_uuid: str,
Expand All @@ -192,24 +188,18 @@ async def preference_optimize(
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...

@webmethod(route="/post-training/jobs")
@webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> List[PostTrainingJob]: ...

# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...

@webmethod(route="/post-training/job/status")
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
) -> Optional[PostTrainingJobStatusResponse]: ...

@webmethod(route="/post-training/job/cancel")
@webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ...

@webmethod(route="/post-training/job/artifacts")
@webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...
) -> Optional[PostTrainingJobArtifactsResponse]: ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import os
import shutil
from pathlib import Path
from typing import Any, Dict, List

import torch
from torchtune import training
from torchtune.models import convert_weights
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
from torchtune.utils._logging import get_logger

logger = get_logger("DEBUG")


class TorchtuneCheckpointer:
def __init__(
self,
model_id: str,
training_algorithm: str,
checkpoint_dir: str,
checkpoint_files: List[str],
output_dir: str,
model_type: str,
) -> None:
# Fail fast if ``checkpoint_files`` is invalid
# TODO: support loading more than one file
if len(checkpoint_files) != 1:
raise ValueError(
"Currently we only support reading from a single torchtune checkpoint file. "
f"Got {len(checkpoint_files)} files instead."
)
self._checkpoint_file = checkpoint_files[0]
self._model_id = model_id
self._training_algorithm = training_algorithm
self._checkpoint_dir = Path(checkpoint_dir)
self._model_type = ModelType[model_type]
self._output_dir = output_dir
# get ckpt paths
self._checkpoint_path = Path.joinpath(
self._checkpoint_dir, self._checkpoint_file
)

def load_checkpoint(self) -> Dict[str, Any]:
"""
Load Meta checkpoint from file. Currently only loading from a single file is supported.
"""
state_dict: Dict[str:Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_meta_to_tune,
)

state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
model_state_dict
)
else:
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
model_state_dict
)

# llama3_2 has tied weights, so we need to remove the output.weight key
if self._model_type == ModelType.LLAMA3_2:
logger.info(
"Identified model_type = Llama3_2. Ignoring output.weight in"
" checkpoint in favor of the tok_embedding.weight"
" tied weights."
)
state_dict[training.MODEL_KEY].pop("output.weight")

return state_dict

def save_checkpoint(
self,
state_dict: Dict[str, Any],
epoch: int,
adapter_only: bool = False,
) -> str:
model_file_path = (
Path(self._output_dir)
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
)

model_file_path.mkdir(parents=True, exist_ok=True)

# copy the related files for inference
shutil.copy(
Path.joinpath(self._checkpoint_dir, "params.json"),
Path.joinpath(model_file_path, "params.json"),
)
shutil.copy(
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
Path.joinpath(model_file_path, "tokenizer.model"),
)
shutil.copy(
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
Path.joinpath(model_file_path, "orig_params.json"),
)

if not adapter_only:
model_state_dict = state_dict[training.MODEL_KEY]
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_tune_to_meta,
)

state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
model_state_dict
)
else:
# llama3_2 has tied weights, so we need to add the output.weight key
if (
self._model_type == ModelType.LLAMA3_2
and "output.weight" not in model_state_dict
):
model_state_dict["output.weight"] = model_state_dict[
"tok_embeddings.weight"
]

state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
model_state_dict
)

model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")

torch.save(state_dict[training.MODEL_KEY], model_file_name)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
f"saved to {model_file_name}"
)

if training.ADAPTER_KEY in state_dict:
adapter_file_path = model_file_path / "adapter"
adapter_file_path.mkdir(parents=True, exist_ok=True)
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
f"saved to {adapter_file_name}"
)

elif adapter_only:
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)

print("model_file_path", str(model_file_path))

return str(model_file_path)
Loading

0 comments on commit c294a01

Please sign in to comment.