Skip to content

Commit

Permalink
support for train_file and validation_file
Browse files Browse the repository at this point in the history
  • Loading branch information
vidyasiv committed Jun 4, 2024
1 parent 3069818 commit 9e96f2b
Showing 1 changed file with 47 additions and 29 deletions.
76 changes: 47 additions & 29 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,33 @@ def main():
cache_dir=model_args.cache_dir,
token=model_args.token,
)
if data_args.dataset_name == "tatsu-lab/alpaca" or data_args.sql_prompt:
# Preprocessing the datasets.
for key in raw_datasets:
prompts = (
create_prompts(raw_datasets[key])
if not data_args.sql_prompt
else create_sql_prompts(raw_datasets[key])
)
columns_to_be_removed = list(raw_datasets[key].features.keys())
raw_datasets[key] = raw_datasets[key].add_column("prompt_sources", prompts["source"])
raw_datasets[key] = raw_datasets[key].add_column("prompt_targets", prompts["target"])
raw_datasets[key] = raw_datasets[key].remove_columns(columns_to_be_removed)
elif (
data_args.dataset_name == "timdettmers/openassistant-guanaco"
): # from https://github.com/artidoro/qlora/blob/main/qlora.py#L621
raw_datasets = raw_datasets.map(
lambda x: {
"input": "",
"output": x["text"],
}
)
# Remove unused columns.
raw_datasets = raw_datasets.remove_columns(
[col for col in raw_datasets.column_names["train"] if col not in ["input", "output"]]
)
else:
raise ValueError("Unsupported dataset")
else:
data_files = {}
dataset_args = {}
Expand All @@ -512,6 +539,19 @@ def main():
**dataset_args,
)

if data_args.train_file and training_args.do_train:
print([x for x in raw_datasets])
raw_datasets = raw_datasets.map(
lambda x: {
"input": "",
"output": x["text"],
}
)
# Remove unused columns.
raw_datasets = raw_datasets.remove_columns(
[col for col in raw_datasets.column_names["train"] if col not in ["input", "output"]]
)

# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys() and training_args.do_eval:
raw_datasets["validation"] = load_dataset(
Expand All @@ -530,34 +570,6 @@ def main():
token=model_args.token,
**dataset_args,
)

if data_args.dataset_name == "tatsu-lab/alpaca" or data_args.sql_prompt:
# Preprocessing the datasets.
for key in raw_datasets:
prompts = (
create_prompts(raw_datasets[key])
if not data_args.sql_prompt
else create_sql_prompts(raw_datasets[key])
)
columns_to_be_removed = list(raw_datasets[key].features.keys())
raw_datasets[key] = raw_datasets[key].add_column("prompt_sources", prompts["source"])
raw_datasets[key] = raw_datasets[key].add_column("prompt_targets", prompts["target"])
raw_datasets[key] = raw_datasets[key].remove_columns(columns_to_be_removed)
elif (
data_args.dataset_name == "timdettmers/openassistant-guanaco"
): # from https://github.com/artidoro/qlora/blob/main/qlora.py#L621
raw_datasets = raw_datasets.map(
lambda x: {
"input": "",
"output": x["text"],
}
)
# Remove unused columns.
raw_datasets = raw_datasets.remove_columns(
[col for col in raw_datasets.column_names["train"] if col not in ["input", "output"]]
)
else:
raise ValueError("Unsupported dataset")
# Load model
if model_args.model_name_or_path:
model_dtype = torch.bfloat16 if training_args.bf16 else None
Expand Down Expand Up @@ -671,9 +683,15 @@ def concatenate_data(dataset, max_seq_length):
tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["input", "output"])
if training_args.do_eval:
tokenized_datasets_eval_ = tokenized_datasets["test"].remove_columns(["input", "output"])
elif data_args.dataset_name is None:
if training_args.do_train:
tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["input", "output"])
if training_args.do_eval:
tokenized_datasets_eval_ = tokenized_datasets["validation"].remove_columns(["input", "output"])
else:
raise ValueError("Unsupported dataset")
tokenized_datasets["train"] = concatenate_data(tokenized_datasets_, data_args.max_seq_length)
if training_args.do_train:
tokenized_datasets["train"] = concatenate_data(tokenized_datasets_, data_args.max_seq_length)
if training_args.do_eval:
tokenized_datasets["validation"] = concatenate_data(tokenized_datasets_eval_, data_args.max_seq_length)
if training_args.do_train:
Expand Down

0 comments on commit 9e96f2b

Please sign in to comment.