forked from beefoo/media-tools
-
Notifications
You must be signed in to change notification settings - Fork 0
/
samples_to_fingerprints.py
129 lines (112 loc) · 4.09 KB
/
samples_to_fingerprints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# -*- coding: utf-8 -*-
# https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
import argparse
import audioread
from lib.audio_utils import *
from lib.cache_utils import *
from lib.collection_utils import *
from lib.io_utils import *
from lib.math_utils import *
from lib.processing_utils import *
import librosa
from multiprocessing import Pool
from multiprocessing.dummy import Pool as ThreadPool
import os
import numpy as np
from pprint import pprint
from skimage.measure import block_reduce
import sys
# input
parser = argparse.ArgumentParser()
parser.add_argument('-in', dest="INPUT_FILE", default="tmp/samples.csv", help="Input file")
parser.add_argument('-dir', dest="AUDIO_DIRECTORY", default="media/sample/", help="Input file")
parser.add_argument('-out', dest="OUTPUT_FILE", default="tmp/features.p", help="Output file")
parser.add_argument('-cellw', dest="CELL_W", default=32, type=int, help="Width of each cell")
parser.add_argument('-cellh', dest="CELL_H", default=32, type=int, help="Height of each cell")
parser.add_argument('-threads', dest="THREADS", default=4, type=int, help="Number of threads")
parser.add_argument('-log', dest="USE_LOG", action="store_true", help="Use log for fingerprint?")
a = parser.parse_args()
# Read files
fieldNames, rows = readCsv(a.INPUT_FILE)
rowCount = len(rows)
rows = addIndices(rows)
rows = prependAll(rows, ("filename", a.AUDIO_DIRECTORY))
# Make sure output dirs exist
makeDirectories(a.OUTPUT_FILE)
# find unique filepaths
print("Matching samples to files...")
filenames = list(set([row["filename"] for row in rows]))
params = [{
"samples": [row for row in rows if row["filename"]==fn],
"filename": fn
} for fn in filenames]
fileCount = len(params)
progress = 0
# Adapted from: https://github.com/kylemcdonald/AudioNotebooks/blob/master/Samples%20to%20Fingerprints.ipynb
def getFingerPrint(y, sr, start, dur, n_fft=2048, hop_length=512, window=None, use_logamp=False):
global a
if len(y) < 1:
return np.zeros((a.CELL_H, a.CELL_W))
# take at most one second
dur = min(dur, 1000)
# analyze just the sample
i0 = int(round(start / 1000.0 * sr))
i1 = int(round((start+dur) / 1000.0 * sr))
y = y[i0:i1]
reduce_rows = 10 # how many frequency bands to average into one
reduce_cols = 1 # how many time steps to average into one
crop_rows = a.CELL_H # limit how many frequency bands to use
crop_cols = a.CELL_W # limit how many time steps to use
if not window:
window = np.hanning(n_fft)
S = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, window=window)
amp = np.abs(S)
if reduce_rows > 1 or reduce_cols > 1:
amp = block_reduce(amp, (reduce_rows, reduce_cols), func=np.mean)
if amp.shape[1] < crop_cols:
amp = np.pad(amp, ((0, 0), (0, crop_cols-amp.shape[1])), 'constant')
amp = amp[:crop_rows, :crop_cols]
if use_logamp:
amp = librosa.amplitude_to_db(amp**2)
amp -= amp.min()
if amp.max() > 0:
amp /= amp.max()
amp = np.flipud(amp) # for visualization, put low frequencies on bottom
return amp
def processFile(p):
global progress
global rowCount
global a
fingerprints = []
# load audio
fn = getAudioFile(p["filename"])
try:
y, sr = loadAudioData(fn)
except audioread.macca.MacError:
y = []
sr = 48000
for sample in p["samples"]:
fingerprint = getFingerPrint(y, sr, sample["start"], sample["dur"], use_logamp=a.USE_LOG)
fingerprints.append({
"index": sample["index"],
"fingerprint": fingerprint
})
progress += 1
printProgress(progress, rowCount)
return fingerprints
print("Processing fingerprints...")
data = []
if a.THREADS == 1:
for p in params:
processFile(p)
else:
threads = getThreadCount(a.THREADS)
pool = ThreadPool(threads)
data = pool.map(processFile, params)
pool.close()
pool.join()
data = flattenList(data)
data = sorted(data, key=lambda d: d["index"])
fingerprints = [d["fingerprint"] for d in data]
saveCacheFile(a.OUTPUT_FILE, fingerprints, overwrite=True)
print("Done.")