Skip to content

Commit

Permalink
Merge pull request #406 from bioimage-io/predict_cmd
Browse files Browse the repository at this point in the history
add predict command
  • Loading branch information
FynnBe authored Sep 13, 2024
2 parents db331df + 7306f98 commit c59a1f8
Show file tree
Hide file tree
Showing 27 changed files with 1,486 additions and 448 deletions.
329 changes: 289 additions & 40 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion bioimageio/core/VERSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"version": "0.6.8"
"version": "0.6.9"
}
5 changes: 4 additions & 1 deletion bioimageio/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@

from bioimageio.spec import build_description as build_description
from bioimageio.spec import dump_description as dump_description
from bioimageio.spec import load_dataset_description as load_dataset_description
from bioimageio.spec import load_description as load_description
from bioimageio.spec import (
load_description_and_validate_format_only as load_description_and_validate_format_only,
)
from bioimageio.spec import load_model_description as load_model_description
from bioimageio.spec import save_bioimageio_package as save_bioimageio_package
from bioimageio.spec import (
save_bioimageio_package_as_folder as save_bioimageio_package_as_folder,
)
from bioimageio.spec import save_bioimageio_yaml_only as save_bioimageio_yaml_only
from bioimageio.spec import validate_format as validate_format

from . import digest_spec as digest_spec
from ._prediction_pipeline import PredictionPipeline as PredictionPipeline
from ._prediction_pipeline import (
create_prediction_pipeline as create_prediction_pipeline,
Expand All @@ -38,4 +41,4 @@
# aliases
test_resource = test_description
load_resource = load_description
load_model = load_description
load_model = load_model_description
8 changes: 7 additions & 1 deletion bioimageio/core/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from bioimageio.core.commands import main
from bioimageio.core.cli import Bioimageio


def main():
cli = Bioimageio() # pyright: ignore[reportCallIssue]
cli.run()


if __name__ == "__main__":
main()
78 changes: 48 additions & 30 deletions bioimageio/core/_prediction_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def __init__(
postprocessing: List[Processing],
model_adapter: ModelAdapter,
default_ns: Union[
v0_5.ParameterizedSize.N,
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N],
v0_5.ParameterizedSize_N,
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
] = 10,
default_batch_size: int = 1,
) -> None:
Expand Down Expand Up @@ -179,40 +179,17 @@ def get_output_sample_id(self, input_sample_id: SampleId):
self.model_description.id or self.model_description.name
)

def predict_sample_with_blocking(
def predict_sample_with_fixed_blocking(
self,
sample: Sample,
input_block_shape: Mapping[MemberId, Mapping[AxisId, int]],
*,
skip_preprocessing: bool = False,
skip_postprocessing: bool = False,
ns: Optional[
Union[
v0_5.ParameterizedSize.N,
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N],
]
] = None,
batch_size: Optional[int] = None,
) -> Sample:
"""predict a sample by splitting it into blocks according to the model and the `ns` parameter"""
if not skip_preprocessing:
self.apply_preprocessing(sample)

if isinstance(self.model_description, v0_4.ModelDescr):
raise NotImplementedError(
"predict with blocking not implemented for v0_4.ModelDescr {self.model_description.name}"
)

ns = ns or self._default_ns
if isinstance(ns, int):
ns = {
(ipt.id, a.id): ns
for ipt in self.model_description.inputs
for a in ipt.axes
if isinstance(a.size, v0_5.ParameterizedSize)
}
input_block_shape = self.model_description.get_tensor_sizes(
ns, batch_size or self._default_batch_size
).inputs

n_blocks, input_blocks = sample.split_into_blocks(
input_block_shape,
halo=self._default_input_halo,
Expand All @@ -239,6 +216,47 @@ def predict_sample_with_blocking(

return predicted_sample

def predict_sample_with_blocking(
self,
sample: Sample,
skip_preprocessing: bool = False,
skip_postprocessing: bool = False,
ns: Optional[
Union[
v0_5.ParameterizedSize_N,
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
]
] = None,
batch_size: Optional[int] = None,
) -> Sample:
"""predict a sample by splitting it into blocks according to the model and the `ns` parameter"""

if isinstance(self.model_description, v0_4.ModelDescr):
raise NotImplementedError(
"`predict_sample_with_blocking` not implemented for v0_4.ModelDescr"
+ f" {self.model_description.name}."
+ " Consider using `predict_sample_with_fixed_blocking`"
)

ns = ns or self._default_ns
if isinstance(ns, int):
ns = {
(ipt.id, a.id): ns
for ipt in self.model_description.inputs
for a in ipt.axes
if isinstance(a.size, v0_5.ParameterizedSize)
}
input_block_shape = self.model_description.get_tensor_sizes(
ns, batch_size or self._default_batch_size
).inputs

return self.predict_sample_with_fixed_blocking(
sample,
input_block_shape=input_block_shape,
skip_preprocessing=skip_preprocessing,
skip_postprocessing=skip_postprocessing,
)

# def predict(
# self,
# inputs: Predict_IO,
Expand Down Expand Up @@ -310,8 +328,8 @@ def create_prediction_pipeline(
),
model_adapter: Optional[ModelAdapter] = None,
ns: Union[
v0_5.ParameterizedSize.N,
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N],
v0_5.ParameterizedSize_N,
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
] = 10,
**deprecated_kwargs: Any,
) -> PredictionPipeline:
Expand Down
16 changes: 8 additions & 8 deletions bioimageio/core/_resource_tests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import traceback
import warnings
from itertools import product
from typing import Dict, Hashable, List, Literal, Optional, Set, Tuple, Union
from typing import Dict, Hashable, List, Literal, Optional, Sequence, Set, Tuple, Union

