diff --git a/setup.py b/setup.py index ec6bcbe93..1af316128 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,7 @@ def get_requirements(require_name=None): "sleap-track=sleap.nn.inference:main", "sleap-inspect=sleap.info.labels:main", "sleap-diagnostic=sleap.diagnostic:main", + "sleap-export=sleap.nn.inference:export_cli", ], }, python_requires=">=3.6", diff --git a/sleap/__init__.py b/sleap/__init__.py index 35ca4b460..7e506b10a 100644 --- a/sleap/__init__.py +++ b/sleap/__init__.py @@ -14,7 +14,7 @@ import sleap.nn from sleap.nn.data import pipelines from sleap.nn import inference -from sleap.nn.inference import load_model +from sleap.nn.inference import load_model, export_model from sleap.nn.system import use_cpu_only, disable_preallocation from sleap.nn.system import summary as system_summary from sleap.nn.config import TrainingJobConfig, load_config diff --git a/sleap/nn/data/resizing.py b/sleap/nn/data/resizing.py index ad0c19d81..f5def38d5 100644 --- a/sleap/nn/data/resizing.py +++ b/sleap/nn/data/resizing.py @@ -30,6 +30,7 @@ def find_padding_for_stride( return pad_bottom, pad_right +@tf.function def pad_to_stride(image: tf.Tensor, max_stride: int) -> tf.Tensor: """Pad an image to meet a max stride constraint. diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 3df127a4a..72a15c8f1 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -66,6 +66,9 @@ from sleap.io.dataset import Labels from sleap.util import frame_list +from tensorflow.python.framework.convert_to_constants import ( + convert_variables_to_constants_v2, +) logger = logging.getLogger(__name__) @@ -440,6 +443,51 @@ def predict( # Just return the raw results. return list(generator) + def export_model( + self, + save_path: str, + signatures: str = "serving_default", + save_traces: bool = True, + model_name: Optional[str] = None, + tensors: Optional[Dict[str, str]] = None, + ): + + """Export a trained SLEAP model as a frozen graph. Initializes model, + creates a dummy tracing batch and passes it through the model. The + frozen graph is saved along with training meta info. + + Args: + save_path: Path to output directory to store the frozen graph + signatures: String defining the input and output types for + computation. + save_traces: If `True` (default) the SavedModel will store the + function traces for each layer + model_name: (Optional) Name to give the model. If given, will be + added to the output json file containing meta information about the + model + tensors: (Optional) Dictionary describing the predicted tensors (see + sleap.nn.data.utils.describe_tensors as an example) + + """ + + self._initialize_inference_model() + + first_inference_layer = self.inference_model.layers[0] + keras_model_shape = first_inference_layer.keras_model.input.shape + + sample_shape = tuple( + ( + np.array(keras_model_shape[1:3]) / first_inference_layer.input_scale + ).astype(int) + ) + (keras_model_shape[3],) + + tracing_batch = np.zeros((1,) + sample_shape, dtype="uint8") + outputs = self.inference_model.predict(tracing_batch) + + self.inference_model.export_model( + save_path, signatures, save_traces, model_name, tensors + ) + # TODO: Rewrite this class. @attr.s(auto_attribs=True) @@ -925,6 +973,78 @@ def predict_on_batch( return outs + def export_model( + self, + save_path: str, + signatures: str = "serving_default", + save_traces: bool = True, + model_name: Optional[str] = None, + tensors: Optional[Dict[str, str]] = None, + ): + """Save the frozen graph of a model. + + Args: + save_path: Path to output directory to store the frozen graph + signatures: String defining the input and output types for + computation. + save_traces: If `True` (default) the SavedModel will store the + function traces for each layer + model_name: (Optional) Name to give the model. If given, will be + added to the output json file containing meta information about the + model + tensors: (Optional) Dictionary describing the predicted tensors (see + sleap.nn.data.utils.describe_tensors as an example) + + + Notes: + This function call writes relevant meta data to an `info.json` file + in the given save_path in addition to the frozen_graph.pb file + + """ + os.makedirs(save_path, exist_ok=True) + + with tempfile.TemporaryDirectory() as tmp_dir: + + self.save(tmp_dir, save_format="tf", save_traces=save_traces) + + imported = tf.saved_model.load(tmp_dir) + + model = imported.signatures[signatures] + + info = { + "model_structured_input_signature": model.structured_input_signature, + "model_structured_outputs_signature": model.structured_outputs, + } + + if model_name: + info["model_name"] = model_name + if tensors: + info["predicted_tensors"] = tensors + + full_model = tf.function(lambda x: model(x)) + + full_model = full_model.get_concrete_function( + tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype) + ) + + frozen_func = convert_variables_to_constants_v2(full_model) + frozen_func.graph.as_graph_def() + + info["frozen_model_inputs"] = frozen_func.inputs + info["frozen_model_outputs"] = frozen_func.outputs + + with (Path(save_path) / "info.json").open("w") as fp: + json.dump( + info, fp, indent=4, sort_keys=True, separators=(",", ": "), default=str + ) + + tf.io.write_graph( + graph_or_graph_def=frozen_func.graph, + logdir=save_path, + name="frozen_graph.pb", + as_text=False, + ) + def get_model_output_stride( model: tf.keras.Model, input_ind: int = 0, output_ind: int = -1 @@ -1331,6 +1451,19 @@ def _make_labeled_frames_from_generator( return predicted_frames + def export_model( + self, + save_path: str, + signatures: str = "serving_default", + save_traces: bool = True, + model_name: Optional[str] = None, + tensors: Optional[Dict[str, str]] = None, + ): + + super().export_model(save_path, signatures, save_traces, model_name, tensors) + + self.confmap_config.save_json(os.path.join(save_path, "confmap_config.json")) + class CentroidCrop(InferenceLayer): """Inference layer for applying centroid crop-based models. @@ -1373,6 +1506,10 @@ class CentroidCrop(InferenceLayer): automatically by searching for the first tensor that contains `"OffsetRefinementHead"` in its name. If the head is not present, the method specified in the `refinement` attribute will be used. + return_crops: If `True`, the crops and offsets will be returned together with + the predicted peaks. This is true by default since crops are used + for finding instance peaks in a top down model. If using a centroid + only inference model, this should be set to `False`. """ def __init__( @@ -1388,6 +1525,7 @@ def __init__( return_confmaps: bool = False, confmaps_ind: Optional[int] = None, offsets_ind: Optional[int] = None, + return_crops: bool = True, **kwargs, ): super().__init__( @@ -1424,7 +1562,9 @@ def __init__( self.refinement = refinement self.integral_patch_size = integral_patch_size self.return_confmaps = return_confmaps + self.return_crops = return_crops + @tf.function def call(self, inputs): """Predict centroid confidence maps and crop around peaks. @@ -1439,11 +1579,6 @@ def call(self, inputs): Returns: A dictionary of outputs grouped by sample with keys: - `"crops"`: Cropped images of shape - `(samples, ?, crop_size, crop_size, channels)`. - `"crop_offsets"`: Coordinates of the top-left of the crops as `(x, y)` - offsets of shape `(samples, ?, 2)` for adjusting the predicted peak - coordinates. `"centroids"`: The predicted centroids of shape `(samples, ?, 2)`. `"centroid_vals": The centroid confidence values of shape `(samples, ?)`. @@ -1451,6 +1586,14 @@ def call(self, inputs): contain a key named `"centroid_confmaps"` containing a `tf.RaggedTensor` of shape `(samples, ?, output_height, output_width, 1)` containing the confidence maps predicted by the model. + + If the `return_crops` attribute is set to `True`, the output will + also contain keys named `crops` and `crop_offsets`. The former is a + `tf.RaggedTensor` of cropped images of shape `(samples, ?, + crop_size, crop_size, channels)`. The latter is a `tf.RaggedTensor` + of Coordinates of the top-left of the crops as `(x, y)` offsets of + shape `(samples, ?, 2)` for adjusting the predicted peak + coordinates. """ if isinstance(inputs, dict): # Pull out image from example dictionary. @@ -1540,28 +1683,30 @@ def call(self, inputs): centroids = tf.RaggedTensor.from_value_rowids( centroid_points, crop_sample_inds, nrows=samples ) - crops = tf.RaggedTensor.from_value_rowids( - crops, crop_sample_inds, nrows=samples - ) - crop_offsets = tf.RaggedTensor.from_value_rowids( - crop_offsets, crop_sample_inds, nrows=samples - ) centroid_vals = tf.RaggedTensor.from_value_rowids( centroid_vals, crop_sample_inds, nrows=samples ) - outputs = dict( - centroids=centroids, - centroid_vals=centroid_vals, - crops=crops, - crop_offsets=crop_offsets, - ) + outputs = dict(centroids=centroids, centroid_vals=centroid_vals) if self.return_confmaps: # Return confidence maps with outputs. cms = tf.RaggedTensor.from_value_rowids( cms, crop_sample_inds, nrows=samples ) outputs["centroid_confmaps"] = cms + + if self.return_crops: + # return crops and offsets + crops = tf.RaggedTensor.from_value_rowids( + crops, crop_sample_inds, nrows=samples + ) + crop_offsets = tf.RaggedTensor.from_value_rowids( + crop_offsets, crop_sample_inds, nrows=samples + ) + + outputs["crops"] = crops + outputs["crop_offsets"] = crop_offsets + return outputs @@ -1790,6 +1935,49 @@ def call( return outputs +class CentroidInferenceModel(InferenceModel): + """Centroid only instance prediction model. + + This model encapsulates the first step in a top-down approach where instances are detected by + local peak detection of an anchor point and then cropped. + + Attributes: + centroid_crop: A centroid cropping layer. This can be either `CentroidCrop` or + `CentroidCropGroundTruth`. This layer takes the full image as input and + outputs a set of centroids and cropped boxes. + """ + + def __init__(self, centroid_crop: Union[CentroidCrop, CentroidCropGroundTruth]): + super().__init__() + self.centroid_crop = centroid_crop + + def call( + self, example: Union[Dict[str, tf.Tensor], tf.Tensor] + ) -> Dict[str, tf.Tensor]: + """Predict instances for one batch of images. + + Args: + example: This may be either a single batch of images as a 4-D tensor of + shape `(batch_size, height, width, channels)`, or a dictionary + containing the image batch in the `"images"` key. If using a ground + truth model for either centroid cropping or instance peaks, the full + example from a `Pipeline` is required for providing the metadata. + + Returns: + The predicted instances as a dictionary of tensors with keys: + + `"centroids": (batch_size, n_instances, 2)`: Instance centroids. + `"centroid_vals": (batch_size, n_instances)`: Instance centroid confidence + values. + """ + if isinstance(example, tf.Tensor): + example = dict(image=example) + + crop_output = self.centroid_crop(example) + + return crop_output + + class TopDownInferenceModel(InferenceModel): """Top-down instance prediction model. @@ -2181,6 +2369,26 @@ def _make_labeled_frames_from_generator( return predicted_frames + def export_model( + self, + save_path: str, + signatures: str = "serving_default", + save_traces: bool = True, + model_name: Optional[str] = None, + tensors: Optional[Dict[str, str]] = None, + ): + + super().export_model(save_path, signatures, save_traces, model_name, tensors) + + if self.confmap_config is not None: + self.confmap_config.save_json( + os.path.join(save_path, "confmap_config.json") + ) + if self.centroid_config is not None: + self.centroid_config.save_json( + os.path.join(save_path, "centroid_config.json") + ) + class BottomUpInferenceLayer(InferenceLayer): """Keras layer that predicts instances from images using a trained model. @@ -3240,6 +3448,9 @@ class TopDownMultiClassFindPeaks(InferenceLayer): classification vectors. If `None` (the default), this will be detected automatically by searching for the first tensor that contains `"ClassVectorsHead"` in its name. + optimal_grouping: If `True` (the default), group peaks from classification + probabilities. If saving a frozen graph of the model, this will be + overridden to `False`. """ def __init__( @@ -3255,6 +3466,7 @@ def __init__( confmaps_ind: Optional[int] = None, offsets_ind: Optional[int] = None, class_vectors_ind: Optional[int] = None, + optimal_grouping: bool = True, **kwargs, ): super().__init__( @@ -3268,6 +3480,7 @@ def __init__( self.confmaps_ind = confmaps_ind self.class_vectors_ind = class_vectors_ind self.offsets_ind = offsets_ind + self.optimal_grouping = optimal_grouping if self.confmaps_ind is None: self.confmaps_ind = find_head( @@ -3340,9 +3553,14 @@ def call( shape `(samples, ?, output_height, output_width, nodes)` containing the confidence maps predicted by the model. - If the `return_class_vectors` attribe is set to `True`, the output will also + If the `return_class_vectors` attribute is set to `True`, the output will also contain a key named `"class_vectors"` containing the full classification probabilities for all crops. + + If the `optimal_grouping` attribute is set to `True`, peaks are + grouped from classification properties. This is overridden to False + if exporting a frozen graph to allow for tracing. Note: If set to False + this will change the output dict keys and shapes. """ if isinstance(inputs, dict): crops = inputs["crops"] @@ -3415,17 +3633,30 @@ def call( crop_offsets = inputs["crop_offsets"].merge_dims(0, 1) peak_points = peak_points + tf.expand_dims(crop_offsets, axis=1) - # Group peaks from classification probabilities. - points, point_vals, class_probs = sleap.nn.identity.classify_peaks_from_vectors( - peak_points, peak_vals, peak_class_probs, crop_sample_inds, samples - ) + if self.optimal_grouping: + # Group peaks from classification probabilities. + ( + points, + point_vals, + class_probs, + ) = sleap.nn.identity.classify_peaks_from_vectors( + peak_points, peak_vals, peak_class_probs, crop_sample_inds, samples + ) + + # Build outputs. + outputs = { + "instance_peaks": points, + "instance_peak_vals": point_vals, + "instance_scores": class_probs, + } + + else: + outputs = { + "instance_peaks": peak_points, + "instance_peak_vals": peak_vals, + "instance_scores": peak_class_probs, + } - # Build outputs. - outputs = { - "instance_peaks": points, - "instance_peak_vals": point_vals, - "instance_scores": class_probs, - } if "centroids" in inputs: outputs["centroids"] = inputs["centroids"] if "centroids" in inputs: @@ -3435,7 +3666,7 @@ def call( cms, crop_sample_inds, nrows=samples ) outputs["instance_confmaps"] = cms - if self.return_class_vectors: + if self.return_class_vectors and self.optimal_grouping: outputs["class_vectors"] = peak_class_probs return outputs @@ -3497,6 +3728,19 @@ def call( peaks_output = self.instance_peaks(crop_output) return peaks_output + def export_model( + self, + save_path: str, + signatures: str = "serving_default", + save_traces: bool = True, + model_name: Optional[str] = None, + tensors: Optional[Dict[str, str]] = None, + ): + + self.instance_peaks.optimal_grouping = False + + super().export_model(save_path, signatures, save_traces, model_name, tensors) + @attr.s(auto_attribs=True) class TopDownMultiClassPredictor(Predictor): @@ -3815,6 +4059,26 @@ def _make_labeled_frames_from_generator( return predicted_frames + def export_model( + self, + save_path: str, + signatures: str = "serving_default", + save_traces: bool = True, + model_name: Optional[str] = None, + tensors: Optional[Dict[str, str]] = None, + ): + + super().export_model(save_path, signatures, save_traces, model_name, tensors) + + if self.confmap_config is not None: + self.confmap_config.save_json( + os.path.join(save_path, "confmap_config.json") + ) + if self.centroid_config is not None: + self.centroid_config.save_json( + os.path.join(save_path, "centroid_config.json") + ) + def load_model( model_path: Union[str, List[str]], @@ -3921,6 +4185,63 @@ def load_model( return predictor +def export_model( + model_path: Union[str, List[str]], + save_path: str = "exported_model", + signatures: str = "serving_default", + save_traces: bool = True, + model_name: Optional[str] = None, + tensors: Optional[Dict[str, str]] = None, +): + + """High level export of a trained SLEAP model as a frozen graph. + + Args: + model_path: Path to model or list of path to models that were trained by SLEAP. + These should be the directories that contain `training_job.json` and + `best_model.h5`. + save_path: Path to output directory to store the frozen graph + signatures: String defining the input and output types for + computation. + save_traces: If `True` (default) the SavedModel will store the + function traces for each layer + model_name: (Optional) Name to give the model. If given, will be + added to the output json file containing meta information about the + model + tensors: (Optional) Dictionary describing the predicted tensors (see + sleap.nn.data.utils.describe_tensors as an example) + + """ + + predictor = load_model(model_path) + predictor.export_model(save_path, signatures, save_traces, model_name, tensors) + + +def export_cli(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + dest="models", + action="append", + help=( + "Path to trained model directory (with training_config.json). " + "Multiple models can be specified, each preceded by --model." + ), + ) + parser.add_argument( + "-e", + "export_path", + type=str, + nargs="?", + default="exported_model", + help=("Path to data export model to."), + ) + + args, _ = parser.parse_known_args() + export_model(args["models"], args["export_path"]) + + def _make_cli_parser() -> argparse.ArgumentParser: """Create argument parser for CLI. diff --git a/sleap/nn/peak_finding.py b/sleap/nn/peak_finding.py index b6b56ab92..84dca00ae 100644 --- a/sleap/nn/peak_finding.py +++ b/sleap/nn/peak_finding.py @@ -334,6 +334,7 @@ def integral_regression( return x_hat, y_hat +@tf.function def find_global_peaks( cms: tf.Tensor, threshold: float = 0.2, @@ -447,6 +448,7 @@ def find_global_peaks_integral( ) +@tf.function def find_local_peaks( cms: tf.Tensor, threshold: float = 0.2, @@ -561,6 +563,7 @@ def find_local_peaks_integral( ) +@tf.function def find_global_peaks_with_offsets( cms: tf.Tensor, offsets: tf.Tensor, threshold: float = 0.2 ) -> Tuple[tf.Tensor, tf.Tensor]: @@ -640,6 +643,7 @@ def find_global_peaks_with_offsets( return refined_peaks, peak_vals +@tf.function def find_local_peaks_with_offsets( cms: tf.Tensor, offsets: tf.Tensor, diff --git a/tests/nn/data/test_resizing.py b/tests/nn/data/test_resizing.py index c8d04a3d6..e0f63ebbb 100644 --- a/tests/nn/data/test_resizing.py +++ b/tests/nn/data/test_resizing.py @@ -27,7 +27,7 @@ def test_find_padding_for_stride(): def test_pad_to_stride(): np.testing.assert_array_equal( - resizing.pad_to_stride(tf.ones([3, 5, 1]), max_stride=2), + resizing.pad_to_stride.__wrapped__(tf.ones([3, 5, 1]), max_stride=2), tf.expand_dims( [ [1, 1, 1, 1, 1, 0], @@ -39,14 +39,20 @@ def test_pad_to_stride(): ), ) assert ( - resizing.pad_to_stride(tf.ones([3, 5, 1], dtype=tf.uint8), max_stride=2).dtype + resizing.pad_to_stride.__wrapped__( + tf.ones([3, 5, 1], dtype=tf.uint8), max_stride=2 + ).dtype == tf.uint8 ) assert ( - resizing.pad_to_stride(tf.ones([3, 5, 1], dtype=tf.float32), max_stride=2).dtype + resizing.pad_to_stride.__wrapped__( + tf.ones([3, 5, 1], dtype=tf.float32), max_stride=2 + ).dtype == tf.float32 ) - assert resizing.pad_to_stride(tf.ones([4, 4, 1]), max_stride=2).shape == (4, 4, 1) + assert resizing.pad_to_stride.__wrapped__( + tf.ones([4, 4, 1]), max_stride=2 + ).shape == (4, 4, 1) def test_resize_image(): diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 7090cb7f0..85773436c 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -1,5 +1,7 @@ +import ast import pytest import numpy as np +import json from sleap.io.dataset import Labels import tensorflow as tf import sleap @@ -22,14 +24,18 @@ SingleInstancePredictor, CentroidCropGroundTruth, CentroidCrop, + CentroidInferenceModel, FindInstancePeaksGroundTruth, FindInstancePeaks, + TopDownMultiClassFindPeaks, TopDownInferenceModel, + TopDownMultiClassInferenceModel, TopDownPredictor, BottomUpPredictor, BottomUpMultiClassPredictor, TopDownMultiClassPredictor, load_model, + export_model, _make_cli_parser, _make_tracker_from_cli, main as sleap_track, @@ -205,6 +211,13 @@ def test_centroid_crop_layer(): return_confmaps=False, ) + # For Codecov to realize the wrapped CentroidCrop.call is tested/covered, + # we need to unbind CentroidCrop.call from its bind with TfMethodTarget object + # and then rebind the standalone function with the CentroidCrop object + TfMethodTarget_object = layer.call.__wrapped__.__self__ # Get the bound object + original_func = TfMethodTarget_object.weakrefself_func__() # Get unbound function + layer.call = original_func.__get__(layer, layer.__class__) # Bind function + out = layer(cms) assert tuple(out["centroids"].shape) == (1, None, 2) assert tuple(out["centroid_vals"].shape) == (1, None) @@ -797,6 +810,245 @@ def test_ensure_numpy( assert type(out["n_valid"]) == np.ndarray +def test_centroid_inference(): + + xv, yv = make_grid_vectors(image_height=12, image_width=12, output_stride=1) + points = tf.cast([[[1.75, 2.75]], [[3.75, 4.75]], [[5.75, 6.75]]], tf.float32) + cms = tf.expand_dims(make_multi_confmaps(points, xv, yv, sigma=1.5), axis=0) + + x_in = tf.keras.layers.Input([12, 12, 1]) + x_out = tf.keras.layers.Lambda(lambda x: x, name="CentroidConfmapsHead")(x_in) + model = tf.keras.Model(inputs=x_in, outputs=x_out) + + layer = CentroidCrop( + keras_model=model, + input_scale=1.0, + crop_size=3, + pad_to_stride=1, + output_stride=None, + refinement="local", + integral_patch_size=5, + peak_threshold=0.2, + return_confmaps=False, + return_crops=False, + ) + + # For Codecov to realize the wrapped CentroidCrop.call is tested/covered, + # we need to unbind CentroidCrop.call from its bind with TfMethodTarget object + # and then rebind the standalone function with the CentroidCrop object + TfMethodTarget_object = layer.call.__wrapped__.__self__ # Get the bound object + original_func = TfMethodTarget_object.weakrefself_func__() # Get unbound function + layer.call = original_func.__get__(layer, layer.__class__) # Bind function + + out = layer(cms) + assert tuple(out["centroids"].shape) == (1, None, 2) + assert tuple(out["centroid_vals"].shape) == (1, None) + + assert tuple(out["centroids"].bounding_shape()) == (1, 3, 2) + assert tuple(out["centroid_vals"].bounding_shape()) == (1, 3) + + assert_allclose(out["centroids"][0].numpy(), points.numpy().squeeze(axis=1)) + assert_allclose(out["centroid_vals"][0].numpy(), [1, 1, 1], atol=0.1) + + model = CentroidInferenceModel(layer) + + preds = model.predict(cms) + + assert preds["centroids"].shape == (1, 3, 2) + assert preds["centroid_vals"].shape == (1, 3) + + +def export_frozen_graph(model, preds, output_path): + + tensors = {} + + for key, val in preds.items(): + dtype = str(val.dtype) if isinstance(val.dtype, np.dtype) else repr(val.dtype) + tensors[key] = { + "type": f"{type(val).__name__}", + "shape": f"{val.shape}", + "dtype": dtype, + "device": f"{val.device if hasattr(val, 'device') else 'N/A'}", + } + + with output_path as d: + model.export_model(d.as_posix(), tensors=tensors) + + tf.compat.v1.reset_default_graph() + with tf.compat.v2.io.gfile.GFile(f"{d}/frozen_graph.pb", "rb") as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def) + + with open(f"{d}/info.json") as json_file: + info = json.load(json_file) + + for tensor_info in info["frozen_model_inputs"] + info["frozen_model_outputs"]: + + saved_name = ( + tensor_info.split("Tensor(")[1].split(", shape")[0].replace('"', "") + ) + saved_shape = ast.literal_eval( + tensor_info.split("shape=", 1)[1].split("), ")[0] + ")" + ) + saved_dtype = tensor_info.split("dtype=")[1].split(")")[0] + + loaded_shape = tuple(graph.get_tensor_by_name(f"import/{saved_name}").shape) + loaded_dtype = graph.get_tensor_by_name(f"import/{saved_name}").dtype.name + + assert saved_shape == loaded_shape + assert saved_dtype == loaded_dtype + + +def test_single_instance_save(min_single_instance_robot_model_path, tmp_path): + + single_instance_model = tf.keras.models.load_model( + min_single_instance_robot_model_path + "/best_model.h5", compile=False + ) + + model = SingleInstanceInferenceModel( + SingleInstanceInferenceLayer(keras_model=single_instance_model) + ) + + preds = model.predict(np.zeros((4, 160, 280, 3), dtype="uint8")) + + export_frozen_graph(model, preds, tmp_path) + + +def test_centroid_save(min_centroid_model_path, tmp_path): + + centroid_model = tf.keras.models.load_model( + min_centroid_model_path + "/best_model.h5", compile=False + ) + + centroid = CentroidCrop( + keras_model=centroid_model, crop_size=160, return_crops=False + ) + + model = CentroidInferenceModel(centroid) + + preds = model.predict(np.zeros((4, 384, 384, 1), dtype="uint8")) + + export_frozen_graph(model, preds, tmp_path) + + +def test_topdown_save( + min_centroid_model_path, min_centered_instance_model_path, min_labels_slp, tmp_path +): + + centroid_model = tf.keras.models.load_model( + min_centroid_model_path + "/best_model.h5", compile=False + ) + + top_down_model = tf.keras.models.load_model( + min_centered_instance_model_path + "/best_model.h5", compile=False + ) + + centroid = CentroidCrop(keras_model=centroid_model, crop_size=96) + + instance_peaks = FindInstancePeaks(keras_model=top_down_model) + + model = TopDownInferenceModel(centroid, instance_peaks) + + imgs = min_labels_slp.video[:4] + preds = model.predict(imgs) + + export_frozen_graph(model, preds, tmp_path) + + +def test_topdown_id_save( + min_centroid_model_path, min_topdown_multiclass_model_path, min_labels_slp, tmp_path +): + + centroid_model = tf.keras.models.load_model( + min_centroid_model_path + "/best_model.h5", compile=False + ) + + top_down_id_model = tf.keras.models.load_model( + min_topdown_multiclass_model_path + "/best_model.h5", compile=False + ) + + centroid = CentroidCrop(keras_model=centroid_model, crop_size=128) + + instance_peaks = TopDownMultiClassFindPeaks(keras_model=top_down_id_model) + + model = TopDownMultiClassInferenceModel(centroid, instance_peaks) + + imgs = min_labels_slp.video[:4] + preds = model.predict(imgs) + + export_frozen_graph(model, preds, tmp_path) + + +def test_single_instance_predictor_save(min_single_instance_robot_model_path, tmp_path): + + # directly initialize predictor + predictor = SingleInstancePredictor.from_trained_models( + min_single_instance_robot_model_path + ) + + predictor.export_model(save_path=tmp_path.as_posix()) + + # high level load to predictor + predictor = load_model(min_single_instance_robot_model_path) + + predictor.export_model(save_path=tmp_path.as_posix()) + + # high level export + + export_model(min_single_instance_robot_model_path, save_path=tmp_path.as_posix()) + + +def test_topdown_predictor_save( + min_centroid_model_path, min_centered_instance_model_path, tmp_path +): + + # directly initialize predictor + predictor = TopDownPredictor.from_trained_models( + centroid_model_path=min_centroid_model_path, + confmap_model_path=min_centered_instance_model_path, + ) + + predictor.export_model(save_path=tmp_path.as_posix()) + + # high level load to predictor + predictor = load_model([min_centroid_model_path, min_centered_instance_model_path]) + + predictor.export_model(save_path=tmp_path.as_posix()) + + # high level export + export_model( + [min_centroid_model_path, min_centered_instance_model_path], + save_path=tmp_path.as_posix(), + ) + + +def test_topdown_id_predictor_save( + min_centroid_model_path, min_topdown_multiclass_model_path, tmp_path +): + + # directly initialize predictor + predictor = TopDownMultiClassPredictor.from_trained_models( + centroid_model_path=min_centroid_model_path, + confmap_model_path=min_topdown_multiclass_model_path, + ) + + predictor.export_model(save_path=tmp_path.as_posix()) + + # high level load to predictor + predictor = load_model([min_centroid_model_path, min_topdown_multiclass_model_path]) + + predictor.export_model(save_path=tmp_path.as_posix()) + + # high level export + export_model( + [min_centroid_model_path, min_topdown_multiclass_model_path], + save_path=tmp_path.as_posix(), + ) + + @pytest.mark.parametrize( "output_path,tracker_method", [("not_default", "flow"), (None, "simple")] ) @@ -807,6 +1059,7 @@ def test_retracking( labels: Labels = Labels.save(centered_pair_predictions, slp_path) # Create sleap-track command + cmd = f"{slp_path} --tracking.tracker {tracker_method} --frames 1-3 --cpu" cmd = ( f"{slp_path} --tracking.tracker {tracker_method} --video.index 0 --frames 1-3 " "--cpu" @@ -852,6 +1105,7 @@ def test_sleap_track( labels: Labels = Labels.save(centered_pair_predictions, slp_path) # Create sleap-track command + args = f"{slp_path} --model {min_centered_instance_model_path} --frames 1-3 --cpu".split() args = ( f"{slp_path} --model {min_centroid_model_path} " f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" diff --git a/tests/nn/test_peak_finding.py b/tests/nn/test_peak_finding.py index 9e8f8c590..93beaa193 100644 --- a/tests/nn/test_peak_finding.py +++ b/tests/nn/test_peak_finding.py @@ -53,7 +53,9 @@ def test_find_global_peaks_rough(): points2 = points + 1 cms = tf.stack([cm, make_confmaps(points2, xv, yv, sigma=1.0)]) - peaks, peak_vals = find_global_peaks(cms, threshold=0.1, refinement=None) + peaks, peak_vals = find_global_peaks.__wrapped__( + cms, threshold=0.1, refinement=None + ) # Use __wrapped__ for codecov to catch coverage of unwrapped function assert peaks.shape == (2, 3, 2) assert peak_vals.shape == (2, 3) @@ -76,35 +78,35 @@ def test_find_global_peaks_integral(): points = tf.cast([[1.5, 2.5], [3.5, 4.5], [5.5, 6.5]], tf.float32) cm = make_confmaps(points, xv, yv, sigma=1.0) - peaks, peak_vals = find_global_peaks( + peaks, peak_vals = find_global_peaks.__wrapped__( tf.expand_dims(cm, axis=0), threshold=0.1, refinement="integral", integral_patch_size=5, - ) + ) # Use __wrapped__ for codecov to catch coverage of unwrapped function assert peaks.shape == (1, 3, 2) assert peak_vals.shape == (1, 3) assert_allclose(peaks[0].numpy(), points.numpy(), atol=0.1) assert_allclose(peak_vals[0].numpy(), [1, 1, 1], atol=0.3) - peaks, peak_vals = find_global_peaks( + peaks, peak_vals = find_global_peaks.__wrapped__( tf.zeros((1, 8, 8, 3), dtype=tf.float32), threshold=0.1, refinement="integral", integral_patch_size=5, - ) + ) # Use __wrapped__ for codecov to catch coverage of unwrapped function assert peaks.shape == (1, 3, 2) assert peak_vals.shape == (1, 3) assert tf.reduce_all(tf.math.is_nan(peaks)) assert_array_equal(peak_vals, [[0, 0, 0]]) - peaks, peak_vals = find_global_peaks( + peaks, peak_vals = find_global_peaks.__wrapped__( tf.stack([tf.zeros([12, 12, 3], dtype=tf.float32), cm], axis=0), threshold=0.1, refinement="integral", integral_patch_size=5, - ) + ) # Use __wrapped__ for codecov to catch coverage of unwrapped function assert peaks.shape == (2, 3, 2) assert tf.reduce_all(tf.math.is_nan(peaks[0])) assert_allclose(peaks[1].numpy(), points.numpy(), atol=0.1) @@ -124,9 +126,9 @@ def test_find_global_peaks_local(): points = tf.cast([[1.6, 2.6], [3.6, 4.6], [5.6, 6.6]], tf.float32) cm = make_confmaps(points, xv, yv, sigma=1.0) - peaks, peak_vals = find_global_peaks( + peaks, peak_vals = find_global_peaks.__wrapped__( tf.expand_dims(cm, axis=0), threshold=0.1, refinement="local" - ) + ) # Use __wrapped__ for codecov to catch coverage of unwrapped function assert peaks.shape == (1, 3, 2) assert peak_vals.shape == (1, 3) @@ -152,9 +154,12 @@ def test_find_local_peaks_rough(): [cms, make_multi_confmaps(instances2, xv=xv, yv=yv, sigma=1.0)], axis=0 ) - peak_points, peak_vals, peak_sample_inds, peak_channel_inds = find_local_peaks( - cms, threshold=0.1, refinement=None - ) + ( + peak_points, + peak_vals, + peak_sample_inds, + peak_channel_inds, + ) = find_local_peaks.__wrapped__(cms, threshold=0.1, refinement=None) assert peak_points.shape == (9, 2) assert peak_vals.shape == (9,) @@ -179,9 +184,14 @@ def test_find_local_peaks_rough(): assert_array_equal(peak_sample_inds, [0, 0, 0, 0, 0, 1, 1, 1, 1]) assert_array_equal(peak_channel_inds, [0, 1, 0, 1, 1, 0, 1, 0, 1]) - peak_points, peak_vals, peak_sample_inds, peak_channel_inds = find_local_peaks( + ( + peak_points, + peak_vals, + peak_sample_inds, + peak_channel_inds, + ) = find_local_peaks.__wrapped__( tf.zeros([1, 4, 4, 3], tf.float32), threshold=0.1, refinement=None - ) + ) # Use __wrapped__ for codecov to catch coverage of unwrapped function assert peak_points.shape == (0, 2) assert peak_vals.shape == (0,) assert peak_sample_inds.shape == (0,) @@ -208,9 +218,14 @@ def test_find_local_peaks_integral(): [cms, make_multi_confmaps(instances2, xv=xv, yv=yv, sigma=1.0)], axis=0 ) - peak_points, peak_vals, peak_sample_inds, peak_channel_inds = find_local_peaks( + ( + peak_points, + peak_vals, + peak_sample_inds, + peak_channel_inds, + ) = find_local_peaks.__wrapped__( cms, threshold=0.1, refinement="integral", integral_patch_size=5 - ) + ) # Use __wrapped__ for codecov to catch coverage of unwrapped function assert peak_points.shape == (9, 2) assert peak_vals.shape == (9,) @@ -240,9 +255,14 @@ def test_find_local_peaks_integral(): assert_array_equal(peak_sample_inds, [0, 0, 0, 0, 0, 1, 1, 1, 1]) assert_array_equal(peak_channel_inds, [0, 1, 0, 1, 1, 0, 1, 0, 1]) - peak_points, peak_vals, peak_sample_inds, peak_channel_inds = find_local_peaks( + ( + peak_points, + peak_vals, + peak_sample_inds, + peak_channel_inds, + ) = find_local_peaks.__wrapped__( tf.zeros([1, 4, 4, 3], tf.float32), refinement="integral", integral_patch_size=5 - ) + ) # Use __wrapped__ for codecov to catch coverage of unwrapped function assert peak_points.shape == (0, 2) assert peak_vals.shape == (0,) assert peak_sample_inds.shape == (0,) @@ -280,9 +300,14 @@ def test_find_local_peaks_local(): [cms, make_multi_confmaps(instances2, xv=xv, yv=yv, sigma=1.0)], axis=0 ) - peak_points, peak_vals, peak_sample_inds, peak_channel_inds = find_local_peaks( + ( + peak_points, + peak_vals, + peak_sample_inds, + peak_channel_inds, + ) = find_local_peaks.__wrapped__( cms, threshold=0.1, refinement="local" - ) + ) # Use __wrapped__ for codecov to catch coverage of unwrapped function assert peak_points.shape == (9, 2) assert peak_vals.shape == (9,)