From c2d939af91933ab4cdd01090d64b6ac42765b500 Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 27 Apr 2024 09:48:52 +0000 Subject: [PATCH 1/3] update README --- README.md | 13 +++++++++---- amt/data.py | 3 +++ scripts/{eval => }/split.py | 7 ++++--- 3 files changed, 16 insertions(+), 7 deletions(-) rename scripts/{eval => }/split.py (89%) diff --git a/README.md b/README.md index 831948c..2ea35e7 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,14 @@ pip install -e . Download the preliminary model weights: +Piano (not final) ``` -wget https://storage.googleapis.com/aria-checkpoints/amt/small-0.safetensors +wget https://storage.googleapis.com/aria-checkpoints/amt/guitar-temp.safetensors +``` + +Classical guitar (not final) +``` +wget https://storage.googleapis.com/aria-checkpoints/amt/piano-temp.safetensors ``` ## Usage @@ -39,7 +45,6 @@ aria-amt transcribe \ -q8 ``` -If you want to do batch transcription, use the `-load_dir` flag and adjust `-bs` accordingly. Compiling may take some time, but provides a significant speedup. - -NOTE: Currently only bf16 is supported. +If you want to do batch transcription, use the `-load_dir` flag and adjust `-bs` accordingly. Compiling and may take some time, but provides a significant speedup. Quantizing (`-q8`) further speeds up inference when the `-compile` flag is also used. +NOTE: Int8 quantization is only supported on GPUs that support BF16. diff --git a/amt/data.py b/amt/data.py index 4bf3ccb..6abb2f3 100644 --- a/amt/data.py +++ b/amt/data.py @@ -371,6 +371,9 @@ def build( num_processes: int = 1, ): assert os.path.isfile(save_path) is False, f"{save_path} already exists" + assert ( + len(save_path.rsplit(".", 1)) == 2 + ), "path is missing a file extension" index_path = AmtDataset._get_index_path(load_path=save_path) if os.path.isfile(index_path): diff --git a/scripts/eval/split.py b/scripts/split.py similarity index 89% rename from scripts/eval/split.py rename to scripts/split.py index c912cbc..ef688b3 100644 --- a/scripts/eval/split.py +++ b/scripts/split.py @@ -33,13 +33,13 @@ def get_matched_paths(audio_dir: str, mid_dir: str): return res -def create_csv(matched_paths, csv_path): +def create_csv(matched_paths, csv_path, ratio): split_csv = open(csv_path, "w") csv_writer = csv.writer(split_csv) csv_writer.writerow(["mid_path", "audio_path", "split"]) for audio_path, mid_path in matched_paths: - if random.random() < 0.1: + if random.random() < ratio: csv_writer.writerow([mid_path, audio_path, "test"]) else: csv_writer.writerow([mid_path, audio_path, "train"]) @@ -50,8 +50,9 @@ def create_csv(matched_paths, csv_path): parser.add_argument("-mid_dir", type=str) parser.add_argument("-audio_dir", type=str) parser.add_argument("-csv_path", type=str) + parser.add_argument("-ratio", type=int, default=0.1) args = parser.parse_args() matched_paths = get_matched_paths(args.audio_dir, args.mid_dir) - create_csv(matched_paths, args.csv_path) + create_csv(matched_paths, args.csv_path, args.ratio) From 148db775c6d4eda654ce27b4e6f9b93a791fd1ba Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 27 Apr 2024 09:50:34 +0000 Subject: [PATCH 2/3] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2ea35e7..7327587 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,6 @@ aria-amt transcribe \ -q8 ``` -If you want to do batch transcription, use the `-load_dir` flag and adjust `-bs` accordingly. Compiling and may take some time, but provides a significant speedup. Quantizing (`-q8`) further speeds up inference when the `-compile` flag is also used. +If you want to do batch transcription, use the `-load_dir` flag and adjust `-bs` accordingly. Compiling and may take some time, but provides a significant speedup. Quantizing (`-q8` flag) further speeds up inference when the `-compile` flag is also used. NOTE: Int8 quantization is only supported on GPUs that support BF16. From 69dca6358acfdf7e84c4629d485eb716da52ac9b Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 27 Apr 2024 09:59:30 +0000 Subject: [PATCH 3/3] add fp16 --- amt/inference/model.py | 3 +++ amt/inference/transcribe.py | 8 ++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/amt/inference/model.py b/amt/inference/model.py index e8390f4..44655c6 100644 --- a/amt/inference/model.py +++ b/amt/inference/model.py @@ -386,6 +386,7 @@ def setup_cache( batch_size, max_seq_len=4096, max_audio_len=1500, + dtype=torch.bfloat16, ): self.causal_mask = torch.tril( torch.ones(max_seq_len, max_seq_len, dtype=torch.bool) @@ -397,12 +398,14 @@ def setup_cache( max_seq_length=max_seq_len, n_heads=8, head_dim=64, + dtype=dtype, ).cuda() b.cross_attn.kv_cache = KVCache( max_batch_size=batch_size, max_seq_length=max_audio_len, n_heads=8, head_dim=64, + dtype=dtype, ).cuda() diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index 622b005..ba50102 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -132,7 +132,7 @@ def wrapper(*args, **kwargs): with torch.autocast("cuda", dtype=torch.bfloat16): return func(*args, **kwargs) else: - with torch.autocast("cuda", dtype=torch.float32): + with torch.autocast("cuda", dtype=torch.float16): return func(*args, **kwargs) return wrapper @@ -265,7 +265,11 @@ def gpu_manager( if gpu_id is not None: os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) - model.decoder.setup_cache(batch_size=batch_size, max_seq_len=MAX_BLOCK_LEN) + model.decoder.setup_cache( + batch_size=batch_size, + max_seq_len=MAX_BLOCK_LEN, + dtype=torch.bfloat16 if is_bf16_supported() else torch.float16, + ) model.cuda() model.eval() if compile is True: