Skip to content

Commit 755d358

Browse files
committed
Addressing reviewers comments @mengxr
1 parent a6ad82a commit 755d358

File tree

2 files changed

+42
-26
lines changed

2 files changed

+42
-26
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,42 @@
1717

1818
package org.apache.spark.mllib.feature
1919

20+
import scala.collection.mutable.ArrayBuilder
21+
2022
import org.apache.spark.annotation.Experimental
21-
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector}
23+
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
2224
import org.apache.spark.mllib.regression.LabeledPoint
2325
import org.apache.spark.mllib.stat.Statistics
2426
import org.apache.spark.rdd.RDD
2527

26-
import scala.collection.mutable.ArrayBuilder
27-
2828
/**
2929
* :: Experimental ::
3030
* Chi Squared selector model.
3131
*
32-
* @param indices list of indices to select (filter). Must be ordered asc
32+
* @param selectedFeatures list of indices to select (filter). Must be ordered asc
3333
*/
3434
@Experimental
35-
class ChiSqSelectorModel private[mllib] (indices: Array[Int]) extends VectorTransformer {
35+
class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransformer {
36+
37+
require(isSorted(selectedFeatures), "Array has to be sorted asc")
38+
39+
protected def isSorted(array: Array[Int]): Boolean = {
40+
var i = 1
41+
while (i < array.length) {
42+
if (array(i) < array(i-1)) return false
43+
i += 1
44+
}
45+
true
46+
}
47+
3648
/**
3749
* Applies transformation on a vector.
3850
*
3951
* @param vector vector to be transformed.
4052
* @return transformed vector.
4153
*/
4254
override def transform(vector: Vector): Vector = {
43-
compress(vector, indices)
55+
compress(vector, selectedFeatures)
4456
}
4557

4658
/**
@@ -56,23 +68,27 @@ class ChiSqSelectorModel private[mllib] (indices: Array[Int]) extends VectorTra
5668
val newSize = filterIndices.length
5769
val newValues = new ArrayBuilder.ofDouble
5870
val newIndices = new ArrayBuilder.ofInt
59-
var i: Int = 0
60-
var j: Int = 0
61-
while(i < indices.length && j < filterIndices.length) {
62-
if(indices(i) == filterIndices(j)) {
71+
var i = 0
72+
var j = 0
73+
var indicesIdx = 0
74+
var filterIndicesIdx = 0
75+
while (i < indices.length && j < filterIndices.length) {
76+
indicesIdx = indices(i)
77+
filterIndicesIdx = filterIndices(j)
78+
if (indicesIdx == filterIndicesIdx) {
6379
newIndices += j
6480
newValues += values(i)
6581
j += 1
6682
i += 1
6783
} else {
68-
if(indices(i) > filterIndices(j)) {
84+
if (indicesIdx > filterIndicesIdx) {
6985
j += 1
7086
} else {
7187
i += 1
7288
}
7389
}
7490
}
75-
/** Sparse representation might be ineffective if (newSize ~= newValues.size) */
91+
// TODO: Sparse representation might be ineffective if (newSize ~= newValues.size)
7692
Vectors.sparse(newSize, newIndices.result(), newValues.result())
7793
case DenseVector(values) =>
7894
val values = features.toArray
@@ -96,13 +112,15 @@ class ChiSqSelector (val numTopFeatures: Int) {
96112
/**
97113
* Returns a ChiSquared feature selector.
98114
*
99-
* @param data data used to compute the Chi Squared statistic.
115+
* @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features.
116+
* Real-valued features will be treated as categorical for each distinct value.
117+
* Apply feature discretizer before using this function.
100118
*/
101119
def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
102120
val indices = Statistics.chiSqTest(data)
103-
.zipWithIndex.sortBy { case(res, _) => -res.statistic }
121+
.zipWithIndex.sortBy { case (res, _) => -res.statistic }
104122
.take(numTopFeatures)
105-
.map{ case(_, indices) => indices }
123+
.map { case (_, indices) => indices }
106124
.sorted
107125
new ChiSqSelectorModel(indices)
108126
}

mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,18 @@ class ChiSqSelectorSuite extends FunSuite with MLlibTestSparkContext {
4949

5050
test("ChiSqSelector transform test (sparse & dense vector)") {
5151
val labeledDiscreteData = sc.parallelize(
52-
Seq(new LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
53-
new LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))),
54-
new LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
55-
new LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))
56-
), 2)
52+
Seq(LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
53+
LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))),
54+
LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
55+
LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2)
5756
val preFilteredData =
58-
Set(new LabeledPoint(0.0, Vectors.dense(Array(0.0))),
59-
new LabeledPoint(1.0, Vectors.dense(Array(6.0))),
60-
new LabeledPoint(1.0, Vectors.dense(Array(8.0))),
61-
new LabeledPoint(2.0, Vectors.dense(Array(5.0)))
62-
)
57+
Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))),
58+
LabeledPoint(1.0, Vectors.dense(Array(6.0))),
59+
LabeledPoint(1.0, Vectors.dense(Array(8.0))),
60+
LabeledPoint(2.0, Vectors.dense(Array(5.0))))
6361
val model = new ChiSqSelector(1).fit(labeledDiscreteData)
6462
val filteredData = labeledDiscreteData.map { lp =>
65-
new LabeledPoint(lp.label, model.transform(lp.features))
63+
LabeledPoint(lp.label, model.transform(lp.features))
6664
}.collect().toSet
6765
assert(filteredData == preFilteredData)
6866
}

0 commit comments

Comments
 (0)