From c76d10818ce2122dad12b354980d7d315ef40213 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Fri, 25 Oct 2024 12:36:28 +0200 Subject: [PATCH] set_sequence working properly --- src/model.rs | 52 +++++++--- src/updates/mod.rs | 2 +- src/updates/posterior/mod.rs | 3 +- .../exponential.rs | 2 +- src/updates/prediction_error/mod.rs | 1 + src/utils/function_pointer.rs | 18 ++++ src/utils/mod.rs | 2 + src/{utils.rs => utils/set_sequence.rs} | 98 ++++++++++++------- 8 files changed, 128 insertions(+), 50 deletions(-) rename src/updates/{posterior => prediction_error}/exponential.rs (91%) create mode 100644 src/utils/function_pointer.rs create mode 100644 src/utils/mod.rs rename src/{utils.rs => utils/set_sequence.rs} (65%) diff --git a/src/model.rs b/src/model.rs index 2b5845f8..5f23e09e 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; -use crate::updates::observations::observation_update; -use crate::utils::get_update_sequence; +use crate::{updates::observations::observation_update, utils::function_pointer::FnType}; +use crate::utils::set_sequence::set_update_sequence; +use crate::utils::function_pointer::get_func_map; +use pyo3::types::PyTuple; use pyo3::{prelude::*, types::{PyList, PyDict}}; #[derive(Debug)] @@ -39,9 +41,6 @@ pub enum Node { Exponential(ExponentialFamiliyStateNode), } -// Create a default signature for update functions -pub type FnType = fn(&mut Network, usize); - #[derive(Debug)] pub struct UpdateSequence { pub predictions: Vec<(usize, FnType)>, @@ -79,10 +78,10 @@ impl Network { /// * `value_children` - The indexes of the node's value children. /// * `volatility_children` - The indexes of the node's volatility children. /// * `volatility_parents` - The indexes of the node's volatility parents. - #[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_children=None, volatility_parents=None))] + #[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_parents=None, volatility_children=None,))] pub fn add_nodes(&mut self, kind: &str, value_parents: Option>, - value_children: Option>, volatility_children: Option>, - volatility_parents: Option>) { + value_children: Option>, + volatility_parents: Option>, volatility_children: Option>, ) { // the node ID is equal to the number of nodes already in the network let node_id: usize = self.nodes.len(); @@ -93,8 +92,8 @@ impl Network { } let edges = AdjacencyLists{ - value_parents: value_children, - value_children: value_parents, + value_parents: value_parents, + value_children: value_children, volatility_parents: volatility_parents, volatility_children: volatility_children, }; @@ -122,8 +121,8 @@ impl Network { } } - pub fn get_update_sequence(&mut self) { - self.update_sequence = get_update_sequence(self); + pub fn set_update_sequence(&mut self) { + self.update_sequence = set_update_sequence(self); } /// Single time slice belief propagation. @@ -189,6 +188,35 @@ impl Network { } Ok(py_list) } + + #[getter] + pub fn get_update_sequence<'py>(&self, py: Python<'py>) -> PyResult<&'py PyList> { + + let func_map = get_func_map(); + let py_list = PyList::empty(py); + // Iterate over the Rust vector and convert each tuple + for &(num, func) in self.update_sequence.predictions.iter() { + // Retrieve the function name from the map + let func_name = func_map.get(&func).unwrap_or(&"unknown"); + + // Convert the Rust tuple to a Python tuple with the function name as a string + let py_tuple = PyTuple::new(py, &[num.into_py(py), (*func_name).into_py(py)]); + + // Append the Python tuple to the Python list + py_list.append(py_tuple)?; + } + for &(num, func) in self.update_sequence.updates.iter() { + // Retrieve the function name from the map + let func_name = func_map.get(&func).unwrap_or(&"unknown"); + + // Convert the Rust tuple to a Python tuple with the function name as a string + let py_tuple = PyTuple::new(py, &[num.into_py(py), (*func_name).into_py(py)]); + + // Append the Python tuple to the Python list + py_list.append(py_tuple)?; + } + Ok(py_list) + } } // Create a module to expose the class to Python diff --git a/src/updates/mod.rs b/src/updates/mod.rs index 28479cc7..ad43beac 100644 --- a/src/updates/mod.rs +++ b/src/updates/mod.rs @@ -1,4 +1,4 @@ pub mod posterior; pub mod prediction; pub mod prediction_error; -pub mod observations; \ No newline at end of file +pub mod observations; diff --git a/src/updates/posterior/mod.rs b/src/updates/posterior/mod.rs index d9281618..6817d49e 100644 --- a/src/updates/posterior/mod.rs +++ b/src/updates/posterior/mod.rs @@ -1,2 +1 @@ -pub mod continuous; -pub mod exponential; \ No newline at end of file +pub mod continuous; \ No newline at end of file diff --git a/src/updates/posterior/exponential.rs b/src/updates/prediction_error/exponential.rs similarity index 91% rename from src/updates/posterior/exponential.rs rename to src/updates/prediction_error/exponential.rs index 20dcf2d1..35695b9c 100644 --- a/src/updates/posterior/exponential.rs +++ b/src/updates/prediction_error/exponential.rs @@ -9,7 +9,7 @@ use crate::math::sufficient_statistics; /// /// # Returns /// * `network` - The network after message passing. -pub fn posterior_update_exponential_state_node(network: &mut Network, node_idx: usize) { +pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize) { match network.nodes.get_mut(&node_idx) { Some(Node::Exponential(ref mut node)) => { diff --git a/src/updates/prediction_error/mod.rs b/src/updates/prediction_error/mod.rs index c53f4e7a..810699ed 100644 --- a/src/updates/prediction_error/mod.rs +++ b/src/updates/prediction_error/mod.rs @@ -1 +1,2 @@ pub mod continuous; +pub mod exponential; diff --git a/src/utils/function_pointer.rs b/src/utils/function_pointer.rs new file mode 100644 index 00000000..163c12d1 --- /dev/null +++ b/src/utils/function_pointer.rs @@ -0,0 +1,18 @@ +use std::collections::HashMap; + +use crate::{model::Network, updates::{posterior::continuous::posterior_update_continuous_state_node, prediction::continuous::prediction_continuous_state_node, prediction_error::{continuous::prediction_error_continuous_state_node, exponential::prediction_error_exponential_state_node}}}; + +// Create a default signature for update functions +pub type FnType = for<'a> fn(&'a mut Network, usize); + +pub fn get_func_map() -> HashMap { + let function_map: HashMap = [ + (posterior_update_continuous_state_node as FnType, "posterior_update_continuous_state_node"), + (prediction_continuous_state_node as FnType, "prediction_continuous_state_node"), + (prediction_error_continuous_state_node as FnType, "prediction_error_continuous_state_node"), + (prediction_error_exponential_state_node as FnType, "prediction_error_exponential_state_node"), + ] + .into_iter() + .collect(); + function_map +} \ No newline at end of file diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 00000000..1918ac54 --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1,2 @@ +pub mod set_sequence; +pub mod function_pointer; \ No newline at end of file diff --git a/src/utils.rs b/src/utils/set_sequence.rs similarity index 65% rename from src/utils.rs rename to src/utils/set_sequence.rs index 178d1ba6..2a130980 100644 --- a/src/utils.rs +++ b/src/utils/set_sequence.rs @@ -1,6 +1,7 @@ -use crate::{model::{FnType, Network, Node, UpdateSequence}, updates::{posterior::{continuous::posterior_update_continuous_state_node, exponential::posterior_update_exponential_state_node}, prediction::continuous::prediction_continuous_state_node, prediction_error::continuous::prediction_error_continuous_state_node}}; +use crate::{model::{Network, Node, UpdateSequence}, updates::{posterior::continuous::posterior_update_continuous_state_node, prediction::continuous::prediction_continuous_state_node, prediction_error::{continuous::prediction_error_continuous_state_node, exponential::prediction_error_exponential_state_node}}}; +use crate::utils::function_pointer::FnType; -pub fn get_update_sequence(network: &Network) -> UpdateSequence { +pub fn set_update_sequence(network: &Network) -> UpdateSequence { let predictions = get_predictions_sequence(network); let updates = get_updates_sequence(network); @@ -40,8 +41,8 @@ pub fn get_predictions_sequence(network: &Network) -> Vec<(usize, FnType)> { // If both are Some, merge the vectors (Some(ref vec1), Some(ref vec2)) => { // Create a new vector by merging the two - let merged_vec: Vec = vec1.iter().chain(vec2.iter()).cloned().collect(); - Some(merged_vec) // Return the merged vector wrapped in Some + let vec: Vec = vec1.iter().chain(vec2.iter()).cloned().collect(); + Some(vec) // Return the merged vector wrapped in Some } // If one is Some and the other is None, return the one that's Some (Some(vec), None) | (None, Some(vec)) => Some(vec.clone()), @@ -99,7 +100,7 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { po_nodes_idxs.retain(|x| !network.inputs.contains(x)); // iterate over all nodes and add the prediction step if all criteria are met - let mut n_remaining = 2 * pe_nodes_idxs.len(); // posterior updates + prediction errors + let mut n_remaining = po_nodes_idxs.len() + pe_nodes_idxs.len(); // posterior updates + prediction errors while n_remaining > 0 { @@ -107,6 +108,7 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { let mut has_update = false; // loop over all the remaining nodes for prediction errors --------------------- + // ----------------------------------------------------------------------------- for i in 0..pe_nodes_idxs.len() { let idx = pe_nodes_idxs[i]; @@ -114,25 +116,42 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { // to send a prediction error, this node should have been updated first if !(po_nodes_idxs.contains(&idx)) { + // only send a prediction error if this node has any parent + let value_parents_idxs = &network.edges[idx].value_parents; + let volatility_parents_idxs = &network.edges[idx].volatility_parents; + + let has_parents = match (value_parents_idxs, volatility_parents_idxs) { + // If both are None, return false + (None, None) => false, + _ => true, + }; + // add the node in the update list - match network.nodes.get(&idx) { - Some(Node::Continuous(_)) => { + match (network.nodes.get(&idx), has_parents) { + (Some(Node::Continuous(_)), true) => { updates.push((idx, prediction_error_continuous_state_node)); + // remove the node from the to-be-updated list + pe_nodes_idxs.retain(|&x| x != idx); + n_remaining -= 1; + has_update = true; + break; } - Some(Node::Exponential(_)) => (), - None => () + (Some(Node::Exponential(_)), _) => { + updates.push((idx, prediction_error_exponential_state_node)); + // remove the node from the to-be-updated list + pe_nodes_idxs.retain(|&x| x != idx); + n_remaining -= 1; + has_update = true; + break; + } + _ => () } - - // remove the node from the to-be-updated list - pe_nodes_idxs.retain(|&x| x != idx); - n_remaining -= 1; - has_update = true; - break; } } // loop over all the remaining nodes for posterior updates --------------------- + // ----------------------------------------------------------------------------- for i in 0..po_nodes_idxs.len() { let idx = po_nodes_idxs[i]; @@ -163,17 +182,14 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { }; // 3. if false, add the posterior update to the list - if !(missing_pe) { + if !missing_pe { // add the node in the update list match network.nodes.get(&idx) { Some(Node::Continuous(_)) => { updates.push((idx, posterior_update_continuous_state_node)); } - Some(Node::Exponential(_)) => { - updates.push((idx, posterior_update_exponential_state_node)); - } - None => () + _ => () } @@ -183,9 +199,7 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { has_update = true; break; } - } - // 2. get update sequence ---------------------------------------------------------- - + } if !(has_update) { break; } @@ -197,46 +211,62 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { // Tests module for unit tests #[cfg(test)] // Only compile and include this module when running tests mod tests { + use crate::utils::function_pointer::get_func_map; + use super::*; // Import the parent module's items to test them #[test] fn test_get_update_order() { + let func_map = get_func_map(); + // initialize network - let mut network = Network::new(); + let mut hgf_network = Network::new(); // create a network - network.add_nodes( + hgf_network.add_nodes( "continuous-state", Some(vec![1]), None, - None, Some(vec![2]), + None, ); - network.add_nodes( + hgf_network.add_nodes( "continuous-state", None, Some(vec![0]), None, None, ); - network.add_nodes( + hgf_network.add_nodes( "continuous-state", None, None, - Some(vec![0]), None, + Some(vec![0]), ); - network.add_nodes( - "exponential-node", + hgf_network.set_update_sequence(); + + println!("Prediction sequence ----------"); + println!("Node: {} - Function name: {}", &hgf_network.update_sequence.predictions[0].0, func_map.get(&hgf_network.update_sequence.predictions[0].1).unwrap_or(&"unknown")); + println!("Node: {} - Function name: {}", &hgf_network.update_sequence.predictions[1].0, func_map.get(&hgf_network.update_sequence.predictions[1].1).unwrap_or(&"unknown")); + println!("Node: {} - Function name: {}", &hgf_network.update_sequence.predictions[2].0, func_map.get(&hgf_network.update_sequence.predictions[2].1).unwrap_or(&"unknown")); + println!("Update sequence ----------"); + println!("Node: {} - Function name: {}", &hgf_network.update_sequence.updates[0].0, func_map.get(&hgf_network.update_sequence.updates[0].1).unwrap_or(&"unknown")); + println!("Node: {} - Function name: {}", &hgf_network.update_sequence.updates[1].0, func_map.get(&hgf_network.update_sequence.updates[1].1).unwrap_or(&"unknown")); + println!("Node: {} - Function name: {}", &hgf_network.update_sequence.updates[2].0, func_map.get(&hgf_network.update_sequence.updates[2].1).unwrap_or(&"unknown")); + + // initialize network + let mut exp_network = Network::new(); + exp_network.add_nodes( + "exponential-state", None, None, None, None, ); + exp_network.set_update_sequence(); + println!("Node: {} - Function name: {}", &exp_network.update_sequence.updates[0].0, func_map.get(&exp_network.update_sequence.updates[0].1).unwrap_or(&"unknown")); - println!("Network: {:?}", network); - println!("Update order: {:?}", get_update_sequence(&network)); - } }