Skip to content

Commit

Permalink
CR: new methods for templates
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelnehrer02 committed Aug 25, 2024
1 parent e69cabf commit 40af899
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/pomdp/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ function fixed_point_iteration(A::Vector{Any}, obs::Vector{Vector{Real}}, num_ob
end

if prior === nothing
prior = array_of_any_uniform(num_states)
prior = create_matrix_templates(num_states)
end

prior = spm_log_array_any(prior)
Expand Down Expand Up @@ -355,7 +355,7 @@ end
""" Sample Action [Stochastic or Deterministic] """
function sample_action(q_pi, policies, num_controls; action_selection="stochastic", alpha=16.0)
num_factors = length(num_controls)
action_marginals = array_of_any_zeros(num_controls)
action_marginals = create_matrix_templates(num_controls, "zeros")
selected_policy = zeros(Real,num_factors)

for (pol_idx, policy) in enumerate(policies)
Expand Down
6 changes: 3 additions & 3 deletions src/pomdp/struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ function create_aif(A, B;

# If C-vectors are not provided
if isnothing(C)
C = array_of_any_zeros(num_obs)
C = create_matrix_templates(num_obs, "zeros")
end

# If D-vectors are not provided
if isnothing(D)
D = array_of_any_uniform(num_states)
D = create_matrix_templates(num_states)
end

# if num_controls are not given, they are inferred from the B matrix
Expand All @@ -100,7 +100,7 @@ function create_aif(A, B;
error("Length of E-vector must match the number of policies.")
end

qs_current = array_of_any_uniform(num_states)
qs_current = create_matrix_templates(num_states)
prior = D
Q_pi = ones(Real,length(policies)) / length(policies)
G = zeros(Real,length(policies))
Expand Down
114 changes: 111 additions & 3 deletions src/utils/create_matrix_templates.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
######################## Create Templates Based on states, observations, controls and policy length ########################

"""
create_matrix_templates(n_states::Vector{Int64}, n_observations::Vector{Int64}, n_controls::Vector{Int64}, policy_length::Int64, template_type::String = "uniform")
Expand All @@ -14,8 +16,6 @@ Creates templates for the A, B, C, D, and E matrices based on the specified para
- `A, B, C, D, E`: The generative model as matrices and vectors.
"""

### Function for creating uniform matrices
function create_matrix_templates(n_states::Vector{Int64}, n_observations::Vector{Int64}, n_controls::Vector{Int64}, policy_length::Int64)

# Calculate the number of policies based on the policy length
Expand All @@ -39,7 +39,6 @@ function create_matrix_templates(n_states::Vector{Int64}, n_observations::Vector
return A, B, C, D, E
end

### Function for specific template type, can be either 'random' or 'zeros'
function create_matrix_templates(n_states::Vector{Int64}, n_observations::Vector{Int64}, n_controls::Vector{Int64}, policy_length::Int64, template_type::String)

# If the template_type is uniform
Expand Down Expand Up @@ -82,4 +81,113 @@ function create_matrix_templates(n_states::Vector{Int64}, n_observations::Vector
end

return A, B, C, D, E
end

######################## Create Templates Based on Shapes ########################

### Single Array Input

"""
create_matrix_templates(shapes::Vector{Int64})
Creates uniform templates based on the specified shapes vector.
# Arguments
- `shapes::Vector{Int64}`: A vector specifying the dimensions of each template to create.
# Returns
- A vector of normalized arrays.
"""
function create_matrix_templates(shapes::Vector{Int64})

# Create arrays filled with ones and then normalize
return [norm_dist(ones(n)) for n in shapes]
end

"""
create_matrix_templates(shapes::Vector{Int64}, template_type::String)
Creates templates based on the specified shapes vector and template type. Templates can be uniform, random, or filled with zeros.
# Arguments
- `shapes::Vector{Int64}`: A vector specifying the dimensions of each template to create.
- `template_type::String`: The type of templates to create. Can be "uniform" (default), "random", or "zeros".
# Returns
- A vector of arrays, each corresponding to the shape given by the input vector.
"""
function create_matrix_templates(shapes::Vector{Int64}, template_type::String)

if template_type == "uniform"
# Create arrays filled with ones and then normalize
return [norm_dist(ones(n)) for n in shapes]

elseif template_type == "random"
# Create arrays filled with random values
return [norm_dist(rand(n)) for n in shapes]

elseif template_type == "zeros"
# Create arrays filled with zeros
return [zeros(n) for n in shapes]

else
# Throw error for invalid template type
throw(ArgumentError("Invalid type: $template_type. Choose either 'uniform', 'random' or 'zeros'."))
end
end

### Vector of Arrays Input

"""
create_matrix_templates(shapes::Vector{Vector{Int64}})
Creates a uniform, multidimensional template based on the specified shapes vector.
# Arguments
- `shapes::Vector{Vector{Int64}}`: A vector of vectors, where each vector represent a dimension of the template to create.
# Returns
- A vector of normalized arrays (uniform distributions), each having the multi-dimensional shape specified in the input vector.
"""
function create_matrix_templates(shapes::Vector{Vector{Int64}})

# Create arrays filled with ones and then normalize
return [norm_dist(ones(shape...)) for shape in shapes]
end

"""
create_matrix_templates(shapes::Vector{Vector{Int64}}, template_type::String)
Creates a multidimensional template based on the specified vector of shape vectors and template type. Templates can be uniform, random, or filled with zeros.
# Arguments
- `shapes::Vector{Vector{Int64}}`: A vector of vectors, where each vector represent a dimension of the template to create.
- `template_type::String`: The type of templates to create. Can be "uniform" (default), "random", or "zeros".
# Returns
- A vector of arrays, each having the multi-dimensional shape specified in the input vector.
"""
function create_matrix_templates(shapes::Vector{Vector{Int64}}, template_type::String)

if template_type == "uniform"
# Create arrays filled with ones and then normalize
return [norm_dist(ones(shape...)) for shape in shapes]

elseif template_type == "random"
# Create arrays filled with random values
return [norm_dist(rand(shape...)) for shape in shapes]

elseif template_type == "zeros"
# Create arrays filled with zeros
return [zeros(shape...) for shape in shapes]

else
# Throw error for invalid template type
throw(ArgumentError("Invalid type: $template_type. Choose either 'uniform', 'random' or 'zeros'."))
end
end
2 changes: 1 addition & 1 deletion src/utils/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ end
""" Function to get log marginal probabilities of actions """
function get_log_action_marginals(aif)
num_factors = length(aif.num_controls)
action_marginals = array_of_any_zeros(aif.num_controls)
action_marginals = create_matrix_templates(aif.num_controls, "zeros")
log_action_marginals = array_of_any(num_factors)
q_pi = get_states(aif, "posterior_policies")
policies = get_states(aif, "policies")
Expand Down

0 comments on commit 40af899

Please sign in to comment.