@@ -32,7 +32,7 @@ import org.apache.spark.sql.types._
3232 * large numbers of rows where the regular percentile() UDAF might run out of memory.
3333 *
3434 * The input is a single double value or an array of double values representing the percentiles
35- * requested. The output, corresponding to the input, is either an single double value or an
35+ * requested. The output, corresponding to the input, is either a single double value or an
3636 * array of doubles that are the percentile values.
3737 */
3838@ ExpressionDescription (
@@ -177,26 +177,32 @@ object PercentileApprox {
177177 private [sql] def bToRelativeError (B : Int ): Double = Math .max(1.0d / B , 0.001 )
178178
179179 /**
180- * Validates the percentile(s) expression and extract the percentile(s).
180+ * Validates the percentile(s) expression and extracts the percentile(s).
181181 * Returns the extracted percentile(s) and an indicator of whether it's an array.
182182 */
183183 private def validatePercentilesLiteral (exp : Expression ): (Seq [Double ], Boolean ) = {
184184 def withinRange (v : Double ): Boolean = 0.0 <= v && v <= 1.0
185- exp match {
186- case Literal (f : Float , FloatType ) if withinRange(f) => (Seq (f.toDouble), false )
187- case Literal (d : Double , DoubleType ) if withinRange(d) => (Seq (d), false )
188- case Literal (dec : Decimal , _) if withinRange(dec.toDouble) => (Seq (dec.toDouble), false )
189-
190- case CreateArray (children : Seq [Expression ]) if (children.length > 0 ) =>
191- (children.map(_ match {
192- case Literal (f : Float , FloatType ) if withinRange(f) => f.toDouble
193- case Literal (d : Double , DoubleType ) if withinRange(d) => d
194- case Literal (dec : Decimal , _) if withinRange(dec.toDouble) => dec.toDouble
195- case _ =>
196- throw new AnalysisException (
197- " The second argument should be a double literal or an array of doubles, and should " +
198- " be within range [0.0, 1.0]" )
199- }), true )
185+ exp.eval() match {
186+ case f : Float if withinRange(f) => (Seq (f.toDouble), false )
187+ case d : Double if withinRange(d) => (Seq (d), false )
188+ case dec : Decimal if withinRange(dec.toDouble) => (Seq (dec.toDouble), false )
189+
190+ case arrayData : GenericArrayData if arrayData.numElements() > 0 => {
191+ val ret =
192+ // arrayData.array.getClass.getComponentType() doesn't help here because it always gives
193+ // us java.lang.Object
194+ arrayData.array(0 ) match {
195+ case _ : Float => arrayData.toFloatArray().map(_.toDouble)
196+ case _ : Double => arrayData.toDoubleArray()
197+ case _ : Decimal => arrayData.array.map(_.asInstanceOf [Decimal ].toDouble)
198+ }
199+ if (! ret.forall(withinRange)) {
200+ throw new AnalysisException (
201+ " The second argument should be a double literal or an array of doubles, and should " +
202+ " be within range [0.0, 1.0]" )
203+ }
204+ (ret, true )
205+ }
200206
201207 case _ =>
202208 throw new AnalysisException (
@@ -205,7 +211,7 @@ object PercentileApprox {
205211 }
206212 }
207213
208- /** Validates the B expression and extract its value. */
214+ /** Validates the B expression and extracts its value. */
209215 private def validateBLiteral (exp : Expression ): Int = exp match {
210216 case Literal (i : Int , IntegerType ) if i > 0 => i
211217
0 commit comments