Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/n] torchtune <> llama-stack integration skeleton #540

Merged
merged 21 commits into from
Dec 13, 2024
120 changes: 53 additions & 67 deletions llama_stack/apis/post_training/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,50 +6,60 @@

from datetime import datetime
from enum import Enum

from typing import Any, Dict, List, Optional, Protocol
from typing import Any, Dict, List, Optional, Protocol, Union

from llama_models.schema_utils import json_schema_type, webmethod

from pydantic import BaseModel, Field
from typing_extensions import Annotated

from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403


@json_schema_type
class OptimizerType(Enum):
adam = "adam"
adamw = "adamw"
sgd = "sgd"


@json_schema_type
class DataConfig(BaseModel):
dataset_id: str
batch_size: int
shuffle: bool
validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False
train_on_input: Optional[bool] = False


@json_schema_type
class OptimizerConfig(BaseModel):
optimizer_type: OptimizerType
lr: float
lr_min: float
weight_decay: float
num_warmup_steps: int


@json_schema_type
class TrainingConfig(BaseModel):
n_epochs: int
batch_size: int
shuffle: bool
n_iters: int

enable_activation_checkpointing: bool
memory_efficient_fsdp_wrap: bool
fsdp_cpu_offload: bool
class EfficiencyConfig(BaseModel):
enable_activation_checkpointing: Optional[bool] = False
enable_activation_offloading: Optional[bool] = False
memory_efficient_fsdp_wrap: Optional[bool] = False
fsdp_cpu_offload: Optional[bool] = False


@json_schema_type
class FinetuningAlgorithm(Enum):
full = "full"
lora = "lora"
qlora = "qlora"
dora = "dora"
class TrainingConfig(BaseModel):
n_epochs: int
max_steps_per_epoch: int
gradient_accumulation_steps: int
data_config: DataConfig
optimizer_config: OptimizerConfig
SLR722 marked this conversation as resolved.
Show resolved Hide resolved
efficiency_config: Optional[EfficiencyConfig] = None
dtype: Optional[str] = "bf16"


@json_schema_type
Expand All @@ -59,16 +69,19 @@ class LoraFinetuningConfig(BaseModel):
apply_lora_to_output: bool
rank: int
alpha: int
use_dora: Optional[bool] = False
quantize_base: Optional[bool] = False


@json_schema_type
class QLoraFinetuningConfig(LoraFinetuningConfig):
pass
class QATFinetuningConfig(BaseModel):
quantizer_name: str
group_size: int


@json_schema_type
class DoraFinetuningConfig(LoraFinetuningConfig):
pass
AlgorithmConfig = Annotated[
Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type")
]


@json_schema_type
Expand Down Expand Up @@ -100,29 +113,6 @@ class DPOAlignmentConfig(BaseModel):
gamma: float


@json_schema_type
class PostTrainingSFTRequest(BaseModel):
"""Request to finetune a model."""

job_uuid: str

model: str
dataset_id: str
validation_dataset_id: str

algorithm: FinetuningAlgorithm
algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
]

optimizer_config: OptimizerConfig
training_config: TrainingConfig

# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]


@json_schema_type
class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model."""
Expand All @@ -135,7 +125,7 @@ class PostTrainingRLHFRequest(BaseModel):
validation_dataset_id: str

algorithm: RLHFAlgorithm
algorithm_config: Union[DPOAlignmentConfig]
algorithm_config: DPOAlignmentConfig
SLR722 marked this conversation as resolved.
Show resolved Hide resolved

optimizer_config: OptimizerConfig
training_config: TrainingConfig
Expand Down Expand Up @@ -177,53 +167,49 @@ class PostTrainingJobArtifactsResponse(BaseModel):

class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune")
def supervised_fine_tune(
async def supervised_fine_tune(
SLR722 marked this conversation as resolved.
Show resolved Hide resolved
self,
job_uuid: str,
model: str,
dataset_id: str,
validation_dataset_id: str,
algorithm: FinetuningAlgorithm,
algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
],
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ...

@webmethod(route="/post-training/preference-optimize")
def preference_optimize(
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: URL,
dataset_id: str,
validation_dataset_id: str,
algorithm: RLHFAlgorithm,
algorithm_config: Union[DPOAlignmentConfig],
optimizer_config: OptimizerConfig,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...

@webmethod(route="/post-training/jobs")
def get_training_jobs(self) -> List[PostTrainingJob]: ...
async def get_training_jobs(self) -> List[PostTrainingJob]: ...

# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...

@webmethod(route="/post-training/job/status")
def get_training_job_status(
async def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...

@webmethod(route="/post-training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: ...
async def cancel_training_job(self, job_uuid: str) -> None: ...

@webmethod(route="/post-training/job/artifacts")
def get_training_job_artifacts(
async def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...
2 changes: 2 additions & 0 deletions llama_stack/distribution/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
Expand Down Expand Up @@ -58,6 +59,7 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.scoring_functions: ScoringFunctions,
Api.eval: Eval,
Api.eval_tasks: EvalTasks,
Api.post_training: PostTraining,
}


Expand Down
1 change: 1 addition & 0 deletions llama_stack/providers/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Api(Enum):
datasetio = "datasetio"
scoring = "scoring"
eval = "eval"
post_training = "post_training"

telemetry = "telemetry"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Dict

from llama_stack.distribution.datatypes import Api, ProviderSpec

from .config import TorchtunePostTrainingConfig

# post_training api and the torchtune provider is still experimental and under heavy development


async def get_provider_impl(
config: TorchtunePostTrainingConfig,
deps: Dict[Api, ProviderSpec],
):
from .post_training import TorchtunePostTrainingImpl

impl = TorchtunePostTrainingImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
)
return impl
13 changes: 13 additions & 0 deletions llama_stack/providers/inline/post_training/torchtune/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Optional

from pydantic import BaseModel


class TorchtunePostTrainingConfig(BaseModel):
torch_seed: Optional[int] = None
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Mapping

import numpy as np

from torch.utils.data import Dataset
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._messages import validate_messages
from torchtune.modules.transforms import Transform


class SFTDataset(Dataset):
def __init__(
self,
rows: List[Dict[str, Any]],
message_transform: Transform,
model_transform: Transform,
) -> None:
self._rows = rows
self._message_transform = message_transform
self._model_transform = model_transform

def __len__(self):
return len(self._rows)

def __getitem__(self, index: int) -> Dict[str, Any]:
sample = self._rows[index]
return self._prepare_sample(sample)

def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
transformed_sample = self._message_transform(sample)
if "messages" in transformed_sample:
validate_messages(transformed_sample["messages"])

tokenized_dict = self._model_transform(transformed_sample)

if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
keys_str = ", ".join(tokenized_dict.keys())
error_message = (
"model_transform returned the following keys: "
f"{keys_str}. Must return 'tokens' and 'mask' as keys."
)
raise ValueError(error_message)

# Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens
tokenized_dict["labels"] = list(
np.where(
tokenized_dict["mask"],
CROSS_ENTROPY_IGNORE_IDX,
tokenized_dict["tokens"],
)
)
assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"])

return tokenized_dict
Loading
Loading