Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include training for SimCLR and auto-encoder #17

Merged
merged 2 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ data/*
*.log
*.pt
*.mdl
*.png
*.ipynb

# Org-mode
.org-id-locations
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ repos:
rev: v1.2.3
hooks:
- id: flake8
args: ['--ignore=E203,W503']
args: ['--ignore=E203,W503,E731']
exclude: mtl.py
25 changes: 25 additions & 0 deletions conf/evaluation/flip_net_ae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
## baselines
## Pre-trained flip-net fine-tuned fine tune all layers

feat_hand_crafted: false
feat_random_cnn: false

# trained networks
model_path: /data/UKBB/SSL/day_sec_10k/logs/models
flip_net: false
flip_net_ft: true
flip_net_random_mlp: false
load_weights: true
freeze_weight: false
flip_net_path: "/data/UKBB/SSL/final_models/ae.mdl"
input_size: 300 # input size after resampling the raw data
subR: 1


# hyper-parameters
learning_rate: 0.0001
num_workers: 6
patience: 5
num_epoch: 200

evaluation_name: flip_net_ae
26 changes: 26 additions & 0 deletions conf/evaluation/flip_net_simclr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
## baselines
## Pre-trained flip-net fine-tuned fine tune all layers

feat_hand_crafted: false
feat_random_cnn: false

# trained networks
model_path: /data/UKBB/SSL/day_sec_10k/logs/models
flip_net: false
flip_net_ft: true
flip_net_random_mlp: false
load_weights: true
freeze_weight: false
# flip_net_path: "/data/UKBB/SSL/final_models/simclr_first.mdl"
flip_net_path: "/data/UKBB/SSL/day_sec_1k/logs/models/simclr_200e.mdl"
input_size: 300 # input size after resampling the raw data
subR: 1


# hyper-parameters
learning_rate: 0.0001
num_workers: 6
patience: 5
num_epoch: 200

evaluation_name: flip_net_simclr_100k
2 changes: 1 addition & 1 deletion conf/model/resnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ resnet_version: 1
warm_up_step: 5
lr_scale: true
patience: 5

is_ae: false
9 changes: 9 additions & 0 deletions conf/task/ae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
rotation: false
switch_axis: false
time_reversal: false
permutation: false
scale: false
time_warped: false
positive_ratio: 0.5
task_name: 'ae'
multi: false
3 changes: 2 additions & 1 deletion conf/task/permutation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ permutation: true
scale: false
time_warped: false
positive_ratio: 0.5
task_name: 'permutation'
task_name: 'permutation'
multi: false
1 change: 1 addition & 0 deletions conf/task/scale.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ scale_sigma: 0.25
min_scale_sigma: 0.05
positive_ratio: 0.5
task_name: 'scale'
multi: false
9 changes: 9 additions & 0 deletions conf/task/simclr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
rotation: false
switch_axis: false
time_reversal: false
permutation: false
scale: false
time_warped: false
positive_ratio: 0.5
task_name: 'simclr'
multi: false
3 changes: 2 additions & 1 deletion conf/task/time_reversal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ permutation: false
scale: false
time_warped: false
positive_ratio: 0.5
task_name: 'aot'
task_name: 'aot'
multi: false
39 changes: 21 additions & 18 deletions downstream_task_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pathlib

# SSL net
from sslearning.models.accNet import cnn1, SSLNET, Resnet
from sslearning.models.accNet import cnn1, SSLNET, Resnet, EncoderMLP
from sslearning.scores import classification_scores, classification_report
import copy
from sklearn import preprocessing
Expand Down Expand Up @@ -282,7 +282,9 @@ def mlp_predict(model, data_loader, my_device, cfg):


def init_model(cfg, my_device):
if cfg.model.resnet_version > 0:
if cfg.model.is_ae:
model = EncoderMLP(cfg.data.output_size)
elif cfg.model.resnet_version > 0:
model = Resnet(
output_size=cfg.data.output_size,
is_eva=True,
Expand All @@ -297,6 +299,7 @@ def init_model(cfg, my_device):
if cfg.multi_gpu:
model = nn.DataParallel(model, device_ids=cfg.gpu_ids)

print(model)
model.to(my_device, dtype=torch.float)
return model

Expand All @@ -305,11 +308,8 @@ def setup_model(cfg, my_device):
model = init_model(cfg, my_device)

if cfg.evaluation.load_weights:
load_weights(
cfg.evaluation.flip_net_path,
model,
my_device
)
print("Loading weights from %s" % cfg.evaluation.flip_net_path)
load_weights(cfg.evaluation.flip_net_path, model, my_device)
if cfg.evaluation.freeze_weight:
freeze_weights(model)
return model
Expand Down Expand Up @@ -503,11 +503,11 @@ def handcraft_features(xyz, sample_rate):
feats["std"] = np.std(m)
feats["range"] = np.ptp(m)
feats["mad"] = stats.median_abs_deviation(m)
if feats['std'] > .01:
feats['skew'] = np.nan_to_num(stats.skew(m))
feats['kurt'] = np.nan_to_num(stats.kurtosis(m))
if feats["std"] > 0.01:
feats["skew"] = np.nan_to_num(stats.skew(m))
feats["kurt"] = np.nan_to_num(stats.kurtosis(m))
else:
feats['skew'] = feats['kurt'] = 0
feats["skew"] = feats["kurt"] = 0
feats["enmomean"] = np.mean(np.abs(m - 1))

# Spectrum using Welch's method with 3s segment length
Expand Down Expand Up @@ -620,9 +620,7 @@ def get_data_with_subject_count(subject_count, X, y, pid):
return filter_X, filter_y, filter_pid


def load_weights(
weight_path, model, my_device
):
def load_weights(weight_path, model, my_device):
# only need to change weights name when
# the model is trained in a distributed manner

Expand All @@ -632,12 +630,17 @@ def load_weights(
) # v2 has the right para names

# distributed pretraining can be inferred from the keys' module. prefix
head = next(iter(pretrained_dict_v2)).split('.')[0] # get head of first key
if head == 'module':
head = next(iter(pretrained_dict_v2)).split(".")[
0
] # get head of first key
if head == "module":
# remove module. prefix from dict keys
pretrained_dict_v2 = {k.partition('module.')[2]: pretrained_dict_v2[k] for k in pretrained_dict_v2.keys()}
pretrained_dict_v2 = {
k.partition("module.")[2]: pretrained_dict_v2[k]
for k in pretrained_dict_v2.keys()
}

if hasattr(model, 'module'):
if hasattr(model, "module"):
model_dict = model.module.state_dict()
multi_gpu_ft = True
else:
Expand Down
2 changes: 1 addition & 1 deletion mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def main_worker(rank, cfg):
task_losses.append(task_loss)

train_task_losses = np.array(task_losses)
if epoch < cfg.model.warm_up_step:
if epoch >= cfg.model.warm_up_step:
scheduler.step()

train_losses = np.array(train_losses)
Expand Down
112 changes: 112 additions & 0 deletions sslearning/data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,48 @@ def subject_collate(batch):
return [data, aot_y, scale_y, permutation_y, time_w_y]


def simclr_subject_collate(batch):
x1 = [item[0] for item in batch]
x1 = torch.cat(x1)
x2 = [item[1] for item in batch]
x2 = torch.cat(x2)
return [x1, x2]


def worker_init_fn(worker_id):
np.random.seed(int(time.time()))


def augment_view(X, cfg):
new_X = []
X = X.numpy()

for i in range(len(X)):
current_x = X[i, :, :]

# choice = np.random.choice(
# 2, 1, p=[cfg.task.positive_ratio, 1 - cfg.task.positive_ratio]
# )[0]
# current_x = my_transforms.flip(current_x, choice)
# choice = np.random.choice(
# 2, 1, p=[cfg.task.positive_ratio, 1 - cfg.task.positive_ratio]
# )[0]
# current_x = my_transforms.permute(current_x, choice)
# choice = np.random.choice(
# 2, 1, p=[cfg.task.positive_ratio, 1 - cfg.task.positive_ratio]
# )[0]
# current_x = my_transforms.time_warp(current_x, choice)
choice = np.random.choice(
2, 1, p=[cfg.task.positive_ratio, 1 - cfg.task.positive_ratio]
)[0]
current_x = my_transforms.rotation(current_x, choice)
new_X.append(current_x)

new_X = np.array(new_X)
new_X = torch.Tensor(new_X)
return new_X


def generate_labels(X, shuffle, cfg):
labels = []
new_X = []
Expand Down Expand Up @@ -314,6 +352,80 @@ def __getitem__(self, idx):
)


class SIMCLR_dataset:
def __init__(
self,
data_root,
file_list_path,
cfg,
transform=None,
shuffle=False,
is_epoch_data=False,
):
"""
Args:
data_root (string): directory containing all data files
file_list_path (string): file list
cfg (dict): config
shuffle (bool): whether permute epoches within one subject
is_epoch_data (bool): whether each sample is one
second of data or 10 seconds of data


Returns:
data : transformed sample
labels (dict) : labels for avalaible transformations
"""
check_file_list(file_list_path, data_root, cfg)
file_list_df = pd.read_csv(file_list_path)
self.file_list = file_list_df["file_list"].to_list()
self.data_root = data_root
self.cfg = cfg
self.is_epoch_data = is_epoch_data
self.ratio2keep = cfg.data.ratio2keep
self.shuffle = shuffle
self.transform = transform

def __len__(self):
return len(self.file_list)

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
print(idx)

# idx starts from zero
file_to_load = self.file_list[idx]
X = np.load(file_to_load, allow_pickle=True)

# to help select a percentage of data per subject
subject_data_count = int(len(X) * self.ratio2keep)
assert subject_data_count >= self.cfg.dataloader.num_sample_per_subject
if self.ratio2keep != 1:
X = X[:subject_data_count, :]

if self.is_epoch_data:
X = weighted_epoch_sample(
X, num_sample=self.cfg.dataloader.num_sample_per_subject
)
else:
X = weighted_sample(
X,
num_sample=self.cfg.dataloader.num_sample_per_subject,
epoch_len=self.cfg.dataloader.epoch_len,
sample_rate=self.cfg.dataloader.sample_rate,
is_weighted_sample=self.cfg.data.weighted_sample,
)

X = torch.from_numpy(X)
if self.transform:
X = self.transform(X)

X1 = augment_view(X, self.cfg)
X2 = augment_view(X, self.cfg)
return (X1, X2)


# Return:
# x: batch_size * feature size (125)
# y: batch_size * label_size (5)
Expand Down
20 changes: 10 additions & 10 deletions sslearning/data/data_transformation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
from transforms3d.axangles import axangle2mat # for rotation
from scipy.interpolate import CubicSpline # for warping
import math

"""
This file implements a list of transforms for tri-axial raw-accelerometry
Expand Down Expand Up @@ -29,17 +28,18 @@ def rotation(sample, choice):
choice (float): [0, 9] for each axis,
we can do 4 rotations 0, 90 180, 270
"""
if choice == 9:
return sample
if choice == 1:
# angle_choices = [1 / 4 * np.pi, 1 / 2 * np.pi, 3 / 4 * np.pi]
# angle = angle_choices[choice % 3]
# axis = axis_choices[math.floor(choice / 3)]

axis_choices = [[0, 0, 1], [0, 1, 0], [1, 0, 0]]
angle_choices = [1 / 4 * np.pi, 1 / 2 * np.pi, 3 / 4 * np.pi]
axis = axis_choices[math.floor(choice / 3)]
angle = angle_choices[choice % 3]
axes = [[0, 0, 1], [0, 1, 0], [1, 0, 0]]
sample = np.swapaxes(sample, 0, 1)
for i in range(3):
angle = np.random.uniform(low=-np.pi, high=np.pi)
sample = np.matmul(sample, axangle2mat(axes[i], angle))

sample = np.swapaxes(sample, 0, 1)
sample = np.matmul(sample, axangle2mat(axis, angle))
sample = np.swapaxes(sample, 0, 1)
sample = np.swapaxes(sample, 0, 1)
return sample


Expand Down
Loading
Loading