import numpy as np
from loguru import logger
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_description(
*,
format_version: Union[Literal["discover", "latest"], str] = "discover",
weight_format: Optional[WeightsFormat] = None,
devices: Optional[List[str]] = None,
devices: Optional[Sequence[str]] = None,
absolute_tolerance: float = 1.5e-4,
relative_tolerance: float = 1e-4,
decimal: Optional[int] = None,
Expand All @@ -83,7 +83,7 @@ def load_description_and_test(
*,
format_version: Union[Literal["discover", "latest"], str] = "discover",
weight_format: Optional[WeightsFormat] = None,
devices: Optional[List[str]] = None,
devices: Optional[Sequence[str]] = None,
absolute_tolerance: float = 1.5e-4,
relative_tolerance: float = 1e-4,
decimal: Optional[int] = None,
Expand Down Expand Up @@ -138,12 +138,12 @@ def load_description_and_test(
def _test_model_inference(
model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
weight_format: WeightsFormat,
devices: Optional[List[str]],
devices: Optional[Sequence[str]],
absolute_tolerance: float,
relative_tolerance: float,
decimal: Optional[int],
) -> None:
test_name = "Reproduce test outputs from test inputs"
test_name = f"Reproduce test outputs from test inputs ({weight_format})"
logger.info("starting '{}'", test_name)
error: Optional[str] = None
tb: List[str] = []
Expand Down Expand Up @@ -209,15 +209,15 @@ def _test_model_inference(
def _test_model_inference_parametrized(
model: v0_5.ModelDescr,
weight_format: WeightsFormat,
devices: Optional[List[str]],
devices: Optional[Sequence[str]],
) -> None:
if not any(
isinstance(a.size, v0_5.ParameterizedSize)
for ipt in model.inputs
for a in ipt.axes
):
# no parameterized sizes => set n=0
ns: Set[v0_5.ParameterizedSize.N] = {0}
ns: Set[v0_5.ParameterizedSize_N] = {0}
else:
ns = {0, 1, 2}

Expand All @@ -236,7 +236,7 @@ def _test_model_inference_parametrized(
# no batch axis
batch_sizes = {1}

test_cases: Set[Tuple[v0_5.ParameterizedSize.N, BatchSize]] = {
test_cases: Set[Tuple[v0_5.ParameterizedSize_N, BatchSize]] = {
(n, b) for n, b in product(sorted(ns), sorted(batch_sizes))
}
logger.info(
Expand Down
27 changes: 6 additions & 21 deletions bioimageio/core/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,6 @@ def _get_axis_type(a: Literal["b", "t", "i", "c", "x", "y", "z"]):
S = TypeVar("S", bound=str)


def _get_axis_id(a: Union[Literal["b", "t", "i", "c"], S]):
if a == "b":
return AxisId("batch")
elif a == "t":
return AxisId("time")
elif a == "i":
return AxisId("index")
elif a == "c":
return AxisId("channel")
else:
return AxisId(a)


AxisId = v0_5.AxisId

T = TypeVar("T")
Expand All @@ -47,7 +34,7 @@ def _get_axis_id(a: Union[Literal["b", "t", "i", "c"], S]):
BatchSize = int

AxisLetter = Literal["b", "i", "t", "c", "z", "y", "x"]
AxisLike = Union[AxisLetter, v0_5.AnyAxis, "Axis"]
AxisLike = Union[AxisId, AxisLetter, v0_5.AnyAxis, "Axis"]


@dataclass
Expand All @@ -62,7 +49,7 @@ def create(cls, axis: AxisLike) -> Axis:
elif isinstance(axis, Axis):
return Axis(id=axis.id, type=axis.type)
elif isinstance(axis, str):
return Axis(id=_get_axis_id(axis), type=_get_axis_type(axis))
return Axis(id=AxisId(axis), type=_get_axis_type(axis))
elif isinstance(axis, v0_5.AxisBase):
return Axis(id=AxisId(axis.id), type=axis.type)
else:
Expand All @@ -71,7 +58,7 @@ def create(cls, axis: AxisLike) -> Axis:

@dataclass
class AxisInfo(Axis):
maybe_singleton: bool
maybe_singleton: bool # TODO: replace 'maybe_singleton' with size min/max for better axis guessing

@classmethod
def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisInfo:
Expand All @@ -80,18 +67,16 @@ def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisI

axis_base = super().create(axis)
if maybe_singleton is None:
if isinstance(axis, Axis):
maybe_singleton = False
elif isinstance(axis, str):
maybe_singleton = axis == "b"
if isinstance(axis, (Axis, str)):
maybe_singleton = True
else:
if axis.size is None:
maybe_singleton = True
elif isinstance(axis.size, int):
maybe_singleton = axis.size == 1
elif isinstance(axis.size, v0_5.SizeReference):
maybe_singleton = (
False # TODO: check if singleton is ok for a `SizeReference`
True # TODO: check if singleton is ok for a `SizeReference`
)
elif isinstance(
axis.size, (v0_5.ParameterizedSize, v0_5.DataDependentSize)
Expand Down
Loading

0 comments on commit c59a1f8

Please sign in to comment.