This repository has been archived by the owner on Feb 12, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* update requirement to conda * update env * add argument to change model * update models * remove model details as we now have multiple models
- Loading branch information
Showing
15 changed files
with
291 additions
and
452 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
name: countnet | ||
channels: | ||
- defaults | ||
dependencies: | ||
- ca-certificates=2020.1.1 | ||
- certifi=2020.4.5.2 | ||
- intel-openmp=2019.4 | ||
- libcxx=10.0.0 | ||
- libedit=3.1.20191231 | ||
- libffi=3.3 | ||
- mkl=2019.4 | ||
- mkl-service=2.3.0 | ||
- ncurses=6.2 | ||
- openssl=1.1.1g | ||
- pip=20.1.1 | ||
- python=3.6.10 | ||
- readline=8.0 | ||
- setuptools=47.3.0 | ||
- six=1.15.0 | ||
- sqlite=3.32.2 | ||
- tk=8.6.10 | ||
- wheel=0.34.2 | ||
- xz=5.2.5 | ||
- zlib=1.2.11 | ||
- pip: | ||
- audioread==2.1.8 | ||
- backports-weakref==1.0rc1 | ||
- bleach==1.5.0 | ||
- cffi==1.14.0 | ||
- decorator==4.4.2 | ||
- h5py==2.10.0 | ||
- html5lib==0.9999999 | ||
- joblib==0.15.1 | ||
- keras==1.2.2 | ||
- librosa==0.7.2 | ||
- llvmlite==0.32.1 | ||
- markdown==2.2.0 | ||
- numba==0.43.0 | ||
- numpy==1.18.5 | ||
- protobuf==3.12.2 | ||
- pycparser==2.20 | ||
- pyyaml==5.3.1 | ||
- resampy==0.2.2 | ||
- scikit-learn==0.22 | ||
- scipy==1.4.1 | ||
- soundfile==0.10.3.post1 | ||
- theano==0.9.0 | ||
- werkzeug==1.0.1 | ||
- tqdm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import numpy as np | ||
import soundfile as sf | ||
import argparse | ||
import os | ||
import keras | ||
import sklearn | ||
import glob | ||
import predict | ||
import json | ||
from keras import backend as K | ||
|
||
import tqdm | ||
|
||
eps = np.finfo(np.float).eps | ||
|
||
|
||
def mae(y, p): | ||
return np.mean([abs(a - b) for a, b in zip(p, y)]) | ||
|
||
|
||
def mae_by_count(y, p): | ||
diffs = [] | ||
for c in range(0, int(np.max(y)) + 1): | ||
ind = np.where(y == c) | ||
diff = mae(y[ind], np.round(p[ind])) | ||
diffs.append(diff) | ||
|
||
return diffs | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser( | ||
description='Load keras model and predict speaker count' | ||
) | ||
parser.add_argument( | ||
'root', | ||
help='root dir to evaluation data set' | ||
) | ||
|
||
parser.add_argument( | ||
'--model', default='CRNN', | ||
help='model name' | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
# load model | ||
model = keras.models.load_model( | ||
os.path.join('models', args.model + '.h5'), | ||
custom_objects={ | ||
'class_mae': predict.class_mae, | ||
'exp': K.exp | ||
} | ||
) | ||
|
||
|
||
# print model configuration | ||
model.summary() | ||
|
||
# load standardisation parameters | ||
scaler = sklearn.preprocessing.StandardScaler() | ||
with np.load(os.path.join("models", 'scaler.npz')) as data: | ||
scaler.mean_ = data['arr_0'] | ||
scaler.scale_ = data['arr_1'] | ||
|
||
input_files = glob.glob(os.path.join( | ||
args.root, 'test', '*.wav' | ||
)) | ||
|
||
y_trues = [] | ||
y_preds = [] | ||
|
||
for input_file in tqdm.tqdm(input_files): | ||
|
||
metadata_file = os.path.splitext( | ||
os.path.basename(input_file) | ||
)[0] + ".json" | ||
metadata_path = os.path.join(args.root, 'test', metadata_file) | ||
|
||
with open(metadata_path) as data_file: | ||
data = json.load(data_file) | ||
# add ground truth | ||
y_trues.append(len(data)) | ||
|
||
# compute audio | ||
audio, rate = sf.read(input_file, always_2d=True) | ||
|
||
# downmix to mono | ||
audio = np.mean(audio, axis=1) | ||
|
||
count = predict.count(audio, model, scaler) | ||
# add prediction | ||
y_preds.append(count) | ||
|
||
y_preds = np.array(y_preds) | ||
y_trues = np.array(y_trues) | ||
|
||
|
||
mae_k = mae_by_count(y_trues, y_preds) | ||
print("MAE per Count: ", {k: v for k, v in enumerate(mae_k)}) | ||
print("Mean MAE", mae(y_trues, y_preds)) |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import numpy as np | ||
import soundfile as sf | ||
import argparse | ||
import os | ||
import keras | ||
import sklearn | ||
import librosa | ||
from keras import backend as K | ||
|
||
|
||
eps = np.finfo(np.float).eps | ||
|
||
|
||
def class_mae(y_true, y_pred): | ||
return K.mean( | ||
K.abs( | ||
K.argmax(y_pred, axis=-1) - K.argmax(y_true, axis=-1) | ||
), | ||
axis=-1 | ||
) | ||
|
||
|
||
def count(audio, model, scaler): | ||
# compute STFT | ||
X = np.abs(librosa.stft(audio, n_fft=400, hop_length=160)).T | ||
|
||
# apply global (featurewise) standardization to mean1, var0 | ||
X = scaler.transform(X) | ||
|
||
# cut to input shape length (500 frames x 201 STFT bins) | ||
X = X[:500, :] | ||
|
||
# apply l2 normalization | ||
Theta = np.linalg.norm(X, axis=1) + eps | ||
X /= np.mean(Theta) | ||
|
||
# add sample dimension | ||
X = X[np.newaxis, ...] | ||
|
||
if len(model.input_shape) == 4: | ||
X = X[:, np.newaxis, ...] | ||
|
||
ys = model.predict(X, verbose=0) | ||
return np.argmax(ys, axis=1)[0] | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser( | ||
description='Load keras model and predict speaker count' | ||
) | ||
|
||
parser.add_argument( | ||
'audio', | ||
help='audio file (samplerate 16 kHz) of 5 seconds duration' | ||
) | ||
|
||
parser.add_argument( | ||
'--model', default='CRNN', | ||
help='model name' | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
# load model | ||
model = keras.models.load_model( | ||
os.path.join('models', args.model + '.h5'), | ||
custom_objects={ | ||
'class_mae': class_mae, | ||
'exp': K.exp | ||
} | ||
) | ||
|
||
# print model configuration | ||
model.summary() | ||
# save as svg file | ||
# load standardisation parameters | ||
scaler = sklearn.preprocessing.StandardScaler() | ||
with np.load(os.path.join("models", 'scaler.npz')) as data: | ||
scaler.mean_ = data['arr_0'] | ||
scaler.scale_ = data['arr_1'] | ||
|
||
# compute audio | ||
audio, rate = sf.read(args.audio, always_2d=True) | ||
|
||
# downmix to mono | ||
audio = np.mean(audio, axis=1) | ||
estimate = count(audio, model, scaler) | ||
print("Speaker Count Estimate: ", estimate) |
Oops, something went wrong.
8e30524
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accidentally deleted
requirements.txt
?8e30524
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with conda only. but if you think its helpful, I can revert that.
More important, if you have any idea to go beyond keras+theano without loosing performance, let me know ;-)
8e30524
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It breaks the Docker build, so I think it's useful :-) Will send a PR when I got it fixed. Some requirements need update
8e30524
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, I didn't check that