Skip to content

Commit

Permalink
removed unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
lufre1 committed Jul 12, 2024
1 parent a90ca2e commit a550893
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 44 deletions.
2 changes: 0 additions & 2 deletions development/train_3d_model_with_lucchi.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,8 @@ def __getitem__(self, index):
raw = raw.view(image_shape)
raw = raw.squeeze(0)
raw = raw.repeat(1, 3, 1, 1)
# print("raw shape", raw.shape)
# wanted label shape: (1, z, y, x)
label = (label != 0).to(torch.float)
# print("label shape", label.shape)
return raw, label


Expand Down
52 changes: 11 additions & 41 deletions development/train_3d_model_with_lucchi_without_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,10 @@ def train(args):
patch_shape = args.patch_shape
bs = args.batch_size
#label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=False)
label_transform = torch_em.transform.label.labels_to_binary
ndim = 3
label_transform = torch_em.transform.label.MinSizeLabelTransform
ndim = 2
min_size = 50
max_sampling_attempts = 5000

if with_rois:
data_paths, rois_dict = get_data_paths_and_rois(data_dir, min_shape=patch_shape, with_thresholds=True)
Expand All @@ -190,61 +192,29 @@ def train(args):
raw_paths=data["train"], raw_key="raw",
label_paths=data["train"], label_key=label_key,
patch_shape=args.patch_shape, with_segmentation_decoder=False,
sampler=MinInstanceSampler(2),
sampler=MinInstanceSampler(2, min_size=min_size),
min_size=min_size,
raw_transform=raw_transform,
#rois=np.s_[64:, :, :],
#n_samples=200,
)
train_ds.max_sampling_attempts = max_sampling_attempts
train_loader = get_data_loader(train_ds, shuffle=True, batch_size=2)

val_ds = default_sam_dataset(
raw_paths=data["val"], raw_key="raw",
label_paths=data["val"], label_key=label_key,
patch_shape=args.patch_shape, with_segmentation_decoder=False,
sampler=MinInstanceSampler(2),
sampler=MinInstanceSampler(2, min_size=min_size),
min_size=min_size,
raw_transform=raw_transform,
#rois=np.s_[64:, :, :],
is_train=False,
#n_samples=25,
)
val_ds.max_sampling_attempts = max_sampling_attempts
val_loader = get_data_loader(val_ds, shuffle=True, batch_size=1)
# if with_rois:
# train_loader = torch_em.default_segmentation_loader(
# raw_paths=data["train"], raw_key="raw",
# label_paths=data["train"], label_key="labels/mitochondria",
# patch_shape=patch_shape, ndim=ndim, batch_size=bs,
# label_transform=label_transform, raw_transform=raw_transform,
# num_workers=n_workers,
# rois=rois_dict["train"]
# #rois=[np.s_[64:, :, :]] * len(data["train"])
# )
# val_loader = torch_em.default_segmentation_loader(
# raw_paths=data["val"], raw_key="raw",
# label_paths=data["val"], label_key="labels/mitochondria",
# patch_shape=patch_shape, ndim=ndim, batch_size=bs,
# label_transform=label_transform, raw_transform=raw_transform,
# num_workers=n_workers,
# rois=rois_dict["val"]
# # rois=[np.s_[64:, :, :]] * len(data["val"])
# )
# else:
# train_loader = torch_em.default_segmentation_loader(
# raw_paths=data["train"], raw_key="raw",
# label_paths=data["train"], label_key=label_key,
# patch_shape=patch_shape, ndim=ndim, batch_size=bs,
# label_transform=label_transform, raw_transform=raw_transform,
# num_workers=n_workers,
# )
# print("len data[val]", len(data["val"]))
# val_loader = torch_em.default_segmentation_loader(
# raw_paths=data["val"], raw_key="raw",
# label_paths=data["val"], label_key=label_key,
# patch_shape=patch_shape, ndim=ndim, batch_size=bs,
# label_transform=label_transform, raw_transform=raw_transform,
# num_workers=n_workers,
# )



#check_loader(train_loader, n_samples=3)
# x,y =next(iter(train_loader))
# print("shapes of x and y", x.shape, y.shape)
Expand Down
1 change: 0 additions & 1 deletion micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,6 @@ def segmentation_to_one_hot(
masks = segmentation.copy()
if segmentation_ids is None:
n_ids = int(segmentation.max())

else:
assert segmentation_ids[0] != 0, "No objects were found."

Expand Down

0 comments on commit a550893

Please sign in to comment.