Skip to content

Commit

Permalink
CR: removed to_array_of_any
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelnehrer02 committed Aug 21, 2024
1 parent 5b7f52a commit f5b60ed
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/pomdp/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ function fixed_point_iteration(A::Vector{Any}, obs::Vector{Vector{Real}}, num_ob
likelihood = spm_log_single(likelihood)

# Initialize posterior and prior
qs = Array{Any}(undef, n_factors)
qs = Vector{Vector{Real}}(undef, n_factors)
for factor in 1:n_factors
qs[factor] = ones(Real,num_states[factor]) / num_states[factor]
end
Expand All @@ -113,7 +113,7 @@ function fixed_point_iteration(A::Vector{Any}, obs::Vector{Vector{Real}}, num_ob
# Single factor condition
if n_factors == 1
qL = spm_dot(likelihood, qs[1])
return to_array_of_any(softmax(qL .+ prior[1]))
return [softmax(qL .+ prior[1])]
else
# Run Iteration
curr_iter = 0
Expand Down
13 changes: 0 additions & 13 deletions src/utils/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,6 @@ function get_model_dimensions(A = nothing, B = nothing)
end


""" Equivalent to pymdp's "to_obj_array" """
function to_array_of_any(arr::Array)
# Check if arr is already an array of arrays
if typeof(arr) == Array{Array,1}
return arr
end
# Create an array_out and assign squeezed array to the first element
obj_array_out = Array{Any,1}(undef, 1)
obj_array_out[1] = dropdims(arr, dims = tuple(findall(size(arr) .== 1)...))
return obj_array_out
end


""" Selects the highest value from Array -- used for deterministic action sampling """
function select_highest(options_array::Array{Float64})
options_with_idx = [(i, option) for (i, option) in enumerate(options_array)]
Expand Down

0 comments on commit f5b60ed

Please sign in to comment.