Skip to content

Commit

Permalink
added epistchainenv and gridworldenv
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonathan7773 committed Dec 25, 2023
1 parent 536e08f commit b4071e0
Showing 1 changed file with 45 additions and 45 deletions.
90 changes: 45 additions & 45 deletions src/environment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using IterTools

"""Grid World for Epistemic Chaining"""

mutable struct GridWorldEnv
mutable struct EpistChainEnv
init_loc::Tuple{Int, Int}
current_loc::Tuple{Int, Int}
cue1_loc::Tuple{Int, Int}
Expand All @@ -15,14 +15,14 @@ mutable struct GridWorldEnv
len_x::Int
grid_locations::Matrix{Tuple{Int, Int}}

function GridWorldEnv(starting_loc::Tuple{Int, Int}, cue1_loc::Tuple{Int, Int}, cue2::String, reward_condition::String, grid_locations::Matrix{Tuple{Int, Int}})
function EpistChainEnv(starting_loc::Tuple{Int, Int}, cue1_loc::Tuple{Int, Int}, cue2::String, reward_condition::String, grid_locations::Matrix{Tuple{Int, Int}})
len_y, len_x = size(grid_locations)
new(starting_loc, starting_loc, cue1_loc, cue2, reward_condition)
end
end


function step!(env::GridWorldEnv, action_label::String)
function step!(env::EpistChainEnv, action_label::String)
y, x = env.current_state
next_y, next_x = y, x

Expand Down Expand Up @@ -85,52 +85,52 @@ function step!(env::GridWorldEnv, action_label::String)
return loc_obs, cue1_obs, cue2_obs, reward_obs
end

function reset!(env::GridWorldEnv)
function reset!(env::EpistChainEnv)
env.current_loc = env.init_loc

return env.current_loc
end


# """Mutable structure creating the environment"""
# mutable struct GridWorldEnv
# init_state::Tuple{Int, Int}
# current_state::Tuple{Int, Int}
# len_y::Int
# len_x::Int

# function GridWorldEnv(starting_state::Tuple{Int, Int}, grid_locations)
# len_y, len_x = maximum(first.(grid_locations)), maximum(last.(grid_locations))
# new(starting_state, starting_state, len_y, len_x)
# end
# end

# """Function for how to "step" in the Grid World"""
# function step!(env::GridWorldEnv, action_label::String)
# y, x = env.current_state
# next_y, next_x = y, x

# if action_label == "DOWN" # Y-axis reversed
# next_y = y < env.len_y ? y + 1 : y
# elseif action_label == "UP"
# next_y = y > 1 ? y - 1 : y
# elseif action_label == "LEFT"
# next_x = x > 1 ? x - 1 : x
# elseif action_label == "RIGHT"
# next_x = x < env.len_x ? x + 1 : x
# elseif action_label == "STAY"
# end

# env.current_state = (next_y, next_x)

# return env.current_state
# end

# """Reset function"""

# function reset!(env::GridWorldEnv)
# env.current_state = env.init_state
# println("Re-initialized location to ", env.init_state)
# return env.current_state
# end
"""Gridworld Simple"""
mutable struct GridWorldEnv
init_state::Tuple{Int, Int}
current_state::Tuple{Int, Int}
len_y::Int
len_x::Int

function GridWorldEnv(starting_state::Tuple{Int, Int}, grid_locations)
len_y, len_x = maximum(first.(grid_locations)), maximum(last.(grid_locations))
new(starting_state, starting_state, len_y, len_x)
end
end

"""Function for how to "step" in the Grid World"""
function step!(env::GridWorldEnv, action_label::String)
y, x = env.current_state
next_y, next_x = y, x

if action_label == "DOWN" # Y-axis reversed
next_y = y < env.len_y ? y + 1 : y
elseif action_label == "UP"
next_y = y > 1 ? y - 1 : y
elseif action_label == "LEFT"
next_x = x > 1 ? x - 1 : x
elseif action_label == "RIGHT"
next_x = x < env.len_x ? x + 1 : x
elseif action_label == "STAY"
end

env.current_state = (next_y, next_x)

return env.current_state
end

"""Reset function"""

function reset!(env::GridWorldEnv)
env.current_state = env.init_state
println("Re-initialized location to ", env.init_state)
return env.current_state
end

0 comments on commit b4071e0

Please sign in to comment.