-
Notifications
You must be signed in to change notification settings - Fork 0
/
autocomp.py
78 lines (68 loc) · 2.38 KB
/
autocomp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
Automatic function compositions so you don't have to write them yourself.
Algorithm: Unweighted Floyd-Warshall. All pairs shortest path.
Remark: If the base functions have costs like runtime, a weighted version is possible.
"""
from functools import reduce
from itertools import product
from typing import Any, List, Callable, Tuple, Dict
Unary = Callable[[Any], Any]
AutoCompInput = List[Tuple[Unary, str, str]]
AutoCompOutput = Callable[[Any, str, str], Any]
def compose(*funcs: Unary) -> Unary:
return reduce(lambda f, g: lambda x: g(f(x)), funcs)
def createComps(funcs: AutoCompInput) -> AutoCompOutput:
vertIdMap: Dict[str, int] = dict()
vertList: List[str] = list()
compMap: Dict[Tuple[str, str], Unary] = dict()
for _, src, dst in funcs:
if src not in vertIdMap:
vertIdMap[src] = len(vertIdMap)
vertList.append(src)
if dst not in vertIdMap:
vertIdMap[dst] = len(vertIdMap)
vertList.append(dst)
n = len(vertList)
# Table for tracking costs (e.g., weights, distances)
table = [[n]*n for _ in range(n)]
for f, src, dst in funcs:
i = vertIdMap[src]
j = vertIdMap[dst]
table[i][j] = 1
compMap[(src, dst)] = f
for i, src in enumerate(vertList):
table[i][i] = 0
compMap[(src, src)] = lambda x: x
for k in range(n):
for i, j in product(range(n), range(n)):
cost = table[i][k] + table[k][j]
if cost < table[i][j]:
table[i][j] = cost
src = vertList[i]
dst = vertList[j]
vert = vertList[k]
f = compMap[(src, vert)]
g = compMap[(vert, dst)]
compMap[(src, dst)] = compose(f, g)
def _comp(v: Any, src: str, dst: str) -> Any:
if (src, dst) in compMap:
return compMap[(src, dst)](v)
else:
raise Exception('Composition does not exist:', (src, dst))
return _comp
def runTests() -> None:
funcs1 = [
(lambda x: x+1, 'zero', 'one'),
(lambda x: x+9, 'one', 'ten'),
(lambda x: x-10, 'ten', 'zero'),
]
comps1 = createComps(funcs1)
samples1 = [
((1, 'one', 'zero'), 0),
((10, 'ten', 'zero'), 0),
((0, 'zero', 'ten'), 10),
]
for x, y in samples1:
if comps1(*x) != y:
print('Error:', x, y)
runTests()