Skip to content

Commit

Permalink
refactor!: PosqlTimeZone to use sqlparser::ast::TimezoneInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
varshith257 committed Jan 14, 2025
1 parent 8baa9df commit 7621c97
Show file tree
Hide file tree
Showing 30 changed files with 279 additions and 208 deletions.
56 changes: 55 additions & 1 deletion crates/proof-of-sql-parser/src/sqlparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ use crate::{
OrderBy as PoSqlOrderBy, OrderByDirection, SelectResultExpr, SetExpression,
TableExpression, UnaryOperator as PoSqlUnaryOperator,
},
posql_time::{PoSQLTimeUnit, PoSQLTimeZone},
Identifier, ResourceId, SelectStatement,
};
use alloc::{boxed::Box, string::ToString, vec};
use alloc::{
boxed::Box,
string::{String, ToString},
vec,
};
use core::fmt::Display;
use sqlparser::ast::{
BinaryOperator, DataType, Expr, Function, FunctionArg, FunctionArgExpr, GroupByExpr, Ident,
Expand All @@ -28,6 +33,50 @@ fn id(id: Identifier) -> Expr {
Expr::Identifier(id.into())
}

/// Provides an extension for the `TimezoneInfo` type for offsets.
pub trait TimezoneInfoExt {
/// Retrieve the offset in seconds for `TimezoneInfo`.
fn offset(&self, timezone_str: Option<&str>) -> i32;
}

impl TimezoneInfoExt for TimezoneInfo {
fn offset(&self, timezone_str: Option<&str>) -> i32 {
match self {
TimezoneInfo::None | TimezoneInfo::WithoutTimeZone => PoSQLTimeZone::utc().offset(),
TimezoneInfo::WithTimeZone | TimezoneInfo::Tz => match timezone_str {
Some(tz_str) => PoSQLTimeZone::try_from(&Some(tz_str.into()))
.unwrap_or_else(|_| PoSQLTimeZone::utc())
.offset(),
None => PoSQLTimeZone::utc().offset(),
},
}
}
}

/// Utility function to create a `Timestamp` expression.
pub fn timestamp_to_expr(
value: &str,
time_unit: PoSQLTimeUnit,
timezone: TimezoneInfo,
) -> Result<Expr, String> {
let time_unit_as_u64 = u64::from(time_unit);

Ok(Expr::TypedString {
data_type: DataType::Timestamp(Some(time_unit_as_u64), timezone),
value: value.to_string(),
})
}

/// Parses [`PoSQLTimeZone`] into a `TimezoneInfo`.
impl From<PoSQLTimeZone> for TimezoneInfo {
fn from(posql_timezone: PoSQLTimeZone) -> Self {
match posql_timezone.offset() {
0 => TimezoneInfo::None,
_ => TimezoneInfo::WithTimeZone,
}
}
}

impl From<Identifier> for Ident {
fn from(id: Identifier) -> Self {
Ident::new(id.as_str())
Expand Down Expand Up @@ -268,6 +317,11 @@ mod test {
"select timestamp '2024-11-07T04:55:12.345+03:00' as time from t;",
"select timestamp(3) '2024-11-07 01:55:12.345 UTC' as time from t;",
);

check_posql_intermediate_ast_to_sqlparser_equivalence(
"select timestamp '2024-11-07T04:55:12+00:00' as time from t;",
"select timestamp(0) '2024-11-07 04:55:12 UTC' as time from t;",
);
}

// Check that PoSQL intermediate AST can be converted to SQL parser AST and that the two are equal.
Expand Down
5 changes: 3 additions & 2 deletions crates/proof-of-sql/benches/bench_append_rows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ use proof_of_sql::{
DoryCommitment, DoryProverPublicSetup, DoryScalar, ProverSetup, PublicParameters,
},
};
use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone};
use proof_of_sql_parser::posql_time::PoSQLTimeUnit;
use rand::Rng;
use sqlparser::ast::TimezoneInfo;

