Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with join with duplicate obs indices #822

Merged
merged 4 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ def _inner_join_spatialelement_table(
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
) -> tuple[dict[str, Any], AnnData]:
regions, region_column_name, instance_key = get_table_keys(table)
groups_df = table.obs.groupby(by=region_column_name, observed=False)
obs = table.obs.reset_index()
groups_df = obs.groupby(by=region_column_name, observed=False)
joined_indices = None
for element_type, name_element in element_dict.items():
for name, element in name_element.items():
Expand All @@ -401,7 +402,7 @@ def _inner_join_spatialelement_table(
element_dict[element_type][name] = masked_element

joined_indices = _get_joined_table_indices(
joined_indices, element_indices, table_instance_key_column, match_rows
joined_indices, masked_element.index, table_instance_key_column, match_rows
)
else:
warnings.warn(
Expand All @@ -414,6 +415,7 @@ def _inner_join_spatialelement_table(
joined_indices = joined_indices.dropna() if any(joined_indices.isna()) else joined_indices

joined_table = table[joined_indices, :].copy() if joined_indices is not None else None

_inplace_fix_subset_categorical_obs(subset_adata=joined_table, original_adata=table)
return element_dict, joined_table

Expand Down Expand Up @@ -455,7 +457,8 @@ def _left_join_spatialelement_table(
if match_rows == "right":
warnings.warn("Matching rows 'right' is not supported for 'left' join.", UserWarning, stacklevel=2)
regions, region_column_name, instance_key = get_table_keys(table)
groups_df = table.obs.groupby(by=region_column_name, observed=False)
obs = table.obs.reset_index()
groups_df = obs.groupby(by=region_column_name, observed=False)
joined_indices = None
for element_type, name_element in element_dict.items():
for name, element in name_element.items():
Expand Down
44 changes: 43 additions & 1 deletion tests/core/query/test_relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import pytest
from anndata import AnnData

from spatialdata import get_values, match_table_to_element
from spatialdata import SpatialData, get_values, match_table_to_element
from spatialdata._core.query.relational_query import (
_locate_value,
_ValueOrigin,
get_element_annotators,
join_spatialelement_table,
)
from spatialdata.models.models import TableModel
from spatialdata.testing import assert_anndata_equal, assert_geodataframe_equal


def test_match_table_to_element(sdata_query_aggregation):
Expand Down Expand Up @@ -376,6 +377,47 @@ def test_match_rows_inner_join_non_matching_table(sdata_query_aggregation):
assert all(indices == reversed_instance_id)


# TODO: 'left_exclusive' is currently not working, reported in this issue:
@pytest.mark.parametrize("join_type", ["left", "right", "inner", "right_exclusive"])
def test_inner_join_match_rows_duplicate_obs_indices(sdata_query_aggregation: SpatialData, join_type: str) -> None:
sdata = sdata_query_aggregation
sdata["table"].obs.index = ["a"] * sdata["table"].n_obs
sdata["values_circles"] = sdata_query_aggregation["values_circles"][:4]
sdata["values_polygons"] = sdata_query_aggregation["values_polygons"][:5]

element_dict, table = join_spatialelement_table(
sdata=sdata,
spatial_element_names=["values_circles", "values_polygons"],
table_name="table",
how=join_type,
)

if join_type in ["left", "inner"]:
# table check
assert table.n_obs == 9
assert np.array_equal(table.obs["instance_id"][:4], sdata["values_circles"].index)
assert np.array_equal(table.obs["instance_id"][4:], sdata["values_polygons"].index)
# shapes check
assert_geodataframe_equal(element_dict["values_circles"], sdata["values_circles"])
assert_geodataframe_equal(element_dict["values_polygons"], sdata["values_polygons"])
elif join_type == "right":
# table check
assert_anndata_equal(table.obs, sdata["table"].obs)
# shapes check
assert_geodataframe_equal(element_dict["values_circles"], sdata["values_circles"])
assert_geodataframe_equal(element_dict["values_polygons"], sdata["values_polygons"])
elif join_type == "left_exclusive":
# TODO: currently not working, reported in this issue
pass
else:
assert join_type == "right_exclusive"
# table check
assert table.n_obs == sdata["table"].n_obs - len(sdata["values_circles"]) - len(sdata["values_polygons"])
# shapes check
assert element_dict["values_circles"] is None
assert element_dict["values_polygons"] is None


# TODO: there is a lot of dublicate code, simplify with a function that tests both the case sdata=None and sdata=sdata
def test_match_rows_join(sdata_query_aggregation):
sdata = sdata_query_aggregation
Expand Down
Loading