Skip to content

Commit

Permalink
fix code for training
Browse files Browse the repository at this point in the history
  • Loading branch information
davebulaval committed Feb 11, 2024
1 parent f972dc9 commit 3b08e6d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/training/few_shot_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
parser.add_argument(
"--seed",
type=int,
default=42,
default=45,
help="The seed to use for training.",
)

Expand All @@ -42,7 +42,7 @@
parser.add_argument(
"--data_augmentation",
type=bool_parse,
default=False,
default=True,
help="Either or not to do data augmentation.",
)

Expand Down Expand Up @@ -76,7 +76,7 @@ def tokenize_function(example):
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

training_args = TrainingArguments(
output_dir="test_trainer",
output_dir="meaning_bert_train",
logging_strategy="epoch",
evaluation_strategy="epoch",
per_device_train_batch_size=16,
Expand All @@ -86,7 +86,7 @@ def tokenize_function(example):
save_strategy="epoch",
load_best_model_at_end=True, # By default, use the eval loss to retrieve the best model.
seed=seed,
metric_for_best_model="loss",
metric_for_best_model="eval_loss",
)

# num_labels to 1 to create a regression head
Expand Down

0 comments on commit 3b08e6d

Please sign in to comment.