Skip to content

Commit

Permalink
Minor updates to finetuning notebook (#753)
Browse files Browse the repository at this point in the history
* Minor updates to spliting additional arguments

* Debug kwargs

* Capture dataset arguments from loader_kwargs

* Minor fix to typo in workshops

* Update finetuning notebook

* Remove unnecessary imports
  • Loading branch information
anwai98 authored Oct 20, 2024
1 parent aebb23e commit 6766019
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 211 deletions.
52 changes: 37 additions & 15 deletions micro_sam/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ def default_sam_dataset(
is_train: bool = True,
min_size: int = 25,
max_sampling_attempts: Optional[int] = None,
is_seg_dataset: Optional[bool] = None,
**kwargs,
) -> Dataset:
"""Create a PyTorch Dataset for training a SAM model.
Expand All @@ -413,11 +412,10 @@ def default_sam_dataset(
is_train: Whether this dataset is used for training or validation.
min_size: Minimal object size. Smaller objects will be filtered.
max_sampling_attempts: Number of sampling attempts to make from a dataset.
is_seg_dataset: Whether the dataset is built 'from torch_em.data import SegmentationDataset'
or 'from torch_em.data import ImageCollectionDataset'
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
Returns:
The dataset.
The segmentation dataset.
"""

# Set the data transformations.
Expand Down Expand Up @@ -446,18 +444,29 @@ def default_sam_dataset(
# Set a minimum number of samples per epoch.
if n_samples is None:
loader = torch_em.default_segmentation_loader(
raw_paths, raw_key, label_paths, label_key, batch_size=1,
patch_shape=patch_shape, ndim=2, is_seg_dataset=is_seg_dataset,
raw_paths=raw_paths,
raw_key=raw_key,
label_paths=label_paths,
label_key=label_key,
batch_size=1,
patch_shape=patch_shape,
ndim=2,
**kwargs
)
n_samples = max(len(loader), 100 if is_train else 5)

dataset = torch_em.default_segmentation_dataset(
raw_paths, raw_key, label_paths, label_key,
raw_paths=raw_paths,
raw_key=raw_key,
label_paths=label_paths,
label_key=label_key,
patch_shape=patch_shape,
raw_transform=raw_transform, label_transform=label_transform,
with_channels=with_channels, ndim=2,
sampler=sampler, n_samples=n_samples,
is_seg_dataset=is_seg_dataset,
raw_transform=raw_transform,
label_transform=label_transform,
with_channels=with_channels,
ndim=2,
sampler=sampler,
n_samples=n_samples,
**kwargs,
)

Expand All @@ -472,10 +481,23 @@ def default_sam_dataset(


def default_sam_loader(**kwargs) -> DataLoader:
ds_kwargs, loader_kwargs = split_kwargs(default_sam_dataset, **kwargs)
"""Create a PyTorch DataLoader for training a SAM model.
Args:
kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader.
Returns:
The DataLoader.
"""
sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs)

# There might be additional parameters supported by `torch_em.default_segmentation_dataset`,
# which the users can provide to get their desired segmentation dataset.
extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs)
ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs}

ds = default_sam_dataset(**ds_kwargs)
loader = torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
return loader
return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)


CONFIGURATIONS = {
Expand Down Expand Up @@ -517,7 +539,7 @@ def train_sam_for_configuration(
model_type: Over-ride the default model type.
This can be used to use one of the micro_sam models as starting point
instead of a default sam model.
kwargs: Additional keyword parameterts that will be passed to `train_sam`.
kwargs: Additional keyword parameters that will be passed to `train_sam`.
"""
if configuration in CONFIGURATIONS:
train_kwargs = CONFIGURATIONS[configuration]
Expand Down
345 changes: 150 additions & 195 deletions notebooks/sam_finetuning.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion workshops/i2k_2024/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ $ git clone https://github.com/computational-cell-analytics/micro-sam
then go to this directory:

```bash
$ cd micro_sam/workshops/i2k_2024
$ cd micro-sam/workshops/i2k_2024
```

and download the precomputed embeddings:
Expand Down

0 comments on commit 6766019

Please sign in to comment.