Skip to content

Commit c0acf16

Browse files
committed
Address comments
1 parent da87bfd commit c0acf16

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileApprox.scala

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types._
3232
* large numbers of rows where the regular percentile() UDAF might run out of memory.
3333
*
3434
* The input is a single double value or an array of double values representing the percentiles
35-
* requested. The output, corresponding to the input, is either an single double value or an
35+
* requested. The output, corresponding to the input, is either a single double value or an
3636
* array of doubles that are the percentile values.
3737
*/
3838
@ExpressionDescription(
@@ -177,26 +177,32 @@ object PercentileApprox {
177177
private[sql] def bToRelativeError(B: Int): Double = Math.max(1.0d / B, 0.001)
178178

179179
/**
180-
* Validates the percentile(s) expression and extract the percentile(s).
180+
* Validates the percentile(s) expression and extracts the percentile(s).
181181
* Returns the extracted percentile(s) and an indicator of whether it's an array.
182182
*/
183183
private def validatePercentilesLiteral(exp: Expression): (Seq[Double], Boolean) = {
184184
def withinRange(v: Double): Boolean = 0.0 <= v && v <= 1.0
185-
exp match {
186-
case Literal(f: Float, FloatType) if withinRange(f) => (Seq(f.toDouble), false)
187-
case Literal(d: Double, DoubleType) if withinRange(d) => (Seq(d), false)
188-
case Literal(dec: Decimal, _) if withinRange(dec.toDouble) => (Seq(dec.toDouble), false)
189-
190-
case CreateArray(children: Seq[Expression]) if (children.length > 0) =>
191-
(children.map(_ match {
192-
case Literal(f: Float, FloatType) if withinRange(f) => f.toDouble
193-
case Literal(d: Double, DoubleType) if withinRange(d) => d
194-
case Literal(dec: Decimal, _) if withinRange(dec.toDouble) => dec.toDouble
195-
case _ =>
196-
throw new AnalysisException(
197-
"The second argument should be a double literal or an array of doubles, and should " +
198-
"be within range [0.0, 1.0]")
199-
}), true)
185+
exp.eval() match {
186+
case f: Float if withinRange(f) => (Seq(f.toDouble), false)
187+
case d: Double if withinRange(d) => (Seq(d), false)
188+
case dec: Decimal if withinRange(dec.toDouble) => (Seq(dec.toDouble), false)
189+
190+
case arrayData: GenericArrayData if arrayData.numElements() > 0 => {
191+
val ret =
192+
// arrayData.array.getClass.getComponentType() doesn't help here because it always gives
193+
// us java.lang.Object
194+
arrayData.array(0) match {
195+
case _: Float => arrayData.toFloatArray().map(_.toDouble)
196+
case _: Double => arrayData.toDoubleArray()
197+
case _: Decimal => arrayData.array.map(_.asInstanceOf[Decimal].toDouble)
198+
}
199+
if (!ret.forall(withinRange)) {
200+
throw new AnalysisException(
201+
"The second argument should be a double literal or an array of doubles, and should " +
202+
"be within range [0.0, 1.0]")
203+
}
204+
(ret, true)
205+
}
200206

201207
case _ =>
202208
throw new AnalysisException(
@@ -205,7 +211,7 @@ object PercentileApprox {
205211
}
206212
}
207213

208-
/** Validates the B expression and extract its value. */
214+
/** Validates the B expression and extracts its value. */
209215
private def validateBLiteral(exp: Expression): Int = exp match {
210216
case Literal(i: Int, IntegerType) if i > 0 => i
211217

0 commit comments

Comments
 (0)