-
Notifications
You must be signed in to change notification settings - Fork 3
/
get_transforms.py
104 lines (94 loc) · 4.04 KB
/
get_transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from monai.transforms import (
AsDiscrete,
EnsureChannelFirstd,
Compose,
CropForegroundd,
LoadImaged,
Orientationd,
RandFlipd,
RandCropByPosNegLabeld,
RandShiftIntensityd,
ScaleIntensityRanged,
ScaleIntensityRangePercentilesd,
Spacingd,
RandRotate90d,
RandGaussianNoised,
RandRotated,
SpatialPadd,
NormalizeIntensityd,
RandScaleIntensityd,
ScaleIntensityd,
ToTensord,
AddChanneld,
SpatialCropd,
CenterSpatialCropd,
Activationsd,
AsDiscreted,
Invertd,
)
def get_test_transforms(params):
test_transforms = Compose(
[
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys=["image"]),
Orientationd(keys=["image"], axcodes="RAS"),
#CenterSpatialCropd(keys=["image"], roi_size=(160, 218, 182)), #fix w when sampleing
ScaleIntensityRangePercentilesd(keys=["image"], lower=0.5, upper=99.5, b_min=0, b_max=1, clip=True, channel_wise=True),
#NormalizeIntensityd(keys=["image"], channel_wise=True, nonzero=False),
]
)
post_transforms = Compose(
[
Activationsd(keys="pred", sigmoid=False, softmax=True),
Invertd(
keys="pred", # invert the `pred` data field, also support multiple fields
transform=test_transforms,
orig_keys="image", # get the previously applied pre_transforms information on the `img` data field,
# then invert `pred` based on this information. we can use same info
# for multiple fields, also support different orig_keys for different fields
nearest_interp=False, # don't change the interpolation mode to "nearest" when inverting transforms
# to ensure a smooth output, then execute `AsDiscreted` transform
to_tensor=True, # convert to PyTorch Tensor after inverting
),
AsDiscreted(keys="pred", argmax=True),
]
)
return test_transforms, post_transforms
def get_trainval_transforms(params):
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
#CenterSpatialCropd(keys=["image", "label"], roi_size=(160, 218, 182)), #fix w when sampleing
ScaleIntensityRangePercentilesd(["image"], lower=0.5, upper=99.5, b_min=0, b_max=1, clip=True, channel_wise=True),
#NormalizeIntensityd(keys=["image"], channel_wise=True, nonzero=False),
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=params['patch_size'],
pos=1,
neg=1,
num_samples=params['samples_per_case'],
image_key=None,
image_threshold=0,
),
RandFlipd(keys=["image", "label"], spatial_axis=[0], prob=0.5),
RandFlipd(keys=["image", "label"], spatial_axis=[1], prob=0.5),
RandFlipd(keys=["image", "label"], spatial_axis=[2], prob=0.5),
RandRotated(keys=["image", "label"], range_y=1.0/2*3.14159, mode=["bilinear","nearest"], prob=0.5),
RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
RandGaussianNoised(keys=["image"], std=0.1, prob=0.5),
]
)
valid_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
#CenterSpatialCropd(keys=["image", "label"], roi_size=(160, 218, 182)), #fix w when sampleing
ScaleIntensityRangePercentilesd(keys=["image"], lower=0.5, upper=99.5, b_min=0, b_max=1, clip=True, channel_wise=True),
#NormalizeIntensityd(keys=["image"], channel_wise=True, nonzero=False),
]
)
return train_transforms, valid_transforms