This repository contains the official PyTorch implementation for the paper [1]:
- Tal Peer, Simon Welker, Timo Gerkmann. "DiffPhase: Generative Diffusion-based STFT Phase Retrieval", 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), Rhodes Island, Greece, Jun. 2023. [arxiv] [bibtex]
Audio examples are available on our project page.
DiffPhase is an adaptation of the SGMSE+ diffusion-based speech enhancement method to phase retrieval. SGMSE+ is described in [2] and [3] and has its own repository.
- Clone this repository along with the sgmse repository which is included as a submodule:
git clone --recurse-submodules https://github.com/sp-uhh/diffphase.git
- Create a new virtual environment with Python 3.8 (we have not tested other Python versions, but they may work).
- Install the package dependencies via
pip install -r requirements.txt
. - If using W&B logging (default):
- Set up a wandb.ai account
- Log in via
wandb login
before running our code.
- If not using W&B logging:
- Pass the option
--no_wandb
totrain.py
. - Your logs will be stored as local TensorBoard logs. Run
tensorboard --logdir logs/
to see them.
- Pass the option
We provide two pretrained checkpoints:
- DiffPhase using the default SGMSE configuration. This model has ~65M parameters
- DiffPhase-small with ~22M parameters
Usage:
- For resuming training, you can use the
--resume_from_checkpoint
option oftrain.py
. - For performing phase reconstructions with these checkpoints, use the
--ckpt
option ofreconstruct.py
(see section Evaluation below).
Training is done by executing train.py
. A minimal running example with default settings (as in our paper [1]) can be run with
python train.py --base_dir <your_base_dir>
where your_base_dir
should be a path to a folder containing subdirectories train/
and valid/
. Each subdirectory must itself have a directory clean/
. We currently only support training with .wav
files sampled at 16 kHz.
For the DiffPhase-small variant, use the following options:
python train.py --num_res_blocks 1 --attn_resolutions 0 --ch_mult 1 1 2 2 1 --base_dir <your_base_dir>
To see all available training options, run python train.py --help
. Also see the sgmse repository for more information.
We provide an example script that takes a .wav
file as an input, removes the phase and writes a reconstructed signal to another .wav
file. Reconstruction is performed using the same procedure described in our paper. To use it, run
python reconstruct.py --input <input_wav> --output <reconstructed_wav> --ckpt <path_to_model_checkpoint> --N <number_of_reverse_steps>
We kindly ask you to cite our paper in your publication when using any of our research or code:
@inproceedings{peerDiffPhase2023,
title = {{DiffPhase: Generative Diffusion-based STFT Phase Retrieval}},
booktitle = {{2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}},
author = {Peer, Tal and Welker, Simon and Gerkmann, Timo},
date = {2023-06},
doi = {10.1109/ICASSP49357.2023.10095396}
}
[1] Tal Peer, Simon Welker, Timo Gerkmann. "DiffPhase: Generative Diffusion-based STFT Phase Retrieval", 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), Rhodes Island, Greece, Jun. 2023.
[2] Simon Welker, Julius Richter, Timo Gerkmann. "Speech Enhancement with Score-Based Generative Models in the Complex STFT Domain", ISCA Interspeech, Incheon, Korea, Sep. 2022.
[3] Julius Richter, Simon Welker, Jean-Marie Lemercier, Bunlong Lay, Timo Gerkmann. "Speech Enhancement and Dereverberation with Diffusion-Based Generative Models", IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 31, pp. 2351-2364, 2023.