diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index 9aad1beecc..3a13200619 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -489,6 +489,33 @@ def main(): cache_dir=model_args.cache_dir, token=model_args.token, ) + if data_args.dataset_name == "tatsu-lab/alpaca" or data_args.sql_prompt: + # Preprocessing the datasets. + for key in raw_datasets: + prompts = ( + create_prompts(raw_datasets[key]) + if not data_args.sql_prompt + else create_sql_prompts(raw_datasets[key]) + ) + columns_to_be_removed = list(raw_datasets[key].features.keys()) + raw_datasets[key] = raw_datasets[key].add_column("prompt_sources", prompts["source"]) + raw_datasets[key] = raw_datasets[key].add_column("prompt_targets", prompts["target"]) + raw_datasets[key] = raw_datasets[key].remove_columns(columns_to_be_removed) + elif ( + data_args.dataset_name == "timdettmers/openassistant-guanaco" + ): # from https://github.com/artidoro/qlora/blob/main/qlora.py#L621 + raw_datasets = raw_datasets.map( + lambda x: { + "input": "", + "output": x["text"], + } + ) + # Remove unused columns. + raw_datasets = raw_datasets.remove_columns( + [col for col in raw_datasets.column_names["train"] if col not in ["input", "output"]] + ) + else: + raise ValueError("Unsupported dataset") else: data_files = {} dataset_args = {} @@ -512,6 +539,19 @@ def main(): **dataset_args, ) + if data_args.train_file and training_args.do_train: + print([x for x in raw_datasets]) + raw_datasets = raw_datasets.map( + lambda x: { + "input": "", + "output": x["text"], + } + ) + # Remove unused columns. + raw_datasets = raw_datasets.remove_columns( + [col for col in raw_datasets.column_names["train"] if col not in ["input", "output"]] + ) + # If no validation data is there, validation_split_percentage will be used to divide the dataset. if "validation" not in raw_datasets.keys() and training_args.do_eval: raw_datasets["validation"] = load_dataset( @@ -530,34 +570,6 @@ def main(): token=model_args.token, **dataset_args, ) - - if data_args.dataset_name == "tatsu-lab/alpaca" or data_args.sql_prompt: - # Preprocessing the datasets. - for key in raw_datasets: - prompts = ( - create_prompts(raw_datasets[key]) - if not data_args.sql_prompt - else create_sql_prompts(raw_datasets[key]) - ) - columns_to_be_removed = list(raw_datasets[key].features.keys()) - raw_datasets[key] = raw_datasets[key].add_column("prompt_sources", prompts["source"]) - raw_datasets[key] = raw_datasets[key].add_column("prompt_targets", prompts["target"]) - raw_datasets[key] = raw_datasets[key].remove_columns(columns_to_be_removed) - elif ( - data_args.dataset_name == "timdettmers/openassistant-guanaco" - ): # from https://github.com/artidoro/qlora/blob/main/qlora.py#L621 - raw_datasets = raw_datasets.map( - lambda x: { - "input": "", - "output": x["text"], - } - ) - # Remove unused columns. - raw_datasets = raw_datasets.remove_columns( - [col for col in raw_datasets.column_names["train"] if col not in ["input", "output"]] - ) - else: - raise ValueError("Unsupported dataset") # Load model if model_args.model_name_or_path: model_dtype = torch.bfloat16 if training_args.bf16 else None @@ -671,9 +683,15 @@ def concatenate_data(dataset, max_seq_length): tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["input", "output"]) if training_args.do_eval: tokenized_datasets_eval_ = tokenized_datasets["test"].remove_columns(["input", "output"]) + elif data_args.dataset_name is None: + if training_args.do_train: + tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["input", "output"]) + if training_args.do_eval: + tokenized_datasets_eval_ = tokenized_datasets["validation"].remove_columns(["input", "output"]) else: raise ValueError("Unsupported dataset") - tokenized_datasets["train"] = concatenate_data(tokenized_datasets_, data_args.max_seq_length) + if training_args.do_train: + tokenized_datasets["train"] = concatenate_data(tokenized_datasets_, data_args.max_seq_length) if training_args.do_eval: tokenized_datasets["validation"] = concatenate_data(tokenized_datasets_eval_, data_args.max_seq_length) if training_args.do_train: