-
Notifications
You must be signed in to change notification settings - Fork 0
/
policyIteration.py
96 lines (86 loc) · 3.18 KB
/
policyIteration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import random
# Arguments
REWARD = -0.01 # constant reward for non-terminal states
DISCOUNT = 0.99
MAX_ERROR = 10**(-3)
# Set up the initial environment
NUM_ACTIONS = 4
ACTIONS = [(1, 0), (0, -1), (-1, 0), (0, 1)] # Down, Left, Up, Right
NUM_ROW = 3
NUM_COL = 4
U = [[0, 0, 0, 1], [0, 0, 0, -1], [0, 0, 0, 0], [0, 0, 0, 0]]
policy = [[random.randint(0, 3) for j in range(NUM_COL)] for i in range(NUM_ROW)] # construct a random policy
# Visualization
def printEnvironment(arr, policy=False):
res = ""
for r in range(NUM_ROW):
res += "|"
for c in range(NUM_COL):
if r == c == 1:
val = "WALL"
elif r <= 1 and c == 3:
val = "+1" if r == 0 else "-1"
else:
val = ["Down", "Left", "Up", "Right"][arr[r][c]]
res += " " + val[:5].ljust(5) + " |" # format
res += "\n"
print(res)
# Get the utility of the state reached by performing the given action from the given state
def getU(U, r, c, action):
dr, dc = ACTIONS[action]
newR, newC = r+dr, c+dc
if newR < 0 or newC < 0 or newR >= NUM_ROW or newC >= NUM_COL or (newR == newC == 1): # collide with the boundary or the wall
return U[r][c]
else:
return U[newR][newC]
# Calculate the utility of a state given an action
def calculateU(U, r, c, action):
u = REWARD
u += 0.1 * DISCOUNT * getU(U, r, c, (action-1)%4)
u += 0.8 * DISCOUNT * getU(U, r, c, action)
u += 0.1 * DISCOUNT * getU(U, r, c, (action+1)%4)
return u
# Perform some simplified value iteration steps to get an approximation of the utilities
def policyEvaluation(policy, U):
while True:
nextU = [[0, 0, 0, 1], [0, 0, 0, -1], [0, 0, 0, 0], [0, 0, 0, 0]]
error = 0
for r in range(NUM_ROW):
for c in range(NUM_COL):
if (r <= 1 and c == 3) or (r == c == 1):
continue
nextU[r][c] = calculateU(U, r, c, policy[r][c]) # simplified Bellman update
error = max(error, abs(nextU[r][c]-U[r][c]))
U = nextU
if error < MAX_ERROR * (1-DISCOUNT) / DISCOUNT:
break
return U
def policyIteration(policy, U):
print("During the policy iteration:\n")
while True:
U = policyEvaluation(policy, U)
unchanged = True
for r in range(NUM_ROW):
for c in range(NUM_COL):
if (r <= 1 and c == 3) or (r == c == 1):
continue
maxAction, maxU = None, -float("inf")
for action in range(NUM_ACTIONS):
u = calculateU(U, r, c, action)
if u > maxU:
maxAction, maxU = action, u
if maxU > calculateU(U, r, c, policy[r][c]):
policy[r][c] = maxAction # the action that maximizes the utility
unchanged = False
if unchanged:
break
printEnvironment(policy)
return policy
# Print the initial environment
print("The initial random policy is:\n")
printEnvironment(policy)
# Policy iteration
policy = policyIteration(policy, U)
# Print the optimal policy
print("The optimal policy is:\n")
printEnvironment(policy)