From e0645b8e60762e0af96b6b91bc4ff7201616f424 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Wed, 21 Feb 2024 13:56:43 -0500 Subject: [PATCH] fixing training bug (#867) --- cellpose/__main__.py | 3 ++- cellpose/models.py | 3 +-- cellpose/train.py | 9 ++++----- cellpose/transforms.py | 4 ++-- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/cellpose/__main__.py b/cellpose/__main__.py index ff46347d..f13affb7 100644 --- a/cellpose/__main__.py +++ b/cellpose/__main__.py @@ -231,7 +231,8 @@ def main(): model.net, images, labels, train_files=image_names, test_data=test_images, test_labels=test_labels, test_files=image_names_test, learning_rate=args.learning_rate, - weight_decay=args.weight_decay, channels=channels, + weight_decay=args.weight_decay, channels=channels, + channel_axis=args.channel_axis, save_path=os.path.realpath(args.dir), save_every=args.save_every, SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.batch_size, min_train_masks=args.min_train_masks, diff --git a/cellpose/models.py b/cellpose/models.py index 5eb60958..e6329ec5 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -283,9 +283,8 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, self.nchan = nchan self.nclasses = 3 nbase = [32, 64, 128, 256] - self.nchan = nchan self.nbase = [nchan, *nbase] - + self.net = CPnet(self.nbase, self.nclasses, sz=3, mkldnn=self.mkldnn, max_pool=True, diam_mean=diam_mean).to(self.device) diff --git a/cellpose/train.py b/cellpose/train.py index 6ecc76a1..4b5a4788 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -60,7 +60,7 @@ def _get_batch(inds, data=None, labels=None, files=None, labels_files=None, if channels is not None: imgs = [ transforms.convert_image(img, channels=channels, - channel_axis=channel_axis) for img in imgs + channel_axis=channel_axis, nchan=None) for img in imgs ] imgs = [img.transpose(2, 0, 1) for img in imgs] if normalize_params["normalize"]: @@ -89,9 +89,9 @@ def _reshape_norm(data, channels=None, channel_axis=None, Returns: list: List of reshaped and normalized data. """ - if channels is not None: + if channels is not None or channel_axis is not None: data = [ - transforms.convert_image(td, channels=channels, channel_axis=channel_axis) + transforms.convert_image(td, channels=channels, channel_axis=channel_axis, nchan=None) for td in data ] data = [td.transpose(2, 0, 1) for td in data] @@ -111,7 +111,7 @@ def _reshape_norm_save(files, channels=None, channel_axis=None, td = io.imread(f) if channels is not None: td = transforms.convert_image(td, channels=channels, - channel_axis=channel_axis) + channel_axis=channel_axis, nchan=None) td = td.transpose(2, 0, 1) if normalize_params["normalize"]: td = transforms.normalize_img(td, normalize=normalize_params, axis=0) @@ -336,7 +336,6 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, Returns: Path: path to saved model weights """ - device = net.device scale_range = 0.5 if rescale else 1.0 diff --git a/cellpose/transforms.py b/cellpose/transforms.py index 2d630c3c..0d5f994e 100644 --- a/cellpose/transforms.py +++ b/cellpose/transforms.py @@ -514,7 +514,7 @@ def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, ncha else: # code above put channels last - if x.shape[-1] > nchan: + if nchan is not None and x.shape[-1] > nchan: transforms_logger.warning( "WARNING: more than %d channels given, use 'channels' input for specifying channels - just using first %d channels to run processing" % (nchan, nchan)) @@ -524,7 +524,7 @@ def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, ncha transforms_logger.critical("ERROR: cannot process 4D images in 2D mode") raise ValueError("ERROR: cannot process 4D images in 2D mode") - if x.shape[-1] < nchan: + if nchan is not None and x.shape[-1] < nchan: x = np.concatenate((x, np.tile(np.zeros_like(x), (1, 1, nchan - 1))), axis=-1)