Skip to content
/ nuwave Public

NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling @ INTERSPEECH 2021

License

Notifications You must be signed in to change notification settings

maum-ai/nuwave

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

42 Commits
Jul 27, 2021
Jun 21, 2021
Jul 27, 2021
Jun 21, 2021
Jan 18, 2022
Apr 5, 2021
Jul 22, 2022
Jul 5, 2021
Jun 21, 2021
Jun 21, 2021
Jun 21, 2021
Jan 18, 2022
Jun 21, 2021
Jun 21, 2021
Jun 21, 2021
Jun 21, 2021
Jun 21, 2021

Repository files navigation

NU-Wave — Official PyTorch Implementation

NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling
Junhyeok Lee, Seungu Han @ MINDsLab Inc., SNU

arXiv GitHub Repo stars githubio

Official Pytorch+Lightning Implementation for NU-Wave.

Update: typo fixed lightning_model.py line 36 10 --> 20
Errata added for isca_archive and arXiv

Checkpoint Contribution: Thanks to freds0, he released his checkpoint at issue#18!

Official Checkpoints for SingleSpeaker released google_drive.

Since NU-Wave 2 repo is opened, we try to handle issue on new repo.

NU-Wave 2 is accepted to Interspeech 2022! Code and checkpoints are available at github!

Requirements

Preprocessing

Before running our project, you need to download and preprocess dataset to .pt files

  1. Download VCTK dataset
  2. Remove speaker p280 and p315
  3. Modify path of downloaded dataset data:dir in hparameter.yaml
  4. run utils/wav2pt.py
python utils/wav2pt.py

Training

  1. Adjust hparameter.yaml, especially train section.
train:
  batch_size: 18 # Dependent on GPU memory size
  lr: 0.00003
  weight_decay: 0.00
  num_workers: 64 # Dependent on CPU cores
  gpus: 2 # number of GPUs
  opt_eps: 1e-9
  beta1: 0.5
  beta2: 0.999
  • If you want to train with single speaker, use VCTKSingleSpkDataset instead of VCTKMultiSpkDataset for dataset in dataloader.py. And use batch_size=1 for validation dataloader.
  • Adjust data section in hparameters.yaml.
data:
  dir: '/DATA1/VCTK/VCTK-Corpus/wav48/p225' #dir/spk/format
  format: '*mic1.pt'
  cv_ratio: (223./231., 8./231., 0.00) #train/val/test
  1. run trainer.py.
$ python trainer.py
  • If you want to resume training from checkpoint, check parser.
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--resume_from', type =int,\
            required = False, help = "Resume Checkpoint epoch number")
    parser.add_argument('-s', '--restart', action = "store_true",\
            required = False, help = "Significant change occured, use this")
    parser.add_argument('-e', '--ema', action = "store_true",\
            required = False, help = "Start from ema checkpoint")
    args = parser.parse_args()
  • During training, tensorboard logger is logging loss, spectrogram and audio.
$ tensorboard --logdir=./tensorboard --bind_all

Evaluation

run for_test.py or test.py

python test.py -r {checkpoint_number} {-e:option, if ema} {--save:option}
or
python for_test.py -r {checkpoint_number} {-e:option, if ema} {--save:option}

Please check parser.

    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--resume_from', type =int,
                required = True, help = "Resume Checkpoint epoch number")
    parser.add_argument('-e', '--ema', action = "store_true",
                required = False, help = "Start from ema checkpoint")
    parser.add_argument('--save', action = "store_true",
               required = False, help = "Save file")

While we provide lightning style test code test.py, it has device dependency. Thus, we recommend to use for_test.py.

References

This implementation uses code from following repositories:

This README and the webpage for the audio samples are inspired by:

The audio samples on our webpage are partially derived from:

Repository Structure

.
├── Dockerfile
├── dataloader.py           # Dataloader for train/val(=test)
├── filters.py              # Filter implementation
├── test.py                 # Test with lightning_loop.
├── for_test.py             # Test with for_loop. Recommended due to device dependency of lightning
├── hparameter.yaml         # Config
├── lightning_model.py      # NU-Wave implementation. DDPM is based on ivanvok's WaveGrad implementation
├── model.py                # NU-Wave model based on lmnt-com's DiffWave implementation
├── requirement.txt         # requirement libraries
├── sampling.py             # Sampling a file
├── trainer.py              # Lightning trainer
├── README.md           
├── LICSENSE
├── utils
│  ├── stft.py              # STFT layer
│  ├── tblogger.py          # Tensorboard Logger for lightning
│  └── wav2pt.py            # Preprocessing
└── docs                    # For github.io
   └─ ...

Citation & Contact

If this repository useful for your research, please consider citing!

@inproceedings{lee21nuwave,
  author={Junhyeok Lee and Seungu Han},
  title={{NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling}},
  year=2021,
  booktitle={Proc. Interspeech 2021},
  pages={1634--1638},
  doi={10.21437/Interspeech.2021-36}
}

If you have a question or any kind of inquiries, please contact Junhyeok Lee at jun3518@mindslab.ai