Skip to content

Commit

Permalink
fixed model tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fratajcz committed Oct 14, 2024
1 parent f136897 commit 80bc9ec
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 17 deletions.
10 changes: 5 additions & 5 deletions speos/preprocessing/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@


class GeneDataset(InMemoryDataset):
def __init__(self, name, config, holdout_size: float = 0.5, transform=None, pre_transform=None):
def __init__(self, name, config, holdout_size: float = 0.5, transform=None, pre_transform=None, preprocessor_kwargs={}):
self.root = config.input.save_dir
self.save = config.input.save_data
self.name = name
self.config = config
self.holdout_size = holdout_size
self.preprocessor = InputHandler(config).get_preprocessor()
self.preprocessor = InputHandler(config, preprocessor_kwargs=preprocessor_kwargs).get_preprocessor()
self.logger_args = [config, __name__]
self.num_relations = self.preprocessor.get_num_relations()
logger = setup_logger(*self.logger_args)
Expand Down Expand Up @@ -143,15 +143,15 @@ def process(self):


class DatasetBootstrapper:
def __init__(self, name, config, holdout_size: float = 0.05):
def __init__(self, name, config, holdout_size: float = 0.05, preprocessor_kwargs={}):

# sadly we have to check here how many adjacencies we are gonna get. The actual preprocessing starts within the dataset class
adjacencies = AdjacencyMapper(config.input.adjacency_mappings, blacklist=config.input.adjacency_blacklist).get_mappings(config.input.adjacency, fields=config.input.adjacency_field)

if len(adjacencies) > 1 or config.input.force_multigraph:
self.dataset = MultiGeneDataset(name, config, holdout_size)
self.dataset = MultiGeneDataset(name, config, holdout_size, preprocessor_kwargs=preprocessor_kwargs)
else:
self.dataset = GeneDataset(name, config, holdout_size)
self.dataset = GeneDataset(name, config, holdout_size, preprocessor_kwargs=preprocessor_kwargs)

def get_dataset(self):
return self.dataset
Expand Down
2 changes: 1 addition & 1 deletion speos/preprocessing/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from speos.preprocessing.preprocessor import PreProcessor

class InputHandler:
def __init__(self, config, **preprocessor_kwargs):
def __init__(self, config, preprocessor_kwargs):
""" Utility class that strings together gwas and adjacency mapping and feeds it into the preprocessor
Args:
Expand Down
31 changes: 20 additions & 11 deletions speos/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,25 @@ def test_repackage_into_one_sequential(self):
self.assertTrue(torch.eq(old_param, new_param).all())

def test_forward(self):
dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config).get_dataset()
dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config,
preprocessor_kwargs=self.prepro_kwargs).get_dataset()

model = ModelBootstrapper(self.config, dataset.data.x.shape[1], 1).get_model()

train_out, loss = model.step(dataset.data, dataset.data.train_mask)

def test_forward_concat(self):
dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config).get_dataset()
dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config,
preprocessor_kwargs=self.prepro_kwargs).get_dataset()
self.config.model.concat_after_mp = True

model = ModelBootstrapper(self.config, dataset.data.x.shape[1], 1).get_model()

train_out, loss = model.step(dataset.data, dataset.data.train_mask)

def test_forward_skip(self):
dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config).get_dataset()
dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config,
preprocessor_kwargs=self.prepro_kwargs).get_dataset()

config = self.config.deepcopy()
config.model.skip_mp = True
Expand Down Expand Up @@ -190,11 +193,12 @@ def test_balance_classes(self):
def test_forward_random_input_features(self):

config = self.config.deepcopy()
config.input.adjacency = ["BioPlex 3.0 293T"]
config.input.adjacency = ["DummyUndirected"]
config.input.use_gwas = False
config.input.use_expression = False

dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config).get_dataset()
dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config,
preprocessor_kwargs=self.prepro_kwargs).get_dataset()

self.model = ModelBootstrapper(config, dataset.data.x.shape[1], 1).get_model()

Expand All @@ -203,12 +207,13 @@ def test_forward_random_input_features(self):
def test_forward_mlp_random_input_features(self):

config = self.config.deepcopy()
config.input.adjacency = ["BioPlex 3.0 293T"]
config.input.adjacency = ["DummyUndirected"]
config.input.use_gwas = False
config.input.use_expression = False
config.model.mp.n_layers = 0

dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config).get_dataset()
dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config,
preprocessor_kwargs=self.prepro_kwargs).get_dataset()

self.model = ModelBootstrapper(config, dataset.data.x.shape[1], 1).get_model()

Expand Down Expand Up @@ -288,7 +293,8 @@ def test_bootstrap_gat(self):

def test_forward_rgcn(self):
self.config.input.save_data = True
dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config).get_dataset()
dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config,
preprocessor_kwargs=self.prepro_kwargs).get_dataset()

model = ModelBootstrapper(self.config, dataset.data.x.shape[1], 2).get_model()

Expand All @@ -301,7 +307,8 @@ def test_forward_force_multigraph(self):
config.input.force_multigraph = True
config.model.mp.type = "film"

dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config).get_dataset()
dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config,
preprocessor_kwargs=self.prepro_kwargs).get_dataset()

self.model = ModelBootstrapper(config, dataset.data.x.shape[1], 1).get_model()

Expand All @@ -310,7 +317,8 @@ def test_forward_force_multigraph(self):
def test_forward_rtag(self):
import numpy as np
self.config.input.save_data = True
dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config).get_dataset()
dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config,
preprocessor_kwargs=self.prepro_kwargs).get_dataset()

config = self.config.deepcopy()
config.model.mp.type = "rtag"
Expand All @@ -331,7 +339,8 @@ def test_forward_rtag(self):
def test_forward_filmtag(self):
import numpy as np
self.config.input.save_data = True
dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config).get_dataset()
dataset = DatasetBootstrapper(holdout_size=self.config.input.holdout_size, name=self.config.name, config=self.config,
preprocessor_kwargs=self.prepro_kwargs).get_dataset()

config = self.config.deepcopy()
config.model.mp.type = "filmtag"
Expand Down

0 comments on commit 80bc9ec

Please sign in to comment.