From 4cc9c60e8f200ce99324d1c767e6502365e6f600 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 27 Nov 2023 09:52:32 +0900 Subject: [PATCH] convert `_expr.Constant` into `int` in the unflattened_size --- python/tvm/relay/frontend/pytorch.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 03afcfb9260c1..2eb3b9b712925 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1555,15 +1555,21 @@ def unflatten(self, inputs, input_types): dim = dim if dim >= 0 else len(dshape) + dim assert len(dshape) > dim >= 0 - assert unflattened_size.count(-1) <= 1 + new_unflattened_size = [] + for s in unflattened_size: + if isinstance(s, _expr.Constant): + s = s.data.numpy().item() + new_unflattened_size.append(s) + + assert new_unflattened_size.count(-1) <= 1 - mult = np.multiply.reduce(unflattened_size) + mult = np.multiply.reduce(new_unflattened_size) if mult < 0: assert dshape[dim] % mult == 0 else: assert dshape[dim] == mult - new_shape = dshape[:dim] + unflattened_size + dshape[dim + 1 :] + new_shape = dshape[:dim] + tuple(new_unflattened_size) + dshape[dim + 1 :] out = _op.reshape(data, new_shape) return out