Skip to content

Commit

Permalink
set_sequence working properly
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Oct 25, 2024
1 parent db5910c commit c76d108
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 50 deletions.
52 changes: 40 additions & 12 deletions src/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::collections::HashMap;
use crate::updates::observations::observation_update;
use crate::utils::get_update_sequence;
use crate::{updates::observations::observation_update, utils::function_pointer::FnType};
use crate::utils::set_sequence::set_update_sequence;
use crate::utils::function_pointer::get_func_map;
use pyo3::types::PyTuple;
use pyo3::{prelude::*, types::{PyList, PyDict}};

#[derive(Debug)]
Expand Down Expand Up @@ -39,9 +41,6 @@ pub enum Node {
Exponential(ExponentialFamiliyStateNode),
}

// Create a default signature for update functions
pub type FnType = fn(&mut Network, usize);

#[derive(Debug)]
pub struct UpdateSequence {
pub predictions: Vec<(usize, FnType)>,
Expand Down Expand Up @@ -79,10 +78,10 @@ impl Network {
/// * `value_children` - The indexes of the node's value children.
/// * `volatility_children` - The indexes of the node's volatility children.
/// * `volatility_parents` - The indexes of the node's volatility parents.
#[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_children=None, volatility_parents=None))]
#[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_parents=None, volatility_children=None,))]
pub fn add_nodes(&mut self, kind: &str, value_parents: Option<Vec<usize>>,
value_children: Option<Vec<usize>>, volatility_children: Option<Vec<usize>>,
volatility_parents: Option<Vec<usize>>) {
value_children: Option<Vec<usize>>,
volatility_parents: Option<Vec<usize>>, volatility_children: Option<Vec<usize>>, ) {

// the node ID is equal to the number of nodes already in the network
let node_id: usize = self.nodes.len();
Expand All @@ -93,8 +92,8 @@ impl Network {
}

let edges = AdjacencyLists{
value_parents: value_children,
value_children: value_parents,
value_parents: value_parents,
value_children: value_children,
volatility_parents: volatility_parents,
volatility_children: volatility_children,
};
Expand Down Expand Up @@ -122,8 +121,8 @@ impl Network {
}
}

pub fn get_update_sequence(&mut self) {
self.update_sequence = get_update_sequence(self);
pub fn set_update_sequence(&mut self) {
self.update_sequence = set_update_sequence(self);
}

/// Single time slice belief propagation.
Expand Down Expand Up @@ -189,6 +188,35 @@ impl Network {
}
Ok(py_list)
}

#[getter]
pub fn get_update_sequence<'py>(&self, py: Python<'py>) -> PyResult<&'py PyList> {

let func_map = get_func_map();
let py_list = PyList::empty(py);
// Iterate over the Rust vector and convert each tuple
for &(num, func) in self.update_sequence.predictions.iter() {
// Retrieve the function name from the map
let func_name = func_map.get(&func).unwrap_or(&"unknown");

// Convert the Rust tuple to a Python tuple with the function name as a string
let py_tuple = PyTuple::new(py, &[num.into_py(py), (*func_name).into_py(py)]);

// Append the Python tuple to the Python list
py_list.append(py_tuple)?;
}
for &(num, func) in self.update_sequence.updates.iter() {
// Retrieve the function name from the map
let func_name = func_map.get(&func).unwrap_or(&"unknown");

// Convert the Rust tuple to a Python tuple with the function name as a string
let py_tuple = PyTuple::new(py, &[num.into_py(py), (*func_name).into_py(py)]);

// Append the Python tuple to the Python list
py_list.append(py_tuple)?;
}
Ok(py_list)
}
}

