Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor spanner to avoid creating large array #773

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

XiaohanZhangCMU
Copy link
Collaborator

@XiaohanZhangCMU XiaohanZhangCMU commented Sep 3, 2024

Description of changes:

Refactor spanner init to avoid creating a large array which may lead to OOM.

Issue #, if available:

OOM issue

Merge Checklist:

Put an x without space in the boxes that apply. If you are unsure about any checklist, please don't hesitate to ask. We are here to help! This is simply a reminder of what we are going to look for before merging your pull request.

General

  • I have read the contributor guidelines
  • This is a documentation change or typo fix. If so, skip the rest of this checklist.
  • I certify that the changes I am introducing will be backward compatible, and I have discussed concerns about this, if any, with the MosaicML team.
  • I have updated any necessary documentation, including README and API docs (if appropriate).

Tests

  • I ran pre-commit on my change. (check out the pre-commit section of prerequisites)
  • I have added tests that prove my fix is effective or that my feature works (if appropriate).
  • I ran the tests locally to make sure it pass. (check out testing)
  • I have added unit and/or integration tests as appropriate to ensure backward compatibility of the changes.

Copy link

@mihir-db mihir-db left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this have tests

Copy link
Collaborator

@snarayan21 snarayan21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok i think this makes sense, but can you add the following:

  1. tests comparing the self.spans object with old method vs new method
  2. quick local perf benchmark for old vs. new spanner object creation -- want to make sure this doesn't increase init time excessively

@gluonfield
Copy link

gluonfield commented Sep 4, 2024

Hi @XiaohanZhangCMU incredible job! I just tested and can confirm that this solves a problem with large number of shards.

However, there may be another bottleneck with StreamingDataLoader on datasets with large number of shards.

The following code takes ~5 minutes and CPU memory utilisation climbs to ~80GB before the first batch can be retrieved. As the dataloader is being used it consumes ~50GB of memory. If I end up using 8 GPUs there may be not enough CPU memory in H100s to be able to use it.

dataset = StreamingDataset(
    remote="/mnt/disks/raw/train",
    local="./local",
    batch_size=1024,
    cache_limit="10GB",
)
dataloader = StreamingDataLoader(dataset, batch_size=1024)
for batch in dataloader:
    print(batch)
    break

Any ideas what could be going wrong?

@gluonfield
Copy link

Sharing profiler results on time to first batch

image

@XiaohanZhangCMU
Copy link
Collaborator Author

@AugustDev yeah, I wouldn't be surprised with those profiling numbers. There are places that some large arrays of int32 are created, which may add up.....

@gluonfield
Copy link

gluonfield commented Sep 5, 2024

I see... So MosaicML currently is not suitable for datasets with large number of shards 😢. We're using litdata, but been thinking to migrate to MosaicML as has great features. Do you have any suggestions what could work for us? Currently we generate dataset using default size_limit. Increasing it would reduce number of shards, but may reduce performance?

@snarayan21
Copy link
Collaborator

Hey @AugustDev, in your case, you have a whole lot of samples and storing the sample partition array is taking up a good amount of space. However, according to your issue, you have ~27B samples, right? So the sample partition array, even in int64, will be ~216GB, which should be possible. The array is created but then saved to a shared memory file which all workers can access, so this should be a temporary memory cost. Regardless, most H100 systems have nodes with much more CPU ram than 216 GB so this should still be feasible...

Reducing the number of shards can help but will probably not make a significant difference for memory savings. It's more about the sheer number of samples you have.

@XiaohanZhangCMU
Copy link
Collaborator Author

@AugustDev I would still recommend give it a try using your current set up for 8 gpus (with or without this PR). And let us know if you hit any memory issue.

To clarify, the number of shards is not the issue, having smaller shards is actually preferred when streaming from cloud. In your case, the partition array (a shared array of int) was causing the large mem allocation. This PR is to reduce the mem cost in initialization but the partition array alloc is not avoidable. But as @snarayan21 mentioned, the partition array is shared across processes so I think you should be good to scale up to 8 gpus.

@gluonfield
Copy link

Thank you for the message. Will try on H100 - 8GPUs and report back the findings.

@gluonfield
Copy link

gluonfield commented Sep 23, 2024

Hi guys, for me to load the dataset very large dataset (2.4B rows) even with the spanner fix takes 2min 28sec.

When it comes to loading first batch

for batch in dataloader:
    print(batch)
    break

it's been 15min (and still running). It seems I won't be able to train on this dataset using Mosaic. Any plans to support Mosaic Streaming on large datasets? I'm not sure where the problem with dataloading is, but happy to help fixing if you can guide where the performance issue might come from?

Perhaps it's time to merge this?

@snarayan21
Copy link
Collaborator

Hey @AugustDev, we've been able to train on datasets that have that many (or more) samples -- this is likely an issue particular to your dataset. Are you trying to retrieve a batch locally on your laptop or from the GPUs themselves? Have you tried model training while dataloading instead of dataloading alone? For performance tuning, you can use the streaming simulator to input your dataset characteristics and understand what performance you can expect. Happy to help further.

@snarayan21
Copy link
Collaborator

@XiaohanZhangCMU Mind adding the tests mentioned above and we can get this one in?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants