From cde4b685aa8380356b6dca34e791e5c627a436bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=91=BE=E6=B4=81?= Date: Thu, 21 Nov 2024 20:54:23 +0800 Subject: [PATCH] torch-tensorrt: support compile inputs size=0/1 --- tzrec/acc/trt_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tzrec/acc/trt_utils.py b/tzrec/acc/trt_utils.py index 9ea1279..1ed06a1 100644 --- a/tzrec/acc/trt_utils.py +++ b/tzrec/acc/trt_utils.py @@ -191,10 +191,17 @@ def export_model_trt( values_list_cuda = [] for i, value in enumerate(emb_res): v = value.detach().to("cuda:0") - values_list_cuda.append(v) dict_dy = {0: batch} if v.dim() == 3: + # workaround -> 0/1 specialization + if v.size(1) < 2: + v = torch.zeros(v.size(0), 2, v.size(2), device="cuda:0") dict_dy[1] = torch.export.Dim("seq_len" + str(i), min=1, max=max_seq_len) + + if v.size(0) < 2: + v = torch.zeros((2,) + v.size()[1:], device="cuda:0") + + values_list_cuda.append(v) dynamic_shapes_list.append(dict_dy) # convert dense