Skip to content

Commit

Permalink
Fix nullability checks in evalengine (vitessio#14556)
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <[email protected]>
Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
GuptaManan100 committed Nov 22, 2023
1 parent 72443ea commit 9badb73
Show file tree
Hide file tree
Showing 19 changed files with 142 additions and 69 deletions.
2 changes: 1 addition & 1 deletion go/mysql/collations/integration/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func mysqlconn(t *testing.T) *mysql.Conn {
if err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(conn.ServerVersion, "8.0.") {
if !strings.HasPrefix(conn.ServerVersion, "8.") {
conn.Close()
t.Skipf("collation integration tests are only supported in MySQL 8.0+")
}
Expand Down
24 changes: 17 additions & 7 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc coll
asm.emit(func(env *ExpressionEnv) int {
end := env.vm.sp - elseOffset
for sp := env.vm.sp - stackDepth; sp < end; sp += 2 {
if env.vm.stack[sp].(*evalInt64).i != 0 {
if env.vm.stack[sp] != nil && env.vm.stack[sp].(*evalInt64).i != 0 {
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation)
goto done
}
Expand Down Expand Up @@ -780,16 +780,18 @@ func (asm *assembler) Convert_bB(offset int) {
var f float64
if arg != nil {
f, _ = fastparse.ParseFloat64(arg.(*evalBytes).string())
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(f != 0.0)
}
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(f != 0.0)
return 1
}, "CONV VARBINARY(SP-%d), BOOL", offset)
}

func (asm *assembler) Convert_TB(offset int) {
asm.emit(func(env *ExpressionEnv) int {
arg := env.vm.stack[env.vm.sp-offset]
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && !arg.(*evalTemporal).isZero())
if arg != nil {
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(!arg.(*evalTemporal).isZero())
}
return 1
}, "CONV SQLTYPES(SP-%d), BOOL", offset)
}
Expand Down Expand Up @@ -837,7 +839,9 @@ func (asm *assembler) Convert_Tj(offset int) {
func (asm *assembler) Convert_dB(offset int) {
asm.emit(func(env *ExpressionEnv) int {
arg := env.vm.stack[env.vm.sp-offset]
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && !arg.(*evalDecimal).dec.IsZero())
if arg != nil {
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(!arg.(*evalDecimal).dec.IsZero())
}
return 1
}, "CONV DECIMAL(SP-%d), BOOL", offset)
}
Expand All @@ -857,7 +861,9 @@ func (asm *assembler) Convert_dbit(offset int) {
func (asm *assembler) Convert_fB(offset int) {
asm.emit(func(env *ExpressionEnv) int {
arg := env.vm.stack[env.vm.sp-offset]
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && arg.(*evalFloat).f != 0.0)
if arg != nil {
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg.(*evalFloat).f != 0.0)
}
return 1
}, "CONV FLOAT64(SP-%d), BOOL", offset)
}
Expand Down Expand Up @@ -904,7 +910,9 @@ func (asm *assembler) Convert_Tf(offset int) {
func (asm *assembler) Convert_iB(offset int) {
asm.emit(func(env *ExpressionEnv) int {
arg := env.vm.stack[env.vm.sp-offset]
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && arg.(*evalInt64).i != 0)
if arg != nil {
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg.(*evalInt64).i != 0)
}
return 1
}, "CONV INT64(SP-%d), BOOL", offset)
}
Expand Down Expand Up @@ -984,7 +992,9 @@ func (asm *assembler) Convert_Nj(offset int) {
func (asm *assembler) Convert_uB(offset int) {
asm.emit(func(env *ExpressionEnv) int {
arg := env.vm.stack[env.vm.sp-offset]
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && arg.(*evalUint64).u != 0)
if arg != nil {
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg.(*evalUint64).u != 0)
}
return 1
}, "CONV UINT64(SP-%d), BOOL", offset)
}
Expand Down
32 changes: 32 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,38 @@ func TestCompilerSingle(t *testing.T) {
expression: `INTERVAL(0, 0, 0, -1, NULL, NULL, 1)`,
result: `INT64(5)`,
},
{
expression: `cast(null * 1 as CHAR)`,
result: `NULL`,
},
{
expression: `cast(null + 1 as CHAR)`,
result: `NULL`,
},
{
expression: `cast(null - 1 as CHAR)`,
result: `NULL`,
},
{
expression: `cast(null / 1 as CHAR)`,
result: `NULL`,
},
{
expression: `cast(null % 1 as CHAR)`,
result: `NULL`,
},
{
expression: `1 AND NULL * 1`,
result: `NULL`,
},
{
expression: `case 0 when NULL then 1 else 0 end`,
result: `INT64(0)`,
},
{
expression: `case when null is null then 23 else null end`,
result: `INT64(23)`,
},
}

