Skip to content

Commit

Permalink
add check for regions
Browse files Browse the repository at this point in the history
  • Loading branch information
ykirchhoff committed Sep 5, 2024
1 parent ee1bd9b commit f2ff614
Showing 1 changed file with 3 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@
from batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform
from nnunetv2.training.data_augmentation.custom_transforms.skeletonization import SkeletonTransform


class nnUNetTrainerSkeletonRecall(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
self.weight_srec = 1 # This is the default value, you can change it if you want
if self.label_manager.has_regions:
raise NotImplementedError("trainer not implemented for regions")

def _build_loss(self):
if self.label_manager.ignore_label is not None:
Expand Down

0 comments on commit f2ff614

Please sign in to comment.