Skip to content

Commit

Permalink
add utils files
Browse files Browse the repository at this point in the history
  • Loading branch information
lisiyao21 committed Apr 25, 2022
1 parent d9d31a0 commit 5f40dda
Show file tree
Hide file tree
Showing 9 changed files with 889 additions and 0 deletions.
Empty file added models/utils/__init__.py
Empty file.
148 changes: 148 additions & 0 deletions models/utils/audio_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import numpy as np
import torch as t
import models.utils.dist_adapter as dist
import soundfile
import librosa
from models.utils.dist_utils import print_once

class DefaultSTFTValues:
def __init__(self, hps):
self.sr = hps.sr
self.n_fft = 2048
self.hop_length = 256
self.window_size = 6 * self.hop_length

class STFTValues:
def __init__(self, hps, n_fft, hop_length, window_size):
self.sr = hps.sr
self.n_fft = n_fft
self.hop_length = hop_length
self.window_size = window_size

def calculate_bandwidth(dataset, hps, duration=600):
hps = DefaultSTFTValues(hps)
n_samples = int(dataset.sr * duration)
l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank()
spec_norm_total, spec_nelem = 0.0, 0.0
while n_seen < n_samples:
x = dataset[idx]
if isinstance(x, (tuple, list)):
x, y = x
samples = x.astype(np.float64)
stft = librosa.core.stft(np.mean(samples, axis=1), hps.n_fft, hop_length=hps.hop_length, win_length=hps.window_size)
spec = np.absolute(stft)
spec_norm_total += np.linalg.norm(spec)
spec_nelem += 1
n_seen += int(np.prod(samples.shape))
l1 += np.sum(np.abs(samples))
total += np.sum(samples)
total_sq += np.sum(samples ** 2)
idx += max(16, dist.get_world_size())

if dist.is_available():
from jukebox.utils.dist_utils import allreduce
n_seen = allreduce(n_seen)
total = allreduce(total)
total_sq = allreduce(total_sq)
l1 = allreduce(l1)
spec_nelem = allreduce(spec_nelem)
spec_norm_total = allreduce(spec_norm_total)

mean = total / n_seen
bandwidth = dict(l2 = total_sq / n_seen - mean ** 2,
l1 = l1 / n_seen,
spec = spec_norm_total / spec_nelem)
print_once(bandwidth)
return bandwidth

def audio_preprocess(x, hps):
# Extra layer in case we want to experiment with different preprocessing
# For two channel, blend randomly into mono (standard is .5 left, .5 right)

# x: NTC
# x = x.float()
# if x.shape[-1]==2:
# if hps.aug_blend:
# mix=t.rand((x.shape[0],1), device=x.device) #np.random.rand()
# else:
# mix = 0.5
# x=(mix*x[:,:,0]+(1-mix)*x[:,:,1])
# elif x.shape[-1]==1:
# x=x[:,:,0]
# else:
# assert False, f'Expected channels {hps.channels}. Got unknown {x.shape[-1]} channels'

# # x: NT -> NTC
# x = x.unsqueeze(2)
return x

def audio_postprocess(x, hps):
return x

def stft(sig, hps):
return t.stft(sig, hps.n_fft, hps.hop_length, win_length=hps.window_size, window=t.hann_window(hps.window_size, device=sig.device))

def spec(x, hps):
return t.norm(stft(x, hps), p=2, dim=-1)

def norm(x):
return (x.view(x.shape[0], -1) ** 2).sum(dim=-1).sqrt()

def squeeze(x):
if len(x.shape) == 3:
assert x.shape[-1] in [1,2]
x = t.mean(x, -1)
if len(x.shape) != 2:
raise ValueError(f'Unknown input shape {x.shape}')
return x

def spectral_loss(x_in, x_out, hps):
hps = DefaultSTFTValues(hps)
spec_in = spec(squeeze(x_in.float()), hps)
spec_out = spec(squeeze(x_out.float()), hps)
return norm(spec_in - spec_out)

def multispectral_loss(x_in, x_out, hps):
losses = []
assert len(hps.multispec_loss_n_fft) == len(hps.multispec_loss_hop_length) == len(hps.multispec_loss_window_size)
args = [hps.multispec_loss_n_fft,
hps.multispec_loss_hop_length,
hps.multispec_loss_window_size]
for n_fft, hop_length, window_size in zip(*args):
hps = STFTValues(hps, n_fft, hop_length, window_size)
spec_in = spec(squeeze(x_in.float()), hps)
spec_out = spec(squeeze(x_out.float()), hps)
losses.append(norm(spec_in - spec_out))
return sum(losses) / len(losses)

def spectral_convergence(x_in, x_out, hps, epsilon=2e-3):
hps = DefaultSTFTValues(hps)
spec_in = spec(squeeze(x_in.float()), hps)
spec_out = spec(squeeze(x_out.float()), hps)

gt_norm = norm(spec_in)
residual_norm = norm(spec_in - spec_out)
mask = (gt_norm > epsilon).float()
return (residual_norm * mask) / t.clamp(gt_norm, min=epsilon)

def log_magnitude_loss(x_in, x_out, hps, epsilon=1e-4):
hps = DefaultSTFTValues(hps)
spec_in = t.log(spec(squeeze(x_in.float()), hps) + epsilon)
spec_out = t.log(spec(squeeze(x_out.float()), hps) + epsilon)
return t.mean(t.abs(spec_in - spec_out))

def load_audio(file, sr, offset, duration, mono=False):
# Librosa loads more filetypes than soundfile
x, _ = librosa.load(file, sr=sr, mono=mono, offset=offset/sr, duration=duration/sr)
if len(x.shape) == 1:
x = x.reshape((1, -1))
return x


def save_wav(fname, aud, sr):
# clip before saving?
aud = t.clamp(aud, -1, 1).cpu().numpy()
for i in list(range(aud.shape[0])):
soundfile.write(f'{fname}/item_{i}.wav', aud[i], samplerate=sr, format='wav')


101 changes: 101 additions & 0 deletions models/utils/dist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os
from time import sleep
import torch
import models.utils.dist_adapter as dist

def print_once(msg):
if (not dist.is_available()) or dist.get_rank()==0:
print(msg)

def print_all(msg):
if (not dist.is_available()):
print(msg)
elif dist.get_rank()%8==0:
print(f'{dist.get_rank()//8}: {msg}')

def allgather(x):
xs = [torch.empty_like(x) for _ in range(dist.get_world_size())]
dist.all_gather(xs, x)
xs = torch.cat(xs, dim=0)
return xs

def allreduce(x, op=dist.ReduceOp.SUM):
x = torch.tensor(x).float().cuda()
dist.all_reduce(x, op=op)
return x.item()

def allgather_lists(xs):
bs = len(xs)
total_bs = dist.get_world_size()*len(xs)
lengths = torch.tensor([len(x) for x in xs], dtype=t.long, device='cuda')
lengths = allgather(lengths)
assert lengths.shape == (total_bs,)
max_length = torch.max(lengths).item()

xs = torch.tensor([[*x, *[0]*(max_length - len(x))] for x in xs], device='cuda')
assert xs.shape == (bs, max_length), f'Expected {(bs, max_length)}, got {xs.shape}'
xs = allgather(xs)
assert xs.shape == (total_bs,max_length), f'Expected {(total_bs, max_length)}, got {xs.shape}'

return [xs[i][:lengths[i]].cpu().numpy().tolist() for i in range(total_bs)]

def setup_dist_from_mpi(
master_addr="127.0.0.1", backend="nccl", port=29500, n_attempts=5, verbose=False
):
if dist.is_available():
return _setup_dist_from_mpi(master_addr, backend, port, n_attempts, verbose)
else:
use_cuda = torch.cuda.is_available()
print(f'Using cuda {use_cuda}')

mpi_rank = 0
local_rank = 0

device = torch.device("cuda", local_rank) if use_cuda else torch.device("cpu")
torch.cuda.set_device(local_rank)

return mpi_rank, local_rank, device

def _setup_dist_from_mpi(master_addr, backend, port, n_attempts, verbose):
from mpi4py import MPI # This must be imported in order to get e rrors from all ranks to show up

mpi_rank = MPI.COMM_WORLD.Get_rank()
mpi_size = MPI.COMM_WORLD.Get_size()


os.environ["RANK"] = str(mpi_rank)
os.environ["WORLD_SIZE"] = str(mpi_size)
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = str(port)
os.environ["NCCL_LL_THRESHOLD"] = "0"
os.environ["NCCL_NSOCKS_PERTHREAD"] = "2"
os.environ["NCCL_SOCKET_NTHREADS"] = "8"

# Pin this rank to a specific GPU on the node
local_rank = mpi_rank % 8
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)

if verbose:
print(f"Connecting to master_addr: {master_addr}")

# There is a race condition when initializing NCCL with a large number of ranks (e.g 500 ranks)
# We guard against the failure and then retry
for attempt_idx in range(n_attempts):
try:
dist.init_process_group(backend=backend, init_method=f"env://")
assert dist.get_rank() == mpi_rank

use_cuda = torch.cuda.is_available()
print(f'Using cuda {use_cuda}')
local_rank = mpi_rank % 8
device = torch.device("cuda", local_rank) if use_cuda else torch.device("cpu")
torch.cuda.set_device(local_rank)

return mpi_rank, local_rank, device
except RuntimeError as e:
print(f"Caught error during NCCL init (attempt {attempt_idx} of {n_attempts}): {e}")
sleep(1 + (0.01 * mpi_rank)) # Sleep to avoid thundering herd
pass

raise RuntimeError("Failed to initialize NCCL")
94 changes: 94 additions & 0 deletions models/utils/ema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
from torch._utils import _flatten_dense_tensors
import numpy as np

# EMA always in float, as accumulation needs lots of bits
class EMA:
def __init__(self, params, mu=0.999):
self.mu = mu
self.state = [(p, self.get_model_state(p)) for p in params if p.requires_grad]

def get_model_state(self, p):
return p.data.float().detach().clone()

def step(self):
for p, state in self.state:
state.mul_(self.mu).add_(1 - self.mu, p.data.float())

def swap(self):
# swap ema and model params
for p, state in self.state:
other_state = self.get_model_state(p)
p.data.copy_(state.type_as(p.data))
state.copy_(other_state)


class CPUEMA:
def __init__(self, params, mu=0.999, freq=1):
self.mu = mu**freq
self.state = [(p, self.get_model_state(p)) for p in params if p.requires_grad]
self.freq = freq
self.steps = 0

def get_model_state(self, p):
with torch.no_grad():
state = p.data.float().detach().cpu().numpy()
return state

def step(self):
with torch.no_grad():
self.steps += 1
if self.steps % self.freq == 0:
for i in range(len(self.state)):
p, state = self.state[i]
state = torch.from_numpy(state).cuda()
state.mul_(self.mu).add_(1 - self.mu, p.data.float())
self.state[i] = (p, state.cpu().numpy())

def swap(self):
with torch.no_grad():
# swap ema and model params
for p, state in self.state:
other_state = self.get_model_state(p)
p.data.copy_(torch.from_numpy(state).type_as(p.data))
np.copyto(state, other_state)

class FusedEMA:
def __init__(self, params, mu=0.999):
self.mu = mu
params = list(params)
self.params = {}
self.params['fp16'] = [p for p in params if p.requires_grad and p.data.dtype == torch.float16]
self.params['fp32'] = [p for p in params if p.requires_grad and p.data.dtype != torch.float16]
self.groups = [group for group in self.params.keys() if len(self.params[group]) > 0]
self.state = {}
for group in self.groups:
self.state[group] = self.get_model_state(group)

def get_model_state(self, group):
params = self.params[group]
return _flatten_dense_tensors([p.data.float() for p in params])
# if self.fp16:
# return _flatten_dense_tensors([p.data.half() for p in self.param_group if p.dtype])
# else:
# return _flatten_dense_tensors([p.data for p in self.param_group])

def step(self):
for group in self.groups:
self.state[group].mul_(self.mu).add_(1 - self.mu, self.get_model_state(group))

def swap(self):
# swap ema and model params
for group in self.groups:
other_state = self.get_model_state(group)
state = self.state[group]
params = self.params[group]
offset = 0
for p in params:
numel = p.data.numel()
p.data = state.narrow(0, offset, numel).view_as(p.data).type_as(p.data)
offset += numel

self.state[group] = other_state


Loading

0 comments on commit 5f40dda

Please sign in to comment.