Skip to content

Commit

Permalink
Add more filter options to link_iterable_by_fields
Browse files Browse the repository at this point in the history
  • Loading branch information
cmutel committed Nov 27, 2024
1 parent 1ef1539 commit 14619ec
Show file tree
Hide file tree
Showing 10 changed files with 544 additions and 396 deletions.
35 changes: 27 additions & 8 deletions bw2io/importers/base_lci.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,16 @@ def write_excel(

def match_database(
self,
db_name=None,
fields=None,
ignore_categories=False,
relink=False,
kind=None,
):
db_name: Optional[str] = None,
fields: Optional[List[str]] = None,
ignore_categories: bool = False,
relink: bool = False,
kind: Optional[Union[List[str], str]] = None,
edge_kinds: Optional[List[str]] = None,
this_node_kinds: Optional[List[str]] = None,
other_node_kinds: Optional[List[str]] = None,
processes_to_products: bool = False,
) -> None:
"""Match current database against itself or another database.
If ``db_name`` is None, match against current data. Otherwise, ``db_name`` should be the name of an existing ``Database``.
Expand All @@ -410,9 +414,24 @@ def match_database(
Nothing is returned, but ``self.data`` is changed.
"""
if kind is not None:
warnings.warn(
"`kind` is deprecated, please use `edge_kind` instead",
DeprecationWarning,
)
edge_kinds = list(kind)

if processes_to_products:
this_node_kinds = labels.process_node_types + [
labels.multifunctional_node_default
]
other_node_kinds = labels.product_node_types

kwargs = {
"fields": fields,
"kind": kind,
"edge_kinds": edge_kinds,
"this_node_kinds": this_node_kinds,
"other_node_kinds": other_node_kinds,
"relink": relink,
}
if fields and ignore_categories:
Expand Down Expand Up @@ -648,7 +667,7 @@ def reformat(exc):
for obj in Database(biosphere_name)
if obj.get("type") == "emission"
),
kind="biosphere",
edge_kinds=["biosphere"],
),
)

Expand Down
2 changes: 1 addition & 1 deletion bw2io/importers/base_lcia.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, filepath, biosphere=None):
for obj in Database(self.biosphere_name)
if obj.get("type") == "emission"
),
kind="biosphere",
edge_kinds=["biosphere"],
),
functools.partial(
match_subcategories, biosphere_db_name=self.biosphere_name
Expand Down
2 changes: 1 addition & 1 deletion bw2io/importers/ecospold1.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
functools.partial(
link_iterable_by_fields,
other=Database(config.biosphere),
kind="biosphere",
edge_kinds=["biosphere"],
),
functools.partial(
link_technosphere_by_activity_hash,
Expand Down
2 changes: 1 addition & 1 deletion bw2io/importers/excel.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, filepath):
functools.partial(
link_iterable_by_fields,
other=Database(config.biosphere),
kind="biosphere",
edge_kinds=["biosphere"],
),
assign_only_product_as_production,
link_technosphere_by_activity_hash,
Expand Down
2 changes: 1 addition & 1 deletion bw2io/importers/excel_lcia.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(self, filepath, name, description, unit, **metadata):
functools.partial(
link_iterable_by_fields,
other=Database(config.biosphere),
kind="biosphere",
edge_kinds=["biosphere"],
fields=("name", "categories"),
),
drop_falsey_uncertainty_fields_but_keep_zeros,
Expand Down
4 changes: 2 additions & 2 deletions bw2io/importers/json_ld.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ def __init__(self, dirpath, database_name, preferred_allocation=None):
partial(
link_iterable_by_fields,
fields=["code"],
kind={"production", "technosphere"},
edge_kinds=["production", "technosphere"],
internal=True,
),
partial(
link_iterable_by_fields,
other=self.biosphere_database,
fields=["code"],
kind={"biosphere"},
edge_kinds=["biosphere"],
),
normalize_units,
]
Expand Down
4 changes: 2 additions & 2 deletions bw2io/importers/simapro_block_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
functools.partial(
link_iterable_by_fields,
other=Database(biosphere_database_name or config.biosphere),
kind=labels.biosphere_edge_types,
edge_kinds=labels.biosphere_edge_types,
),
match_internal_simapro_simapro_with_unit_conversion,
]
Expand Down Expand Up @@ -226,7 +226,7 @@ def use_ecoinvent_strategies(self) -> None:
other=Database(
self.default_biosphere_database_name or config.biosphere
),
kind="biosphere",
edge_kinds=["biosphere"],
),
]

Expand Down
2 changes: 1 addition & 1 deletion bw2io/importers/simapro_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(
functools.partial(
link_iterable_by_fields,
other=Database(biosphere_db or config.biosphere),
kind="biosphere",
edge_kinds=["biosphere"],
),
convert_activity_parameters_to_list,
]
Expand Down
46 changes: 36 additions & 10 deletions bw2io/strategies/generic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numbers
import pprint
import warnings
from collections import defaultdict
from copy import deepcopy
from typing import Iterable, List, Optional, Union
Expand Down Expand Up @@ -75,6 +76,9 @@ def link_iterable_by_fields(
other: Optional[Iterable[dict]] = None,
fields: Optional[List[str]] = None,
kind: Optional[Union[str, List[str]]] = None,
edge_kinds: Optional[List[str]] = None,
this_node_kinds: Optional[List[str]] = None,
other_node_kinds: Optional[List[str]] = None,
internal: bool = False,
relink: bool = False,
) -> List[dict]:
Expand All @@ -91,9 +95,15 @@ def link_iterable_by_fields(
fields : iterable[str], optional
An iterable of strings indicating which fields should be used to match objects. If not
specified, all fields will be used.
kind : str|list[string], optional
kind : deprecated
Use `edge_kinds` instead
edge_kinds : str|list[string], optional
If specified, limit the exchange to objects of the given kind. `kind` can be a string or an
iterable of strings.
this_node_kinds : str|list[string], optional
If specified, limit linking to objects in `unlinked` which have `type` in `this_node_kinds`.
other_node_kinds : str|list[string], optional
If specified, limit linking to objects in `other` which have `type` in `other_node_kinds`.
internal : bool, optional
If `True`, link objects in `unlinked` to other objects in `unlinked`. Each object must have
the attributes `database` and `code`.
Expand Down Expand Up @@ -153,25 +163,41 @@ def link_iterable_by_fields(
>>> linked[1]["exchanges"][0]["input"]
('db2', 'C')
"""
if kind:
kind = {kind} if isinstance(kind, str) else kind
if kind is not None:
warnings.warn(
"`kind` is deprecated, please use `edge_kinds` instead", DeprecationWarning
)
edge_kinds = kind

if edge_kinds:
edge_kinds = {edge_kinds} if isinstance(edge_kinds, str) else edge_kinds
if relink:
filter_func = lambda x: x.get("type") in kind
edge_filter_func = lambda x: x.get("type") in edge_kinds
else:
filter_func = lambda x: x.get("type") in kind and not x.get("input")
edge_filter_func = lambda x: x.get("type") in edge_kinds and not x.get(
"input"
)
else:
if relink:
filter_func = lambda x: True
edge_filter_func = lambda x: True
else:
filter_func = lambda x: not x.get("input")
edge_filter_func = lambda x: not x.get("input")
if this_node_kinds:
this_filter_func = lambda x: x.get("type") in this_node_kinds
else:
this_filter_func = lambda x: True
if other_node_kinds:
other_filter_func = lambda x: x.get("type") in other_node_kinds
else:
other_filter_func = lambda x: True

if internal:
other = unlinked

duplicates, candidates = {}, {}
try:
# Other can be a generator, so a bit convoluted
for ds in other:
for ds in filter(other_filter_func, other):
key = activity_hash(ds, fields)
if key in candidates:
duplicates.setdefault(key, []).append(ds)
Expand All @@ -183,8 +209,8 @@ def link_iterable_by_fields(
"``database`` or ``code`` attributes"
)

for container in unlinked:
for obj in filter(filter_func, container.get("exchanges", [])):
for container in filter(this_filter_func, unlinked):
for obj in filter(edge_filter_func, container.get("exchanges", [])):
key = activity_hash(obj, fields)
if key in duplicates:
raise StrategyError(
Expand Down
Loading

0 comments on commit 14619ec

Please sign in to comment.