Skip to content

Commit 3dda1ae

Browse files
committed
fix UT
1 parent 6be2ebe commit 3dda1ae

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ case class ApproximatePercentile(
9191
// Rule ImplicitTypeCasts can cast other numeric types to double
9292
case (_, num: Double) => (false, Array(num))
9393
case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) =>
94-
val numericArray = arrayData.toArray(baseType)(baseType.classTag)
95-
(true, numericArray.map(baseType.numeric.toDouble))
94+
val numericArray = arrayData.toObjectArray(baseType)
95+
(true, numericArray.map {x =>
96+
baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])
97+
})
9698
case other =>
9799
throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage")
98100
}
@@ -183,8 +185,8 @@ object ApproximatePercentile {
183185
* underlying quantileSummaries is compressed.
184186
*/
185187
class PercentileDigest(
186-
private var summaries: QuantileSummaries,
187-
private var isCompressed: Boolean) {
188+
private var summaries: QuantileSummaries,
189+
private var isCompressed: Boolean) {
188190

189191
// Trigger compression if the QuantileSummaries's buffer length exceeds
190192
// compressThresHoldBufferLength. The buffer length can be get by
@@ -287,10 +289,10 @@ object ApproximatePercentile {
287289
// An case class to wrap fields of QuantileSummaries, so that we can use the expression encoder
288290
// to serialize it.
289291
case class QuantileSummariesData(
290-
val compressThreshold: Int,
291-
val relativeError: Double,
292-
val sampled: Array[Stats] = Array.empty,
293-
val count: Long = 0L) {
292+
compressThreshold: Int,
293+
relativeError: Double,
294+
sampled: Array[Stats],
295+
count: Long) {
294296
def this(summary: QuantileSummaries) = {
295297
this(summary.compressThreshold, summary.relativeError, summary.sampled, summary.count)
296298
}

sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,11 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
155155

156156
test("aggregate functions") {
157157
checkSqlGeneration("SELECT approx_count_distinct(value) FROM t1 GROUP BY key")
158-
checkSqlGeneration("SELECT percentile_approx(value) FROM t1 GROUP BY key")
158+
checkSqlGeneration("SELECT percentile_approx(value, 0.25) FROM t1 GROUP BY key")
159+
checkSqlGeneration("SELECT percentile_approx(value, array(0.25, 0.75)) FROM t1 GROUP BY key")
160+
checkSqlGeneration("SELECT percentile_approx(value, 0.25, 100) FROM t1 GROUP BY key")
161+
checkSqlGeneration(
162+
"SELECT percentile_approx(value, array(0.25, 0.75), 100) FROM t1 GROUP BY key")
159163
checkSqlGeneration("SELECT avg(value) FROM t1 GROUP BY key")
160164
checkSqlGeneration("SELECT corr(value, key) FROM t1 GROUP BY key")
161165
checkSqlGeneration("SELECT count(value) FROM t1 GROUP BY key")

0 commit comments

Comments
 (0)