Skip to content

Commit

Permalink
[SPARK-44018][SQL] Improve the hashCode and toString for some DS V2 E…
Browse files Browse the repository at this point in the history
…xpression

### What changes were proposed in this pull request?
The `hashCode() `of `UserDefinedScalarFunc` and `GeneralScalarExpression` is not good enough. Take for example, `GeneralScalarExpression` uses `Objects.hash(name, children)`, it adopt the hash code of `name` and `children`'s reference and then combine them together as the `GeneralScalarExpression`'s hash code.
In fact, we should adopt the hash code for each element in `children`.

Because `UserDefinedAggregateFunc` and `GeneralAggregateFunc` missing `hashCode()`, this PR also want add them.

This PR also improve the toString for `UserDefinedAggregateFunc` and `GeneralAggregateFunc` by using bool primitive comparison instead `Objects.equals`. Because the performance of bool primitive comparison better than `Objects.equals`.

### Why are the changes needed?
Improve the hash code for some DS V2 Expression.

### Does this PR introduce _any_ user-facing change?
'Yes'.

### How was this patch tested?
N/A

Closes apache#41543 from beliefer/SPARK-44018.

Authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 8c84d2c)
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
beliefer authored and catalinii committed Oct 10, 2023
1 parent 02919fa commit 9dbd37a
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.connector.expressions;

import java.util.Arrays;
import java.util.Objects;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.filter.Predicate;
Expand Down Expand Up @@ -441,12 +440,17 @@ public GeneralScalarExpression(String name, Expression[] children) {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

GeneralScalarExpression that = (GeneralScalarExpression) o;
return Objects.equals(name, that.name) && Arrays.equals(children, that.children);

if (!name.equals(that.name)) return false;
return Arrays.equals(children, that.children);
}

@Override
public int hashCode() {
return Objects.hash(name, children);
int result = name.hashCode();
result = 31 * result + Arrays.hashCode(children);
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.connector.expressions;

import java.util.Arrays;
import java.util.Objects;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.internal.connector.ExpressionWithToString;
Expand Down Expand Up @@ -51,13 +50,19 @@ public UserDefinedScalarFunc(String name, String canonicalName, Expression[] chi
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

UserDefinedScalarFunc that = (UserDefinedScalarFunc) o;
return Objects.equals(name, that.name) && Objects.equals(canonicalName, that.canonicalName) &&
Arrays.equals(children, that.children);

if (!name.equals(that.name)) return false;
if (!canonicalName.equals(that.canonicalName)) return false;
return Arrays.equals(children, that.children);
}

@Override
public int hashCode() {
return Objects.hash(name, canonicalName, children);
int result = name.hashCode();
result = 31 * result + canonicalName.hashCode();
result = 31 * result + Arrays.hashCode(children);
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.connector.expressions.aggregate;

import java.util.Arrays;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Expression;
import org.apache.spark.sql.internal.connector.ExpressionWithToString;
Expand Down Expand Up @@ -60,4 +62,24 @@ public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] childr

@Override
public Expression[] children() { return children; }

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

GeneralAggregateFunc that = (GeneralAggregateFunc) o;

if (isDistinct != that.isDistinct) return false;
if (!name.equals(that.name)) return false;
return Arrays.equals(children, that.children);
}

@Override
public int hashCode() {
int result = name.hashCode();
result = 31 * result + (isDistinct ? 1 : 0);
result = 31 * result + Arrays.hashCode(children);
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.connector.expressions.aggregate;

import java.util.Arrays;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Expression;
import org.apache.spark.sql.internal.connector.ExpressionWithToString;
Expand Down Expand Up @@ -50,4 +52,26 @@ public UserDefinedAggregateFunc(

@Override
public Expression[] children() { return children; }

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

UserDefinedAggregateFunc that = (UserDefinedAggregateFunc) o;

if (isDistinct != that.isDistinct) return false;
if (!name.equals(that.name)) return false;
if (!canonicalName.equals(that.canonicalName)) return false;
return Arrays.equals(children, that.children);
}

@Override
public int hashCode() {
int result = name.hashCode();
result = 31 * result + canonicalName.hashCode();
result = 31 * result + (isDistinct ? 1 : 0);
result = 31 * result + Arrays.hashCode(children);
return result;
}
}

0 comments on commit 9dbd37a

Please sign in to comment.