Skip to content

Commit

Permalink
remote changes
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad committed Apr 9, 2024
1 parent 59dbb4c commit 1f423d8
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 132 deletions.
21 changes: 13 additions & 8 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,15 @@ def __init__(
max_snr: int = 50,
max_dist_gain: int = 25,
min_dist_gain: int = 0,
noise_ratio: float = 0.95,
reverb_ratio: float = 0.95,
noise_ratio: float = 0.75,
reverb_ratio: float = 0.75,
applause_ratio: float = 0.01,
bandpass_ratio: float = 0.15,
distort_ratio: float = 0.15,
reduce_ratio: float = 0.01,
detune_ratio: float = 0.1,
detune_max_shift: float = 0.15,
spec_aug_ratio: float = 0.5,
detune_ratio: float = 0.0,
detune_max_shift: float = 0.0,
spec_aug_ratio: float = 0.9,
):
super().__init__()
self.tokenizer = AmtTokenizer()
Expand All @@ -223,7 +223,10 @@ def __init__(
self.detune_ratio = detune_ratio
self.detune_max_shift = detune_max_shift
self.spec_aug_ratio = spec_aug_ratio
self.reduction_resample_rate = 6000 # Hardcoded?

self.time_mask_param = 2500
self.freq_mask_param = 15
self.reduction_resample_rate = 6000

# Audio aug
impulse_paths = self._get_paths(
Expand Down Expand Up @@ -263,10 +266,10 @@ def __init__(
)
self.spec_aug = torch.nn.Sequential(
torchaudio.transforms.FrequencyMasking(
freq_mask_param=15, iid_masks=True
freq_mask_param=self.freq_mask_param, iid_masks=True
),
torchaudio.transforms.TimeMasking(
time_mask_param=1000, iid_masks=True
time_mask_param=self.time_mask_param, iid_masks=True
),
)

Expand All @@ -281,6 +284,8 @@ def get_params(self):
"detune_ratio": self.detune_ratio,
"detune_max_shift": self.detune_max_shift,
"spec_aug_ratio": self.spec_aug_ratio,
"time_mask_param": self.time_mask_param,
"freq_mask_param": self.freq_mask_param,
}

def _get_paths(self, dir_path):
Expand Down
58 changes: 37 additions & 21 deletions amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@
from amt.audio import pad_or_trim


# Occasionally the worker util goes to 0 for some reason, debug this
def _check_onset_threshold(seq: list, onset: int):
for tok_1, tok_2 in zip(seq, seq[1:]):
if isinstance(tok_1, tuple) and tok_1[0] in ("on", "off"):
_onset = tok_2[1]
if _onset > onset:
return True

return False


def get_wav_mid_segments(
Expand Down Expand Up @@ -80,6 +87,12 @@ def get_wav_mid_segments(
end_ms=(idx + num_samples) / samples_per_ms,
max_pedal_len_ms=10000,
)

# Hardcoded to 2.5s
if _check_onset_threshold(mid_feature, 2500) is False:
print("No note messages after 2.5s - skipping")
continue

else:
mid_feature = []

Expand Down Expand Up @@ -136,7 +149,7 @@ def pianoteq_cmd_fn(mid_path: str, wav_path: str):
safe_wav_path = shlex.quote(wav_path)

# Construct the command
command = f"/home/mchorse/amt/pianoteq/x86-64bit/Pianoteq\\ 8\\ STAGE --preset {safe_preset} --midi {safe_mid_path} --wav {safe_wav_path}"
command = f"/home/mchorse/pianoteq/x86-64bit/Pianoteq\\ 8\\ STAGE --preset {safe_preset} --midi {safe_mid_path} --wav {safe_wav_path}"

return command

Expand Down Expand Up @@ -192,23 +205,22 @@ def write_synth_features(cli_cmd_fn: Callable, mid_path: str, save_path: str):
if os.path.isfile(audio_path_temp):
os.remove(audio_path_temp)

print(f"Found {len(features)}")

with open(save_path, mode="a") as file:
try:
for wav, seq in features:
wav_buffer = io.BytesIO()
torch.save(wav, wav_buffer)
wav_buffer.seek(0)
wav_bytes = wav_buffer.read()
wav_str = base64.b64encode(wav_bytes).decode("utf-8")
file.write(wav_str)
file.write("\n")

seq_bytes = orjson.dumps(seq)
seq_str = base64.b64encode(seq_bytes).decode("utf-8")
file.write(seq_str)
file.write("\n")
except Exception as e:
return
for wav, seq in features:
wav_buffer = io.BytesIO()
torch.save(wav, wav_buffer)
wav_buffer.seek(0)
wav_bytes = wav_buffer.read()
wav_str = base64.b64encode(wav_bytes).decode("utf-8")
file.write(wav_str)
file.write("\n")

seq_bytes = orjson.dumps(seq)
seq_str = base64.b64encode(seq_bytes).decode("utf-8")
file.write(seq_str)
file.write("\n")


def build_worker_fn(load_path_queue, save_path_queue, _save_path: str):
Expand All @@ -234,7 +246,11 @@ def build_synth_worker_fn(

while not load_path_queue.empty():
mid_path = load_path_queue.get()
write_synth_features(cli_cmd, mid_path, worker_save_path)
try:
write_synth_features(cli_cmd, mid_path, worker_save_path)
except Exception as e:
print("Failed")
print(e)

save_path_queue.put(worker_save_path)

Expand Down Expand Up @@ -299,7 +315,7 @@ def _format(tok):
seq_len=self.config["max_seq_len"],
)

return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt)
return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt), idx

