forked from alexreuter/Causal-Model-Task-Planner
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathabstracttypes.py
148 lines (116 loc) · 3.93 KB
/
abstracttypes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from abc import ABC, abstractmethod
import itertools
import customerrors
from copy import deepcopy
class Domain():
def __init__(self, state):
#Is a list of action classes that are possible
self.actions = []
#Stores a current state object
self.state = state
self.causal_models = []
self.goal = None
def checkActionExists(self, name):
for a in self.actions:
if type(a).__name__ == name:
return True
return False
def getValidActions(self, state):
actions = []
#Loop through all the possible actions of the domain
for a in self.actions:
list_to_product = []
#This is making a list of lists of objects that are valid for each parameter
for param in a.param_types:
list_to_product.append(state.obj_types[param])
#Now we have a list of parameters that meet the specified type
type_valid_params = list(itertools.product(*list_to_product))
for params in type_valid_params:
#Now we have to check that the type correct parameters
#satisfy the predicates before we can consider the action valid
try:
a.checkPredicates(state, list(params))
except customerrors.PredicateFailed as e:
continue
#Create a new specific action with the parameters and state
#Add it to the possible actions list
actions.append(SpecificAction(a, list(params), deepcopy(state)))
return actions
@abstractmethod
def isGoalSatisfied(self, state):
pass
# This superclass maintains a list of objects and their names in a domain state
class State(ABC):
def __init__(self):
self.obj_dict = {}
self.obj_names = []
self.objects = []
self.obj_types = {}
def addObject(self, obj):
self.obj_dict[obj.name] = obj
self.obj_names = list(self.obj_dict.keys())
self.objects = list(self.obj_dict.values())
obj_type = type(obj).__name__
try:
self.obj_types[obj_type].append(obj.name)
except KeyError as e:
#We have no record of that type yet
self.obj_types[obj_type] = [obj.name]
def get(self, name):
return self.obj_dict[name]
#Abstract action class
#This ensures that an action follows the correct function pattern
#and has its names stored correctly for enumeration by the planner
class Action(ABC):
def __init__(self, domain):
self.domain = domain
self.domain.actions.append(self)
@abstractmethod
def checkPredicates(self, state):
pass
@abstractmethod
def doAction(self, state):
pass
class SpecificAction():
def __init__(self, action, parameters, state):
self.action = action
self.parameters = parameters
self.state = state
def __str__(self):
if self.action == None:
return "None"
else:
return str(self.action.name) + " " + str(self.parameters) #+ str(self.state)
def __eq__(self, obj):
return (self.parameters == obj.parameters and self.state == obj.state)
def __str__(self):
return str(self.parameters)
def getType(obj):
return type(obj).__name__
def checkType(obj, expected):
objType = getType(obj)
if objType != expected:
raise customerrors.WrongInputType(objType, expected)
def checkPredicateTrue(lambd, obj):
if not(lambd(obj)):
raise customerrors.PredicateFailed()
def checkParams(func):
def func_wrapper(self, state, params):
if len(params) != len(self.param_types):
raise ValueError("Wrong number of parameters")
objs = []
for p in params:
objs.append(state.get(p))
for x in range(len(params)):
correct_type = self.param_types[x]
passed_type = getType(objs[x])
if type(correct_type).__name__ != "list":
#If the type parameter is just a single object
if correct_type != passed_type:
raise ValueError("Wrong parameter type. Param number: " + str(x) + " Expected: " + correct_type + " Found: " + passed_type)
else:
#If there are multiple valid types that an action can be applied to
if passed_type not in correct_type:
raise ValueError("Wrong parameter type. Param number: " + str(x) + " Expected: " + str(correct_type) + " Found: " + passed_type)
return func(*([self, state] + params))
return func_wrapper