Skip to content

Commit

Permalink
[opt](Nereids) let lead and lag type coersion behavior same with MySQL
Browse files Browse the repository at this point in the history
For input and default are numeric types, use below priority
- DecimalV3
- DecimalV2
- Double
- Float
- LargeInt
- BigInt
- Int
- SmallInt
- TinyInt

For input and default are Date or DateTime types, user below priority
- DateTimeV2
- DateTime
- DateV2
- Date

other wise, use String
  • Loading branch information
morrySnow committed Nov 26, 2024
1 parent 14d928b commit 68c34c4
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 223 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,17 @@

package org.apache.doris.nereids.trees.expressions.functions.window;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.shape.TernaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

import java.util.List;

/** Window function: lag */
public class Lag extends WindowFunction implements TernaryExpression, ExplicitlyCastableSignature,
RequireTrivialTypes {

static {
List<FunctionSignature> signatures = Lists.newArrayList();
trivialTypes.forEach(t ->
signatures.add(FunctionSignature.ret(t).args(t, BigIntType.INSTANCE, t))
);
SIGNATURES = ImmutableList.copyOf(signatures);
}

private static final List<FunctionSignature> SIGNATURES;
/**
* Window function: Lag()
*/
public class Lag extends LeadOrLag {

public Lag(Expression child, Expression offset, Expression defaultValue) {
super("lag", child, offset, defaultValue);
Expand All @@ -55,69 +37,14 @@ private Lag(List<Expression> children) {
super("lag", children);
}

public Expression getOffset() {
if (children().size() <= 1) {
throw new AnalysisException("Not set offset of Lead(): " + this.toSql());
}
return child(1);
}

public Expression getDefaultValue() {
if (children.size() <= 2) {
throw new AnalysisException("Not set default value of Lead(): " + this.toSql());
}
return child(2);
}

@Override
public boolean nullable() {
if (children.size() == 3 && child(2).nullable()) {
return true;
}
return child(0).nullable();
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitLag(this, context);
}

@Override
public Lag withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() >= 1 && children.size() <= 3);
return new Lag(children);
}

@Override
public void checkLegalityBeforeTypeCoercion() {
if (children().size() == 1) {
return;
}
if (children().size() >= 2) {
checkValidParams(getOffset(), true);
if (getOffset() instanceof Literal) {
if (((Literal) getOffset()).getDouble() < 0) {
throw new AnalysisException(
"The offset parameter of LAG must be a constant positive integer: " + this.toSql());
}
} else {
throw new AnalysisException(
"The offset parameter of LAG must be a constant positive integer: " + this.toSql());
}
if (children().size() >= 3) {
checkValidParams(getDefaultValue(), false);
}
}
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitLag(this, context);
}

@Override
public DataType getDataType() {
return child(0).getDataType();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,17 @@

package org.apache.doris.nereids.trees.expressions.functions.window;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.shape.TernaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

import java.util.List;

/**
* Window function: Lead()
*/
public class Lead extends WindowFunction implements TernaryExpression, ExplicitlyCastableSignature,
RequireTrivialTypes {

static {
List<FunctionSignature> signatures = Lists.newArrayList();
trivialTypes.forEach(t ->
signatures.add(FunctionSignature.ret(t).args(t, BigIntType.INSTANCE, t))
);
SIGNATURES = ImmutableList.copyOf(signatures);
}

private static final List<FunctionSignature> SIGNATURES;
public class Lead extends LeadOrLag {

public Lead(Expression child, Expression offset, Expression defaultValue) {
super("lead", child, offset, defaultValue);
Expand All @@ -57,68 +37,14 @@ private Lead(List<Expression> children) {
super("lead", children);
}

public Expression getOffset() {
if (children().size() <= 1) {
throw new AnalysisException("Not set offset of Lead(): " + this.toSql());
}
return child(1);
}

public Expression getDefaultValue() {
if (children.size() <= 2) {
throw new AnalysisException("Not set default value of Lead(): " + this.toSql());
}
return child(2);
}

@Override
public boolean nullable() {
if (children.size() == 3 && child(2).nullable()) {
return true;
}
return child(0).nullable();
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitLead(this, context);
}

@Override
public void checkLegalityBeforeTypeCoercion() {
if (children().size() == 1) {
return;
}
if (children().size() >= 2) {
checkValidParams(getOffset(), true);
if (getOffset() instanceof Literal) {
if (((Literal) getOffset()).getDouble() < 0) {
throw new AnalysisException(
"The offset parameter of LEAD must be a constant positive integer: " + this.toSql());
}
} else {
throw new AnalysisException(
"The offset parameter of LAG must be a constant positive integer: " + this.toSql());
}
if (children().size() >= 3) {
checkValidParams(getDefaultValue(), false);
}
}
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

@Override
public Lead withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() >= 1 && children.size() <= 3);
return new Lead(children);
}

@Override
public DataType getDataType() {
return child(0).getDataType();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.window;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.shape.TernaryExpression;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.util.TypeCoercionUtils;

import java.util.List;

/**
* types used for specific window functions.
* This list is equal with Type.trivialTypes used in legacy planner
*/
public abstract class LeadOrLag extends WindowFunction implements TernaryExpression, CustomSignature {

public LeadOrLag(String name, Expression child, Expression offset, Expression defaultValue) {
super(name, child, offset, defaultValue);
}

protected LeadOrLag(String name, List<Expression> children) {
super(name, children);
}

public Expression getOffset() {
if (children().size() <= 1) {
throw new AnalysisException("Not set offset of " + getName() + "(): " + this.toSql());
}
return child(1);
}

public Expression getDefaultValue() {
if (children.size() <= 2) {
throw new AnalysisException("Not set default value of " + getName() + "(): " + this.toSql());
}
return child(2);
}

@Override
public boolean nullable() {
if (children.size() == 3 && child(2).nullable()) {
return true;
}
return child(0).nullable();
}

@Override
public DataType getDataType() {
return child(0).getDataType();
}

@Override
public void checkLegalityBeforeTypeCoercion() {
if (children().size() == 1) {
return;
}
if (children().size() >= 2) {
checkValidParams(getOffset(), true);
if (getOffset() instanceof Literal) {
if (((Literal) getOffset()).getDouble() < 0) {
throw new AnalysisException(
"The offset parameter of " + getName()
+ " must be a constant positive integer: " + this.toSql());
}
} else {
throw new AnalysisException(
"The offset parameter of " + getName()
+ " must be a constant positive integer: " + this.toSql());
}
if (children().size() >= 3) {
checkValidParams(getDefaultValue(), false);
}
}
}

@Override
public FunctionSignature customSignature() {
DataType inputType = getArgument(0).getDataType();
DataType defaultType = getDefaultValue().getDataType();
DataType commonType = inputType;
if (inputType.isNumericType() && defaultType.isNumericType()) {
for (DataType dataType : TypeCoercionUtils.NUMERIC_PRECEDENCE) {
if (inputType.equals(dataType) || defaultType.equals(dataType)) {
commonType = dataType;
break;
}
}
if (commonType.isFloatLikeType() && (inputType.isDecimalV3Type() || defaultType.isDecimalV3Type())) {
commonType = DoubleType.INSTANCE;
}
if (inputType.isDecimalV2Type() || defaultType.isDecimalV2Type()) {
commonType = DecimalV2Type.SYSTEM_DEFAULT;
}
if (inputType.isDecimalV3Type() || defaultType.isDecimalV3Type()) {
commonType = DecimalV3Type.widerDecimalV3Type(
DecimalV3Type.forType(inputType), DecimalV3Type.forType(defaultType), true);
}
} else if (inputType.isDateLikeType() && inputType.isDateLikeType()) {
if (inputType.isDateTimeV2Type() || defaultType.isDateTimeV2Type()) {
commonType = DateTimeV2Type.getWiderDatetimeV2Type(
DateTimeV2Type.forType(inputType), DateTimeV2Type.forType(defaultType));
} else if (inputType.isDateTimeType() || defaultType.isDateTimeType()) {
commonType = DateTimeType.INSTANCE;
} else if (inputType.isDateV2Type() || defaultType.isDateV2Type()) {
commonType = DateV2Type.INSTANCE;
} else {
commonType = DateType.INSTANCE;
}
} else if (!defaultType.isNullType()) {
commonType = StringType.INSTANCE;
}
return FunctionSignature.ret(commonType).args(commonType, BigIntType.INSTANCE, commonType);
}
}
Loading

0 comments on commit 68c34c4

Please sign in to comment.