-
Notifications
You must be signed in to change notification settings - Fork 7
/
solver.sage
85 lines (68 loc) · 2.18 KB
/
solver.sage
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from sage.modules.free_module_integer import IntegerLattice
# Directly taken from rbtree's LLL repository
# From https://oddcoder.com/LOL-34c3/, https://hackmd.io/@hakatashi/B1OM7HFVI
def Babai_CVP(mat, target):
M = IntegerLattice(mat, lll_reduce=True).reduced_basis
G = M.gram_schmidt()[0]
diff = target
for i in reversed(range(G.nrows())):
diff -= M[i] * ((diff * G[i]) / (G[i] * G[i])).round()
return target - diff
def solve(M, lbounds, ubounds, weight = None):
mat, lb, ub = copy(M), copy(lbounds), copy(ubounds)
num_var = mat.nrows()
num_ineq = mat.ncols()
max_element = 0
for i in range(num_var):
for j in range(num_ineq):
max_element = max(max_element, abs(mat[i, j]))
if weight == None:
weight = num_ineq * max_element
# sanity checker
if len(lb) != num_ineq:
print("Fail: len(lb) != num_ineq")
return
if len(ub) != num_ineq:
print("Fail: len(ub) != num_ineq")
return
for i in range(num_ineq):
if lb[i] > ub[i]:
print("Fail: lb[i] > ub[i] at index", i)
return
# heuristic for number of solutions
DET = 0
if num_var == num_ineq:
DET = abs(mat.det())
num_sol = 1
for i in range(num_ineq):
num_sol *= (ub[i] - lb[i])
if DET == 0:
print("Zero Determinant")
else:
num_sol //= DET
# + 1 added in for the sake of not making it zero...
print("Expected Number of Solutions : ", num_sol + 1)
# scaling process begins
max_diff = max([ub[i] - lb[i] for i in range(num_ineq)])
applied_weights = []
for i in range(num_ineq):
ineq_weight = weight if lb[i] == ub[i] else max_diff // (ub[i] - lb[i])
applied_weights.append(ineq_weight)
for j in range(num_var):
mat[j, i] *= ineq_weight
lb[i] *= ineq_weight
ub[i] *= ineq_weight
# Solve CVP
target = vector([(lb[i] + ub[i]) // 2 for i in range(num_ineq)])
result = Babai_CVP(mat, target)
for i in range(num_ineq):
if (lb[i] <= result[i] <= ub[i]) == False:
print("Fail : inequality does not hold after solving")
break
# recover x
fin = None
if DET != 0:
mat = mat.transpose()
fin = mat.solve_right(result)
## recover your result
return result, applied_weights, fin