From 49327f153ce65c34e631f0e14f96e045df435ee5 Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 27 Apr 2024 11:03:14 +0100 Subject: [PATCH] Fix fp16 support (#33) * update README * update README * add fp16 --- amt/inference/model.py | 3 +++ amt/inference/transcribe.py | 8 ++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/amt/inference/model.py b/amt/inference/model.py index e8390f4..44655c6 100644 --- a/amt/inference/model.py +++ b/amt/inference/model.py @@ -386,6 +386,7 @@ def setup_cache( batch_size, max_seq_len=4096, max_audio_len=1500, + dtype=torch.bfloat16, ): self.causal_mask = torch.tril( torch.ones(max_seq_len, max_seq_len, dtype=torch.bool) @@ -397,12 +398,14 @@ def setup_cache( max_seq_length=max_seq_len, n_heads=8, head_dim=64, + dtype=dtype, ).cuda() b.cross_attn.kv_cache = KVCache( max_batch_size=batch_size, max_seq_length=max_audio_len, n_heads=8, head_dim=64, + dtype=dtype, ).cuda() diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index 622b005..ba50102 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -132,7 +132,7 @@ def wrapper(*args, **kwargs): with torch.autocast("cuda", dtype=torch.bfloat16): return func(*args, **kwargs) else: - with torch.autocast("cuda", dtype=torch.float32): + with torch.autocast("cuda", dtype=torch.float16): return func(*args, **kwargs) return wrapper @@ -265,7 +265,11 @@ def gpu_manager( if gpu_id is not None: os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) - model.decoder.setup_cache(batch_size=batch_size, max_seq_len=MAX_BLOCK_LEN) + model.decoder.setup_cache( + batch_size=batch_size, + max_seq_len=MAX_BLOCK_LEN, + dtype=torch.bfloat16 if is_bf16_supported() else torch.float16, + ) model.cuda() model.eval() if compile is True: