Skip to content

Commit

Permalink
feat: Part 1 - Add FineTuning API (#201)
Browse files Browse the repository at this point in the history
PEFT - LoRA for Fine-Tuning LLMs

This PR is Part 1 in PRs that will integrate fine tuning into Kaito.
This PR is base API code. Future PRs will allow you to specify custom
dataset, load config from configmap, and upload training results as
image to ACR.

---------

Signed-off-by: Ishaan Sehgal <[email protected]>
  • Loading branch information
ishaansehgal99 committed Mar 15, 2024
1 parent ea83144 commit ec8a8e2
Show file tree
Hide file tree
Showing 15 changed files with 417 additions and 15 deletions.
20 changes: 10 additions & 10 deletions .github/matrix-configs.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"model": {
"runs_on": "self-hosted",
"name": "falcon-7b",
"dockerfile": "docker/presets/falcon/Dockerfile",
"dockerfile": "docker/presets/inference/falcon/Dockerfile",
"build_args": "--build-arg FALCON_MODEL_NAME=tiiuae/falcon-7b"
},
"shouldBuildFalcon": "true"
Expand All @@ -12,7 +12,7 @@
"model": {
"runs_on": "self-hosted",
"name": "falcon-7b-instruct",
"dockerfile": "docker/presets/falcon/Dockerfile",
"dockerfile": "docker/presets/inference/falcon/Dockerfile",
"build_args": "--build-arg FALCON_MODEL_NAME=tiiuae/falcon-7b-instruct"
},
"shouldBuildFalcon": "true"
Expand All @@ -22,7 +22,7 @@
"model": {
"runs_on": "self-hosted",
"name": "falcon-40b",
"dockerfile": "docker/presets/falcon/Dockerfile",
"dockerfile": "docker/presets/inference/falcon/Dockerfile",
"build_args": "--build-arg FALCON_MODEL_NAME=tiiuae/falcon-40b"
},
"shouldBuildFalcon": "true"
Expand All @@ -32,7 +32,7 @@
"model": {
"runs_on": "self-hosted",
"name": "falcon-40b-instruct",
"dockerfile": "docker/presets/falcon/Dockerfile",
"dockerfile": "docker/presets/inference/falcon/Dockerfile",
"build_args": "--build-arg FALCON_MODEL_NAME=tiiuae/falcon-40b-instruct"
},
"shouldBuildFalcon": "true"
Expand All @@ -42,7 +42,7 @@
"model": {
"runs_on": "self-hosted",
"name": "llama-2-7b",
"dockerfile": "docker/presets/llama-2/Dockerfile",
"dockerfile": "docker/presets/inference/llama-2/Dockerfile",
"build_args": "--build-arg LLAMA_WEIGHTS=/llama/llama-2-7b --build-arg SRC_DIR=/home/presets/llama-2"
},
"shouldBuildLlama2": "true"
Expand All @@ -52,7 +52,7 @@
"model": {
"runs_on": "self-hosted",
"name": "llama-2-13b",
"dockerfile": "docker/presets/llama-2/Dockerfile",
"dockerfile": "docker/presets/inference/llama-2/Dockerfile",
"build_args": "--build-arg LLAMA_WEIGHTS=/llama/llama-2-13b --build-arg SRC_DIR=/home/presets/llama-2"
},
"shouldBuildLlama2": "true"
Expand All @@ -62,7 +62,7 @@
"model": {
"runs_on": "self-hosted",
"name": "llama-2-70b",
"dockerfile": "docker/presets/llama-2/Dockerfile",
"dockerfile": "docker/presets/inference/llama-2/Dockerfile",
"build_args": "--build-arg LLAMA_WEIGHTS=/llama/llama-2-70b --build-arg SRC_DIR=/home/presets/llama-2"
},
"shouldBuildLlama2": "true"
Expand All @@ -72,7 +72,7 @@
"model": {
"runs_on": "self-hosted",
"name": "llama-2-7b-chat",
"dockerfile": "docker/presets/llama-2/Dockerfile",
"dockerfile": "docker/presets/inference/llama-2/Dockerfile",
"build_args": "--build-arg LLAMA_WEIGHTS=/llama/llama-2-7b-chat --build-arg SRC_DIR=/home/presets/llama-2-chat"
},
"shouldBuildLlama2Chat": "true"
Expand All @@ -82,7 +82,7 @@
"model": {
"runs_on": "self-hosted",
"name": "llama-2-13b-chat",
"dockerfile": "docker/presets/llama-2/Dockerfile",
"dockerfile": "docker/presets/inference/llama-2/Dockerfile",
"build_args": "--build-arg LLAMA_WEIGHTS=/llama/llama-2-13b-chat --build-arg SRC_DIR=/home/presets/llama-2-chat"
},
"shouldBuildLlama2Chat": "true"
Expand All @@ -92,7 +92,7 @@
"model": {
"runs_on": "self-hosted",
"name": "llama-2-70b-chat",
"dockerfile": "docker/presets/llama-2/Dockerfile",
"dockerfile": "docker/presets/inference/llama-2/Dockerfile",
"build_args": "--build-arg LLAMA_WEIGHTS=/llama/llama-2-70b-chat --build-arg SRC_DIR=/home/presets/llama-2-chat"
},
"shouldBuildLlama2Chat": "true"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/kind-cluster/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def get_weights_path(model_name):
return f"/datadrive/{model_name}/weights"

def get_dockerfile_path(model_runtime):
return f"/kaito/docker/presets/{model_runtime}/Dockerfile"
return f"/kaito/docker/presets/inference/{model_runtime}/Dockerfile"

def generate_unique_id():
"""Generate a unique identifier for a job."""
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
23 changes: 23 additions & 0 deletions docker/presets/tuning/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
FROM python:3.10-slim

ARG WEIGHTS_PATH
ARG MODEL_TYPE
ARG VERSION

# Set the working directory
WORKDIR /workspace/tfs

# Write the version to a file
RUN echo $VERSION > /workspace/tfs/version.txt

# First, copy just the preset files and install dependencies
# This is done before copying the code to utilize Docker's layer caching and
# avoid reinstalling dependencies unless the requirements file changes.
COPY kaito/presets/tuning/${MODEL_TYPE}/requirements.txt /workspace/tfs/requirements.txt
RUN pip install --no-cache-dir -r requirements.txt

COPY kaito/presets/tuning/${MODEL_TYPE}/cli.py /workspace/tfs/cli.py
COPY kaito/presets/tuning/${MODEL_TYPE}/fine_tuning_api.py /workspace/tfs/tuning_api.py

# Copy the entire model weights to the weights directory
COPY ${WEIGHTS_PATH} /workspace/tfs/weights
2 changes: 1 addition & 1 deletion pkg/controllers/workspace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka

func (c *WorkspaceReconciler) deleteWorkspace(ctx context.Context, wObj *kaitov1alpha1.Workspace) (reconcile.Result, error) {
klog.InfoS("deleteWorkspace", "workspace", klog.KObj(wObj))
// TODO delete workspace, machine(s), training and inference (deployment, service) obj ( ok to delete machines? which will delete nodes??)
// TODO delete workspace, machine(s), fine_tuning and inference (deployment, service) obj ( ok to delete machines? which will delete nodes??)
err := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeDeleting, metav1.ConditionTrue, "workspaceDeleted", "workspace is being deleted")
if err != nil {
klog.ErrorS(err, "failed to update workspace status", "workspace", klog.KObj(wObj))
Expand Down
2 changes: 1 addition & 1 deletion presets/models/llama2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export LLAMA_WEIGHTS_PATH=<path to your downloaded model weight files>
Use the following command to build the llama2 inference service image from the root of the repo.
```
docker build \
--file docker/presets/llama-2/Dockerfile \
--file docker/presets/inference/llama-2/Dockerfile \
--build-arg WEIGHTS_PATH=$LLAMA_WEIGHTS_PATH \
--build-arg MODEL_TYPE=llama2-completion \
--build-arg VERSION=0.0.1 \
Expand Down
2 changes: 1 addition & 1 deletion presets/models/llama2chat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export LLAMA_WEIGHTS_PATH=<path to your downloaded model weight files>
Use the following command to build the llama2chat inference service image from the root of the repo.
```
docker build \
--file docker/presets/llama-2/Dockerfile \
--file docker/presets/inference/llama-2/Dockerfile \
--build-arg WEIGHTS_PATH=$LLAMA_WEIGHTS_PATH \
--build-arg MODEL_TYPE=llama2-chat \
--build-arg VERSION=0.0.1 \
Expand Down
2 changes: 1 addition & 1 deletion presets/test/falcon-benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Ensure your `accelerate` configuration aligns with the values provided during be
- If you haven't already, you can use the Azure CLI or the Azure Portal to create and configure a GPU node pool in your AKS cluster.
<!-- markdown-link-check-disable -->
2. Building and Pushing the Docker Image:
- First, you need to build a Docker image from the provided [Dockerfile](https://github.com/Azure/kaito/blob/main/docker/presets/tfs/Dockerfile) and push it to a container registry accessible by your AKS cluster
- First, you need to build a Docker image from the provided [Dockerfile](https://github.com/Azure/kaito/blob/main/docker/presets/inference/tfs/Dockerfile) and push it to a container registry accessible by your AKS cluster
<!-- markdown-link-check-enable -->
- Example:
```
Expand Down
103 changes: 103 additions & 0 deletions presets/test/manifests/tuning/falcon/falcon-7b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: falcon-7b-tuning
spec:
replicas: 1
selector:
matchLabels:
app: falcon
template:
metadata:
labels:
app: falcon
spec:
containers:
- name: falcon-container
image: aimodelsregistrytest.azurecr.io/tuning-falcon-7b:0.0.1
command: ["/bin/sh", "-c", "sleep infinity"]
resources:
requests:
nvidia.com/gpu: 2
limits:
nvidia.com/gpu: 2 # Requesting 2 GPUs
volumeMounts:
- name: dshm
mountPath: /dev/shm
- name: workspace
mountPath: /workspace

- name: docker-sidecar
image: docker:dind
securityContext:
privileged: true # Allows container to manage its own containers
volumeMounts:
- name: workspace
mountPath: /workspace
env:
- name: ACR_USERNAME
value: "{{ACR_USERNAME}}"
- name: ACR_PASSWORD
value: "{{ACR_PASSWORD}}"
- name: TAG
value: "{{TAG}}"
command: ["/bin/sh"]
args:
- -c
- |
# Start the Docker daemon in the background with specific options for DinD
dockerd &
# Wait for the Docker daemon to be ready
while ! docker info > /dev/null 2>&1; do
echo "Waiting for Docker daemon to start..."
sleep 1
done
echo 'Docker daemon started'
while true; do
FILE_PATH=$(find /workspace/tfs -name 'fine_tuning_completed.txt')
if [ ! -z "$FILE_PATH" ]; then
echo "FOUND TRAINING COMPLETED FILE at $FILE_PATH"
PARENT_DIR=$(dirname "$FILE_PATH")
echo "Parent directory is $PARENT_DIR"
TEMP_CONTEXT=$(mktemp -d)
cp "$PARENT_DIR/adapter_config.json" "$TEMP_CONTEXT/adapter_config.json"
cp -r "$PARENT_DIR/adapter_model.safetensors" "$TEMP_CONTEXT/adapter_model.safetensors"
# Create a minimal Dockerfile
echo 'FROM scratch
ADD adapter_config.json /
ADD adapter_model.safetensors /' > "$TEMP_CONTEXT/Dockerfile"
# Login to Docker registry
echo $ACR_PASSWORD | docker login $ACR_USERNAME.azurecr.io -u $ACR_USERNAME --password-stdin
docker build -t $ACR_USERNAME.azurecr.io/adapter-falcon-7b:$TAG "$TEMP_CONTEXT"
docker push $ACR_USERNAME.azurecr.io/adapter-falcon-7b:$TAG
# Cleanup: Remove the temporary directory
rm -rf "$TEMP_CONTEXT"
# Remove the file to prevent repeated builds, or handle as needed
# rm "$FILE_PATH"
fi
sleep 10 # Check every 10 seconds
done
volumes:
- name: dshm
emptyDir:
medium: Memory
- name: workspace
emptyDir: {}

tolerations:
- effect: NoSchedule
key: sku
operator: Equal
value: gpu
- effect: NoSchedule
key: nvidia.com/gpu
operator: Exists
129 changes: 129 additions & 0 deletions presets/tuning/tfs/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional

import torch
from peft import LoraConfig
from transformers import (BitsAndBytesConfig, DataCollatorForLanguageModeling,
PreTrainedTokenizer, TrainerCallback)


@dataclass
class ExtDataCollator(DataCollatorForLanguageModeling):
tokenizer: Optional[PreTrainedTokenizer] = field(default=PreTrainedTokenizer, metadata={"help": "Tokenizer for DataCollatorForLanguageModeling"})

@dataclass
class ExtLoraConfig(LoraConfig):
"""
Lora Config
"""
init_lora_weights: bool = field(default=True, metadata={"help": "Enable initialization of LoRA weights"})
target_modules: Optional[List[str]] = field(default=None, metadata={"help": ("List of module names to replace with LoRA.")})
layers_to_transform: Optional[List[int]] = field(default=None, metadata={"help": "Layer indices to apply LoRA"})
layers_pattern: Optional[List[str]] = field(default=None, metadata={"help": "Pattern to match layers for LoRA"})
loftq_config: Dict[str, any] = field(default_factory=dict, metadata={"help": "LoftQ configuration for quantization"})

@dataclass
class DatasetConfig:
"""
Config for Dataset
"""
dataset_name: str = field(metadata={"help": "Name of Dataset"})
shuffle_dataset: bool = field(default=True, metadata={"help": "Whether to shuffle dataset"})
shuffle_seed: int = field(default=42, metadata={"help": "Seed for shuffling data"})
context_column: str = field(default="Context", metadata={"help": "Example human input column in the dataset"})
response_column: str = field(default="Response", metadata={"help": "Example bot response output column in the dataset"})
train_test_split: float = field(default=0.8, metadata={"help": "Split between test and training data (e.g. 0.8 means 80/20% train/test split)"})

@dataclass
class TokenizerParams:
"""
Tokenizer params
"""
add_special_tokens: bool = field(default=True, metadata={"help": ""})
padding: bool = field(default=False, metadata={"help": ""})
truncation: bool = field(default=None, metadata={"help": ""})
max_length: Optional[int] = field(default=None, metadata={"help": ""})
stride: int = field(default=0, metadata={"help": ""})
is_split_into_words: bool = field(default=False, metadata={"help": ""})
tok_pad_to_multiple_of: Optional[int] = field(default=None, metadata={"help": ""})
tok_return_tensors: Optional[str] = field(default=None, metadata={"help": ""})
return_token_type_ids: Optional[bool] = field(default=None, metadata={"help": ""})
return_attention_mask: Optional[bool] = field(default=None, metadata={"help": ""})
return_overflowing_tokens: bool = field(default=False, metadata={"help": ""})
return_special_tokens_mask: bool = field(default=False, metadata={"help": ""})
return_offsets_mapping: bool = field(default=False, metadata={"help": ""})
return_length: bool = field(default=False, metadata={"help": ""})
verbose: bool = field(default=True, metadata={"help": ""})

@dataclass
class ModelConfig:
"""
Transformers Model Configuration Parameters
"""
pretrained_model_name_or_path: Optional[str] = field(default="/workspace/tfs/weights", metadata={"help": "Path to the pretrained model or model identifier from huggingface.co/models"})
state_dict: Optional[Dict[str, Any]] = field(default=None, metadata={"help": "State dictionary for the model"})
cache_dir: Optional[str] = field(default=None, metadata={"help": "Cache directory for the model"})
from_tf: bool = field(default=False, metadata={"help": "Load model from a TensorFlow checkpoint"})
force_download: bool = field(default=False, metadata={"help": "Force the download of the model"})
resume_download: bool = field(default=False, metadata={"help": "Resume an interrupted download"})
proxies: Optional[str] = field(default=None, metadata={"help": "Proxy configuration for downloading the model"})
output_loading_info: bool = field(default=False, metadata={"help": "Output additional loading information"})
allow_remote_files: bool = field(default=False, metadata={"help": "Allow using remote files, default is local only"})
m_revision: str = field(default="main", metadata={"help": "Specific model version to use"})
trust_remote_code: bool = field(default=False, metadata={"help": "Enable trusting remote code when loading the model"})
m_load_in_4bit: bool = field(default=False, metadata={"help": "Load model in 4-bit mode"})
m_load_in_8bit: bool = field(default=False, metadata={"help": "Load model in 8-bit mode"})
torch_dtype: Optional[str] = field(default=None, metadata={"help": "The torch dtype for the pre-trained model"})
device_map: str = field(default="auto", metadata={"help": "The device map for the pre-trained model"})

def __post_init__(self):
"""
Post-initialization to validate some ModelConfig values
"""
if self.torch_dtype and not hasattr(torch, self.torch_dtype):
raise ValueError(f"Invalid torch dtype: {self.torch_dtype}")
self.torch_dtype = getattr(torch, self.torch_dtype) if self.torch_dtype else None

@dataclass
class QuantizationConfig(BitsAndBytesConfig):
"""
Quanitization Configuration
"""
quant_method: str = field(default="bitsandbytes", metadata={"help": "Quantization Method {bitsandbytes,gptq,awq}"})
load_in_8bit: bool = field(default=False, metadata={"help": "Enable 8-bit quantization"})
load_in_4bit: bool = field(default=False, metadata={"help": "Enable 4-bit quantization"})
llm_int8_threshold: float = field(default=6.0, metadata={"help": "LLM.int8 threshold"})
llm_int8_skip_modules: List[str] = field(default=None, metadata={"help": "Modules to skip for 8-bit conversion"})
llm_int8_enable_fp32_cpu_offload: bool = field(default=False, metadata={"help": "Enable FP32 CPU offload for 8-bit"})
llm_int8_has_fp16_weight: bool = field(default=False, metadata={"help": "Use FP16 weights for LLM.int8"})
bnb_4bit_compute_dtype: str = field(default="float32", metadata={"help": "Compute dtype for 4-bit quantization"})
bnb_4bit_quant_type: str = field(default="fp4", metadata={"help": "Quantization type for 4-bit"})
bnb_4bit_use_double_quant: bool = field(default=False, metadata={"help": "Use double quantization for 4-bit"})

@dataclass
class TrainingConfig:
"""
Configuration for fine_tuning process
"""
save_output_path: str = field(default=".", metadata={"help": "Path where fine_tuning output is saved"})
# Other fine_tuning-related configurations can go here

# class CheckpointCallback(TrainerCallback):
# def on_train_end(self, args, state, control, **kwargs):
# model_path = args.output_dir
# timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
# img_tag = f"ghcr.io/YOUR_USERNAME/LoRA-Adapter:{timestamp}"

# # Write a file to indicate fine_tuning completion
# completion_indicator_path = os.path.join(model_path, "training_completed.txt")
# with open(completion_indicator_path, 'w') as f:
# f.write(f"Training completed at {timestamp}\n")
# f.write(f"Image Tag: {img_tag}\n")

# This method is called whenever a checkpoint is saved.
# def on_save(self, args, state, control, **kwargs):
# docker_build_and_push()
Loading

0 comments on commit ec8a8e2

Please sign in to comment.