Skip to content

Commit

Permalink
update spectrogram params
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad committed Feb 26, 2024
1 parent 45ebd80 commit dee6cc5
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 16 deletions.
Binary file modified amt/assets/mel_filters.npz
Binary file not shown.
4 changes: 2 additions & 2 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
)
"""
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
assert n_mels in {80, 128, 256}, f"Unsupported n_mels: {n_mels}"

filters_path = os.path.join(
os.path.dirname(__file__), "assets", "mel_filters.npz"
Expand All @@ -127,7 +127,7 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:

def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = 80,
n_mels: int = 256,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
Expand Down
18 changes: 10 additions & 8 deletions amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
N_FRAMES,
)

config = load_config()["data"]
STRIDE_FACTOR = config["stride_factor"]
config = load_config()
STRIDE_FACTOR = config["data"]["stride_factor"]


def setup_logger():
Expand Down Expand Up @@ -47,6 +47,7 @@ def get_features(audio_path: str, mid_path: str = ""):
then it will return an empty list for the mid_feature
"""
tokenizer = AmtTokenizer()
n_mels = config["audio"]["n_mels"]

if not os.path.isfile(audio_path):
return None
Expand All @@ -57,7 +58,7 @@ def get_features(audio_path: str, mid_path: str = ""):
return None

try:
log_spec = log_mel_spectrogram(audio=audio_path)
log_spec = log_mel_spectrogram(audio=audio_path, n_mels=n_mels)
if mid_path != "":
midi_dict = MidiDict.from_midi(mid_path)
else:
Expand Down Expand Up @@ -162,11 +163,12 @@ def build(
cls,
matched_load_paths: list[tuple[str, str]],
save_path: str,
num_processes: int = 4,
num_processes: int = 1,
):
def _get_features(_matched_load_paths: list):
num_paths = len(_matched_load_paths)
for idx, entry in enumerate(_matched_load_paths):
print(idx)
success, res = get_features_mp(entry)
if idx % 10 == 0 and idx != 0:
print(f"Processed audio-mid pairs: {idx}/{num_paths}")
Expand All @@ -175,10 +177,8 @@ def _get_features(_matched_load_paths: list):
for _audio_feature, _mid_feature in res:
yield _audio_feature.tolist(), _mid_feature

# MP CODE DOESN'T WORK FOR SOME REASON !!

# with Pool(num_processes) as pool:
# results = pool.imap(get_features_mp, _matched_load_paths)
# results = pool.imap_unordered(get_features_mp, _matched_load_paths)
# num_paths = len(_matched_load_paths)
# for idx, (success, res) in enumerate(results):
# if idx % 10 == 0 and idx != 0:
Expand All @@ -191,4 +191,6 @@ def _get_features(_matched_load_paths: list):

with jsonlines.open(save_path, mode="w") as writer:
for audio_feature, mid_feature in _get_features(matched_load_paths):
writer.write([audio_feature, mid_feature])
# writer.write([audio_feature, mid_feature])
print(len(mid_feature))
writer.write(mid_feature)
7 changes: 4 additions & 3 deletions config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
},
"audio": {
"sample_rate": 16000,
"n_fft": 400,
"n_fft": 2048,
"hop_len": 160,
"chunk_len": 30
"chunk_len": 30,
"n_mels": 256
},
"data": {
"stride_factor": 3,
"stride_factor": 1,
"max_seq_len": 4096
}
}
2 changes: 1 addition & 1 deletion config/models/medium.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"n_mels": 80,
"n_mels": 256,
"n_audio_ctx": 1500,
"n_audio_state": 512,
"n_audio_head": 8,
Expand Down
2 changes: 1 addition & 1 deletion config/models/small.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"n_mels": 80,
"n_mels": 256,
"n_audio_ctx": 1500,
"n_audio_state": 384,
"n_audio_head": 6,
Expand Down
2 changes: 1 addition & 1 deletion config/models/test.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"n_mels": 80,
"n_mels": 256,
"n_audio_ctx": 1500,
"n_audio_state": 64,
"n_audio_head": 4,
Expand Down

0 comments on commit dee6cc5

Please sign in to comment.