Skip to content

Commit

Permalink
update demos of different tasks and PartitionFormDict
Browse files Browse the repository at this point in the history
  • Loading branch information
wisskarrou committed Jan 7, 2025
1 parent 10cbe54 commit af9aa34
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 43 deletions.
2 changes: 2 additions & 0 deletions src/rnaglib/config/feature_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
"binding_protein_Rz": FloatEncoder(),
"is_modified": BoolEncoder(),
"is_broken": BoolEncoder(),
"protein_binding": BoolEncoder(),
"protein_content": ListEncoder(list_length=3)
}

# TODO : include edge information, but it's not trivial to deal with edges beyond RGCN...
Expand Down
4 changes: 2 additions & 2 deletions src/rnaglib/tasks/RBP_Node/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from rnaglib.tasks import ProteinBindingSiteDetection
from rnaglib.learning.task_models import PygModel

ta = ProteinBindingSiteDetection("RBP-Node", recompute=False, debug=False)
ta = ProteinBindingSiteDetection("RBP-Node", recompute=True, debug=False, filter_by_size=True, filter_by_resolution=True)

# Add representation
ta.dataset.add_representation(GraphRepresentation(framework="pyg"))

# Splitting dataset
ta.get_split_loaders(recompute=False)
ta.get_split_loaders(recompute=True)

# Train model
# Either by hand:
Expand Down
17 changes: 8 additions & 9 deletions src/rnaglib/tasks/RBP_Node/protein_binding_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
from rnaglib.data_loading import RNADataset
from rnaglib.tasks import ResidueClassificationTask
from rnaglib.transforms import FeaturesComputer
from rnaglib.transforms import ComposeFilters
from rnaglib.transforms import RibosomalFilter
from rnaglib.transforms import DummyFilter
from rnaglib.transforms import ComposeFilters, RibosomalFilter, DummyFilter, ResidueAttributeFilter
from rnaglib.transforms import PDBIDNameTransform
from rnaglib.transforms import ResidueAttributeFilter
from rnaglib.utils import dump_json


