Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Jan 25, 2024
1 parent 17e973a commit 5f68606
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 58 deletions.
58 changes: 0 additions & 58 deletions py-polars/tests/unit/namespaces/array/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,64 +314,6 @@ def test_array_explode() -> None:
assert_series_equal(out_s, expected_s)


@pytest.mark.parametrize(
("array", "data", "expected", "dtype"),
[
([[1, 2], [3, 4]], [1, 5], [True, False], pl.Int64),
([[True, False], [True, True]], [True, False], [True, False], pl.Boolean),
([["a", "b"], ["c", "d"]], ["a", "b"], [True, False], pl.String),
([[b"a", b"b"], [b"c", b"d"]], [b"a", b"b"], [True, False], pl.Binary),
(
[[{"a": 1}, {"a": 2}], [{"b": 1}, {"a": 3}]],
[{"a": 1}, {"a": 2}],
[True, False],
pl.Struct([pl.Field("a", pl.Int64)]),
),
],
)
def test_array_contains_expr(
array: list[list[Any]], data: list[Any], expected: list[bool], dtype: pl.DataType
) -> None:
df = pl.DataFrame(
{
"array": array,
"data": data,
},
schema={
"array": pl.Array(dtype, 2),
"data": dtype,
},
)
out = df.select(contains=pl.col("array").arr.contains(pl.col("data"))).to_series()
expected_series = pl.Series("contains", expected)
assert_series_equal(out, expected_series)


@pytest.mark.parametrize(
("array", "data", "expected", "dtype"),
[
([[1, 2], [3, 4]], 1, [True, False], pl.Int64),
([[True, False], [True, True]], True, [True, True], pl.Boolean),
([["a", "b"], ["c", "d"]], "a", [True, False], pl.String),
([[b"a", b"b"], [b"c", b"d"]], b"a", [True, False], pl.Binary),
],
)
def test_array_contains_literal(
array: list[list[Any]], data: Any, expected: list[bool], dtype: pl.DataType
) -> None:
df = pl.DataFrame(
{
"array": array,
},
schema={
"array": pl.Array(dtype, 2),
},
)
out = df.select(contains=pl.col("array").arr.contains(data)).to_series()
expected_series = pl.Series("contains", expected)
assert_series_equal(out, expected_series)


@pytest.mark.parametrize(
("arr", "data", "expected", "dtype"),
[
Expand Down
72 changes: 72 additions & 0 deletions py-polars/tests/unit/namespaces/array/test_contains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

from typing import Any

import pytest

import polars as pl
from polars.testing import assert_series_equal


@pytest.mark.parametrize(
("array", "data", "expected", "dtype"),
[
([[1, 2], [3, 4]], [1, 5], [True, False], pl.Int64),
([[True, False], [True, True]], [True, False], [True, False], pl.Boolean),
([["a", "b"], ["c", "d"]], ["a", "b"], [True, False], pl.String),
([[b"a", b"b"], [b"c", b"d"]], [b"a", b"b"], [True, False], pl.Binary),
(
[[{"a": 1}, {"a": 2}], [{"b": 1}, {"a": 3}]],
[{"a": 1}, {"a": 2}],
[True, False],
pl.Struct([pl.Field("a", pl.Int64)]),
),
],
)
def test_array_contains_expr(
array: list[list[Any]], data: list[Any], expected: list[bool], dtype: pl.DataType
) -> None:
df = pl.DataFrame(
{
"array": array,
"data": data,
},
schema={
"array": pl.Array(dtype, 2),
"data": dtype,
},
)
out = df.select(contains=pl.col("array").arr.contains(pl.col("data"))).to_series()
expected_series = pl.Series("contains", expected)
assert_series_equal(out, expected_series)


@pytest.mark.parametrize(
("array", "data", "expected", "dtype"),
[
([[1, 2], [3, 4]], 1, [True, False], pl.Int64),
([[True, False], [True, True]], True, [True, True], pl.Boolean),
([["a", "b"], ["c", "d"]], "a", [True, False], pl.String),
([[b"a", b"b"], [b"c", b"d"]], b"a", [True, False], pl.Binary),
],
)
def test_array_contains_literal(
array: list[list[Any]], data: Any, expected: list[bool], dtype: pl.DataType
) -> None:
df = pl.DataFrame(
{
"array": array,
},
schema={
"array": pl.Array(dtype, 2),
},
)
out = df.select(contains=pl.col("array").arr.contains(data)).to_series()
expected_series = pl.Series("contains", expected)
assert_series_equal(out, expected_series)


def test_array_contains_invalid_datatype() -> None:
df = pl.DataFrame({"a": [[1, 2], [3, 4]]}, schema={"a": pl.List(pl.Int8)})
with pytest.raises(pl.SchemaError, match="invalid series dtype: expected `Array`"):
df.select(pl.col("a").arr.contains(2))
6 changes: 6 additions & 0 deletions py-polars/tests/unit/namespaces/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def test_contains() -> None:
assert_series_equal(out, expected)


def test_list_contains_invalid_datatype() -> None:
df = pl.DataFrame({"a": [[1, 2], [3, 4]]}, schema={"a": pl.Array(pl.Int8, width=2)})
with pytest.raises(pl.SchemaError, match="invalid series dtype: expected `List`"):
df.select(pl.col("a").list.contains(2))


def test_list_concat() -> None:
df = pl.DataFrame({"a": [[1, 2], [1], [1, 2, 3]]})

Expand Down

0 comments on commit 5f68606

Please sign in to comment.