Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding replay into GPT-NeoX #1200

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Adding replay into GPT-NeoX #1200

wants to merge 8 commits into from

Conversation

AIproj
Copy link
Contributor

@AIproj AIproj commented Apr 13, 2024

This PR aims to add replay to GPT-NeoX. I had implemented this for the paper Simple and Scalable Strategies to Continually Pre-train Large Language Models that shows simple ways to efficiently continue to pretrain by improving adaptation to new data while mitigating forgetting of previous data. Note that this PR can serve as a basis to add the ability to resume training from a certain index in a dataset, based on how I implemented this feature for replay datasets.

How to use

I tried to make the descriptions of the replay args informative enough to serve as documentation. An example of a config using replay is also provided in tests/config/example_replay_config.yml.

Unsupported/untested features:

  • (UNTESTED) Using replay AND weighting by number of documents. There's an assert to throw an error if someone tries to use both.
  • (UNSUPPORTED) Using replay AND splitting the datasets automatically instead of providing separate train, val and test paths. There's an assert to throw an error if someone tries to use both.
  • (UNSUPPORTED) Using replay AND label data. There's an assert to throw an error if someone tries to use both. As indicated in comments, it might be doable by adding a replay_label_data arg that would specify the prefix to the idx and data path of replay label data, then generate the specific replay label data path from the prefix, and treat it in a similar way as the training data in the block
    # The concatenate_train_replay_paths bool is necessary to avoid issues when this function gets called a second time.
    if neox_args.is_replay_enabled and concatenate_train_replay_paths:
        # Merge replay data paths into train data paths logic, but need to keep track of
        # what paths in train_data_paths came from replay
        num_replay_data_paths = len(neox_args.replay_data_paths)
        num_non_replay_data_paths = len(neox_args.train_data_paths)
        neox_args.train_data_paths += neox_args.replay_data_paths

Pending tests

Currently, the tests required are:

  1. Sanity check that first few batches are the same with/without these changes.
  2. Similarly as above, check that label data support did not break with this.
  3. Sanity check that given two datasets, not using replay but having 0.5 weights for each is the same as setting one dataset as training dataset, and the other as replay dataset with replay fraction 0.5.

The tests can follow the procedure described in tests/model/test_batch_replicability.py. Tests 1 and 3 were passed with the Summit version of NeoX, but I'll need to run them again on the replay implementation based on the current main branch of NeoX. I'll probably need someone else to test that label data support (test 2) did not break as I'm unfamiliar with this feature of NeoX and am currently too busy to take that on.

@AIproj AIproj requested a review from haileyschoelkopf April 13, 2024 00:08
@AIproj AIproj self-assigned this Apr 13, 2024
@AIproj AIproj requested a review from Quentin-Anthony as a code owner April 13, 2024 00:08
@bentherien
Copy link
Contributor

Please ignore the above commits. I accidentally pushed to upstream when modifying this branch in my fork.


Default = 0.05

Fraction of a batch dedicated to doing replay. For example, 0.1 means that in a batch of 100, 19 samples will come from the replay
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, 0.1 means that in a batch of 100, 19 samples will come from the replay buffer.

Is this a typo? Why wouldn't it be 10 samples?


- **replay_seed**: int

Default = 1234
Copy link
Member

@StellaAthena StellaAthena Apr 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems important that the replay seed isn't the same as the general data seed from your other comments. If that's correct, let's use a different default.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants