Skip to content

Commit

Permalink
fixed model tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fratajcz committed Oct 14, 2024
1 parent ea27750 commit 4b025c3
Showing 1 changed file with 2 additions and 16 deletions.
18 changes: 2 additions & 16 deletions speos/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def test_forward_mlp_random_input_features(self):
train_out, loss = self.model.step(dataset.data, dataset.data.train_mask)


class RelationalGeneNetworkTest(unittest.TestCase):
class RelationalGeneNetworkTest(TestSetup):

def setUp(self):
self.config = Config()
Expand Down Expand Up @@ -271,20 +271,6 @@ def test_bootstrap_film(self):
layers = [module for module in model.architectures[0].mp.modules() if not isinstance(module, nn.Sequential)]
self.assertEqual("FiLMConv", str(layers[1].__class__.__name__))

def test_bootstrap_filmtag(self):
config = self.config.deepcopy()
config.model.mp.type = "filmtag"
model = ModelBootstrapper(config, 90, 2).get_model()
layers = [module for module in model.architectures[0].mp.modules() if not isinstance(module, nn.Sequential)]
self.assertEqual("FiLMTAGConv", str(layers[1].__class__.__name__))

def test_bootstrap_rtag(self):
config = self.config.deepcopy()
config.model.mp.type = "rtag"
model = ModelBootstrapper(config, 90, 2).get_model()
layers = [module for module in model.architectures[0].mp.modules() if not isinstance(module, nn.Sequential)]
self.assertEqual("RTAGConv", str(layers[1].__class__.__name__))

def test_bootstrap_gat(self):
config = self.config.deepcopy()
config.model.mp.type = "rgat"
Expand Down Expand Up @@ -360,7 +346,7 @@ def test_forward_filmtag(self):
self.assertFalse(np.isnan(loss))


class SKLearnModelTest(unittest.TestCase):
class SKLearnModelTest(TestSetup):

def setUp(self):
self.config = Config()
Expand Down

0 comments on commit 4b025c3

Please sign in to comment.