diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 13054bd561..6062b5352f 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -62,7 +62,7 @@ def test_access_attn_matrix(self): # be not able to access the matrix no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate) no_matrix_acess_blk(torch.randn(input_shape)) - assert type(no_matrix_acess_blk.att_mat) == torch.Tensor + assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor) # no of elements is zero assert no_matrix_acess_blk.att_mat.nelement() == 0 diff --git a/tests/test_transformerblock.py b/tests/test_transformerblock.py index a966dcbfdc..914336668d 100644 --- a/tests/test_transformerblock.py +++ b/tests/test_transformerblock.py @@ -66,7 +66,7 @@ def test_access_attn_matrix(self): hidden_size=hidden_size, mlp_dim=mlp_dim, num_heads=num_heads, dropout_rate=dropout_rate ) no_matrix_acess_blk(torch.randn(input_shape)) - assert type(no_matrix_acess_blk.attn.att_mat) == torch.Tensor + assert isinstance(no_matrix_acess_blk.attn.att_mat, torch.Tensor) # no of elements is zero assert no_matrix_acess_blk.attn.att_mat.nelement() == 0 diff --git a/tests/test_vit.py b/tests/test_vit.py index d193e6d222..f911c2d5c9 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -160,7 +160,7 @@ def test_access_attn_matrix(self): # no data in the matrix no_matrix_acess_blk = ViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size) no_matrix_acess_blk(torch.randn(in_shape)) - assert type(no_matrix_acess_blk.blocks[0].attn.att_mat) == torch.Tensor + assert isinstance(no_matrix_acess_blk.blocks[0].attn.att_mat, torch.Tensor) # no of elements is zero assert no_matrix_acess_blk.blocks[0].attn.att_mat.nelement() == 0