Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add RandAugment_T to pipelines #2154

Open
wants to merge 1 commit into
base: 0.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmaction/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@
'PyAVDecodeMotionVector', 'Rename', 'Imgaug', 'UniformSampleFrames',
'PoseDecode', 'LoadKineticsPose', 'GeneratePoseTarget', 'PIMSInit',
'PIMSDecode', 'TorchvisionTrans', 'PytorchVideoTrans', 'PoseNormalize',
'FormatGCNInput', 'PaddingWithLoop', 'ArrayDecode', 'JointToBone'
'FormatGCNInput', 'PaddingWithLoop', 'ArrayDecode', 'JointToBone', 'RandAugment_T'
]
54 changes: 54 additions & 0 deletions mmaction/datasets/pipelines/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import random
import warnings
from collections.abc import Sequence
from PIL import Image

import cv2
import mmcv
import numpy as np
from mmcv.utils import digit_version
from randaugment_utils import Augment
from torch.nn.modules.utils import _pair

from ..builder import PIPELINES
Expand Down Expand Up @@ -266,6 +268,58 @@ def __repr__(self):
f'allow_imgpad={self.allow_imgpad})')
return repr_str

@PIPELINES.register_module()
class RandAugment_T(Augment):
"""Apply a random augment that linearly changes from a starting frame to an end frame.

See paper "Learning Temporally Invariant and Localizable Features via
Data Augmentation for Video Recognition", Taeoh Kim et al., 2020
(https://arxiv.org/pdf/2008.05721.pdf) for details.

Args:
n (int): Number of augments to be applied sequentially. Default: 2.
m (int): Magnitude of each augment between range [0,30]. Default: 7.
temp_degree (boolean): Change augment intensity temporally. Default: True.
range (float): Highest relative change in magnitude between frames, [0, 1.0]. Default: 1.0.
"""

def __init__(self, n=2, m=7, temp_degree=True, range=1.0):
super(RandAugment_T, self).__init__()
self.max_severity = 30
self.temp_degree = temp_degree
self.n = n
self.m = m # usually values in the range [5, 30] works best
self.range = range
self.augment_list = self.augment_list()

def __call__(self, results):
buffer = [Image.fromarray(img.astype('uint8'))
for img in np.array(results['imgs'])]

ops = random.choices(self.augment_list, k=self.n)
for op, minval, maxval in ops:
if self.temp_degree:
val_list = [(float(self.m) / self.max_severity)
* float(maxval - minval) + minval]
else: # temp_degree == False
tval = float(np.random.uniform(
low=0.0, high=0.5 * self.range * self.m))
if random.random() > 0.5:
val_list = [((float(self.m) - tval) / self.max_severity)
* float(maxval - minval) + minval]
val_list.extend(
[((float(self.m) + tval) / self.max_severity) * float(maxval - minval) + minval])
else:
val_list = [((float(self.m) + tval) / self.max_severity)
* float(maxval - minval) + minval]
val_list.extend(
[((float(self.m) - tval) / self.max_severity) * float(maxval - minval) + minval])
buffer = op(buffer, val_list)

results['imgs'] = np.array(
[np.array(img, np.dtype('int64')) for img in buffer])

return results

@PIPELINES.register_module()
class Imgaug:
Expand Down
189 changes: 189 additions & 0 deletions mmaction/datasets/pipelines/randaugment_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
from ..builder import PIPELINES
import numpy as np
import random
import PIL
import PIL.ImageOps
import PIL.ImageEnhance
import PIL.ImageDraw


def temporal_interpolate(v_list, t, n):
if len(v_list) == 1:
return v_list[0]
elif len(v_list) == 2:
return v_list[0] + (v_list[1] - v_list[0]) * t / n
else:
NotImplementedError('Invalid degree')


class Augment:
def __init__(self):
pass

def __call__(self, buffer):
raise NotImplementedError

def ShearX(self, imgs, v_list): # [-0.3, 0.3]
for v in v_list:
assert -0.3 <= v <= 0.3
if random.random() > 0.5:
v_list = [-v for v in v_list]

out = [img.transform(img.size, PIL.Image.Transform.AFFINE, (1, temporal_interpolate(
v_list, t, len(imgs) - 1), 0, 0, 1, 0)) for t, img in enumerate(imgs)]
return out

def ShearY(self, imgs, v_list): # [-0.3, 0.3]
for v in v_list:
assert -0.3 <= v <= 0.3
if random.random() > 0.5:
v_list = [-v for v in v_list]

out = [img.transform(img.size, PIL.Image.Transform.AFFINE, (1, 0, 0, temporal_interpolate(
v_list, t, len(imgs) - 1), 1, 0)) for t, img in enumerate(imgs)]
return out

# [-150, 150] => percentage: [-0.45, 0.45]
def TranslateX(self, imgs, v_list):
for v in v_list:
assert -0.45 <= v <= 0.45
if random.random() > 0.5:
v_list = [-v for v in v_list]
v_list = [v * imgs.size[1] for v in v_list]