for _, tc := range testCases {
Expand Down
12 changes: 8 additions & 4 deletions go/vt/vtgate/evalengine/expr_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func (op *opArithAdd) compile(c *compiler, left, right Expr) (ctype, error) {
}

c.asm.jumpDestination(skip1, skip2)
return ctype{Type: sumtype, Col: collationNumeric}, nil
return ctype{Type: sumtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
}

func (op *opArithSub) eval(left, right eval) (eval, error) {
Expand Down Expand Up @@ -274,7 +274,7 @@ func (op *opArithSub) compile(c *compiler, left, right Expr) (ctype, error) {
}

c.asm.jumpDestination(skip1, skip2)
return ctype{Type: subtype, Col: collationNumeric}, nil
return ctype{Type: subtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
}

func (op *opArithMul) eval(left, right eval) (eval, error) {
Expand Down Expand Up @@ -334,7 +334,7 @@ func (op *opArithMul) compile(c *compiler, left, right Expr) (ctype, error) {
}

c.asm.jumpDestination(skip1, skip2)
return ctype{Type: multype, Col: collationNumeric}, nil
return ctype{Type: multype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
}

func (op *opArithDiv) eval(left, right eval) (eval, error) {
Expand Down Expand Up @@ -604,5 +604,9 @@ func (expr *NegateExpr) compile(c *compiler) (ctype, error) {
}

c.asm.jumpDestination(skip)
return ctype{Type: neg, Col: collationNumeric}, nil
return ctype{Type: neg, Flag: nullableFlags(arg.Flag), Col: collationNumeric}, nil
}

func nullableFlags(flag typeFlag) typeFlag {
return flag & (flagNull | flagNullable)
}
8 changes: 4 additions & 4 deletions go/vt/vtgate/evalengine/expr_bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ func (expr *BitwiseExpr) compileBinary(c *compiler, asm_ins_bb, asm_ins_uu func(

asm_ins_uu()
c.asm.jumpDestination(skip1, skip2)
return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil
return ctype{Type: sqltypes.Uint64, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
}

func (expr *BitwiseExpr) compileShift(c *compiler, i int) (ctype, error) {
Expand Down Expand Up @@ -327,8 +327,8 @@ func (expr *BitwiseExpr) compileShift(c *compiler, i int) (ctype, error) {
return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil
}

_ = c.compileToBitwiseUint64(lt, 2)
_ = c.compileToUint64(rt, 1)
lt = c.compileToBitwiseUint64(lt, 2)
rt = c.compileToUint64(rt, 1)

if i < 0 {
c.asm.BitShiftLeft_uu()
Expand All @@ -337,7 +337,7 @@ func (expr *BitwiseExpr) compileShift(c *compiler, i int) (ctype, error) {
}

c.asm.jumpDestination(skip1, skip2)
return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil
return ctype{Type: sqltypes.Uint64, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
}

func (expr *BitwiseExpr) compile(c *compiler) (ctype, error) {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/evalengine/expr_bvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ func (bv *BindVariable) typeof(env *ExpressionEnv, _ []*querypb.Field) (sqltypes
case sqltypes.Null:
return sqltypes.Null, flagNull | flagNullable
case sqltypes.HexNum, sqltypes.HexVal:
return sqltypes.VarBinary, flagHex
return sqltypes.VarBinary, flagHex | flagNullable
default:
return tt, 0
return tt, flagNullable
}
}

Expand Down
11 changes: 8 additions & 3 deletions go/vt/vtgate/evalengine/expr_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,13 @@ func (expr *ComparisonExpr) compile(c *compiler) (ctype, error) {

swapped := false
var skip2 *jump
nullable := true

switch expr.Op.(type) {
case compareNullSafeEQ:
skip2 = c.asm.jumpFrom()
c.asm.Cmp_nullsafe(skip2)
nullable = false
default:
skip2 = c.compileNullCheck1r(rt)
}
Expand Down Expand Up @@ -392,6 +394,9 @@ func (expr *ComparisonExpr) compile(c *compiler) (ctype, error) {
}

cmptype := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}
if nullable {
cmptype.Flag |= nullableFlags(lt.Flag | rt.Flag)
}

switch expr.Op.(type) {
case compareEQ:
Expand Down Expand Up @@ -530,17 +535,17 @@ func (expr *InExpr) compile(c *compiler) (ctype, error) {
}

rhs := expr.Right.(TupleExpr)

var rt ctype
if table := expr.compileTable(lhs, rhs); table != nil {
c.asm.In_table(expr.Negate, table)
} else {
_, err := rhs.compile(c)
rt, err = rhs.compile(c)
if err != nil {
return ctype{}, err
}
c.asm.In_slow(expr.Negate)
}
return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}, nil
return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean | (nullableFlags(lhs.Flag) | (rt.Flag & flagNullable))}, nil
}

func (l *LikeExpr) matchWildcard(left, right []byte, coll collations.ID) bool {
Expand Down
10 changes: 7 additions & 3 deletions go/vt/vtgate/evalengine/expr_logical.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func (expr *NotExpr) compile(c *compiler) (ctype, error) {
c.asm.Not_i()
}
c.asm.jumpDestination(skip)
return ctype{Type: sqltypes.Int64, Flag: flagNullable | flagIsBoolean, Col: collationNumeric}, nil
return ctype{Type: sqltypes.Int64, Flag: nullableFlags(arg.Flag) | flagIsBoolean, Col: collationNumeric}, nil
}

func (l *LogicalExpr) eval(env *ExpressionEnv) (eval, error) {
Expand Down Expand Up @@ -331,7 +331,7 @@ func (expr *LogicalExpr) compile(c *compiler) (ctype, error) {

c.asm.LogicalRight(expr.opname)
c.asm.jumpDestination(jump)
return ctype{Type: sqltypes.Int64, Flag: flagNullable | flagIsBoolean, Col: collationNumeric}, nil
return ctype{Type: sqltypes.Int64, Flag: ((lt.Flag | rt.Flag) & flagNullable) | flagIsBoolean, Col: collationNumeric}, nil
}

func intervalCompare(n, val eval) (int, bool, error) {
Expand Down Expand Up @@ -629,7 +629,11 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
}
}

ct := ctype{Type: ta.result(), Col: ca.result()}
var f typeFlag
if ta.nullable {
f |= flagNullable
}
ct := ctype{Type: ta.result(), Flag: f, Col: ca.result()}
c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col)
return ct, nil
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/evalengine/fn_base64.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (call *builtinToBase64) compile(c *compiler) (ctype, error) {
c.asm.Fn_TO_BASE64(t, col)
c.asm.jumpDestination(skip)

return ctype{Type: t, Col: col}, nil
return ctype{Type: t, Flag: nullableFlags(str.Flag), Col: col}, nil
}

func (call *builtinFromBase64) eval(env *ExpressionEnv) (eval, error) {
Expand Down Expand Up @@ -172,5 +172,5 @@ func (call *builtinFromBase64) compile(c *compiler) (ctype, error) {
c.asm.Fn_FROM_BASE64(t)
c.asm.jumpDestination(skip)

return ctype{Type: t, Col: collationBinary}, nil
return ctype{Type: t, Flag: nullableFlags(str.Flag), Col: collationBinary}, nil
}
4 changes: 2 additions & 2 deletions go/vt/vtgate/evalengine/fn_bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ func (expr *builtinBitCount) compile(c *compiler) (ctype, error) {
if ct.Type == sqltypes.VarBinary && !ct.isHexOrBitLiteral() {
c.asm.BitCount_b()
c.asm.jumpDestination(skip)
return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil
return ctype{Type: sqltypes.Int64, Flag: nullableFlags(ct.Flag), Col: collationBinary}, nil
}

_ = c.compileToBitwiseUint64(ct, 1)
c.asm.BitCount_u()
c.asm.jumpDestination(skip)
return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil
return ctype{Type: sqltypes.Int64, Flag: nullableFlags(ct.Flag), Col: collationBinary}, nil
}
Loading

0 comments on commit 9badb73

Please sign in to comment.