diff --git a/speos/tests/test_models.py b/speos/tests/test_models.py index a070955..fd25f56 100644 --- a/speos/tests/test_models.py +++ b/speos/tests/test_models.py @@ -221,7 +221,7 @@ def test_forward_mlp_random_input_features(self): train_out, loss = self.model.step(dataset.data, dataset.data.train_mask) -class RelationalGeneNetworkTest(unittest.TestCase): +class RelationalGeneNetworkTest(TestSetup): def setUp(self): self.config = Config() @@ -271,20 +271,6 @@ def test_bootstrap_film(self): layers = [module for module in model.architectures[0].mp.modules() if not isinstance(module, nn.Sequential)] self.assertEqual("FiLMConv", str(layers[1].__class__.__name__)) - def test_bootstrap_filmtag(self): - config = self.config.deepcopy() - config.model.mp.type = "filmtag" - model = ModelBootstrapper(config, 90, 2).get_model() - layers = [module for module in model.architectures[0].mp.modules() if not isinstance(module, nn.Sequential)] - self.assertEqual("FiLMTAGConv", str(layers[1].__class__.__name__)) - - def test_bootstrap_rtag(self): - config = self.config.deepcopy() - config.model.mp.type = "rtag" - model = ModelBootstrapper(config, 90, 2).get_model() - layers = [module for module in model.architectures[0].mp.modules() if not isinstance(module, nn.Sequential)] - self.assertEqual("RTAGConv", str(layers[1].__class__.__name__)) - def test_bootstrap_gat(self): config = self.config.deepcopy() config.model.mp.type = "rgat" @@ -360,7 +346,7 @@ def test_forward_filmtag(self): self.assertFalse(np.isnan(loss)) -class SKLearnModelTest(unittest.TestCase): +class SKLearnModelTest(TestSetup): def setUp(self): self.config = Config()