Skip to content

Commit

Permalink
add alpaca math mix datamodule
Browse files Browse the repository at this point in the history
  • Loading branch information
shuishen112 committed Dec 28, 2024
1 parent 3036dd6 commit c6a5ef7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
8 changes: 8 additions & 0 deletions mttl/dataloader/alpaca_dataset_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,11 @@ def __init__(self):
self.dataset = DatasetLibrary.pull_dataset(
"zhan1993/code_alpaca_20k", split="train"
)


class MathQaAlpacaCodeDataset(AlpacaDataset):
def __init__(self):
super().__init__()
self.dataset = DatasetLibrary.pull_dataset(
"zhan1993/metamath_code_alpaca_10k", split="train"
)
30 changes: 19 additions & 11 deletions mttl/datamodule/alpaca_data_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from mttl.dataloader.alpaca_dataset_readers import AlpacaCodeDataset, AlpacaDataset
from mttl.dataloader.alpaca_dataset_readers import (
AlpacaCodeDataset,
AlpacaDataset,
MathQaAlpacaCodeDataset,
)
from mttl.datamodule.base import DataModule, DatasetConfig


Expand Down Expand Up @@ -34,26 +38,30 @@ def setup_dataset(self):
self.test_dataset = self.dev_dataset


class AlpacaPretrainDataModule(AlpacaDataModule):
pass
@DataModule.register("mathqa_alpaca_code", config_cls=DatasetConfig)
class MathQaAlpacaCodeDataModule(AlpacaDataModule):
def setup_dataset(self):
dataset = MathQaAlpacaCodeDataset()
self.train_dataset, self.dev_dataset = self.create_train_valid_split(dataset)
self.test_dataset = self.dev_dataset


class AlpacaFinetuneDataModule(AlpacaDataModule):
pass


if __name__ == "__main__":
alpaca_data_module = AlpacaDataModule(
DatasetConfig(model="meta-llama/Llama-2-7b-hf")
)
alpaca_data_module.setup_dataset()
print(alpaca_data_module.train_dataset)
# alpaca_data_module = AlpacaDataModule(
# DatasetConfig(model="meta-llama/Llama-2-7b-hf")
# )
# alpaca_data_module.setup_dataset()
# print(alpaca_data_module.train_dataset)

alpaca_code_data_module = AlpacaCodeDataModule(
mathqa_alpaca_code_data_module = MathQaAlpacaCodeDataModule(
DatasetConfig(model="meta-llama/Llama-2-7b-hf")
)
alpaca_code_data_module.setup_dataset()
val_dataloder = alpaca_code_data_module.val_dataloader()
mathqa_alpaca_code_data_module.setup_dataset()
val_dataloder = mathqa_alpaca_code_data_module.val_dataloader()
for batch in val_dataloder:
print(batch)
breakpoint()

0 comments on commit c6a5ef7

Please sign in to comment.