-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
124 lines (92 loc) · 3.55 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from inner_type import Polymorphic, Composite, Primitive, Slot, Arrow, unify
class Environment(dict):
def __init__(self, parent=None):
super(Environment, self).__init__(parent if parent else {})
self.type_slots = parent.type_slots if parent else {}
def lookup(self, name):
return self.get(name)
class Form(object):
type_suffix = 0
var_suffix = 0
def __init__(self, name):
self.name = name
def __repr__(self):
return str(self.name)
def infer(self, env):
pass
@staticmethod
def new_type(name=None):
Form.type_suffix += 1
return Slot('{}{}'.format(name or '', Form.type_suffix))
@staticmethod
def new_var(name=None):
Form.var_suffix += 1
return Slot('{}{}'.format(name or 't', Form.type_suffix))
class Id(Form):
def infer(self, env):
default_type = env.lookup(self.name)
if not default_type:
raise NameError('name {} is not defined'.format(default_type))
elif isinstance(default_type, Polymorphic):
return default_type.instantiate(self.new_type)
else:
return default_type
class Apply(Form):
def __init__(self, func, arg):
super(Apply, self).__init__(func)
self.func = func
self.arg = arg
def infer(self, env):
func_type = self.func.infer(env).apply(env.type_slots)
arg_type = self.arg.infer(env).apply(env.type_slots)
slot_foo, slot_bar = self.new_type(), self.new_type()
arrow = Arrow(slot_foo, slot_bar)
if not unify(env.type_slots, arrow, func_type):
raise Exception('type of {} is not a function'.format(self.func, func_type.apply(env.type_slots)))
arg_type_prime = slot_foo.apply(env.type_slots)
if not unify(env.type_slots, arg_type_prime, arg_type):
raise Exception('Type incompatible for {}'.format(self.arg))
result = slot_bar.apply(env.type_slots)
return result
def __repr__(self):
if not isinstance(self.arg, Id):
template = '{} ({})'
else:
template = '{} {}'
return template.format(self.func, self.arg)
class FunctionDefine(Form):
def __init__(self, name, param, body, local=False):
super(FunctionDefine, self).__init__(name)
self.param = param
self.body = body
self.local = local
def infer(self, env):
inner_env = Environment(env)
alpha, beta = self.new_type("a"), self.new_type("b")
func_type = Arrow(alpha, beta)
inner_env[self.param.name] = alpha
inner_env[self.name] = Arrow(alpha, beta)
inner_env.type_slots[beta] = self.body.infer(inner_env)
env.type_slots.update(inner_env.type_slots)
func_type = func_type.apply(inner_env.type_slots)
if self.local:
env[self.name] = func_type
return func_type
else:
free_slots = set()
func_type.get_free_slots(inner_env.type_slots, free_slots)
poly_type = Polymorphic(free_slots, func_type)
env[self.name] = poly_type
return poly_type.instantiate(self.new_type)
def __repr__(self):
return 'function {} {} = {}'.format(self.name, self.param, self.body)
class Assign(Form):
def __init__(self, name, arg):
super(Assign, self).__init__(name)
self.arg = arg
def infer(self, env):
t = self.arg.infer(env)
env[self.name] = t
return t
def __repr__(self):
return "set {} = {}".format(self.name, self.arg)