Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attempt to improve node info subclass typing #27

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nxontology/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# Type definitions. networkx does not declare types.
# https://github.com/networkx/networkx/issues/3988#issuecomment-639969263
NodeT = TypeVar("NodeT", bound=Hashable)
NodeInfoT = TypeVar("NodeInfoT")


class NodeInfo(Freezable, Generic[NodeT]):
Expand All @@ -35,7 +36,7 @@ class NodeInfo(Freezable, Generic[NodeT]):
Each ic_metric has a scaled version accessible by adding a _scaled suffix.
"""

def __init__(self, nxo: NXOntology[NodeT], node: NodeT):
def __init__(self, nxo: NXOntology[NodeT, NodeInfoT], node: NodeT):
if node not in nxo.graph:
raise NodeNotFound(f"{node} not in graph.")
self.nxo = nxo
Expand Down
18 changes: 13 additions & 5 deletions nxontology/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import json
import logging
from abc import abstractmethod
from os import PathLike, fspath
from typing import Any, Generic, Iterable, cast

Expand All @@ -12,7 +13,7 @@
from networkx.algorithms.isolate import isolates
from networkx.readwrite.json_graph import node_link_data, node_link_graph

from nxontology.node import NodeT
from nxontology.node import NodeInfoT, NodeT

from .exceptions import DuplicateError, NodeNotFound
from .node import NodeInfo
Expand All @@ -22,7 +23,7 @@
logger = logging.getLogger(__name__)


class NXOntology(Freezable, Generic[NodeT]):
class NXOntologyBase(Freezable, Generic[NodeT, NodeInfoT]):
"""
Encapsulate a networkx.DiGraph to represent an ontology.
Regarding edge directionality, parent terms should point to child term.
Expand Down Expand Up @@ -77,7 +78,7 @@ def write_node_link_json(self, path: str | PathLike[str]) -> None:
write_file.write("\n") # json.dump does not include a trailing newline

@classmethod
def read_node_link_json(cls, path: str | PathLike[str]) -> NXOntology[NodeT]:
def read_node_link_json(cls, path: str | PathLike[str]) -> NXOntologyBase[NodeT]:
"""
Retrun a new graph from node-link format as written by `write_node_link_json`.
"""
Expand Down Expand Up @@ -213,7 +214,8 @@ def compute_similarities(
yield metrics

@classmethod
def _get_node_info_cls(cls) -> type[NodeInfo[NodeT]]:
@abstractmethod
def _get_node_info_cls(cls) -> type[NodeInfoT]:
"""
Return the Node_Info class to use for this ontology.
Subclasses can override this to use a custom Node_Info class.
Expand All @@ -222,7 +224,7 @@ def _get_node_info_cls(cls) -> type[NodeInfo[NodeT]]:
"""
return NodeInfo

def node_info(self, node: NodeT) -> NodeInfo[NodeT]:
def node_info(self, node: NodeT) -> NodeInfoT:
"""
Return Node_Info instance for `node`.
If frozen, cache node info in `self._node_info_cache`.
Expand Down Expand Up @@ -306,3 +308,9 @@ def set_graph_attributes(
self.graph.graph["node_identifier_attribute"] = node_identifier_attribute
if node_url_attribute:
self.graph.graph["node_url_attribute"] = node_url_attribute


class NXOntology(NXOntologyBase[NodeT, NodeInfo[NodeT]]):
@classmethod
def _get_node_info_cls(cls) -> type[NodeInfo[NodeT]]:
return NodeInfo
10 changes: 6 additions & 4 deletions nxontology/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from typing import TYPE_CHECKING, Any, Generic

if TYPE_CHECKING:
from nxontology.ontology import NXOntology
from nxontology.ontology import NXOntologyBase

from networkx import shortest_path_length

from nxontology.node import NodeInfo, NodeT
from nxontology.node import NodeInfo, NodeInfoT, NodeT
from nxontology.utils import Freezable, cache_on_frozen


Expand All @@ -29,7 +29,9 @@ class Similarity(Freezable, Generic[NodeT]):
"batet_log",
]

def __init__(self, nxo: NXOntology[NodeT], node_0: NodeT, node_1: NodeT):
def __init__(
self, nxo: NXOntologyBase[NodeT, NodeInfoT], node_0: NodeT, node_1: NodeT
):
self.nxo = nxo
self.node_0 = node_0
self.node_1 = node_1
Expand Down Expand Up @@ -125,7 +127,7 @@ class SimilarityIC(Similarity[NodeT]):

def __init__(
self,
nxo: NXOntology[NodeT],
nxo: NXOntologyBase[NodeT],
node_0: NodeT,
node_1: NodeT,
ic_metric: str = "intrinsic_ic_sanchez",
Expand Down