Skip to content

Commit

Permalink
fix #6799
Browse files Browse the repository at this point in the history
Signed-off-by: KumoLiu <[email protected]>
  • Loading branch information
KumoLiu committed Jul 31, 2023
1 parent 5feb353 commit 3d3e267
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tests/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_access_attn_matrix(self):
# be not able to access the matrix
no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate)
no_matrix_acess_blk(torch.randn(input_shape))
assert type(no_matrix_acess_blk.att_mat) == torch.Tensor
assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor)
# no of elements is zero
assert no_matrix_acess_blk.att_mat.nelement() == 0

Expand Down
2 changes: 1 addition & 1 deletion tests/test_transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_access_attn_matrix(self):
hidden_size=hidden_size, mlp_dim=mlp_dim, num_heads=num_heads, dropout_rate=dropout_rate
)
no_matrix_acess_blk(torch.randn(input_shape))
assert type(no_matrix_acess_blk.attn.att_mat) == torch.Tensor
assert isinstance(no_matrix_acess_blk.attn.att_mat, torch.Tensor)
# no of elements is zero
assert no_matrix_acess_blk.attn.att_mat.nelement() == 0

Expand Down
2 changes: 1 addition & 1 deletion tests/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_access_attn_matrix(self):
# no data in the matrix
no_matrix_acess_blk = ViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size)
no_matrix_acess_blk(torch.randn(in_shape))
assert type(no_matrix_acess_blk.blocks[0].attn.att_mat) == torch.Tensor
assert isinstance(no_matrix_acess_blk.blocks[0].attn.att_mat, torch.Tensor)

Check warning on line 163 in tests/test_vit.py

View check run for this annotation

Codecov / codecov/patch

tests/test_vit.py#L163

Added line #L163 was not covered by tests
# no of elements is zero
assert no_matrix_acess_blk.blocks[0].attn.att_mat.nelement() == 0

Expand Down

0 comments on commit 3d3e267

Please sign in to comment.