From fa1d381fa6f53d340bd74f4404c00b7995706bdc Mon Sep 17 00:00:00 2001 From: Daniel Himmelstein Date: Wed, 12 Jul 2023 13:36:32 -0400 Subject: [PATCH] attempt to improve node info subclass typing method from https://github.com/related-sciences/nxontology/pull/26#discussion_r1260172182 --- nxontology/imports.py | 12 ++++----- nxontology/node.py | 25 ++++++++++--------- nxontology/ontology.py | 54 +++++++++++++++++++++++----------------- nxontology/similarity.py | 30 ++++++++++++---------- nxontology/viz.py | 8 +++--- 5 files changed, 71 insertions(+), 58 deletions(-) diff --git a/nxontology/imports.py b/nxontology/imports.py index c408bfa..eb180fa 100644 --- a/nxontology/imports.py +++ b/nxontology/imports.py @@ -11,7 +11,7 @@ from nxontology import NXOntology from nxontology.exceptions import NodeNotFound -from nxontology.node import Node +from nxontology.node import NodeT logger = logging.getLogger(__name__) @@ -86,21 +86,21 @@ def from_file(handle: BinaryIO | str | PathLike[AnyStr]) -> NXOntology[str]: def _pronto_edges_for_term( term: Term, default_rel_type: str = "is a" -) -> list[tuple[Node, Node, str]]: +) -> list[tuple[NodeT, NodeT, str]]: """ Extract edges including "is a" relationships for a Pronto term. https://github.com/althonos/pronto/issues/119#issuecomment-956541286 """ rels = [] - source_id = cast(Node, term.id) + source_id = cast(NodeT, term.id) for target in term.superclasses(distance=1, with_self=False): - rels.append((source_id, cast(Node, target.id), default_rel_type)) + rels.append((source_id, cast(NodeT, target.id), default_rel_type)) for rel_type, targets in term.relationships.items(): for target in sorted(targets): rels.append( ( - cast(Node, term.id), - cast(Node, target.id), + cast(NodeT, term.id), + cast(NodeT, target.id), rel_type.name or rel_type.id, ) ) diff --git a/nxontology/node.py b/nxontology/node.py index b1ec4c4..673c65a 100644 --- a/nxontology/node.py +++ b/nxontology/node.py @@ -15,10 +15,11 @@ # Type definitions. networkx does not declare types. # https://github.com/networkx/networkx/issues/3988#issuecomment-639969263 -Node = TypeVar("Node", bound=Hashable) +NodeT = TypeVar("NodeT", bound=Hashable) +NodeInfoT = TypeVar("NodeInfoT") -class Node_Info(Freezable, Generic[Node]): +class Node_Info(Freezable, Generic[NodeT]): """ Compute metrics and values for a node of an NXOntology. Includes intrinsic information content (IC) metrics. @@ -35,7 +36,7 @@ class Node_Info(Freezable, Generic[Node]): Each ic_metric has a scaled version accessible by adding a _scaled suffix. """ - def __init__(self, nxo: NXOntology[Node], node: Node): + def __init__(self, nxo: NXOntology[NodeT, NodeInfoT], node: NodeT): if node not in nxo.graph: raise NodeNotFound(f"{node} not in graph.") self.nxo = nxo @@ -98,12 +99,12 @@ def data(self) -> dict[Any, Any]: return data @property - def parents(self) -> set[Node]: + def parents(self) -> set[NodeT]: """Direct parent nodes of this node.""" return set(self.nxo.graph.predecessors(self.node)) @property - def parent(self) -> Node | None: + def parent(self) -> NodeT | None: """ Sole parent of this node, or None if this node is a root. If this node has multiple parents, raise ValueError. @@ -118,13 +119,13 @@ def parent(self) -> Node | None: raise ValueError(f"Node {self!r} has multiple parents.") @property - def children(self) -> set[Node]: + def children(self) -> set[NodeT]: """Direct child nodes of this node.""" return set(self.nxo.graph.successors(self.node)) @property @cache_on_frozen - def ancestors(self) -> set[Node]: + def ancestors(self) -> set[NodeT]: """ Get ancestors of node in graph, including the node itself. Ancestors refers to more general concepts in an ontology, @@ -137,7 +138,7 @@ def ancestors(self) -> set[Node]: @property @cache_on_frozen - def descendants(self) -> set[Node]: + def descendants(self) -> set[NodeT]: """ Get descendants of node in graph, including the node itself. Descendants refers to more specific concepts in an ontology, @@ -160,12 +161,12 @@ def n_descendants(self) -> int: @property @cache_on_frozen - def roots(self) -> set[Node]: + def roots(self) -> set[NodeT]: """Ancestors of this node that are roots (top-level).""" return self.ancestors & self.nxo.roots @property - def leaves(self) -> set[Node]: + def leaves(self) -> set[NodeT]: """Descendents of this node that are leaves.""" return self.descendants & self.nxo.leaves @@ -181,14 +182,14 @@ def depth(self) -> int: return depth @property - def paths_from_roots(self) -> Iterator[list[Node]]: + def paths_from_roots(self) -> Iterator[list[NodeT]]: for root in self.roots: yield from nx.all_simple_paths( self.nxo.graph, source=root, target=self.node ) @property - def paths_to_leaves(self) -> Iterator[list[Node]]: + def paths_to_leaves(self) -> Iterator[list[NodeT]]: yield from nx.all_simple_paths( self.nxo.graph, source=self.node, target=self.leaves ) diff --git a/nxontology/ontology.py b/nxontology/ontology.py index e44ecdc..10c2ca4 100644 --- a/nxontology/ontology.py +++ b/nxontology/ontology.py @@ -3,6 +3,7 @@ import itertools import json import logging +from abc import abstractmethod from os import PathLike, fspath from typing import Any, Generic, Iterable, cast @@ -12,7 +13,7 @@ from networkx.algorithms.isolate import isolates from networkx.readwrite.json_graph import node_link_data, node_link_graph -from nxontology.node import Node +from nxontology.node import NodeInfoT, NodeT from .exceptions import DuplicateError, NodeNotFound from .node import Node_Info @@ -22,7 +23,7 @@ logger = logging.getLogger(__name__) -class NXOntology(Freezable, Generic[Node]): +class NXOntologyBase(Freezable, Generic[NodeT, NodeInfoT]): """ Encapsulate a networkx.DiGraph to represent an ontology. Regarding edge directionality, parent terms should point to child term. @@ -39,7 +40,7 @@ def __init__( # in case there are compatability issues in the future. self._add_nxontology_metadata() self.check_is_dag() - self._node_info_cache: dict[Node, Node_Info[Node]] = {} + self._node_info_cache: dict[NodeT, Node_Info[NodeT]] = {} def _add_nxontology_metadata(self) -> None: self.graph.graph["nxontology_version"] = get_nxontology_version() @@ -77,7 +78,7 @@ def write_node_link_json(self, path: str | PathLike[str]) -> None: write_file.write("\n") # json.dump does not include a trailing newline @classmethod - def read_node_link_json(cls, path: str | PathLike[str]) -> NXOntology[Node]: + def read_node_link_json(cls, path: str | PathLike[str]) -> NXOntologyBase[NodeT]: """ Retrun a new graph from node-link format as written by `write_node_link_json`. """ @@ -90,7 +91,7 @@ def read_node_link_json(cls, path: str | PathLike[str]) -> NXOntology[Node]: nxo = cls(digraph) return nxo - def add_node(self, node_for_adding: Node, **attr: Any) -> None: + def add_node(self, node_for_adding: NodeT, **attr: Any) -> None: """ Like networkx.DiGraph.add_node but raises a DuplicateError if the node already exists. @@ -99,7 +100,7 @@ def add_node(self, node_for_adding: Node, **attr: Any) -> None: raise DuplicateError(f"node already in graph: {node_for_adding}") self.graph.add_node(node_for_adding, **attr) - def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr: Any) -> None: + def add_edge(self, u_of_edge: NodeT, v_of_edge: NodeT, **attr: Any) -> None: """ Like networkx.DiGraph.add_edge but raises a NodeNotFound if either node does not exist @@ -116,7 +117,7 @@ def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr: Any) -> None: @property @cache_on_frozen - def roots(self) -> set[Node]: + def roots(self) -> set[NodeT]: """ Return all top-level nodes, including isolates. """ @@ -127,7 +128,7 @@ def roots(self) -> set[Node]: return roots @property - def root(self) -> Node: + def root(self) -> NodeT: """ Sole root of this directed acyclic graph. If this ontology has multiple roots, raise ValueError. @@ -142,7 +143,7 @@ def root(self) -> Node: @property @cache_on_frozen - def leaves(self) -> set[Node]: + def leaves(self) -> set[NodeT]: """ Return all bottom-level nodes, including isolates. """ @@ -154,7 +155,7 @@ def leaves(self) -> set[Node]: @property @cache_on_frozen - def isolates(self) -> set[Node]: + def isolates(self) -> set[NodeT]: """ Return disconnected nodes. """ @@ -175,17 +176,17 @@ def frozen(self) -> bool: def similarity( self, - node_0: Node, - node_1: Node, + node_0: NodeT, + node_1: NodeT, ic_metric: str = "intrinsic_ic_sanchez", - ) -> SimilarityIC[Node]: + ) -> SimilarityIC[NodeT]: """SimilarityIC instance for the specified nodes""" return SimilarityIC(self, node_0, node_1, ic_metric) def similarity_metrics( self, - node_0: Node, - node_1: Node, + node_0: NodeT, + node_1: NodeT, ic_metric: str = "intrinsic_ic_sanchez", keys: list[str] | None = None, ) -> dict[str, Any]: @@ -197,8 +198,8 @@ def similarity_metrics( def compute_similarities( self, - source_nodes: Iterable[Node], - target_nodes: Iterable[Node], + source_nodes: Iterable[NodeT], + target_nodes: Iterable[NodeT], ic_metrics: list[str] | tuple[str, ...] = ("intrinsic_ic_sanchez",), ) -> Iterable[dict[str, Any]]: """ @@ -213,16 +214,17 @@ def compute_similarities( yield metrics @classmethod - def _get_node_info_cls(cls) -> type[Node_Info[Node]]: + @abstractmethod + def _get_node_info_cls(cls) -> type[NodeInfoT]: """ Return the Node_Info class to use for this ontology. Subclasses can override this to use a custom Node_Info class. For the complexity of typing this method, see . """ - return Node_Info + ... - def node_info(self, node: Node) -> Node_Info[Node]: + def node_info(self, node: NodeT) -> NodeInfoT: """ Return Node_Info instance for `node`. If frozen, cache node info in `self._node_info_cache`. @@ -235,8 +237,8 @@ def node_info(self, node: Node) -> Node_Info[Node]: return self._node_info_cache[node] @cache_on_frozen - def _get_name_to_node_info(self) -> dict[str, Node_Info[Node]]: - name_to_node_info: dict[str, Node_Info[Node]] = {} + def _get_name_to_node_info(self) -> dict[str, Node_Info[NodeT]]: + name_to_node_info: dict[str, Node_Info[NodeT]] = {} for node in self.graph: info = self.node_info(node) name = info.name @@ -249,7 +251,7 @@ def _get_name_to_node_info(self) -> dict[str, Node_Info[Node]]: name_to_node_info[name] = info return name_to_node_info - def node_info_by_name(self, name: str) -> Node_Info[Node]: + def node_info_by_name(self, name: str) -> Node_Info[NodeT]: """ Return Node_Info instance using a lookup by name. """ @@ -306,3 +308,9 @@ def set_graph_attributes( self.graph.graph["node_identifier_attribute"] = node_identifier_attribute if node_url_attribute: self.graph.graph["node_url_attribute"] = node_url_attribute + + +class NXOntology(NXOntologyBase[NodeT, Node_Info[NodeT]]): + @classmethod + def _get_node_info_cls(cls) -> type[Node_Info[NodeT]]: + return Node_Info diff --git a/nxontology/similarity.py b/nxontology/similarity.py index 5fbe596..700c808 100644 --- a/nxontology/similarity.py +++ b/nxontology/similarity.py @@ -3,16 +3,18 @@ import math from typing import TYPE_CHECKING, Any, Generic +from nxontology.ontology import NXOntologyBase + if TYPE_CHECKING: - from nxontology.ontology import NXOntology + pass from networkx import shortest_path_length -from nxontology.node import Node, Node_Info +from nxontology.node import Node_Info, NodeInfoT, NodeT from nxontology.utils import Freezable, cache_on_frozen -class Similarity(Freezable, Generic[Node]): +class Similarity(Freezable, Generic[NodeT]): """ Compute intrinsic similarity metrics for a pair of nodes. """ @@ -29,7 +31,9 @@ class Similarity(Freezable, Generic[Node]): "batet_log", ] - def __init__(self, nxo: NXOntology[Node], node_0: Node, node_1: Node): + def __init__( + self, nxo: NXOntologyBase[NodeT, NodeInfoT], node_0: NodeT, node_1: NodeT + ): self.nxo = nxo self.node_0 = node_0 self.node_1 = node_1 @@ -68,12 +72,12 @@ def depth(self) -> int | None: @property @cache_on_frozen - def common_ancestors(self) -> set[Node]: + def common_ancestors(self) -> set[NodeT]: return self.info_0.ancestors & self.info_1.ancestors @property @cache_on_frozen - def union_ancestors(self) -> set[Node]: + def union_ancestors(self) -> set[NodeT]: return self.info_0.ancestors | self.info_1.ancestors @property @@ -116,7 +120,7 @@ def results(self, keys: list[str] | None = None) -> dict[str, Any]: return {key: getattr(self, key) for key in keys} -class SimilarityIC(Similarity[Node]): +class SimilarityIC(Similarity[NodeT]): """ Compute intrinsic similarity metrics for a pair of nodes, including Information Content (IC) derived metrics. @@ -125,9 +129,9 @@ class SimilarityIC(Similarity[Node]): def __init__( self, - nxo: NXOntology[Node], - node_0: Node, - node_1: Node, + nxo: NXOntologyBase[NodeT], + node_0: NodeT, + node_1: NodeT, ic_metric: str = "intrinsic_ic_sanchez", ): super().__init__(nxo, node_0, node_1) @@ -151,7 +155,7 @@ def __init__( "jiang_seco", ] - def _get_ic(self, node_info: Node_Info[Node], ic_metric: str) -> float: + def _get_ic(self, node_info: Node_Info[NodeT], ic_metric: str) -> float: ic = getattr(node_info, ic_metric) assert isinstance(ic, float) return ic @@ -174,7 +178,7 @@ def node_1_ic_scaled(self) -> float: @property @cache_on_frozen - def _resnik_mica(self) -> tuple[float, Node | None]: + def _resnik_mica(self) -> tuple[float, NodeT | None]: if not self.common_ancestors: return 0.0, None resnik, mica = max( @@ -185,7 +189,7 @@ def _resnik_mica(self) -> tuple[float, Node | None]: return resnik, mica @property - def mica(self) -> Node | None: + def mica(self) -> NodeT | None: """ Most informative common ancestor. None if no common ancestors exist. diff --git a/nxontology/viz.py b/nxontology/viz.py index 6a3be1b..2bd8076 100644 --- a/nxontology/viz.py +++ b/nxontology/viz.py @@ -5,13 +5,13 @@ from networkx.drawing.nx_agraph import to_agraph from pygraphviz.agraph import AGraph -from nxontology.node import Node, Node_Info +from nxontology.node import Node_Info, NodeT from nxontology.similarity import SimilarityIC def create_similarity_graphviz( - sim: SimilarityIC[Node], - nodes: Iterable[Node] | None = None, + sim: SimilarityIC[NodeT], + nodes: Iterable[NodeT] | None = None, ) -> AGraph: """ Create a pygraphviz AGraph to render the similarity subgraph with graphviz. @@ -81,7 +81,7 @@ def create_similarity_graphviz( return gviz -def get_verbose_node_label(info: Node_Info[Node]) -> str: +def get_verbose_node_label(info: Node_Info[NodeT]) -> str: """Return verbose label like 'name (identifier)'.""" verbose_label = info.name assert isinstance(verbose_label, str)