Skip to content

Commit

Permalink
convert _expr.Constant into int in the unflattened_size
Browse files Browse the repository at this point in the history
  • Loading branch information
mshr-h committed Nov 27, 2023
1 parent c018d43 commit 4cc9c60
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,15 +1555,21 @@ def unflatten(self, inputs, input_types):
dim = dim if dim >= 0 else len(dshape) + dim
assert len(dshape) > dim >= 0

assert unflattened_size.count(-1) <= 1
new_unflattened_size = []
for s in unflattened_size:
if isinstance(s, _expr.Constant):
s = s.data.numpy().item()
new_unflattened_size.append(s)

assert new_unflattened_size.count(-1) <= 1

mult = np.multiply.reduce(unflattened_size)
mult = np.multiply.reduce(new_unflattened_size)
if mult < 0:
assert dshape[dim] % mult == 0
else:
assert dshape[dim] == mult

new_shape = dshape[:dim] + unflattened_size + dshape[dim + 1 :]
new_shape = dshape[:dim] + tuple(new_unflattened_size) + dshape[dim + 1 :]
out = _op.reshape(data, new_shape)
return out

Expand Down

0 comments on commit 4cc9c60

Please sign in to comment.