Skip to content

Commit

Permalink
Test model bands (#2383)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart authored Nov 1, 2024
1 parent a473222 commit f76cb9d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tests/models/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def test_resnet(self) -> None:
def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
resnet18(weights=mocked_weights)

def test_bands(self, mocked_weights: WeightsEnum) -> None:
if 'bands' in mocked_weights.meta:
assert len(mocked_weights.meta['bands']) == mocked_weights.meta['in_chans']

def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta['in_chans']
sample = {
Expand Down Expand Up @@ -88,6 +92,10 @@ def test_resnet(self) -> None:
def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
resnet50(weights=mocked_weights)

def test_bands(self, mocked_weights: WeightsEnum) -> None:
if 'bands' in mocked_weights.meta:
assert len(mocked_weights.meta['bands']) == mocked_weights.meta['in_chans']

def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta['in_chans']
sample = {
Expand Down Expand Up @@ -128,6 +136,10 @@ def test_resnet(self) -> None:
def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
resnet152(weights=mocked_weights)

def test_bands(self, mocked_weights: WeightsEnum) -> None:
if 'bands' in mocked_weights.meta:
assert len(mocked_weights.meta['bands']) == mocked_weights.meta['in_chans']

def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta['in_chans']
sample = {
Expand Down
8 changes: 8 additions & 0 deletions tests/models/test_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def test_swin_v2_t(self) -> None:
def test_swin_v2_t_weights(self, mocked_weights: WeightsEnum) -> None:
swin_v2_t(weights=mocked_weights)

def test_bands(self, mocked_weights: WeightsEnum) -> None:
if 'bands' in mocked_weights.meta:
assert len(mocked_weights.meta['bands']) == mocked_weights.meta['in_chans']

def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta['in_chans']
sample = {
Expand Down Expand Up @@ -91,6 +95,10 @@ def test_swin_v2_b(self) -> None:
def test_swin_v2_b_weights(self, mocked_weights: WeightsEnum) -> None:
swin_v2_b(weights=mocked_weights)

def test_bands(self, mocked_weights: WeightsEnum) -> None:
if 'bands' in mocked_weights.meta:
assert len(mocked_weights.meta['bands']) == mocked_weights.meta['in_chans']

def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta['in_chans']
sample = {
Expand Down
4 changes: 4 additions & 0 deletions tests/models/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def test_vit(self) -> None:
def test_vit_weights(self, mocked_weights: WeightsEnum) -> None:
vit_small_patch16_224(weights=mocked_weights)

def test_bands(self, mocked_weights: WeightsEnum) -> None:
if 'bands' in mocked_weights.meta:
assert len(mocked_weights.meta['bands']) == mocked_weights.meta['in_chans']

def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta['in_chans']
sample = {
Expand Down

0 comments on commit f76cb9d

Please sign in to comment.