diff --git a/tzrec/acc/trt_utils.py b/tzrec/acc/trt_utils.py index 9ea1279..49e8e7a 100644 --- a/tzrec/acc/trt_utils.py +++ b/tzrec/acc/trt_utils.py @@ -191,10 +191,16 @@ def export_model_trt( values_list_cuda = [] for i, value in enumerate(emb_res): v = value.detach().to("cuda:0") - values_list_cuda.append(v) dict_dy = {0: batch} if v.dim() == 3: + # workaround -> 0/1 specialization + if v.size(1) < 2: + v = torch.zeros(v.size(0), 2, v.size(2), device="cuda:0", dtype=v.dtype) dict_dy[1] = torch.export.Dim("seq_len" + str(i), min=1, max=max_seq_len) + + if v.size(0) < 2: + v = torch.zeros((2,) + v.size()[1:], device="cuda:0", dtype=v.dtype) + values_list_cuda.append(v) dynamic_shapes_list.append(dict_dy) # convert dense