Skip to content

Commit 120662e

Browse files
committed
address comments.
1 parent c53a0c7 commit 120662e

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
6060
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
6161
val prunedFsRelation =
6262
fsRelation.copy(location = prunedFileIndex)(sparkSession)
63+
// Change table stats based on the sizeInBytes of pruned files
6364
val withStats = logicalRelation.catalogTable.map(_.copy(
6465
stats = Some(CatalogStatistics(sizeInBytes = BigInt(prunedFileIndex.sizeInBytes)))))
6566
val prunedLogicalRelation = logicalRelation.copy(

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.hive.execution
1919

2020
import org.apache.spark.sql.QueryTest
21+
import org.apache.spark.sql.catalyst.TableIdentifier
2122
import org.apache.spark.sql.catalyst.dsl.expressions._
2223
import org.apache.spark.sql.catalyst.dsl.plans._
2324
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
@@ -81,19 +82,26 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te
8182
""".stripMargin)
8283
}
8384

85+
val tableName = "partTbl"
86+
sql(s"analyze table partTbl compute STATISTICS")
87+
88+
val tableStats =
89+
spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).stats
90+
assert(tableStats.isDefined && tableStats.get.sizeInBytes > 0, "tableStats is lost")
91+
8492
withSQLConf(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key -> "true") {
8593
val df = sql("SELECT * FROM partTbl where part = 1")
8694
val query = df.queryExecution.analyzed.analyze
8795
val sizes1 = query.collect {
8896
case relation: LogicalRelation => relation.computeStats(conf).sizeInBytes
8997
}
9098
assert(sizes1.size === 1, s"Size wrong for:\n ${df.queryExecution}")
91-
assert(sizes1(0) > 5000, s"expected > 5000 for test table 'src', got: ${sizes1(0)}")
99+
assert(sizes1(0) == tableStats.get.sizeInBytes)
92100
val sizes2 = Optimize.execute(query).collect {
93101
case relation: LogicalRelation => relation.computeStats(conf).sizeInBytes
94102
}
95103
assert(sizes2.size === 1, s"Size wrong for:\n ${df.queryExecution}")
96-
assert(sizes2(0) < 5000, s"expected < 5000 for test table 'src', got: ${sizes2(0)}")
104+
assert(sizes2(0) < tableStats.get.sizeInBytes)
97105
}
98106
}
99107
}

0 commit comments

Comments
 (0)