diff --git a/quivr/tables.py b/quivr/tables.py index 6ec16ce..88cd2d7 100644 --- a/quivr/tables.py +++ b/quivr/tables.py @@ -1216,3 +1216,53 @@ def set_column(self, name: str, data: DataSourceType) -> Self: table = column._set_on_pyarrow_table(self.table, data) return self.from_pyarrow(table=table, validate=True, permit_nulls=False) + + def unique_indices( + self, subset: Optional[List[str]] = None, keep: Literal["first", "last"] = "first" + ) -> pa.Array: + """ + Get the indices of the first or last occurrence of each unique row in the table. A subset of + columns can be specified to consider when determining uniqueness. If no subset is specified then + all columns are used. + + :param subset: Subset of columns to consider when determining uniqueness. + :param keep: If there are duplicate rows then keep the first or last row. + """ + # Flatten the table so nested columns are dot-delimited at the top level + flattened_table = self.flattened_table() + + # If subset is not specified then use all the columns + if subset is None: + subset = [c for c in flattened_table.column_names] + + # Add an index column to the flattened table + flattened_table = flattened_table.add_column(0, "index", pa.array(np.arange(len(flattened_table)))) + + if keep not in ["first", "last"]: + raise ValueError(f"keep must be 'first' or 'last', got {keep}.") + + agg_func = keep + indices = ( + flattened_table.group_by(subset, use_threads=False) + .aggregate([("index", agg_func)]) + .column(f"index_{agg_func}") + ).combine_chunks() + return indices + + def drop_duplicates( + self, + subset: Optional[List[str]] = None, + keep: Literal["first", "last"] = "first", + ) -> Self: + """ + Drop duplicate rows from a `~quivr.Table`. This function is similar to + `~pandas.DataFrame.drop_duplicates` but it supports nested columns (representing + nested tables). + + :param subset: Subset of columns to consider when dropping duplicates. If not specified then + all columns are used. + :param keep: If there are duplicate rows then keep the first or last row. + + """ + indices = self.unique_indices(subset=subset, keep=keep) + return self.take(indices) diff --git a/test/test_tables.py b/test/test_tables.py index 5d66960..60725fb 100644 --- a/test/test_tables.py +++ b/test/test_tables.py @@ -1513,6 +1513,126 @@ class Outer(qv.Table): assert o2.inner.name == "b" +def test_get_unique_indices_first(): + t = Pair.from_kwargs(x=[1, 2, 3, 1, 2, 3], y=[4, 5, 6, 4, 5, 6]) + + indices = t.unique_indices(keep="first") + assert indices.equals(pa.array([0, 1, 2], pa.int64())) + + indices = t.unique_indices(keep="last") + assert indices.equals(pa.array([3, 4, 5], pa.int64())) + + +def test_get_unique_indices_subset(): + t = Pair.from_kwargs(x=[1, 2, 3, 1, 2, 3], y=[1, 1, 2, 2, 3, 3]) + + indices = t.unique_indices(subset=["y"], keep="first") + assert indices.equals(pa.array([0, 2, 4], pa.int64())) + + indices = t.unique_indices(subset=["y"], keep="last") + assert indices.equals(pa.array([1, 3, 5], pa.int64())) + + indices = t.unique_indices(subset=["x"], keep="first") + assert indices.equals(pa.array([0, 1, 2], pa.int64())) + + indices = t.unique_indices(subset=["x"], keep="last") + assert indices.equals(pa.array([3, 4, 5], pa.int64())) + + indices = t.unique_indices(subset=["x", "y"], keep="first") + assert indices.equals(pa.array([0, 1, 2, 3, 4, 5], pa.int64())) + + indices = t.unique_indices(subset=["x", "y"], keep="last") + assert indices.equals(pa.array([0, 1, 2, 3, 4, 5], pa.int64())) + + +def test_get_unique_indices_nested(): + t = Pair.from_kwargs(x=[1, 2, 3, 1, 2, 3], y=[4, 5, 6, 4, 5, 6]) + w = Wrapper.from_kwargs(id=["a", "b", "c", "a", "b", "c"], pair=t) + + indices = w.unique_indices(subset=["pair.x"], keep="first") + assert indices.equals(pa.array([0, 1, 2], pa.int64())) + + indices = w.unique_indices(subset=["pair.x"], keep="last") + assert indices.equals(pa.array([3, 4, 5], pa.int64())) + + indices = w.unique_indices(subset=["pair.y"], keep="first") + assert indices.equals(pa.array([0, 1, 2], pa.int64())) + + indices = w.unique_indices(subset=["pair.y"], keep="last") + assert indices.equals(pa.array([3, 4, 5], pa.int64())) + + indices = w.unique_indices(subset=["pair.x", "pair.y"], keep="first") + assert indices.equals(pa.array([0, 1, 2], pa.int64())) + + indices = w.unique_indices(subset=["id"], keep="last") + assert indices.equals(pa.array([3, 4, 5], pa.int64())) + + +def test_get_unique_indices_raises(): + with pytest.raises(ValueError, match="keep must be 'first' or 'last', got invalid"): + t = Pair.from_kwargs(x=[1, 2, 3, 1, 2, 3], y=[4, 5, 6, 4, 5, 6]) + t.unique_indices(keep="invalid") + + +def test_drop_duplicates(): + + t = Pair.from_kwargs(x=[1, 2, 3, 1, 2, 3], y=[4, 5, 6, 4, 5, 6]) + + t2 = t.drop_duplicates(subset=["x"]) + assert t2.x.equals(pa.array([1, 2, 3], pa.int64())) + assert t2.y.equals(pa.array([4, 5, 6], pa.int64())) + + t2 = t.drop_duplicates(subset=["y"]) + assert t2.x.equals(pa.array([1, 2, 3], pa.int64())) + assert t2.y.equals(pa.array([4, 5, 6], pa.int64())) + + t2 = t.drop_duplicates(subset=["x", "y"]) + assert t2.x.equals(pa.array([1, 2, 3], pa.int64())) + assert t2.y.equals(pa.array([4, 5, 6], pa.int64())) + + t2 = t.drop_duplicates(subset=["x"], keep="last") + assert t2.x.equals(pa.array([1, 2, 3], pa.int64())) + assert t2.y.equals(pa.array([4, 5, 6], pa.int64())) + + t2 = t.drop_duplicates(subset=["y"], keep="last") + assert t2.x.equals(pa.array([1, 2, 3], pa.int64())) + assert t2.y.equals(pa.array([4, 5, 6], pa.int64())) + + t2 = t.drop_duplicates(subset=["x", "y"], keep="last") + assert t2.x.equals(pa.array([1, 2, 3], pa.int64())) + assert t2.y.equals(pa.array([4, 5, 6], pa.int64())) + + +def test_drop_duplicates_nested(): + + t = Pair.from_kwargs(x=[1, 2, 3, 1, 2, 3], y=[4, 5, 6, 4, 5, 6]) + w = Wrapper.from_kwargs(id=["a", "a", "b", "b", "c", "c"], pair=t) + + w2 = w.drop_duplicates(subset=["pair.x"]) + assert w2.id.equals(pa.array(["a", "a", "b"], pa.string())) + assert w2.pair.x.equals(pa.array([1, 2, 3], pa.int64())) + assert w2.pair.y.equals(pa.array([4, 5, 6], pa.int64())) + + w2 = w.drop_duplicates(subset=["pair.y"]) + assert w2.id.equals(pa.array(["a", "a", "b"], pa.string())) + assert w2.pair.x.equals(pa.array([1, 2, 3], pa.int64())) + + w2 = w.drop_duplicates(subset=["pair.x", "pair.y"]) + assert w2.id.equals(pa.array(["a", "a", "b"], pa.string())) + assert w2.pair.x.equals(pa.array([1, 2, 3], pa.int64())) + assert w2.pair.y.equals(pa.array([4, 5, 6], pa.int64())) + + w2 = w.drop_duplicates(subset=["id"], keep="last") + assert w2.id.equals(pa.array(["a", "b", "c"], pa.string())) + assert w2.pair.x.equals(pa.array([2, 1, 3], pa.int64())) + assert w2.pair.y.equals(pa.array([5, 4, 6], pa.int64())) + + w2 = w.drop_duplicates(subset=["id"], keep="first") + assert w2.id.equals(pa.array(["a", "b", "c"], pa.string())) + assert w2.pair.x.equals(pa.array([1, 3, 2], pa.int64())) + assert w2.pair.y.equals(pa.array([4, 6, 5], pa.int64())) + + @pytest.mark.benchmark(group="column-access") class TestColumnAccessBenchmark: def test_access_f64(self, benchmark):