Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SFTTrainer Raises NotImplementedError with IterableDataset #2138

Open
2 of 4 tasks
research-boy opened this issue Sep 27, 2024 · 4 comments
Open
2 of 4 tasks

SFTTrainer Raises NotImplementedError with IterableDataset #2138

research-boy opened this issue Sep 27, 2024 · 4 comments
Labels
🐛 bug Something isn't working

Comments

@research-boy
Copy link

research-boy commented Sep 27, 2024

System Info

Google Colab

Description

When attempting to fine-tune a model using the SFTTrainer with an IterableDataset, an error occurs because the SFTTrainer expects a dataset that supports random access (__getitem__). This is problematic when working with large datasets that cannot be loaded into memory at once and require streaming.
Error Message

NotImplementedError: Subclasses of Dataset should implement __getitem__.

Context : This issue is especially relevant for fine-tuning on very large datasets, where memory constraints make it impractical to load the dataset fully into memory.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

  1. Load a large dataset using the datasets library with streaming enabled, like this:
from datasets import load_dataset

# Load dataset in streaming mode
dataset = load_dataset('csv', data_files='path_to_large_files/*.csv', streaming=True)
  1. Attempt to fine-tune a model using SFTTrainer with the streaming dataset:
from trl import SFTTrainer
from unsloth import is_bfloat16_supported

# Define the model and tokenizer
model = ... # Load your model here
tokenizer = ... # Load your tokenizer here

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    formatting_func = format_example,
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        
        # Use num_train_epochs = 1, warmup_ratio for full training runs!
        warmup_steps = 5,
        max_steps = 320,

        learning_rate = 2e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        eval_strategy="no"
    ),
)

Expected behavior

The NotImplementedError is raised when the trainer tries to access the dataset.

@research-boy research-boy added the 🐛 bug Something isn't working label Sep 27, 2024
@dame-cell
Copy link

dame-cell commented Sep 27, 2024

The trl library can handle IterableDataset and it was actually fixed check this out pr
and if you get any error regarding unsloth try fine-tuning the model without using unsloth

@research-boy
Copy link
Author

research-boy commented Sep 27, 2024

@dame-cell ya i did check the PR , you can see the last few comments mentioning to do pip install git+https://github.com/huggingface/trl.git which also didn't work for me.

@research-boy
Copy link
Author

This is what happening,the function _prepare_non_packed_dataloader doesn't have 'IterableDataset' implemented properly while _prepare_packed_dataloader does have . So it runs when you set the packing=True. But on running trainer_stats = trainer.train() gives another error AttributeError: 'ConstantLengthDataset' object has no attribute 'column_names'

@dame-cell
Copy link

dame-cell commented Sep 28, 2024

Hmm did you try running the code without unsloth? like just using the trl library

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants