diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py index 28a10e8..1deb3bc 100644 --- a/tests/partitionedjobs/test_crossbackend.py +++ b/tests/partitionedjobs/test_crossbackend.py @@ -1048,3 +1048,35 @@ def test_triple_split(self): _SubGraphData(split_node="merge1", node_ids={"bands1", "merge1", "temporal2", "lc2"}, backend_id="b2"), ], ) + + @pytest.mark.parametrize( + ["primary_backend", "secondary_graph"], + [ + ("b1", _SubGraphData(split_node="lc2", node_ids={"lc2"}, backend_id="b2")), + ("b2", _SubGraphData(split_node="lc1", node_ids={"lc1"}, backend_id="b1")), + ], + ) + def test_split_with_primary_backend(self, primary_backend, secondary_graph): + """Test `primary_backend` argument of DeepGraphSplitter""" + splitter = DeepGraphSplitter( + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}), + primary_backend=primary_backend, + ) + flat = { + # lc1 lc2 + # \ / + # merge + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "merge": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, + "result": True, + }, + } + result = splitter.split(flat) + assert result == _PGSplitResult( + primary_node_ids={"lc1", "lc2", "merge"}, + primary_backend_id=primary_backend, + secondary_graphs=[secondary_graph], + )