Skip to content

Commit

Permalink
exponential 1d node working
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Oct 29, 2024
1 parent d46529d commit 8c5b0ca
Show file tree
Hide file tree
Showing 11 changed files with 112 additions and 56 deletions.
38 changes: 24 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ path = "src/lib.rs" # The source file of the target.

[dependencies]
pyo3 = { version = "0.21.2", features = ["extension-module"] }
ndarray = "0.16.1"
numpy = "0.21"
20 changes: 10 additions & 10 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,6 @@ Continuous nodes
continuous_node_posterior_update
continuous_node_posterior_update_ehgf

Exponential family
------------------

.. currentmodule:: pyhgf.updates.posterior.exponential

.. autosummary::
:toctree: generated/pyhgf.updates.posterior.exponential

posterior_update_exponential_family

Prediction steps
================

Expand Down Expand Up @@ -146,6 +136,16 @@ Dirichlet state nodes
likely_cluster_proposal
clusters_likelihood

Exponential family
^^^^^^^^^^^^^^^^^^

.. currentmodule:: pyhgf.updates.prediction_error.exponential

.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.exponential

prediction_error_update_exponential_family

Distribution
************

Expand Down
4 changes: 2 additions & 2 deletions docs/source/notebooks/0.3-Generalised_filtering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@
"\n",
"### Using a fixed $\\nu$\n",
"\n",
"This operation can be achieved using a continuous state node that implements the exponential family updates on the values that are passed by the value child nodes. Such nodes are referred to as `ef-` nodes, with the type of distribution (here a simple one-dimensional Gaussian distribution, therefore the kind is set to `\"ef-normal\"`). The input node is set to generic, which means that this input simply passes the observed value to the value parents without any additional computation. We can define such a model as follows:"
"This operation can be achieved using a continuous state node that implements the exponential family updates on the values that are passed by the value child nodes. Such nodes are referred to as `exponential-state` nodes, with the type of distribution (here a simple one-dimensional Gaussian distribution). The input node is set to generic, which means that this input simply passes the observed value to the value parents without any additional computation. We can define such a model as follows:"
]
},
{
Expand All @@ -340,7 +340,7 @@
"generalised_filter = (\n",
" Network()\n",
" .add_nodes(kind=\"generic-state\")\n",
" .add_nodes(kind=\"ef-normal\", value_children=0, xis=np.array([0, 1 / 8]))\n",
" .add_nodes(kind=\"exponential-state\", value_children=0, xis=np.array([0, 1 / 8]))\n",
")"
]
},
Expand Down
8 changes: 4 additions & 4 deletions pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,8 @@ def add_nodes(
raise ValueError(
(
"Invalid node type. Should be one of the following: "
"'DP-state', 'continuous-state', 'binary-state', 'ef-normal'."
"'generic-state' or 'categorical-state'"
"'DP-state', 'continuous-state', 'binary-state', "
"'exponential-state', 'generic-state' or 'categorical-state'"
)
)

Expand Down Expand Up @@ -473,7 +473,7 @@ def add_nodes(
"nus": 3.0,
"xis": jnp.array([0.0, 1.0]),
"mean": 0.0,
"observed": 1.0,
"observed": 1,
}
elif kind == "categorical-state":
if "n_categories" in node_parameters:
Expand Down Expand Up @@ -562,7 +562,7 @@ def add_nodes(
node_type = 1
elif kind == "continuous-state":
node_type = 2
elif kind == "ef-normal":
elif kind == "exponential-state":
node_type = 3
elif kind == "DP-state":
node_type = 4
Expand Down
2 changes: 1 addition & 1 deletion pyhgf/updates/prediction/dirichlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def dirichlet_node_prediction(
Static parameters of the Dirichlet process node.
"""
# get the parameter (mean and variance) from the EF-normal parent nodes
# get the parameter (mean and variance) from the exponential state parent nodes
value_parent_idxs = edges[node_idx].value_parents
if value_parent_idxs is not None:
parameters = jnp.array(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@partial(jit, static_argnames=("edges", "node_idx", "sufficient_stats_fn"))
def posterior_update_exponential_family(
def prediction_error_update_exponential_family(
attributes: Dict, edges: Edges, node_idx: int, sufficient_stats_fn: Callable, **args
) -> Attributes:
r"""Update the parameters of an exponential family distribution.
Expand Down
31 changes: 17 additions & 14 deletions pyhgf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
continuous_node_posterior_update,
continuous_node_posterior_update_ehgf,
)
from pyhgf.updates.posterior.exponential import posterior_update_exponential_family
from pyhgf.updates.prediction.binary import binary_state_node_prediction
from pyhgf.updates.prediction.continuous import continuous_node_prediction
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction
Expand All @@ -28,6 +27,9 @@
)
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error
from pyhgf.updates.prediction_error.exponential import (
prediction_error_update_exponential_family,
)
from pyhgf.updates.prediction_error.generic import generic_state_prediction_error

if TYPE_CHECKING:
Expand Down Expand Up @@ -374,16 +376,6 @@ def get_update_sequence(
elif update_type == "standard":
update_fn = continuous_node_posterior_update

elif network.edges[idx].node_type == 3:

# create the sufficient statistic function
# for the exponential family node
ef_update = Partial(
posterior_update_exponential_family,
sufficient_stats_fn=Normal().sufficient_statistics,
)
update_fn = ef_update

elif network.edges[idx].node_type == 4:

update_fn = None
Expand All @@ -407,8 +399,21 @@ def get_update_sequence(
]

# if this node has no parent, no need to compute prediction errors
# unless this is an exponential family state node
if len(all_parents) == 0:
nodes_without_prediction_error.remove(idx)
if network.edges[idx].node_type == 3:
# create the sufficient statistic function
# for the exponential family node
ef_update = Partial(
prediction_error_update_exponential_family,
sufficient_stats_fn=Normal().sufficient_statistics,
)
update_fn = ef_update
no_update = False
update_sequence.append((idx, update_fn))
nodes_without_prediction_error.remove(idx)
else:
nodes_without_prediction_error.remove(idx)
else:
# if this node has been updated
if idx not in nodes_without_posterior_update:
Expand All @@ -419,8 +424,6 @@ def get_update_sequence(
update_fn = binary_state_node_prediction_error
elif network.edges[idx].node_type == 2:
update_fn = continuous_node_prediction_error
elif network.edges[idx].node_type == 3:
update_fn = None
elif network.edges[idx].node_type == 4:
update_fn = dirichlet_node_prediction_error
elif network.edges[idx].node_type == 5:
Expand Down
42 changes: 38 additions & 4 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::utils::set_sequence::set_update_sequence;
use crate::utils::function_pointer::get_func_map;
use pyo3::types::PyTuple;
use pyo3::{prelude::*, types::{PyList, PyDict}};
use ndarray::{Array2, Axis, stack};
use numpy::{PyArray1, PyArray};

#[derive(Debug)]
#[pyclass]
Expand Down Expand Up @@ -165,6 +165,8 @@ impl Network {

// initialize the belief trajectories result struture
let mut node_trajectories = NodeTrajectories {floats: HashMap::new(), vectors: HashMap::new()};

// add empty vectors in the floats hashmap
for (node_idx, node) in &self.attributes.floats {
let new_map: HashMap<String, Vec<f64>> = HashMap::new();
node_trajectories.floats.insert(*node_idx, new_map);
Expand All @@ -174,13 +176,25 @@ impl Network {
}
}
}
// add empty vectors in the vectors hashmap
for (node_idx, node) in &self.attributes.vectors {
let new_map: HashMap<String, Vec<Vec<f64>>> = HashMap::new();
node_trajectories.vectors.insert(*node_idx, new_map);
if let Some(attr) = node_trajectories.vectors.get_mut(node_idx) {
for key in node.keys() {
attr.insert(key.clone(), Vec::new());
}
}
}

// iterate over the observations
for observation in input_data {

// 1. belief propagation for one time slice
self.belief_propagation(vec![observation]);

// 2. append the new states in the result vector
// 2. append the new beliefs in the trajectories structure
// iterate over the float hashmap
for (new_node_idx, new_node) in &self.attributes.floats {
for (new_key, new_value) in new_node {
// If the key exists in map1, append the vector from map2
Expand All @@ -191,6 +205,17 @@ impl Network {
}
}
}
// iterate over the vector hashmap
for (new_node_idx, new_node) in &self.attributes.vectors {
for (new_key, new_value) in new_node {
// If the key exists in map1, append the vector from map2
if let Some(old_node) = node_trajectories.vectors.get_mut(&new_node_idx) {
if let Some(old_value) = old_node.get_mut(new_key) {
old_value.push(new_value.clone());
}
}
}
}
}

self.node_trajectories = node_trajectories;
Expand All @@ -201,15 +226,24 @@ impl Network {
let py_list = PyList::empty(py);


// Iterate over the Rust HashMap and insert key-value pairs into the PyDict
// Iterate over the float hashmap and insert key-value pairs into the list as PyDict
for (node_idx, node) in &self.node_trajectories.floats {
let py_dict = PyDict::new(py);
for (key, value) in node {
// Create a new Python dictionary
py_dict.set_item(key, value).expect("Failed to set item in PyDict");
py_dict.set_item(key, PyArray1::from_vec(py, value.clone()).to_owned()).expect("Failed to set item in PyDict");
}

// Iterate over the vector hashmap if any and insert key-value pairs into the list as PyDict
if let Some(vector_node) = self.node_trajectories.vectors.get(node_idx) {
for (vector_key, vector_value) in vector_node {
// Create a new Python dictionary
py_dict.set_item(vector_key, PyArray::from_vec2_bound(py, &vector_value).unwrap()).expect("Failed to set item in PyDict");
}
}
py_list.append(py_dict)?;
}

// Create a PyList from Vec<usize>
Ok(py_list)
}
Expand Down
17 changes: 13 additions & 4 deletions tests/test_exponential_family.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Author: Nicolas Legrand <[email protected]>

import numpy as np
from rshgf import Network as RsNetwork

from pyhgf import load_data
Expand All @@ -13,12 +14,20 @@ def test_1d_gaussain():
# Rust -----------------------------------------------------------------------------
rs_network = RsNetwork()
rs_network.add_nodes(kind="exponential-state")
rs_network.inputs
rs_network.edges
rs_network.set_update_sequence()

rs_network.input_data(timeseries)

# Python ---------------------------------------------------------------------------
py_network = PyNetwork().add_nodes(kind="exponential-state")
py_network.attributes
py_network.input_data(timeseries)

# Ensure identical results
assert np.isclose(
py_network.node_trajectories[0]["xis"], rs_network.node_trajectories[0]["xis"]
).all()
assert np.isclose(
py_network.node_trajectories[0]["mean"], rs_network.node_trajectories[0]["mean"]
).all()
assert np.isclose(
py_network.node_trajectories[0]["nus"], rs_network.node_trajectories[0]["nus"]
).all()
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_set_update_sequence():
network3 = (
Network()
.add_nodes(kind="generic-state")
.add_nodes(kind="ef-normal", value_children=0)
.add_nodes(kind="exponential-state", value_children=0)
.create_belief_propagation_fn()
)
predictions, updates = network3.update_sequence
Expand Down

0 comments on commit 8c5b0ca

Please sign in to comment.