out = [img.transform(img.size, PIL.Image.Transform.AFFINE, (1, 0, temporal_interpolate(
v_list, t, len(imgs) - 1), 0, 1, 0)) for t, img in enumerate(imgs)]
return out

# [-150, 150] => percentage: [-0.45, 0.45]
def TranslateXabs(self, imgs, v_list):
for v in v_list:
assert 0 <= v
if random.random() > 0.5:
v_list = [-v for v in v_list]

out = [img.transform(img.size, PIL.Image.Transform.AFFINE, (1, 0, temporal_interpolate(
v_list, t, len(imgs) - 1), 0, 1, 0)) for t, img in enumerate(imgs)]
return out

# [-150, 150] => percentage: [-0.45, 0.45]
def TranslateY(self, imgs, v_list):
for v in v_list:
assert -0.45 <= v <= 0.45
if random.random() > 0.5:
v_list = [-v for v in v_list]
v_list = [v * imgs.size[2] for v in v_list]

out = [img.transform(img.size, PIL.Image.Transform.AFFINE, (1, 0, 0, 0, 1, temporal_interpolate(
v_list, t, len(imgs) - 1))) for t, img in enumerate(imgs)]
return out

# [-150, 150] => percentage: [-0.45, 0.45]
def TranslateYabs(self, imgs, v_list):
for v in v_list:
assert 0 <= v
if random.random() > 0.5:
v_list = [-v for v in v_list]

out = [img.transform(img.size, PIL.Image.Transform.AFFINE, (1, 0, 0, 0, 1, temporal_interpolate(
v_list, t, len(imgs) - 1))) for t, img in enumerate(imgs)]
return out

def Rotate(self, imgs, v_list): # [-30, 30]
for v in v_list:
assert -30 <= v <= 30
if random.random() > 0.5:
v_list = [-v for v in v_list]

out = [img.rotate(temporal_interpolate(v_list, t, len(imgs) - 1))
for t, img in enumerate(imgs)]
return out

def AutoContrast(self, imgs, _):
out = [PIL.ImageOps.autocontrast(img) for img in imgs]
return out

def Invert(self, imgs, _):
out = [PIL.ImageOps.invert(img) for img in imgs]
return out

def Equalize(self, imgs, _):
out = [PIL.ImageOps.equalize(img) for img in imgs]
return out

def Flip(self, imgs, _): # not from the paper
out = [PIL.ImageOps.mirror(img) for img in imgs]
return out

def Solarize(self, imgs, v_list): # [0, 256]
for v in v_list:
assert 0 <= v <= 256

out = [PIL.ImageOps.solarize(img, temporal_interpolate(
v_list, t, len(imgs) - 1)) for t, img in enumerate(imgs)]
return out

def Posterize(self, imgs, v_list): # [4, 8]
v_list = [max(1, int(v)) for v in v_list]
v_list = [max(1, int(v)) for v in v_list]

out = [PIL.ImageOps.posterize(img, int(temporal_interpolate(
v_list, t, len(imgs) - 1))) for t, img in enumerate(imgs)]
return out

def Contrast(self, imgs, v_list): # [0.1,1.9]
for v in v_list:
assert 0.1 <= v <= 1.9

out = [PIL.ImageEnhance.Contrast(img).enhance(temporal_interpolate(
v_list, t, len(imgs) - 1)) for t, img in enumerate(imgs)]
return out

def Color(self, imgs, v_list): # [0.1,1.9]
for v in v_list:
assert 0.1 <= v <= 1.9

out = [PIL.ImageEnhance.Color(img).enhance(temporal_interpolate(
v_list, t, len(imgs) - 1)) for t, img in enumerate(imgs)]
return out

def Brightness(self, imgs, v_list): # [0.1,1.9]
for v in v_list:
assert 0.1 <= v <= 1.9

out = [PIL.ImageEnhance.Brightness(img).enhance(temporal_interpolate(
v_list, t, len(imgs) - 1)) for t, img in enumerate(imgs)]
return out

def Sharpness(self, imgs, v_list): # [0.1,1.9]
for v in v_list:
assert 0.1 <= v <= 1.9

out = [PIL.ImageEnhance.Sharpness(img).enhance(temporal_interpolate(
v_list, t, len(imgs) - 1)) for t, img in enumerate(imgs)]
return out

def Identity(self, imgs, _):
return imgs

def augment_list(self):
# list of data augmentations and their ranges
l = [
(self.Identity, 0, 1),
(self.AutoContrast, 0, 1),
(self.Equalize, 0, 1),
(self.Invert, 0, 1),
(self.Rotate, 0, 30),
(self.Posterize, 0, 4),
(self.Solarize, 0, 256),
(self.Color, 0.1, 1.9),
(self.Contrast, 0.1, 1.9),
(self.Brightness, 0.1, 1.9),
(self.Sharpness, 0.1, 1.9),
(self.ShearX, 0., 0.3),
(self.ShearY, 0., 0.3),
(self.TranslateXabs, 0., 100),
(self.TranslateYabs, 0., 100),
]

return l