diff --git a/.gitignore b/.gitignore index e00a302..fe3cf06 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,8 @@ data/* *.log *.pt *.mdl +*.png +*.ipynb # Org-mode .org-id-locations diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bb9ece6..7cca3af 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,5 +8,5 @@ repos: rev: v1.2.3 hooks: - id: flake8 - args: ['--ignore=E203,W503'] + args: ['--ignore=E203,W503,E731'] exclude: mtl.py diff --git a/conf/evaluation/flip_net_ae.yaml b/conf/evaluation/flip_net_ae.yaml new file mode 100644 index 0000000..094ccf4 --- /dev/null +++ b/conf/evaluation/flip_net_ae.yaml @@ -0,0 +1,25 @@ +## baselines +## Pre-trained flip-net fine-tuned fine tune all layers + +feat_hand_crafted: false +feat_random_cnn: false + +# trained networks +model_path: /data/UKBB/SSL/day_sec_10k/logs/models +flip_net: false +flip_net_ft: true +flip_net_random_mlp: false +load_weights: true +freeze_weight: false +flip_net_path: "/data/UKBB/SSL/final_models/ae.mdl" +input_size: 300 # input size after resampling the raw data +subR: 1 + + +# hyper-parameters +learning_rate: 0.0001 +num_workers: 6 +patience: 5 +num_epoch: 200 + +evaluation_name: flip_net_ae diff --git a/conf/evaluation/flip_net_simclr.yaml b/conf/evaluation/flip_net_simclr.yaml new file mode 100644 index 0000000..cd4cab2 --- /dev/null +++ b/conf/evaluation/flip_net_simclr.yaml @@ -0,0 +1,26 @@ +## baselines +## Pre-trained flip-net fine-tuned fine tune all layers + +feat_hand_crafted: false +feat_random_cnn: false + +# trained networks +model_path: /data/UKBB/SSL/day_sec_10k/logs/models +flip_net: false +flip_net_ft: true +flip_net_random_mlp: false +load_weights: true +freeze_weight: false +# flip_net_path: "/data/UKBB/SSL/final_models/simclr_first.mdl" +flip_net_path: "/data/UKBB/SSL/day_sec_1k/logs/models/simclr_200e.mdl" +input_size: 300 # input size after resampling the raw data +subR: 1 + + +# hyper-parameters +learning_rate: 0.0001 +num_workers: 6 +patience: 5 +num_epoch: 200 + +evaluation_name: flip_net_simclr_100k diff --git a/conf/model/resnet.yaml b/conf/model/resnet.yaml index 267f367..caba922 100644 --- a/conf/model/resnet.yaml +++ b/conf/model/resnet.yaml @@ -6,4 +6,4 @@ resnet_version: 1 warm_up_step: 5 lr_scale: true patience: 5 - +is_ae: false diff --git a/conf/task/ae.yaml b/conf/task/ae.yaml new file mode 100644 index 0000000..ef3bed2 --- /dev/null +++ b/conf/task/ae.yaml @@ -0,0 +1,9 @@ +rotation: false +switch_axis: false +time_reversal: false +permutation: false +scale: false +time_warped: false +positive_ratio: 0.5 +task_name: 'ae' +multi: false diff --git a/conf/task/permutation.yaml b/conf/task/permutation.yaml index 4d8d0ab..fcea5ea 100644 --- a/conf/task/permutation.yaml +++ b/conf/task/permutation.yaml @@ -5,4 +5,5 @@ permutation: true scale: false time_warped: false positive_ratio: 0.5 -task_name: 'permutation' \ No newline at end of file +task_name: 'permutation' +multi: false diff --git a/conf/task/scale.yaml b/conf/task/scale.yaml index 49b1a54..98624db 100644 --- a/conf/task/scale.yaml +++ b/conf/task/scale.yaml @@ -8,3 +8,4 @@ scale_sigma: 0.25 min_scale_sigma: 0.05 positive_ratio: 0.5 task_name: 'scale' +multi: false diff --git a/conf/task/simclr.yaml b/conf/task/simclr.yaml new file mode 100644 index 0000000..5d9cdc6 --- /dev/null +++ b/conf/task/simclr.yaml @@ -0,0 +1,9 @@ +rotation: false +switch_axis: false +time_reversal: false +permutation: false +scale: false +time_warped: false +positive_ratio: 0.5 +task_name: 'simclr' +multi: false diff --git a/conf/task/time_reversal.yaml b/conf/task/time_reversal.yaml index 4382683..dddbbc6 100644 --- a/conf/task/time_reversal.yaml +++ b/conf/task/time_reversal.yaml @@ -5,4 +5,5 @@ permutation: false scale: false time_warped: false positive_ratio: 0.5 -task_name: 'aot' \ No newline at end of file +task_name: 'aot' +multi: false diff --git a/downstream_task_evaluation.py b/downstream_task_evaluation.py index bcc55a1..b4b9e77 100644 --- a/downstream_task_evaluation.py +++ b/downstream_task_evaluation.py @@ -15,7 +15,7 @@ import pathlib # SSL net -from sslearning.models.accNet import cnn1, SSLNET, Resnet +from sslearning.models.accNet import cnn1, SSLNET, Resnet, EncoderMLP from sslearning.scores import classification_scores, classification_report import copy from sklearn import preprocessing @@ -282,7 +282,9 @@ def mlp_predict(model, data_loader, my_device, cfg): def init_model(cfg, my_device): - if cfg.model.resnet_version > 0: + if cfg.model.is_ae: + model = EncoderMLP(cfg.data.output_size) + elif cfg.model.resnet_version > 0: model = Resnet( output_size=cfg.data.output_size, is_eva=True, @@ -297,6 +299,7 @@ def init_model(cfg, my_device): if cfg.multi_gpu: model = nn.DataParallel(model, device_ids=cfg.gpu_ids) + print(model) model.to(my_device, dtype=torch.float) return model @@ -305,11 +308,8 @@ def setup_model(cfg, my_device): model = init_model(cfg, my_device) if cfg.evaluation.load_weights: - load_weights( - cfg.evaluation.flip_net_path, - model, - my_device - ) + print("Loading weights from %s" % cfg.evaluation.flip_net_path) + load_weights(cfg.evaluation.flip_net_path, model, my_device) if cfg.evaluation.freeze_weight: freeze_weights(model) return model @@ -503,11 +503,11 @@ def handcraft_features(xyz, sample_rate): feats["std"] = np.std(m) feats["range"] = np.ptp(m) feats["mad"] = stats.median_abs_deviation(m) - if feats['std'] > .01: - feats['skew'] = np.nan_to_num(stats.skew(m)) - feats['kurt'] = np.nan_to_num(stats.kurtosis(m)) + if feats["std"] > 0.01: + feats["skew"] = np.nan_to_num(stats.skew(m)) + feats["kurt"] = np.nan_to_num(stats.kurtosis(m)) else: - feats['skew'] = feats['kurt'] = 0 + feats["skew"] = feats["kurt"] = 0 feats["enmomean"] = np.mean(np.abs(m - 1)) # Spectrum using Welch's method with 3s segment length @@ -620,9 +620,7 @@ def get_data_with_subject_count(subject_count, X, y, pid): return filter_X, filter_y, filter_pid -def load_weights( - weight_path, model, my_device -): +def load_weights(weight_path, model, my_device): # only need to change weights name when # the model is trained in a distributed manner @@ -632,12 +630,17 @@ def load_weights( ) # v2 has the right para names # distributed pretraining can be inferred from the keys' module. prefix - head = next(iter(pretrained_dict_v2)).split('.')[0] # get head of first key - if head == 'module': + head = next(iter(pretrained_dict_v2)).split(".")[ + 0 + ] # get head of first key + if head == "module": # remove module. prefix from dict keys - pretrained_dict_v2 = {k.partition('module.')[2]: pretrained_dict_v2[k] for k in pretrained_dict_v2.keys()} + pretrained_dict_v2 = { + k.partition("module.")[2]: pretrained_dict_v2[k] + for k in pretrained_dict_v2.keys() + } - if hasattr(model, 'module'): + if hasattr(model, "module"): model_dict = model.module.state_dict() multi_gpu_ft = True else: diff --git a/mtl.py b/mtl.py index 7205493..42ca33f 100644 --- a/mtl.py +++ b/mtl.py @@ -576,7 +576,7 @@ def main_worker(rank, cfg): task_losses.append(task_loss) train_task_losses = np.array(task_losses) - if epoch < cfg.model.warm_up_step: + if epoch >= cfg.model.warm_up_step: scheduler.step() train_losses = np.array(train_losses) diff --git a/sslearning/data/data_loader.py b/sslearning/data/data_loader.py index fa3fa5a..506aee0 100644 --- a/sslearning/data/data_loader.py +++ b/sslearning/data/data_loader.py @@ -55,10 +55,48 @@ def subject_collate(batch): return [data, aot_y, scale_y, permutation_y, time_w_y] +def simclr_subject_collate(batch): + x1 = [item[0] for item in batch] + x1 = torch.cat(x1) + x2 = [item[1] for item in batch] + x2 = torch.cat(x2) + return [x1, x2] + + def worker_init_fn(worker_id): np.random.seed(int(time.time())) +def augment_view(X, cfg): + new_X = [] + X = X.numpy() + + for i in range(len(X)): + current_x = X[i, :, :] + + # choice = np.random.choice( + # 2, 1, p=[cfg.task.positive_ratio, 1 - cfg.task.positive_ratio] + # )[0] + # current_x = my_transforms.flip(current_x, choice) + # choice = np.random.choice( + # 2, 1, p=[cfg.task.positive_ratio, 1 - cfg.task.positive_ratio] + # )[0] + # current_x = my_transforms.permute(current_x, choice) + # choice = np.random.choice( + # 2, 1, p=[cfg.task.positive_ratio, 1 - cfg.task.positive_ratio] + # )[0] + # current_x = my_transforms.time_warp(current_x, choice) + choice = np.random.choice( + 2, 1, p=[cfg.task.positive_ratio, 1 - cfg.task.positive_ratio] + )[0] + current_x = my_transforms.rotation(current_x, choice) + new_X.append(current_x) + + new_X = np.array(new_X) + new_X = torch.Tensor(new_X) + return new_X + + def generate_labels(X, shuffle, cfg): labels = [] new_X = [] @@ -314,6 +352,80 @@ def __getitem__(self, idx): ) +class SIMCLR_dataset: + def __init__( + self, + data_root, + file_list_path, + cfg, + transform=None, + shuffle=False, + is_epoch_data=False, + ): + """ + Args: + data_root (string): directory containing all data files + file_list_path (string): file list + cfg (dict): config + shuffle (bool): whether permute epoches within one subject + is_epoch_data (bool): whether each sample is one + second of data or 10 seconds of data + + + Returns: + data : transformed sample + labels (dict) : labels for avalaible transformations + """ + check_file_list(file_list_path, data_root, cfg) + file_list_df = pd.read_csv(file_list_path) + self.file_list = file_list_df["file_list"].to_list() + self.data_root = data_root + self.cfg = cfg + self.is_epoch_data = is_epoch_data + self.ratio2keep = cfg.data.ratio2keep + self.shuffle = shuffle + self.transform = transform + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + print(idx) + + # idx starts from zero + file_to_load = self.file_list[idx] + X = np.load(file_to_load, allow_pickle=True) + + # to help select a percentage of data per subject + subject_data_count = int(len(X) * self.ratio2keep) + assert subject_data_count >= self.cfg.dataloader.num_sample_per_subject + if self.ratio2keep != 1: + X = X[:subject_data_count, :] + + if self.is_epoch_data: + X = weighted_epoch_sample( + X, num_sample=self.cfg.dataloader.num_sample_per_subject + ) + else: + X = weighted_sample( + X, + num_sample=self.cfg.dataloader.num_sample_per_subject, + epoch_len=self.cfg.dataloader.epoch_len, + sample_rate=self.cfg.dataloader.sample_rate, + is_weighted_sample=self.cfg.data.weighted_sample, + ) + + X = torch.from_numpy(X) + if self.transform: + X = self.transform(X) + + X1 = augment_view(X, self.cfg) + X2 = augment_view(X, self.cfg) + return (X1, X2) + + # Return: # x: batch_size * feature size (125) # y: batch_size * label_size (5) diff --git a/sslearning/data/data_transformation.py b/sslearning/data/data_transformation.py index 75d6f6c..b834c2d 100644 --- a/sslearning/data/data_transformation.py +++ b/sslearning/data/data_transformation.py @@ -1,7 +1,6 @@ import numpy as np from transforms3d.axangles import axangle2mat # for rotation from scipy.interpolate import CubicSpline # for warping -import math """ This file implements a list of transforms for tri-axial raw-accelerometry @@ -29,17 +28,18 @@ def rotation(sample, choice): choice (float): [0, 9] for each axis, we can do 4 rotations 0, 90 180, 270 """ - if choice == 9: - return sample + if choice == 1: + # angle_choices = [1 / 4 * np.pi, 1 / 2 * np.pi, 3 / 4 * np.pi] + # angle = angle_choices[choice % 3] + # axis = axis_choices[math.floor(choice / 3)] - axis_choices = [[0, 0, 1], [0, 1, 0], [1, 0, 0]] - angle_choices = [1 / 4 * np.pi, 1 / 2 * np.pi, 3 / 4 * np.pi] - axis = axis_choices[math.floor(choice / 3)] - angle = angle_choices[choice % 3] + axes = [[0, 0, 1], [0, 1, 0], [1, 0, 0]] + sample = np.swapaxes(sample, 0, 1) + for i in range(3): + angle = np.random.uniform(low=-np.pi, high=np.pi) + sample = np.matmul(sample, axangle2mat(axes[i], angle)) - sample = np.swapaxes(sample, 0, 1) - sample = np.matmul(sample, axangle2mat(axis, angle)) - sample = np.swapaxes(sample, 0, 1) + sample = np.swapaxes(sample, 0, 1) return sample diff --git a/sslearning/models/accNet.py b/sslearning/models/accNet.py index 404433a..2f02fda 100644 --- a/sslearning/models/accNet.py +++ b/sslearning/models/accNet.py @@ -19,6 +19,19 @@ def forward(self, x): return y_pred +class ProjectionHead(nn.Module): + def __init__(self, input_size=1024, nn_size=256, encoding_size=100): + super(ProjectionHead, self).__init__() + self.linear1 = torch.nn.Linear(input_size, nn_size) + self.linear2 = torch.nn.Linear(nn_size, encoding_size) + + def forward(self, x): + x = self.linear1(x) + x = F.relu(x) + x = self.linear2(x) + return x + + class EvaClassifier(nn.Module): def __init__(self, input_size=1024, nn_size=512, output_size=2): super(EvaClassifier, self).__init__() @@ -655,6 +668,7 @@ def __init__( resnet_version=1, epoch_len=10, is_mtl=False, + is_simclr=False, ): super(Resnet, self).__init__() @@ -748,9 +762,9 @@ def __init__( self.time_w_h = Classifier( input_size=out_channels, output_size=output_size ) - else: - self.classifier = Classifier( - input_size=out_channels, output_size=output_size + elif is_simclr: + self.classifier = ProjectionHead( + input_size=out_channels, encoding_size=output_size ) weight_init(self) @@ -845,3 +859,82 @@ def weight_init(self, mode="fan_out", nonlinearity="relu"): elif isinstance(m, (nn.BatchNorm1d)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) + + +class Autoencoder(nn.Module): + def __init__(self): + super(Autoencoder, self).__init__() + self.encoder = nn.Sequential( + nn.Conv1d(3, 64, 5, stride=2, padding=1), + nn.ReLU(True), + nn.Conv1d(64, 64, 5, stride=2, padding=1), + nn.ReLU(True), + nn.Conv1d(64, 128, 5, stride=2, padding=1), + nn.ReLU(True), + nn.Conv1d(128, 128, 5, stride=2, padding=1), + nn.ReLU(True), + nn.Conv1d(128, 256, 5, stride=2, padding=1), + nn.ReLU(True), + nn.Conv1d(256, 256, 5, stride=2, padding=1), + nn.ReLU(True), + nn.Conv1d(256, 512, 3, stride=3, padding=1), + nn.ReLU(True), + ) + + self.decoder = nn.Sequential( + nn.ConvTranspose1d(512, 256, 3, stride=2, padding=1), + nn.ReLU(True), + nn.ConvTranspose1d(256, 256, 5, stride=2, padding=1), + nn.ReLU(True), + nn.ConvTranspose1d(256, 128, 5, stride=2, padding=1), + nn.ReLU(True), + nn.ConvTranspose1d(128, 128, 5, stride=2, padding=1), + nn.ReLU(True), + nn.ConvTranspose1d(128, 64, 7, stride=2, padding=1), + nn.ReLU(True), + nn.ConvTranspose1d(64, 64, 7, stride=3, padding=1), + nn.ReLU(True), + nn.ConvTranspose1d( + 64, 3, 5, stride=3, padding=3, output_padding=1 + ), + nn.ReLU(True), + ) + + weight_init(self) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + + +class EncoderMLP(nn.Module): + def __init__(self, output_size): + super(EncoderMLP, self).__init__() + self.encoder = nn.Sequential( + nn.Conv1d(3, 64, 5, stride=2, padding=1), + nn.ReLU(True), + nn.Conv1d(64, 64, 5, stride=2, padding=1), + nn.ReLU(True), + nn.Conv1d(64, 128, 5, stride=2, padding=1), + nn.ReLU(True), + nn.Conv1d(128, 128, 5, stride=2, padding=1), + nn.ReLU(True), + nn.Conv1d(128, 256, 5, stride=2, padding=1), + nn.ReLU(True), + nn.Conv1d(256, 256, 5, stride=2, padding=1), + nn.ReLU(True), + nn.Conv1d(256, 512, 3, stride=3, padding=1), + nn.ReLU(True), + ) + + self.classifier = EvaClassifier( + input_size=512, output_size=output_size + ) + + weight_init(self) + + def forward(self, x): + feats = self.encoder(x) + y = self.classifier(feats.view(x.shape[0], -1)) + return y diff --git a/sslearning/models/lars.py b/sslearning/models/lars.py new file mode 100644 index 0000000..60f5aee --- /dev/null +++ b/sslearning/models/lars.py @@ -0,0 +1,159 @@ +import torch +from torch.optim.optimizer import Optimizer, required +import re + +EETA_DEFAULT = 0.001 + + +class LARS(Optimizer): + """ + Layer-wise Adaptive Rate Scaling for large batch training. + Introduced by "Large Batch Training of Convolutional Networks" by Y. You, + I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888) + From https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/lars.py + """ + + def __init__( + self, + params, + lr=required, + momentum=0.9, + use_nesterov=False, + weight_decay=0.0, + exclude_from_weight_decay=None, + exclude_from_layer_adaptation=None, + classic_momentum=True, + eeta=EETA_DEFAULT, + ): + """Constructs a LARSOptimizer. + Args: + lr: A `float` for learning rate. + momentum: A `float` for momentum. + use_nesterov: A 'Boolean' for whether to use nesterov momentum. + weight_decay: A `float` for weight decay. + exclude_from_weight_decay: A list of `string` for variable screening, if + any of the string appears in a variable's name, the variable will be + excluded for computing weight decay. For example, one could specify + the list like ['batch_normalization', 'bias'] to exclude BN and bias + from weight decay. + exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but + for layer adaptation. If it is None, it will be defaulted the same as + exclude_from_weight_decay. + classic_momentum: A `boolean` for whether to use classic (or popular) + momentum. The learning rate is applied during momeuntum update in + classic momentum, but after momentum for popular momentum. + eeta: A `float` for scaling of learning rate when computing trust ratio. + name: The name for the scope. + """ + + self.epoch = 0 + defaults = dict( + lr=lr, + momentum=momentum, + use_nesterov=use_nesterov, + weight_decay=weight_decay, + exclude_from_weight_decay=exclude_from_weight_decay, + exclude_from_layer_adaptation=exclude_from_layer_adaptation, + classic_momentum=classic_momentum, + eeta=eeta, + ) + + super(LARS, self).__init__(params, defaults) + self.lr = lr + self.momentum = momentum + self.weight_decay = weight_decay + self.use_nesterov = use_nesterov + self.classic_momentum = classic_momentum + self.eeta = eeta + self.exclude_from_weight_decay = exclude_from_weight_decay + # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the + # arg is None. + if exclude_from_layer_adaptation: + self.exclude_from_layer_adaptation = exclude_from_layer_adaptation + else: + self.exclude_from_layer_adaptation = exclude_from_weight_decay + + def step(self, epoch=None, closure=None): + loss = None + if closure is not None: + loss = closure() + + if epoch is None: + epoch = self.epoch + self.epoch += 1 + + for group in self.param_groups: + # weight_decay = group["weight_decay"] + momentum = group["momentum"] + # eeta = group["eeta"] + lr = group["lr"] + + for p in group["params"]: + if p.grad is None: + continue + + param = p.data + grad = p.grad.data + + param_state = self.state[p] + + # TODO: get param names + # if self._use_weight_decay(param_name): + grad += self.weight_decay * param + + if self.classic_momentum: + trust_ratio = 1.0 + + # TODO: get param names + # if self._do_layer_adaptation(param_name): + w_norm = torch.norm(param) + g_norm = torch.norm(grad) + + device = g_norm.get_device() + trust_ratio = torch.where( + w_norm.ge(0), + torch.where( + g_norm.ge(0), + (self.eeta * w_norm / g_norm), + torch.Tensor([1.0]).to(device), + ), + torch.Tensor([1.0]).to(device), + ).item() + + scaled_lr = lr * trust_ratio + if "momentum_buffer" not in param_state: + next_v = param_state[ + "momentum_buffer" + ] = torch.zeros_like(p.data) + else: + next_v = param_state["momentum_buffer"] + + next_v.mul_(momentum).add_(scaled_lr, grad) + if self.use_nesterov: + update = (self.momentum * next_v) + (scaled_lr * grad) + else: + update = next_v + + p.data.add_(-update) + else: + raise NotImplementedError + + return loss + + def _use_weight_decay(self, param_name): + """Whether to use L2 weight decay for `param_name`.""" + if not self.weight_decay: + return False + if self.exclude_from_weight_decay: + for r in self.exclude_from_weight_decay: + if re.search(r, param_name) is not None: + return False + return True + + def _do_layer_adaptation(self, param_name): + """Whether to do layer-wise learning rate adaptation for `param_name`.""" + if self.exclude_from_layer_adaptation: + for r in self.exclude_from_layer_adaptation: + if re.search(r, param_name) is not None: + return False + return True diff --git a/train_ae.py b/train_ae.py new file mode 100644 index 0000000..6760725 --- /dev/null +++ b/train_ae.py @@ -0,0 +1,509 @@ +import os +import numpy as np +import hydra +from omegaconf import OmegaConf +from sslearning.data.data_loader import check_file_list +from torchvision import transforms +from torchsummary import summary + +# Model utils +from sslearning.models.accNet import Autoencoder +from sslearning.data.datautils import ( + RandomSwitchAxisTimeSeries, + RotationAxisTimeSeries, +) + +# Data utils +from sslearning.data.data_loader import ( + SSL_dataset, + subject_collate, + worker_init_fn, +) + +# Torch +import torch +from torch.utils.data import DataLoader +from torch.autograd import Variable +from torch import nn +import torch.optim as optim + +# Torch DDP +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist + +# Plotting +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime + +import signal +import time +import sys +from sslearning.pytorchtools import EarlyStopping + +import warnings + +cuda = torch.cuda.is_available() +now = datetime.now() + +"""" +Muti-tasking learning for self-supervised wearable models + +Our input data will be unlabelled. This script can assign pre-text +task labels to all the data. All the task labels +will be generated all the time but by specifying which tasks to use, +we can train on only a subset of these tasks. + +Whenever we introduce a new task, there are several things to change. +1. Dataloader and dataset classes to handle the data generation +2. In the train step, update the `compute_loss` and `get_task_loss` functions. +3. Update the inference step + +Example usage: + python mtl.py data=day_sec_test task=time_reversal augmentation=all + + # multi-processed distributed parallel (DPP) + python mtl.py data=day_sec_10k task=time_reversal + augmentation=all model=resnet + dataloader.num_sample_per_subject=1500 data.batch_subject_num=14 + dataloader=ten_sec model.lr_scale=True + runtime.distributed=True + +""" + + +################################ +# +# +# DDP functions +# +# +################################ +def setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def run_program(): + while True: + time.sleep(1) + print("a") + + +def signal_handler(signal, frame): + # your code here + cleanup() + sys.exit(0) + + +################################ +# +# +# helper functions +# +# +################################ +def set_seed(my_seed=0): + random_seed = my_seed + np.random.seed(random_seed) + torch.manual_seed(random_seed) + if cuda: + torch.cuda.manual_seed_all(random_seed) + + +def set_up_data4train( + my_X, aot_y, scale_y, permute_y, time_w_y, cfg, my_device, rank +): + aot_y, scale_y, permute_y, time_w_y = ( + Variable(aot_y), + Variable(scale_y), + Variable(permute_y), + Variable(time_w_y), + ) + my_X = Variable(my_X) + + if cfg.runtime.distributed: + my_X = my_X.to(rank, dtype=torch.float) + aot_y = aot_y.to(rank, dtype=torch.long) + scale_y = scale_y.to(rank, dtype=torch.long) + permute_y = permute_y.to(rank, dtype=torch.long) + time_w_y = time_w_y.to(rank, dtype=torch.long) + else: + my_X = my_X.to(my_device, dtype=torch.float) + aot_y = aot_y.to(my_device, dtype=torch.long) + scale_y = scale_y.to(my_device, dtype=torch.long) + permute_y = permute_y.to(my_device, dtype=torch.long) + time_w_y = time_w_y.to(my_device, dtype=torch.long) + return my_X, aot_y, scale_y, permute_y, time_w_y + + +def evaluate_model(model, data_loader, my_device, cfg, rank, criterion): + model.eval() + losses = [] + + for i, (my_X, aot_y, scale_y, permute_y, time_w_y) in enumerate( + data_loader + ): + with torch.no_grad(): + my_X, aot_y, scale_y, permute_y, time_w_y = set_up_data4train( + my_X, aot_y, scale_y, permute_y, time_w_y, cfg, my_device, rank + ) + logits = model(my_X) + + loss = criterion(logits, my_X) + + losses.append(loss.item()) + losses = np.array(losses) + return losses + + +def log_performance(current_loss, writer, mode, epoch, task_name): + # We want to have individual task performance + # and an average loss performance + # train_loss: numpy array + # mode (str): train or test + # overall = np.mean(np.mean(train_loss)) + # rotataion_loss = np.mean(train_loss[:, ROTATION_IDX]) + # task_loss: is only true for all task config + loss = np.mean(current_loss) + writer.add_scalar(mode + "/" + task_name + "_loss", loss, epoch) + + return loss + + +def set_linear_scale_lr(model, cfg): + """Allow for large minibatch + https://arxiv.org/abs/1706.02677 + 1. Linear scale learning rate in proportion to minibatch size + 2. Linear learning scheduler to allow for warm up for the first 5 epoches + """ + if cfg.model.lr_scale: + # reference batch size and learning rate + # lr: 0.0001 batch_size: 512 + reference_lr = 0.0001 + ref_batch_size = 512.0 + optimizer = optim.Adam( + model.parameters(), lr=reference_lr, amsgrad=True + ) + k = ( + 1.0 + * cfg.dataloader.num_sample_per_subject + * cfg.data.batch_subject_num + ) / ref_batch_size + scale_ratio = k ** (1.0 / 5.0) + # linear warm up to account for large batch size + lambda1 = lambda epoch: scale_ratio**epoch + else: + optimizer = optim.Adam( + model.parameters(), lr=cfg.model.learning_rate, amsgrad=True + ) + lambda1 = lambda epoch: 1.0**epoch + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) + return optimizer, scheduler + + +def compute_acc(logits, true_y): + pred_y = torch.argmax(logits, dim=1) + acc = torch.sum(pred_y == true_y) + acc = 1.0 * acc / (pred_y.size()[0]) + return acc + + +@hydra.main(config_path="conf", config_name="config") +def main(cfg): + n_gpus = torch.cuda.device_count() + signal.signal(signal.SIGINT, signal_handler) + + if cfg.runtime.distributed: + if n_gpus < 4: + print(f"Requires at least 4 GPUs to run, but got {n_gpus}.") + else: + cfg.runtime.multi_gpu = True + mp.spawn(main_worker, nprocs=n_gpus, args=(cfg,), join=True) + else: + main_worker(-1, cfg) + + +def main_worker(rank, cfg): + if cfg.runtime.distributed: + setup(rank, 4) + set_seed() + print(OmegaConf.to_yaml(cfg)) + + #################### + # Setting macros + ################### + num_epochs = cfg.runtime.num_epoch + lr = cfg.model.learning_rate # learning rate in SGD + batch_subject_num = cfg.data.batch_subject_num + GPU = cfg.runtime.gpu + multi_gpu = cfg.runtime.multi_gpu + gpu_ids = cfg.runtime.gpu_ids + is_epoch_data = cfg.runtime.is_epoch_data + # mixed_precision = cfg.model.mixed_precision + # useAugment = cfg.runtime.augment + + # data config + train_data_root = cfg.data.train_data_root + test_data_root = cfg.data.test_data_root + train_file_list_path = cfg.data.train_file_list + test_file_list_path = cfg.data.test_file_list + log_interval = cfg.data.log_interval + gpu_id2save = 0 + if cfg.runtime.distributed is False or ( + cfg.runtime.distributed and rank == gpu_id2save + ): + main_log_dir = cfg.data.log_path + dt_string = now.strftime("%d-%m-%Y_%H:%M:%S") + log_dir = os.path.join( + main_log_dir, + cfg.model.name + "_" + cfg.task.task_name + "_" + dt_string, + ) + writer = SummaryWriter(log_dir) + + switch_aug = cfg.augmentation.axis_switch + rotation_aug = cfg.augmentation.rotation + + check_file_list(train_file_list_path, train_data_root, cfg) + check_file_list(test_file_list_path, test_data_root, cfg) + + # y_path = cfg.data.y_path + main_log_dir = cfg.data.log_path + dt_string = now.strftime("%d-%m-%Y_%H:%M:%S") + log_dir = os.path.join(main_log_dir, cfg.model.name + "_" + dt_string) + general_model_path = os.path.join( + main_log_dir, + "models", + cfg.model.name + + "_len_" + + str(cfg.dataloader.epoch_len) + + "_sR_" + + str(cfg.data.ratio2keep) + + "_" + + dt_string + + "_" + + str(cfg.task.task_name), + ) + model_path = general_model_path + ".mdl" + num_workers = 8 + true_batch_size = batch_subject_num * cfg.dataloader.num_sample_per_subject + if true_batch_size > 2000 and cfg.model.lr_scale is False: + warnings.warn( + "Batch size > 2000 but learning rate not using linear scale. \n " + + "Model performance is going to be worse. Fix: run with " + + "cfg.model.lr_scale=True" + ) + + print("Model name: %s" % cfg.model.name) + print("Learning rate: %f" % lr) + print("Number of epoches: %d" % num_epochs) + print("GPU usage: %d" % GPU) + print("Subjects per batch: %d" % batch_subject_num) + print("True batch size : %d" % true_batch_size) + print("Tensor log dir: %s" % log_dir) + + #################### + # Model construction + ################### + if GPU >= -1: + my_device = "cuda:" + str(GPU) + elif multi_gpu is True and cfg.runtime.distributed is False: + my_device = "cuda:0" # use the first GPU as master + else: + my_device = "cpu" + + model = Autoencoder() + model = model.float() + print(model) + + pytorch_total_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + print("Num of paras %d " % pytorch_total_params) + # check if each process is having the same input + if cfg.runtime.distributed: + print("Training using DDP") + torch.cuda.set_device(rank) + model.cuda(rank) + ngpus_per_node = 4 + cfg.data.batch_subject_num = int( + cfg.data.batch_subject_num / ngpus_per_node + ) + num_workers = int(num_workers / ngpus_per_node) + model = DDP(model, device_ids=[rank], output_device=rank) + elif multi_gpu: + print("Training using multiple GPUS") + model = nn.DataParallel(model, device_ids=gpu_ids) + model.to(my_device) + else: + print("Training using device %s" % my_device) + model.to(my_device, dtype=torch.float) + model.to(my_device, dtype=torch.float) + + if GPU == -1 and multi_gpu is False: + summary( + model, + (3, cfg.dataloader.sample_rate * cfg.dataloader.epoch_len), + device="cpu", + ) + elif GPU == 0: + summary( + model, + (3, cfg.dataloader.sample_rate * cfg.dataloader.epoch_len), + device="cuda", + ) + + #################### + # Set up data + ################### + my_transform = None + if switch_aug and rotation_aug: + my_transform = transforms.Compose( + [RandomSwitchAxisTimeSeries(), RotationAxisTimeSeries()] + ) + elif switch_aug: + my_transform = RandomSwitchAxisTimeSeries() + elif rotation_aug: + my_transform = RotationAxisTimeSeries() + + train_dataset = SSL_dataset( + train_data_root, + train_file_list_path, + cfg, + is_epoch_data=is_epoch_data, + transform=my_transform, + ) + if cfg.runtime.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset + ) + else: + train_sampler = None + train_loader = DataLoader( + train_dataset, + batch_size=cfg.data.batch_subject_num, + collate_fn=subject_collate, + shuffle=(train_sampler is None), + sampler=train_sampler, + pin_memory=True, + worker_init_fn=worker_init_fn, + num_workers=num_workers, + ) + + test_dataset = SSL_dataset( + test_data_root, test_file_list_path, cfg, is_epoch_data=is_epoch_data + ) + test_loader = DataLoader( + test_dataset, + batch_size=cfg.data.batch_subject_num, + collate_fn=subject_collate, + shuffle=False, + pin_memory=True, + worker_init_fn=worker_init_fn, + num_workers=num_workers, + ) + + #################### + # Set up Training + ################### + criterion = nn.MSELoss() + optimizer, scheduler = set_linear_scale_lr(model, cfg) + total_step = len(train_loader) + + print("Start training") + # scaler = torch.cuda.amp.GradScaler() + early_stopping = EarlyStopping( + patience=cfg.model.patience, path=model_path, verbose=True + ) + + for epoch in range(num_epochs): + if cfg.runtime.distributed: + train_sampler.set_epoch(epoch) + + model.train() + train_losses = [] + + for i, (my_X, aot_y, scale_y, permute_y, time_w_y) in enumerate( + train_loader + ): + # the labels for all tasks are always generated + my_X, aot_y, scale_y, permute_y, time_w_y = set_up_data4train( + my_X, aot_y, scale_y, permute_y, time_w_y, cfg, my_device, rank + ) + + logit = model(my_X) + + loss = criterion(my_X, logit) + + loss.backward() + optimizer.step() + + optimizer.zero_grad() + + if i % log_interval == 0: + msg = ( + "Train: Epoch [{}/{}], Step [{}/{}], Loss: {:.4f} ".format( + epoch + 1, + num_epochs, + i, + total_step, + loss.item(), + ) + ) + print(msg) + train_losses.append(loss.cpu().detach().numpy()) + + if epoch >= cfg.model.warm_up_step: + scheduler.step() + + train_losses = np.array(train_losses) + test_losses = evaluate_model( + model, test_loader, my_device, cfg, rank, criterion + ) + + # logging + if cfg.runtime.distributed is False or ( + cfg.runtime.distributed and rank == gpu_id2save + ): + log_performance( + train_losses, + writer, + "train", + epoch, + cfg.task.task_name, + ) + test_loss = log_performance( + test_losses, + writer, + "test", + epoch, + cfg.task.task_name, + ) + + # save regularly + if cfg.runtime.distributed is False or ( + cfg.runtime.distributed and rank == gpu_id2save + ): + if epoch % 5 == 0 and cfg.data.data_name == "100k": + epoch_model_path = general_model_path + str(epoch) + ".mdl" + torch.save(model.state_dict(), epoch_model_path) + + early_stopping(test_loss, model) + + if early_stopping.early_stop: + print("Early stopping") + break + + if cfg.runtime.distributed: + cleanup() + + +if __name__ == "__main__": + main() diff --git a/train_simclr.py b/train_simclr.py new file mode 100644 index 0000000..c03744e --- /dev/null +++ b/train_simclr.py @@ -0,0 +1,580 @@ +import os +import numpy as np +import hydra +from omegaconf import OmegaConf +from sslearning.data.data_loader import check_file_list +from torchvision import transforms +from torchsummary import summary + +# Model utils +from sslearning.models.accNet import SSLNET, Resnet +from sslearning.models.lars import LARS +from sslearning.data.datautils import ( + RandomSwitchAxisTimeSeries, +) + +# Data utils +from sslearning.data.data_loader import ( + simclr_subject_collate, + worker_init_fn, + SIMCLR_dataset, +) + +# Torch +import torch +from torch.utils.data import DataLoader +from torch.autograd import Variable +from torch import nn + +# Torch DDP +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist + +# Plotting +from datetime import datetime + +import signal +import time +import sys + +import warnings + +cuda = torch.cuda.is_available() +now = datetime.now() + +"""" +Muti-tasking learning for self-supervised wearable models + +Our input data will be unlabelled. This script can assign pre-text +task labels to all the data. All the task labels +will be generated all the time but by specifying which tasks to use, +we can train on only a subset of these tasks. + +Whenever we introduce a new task, there are several things to change. +1. Dataloader and dataset classes to handle the data generation +2. In the train step, update the `compute_loss` and `get_task_loss` functions. +3. Update the inference step + +Example usage: + python mtl.py data=day_sec_test task=time_reversal augmentation=all + + # multi-processed distributed parallel (DPP) + python mtl.py data=day_sec_10k task=time_reversal + augmentation=all model=resnet + dataloader.num_sample_per_subject=1500 data.batch_subject_num=14 + dataloader=ten_sec model.lr_scale=True + runtime.distributed=True + +""" + + +################################ +# +# +# DDP functions +# +# +################################ +def setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def run_program(): + while True: + time.sleep(1) + print("a") + + +def signal_handler(signal, frame): + # your code here + cleanup() + sys.exit(0) + + +################################ +# +# +# helper functions +# +# +################################ +def set_seed(my_seed=0): + random_seed = my_seed + np.random.seed(random_seed) + torch.manual_seed(random_seed) + if cuda: + torch.cuda.manual_seed_all(random_seed) + + +def set_up_data4train(X1, X2, cfg, my_device, rank): + X1, X2 = ( + Variable(X1), + Variable(X2), + ) + + if cfg.runtime.distributed: + X1 = X1.to(rank, dtype=torch.float) + X2 = X2.to(rank, dtype=torch.float) + else: + X1 = X1.to(my_device, dtype=torch.float) + X2 = X2.to(my_device, dtype=torch.float) + return X1, X2 + + +def evaluate_model(model, data_loader, cfg, my_device, rank, my_criterion): + model.eval() + losses = [] + + for i, (X1, X2) in enumerate(data_loader): + with torch.no_grad(): + ( + X1, + X2, + ) = set_up_data4train(X1, X2, cfg, my_device, rank) + + # obtain two views of the same data + h1 = model(X1) + h2 = model(X2) + loss = my_criterion(h1, h2) + losses.append(loss.item()) + + losses = np.array(losses) + + return (losses,) + + +def log_performance(current_loss, writer, mode, epoch, task_name): + # We want to have individual task performance + # and an average loss performance + # train_loss: numpy array + # mode (str): train or test + # overall = np.mean(np.mean(train_loss)) + # rotataion_loss = np.mean(train_loss[:, ROTATION_IDX]) + # task_loss: is only true for all task config + loss = np.mean(current_loss) + + writer.add_scalar(mode + "/" + task_name + "_loss", loss, epoch) + + return loss + + +def load_optimizer(opt, batch_size, weight_decay, training_epoch, model): + + scheduler = None + if opt == "Adam": + optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) # TODO: LARS + elif opt == "LARS": + # optimized using LARS with linear learning rate scaling + # (i.e. LearningRate = 0.3 × BatchSize/256) and weight decay of 10−6. + learning_rate = 0.1 * batch_size / 256 + optimizer = LARS( + model.parameters(), + lr=learning_rate, + weight_decay=weight_decay, + exclude_from_weight_decay=["batch_normalization", "bias"], + ) + + # "decay the learning rate with the cosine decay schedule without restarts" + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, training_epoch, eta_min=0, last_epoch=-1 + ) + else: + raise NotImplementedError + + return optimizer, scheduler + + +class GatherLayer(torch.autograd.Function): + """Gather tensors from all process, supporting backward propagation.""" + + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + output = [ + torch.zeros_like(input) for _ in range(dist.get_world_size()) + ] + dist.all_gather(output, input) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + (input,) = ctx.saved_tensors + grad_out = torch.zeros_like(input) + grad_out[:] = grads[dist.get_rank()] + return grad_out + + +class NT_Xent(nn.Module): + def __init__(self, batch_size, temperature, world_size): + super(NT_Xent, self).__init__() + self.batch_size = batch_size + self.temperature = temperature + self.world_size = world_size + + self.mask = self.mask_correlated_samples(batch_size, world_size) + self.criterion = nn.CrossEntropyLoss(reduction="sum") + self.similarity_f = nn.CosineSimilarity(dim=2) + + def mask_correlated_samples(self, batch_size, world_size): + N = 2 * batch_size * world_size + mask = torch.ones((N, N), dtype=bool) + mask = mask.fill_diagonal_(0) + for i in range(batch_size * world_size): + mask[i, batch_size * world_size + i] = 0 + mask[batch_size * world_size + i, i] = 0 + return mask + + def forward(self, z_i, z_j): + """ + We do not sample negative examples explicitly. + Instead, given a positive pair, similar to (Chen et al., 2017), + we treat the other 2(N − 1) augmented examples within a minibatch as negative examples. + """ + N = 2 * self.batch_size * self.world_size + + z = torch.cat((z_i, z_j), dim=0) + if self.world_size > 1: + z = torch.cat(GatherLayer.apply(z), dim=0) + + sim = ( + self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) + / self.temperature + ) + + sim_i_j = torch.diag(sim, self.batch_size * self.world_size) + sim_j_i = torch.diag(sim, -self.batch_size * self.world_size) + + # We have 2N samples, but with Distributed training every GPU gets N examples too, resulting in: 2xNxN + positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) + negative_samples = sim[self.mask].reshape(N, -1) + + labels = torch.zeros(N).to(positive_samples.device).long() + logits = torch.cat((positive_samples, negative_samples), dim=1) + loss = self.criterion(logits, labels) + loss /= N + return loss + + +# contrastive loss +# Taken from https://medium.com/the-owl/simclr-in-pytorch-5f290cb11dd7 +class SimCLR_Loss(nn.Module): + def __init__(self, batch_size, temperature): + super().__init__() + self.batch_size = batch_size + self.temperature = temperature + + self.mask = self.mask_correlated_samples(batch_size) + self.criterion = nn.CrossEntropyLoss(reduction="sum") + self.similarity_f = nn.CosineSimilarity(dim=2) + + def mask_correlated_samples(self, batch_size): + N = 2 * batch_size + mask = torch.ones((N, N), dtype=bool) + mask = mask.fill_diagonal_(0) + + for i in range(batch_size): + mask[i, batch_size + i] = 0 + mask[batch_size + i, i] = 0 + return mask + + def forward(self, z_i, z_j): + N = 2 * self.batch_size + + z = torch.cat((z_i, z_j), dim=0) + + sim = ( + self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) + / self.temperature + ) + + sim_i_j = torch.diag(sim, self.batch_size) + sim_j_i = torch.diag(sim, -self.batch_size) + + # We have 2N samples, but with Distributed training every GPU gets N examples too, resulting in: 2xNxN + positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) + negative_samples = sim[self.mask].reshape(N, -1) + + # SIMCLR + labels = ( + torch.from_numpy(np.array([0] * N)) + .reshape(-1) + .to(positive_samples.device) + .long() + ) # .float() + + logits = torch.cat((positive_samples, negative_samples), dim=1) + loss = self.criterion(logits, labels) + loss /= N + + return loss + + +@hydra.main(config_path="conf", config_name="config") +def main(cfg): + n_gpus = torch.cuda.device_count() + signal.signal(signal.SIGINT, signal_handler) + + if cfg.runtime.distributed: + if n_gpus < 4: + print(f"Requires at least 4 GPUs to run, but got {n_gpus}.") + else: + cfg.runtime.multi_gpu = True + mp.spawn(main_worker, nprocs=n_gpus, args=(cfg,), join=True) + else: + main_worker(-1, cfg) + + +def main_worker(rank, cfg): + if cfg.runtime.distributed: + setup(rank, 4) + set_seed() + print(OmegaConf.to_yaml(cfg)) + + #################### + # Setting macros + ################### + num_epochs = cfg.runtime.num_epoch + lr = cfg.model.learning_rate # learning rate in SGD + batch_subject_num = cfg.data.batch_subject_num + GPU = cfg.runtime.gpu + multi_gpu = cfg.runtime.multi_gpu + gpu_ids = cfg.runtime.gpu_ids + is_epoch_data = cfg.runtime.is_epoch_data + # mixed_precision = cfg.model.mixed_precision + # useAugment = cfg.runtime.augment + + # data config + train_data_root = cfg.data.train_data_root + test_data_root = cfg.data.test_data_root + train_file_list_path = cfg.data.train_file_list + test_file_list_path = cfg.data.test_file_list + log_interval = cfg.data.log_interval + gpu_id2save = 0 + if cfg.runtime.distributed is False or ( + cfg.runtime.distributed and rank == gpu_id2save + ): + main_log_dir = cfg.data.log_path + dt_string = now.strftime("%d-%m-%Y_%H:%M:%S") + log_dir = os.path.join( + main_log_dir, + cfg.model.name + "_" + cfg.task.task_name + "_" + dt_string, + ) + # writer = SummaryWriter(log_dir) + + check_file_list(train_file_list_path, train_data_root, cfg) + check_file_list(test_file_list_path, test_data_root, cfg) + + # y_path = cfg.data.y_path + main_log_dir = cfg.data.log_path + dt_string = now.strftime("%d-%m-%Y_%H:%M:%S") + log_dir = os.path.join(main_log_dir, cfg.model.name + "_" + dt_string) + general_model_path = os.path.join( + main_log_dir, + "models", + cfg.model.name + + "_len_" + + str(cfg.dataloader.epoch_len) + + "_sR_" + + str(cfg.data.ratio2keep) + + "_" + + dt_string + + "_", + ) + model_path = general_model_path + ".mdl" + num_workers = 8 + true_batch_size = batch_subject_num * cfg.dataloader.num_sample_per_subject + if true_batch_size > 2000 and cfg.model.lr_scale is False: + warnings.warn( + "Batch size > 2000 but learning rate not using linear scale. \n " + + "Model performance is going to be worse. Fix: run with " + + "cfg.model.lr_scale=True" + ) + + print("Model name: %s" % cfg.model.name) + print("Learning rate: %f" % lr) + print("Number of epoches: %d" % num_epochs) + print("GPU usage: %d" % GPU) + print("Subjects per batch: %d" % batch_subject_num) + print("True batch size : %d" % true_batch_size) + print("Tensor log dir: %s" % log_dir) + + #################### + # Model construction + ################### + if GPU >= -1: + my_device = "cuda:" + str(GPU) + elif multi_gpu is True and cfg.runtime.distributed is False: + my_device = "cuda:0" # use the first GPU as master + else: + my_device = "cpu" + + if cfg.task.task_name == "simclr": + z_size = 64 + model = Resnet( + output_size=z_size, + resnet_version=cfg.model.resnet_version, + epoch_len=cfg.dataloader.epoch_len, + is_simclr=True, + ) + criterion = NT_Xent( + batch_size=true_batch_size, temperature=0.1, world_size=1 + ) + else: + model = SSLNET(output_size=2, flatten_size=1024) # VGG + model = model.float() + print(model) + + pytorch_total_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + print("Num of paras %d " % pytorch_total_params) + # check if each process is having the same input + if cfg.runtime.distributed: + print("Training using DDP") + torch.cuda.set_device(rank) + model.cuda(rank) + ngpus_per_node = 4 + cfg.data.batch_subject_num = int( + cfg.data.batch_subject_num / ngpus_per_node + ) + num_workers = int(num_workers / ngpus_per_node) + model = DDP(model, device_ids=[rank], output_device=rank) + elif multi_gpu: + print("Training using multiple GPUS") + model = nn.DataParallel(model, device_ids=gpu_ids) + model.to(my_device) + else: + print("Training using device %s" % my_device) + model.to(my_device, dtype=torch.float) + model.to(my_device, dtype=torch.float) + + if GPU == -1 and multi_gpu is False: + summary( + model, + (3, cfg.dataloader.sample_rate * cfg.dataloader.epoch_len), + device="cpu", + ) + elif GPU == 0: + summary( + model, + (3, cfg.dataloader.sample_rate * cfg.dataloader.epoch_len), + device="cuda", + ) + + #################### + # Set up data + ################### + # my_transform = transforms.Compose( + # [RandomSwitchAxisTimeSeries(), RotationAxisTimeSeries()] + # ) + my_transform = transforms.Compose( + [ + RandomSwitchAxisTimeSeries(), + transforms.standard_normalization, + ] + ) + + train_dataset = SIMCLR_dataset( + train_data_root, + train_file_list_path, + cfg, + is_epoch_data=is_epoch_data, + transform=my_transform, + ) + if cfg.runtime.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset + ) + train_shuffle = False + else: + train_sampler = None + train_shuffle = True + + train_loader = DataLoader( + train_dataset, + batch_size=cfg.data.batch_subject_num, + collate_fn=simclr_subject_collate, + shuffle=train_shuffle, + sampler=train_sampler, + pin_memory=True, + worker_init_fn=worker_init_fn, + num_workers=num_workers, + ) + + #################### + # Set up Training + ################### + weight_decay = 10e-6 + epoch_without_restart = 200 + optimizer, scheduler = load_optimizer( + "LARS", true_batch_size, weight_decay, epoch_without_restart, model + ) + total_step = len(train_loader) + + print("Start training") + # scaler = torch.cuda.amp.GradScaler() + # early_stopping = EarlyStopping( + # patience=cfg.model.patience, path=model_path, verbose=True + # ) + print("saving model weight to %s" % model_path) + + for epoch in range(num_epochs): + if cfg.runtime.distributed: + train_sampler.set_epoch(epoch) + + model.train() + train_losses = [] + + for i, (X1, X2) in enumerate(train_loader): + # the labels for all tasks are always generated + + ( + X1, + X2, + ) = set_up_data4train(X1, X2, cfg, my_device, rank) + + # obtain two views of the same data + h1 = model(X1) + h2 = model(X2) + loss = criterion(h1, h2) + + loss.backward() + optimizer.step() + + optimizer.zero_grad() + + if i % log_interval == 0: + msg = ( + "Train: Epoch [{}/{}], Step [{}/{}], Loss: {:.4f},".format( + epoch + 1, + num_epochs, + i, + total_step, + loss.item(), + ) + ) + print(msg) + train_losses.append(loss.cpu().detach().numpy()) + + if epoch >= cfg.model.warm_up_step: + scheduler.step() + + loss = np.mean(np.array(train_losses)) + print("Epoch: %d, training Loss: %f" % (epoch, loss)) + torch.save(model.state_dict(), model_path) + + if cfg.runtime.distributed: + cleanup() + + +if __name__ == "__main__": + main()