From 3b08e6d13aec4504c34523c1e9f1526076d3ab2e Mon Sep 17 00:00:00 2001 From: davebulaval Date: Sun, 11 Feb 2024 10:34:02 -0500 Subject: [PATCH] fix code for training --- src/training/few_shot_training.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/training/few_shot_training.py b/src/training/few_shot_training.py index f551de4..8f43a0c 100644 --- a/src/training/few_shot_training.py +++ b/src/training/few_shot_training.py @@ -28,7 +28,7 @@ parser.add_argument( "--seed", type=int, - default=42, + default=45, help="The seed to use for training.", ) @@ -42,7 +42,7 @@ parser.add_argument( "--data_augmentation", type=bool_parse, - default=False, + default=True, help="Either or not to do data augmentation.", ) @@ -76,7 +76,7 @@ def tokenize_function(example): data_collator = DataCollatorWithPadding(tokenizer=tokenizer) training_args = TrainingArguments( - output_dir="test_trainer", + output_dir="meaning_bert_train", logging_strategy="epoch", evaluation_strategy="epoch", per_device_train_batch_size=16, @@ -86,7 +86,7 @@ def tokenize_function(example): save_strategy="epoch", load_best_model_at_end=True, # By default, use the eval loss to retrieve the best model. seed=seed, - metric_for_best_model="loss", + metric_for_best_model="eval_loss", ) # num_labels to 1 to create a regression head