Skip to content

Commit e0605d6

Browse files
committed
rdd -> df
1 parent 91ae454 commit e0605d6

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -176,18 +176,20 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
176176
.map(_.headOption.getOrElse(Double.NaN))
177177

178178
case Imputer.mode =>
179-
// Keep in line with sklearn.impute.SimpleImputer (using scipy.stats.mode).
180-
// If there is more than one mode, choose the smallest one.
181-
val modes = dataset.select(cols: _*).rdd.flatMap { row =>
179+
import spark.implicits._
180+
// If there is more than one mode, choose the smallest one to keep in line
181+
// with sklearn.impute.SimpleImputer (using scipy.stats.mode).
182+
val modes = dataset.select(cols: _*).flatMap { row =>
182183
Iterator.range(0, numCols).flatMap { i =>
183184
// Ignore null.
184-
if (row.isNullAt(i)) Iterator.empty else Iterator.single((i, row.getDouble(i)), 1L)
185+
// negative value to apply the default ranking of [Long, Double]
186+
if (row.isNullAt(i)) Iterator.empty else Iterator.single((i, -row.getDouble(i)))
185187
}
186-
}.reduceByKey(_ + _).map { case ((i, v), c) =>
187-
// negative value to apply the default ranking of [Long, Double]
188-
(i, (c, -v))
189-
}.reduceByKey(Ordering.apply[(Long, Double)].max
190-
).mapValues(-_._2).collectAsMap()
188+
}.toDF("index", "negative_value")
189+
.groupBy("index", "negative_value").agg(count(lit(0)).as("count"))
190+
.groupBy("index").agg(max(struct("count", "negative_value")).as("mode"))
191+
.select(col("index"), negate(col("mode.negative_value")))
192+
.as[(Int, Double)].collect().toMap
191193
Array.tabulate(numCols)(i => modes.getOrElse(i, Double.NaN))
192194
}
193195

0 commit comments

Comments
 (0)