From 13420be17f43f5ee2fe021a1f01bdb74d03f5216 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 15 Sep 2024 06:49:05 -0400 Subject: [PATCH] fix(clickhouse): move nullable computation into attribute to match sqlglot --- ibis/backends/sql/datatypes.py | 287 +++++++++++++++++++-------------- 1 file changed, 167 insertions(+), 120 deletions(-) diff --git a/ibis/backends/sql/datatypes.py b/ibis/backends/sql/datatypes.py index 649e155ee705..c4d3317a2970 100644 --- a/ibis/backends/sql/datatypes.py +++ b/ibis/backends/sql/datatypes.py @@ -170,10 +170,13 @@ def to_ibis(cls, typ: sge.DataType, nullable: bool | None = None) -> dt.DataType ) typecode = typ.this + nullable = typ.args.get( + "nullable", nullable if nullable is not None else cls.default_nullable + ) if method := getattr(cls, f"_from_sqlglot_{typecode.name}", None): - dtype = method(*typ.expressions) + dtype = method(*typ.expressions, nullable=nullable) elif (known_typ := _from_sqlglot_types.get(typecode)) is not None: - dtype = known_typ(nullable=cls.default_nullable) + dtype = known_typ(nullable=nullable) else: dtype = dt.unknown @@ -209,65 +212,83 @@ def to_string(cls, dtype: dt.DataType) -> str: return cls.from_ibis(dtype).sql(dialect=cls.dialect) @classmethod - def _from_sqlglot_ARRAY(cls, value_type: sge.DataType) -> dt.Array: - return dt.Array(cls.to_ibis(value_type), nullable=cls.default_nullable) + def _from_sqlglot_ARRAY( + cls, value_type: sge.DataType, nullable: bool | None = None + ) -> dt.Array: + return dt.Array(cls.to_ibis(value_type), nullable=nullable) @classmethod def _from_sqlglot_MAP( - cls, key_type: sge.DataType, value_type: sge.DataType + cls, + key_type: sge.DataType, + value_type: sge.DataType, + nullable: bool | None = None, ) -> dt.Map: - return dt.Map( - cls.to_ibis(key_type), - cls.to_ibis(value_type), - nullable=cls.default_nullable, - ) + return dt.Map(cls.to_ibis(key_type), cls.to_ibis(value_type), nullable=nullable) @classmethod - def _from_sqlglot_STRUCT(cls, *fields: sge.ColumnDef) -> dt.Struct: + def _from_sqlglot_STRUCT( + cls, *fields: sge.ColumnDef, nullable: bool | None = None + ) -> dt.Struct: types = {} for i, field in enumerate(fields): if isinstance(field, sge.ColumnDef): - types[field.name] = cls.to_ibis(field.args["kind"]) + name = field.name + sgtype = field.args["kind"] else: - types[f"f{i:d}"] = cls.from_string(str(field)) - return dt.Struct(types, nullable=cls.default_nullable) + # handle unnamed fields (e.g., ClickHouse's Tuple type) + assert isinstance(field, sge.DataType), type(field) + name = f"f{i:d}" + sgtype = field + + types[name] = cls.to_ibis(sgtype) + return dt.Struct(types, nullable=nullable) @classmethod - def _from_sqlglot_TIMESTAMP(cls, scale=None) -> dt.Timestamp: + def _from_sqlglot_TIMESTAMP( + cls, scale=None, nullable: bool | None = None + ) -> dt.Timestamp: return dt.Timestamp( scale=cls.default_temporal_scale if scale is None else int(scale.this.this), - nullable=cls.default_nullable, + nullable=nullable, ) @classmethod - def _from_sqlglot_TIMESTAMPTZ(cls, scale=None) -> dt.Timestamp: + def _from_sqlglot_TIMESTAMPTZ( + cls, scale=None, nullable: bool | None = None + ) -> dt.Timestamp: return dt.Timestamp( timezone="UTC", scale=cls.default_temporal_scale if scale is None else int(scale.this.this), - nullable=cls.default_nullable, + nullable=nullable, ) @classmethod - def _from_sqlglot_TIMESTAMPLTZ(cls, scale=None) -> dt.Timestamp: + def _from_sqlglot_TIMESTAMPLTZ( + cls, scale=None, nullable: bool | None = None + ) -> dt.Timestamp: return dt.Timestamp( timezone="UTC", scale=cls.default_temporal_scale if scale is None else int(scale.this.this), - nullable=cls.default_nullable, + nullable=nullable, ) @classmethod - def _from_sqlglot_TIMESTAMPNTZ(cls, scale=None) -> dt.Timestamp: + def _from_sqlglot_TIMESTAMPNTZ( + cls, scale=None, nullable: bool | None = None + ) -> dt.Timestamp: return dt.Timestamp( timezone=None, scale=cls.default_temporal_scale if scale is None else int(scale.this.this), - nullable=cls.default_nullable, + nullable=nullable, ) @classmethod def _from_sqlglot_INTERVAL( - cls, precision_or_span: sge.IntervalSpan | None = None + cls, + precision_or_span: sge.IntervalSpan | None = None, + nullable: bool | None = None, ) -> dt.Interval: - nullable = cls.default_nullable if precision_or_span is None: precision_or_span = cls.default_interval_precision @@ -291,6 +312,7 @@ def _from_sqlglot_DECIMAL( cls, precision: sge.DataTypeParam | None = None, scale: sge.DataTypeParam | None = None, + nullable: bool | None = None, ) -> dt.Decimal: if precision is None: precision = cls.default_decimal_precision @@ -302,11 +324,14 @@ def _from_sqlglot_DECIMAL( else: scale = int(scale.this.this) - return dt.Decimal(precision, scale, nullable=cls.default_nullable) + return dt.Decimal(precision, scale, nullable=nullable) @classmethod def _from_sqlglot_GEOMETRY( - cls, arg: sge.DataTypeParam | None = None, srid: sge.DataTypeParam | None = None + cls, + arg: sge.DataTypeParam | None = None, + srid: sge.DataTypeParam | None = None, + nullable: bool | None = None, ) -> sge.DataType: if arg is not None: typeclass = _geotypes[arg.this.this] @@ -314,11 +339,14 @@ def _from_sqlglot_GEOMETRY( typeclass = dt.GeoSpatial if srid is not None: srid = int(srid.this.this) - return typeclass(geotype="geometry", nullable=cls.default_nullable, srid=srid) + return typeclass(geotype="geometry", nullable=nullable, srid=srid) @classmethod def _from_sqlglot_GEOGRAPHY( - cls, arg: sge.DataTypeParam | None = None, srid: sge.DataTypeParam | None = None + cls, + arg: sge.DataTypeParam | None = None, + srid: sge.DataTypeParam | None = None, + nullable: bool | None = None, ) -> sge.DataType: if arg is not None: typeclass = _geotypes[arg.this.this] @@ -326,7 +354,7 @@ def _from_sqlglot_GEOGRAPHY( typeclass = dt.GeoSpatial if srid is not None: srid = int(srid.this.this) - return typeclass(geotype="geography", nullable=cls.default_nullable, srid=srid) + return typeclass(geotype="geography", nullable=nullable, srid=srid) @classmethod def _from_ibis_JSON(cls, dtype: dt.JSON) -> sge.DataType: @@ -517,19 +545,23 @@ class MySQLType(SqlglotType): ) @classmethod - def _from_sqlglot_BIT(cls, nbits: sge.DataTypeParam) -> dt.Integer: + def _from_sqlglot_BIT( + cls, nbits: sge.DataTypeParam, nullable: bool | None = None + ) -> dt.Integer: nbits = int(nbits.this.this) if nbits > 32: - return dt.Int64(nullable=cls.default_nullable) + return dt.Int64(nullable=nullable) elif nbits > 16: - return dt.Int32(nullable=cls.default_nullable) + return dt.Int32(nullable=nullable) elif nbits > 8: - return dt.Int16(nullable=cls.default_nullable) + return dt.Int16(nullable=nullable) else: - return dt.Int8(nullable=cls.default_nullable) + return dt.Int8(nullable=nullable) @classmethod - def _from_sqlglot_DATETIME(cls, scale=None) -> dt.Timestamp: + def _from_sqlglot_DATETIME( + cls, scale=None, nullable: bool | None = None + ) -> dt.Timestamp: if scale is not None: scale = int(scale.this.this) return dt.Timestamp( @@ -540,12 +572,12 @@ def _from_sqlglot_DATETIME(cls, scale=None) -> dt.Timestamp: # https://dev.mysql.com/doc/refman/8.4/en/fractional-seconds.html # for details scale=scale or None, - nullable=cls.default_nullable, + nullable=nullable, ) @classmethod - def _from_sqlglot_TIMESTAMP(cls) -> dt.Timestamp: - return dt.Timestamp(timezone="UTC", nullable=cls.default_nullable) + def _from_sqlglot_TIMESTAMP(cls, nullable: bool | None = None) -> dt.Timestamp: + return dt.Timestamp(timezone="UTC", nullable=nullable) @classmethod def _from_ibis_String(cls, dtype: dt.String) -> sge.DataType: @@ -561,24 +593,24 @@ class DuckDBType(SqlglotType): unknown_type_strings = FrozenDict({"wkb_blob": dt.binary}) @classmethod - def _from_sqlglot_TIMESTAMP(cls) -> dt.Timestamp: - return dt.Timestamp(scale=6, nullable=cls.default_nullable) + def _from_sqlglot_TIMESTAMP(cls, nullable: bool | None = None) -> dt.Timestamp: + return dt.Timestamp(scale=6, nullable=nullable) @classmethod - def _from_sqlglot_TIMESTAMPTZ(cls) -> dt.Timestamp: - return dt.Timestamp(scale=6, timezone="UTC", nullable=cls.default_nullable) + def _from_sqlglot_TIMESTAMPTZ(cls, nullable: bool | None = None) -> dt.Timestamp: + return dt.Timestamp(scale=6, timezone="UTC", nullable=nullable) @classmethod - def _from_sqlglot_TIMESTAMP_S(cls) -> dt.Timestamp: - return dt.Timestamp(scale=0, nullable=cls.default_nullable) + def _from_sqlglot_TIMESTAMP_S(cls, nullable: bool | None = None) -> dt.Timestamp: + return dt.Timestamp(scale=0, nullable=nullable) @classmethod - def _from_sqlglot_TIMESTAMP_MS(cls) -> dt.Timestamp: - return dt.Timestamp(scale=3, nullable=cls.default_nullable) + def _from_sqlglot_TIMESTAMP_MS(cls, nullable: bool | None = None) -> dt.Timestamp: + return dt.Timestamp(scale=3, nullable=nullable) @classmethod - def _from_sqlglot_TIMESTAMP_NS(cls) -> dt.Timestamp: - return dt.Timestamp(scale=9, nullable=cls.default_nullable) + def _from_sqlglot_TIMESTAMP_NS(cls, nullable: bool | None = None) -> dt.Timestamp: + return dt.Timestamp(scale=9, nullable=nullable) @classmethod def _from_ibis_GeoSpatial(cls, dtype: dt.GeoSpatial): @@ -628,8 +660,8 @@ def _from_ibis_Interval(cls, dtype: dt.Interval) -> sge.DataType: ) @classmethod - def _from_sqlglot_UBIGINT(cls): - return dt.Decimal(precision=19, scale=0, nullable=cls.default_nullable) + def _from_sqlglot_UBIGINT(cls, nullable: bool | None = None): + return dt.Decimal(precision=19, scale=0, nullable=nullable) @classmethod def _from_ibis_UInt64(cls, dtype): @@ -642,24 +674,24 @@ def _from_ibis_UInt64(cls, dtype): ) @classmethod - def _from_sqlglot_UINT(cls): - return dt.Int64(nullable=cls.default_nullable) + def _from_sqlglot_UINT(cls, nullable: bool | None = None): + return dt.Int64(nullable=nullable) @classmethod def _from_ibis_UInt32(cls, dtype): return sge.DataType(this=typecode.BIGINT) @classmethod - def _from_sqlglot_USMALLINT(cls): - return dt.Int32(nullable=cls.default_nullable) + def _from_sqlglot_USMALLINT(cls, nullable: bool | None = None): + return dt.Int32(nullable=nullable) @classmethod def _from_ibis_UInt16(cls, dtype): return sge.DataType(this=typecode.INT) @classmethod - def _from_sqlglot_UTINYINT(cls): - return dt.Int16(nullable=cls.default_nullable) + def _from_sqlglot_UTINYINT(cls, nullable: bool | None = None): + return dt.Int16(nullable=nullable) @classmethod def _from_ibis_UInt8(cls, dtype): @@ -683,13 +715,15 @@ class OracleType(SqlglotType): unknown_type_strings = FrozenDict({"raw": dt.binary}) @classmethod - def _from_sqlglot_FLOAT(cls) -> dt.Float64: - return dt.Float64(nullable=cls.default_nullable) + def _from_sqlglot_FLOAT(cls, nullable: bool | None = None) -> dt.Float64: + return dt.Float64(nullable=nullable) @classmethod - def _from_sqlglot_DECIMAL(cls, precision=None, scale=None) -> dt.Decimal: + def _from_sqlglot_DECIMAL( + cls, precision=None, scale=None, nullable: bool | None = None + ) -> dt.Decimal: if scale is None or int(scale.this.this) == 0: - return dt.Int64(nullable=cls.default_nullable) + return dt.Int64(nullable=nullable) else: return super()._from_sqlglot_DECIMAL(precision, scale) @@ -708,20 +742,24 @@ class SnowflakeType(SqlglotType): default_temporal_scale = 9 @classmethod - def _from_sqlglot_FLOAT(cls) -> dt.Float64: - return dt.Float64(nullable=cls.default_nullable) + def _from_sqlglot_FLOAT(cls, nullable: bool | None = None) -> dt.Float64: + return dt.Float64(nullable=nullable) @classmethod - def _from_sqlglot_DECIMAL(cls, precision=None, scale=None) -> dt.Decimal: + def _from_sqlglot_DECIMAL( + cls, precision=None, scale=None, nullable: bool | None = None + ) -> dt.Decimal: if scale is None or int(scale.this.this) == 0: - return dt.Int64(nullable=cls.default_nullable) + return dt.Int64(nullable=nullable) else: return super()._from_sqlglot_DECIMAL(precision, scale) @classmethod - def _from_sqlglot_ARRAY(cls, value_type=None) -> dt.Array: + def _from_sqlglot_ARRAY( + cls, value_type=None, nullable: bool | None = None + ) -> dt.Array: assert value_type is None - return dt.Array(dt.json, nullable=cls.default_nullable) + return dt.Array(dt.json, nullable=nullable) @classmethod def _from_ibis_JSON(cls, dtype: dt.JSON) -> sge.DataType: @@ -744,12 +782,12 @@ class SQLiteType(SqlglotType): dialect = "sqlite" @classmethod - def _from_sqlglot_INT(cls) -> dt.Int64: - return dt.Int64(nullable=cls.default_nullable) + def _from_sqlglot_INT(cls, nullable: bool | None = None) -> dt.Int64: + return dt.Int64(nullable=nullable) @classmethod - def _from_sqlglot_FLOAT(cls) -> dt.Float64: - return dt.Float64(nullable=cls.default_nullable) + def _from_sqlglot_FLOAT(cls, nullable: bool | None = None) -> dt.Float64: + return dt.Float64(nullable=nullable) @classmethod def _from_ibis_Array(cls, dtype: dt.Array) -> NoReturn: @@ -797,40 +835,39 @@ class BigQueryType(SqlglotType): default_decimal_scale = 9 @classmethod - def _from_sqlglot_NUMERIC(cls) -> dt.Decimal: + def _from_sqlglot_NUMERIC(cls, nullable: bool | None = None) -> dt.Decimal: return dt.Decimal( - cls.default_decimal_precision, - cls.default_decimal_scale, - nullable=cls.default_nullable, + cls.default_decimal_precision, cls.default_decimal_scale, nullable=nullable ) @classmethod - def _from_sqlglot_BIGNUMERIC(cls) -> dt.Decimal: - return dt.Decimal(76, 38, nullable=cls.default_nullable) + def _from_sqlglot_BIGNUMERIC(cls, nullable: bool | None = None) -> dt.Decimal: + return dt.Decimal(76, 38, nullable=nullable) @classmethod - def _from_sqlglot_DATETIME(cls) -> dt.Timestamp: - return dt.Timestamp(timezone=None, nullable=cls.default_nullable) + def _from_sqlglot_DATETIME(cls, nullable: bool | None = None) -> dt.Timestamp: + return dt.Timestamp(timezone=None, nullable=nullable) @classmethod - def _from_sqlglot_TIMESTAMP(cls) -> dt.Timestamp: - return dt.Timestamp(timezone=None, nullable=cls.default_nullable) + def _from_sqlglot_TIMESTAMP(cls, nullable: bool | None = None) -> dt.Timestamp: + return dt.Timestamp(timezone=None, nullable=nullable) @classmethod - def _from_sqlglot_TIMESTAMPTZ(cls) -> dt.Timestamp: - return dt.Timestamp(timezone="UTC", nullable=cls.default_nullable) + def _from_sqlglot_TIMESTAMPTZ(cls, nullable: bool | None = None) -> dt.Timestamp: + return dt.Timestamp(timezone="UTC", nullable=nullable) @classmethod def _from_sqlglot_GEOGRAPHY( - cls, arg: sge.DataTypeParam | None = None, srid: sge.DataTypeParam | None = None + cls, + arg: sge.DataTypeParam | None = None, + srid: sge.DataTypeParam | None = None, + nullable: bool | None = None, ) -> dt.GeoSpatial: - return dt.GeoSpatial( - geotype="geography", srid=4326, nullable=cls.default_nullable - ) + return dt.GeoSpatial(geotype="geography", srid=4326, nullable=nullable) @classmethod - def _from_sqlglot_TINYINT(cls) -> dt.Int64: - return dt.Int64(nullable=cls.default_nullable) + def _from_sqlglot_TINYINT(cls, nullable: bool | None = None) -> dt.Int64: + return dt.Int64(nullable=nullable) _from_sqlglot_UINT = _from_sqlglot_USMALLINT = _from_sqlglot_UTINYINT = ( _from_sqlglot_INT @@ -843,8 +880,8 @@ def _from_sqlglot_UBIGINT(cls) -> NoReturn: ) @classmethod - def _from_sqlglot_FLOAT(cls) -> dt.Float64: - return dt.Float64(nullable=cls.default_nullable) + def _from_sqlglot_FLOAT(cls, nullable: bool | None = None) -> dt.Float64: + return dt.Float64(nullable=nullable) @classmethod def _from_sqlglot_MAP(cls) -> NoReturn: @@ -931,6 +968,7 @@ def _from_sqlglot_DECIMAL( cls, precision: sge.DataTypeParam | None = None, scale: sge.DataTypeParam | None = None, + nullable: bool | None = None, ) -> dt.Decimal: if precision is None: precision = cls.default_decimal_precision @@ -944,18 +982,18 @@ def _from_sqlglot_DECIMAL( if not scale: if 0 < precision <= 3: - return dt.Int8(nullable=cls.default_nullable) + return dt.Int8(nullable=nullable) elif 3 < precision <= 9: - return dt.Int16(nullable=cls.default_nullable) + return dt.Int16(nullable=nullable) elif 9 < precision <= 18: - return dt.Int32(nullable=cls.default_nullable) + return dt.Int32(nullable=nullable) elif 18 < precision <= 36: - return dt.Int64(nullable=cls.default_nullable) + return dt.Int64(nullable=nullable) else: raise com.UnsupportedBackendType( "Decimal precision is too large; Exasol supports precision up to 36." ) - return dt.Decimal(precision, scale, nullable=cls.default_nullable) + return dt.Decimal(precision, scale, nullable=nullable) @classmethod def _from_ibis_Array(cls, dtype: dt.Array) -> NoReturn: @@ -993,17 +1031,17 @@ class MSSQLType(SqlglotType): unknown_type_strings = FrozenDict({"hierarchyid": dt.string}) @classmethod - def _from_sqlglot_BIT(cls): - return dt.Boolean(nullable=cls.default_nullable) + def _from_sqlglot_BIT(cls, nullable: bool | None = None): + return dt.Boolean(nullable=nullable) @classmethod - def _from_sqlglot_IMAGE(cls): - return dt.Binary(nullable=cls.default_nullable) + def _from_sqlglot_IMAGE(cls, nullable: bool | None = None): + return dt.Binary(nullable=nullable) @classmethod - def _from_sqlglot_DATETIME(cls, n=None): + def _from_sqlglot_DATETIME(cls, n=None, nullable: bool | None = None): return dt.Timestamp( - scale=n if n is None else int(n.this.this), nullable=cls.default_nullable + scale=n if n is None else int(n.this.this), nullable=nullable ) @classmethod @@ -1062,30 +1100,36 @@ class ClickHouseType(SqlglotType): def from_ibis(cls, dtype: dt.DataType) -> sge.DataType: typ = super().from_ibis(dtype) - if typ.this == typecode.NULLABLE: + if typ.args.get("nullable") is True: return typ - # nested types cannot be nullable in clickhouse - typ.args["nullable"] = False - if dtype.nullable and not ( + typ.args["nullable"] = dtype.nullable and not ( + # nested types cannot be nullable in clickhouse dtype.is_map() or dtype.is_array() or dtype.is_struct() - ): - return sge.DataType(this=typecode.NULLABLE, expressions=[typ]) - else: - return typ + ) + return typ @classmethod - def _from_sqlglot_NULLABLE(cls, inner_type: sge.DataType) -> dt.DataType: + def _from_sqlglot_NULLABLE( + cls, + inner_type: sge.DataType, + # nullable is ignored when explicitly wrapped in ClickHouse's Nullable + # type modifier + # + # NULLABLE was removed in sqlglot 25.11, but this remains for backwards + # compatibility in Ibis + nullable: bool | None = None, + ) -> dt.DataType: return cls.to_ibis(inner_type, nullable=True) @classmethod def _from_sqlglot_DATETIME( - cls, timezone: sge.DataTypeParam | None = None + cls, timezone: sge.DataTypeParam | None = None, nullable: bool | None = None ) -> dt.Timestamp: return dt.Timestamp( scale=0, timezone=None if timezone is None else timezone.this.this, - nullable=cls.default_nullable, + nullable=nullable, ) @classmethod @@ -1093,26 +1137,29 @@ def _from_sqlglot_DATETIME64( cls, scale: sge.DataTypeSize | None = None, timezone: sge.Literal | None = None, + nullable: bool | None = None, ) -> dt.Timestamp: return dt.Timestamp( timezone=None if timezone is None else timezone.this.this, scale=int(scale.this.this), - nullable=cls.default_nullable, + nullable=nullable, ) @classmethod - def _from_sqlglot_LOWCARDINALITY(cls, inner_type: sge.DataType) -> dt.DataType: - return cls.to_ibis(inner_type) + def _from_sqlglot_LOWCARDINALITY( + cls, inner_type: sge.DataType, nullable: bool | None = None + ) -> dt.DataType: + return cls.to_ibis(inner_type, nullable=nullable) @classmethod - def _from_sqlglot_NESTED(cls, *fields: sge.DataType) -> dt.Struct: + def _from_sqlglot_NESTED( + cls, *fields: sge.DataType, nullable: bool | None = None + ) -> dt.Struct: fields = { - field.name: dt.Array( - cls.to_ibis(field.args["kind"]), nullable=cls.default_nullable - ) + field.name: dt.Array(cls.to_ibis(field.args["kind"]), nullable=nullable) for field in fields } - return dt.Struct(fields, nullable=cls.default_nullable) + return dt.Struct(fields, nullable=nullable) @classmethod def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType: