Skip to content

Commit

Permalink
Allow left side as update target in join pushdown
Browse files Browse the repository at this point in the history
Allow left side as update target when try to pushdown join
into table scan. The change prevent the pushdown join into the
table scan instead of throwing exception

Co-Authored-By: Łukasz Osipiuk <[email protected]>
  • Loading branch information
chenjian2664 and losipiuk committed Dec 18, 2024
1 parent 1c5e621 commit 8205581
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors;
Expand Down Expand Up @@ -103,7 +102,9 @@ public Result apply(JoinNode joinNode, Captures captures, Context context)
TableScanNode left = captures.get(LEFT_TABLE_SCAN);
TableScanNode right = captures.get(RIGHT_TABLE_SCAN);

verify(!left.isUpdateTarget() && !right.isUpdateTarget(), "Unexpected Join over for-update table scan");
if (left.isUpdateTarget() && !right.isUpdateTarget()) {
return Result.empty();
}

Expression effectiveFilter = getEffectiveFilter(joinNode);
ConnectorExpressionTranslation translation = ConnectorExpressionTranslator.translateConjuncts(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,20 @@ public void testUpdateRowConcurrently()
abort("Kudu doesn't support concurrent update of different columns in a row");
}

@Test
@Override
protected void testUpdateWithSubquery()
{
withTableName("test_update_with_subquery", tableName -> {
createTableForWrites("CREATE TABLE %s " + ORDER_COLUMNS, tableName, Optional.empty());
assertUpdate("INSERT INTO " + tableName + " SELECT * FROM orders", 15000);

assertQuery("SELECT count(*) FROM " + tableName + " WHERE shippriority = 101 AND custkey = (SELECT min(custkey) FROM customer)", "VALUES 0");
assertUpdate("UPDATE " + tableName + " SET shippriority = 101 WHERE custkey = (SELECT min(custkey) FROM customer)", 9);
assertQuery("SELECT count(*) FROM " + tableName + " WHERE shippriority = 101 AND custkey = (SELECT min(custkey) FROM customer)", "VALUES 9");
});
}

@Test
@Override
public void testCreateTableWithTableComment()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
import java.util.Map;
import java.util.Optional;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assumptions.abort;

public abstract class BaseMariaDbFailureRecoveryTest
extends BaseJdbcFailureRecoveryTest
{
Expand All @@ -55,14 +52,6 @@ protected QueryRunner createQueryRunner(List<TpchTable<?>> requiredTpchTables, M
.build();
}

@Test
@Override
protected void testUpdateWithSubquery()
{
assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("Unexpected Join over for-update table scan");
abort("skipped");
}

@Test
@Override
protected void testUpdate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
import java.util.Map;
import java.util.Optional;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assumptions.abort;

public abstract class BaseMySqlFailureRecoveryTest
extends BaseJdbcFailureRecoveryTest
{
Expand Down Expand Up @@ -58,14 +55,6 @@ protected QueryRunner createQueryRunner(
.build();
}

@Test
@Override
protected void testUpdateWithSubquery()
{
assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("Unexpected Join over for-update table scan");
abort("skipped");
}

@Test
@Override
protected void testUpdate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
import java.util.Map;
import java.util.Optional;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assumptions.abort;

public abstract class BaseOracleFailureRecoveryTest
extends BaseJdbcFailureRecoveryTest
{
Expand Down Expand Up @@ -59,14 +56,6 @@ protected QueryRunner createQueryRunner(
.build();
}

@Test
@Override
protected void testUpdateWithSubquery()
{
assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("Unexpected Join over for-update table scan");
abort("skipped");
}

@Test
@Override
protected void testUpdate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ protected void testDeleteWithSubquery()
@Override
protected void testUpdateWithSubquery()
{
assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("Unexpected Join over for-update table scan");
assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("Non-transactional MERGE is disabled");
abort("skipped");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
import java.util.Map;
import java.util.Optional;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assumptions.abort;

public abstract class BaseRedshiftFailureRecoveryTest
extends BaseJdbcFailureRecoveryTest
{
Expand Down Expand Up @@ -58,14 +55,6 @@ protected QueryRunner createQueryRunner(
.build();
}

@Test
@Override
protected void testUpdateWithSubquery()
{
assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("Unexpected Join over for-update table scan");
abort("skipped");
}

@Test
@Override
protected void testUpdate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
import java.util.Map;
import java.util.Optional;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assumptions.abort;

public abstract class BaseSqlServerFailureRecoveryTest
extends BaseJdbcFailureRecoveryTest
{
Expand Down Expand Up @@ -58,14 +55,6 @@ protected QueryRunner createQueryRunner(
.build();
}

@Test
@Override
protected void testUpdateWithSubquery()
{
assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("Unexpected Join over for-update table scan");
abort("skipped");
}

@Test
@Override
protected void testUpdate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6801,6 +6801,18 @@ public void testMergeAllColumnsReversed()
assertUpdate("DROP TABLE " + targetTable);
}

@Test
protected void testUpdateWithSubquery()
{
skipTestUnless(hasBehavior(SUPPORTS_MERGE));

try (TestTable table = createTestTableForWrites("test_update_with_subquery", " AS SELECT * FROM orders", "orderkey")) {
assertQuery("SELECT count(*) FROM " + table.getName() + " WHERE shippriority = 101 AND custkey = (SELECT min(custkey) FROM customer)", "VALUES 0");
assertUpdate("UPDATE " + table.getName() + " SET shippriority = 101 WHERE custkey = (SELECT min(custkey) FROM customer)", 9);
assertQuery("SELECT count(*) FROM " + table.getName() + " WHERE shippriority = 101 AND custkey = (SELECT min(custkey) FROM customer)", "VALUES 9");
}
}

private void verifyUnsupportedTypeException(Throwable exception, String trinoTypeName)
{
String typeNameBase = trinoTypeName.replaceFirst("\\(.*", "");
Expand Down

0 comments on commit 8205581

Please sign in to comment.