From aa2f2d302564e648fbb2cbc1dcac24904e770969 Mon Sep 17 00:00:00 2001 From: Kevin Hartman Date: Wed, 31 Jul 2024 17:17:07 -0400 Subject: [PATCH 1/5] Port substitute_node_with_subgraph to core. --- rustworkx-core/src/graph_ext/mod.rs | 2 + rustworkx-core/src/graph_ext/substitution.rs | 300 +++++++++++++++++++ 2 files changed, 302 insertions(+) create mode 100644 rustworkx-core/src/graph_ext/substitution.rs diff --git a/rustworkx-core/src/graph_ext/mod.rs b/rustworkx-core/src/graph_ext/mod.rs index 256d6ac0a..00bd5f156 100644 --- a/rustworkx-core/src/graph_ext/mod.rs +++ b/rustworkx-core/src/graph_ext/mod.rs @@ -71,12 +71,14 @@ use petgraph::{EdgeType, Graph}; pub mod contraction; pub mod multigraph; +pub mod substitution; pub use contraction::{ ContractNodesDirected, ContractNodesSimpleDirected, ContractNodesSimpleUndirected, ContractNodesUndirected, }; pub use multigraph::{HasParallelEdgesDirected, HasParallelEdgesUndirected}; +pub use substitution::SubstituteNodeWithGraph; /// A graph whose nodes may be removed. pub trait NodeRemovable: Data { diff --git a/rustworkx-core/src/graph_ext/substitution.rs b/rustworkx-core/src/graph_ext/substitution.rs new file mode 100644 index 000000000..48b55a393 --- /dev/null +++ b/rustworkx-core/src/graph_ext/substitution.rs @@ -0,0 +1,300 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +//! This module defines graph traits for node contraction. + +use crate::dictmap::{DictMap, InitWithHasher}; +use petgraph::data::DataMap; +use petgraph::stable_graph; +use petgraph::visit::{ + Data, EdgeRef, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeRef, +}; +use petgraph::{Directed, Direction}; +use std::convert::Infallible; +use std::error::Error; +use std::fmt::{Display, Formatter}; +use std::hash::Hash; + +#[derive(Debug)] +pub enum SubstituteNodeWithGraphError { + EdgeMapErr(EME), + NodeFilterErr(NFE), + EdgeWeightTransformErr(ETE), +} + +impl Display for SubstituteNodeWithGraphError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + SubstituteNodeWithGraphError::EdgeMapErr(e) => { + write!(f, "Edge map callback failed with: {}", e) + } + SubstituteNodeWithGraphError::NodeFilterErr(e) => { + write!(f, "Node filter callback failed with: {}", e) + } + SubstituteNodeWithGraphError::EdgeWeightTransformErr(e) => { + write!(f, "Edge weight transform callback failed with: {}", e) + } + } + } +} + +impl Error for SubstituteNodeWithGraphError {} + +pub struct NoCallback; + +pub trait NodeFilter { + type Error; + fn enabled(&self) -> bool; + fn filter(&mut self, _g0: &G0, _n0: G0::NodeId) -> Result; +} + +impl NodeFilter for NoCallback { + type Error = Infallible; + #[inline] + fn enabled(&self) -> bool { + false + } + #[inline] + fn filter(&mut self, _g0: &G0, _n0: G0::NodeId) -> Result { + Ok(true) + } +} + +impl NodeFilter for F +where + G0: GraphBase + DataMap, + F: FnMut(&G0::NodeWeight) -> Result, +{ + type Error = E; + #[inline] + fn enabled(&self) -> bool { + true + } + #[inline] + fn filter(&mut self, g0: &G0, n0: G0::NodeId) -> Result { + if let Some(x) = g0.node_weight(n0) { + self(x) + } else { + Ok(false) + } + } +} + +pub trait EdgeWeightMapper { + type Error; + type MappedWeight; + + fn map(&mut self, g: &G, e: G::EdgeId) -> Result; +} + +impl EdgeWeightMapper for NoCallback +where + G::EdgeWeight: Clone, +{ + type Error = Infallible; + type MappedWeight = G::EdgeWeight; + #[inline] + fn map(&mut self, g: &G, e: G::EdgeId) -> Result { + Ok(g.edge_weight(e).unwrap().clone()) + } +} + +impl EdgeWeightMapper for F +where + G0: GraphBase + DataMap, + F: FnMut(&G0::EdgeWeight) -> Result, +{ + type Error = E; + type MappedWeight = EW; + + #[inline] + fn map(&mut self, g0: &G0, e0: G0::EdgeId) -> Result { + if let Some(x) = g0.edge_weight(e0) { + self(x) + } else { + panic!("Edge MUST exist in graph.") + } + } +} +pub trait SubstituteNodeWithGraph: DataMap { + /// The error type returned by the substitution. + type Error: Error; + + /// Substitute a node with a Graph. + /// + /// The specified `node` is replaced with the Graph `other`. + /// + /// To control the + /// + /// :param int node: The node to replace with the PyDiGraph object + /// :param PyDiGraph other: The other graph to replace ``node`` with + /// :param callable edge_map_fn: A callable object that will take 3 position + /// parameters, ``(source, target, weight)`` to represent an edge either to + /// or from ``node`` in this graph. The expected return value from this + /// callable is the node index of the node in ``other`` that an edge should + /// be to/from. If None is returned, that edge will be skipped and not + /// be copied. + /// :param callable node_filter: An optional callable object that when used + /// will receive a node's payload object from ``other`` and return + /// ``True`` if that node is to be included in the graph or not. + /// :param callable edge_weight_map: An optional callable object that when + /// used will receive an edge's weight/data payload from ``other`` and + /// will return an object to use as the weight for a newly created edge + /// after the edge is mapped from ``other``. If not specified the weight + /// from the edge in ``other`` will be copied by reference and used. + /// + /// :returns: A mapping of node indices in ``other`` to the equivalent node + /// in this graph. + /// :rtype: NodeMap + /// + /// .. note:: + /// + /// The return type is a :class:`rustworkx.NodeMap` which is an unordered + /// type. So it does not provide a deterministic ordering between objects + /// when iterated over (although the same object will have a consistent + /// order when iterated over multiple times). + fn substitute_node_with_graph( + &mut self, + node: Self::NodeId, + other: &G1, + edge_map_fn: EM, + node_filter: NF, + edge_weight_map: ET, + ) -> Result, Self::Error> + where + G1: Data + DataMap + NodeCount, + G1::NodeId: Hash + Eq, + G1::NodeWeight: Clone, + for<'a> &'a G1: GraphBase + + Data + + IntoNodeReferences + + IntoEdgeReferences, + EM: FnMut(Direction, Self::NodeId, &Self::EdgeWeight) -> Result, EME>, + NF: NodeFilter, + ET: EdgeWeightMapper, + NF::Error: Error, + ET::Error: Error; +} + +impl SubstituteNodeWithGraph for stable_graph::StableGraph +where + Ix: stable_graph::IndexType, + E: Clone, +{ + type Error = SubstituteNodeWithGraphError; + + fn substitute_node_with_graph( + &mut self, + node: Self::NodeId, + other: &G1, + mut edge_map_fn: EM, + mut node_filter: NF, + mut edge_weight_map: ET, + ) -> Result, Self::Error> + where + G1: Data + DataMap + NodeCount, + G1::NodeId: Hash + Eq, + G1::NodeWeight: Clone, + for<'a> &'a G1: GraphBase + + Data + + IntoNodeReferences + + IntoEdgeReferences, + EM: FnMut(Direction, Self::NodeId, &Self::EdgeWeight) -> Result, EME>, + NF: NodeFilter, + ET: EdgeWeightMapper, + NF::Error: Error, + ET::Error: Error, + { + let node_index = node; + if self.node_weight(node_index).is_none() { + panic!("Node `node` MUST be present in graph."); + } + // Copy nodes from other to self + let mut out_map: DictMap = + DictMap::with_capacity(other.node_count()); + for node in other.node_references() { + if node_filter.enabled() + && !node_filter + .filter(other, node.id()) + .map_err(|e| SubstituteNodeWithGraphError::NodeFilterErr(e))? + { + continue; + } + let new_index = self.add_node(node.weight().clone()); + out_map.insert(node.id(), new_index); + } + // If no nodes are copied bail here since there is nothing left + // to do. + if out_map.is_empty() { + self.remove_node(node_index); + // Return a new empty map to clear allocation from out_map + return Ok(DictMap::new()); + } + // Copy edges from other to self + for edge in other.edge_references().filter(|edge| { + out_map.contains_key(&edge.target()) && out_map.contains_key(&edge.source()) + }) { + self.add_edge( + out_map[&edge.source()], + out_map[&edge.target()], + edge_weight_map + .map(other, edge.id()) + .map_err(|e| SubstituteNodeWithGraphError::EdgeWeightTransformErr(e))?, + ); + } + // Add edges to/from node to nodes in other + let in_edges: Vec> = self + .edges_directed(node_index, petgraph::Direction::Incoming) + .map(|edge| { + let Some(target_in_other) = + edge_map_fn(Direction::Incoming, edge.source(), edge.weight()) + .map_err(|e| SubstituteNodeWithGraphError::EdgeMapErr(e))? + else { + return Ok(None); + }; + let target_in_self = out_map.get(&target_in_other).unwrap(); + Ok(Some(( + edge.source(), + *target_in_self, + edge.weight().clone(), + ))) + }) + .collect::>()?; + let out_edges: Vec> = self + .edges_directed(node_index, petgraph::Direction::Outgoing) + .map(|edge| { + let Some(source_in_other) = + edge_map_fn(Direction::Outgoing, edge.target(), edge.weight()) + .map_err(|e| SubstituteNodeWithGraphError::EdgeMapErr(e))? + else { + return Ok(None); + }; + let source_in_self = out_map.get(&source_in_other).unwrap(); + Ok(Some(( + *source_in_self, + edge.target(), + edge.weight().clone(), + ))) + }) + .collect::>()?; + for (source, target, weight) in in_edges + .into_iter() + .flatten() + .chain(out_edges.into_iter().flatten()) + { + self.add_edge(source, target, weight); + } + // Remove node + self.remove_node(node_index); + Ok(out_map) + } +} From f4b19d7f5a02e5ae764f8e0d700cdff813317c37 Mon Sep 17 00:00:00 2001 From: Kevin Hartman Date: Thu, 1 Aug 2024 22:20:29 -0400 Subject: [PATCH 2/5] Use core method from PyDiGraph. --- rustworkx-core/src/graph_ext/substitution.rs | 187 ++++++++++--------- src/digraph.rs | 146 +++++---------- src/lib.rs | 19 +- 3 files changed, 171 insertions(+), 181 deletions(-) diff --git a/rustworkx-core/src/graph_ext/substitution.rs b/rustworkx-core/src/graph_ext/substitution.rs index 48b55a393..4d85eb14e 100644 --- a/rustworkx-core/src/graph_ext/substitution.rs +++ b/rustworkx-core/src/graph_ext/substitution.rs @@ -21,19 +21,25 @@ use petgraph::visit::{ use petgraph::{Directed, Direction}; use std::convert::Infallible; use std::error::Error; -use std::fmt::{Display, Formatter}; +use std::fmt::{Debug, Display, Formatter}; use std::hash::Hash; #[derive(Debug)] -pub enum SubstituteNodeWithGraphError { +pub enum SubstituteNodeWithGraphError { + ReplacementGraphIndexError(N), EdgeMapErr(EME), NodeFilterErr(NFE), EdgeWeightTransformErr(ETE), } -impl Display for SubstituteNodeWithGraphError { +impl Display + for SubstituteNodeWithGraphError +{ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { + SubstituteNodeWithGraphError::ReplacementGraphIndexError(n) => { + write!(f, "Node {:?} was not found in the replacement graph.", n) + } SubstituteNodeWithGraphError::EdgeMapErr(e) => { write!(f, "Edge map callback failed with: {}", e) } @@ -47,24 +53,22 @@ impl Display for SubstituteNodeWithGraphErro } } -impl Error for SubstituteNodeWithGraphError {} +impl Error + for SubstituteNodeWithGraphError +{ +} pub struct NoCallback; -pub trait NodeFilter { +pub trait NodeFilter { type Error; - fn enabled(&self) -> bool; - fn filter(&mut self, _g0: &G0, _n0: G0::NodeId) -> Result; + fn filter(&mut self, graph: &G, node: G::NodeId) -> Result; } -impl NodeFilter for NoCallback { +impl NodeFilter for NoCallback { type Error = Infallible; #[inline] - fn enabled(&self) -> bool { - false - } - #[inline] - fn filter(&mut self, _g0: &G0, _n0: G0::NodeId) -> Result { + fn filter(&mut self, _graph: &G, _node: G::NodeId) -> Result { Ok(true) } } @@ -76,12 +80,8 @@ where { type Error = E; #[inline] - fn enabled(&self) -> bool { - true - } - #[inline] - fn filter(&mut self, g0: &G0, n0: G0::NodeId) -> Result { - if let Some(x) = g0.node_weight(n0) { + fn filter(&mut self, graph: &G0, node: G0::NodeId) -> Result { + if let Some(x) = graph.node_weight(node) { self(x) } else { Ok(false) @@ -92,8 +92,7 @@ where pub trait EdgeWeightMapper { type Error; type MappedWeight; - - fn map(&mut self, g: &G, e: G::EdgeId) -> Result; + fn map(&mut self, graph: &G, edge: G::EdgeId) -> Result; } impl EdgeWeightMapper for NoCallback @@ -103,84 +102,94 @@ where type Error = Infallible; type MappedWeight = G::EdgeWeight; #[inline] - fn map(&mut self, g: &G, e: G::EdgeId) -> Result { - Ok(g.edge_weight(e).unwrap().clone()) + fn map(&mut self, graph: &G, edge: G::EdgeId) -> Result { + Ok(graph.edge_weight(edge).unwrap().clone()) } } -impl EdgeWeightMapper for F +impl EdgeWeightMapper for F where - G0: GraphBase + DataMap, - F: FnMut(&G0::EdgeWeight) -> Result, + G: GraphBase + DataMap, + F: FnMut(&G::EdgeWeight) -> Result, { type Error = E; type MappedWeight = EW; #[inline] - fn map(&mut self, g0: &G0, e0: G0::EdgeId) -> Result { - if let Some(x) = g0.edge_weight(e0) { + fn map(&mut self, graph: &G, edge: G::EdgeId) -> Result { + if let Some(x) = graph.edge_weight(edge) { self(x) } else { panic!("Edge MUST exist in graph.") } } } + pub trait SubstituteNodeWithGraph: DataMap { /// The error type returned by the substitution. - type Error: Error; + type Error: Error; /// Substitute a node with a Graph. /// - /// The specified `node` is replaced with the Graph `other`. + /// The nodes and edges of Graph `other` are cloned into this + /// graph and connected to its preexisting nodes using an edge mapping + /// function, `edge_map_fn`. /// - /// To control the + /// The specified `edge_map_fn` is called for each of the edges between + /// the `node` being replaced and the rest of the graph and is expected + /// to return an index in `other` that the edge should be connected + /// to after the replacement, i.e. the node in `graph` that the edge + /// should be connected to once `node` is gone. It is also acceptable + /// for `edge_map_fn` to return `None`, in which case the edge is + /// ignored and will be dropped. /// - /// :param int node: The node to replace with the PyDiGraph object - /// :param PyDiGraph other: The other graph to replace ``node`` with - /// :param callable edge_map_fn: A callable object that will take 3 position - /// parameters, ``(source, target, weight)`` to represent an edge either to - /// or from ``node`` in this graph. The expected return value from this - /// callable is the node index of the node in ``other`` that an edge should - /// be to/from. If None is returned, that edge will be skipped and not - /// be copied. - /// :param callable node_filter: An optional callable object that when used - /// will receive a node's payload object from ``other`` and return - /// ``True`` if that node is to be included in the graph or not. - /// :param callable edge_weight_map: An optional callable object that when - /// used will receive an edge's weight/data payload from ``other`` and - /// will return an object to use as the weight for a newly created edge - /// after the edge is mapped from ``other``. If not specified the weight - /// from the edge in ``other`` will be copied by reference and used. + /// It accepts the following three arguments: + /// - The [Direction], which designates whether the original edge was + /// incoming or outgoing to `node`. + /// - The [Self::NodeId] of the _other_ node of the original edge (i.e. the + /// one that isn't `node`). + /// - A reference to the edge weight of the original edge. /// - /// :returns: A mapping of node indices in ``other`` to the equivalent node - /// in this graph. - /// :rtype: NodeMap + /// An optional `node_filter` can be provided to ignore nodes in `other` that + /// should not be copied into this graph. This parameter accepts implementations + /// of the trait [NodeFilter], which has a blanket implementation for callables + /// which are `FnMut(&G1::NodeWeight) -> Result`, i.e. functions which + /// take a reference to a node weight in `other` and return a boolean to indicate + /// if the node corresponding to this weight should be included or not. To disable + /// filtering, simply provide [NoCallback]. /// - /// .. note:: + /// A _sometimes_ optional `edge_weight_map` can be provided to transform edge weights from + /// the source graph `other` into weights of this graph. This parameter accepts + /// implementations of the trait [EdgeWeightMapper], which has a blanket + /// implementation for callables which are + /// `F: FnMut(&G1::EdgeWeight) -> Result`, + /// i.e. functions which take a reference to an edge weight in `graph` and return + /// an owned weight typed for this graph. An `edge_weight_map` must be provided + /// when `other` uses a different type for its edge weights, but can otherwise + /// be specified as [NoCallback] to disable mapping. /// - /// The return type is a :class:`rustworkx.NodeMap` which is an unordered - /// type. So it does not provide a deterministic ordering between objects - /// when iterated over (although the same object will have a consistent - /// order when iterated over multiple times). - fn substitute_node_with_graph( + /// This method returns a mapping of nodes in `other` to the copied node in + /// this graph. + #[allow(clippy::type_complexity)] + fn substitute_node_with_graph( &mut self, node: Self::NodeId, - other: &G1, + other: &G, edge_map_fn: EM, node_filter: NF, edge_weight_map: ET, - ) -> Result, Self::Error> + ) -> Result, Self::Error> where - G1: Data + DataMap + NodeCount, - G1::NodeId: Hash + Eq, - G1::NodeWeight: Clone, - for<'a> &'a G1: GraphBase - + Data + G: Data + DataMap + NodeCount, + G::NodeId: Debug + Hash + Eq, + G::NodeWeight: Clone, + for<'a> &'a G: GraphBase + + Data + IntoNodeReferences + IntoEdgeReferences, - EM: FnMut(Direction, Self::NodeId, &Self::EdgeWeight) -> Result, EME>, - NF: NodeFilter, - ET: EdgeWeightMapper, + EM: FnMut(Direction, Self::NodeId, &Self::EdgeWeight) -> Result, EME>, + NF: NodeFilter, + ET: EdgeWeightMapper, NF::Error: Error, ET::Error: Error; } @@ -190,27 +199,28 @@ where Ix: stable_graph::IndexType, E: Clone, { - type Error = SubstituteNodeWithGraphError; + type Error = + SubstituteNodeWithGraphError; - fn substitute_node_with_graph( + fn substitute_node_with_graph( &mut self, node: Self::NodeId, - other: &G1, + other: &G, mut edge_map_fn: EM, mut node_filter: NF, mut edge_weight_map: ET, - ) -> Result, Self::Error> + ) -> Result, Self::Error> where - G1: Data + DataMap + NodeCount, - G1::NodeId: Hash + Eq, - G1::NodeWeight: Clone, - for<'a> &'a G1: GraphBase - + Data + G: Data + DataMap + NodeCount, + G::NodeId: Debug + Hash + Eq, + G::NodeWeight: Clone, + for<'a> &'a G: GraphBase + + Data + IntoNodeReferences + IntoEdgeReferences, - EM: FnMut(Direction, Self::NodeId, &Self::EdgeWeight) -> Result, EME>, - NF: NodeFilter, - ET: EdgeWeightMapper, + EM: FnMut(Direction, Self::NodeId, &Self::EdgeWeight) -> Result, EME>, + NF: NodeFilter, + ET: EdgeWeightMapper, NF::Error: Error, ET::Error: Error, { @@ -219,13 +229,12 @@ where panic!("Node `node` MUST be present in graph."); } // Copy nodes from other to self - let mut out_map: DictMap = + let mut out_map: DictMap = DictMap::with_capacity(other.node_count()); for node in other.node_references() { - if node_filter.enabled() - && !node_filter - .filter(other, node.id()) - .map_err(|e| SubstituteNodeWithGraphError::NodeFilterErr(e))? + if !node_filter + .filter(other, node.id()) + .map_err(|e| SubstituteNodeWithGraphError::NodeFilterErr(e))? { continue; } @@ -261,7 +270,11 @@ where else { return Ok(None); }; - let target_in_self = out_map.get(&target_in_other).unwrap(); + let Some(target_in_self) = out_map.get(&target_in_other) else { + return Err(SubstituteNodeWithGraphError::ReplacementGraphIndexError( + target_in_other, + )); + }; Ok(Some(( edge.source(), *target_in_self, @@ -278,7 +291,11 @@ where else { return Ok(None); }; - let source_in_self = out_map.get(&source_in_other).unwrap(); + let Some(source_in_self) = out_map.get(&source_in_other) else { + return Err(SubstituteNodeWithGraphError::ReplacementGraphIndexError( + source_in_other, + )); + }; Ok(Some(( *source_in_self, edge.target(), diff --git a/src/digraph.rs b/src/digraph.rs index b15b3dea0..88e771e2a 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -2550,6 +2550,10 @@ impl PyDiGraph { /// when iterated over (although the same object will have a consistent /// order when iterated over multiple times). /// + /// If the replacement graph ``other`` contains cycles or is not a + /// multigraph, then this graph will also contain cycles or become + /// a multigraph after the substitution. + /// #[pyo3( text_signature = "(self, node, other, edge_map_fn, /, node_filter=None, edge_weight_map=None)" )] @@ -2561,109 +2565,61 @@ impl PyDiGraph { edge_map_fn: PyObject, node_filter: Option, edge_weight_map: Option, - ) -> PyResult { - let weight_map_fn = |obj: &PyObject, weight_fn: &Option| -> PyResult { - match weight_fn { - Some(weight_fn) => weight_fn.call1(py, (obj,)), - None => Ok(obj.clone_ref(py)), - } - }; - let map_fn = |source: usize, target: usize, weight: &PyObject| -> PyResult> { - let res = edge_map_fn.call1(py, (source, target, weight))?; - res.extract(py) + ) -> RxPyResult { + let node_index: NodeIndex = NodeIndex::new(node); + if self.graph.node_weight(node_index).is_none() { + return Err(PyIndexError::new_err(format!( + "Specified node {} is not in this graph", + node + )) + .into()); + } + + let edge_map_fn = |direction: Direction, + node: NodeIndex, + weight: &PyObject| + -> PyResult> { + let edge = match direction { + Direction::Incoming => (node.index(), node_index.index(), weight), + Direction::Outgoing => (node_index.index(), node.index(), weight), + }; + let res = edge_map_fn.call1(py, edge)?; + let index: Option = res.extract(py)?; + Ok(index.map(|i| NodeIndex::new(i))) }; - let filter_fn = |obj: &PyObject, filter_fn: &Option| -> PyResult { - match filter_fn { - Some(filter) => { + + let node_filter = move |obj: &PyObject| -> PyResult { + match node_filter { + Some(ref filter) => { let res = filter.call1(py, (obj,))?; res.extract(py) } None => Ok(true), } }; - let node_index: NodeIndex = NodeIndex::new(node); - if self.graph.node_weight(node_index).is_none() { - return Err(PyIndexError::new_err(format!( - "Specified node {} is not in this graph", - node - ))); - } - // Copy nodes from other to self - let mut out_map: DictMap = DictMap::with_capacity(other.node_count()); - for node in other.graph.node_indices() { - let node_weight = other.graph[node].clone_ref(py); - if !filter_fn(&node_weight, &node_filter)? { - continue; + + let weight_map_fn = move |obj: &PyObject| -> PyResult { + match edge_weight_map { + Some(ref weight_fn) => weight_fn.call1(py, (obj,)), + None => Ok(obj.clone_ref(py)), } - let new_index = self.graph.add_node(node_weight); - out_map.insert(node.index(), new_index.index()); - } - // If no nodes are copied bail here since there is nothing left - // to do. - if out_map.is_empty() { - self.remove_node(node_index.index())?; - // Return a new empty map to clear allocation from out_map - return Ok(NodeMap { - node_map: DictMap::new(), - }); - } - // Copy edges from other to self - for edge in other.graph.edge_references().filter(|edge| { - out_map.contains_key(&edge.target().index()) - && out_map.contains_key(&edge.source().index()) - }) { - self._add_edge( - NodeIndex::new(out_map[&edge.source().index()]), - NodeIndex::new(out_map[&edge.target().index()]), - weight_map_fn(edge.weight(), &edge_weight_map)?, - )?; - } - // Add edges to/from node to nodes in other - let in_edges: Vec<(NodeIndex, NodeIndex, PyObject)> = self - .graph - .edges_directed(node_index, petgraph::Direction::Incoming) - .map(|edge| (edge.source(), edge.target(), edge.weight().clone_ref(py))) - .collect(); - let out_edges: Vec<(NodeIndex, NodeIndex, PyObject)> = self - .graph - .edges_directed(node_index, petgraph::Direction::Outgoing) - .map(|edge| (edge.source(), edge.target(), edge.weight().clone_ref(py))) - .collect(); - for (source, target, weight) in in_edges { - let old_index = map_fn(source.index(), target.index(), &weight)?; - let target_out = match old_index { - Some(old_index) => match out_map.get(&old_index) { - Some(new_index) => NodeIndex::new(*new_index), - None => { - return Err(PyIndexError::new_err(format!( - "No mapped index {} found", - old_index - ))) - } - }, - None => continue, - }; - self._add_edge(source, target_out, weight)?; - } - for (source, target, weight) in out_edges { - let old_index = map_fn(source.index(), target.index(), &weight)?; - let source_out = match old_index { - Some(old_index) => match out_map.get(&old_index) { - Some(new_index) => NodeIndex::new(*new_index), - None => { - return Err(PyIndexError::new_err(format!( - "No mapped index {} found", - old_index - ))) - } - }, - None => continue, - }; - self._add_edge(source_out, target, weight)?; - } - // Remove node - self.remove_node(node_index.index())?; - Ok(NodeMap { node_map: out_map }) + }; + + let out_map = self.graph.substitute_node_with_graph( + node_index, + &other.graph, + edge_map_fn, + node_filter, + weight_map_fn, + )?; + + self.node_removed = true; + Ok(NodeMap { + node_map: out_map + .into_iter() + .map(|(k, v)| (k.index(), v.index())) + .collect(), + }) } /// Substitute a set of nodes with a single new node. diff --git a/src/lib.rs b/src/lib.rs index 79f183462..a78251d85 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -70,8 +70,8 @@ use hashbrown::HashMap; use numpy::Complex64; use pyo3::create_exception; -use pyo3::exceptions::PyException; use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{PyException, PyIndexError}; use pyo3::import_exception; use pyo3::prelude::*; use pyo3::wrap_pyfunction; @@ -88,9 +88,11 @@ use petgraph::EdgeType; use rustworkx_core::dag_algo::TopologicalSortError; use std::convert::TryFrom; +use std::fmt::Debug; use rustworkx_core::dictmap::*; use rustworkx_core::err::{ContractError, ContractSimpleError}; +use rustworkx_core::graph_ext::substitution::SubstituteNodeWithGraphError; /// An ergonomic error type used to map Rustworkx core errors to /// [PyErr] automatically, via [From::from]. @@ -144,6 +146,21 @@ impl From> for RxPyErr { } } +impl From> for RxPyErr { + fn from(value: SubstituteNodeWithGraphError) -> Self { + RxPyErr { + pyerr: match value { + SubstituteNodeWithGraphError::EdgeMapErr(e) + | SubstituteNodeWithGraphError::NodeFilterErr(e) + | SubstituteNodeWithGraphError::EdgeWeightTransformErr(e) => e, + SubstituteNodeWithGraphError::ReplacementGraphIndexError(_) => { + PyIndexError::new_err(format!("{}", value)) + } + }, + } + } +} + impl From> for RxPyErr { fn from(value: TopologicalSortError) -> Self { RxPyErr { From a2f4cafa6847fadef2820b33aee68e13252640de Mon Sep 17 00:00:00 2001 From: Kevin Hartman Date: Thu, 22 Aug 2024 12:11:53 -0400 Subject: [PATCH 3/5] Update comment. --- rustworkx-core/src/graph_ext/substitution.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rustworkx-core/src/graph_ext/substitution.rs b/rustworkx-core/src/graph_ext/substitution.rs index 4d85eb14e..49073290a 100644 --- a/rustworkx-core/src/graph_ext/substitution.rs +++ b/rustworkx-core/src/graph_ext/substitution.rs @@ -10,7 +10,7 @@ // License for the specific language governing permissions and limitations // under the License. -//! This module defines graph traits for node contraction. +//! This module defines graph traits for node substitution. use crate::dictmap::{DictMap, InitWithHasher}; use petgraph::data::DataMap; From aeb8e30665b4e98f6493506730ac564d3086c284 Mon Sep 17 00:00:00 2001 From: Kevin Hartman Date: Mon, 16 Sep 2024 15:09:15 -0400 Subject: [PATCH 4/5] Combine callback failure error variants. This makes the interface a lot simpler, especially given that most clients are unlikely to need to raise errors in the callbacks at all (we only need it to support Python's error propagation). --- rustworkx-core/src/graph_ext/substitution.rs | 116 +++++++++---------- src/lib.rs | 8 +- 2 files changed, 61 insertions(+), 63 deletions(-) diff --git a/rustworkx-core/src/graph_ext/substitution.rs b/rustworkx-core/src/graph_ext/substitution.rs index 49073290a..467e215d7 100644 --- a/rustworkx-core/src/graph_ext/substitution.rs +++ b/rustworkx-core/src/graph_ext/substitution.rs @@ -25,50 +25,45 @@ use std::fmt::{Debug, Display, Formatter}; use std::hash::Hash; #[derive(Debug)] -pub enum SubstituteNodeWithGraphError { +pub enum SubstituteNodeWithGraphError { ReplacementGraphIndexError(N), - EdgeMapErr(EME), - NodeFilterErr(NFE), - EdgeWeightTransformErr(ETE), + CallbackError(E), } -impl Display - for SubstituteNodeWithGraphError -{ +impl Display for SubstituteNodeWithGraphError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { SubstituteNodeWithGraphError::ReplacementGraphIndexError(n) => { write!(f, "Node {:?} was not found in the replacement graph.", n) } - SubstituteNodeWithGraphError::EdgeMapErr(e) => { - write!(f, "Edge map callback failed with: {}", e) - } - SubstituteNodeWithGraphError::NodeFilterErr(e) => { - write!(f, "Node filter callback failed with: {}", e) - } - SubstituteNodeWithGraphError::EdgeWeightTransformErr(e) => { - write!(f, "Edge weight transform callback failed with: {}", e) + SubstituteNodeWithGraphError::CallbackError(e) => { + write!(f, "Callback failed with: {}", e) } } } } -impl Error - for SubstituteNodeWithGraphError -{ -} +impl Error for SubstituteNodeWithGraphError {} pub struct NoCallback; pub trait NodeFilter { - type Error; - fn filter(&mut self, graph: &G, node: G::NodeId) -> Result; + type CallbackError; + fn filter( + &mut self, + graph: &G, + node: G::NodeId, + ) -> Result>; } impl NodeFilter for NoCallback { - type Error = Infallible; + type CallbackError = Infallible; #[inline] - fn filter(&mut self, _graph: &G, _node: G::NodeId) -> Result { + fn filter( + &mut self, + _graph: &G, + _node: G::NodeId, + ) -> Result> { Ok(true) } } @@ -78,11 +73,15 @@ where G0: GraphBase + DataMap, F: FnMut(&G0::NodeWeight) -> Result, { - type Error = E; + type CallbackError = E; #[inline] - fn filter(&mut self, graph: &G0, node: G0::NodeId) -> Result { + fn filter( + &mut self, + graph: &G0, + node: G0::NodeId, + ) -> Result> { if let Some(x) = graph.node_weight(node) { - self(x) + self(x).map_err(|e| SubstituteNodeWithGraphError::CallbackError(e)) } else { Ok(false) } @@ -90,19 +89,28 @@ where } pub trait EdgeWeightMapper { - type Error; + type CallbackError; type MappedWeight; - fn map(&mut self, graph: &G, edge: G::EdgeId) -> Result; + fn map( + &mut self, + graph: &G, + edge: G::EdgeId, + ) -> Result>; } impl EdgeWeightMapper for NoCallback where G::EdgeWeight: Clone, { - type Error = Infallible; + type CallbackError = Infallible; type MappedWeight = G::EdgeWeight; #[inline] - fn map(&mut self, graph: &G, edge: G::EdgeId) -> Result { + fn map( + &mut self, + graph: &G, + edge: G::EdgeId, + ) -> Result> + { Ok(graph.edge_weight(edge).unwrap().clone()) } } @@ -112,23 +120,27 @@ where G: GraphBase + DataMap, F: FnMut(&G::EdgeWeight) -> Result, { - type Error = E; + type CallbackError = E; type MappedWeight = EW; #[inline] - fn map(&mut self, graph: &G, edge: G::EdgeId) -> Result { + fn map( + &mut self, + graph: &G, + edge: G::EdgeId, + ) -> Result> + { if let Some(x) = graph.edge_weight(edge) { - self(x) + self(x).map_err(|e| SubstituteNodeWithGraphError::CallbackError(e)) } else { panic!("Edge MUST exist in graph.") } } } -pub trait SubstituteNodeWithGraph: DataMap { - /// The error type returned by the substitution. - type Error: Error; +pub type SubstituteNodeWithGraphResult = Result>; +pub trait SubstituteNodeWithGraph: DataMap { /// Substitute a node with a Graph. /// /// The nodes and edges of Graph `other` are cloned into this @@ -178,7 +190,7 @@ pub trait SubstituteNodeWithGraph: DataMap { edge_map_fn: EM, node_filter: NF, edge_weight_map: ET, - ) -> Result, Self::Error> + ) -> Result, SubstituteNodeWithGraphError> where G: Data + DataMap + NodeCount, G::NodeId: Debug + Hash + Eq, @@ -188,10 +200,8 @@ pub trait SubstituteNodeWithGraph: DataMap { + IntoNodeReferences + IntoEdgeReferences, EM: FnMut(Direction, Self::NodeId, &Self::EdgeWeight) -> Result, EME>, - NF: NodeFilter, - ET: EdgeWeightMapper, - NF::Error: Error, - ET::Error: Error; + NF: NodeFilter, + ET: EdgeWeightMapper; } impl SubstituteNodeWithGraph for stable_graph::StableGraph @@ -199,9 +209,6 @@ where Ix: stable_graph::IndexType, E: Clone, { - type Error = - SubstituteNodeWithGraphError; - fn substitute_node_with_graph( &mut self, node: Self::NodeId, @@ -209,7 +216,7 @@ where mut edge_map_fn: EM, mut node_filter: NF, mut edge_weight_map: ET, - ) -> Result, Self::Error> + ) -> Result, SubstituteNodeWithGraphError> where G: Data + DataMap + NodeCount, G::NodeId: Debug + Hash + Eq, @@ -219,10 +226,8 @@ where + IntoNodeReferences + IntoEdgeReferences, EM: FnMut(Direction, Self::NodeId, &Self::EdgeWeight) -> Result, EME>, - NF: NodeFilter, - ET: EdgeWeightMapper, - NF::Error: Error, - ET::Error: Error, + NF: NodeFilter, + ET: EdgeWeightMapper, { let node_index = node; if self.node_weight(node_index).is_none() { @@ -232,10 +237,7 @@ where let mut out_map: DictMap = DictMap::with_capacity(other.node_count()); for node in other.node_references() { - if !node_filter - .filter(other, node.id()) - .map_err(|e| SubstituteNodeWithGraphError::NodeFilterErr(e))? - { + if !node_filter.filter(other, node.id())? { continue; } let new_index = self.add_node(node.weight().clone()); @@ -255,9 +257,7 @@ where self.add_edge( out_map[&edge.source()], out_map[&edge.target()], - edge_weight_map - .map(other, edge.id()) - .map_err(|e| SubstituteNodeWithGraphError::EdgeWeightTransformErr(e))?, + edge_weight_map.map(other, edge.id())?, ); } // Add edges to/from node to nodes in other @@ -266,7 +266,7 @@ where .map(|edge| { let Some(target_in_other) = edge_map_fn(Direction::Incoming, edge.source(), edge.weight()) - .map_err(|e| SubstituteNodeWithGraphError::EdgeMapErr(e))? + .map_err(|e| SubstituteNodeWithGraphError::CallbackError(e))? else { return Ok(None); }; @@ -287,7 +287,7 @@ where .map(|edge| { let Some(source_in_other) = edge_map_fn(Direction::Outgoing, edge.target(), edge.weight()) - .map_err(|e| SubstituteNodeWithGraphError::EdgeMapErr(e))? + .map_err(|e| SubstituteNodeWithGraphError::CallbackError(e))? else { return Ok(None); }; diff --git a/src/lib.rs b/src/lib.rs index a78251d85..8866523a7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -146,13 +146,11 @@ impl From> for RxPyErr { } } -impl From> for RxPyErr { - fn from(value: SubstituteNodeWithGraphError) -> Self { +impl From> for RxPyErr { + fn from(value: SubstituteNodeWithGraphError) -> Self { RxPyErr { pyerr: match value { - SubstituteNodeWithGraphError::EdgeMapErr(e) - | SubstituteNodeWithGraphError::NodeFilterErr(e) - | SubstituteNodeWithGraphError::EdgeWeightTransformErr(e) => e, + SubstituteNodeWithGraphError::CallbackError(e) => e, SubstituteNodeWithGraphError::ReplacementGraphIndexError(_) => { PyIndexError::new_err(format!("{}", value)) } From 93097684bb3b7ff6f576eaee1bb43b700b649cc9 Mon Sep 17 00:00:00 2001 From: Kevin Hartman Date: Mon, 16 Sep 2024 15:21:36 -0400 Subject: [PATCH 5/5] Add custom type alias for return type. --- rustworkx-core/src/graph_ext/substitution.rs | 48 ++++++++++---------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/rustworkx-core/src/graph_ext/substitution.rs b/rustworkx-core/src/graph_ext/substitution.rs index 467e215d7..987de68fc 100644 --- a/rustworkx-core/src/graph_ext/substitution.rs +++ b/rustworkx-core/src/graph_ext/substitution.rs @@ -45,6 +45,8 @@ impl Display for SubstituteNodeWithGraphError { impl Error for SubstituteNodeWithGraphError {} +pub type SubstitutionResult = Result>; + pub struct NoCallback; pub trait NodeFilter { @@ -53,7 +55,7 @@ pub trait NodeFilter { &mut self, graph: &G, node: G::NodeId, - ) -> Result>; + ) -> SubstitutionResult; } impl NodeFilter for NoCallback { @@ -63,23 +65,23 @@ impl NodeFilter for NoCallback { &mut self, _graph: &G, _node: G::NodeId, - ) -> Result> { + ) -> SubstitutionResult { Ok(true) } } -impl NodeFilter for F +impl NodeFilter for F where - G0: GraphBase + DataMap, - F: FnMut(&G0::NodeWeight) -> Result, + G: GraphBase + DataMap, + F: FnMut(&G::NodeWeight) -> Result, { type CallbackError = E; #[inline] fn filter( &mut self, - graph: &G0, - node: G0::NodeId, - ) -> Result> { + graph: &G, + node: G::NodeId, + ) -> SubstitutionResult { if let Some(x) = graph.node_weight(node) { self(x).map_err(|e| SubstituteNodeWithGraphError::CallbackError(e)) } else { @@ -95,7 +97,7 @@ pub trait EdgeWeightMapper { &mut self, graph: &G, edge: G::EdgeId, - ) -> Result>; + ) -> SubstitutionResult; } impl EdgeWeightMapper for NoCallback @@ -109,8 +111,7 @@ where &mut self, graph: &G, edge: G::EdgeId, - ) -> Result> - { + ) -> SubstitutionResult { Ok(graph.edge_weight(edge).unwrap().clone()) } } @@ -128,8 +129,7 @@ where &mut self, graph: &G, edge: G::EdgeId, - ) -> Result> - { + ) -> SubstitutionResult { if let Some(x) = graph.edge_weight(edge) { self(x).map_err(|e| SubstituteNodeWithGraphError::CallbackError(e)) } else { @@ -138,8 +138,6 @@ where } } -pub type SubstituteNodeWithGraphResult = Result>; - pub trait SubstituteNodeWithGraph: DataMap { /// Substitute a node with a Graph. /// @@ -183,14 +181,14 @@ pub trait SubstituteNodeWithGraph: DataMap { /// This method returns a mapping of nodes in `other` to the copied node in /// this graph. #[allow(clippy::type_complexity)] - fn substitute_node_with_graph( + fn substitute_node_with_graph( &mut self, node: Self::NodeId, other: &G, edge_map_fn: EM, node_filter: NF, edge_weight_map: ET, - ) -> Result, SubstituteNodeWithGraphError> + ) -> SubstitutionResult, G::NodeId, E> where G: Data + DataMap + NodeCount, G::NodeId: Debug + Hash + Eq, @@ -199,9 +197,9 @@ pub trait SubstituteNodeWithGraph: DataMap { + Data + IntoNodeReferences + IntoEdgeReferences, - EM: FnMut(Direction, Self::NodeId, &Self::EdgeWeight) -> Result, EME>, - NF: NodeFilter, - ET: EdgeWeightMapper; + EM: FnMut(Direction, Self::NodeId, &Self::EdgeWeight) -> Result, E>, + NF: NodeFilter, + ET: EdgeWeightMapper; } impl SubstituteNodeWithGraph for stable_graph::StableGraph @@ -209,14 +207,14 @@ where Ix: stable_graph::IndexType, E: Clone, { - fn substitute_node_with_graph( + fn substitute_node_with_graph( &mut self, node: Self::NodeId, other: &G, mut edge_map_fn: EM, mut node_filter: NF, mut edge_weight_map: ET, - ) -> Result, SubstituteNodeWithGraphError> + ) -> SubstitutionResult, G::NodeId, ER> where G: Data + DataMap + NodeCount, G::NodeId: Debug + Hash + Eq, @@ -225,9 +223,9 @@ where + Data + IntoNodeReferences + IntoEdgeReferences, - EM: FnMut(Direction, Self::NodeId, &Self::EdgeWeight) -> Result, EME>, - NF: NodeFilter, - ET: EdgeWeightMapper, + EM: FnMut(Direction, Self::NodeId, &Self::EdgeWeight) -> Result, ER>, + NF: NodeFilter, + ET: EdgeWeightMapper, { let node_index = node; if self.node_weight(node_index).is_none() {