Skip to content

Commit

Permalink
Making a variety of adjustments in wrappers and unit tests to account…
Browse files Browse the repository at this point in the history
… for the switch from string to string_view as default
  • Loading branch information
timsaucer committed Nov 9, 2024
1 parent 73cfddf commit fcb5f96
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 27 deletions.
2 changes: 1 addition & 1 deletion python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def fill_null(self, value: Any | Expr | None = None) -> Expr:
_to_pyarrow_types = {
float: pa.float64(),
int: pa.int64(),
str: pa.string_view(),
str: pa.string(),
bool: pa.bool_(),
}

Expand Down
13 changes: 9 additions & 4 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def decode(input: Expr, encoding: Expr) -> Expr:

def array_to_string(expr: Expr, delimiter: Expr) -> Expr:
"""Converts each element to its text representation."""
return Expr(f.array_to_string(expr.expr, delimiter.expr))
return Expr(f.array_to_string(expr.expr, delimiter.expr.cast(pa.string())))


def array_join(expr: Expr, delimiter: Expr) -> Expr:
Expand Down Expand Up @@ -924,7 +924,7 @@ def to_timestamp(arg: Expr, *formatters: Expr) -> Expr:
return f.to_timestamp(arg.expr)

formatters = [f.expr for f in formatters]
return Expr(f.to_timestamp(arg.expr, *formatters))
return Expr(f.to_timestamp(arg.expr.cast(pa.string()), *formatters))


def to_timestamp_millis(arg: Expr, *formatters: Expr) -> Expr:
Expand Down Expand Up @@ -1065,7 +1065,10 @@ def struct(*args: Expr) -> Expr:

def named_struct(name_pairs: list[tuple[str, Expr]]) -> Expr:
"""Returns a struct with the given names and arguments pairs."""
name_pair_exprs = [[Expr.literal(pair[0]), pair[1]] for pair in name_pairs]
name_pair_exprs = [
[Expr.literal(pa.scalar(pair[0], type=pa.string())), pair[1]]
for pair in name_pairs
]

# flatten
name_pairs = [x.expr for xs in name_pair_exprs for x in xs]
Expand Down Expand Up @@ -1422,7 +1425,9 @@ def array_sort(array: Expr, descending: bool = False, null_first: bool = False)
nulls_first = "NULLS FIRST" if null_first else "NULLS LAST"
return Expr(
f.array_sort(
array.expr, Expr.literal(desc).expr, Expr.literal(nulls_first).expr
array.expr,
Expr.literal(pa.scalar(desc, type=pa.string())).expr,
Expr.literal(pa.scalar(nulls_first, type=pa.string())).expr,
)
)

Expand Down
8 changes: 6 additions & 2 deletions python/tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,10 @@ def test_relational_expr(test_ctx):
ctx = SessionContext()

batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array(["alpha", "beta", "gamma"])],
[
pa.array([1, 2, 3]),
pa.array(["alpha", "beta", "gamma"], type=pa.string_view()),
],
names=["a", "b"],
)
df = ctx.create_dataframe([[batch]], name="batch_array")
Expand All @@ -145,7 +148,8 @@ def test_relational_expr(test_ctx):
assert df.filter(col("b") == "beta").count() == 1
assert df.filter(col("b") != "beta").count() == 2

assert df.filter(col("a") == "beta").count() == 0
with pytest.raises(Exception):
df.filter(col("a") == "beta").count()


def test_expr_to_variant():
Expand Down
67 changes: 47 additions & 20 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def df():
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[
pa.array(["Hello", "World", "!"]),
pa.array(["Hello", "World", "!"], type=pa.string_view()),
pa.array([4, 5, 6]),
pa.array(["hello ", " world ", " !"]),
pa.array(["hello ", " world ", " !"], type=pa.string_view()),
pa.array(
[
datetime(2022, 12, 31),
Expand Down Expand Up @@ -88,16 +88,18 @@ def test_literal(df):
assert len(result) == 1
result = result[0]
assert result.column(0) == pa.array([1] * 3)
assert result.column(1) == pa.array(["1"] * 3)
assert result.column(2) == pa.array(["OK"] * 3)
assert result.column(1) == pa.array(["1"] * 3, type=pa.string_view())
assert result.column(2) == pa.array(["OK"] * 3, type=pa.string_view())
assert result.column(3) == pa.array([3.14] * 3)
assert result.column(4) == pa.array([True] * 3)
assert result.column(5) == pa.array([b"hello world"] * 3)


def test_lit_arith(df):
"""Test literals with arithmetic operations"""
df = df.select(literal(1) + column("b"), f.concat(column("a"), literal("!")))
df = df.select(
literal(1) + column("b"), f.concat(column("a").cast(pa.string()), literal("!"))
)
result = df.collect()
assert len(result) == 1
result = result[0]
Expand Down Expand Up @@ -578,21 +580,33 @@ def test_array_function_obj_tests(stmt, py_expr):
f.ascii(column("a")),
pa.array([72, 87, 33], type=pa.int32()),
), # H = 72; W = 87; ! = 33
(f.bit_length(column("a")), pa.array([40, 40, 8], type=pa.int32())),
(f.btrim(literal(" World ")), pa.array(["World", "World", "World"])),
(
f.bit_length(column("a").cast(pa.string())),
pa.array([40, 40, 8], type=pa.int32()),
),
(
f.btrim(literal(" World ")),
pa.array(["World", "World", "World"], type=pa.string_view()),
),
(f.character_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
(f.chr(literal(68)), pa.array(["D", "D", "D"])),
(
f.concat_ws("-", column("a"), literal("test")),
pa.array(["Hello-test", "World-test", "!-test"]),
),
(f.concat(column("a"), literal("?")), pa.array(["Hello?", "World?", "!?"])),
(
f.concat(column("a").cast(pa.string()), literal("?")),
pa.array(["Hello?", "World?", "!?"]),
),
(f.initcap(column("c")), pa.array(["Hello ", " World ", " !"])),
(f.left(column("a"), literal(3)), pa.array(["Hel", "Wor", "!"])),
(f.length(column("c")), pa.array([6, 7, 2], type=pa.int32())),
(f.lower(column("a")), pa.array(["hello", "world", "!"])),
(f.lpad(column("a"), literal(7)), pa.array([" Hello", " World", " !"])),
(f.ltrim(column("c")), pa.array(["hello ", "world ", "!"])),
(
f.ltrim(column("c")),
pa.array(["hello ", "world ", "!"], type=pa.string_view()),
),
(
f.md5(column("a")),
pa.array(
Expand All @@ -618,19 +632,25 @@ def test_array_function_obj_tests(stmt, py_expr):
f.rpad(column("a"), literal(8)),
pa.array(["Hello ", "World ", "! "]),
),
(f.rtrim(column("c")), pa.array(["hello", " world", " !"])),
(
f.rtrim(column("c")),
pa.array(["hello", " world", " !"], type=pa.string_view()),
),
(
f.split_part(column("a"), literal("l"), literal(1)),
pa.array(["He", "Wor", "!"]),
),
(f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])),
(f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())),
(f.substr(column("a"), literal(3)), pa.array(["llo", "rld", ""])),
(
f.substr(column("a"), literal(3)),
pa.array(["llo", "rld", ""], type=pa.string_view()),
),
(
f.translate(column("a"), literal("or"), literal("ld")),
pa.array(["Helll", "Wldld", "!"]),
),
(f.trim(column("c")), pa.array(["hello", "world", "!"])),
(f.trim(column("c")), pa.array(["hello", "world", "!"], type=pa.string_view())),
(f.upper(column("c")), pa.array(["HELLO ", " WORLD ", " !"])),
(f.ends_with(column("a"), literal("llo")), pa.array([True, False, False])),
(
Expand Down Expand Up @@ -772,9 +792,9 @@ def test_temporal_functions(df):
f.date_trunc(literal("month"), column("d")),
f.datetrunc(literal("day"), column("d")),
f.date_bin(
literal("15 minutes"),
literal("15 minutes").cast(pa.string()),
column("d"),
literal("2001-01-01 00:02:30"),
literal("2001-01-01 00:02:30").cast(pa.string()),
),
f.from_unixtime(literal(1673383974)),
f.to_timestamp(literal("2023-09-07 05:06:14.523952")),
Expand Down Expand Up @@ -836,8 +856,8 @@ def test_case(df):
result = df.collect()
result = result[0]
assert result.column(0) == pa.array([10, 8, 8])
assert result.column(1) == pa.array(["Hola", "Mundo", "!!"])
assert result.column(2) == pa.array(["Hola", "Mundo", None])
assert result.column(1) == pa.array(["Hola", "Mundo", "!!"], type=pa.string_view())
assert result.column(2) == pa.array(["Hola", "Mundo", None], type=pa.string_view())


def test_when_with_no_base(df):
Expand All @@ -855,8 +875,10 @@ def test_when_with_no_base(df):
result = df.collect()
result = result[0]
assert result.column(0) == pa.array([4, 5, 6])
assert result.column(1) == pa.array(["too small", "just right", "too big"])
assert result.column(2) == pa.array(["Hello", None, None])
assert result.column(1) == pa.array(
["too small", "just right", "too big"], type=pa.string_view()
)
assert result.column(2) == pa.array(["Hello", None, None], type=pa.string_view())


def test_regr_funcs_sql(df):
Expand Down Expand Up @@ -999,8 +1021,13 @@ def test_regr_funcs_df(func, expected):

def test_binary_string_functions(df):
df = df.select(
f.encode(column("a"), literal("base64")),
f.decode(f.encode(column("a"), literal("base64")), literal("base64")),
f.encode(column("a").cast(pa.string()), literal("base64").cast(pa.string())),
f.decode(
f.encode(
column("a").cast(pa.string()), literal("base64").cast(pa.string())
),
literal("base64").cast(pa.string()),
),
)
result = df.collect()
assert len(result) == 1
Expand Down

0 comments on commit fcb5f96

Please sign in to comment.