From 79fe3b17a3fb6b723b2e0f415a721a15c49c1e6e Mon Sep 17 00:00:00 2001 From: Daniel Franco Date: Tue, 4 Jun 2024 11:41:34 +0200 Subject: [PATCH] Adapting code to train with BMZ models and solve minor problems --- biapy/engine/base_workflow.py | 38 ++++++++++++++++++++--------- biapy/engine/check_configuration.py | 11 ++++++--- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/biapy/engine/base_workflow.py b/biapy/engine/base_workflow.py index 937e17b1..a7bfa6ec 100644 --- a/biapy/engine/base_workflow.py +++ b/biapy/engine/base_workflow.py @@ -12,6 +12,8 @@ from sklearn.model_selection import StratifiedKFold import torch.multiprocessing as mp import torch.distributed as dist +import xarray as xr +import bioimageio.core from biapy.models import build_model, build_torchvision_model from biapy.engine import prepare_optimizer, build_callbacks @@ -131,20 +133,26 @@ def __init__(self, cfg, job_identifier, device, args): self.bmz_test_output = None self.bmz_model_resource = None if self.cfg.MODEL.SOURCE == "bmz": - import bioimageio.core - import xarray as xr - print("Loading Bioimage Model Zoo pretrained model . . .") self.bmz_model_resource = bioimageio.core.load_resource_description(self.cfg.MODEL.BMZ.SOURCE_MODEL_DOI) + # Temporal adjust until we find a solution with the BMZ team + bs = self.cfg.TRAIN.BATCH_SIZE + if self.cfg.MODEL.BMZ.SOURCE_MODEL_DOI == "10.5281/zenodo.5874841": + bs = 2 + elif self.cfg.MODEL.BMZ.SOURCE_MODEL_DOI == "10.5281/zenodo.6028097": + bs = 4 + # Change PATCH_SIZE with the one stored in the RDF input_image = np.load(self.bmz_model_resource.test_inputs[0]) - opts = ["DATA.PATCH_SIZE", input_image.shape[2:]+(input_image.shape[1],)] + opts = ["DATA.PATCH_SIZE", input_image.shape[2:]+(input_image.shape[1],), + "TRAIN.BATCH_SIZE", bs] print("[BMZ] Changed 'DATA.PATCH_SIZE' from {} to {} as defined in the RDF" - .format(self.cfg.DATA.PATCH_SIZE,opts[1])) + .format(self.cfg.DATA.PATCH_SIZE,opts[1])) + print("[BMZ] Changed 'TRAIN.BATCH_SIZE' from {} to {} as defined in the RDF" + .format(self.cfg.TRAIN.BATCH_SIZE,opts[3])) self.cfg.merge_from_list(opts) - @abstractmethod def define_metrics(self): """ @@ -450,12 +458,16 @@ def bmz_model_call(self, in_img, is_train=False): # Predict prediction_tensors = self.model.predict(*list(in_img.values())) - # Apply post-processing - prediction = dict(zip([out.name for out in self.model.output_specs], prediction_tensors)) - self.model.apply_postprocessing(prediction, self.bmz_computed_measures) + # Apply post-processing (if any) + if bool(self.model.output_specs[0].postprocessing): + prediction = dict(zip([out.name for out in self.model.output_specs], prediction_tensors)) + self.model.apply_postprocessing(prediction, self.bmz_computed_measures) - # Convert back to Tensor - prediction = torch.from_numpy(prediction['output0'].to_numpy()) + # Convert back to Tensor + prediction = torch.from_numpy(prediction['output0'].to_numpy()) + else: + # Convert back to Tensor + prediction = torch.from_numpy(np.array(prediction_tensors[0])) return prediction @@ -822,6 +834,10 @@ def apply_model_activations(self, pred, training=False): pred : Torch tensor Resulting predictions after applying last activation(s). """ + # Not apply the activation, as it will be done in the BMZ model + if self.cfg.MODEL.SOURCE == "bmz": + return pred + if not isinstance(pred, list): multiple_heads = False pred = [pred] diff --git a/biapy/engine/check_configuration.py b/biapy/engine/check_configuration.py index b8bd14c2..9999cb13 100644 --- a/biapy/engine/check_configuration.py +++ b/biapy/engine/check_configuration.py @@ -170,8 +170,6 @@ def check_configuration(cfg, jobname, check_data_paths=True): raise ValueError("'MODEL.SOURCE' needs to be one between ['biapy', 'bmz', 'torchvision']") if cfg.MODEL.SOURCE == "bmz": - if cfg.TRAIN.ENABLE: - raise ValueError("Currently not supported to train a BMZ model") if cfg.MODEL.BMZ.SOURCE_MODEL_DOI == "": raise ValueError("'MODEL.BMZ.SOURCE_MODEL_DOI' needs to be configured when 'MODEL.SOURCE' is 'bmz'") @@ -572,8 +570,13 @@ def check_configuration(cfg, jobname, check_data_paths=True): raise ValueError("When PROBLEM.NDIM == {} DATA.TEST.PADDING tuple must be length {}, given {}." .format(cfg.PROBLEM.NDIM, dim_count, cfg.DATA.TEST.PADDING)) if len(cfg.DATA.PATCH_SIZE) != dim_count+1: - raise ValueError("When PROBLEM.NDIM == {} DATA.PATCH_SIZE tuple must be length {}, given {}." - .format(cfg.PROBLEM.NDIM, dim_count+1, cfg.DATA.PATCH_SIZE)) + if cfg.MODEL.SOURCE != "bmz": + raise ValueError("When PROBLEM.NDIM == {} DATA.PATCH_SIZE tuple must be length {}, given {}." + .format(cfg.PROBLEM.NDIM, dim_count+1, cfg.DATA.PATCH_SIZE)) + else: + print("WARNING: when PROBLEM.NDIM == {} DATA.PATCH_SIZE tuple must be length {}, given {}. Not an error " + "because you are using a model from Bioimage Model Zoo (BMZ) and the patch size will be determined by the model." + " However, this message is printed so you are aware of this. ") assert cfg.DATA.NORMALIZATION.TYPE in ['div', 'custom'], "DATA.NORMALIZATION.TYPE not in ['div', 'custom']" assert cfg.DATA.NORMALIZATION.APPLICATION_MODE in ["image", "dataset"], "'DATA.NORMALIZATION.APPLICATION_MODE' needs to be one between ['image', 'dataset']" if not cfg.DATA.TRAIN.IN_MEMORY and cfg.DATA.NORMALIZATION.APPLICATION_MODE == "dataset":