-
Notifications
You must be signed in to change notification settings - Fork 0
/
augmentations.py
140 lines (118 loc) · 5.27 KB
/
augmentations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import random
import tensorflow as tf
from tensorflow.keras import layers
# the implementation of these image augmentations follow the torchvision library:
# https://github.com/pytorch/vision/blob/master/torchvision/transforms/transforms.py
# https://github.com/pytorch/vision/blob/master/torchvision/transforms/functional_tensor.py
# however these augmentations:
# -run on batches of images
# -run on gpu
# -can be part of a model
# crops and resizes part of the image to the original resolutions
class RandomResizedCrop(layers.Layer):
def __init__(self, scale, ratio, **kwargs):
super().__init__(**kwargs)
# area-range of the cropped part: (min area, max area), uniform sampling
self.scale = scale
# aspect-ratio-range of the cropped part: (log min ratio, log max ratio), log-uniform sampling
self.log_ratio = (tf.math.log(ratio[0]), tf.math.log(ratio[1]))
def call(self, images, training=True):
if training:
batch_size = tf.shape(images)[0]
height = tf.shape(images)[1]
width = tf.shape(images)[2]
# independently sampled scales and ratios for every image in the batch
random_scales = tf.random.uniform(
(batch_size,), self.scale[0], self.scale[1]
)
random_ratios = tf.exp(
tf.random.uniform((batch_size,), self.log_ratio[0], self.log_ratio[1])
)
# corresponding height and widths, clipped to fit in the image
new_heights = tf.clip_by_value(tf.sqrt(random_scales / random_ratios), 0, 1)
new_widths = tf.clip_by_value(tf.sqrt(random_scales * random_ratios), 0, 1)
# random anchors for the crop bounding boxes
height_offsets = tf.random.uniform((batch_size,), 0, 1 - new_heights)
width_offsets = tf.random.uniform((batch_size,), 0, 1 - new_widths)
# assemble bounding boxes and crop
bounding_boxes = tf.stack(
[
height_offsets,
width_offsets,
height_offsets + new_heights,
width_offsets + new_widths,
],
axis=1,
)
images = tf.image.crop_and_resize(
images, bounding_boxes, tf.range(batch_size), (height, width)
)
return images
# distorts the color distibutions of images
class RandomColorJitter(layers.Layer):
def __init__(self, brightness, contrast, saturation, hue, **kwargs):
super().__init__(**kwargs)
# color jitter ranges: (min jitter strength, max jitter strength)
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
# list of applicable color augmentations
self.color_augmentations = [
self.random_brightness,
self.random_contrast,
self.random_saturation,
self.random_hue,
]
# the tf.image.random_[brightness, contrast, saturation, hue] operations
# cannot be used here, as they transform a batch of images in the same way
def blend(self, images_1, images_2, ratios):
# linear interpolation between two images, with values clipped to the valid range
return tf.clip_by_value(ratios * images_1 + (1.0 - ratios) * images_2, 0, 1)
def random_brightness(self, images):
# random interpolation/extrapolation between the image and darkness
return self.blend(
images,
0,
tf.random.uniform(
(tf.shape(images)[0], 1, 1, 1), 1 - self.brightness, 1 + self.brightness
),
)
def random_contrast(self, images):
# random interpolation/extrapolation between the image and its mean intensity value
mean = tf.reduce_mean(
tf.image.rgb_to_grayscale(images), axis=(1, 2), keepdims=True
)
return self.blend(
images,
mean,
tf.random.uniform(
(tf.shape(images)[0], 1, 1, 1), 1 - self.contrast, 1 + self.contrast
),
)
def random_saturation(self, images):
# random interpolation/extrapolation between the image and its grayscale counterpart
return self.blend(
images,
tf.image.rgb_to_grayscale(images),
tf.random.uniform(
(tf.shape(images)[0], 1, 1, 1), 1 - self.saturation, 1 + self.saturation
),
)
def random_hue(self, images):
# random shift in hue in hsv colorspace
images = tf.image.rgb_to_hsv(images)
images += tf.random.uniform(
(tf.shape(images)[0], 1, 1, 3), (-self.hue, 0, 0), (self.hue, 0, 0)
)
# tf.math.floormod(images, 1.0) should be used here, however in introduces artifacts
images = tf.where(images < 0.0, images + 1.0, images)
images = tf.where(images > 1.0, images - 1.0, images)
images = tf.image.hsv_to_rgb(images)
return images
def call(self, images, training=True):
if training:
# applies color augmentations in random order
for color_augmentation in random.sample(self.color_augmentations, 4):
images = color_augmentation(images)
return images