From 14619ec5c1336e2db500aa9189a177ed2a5f6725 Mon Sep 17 00:00:00 2001 From: Chris Mutel Date: Wed, 27 Nov 2024 08:50:00 +0100 Subject: [PATCH] Add more filter options to `link_iterable_by_fields` --- bw2io/importers/base_lci.py | 35 +- bw2io/importers/base_lcia.py | 2 +- bw2io/importers/ecospold1.py | 2 +- bw2io/importers/excel.py | 2 +- bw2io/importers/excel_lcia.py | 2 +- bw2io/importers/json_ld.py | 4 +- bw2io/importers/simapro_block_csv.py | 4 +- bw2io/importers/simapro_csv.py | 2 +- bw2io/strategies/generic.py | 46 +- tests/strategies/link_iterable.py | 841 +++++++++++++++------------ 10 files changed, 544 insertions(+), 396 deletions(-) diff --git a/bw2io/importers/base_lci.py b/bw2io/importers/base_lci.py index 6c0e35c0..3fe1c5de 100644 --- a/bw2io/importers/base_lci.py +++ b/bw2io/importers/base_lci.py @@ -389,12 +389,16 @@ def write_excel( def match_database( self, - db_name=None, - fields=None, - ignore_categories=False, - relink=False, - kind=None, - ): + db_name: Optional[str] = None, + fields: Optional[List[str]] = None, + ignore_categories: bool = False, + relink: bool = False, + kind: Optional[Union[List[str], str]] = None, + edge_kinds: Optional[List[str]] = None, + this_node_kinds: Optional[List[str]] = None, + other_node_kinds: Optional[List[str]] = None, + processes_to_products: bool = False, + ) -> None: """Match current database against itself or another database. If ``db_name`` is None, match against current data. Otherwise, ``db_name`` should be the name of an existing ``Database``. @@ -410,9 +414,24 @@ def match_database( Nothing is returned, but ``self.data`` is changed. """ + if kind is not None: + warnings.warn( + "`kind` is deprecated, please use `edge_kind` instead", + DeprecationWarning, + ) + edge_kinds = list(kind) + + if processes_to_products: + this_node_kinds = labels.process_node_types + [ + labels.multifunctional_node_default + ] + other_node_kinds = labels.product_node_types + kwargs = { "fields": fields, - "kind": kind, + "edge_kinds": edge_kinds, + "this_node_kinds": this_node_kinds, + "other_node_kinds": other_node_kinds, "relink": relink, } if fields and ignore_categories: @@ -648,7 +667,7 @@ def reformat(exc): for obj in Database(biosphere_name) if obj.get("type") == "emission" ), - kind="biosphere", + edge_kinds=["biosphere"], ), ) diff --git a/bw2io/importers/base_lcia.py b/bw2io/importers/base_lcia.py index 58565893..4bda2841 100644 --- a/bw2io/importers/base_lcia.py +++ b/bw2io/importers/base_lcia.py @@ -41,7 +41,7 @@ def __init__(self, filepath, biosphere=None): for obj in Database(self.biosphere_name) if obj.get("type") == "emission" ), - kind="biosphere", + edge_kinds=["biosphere"], ), functools.partial( match_subcategories, biosphere_db_name=self.biosphere_name diff --git a/bw2io/importers/ecospold1.py b/bw2io/importers/ecospold1.py index 9f91c3cd..cba16532 100644 --- a/bw2io/importers/ecospold1.py +++ b/bw2io/importers/ecospold1.py @@ -72,7 +72,7 @@ def __init__( functools.partial( link_iterable_by_fields, other=Database(config.biosphere), - kind="biosphere", + edge_kinds=["biosphere"], ), functools.partial( link_technosphere_by_activity_hash, diff --git a/bw2io/importers/excel.py b/bw2io/importers/excel.py index 8af81709..d28004e3 100644 --- a/bw2io/importers/excel.py +++ b/bw2io/importers/excel.py @@ -93,7 +93,7 @@ def __init__(self, filepath): functools.partial( link_iterable_by_fields, other=Database(config.biosphere), - kind="biosphere", + edge_kinds=["biosphere"], ), assign_only_product_as_production, link_technosphere_by_activity_hash, diff --git a/bw2io/importers/excel_lcia.py b/bw2io/importers/excel_lcia.py index 885b6d95..696beba7 100644 --- a/bw2io/importers/excel_lcia.py +++ b/bw2io/importers/excel_lcia.py @@ -67,7 +67,7 @@ def __init__(self, filepath, name, description, unit, **metadata): functools.partial( link_iterable_by_fields, other=Database(config.biosphere), - kind="biosphere", + edge_kinds=["biosphere"], fields=("name", "categories"), ), drop_falsey_uncertainty_fields_but_keep_zeros, diff --git a/bw2io/importers/json_ld.py b/bw2io/importers/json_ld.py index f9c930f1..ba125a7c 100644 --- a/bw2io/importers/json_ld.py +++ b/bw2io/importers/json_ld.py @@ -65,14 +65,14 @@ def __init__(self, dirpath, database_name, preferred_allocation=None): partial( link_iterable_by_fields, fields=["code"], - kind={"production", "technosphere"}, + edge_kinds=["production", "technosphere"], internal=True, ), partial( link_iterable_by_fields, other=self.biosphere_database, fields=["code"], - kind={"biosphere"}, + edge_kinds=["biosphere"], ), normalize_units, ] diff --git a/bw2io/importers/simapro_block_csv.py b/bw2io/importers/simapro_block_csv.py index 5a44d2f7..dd68c95a 100644 --- a/bw2io/importers/simapro_block_csv.py +++ b/bw2io/importers/simapro_block_csv.py @@ -70,7 +70,7 @@ def __init__( functools.partial( link_iterable_by_fields, other=Database(biosphere_database_name or config.biosphere), - kind=labels.biosphere_edge_types, + edge_kinds=labels.biosphere_edge_types, ), match_internal_simapro_simapro_with_unit_conversion, ] @@ -226,7 +226,7 @@ def use_ecoinvent_strategies(self) -> None: other=Database( self.default_biosphere_database_name or config.biosphere ), - kind="biosphere", + edge_kinds=["biosphere"], ), ] diff --git a/bw2io/importers/simapro_csv.py b/bw2io/importers/simapro_csv.py index 71031404..9d4b9851 100644 --- a/bw2io/importers/simapro_csv.py +++ b/bw2io/importers/simapro_csv.py @@ -92,7 +92,7 @@ def __init__( functools.partial( link_iterable_by_fields, other=Database(biosphere_db or config.biosphere), - kind="biosphere", + edge_kinds=["biosphere"], ), convert_activity_parameters_to_list, ] diff --git a/bw2io/strategies/generic.py b/bw2io/strategies/generic.py index 4a396c61..0751ef81 100644 --- a/bw2io/strategies/generic.py +++ b/bw2io/strategies/generic.py @@ -1,5 +1,6 @@ import numbers import pprint +import warnings from collections import defaultdict from copy import deepcopy from typing import Iterable, List, Optional, Union @@ -75,6 +76,9 @@ def link_iterable_by_fields( other: Optional[Iterable[dict]] = None, fields: Optional[List[str]] = None, kind: Optional[Union[str, List[str]]] = None, + edge_kinds: Optional[List[str]] = None, + this_node_kinds: Optional[List[str]] = None, + other_node_kinds: Optional[List[str]] = None, internal: bool = False, relink: bool = False, ) -> List[dict]: @@ -91,9 +95,15 @@ def link_iterable_by_fields( fields : iterable[str], optional An iterable of strings indicating which fields should be used to match objects. If not specified, all fields will be used. - kind : str|list[string], optional + kind : deprecated + Use `edge_kinds` instead + edge_kinds : str|list[string], optional If specified, limit the exchange to objects of the given kind. `kind` can be a string or an iterable of strings. + this_node_kinds : str|list[string], optional + If specified, limit linking to objects in `unlinked` which have `type` in `this_node_kinds`. + other_node_kinds : str|list[string], optional + If specified, limit linking to objects in `other` which have `type` in `other_node_kinds`. internal : bool, optional If `True`, link objects in `unlinked` to other objects in `unlinked`. Each object must have the attributes `database` and `code`. @@ -153,17 +163,33 @@ def link_iterable_by_fields( >>> linked[1]["exchanges"][0]["input"] ('db2', 'C') """ - if kind: - kind = {kind} if isinstance(kind, str) else kind + if kind is not None: + warnings.warn( + "`kind` is deprecated, please use `edge_kinds` instead", DeprecationWarning + ) + edge_kinds = kind + + if edge_kinds: + edge_kinds = {edge_kinds} if isinstance(edge_kinds, str) else edge_kinds if relink: - filter_func = lambda x: x.get("type") in kind + edge_filter_func = lambda x: x.get("type") in edge_kinds else: - filter_func = lambda x: x.get("type") in kind and not x.get("input") + edge_filter_func = lambda x: x.get("type") in edge_kinds and not x.get( + "input" + ) else: if relink: - filter_func = lambda x: True + edge_filter_func = lambda x: True else: - filter_func = lambda x: not x.get("input") + edge_filter_func = lambda x: not x.get("input") + if this_node_kinds: + this_filter_func = lambda x: x.get("type") in this_node_kinds + else: + this_filter_func = lambda x: True + if other_node_kinds: + other_filter_func = lambda x: x.get("type") in other_node_kinds + else: + other_filter_func = lambda x: True if internal: other = unlinked @@ -171,7 +197,7 @@ def link_iterable_by_fields( duplicates, candidates = {}, {} try: # Other can be a generator, so a bit convoluted - for ds in other: + for ds in filter(other_filter_func, other): key = activity_hash(ds, fields) if key in candidates: duplicates.setdefault(key, []).append(ds) @@ -183,8 +209,8 @@ def link_iterable_by_fields( "``database`` or ``code`` attributes" ) - for container in unlinked: - for obj in filter(filter_func, container.get("exchanges", [])): + for container in filter(this_filter_func, unlinked): + for obj in filter(edge_filter_func, container.get("exchanges", [])): key = activity_hash(obj, fields) if key in duplicates: raise StrategyError( diff --git a/tests/strategies/link_iterable.py b/tests/strategies/link_iterable.py index 188bfbdd..c273199b 100644 --- a/tests/strategies/link_iterable.py +++ b/tests/strategies/link_iterable.py @@ -1,387 +1,490 @@ import copy import unittest +import pytest + from bw2io.errors import StrategyError from bw2io.strategies import link_iterable_by_fields -class LinkIterableTestCase(unittest.TestCase): - def test_all_datasets_in_target_have_database_field(self): - self.assertEqual( - link_iterable_by_fields([], [{"database": "foo", "code": "bar"}]), [] - ) - with self.assertRaises(StrategyError): - link_iterable_by_fields([], [{"code": "bar"}]) +def test_all_datasets_in_target_have_database_field(): + assert link_iterable_by_fields([], [{"database": "foo", "code": "bar"}]) == [] + with pytest.raises(StrategyError): + link_iterable_by_fields([], [{"code": "bar"}]) - def test_all_datasets_in_target_have_code_field(self): - self.assertEqual( - link_iterable_by_fields([], [{"database": "foo", "code": "bar"}]), [] - ) - with self.assertRaises(StrategyError): - link_iterable_by_fields([], [{"database": "foo"}]) - def test_nonunique_target_but_not_linked_no_error(self): - data = [ - {"name": "foo", "database": "a", "code": "b"}, - {"name": "foo", "database": "a", "code": "c"}, - {"name": "bar", "database": "a", "code": "d"}, - ] - self.assertEqual( - link_iterable_by_fields([{"exchanges": [{"name": "bar"}]}], data), - [{"exchanges": [{"name": "bar", "input": ("a", "d")}]}], - ) +def test_all_datasets_in_target_have_code_field(): + assert link_iterable_by_fields([], [{"database": "foo", "code": "bar"}]) == [] + with pytest.raises(StrategyError): + link_iterable_by_fields([], [{"database": "foo"}]) - def test_nonunique_target_raises_error(self): - data = [ - {"name": "foo", "database": "a", "code": "b"}, - {"name": "foo", "database": "a", "code": "c"}, - {"name": "bar", "database": "a", "code": "d"}, - ] - with self.assertRaises(StrategyError): - link_iterable_by_fields([{"exchanges": [{"name": "foo"}]}], data) - def test_generic_linking_no_kind_no_relink(self): - unlinked = [ - { - "exchanges": [ - { - "name": "foo", - "categories": ("bar",), - "type": "a", - }, - { - "name": "foo", - "unit": "kilogram", - "type": "b", - }, - ] - } - ] - other = [ - {"name": "foo", "categories": ("bar",), "database": "db", "code": "first"}, - {"name": "baz", "categories": ("bar",), "database": "db", "code": "second"}, - ] - expected = [ - { - "exchanges": [ - { - "name": "foo", - "type": "a", - "categories": ("bar",), - "input": ("db", "first"), - }, - {"name": "foo", "type": "b", "unit": "kilogram"}, - ] - } - ] - self.assertEqual(expected, link_iterable_by_fields(unlinked, other)) +def test_nonunique_target_but_not_linked_no_error(): + data = [ + {"name": "foo", "database": "a", "code": "b"}, + {"name": "foo", "database": "a", "code": "c"}, + {"name": "bar", "database": "a", "code": "d"}, + ] + assert link_iterable_by_fields([{"exchanges": [{"name": "bar"}]}], data) == [ + {"exchanges": [{"name": "bar", "input": ("a", "d")}]} + ] - def test_internal_linking(self): - unlinked = [ - { - "database": "db", - "code": "first", - "name": "foo", - "categories": ("bar",), - "exchanges": [ - {"name": "foo", "categories": ("bar",)}, - {"name": "foo", "categories": ("baz",)}, - ], - }, - { - "database": "db", - "code": "second", - "name": "foo", - "categories": ("baz",), - "exchanges": [], - }, - ] - expected = [ - { - "database": "db", - "code": "first", - "name": "foo", - "categories": ("bar",), - "exchanges": [ - { - "name": "foo", - "categories": ("bar",), - "input": ("db", "first"), - }, - { - "name": "foo", - "categories": ("baz",), - "input": ("db", "second"), - }, - ], - }, - { - "database": "db", - "code": "second", - "name": "foo", - "categories": ("baz",), - "exchanges": [], - }, - ] - self.assertEqual(expected, link_iterable_by_fields(unlinked, internal=True)) - def test_kind_filter(self): - unlinked = [ - { - "database": "db", - "code": "first", - "name": "foo", - "categories": ("bar",), - "exchanges": [ - { - "name": "foo", - "categories": ("bar",), - "type": "a", - }, - {"name": "foo", "categories": ("baz",)}, - ], - }, - { - "database": "db", - "code": "second", - "name": "foo", - "categories": ("baz",), - "exchanges": [], - }, - ] - expected = [ - { - "database": "db", - "code": "first", - "name": "foo", - "categories": ("bar",), - "exchanges": [ - { - "name": "foo", - "categories": ("bar",), - "type": "a", - "input": ("db", "first"), - }, - { - "name": "foo", - "categories": ("baz",), - }, - ], - }, - { - "database": "db", - "code": "second", - "name": "foo", - "categories": ("baz",), - "exchanges": [], - }, - ] - self.assertEqual( - expected, link_iterable_by_fields(unlinked, internal=True, kind="a") - ) - self.assertEqual( - expected, link_iterable_by_fields(unlinked, internal=True, kind=["a"]) - ) +def test_nonunique_target_raises_error(): + data = [ + {"name": "foo", "database": "a", "code": "b"}, + {"name": "foo", "database": "a", "code": "c"}, + {"name": "bar", "database": "a", "code": "d"}, + ] + with pytest.raises(StrategyError): + link_iterable_by_fields([{"exchanges": [{"name": "foo"}]}], data) - def test_kind_filter_and_relink(self): - unlinked = [ - { - "database": "db", - "code": "first", - "name": "foo", - "categories": ("bar",), - "exchanges": [ - { - "name": "foo", - "categories": ("bar",), - "type": "a", - "input": ("something", "else"), - }, - {"name": "foo", "categories": ("baz",)}, - ], - }, - { - "database": "db", - "code": "second", - "name": "foo", - "categories": ("baz",), - "exchanges": [], - }, - ] - expected = [ - { - "database": "db", - "code": "first", - "name": "foo", - "categories": ("bar",), - "exchanges": [ - { - "name": "foo", - "categories": ("bar",), - "type": "a", - "input": ("db", "first"), - }, - { - "name": "foo", - "categories": ("baz",), - }, - ], - }, - { - "database": "db", - "code": "second", - "name": "foo", - "categories": ("baz",), - "exchanges": [], - }, - ] - self.assertEqual( - expected, - link_iterable_by_fields(unlinked, internal=True, kind="a", relink=True), - ) - def test_relink(self): - unlinked = [ - { - "database": "db", - "code": "first", - "name": "foo", - "categories": ("bar",), - "exchanges": [ - { - "name": "foo", - "categories": ("bar",), - "type": "a", - "input": ("something", "else"), - }, - { - "name": "foo", - "type": "b", - "input": ("something", "else"), - "categories": ("baz",), - }, - ], - }, - { - "database": "db", - "code": "second", - "name": "foo", - "categories": ("baz",), - "exchanges": [], - }, - ] - expected = [ - { - "database": "db", - "code": "first", - "name": "foo", - "categories": ("bar",), - "exchanges": [ - { - "name": "foo", - "categories": ("bar",), - "type": "a", - "input": ("db", "first"), - }, - { - "name": "foo", - "type": "b", - "input": ("db", "second"), - "categories": ("baz",), - }, - ], - }, - { - "database": "db", - "code": "second", - "name": "foo", - "categories": ("baz",), - "exchanges": [], - }, - ] - self.assertEqual( - expected, link_iterable_by_fields(unlinked, internal=True, relink=True) - ) +def test_generic_linking_no_kind_no_relink(): + unlinked = [ + { + "exchanges": [ + { + "name": "foo", + "categories": ("bar",), + "type": "a", + }, + { + "name": "foo", + "unit": "kilogram", + "type": "b", + }, + ] + } + ] + other = [ + {"name": "foo", "categories": ("bar",), "database": "db", "code": "first"}, + {"name": "baz", "categories": ("bar",), "database": "db", "code": "second"}, + ] + expected = [ + { + "exchanges": [ + { + "name": "foo", + "type": "a", + "categories": ("bar",), + "input": ("db", "first"), + }, + {"name": "foo", "type": "b", "unit": "kilogram"}, + ] + } + ] + assert link_iterable_by_fields(unlinked, other) == expected + + +def test_internal_linking(): + unlinked = [ + { + "database": "db", + "code": "first", + "name": "foo", + "categories": ("bar",), + "exchanges": [ + {"name": "foo", "categories": ("bar",)}, + {"name": "foo", "categories": ("baz",)}, + ], + }, + { + "database": "db", + "code": "second", + "name": "foo", + "categories": ("baz",), + "exchanges": [], + }, + ] + expected = [ + { + "database": "db", + "code": "first", + "name": "foo", + "categories": ("bar",), + "exchanges": [ + { + "name": "foo", + "categories": ("bar",), + "input": ("db", "first"), + }, + { + "name": "foo", + "categories": ("baz",), + "input": ("db", "second"), + }, + ], + }, + { + "database": "db", + "code": "second", + "name": "foo", + "categories": ("baz",), + "exchanges": [], + }, + ] + assert link_iterable_by_fields(unlinked, internal=True) == expected + + +def test_edge_kinds_filter(): + unlinked = [ + { + "database": "db", + "code": "first", + "name": "foo", + "categories": ("bar",), + "exchanges": [ + { + "name": "foo", + "categories": ("bar",), + "type": "a", + }, + {"name": "foo", "categories": ("baz",)}, + ], + }, + { + "database": "db", + "code": "second", + "name": "foo", + "categories": ("baz",), + "exchanges": [], + }, + ] + expected = [ + { + "database": "db", + "code": "first", + "name": "foo", + "categories": ("bar",), + "exchanges": [ + { + "name": "foo", + "categories": ("bar",), + "type": "a", + "input": ("db", "first"), + }, + { + "name": "foo", + "categories": ("baz",), + }, + ], + }, + { + "database": "db", + "code": "second", + "name": "foo", + "categories": ("baz",), + "exchanges": [], + }, + ] + assert ( + link_iterable_by_fields(unlinked, internal=True, edge_kinds=["a"]) == expected + ) + assert link_iterable_by_fields(unlinked, internal=True, edge_kinds="a") == expected - def test_linking_with_fields(self): - unlinked = [ - { - "exchanges": [ - { - "name": "foo", - "categories": ("bar",), - "type": "a", - }, - { - "name": "foo", - "categories": ("baz",), - "unit": "kilogram", - "type": "b", - }, - ] - } - ] - other = [ - {"name": "foo", "categories": ("bar",), "database": "db", "code": "first"}, - {"name": "foo", "categories": ("baz",), "database": "db", "code": "second"}, - ] - expected = [ - { - "exchanges": [ - { - "name": "foo", - "type": "a", - "categories": ("bar",), - "input": ("db", "first"), - }, - { - "name": "foo", - "type": "b", - "categories": ("baz",), - "input": ("db", "second"), - "unit": "kilogram", - }, - ] - } - ] - self.assertEqual( - expected, - link_iterable_by_fields(unlinked, other, fields=["name", "categories"]), - ) - def test_no_relink_skips_linking(self): - unlinked = [ - { - "database": "db", - "code": "first", - "name": "foo", - "categories": ("bar",), - "exchanges": [ - { - "name": "foo", - "categories": ("bar",), - "input": ("something", "else"), - } - ], - } - ] - expected = [ - { - "database": "db", - "code": "first", - "name": "foo", - "categories": ("bar",), - "exchanges": [ - { - "name": "foo", - "categories": ("bar",), - "input": ("db", "first"), - } - ], - } - ] - self.assertEqual( - unlinked, link_iterable_by_fields(copy.deepcopy(unlinked), internal=True) +def test_edge_kinds_filter_and_relink(): + unlinked = [ + { + "database": "db", + "code": "first", + "name": "foo", + "categories": ("bar",), + "exchanges": [ + { + "name": "foo", + "categories": ("bar",), + "type": "a", + "input": ("something", "else"), + }, + {"name": "foo", "categories": ("baz",)}, + ], + }, + { + "database": "db", + "code": "second", + "name": "foo", + "categories": ("baz",), + "exchanges": [], + }, + ] + expected = [ + { + "database": "db", + "code": "first", + "name": "foo", + "categories": ("bar",), + "exchanges": [ + { + "name": "foo", + "categories": ("bar",), + "type": "a", + "input": ("db", "first"), + }, + { + "name": "foo", + "categories": ("baz",), + }, + ], + }, + { + "database": "db", + "code": "second", + "name": "foo", + "categories": ("baz",), + "exchanges": [], + }, + ] + assert ( + link_iterable_by_fields(unlinked, internal=True, edge_kinds=["a"], relink=True) + == expected + ) + + +def test_relink(): + unlinked = [ + { + "database": "db", + "code": "first", + "name": "foo", + "categories": ("bar",), + "exchanges": [ + { + "name": "foo", + "categories": ("bar",), + "type": "a", + "input": ("something", "else"), + }, + { + "name": "foo", + "type": "b", + "input": ("something", "else"), + "categories": ("baz",), + }, + ], + }, + { + "database": "db", + "code": "second", + "name": "foo", + "categories": ("baz",), + "exchanges": [], + }, + ] + expected = [ + { + "database": "db", + "code": "first", + "name": "foo", + "categories": ("bar",), + "exchanges": [ + { + "name": "foo", + "categories": ("bar",), + "type": "a", + "input": ("db", "first"), + }, + { + "name": "foo", + "type": "b", + "input": ("db", "second"), + "categories": ("baz",), + }, + ], + }, + { + "database": "db", + "code": "second", + "name": "foo", + "categories": ("baz",), + "exchanges": [], + }, + ] + assert link_iterable_by_fields(unlinked, internal=True, relink=True) == expected + + +def test_linking_with_fields(): + unlinked = [ + { + "exchanges": [ + { + "name": "foo", + "categories": ("bar",), + "type": "a", + }, + { + "name": "foo", + "categories": ("baz",), + "unit": "kilogram", + "type": "b", + }, + ] + } + ] + other = [ + {"name": "foo", "categories": ("bar",), "database": "db", "code": "first"}, + {"name": "foo", "categories": ("baz",), "database": "db", "code": "second"}, + ] + expected = [ + { + "exchanges": [ + { + "name": "foo", + "type": "a", + "categories": ("bar",), + "input": ("db", "first"), + }, + { + "name": "foo", + "type": "b", + "categories": ("baz",), + "input": ("db", "second"), + "unit": "kilogram", + }, + ] + } + ] + assert ( + link_iterable_by_fields(unlinked, other, fields=["name", "categories"]) + == expected + ) + + +def test_no_relink_skips_linking(): + unlinked = [ + { + "database": "db", + "code": "first", + "name": "foo", + "categories": ("bar",), + "exchanges": [ + { + "name": "foo", + "categories": ("bar",), + "input": ("something", "else"), + } + ], + } + ] + expected = [ + { + "database": "db", + "code": "first", + "name": "foo", + "categories": ("bar",), + "exchanges": [ + { + "name": "foo", + "categories": ("bar",), + "input": ("db", "first"), + } + ], + } + ] + assert link_iterable_by_fields(copy.deepcopy(unlinked), internal=True) == unlinked + + del unlinked[0]["exchanges"][0]["input"] + assert link_iterable_by_fields(unlinked, internal=True) == expected + + +def test_node_filters(): + unlinked = [ + { + "database": "db", + "code": "first", + "name": "foo", + "type": "process", + "exchanges": [ + { + "name": "bar", + } + ], + }, + { + "database": "db", + "code": "second", + "name": "bar", + "type": "product", + "exchanges": [ + { + "name": "foo", + } + ], + }, + ] + expected = [ + { + "database": "db", + "code": "first", + "name": "foo", + "type": "process", + "exchanges": [{"name": "bar", "input": ("db", "second")}], + }, + { + "database": "db", + "code": "second", + "name": "bar", + "type": "product", + "exchanges": [ + { + "name": "foo", + } + ], + }, + ] + + assert ( + link_iterable_by_fields( + unlinked, + this_node_kinds=["process"], + other_node_kinds=["product"], + internal=True, ) - del unlinked[0]["exchanges"][0]["input"] - self.assertEqual(expected, link_iterable_by_fields(unlinked, internal=True)) + == expected + ) + + +def test_without_node_filters(): + unlinked = [ + { + "database": "db", + "code": "first", + "name": "foo", + "type": "process", + "exchanges": [ + { + "name": "bar", + } + ], + }, + { + "database": "db", + "code": "second", + "name": "bar", + "type": "product", + "exchanges": [ + { + "name": "foo", + } + ], + }, + ] + expected = [ + { + "database": "db", + "code": "first", + "name": "foo", + "type": "process", + "exchanges": [{"name": "bar", "input": ("db", "second")}], + }, + { + "database": "db", + "code": "second", + "name": "bar", + "type": "product", + "exchanges": [{"name": "foo", "input": ("db", "first")}], + }, + ] + + assert link_iterable_by_fields(unlinked, internal=True) == expected