diff --git a/pkg/sql/logictest/testdata/logic_test/drop_function b/pkg/sql/logictest/testdata/logic_test/drop_function index d1f06990587f..de532b722b85 100644 --- a/pkg/sql/logictest/testdata/logic_test/drop_function +++ b/pkg/sql/logictest/testdata/logic_test/drop_function @@ -367,3 +367,27 @@ CREATE FUNCTION f_char(c BPCHAR) RETURNS INT LANGUAGE SQL AS 'SELECT 1' statement ok DROP FUNCTION f_char(CHAR(2)) + +statement ok +CREATE FUNCTION f_bit(c BIT) RETURNS INT LANGUAGE SQL AS 'SELECT 1' + +statement ok +DROP FUNCTION f_bit(BIT(0)) + +statement ok +CREATE FUNCTION f_bit(c BIT(2)) RETURNS INT LANGUAGE SQL AS 'SELECT 1' + +statement ok +DROP FUNCTION f_bit(BIT(0)) + +statement ok +CREATE FUNCTION f_bit(c BIT(0)) RETURNS INT LANGUAGE SQL AS 'SELECT 1' + +statement ok +DROP FUNCTION f_bit(BIT(0)) + +statement ok +CREATE FUNCTION f_bit(c BIT(0)) RETURNS INT LANGUAGE SQL AS 'SELECT 1' + +statement ok +DROP FUNCTION f_bit(BIT(2)) diff --git a/pkg/sql/logictest/testdata/logic_test/typing b/pkg/sql/logictest/testdata/logic_test/typing index 4a411a75f888..8c010571fefa 100644 --- a/pkg/sql/logictest/testdata/logic_test/typing +++ b/pkg/sql/logictest/testdata/logic_test/typing @@ -310,6 +310,69 @@ WHERE t108360_1.t = (CASE WHEN t108360_1.t > t108360_2.c THEN t108360_1.t ELSE t ---- tt +# Regression test for #131346. Ensure CASE is typed correctly and that a cast to +# the BIT type with an unspecified length behaves correctly during distributed +# execution. +statement ok +CREATE TABLE t131346v (v VARBIT); +INSERT INTO t131346v VALUES ('11'); + +statement ok +CREATE TABLE t131346b (b BIT); +INSERT INTO t131346b VALUES ('0'); + +query T +SELECT v FROM t131346v, t131346b +WHERE v NOT BETWEEN v AND + (CASE WHEN NULL THEN '0'::BIT ELSE v END) +---- + +query T +SELECT v FROM t131346v, t131346b +WHERE v NOT BETWEEN v AND + (CASE WHEN NULL THEN b ELSE v END) +---- + +query T +SELECT v FROM t131346v, t131346b +WHERE v NOT BETWEEN v AND + IF(NULL, '0'::BIT, v); +---- + +query T +SELECT v FROM t131346v, t131346b +WHERE v NOT BETWEEN v AND + IF(NULL, b, v); +---- + +query T +SELECT (CASE WHEN v > '0'::BIT THEN v ELSE '0'::BIT END) +FROM t131346v, t131346b +WHERE v = (CASE WHEN v > '0'::BIT THEN v ELSE '0'::BIT END) +---- +11 + +query T +SELECT (CASE WHEN v > b THEN v ELSE b END) +FROM t131346v, t131346b +WHERE v = (CASE WHEN v > b THEN v ELSE b END) +---- +11 + +query T +SELECT (CASE WHEN v > '0'::BIT THEN v ELSE '0'::BIT END) +FROM t131346v, t131346b +WHERE v = (CASE WHEN v < '0'::BIT THEN '0'::BIT ELSE v END) +---- +11 + +query T +SELECT (CASE WHEN v > b THEN v ELSE b END) +FROM t131346v, t131346b +WHERE v = (CASE WHEN v < b THEN b ELSE v END) +---- +11 + # Static analysis types should never make it to execution. statement ok CREATE TABLE t83496 ( diff --git a/pkg/sql/opt/optbuilder/scalar.go b/pkg/sql/opt/optbuilder/scalar.go index 2d9f91b37bde..91a55d7d03df 100644 --- a/pkg/sql/opt/optbuilder/scalar.go +++ b/pkg/sql/opt/optbuilder/scalar.go @@ -244,6 +244,8 @@ func (b *Builder) buildScalar( for i := range t.Whens { condExpr := t.Whens[i].Cond.(tree.TypedExpr) cond := b.buildScalar(condExpr, inScope, nil, nil, colRefs) + // TODO(mgartner): Rather than use WithoutTypeModifiers here, + // consider typing the CaseExpr without a type modifier. valExpr, ok := eval.ReType(t.Whens[i].Val.(tree.TypedExpr), valType.WithoutTypeModifiers()) if !ok { panic(pgerror.Newf( diff --git a/pkg/sql/parser/parse.go b/pkg/sql/parser/parse.go index dd921487781d..57a7e52ac82c 100644 --- a/pkg/sql/parser/parse.go +++ b/pkg/sql/parser/parse.go @@ -449,17 +449,18 @@ func GetTypeFromCastOrCollate(expr tree.Expr) (tree.ResolvableTypeReference, err return cast.Type, nil } -var errBitLengthNotPositive = pgerror.WithCandidateCode( - errors.New("length for type bit must be at least 1"), pgcode.InvalidParameterValue) +var errVarBitLengthNotPositive = pgerror.WithCandidateCode( + errors.New("length for type varbit must be at least 1"), pgcode.InvalidParameterValue) // newBitType creates a new BIT type with the given bit width. func newBitType(width int32, varying bool) (*types.T, error) { - if width < 1 { - return nil, errBitLengthNotPositive - } if varying { + if width < 1 { + return nil, errVarBitLengthNotPositive + } return types.MakeVarBit(width), nil } + // The iconst32 pattern in the parser guarantees that the width is positive. return types.MakeBit(width), nil } diff --git a/pkg/sql/parser/testdata/create_table b/pkg/sql/parser/testdata/create_table index 0238fe007724..16d3d9910a0c 100644 --- a/pkg/sql/parser/testdata/create_table +++ b/pkg/sql/parser/testdata/create_table @@ -176,23 +176,24 @@ CREATE TABLE foo(a CHAR(0)) ^ parse -CREATE TABLE a (b BIT VARYING(2), c BIT(1)) +CREATE TABLE a (b BIT VARYING(2), c BIT(1), c BIT, c BIT(0)) ---- -CREATE TABLE a (b VARBIT(2), c BIT) -- normalized! -CREATE TABLE a (b VARBIT(2), c BIT) -- fully parenthesized -CREATE TABLE a (b VARBIT(2), c BIT) -- literals removed -CREATE TABLE _ (_ VARBIT(2), _ BIT) -- identifiers removed +CREATE TABLE a (b VARBIT(2), c BIT, c BIT, c BIT(0)) -- normalized! +CREATE TABLE a (b VARBIT(2), c BIT, c BIT, c BIT(0)) -- fully parenthesized +CREATE TABLE a (b VARBIT(2), c BIT, c BIT, c BIT(0)) -- literals removed +CREATE TABLE _ (_ VARBIT(2), _ BIT, _ BIT, _ BIT(0)) -- identifiers removed error CREATE TABLE test ( - foo BIT(0) + foo BIT(-1) ) ---- -at or near ")": syntax error: length for type bit must be at least 1 +at or near "-": syntax error DETAIL: source SQL: CREATE TABLE test ( - foo BIT(0) - ^ + foo BIT(-1) + ^ +HINT: try \h CREATE TABLE parse diff --git a/pkg/sql/parser/testdata/select_exprs b/pkg/sql/parser/testdata/select_exprs index 0c252b3f5fb7..8d0d329b83b4 100644 --- a/pkg/sql/parser/testdata/select_exprs +++ b/pkg/sql/parser/testdata/select_exprs @@ -952,6 +952,14 @@ SELECT (('foo')::DECIMAL(2,1)) -- fully parenthesized SELECT '_'::DECIMAL(2,1) -- literals removed SELECT 'foo'::DECIMAL(2,1) -- identifiers removed +parse +SELECT 'foo'::BIT +---- +SELECT 'foo'::BIT +SELECT (('foo')::BIT) -- fully parenthesized +SELECT '_'::BIT -- literals removed +SELECT 'foo'::BIT -- identifiers removed + parse SELECT 'foo'::BIT(3) ---- @@ -960,6 +968,14 @@ SELECT (('foo')::BIT(3)) -- fully parenthesized SELECT '_'::BIT(3) -- literals removed SELECT 'foo'::BIT(3) -- identifiers removed +parse +SELECT 'foo'::BIT(0) +---- +SELECT 'foo'::BIT(0) +SELECT (('foo')::BIT(0)) -- fully parenthesized +SELECT '_'::BIT(0) -- literals removed +SELECT 'foo'::BIT(0) -- identifiers removed + parse SELECT 'foo'::VARBIT(3) ---- diff --git a/pkg/sql/sem/cast/type_name.go b/pkg/sql/sem/cast/type_name.go index 6d872d275ffa..50578cd5a6af 100644 --- a/pkg/sql/sem/cast/type_name.go +++ b/pkg/sql/sem/cast/type_name.go @@ -28,6 +28,9 @@ func CastTypeName(t *types.T) string { case oid.T_text: // SQLString returns `string` return "text" + case oid.T_bit: + // SQLString returns `decimal` + return "bit" } return strings.ToLower(t.SQLString()) } diff --git a/pkg/sql/sem/tree/col_types_test.go b/pkg/sql/sem/tree/col_types_test.go index 929d38470044..3b59319cb8bc 100644 --- a/pkg/sql/sem/tree/col_types_test.go +++ b/pkg/sql/sem/tree/col_types_test.go @@ -25,6 +25,7 @@ func TestParseColumnType(t *testing.T) { expectedType *types.T }{ {"BIT", types.MakeBit(1)}, + {"BIT(0)", types.MakeBit(0)}, {"VARBIT", types.MakeVarBit(0)}, {"BIT(2)", types.MakeBit(2)}, {"VARBIT(2)", types.MakeVarBit(2)}, diff --git a/pkg/sql/sem/tree/overload.go b/pkg/sql/sem/tree/overload.go index 4217ceee3575..2a212286fa76 100644 --- a/pkg/sql/sem/tree/overload.go +++ b/pkg/sql/sem/tree/overload.go @@ -497,7 +497,10 @@ func (p ParamTypes) MatchAtIdentical(typ *types.T, i int) bool { p[i].Typ.Identical(typ) || // Special case for CHAR, CHAR(N), and BPCHAR which are not "identical" // but have the same OID. See #129007. - (p[i].Typ.Oid() == oid.T_bpchar && typ.Oid() == oid.T_bpchar)) + (p[i].Typ.Oid() == oid.T_bpchar && typ.Oid() == oid.T_bpchar) || + // Special case for BIT, BIT(N), and BIT(0) which are not "identical" + // but have the same OID. See #132944. + (p[i].Typ.Oid() == oid.T_bit && typ.Oid() == oid.T_bit)) } // MatchLen is part of the TypeList interface. diff --git a/pkg/sql/types/types.go b/pkg/sql/types/types.go index 96b5f31eff4e..728e11ce4230 100644 --- a/pkg/sql/types/types.go +++ b/pkg/sql/types/types.go @@ -115,6 +115,7 @@ import ( // | FLOAT4 | FLOAT | T_float4 | 0 | 0 | // | | | | | | // | BIT | BIT | T_bit | 0 | 1 | +// | BIT(0) | BIT | T_bit | 0 | 0 | // | BIT(N) | BIT | T_bit | 0 | N | // | VARBIT | BIT | T_varbit | 0 | 0 | // | VARBIT(N) | BIT | T_varbit | 0 | N | @@ -716,9 +717,10 @@ var ( // Unexported wrapper types. var ( - // typeBit is the SQL BIT type. It is not exported to avoid confusion with - // the VarBit type, and confusion over whether its default Width is - // unspecified or is 1. More commonly used instead is the VarBit type. + // typeBit is the SQL BIT type with an unspecified width. It is not exported + // to avoid confusion with the VarBit type, and confusion over whether its + // default Width is unspecified or is 1. More commonly used instead is the + // VarBit type. typeBit = &T{InternalType: InternalType{ Family: BitFamily, Oid: oid.T_bit, Locale: &emptyLocale}} @@ -1928,17 +1930,27 @@ func (t *T) InformationSchemaName() string { func (t *T) SQLString() string { switch t.Family() { case BitFamily: - o := t.Oid() - typName := "BIT" - if o == oid.T_varbit { - typName = "VARBIT" - } - // BIT(1) pretty-prints as just BIT. - if (o != oid.T_varbit && t.Width() > 1) || - (o == oid.T_varbit && t.Width() > 0) { - typName = fmt.Sprintf("%s(%d)", typName, t.Width()) + switch t.Oid() { + case oid.T_bit: + typName := "BIT" + // BIT(1) pretty-prints as just BIT. + // BIT(0) represents a BIT type with unspecified length. This is a + // divergence from Postgres which does not allow this type and has + // no way to represent it in SQL. It is required in order for it to + // be correctly serialized into SQL and evaluated during distributed + // query execution. VARBIT cannot be used because it has a different + // OID. + if t.Width() != 1 { + typName = fmt.Sprintf("%s(%d)", typName, t.Width()) + } + return typName + default: + typName := "VARBIT" + if t.Width() > 0 { + typName = fmt.Sprintf("%s(%d)", typName, t.Width()) + } + return typName } - return typName case IntFamily: switch t.Width() { case 16: