Skip to content

Commit

Permalink
Include training for SimCLR and auto-encoder (#17)
Browse files Browse the repository at this point in the history
include simclr and ae
  • Loading branch information
angerhang authored Nov 14, 2023
1 parent 1c2ce6a commit 1079160
Show file tree
Hide file tree
Showing 18 changed files with 1,566 additions and 36 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ data/*
*.log
*.pt
*.mdl
*.png
*.ipynb

# Org-mode
.org-id-locations
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ repos:
rev: v1.2.3
hooks:
- id: flake8
args: ['--ignore=E203,W503']
args: ['--ignore=E203,W503,E731']
exclude: mtl.py
25 changes: 25 additions & 0 deletions conf/evaluation/flip_net_ae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
## baselines
## Pre-trained flip-net fine-tuned fine tune all layers

feat_hand_crafted: false
feat_random_cnn: false

# trained networks
model_path: /data/UKBB/SSL/day_sec_10k/logs/models
flip_net: false
flip_net_ft: true
flip_net_random_mlp: false
load_weights: true
freeze_weight: false
flip_net_path: "/data/UKBB/SSL/final_models/ae.mdl"
input_size: 300 # input size after resampling the raw data
subR: 1


# hyper-parameters
learning_rate: 0.0001
num_workers: 6
patience: 5
num_epoch: 200

evaluation_name: flip_net_ae
26 changes: 26 additions & 0 deletions conf/evaluation/flip_net_simclr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
## baselines
## Pre-trained flip-net fine-tuned fine tune all layers

feat_hand_crafted: false
feat_random_cnn: false

# trained networks
model_path: /data/UKBB/SSL/day_sec_10k/logs/models
flip_net: false
flip_net_ft: true
flip_net_random_mlp: false
load_weights: true
freeze_weight: false
# flip_net_path: "/data/UKBB/SSL/final_models/simclr_first.mdl"
flip_net_path: "/data/UKBB/SSL/day_sec_1k/logs/models/simclr_200e.mdl"
input_size: 300 # input size after resampling the raw data
subR: 1


# hyper-parameters
learning_rate: 0.0001
num_workers: 6
patience: 5
num_epoch: 200

evaluation_name: flip_net_simclr_100k
2 changes: 1 addition & 1 deletion conf/model/resnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ resnet_version: 1
warm_up_step: 5
lr_scale: true
patience: 5

is_ae: false
9 changes: 9 additions & 0 deletions conf/task/ae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
rotation: false
switch_axis: false
time_reversal: false
permutation: false
scale: false
time_warped: false
positive_ratio: 0.5
task_name: 'ae'
multi: false
3 changes: 2 additions & 1 deletion conf/task/permutation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ permutation: true
scale: false
time_warped: false
positive_ratio: 0.5
task_name: 'permutation'
task_name: 'permutation'
multi: false
1 change: 1 addition & 0 deletions conf/task/scale.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ scale_sigma: 0.25
min_scale_sigma: 0.05
positive_ratio: 0.5
task_name: 'scale'
multi: false
9 changes: 9 additions & 0 deletions conf/task/simclr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
rotation: false
switch_axis: false
time_reversal: false
permutation: false
scale: false
time_warped: false
positive_ratio: 0.5
task_name: 'simclr'
multi: false
3 changes: 2 additions & 1 deletion conf/task/time_reversal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ permutation: false
scale: false
time_warped: false
positive_ratio: 0.5
task_name: 'aot'
task_name: 'aot'
multi: false
39 changes: 21 additions & 18 deletions downstream_task_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pathlib

# SSL net
from sslearning.models.accNet import cnn1, SSLNET, Resnet
from sslearning.models.accNet import cnn1, SSLNET, Resnet, EncoderMLP
from sslearning.scores import classification_scores, classification_report
import copy
from sklearn import preprocessing
Expand Down Expand Up @@ -282,7 +282,9 @@ def mlp_predict(model, data_loader, my_device, cfg):


def init_model(cfg, my_device):
if cfg.model.resnet_version > 0:
if cfg.model.is_ae:
model = EncoderMLP(cfg.data.output_size)
elif cfg.model.resnet_version > 0:
model = Resnet(
output_size=cfg.data.output_size,
is_eva=True,
Expand All @@ -297,6 +299,7 @@ def init_model(cfg, my_device):
if cfg.multi_gpu:
model = nn.DataParallel(model, device_ids=cfg.gpu_ids)

print(model)
model.to(my_device, dtype=torch.float)
return model

Expand All @@ -305,11 +308,8 @@ def setup_model(cfg, my_device):
model = init_model(cfg, my_device)

if cfg.evaluation.load_weights:
load_weights(
cfg.evaluation.flip_net_path,
model,
my_device
)
print("Loading weights from %s" % cfg.evaluation.flip_net_path)
load_weights(cfg.evaluation.flip_net_path, model, my_device)
if cfg.evaluation.freeze_weight:
freeze_weights(model)
return model
Expand Down Expand Up @@ -503,11 +503,11 @@ def handcraft_features(xyz, sample_rate):
feats["std"] = np.std(m)
feats["range"] = np.ptp(m)
feats["mad"] = stats.median_abs_deviation(m)
if feats['std'] > .01:
feats['skew'] = np.nan_to_num(stats.skew(m))
feats['kurt'] = np.nan_to_num(stats.kurtosis(m))
if feats["std"] > 0.01:
feats["skew"] = np.nan_to_num(stats.skew(m))
feats["kurt"] = np.nan_to_num(stats.kurtosis(m))
else:
feats['skew'] = feats['kurt'] = 0
feats["skew"] = feats["kurt"] = 0
feats["enmomean"] = np.mean(np.abs(m - 1))

# Spectrum using Welch's method with 3s segment length
Expand Down Expand Up @@ -620,9 +620,7 @@ def get_data_with_subject_count(subject_count, X, y, pid):
return filter_X, filter_y, filter_pid


def load_weights(
weight_path, model, my_device
):
def load_weights(weight_path, model, my_device):
# only need to change weights name when
# the model is trained in a distributed manner

Expand All @@ -632,12 +630,17 @@ def load_weights(
) # v2 has the right para names

# distributed pretraining can be inferred from the keys' module. prefix
head = next(iter(pretrained_dict_v2)).split('.')[0] # get head of first key
if head == 'module':
head = next(iter(pretrained_dict_v2)).split(".")[
0
] # get head of first key
if head == "module":
# remove module. prefix from dict keys
pretrained_dict_v2 = {k.partition('module.')[2]: pretrained_dict_v2[k] for k in pretrained_dict_v2.keys()}
pretrained_dict_v2 = {
k.partition("module.")[2]: pretrained_dict_v2[k]
for k in pretrained_dict_v2.keys()
}

if hasattr(model, 'module'):
if hasattr(model, "module"):
model_dict = model.module.state_dict()
multi_gpu_ft = True
else:
Expand Down
2 changes: 1 addition & 1 deletion mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def main_worker(rank, cfg):
task_losses.append(task_loss)

train_task_losses = np.array(task_losses)
if epoch < cfg.model.warm_up_step:
if epoch >= cfg.model.warm_up_step:
scheduler.step()

train_losses = np.array(train_losses)
Expand Down
112 changes: 112 additions & 0 deletions sslearning/data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,48 @@ def subject_collate(batch):
return [data, aot_y, scale_y, permutation_y, time_w_y]


def simclr_subject_collate(batch):
x1 = [item[0] for item in batch]
x1 = torch.cat(x1)
x2 = [item[1] for item in batch]
x2 = torch.cat(x2)
return [x1, x2]


def worker_init_fn(worker_id):
np.random.seed(int(time.time()))


def augment_view(X, cfg):
new_X = []
X = X.numpy()

for i in range(len(X)):
current_x = X[i, :, :]

# choice = np.random.choice(
# 2, 1, p=[cfg.task.positive_ratio, 1 - cfg.task.positive_ratio]
# )[0]
# current_x = my_transforms.flip(current_x, choice)
# choice = np.random.choice(
# 2, 1, p=[cfg.task.positive_ratio, 1 - cfg.task.positive_ratio]
# )[0]
# current_x = my_transforms.permute(current_x, choice)
# choice = np.random.choice(
# 2, 1, p=[cfg.task.positive_ratio, 1 - cfg.task.positive_ratio]
# )[0]
# current_x = my_transforms.time_warp(current_x, choice)
choice = np.random.choice(
2, 1, p=[cfg.task.positive_ratio, 1 - cfg.task.positive_ratio]
)[0]
current_x = my_transforms.rotation(current_x, choice)
new_X.append(current_x)

new_X = np.array(new_X)
new_X = torch.Tensor(new_X)
return new_X


def generate_labels(X, shuffle, cfg):
labels = []
new_X = []
Expand Down Expand Up @@ -314,6 +352,80 @@ def __getitem__(self, idx):
)


class SIMCLR_dataset:
def __init__(
self,
data_root,
file_list_path,
cfg,
transform=None,
shuffle=False,
is_epoch_data=False,
):
"""
Args:
data_root (string): directory containing all data files
file_list_path (string): file list
cfg (dict): config
shuffle (bool): whether permute epoches within one subject
is_epoch_data (bool): whether each sample is one
second of data or 10 seconds of data
Returns:
data : transformed sample
labels (dict) : labels for avalaible transformations
"""
check_file_list(file_list_path, data_root, cfg)
file_list_df = pd.read_csv(file_list_path)
self.file_list = file_list_df["file_list"].to_list()
self.data_root = data_root
self.cfg = cfg
self.is_epoch_data = is_epoch_data
self.ratio2keep = cfg.data.ratio2keep
self.shuffle = shuffle
self.transform = transform

def __len__(self):
return len(self.file_list)

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
print(idx)

# idx starts from zero
file_to_load = self.file_list[idx]
X = np.load(file_to_load, allow_pickle=True)

# to help select a percentage of data per subject
subject_data_count = int(len(X) * self.ratio2keep)
assert subject_data_count >= self.cfg.dataloader.num_sample_per_subject
if self.ratio2keep != 1:
X = X[:subject_data_count, :]

if self.is_epoch_data:
X = weighted_epoch_sample(
X, num_sample=self.cfg.dataloader.num_sample_per_subject
)
else:
X = weighted_sample(
X,
num_sample=self.cfg.dataloader.num_sample_per_subject,
epoch_len=self.cfg.dataloader.epoch_len,
sample_rate=self.cfg.dataloader.sample_rate,
is_weighted_sample=self.cfg.data.weighted_sample,
)

X = torch.from_numpy(X)
if self.transform:
X = self.transform(X)

X1 = augment_view(X, self.cfg)
X2 = augment_view(X, self.cfg)
return (X1, X2)


# Return:
# x: batch_size * feature size (125)
# y: batch_size * label_size (5)
Expand Down
20 changes: 10 additions & 10 deletions sslearning/data/data_transformation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
from transforms3d.axangles import axangle2mat # for rotation
from scipy.interpolate import CubicSpline # for warping
import math

"""
This file implements a list of transforms for tri-axial raw-accelerometry
Expand Down Expand Up @@ -29,17 +28,18 @@ def rotation(sample, choice):
choice (float): [0, 9] for each axis,
we can do 4 rotations 0, 90 180, 270
"""
if choice == 9:
return sample
if choice == 1:
# angle_choices = [1 / 4 * np.pi, 1 / 2 * np.pi, 3 / 4 * np.pi]
# angle = angle_choices[choice % 3]
# axis = axis_choices[math.floor(choice / 3)]

axis_choices = [[0, 0, 1], [0, 1, 0], [1, 0, 0]]
angle_choices = [1 / 4 * np.pi, 1 / 2 * np.pi, 3 / 4 * np.pi]
axis = axis_choices[math.floor(choice / 3)]
angle = angle_choices[choice % 3]
axes = [[0, 0, 1], [0, 1, 0], [1, 0, 0]]
sample = np.swapaxes(sample, 0, 1)
for i in range(3):
angle = np.random.uniform(low=-np.pi, high=np.pi)
sample = np.matmul(sample, axangle2mat(axes[i], angle))

sample = np.swapaxes(sample, 0, 1)
sample = np.matmul(sample, axangle2mat(axis, angle))
sample = np.swapaxes(sample, 0, 1)
sample = np.swapaxes(sample, 0, 1)
return sample


Expand Down
Loading

0 comments on commit 1079160

Please sign in to comment.