Skip to content

Commit

Permalink
feat: Extend recognised EXTRACT and DATE_PART SQL part abbreviati…
Browse files Browse the repository at this point in the history
…ons (pola-rs#16767)
  • Loading branch information
alexander-beedie authored and Wouittone committed Jun 22, 2024
1 parent 79e1be1 commit 7307448
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 102 deletions.
15 changes: 11 additions & 4 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ use polars_plan::prelude::col;
use polars_plan::prelude::LiteralValue::Null;
use polars_plan::prelude::{lit, StrptimeOptions};
use sqlparser::ast::{
Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, Value as SQLValue,
WindowSpec, WindowType,
DateTimeField, Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, Ident,
Value as SQLValue, WindowSpec, WindowType,
};

use crate::sql_expr::{parse_date_part, parse_sql_expr};
use crate::sql_expr::{parse_extract_date_part, parse_sql_expr};
use crate::SQLContext;

pub(crate) struct SQLFunctionVisitor<'a> {
Expand Down Expand Up @@ -889,7 +889,14 @@ impl SQLFunctionVisitor<'_> {
},
DatePart => self.try_visit_binary(|part, e| {
match part {
Expr::Literal(LiteralValue::String(p)) => parse_date_part(e, &p),
Expr::Literal(LiteralValue::String(p)) => {
// note: 'DATE_PART' and 'EXTRACT' are minor syntactic
// variations on otherwise identical functionality
parse_extract_date_part(e, &DateTimeField::Custom(Ident {
value: p,
quote_style: None,
}))
},
_ => {
polars_bail!(SQLSyntax: "invalid 'part' for EXTRACT/DATE_PART: {}", function.args[1]);
}
Expand Down
74 changes: 38 additions & 36 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ impl SQLExprVisitor<'_> {
} => self.visit_cast(expr, data_type, format, true),
SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()),
SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents),
SQLExpr::Extract { field, expr } => parse_extract(self.visit_expr(expr)?, field),
SQLExpr::Extract { field, expr } => {
parse_extract_date_part(self.visit_expr(expr)?, field)
},
SQLExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()),
SQLExpr::Function(function) => self.visit_function(function),
SQLExpr::Identifier(ident) => self.visit_identifier(ident),
Expand Down Expand Up @@ -1171,7 +1173,41 @@ pub(crate) fn parse_sql_array(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsRes
}
}

