Skip to content

Commit

Permalink
changed training for n iterations to n epochs
Browse files Browse the repository at this point in the history
  • Loading branch information
lufre1 committed Jul 9, 2024
1 parent 63b4654 commit eaacf7a
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions development/train_3d_model_with_lucchi.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,21 +120,26 @@ def train_on_lucchi(args):
num_workers = args.num_workers
n_classes = args.n_classes
model_type = args.model_type
n_iterations = args.n_iterations
n_epochs = args.n_epochs
save_root = args.save_root



device = "cuda" if torch.cuda.is_available() else "cpu"
sam_3d = get_sam_3d_model(
device, n_classes=n_classes, image_size=patch_shape[1],
model_type=model_type, lora_rank=4)
if args.without_lora:
sam_3d = get_sam_3d_model(
device, n_classes=n_classes, image_size=patch_shape[1],
model_type=model_type, lora_rank=None) # freeze encoder
else:
sam_3d = get_sam_3d_model(
device, n_classes=n_classes, image_size=patch_shape[1],
model_type=model_type, lora_rank=4)
train_loader, val_loader = get_loaders(input_path=input_path, patch_shape=patch_shape)
optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), weight_decay=0.1)


trainer = SemanticSamTrainer(
name="3d-sam-vith-masamhyp-lucchi",
name=args.exp_name,
model=sam_3d,
convert_inputs=ConvertToSemanticSamInputs(),
num_classes=n_classes,
Expand All @@ -147,7 +152,7 @@ def train_on_lucchi(args):
#logger=None
)
# check_loader(train_loader, n_samples=10)
trainer.fit(epochs=n_iterations)
trainer.fit(epochs=n_epochs)


def main():
Expand All @@ -160,16 +165,22 @@ def main():
"--model_type", "-m", default="vit_b",
help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h."
)
parser.add_argument("--without_lora", action="store_true", help="Whether to use LoRA for finetuning SAM for semantic segmentation.")
parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)")
parser.add_argument("--n_iterations", type=int, default=10, help="Number of training iterations")

parser.add_argument("--n_epochs", type=int, default=400, help="Number of training epochs")
parser.add_argument("--n_classes", type=int, default=3, help="Number of classes to predict")
parser.add_argument("--batch_size", type=int, default=3, help="Batch size")
parser.add_argument("--batch_size", "-bs", type=int, default=3, help="Batch size")
parser.add_argument("--num_workers", type=int, default=4, help="num_workers")
parser.add_argument("--learning_rate", type=float, default=0.0008, help="base learning rate")
parser.add_argument(
"--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam3d",
help="The filepath to where the logs and the checkpoints will be saved."
)
parser.add_argument(
"--exp_name", default="vitb_3d_lora4",
help="The filepath to where the logs and the checkpoints will be saved."
)

args = parser.parse_args()
train_on_lucchi(args)
Expand Down

0 comments on commit eaacf7a

Please sign in to comment.