From 9dbd37adb806f7deb18929315cc54dfc9549ac41 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 19 Jun 2023 15:55:06 +0800 Subject: [PATCH] [SPARK-44018][SQL] Improve the hashCode and toString for some DS V2 Expression ### What changes were proposed in this pull request? The `hashCode() `of `UserDefinedScalarFunc` and `GeneralScalarExpression` is not good enough. Take for example, `GeneralScalarExpression` uses `Objects.hash(name, children)`, it adopt the hash code of `name` and `children`'s reference and then combine them together as the `GeneralScalarExpression`'s hash code. In fact, we should adopt the hash code for each element in `children`. Because `UserDefinedAggregateFunc` and `GeneralAggregateFunc` missing `hashCode()`, this PR also want add them. This PR also improve the toString for `UserDefinedAggregateFunc` and `GeneralAggregateFunc` by using bool primitive comparison instead `Objects.equals`. Because the performance of bool primitive comparison better than `Objects.equals`. ### Why are the changes needed? Improve the hash code for some DS V2 Expression. ### Does this PR introduce _any_ user-facing change? 'Yes'. ### How was this patch tested? N/A Closes #41543 from beliefer/SPARK-44018. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan (cherry picked from commit 8c84d2c9349d7b607db949c2e114df781f23e438) Signed-off-by: Wenchen Fan --- .../expressions/GeneralScalarExpression.java | 10 +++++--- .../expressions/UserDefinedScalarFunc.java | 13 ++++++---- .../aggregate/GeneralAggregateFunc.java | 22 +++++++++++++++++ .../aggregate/UserDefinedAggregateFunc.java | 24 +++++++++++++++++++ 4 files changed, 62 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java index cb9bf6d69e2ea..859660600214d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.connector.expressions; import java.util.Arrays; -import java.util.Objects; import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.filter.Predicate; @@ -441,12 +440,17 @@ public GeneralScalarExpression(String name, Expression[] children) { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; + GeneralScalarExpression that = (GeneralScalarExpression) o; - return Objects.equals(name, that.name) && Arrays.equals(children, that.children); + + if (!name.equals(that.name)) return false; + return Arrays.equals(children, that.children); } @Override public int hashCode() { - return Objects.hash(name, children); + int result = name.hashCode(); + result = 31 * result + Arrays.hashCode(children); + return result; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java index b7f603cd43162..cbf3941d77d6c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.connector.expressions; import java.util.Arrays; -import java.util.Objects; import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.internal.connector.ExpressionWithToString; @@ -51,13 +50,19 @@ public UserDefinedScalarFunc(String name, String canonicalName, Expression[] chi public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; + UserDefinedScalarFunc that = (UserDefinedScalarFunc) o; - return Objects.equals(name, that.name) && Objects.equals(canonicalName, that.canonicalName) && - Arrays.equals(children, that.children); + + if (!name.equals(that.name)) return false; + if (!canonicalName.equals(that.canonicalName)) return false; + return Arrays.equals(children, that.children); } @Override public int hashCode() { - return Objects.hash(name, canonicalName, children); + int result = name.hashCode(); + result = 31 * result + canonicalName.hashCode(); + result = 31 * result + Arrays.hashCode(children); + return result; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java index 1abf386565913..4ef5b7f97e926 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector.expressions.aggregate; +import java.util.Arrays; + import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.sql.internal.connector.ExpressionWithToString; @@ -60,4 +62,24 @@ public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] childr @Override public Expression[] children() { return children; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GeneralAggregateFunc that = (GeneralAggregateFunc) o; + + if (isDistinct != that.isDistinct) return false; + if (!name.equals(that.name)) return false; + return Arrays.equals(children, that.children); + } + + @Override + public int hashCode() { + int result = name.hashCode(); + result = 31 * result + (isDistinct ? 1 : 0); + result = 31 * result + Arrays.hashCode(children); + return result; + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java index d166ba16ba52c..10a62d0478b6d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector.expressions.aggregate; +import java.util.Arrays; + import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.sql.internal.connector.ExpressionWithToString; @@ -50,4 +52,26 @@ public UserDefinedAggregateFunc( @Override public Expression[] children() { return children; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + UserDefinedAggregateFunc that = (UserDefinedAggregateFunc) o; + + if (isDistinct != that.isDistinct) return false; + if (!name.equals(that.name)) return false; + if (!canonicalName.equals(that.canonicalName)) return false; + return Arrays.equals(children, that.children); + } + + @Override + public int hashCode() { + int result = name.hashCode(); + result = 31 * result + canonicalName.hashCode(); + result = 31 * result + (isDistinct ? 1 : 0); + result = 31 * result + Arrays.hashCode(children); + return result; + } }