-
Notifications
You must be signed in to change notification settings - Fork 0
/
pomdp.py
101 lines (83 loc) · 3.28 KB
/
pomdp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""pomdp.py"""
import os
import subprocess
import tempfile
import shutil
import pomdp_policy
class POMDP(object):
"""Base POMDP class."""
def __init__(self):
raise NotImplementedError
def solve(self, model):
"""Return parsed policy."""
raise NotImplementedError
def write_cassandra(self, fo, states, actions, observations, f_start,
f_transition, f_observation, f_reward, discount):
"""Write a Cassandra-style POMDP file."""
if discount >= 1.0:
raise Exception('Discount must be less than 1.0')
# Write header
fo.write('discount: {}\n'.format(discount))
fo.write('values: reward\n')
fo.write('states: {}\n'.format(' '.join(str(s) for s in states)))
fo.write('actions: {}\n'.format(' '.join(str(a) for a in actions)))
fo.write('observations: {}\n'.format(' '.join(observations)))
fo.write('start: {}\n'.format(' '.join(str(f_start(s)) for
s in states)))
fo.write('\n\n### Transitions\n')
for s in states:
for a in actions:
for s1 in states:
fo.write('T: {} : {} : {} {}\n'.format(
a, s, s1, f_transition(s, a, s1)))
fo.write('\n')
fo.write('\n\n### Observations\n')
for s in states:
for a in actions:
for o in observations:
fo.write('O: {} : {} : {} {}\n'.format(
a, s, o, f_observation(s, a, o)))
fo.write('\n')
fo.write('\n\n### Rewards\n')
for s in states:
for a in actions:
for s1 in states:
fo.write('R: {} : {} : {} : * {}\n'.format(
a, s, s1, f_reward(s, a, s1)))
fo.write('\n')
class ZPOMDP(POMDP):
"""ZMDP POMDP class."""
def __init__(self):
pass
def solve(self, states, actions, observations, f_start, f_transition,
f_observation, f_reward, discount, timeout=None, directory=None):
"""
Optional: save model and policy files in pre-existing directory
"""
if directory:
# save files in user directory
assert os.path.isdir(directory)
d = directory
else:
# temporary, only return policy
d = tempfile.mkdtemp()
model_filename = os.path.join(d, 'm.pomdp')
policy_filename = os.path.join(d, 'p.policy')
with open(model_filename, 'w') as f:
self.write_cassandra(
f, states, actions, observations, f_start,
f_transition, f_observation, f_reward, discount)
args = [os.environ['ZMDP_ALIAS'], 'solve', model_filename,
'-o', policy_filename]
if timeout:
args += ['-t', str(timeout)]
print args
exit_status = subprocess.call(args)
# NOTE check that solver ran successfully
assert exit_status == 0
# parse policy output
policy = pomdp_policy.POMDPPolicy(policy_filename, file_format='zmdp', n_states=len(states))
if not directory:
# delete temporary directory
shutil.rmtree(d)
return policy