fn parse_extract(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
pub(crate) fn parse_extract_date_part(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
let field = match field {
// handle 'DATE_PART' and all valid abbreviations/alternates
DateTimeField::Custom(Ident { value, .. }) => {
let value = value.to_ascii_lowercase();
match value.as_str() {
"millennium" | "millennia" => &DateTimeField::Millennium,
"century" | "centuries" => &DateTimeField::Century,
"decade" | "decades" => &DateTimeField::Decade,
"isoyear" => &DateTimeField::Isoyear,
"year" | "years" | "y" => &DateTimeField::Year,
"quarter" | "quarters" => &DateTimeField::Quarter,
"month" | "months" | "mon" | "mons" => &DateTimeField::Month,
"dayofyear" | "doy" => &DateTimeField::DayOfYear,
"dayofweek" | "dow" => &DateTimeField::DayOfWeek,
"isoweek" | "week" | "weeks" => &DateTimeField::IsoWeek,
"isodow" => &DateTimeField::Isodow,
"day" | "days" | "d" => &DateTimeField::Day,
"hour" | "hours" | "h" => &DateTimeField::Hour,
"minute" | "minutes" | "mins" | "min" | "m" => &DateTimeField::Minute,
"second" | "seconds" | "sec" | "secs" | "s" => &DateTimeField::Second,
"millisecond" | "milliseconds" | "ms" => &DateTimeField::Millisecond,
"microsecond" | "microseconds" | "us" => &DateTimeField::Microsecond,
"nanosecond" | "nanoseconds" | "ns" => &DateTimeField::Nanosecond,
#[cfg(feature = "timezones")]
"timezone" => &DateTimeField::Timezone,
"time" => &DateTimeField::Time,
"epoch" => &DateTimeField::Epoch,
_ => {
polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", value)
},
}
},
_ => field,
};
Ok(match field {
DateTimeField::Millennium => expr.dt().millennium(),
DateTimeField::Century => expr.dt().century(),
Expand Down Expand Up @@ -1226,40 +1262,6 @@ fn parse_extract(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
})
}

pub(crate) fn parse_date_part(expr: Expr, part: &str) -> PolarsResult<Expr> {
let part = part.to_ascii_lowercase();
parse_extract(
expr,
match part.as_str() {
"millennium" => &DateTimeField::Millennium,
"century" => &DateTimeField::Century,
"decade" => &DateTimeField::Decade,
"isoyear" => &DateTimeField::Isoyear,
"year" => &DateTimeField::Year,
"quarter" => &DateTimeField::Quarter,
"month" => &DateTimeField::Month,
"dayofyear" | "doy" => &DateTimeField::DayOfYear,
"dayofweek" | "dow" => &DateTimeField::DayOfWeek,
"isoweek" | "week" => &DateTimeField::IsoWeek,
"isodow" => &DateTimeField::Isodow,
"day" => &DateTimeField::Day,
"hour" => &DateTimeField::Hour,
"minute" => &DateTimeField::Minute,
"second" => &DateTimeField::Second,
"millisecond" | "milliseconds" => &DateTimeField::Millisecond,
"microsecond" | "microseconds" => &DateTimeField::Microsecond,
"nanosecond" | "nanoseconds" => &DateTimeField::Nanosecond,
#[cfg(feature = "timezones")]
"timezone" => &DateTimeField::Timezone,
"time" => &DateTimeField::Time,
"epoch" => &DateTimeField::Epoch,
_ => {
polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", part)
},
},
)
}

fn bitstring_to_bytes_literal(b: &String) -> PolarsResult<Expr> {
let n_bits = b.len();
if !b.chars().all(|c| c == '0' || c == '1') || n_bits > 64 {
Expand Down
10 changes: 3 additions & 7 deletions crates/polars-time/src/windows/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,9 @@ impl Duration {
},
_ if as_interval => match &*unit {
// interval-only (verbose/sql) matches
"nanosec" | "nanosecs" | "nanosecond" | "nanoseconds" => nsecs += n,
"microsec" | "microsecs" | "microsecond" | "microseconds" => {
nsecs += n * NS_MICROSECOND
},
"millisec" | "millisecs" | "millisecond" | "milliseconds" => {
nsecs += n * NS_MILLISECOND
},
"nanosecond" | "nanoseconds" => nsecs += n,
"microsecond" | "microseconds" => nsecs += n * NS_MICROSECOND,
"millisecond" | "milliseconds" => nsecs += n * NS_MILLISECOND,
"sec" | "secs" | "second" | "seconds" => nsecs += n * NS_SECOND,
"min" | "mins" | "minute" | "minutes" => nsecs += n * NS_MINUTE,
"hour" | "hours" => nsecs += n * NS_HOUR,
Expand Down
66 changes: 36 additions & 30 deletions py-polars/docs/source/reference/sql/functions/temporal.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,27 @@ DATE_PART
Extracts a part of a date (or datetime) such as 'year', 'month', etc.

**Supported parts/fields:**
- "day"
- "dayofweek" | "dow"
- "millennium" | "millennia"
- "century" | "centuries"
- "decade" | "decades"
- "isoyear"
- "year" | "years" | "y"
- "quarter" | "quarters"
- "month" | "months" | "mon" | "mons"
- "dayofyear" | "doy"
- "decade"
- "epoch"
- "hour"
- "isodow"
- "dayofweek" | "dow"
- "isoweek" | "week"
- "isoyear"
- "microsecond(s)"
- "millisecond(s)"
- "nanosecond(s)"
- "minute"
- "month"
- "quarter"
- "second"
- "isodow"
- "day" | "days" | "d"
- "hour" | "hours" | "h"
- "minute" | "minutes" | "mins" | "min" | "m"
- "second" | "seconds" | "sec" | "secs" | "s"
- "millisecond" | "milliseconds" | "ms"
- "microsecond" | "microseconds" | "us"
- "nanosecond" | "nanoseconds" | "ns"
- "timezone"
- "time"
- "year"
- "epoch"

**Example:**

Expand Down Expand Up @@ -106,24 +109,27 @@ EXTRACT
Extracts a part of a date (or datetime) such as 'year', 'month', etc.

**Supported parts/fields:**
- "day"
- "dayofweek" | "dow"
- "millennium" | "millennia"
- "century" | "centuries"
- "decade" | "decades"
- "isoyear"
- "year" | "years" | "y"
- "quarter" | "quarters"
- "month" | "months" | "mon" | "mons"
- "dayofyear" | "doy"
- "decade"
- "epoch"
- "hour"
- "isodow"
- "dayofweek" | "dow"
- "isoweek" | "week"
- "isoyear"
- "microsecond(s)"
- "millisecond(s)"
- "nanosecond(s)"
- "minute"
- "month"
- "quarter"
- "second"
- "isodow"
- "day" | "days" | "d"
- "hour" | "hours" | "h"
- "minute" | "minutes" | "mins" | "min" | "m"
- "second" | "seconds" | "sec" | "secs" | "s"
- "millisecond" | "milliseconds" | "ms"
- "microsecond" | "microseconds" | "us"
- "nanosecond" | "nanoseconds" | "ns"
- "timezone"
- "time"
- "year"
- "epoch"


.. code-block:: python
Expand Down
65 changes: 40 additions & 25 deletions py-polars/tests/unit/sql/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,37 +56,49 @@ def test_datetime_to_time(time_unit: Literal["ns", "us", "ms"]) -> None:


@pytest.mark.parametrize(
("part", "dtype", "expected"),
("parts", "dtype", "expected"),
[
("decade", pl.Int32, [202, 202, 200]),
("isoyear", pl.Int32, [2024, 2020, 2005]),
("year", pl.Int32, [2024, 2020, 2006]),
("quarter", pl.Int8, [1, 4, 1]),
("month", pl.Int8, [1, 12, 1]),
("week", pl.Int8, [1, 53, 52]),
("doy", pl.Int16, [7, 365, 1]),
("isodow", pl.Int8, [7, 3, 7]),
("dow", pl.Int8, [0, 3, 0]),
("day", pl.Int8, [7, 30, 1]),
("hour", pl.Int8, [1, 10, 23]),
("minute", pl.Int8, [2, 30, 59]),
("second", pl.Int8, [3, 45, 59]),
("millisecond", pl.Float64, [3123.456, 45987.654, 59555.555]),
("microsecond", pl.Float64, [3123456.0, 45987654.0, 59555555.0]),
("nanosecond", pl.Float64, [3123456000.0, 45987654000.0, 59555555000.0]),
(["decade", "decades"], pl.Int32, [202, 202, 200]),
(["isoyear"], pl.Int32, [2024, 2020, 2005]),
(["year", "y"], pl.Int32, [2024, 2020, 2006]),
(["quarter"], pl.Int8, [1, 4, 1]),
(["month", "months", "mon", "mons"], pl.Int8, [1, 12, 1]),
(["week", "weeks"], pl.Int8, [1, 53, 52]),
(["doy"], pl.Int16, [7, 365, 1]),
(["isodow"], pl.Int8, [7, 3, 7]),
(["dow"], pl.Int8, [0, 3, 0]),
(["day", "days", "d"], pl.Int8, [7, 30, 1]),
(["hour", "hours", "h"], pl.Int8, [1, 10, 23]),
(["minute", "min", "mins", "m"], pl.Int8, [2, 30, 59]),
(["second", "seconds", "secs", "sec"], pl.Int8, [3, 45, 59]),
(
"time",
["millisecond", "milliseconds", "ms"],
pl.Float64,
[3123.456, 45987.654, 59555.555],
),
(
["microsecond", "microseconds", "us"],
pl.Float64,
[3123456.0, 45987654.0, 59555555.0],
),
(
["nanosecond", "nanoseconds", "ns"],
pl.Float64,
[3123456000.0, 45987654000.0, 59555555000.0],
),
(
["time"],
pl.Time,
[time(1, 2, 3, 123456), time(10, 30, 45, 987654), time(23, 59, 59, 555555)],
),
(
"epoch",
["epoch"],
pl.Float64,
[1704589323.123456, 1609324245.987654, 1136159999.555555],
),
],
)
def test_extract(part: str, dtype: pl.DataType, expected: list[Any]) -> None:
def test_extract(parts: list[str], dtype: pl.DataType, expected: list[Any]) -> None:
df = pl.DataFrame(
{
"dt": [
Expand All @@ -100,11 +112,14 @@ def test_extract(part: str, dtype: pl.DataType, expected: list[Any]) -> None:
}
)
with pl.SQLContext(frame_data=df, eager=True) as ctx:
for func in (f"EXTRACT({part} FROM dt)", f"DATE_PART('{part}',dt)"):
res = ctx.execute(f"SELECT {func} AS {part} FROM frame_data").to_series()

assert res.dtype == dtype
assert res.to_list() == expected
for part in parts:
for fn in (
f"EXTRACT({part} FROM dt)",
f"DATE_PART('{part}',dt)",
):
res = ctx.execute(f"SELECT {fn} AS {part} FROM frame_data").to_series()
assert res.dtype == dtype
assert res.to_list() == expected


def test_extract_errors() -> None:
Expand Down

0 comments on commit 7307448

Please sign in to comment.