diff --git a/development/train_3d_model_with_lucchi_without_decoder.py b/development/train_3d_model_with_lucchi_without_decoder.py index eb8d6eb3..bac541be 100644 --- a/development/train_3d_model_with_lucchi_without_decoder.py +++ b/development/train_3d_model_with_lucchi_without_decoder.py @@ -177,7 +177,7 @@ def train(args): #label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=False) label_transform = torch_em.transform.label.MinSizeLabelTransform ndim = 2 - min_size = 50 + min_size = 100 max_sampling_attempts = 5000 if with_rois: @@ -198,7 +198,8 @@ def train(args): #rois=np.s_[64:, :, :], #n_samples=200, ) - train_ds.max_sampling_attempts = max_sampling_attempts + for ds in train_ds.datasets: + ds.max_sampling_attempts = max_sampling_attempts train_loader = get_data_loader(train_ds, shuffle=True, batch_size=2) val_ds = default_sam_dataset( @@ -212,7 +213,8 @@ def train(args): is_train=False, #n_samples=25, ) - val_ds.max_sampling_attempts = max_sampling_attempts + for ds in val_ds.datasets: + ds.max_sampling_attempts = max_sampling_attempts val_loader = get_data_loader(val_ds, shuffle=True, batch_size=1) #check_loader(train_loader, n_samples=3) @@ -221,7 +223,7 @@ def train(args): # breakpoint() train_sam( - name="mito_model", model_type="vit_b", + name=args.exp_name, model_type="vit_b", train_loader=train_loader, val_loader=val_loader, n_epochs=50, n_objects_per_batch=10, with_segmentation_decoder=False, @@ -252,7 +254,7 @@ def main(): help="The filepath to where the logs and the checkpoints will be saved." ) parser.add_argument( - "--exp_name", default="vitb_3d_lora4-microsam-hypam-lucchi", + "--exp_name", default="vitb_3d-mitotomo", help="The filepath to where the logs and the checkpoints will be saved." ) parser.add_argument("--without_rois", action="store_true", help="Train without Regions Of Interest (ROI)")