Skip to content

Commit

Permalink
Fix return lstm
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 4, 2024
1 parent 80646b9 commit 32ac8e9
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions tests/brevitas/nn/test_nn_quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,23 +121,18 @@ def test_quant_lstm_rnn_full(model_input, current_cases):
if h is not None:
if return_quant_tensor and kwargs['io_quant'] is not None:
assert isinstance(h, QuantTensor)
assert h.scale is not None
assert h.bit_width is not None
else:
assert isinstance(h, torch.Tensor)

if c is not None:
if kwargs['signed_act'] is None or not kwargs['return_quant_tensor']:
if not kwargs['bidirectional']:
if not kwargs['return_quant_tensor']:
assert isinstance(c, torch.Tensor)
elif kwargs['return_quant_tensor'] and kwargs['signed_act'] is None and kwargs[
'num_layers'] == 2:
assert isinstance(c, torch.Tensor)
else:
assert isinstance(c, QuantTensor)
else:
assert isinstance(c, torch.Tensor)
if kwargs['signed_act'] is None or not return_quant_tensor:
assert isinstance(c, torch.Tensor)
else:
assert isinstance(c, QuantTensor)
assert c.scale is not None
assert c.bit_width is not None


@pytest_cases.parametrize_with_cases('model_input', cases=[case_quant_lstm, case_quant_rnn])
Expand Down

0 comments on commit 32ac8e9

Please sign in to comment.