Skip to content

Commit

Permalink
Make covid if training parameters consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Jul 17, 2024
1 parent 1819a83 commit ba180ee
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
21 changes: 14 additions & 7 deletions scripts/for_benchmarking_ais/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,28 @@ def get_loaders(path, patch_shape, dataset, for_sam=False):

# Let's get the number of samples extracted, to set the "n_samples" value
# This is done to avoid the time taken to save checkpoints over fewer training images.
_train_loader = get_covid_if_loader(
path=data_path, patch_shape=patch_shape, batch_size=1, sample_range=train_volumes
_loader = get_covid_if_loader(
path=data_path, patch_shape=patch_shape, batch_size=2, sample_range=train_volumes
)

print(
f"Found {len(train_loader)} samples for training. ",
"Hence, we will use {0} samples for training.".format(50 if len(_train_loader) < 50 else len(_train_loader))
f"Found {len(_loader)} samples for training. ",
"Hence, we will use {0} samples for training.".format(100 if len(_loader) < 50 else len(_loader))
)

# Finally, let's get the dataloaders
train_loader = get_covid_if_loader(
path=data_path, batch_size=2, sample_range=train_volumes, n_samples=50 if len(train_loader) < 50 else None,
path=data_path,
batch_size=2,
sample_range=train_volumes,
n_samples=100 if len(_loader) < 50 else None,
**kwargs
)
val_loader = get_covid_if_loader(
path=data_path, batch_size=1, sample_range=val_volumes,
path=data_path,
batch_size=1,
sample_range=val_volumes,
**kwargs
)

else:
Expand Down Expand Up @@ -217,7 +224,7 @@ def prediction_fn(net, inp):

def get_default_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataset", type=str, required=True,)
parser.add_argument("-d", "--dataset", type=str, required=True)
parser.add_argument("-i", "--input_path", type=str, default="/scratch/projects/nim00007/sam/data")
parser.add_argument("-s", "--save_root", type=str, default=None)
parser.add_argument("-p", "--phase", type=str, default=None, choices=["train", "predict"])
Expand Down
2 changes: 2 additions & 0 deletions scripts/for_benchmarking_ais/train_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def main(args):
iterations=args.iterations,
model=model,
device=device,
dataset=args.dataset,
)

if args.phase == "predict":
Expand All @@ -40,6 +41,7 @@ def main(args):
model=model,
device=device,
result_path=result_path,
dataset=args.dataset,
)


Expand Down

0 comments on commit ba180ee

Please sign in to comment.