Skip to content

Commit

Permalink
Update train script with new interleave function (#171)
Browse files Browse the repository at this point in the history
* new name for interleave dataset

* jat-small to jat

* change output dir
  • Loading branch information
qgallouedec authored Jun 6, 2024
1 parent 912dc9c commit e8a1964
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ Here are some examples of how you might use JAT in both evaluation and fine-tuni
- **Training JAT**: Train your own JAT model from scratch (run on 8xA100)
```shell
accelerate launch scripts/train_jat_tokenized.py \
--output_dir checkpoints/jat_small_v100 \
--model_name_or_path jat-project/jat-small \
--output_dir checkpoints/jat \
--model_name_or_path jat-project/jat \
--tasks all \
--trust_remote_code \
--per_device_train_batch_size 20 \
Expand Down
1 change: 1 addition & 0 deletions jat/processing_jat.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ class JatProcessor(ProcessorMixin):
tokenizer ([`AutoTokenizer`]):
The tokenizer is a required input.
"""

attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
Expand Down
12 changes: 9 additions & 3 deletions scripts/train_jat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from jat.eval.rl.core import TASK_NAME_TO_ENV_ID
from jat.modeling_jat import JatModel
from jat.utils import mix_iterable_datasets
from jat.utils_interleave_datasets import interleave_datasets


# Sometimes, the server is down; increasing the number of
Expand Down Expand Up @@ -185,9 +185,15 @@ def add_loss_weight(example, loss_weight):
eval_dataset[key] = eval_dataset[key].take(data_args.eval_num_samples)

weights = [SAMPLE_WEIGHTS.get(t, 1.0) for t in train_dataset.keys()]
train_dataset = mix_iterable_datasets(
list(train_dataset.values()), batch_size=training_args.per_device_train_batch_size, weights=weights

train_dataset = interleave_datasets(
list(train_dataset.values()),
probabilities=[w / sum(weights) for w in weights],
seed=training_args.seed,
stopping_strategy="all_exhausted",
n_contiguous=training_args.per_device_train_batch_size,
)

# Due to the train dataset's structure, where every 'n' consecutive samples share the same modalities, we can't
# load all samples at once. Different sets of 'n' samples have different modalities. Therefore, we must load and
# process each set of 'n' samples separately.
Expand Down

0 comments on commit e8a1964

Please sign in to comment.