Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
davebulaval committed Feb 11, 2024
1 parent 5ed9673 commit ef13b39
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 19 deletions.
18 changes: 1 addition & 17 deletions src/training/few_shot_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
log.propagate = False
log.setLevel(logging.ERROR)

device = "cuda:0"
num_epoch = 250

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -124,19 +123,4 @@ def tokenize_function(example):
metric_key_prefix="test/unrelated_sentences",
)

# Local import and model delete to reduce memory usage on GPU
del model
del trainer

from evaluate_metrics import compute_other_metrics_performance

print("----------Test Set Evaluation start of Other Metrics----------")
compute_other_metrics_performance(
test_set=tokenized_csmd_dataset["test"],
holdout_identical_set=tokenize_holdout_identical_dataset["test"],
holdout_unrelated_set=tokenize_holdout_unrelated_dataset["test"],
logger=wandb,
device=device,
)

trainer.save_model(f"test_trainer{seed}")
trainer.save_model(f"meaningbert_best_model")
3 changes: 1 addition & 2 deletions src/training_only_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
evaluate
datasets
tqdm
# sacremoses
# rake_nltk
scipy
wandb
scikit-learn
accelerate
poutyne
torchmetrics
transformers

# pip install torch==1.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116

0 comments on commit ef13b39

Please sign in to comment.