Skip to content

Commit

Permalink
[feat] torch-tensorrt: support compile inputs size=0/1 (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
yjjinjie authored Nov 22, 2024
1 parent 7b331c2 commit 3bb4b22
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tzrec/acc/trt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,16 @@ def export_model_trt(
values_list_cuda = []
for i, value in enumerate(emb_res):
v = value.detach().to("cuda:0")
values_list_cuda.append(v)
dict_dy = {0: batch}
if v.dim() == 3:
# workaround -> 0/1 specialization
if v.size(1) < 2:
v = torch.zeros(v.size(0), 2, v.size(2), device="cuda:0", dtype=v.dtype)
dict_dy[1] = torch.export.Dim("seq_len" + str(i), min=1, max=max_seq_len)

if v.size(0) < 2:
v = torch.zeros((2,) + v.size()[1:], device="cuda:0", dtype=v.dtype)
values_list_cuda.append(v)
dynamic_shapes_list.append(dict_dy)

# convert dense
Expand Down

0 comments on commit 3bb4b22

Please sign in to comment.