-
Notifications
You must be signed in to change notification settings - Fork 0
/
torch_maze_solver.py
52 lines (46 loc) · 1.63 KB
/
torch_maze_solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn.functional as F
import numpy as np
import time
def create_directional_kernels(K):
middle = K // 2
kernels = []
for i in range(K):
for j in range(K):
if (i == middle and j == middle) or (abs(i - middle) + abs(j - middle) != 1):
continue
kernel = torch.zeros((K, K))
kernel[middle, middle] = 1
kernel[i, j] = 1
kernels.append(kernel)
return torch.stack(kernels).unsqueeze(1)
def solve_maze(maze, K):
maze = torch.sparse_coo_tensor(
indices=torch.nonzero(maze).t(),
values=torch.ones(maze.nonzero().shape[0]),
size=maze.shape
)
kernels = create_directional_kernels(K)
state = torch.zeros_like(maze.to_dense())
state[0, 0] = 1
state = state.unsqueeze(0).unsqueeze(0)
while True:
propagation = F.conv2d(state, kernels, padding=K//2)
state_ = torch.max(propagation, dim=1, keepdim=True)[0]
state_ = state_ * maze.to_dense().unsqueeze(0).unsqueeze(0)
if state_[0, 0, -1, -1] > 0:
return True
if torch.all(state_ == state):
return False
state = state_
if __name__ == '__main__':
N = 10
maze = torch.tensor(np.random.choice([0, 1], size=(N, N), p=[0.3, 0.7]))
maze[0, 0], maze[N-1, N-1] = 1, 1
print(f"Maze: \n{maze.numpy()}")
for K in [3, 5, 7]:
print(f"\nTesting with K = {K}")
start = time.time()
result = solve_maze(maze, K)
print(f"Maze is {'solvable' if result else 'not solvable'}")
print(f"time: {(time.time() - start) * 1e3:.2f} ms")