diff --git a/development/train_3d_model_with_lucchi.py b/development/train_3d_model_with_lucchi.py index 3bfe32d0..4b41dd1a 100644 --- a/development/train_3d_model_with_lucchi.py +++ b/development/train_3d_model_with_lucchi.py @@ -120,21 +120,26 @@ def train_on_lucchi(args): num_workers = args.num_workers n_classes = args.n_classes model_type = args.model_type - n_iterations = args.n_iterations + n_epochs = args.n_epochs save_root = args.save_root device = "cuda" if torch.cuda.is_available() else "cpu" - sam_3d = get_sam_3d_model( - device, n_classes=n_classes, image_size=patch_shape[1], - model_type=model_type, lora_rank=4) + if args.without_lora: + sam_3d = get_sam_3d_model( + device, n_classes=n_classes, image_size=patch_shape[1], + model_type=model_type, lora_rank=None) # freeze encoder + else: + sam_3d = get_sam_3d_model( + device, n_classes=n_classes, image_size=patch_shape[1], + model_type=model_type, lora_rank=4) train_loader, val_loader = get_loaders(input_path=input_path, patch_shape=patch_shape) optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), weight_decay=0.1) trainer = SemanticSamTrainer( - name="3d-sam-vith-masamhyp-lucchi", + name=args.exp_name, model=sam_3d, convert_inputs=ConvertToSemanticSamInputs(), num_classes=n_classes, @@ -147,7 +152,7 @@ def train_on_lucchi(args): #logger=None ) # check_loader(train_loader, n_samples=10) - trainer.fit(epochs=n_iterations) + trainer.fit(epochs=n_epochs) def main(): @@ -160,16 +165,22 @@ def main(): "--model_type", "-m", default="vit_b", help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." ) + parser.add_argument("--without_lora", action="store_true", help="Whether to use LoRA for finetuning SAM for semantic segmentation.") parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)") - parser.add_argument("--n_iterations", type=int, default=10, help="Number of training iterations") + + parser.add_argument("--n_epochs", type=int, default=400, help="Number of training epochs") parser.add_argument("--n_classes", type=int, default=3, help="Number of classes to predict") - parser.add_argument("--batch_size", type=int, default=3, help="Batch size") + parser.add_argument("--batch_size", "-bs", type=int, default=3, help="Batch size") parser.add_argument("--num_workers", type=int, default=4, help="num_workers") parser.add_argument("--learning_rate", type=float, default=0.0008, help="base learning rate") parser.add_argument( "--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam3d", help="The filepath to where the logs and the checkpoints will be saved." ) + parser.add_argument( + "--exp_name", default="vitb_3d_lora4", + help="The filepath to where the logs and the checkpoints will be saved." + ) args = parser.parse_args() train_on_lucchi(args)