-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlkanren.lua
303 lines (258 loc) · 5.56 KB
/
lkanren.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
-- Implementaiton of micro-kanren, plus some idiosyncratic extensions
-- see: http://webyrd.net/scheme-2013/papers/HemannMuKanren2013.pdf
-- Variables
local variable_mt = {
__tostring = function(v) return "_"..v.id end,
__eq = function(a, b) return a.id == b.id end
}
local function variable(id)
v = {id=id}
setmetatable(v, variable_mt)
return v
end
local function isVar(v)
return getmetatable(v) == variable_mt
end
-- State
local function empty()
return {
-- substitutions is a map between the string encoding of variables
-- and either variables or concrete values.
substitutions = {},
count = 0
}
end
local function increment(state)
local new = table_copy(state)
new.count = new.count + 1
return new
end
-- This implementation of walk recurses over lists and tables;
local function walk(u, state)
if not isVar(u) then
if type(u) == 'table' then
local ret = {}
for k,v in pairs(u) do
ret[k] = walk(v, state)
end
return ret
end
return u
end
if state.substitutions[tostring(u)] ~= nil then
return walk(state.substitutions[tostring(u)], state)
end
return u
end
local function extend(u, v, state)
if not isVar(u) then
error ("can only extend a set of substitutions with variables, got"..u)
end
local new = table_copy(state)
new.substitutions[tostring(u)] = v
return new
end
-- Streams
local function stream(first, rest)
if rest ~= nil and type(rest) ~= 'function' then
error "When construction a stream, 'rest' must be a function if non-nil"
end
return {first=first, rest=rest}
end
-- constructs a stream with a single state
local function unit(state)
return stream(state, nil)
end
local function merge(stream1, stream2)
if stream1 == nil or (stream1.first == nil and stream1.rest == nil) then
return stream2
end
if stream1.first == nil and stream1.rest ~= nil then
return stream(
nil,
function() return merge(stream2, stream1.rest()) end
)
end
if stream1.rest ~= nil then
return stream(
stream1.first,
function() return merge(stream1.rest(), stream2) end
)
else
return stream(
stream1.first,
function() return stream2 end
)
end
end
-- goals are func -> state ->stream
local function bind(stream1, goal)
-- TODO(andrew): clean this up.
if stream1 == nil or (stream1.first == nil and stream1.rest == nil) then
return stream(nil, nil)
end
if stream1.first == nil and stream1.rest ~= nil then
return bind(stream1.rest(), goal)
end
if stream1.rest == nil then
return goal(stream1.first)
end
return merge( goal(stream1.first), bind( stream1.rest(), goal))
end
-- unify
local function unify(u, v, state)
u = walk(u, state)
v = walk(v, state)
if isVar(u) and isVar(v) and u == v then
return state
elseif isVar(u) then
return extend(u, v, state)
elseif isVar(v) then
return extend(v, u, state)
elseif type(u) == 'table' and type(v) == 'table' then
-- if the keys in each table are different, then we can't unify.
if not check_keys(u, v) then
return nil
end
for k,u_value in pairs(u) do
state = unify(u_value, v[k], state)
if state == nil then
return nil
end
end
return state
elseif table_eq(u, v) then
return state
end
return nil
end
-- goal constructors
local function equal(u, v)
return function(state)
local s1 = unify(u, v, state)
if s1 ~= nil then
return unit(s1)
else
return nil
end
end
end
local function notequal(u, v)
return function(state)
local u = walk(u, state)
local v = walk(v, state)
if u ~= v then
return unit(state)
else
return nil
end
end
end
local function disj(goal1, goal2)
return function(state)
return merge(goal1(state), goal2(state))
end
end
local function conj(goal1, goal2)
return function(state)
return bind(goal1(state), goal2)
end
end
local function fail()
return function(state)
return nil
end
end
local function succeed()
return function(state)
return unit(state)
end
end
local function all(...)
local res = fail()
for i, g in ipairs({...}) do
if i == 1 then
res = g
else
-- This argument ordering is unfortunate, as it means that we will
-- traverse all the way down.
res = conj(res, g)
end
end
return res
end
local function any(...)
local res = succeed()
for i, g in ipairs({...}) do
if i == 1 then
res = g
else
-- This argument ordering is unfortunate, as it means that we will
-- traverse all the way down.
res = disj(res, g)
end
end
return res
end
-- main entry point
local function call_fresh_n(n, f)
return function(state)
local vars = {}
for i = 1,n do
vars[i] = variable(state.count)
state = increment(state)
end
return f(unpack(vars))(state)
end
end
-- f: variable -> state -> stream
local function call_fresh(f)
return call_fresh_n(1, f)
end
-- helpers
local function reify1st(state)
return walk(variable(0), state)
end
local function pull(stream, n, e)
e = e or function(v) return v end
local results = {}
while #results < n and stream ~= nil do
if stream.first ~= nil then
results[#results + 1] = e(stream.first)
end
if stream.rest ~= nil then
stream = stream.rest()
else
return results
end
end
return results
end
local function reifyN1st(stream, n)
return pull(stream, n, reify1st)
end
return {
-- Variables
variable = variable,
isVar = isVar,
-- States
extend = extend,
empty = empty,
-- TODO(andrew): reconsider exporting this.
unify = unify,
-- streams
stream = stream,
unit = unit,
equal = equal,
notequal = notequal,
conj = conj,
disj = disj,
succeed = succeed,
fail = fail,
all = all,
any = any,
call_fresh = call_fresh,
call_fresh_n = call_fresh_n,
-- helpers
reifyN1st = reifyN1st
}