Skip to content

Commit

Permalink
Merge pull request #160 from ghost/develop
Browse files Browse the repository at this point in the history
[revised] .ix to .loc AND [add] .astype() for Type error: reduction operation .argmax()
  • Loading branch information
MorvanZhou authored Nov 1, 2020
2 parents e300aa4 + 57ab79d commit 1fd1c08
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions contents/11_Dyna_Q/RL_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,37 @@ def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
self.lr = learning_rate
self.gamma = reward_decay
self.epsilon = e_greedy
self.q_table = pd.DataFrame(columns=self.actions)

## argmax type error
self.q_table = pd.DataFrame(columns=self.actions).astype('float32')

def choose_action(self, observation):
self.check_state_exist(observation)
# action selection
if np.random.uniform() < self.epsilon:
# choose best action
state_action = self.q_table.ix[observation, :]


# state_action = self.q_table.ix[observation, :]
state_action = self.q_table.loc[observation, :] # for label indexing
state_action = state_action.reindex(np.random.permutation(state_action.index)) # some actions have same value
action = state_action.argmax()


else:
# choose random action
action = np.random.choice(self.actions)
return action

def learn(self, s, a, r, s_):
self.check_state_exist(s_)
q_predict = self.q_table.ix[s, a]

q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.ix[s_, :].max() # next state is not terminal
q_target = r + self.gamma * self.q_table.loc[s_, :].max() # next state is not terminal
else:
q_target = r # next state is terminal
self.q_table.ix[s, a] += self.lr * (q_target - q_predict) # update
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update

def check_state_exist(self, state):
if state not in self.q_table.index:
Expand Down Expand Up @@ -71,9 +79,9 @@ def store_transition(self, s, a, r, s_):

def sample_s_a(self):
s = np.random.choice(self.database.index)
a = np.random.choice(self.database.ix[s].dropna().index) # filter out the None value
a = np.random.choice(self.database.loc[s].dropna().index) # filter out the None value
return s, a

def get_r_s_(self, s, a):
r, s_ = self.database.ix[s, a]
r, s_ = self.database.loc[s, a]
return r, s_

0 comments on commit 1fd1c08

Please sign in to comment.