Skip to content

Commit

Permalink
Add test for aggregation index join (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
joechenrh authored Aug 2, 2024
1 parent 4f689ce commit 7ca0e0e
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 242 deletions.
21 changes: 15 additions & 6 deletions framework/mainloop.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,16 @@ func (c *DDLCase) statloop() {
for {
select {
case <-tick.C:
subcaseStat := make([]string, 20)
subcaseUseMvindex := make([]string, 20)
subcaseUseCERT := make([]string, 20)
subcaseStat := make([]string, len(c.cases))
subcaseUseMvindex := make([]string, len(c.cases))
subcaseUseCERT := make([]string, len(c.cases))
subcaseUseAggIndexJoin := make([]string, len(c.cases))
for _, c := range c.cases {
subcaseStat = append(subcaseStat, fmt.Sprintf("%d", len(c.queryPlanMap)))
subcaseUseMvindex = append(subcaseUseMvindex, fmt.Sprintf("%d", c.planUseMvIndex))
subcaseUseCERT = append(subcaseUseCERT, fmt.Sprintf("%d", c.checkCERTCnt))
subcaseUseAggIndexJoin = append(subcaseUseAggIndexJoin, fmt.Sprintf("%d", c.aggregationAsInnerSideOfIndexJoin))

//i := 0
//for k, v := range c.queryPlanMap {
// logutil.BgLogger().Warn("sample query plan", zap.String("plan", k), zap.String("query", v))
Expand All @@ -102,8 +105,14 @@ func (c *DDLCase) statloop() {
//}
}

logutil.BgLogger().Info("stat", zap.Int64("run query:", globalRunQueryCnt.Load()), zap.Int64("success:", globalSuccessQueryCnt.Load()), zap.Int64("fetch json row val:", sqlgenerator.GlobalFetchJsonRowValCnt.Load()),
zap.Strings("unique query plan", subcaseStat), zap.Strings("use mv index", subcaseUseMvindex), zap.Strings("use CERT", subcaseUseCERT))
logutil.BgLogger().Info("stat", zap.Int64("run query:", globalRunQueryCnt.Load()),
zap.Int64("success:", globalSuccessQueryCnt.Load()),
zap.Int64("fetch json row val:", sqlgenerator.GlobalFetchJsonRowValCnt.Load()),
zap.Strings("unique query plan", subcaseStat),
zap.Strings("use mv index", subcaseUseMvindex),
zap.Strings("use CERT", subcaseUseCERT),
zap.Strings("use agg index join", subcaseUseAggIndexJoin),
)
}
}
}
Expand Down Expand Up @@ -594,7 +603,7 @@ func (c *testCase) execute(ctx context.Context) error {
}
log.Infof("tableMetas %d", len(tableMetas))
state.SetTableMeta(tableMetas)
state.PrepareIndexJoinColumns()
sqlgenerator.PrepareIndexJoinColumns(state)

cnt := 0

Expand Down
2 changes: 0 additions & 2 deletions framework/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ func Run(dbAddr string, dbName string, concurrency int, tablesToCreate int, mysq
if err := ddl.Execute(ctx, dbss); err != nil {
log.Fatalf("[ddl] execute error %v", err)
}
// Enable index join on aggregation
globalDbs.Exec("set GLOBAL tidb_enable_inl_join_inner_multi_pattern='ON'")
}

var dmlIgnoreList = []string{
Expand Down
2 changes: 2 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ func prepareEnv() {
}
}
}
// Enable index join on aggregation
tidbC.ExecContext(context.Background(), "set GLOBAL tidb_enable_inl_join_inner_multi_pattern='ON'")
tidbC.Close()

mysql.SetLogger(log.Logger())
Expand Down
251 changes: 251 additions & 0 deletions sqlgenerator/aggregate_index_join.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
package sqlgenerator

import (
"fmt"
"math/rand"
)

type JoinColumn struct {
outerTable *Table
innerTable *Table
outerColumns []*Column
innerColumns []*Column
}

var allJoinColumns map[*State][]*JoinColumn

func init() {
allJoinColumns = make(map[*State][]*JoinColumn)
}

var indexJoinType = map[ColumnType][]ColumnType{
ColumnTypeBoolean: {
ColumnTypeBoolean, ColumnTypeTinyInt, ColumnTypeSmallInt, ColumnTypeMediumInt, ColumnTypeInt,
ColumnTypeBigInt, ColumnTypeDouble, ColumnTypeDecimal, ColumnTypeYear, ColumnTypeBit},
ColumnTypeTinyInt: {
ColumnTypeBoolean, ColumnTypeTinyInt, ColumnTypeSmallInt, ColumnTypeMediumInt, ColumnTypeInt,
ColumnTypeBigInt, ColumnTypeDouble, ColumnTypeDecimal, ColumnTypeYear, ColumnTypeBit},
ColumnTypeSmallInt: {
ColumnTypeBoolean, ColumnTypeTinyInt, ColumnTypeSmallInt, ColumnTypeMediumInt, ColumnTypeInt,
ColumnTypeBigInt, ColumnTypeFloat, ColumnTypeDouble, ColumnTypeDecimal, ColumnTypeYear, ColumnTypeBit},
ColumnTypeMediumInt: {
ColumnTypeBoolean, ColumnTypeTinyInt, ColumnTypeSmallInt, ColumnTypeMediumInt, ColumnTypeInt,
ColumnTypeBigInt, ColumnTypeFloat, ColumnTypeDouble, ColumnTypeDecimal, ColumnTypeYear, ColumnTypeBit},
ColumnTypeInt: {
ColumnTypeBoolean, ColumnTypeTinyInt, ColumnTypeSmallInt, ColumnTypeMediumInt, ColumnTypeInt,
ColumnTypeBigInt, ColumnTypeFloat, ColumnTypeDouble, ColumnTypeDecimal, ColumnTypeYear, ColumnTypeBit},
ColumnTypeBigInt: {
ColumnTypeBoolean, ColumnTypeTinyInt, ColumnTypeSmallInt, ColumnTypeMediumInt, ColumnTypeInt,
ColumnTypeBigInt, ColumnTypeFloat, ColumnTypeDouble, ColumnTypeDecimal, ColumnTypeYear, ColumnTypeBit},
ColumnTypeFloat: {ColumnTypeFloat, ColumnTypeDouble},
ColumnTypeDouble: {ColumnTypeFloat, ColumnTypeDouble},
ColumnTypeDecimal: {ColumnTypeFloat, ColumnTypeDouble, ColumnTypeDecimal},
ColumnTypeChar: {
ColumnTypeFloat, ColumnTypeDouble, ColumnTypeChar, ColumnTypeVarchar,
ColumnTypeDate, ColumnTypeDatetime, ColumnTypeTimestamp},
ColumnTypeVarchar: {
ColumnTypeFloat, ColumnTypeDouble, ColumnTypeChar, ColumnTypeVarchar,
ColumnTypeDate, ColumnTypeDatetime, ColumnTypeTimestamp},
ColumnTypeDate: {ColumnTypeFloat, ColumnTypeDouble, ColumnTypeDate, ColumnTypeDatetime, ColumnTypeTimestamp},
ColumnTypeTime: {
ColumnTypeFloat, ColumnTypeDouble, ColumnTypeChar, ColumnTypeVarchar,
ColumnTypeDate, ColumnTypeTime, ColumnTypeDatetime, ColumnTypeTimestamp},
ColumnTypeDatetime: {ColumnTypeFloat, ColumnTypeDouble, ColumnTypeDate, ColumnTypeDatetime, ColumnTypeTimestamp},
ColumnTypeTimestamp: {ColumnTypeFloat, ColumnTypeDouble, ColumnTypeDate, ColumnTypeDatetime, ColumnTypeTimestamp},
ColumnTypeYear: {
ColumnTypeBoolean, ColumnTypeTinyInt, ColumnTypeSmallInt, ColumnTypeMediumInt, ColumnTypeInt,
ColumnTypeBigInt, ColumnTypeFloat, ColumnTypeDouble, ColumnTypeDecimal, ColumnTypeDate,
ColumnTypeDatetime, ColumnTypeTimestamp, ColumnTypeYear, ColumnTypeBit},
ColumnTypeBit: {
ColumnTypeBoolean, ColumnTypeTinyInt, ColumnTypeSmallInt, ColumnTypeMediumInt, ColumnTypeInt,
ColumnTypeBigInt, ColumnTypeFloat, ColumnTypeDouble, ColumnTypeDecimal, ColumnTypeYear, ColumnTypeBit},
}

func PrepareIndexJoinColumns(s *State) {
var CheckFunc = func(left, right ColumnType) bool {
if matches, ok := indexJoinType[left]; ok {
for _, match := range matches {
if right == match {
return true
}
}
}
return false
}

joinColumns := make([]*JoinColumn, 0)

// Enumerate possible join columns
for i := 0; i < len(s.Tables); i++ {
for j := 0; j < len(s.Tables); j++ {
outerTable := s.Tables[i]
innerTable := s.Tables[j]

var comb = JoinColumn{
outerTable: outerTable,
innerTable: innerTable,
outerColumns: make([]*Column, 0),
innerColumns: make([]*Column, 0),
}

for _, outerCol := range outerTable.Columns {
for _, innerIndex := range innerTable.Indexes {
if CheckFunc(outerCol.Tp, innerIndex.Columns[0].Tp) {
comb.outerColumns = append(comb.outerColumns, outerCol)
comb.innerColumns = append(comb.innerColumns, innerIndex.Columns[0])
}
}
}

if len(comb.innerColumns) > 0 {
joinColumns = append(joinColumns, &comb)
}
}
}

allJoinColumns[s] = joinColumns
}

func RandJoinColumn(s *State) (*Table, *Table, *Column, *Column) {
joinColumns := allJoinColumns[s]

totalNum := 0
for _, joinColumn := range joinColumns {
totalNum += len(joinColumn.innerColumns)
}

if totalNum == 0 {
return nil, nil, nil, nil
}

idx := rand.Intn(totalNum)
for _, joinColumn := range joinColumns {
if idx < len(joinColumn.innerColumns) {
return joinColumn.outerTable, joinColumn.innerTable,
joinColumn.outerColumns[idx], joinColumn.innerColumns[idx]
}
idx -= len(joinColumn.innerColumns)
}

return nil, nil, nil, nil
}

var MultiSelectWithIndexJoin = NewFn(func(state *State) Fn {
tbl1, tbl2, col1, col2 := RandJoinColumn(state)
if tbl1 == nil {
return NoneBecauseOf(fmt.Errorf("not initialized"))
}

// Generate subquery
state.IncSubQueryDeep()
st := state.GenSubQuery()
state.env.QState = &QueryState{
SelectedCols: map[*Table]QueryStateColumns{
tbl2: {
Columns: tbl2.Columns,
Attr: make([]string, len(tbl2.Columns)),
},
},
AggCols: make(map[*Table]Columns),
}
state.env.QState.SelectedCols[tbl2].Attr[col2.Idx] = ChosenSelection

def, err := SimpleAggSelect.Eval(state)
if err != nil {
return NoneBecauseOf(err)
}
var cts []ColumnType
cts, err = getTypeOfExpressions(def, "test", state.tableMeta)
if err != nil {
return NoneBecauseOf(err)
}
for _, t := range cts {
st.AppendColumn(state.GenNewColumnWithType(t))
}
// Reset column name
for i, c := range st.Columns {
c.Name = fmt.Sprintf("r%d", i)
}
state.PushSubQuery(st)

tbl2Str := fmt.Sprintf("(%s) %s", def, st.Name)
sq := state.PopSubQuery()
sq[0].SubQueryDef = tbl2Str
state.env.QState = &QueryState{
SelectedCols: map[*Table]QueryStateColumns{
tbl1: {
Columns: tbl1.Columns,
Attr: make([]string, len(tbl1.Columns)),
},
sq[0]: {
Columns: sq[0].Columns,
Attr: make([]string, len(sq[0].Columns)),
},
},
AggCols: make(map[*Table]Columns),
}

joinHint := Str(fmt.Sprintf("/*+ inl_join(%s) */", sq[0].Name))
joinPredicate := Str(
fmt.Sprintf("on %s.%s = %s.%s",
tbl1.Name, col1.Name, st.Name, sq[0].Columns[0].Name))

tblNames := []Fn{Str(tbl1.Name), Str(tbl2Str)}
join := Join(tblNames, Or(Str("left join"), Str("inner join")))

return And(
Str("select"), joinHint, SelectFields,
Str("from"), join, joinPredicate, Opt(OrderBy), Opt(Limit),
)
})

var SimpleAggSelect = NewFn(func(state *State) Fn {
state.env.QState.IsAgg = true
tbl := state.env.QState.GetRandTable()

groupByColsCnt := rand.Intn(3)
groupByCols := tbl.Columns.RandGiveN(groupByColsCnt)
for i, attr := range state.env.QState.SelectedCols[tbl].Attr {
if attr == ChosenSelection {
groupByCols = append([]*Column{tbl.Columns[i]}, groupByCols...)
}
}
state.env.QState.AggCols[tbl] = groupByCols

return And(
Str("select"), Opt(HintAggToCop), SimpleSelectFields, Str("from"),
TableReference, WhereClause, GroupByColumns,
)
})

var SimpleSelectFields = NewFn(func(state *State) Fn {
queryState := state.env.QState
queryState.FieldNumHint = 2 + rand.Intn(4)

tbl := queryState.GetRandTable()

var fns []Fn

// We need at least one column for join and one aggregation function
fns = append(fns, NewFn(func(state *State) Fn {
state.env.Table = tbl
return Str(fmt.Sprintf("%s.%s as r0", tbl.Name, state.env.QState.AggCols[tbl][0].Name))
}))
fns = append(fns, Str(","))
fns = append(fns, NewFn(func(state *State) Fn {
state.env.Table = tbl
state.env.QColumns = queryState.SelectedCols[state.env.Table]
return And(AggFunction, Str("as r1"))
}))

for i := 2; i < queryState.FieldNumHint; i++ {
fieldID := fmt.Sprintf("r%d", i)
fns = append(fns, Str(","))
fns = append(fns, NewFn(func(state *State) Fn {
state.env.Table = tbl
state.env.QColumns = queryState.SelectedCols[state.env.Table]
return And(Or(SelectFieldName, AggFunction), Str("as"), Str(fieldID))
}))
}
return And(fns...)
})
Loading

0 comments on commit 7ca0e0e

Please sign in to comment.