Skip to content

Commit

Permalink
Merge branch 'clinicadl_v2' into caps_dataset_transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
thibaultdvx committed Dec 17, 2024
2 parents c407d9f + 034e4db commit 1f90670
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 708 deletions.
1 change: 0 additions & 1 deletion clinicadl/transforms/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .base import Extraction, Sample
from .image import Image
from .patch import Patch
from .roi import ROI
from .slice import Slice
44 changes: 3 additions & 41 deletions clinicadl/transforms/extraction/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Extraction(ClinicaDLConfig, ABC):
@property
@abstractmethod
def extract_method(self) -> ExtractionMethod:
"""The method to be used for the extraction process (ROI, Image, Patch, Slice)."""
"""The method to be used for the extraction process (Image, Patch, Slice)."""

@staticmethod
def load_image(input_img: Path) -> torch.Tensor:
Expand Down Expand Up @@ -180,7 +180,7 @@ def num_samples_per_image(self, image: torch.Tensor) -> int:
def _get_sample_description(
self, image_tensor: torch.Tensor, sample_index: int
) -> Any:
"""A description of the sample, e.g. slice position or ROI mask path."""
"""A description of the sample, e.g. slice position or patch index."""

@abstractmethod
def format_output(
Expand Down Expand Up @@ -217,44 +217,6 @@ def format_output(
'label' and 'description'.
"""

@staticmethod
def get_tio_image(
image: torch.Tensor,
label: Optional[Union[float, int, torch.Tensor]],
**masks: torch.Tensor,
) -> tio.Subject:
"""
Creates a TorchIO Subject from the image, the label and possibly
masks related to the image.
Parameters
----------
image : torch.Tensor
the image, as a Pytorch tensor.
label : Optional[Union[float, int, torch.Tensor]]
the label related to the image. Can be None if no label.
**masks : torch.Tensor
any mask related to the image and useful to compute transforms.
Returns
-------
tio.Subject
the TorchIO subject with the image and the label, accessible via
the attributes 'image' and 'label', as well as the masks, accessible
via their names.
"""
tio_image = tio.Subject(image=tio.ScalarImage(tensor=image))

if isinstance(label, torch.Tensor):
tio_image.add_image(tio.LabelMap(tensor=label), "label")
else:
setattr(tio_image, "label", label)

for name, mask in masks.items():
tio_image.add_image(tio.LabelMap(tensor=mask), name)

return tio_image

def extract_tio_sample(
self, tio_image: tio.Subject, sample_index: int
) -> tio.Subject:
Expand Down Expand Up @@ -303,7 +265,7 @@ def extract_tio_sample(

tio_sample.description = self._get_sample_description(
image.tensor, sample_index
) # e.g. roi or slice position
)

tio_sample.sample = tio_sample.image
delattr(tio_sample, "image")
Expand Down
2 changes: 1 addition & 1 deletion clinicadl/transforms/extraction/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Image(Extraction):
@computed_field
@property
def extract_method(self) -> ExtractionMethod:
"""The method to be used for the extraction process (ROI, Image, Patch, Slice)."""
"""The method to be used for the extraction process (Image, Patch, Slice)."""
return ExtractionMethod.IMAGE

def extract(self, nii_path: Path) -> List[Tuple[Path, torch.Tensor]]:
Expand Down
2 changes: 1 addition & 1 deletion clinicadl/transforms/extraction/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class Patch(Extraction):
@computed_field
@property
def extract_method(self) -> ExtractionMethod:
"""The method to be used for the extraction process (ROI, Image, Patch, Slice)."""
"""The method to be used for the extraction process (Image, Patch, Slice)."""
return ExtractionMethod.PATCH

@field_validator("patch_size", "stride", mode="after")
Expand Down
Loading

0 comments on commit 1f90670

Please sign in to comment.