Skip to content

Commit

Permalink
Redundant function array_of_any() fully removed
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonathan7773 committed Sep 5, 2024
1 parent 5754037 commit 8fff040
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 14 deletions.
1 change: 0 additions & 1 deletion src/ActiveInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ export # utils/create_matrix_templates.jl
normalize_arrays,

# utils/utils.jl
array_of_any,
array_of_any_zeros,
array_of_any_uniform,
onehot,
Expand Down
4 changes: 2 additions & 2 deletions src/Environments/TMazeEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ function reset_TMaze!(env::TMazeEnv; state=nothing)
env.reward_condition = onehot(env._reward_condition_idx, env.num_reward_conditions)

# Initialize the full state array
full_state = array_of_any(env.num_factors)
full_state = Vector{Any}(undef, env.num_factors)
full_state[env.location_factor_id] = loc_state
full_state[env.trial_factor_id] = env.reward_condition

Expand Down Expand Up @@ -182,7 +182,7 @@ end

function construct_state(env::TMazeEnv, state_tuple)
# Create an array of any
state = array_of_any(env.num_factors)
state = Vector{Any}(undef, env.num_factors)

# Populate the state array with one-hot encoded vectors
for (f, ns) in enumerate(env.num_states)
Expand Down
4 changes: 2 additions & 2 deletions src/pomdp/POMDP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function action_pomdp!(agent::Agent, obs::Vector{Int64})
n_factors = length(aif.settings["num_controls"])

# Initialize empty arrays for action distribution per factor
action_p = array_of_any(n_factors)
action_p = Vector{Any}(undef, n_factors)
action_distribution = Vector(undef, n_factors)

#If there was a previous action
Expand Down Expand Up @@ -59,7 +59,7 @@ function action_pomdp!(aif::AIF, obs::Vector{Int64})
n_factors = length(aif.settings["num_controls"])

# Initialize an empty arrays for action distribution per factor
action_p = array_of_any(n_factors)
action_p = Vector{Any}(undef, n_factors)
action_distribution = Vector(undef, n_factors)

### Infer states & policies
Expand Down
6 changes: 3 additions & 3 deletions src/pomdp/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ function get_expected_obs(qs_pi, A::Vector{Array{<:Real}})
qo_pi = []

for t in 1:n_steps
qo_pi_t = array_of_any(length(A))
qo_pi_t = Vector{Any}(undef, length(A))
qo_pi = push!(qo_pi, qo_pi_t)
end

Expand Down Expand Up @@ -305,7 +305,7 @@ function calc_pA_info_gain(pA, qo_pi, qs_pi)
n_steps = length(qo_pi)
num_modalities = length(pA)

wA = array_of_any(num_modalities)
wA = Vector{Any}(undef, num_modalities)
for (modality, pA_m) in enumerate(pA)
wA[modality] = spm_wnorm(pA[modality])
end
Expand All @@ -327,7 +327,7 @@ function calc_pB_info_gain(pB, qs_pi, qs_prev, policy)
n_steps = length(qs_pi)
num_factors = length(pB)

wB = array_of_any(num_factors)
wB = Vector{Any}(undef, num_factors)
for (factor, pB_f) in enumerate(pB)
wB[factor] = spm_wnorm(pB_f)
end
Expand Down
7 changes: 1 addition & 6 deletions src/utils/utils.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
""" -------- Utility Functions -------- """

""" Creates an array of "Any" with the desired number of sub-arrays"""
function array_of_any(num_arr::Int)
return Array{Any}(undef, num_arr) #saves it as {Any} e.g. can be any kind of data type.
end

""" Creates an array of "Any" with the desired number of sub-arrays filled with zeros"""
function array_of_any_zeros(shape_list)
arr = Array{Any}(undef, length(shape_list))
Expand Down Expand Up @@ -79,7 +74,7 @@ end
function get_log_action_marginals(aif)
num_factors = length(aif.num_controls)
action_marginals = create_matrix_templates(aif.num_controls, "zeros")
log_action_marginals = array_of_any(num_factors)
log_action_marginals = Vector{Any}(undef, num_factors)
q_pi = get_states(aif, "posterior_policies")
policies = get_states(aif, "policies")

Expand Down

0 comments on commit 8fff040

Please sign in to comment.