Skip to content

Commit

Permalink
Adapting code to train with BMZ models and solve minor problems
Browse files Browse the repository at this point in the history
  • Loading branch information
danifranco committed Jun 4, 2024
1 parent a2b7627 commit 79fe3b1
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
38 changes: 27 additions & 11 deletions biapy/engine/base_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from sklearn.model_selection import StratifiedKFold
import torch.multiprocessing as mp
import torch.distributed as dist
import xarray as xr
import bioimageio.core

from biapy.models import build_model, build_torchvision_model
from biapy.engine import prepare_optimizer, build_callbacks
Expand Down Expand Up @@ -131,20 +133,26 @@ def __init__(self, cfg, job_identifier, device, args):
self.bmz_test_output = None
self.bmz_model_resource = None
if self.cfg.MODEL.SOURCE == "bmz":
import bioimageio.core
import xarray as xr

print("Loading Bioimage Model Zoo pretrained model . . .")
self.bmz_model_resource = bioimageio.core.load_resource_description(self.cfg.MODEL.BMZ.SOURCE_MODEL_DOI)

# Temporal adjust until we find a solution with the BMZ team
bs = self.cfg.TRAIN.BATCH_SIZE
if self.cfg.MODEL.BMZ.SOURCE_MODEL_DOI == "10.5281/zenodo.5874841":
bs = 2
elif self.cfg.MODEL.BMZ.SOURCE_MODEL_DOI == "10.5281/zenodo.6028097":
bs = 4

# Change PATCH_SIZE with the one stored in the RDF
input_image = np.load(self.bmz_model_resource.test_inputs[0])
opts = ["DATA.PATCH_SIZE", input_image.shape[2:]+(input_image.shape[1],)]
opts = ["DATA.PATCH_SIZE", input_image.shape[2:]+(input_image.shape[1],),
"TRAIN.BATCH_SIZE", bs]
print("[BMZ] Changed 'DATA.PATCH_SIZE' from {} to {} as defined in the RDF"
.format(self.cfg.DATA.PATCH_SIZE,opts[1]))
.format(self.cfg.DATA.PATCH_SIZE,opts[1]))
print("[BMZ] Changed 'TRAIN.BATCH_SIZE' from {} to {} as defined in the RDF"
.format(self.cfg.TRAIN.BATCH_SIZE,opts[3]))
self.cfg.merge_from_list(opts)


@abstractmethod
def define_metrics(self):
"""
Expand Down Expand Up @@ -450,12 +458,16 @@ def bmz_model_call(self, in_img, is_train=False):
# Predict
prediction_tensors = self.model.predict(*list(in_img.values()))

# Apply post-processing
prediction = dict(zip([out.name for out in self.model.output_specs], prediction_tensors))
self.model.apply_postprocessing(prediction, self.bmz_computed_measures)
# Apply post-processing (if any)
if bool(self.model.output_specs[0].postprocessing):
prediction = dict(zip([out.name for out in self.model.output_specs], prediction_tensors))
self.model.apply_postprocessing(prediction, self.bmz_computed_measures)

# Convert back to Tensor
prediction = torch.from_numpy(prediction['output0'].to_numpy())
# Convert back to Tensor
prediction = torch.from_numpy(prediction['output0'].to_numpy())
else:
# Convert back to Tensor
prediction = torch.from_numpy(np.array(prediction_tensors[0]))

return prediction

Expand Down Expand Up @@ -822,6 +834,10 @@ def apply_model_activations(self, pred, training=False):
pred : Torch tensor
Resulting predictions after applying last activation(s).
"""
# Not apply the activation, as it will be done in the BMZ model
if self.cfg.MODEL.SOURCE == "bmz":
return pred

if not isinstance(pred, list):
multiple_heads = False
pred = [pred]
Expand Down
11 changes: 7 additions & 4 deletions biapy/engine/check_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ def check_configuration(cfg, jobname, check_data_paths=True):
raise ValueError("'MODEL.SOURCE' needs to be one between ['biapy', 'bmz', 'torchvision']")

if cfg.MODEL.SOURCE == "bmz":
if cfg.TRAIN.ENABLE:
raise ValueError("Currently not supported to train a BMZ model")
if cfg.MODEL.BMZ.SOURCE_MODEL_DOI == "":
raise ValueError("'MODEL.BMZ.SOURCE_MODEL_DOI' needs to be configured when 'MODEL.SOURCE' is 'bmz'")

Expand Down Expand Up @@ -572,8 +570,13 @@ def check_configuration(cfg, jobname, check_data_paths=True):
raise ValueError("When PROBLEM.NDIM == {} DATA.TEST.PADDING tuple must be length {}, given {}."
.format(cfg.PROBLEM.NDIM, dim_count, cfg.DATA.TEST.PADDING))
if len(cfg.DATA.PATCH_SIZE) != dim_count+1:
raise ValueError("When PROBLEM.NDIM == {} DATA.PATCH_SIZE tuple must be length {}, given {}."
.format(cfg.PROBLEM.NDIM, dim_count+1, cfg.DATA.PATCH_SIZE))
if cfg.MODEL.SOURCE != "bmz":
raise ValueError("When PROBLEM.NDIM == {} DATA.PATCH_SIZE tuple must be length {}, given {}."
.format(cfg.PROBLEM.NDIM, dim_count+1, cfg.DATA.PATCH_SIZE))
else:
print("WARNING: when PROBLEM.NDIM == {} DATA.PATCH_SIZE tuple must be length {}, given {}. Not an error "
"because you are using a model from Bioimage Model Zoo (BMZ) and the patch size will be determined by the model."
" However, this message is printed so you are aware of this. ")
assert cfg.DATA.NORMALIZATION.TYPE in ['div', 'custom'], "DATA.NORMALIZATION.TYPE not in ['div', 'custom']"
assert cfg.DATA.NORMALIZATION.APPLICATION_MODE in ["image", "dataset"], "'DATA.NORMALIZATION.APPLICATION_MODE' needs to be one between ['image', 'dataset']"
if not cfg.DATA.TRAIN.IN_MEMORY and cfg.DATA.NORMALIZATION.APPLICATION_MODE == "dataset":
Expand Down

0 comments on commit 79fe3b1

Please sign in to comment.