/// Bench dory performance when appending rows to a table. This includes the computation of
/// commitments. Chose the number of columns to randomly generate across supported `PoSQL`
Expand Down Expand Up @@ -121,7 +122,7 @@ pub fn generate_random_owned_table<S: Scalar>(
"timestamptz" => columns.push(timestamptz(
&*identifier,
PoSQLTimeUnit::Second,
PoSQLTimeZone::utc(),
TimezoneInfo::None,
vec![rng.gen::<i64>(); num_rows],
)),
_ => unreachable!(),
Expand Down
120 changes: 65 additions & 55 deletions crates/proof-of-sql/src/base/arrow/arrow_array_to_column_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,60 +202,69 @@ impl ArrayRefExt for ArrayRef {
}
}
// Handle all possible TimeStamp TimeUnit instances
DataType::Timestamp(time_unit, tz) => match time_unit {
ArrowTimeUnit::Second => {
if let Some(array) = self.as_any().downcast_ref::<TimestampSecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Second,
PoSQLTimeZone::try_from(tz)?,
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
DataType::Timestamp(time_unit, tz) => {
let timezone = PoSQLTimeZone::try_from(tz)?;
match time_unit {
ArrowTimeUnit::Second => {
if let Some(array) = self.as_any().downcast_ref::<TimestampSecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Second,
timezone.into(),
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
}
}
}
ArrowTimeUnit::Millisecond => {
if let Some(array) = self.as_any().downcast_ref::<TimestampMillisecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Millisecond,
PoSQLTimeZone::try_from(tz)?,
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
ArrowTimeUnit::Millisecond => {
if let Some(array) =
self.as_any().downcast_ref::<TimestampMillisecondArray>()
{
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Millisecond,
timezone.into(),
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
}
}
}
ArrowTimeUnit::Microsecond => {
if let Some(array) = self.as_any().downcast_ref::<TimestampMicrosecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Microsecond,
PoSQLTimeZone::try_from(tz)?,
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
ArrowTimeUnit::Microsecond => {
if let Some(array) =
self.as_any().downcast_ref::<TimestampMicrosecondArray>()
{
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Microsecond,
timezone.into(),
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
}
}
}
ArrowTimeUnit::Nanosecond => {
if let Some(array) = self.as_any().downcast_ref::<TimestampNanosecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Nanosecond,
PoSQLTimeZone::try_from(tz)?,
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
ArrowTimeUnit::Nanosecond => {
if let Some(array) =
self.as_any().downcast_ref::<TimestampNanosecondArray>()
{
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Nanosecond,
timezone.into(),
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
}
}
}
},
}
DataType::Utf8 => {
if let Some(array) = self.as_any().downcast_ref::<StringArray>() {
let vals = alloc
Expand Down Expand Up @@ -292,6 +301,7 @@ mod tests {
use alloc::sync::Arc;
use arrow::array::Decimal256Builder;
use core::str::FromStr;
use sqlparser::ast::TimezoneInfo;

#[test]
fn we_can_convert_timestamp_array_normal_range() {
Expand All @@ -305,7 +315,7 @@ mod tests {
let result = array.to_column::<TestScalar>(&alloc, &(1..3), None);
assert_eq!(
result.unwrap(),
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &data[1..3])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &data[1..3])
);
}

Expand All @@ -323,7 +333,7 @@ mod tests {
.unwrap();
assert_eq!(
result,
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[])
);
}

Expand All @@ -339,7 +349,7 @@ mod tests {
let result = array.to_column::<DoryScalar>(&alloc, &(1..1), None);
assert_eq!(
result.unwrap(),
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[])
);
}

Expand Down Expand Up @@ -1006,7 +1016,7 @@ mod tests {
.unwrap();
assert_eq!(
result,
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &data[..])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &data[..])
);
}

Expand Down Expand Up @@ -1076,7 +1086,7 @@ mod tests {
array
.to_column::<TestScalar>(&alloc, &(1..3), None)
.unwrap(),
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &data[1..3])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &data[1..3])
);
}

Expand Down Expand Up @@ -1134,7 +1144,7 @@ mod tests {
.unwrap();
assert_eq!(
result,
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[])
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl TryFrom<DataType> for ColumnType {
};
Ok(ColumnType::TimestampTZ(
posql_time_unit,
PoSQLTimeZone::try_from(&timezone_option)?,
PoSQLTimeZone::try_from(&timezone_option)?.into(),
))
}
DataType::Utf8 => Ok(ColumnType::VarChar),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Second,
PoSQLTimeZone::try_from(timezone)?,
PoSQLTimeZone::try_from(timezone)?.into(),
timestamps,
))
}
Expand All @@ -252,7 +252,7 @@ impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Millisecond,
PoSQLTimeZone::try_from(timezone)?,
PoSQLTimeZone::try_from(timezone)?.into(),
timestamps,
))
}
Expand All @@ -266,7 +266,7 @@ impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Microsecond,
PoSQLTimeZone::try_from(timezone)?,
PoSQLTimeZone::try_from(timezone)?.into(),
timestamps,
))
}
Expand All @@ -280,7 +280,7 @@ impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Nanosecond,
PoSQLTimeZone::try_from(timezone)?,
PoSQLTimeZone::try_from(timezone)?.into(),
timestamps,
))
}
Expand Down
5 changes: 3 additions & 2 deletions crates/proof-of-sql/src/base/commitment/column_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ mod tests {
};
use alloc::{string::String, vec};
use itertools::Itertools;
use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone};
use proof_of_sql_parser::posql_time::PoSQLTimeUnit;
use sqlparser::ast::TimezoneInfo;

#[test]
fn we_can_construct_bounds_by_method() {
Expand Down Expand Up @@ -563,7 +564,7 @@ mod tests {

let timestamp_column = OwnedColumn::<TestScalar>::TimestampTZ(
PoSQLTimeUnit::Second,
PoSQLTimeZone::utc(),
TimezoneInfo::None,
vec![1_i64, 2, 3, 4],
);
let committable_timestamp_column = CommittableColumn::from(&timestamp_column);
Expand Down
Loading

0 comments on commit 7621c97

Please sign in to comment.