Skip to content

Commit

Permalink
Update docstrings to document recent functionalities (#664)
Browse files Browse the repository at this point in the history
Update docstrings
  • Loading branch information
anwai98 authored Jul 24, 2024
1 parent 83c8313 commit 8add576
Show file tree
Hide file tree
Showing 37 changed files with 190 additions and 111 deletions.
6 changes: 3 additions & 3 deletions finetuning/livecell_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_dataloaders(patch_shape, data_path, cell_type=None):


def finetune_livecell(args):
"""Example code for finetuning SAM on LiveCELL"""
"""Example code for finetuning SAM on LIVECell"""
# override this (below) if you have some more complex set-up and need to specify the exact gpu
device = "cuda" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -84,10 +84,10 @@ def finetune_livecell(args):


def main():
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.")
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LIVECell dataset.")
parser.add_argument(
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/",
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded."
help="The filepath to the LIVECell data. If the data does not exist yet it will be downloaded."
)
parser.add_argument(
"--model_type", "-m", default="vit_b",
Expand Down
2 changes: 2 additions & 0 deletions micro_sam/_vendored.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
The license type of the thrid party software project must be compatible with
the software license the micro-sam project is distributed under.
"""

from typing import Any, Dict, List

import numpy as np

import torch

try:
Expand Down
12 changes: 6 additions & 6 deletions micro_sam/bioimageio/model_export.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
import os
import tempfile

from pathlib import Path
from typing import Optional, Union

import bioimageio.core
import bioimageio.spec.model.v0_5 as spec
import matplotlib.pyplot as plt
import xarray
import numpy as np
import matplotlib.pyplot as plt

import torch
import xarray

import bioimageio.core
import bioimageio.spec.model.v0_5 as spec
from bioimageio.spec import save_bioimageio_package
from bioimageio.core.digest_spec import create_sample_for_model


from .. import util
from ..prompt_generators import PointAndBoxPromptGenerator
from ..evaluation.model_comparison import _enhance_image, _overlay_outline, _overlay_box
from ..prompt_based_segmentation import _compute_logits_from_mask
from .predictor_adaptor import PredictorAdaptor


DEFAULTS = {
"authors": [
spec.Author(name="Anwai Archit", affiliation="University Goettingen", github_user="anwai98"),
Expand Down
3 changes: 2 additions & 1 deletion micro_sam/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import numpy as np
import pandas as pd
import imageio.v3 as imageio

from skimage.measure import label

from elf.evaluation import mean_segmentation_accuracy


Expand Down Expand Up @@ -88,6 +88,7 @@ def run_evaluation_for_iterative_prompting(
prediction_root: The folder with the iterative prompt-based instance segmentations to evaluate.
experiment_folder: The folder where all the experiment results are stored.
start_with_box_prompt: Whether to evaluate on experiments with iterative prompting starting with box.
overwrite_results: Whether to overwrite the results to update them with the new evaluation run.
Returns:
A DataFrame that contains the evaluation results.
Expand Down
1 change: 1 addition & 0 deletions micro_sam/evaluation/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from typing import Dict, List, Optional


# TODO fully define the dict type
ExperimentSetting = Dict
ExperimentSettings = List[ExperimentSetting]
Expand Down
10 changes: 5 additions & 5 deletions micro_sam/evaluation/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@

import os
from glob import glob
from itertools import product
from tqdm import tqdm
from pathlib import Path
from itertools import product
from typing import Any, Dict, List, Optional, Tuple, Union

import imageio.v3 as imageio
import numpy as np
import pandas as pd
import imageio.v3 as imageio

from elf.evaluation import mean_segmentation_accuracy
from elf.io import open_file
from tqdm import tqdm
from elf.evaluation import mean_segmentation_accuracy

from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder, mask_data_to_segmentation
from .. import util
from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder, mask_data_to_segmentation


def _get_range_of_search_values(input_vals, step):
Expand Down
3 changes: 2 additions & 1 deletion micro_sam/evaluation/livecell.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Inference and evaluation for the [LIVECell dataset](https://www.nature.com/articles/s41592-021-01249-6) and
the different cell lines contained in it.
"""

import os
import json
import argparse
Expand Down Expand Up @@ -422,7 +423,7 @@ def run_livecell_inference() -> None:


def run_livecell_evaluation() -> None:
"""Run LiveCELL evaluation with command line tool."""
"""Run LIVECell evaluation with command line tool."""
parser = argparse.ArgumentParser()
parser.add_argument(
"-i", "--input", required=True, help="Provide the data directory for LIVECell Dataset"
Expand Down
13 changes: 6 additions & 7 deletions micro_sam/evaluation/model_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@
"""

import os
from functools import partial
from glob import glob
from tqdm import tqdm
from pathlib import Path
from functools import partial
from typing import Optional, Union

import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

import matplotlib.pyplot as plt
import skimage.draw as draw
from scipy.ndimage import binary_dilation
from skimage import exposure
from scipy.ndimage import binary_dilation
from skimage.segmentation import relabel_sequential, find_boundaries

from tqdm import tqdm
from typing import Optional, Union
import torch

from .. import util
from ..prompt_generators import PointAndBoxPromptGenerator
Expand Down
5 changes: 3 additions & 2 deletions micro_sam/inference.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
from typing import Optional, Union

import torch
import numpy as np

import segment_anything.utils.amg as amg_utils
import torch

from segment_anything import SamPredictor
import segment_anything.utils.amg as amg_utils
from segment_anything.utils.transforms import ResizeLongestSide

from . import util
Expand Down
18 changes: 10 additions & 8 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,25 @@

import os
from abc import ABC
from collections import OrderedDict
from copy import deepcopy
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import segment_anything.utils.amg as amg_utils
import vigra

from nifty.tools import blocking
from segment_anything.predictor import SamPredictor

import numpy as np
from skimage.measure import regionprops

import torch
from torchvision.ops.boxes import batched_nms, box_area

from torch_em.model import UNETR
from torch_em.util.segmentation import watershed_from_center_and_boundary_distances

from nifty.tools import blocking

import segment_anything.utils.amg as amg_utils
from segment_anything.predictor import SamPredictor

from . import util
from ._vendored import batched_mask_to_box, mask_to_rle_pytorch

Expand Down Expand Up @@ -56,6 +57,7 @@ def mask_data_to_segmentation(
object in the output will be mapped to zero (the background value).
min_object_size: The minimal size of an object in pixels.
max_object_size: The maximal size of an object in pixels.
Returns:
The instance segmentation.
"""
Expand Down
1 change: 1 addition & 0 deletions micro_sam/models/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# https://github.com/facebookresearch/segment-anything/

#
# NOTE: This code has been adapted from Segment Anything.
Expand Down
4 changes: 4 additions & 0 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class LoRASurgery(nn.Module):
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
```
Args:
rank: The rank of the decomposition matrices for updating weights in each attention layer.
block: The chosen attention blocks for implementing lora.
"""
def __init__(
self,
Expand Down
5 changes: 3 additions & 2 deletions micro_sam/models/sam_3d_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
import torch.nn as nn

from segment_anything.modeling.image_encoder import window_partition, window_unpartition
from segment_anything.modeling import Sam
from segment_anything.modeling.image_encoder import window_partition, window_unpartition

from ..util import get_sam_model

Expand Down Expand Up @@ -42,6 +42,7 @@ def __init__(self, sam_model: Sam, freeze_encoder: bool):
Args:
sam_model: The Sam model to be wrapped.
freeze_encoder: Whether to freeze the image encoder.
"""
super().__init__()
sam_model.image_encoder = ImageEncoderViT3DWrapper(
Expand All @@ -64,7 +65,7 @@ def forward(
Unlike original SAM this model only supports automatic segmentation and does not support prompts.
Args:
batched_input: A list over input images, each a dictionary with the following keys.L
batched_input: A list over input images, each a dictionary with the following keys.
'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
'original_size': The original size of the image (HxW) before transformation.
multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
Expand Down
1 change: 0 additions & 1 deletion micro_sam/models/simple_sam_3d_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def get_simple_sam_3d_model(
model_type="vit_b",
checkpoint_path=None,
):

_, sam = get_sam_model(
model_type=model_type,
device=device,
Expand Down
9 changes: 6 additions & 3 deletions micro_sam/multi_dimensional_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,26 @@
from typing import Optional, Union, Tuple

import numpy as np

import nifty
import elf.tracking.tracking_utils as track_utils

import elf.segmentation as seg_utils
import elf.tracking.tracking_utils as track_utils

from segment_anything.predictor import SamPredictor
from scipy.ndimage import binary_closing
from skimage.measure import label, regionprops
from skimage.segmentation import relabel_sequential

from segment_anything.predictor import SamPredictor

try:
from napari.utils import progress as tqdm
except ImportError:
from tqdm import tqdm

from . import util
from .instance_segmentation import AMGBase, mask_data_to_segmentation
from .prompt_based_segmentation import segment_from_mask
from .instance_segmentation import AMGBase, mask_data_to_segmentation

PROJECTION_MODES = ("box", "mask", "points", "points_and_mask", "single_point")

Expand Down
5 changes: 3 additions & 2 deletions micro_sam/precompute_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@

import os
import pickle

from functools import partial
from glob import glob
from pathlib import Path
from functools import partial
from typing import Optional, Tuple, Union, List

import h5py
import numpy as np

import torch
import torch.nn as nn

from segment_anything.predictor import SamPredictor

try:
Expand Down
9 changes: 6 additions & 3 deletions micro_sam/prompt_based_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
from typing import Optional, Tuple

import numpy as np
import torch
from nifty.tools import blocking
from skimage.feature import peak_local_max
from skimage.filters import gaussian
from skimage.feature import peak_local_max
from skimage.segmentation import find_boundaries
from scipy.ndimage import distance_transform_edt

import torch

from nifty.tools import blocking

from segment_anything.predictor import SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide

from . import util


Expand Down
4 changes: 2 additions & 2 deletions micro_sam/sam_annotator/_annotator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import napari
import numpy as np

from magicgui.widgets import Widget, Container, FunctionGui
from qtpy import QtWidgets
from magicgui.widgets import Widget, Container, FunctionGui

from . import _widgets as widgets
from . import util as vutil
from . import _widgets as widgets
from ._state import AnnotatorState


Expand Down
7 changes: 4 additions & 3 deletions micro_sam/sam_annotator/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
https://itnext.io/deciding-the-best-singleton-approach-in-python-65c61e90cdc4
"""

from dataclasses import dataclass, field
from functools import partial
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

import zarr
import numpy as np
from qtpy.QtWidgets import QWidget

import torch.nn as nn
import zarr

import micro_sam.util as util
from micro_sam.instance_segmentation import AMGBase, get_decoder
from micro_sam.precompute_state import cache_amg_state, cache_is_state
from qtpy.QtWidgets import QWidget

from segment_anything import SamPredictor

Expand Down
Loading

0 comments on commit 8add576

Please sign in to comment.