From 5f68606f14c6b56b9b37df5f624a5be06071b944 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Thu, 25 Jan 2024 10:14:48 +0100 Subject: [PATCH] Add test --- .../tests/unit/namespaces/array/test_array.py | 58 --------------- .../unit/namespaces/array/test_contains.py | 72 +++++++++++++++++++ py-polars/tests/unit/namespaces/test_list.py | 6 ++ 3 files changed, 78 insertions(+), 58 deletions(-) create mode 100644 py-polars/tests/unit/namespaces/array/test_contains.py diff --git a/py-polars/tests/unit/namespaces/array/test_array.py b/py-polars/tests/unit/namespaces/array/test_array.py index add031e16f47..7e04b5e04546 100644 --- a/py-polars/tests/unit/namespaces/array/test_array.py +++ b/py-polars/tests/unit/namespaces/array/test_array.py @@ -314,64 +314,6 @@ def test_array_explode() -> None: assert_series_equal(out_s, expected_s) -@pytest.mark.parametrize( - ("array", "data", "expected", "dtype"), - [ - ([[1, 2], [3, 4]], [1, 5], [True, False], pl.Int64), - ([[True, False], [True, True]], [True, False], [True, False], pl.Boolean), - ([["a", "b"], ["c", "d"]], ["a", "b"], [True, False], pl.String), - ([[b"a", b"b"], [b"c", b"d"]], [b"a", b"b"], [True, False], pl.Binary), - ( - [[{"a": 1}, {"a": 2}], [{"b": 1}, {"a": 3}]], - [{"a": 1}, {"a": 2}], - [True, False], - pl.Struct([pl.Field("a", pl.Int64)]), - ), - ], -) -def test_array_contains_expr( - array: list[list[Any]], data: list[Any], expected: list[bool], dtype: pl.DataType -) -> None: - df = pl.DataFrame( - { - "array": array, - "data": data, - }, - schema={ - "array": pl.Array(dtype, 2), - "data": dtype, - }, - ) - out = df.select(contains=pl.col("array").arr.contains(pl.col("data"))).to_series() - expected_series = pl.Series("contains", expected) - assert_series_equal(out, expected_series) - - -@pytest.mark.parametrize( - ("array", "data", "expected", "dtype"), - [ - ([[1, 2], [3, 4]], 1, [True, False], pl.Int64), - ([[True, False], [True, True]], True, [True, True], pl.Boolean), - ([["a", "b"], ["c", "d"]], "a", [True, False], pl.String), - ([[b"a", b"b"], [b"c", b"d"]], b"a", [True, False], pl.Binary), - ], -) -def test_array_contains_literal( - array: list[list[Any]], data: Any, expected: list[bool], dtype: pl.DataType -) -> None: - df = pl.DataFrame( - { - "array": array, - }, - schema={ - "array": pl.Array(dtype, 2), - }, - ) - out = df.select(contains=pl.col("array").arr.contains(data)).to_series() - expected_series = pl.Series("contains", expected) - assert_series_equal(out, expected_series) - - @pytest.mark.parametrize( ("arr", "data", "expected", "dtype"), [ diff --git a/py-polars/tests/unit/namespaces/array/test_contains.py b/py-polars/tests/unit/namespaces/array/test_contains.py new file mode 100644 index 000000000000..daba5177828e --- /dev/null +++ b/py-polars/tests/unit/namespaces/array/test_contains.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.testing import assert_series_equal + + +@pytest.mark.parametrize( + ("array", "data", "expected", "dtype"), + [ + ([[1, 2], [3, 4]], [1, 5], [True, False], pl.Int64), + ([[True, False], [True, True]], [True, False], [True, False], pl.Boolean), + ([["a", "b"], ["c", "d"]], ["a", "b"], [True, False], pl.String), + ([[b"a", b"b"], [b"c", b"d"]], [b"a", b"b"], [True, False], pl.Binary), + ( + [[{"a": 1}, {"a": 2}], [{"b": 1}, {"a": 3}]], + [{"a": 1}, {"a": 2}], + [True, False], + pl.Struct([pl.Field("a", pl.Int64)]), + ), + ], +) +def test_array_contains_expr( + array: list[list[Any]], data: list[Any], expected: list[bool], dtype: pl.DataType +) -> None: + df = pl.DataFrame( + { + "array": array, + "data": data, + }, + schema={ + "array": pl.Array(dtype, 2), + "data": dtype, + }, + ) + out = df.select(contains=pl.col("array").arr.contains(pl.col("data"))).to_series() + expected_series = pl.Series("contains", expected) + assert_series_equal(out, expected_series) + + +@pytest.mark.parametrize( + ("array", "data", "expected", "dtype"), + [ + ([[1, 2], [3, 4]], 1, [True, False], pl.Int64), + ([[True, False], [True, True]], True, [True, True], pl.Boolean), + ([["a", "b"], ["c", "d"]], "a", [True, False], pl.String), + ([[b"a", b"b"], [b"c", b"d"]], b"a", [True, False], pl.Binary), + ], +) +def test_array_contains_literal( + array: list[list[Any]], data: Any, expected: list[bool], dtype: pl.DataType +) -> None: + df = pl.DataFrame( + { + "array": array, + }, + schema={ + "array": pl.Array(dtype, 2), + }, + ) + out = df.select(contains=pl.col("array").arr.contains(data)).to_series() + expected_series = pl.Series("contains", expected) + assert_series_equal(out, expected_series) + + +def test_array_contains_invalid_datatype() -> None: + df = pl.DataFrame({"a": [[1, 2], [3, 4]]}, schema={"a": pl.List(pl.Int8)}) + with pytest.raises(pl.SchemaError, match="invalid series dtype: expected `Array`"): + df.select(pl.col("a").arr.contains(2)) diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/test_list.py index e6893f7a0a03..1e2a661d15d9 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/test_list.py @@ -79,6 +79,12 @@ def test_contains() -> None: assert_series_equal(out, expected) +def test_list_contains_invalid_datatype() -> None: + df = pl.DataFrame({"a": [[1, 2], [3, 4]]}, schema={"a": pl.Array(pl.Int8, width=2)}) + with pytest.raises(pl.SchemaError, match="invalid series dtype: expected `List`"): + df.select(pl.col("a").list.contains(2)) + + def test_list_concat() -> None: df = pl.DataFrame({"a": [[1, 2], [1], [1, 2, 3]]})