diff --git a/src/rnaglib/splitters/__init__.py b/src/rnaglib/splitters/__init__.py index c039908..2ed5892 100644 --- a/src/rnaglib/splitters/__init__.py +++ b/src/rnaglib/splitters/__init__.py @@ -1,9 +1,9 @@ -from .splitting_utils import random_split +"""imports for splitting module""" +from .splitting_utils import split_dataset, random_split from .splitters import Splitter, RandomSplitter, NameSplitter from .splitters import default_splitter_tr60_tr18, get_ribosomal_rnas from .splitters import SPLITTING_VARS -from .splitting_utils import split_dataset from .similarity_splitter import ClusterSplitter, RNAalignSplitter, CDHitSplitter __all__ = [ diff --git a/src/rnaglib/splitters/linear_optimisation.py b/src/rnaglib/splitters/linear_optimisation.py index ae483b8..ee95a8f 100644 --- a/src/rnaglib/splitters/linear_optimisation.py +++ b/src/rnaglib/splitters/linear_optimisation.py @@ -100,7 +100,7 @@ def assign_clusters( clusters: List[List[str]], cluster_counters: List[Counter], split_ratios: Tuple[float, float, float] = (0.7, 0.15, 0.15), - ratio_tolerance: float = 0.4, + ratio_tolerance: float = 0.5, size_weight: float = 1.0, label_weight: float = 1.0, ) -> Tuple[ diff --git a/src/rnaglib/splitters/similarity_splitter.py b/src/rnaglib/splitters/similarity_splitter.py index 8bbd45b..ad4fd9f 100644 --- a/src/rnaglib/splitters/similarity_splitter.py +++ b/src/rnaglib/splitters/similarity_splitter.py @@ -63,8 +63,6 @@ def forward(self, dataset): _, label_counts = label_counter(dataset) print(f"dataset:{dataset}") - # TODO: simpler to get the proportions in the entire RNADataset, then just pass the relevant ones to the balancer - # here we get the names of the rnas in the clusters. named_clusters = [] for cluster in clusters: named_clusters.append( @@ -75,7 +73,7 @@ def forward(self, dataset): train, val, test = self.balancer( named_clusters, label_counts, - dataset, + keep_dataset, (self.split_train, self.split_valid, self.split_test), ) return train, val, test @@ -105,52 +103,25 @@ def balancer(self, clusters, label_counts, dataset, fracs, n=0.2): overall_counts = reduce(lambda x, y: x + y, labelcounts) print(f"overall_counts:{overall_counts}") - train, val, test = assign_clusters(clusters, labelcounts) + train, val, test, metrics = assign_clusters(clusters, labelcounts) - ####### - # This is a working splitter that considers desired splits size, but not yet label balance - test_size = max(1, int(len(dataset) * fracs[2])) - val_size = max(1, int(len(dataset) * fracs[1])) - - random.seed(self.seed) - - test = set() - val = set() - n_test = max(1, int(test_size * n)) - n_val = max(1, int(val_size * n)) - - pool = list(range(len(dataset))) - - print(f"test size:{test_size}") - while len(test) < test_size: - cluster = random.choice(clusters) - print(f"clusters:{clusters}") - clusters.remove(cluster) - if len(cluster) > n_test: - cluster = random.sample(cluster, n_test) - if len(cluster) > (test_size - len(test)): - cluster = random.sample(cluster, (test_size - len(test))) - test.update(cluster) - while len(val) < val_size: - cluster = random.choice(clusters) - print(f"clusters 2:{clusters}") - clusters.remove(cluster) - if len(cluster) > n_val: - cluster = random.sample(cluster, n_val) - if len(cluster) > (val_size - len(val)): - cluster = random.sample(cluster, (val_size - len(val))) - val.update(cluster) - # not readable but flattens list of sets to list (for pool) - pool = sorted([elem for cluster in clusters for elem in cluster]) - test = sorted(list(test)) - val = sorted(list(val)) - print(f"train:{pool}") # DEBUG - print(f"test:{test}") # DEBUG - print(f"val:{val}") # DEBUG + print(f"metrics:{metrics}") return ( - [dataset[i] for i in pool], - [dataset[i] for i in test], - [dataset[i] for i in val], + [ + dataset[x] + for x in range(len(dataset)) + if dataset[x]["rna"].name in sum(train, []) + ], + [ + dataset[x] + for x in range(len(dataset)) + if dataset[x]["rna"].name in sum(val, []) + ], + [ + dataset[x] + for x in range(len(dataset)) + if dataset[x]["rna"].name in sum(test, []) + ], ) def compute_similarity_matrix(self, dataset: RNADataset) -> Tuple[np.array, List]: diff --git a/src/rnaglib/utils/wrappers.py b/src/rnaglib/utils/wrappers.py index a703070..987f0dd 100644 --- a/src/rnaglib/utils/wrappers.py +++ b/src/rnaglib/utils/wrappers.py @@ -44,7 +44,8 @@ def US_align_wrapper( ] result = subprocess.run(command, capture_output=True, text=True) - print(result.stdout) + # print(result.stdout) + # uncomment above for debugging # Regular expression to find all TM-scores tm_scores = re.findall(r"TM-score=\s*([\d.]+)", result.stdout)