Skip to content

Commit

Permalink
day23: part2 multithread
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-ong committed Dec 23, 2023
1 parent 4992ffb commit f1b95d1
Showing 1 changed file with 91 additions and 35 deletions.
126 changes: 91 additions & 35 deletions day23/lib/classes2.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""part 2 solution"""
import time
from dataclasses import dataclass, field
from multiprocessing import Pool
from queue import Queue

import colorama
import tqdm

from day23.lib.classes import BasePath, Maze, Path, Position, Solver

colorama.init(convert=True)


@dataclass(eq=True)
class Node:
Expand Down Expand Up @@ -70,6 +73,67 @@ def __len__(self) -> int:
return self.path_length


def expand_node_path(node_path: NodePath, nodes: list[Node]) -> list[NodePath]:
"""Expands a node path, giving back a list of of NodePaths"""
last_node: Node = nodes[node_path.last()]
result = []
for edge in last_node.edges:
target_node_id: int = edge.node2
if node_path.can_add(target_node_id):
to_add = node_path.copy()
to_add.add(target_node_id, len(edge.path))
result.append(to_add)
return result


def worker_solve(
nodes: list[Node],
paths_to_process: list[NodePath],
break_early: bool,
thread_id: int,
) -> tuple[list[BasePath], list[NodePath]]:
results: list[BasePath] = []
unfinished_paths: Queue[NodePath] = Queue()
for item in paths_to_process:
unfinished_paths.put(item)

pbar = tqdm.tqdm(
desc=f"Thread{thread_id}", total=len(paths_to_process), position=thread_id
)
if break_early:
pbar.total = 10000
pbar.set_description("Initial run")

while not unfinished_paths.empty():
path = unfinished_paths.get()
node_id: int = path.last()
if node_id == nodes[-1].name: # end node
results.append(path)

if break_early:
pbar.update()
if pbar.n % 10000 == 0:
break

continue

expansions = expand_node_path(path, nodes)

if not break_early:
pbar.total += len(expansions)
pbar.update()

for p in expansions:
unfinished_paths.put(p)
pbar.close()
return results, list(unfinished_paths.queue)


def split_list(items: list[NodePath], num_chunks: int) -> list[list[NodePath]]:
chunk_size = (len(items) // num_chunks) + 1
return [items[i * chunk_size : (i + 1) * chunk_size] for i in range(num_chunks)]


class Solver2(Solver):
maze: Maze

Expand Down Expand Up @@ -160,19 +224,6 @@ def expand_path(self, path: Path) -> list[Path]:

return result

def expand_node_path(
self, node_path: NodePath, nodes: list[Node]
) -> list[NodePath]:
last_node: Node = nodes[node_path.last()]
result = []
for edge in last_node.edges:
target_node_id: int = edge.node2
if node_path.can_add(target_node_id):
to_add = node_path.copy()
to_add.add(target_node_id, len(edge.path))
result.append(to_add)
return result

def build_nodes(self) -> list[Node]:
nodes: dict[Position, Node] = self.get_nodes()
print(self.maze)
Expand All @@ -183,29 +234,34 @@ def build_nodes(self) -> list[Node]:

def solve(self) -> list[BasePath]:
nodes: list[Node] = self.build_nodes()
# print our nodes out:
print("\n".join(str(node) for node in nodes))

first_path = NodePath()
first_path.add(0)
paths: Queue[NodePath] = Queue()
paths.put(first_path)
print("\n".join(str(node) for node in nodes))
last = time.time()
results: list[BasePath] = []
count = 0
while not paths.empty():
path: NodePath = paths.get()
node_id: int = path.last()
if node_id == nodes[-1].name: # end node
# reached an edge
count += 1
results.append(path)
if count % 10000 == 0:
print(paths.qsize(), path, time.time() - last)
last = time.time()
continue
expansions = self.expand_node_path(path, nodes)
for path in expansions:
paths.put(path)

unfinished_paths: list[NodePath] = []
unfinished_paths.append(first_path)

results, unfinished_paths = worker_solve(nodes, unfinished_paths, True, 0)

# time for multithreading!
num_workers = 8
unfinished_chunks: list[list[NodePath]] = split_list(
unfinished_paths, num_workers
)

with Pool(num_workers) as pool:
worker_args = [
(nodes, unfinished_chunks[i], False, i) for i in range(num_workers)
]
result_objects = pool.starmap_async(worker_solve, worker_args)
pool_results = result_objects.get()
for pool_result in pool_results:
paths = pool_result[0]
results.extend(paths)
print("\n" * num_workers * 2) # fix bug in progress bars
# split unfinished_paths:
results.sort(key=lambda x: len(x), reverse=True)

print("total results:", len(results))
return results

0 comments on commit f1b95d1

Please sign in to comment.