Skip to content

Commit

Permalink
Type adjustments
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelnehrer02 committed Aug 25, 2024
1 parent da04c10 commit 15384b4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
12 changes: 6 additions & 6 deletions src/pomdp/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function process_observation(observation::Union{Array{Int}, Tuple{Vararg{Int}}},
end

""" Update Posterior States """
function update_posterior_states(A::Vector{Any}, obs::Vector{Int64}; prior::Union{Nothing, Vector{Any}}=nothing, num_iter::Int=num_iter, dF_tol::Float64=dF_tol, kwargs...)
function update_posterior_states(A::Vector{Array{<:Real}}, obs::Vector{Int64}; prior::Union{Nothing, Vector{Any}}=nothing, num_iter::Int=num_iter, dF_tol::Float64=dF_tol, kwargs...)
num_obs, num_states, num_modalities, num_factors = get_model_dimensions(A)

obs_processed = process_observation(obs, num_modalities, num_obs)
Expand All @@ -87,7 +87,7 @@ end


""" Run State Inference via Fixed-Point Iteration """
function fixed_point_iteration(A::Vector{Any}, obs::Vector{Vector{Real}}, num_obs::Vector{Int64}, num_states::Vector{Int64}; prior::Union{Nothing, Vector{Any}}=nothing, num_iter::Int=num_iter, dF::Float64=1.0, dF_tol::Float64=dF_tol)
function fixed_point_iteration(A::Vector{Array{<:Real}}, obs::Vector{Vector{Real}}, num_obs::Vector{Int64}, num_states::Vector{Int64}; prior::Union{Nothing, Vector{Any}}=nothing, num_iter::Int=num_iter, dF::Float64=1.0, dF_tol::Float64=dF_tol)
n_modalities = length(num_obs)
n_factors = length(num_states)

Expand Down Expand Up @@ -191,9 +191,9 @@ end
""" Update Posterior over Policies """
function update_posterior_policies(
qs::Vector{Any},
A::Vector{Any},
B::Vector{Any},
C::Vector{Any},
A::Vector{Array{<:Real}},
B::Vector{Array{<:Real}},
C::Vector{Array{<:Real}},
policies::Vector{Matrix{Int64}},
use_utility::Bool=true,
use_states_info_gain::Bool=true,
Expand Down Expand Up @@ -243,7 +243,7 @@ function update_posterior_policies(
end

""" Get Expected Observations """
function get_expected_obs(qs_pi, A::Vector{Any})
function get_expected_obs(qs_pi, A::Vector{Array{<:Real}})
n_steps = length(qs_pi)
qo_pi = []

Expand Down
16 changes: 8 additions & 8 deletions src/pomdp/struct.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
""" -------- AIF Mutable Struct -------- """

mutable struct AIF
A::Array{Any,1} # A-matrix
B::Array{Any,1} # B-matrix
C::Array{Any,1} # C-vectors
D::Array{Any,1} # D-vectors
E::Union{Array{Any, 1}, Nothing} # E - vector (Habits)
pA::Union{Array{Any,1}, Nothing} # Dirichlet priors for A-matrix
pB::Union{Array{Any,1}, Nothing} # Dirichlet priors for B-matrix
pD::Union{Array{Any,1}, Nothing} # Dirichlet priors for D-vector
A::Vector{Array{<:Real}} # A-matrix
B::Vector{Array{<:Real}} # B-matrix
C::Vector{Array{<:Real}} # C-vectors
D::Vector{Array{<:Real}} # D-vectors
E::Union{Vector{<:Real}, Nothing} # E-vector (Habits)
pA::Union{Vector{Array{<:Real}}, Nothing} # Dirichlet priors for A-matrix
pB::Union{Vector{Array{<:Real}}, Nothing} # Dirichlet priors for B-matrix
pD::Union{Vector{Array{<:Real}}, Nothing} # Dirichlet priors for D-vector
lr_pA::Real # pA Learning Parameter
fr_pA::Real # pA Forgetting Parameter, 1.0 for no forgetting
lr_pB::Real # pB learning Parameter
Expand Down

0 comments on commit 15384b4

Please sign in to comment.