Skip to content

Commit

Permalink
day20: test case and coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-ong committed Dec 28, 2023
1 parent 173a951 commit 8825640
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 27 deletions.
55 changes: 37 additions & 18 deletions day20/day20.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import math
import os
from multiprocessing import Pool
import shutil
from queue import Queue
from typing import Type, TypeVar, cast

import graphviz
import tqdm
from tqdm.contrib.concurrent import process_map

from day20.lib.classes import (
BaseModule,
Expand All @@ -22,6 +22,7 @@

FILE_A = "day20/input-a.txt"
FILE_B = "day20/input-b.txt"
FILE_PT2 = "day20/input-test2.txt"
FILE_PROD = "day20/input.txt"

FILE = FILE_PROD
Expand Down Expand Up @@ -155,6 +156,18 @@ def graph_modules(module_groups: ModuleGroups, index: int) -> graphviz.Digraph:
return dot


def export_graph(
dots: list[graphviz.Graph],
module_groups: ModuleGroups,
simulation_counter: int,
export_graphs: bool,
) -> None:
"""export a graphviz datatype if graphing is enabled"""
if export_graphs:
dot = graph_modules(module_groups, simulation_counter)
dots.append(dot)


def part2(
modules: list[BaseModule], export_graphs: bool = False
) -> tuple[int, list[graphviz.Graph]]:
Expand All @@ -164,10 +177,12 @@ def part2(

# graph modules in initial state
dots: list[graphviz.Graph] = []
dot: graphviz.Graph

simulation_counter = 0
loop_counter: LoopCounter = LoopCounter(len(module_groups.loops))

# output our initial state:
export_graph(dots, module_groups, simulation_counter, export_graphs)

# run simulation, screenshotting everytime one of the paths "loops"
while not loop_counter.finished:
simulate(module_map)
Expand All @@ -176,9 +191,7 @@ def part2(
if path_is_start_state(loop_path):
loop_end_name = loop_path[-1].name
loop_counter.add_result(loop_end_name, simulation_counter)
if export_graphs:
dot = graph_modules(module_groups, simulation_counter)
dots.append(dot)
export_graph(dots, module_groups, simulation_counter, export_graphs)

print(loop_counter)
result = math.lcm(*list(loop_counter.loop_lengths.values()))
Expand All @@ -198,24 +211,30 @@ def part1(modules: list[BaseModule]) -> int:
return low_total * high_total


def output_graph(dot: graphviz.Graph) -> None:
def output_graph(dot: graphviz.Graph, directory: str) -> None:
"""Saves a dot to file"""
dot.render(directory=VIS_FOLDER)
dot.render(directory=directory)


def output_graph_wrapper(args: tuple[graphviz.Graph, str]) -> None:
"""Since process_map doesnt support star_args, we gotta use this"""
dot, directory = args
output_graph(dot, directory)


def output_files(dots: list[graphviz.Graph]) -> None:
def output_files(dots: list[graphviz.Graph], directory: str) -> None:
"""Saves a list of dots to file"""
if len(dots) == 0:
return
os.makedirs(VIS_FOLDER, exist_ok=True)
with Pool() as pool:
for _ in tqdm.tqdm(pool.imap_unordered(output_graph, dots), total=len(dots)):
pass
shutil.rmtree(directory, ignore_errors=True)
os.makedirs(directory, exist_ok=True)
dot_dirs = [(dot, directory) for dot in dots]
process_map(output_graph_wrapper, dot_dirs, chunksize=4) # type: ignore

# Cleanup *.gv files
for item in os.listdir(VIS_FOLDER):
for item in os.listdir(directory):
if item.endswith(".gv"):
os.remove(os.path.join(VIS_FOLDER, item))
os.remove(os.path.join(directory, item))


def main() -> None:
Expand All @@ -227,9 +246,9 @@ def main() -> None:
# Reload because part1 ruins stuff

modules = get_modules(FILE)
result, dots = part2(modules)
result, dots = part2(modules, EXPORT_GRAPHS)
print(result)
output_files(dots)
output_files(dots, VIS_FOLDER)


if __name__ == "__main__":
Expand Down
26 changes: 26 additions & 0 deletions day20/input-test2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
broadcaster -> loop1a, loop2a, loop3a, loop4a
%loop1a -> loop1_end, loop1b
%loop1b -> loop1c
%loop1c -> loop1d
%loop1d -> loop1_end
&loop1_end -> loop1a, loop1b, loop1c, loop1_tail
&loop1_tail -> final_conj
%loop2a -> loop2_end, loop2b
%loop2b -> loop2c, loop2_end
%loop2c -> loop2d
%loop2d -> loop2_end
&loop2_end -> loop2_tail, loop2a, loop2c
&loop2_tail -> final_conj
%loop3a -> loop3_end, loop3b
%loop3b -> loop3_end, loop3c
%loop3c -> loop3d
%loop3d -> loop3_end
&loop3_end -> loop3a, loop3_tail, loop3c
&loop3_tail -> final_conj
%loop4a -> loop4_end, loop4b
%loop4b -> loop4_end, loop4c
%loop4c -> loop4d, loop4_end
%loop4d -> loop4_end
&loop4_end -> loop4_tail, loop4a
&loop4_tail -> final_conj
&final_conj -> rx
6 changes: 3 additions & 3 deletions day20/lib/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ class Pulse(Flag):
HIGH = True

def __str__(self) -> str:
if self.name is None:
raise ValueError("no valid value for this Pulse")
return self.name.lower()
if self.name is not None:
return self.name.lower()
raise AssertionError("no valid value for this Pulse")


@dataclass
Expand Down
4 changes: 1 addition & 3 deletions day20/lib/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ def parse_line(line: str) -> BaseModule:
return FlipFlopModule(module_name, destination_list)
if module_type_name.startswith("&"):
return ConjunctionModule(module_name, destination_list)
raise ValueError(f"Unparsable line: {line}")
raise AssertionError(f"Unparsable line: {line}")


def get_modules(filename: str) -> list[BaseModule]:
modules: list[BaseModule] = []
with open(filename, encoding="utf8") as file:
for line in file:
if len(line.strip()) == 0:
break
module: BaseModule = parse_line(line)
modules.append(module)

Expand Down
21 changes: 18 additions & 3 deletions day20/tests/test_day20.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from day20.day20 import FILE_A, FILE_B, FILE_PROD, part1, part2
import os
import tempfile

from day20.day20 import FILE_A, FILE_B, FILE_PT2, output_files, part1, part2
from day20.lib.parsers import get_modules


Expand All @@ -11,5 +14,17 @@ def test_day20() -> None:


def test_part2() -> None:
modules = get_modules(FILE_PROD)
assert part2(modules)[0] == 252667369442479
modules = get_modules(FILE_PT2)
result, dots = part2(modules, True)
assert result == 495
with tempfile.TemporaryDirectory(prefix="unit_test_outputs") as temp_dir:
output_files(dots, temp_dir)
assert len(os.listdir(temp_dir)) == 16

# run it without exporting.
modules = get_modules(FILE_PT2)
result, dots = part2(modules, False)
assert result == 495

with tempfile.TemporaryDirectory(prefix="unit_test_outputs") as temp_dir:
output_files(dots, temp_dir)

0 comments on commit 8825640

Please sign in to comment.