diff --git a/3-Self-Supervised-Eval/patch_extraction_utils.py b/3-Self-Supervised-Eval/patch_extraction_utils.py index 3a92c796..3037fc93 100644 --- a/3-Self-Supervised-Eval/patch_extraction_utils.py +++ b/3-Self-Supervised-Eval/patch_extraction_utils.py @@ -132,7 +132,7 @@ def create_embeddings(embeddings_dir, enc_name, dataset, save_patches=False, spr model = resnet50_trunc_baseline(pretrained=True) eval_t = eval_transforms(pretrained=True) elif 'dino' in enc_name: - ckpt_path = os.path.join(assets_dir, enc_name+'.pt') + ckpt_path = os.path.join(assets_dir, enc_name+'.pth') assert os.path.isfile(ckpt_path) model = vit_small(patch_size=16) state_dict = torch.load(ckpt_path, map_location="cpu")['teacher'] @@ -143,7 +143,7 @@ def create_embeddings(embeddings_dir, enc_name, dataset, save_patches=False, spr #print("Unexpected Keys:", unexpected_keys) eval_t = eval_transforms(pretrained=False) elif 'simclr' in enc_name: - ckpt_path = os.path.join(assets_dir, enc_name+'.pt') + ckpt_path = os.path.join(assets_dir, enc_name+'.pth') assert os.path.isfile(ckpt_path) model = torchvision_ssl_encoder('resnet50', pretrained=True) missing_keys, unexpected_keys = model.load_state_dict(torch.load(ckpt_path), strict=False) @@ -243,4 +243,4 @@ def create_embeddings(embeddings_dir, enc_name, dataset, save_patches=False, spr save_embeddings(model=model, fname=train_fname, dataloader=train_dataloader, save_patches=save_patches, sprite_dim=sprite_dim) save_embeddings(model=model, fname=val_fname, dataloader=val_dataloader, - save_patches=save_patches, sprite_dim=sprite_dim) \ No newline at end of file + save_patches=save_patches, sprite_dim=sprite_dim)