diff --git a/src/py_d2/connection.py b/src/py_d2/connection.py index 215a354..b65f8f1 100644 --- a/src/py_d2/connection.py +++ b/src/py_d2/connection.py @@ -26,3 +26,11 @@ def lines(self) -> List[str]: def __repr__(self) -> str: return "\n".join(self.lines()) + + def __hash__(self): + return hash((self.shape_1, self.shape_2, self.label, self.direction)) + + def __eq__(self, other) -> bool: + if ((self.shape_1, self.shape_2, self.direction, self.label) == (other.shape_1, other.shape_2, other.direction, other.label)): + return True + return False diff --git a/tests/test_py_d2/test_d2_connection.py b/tests/test_py_d2/test_d2_connection.py index 954992c..2c03071 100644 --- a/tests/test_py_d2/test_d2_connection.py +++ b/tests/test_py_d2/test_d2_connection.py @@ -31,3 +31,9 @@ def test_d2_connection_direction_both(): def test_d2_connection_direction_none(): connection = D2Connection(shape_1="a", shape_2="b", direction=Direction.NONE) assert str(connection) == "a -- b" + + +def test_d2_connection_uniqueness(): + connection = D2Connection(shape_1="a", shape_2="b", direction=Direction.TO) + connection2 = D2Connection(shape_1="a", shape_2="b", direction=Direction.TO) + assert connection == connection2