Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Litdata optimize is very slow #417

Closed
nightingal3 opened this issue Nov 15, 2024 · 7 comments
Closed

Litdata optimize is very slow #417

nightingal3 opened this issue Nov 15, 2024 · 7 comments
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@nightingal3
Copy link

🐛 Bug

I'm processing 10B tokens of text with the optimize function in litdata. However, it seems like progress is very slow and often gets stuck altogether. The time estimate keeps increasing as more samples are processed (for instance, with 8 workers, the last time estimate was 7 hours, and now appears to be stuck). I also tried using 1 worker, which produced a less varied time estimate, but will take 50+ hours.

To Reproduce

I used the following script (the temp file creation was just to troubleshoot a previous issue). I have the latest version of litdata installed (0.2.29)

import argparse
import os
import tempfile
from functools import partial
from pathlib import Path

from datasets import Dataset, load_dataset
from litdata import optimize
from litgpt import Tokenizer

tokenizer = Tokenizer("<path>/checkpoints/EleutherAI/pythia-1b")


def tokenize(data: Dataset, index: int):
    yield tokenizer.encode(data[index]["text"], eos=True)

def setup_directories(job_id: str):
    """Set up and clean temp and cache directories for a specific job."""
    # Create job-specific paths
    temp_dir = f"<path>/tmp_job{job_id}"
    cache_dir = f"<path>/tmp/huggingface_cache_job{job_id}"
    
    # Clean up existing directories if they exist
    for dir_path in [temp_dir, cache_dir]:
        if os.path.exists(dir_path):
            try:
                shutil.rmtree(dir_path)
                print(f"Cleaned up existing directory: {dir_path}")
            except Exception as e:
                print(f"Error cleaning {dir_path}: {e}")
    
    # Create fresh directories
    os.makedirs(temp_dir, exist_ok=True)
    os.makedirs(cache_dir, exist_ok=True)
    
    # Set environment variables
    os.environ["HF_HOME"] = cache_dir
    os.environ["TMPDIR"] = temp_dir
    os.environ["TEMP"] = temp_dir
    os.environ["TMP"] = temp_dir
    tempfile.tempdir = temp_dir
    
    print(f"Set up fresh directories for job {job_id}:")
    print(f"Temp dir: {temp_dir}")
    print(f"Cache dir: {cache_dir}")
    
    return cache_dir


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, default=".")
    parser.add_argument("--data_split", type=str, default="sample-10BT")
    parser.add_argument("--job_id", type=str, required=True, help="Unique identifier for this job")
    args = parser.parse_args()

    cache_dir = setup_directories(args.job_id)

    temp_dir = f"<path>/tmp_job{args.job_id}"
    os.makedirs(temp_dir, exist_ok=True)

    #Set all temp variables to be safe
    os.environ["TMPDIR"] = temp_dir
    os.environ["TEMP"] = temp_dir
    os.environ["TMP"] = temp_dir

    tempfile.tempdir = temp_dir

    print(f"Temporary directory is: {tempfile.gettempdir()}")
    
    dataset = load_dataset(
        "HuggingFaceFW/fineweb",
        num_proc=(os.cpu_count() - 1),
        name=args.data_split,
        split="train",
        cache_dir=None,
        download_mode="force_redownload"
    )
    print("Total examples:", len(dataset))

    # Split the data in training and validation
    split_dataset = dataset.train_test_split(test_size=0.003, seed=42, shuffle=True)
    split_dataset["val"] = split_dataset.pop("test")  # rename the test split to val
    output_dir = Path(args.data_path)
    output_dir.mkdir(parents=True, exist_ok=True)


    optimize(
        fn=partial(tokenize, split_dataset["train"]),
        inputs=list(range(len(split_dataset["train"]))),
        output_dir=f"{args.data_path}/train",
        num_workers=8,
        chunk_bytes="500MB",
    )
    optimize(
        fn=partial(tokenize, split_dataset["val"]),
        inputs=list(range(len(split_dataset["val"]))),
        output_dir=f"{args.data_path}/val",
        num_workers=8,
        chunk_bytes="500MB",
    )

Expected behavior

I expected processing to take a consistent time and not get stuck when using multiple workers. The number of CPUs available is much larger than 8.

Additional context

Environment detail
  • PyTorch Version (e.g., 1.0): 2.5.1+cu124
  • OS (e.g., Linux): linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.9
  • CUDA/cuDNN version: 12.4
  • GPU models and configuration: N/A
  • Any other relevant information:
@nightingal3 nightingal3 added bug Something isn't working help wanted Extra attention is needed labels Nov 15, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

@tchaton
Copy link
Collaborator

tchaton commented Nov 16, 2024

Yes, it is expected. LitData orders the files based on their size. To process 10B tokens, i strongly recommend to use a multi machine job on the Lightning-AI platform. Processing 1B took us 4 hours on 8 machines.

@nightingal3
Copy link
Author

Thanks for the response, that makes sense! How many cpus would you recommend for processing 10B, 100B, 500B and 1T tokens in a reasonable timespan? (say <=3 days)

@bhimrazy
Copy link
Collaborator

bhimrazy commented Nov 16, 2024

Adding this studio here just in case it could serve as a helpful reference: Prepare the TinyLlama 1T Token Dataset. 😊

@tchaton
Copy link
Collaborator

tchaton commented Nov 16, 2024

We used 16 machines with 32 CPU to prepare a 1T dataset. It is quite a heavy data processing. It took 4 hours. I would expect this to be even more time consuming for 15T.

image

@nightingal3
Copy link
Author

Thanks for the response! I'll definitely scale up the number of machines I'm using in that case!

@nightingal3
Copy link
Author

Update: this may be a niche issue but I suddenly realized that since I'm in an HPC environment, os.cpu_count() that I was using was not accurate, and this could have caused a lot of issues in data processing. Putting this here for future reference in case anyone has the same issue.

>>> os.environ["SLURM_CPUS_ON_NODE"]
'2'
>>> os.cpu_count()
256

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants