-
Notifications
You must be signed in to change notification settings - Fork 64
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
889 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import numpy as np | ||
import torch as t | ||
import models.utils.dist_adapter as dist | ||
import soundfile | ||
import librosa | ||
from models.utils.dist_utils import print_once | ||
|
||
class DefaultSTFTValues: | ||
def __init__(self, hps): | ||
self.sr = hps.sr | ||
self.n_fft = 2048 | ||
self.hop_length = 256 | ||
self.window_size = 6 * self.hop_length | ||
|
||
class STFTValues: | ||
def __init__(self, hps, n_fft, hop_length, window_size): | ||
self.sr = hps.sr | ||
self.n_fft = n_fft | ||
self.hop_length = hop_length | ||
self.window_size = window_size | ||
|
||
def calculate_bandwidth(dataset, hps, duration=600): | ||
hps = DefaultSTFTValues(hps) | ||
n_samples = int(dataset.sr * duration) | ||
l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank() | ||
spec_norm_total, spec_nelem = 0.0, 0.0 | ||
while n_seen < n_samples: | ||
x = dataset[idx] | ||
if isinstance(x, (tuple, list)): | ||
x, y = x | ||
samples = x.astype(np.float64) | ||
stft = librosa.core.stft(np.mean(samples, axis=1), hps.n_fft, hop_length=hps.hop_length, win_length=hps.window_size) | ||
spec = np.absolute(stft) | ||
spec_norm_total += np.linalg.norm(spec) | ||
spec_nelem += 1 | ||
n_seen += int(np.prod(samples.shape)) | ||
l1 += np.sum(np.abs(samples)) | ||
total += np.sum(samples) | ||
total_sq += np.sum(samples ** 2) | ||
idx += max(16, dist.get_world_size()) | ||
|
||
if dist.is_available(): | ||
from jukebox.utils.dist_utils import allreduce | ||
n_seen = allreduce(n_seen) | ||
total = allreduce(total) | ||
total_sq = allreduce(total_sq) | ||
l1 = allreduce(l1) | ||
spec_nelem = allreduce(spec_nelem) | ||
spec_norm_total = allreduce(spec_norm_total) | ||
|
||
mean = total / n_seen | ||
bandwidth = dict(l2 = total_sq / n_seen - mean ** 2, | ||
l1 = l1 / n_seen, | ||
spec = spec_norm_total / spec_nelem) | ||
print_once(bandwidth) | ||
return bandwidth | ||
|
||
def audio_preprocess(x, hps): | ||
# Extra layer in case we want to experiment with different preprocessing | ||
# For two channel, blend randomly into mono (standard is .5 left, .5 right) | ||
|
||
# x: NTC | ||
# x = x.float() | ||
# if x.shape[-1]==2: | ||
# if hps.aug_blend: | ||
# mix=t.rand((x.shape[0],1), device=x.device) #np.random.rand() | ||
# else: | ||
# mix = 0.5 | ||
# x=(mix*x[:,:,0]+(1-mix)*x[:,:,1]) | ||
# elif x.shape[-1]==1: | ||
# x=x[:,:,0] | ||
# else: | ||
# assert False, f'Expected channels {hps.channels}. Got unknown {x.shape[-1]} channels' | ||
|
||
# # x: NT -> NTC | ||
# x = x.unsqueeze(2) | ||
return x | ||
|
||
def audio_postprocess(x, hps): | ||
return x | ||
|
||
def stft(sig, hps): | ||
return t.stft(sig, hps.n_fft, hps.hop_length, win_length=hps.window_size, window=t.hann_window(hps.window_size, device=sig.device)) | ||
|
||
def spec(x, hps): | ||
return t.norm(stft(x, hps), p=2, dim=-1) | ||
|
||
def norm(x): | ||
return (x.view(x.shape[0], -1) ** 2).sum(dim=-1).sqrt() | ||
|
||
def squeeze(x): | ||
if len(x.shape) == 3: | ||
assert x.shape[-1] in [1,2] | ||
x = t.mean(x, -1) | ||
if len(x.shape) != 2: | ||
raise ValueError(f'Unknown input shape {x.shape}') | ||
return x | ||
|
||
def spectral_loss(x_in, x_out, hps): | ||
hps = DefaultSTFTValues(hps) | ||
spec_in = spec(squeeze(x_in.float()), hps) | ||
spec_out = spec(squeeze(x_out.float()), hps) | ||
return norm(spec_in - spec_out) | ||
|
||
def multispectral_loss(x_in, x_out, hps): | ||
losses = [] | ||
assert len(hps.multispec_loss_n_fft) == len(hps.multispec_loss_hop_length) == len(hps.multispec_loss_window_size) | ||
args = [hps.multispec_loss_n_fft, | ||
hps.multispec_loss_hop_length, | ||
hps.multispec_loss_window_size] | ||
for n_fft, hop_length, window_size in zip(*args): | ||
hps = STFTValues(hps, n_fft, hop_length, window_size) | ||
spec_in = spec(squeeze(x_in.float()), hps) | ||
spec_out = spec(squeeze(x_out.float()), hps) | ||
losses.append(norm(spec_in - spec_out)) | ||
return sum(losses) / len(losses) | ||
|
||
def spectral_convergence(x_in, x_out, hps, epsilon=2e-3): | ||
hps = DefaultSTFTValues(hps) | ||
spec_in = spec(squeeze(x_in.float()), hps) | ||
spec_out = spec(squeeze(x_out.float()), hps) | ||
|
||
gt_norm = norm(spec_in) | ||
residual_norm = norm(spec_in - spec_out) | ||
mask = (gt_norm > epsilon).float() | ||
return (residual_norm * mask) / t.clamp(gt_norm, min=epsilon) | ||
|
||
def log_magnitude_loss(x_in, x_out, hps, epsilon=1e-4): | ||
hps = DefaultSTFTValues(hps) | ||
spec_in = t.log(spec(squeeze(x_in.float()), hps) + epsilon) | ||
spec_out = t.log(spec(squeeze(x_out.float()), hps) + epsilon) | ||
return t.mean(t.abs(spec_in - spec_out)) | ||
|
||
def load_audio(file, sr, offset, duration, mono=False): | ||
# Librosa loads more filetypes than soundfile | ||
x, _ = librosa.load(file, sr=sr, mono=mono, offset=offset/sr, duration=duration/sr) | ||
if len(x.shape) == 1: | ||
x = x.reshape((1, -1)) | ||
return x | ||
|
||
|
||
def save_wav(fname, aud, sr): | ||
# clip before saving? | ||
aud = t.clamp(aud, -1, 1).cpu().numpy() | ||
for i in list(range(aud.shape[0])): | ||
soundfile.write(f'{fname}/item_{i}.wav', aud[i], samplerate=sr, format='wav') | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import os | ||
from time import sleep | ||
import torch | ||
import models.utils.dist_adapter as dist | ||
|
||
def print_once(msg): | ||
if (not dist.is_available()) or dist.get_rank()==0: | ||
print(msg) | ||
|
||
def print_all(msg): | ||
if (not dist.is_available()): | ||
print(msg) | ||
elif dist.get_rank()%8==0: | ||
print(f'{dist.get_rank()//8}: {msg}') | ||
|
||
def allgather(x): | ||
xs = [torch.empty_like(x) for _ in range(dist.get_world_size())] | ||
dist.all_gather(xs, x) | ||
xs = torch.cat(xs, dim=0) | ||
return xs | ||
|
||
def allreduce(x, op=dist.ReduceOp.SUM): | ||
x = torch.tensor(x).float().cuda() | ||
dist.all_reduce(x, op=op) | ||
return x.item() | ||
|
||
def allgather_lists(xs): | ||
bs = len(xs) | ||
total_bs = dist.get_world_size()*len(xs) | ||
lengths = torch.tensor([len(x) for x in xs], dtype=t.long, device='cuda') | ||
lengths = allgather(lengths) | ||
assert lengths.shape == (total_bs,) | ||
max_length = torch.max(lengths).item() | ||
|
||
xs = torch.tensor([[*x, *[0]*(max_length - len(x))] for x in xs], device='cuda') | ||
assert xs.shape == (bs, max_length), f'Expected {(bs, max_length)}, got {xs.shape}' | ||
xs = allgather(xs) | ||
assert xs.shape == (total_bs,max_length), f'Expected {(total_bs, max_length)}, got {xs.shape}' | ||
|
||
return [xs[i][:lengths[i]].cpu().numpy().tolist() for i in range(total_bs)] | ||
|
||
def setup_dist_from_mpi( | ||
master_addr="127.0.0.1", backend="nccl", port=29500, n_attempts=5, verbose=False | ||
): | ||
if dist.is_available(): | ||
return _setup_dist_from_mpi(master_addr, backend, port, n_attempts, verbose) | ||
else: | ||
use_cuda = torch.cuda.is_available() | ||
print(f'Using cuda {use_cuda}') | ||
|
||
mpi_rank = 0 | ||
local_rank = 0 | ||
|
||
device = torch.device("cuda", local_rank) if use_cuda else torch.device("cpu") | ||
torch.cuda.set_device(local_rank) | ||
|
||
return mpi_rank, local_rank, device | ||
|
||
def _setup_dist_from_mpi(master_addr, backend, port, n_attempts, verbose): | ||
from mpi4py import MPI # This must be imported in order to get e rrors from all ranks to show up | ||
|
||
mpi_rank = MPI.COMM_WORLD.Get_rank() | ||
mpi_size = MPI.COMM_WORLD.Get_size() | ||
|
||
|
||
os.environ["RANK"] = str(mpi_rank) | ||
os.environ["WORLD_SIZE"] = str(mpi_size) | ||
os.environ["MASTER_ADDR"] = master_addr | ||
os.environ["MASTER_PORT"] = str(port) | ||
os.environ["NCCL_LL_THRESHOLD"] = "0" | ||
os.environ["NCCL_NSOCKS_PERTHREAD"] = "2" | ||
os.environ["NCCL_SOCKET_NTHREADS"] = "8" | ||
|
||
# Pin this rank to a specific GPU on the node | ||
local_rank = mpi_rank % 8 | ||
if torch.cuda.is_available(): | ||
torch.cuda.set_device(local_rank) | ||
|
||
if verbose: | ||
print(f"Connecting to master_addr: {master_addr}") | ||
|
||
# There is a race condition when initializing NCCL with a large number of ranks (e.g 500 ranks) | ||
# We guard against the failure and then retry | ||
for attempt_idx in range(n_attempts): | ||
try: | ||
dist.init_process_group(backend=backend, init_method=f"env://") | ||
assert dist.get_rank() == mpi_rank | ||
|
||
use_cuda = torch.cuda.is_available() | ||
print(f'Using cuda {use_cuda}') | ||
local_rank = mpi_rank % 8 | ||
device = torch.device("cuda", local_rank) if use_cuda else torch.device("cpu") | ||
torch.cuda.set_device(local_rank) | ||
|
||
return mpi_rank, local_rank, device | ||
except RuntimeError as e: | ||
print(f"Caught error during NCCL init (attempt {attempt_idx} of {n_attempts}): {e}") | ||
sleep(1 + (0.01 * mpi_rank)) # Sleep to avoid thundering herd | ||
pass | ||
|
||
raise RuntimeError("Failed to initialize NCCL") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import torch | ||
from torch._utils import _flatten_dense_tensors | ||
import numpy as np | ||
|
||
# EMA always in float, as accumulation needs lots of bits | ||
class EMA: | ||
def __init__(self, params, mu=0.999): | ||
self.mu = mu | ||
self.state = [(p, self.get_model_state(p)) for p in params if p.requires_grad] | ||
|
||
def get_model_state(self, p): | ||
return p.data.float().detach().clone() | ||
|
||
def step(self): | ||
for p, state in self.state: | ||
state.mul_(self.mu).add_(1 - self.mu, p.data.float()) | ||
|
||
def swap(self): | ||
# swap ema and model params | ||
for p, state in self.state: | ||
other_state = self.get_model_state(p) | ||
p.data.copy_(state.type_as(p.data)) | ||
state.copy_(other_state) | ||
|
||
|
||
class CPUEMA: | ||
def __init__(self, params, mu=0.999, freq=1): | ||
self.mu = mu**freq | ||
self.state = [(p, self.get_model_state(p)) for p in params if p.requires_grad] | ||
self.freq = freq | ||
self.steps = 0 | ||
|
||
def get_model_state(self, p): | ||
with torch.no_grad(): | ||
state = p.data.float().detach().cpu().numpy() | ||
return state | ||
|
||
def step(self): | ||
with torch.no_grad(): | ||
self.steps += 1 | ||
if self.steps % self.freq == 0: | ||
for i in range(len(self.state)): | ||
p, state = self.state[i] | ||
state = torch.from_numpy(state).cuda() | ||
state.mul_(self.mu).add_(1 - self.mu, p.data.float()) | ||
self.state[i] = (p, state.cpu().numpy()) | ||
|
||
def swap(self): | ||
with torch.no_grad(): | ||
# swap ema and model params | ||
for p, state in self.state: | ||
other_state = self.get_model_state(p) | ||
p.data.copy_(torch.from_numpy(state).type_as(p.data)) | ||
np.copyto(state, other_state) | ||
|
||
class FusedEMA: | ||
def __init__(self, params, mu=0.999): | ||
self.mu = mu | ||
params = list(params) | ||
self.params = {} | ||
self.params['fp16'] = [p for p in params if p.requires_grad and p.data.dtype == torch.float16] | ||
self.params['fp32'] = [p for p in params if p.requires_grad and p.data.dtype != torch.float16] | ||
self.groups = [group for group in self.params.keys() if len(self.params[group]) > 0] | ||
self.state = {} | ||
for group in self.groups: | ||
self.state[group] = self.get_model_state(group) | ||
|
||
def get_model_state(self, group): | ||
params = self.params[group] | ||
return _flatten_dense_tensors([p.data.float() for p in params]) | ||
# if self.fp16: | ||
# return _flatten_dense_tensors([p.data.half() for p in self.param_group if p.dtype]) | ||
# else: | ||
# return _flatten_dense_tensors([p.data for p in self.param_group]) | ||
|
||
def step(self): | ||
for group in self.groups: | ||
self.state[group].mul_(self.mu).add_(1 - self.mu, self.get_model_state(group)) | ||
|
||
def swap(self): | ||
# swap ema and model params | ||
for group in self.groups: | ||
other_state = self.get_model_state(group) | ||
state = self.state[group] | ||
params = self.params[group] | ||
offset = 0 | ||
for p in params: | ||
numel = p.data.numel() | ||
p.data = state.narrow(0, offset, numel).view_as(p.data).type_as(p.data) | ||
offset += numel | ||
|
||
self.state[group] = other_state | ||
|
||
|
Oops, something went wrong.