Skip to content

Commit

Permalink
fixing training bug (#867)
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Feb 21, 2024
1 parent ee96e17 commit e0645b8
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 10 deletions.
3 changes: 2 additions & 1 deletion cellpose/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ def main():
model.net, images, labels, train_files=image_names,
test_data=test_images, test_labels=test_labels,
test_files=image_names_test, learning_rate=args.learning_rate,
weight_decay=args.weight_decay, channels=channels,
weight_decay=args.weight_decay, channels=channels,
channel_axis=args.channel_axis,
save_path=os.path.realpath(args.dir), save_every=args.save_every,
SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.batch_size,
min_train_masks=args.min_train_masks,
Expand Down
3 changes: 1 addition & 2 deletions cellpose/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,8 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None,
self.nchan = nchan
self.nclasses = 3
nbase = [32, 64, 128, 256]
self.nchan = nchan
self.nbase = [nchan, *nbase]

self.net = CPnet(self.nbase, self.nclasses, sz=3, mkldnn=self.mkldnn,
max_pool=True, diam_mean=diam_mean).to(self.device)

Expand Down
9 changes: 4 additions & 5 deletions cellpose/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _get_batch(inds, data=None, labels=None, files=None, labels_files=None,
if channels is not None:
imgs = [
transforms.convert_image(img, channels=channels,
channel_axis=channel_axis) for img in imgs
channel_axis=channel_axis, nchan=None) for img in imgs
]
imgs = [img.transpose(2, 0, 1) for img in imgs]
if normalize_params["normalize"]:
Expand Down Expand Up @@ -89,9 +89,9 @@ def _reshape_norm(data, channels=None, channel_axis=None,
Returns:
list: List of reshaped and normalized data.
"""
if channels is not None:
if channels is not None or channel_axis is not None:
data = [
transforms.convert_image(td, channels=channels, channel_axis=channel_axis)
transforms.convert_image(td, channels=channels, channel_axis=channel_axis, nchan=None)
for td in data
]
data = [td.transpose(2, 0, 1) for td in data]
Expand All @@ -111,7 +111,7 @@ def _reshape_norm_save(files, channels=None, channel_axis=None,
td = io.imread(f)
if channels is not None:
td = transforms.convert_image(td, channels=channels,
channel_axis=channel_axis)
channel_axis=channel_axis, nchan=None)
td = td.transpose(2, 0, 1)
if normalize_params["normalize"]:
td = transforms.normalize_img(td, normalize=normalize_params, axis=0)
Expand Down Expand Up @@ -336,7 +336,6 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
Returns:
Path: path to saved model weights
"""

device = net.device

scale_range = 0.5 if rescale else 1.0
Expand Down
4 changes: 2 additions & 2 deletions cellpose/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, ncha

else:
# code above put channels last
if x.shape[-1] > nchan:
if nchan is not None and x.shape[-1] > nchan:
transforms_logger.warning(
"WARNING: more than %d channels given, use 'channels' input for specifying channels - just using first %d channels to run processing"
% (nchan, nchan))
Expand All @@ -524,7 +524,7 @@ def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, ncha
transforms_logger.critical("ERROR: cannot process 4D images in 2D mode")
raise ValueError("ERROR: cannot process 4D images in 2D mode")

if x.shape[-1] < nchan:
if nchan is not None and x.shape[-1] < nchan:
x = np.concatenate((x, np.tile(np.zeros_like(x), (1, 1, nchan - 1))),
axis=-1)

Expand Down

0 comments on commit e0645b8

Please sign in to comment.