Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor updates to finetuning notebook #753

Merged
merged 7 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions micro_sam/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ def default_sam_dataset(
is_train: bool = True,
min_size: int = 25,
max_sampling_attempts: Optional[int] = None,
is_seg_dataset: Optional[bool] = None,
**kwargs,
) -> Dataset:
"""Create a PyTorch Dataset for training a SAM model.
Expand All @@ -413,11 +412,10 @@ def default_sam_dataset(
is_train: Whether this dataset is used for training or validation.
min_size: Minimal object size. Smaller objects will be filtered.
max_sampling_attempts: Number of sampling attempts to make from a dataset.
is_seg_dataset: Whether the dataset is built 'from torch_em.data import SegmentationDataset'
or 'from torch_em.data import ImageCollectionDataset'
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.

Returns:
The dataset.
The segmentation dataset.
"""

# Set the data transformations.
Expand Down Expand Up @@ -446,18 +444,29 @@ def default_sam_dataset(
# Set a minimum number of samples per epoch.
if n_samples is None:
loader = torch_em.default_segmentation_loader(
raw_paths, raw_key, label_paths, label_key, batch_size=1,
patch_shape=patch_shape, ndim=2, is_seg_dataset=is_seg_dataset,
raw_paths=raw_paths,
raw_key=raw_key,
label_paths=label_paths,
label_key=label_key,
batch_size=1,
patch_shape=patch_shape,
ndim=2,
**kwargs
)
n_samples = max(len(loader), 100 if is_train else 5)

dataset = torch_em.default_segmentation_dataset(
raw_paths, raw_key, label_paths, label_key,
raw_paths=raw_paths,
raw_key=raw_key,
label_paths=label_paths,
label_key=label_key,
patch_shape=patch_shape,
raw_transform=raw_transform, label_transform=label_transform,
with_channels=with_channels, ndim=2,
sampler=sampler, n_samples=n_samples,
is_seg_dataset=is_seg_dataset,
raw_transform=raw_transform,
label_transform=label_transform,
with_channels=with_channels,
ndim=2,
sampler=sampler,
n_samples=n_samples,
**kwargs,
)

Expand All @@ -472,10 +481,23 @@ def default_sam_dataset(


def default_sam_loader(**kwargs) -> DataLoader:
ds_kwargs, loader_kwargs = split_kwargs(default_sam_dataset, **kwargs)
"""Create a PyTorch DataLoader for training a SAM model.

Args:
kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader.

Returns:
The DataLoader.
"""
sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs)

# There might be additional parameters supported by `torch_em.default_segmentation_dataset`,
# which the users can provide to get their desired segmentation dataset.
extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs)
ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs}

ds = default_sam_dataset(**ds_kwargs)
loader = torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
return loader
return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)


CONFIGURATIONS = {
Expand Down Expand Up @@ -517,7 +539,7 @@ def train_sam_for_configuration(
model_type: Over-ride the default model type.
This can be used to use one of the micro_sam models as starting point
instead of a default sam model.
kwargs: Additional keyword parameterts that will be passed to `train_sam`.
kwargs: Additional keyword parameters that will be passed to `train_sam`.
"""
if configuration in CONFIGURATIONS:
train_kwargs = CONFIGURATIONS[configuration]
Expand Down
345 changes: 150 additions & 195 deletions notebooks/sam_finetuning.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion workshops/i2k_2024/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ $ git clone https://github.com/computational-cell-analytics/micro-sam
then go to this directory:

```bash
$ cd micro_sam/workshops/i2k_2024
$ cd micro-sam/workshops/i2k_2024
```

and download the precomputed embeddings:
Expand Down
Loading