diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsParquetSchemaPruningSuite.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsParquetSchemaPruningSuite.scala index 689448fb7f0..32fa94a300a 100644 --- a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsParquetSchemaPruningSuite.scala +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/suites/RapidsParquetSchemaPruningSuite.scala @@ -19,9 +19,36 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.suites +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaPruningSuite +import org.apache.spark.sql.rapids.GpuFileSourceScanExec import org.apache.spark.sql.rapids.utils.RapidsSQLTestsBaseTrait class RapidsParquetSchemaPruningSuite extends ParquetSchemaPruningSuite - with RapidsSQLTestsBaseTrait {} + with RapidsSQLTestsBaseTrait { + + override protected def checkScanSchemata(df: DataFrame, + expectedSchemaCatalogStrings: String*): Unit = { + val fileSourceScanSchemata = + collect(df.queryExecution.executedPlan) { + case scan: FileSourceScanExec => scan.requiredSchema + case gpuScan: GpuFileSourceScanExec => gpuScan.requiredSchema + } + // Print the full execution plan + println("Full Execution Plan:") + println(df.queryExecution.executedPlan.treeString) + assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, + s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + + s"but expected $expectedSchemaCatalogStrings") + fileSourceScanSchemata.zip(expectedSchemaCatalogStrings).foreach { + case (scanSchema, expectedScanSchemaCatalogString) => + val expectedScanSchema = CatalystSqlParser.parseDataType(expectedScanSchemaCatalogString) + implicit val equality = schemaEquality + assert(scanSchema === expectedScanSchema) + } + } + +} diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala index 8b76e350fef..98bf8f35fd7 100644 --- a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala @@ -102,10 +102,10 @@ class RapidsTestSettings extends BackendTestSettings { .exclude("SPARK-31159: rebasing dates in write", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11404")) .exclude("SPARK-35427: datetime rebasing in the EXCEPTION mode", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11404")) enableSuite[RapidsParquetSchemaPruningSuite] - .excludeByPrefix("Spark vectorized reader", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11405")) - .excludeByPrefix("Non-vectorized reader", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11405")) - .excludeByPrefix("Case-insensitive parser", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11405")) - .excludeByPrefix("Case-sensitive parser", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11405")) + //.excludeByPrefix("Spark vectorized reader", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11405")) + //.excludeByPrefix("Non-vectorized reader", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11405")) + //.excludeByPrefix("Case-insensitive parser", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11405")) + //.excludeByPrefix("Case-sensitive parser", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11405")) enableSuite[RapidsParquetSchemaSuite] .exclude("schema mismatch failure error message for parquet reader", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11434")) .exclude("schema mismatch failure error message for parquet vectorized reader", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11446"))