Skip to content

Commit

Permalink
[fix](Nereids) move tables from connect context to statement context (#…
Browse files Browse the repository at this point in the history
…44568)

Problem Summary:
When using tables in connect context, it would keep on memory in next
run in the same session, but it should not be in memory when running
next sql statement
  • Loading branch information
LiBinfeng-01 authored Dec 3, 2024
1 parent 4b84de4 commit 80c2b8d
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,9 @@ public static class Lock implements AutoCloseable {
public Lock(LogicalPlan plan, CascadesContext cascadesContext) {
this.cascadesContext = cascadesContext;
// tables can also be load from dump file
if (cascadesContext.tables == null) {
if (cascadesContext.getTables() == null || cascadesContext.getTables().isEmpty()) {
cascadesContext.extractTables(plan);
cascadesContext.getStatementContext().setTables(cascadesContext.getTables());
}
for (TableIf table : cascadesContext.tables.values()) {
if (!table.needReadLockWhenPlan()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,8 @@ private void setRuntimeFilterWaitTimeByTableRowCountAndType() {

private void initCascadesContext(LogicalPlan plan, PhysicalProperties requireProperties) {
cascadesContext = CascadesContext.initContext(statementContext, plan, requireProperties);
if (statementContext.getConnectContext().getTables() != null) {
cascadesContext.setTables(statementContext.getConnectContext().getTables());
if (statementContext.getTables() != null) {
cascadesContext.setTables(statementContext.getTables());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.doris.datasource.mvcc.MvccSnapshot;
import org.apache.doris.datasource.mvcc.MvccTable;
import org.apache.doris.datasource.mvcc.MvccTableInfo;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.hint.Hint;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.rules.analysis.ColumnAliasGenerator;
Expand All @@ -53,6 +54,7 @@
import org.apache.doris.system.Backend;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
Expand Down Expand Up @@ -150,6 +152,9 @@ public class StatementContext implements Closeable {
// placeholder params for prepared statement
private List<Placeholder> placeholders;

// tables used for plan replayer
private Map<List<String>, TableIf> tables = null;

// for create view support in nereids
// key is the start and end position of the sql substring that needs to be replaced,
// and value is the new string used for replacement.
Expand Down Expand Up @@ -213,6 +218,30 @@ public StatementContext(ConnectContext connectContext, OriginStatement originSta
}
}

public Map<List<String>, TableIf> getTables() {
if (tables == null) {
tables = Maps.newHashMap();
}
return tables;
}

public void setTables(Map<List<String>, TableIf> tables) {
this.tables = tables;
}

/** get table by table name, try to get from information from dumpfile first */
public TableIf getTableInMinidumpCache(List<String> tableQualifier) {
if (!getConnectContext().getSessionVariable().isPlayNereidsDump()) {
return null;
}
Preconditions.checkState(tables != null, "tables should not be null");
TableIf table = tables.getOrDefault(tableQualifier, null);
if (getConnectContext().getSessionVariable().isPlayNereidsDump() && table == null) {
throw new AnalysisException("Minidump cache can not find table:" + tableQualifier);
}
return table;
}

public void setConnectContext(ConnectContext connectContext) {
this.connectContext = connectContext;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ public static void main(String[] args) {

StatementContext statementContext = new StatementContext(ConnectContext.get(),
new OriginStatement(minidump.getSql(), 0));
statementContext.setTables(minidump.getTables());
ConnectContext.get().setStatementContext(statementContext);
JSONObject resultPlan = MinidumpUtils.executeSql(minidump.getSql());
JSONObject minidumpResult = new JSONObject(minidump.getResultPlanJson());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ public static void setConnectContext(Minidump minidump) {
connectContext.setThreadLocalInfo();
Env.getCurrentEnv().setColocateTableIndex(minidump.getColocateTableIndex());
connectContext.setSessionVariable(minidump.getSessionVariable());
connectContext.setTables(minidump.getTables());
connectContext.setDatabase(minidump.getDbName());
connectContext.getSessionVariable().setPlanNereidsDump(true);
connectContext.getSessionVariable().enableNereidsTimeout = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ private LogicalPlan bindWithCurrentDb(CascadesContext cascadesContext, UnboundRe
List<String> tableQualifier = RelationUtil.getQualifierName(cascadesContext.getConnectContext(),
unboundRelation.getNameParts());
TableIf table = null;
table = ConnectContext.get().getTableInMinidumpCache(tableQualifier);
table = ConnectContext.get().getStatementContext().getTableInMinidumpCache(tableQualifier);
if (table == null) {
if (customTableResolver.isPresent()) {
table = customTableResolver.get().apply(tableQualifier);
Expand All @@ -182,7 +182,7 @@ private LogicalPlan bindWithCurrentDb(CascadesContext cascadesContext, UnboundRe
if (table == null) {
table = RelationUtil.getTable(tableQualifier, cascadesContext.getConnectContext().getEnv());
}
ConnectContext.get().getTables().put(tableQualifier, table);
ConnectContext.get().getStatementContext().getTables().put(tableQualifier, table);

// TODO: should generate different Scan sub class according to table's type
LogicalPlan scan = getLogicalPlan(table, unboundRelation, tableQualifier, cascadesContext);
Expand All @@ -201,13 +201,13 @@ private LogicalPlan bind(CascadesContext cascadesContext, UnboundRelation unboun
if (customTableResolver.isPresent()) {
table = customTableResolver.get().apply(tableQualifier);
}
table = ConnectContext.get().getTableInMinidumpCache(tableQualifier);
table = ConnectContext.get().getStatementContext().getTableInMinidumpCache(tableQualifier);
// In some cases even if we have already called the "cascadesContext.getTableByName",
// it also gets the null. So, we just check it in the catalog again for safety.
if (table == null) {
table = RelationUtil.getTable(tableQualifier, cascadesContext.getConnectContext().getEnv());
}
ConnectContext.get().getTables().put(tableQualifier, table);
ConnectContext.get().getStatementContext().getTables().put(tableQualifier, table);
return getLogicalPlan(table, unboundRelation, tableQualifier, cascadesContext);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ private void handleLoad() throws Exception {
// 3. run nereids planner with sql in minidump file
StatementContext statementContext = new StatementContext(ConnectContext.get(),
new OriginStatement(minidump.getSql(), 0));
statementContext.setTables(minidump.getTables());
ConnectContext.get().setStatementContext(statementContext);
JSONObject resultPlan = MinidumpUtils.executeSql(minidump.getSql());
JSONObject minidumpResult = new JSONObject(minidump.getResultPlanJson());
Expand Down
29 changes: 0 additions & 29 deletions fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.apache.doris.catalog.DatabaseIf;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.catalog.Type;
import org.apache.doris.cloud.qe.ComputeGroupException;
import org.apache.doris.cloud.system.CloudSystemInfoService;
Expand All @@ -54,7 +53,6 @@
import org.apache.doris.mysql.ProxyMysqlChannel;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.stats.StatsErrorEstimator;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.plsql.Exec;
Expand All @@ -73,7 +71,6 @@
import org.apache.doris.transaction.TransactionEntry;
import org.apache.doris.transaction.TransactionStatus;

import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
Expand Down Expand Up @@ -267,8 +264,6 @@ public void setUserInsertTimeout(int insertTimeout) {
// new planner
private Map<String, PreparedStatementContext> preparedStatementContextMap = Maps.newHashMap();

private Map<List<String>, TableIf> tables = null;

private Map<String, ColumnStatistic> totalColumnStatisticMap = new HashMap<>();

public Map<String, ColumnStatistic> getTotalColumnStatisticMap() {
Expand Down Expand Up @@ -433,30 +428,6 @@ public PreparedStatementContext getPreparedStementContext(String stmtName) {
return this.preparedStatementContextMap.get(stmtName);
}

public Map<List<String>, TableIf> getTables() {
if (tables == null) {
tables = Maps.newHashMap();
}
return tables;
}

public void setTables(Map<List<String>, TableIf> tables) {
this.tables = tables;
}

/** get table by table name, try to get from information from dumpfile first */
public TableIf getTableInMinidumpCache(List<String> tableQualifier) {
if (!getSessionVariable().isPlayNereidsDump()) {
return null;
}
Preconditions.checkState(tables != null, "tables should not be null");
TableIf table = tables.getOrDefault(tableQualifier, null);
if (getSessionVariable().isPlayNereidsDump() && table == null) {
throw new AnalysisException("Minidump cache can not find table:" + tableQualifier);
}
return table;
}

public void closeTxn() {
if (isTxnModel()) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.doris.nereids.util;

import org.apache.doris.catalog.TableIf;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.datasets.ssb.SSBTestBase;
Expand Down Expand Up @@ -48,10 +47,8 @@ public void testSimple() {
parser.parseSingle(sql),
PhysicalProperties.ANY
);
CascadesContext cascadesContext = planner.getCascadesContext();

Map<List<String>, TableIf> f = cascadesContext.getTables();
Assertions.assertEquals(2, f.size());
Map<List<String>, TableIf> f = statementContext.getTables();
Assertions.assertEquals(1, f.size());
Set<String> tableNames = new HashSet<>();
for (Map.Entry<List<String>, TableIf> entry : f.entrySet()) {
TableIf table = entry.getValue();
Expand All @@ -75,8 +72,7 @@ public void testCTE() {
parser.parseSingle(sql),
PhysicalProperties.ANY
);
CascadesContext cascadesContext = planner.getCascadesContext();
Map<List<String>, TableIf> f = cascadesContext.getTables();
Map<List<String>, TableIf> f = statementContext.getTables();
Assertions.assertEquals(1, f.size());
for (Map.Entry<List<String>, TableIf> entry : f.entrySet()) {
TableIf table = entry.getValue();
Expand All @@ -93,8 +89,7 @@ public void testSubQuery() {
parser.parseSingle(sql),
PhysicalProperties.ANY
);
CascadesContext cascadesContext = planner.getCascadesContext();
Map<List<String>, TableIf> f = cascadesContext.getTables();
Map<List<String>, TableIf> f = statementContext.getTables();
Assertions.assertEquals(1, f.size());
for (Map.Entry<List<String>, TableIf> entry : f.entrySet()) {
TableIf table = entry.getValue();
Expand All @@ -111,8 +106,7 @@ public void testScalarSubQuery() {
parser.parseSingle(sql),
PhysicalProperties.ANY
);
CascadesContext cascadesContext = planner.getCascadesContext();
Map<List<String>, TableIf> f = cascadesContext.getTables();
Map<List<String>, TableIf> f = statementContext.getTables();
Assertions.assertEquals(2, f.size());
Set<String> tableNames = new HashSet<>();
for (Map.Entry<List<String>, TableIf> entry : f.entrySet()) {
Expand All @@ -134,15 +128,14 @@ public void testInserInto() {
(LogicalPlan) insertIntoTableCommand.getExplainPlan(connectContext),
PhysicalProperties.ANY
);
CascadesContext cascadesContext = planner.getCascadesContext();
Map<List<String>, TableIf> f = cascadesContext.getTables();
Assertions.assertEquals(2, f.size());
Map<List<String>, TableIf> f = statementContext.getTables();
// when table in insert would not be added to statement context, but be lock when insert
Assertions.assertEquals(1, f.size());
Set<String> tableNames = new HashSet<>();
for (Map.Entry<List<String>, TableIf> entry : f.entrySet()) {
TableIf table = entry.getValue();
tableNames.add(table.getName());
}
Assertions.assertTrue(tableNames.contains("supplier"));
Assertions.assertTrue(tableNames.contains("lineorder"));
}
}

0 comments on commit 80c2b8d

Please sign in to comment.