def _build_index(self):
self.file_mmap.seek(0)
Expand All @@ -314,7 +330,7 @@ def _build_index(self):

return index

def _save_index(self, index: list[int], save_path: str):
def _save_index(self, index: list, save_path: str):
with open(save_path, "w") as file:
for idx in index:
file.write(f"{idx}\n")
Expand Down
20 changes: 10 additions & 10 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,14 @@ def process_segments(
[MAX_BLOCK_LEN for _ in prefixes], dtype=torch.int
).cuda()

for idx in (
pbar := tqdm(
range(min_prefix_len, MAX_BLOCK_LEN - 1),
total=MAX_BLOCK_LEN - (min_prefix_len + 1),
leave=False,
)
):
# for idx in range(min_prefix_len, MAX_BLOCK_LEN - 1):
# for idx in (
# pbar := tqdm(
# range(min_prefix_len, MAX_BLOCK_LEN - 1),
# total=MAX_BLOCK_LEN - (min_prefix_len + 1),
# leave=False,
# )
# ):
for idx in range(min_prefix_len, MAX_BLOCK_LEN - 1):
with torch.backends.cuda.sdp_kernel(
enable_flash=False, enable_mem_efficient=False, enable_math=True
):
Expand Down Expand Up @@ -274,7 +274,7 @@ def gpu_manager(
decode_token = torch.compile(
decode_token,
# mode="reduce-overhead",
mode="max-autotune",
# mode="max-autotune",
fullgraph=True,
)

Expand Down Expand Up @@ -743,7 +743,7 @@ def batch_transcribe(
for p in gpu_manager_processes:
p.start()
watchdog_process = multiprocessing.Process(
target=watchdog, args=(gpu_batch_manager_process[0].pid, child_pids)
target=watchdog, args=(gpu_batch_manager_process.pid, child_pids)
)
watchdog_process.start()
else:
Expand Down
14 changes: 7 additions & 7 deletions amt/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,16 @@ def build_maestro(maestro_dir, train_file, val_file, test_file, num_procs):

print(f"Building {train_file}")
AmtDataset.build(
load_paths=matched_paths_train,
load_paths=matched_paths_train + matched_paths_val,
save_path=train_file,
num_processes=num_procs,
)
print(f"Building {val_file}")
AmtDataset.build(
load_paths=matched_paths_val,
save_path=val_file,
num_processes=num_procs,
)
# print(f"Building {val_file}")
# AmtDataset.build(
# load_paths=matched_paths_val,
# save_path=val_file,
# num_processes=num_procs,
# )
print(f"Building {test_file}")
AmtDataset.build(
load_paths=matched_paths_test,
Expand Down
Loading

0 comments on commit 1f423d8

Please sign in to comment.