diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 2999d43af6..3c6918786b 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -6,50 +6,60 @@ from datetime import datetime from enum import Enum - -from typing import Any, Dict, List, Optional, Protocol +from typing import Any, Dict, List, Optional, Protocol, Union from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field +from typing_extensions import Annotated from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.common.training_types import * # noqa: F403 +@json_schema_type class OptimizerType(Enum): adam = "adam" adamw = "adamw" sgd = "sgd" +@json_schema_type +class DataConfig(BaseModel): + dataset_id: str + batch_size: int + shuffle: bool + validation_dataset_id: Optional[str] = None + packed: Optional[bool] = False + train_on_input: Optional[bool] = False + + @json_schema_type class OptimizerConfig(BaseModel): optimizer_type: OptimizerType lr: float - lr_min: float weight_decay: float + num_warmup_steps: int @json_schema_type -class TrainingConfig(BaseModel): - n_epochs: int - batch_size: int - shuffle: bool - n_iters: int - - enable_activation_checkpointing: bool - memory_efficient_fsdp_wrap: bool - fsdp_cpu_offload: bool +class EfficiencyConfig(BaseModel): + enable_activation_checkpointing: Optional[bool] = False + enable_activation_offloading: Optional[bool] = False + memory_efficient_fsdp_wrap: Optional[bool] = False + fsdp_cpu_offload: Optional[bool] = False @json_schema_type -class FinetuningAlgorithm(Enum): - full = "full" - lora = "lora" - qlora = "qlora" - dora = "dora" +class TrainingConfig(BaseModel): + n_epochs: int + max_steps_per_epoch: int + gradient_accumulation_steps: int + data_config: DataConfig + optimizer_config: OptimizerConfig + efficiency_config: Optional[EfficiencyConfig] = None + dtype: Optional[str] = "bf16" @json_schema_type @@ -59,16 +69,19 @@ class LoraFinetuningConfig(BaseModel): apply_lora_to_output: bool rank: int alpha: int + use_dora: Optional[bool] = False + quantize_base: Optional[bool] = False @json_schema_type -class QLoraFinetuningConfig(LoraFinetuningConfig): - pass +class QATFinetuningConfig(BaseModel): + quantizer_name: str + group_size: int -@json_schema_type -class DoraFinetuningConfig(LoraFinetuningConfig): - pass +AlgorithmConfig = Annotated[ + Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type") +] @json_schema_type @@ -100,29 +113,6 @@ class DPOAlignmentConfig(BaseModel): gamma: float -@json_schema_type -class PostTrainingSFTRequest(BaseModel): - """Request to finetune a model.""" - - job_uuid: str - - model: str - dataset_id: str - validation_dataset_id: str - - algorithm: FinetuningAlgorithm - algorithm_config: Union[ - LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig - ] - - optimizer_config: OptimizerConfig - training_config: TrainingConfig - - # TODO: define these - hyperparam_search_config: Dict[str, Any] - logger_config: Dict[str, Any] - - @json_schema_type class PostTrainingRLHFRequest(BaseModel): """Request to finetune a model.""" @@ -135,7 +125,7 @@ class PostTrainingRLHFRequest(BaseModel): validation_dataset_id: str algorithm: RLHFAlgorithm - algorithm_config: Union[DPOAlignmentConfig] + algorithm_config: DPOAlignmentConfig optimizer_config: OptimizerConfig training_config: TrainingConfig @@ -177,53 +167,49 @@ class PostTrainingJobArtifactsResponse(BaseModel): class PostTraining(Protocol): @webmethod(route="/post-training/supervised-fine-tune") - def supervised_fine_tune( + async def supervised_fine_tune( self, job_uuid: str, - model: str, - dataset_id: str, - validation_dataset_id: str, - algorithm: FinetuningAlgorithm, - algorithm_config: Union[ - LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig - ], - optimizer_config: OptimizerConfig, training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], + model: str = Field( + default="Llama3.2-3B-Instruct", + description="Model descriptor from `llama model list`", + ), + checkpoint_dir: Optional[str] = None, + algorithm_config: Optional[AlgorithmConfig] = None, ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize") - def preference_optimize( + async def preference_optimize( self, job_uuid: str, - finetuned_model: URL, - dataset_id: str, - validation_dataset_id: str, - algorithm: RLHFAlgorithm, - algorithm_config: Union[DPOAlignmentConfig], - optimizer_config: OptimizerConfig, + finetuned_model: str, + algorithm_config: DPOAlignmentConfig, training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], ) -> PostTrainingJob: ... @webmethod(route="/post-training/jobs") - def get_training_jobs(self) -> List[PostTrainingJob]: ... + async def get_training_jobs(self) -> List[PostTrainingJob]: ... # sends SSE stream of logs @webmethod(route="/post-training/job/logs") - def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ... + async def get_training_job_logstream( + self, job_uuid: str + ) -> PostTrainingJobLogStream: ... @webmethod(route="/post-training/job/status") - def get_training_job_status( + async def get_training_job_status( self, job_uuid: str ) -> PostTrainingJobStatusResponse: ... @webmethod(route="/post-training/job/cancel") - def cancel_training_job(self, job_uuid: str) -> None: ... + async def cancel_training_job(self, job_uuid: str) -> None: ... @webmethod(route="/post-training/job/artifacts") - def get_training_job_artifacts( + async def get_training_job_artifacts( self, job_uuid: str ) -> PostTrainingJobArtifactsResponse: ... diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 9b3812e9ee..4541b01eb2 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -24,6 +24,7 @@ from llama_stack.apis.memory import Memory from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.models import Models +from llama_stack.apis.post_training import PostTraining from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFunctions @@ -58,6 +59,7 @@ def api_protocol_map() -> Dict[Api, Any]: Api.scoring_functions: ScoringFunctions, Api.eval: Eval, Api.eval_tasks: EvalTasks, + Api.post_training: PostTraining, } diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 080204e450..25c967812f 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -28,6 +28,7 @@ class Api(Enum): datasetio = "datasetio" scoring = "scoring" eval = "eval" + post_training = "post_training" telemetry = "telemetry" diff --git a/llama_stack/providers/inline/post_training/torchtune/__init__.py b/llama_stack/providers/inline/post_training/torchtune/__init__.py new file mode 100644 index 0000000000..7ef8eee013 --- /dev/null +++ b/llama_stack/providers/inline/post_training/torchtune/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict + +from llama_stack.distribution.datatypes import Api, ProviderSpec + +from .config import TorchtunePostTrainingConfig + +# post_training api and the torchtune provider is still experimental and under heavy development + + +async def get_provider_impl( + config: TorchtunePostTrainingConfig, + deps: Dict[Api, ProviderSpec], +): + from .post_training import TorchtunePostTrainingImpl + + impl = TorchtunePostTrainingImpl( + config, + deps[Api.datasetio], + deps[Api.datasets], + ) + return impl diff --git a/llama_stack/providers/inline/post_training/torchtune/config.py b/llama_stack/providers/inline/post_training/torchtune/config.py new file mode 100644 index 0000000000..3ffa55c707 --- /dev/null +++ b/llama_stack/providers/inline/post_training/torchtune/config.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from pydantic import BaseModel + + +class TorchtunePostTrainingConfig(BaseModel): + torch_seed: Optional[int] = None diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py new file mode 100644 index 0000000000..1f91dc73ff --- /dev/null +++ b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Mapping + +import numpy as np + +from torch.utils.data import Dataset +from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX +from torchtune.data._messages import validate_messages +from torchtune.modules.transforms import Transform + + +class SFTDataset(Dataset): + def __init__( + self, + rows: List[Dict[str, Any]], + message_transform: Transform, + model_transform: Transform, + ) -> None: + self._rows = rows + self._message_transform = message_transform + self._model_transform = model_transform + + def __len__(self): + return len(self._rows) + + def __getitem__(self, index: int) -> Dict[str, Any]: + sample = self._rows[index] + return self._prepare_sample(sample) + + def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]: + transformed_sample = self._message_transform(sample) + if "messages" in transformed_sample: + validate_messages(transformed_sample["messages"]) + + tokenized_dict = self._model_transform(transformed_sample) + + if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): + keys_str = ", ".join(tokenized_dict.keys()) + error_message = ( + "model_transform returned the following keys: " + f"{keys_str}. Must return 'tokens' and 'mask' as keys." + ) + raise ValueError(error_message) + + # Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens + tokenized_dict["labels"] = list( + np.where( + tokenized_dict["mask"], + CROSS_ENTROPY_IGNORE_IDX, + tokenized_dict["tokens"], + ) + ) + assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"]) + + return tokenized_dict diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py new file mode 100644 index 0000000000..1987086e10 --- /dev/null +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.providers.inline.post_training.torchtune.config import ( + TorchtunePostTrainingConfig, +) +from llama_stack.apis.post_training import * # noqa +from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import ( + LoraFinetuningSingleDevice, +) + + +class TorchtunePostTrainingImpl: + def __init__( + self, + config: TorchtunePostTrainingConfig, + datasetio_api: DatasetIO, + datasets: Datasets, + ) -> None: + self.config = config + self.datasetio_api = datasetio_api + self.datasets_api = datasets + + async def supervised_fine_tune( + self, + job_uuid: str, + training_config: TrainingConfig, + hyperparam_search_config: Dict[str, Any], + logger_config: Dict[str, Any], + model: str, + checkpoint_dir: Optional[str], + algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], + ) -> PostTrainingJob: + if isinstance(algorithm_config, LoraFinetuningConfig): + recipe = LoraFinetuningSingleDevice( + self.config, + training_config, + hyperparam_search_config, + logger_config, + model, + checkpoint_dir, + algorithm_config, + self.datasetio_api, + self.datasets_api, + ) + await recipe.setup() + await recipe.train() + else: + raise NotImplementedError() + + return PostTrainingJob(job_uuid=job_uuid) + + async def preference_optimize( + self, + job_uuid: str, + finetuned_model: str, + algorithm_config: DPOAlignmentConfig, + training_config: TrainingConfig, + hyperparam_search_config: Dict[str, Any], + logger_config: Dict[str, Any], + ) -> PostTrainingJob: ... + + # TODO @SLR722 impelment below APIs + async def get_training_jobs(self) -> List[PostTrainingJob]: ... + + # sends SSE stream of logs + @webmethod(route="/post-training/job/logs") + async def get_training_job_logstream( + self, job_uuid: str + ) -> PostTrainingJobLogStream: ... + + @webmethod(route="/post-training/job/status") + async def get_training_job_status( + self, job_uuid: str + ) -> PostTrainingJobStatusResponse: ... + + @webmethod(route="/post-training/job/cancel") + async def cancel_training_job(self, job_uuid: str) -> None: ... + + @webmethod(route="/post-training/job/artifacts") + async def get_training_job_artifacts( + self, job_uuid: str + ) -> PostTrainingJobArtifactsResponse: ... diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py new file mode 100644 index 0000000000..7873c7c6f5 --- /dev/null +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -0,0 +1,506 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +import os +import time +from functools import partial +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import torch +from llama_models.sku_list import resolve_model +from llama_stack.apis.datasetio import DatasetIO +from torch import nn +from torchtune import utils as torchtune_utils +from torchtune.training.metric_logging import DiskLogger +from llama_stack.apis.post_training import * # noqa +from llama_stack.distribution.utils.model_utils import model_local_dir + +from llama_stack.providers.inline.post_training.torchtune import utils +from llama_stack.providers.inline.post_training.torchtune.config import ( + TorchtunePostTrainingConfig, +) +from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import modules, training +from torchtune.data import AlpacaToMessages, padded_collate_sft + +from torchtune.modules.loss import CEWithChunkedOutputLoss +from torchtune.modules.peft import ( + get_adapter_params, + get_adapter_state_dict, + get_lora_module_names, + get_merged_lora_ckpt, + load_dora_magnitudes, + set_trainable_params, + validate_missing_and_unexpected_for_lora, +) +from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup + +log = logging.getLogger(__name__) + +from torchtune.models.llama3._tokenizer import Llama3Tokenizer + + +class LoraFinetuningSingleDevice: + # This recipe only supports GPU training + + # This recipe doesn't include several training efficiency setting within origin torchtune repo, including + # - compile + # - activation offloading + + # Resume from checkpoint hasn't been supported yet + # Validation hasn't been supported yet + + # Currently logging only logs limited training metrics to local disk + # will figure out more loggings and how it works with telemetry in future PRs + def __init__( + self, + config: TorchtunePostTrainingConfig, + training_config: TrainingConfig, + hyperparam_search_config: Dict[str, Any], + logger_config: Dict[str, Any], + model: str, + checkpoint_dir: Optional[str], + algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], + datasetio_api: DatasetIO, + datasets_api: Datasets, + ) -> None: + self.training_config = training_config + self.algorithm_config = algorithm_config + self._device = torchtune_utils.get_device(device="cuda") + self._dtype = training.get_dtype(training_config.dtype, device=self._device) + self.model_id = model + + def model_checkpoint_dir(model) -> str: + checkpoint_dir = Path(model_local_dir(model.descriptor())) + + paths = [ + Path(checkpoint_dir / f"consolidated.{ext}") + for ext in ["pth", "00.pth"] + ] + if not any(p.exists() for p in paths): + checkpoint_dir = checkpoint_dir / "original" + + assert checkpoint_dir.exists(), ( + f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. " + f"Please download model using `llama download --model-id {model.descriptor()}`" + ) + return str(checkpoint_dir) + + if checkpoint_dir and checkpoint_dir != "null": + self.checkpoint_dir = config.checkpoint_dir + else: + model = resolve_model(self.model_id) + self.checkpoint_dir = model_checkpoint_dir(model) + + # TODO @SLR722 make it work with get_training_job_artifacts + self._output_dir = self.checkpoint_dir + "/posting_training/" + + self.seed = training.set_seed(seed=config.torch_seed) + self.epochs_run = 0 + self.total_epochs = training_config.n_epochs + self._shuffle = training_config.data_config.shuffle + self._batch_size = training_config.data_config.batch_size + + # this is important for debugging purpose + self.max_steps_per_epoch = training_config.max_steps_per_epoch + self.global_step = 0 + + self._gradient_accumulation_steps = training_config.gradient_accumulation_steps + + self._clip_grad_norm = 1.0 + self._enable_activation_checkpointing = ( + (training_config.efficiency_config.enable_activation_checkpointing) + if training_config.efficiency_config + else False + ) + self._enable_activation_offloading = ( + (training_config.efficiency_config.enable_activation_offloading) + if training_config.efficiency_config + else False + ) + + self.datasetio_api = datasetio_api + self.datasets_api = datasets_api + + async def load_checkpoint(self): + def get_checkpoint_files(checkpoint_dir: str) -> List[str]: + try: + # List all files in the given directory + files = os.listdir(checkpoint_dir) + # Filter files that end with .pth + pth_files = [file for file in files if file.endswith(".pth")] + return pth_files + except FileNotFoundError: + return [f"Error: The directory '{checkpoint_dir}' does not exist."] + + self._checkpointer = training.FullModelMetaCheckpointer( + checkpoint_dir=self.checkpoint_dir, + checkpoint_files=get_checkpoint_files(self.checkpoint_dir), + output_dir=self._output_dir, + model_type=await utils.get_checkpointer_model_type(self.model_id), + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + return checkpoint_dict + + async def setup(self) -> None: + self._metric_logger = DiskLogger(log_dir=self._output_dir) + + checkpoint_dict = await self.load_checkpoint() + + self._model = await self._setup_model( + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, + base_model_state_dict=checkpoint_dict[training.MODEL_KEY], + lora_weights_state_dict=None, + ) + log.info(f"Model is initialized with precision {self._dtype}.") + + self._tokenizer = await self._setup_tokenizer() + log.info("Tokenizer is initialized.") + + self._optimizer = await self._setup_optimizer( + optimizer_config=self.training_config.optimizer_config + ) + log.info("Optimizer is initialized.") + + self._loss_fn = CEWithChunkedOutputLoss() + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + log.info("Loss is initialized.") + + self._sampler, self._dataloader = await self._setup_data( + tokenizer=self._tokenizer, + shuffle=self._shuffle, + batch_size=self._batch_size, + ) + log.info("Dataset and Sampler are initialized.") + + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader and the max_steps_per_epoch param set by the user and is used + # for logging and tracking training state. This should be computed after the dataloader + # has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Learning rate scheduler can only be set up after number of steps + # has been computed + self._lr_scheduler = await self._setup_lr_scheduler( + num_warmup_steps=self.training_config.optimizer_config.num_warmup_steps, + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + log.info("Learning rate scheduler is initialized.") + + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (self._batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + + async def _setup_model( + self, + enable_activation_checkpointing: bool, + enable_activation_offloading: bool, + base_model_state_dict: Dict[str, Any], + lora_weights_state_dict: Optional[Dict[str, Any]] = None, + ) -> nn.Module: + self._lora_rank = self.algorithm_config.rank + self._lora_alpha = self.algorithm_config.alpha + self._lora_attn_modules = list(self.algorithm_config.lora_attn_modules) + self._apply_lora_to_mlp = self.algorithm_config.apply_lora_to_mlp + self._apply_lora_to_output = self.algorithm_config.apply_lora_to_output + self._use_dora = self.algorithm_config.use_dora or False + + with training.set_default_dtype(self._dtype), self._device: + model_type = await utils.get_model_definition(self.model_id) + model = model_type( + lora_attn_modules=self._lora_attn_modules, + apply_lora_to_mlp=self._apply_lora_to_mlp, + apply_lora_to_output=self._apply_lora_to_output, + lora_rank=self._lora_rank, + lora_alpha=self._lora_alpha, + quantize_base=False, + use_dora=self._use_dora, + ) + + self.adapter_params = get_adapter_params(model) + self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()]) + + set_trainable_params(model, self.adapter_params) + + if enable_activation_checkpointing: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + + base_missing, base_unexpected = model.load_state_dict( + base_model_state_dict, strict=False + ) + + # This is for any adapters that need to be initialized after base weights + # have been loaded (e.g. DoRA). + if self._is_dora: + for m in model.modules(): + if hasattr(m, "initialize_dora_magnitude"): + m.initialize_dora_magnitude() + load_dora_magnitudes(model) + if lora_weights_state_dict: + lora_missing, lora_unexpected = model.load_state_dict( + lora_weights_state_dict, strict=False + ) + else: + lora_missing, lora_unexpected = None, None + validate_missing_and_unexpected_for_lora( + lora_attn_modules=self._lora_attn_modules, + apply_lora_to_mlp=self._apply_lora_to_mlp, + apply_lora_to_output=self._apply_lora_to_output, + base_missing=base_missing, + base_unexpected=base_unexpected, + lora_missing=lora_missing, + lora_unexpected=lora_unexpected, + ) + + # Validate model adapter params were loaded in with the expected dtype + training.validate_expected_param_dtype( + self.adapter_params.items(), dtype=self._dtype + ) + + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + return model + + async def _setup_tokenizer( + self, + ) -> Llama3Tokenizer: + tokenizer_path = self.checkpoint_dir + "/tokenizer.model" + tokenizer_type = await utils.get_tokenizer_type(self.model_id) + return tokenizer_type(path=tokenizer_path) + + async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer: + optimizer = torch.optim.AdamW( + params=self._model.parameters(), + lr=optimizer_config.lr, + betas=(0.9, 0.95), + eps=1e-8, + weight_decay=0.1, + ) + return optimizer + + async def _setup_data( + self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int + ) -> Tuple[DistributedSampler, DataLoader]: + dataset_id = self.training_config.data_config.dataset_id + + async def fetch_rows(): + return await self.datasetio_api.get_rows_paginated( + dataset_id=dataset_id, + rows_in_page=-1, + ) + + all_rows = await fetch_rows() + rows = all_rows.rows + + # Curretly only support alpaca instruct dataset + # TODO @SLR722 make the message_transform swappable and support more dataset types + # TODO @SLR722 make the input dataset schema more flexible by exposing column_map + await utils.validate_input_dataset_schema( + datasets_api=self.datasets_api, + dataset_id=dataset_id, + dataset_type="alpaca", + ) + ds = SFTDataset( + rows, + message_transform=AlpacaToMessages(train_on_input=False), + model_transform=tokenizer, + ) + + sampler = DistributedSampler( + ds, + num_replicas=1, + rank=0, + shuffle=shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + sampler=sampler, + batch_size=batch_size, + # dropping last avoids shape issues with compile + flex attention + drop_last=True, + collate_fn=( + partial( + padded_collate_sft, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + ), + ) + + return sampler, dataloader + + async def _setup_lr_scheduler( + self, + num_warmup_steps: int, + num_training_steps: int, + last_epoch: int, + ) -> Optimizer: + lr_scheduler = get_cosine_schedule_with_warmup( + self._optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + return lr_scheduler + + async def save_checkpoint(self, epoch: int) -> None: + ckpt_dict = {} + + adapter_state_dict = get_adapter_state_dict(self._model.state_dict()) + ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) + + # Construct the full state dict with LoRA weights merged into base LLM weights + # Move to CPU to avoid a copy on GPU + state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} + + merged_state_dict = get_merged_lora_ckpt( + state_dict, + rank=self._lora_rank, + alpha=self._lora_alpha, + ) + + ckpt_dict.update({training.MODEL_KEY: merged_state_dict}) + + adapter_config = { + "r": self._lora_rank, + "lora_alpha": self._lora_alpha, + "target_modules": get_lora_module_names( + self._lora_attn_modules, + self._apply_lora_to_mlp, + self._apply_lora_to_output, + ), + "peft_type": "LORA", + } + ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config}) + + self._checkpointer.save_checkpoint( + ckpt_dict, + epoch=epoch, + ) + + async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + # Shape [b, s], needed for the loss not the model + labels = batch.pop("labels") + # run model + with self.activations_handling_ctx: + logits = self._model(**batch) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + + loss = self._loss_fn(logits, labels) + + # free logits otherwise it peaks backward memory + del logits + + return loss + + async def train(self) -> None: + """ + The core training loop. + """ + # Initialize tokens count and running loss (for grad accumulation) + # t0 = time.perf_counter() + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + torchtune_utils.batch_to_device(batch, self._device) + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens + + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + current_loss = await self._loss_step(batch) * current_num_tokens + running_loss += current_loss + current_loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + training.scale_grads(self._model, 1 / num_tokens) + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self._lr_scheduler.step() + # Update the number of steps when the weights are updated + self.global_step += 1 + + loss_to_log = running_loss.item() / num_tokens + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second_per_gpu": num_tokens / time_per_step, + } + log_dict.update(training.get_memory_stats(device=self._device)) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + self.epochs_run += 1 + log.info("Starting checkpoint save...") + await self.save_checkpoint(epoch=curr_epoch) diff --git a/llama_stack/providers/inline/post_training/torchtune/utils.py b/llama_stack/providers/inline/post_training/torchtune/utils.py new file mode 100644 index 0000000000..462cbc21ed --- /dev/null +++ b/llama_stack/providers/inline/post_training/torchtune/utils.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, IAny, nc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Any, Callable, Dict, List + +import torch +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.common.type_system import * # noqa +from llama_models.datatypes import Model +from llama_models.sku_list import resolve_model +from llama_stack.apis.common.type_system import ParamType + +from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b +from torchtune.models.llama3._tokenizer import Llama3Tokenizer +from torchtune.models.llama3_2 import lora_llama3_2_3b + + +class ColumnName(Enum): + instruction = "instruction" + input = "input" + output = "output" + text = "text" + + +class ModelConfig(BaseModel): + model_definition: Any + tokenizer_type: Any + checkpoint_type: str + + +class DatasetSchema(BaseModel): + alpaca: List[Dict[str, ParamType]] + + +MODEL_CONFIGS: Dict[str, ModelConfig] = { + "Llama3.2-3B-Instruct": ModelConfig( + model_definition=lora_llama3_2_3b, + tokenizer_type=llama3_tokenizer, + checkpoint_type="LLAMA3_2", + ), + "Llama-3-8B-Instruct": ModelConfig( + model_definition=lora_llama3_8b, + tokenizer_type=llama3_tokenizer, + checkpoint_type="LLAMA3", + ), +} + + +EXPECTED_DATASET_SCHEMA = DatasetSchema( + alpaca=[ + { + ColumnName.instruction.value: StringType(), + ColumnName.input.value: StringType(), + ColumnName.output.value: StringType(), + ColumnName.text.value: StringType(), + }, + { + ColumnName.instruction.value: StringType(), + ColumnName.input.value: StringType(), + ColumnName.output.value: StringType(), + }, + { + ColumnName.instruction.value: StringType(), + ColumnName.output.value: StringType(), + }, + ] +) + +BuildLoraModelCallable = Callable[..., torch.nn.Module] +BuildTokenizerCallable = Callable[..., Llama3Tokenizer] + + +def _validate_model_id(model_id: str) -> Model: + model = resolve_model(model_id) + if model is None or model.core_model_id.value not in MODEL_CONFIGS: + raise ValueError(f"Model {model_id} is not supported.") + return model + + +async def get_model_definition( + model_id: str, +) -> BuildLoraModelCallable: + model = _validate_model_id(model_id) + model_config = MODEL_CONFIGS[model.core_model_id.value] + if not hasattr(model_config, "model_definition"): + raise ValueError(f"Model {model_id} does not have model definition.") + return model_config.model_definition + + +async def get_tokenizer_type( + model_id: str, +) -> BuildTokenizerCallable: + model = _validate_model_id(model_id) + model_config = MODEL_CONFIGS[model.core_model_id.value] + if not hasattr(model_config, "tokenizer_type"): + raise ValueError(f"Model {model_id} does not have tokenizer_type.") + return model_config.tokenizer_type + + +async def get_checkpointer_model_type( + model_id: str, +) -> str: + """ + checkpointer model type is used in checkpointer for some special treatment on some specific model types + For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041) + """ + model = _validate_model_id(model_id) + model_config = MODEL_CONFIGS[model.core_model_id.value] + if not hasattr(model_config, "checkpoint_type"): + raise ValueError(f"Model {model_id} does not have checkpoint_type.") + return model_config.checkpoint_type + + +async def validate_input_dataset_schema( + datasets_api: Datasets, + dataset_id: str, + dataset_type: str, +) -> None: + dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) + if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: + raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") + + if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type): + raise ValueError(f"Dataset type {dataset_type} is not supported.") + + if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type): + raise ValueError( + f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}" + ) diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py new file mode 100644 index 0000000000..af8b660fa7 --- /dev/null +++ b/llama_stack/providers/registry/post_training.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.distribution.datatypes import * # noqa: F403 + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.post_training, + provider_type="inline::torchtune", + pip_packages=["torch", "torchtune", "torchao", "numpy"], + module="llama_stack.providers.inline.post_training.torchtune", + config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + ), + ] diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 8b73500d0f..4d7831ae3a 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -156,4 +156,5 @@ def pytest_itemcollected(item): "llama_stack.providers.tests.datasetio.fixtures", "llama_stack.providers.tests.scoring.fixtures", "llama_stack.providers.tests.eval.fixtures", + "llama_stack.providers.tests.post_training.fixtures", ] diff --git a/llama_stack/providers/tests/datasetio/fixtures.py b/llama_stack/providers/tests/datasetio/fixtures.py index f0c8cbbe10..d288198ca8 100644 --- a/llama_stack/providers/tests/datasetio/fixtures.py +++ b/llama_stack/providers/tests/datasetio/fixtures.py @@ -10,6 +10,7 @@ from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.tests.resolver import construct_stack_for_test + from ..conftest import ProviderFixture, remote_stack_fixture diff --git a/llama_stack/providers/tests/post_training/__init__.py b/llama_stack/providers/tests/post_training/__init__.py new file mode 100644 index 0000000000..756f351d88 --- /dev/null +++ b/llama_stack/providers/tests/post_training/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/tests/post_training/conftest.py b/llama_stack/providers/tests/post_training/conftest.py new file mode 100644 index 0000000000..14d349106b --- /dev/null +++ b/llama_stack/providers/tests/post_training/conftest.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from ..conftest import get_provider_fixture_overrides + +from ..datasetio.fixtures import DATASETIO_FIXTURES + +from .fixtures import POST_TRAINING_FIXTURES + +DEFAULT_PROVIDER_COMBINATIONS = [ + pytest.param( + { + "post_training": "torchtune", + "datasetio": "huggingface", + }, + id="torchtune_post_training_huggingface_datasetio", + marks=pytest.mark.torchtune_post_training_huggingface_datasetio, + ), +] + + +def pytest_configure(config): + combined_fixtures = "torchtune_post_training_huggingface_datasetio" + config.addinivalue_line( + "markers", + f"{combined_fixtures}: marks tests as {combined_fixtures} specific", + ) + + +def pytest_generate_tests(metafunc): + if "post_training_stack" in metafunc.fixturenames: + available_fixtures = { + "eval": POST_TRAINING_FIXTURES, + "datasetio": DATASETIO_FIXTURES, + } + combinations = ( + get_provider_fixture_overrides(metafunc.config, available_fixtures) + or DEFAULT_PROVIDER_COMBINATIONS + ) + metafunc.parametrize("post_training_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/post_training/fixtures.py b/llama_stack/providers/tests/post_training/fixtures.py new file mode 100644 index 0000000000..3ca48d847f --- /dev/null +++ b/llama_stack/providers/tests/post_training/fixtures.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +import pytest_asyncio + +from llama_models.llama3.api.datatypes import URL +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasets import DatasetInput +from llama_stack.apis.models import ModelInput + +from llama_stack.distribution.datatypes import Api, Provider + +from llama_stack.providers.tests.resolver import construct_stack_for_test + +from ..conftest import ProviderFixture + + +@pytest.fixture(scope="session") +def post_training_torchtune() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="torchtune", + provider_type="inline::torchtune", + config={}, + ) + ], + ) + + +POST_TRAINING_FIXTURES = ["torchtune"] + + +@pytest_asyncio.fixture(scope="session") +async def post_training_stack(request): + fixture_dict = request.param + + providers = {} + provider_data = {} + for key in ["post_training", "datasetio"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if fixture.provider_data: + provider_data.update(fixture.provider_data) + + test_stack = await construct_stack_for_test( + [Api.post_training, Api.datasetio], + providers, + provider_data, + models=[ModelInput(model_id="meta-llama/Llama-3.2-3B-Instruct")], + datasets=[ + DatasetInput( + dataset_id="alpaca", + provider_id="huggingface", + url=URL(uri="https://huggingface.co/datasets/tatsu-lab/alpaca"), + metadata={ + "path": "tatsu-lab/alpaca", + "split": "train", + }, + dataset_schema={ + "instruction": StringType(), + "input": StringType(), + "output": StringType(), + "text": StringType(), + }, + ), + ], + ) + + return test_stack.impls[Api.post_training] diff --git a/llama_stack/providers/tests/post_training/test_post_training.py b/llama_stack/providers/tests/post_training/test_post_training.py new file mode 100644 index 0000000000..a4e2d55c9a --- /dev/null +++ b/llama_stack/providers/tests/post_training/test_post_training.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import pytest +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.post_training import * # noqa: F403 +from llama_stack.distribution.datatypes import * # noqa: F403 + +# How to run this test: +# +# pytest llama_stack/providers/tests/post_training/test_post_training.py +# -m "torchtune_post_training_huggingface_datasetio" +# -v -s --tb=short --disable-warnings + + +class TestPostTraining: + @pytest.mark.asyncio + async def test_supervised_fine_tune(self, post_training_stack): + algorithm_config = LoraFinetuningConfig( + lora_attn_modules=["q_proj", "v_proj", "output_proj"], + apply_lora_to_mlp=True, + apply_lora_to_output=False, + rank=8, + alpha=16, + ) + + data_config = DataConfig( + dataset_id="alpaca", + batch_size=1, + shuffle=False, + ) + + optimizer_config = OptimizerConfig( + optimizer_type="adamw", + lr=3e-4, + lr_min=3e-5, + weight_decay=0.1, + num_warmup_steps=100, + ) + + training_config = TrainingConfig( + n_epochs=1, + data_config=data_config, + optimizer_config=optimizer_config, + max_steps_per_epoch=1, + gradient_accumulation_steps=1, + ) + post_training_impl = post_training_stack + response = await post_training_impl.supervised_fine_tune( + job_uuid="1234", + model="Llama3.2-3B-Instruct", + algorithm_config=algorithm_config, + training_config=training_config, + hyperparam_search_config={}, + logger_config={}, + checkpoint_dir="null", + ) + assert isinstance(response, PostTrainingJob) + assert response.job_uuid == "1234" diff --git a/llama_stack/templates/experimental-post-training/build.yaml b/llama_stack/templates/experimental-post-training/build.yaml new file mode 100644 index 0000000000..1461d05961 --- /dev/null +++ b/llama_stack/templates/experimental-post-training/build.yaml @@ -0,0 +1,13 @@ +version: '2' +name: experimental-post-training +distribution_spec: + description: Experimental template for post training + docker_image: null + providers: + post_training: + - inline::torchtune + datasetio: + - remote::huggingface + telemetry: + - inline::meta-reference +image_type: conda diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml new file mode 100644 index 0000000000..4bdde7aa68 --- /dev/null +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -0,0 +1,53 @@ +version: '2' +image_name: experimental-post-training +docker_image: null +conda_env: experimental-post-training +apis: +- telemetry +- datasetio +- post_training +providers: + datasetio: + - provider_id: huggingface-0 + provider_type: remote::huggingface + config: {} + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} + post_training: + - provider_id: torchtune-post-training + provider_type: inline::torchtune + config: {} + +metadata_store: + namespace: null + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db +models: +- metadata: {} + model_id: ${env.POST_TRAINING_MODEL} + provider_id: meta-reference-inference + provider_model_id: null +shields: [] +memory_banks: [] +datasets: + - dataset_id: alpaca + provider_id: huggingface-0 + url: + uri: https://huggingface.co/datasets/tatsu-lab/alpaca + metadata: + path: tatsu-lab/alpaca + name: + split: train + dataset_schema: + instruction: + type: string + input: + type: string + output: + type: string + text: + type: string +scoring_fns: [] +eval_tasks: []