@@ -176,18 +176,20 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
176176 .map(_.headOption.getOrElse(Double .NaN ))
177177
178178 case Imputer .mode =>
179- // Keep in line with sklearn.impute.SimpleImputer (using scipy.stats.mode).
180- // If there is more than one mode, choose the smallest one.
181- val modes = dataset.select(cols : _* ).rdd.flatMap { row =>
179+ import spark .implicits ._
180+ // If there is more than one mode, choose the smallest one to keep in line
181+ // with sklearn.impute.SimpleImputer (using scipy.stats.mode).
182+ val modes = dataset.select(cols : _* ).flatMap { row =>
182183 Iterator .range(0 , numCols).flatMap { i =>
183184 // Ignore null.
184- if (row.isNullAt(i)) Iterator .empty else Iterator .single((i, row.getDouble(i)), 1L )
185+ // negative value to apply the default ranking of [Long, Double]
186+ if (row.isNullAt(i)) Iterator .empty else Iterator .single((i, - row.getDouble(i)))
185187 }
186- }.reduceByKey(_ + _).map { case ((i, v), c) =>
187- // negative value to apply the default ranking of [Long, Double]
188- (i, (c, - v ))
189- }.reduceByKey( Ordering .apply[( Long , Double )].max
190- ).mapValues( - _._2).collectAsMap()
188+ }.toDF( " index " , " negative_value " )
189+ .groupBy( " index " , " negative_value " ).agg(count(lit( 0 )).as( " count " ))
190+ .groupBy( " index " ).agg(max(struct( " count " , " negative_value " )).as( " mode " ))
191+ .select(col( " index " ), negate(col( " mode.negative_value " )))
192+ .as[( Int , Double )].collect().toMap
191193 Array .tabulate(numCols)(i => modes.getOrElse(i, Double .NaN ))
192194 }
193195
0 commit comments