From 82055816d6420465851f8a2ca8d3722ce161379d Mon Sep 17 00:00:00 2001 From: chenjian2664 Date: Wed, 11 Dec 2024 22:14:47 +0800 Subject: [PATCH] Allow left side as update target in join pushdown MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Allow left side as update target when try to pushdown join into table scan. The change prevent the pushdown join into the table scan instead of throwing exception Co-Authored-By: Ɓukasz Osipiuk --- .../iterative/rule/PushJoinIntoTableScan.java | 5 +++-- .../trino/plugin/kudu/TestKuduConnectorTest.java | 14 ++++++++++++++ .../mariadb/BaseMariaDbFailureRecoveryTest.java | 11 ----------- .../plugin/mysql/BaseMySqlFailureRecoveryTest.java | 11 ----------- .../oracle/BaseOracleFailureRecoveryTest.java | 11 ----------- .../BasePostgresFailureRecoveryTest.java | 2 +- .../redshift/BaseRedshiftFailureRecoveryTest.java | 11 ----------- .../BaseSqlServerFailureRecoveryTest.java | 11 ----------- .../java/io/trino/testing/BaseConnectorTest.java | 12 ++++++++++++ 9 files changed, 30 insertions(+), 58 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java index e6a71664f953..37f0054ae67d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java @@ -46,7 +46,6 @@ import java.util.Map; import java.util.Optional; -import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors; @@ -103,7 +102,9 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) TableScanNode left = captures.get(LEFT_TABLE_SCAN); TableScanNode right = captures.get(RIGHT_TABLE_SCAN); - verify(!left.isUpdateTarget() && !right.isUpdateTarget(), "Unexpected Join over for-update table scan"); + if (left.isUpdateTarget() && !right.isUpdateTarget()) { + return Result.empty(); + } Expression effectiveFilter = getEffectiveFilter(joinNode); ConnectorExpressionTranslation translation = ConnectorExpressionTranslator.translateConjuncts( diff --git a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java index 73de900f28ba..6150f160a19d 100644 --- a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java +++ b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java @@ -995,6 +995,20 @@ public void testUpdateRowConcurrently() abort("Kudu doesn't support concurrent update of different columns in a row"); } + @Test + @Override + protected void testUpdateWithSubquery() + { + withTableName("test_update_with_subquery", tableName -> { + createTableForWrites("CREATE TABLE %s " + ORDER_COLUMNS, tableName, Optional.empty()); + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM orders", 15000); + + assertQuery("SELECT count(*) FROM " + tableName + " WHERE shippriority = 101 AND custkey = (SELECT min(custkey) FROM customer)", "VALUES 0"); + assertUpdate("UPDATE " + tableName + " SET shippriority = 101 WHERE custkey = (SELECT min(custkey) FROM customer)", 9); + assertQuery("SELECT count(*) FROM " + tableName + " WHERE shippriority = 101 AND custkey = (SELECT min(custkey) FROM customer)", "VALUES 9"); + }); + } + @Test @Override public void testCreateTableWithTableComment() diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbFailureRecoveryTest.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbFailureRecoveryTest.java index e8eb5128ba7d..dcde5b95b11d 100644 --- a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbFailureRecoveryTest.java +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbFailureRecoveryTest.java @@ -26,9 +26,6 @@ import java.util.Map; import java.util.Optional; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.junit.jupiter.api.Assumptions.abort; - public abstract class BaseMariaDbFailureRecoveryTest extends BaseJdbcFailureRecoveryTest { @@ -55,14 +52,6 @@ protected QueryRunner createQueryRunner(List> requiredTpchTables, M .build(); } - @Test - @Override - protected void testUpdateWithSubquery() - { - assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("Unexpected Join over for-update table scan"); - abort("skipped"); - } - @Test @Override protected void testUpdate() diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlFailureRecoveryTest.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlFailureRecoveryTest.java index 1b474e83f11c..494e04deead5 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlFailureRecoveryTest.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlFailureRecoveryTest.java @@ -26,9 +26,6 @@ import java.util.Map; import java.util.Optional; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.junit.jupiter.api.Assumptions.abort; - public abstract class BaseMySqlFailureRecoveryTest extends BaseJdbcFailureRecoveryTest { @@ -58,14 +55,6 @@ protected QueryRunner createQueryRunner( .build(); } - @Test - @Override - protected void testUpdateWithSubquery() - { - assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("Unexpected Join over for-update table scan"); - abort("skipped"); - } - @Test @Override protected void testUpdate() diff --git a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleFailureRecoveryTest.java b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleFailureRecoveryTest.java index 6836fcce8859..490dff0c94eb 100644 --- a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleFailureRecoveryTest.java +++ b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleFailureRecoveryTest.java @@ -26,9 +26,6 @@ import java.util.Map; import java.util.Optional; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.junit.jupiter.api.Assumptions.abort; - public abstract class BaseOracleFailureRecoveryTest extends BaseJdbcFailureRecoveryTest { @@ -59,14 +56,6 @@ protected QueryRunner createQueryRunner( .build(); } - @Test - @Override - protected void testUpdateWithSubquery() - { - assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("Unexpected Join over for-update table scan"); - abort("skipped"); - } - @Test @Override protected void testUpdate() diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/BasePostgresFailureRecoveryTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/BasePostgresFailureRecoveryTest.java index c5f49addd7db..53b2211c17d4 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/BasePostgresFailureRecoveryTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/BasePostgresFailureRecoveryTest.java @@ -74,7 +74,7 @@ protected void testDeleteWithSubquery() @Override protected void testUpdateWithSubquery() { - assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("Unexpected Join over for-update table scan"); + assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("Non-transactional MERGE is disabled"); abort("skipped"); } diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/BaseRedshiftFailureRecoveryTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/BaseRedshiftFailureRecoveryTest.java index e38fa4cf3d78..2856863c5b3a 100644 --- a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/BaseRedshiftFailureRecoveryTest.java +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/BaseRedshiftFailureRecoveryTest.java @@ -26,9 +26,6 @@ import java.util.Map; import java.util.Optional; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.junit.jupiter.api.Assumptions.abort; - public abstract class BaseRedshiftFailureRecoveryTest extends BaseJdbcFailureRecoveryTest { @@ -58,14 +55,6 @@ protected QueryRunner createQueryRunner( .build(); } - @Test - @Override - protected void testUpdateWithSubquery() - { - assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("Unexpected Join over for-update table scan"); - abort("skipped"); - } - @Test @Override protected void testUpdate() diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerFailureRecoveryTest.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerFailureRecoveryTest.java index 4dbce2a01c03..ac8d2967479a 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerFailureRecoveryTest.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerFailureRecoveryTest.java @@ -26,9 +26,6 @@ import java.util.Map; import java.util.Optional; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.junit.jupiter.api.Assumptions.abort; - public abstract class BaseSqlServerFailureRecoveryTest extends BaseJdbcFailureRecoveryTest { @@ -58,14 +55,6 @@ protected QueryRunner createQueryRunner( .build(); } - @Test - @Override - protected void testUpdateWithSubquery() - { - assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("Unexpected Join over for-update table scan"); - abort("skipped"); - } - @Test @Override protected void testUpdate() diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java index e53c808fa9bf..c9fbbc3bffa3 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java @@ -6801,6 +6801,18 @@ public void testMergeAllColumnsReversed() assertUpdate("DROP TABLE " + targetTable); } + @Test + protected void testUpdateWithSubquery() + { + skipTestUnless(hasBehavior(SUPPORTS_MERGE)); + + try (TestTable table = createTestTableForWrites("test_update_with_subquery", " AS SELECT * FROM orders", "orderkey")) { + assertQuery("SELECT count(*) FROM " + table.getName() + " WHERE shippriority = 101 AND custkey = (SELECT min(custkey) FROM customer)", "VALUES 0"); + assertUpdate("UPDATE " + table.getName() + " SET shippriority = 101 WHERE custkey = (SELECT min(custkey) FROM customer)", 9); + assertQuery("SELECT count(*) FROM " + table.getName() + " WHERE shippriority = 101 AND custkey = (SELECT min(custkey) FROM customer)", "VALUES 9"); + } + } + private void verifyUnsupportedTypeException(Throwable exception, String trinoTypeName) { String typeNameBase = trinoTypeName.replaceFirst("\\(.*", "");