-
Notifications
You must be signed in to change notification settings - Fork 38
/
type_check_Lfun.py
134 lines (128 loc) · 4.9 KB
/
type_check_Lfun.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import ast
from ast import *
from type_check_Larray import TypeCheckLarray
from utils import *
import typing
class TypeCheckLfun(TypeCheckLarray):
def check_type_equal(self, t1, t2, e):
if t1 == Bottom() or t2 == Bottom():
return
match t1:
case FunctionType(ps1, rt1):
match t2:
case FunctionType(ps2, rt2):
for (p1,p2) in zip(ps1, ps2):
self.check_type_equal(p1, p2, e)
self.check_type_equal(rt1, rt2, e)
case _:
raise Exception('error: ' + repr(t1) + ' != ' + repr(t2) \
+ ' in ' + repr(e))
case _:
super().check_type_equal(t1, t2, e)
def parse_type_annot(self, annot):
match annot:
case Name(id):
if id == 'int':
return IntType()
elif id == 'bool':
return BoolType()
else:
raise Exception('parse_type_annot: unexpected ' + repr(annot))
case TupleType(ts):
return TupleType([self.parse_type_annot(t) for t in ts])
case ListType(elt_ty):
return ListType(self.parse_type_annot(elt_ty))
case FunctionType(ps, rt):
return FunctionType([self.parse_type_annot(t) for t in ps],
self.parse_type_annot(rt))
case Subscript(Name('Callable'), Tuple([ps, rt])):
return FunctionType([self.parse_type_annot(t) for t in ps.elts],
self.parse_type_annot(rt))
case Subscript(Name('tuple'), Tuple(ts)):
return TupleType([self.parse_type_annot(t) for t in ts])
case Subscript(Name('tuple'), t):
return TupleType([self.parse_type_annot(t)])
case Subscript(Name('list'), ty):
return ListType(self.parse_type_annot(ty))
case IntType():
return annot
case BoolType():
return annot
case VoidType():
return annot
case t if t == int:
return IntType()
case t if t == bool:
return BoolType()
case t if t == type(None):
return VoidType()
case Constant(None):
return VoidType()
case _:
raise Exception('parse_type_annot: unexpected ' + repr(annot))
def type_check_exp(self, e, env):
match e:
case FunRef(id, arity):
return env[id]
case Call(Name(f), args) if f in builtin_functions:
return super().type_check_exp(e, env)
case Call(func, args):
func_t = self.type_check_exp(func, env)
args_t = [self.type_check_exp(arg, env) for arg in args]
match func_t:
case FunctionType(params_t, return_t):
for (arg_t, param_t) in zip(args_t, params_t):
self.check_type_equal(param_t, arg_t, e)
return return_t
case _:
raise Exception('type_check_exp: in call, unexpected ' + \
repr(func_t))
case Constant(None):
return VoidType()
case _:
return super().type_check_exp(e, env)
def type_check_stmts(self, ss, env):
if len(ss) == 0:
return VoidType()
match ss[0]:
case FunctionDef(name, params, body, dl, returns, comment):
new_env = {x: t for (x,t) in env.items()}
if isinstance(params, ast.arguments):
new_params = [(p.arg, self.parse_type_annot(p.annotation)) \
for p in params.args]
ss[0].args = new_params
new_returns = self.parse_type_annot(returns)
ss[0].returns = new_returns
else:
new_params = params
new_returns = returns
unique_names = {x for x,_ in new_params}
if len(unique_names) != len(new_params):
raise Exception('type_check: duplicate parameter name in function ' + repr(name))
for x,t in new_params:
new_env[x] = t
rt = self.type_check_stmts(body, new_env)
self.check_type_equal(new_returns, rt, ss[0])
return self.type_check_stmts(ss[1:], env)
case Return(value):
return self.type_check_exp(value, env)
case _:
return super().type_check_stmts(ss, env)
def type_check(self, p):
match p:
case Module(body):
env = {}
for s in body:
match s:
case FunctionDef(name, params, bod, dl, returns, comment):
if isinstance(params, ast.arguments):
params_t = [self.parse_type_annot(p.annotation) \
for p in params.args]
else:
params_t = [t for (x,t) in params]
if name in env:
raise Exception('type_check: duplicate function name ' + name)
env[name] = FunctionType(params_t, self.parse_type_annot(returns))
self.type_check_stmts(body, env)
case _:
raise Exception('type_check: unexpected ' + repr(p))