Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/cgoliver/rnaglib
Browse files Browse the repository at this point in the history
  • Loading branch information
wisskarrou committed Jan 7, 2025
2 parents ba4b728 + 00a6732 commit 3a4ec11
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 51 deletions.
4 changes: 2 additions & 2 deletions src/rnaglib/splitters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .splitting_utils import random_split
"""imports for splitting module"""

from .splitting_utils import split_dataset, random_split
from .splitters import Splitter, RandomSplitter, NameSplitter
from .splitters import default_splitter_tr60_tr18, get_ribosomal_rnas
from .splitters import SPLITTING_VARS
from .splitting_utils import split_dataset
from .similarity_splitter import ClusterSplitter, RNAalignSplitter, CDHitSplitter

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion src/rnaglib/splitters/linear_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def assign_clusters(
clusters: List[List[str]],
cluster_counters: List[Counter],
split_ratios: Tuple[float, float, float] = (0.7, 0.15, 0.15),
ratio_tolerance: float = 0.4,
ratio_tolerance: float = 0.5,
size_weight: float = 1.0,
label_weight: float = 1.0,
) -> Tuple[
Expand Down
65 changes: 18 additions & 47 deletions src/rnaglib/splitters/similarity_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ def forward(self, dataset):
_, label_counts = label_counter(dataset)
print(f"dataset:{dataset}")

# TODO: simpler to get the proportions in the entire RNADataset, then just pass the relevant ones to the balancer
# here we get the names of the rnas in the clusters.
named_clusters = []
for cluster in clusters:
named_clusters.append(
Expand All @@ -75,7 +73,7 @@ def forward(self, dataset):
train, val, test = self.balancer(
named_clusters,
label_counts,
dataset,
keep_dataset,
(self.split_train, self.split_valid, self.split_test),
)
return train, val, test
Expand Down Expand Up @@ -105,52 +103,25 @@ def balancer(self, clusters, label_counts, dataset, fracs, n=0.2):
overall_counts = reduce(lambda x, y: x + y, labelcounts)
print(f"overall_counts:{overall_counts}")

train, val, test = assign_clusters(clusters, labelcounts)
train, val, test, metrics = assign_clusters(clusters, labelcounts)

#######
# This is a working splitter that considers desired splits size, but not yet label balance
test_size = max(1, int(len(dataset) * fracs[2]))
val_size = max(1, int(len(dataset) * fracs[1]))

random.seed(self.seed)

test = set()
val = set()
n_test = max(1, int(test_size * n))
n_val = max(1, int(val_size * n))

pool = list(range(len(dataset)))

print(f"test size:{test_size}")
while len(test) < test_size:
cluster = random.choice(clusters)
print(f"clusters:{clusters}")
clusters.remove(cluster)
if len(cluster) > n_test:
cluster = random.sample(cluster, n_test)
if len(cluster) > (test_size - len(test)):
cluster = random.sample(cluster, (test_size - len(test)))
test.update(cluster)
while len(val) < val_size:
cluster = random.choice(clusters)
print(f"clusters 2:{clusters}")
clusters.remove(cluster)
if len(cluster) > n_val:
cluster = random.sample(cluster, n_val)
if len(cluster) > (val_size - len(val)):
cluster = random.sample(cluster, (val_size - len(val)))
val.update(cluster)
# not readable but flattens list of sets to list (for pool)
pool = sorted([elem for cluster in clusters for elem in cluster])
test = sorted(list(test))
val = sorted(list(val))
print(f"train:{pool}") # DEBUG
print(f"test:{test}") # DEBUG
print(f"val:{val}") # DEBUG
print(f"metrics:{metrics}")
return (
[dataset[i] for i in pool],
[dataset[i] for i in test],
[dataset[i] for i in val],
[
dataset[x]
for x in range(len(dataset))
if dataset[x]["rna"].name in sum(train, [])
],
[
dataset[x]
for x in range(len(dataset))
if dataset[x]["rna"].name in sum(val, [])
],
[
dataset[x]
for x in range(len(dataset))
if dataset[x]["rna"].name in sum(test, [])
],
)

def compute_similarity_matrix(self, dataset: RNADataset) -> Tuple[np.array, List]:
Expand Down
3 changes: 2 additions & 1 deletion src/rnaglib/utils/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def US_align_wrapper(
]

result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout)
# print(result.stdout)
# uncomment above for debugging
# Regular expression to find all TM-scores
tm_scores = re.findall(r"TM-score=\s*([\d.]+)", result.stdout)

Expand Down

0 comments on commit 3a4ec11

Please sign in to comment.