Expand All @@ -17,7 +14,7 @@ class ProteinBindingSiteDetection(ResidueClassificationTask):
a protein-binding interface
"""

target_var = "binding_protein"
target_var = "protein_binding"
input_var = "nt_code"

def __init__(self, root, splitter=None, **kwargs):
Expand All @@ -29,21 +26,23 @@ def get_task_vars(self):
def process(self):
# build the filters
ribo_filter = RibosomalFilter()
non_bind_filter = ResidueAttributeFilter(attribute=self.target_var, value_checker=lambda val: val is not None)
filters = ComposeFilters([ribo_filter, non_bind_filter])
non_bind_filter = ResidueAttributeFilter(attribute=self.target_var, value_checker=lambda val: val)
self.filters_list += [ribo_filter, non_bind_filter]
filters = ComposeFilters(self.filters_list)
if self.debug:
filters = DummyFilter()

# Define your transforms
add_name = PDBIDNameTransform()

# Run through database, applying our filters
dataset = RNADataset(debug=self.debug, in_memory=self.in_memory)
dataset = RNADataset(debug=self.debug, in_memory=False)
all_rnas = []
os.makedirs(self.dataset_path, exist_ok=True)
for rna in dataset:
if filters.forward(rna):
rna = add_name(rna)["rna"]
rna = add_name(rna)
rna = rna["rna"]
if self.in_memory:
all_rnas.append(rna)
else:
Expand Down
6 changes: 4 additions & 2 deletions src/rnaglib/tasks/RNA_CM/chemical_modification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from rnaglib.data_loading import RNADataset
from rnaglib.tasks import ResidueClassificationTask
from rnaglib.transforms import FeaturesComputer
from rnaglib.transforms import ResidueAttributeFilter
from rnaglib.transforms import ResidueAttributeFilter, ComposeFilters
from rnaglib.transforms import DummyFilter
from rnaglib.transforms import PDBIDNameTransform
from rnaglib.utils import dump_json
Expand All @@ -25,7 +25,9 @@ def get_task_vars(self):

def process(self):
# Define your transforms
rna_filter = ResidueAttributeFilter(attribute=self.target_var, value_checker=lambda val: val == True)
residue_attribute_filter = ResidueAttributeFilter(attribute=self.target_var, value_checker=lambda val: val == True)
self.filters_list.append(residue_attribute_filter)
rna_filter = ComposeFilters(self.filters_list)
if self.debug:
rna_filter = DummyFilter()
add_name = PDBIDNameTransform()
Expand Down
4 changes: 2 additions & 2 deletions src/rnaglib/tasks/RNA_CM/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from rnaglib.transforms import GraphRepresentation
from rnaglib.learning.task_models import PygModel

ta = ChemicalModification("RNA-CM")
ta = ChemicalModification(root="RNA-CM", recompute=True, filter_by_size=True, filter_by_resolution=True)

# Add representation
ta.dataset.add_representation(GraphRepresentation(framework="pyg"))

# Splitting dataset
ta.get_split_loaders(recompute=False)
ta.get_split_loaders(recompute=True)

# Train model
# Either by hand:
Expand Down
4 changes: 2 additions & 2 deletions src/rnaglib/tasks/RNA_Family/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from rnaglib.transforms import GraphRepresentation
from rnaglib.learning.task_models import PygModel

ta = RNAFamily(root="RNA-Family", recompute=False, debug=True)
ta = RNAFamily(root="RNA-Family", recompute=True, debug=False, filter_by_size=True, filter_by_resolution=True)

ta.dataset.add_representation(GraphRepresentation(framework="pyg"))

# Splitting dataset
ta.get_split_loaders(recompute=False, batch_size=1)
ta.get_split_loaders(recompute=True)

# Train model
model = PygModel(ta.metadata["description"]["num_node_features"], ta.metadata["description"]["num_classes"], graph_level=True)
Expand Down
55 changes: 41 additions & 14 deletions src/rnaglib/tasks/RNA_Family/rfam.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from rnaglib.data_loading import RNADataset
from rnaglib.tasks import RNAClassificationTask
from rnaglib.encoders import IntMappingEncoder
Expand All @@ -6,8 +8,11 @@
RfamTransform,
ChainNameTransform,
RNAAttributeFilter,
ComposeFilters
)
from rnaglib.transforms import FeaturesComputer
from rnaglib.utils import dump_json



class RNAFamily(RNAClassificationTask):
Expand All @@ -20,31 +25,53 @@ class RNAFamily(RNAClassificationTask):
target_var = "rfam" # graph level attribute
input_var = "nt_code" # node level attribute

def __init__(self, max_size: int = 200, **kwargs):
self.max_size = max_size
def __init__(self, **kwargs):
super().__init__(**kwargs)

def get_task_vars(self):
return FeaturesComputer(
nt_features=self.input_var,
rna_targets=self.target_var,
custom_encoders={self.target_var: IntMappingEncoder(self.metadata["label_mapping"])}, )
custom_encoders={self.target_var: IntMappingEncoder(self.metadata["label_mapping"])}
)

def process(self):
# Create dataset
full_dataset = RNADataset(debug=self.debug)
# compute rfam annotation, only keep ones with an Rfam annot.
# init filters
rna_filter = ComposeFilters(self.filters_list)
# Initialize dataset with in_memory=False to avoid loading everything at once
dataset = RNADataset(debug=self.debug, in_memory=False)
tr_rfam = RfamTransform(parallel=True)
rnas = tr_rfam(full_dataset)
rnas = list(RNAAttributeFilter(attribute=tr_rfam.name, value_checker=lambda val: val is not None)(rnas))
rfam_filter = RNAAttributeFilter(attribute=tr_rfam.name, value_checker=lambda val: val is not None)
chain_split = ChainSplitTransform()
chain_annotator = ChainNameTransform()

# Run through database, applying our filters
all_rnas = []
rfams_set = set()
os.makedirs(self.dataset_path, exist_ok=True)
for rna in dataset:
if rna_filter.forward(rna) and len(rna["rna"].nodes())>0:
annotated_rna = tr_rfam(rna)
if rfam_filter.forward(annotated_rna):
rfams_set.add(annotated_rna["rna"].graph["rfam"])
for chain in chain_split(annotated_rna):
annotated_chain = chain_annotator(chain)["rna"]
if self.in_memory:
all_rnas.append(annotated_chain)
else:
all_rnas.append(annotated_chain.name)
dump_json(
os.path.join(self.dataset_path, f"{annotated_chain.name}.json"),
annotated_chain,
)
# compute one-hot mapping of labels
labels = sorted(set([r["rna"].graph["rfam"] for r in rnas]))
labels = sorted(rfams_set)
rfam_mapping = {rfam: i for i, rfam in enumerate(labels)}
self.metadata["label_mapping"] = rfam_mapping

# split by chain
rnas = ChainSplitTransform()(rnas)
rnas = ChainNameTransform()(rnas)
if self.in_memory:
dataset = RNADataset(rnas=all_rnas)
else:
dataset = RNADataset(dataset_path=self.dataset_path, rna_id_subset=all_rnas)

new_dataset = RNADataset(rnas=list((r["rna"] for r in rnas)))
return new_dataset
return dataset
2 changes: 1 addition & 1 deletion src/rnaglib/tasks/RNA_Ligand/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
data = pd.read_csv(os.path.join(os.path.dirname(__file__), "data/gmsm_dataset.csv"))

# Creating task
ta = LigandIdentification('RNA-Ligand', data, recompute=True)
ta = LigandIdentification('RNA-Ligand', data, recompute=True, filter_by_size=True, filter_by_resolution=True)

# Splitting dataset
print("Splitting Dataset")
Expand Down
27 changes: 19 additions & 8 deletions src/rnaglib/tasks/RNA_Ligand/ligand_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from rnaglib.tasks import RNAClassificationTask
from rnaglib.data_loading import RNADataset
from rnaglib.encoders import IntEncoder
from rnaglib.transforms import FeaturesComputer, AnnotatorFromDict, PartitionFromDict, ResidueNameFilter
from rnaglib.transforms import FeaturesComputer, AnnotatorFromDict, PartitionFromDict, ResidueNameFilter, RBPTransform, ComposeFilters, ResidueAttributeFilter
from rnaglib.utils import dump_json


Expand All @@ -14,10 +14,12 @@ class LigandIdentification(RNAClassificationTask):
target_var = "ligand_code"
num_classes = 44

def __init__(self, root, data, splitter=None, **kwargs):
def __init__(self, root, data, filter_by_size=False, filter_by_resolution=False, splitter=None, **kwargs):
self.data = data
self.nodes_keep = set(data.nid.values)
self.bp_dict, self.ligands_dict = self.parse_data()
self.filter_by_size = filter_by_size
self.filter_by_resolution = filter_by_resolution
super().__init__(root=root, splitter=splitter, **kwargs)

def parse_data(self):
Expand All @@ -32,21 +34,30 @@ def parse_data(self):
return bp_dict, ligands_dict

def process(self):
rna_filter = ResidueNameFilter(value_checker=lambda name: name in self.nodes_keep, min_valid=1)
nt_partition = PartitionFromDict(partition_dict=self.bp_dict)
annotator = AnnotatorFromDict(annotation_dict=self.ligands_dict, name="ligand_code")

# Run through database, applying our filters
# Initialize dataset with in_memory=False to avoid loading everything at once
dataset = RNADataset(
debug=self.debug, in_memory=False, redundancy="all", rna_id_subset=list(self.data["RNA"].unique())
)

# Instantiate filters to apply
rna_set_filter = ResidueNameFilter(value_checker=lambda name: name in self.nodes_keep, min_valid=1)
non_bind_filter = ResidueAttributeFilter(attribute="protein_binding", value_checker=lambda val: val==False)
self.filters_list += [rna_set_filter, non_bind_filter]
filters = ComposeFilters(self.filters_list)

# Instantiate transforms to apply
nt_partition = PartitionFromDict(partition_dict=self.bp_dict)
annotator = AnnotatorFromDict(annotation_dict=self.ligands_dict, name="ligand_code")
#protein_content_annotator = RBPTransform(structures_dir=dataset.structures_path, protein_number_annotations=False, distances=[4.,6.,8.])

# Run through database, applying our filters
all_binding_pockets = []
os.makedirs(self.dataset_path, exist_ok=True)
for rna in dataset:
if rna_filter.forward(rna):
if filters.forward(rna):
for binding_pocket_dict in nt_partition(rna):
annotated_binding_pocket = annotator(binding_pocket_dict)["rna"]
#annotated_binding_pocket = protein_content_annotator(annotated_binding_pocket_dict)["rna"]
if self.in_memory:
all_binding_pockets.append(annotated_binding_pocket)
else:
Expand Down
7 changes: 4 additions & 3 deletions src/rnaglib/transforms/partition/from_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def forward(self, rna_dict: dict) -> Iterator[dict]:
subgraph_idx = 0
for current_subgraph_nodes in self.partition_dict[g.graph["name"]]:
subgraph = g.subgraph(current_subgraph_nodes).copy()
subgraph.name += "_" + str(subgraph_idx)
yield {"rna": subgraph}
subgraph_idx += 1
if len(subgraph.nodes())>0:
subgraph.name += "_" + str(subgraph_idx)
yield {"rna": subgraph}
subgraph_idx += 1

0 comments on commit af9aa34

Please sign in to comment.