Skip to content

Commit

Permalink
Support integral cast projection pushdown in redshift
Browse files Browse the repository at this point in the history
  • Loading branch information
krvikash authored and Praveen2112 committed Oct 16, 2024
1 parent f93e105 commit e4607a3
Show file tree
Hide file tree
Showing 6 changed files with 723 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
*/
package io.trino.plugin.jdbc;

import io.trino.Session;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.testing.AbstractTestQueryFramework;
import io.trino.testing.sql.SqlExecutor;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.Optional;

import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -69,10 +72,29 @@ public void testJoinPushdownWithCast()
@Test
public void testInvalidCast()
{
for (InvalidCastTestCase testCase : invalidCast()) {
assertThat(query("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), leftTable())))
.failure()
.hasMessageMatching(testCase.errorMessage());
assertInvalidCast(leftTable(), invalidCast());
}

protected void assertInvalidCast(String tableName, List<InvalidCastTestCase> invalidCastTestCases)
{
Session withoutPushdown = Session.builder(getSession())
.setSystemProperty("allow_pushdown_into_connectors", "false")
.build();

for (InvalidCastTestCase testCase : invalidCastTestCases) {
if (testCase.pushdownErrorMessage().isPresent()) {
assertThat(query("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), tableName)))
.failure()
.hasMessageMatching(testCase.pushdownErrorMessage().get());
assertThat(query(withoutPushdown, "SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), tableName)))
.failure()
.hasMessageMatching(testCase.errorMessage());
}
else {
assertThat(query("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), tableName)))
.failure()
.hasMessageMatching(testCase.errorMessage());
}
}
}

Expand All @@ -86,18 +108,29 @@ public record CastTestCase(String sourceColumn, String castType, String targetCo
}
}

public record InvalidCastTestCase(String sourceColumn, String castType, String errorMessage)
public record InvalidCastTestCase(String sourceColumn, String castType, String errorMessage, Optional<String> pushdownErrorMessage)
{
public InvalidCastTestCase(String sourceColumn, String castType)
{
this(sourceColumn, castType, "(.*)Cannot cast (.*) to (.*)");
}

public InvalidCastTestCase(String sourceColumn, String castType, String errorMessage)
{
this(sourceColumn, castType, errorMessage, Optional.empty());
}

public InvalidCastTestCase(String sourceColumn, String castType, String errorMessage, @Language("RegExp") String pushdownErrorMessage)
{
this(sourceColumn, castType, errorMessage, Optional.of(pushdownErrorMessage));
}

public InvalidCastTestCase
{
requireNonNull(sourceColumn, "sourceColumn is null");
requireNonNull(castType, "castType is null");
requireNonNull(errorMessage, "errorMessage is null");
requireNonNull(pushdownErrorMessage, "pushdownErrorMessage is null");
}
}
}
2 changes: 2 additions & 0 deletions plugin/trino-redshift/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@
<configuration>
<excludes>
<exclude>**/TestRedshiftAutomaticJoinPushdown.java</exclude>
<exclude>**/TestRedshiftCastPushdown.java</exclude>
<exclude>**/TestRedshiftConnectorTest.java</exclude>
<exclude>**/TestRedshiftConnectorSmokeTest.java</exclude>
<exclude>**/TestRedshiftTableStatisticsReader.java</exclude>
Expand Down Expand Up @@ -262,6 +263,7 @@
<!-- Run only the smoke tests of the connector on the CI environment due to unpredictable -->
<!-- locations of GitHub runners which can lead to increased client latency on the -->
<!-- JDBC operations performed on the ephemeral AWS Redshift cluster. -->
<include>**/TestRedshiftCastPushdown.java</include>
<include>**/TestRedshiftConnectorSmokeTest.java</include>
</includes>
</configuration>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import io.trino.plugin.base.aggregation.AggregateFunctionRule;
import io.trino.plugin.base.expression.ConnectorExpressionRewriter;
import io.trino.plugin.base.mapping.IdentifierMapping;
import io.trino.plugin.base.projection.ProjectFunctionRewriter;
import io.trino.plugin.base.projection.ProjectFunctionRule;
import io.trino.plugin.jdbc.BaseJdbcClient;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.ColumnMapping;
Expand Down Expand Up @@ -222,6 +224,7 @@ public class RedshiftClient
.toFormatter();
private static final OffsetDateTime REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ = OffsetDateTime.of(-4712, 1, 1, 0, 0, 0, 0, ZoneOffset.UTC);

private final ProjectFunctionRewriter<JdbcExpression, ParameterizedExpression> projectFunctionRewriter;
private final AggregateFunctionRewriter<JdbcExpression, ?> aggregateFunctionRewriter;
private final boolean statisticsEnabled;
private final RedshiftTableStatisticsReader statisticsReader;
Expand All @@ -248,6 +251,12 @@ public RedshiftClient(
.map("$greater_than_or_equal(left, right)").to("left >= right")
.build();

this.projectFunctionRewriter = new ProjectFunctionRewriter<>(
connectorExpressionRewriter,
ImmutableSet.<ProjectFunctionRule<JdbcExpression, ParameterizedExpression>>builder()
.add(new RewriteCast((session, type) -> toWriteMapping(session, type).getDataType()))
.build());

JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());

aggregateFunctionRewriter = new AggregateFunctionRewriter<>(
Expand Down Expand Up @@ -359,6 +368,12 @@ public Optional<JdbcExpression> implementAggregation(ConnectorSession session, A
return aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
}

@Override
public Optional<JdbcExpression> convertProjection(ConnectorSession session, JdbcTableHandle handle, ConnectorExpression expression, Map<String, ColumnHandle> assignments)
{
return projectFunctionRewriter.rewrite(session, handle, expression, assignments);
}

@Override
public Optional<ParameterizedExpression> convertPredicate(ConnectorSession session, ConnectorExpression expression, Map<String, ColumnHandle> assignments)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.redshift;

import com.google.common.collect.ImmutableList;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.expression.AbstractRewriteCast;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Optional;
import java.util.function.BiFunction;

import static java.sql.Types.BIGINT;
import static java.sql.Types.BIT;
import static java.sql.Types.INTEGER;
import static java.sql.Types.NUMERIC;
import static java.sql.Types.SMALLINT;

public class RewriteCast
extends AbstractRewriteCast
{
private static final List<Integer> SUPPORTED_SOURCE_TYPE_FOR_INTEGRAL_CAST = ImmutableList.of(BIT, SMALLINT, INTEGER, BIGINT, NUMERIC);

public RewriteCast(BiFunction<ConnectorSession, Type, String> jdbcTypeProvider)
{
super(jdbcTypeProvider);
}

@Override
protected Optional<JdbcTypeHandle> toJdbcTypeHandle(JdbcTypeHandle sourceType, Type targetType)
{
if (!pushdownSupported(sourceType, targetType)) {
return Optional.empty();
}

return switch (targetType) {
case SmallintType smallintType ->
Optional.of(new JdbcTypeHandle(SMALLINT, Optional.of(smallintType.getBaseName()), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()));
case IntegerType integerType ->
Optional.of(new JdbcTypeHandle(INTEGER, Optional.of(integerType.getBaseName()), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()));
case BigintType bigintType ->
Optional.of(new JdbcTypeHandle(BIGINT, Optional.of(bigintType.getBaseName()), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()));
default -> Optional.empty();
};
}

private boolean pushdownSupported(JdbcTypeHandle sourceType, Type targetType)
{
return switch (targetType) {
case SmallintType _, IntegerType _, BigintType _ ->
SUPPORTED_SOURCE_TYPE_FOR_INTEGRAL_CAST.contains(sourceType.jdbcType());
default -> false;
};
}

@Override
protected String buildCast(Type sourceType, Type targetType, String expression, String castType)
{
if (sourceType instanceof DecimalType && isIntegralType(targetType)) {
// Trino rounds up to nearest integral value, whereas Redshift does not.
// So using ROUND() to make pushdown same as the trino behavior
return "CAST(ROUND(%s) AS %s)".formatted(expression, castType);
}
return "CAST(%s AS %s)".formatted(expression, castType);
}

private boolean isIntegralType(Type type)
{
return type instanceof SmallintType
|| type instanceof IntegerType
|| type instanceof BigintType;
}
}
Loading

0 comments on commit e4607a3

Please sign in to comment.