diff --git a/llama_stack/apis/common/job_types.py b/llama_stack/apis/common/job_types.py index ab8ab22dc6..c945bd8ff7 100644 --- a/llama_stack/apis/common/job_types.py +++ b/llama_stack/apis/common/job_types.py @@ -18,3 +18,5 @@ class Job(BaseModel): class JobStatus(Enum): completed = "completed" in_progress = "in_progress" + failed = "failed" + scheduled = "scheduled" diff --git a/llama_stack/apis/common/training_types.py b/llama_stack/apis/common/training_types.py index fd74293ebc..b4bd1b0c6c 100644 --- a/llama_stack/apis/common/training_types.py +++ b/llama_stack/apis/common/training_types.py @@ -4,13 +4,26 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_models.llama3.api.datatypes import URL +from datetime import datetime +from typing import Optional + from llama_models.schema_utils import json_schema_type from pydantic import BaseModel +@json_schema_type +class PostTrainingMetric(BaseModel): + epoch: int + train_loss: float + validation_loss: float + perplexity: float + + @json_schema_type(schema={"description": "Checkpoint created during training runs"}) class Checkpoint(BaseModel): - iters: int - path: URL + identifier: str + created_at: datetime epoch: int + post_training_job_id: str + path: str + training_metrics: Optional[PostTrainingMetric] = None diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 3c6918786b..fdbaa364d7 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -6,6 +6,7 @@ from datetime import datetime from enum import Enum + from typing import Any, Dict, List, Optional, Protocol, Union from llama_models.schema_utils import json_schema_type, webmethod @@ -14,6 +15,7 @@ from typing_extensions import Annotated from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.common.training_types import * # noqa: F403 @@ -64,6 +66,7 @@ class TrainingConfig(BaseModel): @json_schema_type class LoraFinetuningConfig(BaseModel): + type: Literal["LoRA"] = "LoRA" lora_attn_modules: List[str] apply_lora_to_mlp: bool apply_lora_to_output: bool @@ -75,12 +78,13 @@ class LoraFinetuningConfig(BaseModel): @json_schema_type class QATFinetuningConfig(BaseModel): + type: Literal["QAT"] = "QAT" quantizer_name: str group_size: int AlgorithmConfig = Annotated[ - Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type") + Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type") ] @@ -92,14 +96,6 @@ class PostTrainingJobLogStream(BaseModel): log_lines: List[str] -@json_schema_type -class PostTrainingJobStatus(Enum): - running = "running" - completed = "completed" - failed = "failed" - scheduled = "scheduled" - - @json_schema_type class RLHFAlgorithm(Enum): dpo = "dpo" @@ -144,7 +140,7 @@ class PostTrainingJobStatusResponse(BaseModel): """Status of a finetuning job.""" job_uuid: str - status: PostTrainingJobStatus + status: JobStatus scheduled_at: Optional[datetime] = None started_at: Optional[datetime] = None @@ -166,7 +162,7 @@ class PostTrainingJobArtifactsResponse(BaseModel): class PostTraining(Protocol): - @webmethod(route="/post-training/supervised-fine-tune") + @webmethod(route="/post-training/supervised-fine-tune", method="POST") async def supervised_fine_tune( self, job_uuid: str, @@ -181,7 +177,7 @@ async def supervised_fine_tune( algorithm_config: Optional[AlgorithmConfig] = None, ) -> PostTrainingJob: ... - @webmethod(route="/post-training/preference-optimize") + @webmethod(route="/post-training/preference-optimize", method="POST") async def preference_optimize( self, job_uuid: str, @@ -192,24 +188,18 @@ async def preference_optimize( logger_config: Dict[str, Any], ) -> PostTrainingJob: ... - @webmethod(route="/post-training/jobs") + @webmethod(route="/post-training/jobs", method="GET") async def get_training_jobs(self) -> List[PostTrainingJob]: ... - # sends SSE stream of logs - @webmethod(route="/post-training/job/logs") - async def get_training_job_logstream( - self, job_uuid: str - ) -> PostTrainingJobLogStream: ... - - @webmethod(route="/post-training/job/status") + @webmethod(route="/post-training/job/status", method="GET") async def get_training_job_status( self, job_uuid: str - ) -> PostTrainingJobStatusResponse: ... + ) -> Optional[PostTrainingJobStatusResponse]: ... - @webmethod(route="/post-training/job/cancel") + @webmethod(route="/post-training/job/cancel", method="POST") async def cancel_training_job(self, job_uuid: str) -> None: ... - @webmethod(route="/post-training/job/artifacts") + @webmethod(route="/post-training/job/artifacts", method="GET") async def get_training_job_artifacts( self, job_uuid: str - ) -> PostTrainingJobArtifactsResponse: ... + ) -> Optional[PostTrainingJobArtifactsResponse]: ... diff --git a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py new file mode 100644 index 0000000000..688a03c257 --- /dev/null +++ b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import shutil +from pathlib import Path +from typing import Any, Dict, List + +import torch +from torchtune import training +from torchtune.models import convert_weights +from torchtune.training.checkpointing._utils import ModelType, safe_torch_load +from torchtune.utils._logging import get_logger + +logger = get_logger("DEBUG") + + +class TorchtuneCheckpointer: + def __init__( + self, + model_id: str, + training_algorithm: str, + checkpoint_dir: str, + checkpoint_files: List[str], + output_dir: str, + model_type: str, + ) -> None: + # Fail fast if ``checkpoint_files`` is invalid + # TODO: support loading more than one file + if len(checkpoint_files) != 1: + raise ValueError( + "Currently we only support reading from a single torchtune checkpoint file. " + f"Got {len(checkpoint_files)} files instead." + ) + self._checkpoint_file = checkpoint_files[0] + self._model_id = model_id + self._training_algorithm = training_algorithm + self._checkpoint_dir = Path(checkpoint_dir) + self._model_type = ModelType[model_type] + self._output_dir = output_dir + # get ckpt paths + self._checkpoint_path = Path.joinpath( + self._checkpoint_dir, self._checkpoint_file + ) + + def load_checkpoint(self) -> Dict[str, Any]: + """ + Load Meta checkpoint from file. Currently only loading from a single file is supported. + """ + state_dict: Dict[str:Any] = {} + model_state_dict = safe_torch_load(self._checkpoint_path) + if self._model_type == ModelType.LLAMA3_VISION: + from torchtune.models.llama3_2_vision._convert_weights import ( + llama3_vision_meta_to_tune, + ) + + state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune( + model_state_dict + ) + else: + state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune( + model_state_dict + ) + + # llama3_2 has tied weights, so we need to remove the output.weight key + if self._model_type == ModelType.LLAMA3_2: + logger.info( + "Identified model_type = Llama3_2. Ignoring output.weight in" + " checkpoint in favor of the tok_embedding.weight" + " tied weights." + ) + state_dict[training.MODEL_KEY].pop("output.weight") + + return state_dict + + def save_checkpoint( + self, + state_dict: Dict[str, Any], + epoch: int, + adapter_only: bool = False, + ) -> str: + model_file_path = ( + Path(self._output_dir) + / f"{self._model_id}-{self._training_algorithm}-{epoch}" + ) + + model_file_path.mkdir(parents=True, exist_ok=True) + + # copy the related files for inference + shutil.copy( + Path.joinpath(self._checkpoint_dir, "params.json"), + Path.joinpath(model_file_path, "params.json"), + ) + shutil.copy( + Path.joinpath(self._checkpoint_dir, "tokenizer.model"), + Path.joinpath(model_file_path, "tokenizer.model"), + ) + shutil.copy( + Path.joinpath(self._checkpoint_dir, "orig_params.json"), + Path.joinpath(model_file_path, "orig_params.json"), + ) + + if not adapter_only: + model_state_dict = state_dict[training.MODEL_KEY] + if self._model_type == ModelType.LLAMA3_VISION: + from torchtune.models.llama3_2_vision._convert_weights import ( + llama3_vision_tune_to_meta, + ) + + state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta( + model_state_dict + ) + else: + # llama3_2 has tied weights, so we need to add the output.weight key + if ( + self._model_type == ModelType.LLAMA3_2 + and "output.weight" not in model_state_dict + ): + model_state_dict["output.weight"] = model_state_dict[ + "tok_embeddings.weight" + ] + + state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta( + model_state_dict + ) + + model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth") + + torch.save(state_dict[training.MODEL_KEY], model_file_name) + logger.info( + "Model checkpoint of size " + f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB " + f"saved to {model_file_name}" + ) + + if training.ADAPTER_KEY in state_dict: + adapter_file_path = model_file_path / "adapter" + adapter_file_path.mkdir(parents=True, exist_ok=True) + adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth") + torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB " + f"saved to {adapter_file_name}" + ) + + elif adapter_only: + raise ValueError( + "Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights." + ) + + print("model_file_path", str(model_file_path)) + + return str(model_file_path) diff --git a/llama_stack/providers/inline/post_training/torchtune/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py similarity index 100% rename from llama_stack/providers/inline/post_training/torchtune/utils.py rename to llama_stack/providers/inline/post_training/torchtune/common/utils.py diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index 1987086e10..9b1269f16c 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -24,6 +24,11 @@ def __init__( self.datasetio_api = datasetio_api self.datasets_api = datasets + # TODO: assume sync job, will need jobs API for async scheduling + self.jobs_status = {} + self.jobs_list = [] + self.checkpoints_dict = {} + async def supervised_fine_tune( self, job_uuid: str, @@ -32,26 +37,57 @@ async def supervised_fine_tune( logger_config: Dict[str, Any], model: str, checkpoint_dir: Optional[str], - algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], + algorithm_config: Optional[AlgorithmConfig], ) -> PostTrainingJob: + for job in self.jobs_list: + if job_uuid == job.job_uuid: + raise ValueError(f"Job {job_uuid} already exists") + + post_training_job = PostTrainingJob(job_uuid=job_uuid) + + job_status_response = PostTrainingJobStatusResponse( + job_uuid=job_uuid, + status=JobStatus.scheduled, + scheduled_at=datetime.now(), + ) + + self.jobs_list.append(post_training_job) if isinstance(algorithm_config, LoraFinetuningConfig): - recipe = LoraFinetuningSingleDevice( - self.config, - training_config, - hyperparam_search_config, - logger_config, - model, - checkpoint_dir, - algorithm_config, - self.datasetio_api, - self.datasets_api, - ) - await recipe.setup() - await recipe.train() + try: + recipe = LoraFinetuningSingleDevice( + self.config, + job_uuid, + training_config, + hyperparam_search_config, + logger_config, + model, + checkpoint_dir, + algorithm_config, + self.datasetio_api, + self.datasets_api, + ) + + job_status_response.status = JobStatus.in_progress + job_status_response.started_at = datetime.now() + + await recipe.setup() + resources_allocated, checkpoints = await recipe.train() + + self.checkpoints_dict[job_uuid] = checkpoints + job_status_response.resources_allocated = resources_allocated + job_status_response.checkpoints = checkpoints + job_status_response.status = JobStatus.completed + job_status_response.completed_at = datetime.now() + + except Exception: + job_status_response.status = JobStatus.failed + raise else: raise NotImplementedError() - return PostTrainingJob(job_uuid=job_uuid) + self.jobs_status[job_uuid] = job_status_response + + return post_training_job async def preference_optimize( self, @@ -63,24 +99,28 @@ async def preference_optimize( logger_config: Dict[str, Any], ) -> PostTrainingJob: ... - # TODO @SLR722 impelment below APIs - async def get_training_jobs(self) -> List[PostTrainingJob]: ... - - # sends SSE stream of logs - @webmethod(route="/post-training/job/logs") - async def get_training_job_logstream( - self, job_uuid: str - ) -> PostTrainingJobLogStream: ... + async def get_training_jobs(self) -> List[PostTrainingJob]: + return self.jobs_list @webmethod(route="/post-training/job/status") async def get_training_job_status( self, job_uuid: str - ) -> PostTrainingJobStatusResponse: ... + ) -> Optional[PostTrainingJobStatusResponse]: + if job_uuid in self.jobs_status: + return self.jobs_status[job_uuid] + return None @webmethod(route="/post-training/job/cancel") - async def cancel_training_job(self, job_uuid: str) -> None: ... + async def cancel_training_job(self, job_uuid: str) -> None: + raise NotImplementedError("Job cancel is not implemented yet") @webmethod(route="/post-training/job/artifacts") async def get_training_job_artifacts( self, job_uuid: str - ) -> PostTrainingJobArtifactsResponse: ... + ) -> Optional[PostTrainingJobArtifactsResponse]: + if job_uuid in self.checkpoints_dict: + checkpoints = self.checkpoints_dict.get(job_uuid, []) + return PostTrainingJobArtifactsResponse( + job_uuid=job_uuid, checkpoints=checkpoints + ) + return None diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 7873c7c6f5..0714046bf5 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -13,14 +13,20 @@ import torch from llama_models.sku_list import resolve_model + from llama_stack.apis.datasetio import DatasetIO + +from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR +from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( + TorchtuneCheckpointer, +) from torch import nn from torchtune import utils as torchtune_utils from torchtune.training.metric_logging import DiskLogger from llama_stack.apis.post_training import * # noqa from llama_stack.distribution.utils.model_utils import model_local_dir -from llama_stack.providers.inline.post_training.torchtune import utils +from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.config import ( TorchtunePostTrainingConfig, ) @@ -62,16 +68,22 @@ class LoraFinetuningSingleDevice: def __init__( self, config: TorchtunePostTrainingConfig, + job_uuid: str, training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], model: str, checkpoint_dir: Optional[str], - algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], + algorithm_config: Optional[AlgorithmConfig], datasetio_api: DatasetIO, datasets_api: Datasets, ) -> None: + self.job_uuid = job_uuid self.training_config = training_config + if not isinstance(algorithm_config, LoraFinetuningConfig): + raise ValueError( + "You need to speicifc LoraFinetuningConfig for LoRA finetuning" + ) self.algorithm_config = algorithm_config self._device = torchtune_utils.get_device(device="cuda") self._dtype = training.get_dtype(training_config.dtype, device=self._device) @@ -99,8 +111,7 @@ def model_checkpoint_dir(model) -> str: model = resolve_model(self.model_id) self.checkpoint_dir = model_checkpoint_dir(model) - # TODO @SLR722 make it work with get_training_job_artifacts - self._output_dir = self.checkpoint_dir + "/posting_training/" + self._output_dir = str(DEFAULT_CHECKPOINT_DIR) self.seed = training.set_seed(seed=config.torch_seed) self.epochs_run = 0 @@ -140,7 +151,9 @@ def get_checkpoint_files(checkpoint_dir: str) -> List[str]: except FileNotFoundError: return [f"Error: The directory '{checkpoint_dir}' does not exist."] - self._checkpointer = training.FullModelMetaCheckpointer( + self._checkpointer = TorchtuneCheckpointer( + model_id=self.model_id, + training_algorithm="sft", checkpoint_dir=self.checkpoint_dir, checkpoint_files=get_checkpoint_files(self.checkpoint_dir), output_dir=self._output_dir, @@ -150,8 +163,6 @@ def get_checkpoint_files(checkpoint_dir: str) -> List[str]: return checkpoint_dict async def setup(self) -> None: - self._metric_logger = DiskLogger(log_dir=self._output_dir) - checkpoint_dict = await self.load_checkpoint() self._model = await self._setup_model( @@ -370,7 +381,7 @@ async def _setup_lr_scheduler( ) return lr_scheduler - async def save_checkpoint(self, epoch: int) -> None: + async def save_checkpoint(self, epoch: int) -> str: ckpt_dict = {} adapter_state_dict = get_adapter_state_dict(self._model.state_dict()) @@ -400,7 +411,7 @@ async def save_checkpoint(self, epoch: int) -> None: } ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config}) - self._checkpointer.save_checkpoint( + return self._checkpointer.save_checkpoint( ckpt_dict, epoch=epoch, ) @@ -429,20 +440,26 @@ async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: return loss - async def train(self) -> None: + async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]: """ The core training loop. """ # Initialize tokens count and running loss (for grad accumulation) - # t0 = time.perf_counter() t0 = time.perf_counter() running_loss = 0 num_tokens = 0 + # training artifacts + checkpoints = [] + memory_stats = {} + # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): # Update the sampler to ensure data is correctly shuffled across epochs # in case shuffle is True + metric_logger = DiskLogger( + log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}" + ) self._sampler.set_epoch(curr_epoch) for idx, batch in enumerate(self._dataloader): @@ -488,10 +505,14 @@ async def train(self) -> None: "lr": self._optimizer.param_groups[0]["lr"], "tokens_per_second_per_gpu": num_tokens / time_per_step, } - log_dict.update(training.get_memory_stats(device=self._device)) + + memory_stats = training.get_memory_stats(device=self._device) + log_dict.update(memory_stats) + if self._clip_grad_norm is not None: log_dict.update({"grad_norm": grad_norm}) - self._metric_logger.log_dict( + + metric_logger.log_dict( log_dict, step=self.global_step, ) @@ -503,4 +524,14 @@ async def train(self) -> None: self.epochs_run += 1 log.info("Starting checkpoint save...") - await self.save_checkpoint(epoch=curr_epoch) + checkpoint_path = await self.save_checkpoint(epoch=curr_epoch) + checkpoint = Checkpoint( + identifier=f"{self.model_id}-sft-{curr_epoch}", + created_at=datetime.now(), + epoch=curr_epoch, + post_training_job_id=self.job_uuid, + path=checkpoint_path, + ) + checkpoints.append(checkpoint) + + return (memory_stats, checkpoints) diff --git a/llama_stack/providers/tests/post_training/test_post_training.py b/llama_stack/providers/tests/post_training/test_post_training.py index a4e2d55c9a..4ecc05187a 100644 --- a/llama_stack/providers/tests/post_training/test_post_training.py +++ b/llama_stack/providers/tests/post_training/test_post_training.py @@ -19,6 +19,7 @@ class TestPostTraining: @pytest.mark.asyncio async def test_supervised_fine_tune(self, post_training_stack): algorithm_config = LoraFinetuningConfig( + type="LoRA", lora_attn_modules=["q_proj", "v_proj", "output_proj"], apply_lora_to_mlp=True, apply_lora_to_output=False, @@ -59,3 +60,33 @@ async def test_supervised_fine_tune(self, post_training_stack): ) assert isinstance(response, PostTrainingJob) assert response.job_uuid == "1234" + + @pytest.mark.asyncio + async def test_get_training_jobs(self, post_training_stack): + post_training_impl = post_training_stack + jobs_list = await post_training_impl.get_training_jobs() + assert isinstance(jobs_list, List) + assert jobs_list[0].job_uuid == "1234" + + @pytest.mark.asyncio + async def test_get_training_job_status(self, post_training_stack): + post_training_impl = post_training_stack + job_status = await post_training_impl.get_training_job_status("1234") + assert isinstance(job_status, PostTrainingJobStatusResponse) + assert job_status.job_uuid == "1234" + assert job_status.status == JobStatus.completed + assert isinstance(job_status.checkpoints[0], Checkpoint) + + @pytest.mark.asyncio + async def test_get_training_job_artifacts(self, post_training_stack): + post_training_impl = post_training_stack + job_artifacts = await post_training_impl.get_training_job_artifacts("1234") + assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse) + assert job_artifacts.job_uuid == "1234" + assert isinstance(job_artifacts.checkpoints[0], Checkpoint) + assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0" + assert job_artifacts.checkpoints[0].epoch == 0 + assert ( + "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" + in job_artifacts.checkpoints[0].path + )