From 40af899c6c86b704d7d5adeb13930ff744a196b2 Mon Sep 17 00:00:00 2001 From: Samuel William Nehrer Date: Sun, 25 Aug 2024 15:44:23 +0200 Subject: [PATCH] CR: new methods for templates --- src/pomdp/inference.jl | 4 +- src/pomdp/struct.jl | 6 +- src/utils/create_matrix_templates.jl | 114 ++++++++++++++++++++++++++- src/utils/utils.jl | 2 +- 4 files changed, 117 insertions(+), 9 deletions(-) diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index ed7e0fb..e220712 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -102,7 +102,7 @@ function fixed_point_iteration(A::Vector{Any}, obs::Vector{Vector{Real}}, num_ob end if prior === nothing - prior = array_of_any_uniform(num_states) + prior = create_matrix_templates(num_states) end prior = spm_log_array_any(prior) @@ -355,7 +355,7 @@ end """ Sample Action [Stochastic or Deterministic] """ function sample_action(q_pi, policies, num_controls; action_selection="stochastic", alpha=16.0) num_factors = length(num_controls) - action_marginals = array_of_any_zeros(num_controls) + action_marginals = create_matrix_templates(num_controls, "zeros") selected_policy = zeros(Real,num_factors) for (pol_idx, policy) in enumerate(policies) diff --git a/src/pomdp/struct.jl b/src/pomdp/struct.jl index 6ca03db..d915cca 100644 --- a/src/pomdp/struct.jl +++ b/src/pomdp/struct.jl @@ -75,12 +75,12 @@ function create_aif(A, B; # If C-vectors are not provided if isnothing(C) - C = array_of_any_zeros(num_obs) + C = create_matrix_templates(num_obs, "zeros") end # If D-vectors are not provided if isnothing(D) - D = array_of_any_uniform(num_states) + D = create_matrix_templates(num_states) end # if num_controls are not given, they are inferred from the B matrix @@ -100,7 +100,7 @@ function create_aif(A, B; error("Length of E-vector must match the number of policies.") end - qs_current = array_of_any_uniform(num_states) + qs_current = create_matrix_templates(num_states) prior = D Q_pi = ones(Real,length(policies)) / length(policies) G = zeros(Real,length(policies)) diff --git a/src/utils/create_matrix_templates.jl b/src/utils/create_matrix_templates.jl index b7061ad..c747714 100644 --- a/src/utils/create_matrix_templates.jl +++ b/src/utils/create_matrix_templates.jl @@ -1,3 +1,5 @@ +######################## Create Templates Based on states, observations, controls and policy length ######################## + """ create_matrix_templates(n_states::Vector{Int64}, n_observations::Vector{Int64}, n_controls::Vector{Int64}, policy_length::Int64, template_type::String = "uniform") @@ -14,8 +16,6 @@ Creates templates for the A, B, C, D, and E matrices based on the specified para - `A, B, C, D, E`: The generative model as matrices and vectors. """ - -### Function for creating uniform matrices function create_matrix_templates(n_states::Vector{Int64}, n_observations::Vector{Int64}, n_controls::Vector{Int64}, policy_length::Int64) # Calculate the number of policies based on the policy length @@ -39,7 +39,6 @@ function create_matrix_templates(n_states::Vector{Int64}, n_observations::Vector return A, B, C, D, E end -### Function for specific template type, can be either 'random' or 'zeros' function create_matrix_templates(n_states::Vector{Int64}, n_observations::Vector{Int64}, n_controls::Vector{Int64}, policy_length::Int64, template_type::String) # If the template_type is uniform @@ -82,4 +81,113 @@ function create_matrix_templates(n_states::Vector{Int64}, n_observations::Vector end return A, B, C, D, E +end + +######################## Create Templates Based on Shapes ######################## + +### Single Array Input + +""" + create_matrix_templates(shapes::Vector{Int64}) + +Creates uniform templates based on the specified shapes vector. + +# Arguments +- `shapes::Vector{Int64}`: A vector specifying the dimensions of each template to create. + +# Returns +- A vector of normalized arrays. + +""" +function create_matrix_templates(shapes::Vector{Int64}) + + # Create arrays filled with ones and then normalize + return [norm_dist(ones(n)) for n in shapes] +end + +""" + create_matrix_templates(shapes::Vector{Int64}, template_type::String) + +Creates templates based on the specified shapes vector and template type. Templates can be uniform, random, or filled with zeros. + +# Arguments +- `shapes::Vector{Int64}`: A vector specifying the dimensions of each template to create. +- `template_type::String`: The type of templates to create. Can be "uniform" (default), "random", or "zeros". + +# Returns +- A vector of arrays, each corresponding to the shape given by the input vector. + + +""" +function create_matrix_templates(shapes::Vector{Int64}, template_type::String) + + if template_type == "uniform" + # Create arrays filled with ones and then normalize + return [norm_dist(ones(n)) for n in shapes] + + elseif template_type == "random" + # Create arrays filled with random values + return [norm_dist(rand(n)) for n in shapes] + + elseif template_type == "zeros" + # Create arrays filled with zeros + return [zeros(n) for n in shapes] + + else + # Throw error for invalid template type + throw(ArgumentError("Invalid type: $template_type. Choose either 'uniform', 'random' or 'zeros'.")) + end +end + +### Vector of Arrays Input + +""" + create_matrix_templates(shapes::Vector{Vector{Int64}}) + +Creates a uniform, multidimensional template based on the specified shapes vector. + +# Arguments +- `shapes::Vector{Vector{Int64}}`: A vector of vectors, where each vector represent a dimension of the template to create. + +# Returns +- A vector of normalized arrays (uniform distributions), each having the multi-dimensional shape specified in the input vector. + +""" +function create_matrix_templates(shapes::Vector{Vector{Int64}}) + + # Create arrays filled with ones and then normalize + return [norm_dist(ones(shape...)) for shape in shapes] +end + +""" + create_matrix_templates(shapes::Vector{Vector{Int64}}, template_type::String) + +Creates a multidimensional template based on the specified vector of shape vectors and template type. Templates can be uniform, random, or filled with zeros. + +# Arguments +- `shapes::Vector{Vector{Int64}}`: A vector of vectors, where each vector represent a dimension of the template to create. +- `template_type::String`: The type of templates to create. Can be "uniform" (default), "random", or "zeros". + +# Returns +- A vector of arrays, each having the multi-dimensional shape specified in the input vector. + +""" +function create_matrix_templates(shapes::Vector{Vector{Int64}}, template_type::String) + + if template_type == "uniform" + # Create arrays filled with ones and then normalize + return [norm_dist(ones(shape...)) for shape in shapes] + + elseif template_type == "random" + # Create arrays filled with random values + return [norm_dist(rand(shape...)) for shape in shapes] + + elseif template_type == "zeros" + # Create arrays filled with zeros + return [zeros(shape...) for shape in shapes] + + else + # Throw error for invalid template type + throw(ArgumentError("Invalid type: $template_type. Choose either 'uniform', 'random' or 'zeros'.")) + end end \ No newline at end of file diff --git a/src/utils/utils.jl b/src/utils/utils.jl index a02fba0..bbae64e 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -78,7 +78,7 @@ end """ Function to get log marginal probabilities of actions """ function get_log_action_marginals(aif) num_factors = length(aif.num_controls) - action_marginals = array_of_any_zeros(aif.num_controls) + action_marginals = create_matrix_templates(aif.num_controls, "zeros") log_action_marginals = array_of_any(num_factors) q_pi = get_states(aif, "posterior_policies") policies = get_states(aif, "policies")