diff --git a/models/utils/__init__.py b/models/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/utils/audio_utils.py b/models/utils/audio_utils.py new file mode 100644 index 00000000..39f428af --- /dev/null +++ b/models/utils/audio_utils.py @@ -0,0 +1,148 @@ +import numpy as np +import torch as t +import models.utils.dist_adapter as dist +import soundfile +import librosa +from models.utils.dist_utils import print_once + +class DefaultSTFTValues: + def __init__(self, hps): + self.sr = hps.sr + self.n_fft = 2048 + self.hop_length = 256 + self.window_size = 6 * self.hop_length + +class STFTValues: + def __init__(self, hps, n_fft, hop_length, window_size): + self.sr = hps.sr + self.n_fft = n_fft + self.hop_length = hop_length + self.window_size = window_size + +def calculate_bandwidth(dataset, hps, duration=600): + hps = DefaultSTFTValues(hps) + n_samples = int(dataset.sr * duration) + l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank() + spec_norm_total, spec_nelem = 0.0, 0.0 + while n_seen < n_samples: + x = dataset[idx] + if isinstance(x, (tuple, list)): + x, y = x + samples = x.astype(np.float64) + stft = librosa.core.stft(np.mean(samples, axis=1), hps.n_fft, hop_length=hps.hop_length, win_length=hps.window_size) + spec = np.absolute(stft) + spec_norm_total += np.linalg.norm(spec) + spec_nelem += 1 + n_seen += int(np.prod(samples.shape)) + l1 += np.sum(np.abs(samples)) + total += np.sum(samples) + total_sq += np.sum(samples ** 2) + idx += max(16, dist.get_world_size()) + + if dist.is_available(): + from jukebox.utils.dist_utils import allreduce + n_seen = allreduce(n_seen) + total = allreduce(total) + total_sq = allreduce(total_sq) + l1 = allreduce(l1) + spec_nelem = allreduce(spec_nelem) + spec_norm_total = allreduce(spec_norm_total) + + mean = total / n_seen + bandwidth = dict(l2 = total_sq / n_seen - mean ** 2, + l1 = l1 / n_seen, + spec = spec_norm_total / spec_nelem) + print_once(bandwidth) + return bandwidth + +def audio_preprocess(x, hps): + # Extra layer in case we want to experiment with different preprocessing + # For two channel, blend randomly into mono (standard is .5 left, .5 right) + + # x: NTC + # x = x.float() + # if x.shape[-1]==2: + # if hps.aug_blend: + # mix=t.rand((x.shape[0],1), device=x.device) #np.random.rand() + # else: + # mix = 0.5 + # x=(mix*x[:,:,0]+(1-mix)*x[:,:,1]) + # elif x.shape[-1]==1: + # x=x[:,:,0] + # else: + # assert False, f'Expected channels {hps.channels}. Got unknown {x.shape[-1]} channels' + + # # x: NT -> NTC + # x = x.unsqueeze(2) + return x + +def audio_postprocess(x, hps): + return x + +def stft(sig, hps): + return t.stft(sig, hps.n_fft, hps.hop_length, win_length=hps.window_size, window=t.hann_window(hps.window_size, device=sig.device)) + +def spec(x, hps): + return t.norm(stft(x, hps), p=2, dim=-1) + +def norm(x): + return (x.view(x.shape[0], -1) ** 2).sum(dim=-1).sqrt() + +def squeeze(x): + if len(x.shape) == 3: + assert x.shape[-1] in [1,2] + x = t.mean(x, -1) + if len(x.shape) != 2: + raise ValueError(f'Unknown input shape {x.shape}') + return x + +def spectral_loss(x_in, x_out, hps): + hps = DefaultSTFTValues(hps) + spec_in = spec(squeeze(x_in.float()), hps) + spec_out = spec(squeeze(x_out.float()), hps) + return norm(spec_in - spec_out) + +def multispectral_loss(x_in, x_out, hps): + losses = [] + assert len(hps.multispec_loss_n_fft) == len(hps.multispec_loss_hop_length) == len(hps.multispec_loss_window_size) + args = [hps.multispec_loss_n_fft, + hps.multispec_loss_hop_length, + hps.multispec_loss_window_size] + for n_fft, hop_length, window_size in zip(*args): + hps = STFTValues(hps, n_fft, hop_length, window_size) + spec_in = spec(squeeze(x_in.float()), hps) + spec_out = spec(squeeze(x_out.float()), hps) + losses.append(norm(spec_in - spec_out)) + return sum(losses) / len(losses) + +def spectral_convergence(x_in, x_out, hps, epsilon=2e-3): + hps = DefaultSTFTValues(hps) + spec_in = spec(squeeze(x_in.float()), hps) + spec_out = spec(squeeze(x_out.float()), hps) + + gt_norm = norm(spec_in) + residual_norm = norm(spec_in - spec_out) + mask = (gt_norm > epsilon).float() + return (residual_norm * mask) / t.clamp(gt_norm, min=epsilon) + +def log_magnitude_loss(x_in, x_out, hps, epsilon=1e-4): + hps = DefaultSTFTValues(hps) + spec_in = t.log(spec(squeeze(x_in.float()), hps) + epsilon) + spec_out = t.log(spec(squeeze(x_out.float()), hps) + epsilon) + return t.mean(t.abs(spec_in - spec_out)) + +def load_audio(file, sr, offset, duration, mono=False): + # Librosa loads more filetypes than soundfile + x, _ = librosa.load(file, sr=sr, mono=mono, offset=offset/sr, duration=duration/sr) + if len(x.shape) == 1: + x = x.reshape((1, -1)) + return x + + +def save_wav(fname, aud, sr): + # clip before saving? + aud = t.clamp(aud, -1, 1).cpu().numpy() + for i in list(range(aud.shape[0])): + soundfile.write(f'{fname}/item_{i}.wav', aud[i], samplerate=sr, format='wav') + + diff --git a/models/utils/dist_utils.py b/models/utils/dist_utils.py new file mode 100644 index 00000000..f0e7a1f0 --- /dev/null +++ b/models/utils/dist_utils.py @@ -0,0 +1,101 @@ +import os +from time import sleep +import torch +import models.utils.dist_adapter as dist + +def print_once(msg): + if (not dist.is_available()) or dist.get_rank()==0: + print(msg) + +def print_all(msg): + if (not dist.is_available()): + print(msg) + elif dist.get_rank()%8==0: + print(f'{dist.get_rank()//8}: {msg}') + +def allgather(x): + xs = [torch.empty_like(x) for _ in range(dist.get_world_size())] + dist.all_gather(xs, x) + xs = torch.cat(xs, dim=0) + return xs + +def allreduce(x, op=dist.ReduceOp.SUM): + x = torch.tensor(x).float().cuda() + dist.all_reduce(x, op=op) + return x.item() + +def allgather_lists(xs): + bs = len(xs) + total_bs = dist.get_world_size()*len(xs) + lengths = torch.tensor([len(x) for x in xs], dtype=t.long, device='cuda') + lengths = allgather(lengths) + assert lengths.shape == (total_bs,) + max_length = torch.max(lengths).item() + + xs = torch.tensor([[*x, *[0]*(max_length - len(x))] for x in xs], device='cuda') + assert xs.shape == (bs, max_length), f'Expected {(bs, max_length)}, got {xs.shape}' + xs = allgather(xs) + assert xs.shape == (total_bs,max_length), f'Expected {(total_bs, max_length)}, got {xs.shape}' + + return [xs[i][:lengths[i]].cpu().numpy().tolist() for i in range(total_bs)] + +def setup_dist_from_mpi( + master_addr="127.0.0.1", backend="nccl", port=29500, n_attempts=5, verbose=False +): + if dist.is_available(): + return _setup_dist_from_mpi(master_addr, backend, port, n_attempts, verbose) + else: + use_cuda = torch.cuda.is_available() + print(f'Using cuda {use_cuda}') + + mpi_rank = 0 + local_rank = 0 + + device = torch.device("cuda", local_rank) if use_cuda else torch.device("cpu") + torch.cuda.set_device(local_rank) + + return mpi_rank, local_rank, device + +def _setup_dist_from_mpi(master_addr, backend, port, n_attempts, verbose): + from mpi4py import MPI # This must be imported in order to get e rrors from all ranks to show up + + mpi_rank = MPI.COMM_WORLD.Get_rank() + mpi_size = MPI.COMM_WORLD.Get_size() + + + os.environ["RANK"] = str(mpi_rank) + os.environ["WORLD_SIZE"] = str(mpi_size) + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(port) + os.environ["NCCL_LL_THRESHOLD"] = "0" + os.environ["NCCL_NSOCKS_PERTHREAD"] = "2" + os.environ["NCCL_SOCKET_NTHREADS"] = "8" + + # Pin this rank to a specific GPU on the node + local_rank = mpi_rank % 8 + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) + + if verbose: + print(f"Connecting to master_addr: {master_addr}") + + # There is a race condition when initializing NCCL with a large number of ranks (e.g 500 ranks) + # We guard against the failure and then retry + for attempt_idx in range(n_attempts): + try: + dist.init_process_group(backend=backend, init_method=f"env://") + assert dist.get_rank() == mpi_rank + + use_cuda = torch.cuda.is_available() + print(f'Using cuda {use_cuda}') + local_rank = mpi_rank % 8 + device = torch.device("cuda", local_rank) if use_cuda else torch.device("cpu") + torch.cuda.set_device(local_rank) + + return mpi_rank, local_rank, device + except RuntimeError as e: + print(f"Caught error during NCCL init (attempt {attempt_idx} of {n_attempts}): {e}") + sleep(1 + (0.01 * mpi_rank)) # Sleep to avoid thundering herd + pass + + raise RuntimeError("Failed to initialize NCCL") diff --git a/models/utils/ema.py b/models/utils/ema.py new file mode 100644 index 00000000..94f3b47b --- /dev/null +++ b/models/utils/ema.py @@ -0,0 +1,94 @@ +import torch +from torch._utils import _flatten_dense_tensors +import numpy as np + +# EMA always in float, as accumulation needs lots of bits +class EMA: + def __init__(self, params, mu=0.999): + self.mu = mu + self.state = [(p, self.get_model_state(p)) for p in params if p.requires_grad] + + def get_model_state(self, p): + return p.data.float().detach().clone() + + def step(self): + for p, state in self.state: + state.mul_(self.mu).add_(1 - self.mu, p.data.float()) + + def swap(self): + # swap ema and model params + for p, state in self.state: + other_state = self.get_model_state(p) + p.data.copy_(state.type_as(p.data)) + state.copy_(other_state) + + +class CPUEMA: + def __init__(self, params, mu=0.999, freq=1): + self.mu = mu**freq + self.state = [(p, self.get_model_state(p)) for p in params if p.requires_grad] + self.freq = freq + self.steps = 0 + + def get_model_state(self, p): + with torch.no_grad(): + state = p.data.float().detach().cpu().numpy() + return state + + def step(self): + with torch.no_grad(): + self.steps += 1 + if self.steps % self.freq == 0: + for i in range(len(self.state)): + p, state = self.state[i] + state = torch.from_numpy(state).cuda() + state.mul_(self.mu).add_(1 - self.mu, p.data.float()) + self.state[i] = (p, state.cpu().numpy()) + + def swap(self): + with torch.no_grad(): + # swap ema and model params + for p, state in self.state: + other_state = self.get_model_state(p) + p.data.copy_(torch.from_numpy(state).type_as(p.data)) + np.copyto(state, other_state) + +class FusedEMA: + def __init__(self, params, mu=0.999): + self.mu = mu + params = list(params) + self.params = {} + self.params['fp16'] = [p for p in params if p.requires_grad and p.data.dtype == torch.float16] + self.params['fp32'] = [p for p in params if p.requires_grad and p.data.dtype != torch.float16] + self.groups = [group for group in self.params.keys() if len(self.params[group]) > 0] + self.state = {} + for group in self.groups: + self.state[group] = self.get_model_state(group) + + def get_model_state(self, group): + params = self.params[group] + return _flatten_dense_tensors([p.data.float() for p in params]) + # if self.fp16: + # return _flatten_dense_tensors([p.data.half() for p in self.param_group if p.dtype]) + # else: + # return _flatten_dense_tensors([p.data for p in self.param_group]) + + def step(self): + for group in self.groups: + self.state[group].mul_(self.mu).add_(1 - self.mu, self.get_model_state(group)) + + def swap(self): + # swap ema and model params + for group in self.groups: + other_state = self.get_model_state(group) + state = self.state[group] + params = self.params[group] + offset = 0 + for p in params: + numel = p.data.numel() + p.data = state.narrow(0, offset, numel).view_as(p.data).type_as(p.data) + offset += numel + + self.state[group] = other_state + + diff --git a/models/utils/fp16.py b/models/utils/fp16.py new file mode 100644 index 00000000..e7290c67 --- /dev/null +++ b/models/utils/fp16.py @@ -0,0 +1,303 @@ +# Utils for fp16 training. +import importlib +import math +import numpy as np +import torch +import models.utils.dist_adapter as dist +from torch.optim import Optimizer +from torch._utils import _flatten_dense_tensors + +from models.utils.dist_utils import allreduce + +def adam_step(p: torch.Tensor, out_p: torch.Tensor, exp_avg: torch.Tensor, exp_avg_sq: torch.Tensor, grad: torch.Tensor, + lr: float, beta1: float, beta2: float, eps: float, scale: float, step: int, eps_mode: int, bias_correction: int, weight_decay: float): + assert bias_correction == 1 + assert eps_mode == 1 + + grad = grad.float() + grad.div_(scale) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = exp_avg_sq.sqrt().add_(eps) + + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + step_size = lr * math.sqrt(bias_correction2) / bias_correction1 + + p.add_(exp_avg/denom + weight_decay*p.float(), alpha=-step_size) + +# Import fused_adam if we have apex, otherwise use regular adam +try: + fused_adam_cuda = importlib.import_module("fused_adam_cuda") + fused_adam_step = fused_adam_cuda.adam + print("Using apex fused_adam_cuda") +except ModuleNotFoundError: + fused_adam_step = adam_step + +def backward(loss, params, scalar, fp16, logger): + # Perform backward + if not fp16: + scale = 1.0 + loss.backward() + gn = grad_norm(params, scale) + return loss, scale, gn, False, False + else: + scale = scalar.get_scale() + loss = (loss.float())*scale + overflow_loss = check_overflow(loss.item()) + overflow_loss = allreduce(int(overflow_loss), op=dist.ReduceOp.MAX) > 0 + if not overflow_loss: + loss.backward() + gn = grad_norm(params, scale) + overflow_grad = check_overflow(gn) + overflow_grad = allreduce(int(overflow_grad), op=dist.ReduceOp.MAX) > 0 + scalar.update_scale(overflow_grad) + else: + gn = 0.0 + overflow_grad = True + loss = (loss.detach().float()) / scale # Should delete computation graph for overflow + if logger.rank == 0: + if loss > 12.: print(f"\nWarning. Loss is {loss}") + if overflow_loss: print(f"\nOverflow in forward. Loss {loss}, lgscale {np.log2(scale)}. Skipping batch completely (no backward, scale update)") + elif overflow_grad: print(f"\nOverflow in backward. Loss {loss}, grad norm {gn}, lgscale {np.log2(scale)}, new lgscale {np.log2(scalar.get_scale())}") + return loss, scale, gn, overflow_loss, overflow_grad + +# Automatic loss scaling +class LossScalar(object): + def __init__(self, + loss_scale, + init_scale=2. ** 16, + scale_factor=2. ** (1. / 1000), + scale_window=1): + if loss_scale == None: + # Use dynamic loss scaling + self.dynamic = True + self.loss_scale = init_scale + else: + self.dynamic = False + self.loss_scale = loss_scale + self.max_loss_scale = 2.**24 + self.scale_factor = scale_factor + self.scale_window = scale_window + self.unskipped = 0 + self.overflow = False + + def get_scale(self): + return self.loss_scale + + def update_scale(self, overflow): + if overflow and self.dynamic: + self.loss_scale /= 2. + self.unskipped = 0 + else: + self.unskipped += 1 + + if self.unskipped == self.scale_window and self.dynamic: + self.loss_scale = min(self.max_loss_scale, self.loss_scale * self.scale_factor) + self.unskipped = 0 + +def check_overflow(val): + return (val == float('inf')) or (val == -float('inf')) or (val != val) + +def grad_norm(params, scale, flat=False): + params = list(params) + if flat: + # Faster but more memory + fp16_grads = [p.grad for p in params if p.grad is not None and p.data.dtype == torch.float16] + fp16_norm = 0.0 if len(fp16_grads) == 0 else float(_flatten_dense_tensors(fp16_grads).norm(p=2, dtype=torch.float32)) + fp32_grads = [p.grad for p in params if p.grad is not None and p.data.dtype != torch.float16] + fp32_norm = 0.0 if len(fp32_grads) == 0 else float(_flatten_dense_tensors(fp32_grads).norm(p=2)) + grad_norm = (fp16_norm**2 + fp32_norm**2)**0.5 + else: + # Slightly slower but less memory + grad_norm = 0.0 + for p in params: + if p.grad is not None: + grad_norm += p.grad.norm(p=2, dtype=torch.float32)**2 + grad_norm = float(grad_norm**0.5) + return grad_norm / scale + +def clipped_grad_scale(grad_norm, max_grad_norm, scale): + clip = grad_norm / max_grad_norm + if clip > 1: + scale = clip * scale + return scale + +class FP16FusedAdam(Optimizer): + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + eps_inside_sqrt=False, + weight_decay=0.0, + amsgrad=False, + ): + if amsgrad: + raise RuntimeError("FusedAdam does not support the AMSGrad variant.") + defaults = dict( + lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay + ) + super(FP16FusedAdam, self).__init__(params, defaults) + self.eps_mode = 0 if eps_inside_sqrt else 1 + self.FLOAT16_MAX = 65504.0 + self.init_state() + + def init_state(self): + for group in self.param_groups: + for p in group["params"]: + assert p.requires_grad == True + state = self.state[p] + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + if p.data.dtype == torch.float16: + state["scale_exp_avg"] = 1.0 + state["scale_exp_avg_sq"] = 1.0 + + def step(self, closure=None, scale=1.0): + """Performs a single optimization step. Scales gradients down by scale + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + scale (float, optional): factor to divide gradient tensor values + by before applying to weights. (default: 1) + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + bias_correction = 1 if group["bias_correction"] else 0 + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + + state = self.state[p] + + if p.data.dtype == torch.float16: + exp_avg, exp_avg_sq = ( + state["exp_avg"].float() * state["scale_exp_avg"], + state["exp_avg_sq"].float() * state["scale_exp_avg_sq"], + ) + else: + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + out_p = torch.tensor([], dtype=torch.float) + fused_adam_step( + p.data, + out_p, + exp_avg, + exp_avg_sq, + grad, + group["lr"], + beta1, + beta2, + group["eps"], + scale, + state["step"], + self.eps_mode, + bias_correction, + group["weight_decay"], + ) + + if p.data.dtype == torch.float16: + state["scale_exp_avg"] = ( + 1e-8 + float(torch.norm(exp_avg, float("inf"))) / self.FLOAT16_MAX + ) + state["scale_exp_avg_sq"] = ( + 1e-8 + float(torch.norm(exp_avg_sq, float("inf"))) / self.FLOAT16_MAX + ) + state["exp_avg"] = (exp_avg / state["scale_exp_avg"]).half() + state["exp_avg_sq"] = (exp_avg_sq / state["scale_exp_avg_sq"]).half() + + return loss + + +class FusedAdam(Optimizer): + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + eps_inside_sqrt=False, + weight_decay=0.0, + amsgrad=False, + ): + if amsgrad: + raise RuntimeError("FusedAdam does not support the AMSGrad variant.") + defaults = dict( + lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay + ) + super(FusedAdam, self).__init__(params, defaults) + self.eps_mode = 0 if eps_inside_sqrt else 1 + + def step(self, closure=None, scale=1.0): + """Performs a single optimization step. Scales gradients down by scale + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + scale (float, optional): factor to divide gradient tensor values + by before applying to weights. (default: 1) + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + bias_correction = 1 if group["bias_correction"] else 0 + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data).float() + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data).float() + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + out_p = torch.tensor([], dtype=torch.float) + fused_adam_step( + p.data, + out_p, + exp_avg, + exp_avg_sq, + grad, + group["lr"], + beta1, + beta2, + group["eps"], + scale, + state["step"], + self.eps_mode, + bias_correction, + group["weight_decay"], + ) + + return loss + diff --git a/models/utils/logger.py b/models/utils/logger.py new file mode 100644 index 00000000..65f5b5cd --- /dev/null +++ b/models/utils/logger.py @@ -0,0 +1,147 @@ +import torch as t +import models.utils.dist_adapter as dist +from tqdm import tqdm +from datetime import date +import os +import sys + +def def_tqdm(x): + return tqdm(x, leave=True, file=sys.stdout, bar_format="{n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]") + +def get_range(x): + if dist.get_rank() == 0: + return def_tqdm(x) + else: + return x + +def init_logging(hps, local_rank, rank): + logdir = f"{hps.local_logdir}/{hps.name}" + if local_rank == 0: + if not os.path.exists(logdir): + os.makedirs(logdir) + with open(logdir + 'argv.txt', 'w') as f: + f.write(hps.argv + '\n') + print("Logging to", logdir) + logger = Logger(logdir, rank) + metrics = Metrics() + logger.add_text('hps', str(hps)) + return logger, metrics + +def get_name(hps): + name = "" + for key, value in hps.items(): + name += f"{key}_{value}_" + return name + +def average_metrics(_metrics): + metrics = {} + for _metric in _metrics: + for key, val in _metric.items(): + if key not in metrics: + metrics[key] = [] + metrics[key].append(val) + return {key: sum(vals)//len(vals) for key, vals in metrics.items()} + +class Metrics: + def __init__(self): + self.sum = {} + self.n = {} + + def update(self, tag, val, batch): + # v is average value over batch + # store total value and total batch, returns dist average + sum = t.tensor(val * batch).float().cuda() + n = t.tensor(batch).float().cuda() + dist.all_reduce(sum) + dist.all_reduce(n) + sum = sum.item() + n = n.item() + self.sum[tag] = self.sum.get(tag, 0.0) + sum + self.n[tag] = self.n.get(tag, 0.0) + n + return sum / n + + def avg(self, tag): + if tag in self.sum: + return self.sum[tag] / self.n[tag] + else: + return 0.0 + + def reset(self): + self.sum = {} + self.n = {} + +class Logger: + def __init__(self, logdir, rank): + if rank == 0: + from tensorboardX import SummaryWriter + self.sw = SummaryWriter(f"{logdir}/logs") + self.iters = 0 + self.rank = rank + self.works = [] + self.logdir = logdir + + def step(self): + self.iters += 1 + + def flush(self): + if self.rank == 0: + self.sw.flush() + + def add_text(self, tag, text): + if self.rank == 0: + self.sw.add_text(tag, text, self.iters) + + def add_audios(self, tag, auds, sample_rate=22050, max_len=None, max_log=8): + if self.rank == 0: + for i in range(min(len(auds), max_log)): + if max_len: + self.sw.add_audio(f"{i}/{tag}", auds[i][:max_len * sample_rate], self.iters, sample_rate) + else: + self.sw.add_audio(f"{i}/{tag}", auds[i], self.iters, sample_rate) + + def add_audio(self, tag, aud, sample_rate=22050): + if self.rank == 0: + self.sw.add_audio(tag, aud, self.iters, sample_rate) + + def add_images(self, tag, img, dataformats="NHWC"): + if self.rank == 0: + self.sw.add_images(tag, img, self.iters, dataformats=dataformats) + + def add_image(self, tag, img): + if self.rank == 0: + self.sw.add_image(tag, img, self.iters) + + def add_scalar(self, tag, val): + if self.rank == 0: + self.sw.add_scalar(tag, val, self.iters) + + def get_range(self, loader): + if self.rank == 0: + self.trange = def_tqdm(loader) + else: + self.trange = loader + return enumerate(self.trange) + + def close_range(self): + if self.rank == 0: + self.trange.close() + + def set_postfix(self, *args, **kwargs): + if self.rank == 0: + self.trange.set_postfix(*args, **kwargs) + + # For logging summaries of varies graph ops + def add_reduce_scalar(self, tag, layer, val): + if self.iters % 100 == 0: + with t.no_grad(): + val = val.float().norm()/float(val.numel()) + work = dist.reduce(val, 0, async_op=True) + self.works.append((tag, layer, val, work)) + + def finish_reduce(self): + for tag, layer, val, work in self.works: + work.wait() + if self.rank == 0: + val = val.item()/dist.get_world_size() + self.lw[layer].add_scalar(tag, val, self.iters) + self.works = [] diff --git a/models/utils/remote_utils.py b/models/utils/remote_utils.py new file mode 100644 index 00000000..7bdf953f --- /dev/null +++ b/models/utils/remote_utils.py @@ -0,0 +1,42 @@ +import sys +import subprocess + +def download(remote_path, local_path, async_download=False): + args = ['wget', '-O', local_path, remote_path] + print("Running ", " ".join(args)) + if async_download: + subprocess.Popen(args) + else: + subprocess.call(args) + +# GCE +def gs_download(gs_path, local_path, async_download=False): + args = ['gsutil', + '-o', 'GSUtil:parallel_thread_count=1', + '-o', 'GSUtil:sliced_object_download_max_components=8', + 'cp', gs_path, local_path] + if async_download: + subprocess.Popen(args) + else: + subprocess.call(args) + + +def gs_upload(local_path, gs_path, async_upload=False): + # NOTE: Download and upload have differ -o flags. + # We also use -n to prevent clobbering checkpoints by mistake + assert not local_path.startswith("gs://") + assert gs_path.startswith("gs://") + args = ['gsutil', + '-o', 'GSUtil:parallel_composite_upload_threshold=150M', + 'cp', '-n', local_path, gs_path] + if async_upload: + subprocess.Popen(args) + else: + subprocess.call(args) + +def ls(regex): + outputs = subprocess.check_output(['gsutil', 'ls', regex]).decode(sys.stdout.encoding) + outputs = outputs.split('\n') + outputs = [output for output in outputs if output is not ''] + return outputs + diff --git a/models/utils/sample_utils.py b/models/utils/sample_utils.py new file mode 100644 index 00000000..0ae41b16 --- /dev/null +++ b/models/utils/sample_utils.py @@ -0,0 +1,22 @@ +import torch as t + +def split_batch(obj, n_samples, split_size): + n_passes = (n_samples + split_size - 1) // split_size + if isinstance(obj, t.Tensor): + return t.split(obj, split_size, dim=0) + elif isinstance(obj, list): + return list(zip(*[t.split(item, split_size, dim=0) for item in obj])) + elif obj is None: + return [None] * n_passes + else: + raise TypeError('Unknown input type') + +# Break total_length into hops/windows of size n_ctx separated by hop_length +def get_starts(total_length, n_ctx, hop_length): + starts = [] + for start in range(0, total_length - n_ctx + hop_length, hop_length): + if start + n_ctx >= total_length: + # Last hop could be smaller, we make it n_ctx to maximise context + start = total_length - n_ctx + starts.append(start) + return starts diff --git a/models/utils/torch_utils.py b/models/utils/torch_utils.py new file mode 100644 index 00000000..5d020814 --- /dev/null +++ b/models/utils/torch_utils.py @@ -0,0 +1,32 @@ +import gc +import torch as t + +def freeze_model(model): + model.eval() + for params in model.parameters(): + params.requires_grad = False + + +def unfreeze_model(model): + model.train() + for params in model.parameters(): + params.requires_grad = True + +def zero_grad(model): + for p in model.parameters(): + if p.requires_grad and p.grad is not None: + p.grad = None + +def empty_cache(): + gc.collect() + t.cuda.empty_cache() + +def assert_shape(x, exp_shape): + assert x.shape == exp_shape, f"Expected {exp_shape} got {x.shape}" + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +def count_state(model): + return sum(s.numel() for s in model.state_dict().values()) +