Skip to content

Commit

Permalink
add kd trainer e2e test
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 8, 2025
1 parent 476eb7e commit 6240804
Showing 1 changed file with 84 additions and 0 deletions.
84 changes: 84 additions & 0 deletions tests/e2e/integrations/kd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
e2e tests for kd trainer support in Axolotl
"""
from pathlib import Path

import pytest
from e2e.utils import check_tensorboard

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault


@pytest.fixture(name="kd_min_cfg")
def min_cfg(temp_dir):
return {
"base_model": "unsloth/Llama-3.2-1B",
"plugins": [
"axolotl.integrations.kd.KDPlugin",
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rms_norm": True,
"liger_glu_activation": True,
"torch_compile": False,
"chat_template": "llama3",
"kd_trainer": True,
"kd_ce_alpha": 0.1,
"kd_alpha": 0.9,
"kd_temperature": 2.0,
"dataloader_prefetch_factor": 8,
"dataloader_num_workers": 4,
"dataloader_pin_memory": True,
"datasets": [
{
"path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample",
"type": "axolotl.integrations.kd.chat_template",
"field_messages": "messages_combined",
"split": "train",
"logprobs_field": "llm_text_generation_vllm_logprobs",
"temperature": 1.0,
"preprocess_shards": 2,
},
],
"val_set_size": 0.0,
"sequence_len": 4096,
"sample_packing": True,
"pad_to_sequence_len": True,
"gradient_accumulation_steps": 2,
"micro_batch_size": 2,
"num_epochs": 1,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"learning_rate": 0.0001,
"bf16": "auto",
"gradient_checkpointing": True,
"flash_attention": True,
"special_tokens": {
"pad_token": "<|end_of_text|>",
"eos_token": "<|eot_id|>",
},
"max_steps": 5,
"output_dir": temp_dir,
}


class TestKnowledgeDistillation:
"""
Test case for Knowledge Distillation
"""

def test_llama_kd(self, temp_dir, kd_min_cfg):
cfg = DictDefault(kd_min_cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)

0 comments on commit 6240804

Please sign in to comment.