Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KD trainer w/ logprobs #2202

Draft
wants to merge 66 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
88b3198
refactor trainer to prevent circular dependencies later
winglian Dec 16, 2024
303cfa7
KD dataset loading and KD with logprobs
winglian Dec 18, 2024
d584354
filter bad rows
winglian Dec 18, 2024
e633a12
make batch smaller
winglian Dec 18, 2024
ddcf5c6
handle padding/collation for KD datasets
winglian Dec 18, 2024
7fe0ad0
make it work
winglian Dec 19, 2024
b592c05
flipped the slice
winglian Dec 19, 2024
ae545e0
cross entropy loss coefficient during KD
winglian Dec 19, 2024
00ce77e
make sure to multiply against the correct loss
winglian Dec 19, 2024
ed49051
chore: lint
winglian Dec 19, 2024
0b59a24
triton wip
winglian Dec 21, 2024
c73acd7
no where support
winglian Dec 21, 2024
119d586
v2 trial
winglian Dec 21, 2024
18a46c3
no torch.exp inside triton kernel
winglian Dec 21, 2024
dc90c93
no log etc
winglian Dec 21, 2024
081928e
no torch.tensor
winglian Dec 21, 2024
e565694
v3
winglian Dec 21, 2024
c0757e8
fix kwarg
winglian Dec 21, 2024
d8d817e
don't use triton for now
winglian Dec 21, 2024
7366efc
better rescaling for temperatures
winglian Dec 24, 2024
3416302
hash for temperature too
winglian Dec 24, 2024
ca5e397
use kd_alpha in the correct loss method
winglian Dec 25, 2024
6314630
fix kd loss so it's causal (fixes repeating tokens)
winglian Dec 25, 2024
a5c085e
var naming and add todo
winglian Dec 26, 2024
689e1c1
chore: lint
winglian Dec 28, 2024
f09b5da
refactor so we can easily add new loss functions
winglian Dec 29, 2024
746891e
add license block
winglian Dec 29, 2024
f60c623
remove references to triton kd for now
winglian Dec 30, 2024
fa055f9
handle token/logprob shifting
winglian Dec 30, 2024
c51b033
support for custom trainer classes from plugins
winglian Dec 30, 2024
27faacb
refactor kd chat template loader
winglian Dec 30, 2024
885653d
move more things to kd plugin
winglian Dec 30, 2024
cdfcd69
remove moved class from import
winglian Dec 30, 2024
cba6165
make plugin setup concise
winglian Dec 30, 2024
feed96f
increase logging around loading plugins
winglian Dec 30, 2024
92c6c10
add copyrights
winglian Dec 30, 2024
d5bc214
remove duplicate code
winglian Dec 30, 2024
6e409d2
more info on preprocess for kd and fix import
winglian Dec 30, 2024
93dfff9
be a bit pickier about loading dynamic prompt strategies
winglian Dec 30, 2024
d3c2b7c
kd sample packing
winglian Dec 31, 2024
204d6c4
make loss torch script compat
winglian Dec 31, 2024
e659c01
support streaming for processing sft datasts?
winglian Jan 1, 2025
01896b1
improve iterable support
winglian Jan 2, 2025
684b382
ensure that batch vs single is done properly
winglian Jan 7, 2025
6784822
tweak check for batched prompt data
winglian Jan 7, 2025
808328e
reward can use same batch check
winglian Jan 7, 2025
47932f2
fix reward trainer calls for tokenization
winglian Jan 7, 2025
ab690f3
improve check for batched
winglian Jan 7, 2025
317f290
reward model doesn't work well with batched
winglian Jan 7, 2025
ff2fb0f
add kd trainer e2e test
winglian Jan 8, 2025
b9a42b3
linting
winglian Jan 8, 2025
1d039f5
rename test files so it gets picked up
winglian Jan 8, 2025
432f65f
make the kd e2e fit in vram for ci and add lora version
winglian Jan 8, 2025
158071e
set lora_dropout explicitly
winglian Jan 8, 2025
261e4fb
lower lr
winglian Jan 8, 2025
5303478
make sure to set tokenizer from l3 70b and save safetensors
winglian Jan 8, 2025
513ec9e
make sure to use the correct tokenizer
winglian Jan 8, 2025
b984755
fix adapter model check
winglian Jan 9, 2025
a5e0671
make sure to use tensorboard to capture loss for checks
winglian Jan 9, 2025
e8fceb7
chore: lint
winglian Jan 10, 2025
7232cbd
chore: lint
winglian Jan 13, 2025
510cf45
improve logprob masking and shift in trainer
winglian Jan 15, 2025
35a84f2
more fixes
winglian Jan 15, 2025
483defb
try tests for kd on l40s
winglian Jan 15, 2025
04efcb1
don't shift student logits for kd
winglian Jan 15, 2025
32258c2
no batching for kd chat templates
winglian Jan 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.5.1
num_gpus: 1
axolotl_extras:
steps:
Expand Down Expand Up @@ -253,7 +253,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
steps:
Expand Down
2 changes: 1 addition & 1 deletion cicd/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
}

N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)


def run_cmd(cmd: str, run_folder: str):
Expand Down
2 changes: 1 addition & 1 deletion docs/rlhf.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ datasets:
type: chatml.intel
- path: argilla/ultrafeedback-binarized-preferences
split: train
type: chatml.argilla
type: chatml
```

#### IPO
Expand Down
6 changes: 6 additions & 0 deletions src/axolotl/cli/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ class PreprocessCliArgs:
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=None,
metadata={
"help": "Use IterableDataset for streaming processing of large datasets"
},
)


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def preprocess(config: str, **kwargs) -> None:
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

from axolotl.cli.preprocess import do_cli

do_cli(config=config, **kwargs)
Expand Down
5 changes: 4 additions & 1 deletion src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
)


def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
def do_cli(
config: Union[Path, str] = Path("examples/"),
**kwargs,
) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_preprocess`.

Expand Down
6 changes: 6 additions & 0 deletions src/axolotl/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,17 @@ def load_datasets(
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = (
hasattr(cli_args, "iterable")
and cli_args.iterable is not None
and cli_args.iterable
)

train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)

if (
Expand Down
Loading
Loading