// Create a module to expose the class to Python
Expand Down
2 changes: 1 addition & 1 deletion src/updates/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod posterior;
pub mod prediction;
pub mod prediction_error;
pub mod observations;
pub mod observations;
3 changes: 1 addition & 2 deletions src/updates/posterior/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
pub mod continuous;
pub mod exponential;
pub mod continuous;
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::math::sufficient_statistics;
///
/// # Returns
/// * `network` - The network after message passing.
pub fn posterior_update_exponential_state_node(network: &mut Network, node_idx: usize) {
pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize) {

match network.nodes.get_mut(&node_idx) {
Some(Node::Exponential(ref mut node)) => {
Expand Down
1 change: 1 addition & 0 deletions src/updates/prediction_error/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod continuous;
pub mod exponential;
18 changes: 18 additions & 0 deletions src/utils/function_pointer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use std::collections::HashMap;

use crate::{model::Network, updates::{posterior::continuous::posterior_update_continuous_state_node, prediction::continuous::prediction_continuous_state_node, prediction_error::{continuous::prediction_error_continuous_state_node, exponential::prediction_error_exponential_state_node}}};

// Create a default signature for update functions
pub type FnType = for<'a> fn(&'a mut Network, usize);

pub fn get_func_map() -> HashMap<FnType, &'static str> {
let function_map: HashMap<FnType, &str> = [
(posterior_update_continuous_state_node as FnType, "posterior_update_continuous_state_node"),
(prediction_continuous_state_node as FnType, "prediction_continuous_state_node"),
(prediction_error_continuous_state_node as FnType, "prediction_error_continuous_state_node"),
(prediction_error_exponential_state_node as FnType, "prediction_error_exponential_state_node"),
]
.into_iter()
.collect();
function_map
}
2 changes: 2 additions & 0 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod set_sequence;
pub mod function_pointer;
98 changes: 64 additions & 34 deletions src/utils.rs → src/utils/set_sequence.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{model::{FnType, Network, Node, UpdateSequence}, updates::{posterior::{continuous::posterior_update_continuous_state_node, exponential::posterior_update_exponential_state_node}, prediction::continuous::prediction_continuous_state_node, prediction_error::continuous::prediction_error_continuous_state_node}};
use crate::{model::{Network, Node, UpdateSequence}, updates::{posterior::continuous::posterior_update_continuous_state_node, prediction::continuous::prediction_continuous_state_node, prediction_error::{continuous::prediction_error_continuous_state_node, exponential::prediction_error_exponential_state_node}}};
use crate::utils::function_pointer::FnType;

pub fn get_update_sequence(network: &Network) -> UpdateSequence {
pub fn set_update_sequence(network: &Network) -> UpdateSequence {
let predictions = get_predictions_sequence(network);
let updates = get_updates_sequence(network);

Expand Down Expand Up @@ -40,8 +41,8 @@ pub fn get_predictions_sequence(network: &Network) -> Vec<(usize, FnType)> {
// If both are Some, merge the vectors
(Some(ref vec1), Some(ref vec2)) => {
// Create a new vector by merging the two
let merged_vec: Vec<usize> = vec1.iter().chain(vec2.iter()).cloned().collect();
Some(merged_vec) // Return the merged vector wrapped in Some
let vec: Vec<usize> = vec1.iter().chain(vec2.iter()).cloned().collect();
Some(vec) // Return the merged vector wrapped in Some
}
// If one is Some and the other is None, return the one that's Some
(Some(vec), None) | (None, Some(vec)) => Some(vec.clone()),
Expand Down Expand Up @@ -99,40 +100,58 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> {
po_nodes_idxs.retain(|x| !network.inputs.contains(x));

// iterate over all nodes and add the prediction step if all criteria are met
let mut n_remaining = 2 * pe_nodes_idxs.len(); // posterior updates + prediction errors
let mut n_remaining = po_nodes_idxs.len() + pe_nodes_idxs.len(); // posterior updates + prediction errors

while n_remaining > 0 {

// were we able to add an update step in the list on that iteration?
let mut has_update = false;

// loop over all the remaining nodes for prediction errors ---------------------
// -----------------------------------------------------------------------------
for i in 0..pe_nodes_idxs.len() {

let idx = pe_nodes_idxs[i];

// to send a prediction error, this node should have been updated first
if !(po_nodes_idxs.contains(&idx)) {

// only send a prediction error if this node has any parent
let value_parents_idxs = &network.edges[idx].value_parents;
let volatility_parents_idxs = &network.edges[idx].volatility_parents;

let has_parents = match (value_parents_idxs, volatility_parents_idxs) {
// If both are None, return false
(None, None) => false,
_ => true,
};

// add the node in the update list
match network.nodes.get(&idx) {
Some(Node::Continuous(_)) => {
match (network.nodes.get(&idx), has_parents) {
(Some(Node::Continuous(_)), true) => {
updates.push((idx, prediction_error_continuous_state_node));
// remove the node from the to-be-updated list
pe_nodes_idxs.retain(|&x| x != idx);
n_remaining -= 1;
has_update = true;
break;
}
Some(Node::Exponential(_)) => (),
None => ()
(Some(Node::Exponential(_)), _) => {
updates.push((idx, prediction_error_exponential_state_node));
// remove the node from the to-be-updated list
pe_nodes_idxs.retain(|&x| x != idx);
n_remaining -= 1;
has_update = true;
break;
}
_ => ()

}

// remove the node from the to-be-updated list
pe_nodes_idxs.retain(|&x| x != idx);
n_remaining -= 1;
has_update = true;
break;
}
}

// loop over all the remaining nodes for posterior updates ---------------------
// -----------------------------------------------------------------------------
for i in 0..po_nodes_idxs.len() {

let idx = po_nodes_idxs[i];
Expand Down Expand Up @@ -163,17 +182,14 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> {
};

// 3. if false, add the posterior update to the list
if !(missing_pe) {
if !missing_pe {

// add the node in the update list
match network.nodes.get(&idx) {
Some(Node::Continuous(_)) => {
updates.push((idx, posterior_update_continuous_state_node));
}
Some(Node::Exponential(_)) => {
updates.push((idx, posterior_update_exponential_state_node));
}
None => ()
_ => ()

}

Expand All @@ -183,9 +199,7 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> {
has_update = true;
break;
}
}
// 2. get update sequence ----------------------------------------------------------

}
if !(has_update) {
break;
}
Expand All @@ -197,46 +211,62 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> {
// Tests module for unit tests
#[cfg(test)] // Only compile and include this module when running tests
mod tests {
use crate::utils::function_pointer::get_func_map;

use super::*; // Import the parent module's items to test them

#[test]
fn test_get_update_order() {

let func_map = get_func_map();

// initialize network
let mut network = Network::new();
let mut hgf_network = Network::new();

// create a network
network.add_nodes(
hgf_network.add_nodes(
"continuous-state",
Some(vec![1]),
None,
None,
Some(vec![2]),
None,
);
network.add_nodes(
hgf_network.add_nodes(
"continuous-state",
None,
Some(vec![0]),
None,
None,
);
network.add_nodes(
hgf_network.add_nodes(
"continuous-state",
None,
None,
Some(vec![0]),
None,
Some(vec![0]),
);
network.add_nodes(
"exponential-node",
hgf_network.set_update_sequence();

println!("Prediction sequence ----------");
println!("Node: {} - Function name: {}", &hgf_network.update_sequence.predictions[0].0, func_map.get(&hgf_network.update_sequence.predictions[0].1).unwrap_or(&"unknown"));
println!("Node: {} - Function name: {}", &hgf_network.update_sequence.predictions[1].0, func_map.get(&hgf_network.update_sequence.predictions[1].1).unwrap_or(&"unknown"));
println!("Node: {} - Function name: {}", &hgf_network.update_sequence.predictions[2].0, func_map.get(&hgf_network.update_sequence.predictions[2].1).unwrap_or(&"unknown"));
println!("Update sequence ----------");
println!("Node: {} - Function name: {}", &hgf_network.update_sequence.updates[0].0, func_map.get(&hgf_network.update_sequence.updates[0].1).unwrap_or(&"unknown"));
println!("Node: {} - Function name: {}", &hgf_network.update_sequence.updates[1].0, func_map.get(&hgf_network.update_sequence.updates[1].1).unwrap_or(&"unknown"));
println!("Node: {} - Function name: {}", &hgf_network.update_sequence.updates[2].0, func_map.get(&hgf_network.update_sequence.updates[2].1).unwrap_or(&"unknown"));

// initialize network
let mut exp_network = Network::new();
exp_network.add_nodes(
"exponential-state",
None,
None,
None,
None,
);
exp_network.set_update_sequence();
println!("Node: {} - Function name: {}", &exp_network.update_sequence.updates[0].0, func_map.get(&exp_network.update_sequence.updates[0].1).unwrap_or(&"unknown"));

println!("Network: {:?}", network);
println!("Update order: {:?}", get_update_sequence(&network));

}
}

0 comments on commit c76d108

Please sign in to comment.