1717
1818package org .apache .spark .mllib .feature
1919
20+ import scala .collection .mutable .ArrayBuilder
21+
2022import org .apache .spark .annotation .Experimental
21- import org .apache .spark .mllib .linalg .{DenseVector , SparseVector , Vectors , Vector }
23+ import org .apache .spark .mllib .linalg .{DenseVector , SparseVector , Vector , Vectors }
2224import org .apache .spark .mllib .regression .LabeledPoint
2325import org .apache .spark .mllib .stat .Statistics
2426import org .apache .spark .rdd .RDD
2527
26- import scala .collection .mutable .ArrayBuilder
27-
2828/**
2929 * :: Experimental ::
3030 * Chi Squared selector model.
3131 *
32- * @param indices list of indices to select (filter). Must be ordered asc
32+ * @param selectedFeatures list of indices to select (filter). Must be ordered asc
3333 */
3434@ Experimental
35- class ChiSqSelectorModel private [mllib] (indices : Array [Int ]) extends VectorTransformer {
35+ class ChiSqSelectorModel (val selectedFeatures : Array [Int ]) extends VectorTransformer {
36+
37+ require(isSorted(selectedFeatures), " Array has to be sorted asc" )
38+
39+ protected def isSorted (array : Array [Int ]): Boolean = {
40+ var i = 1
41+ while (i < array.length) {
42+ if (array(i) < array(i- 1 )) return false
43+ i += 1
44+ }
45+ true
46+ }
47+
3648 /**
3749 * Applies transformation on a vector.
3850 *
3951 * @param vector vector to be transformed.
4052 * @return transformed vector.
4153 */
4254 override def transform (vector : Vector ): Vector = {
43- compress(vector, indices )
55+ compress(vector, selectedFeatures )
4456 }
4557
4658 /**
@@ -56,23 +68,27 @@ class ChiSqSelectorModel private[mllib] (indices: Array[Int]) extends VectorTra
5668 val newSize = filterIndices.length
5769 val newValues = new ArrayBuilder .ofDouble
5870 val newIndices = new ArrayBuilder .ofInt
59- var i : Int = 0
60- var j : Int = 0
61- while (i < indices.length && j < filterIndices.length) {
62- if (indices(i) == filterIndices(j)) {
71+ var i = 0
72+ var j = 0
73+ var indicesIdx = 0
74+ var filterIndicesIdx = 0
75+ while (i < indices.length && j < filterIndices.length) {
76+ indicesIdx = indices(i)
77+ filterIndicesIdx = filterIndices(j)
78+ if (indicesIdx == filterIndicesIdx) {
6379 newIndices += j
6480 newValues += values(i)
6581 j += 1
6682 i += 1
6783 } else {
68- if (indices(i) > filterIndices(j) ) {
84+ if (indicesIdx > filterIndicesIdx ) {
6985 j += 1
7086 } else {
7187 i += 1
7288 }
7389 }
7490 }
75- /** Sparse representation might be ineffective if (newSize ~= newValues.size) */
91+ // TODO: Sparse representation might be ineffective if (newSize ~= newValues.size)
7692 Vectors .sparse(newSize, newIndices.result(), newValues.result())
7793 case DenseVector (values) =>
7894 val values = features.toArray
@@ -96,13 +112,15 @@ class ChiSqSelector (val numTopFeatures: Int) {
96112 /**
97113 * Returns a ChiSquared feature selector.
98114 *
99- * @param data data used to compute the Chi Squared statistic.
115+ * @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features.
116+ * Real-valued features will be treated as categorical for each distinct value.
117+ * Apply feature discretizer before using this function.
100118 */
101119 def fit (data : RDD [LabeledPoint ]): ChiSqSelectorModel = {
102120 val indices = Statistics .chiSqTest(data)
103- .zipWithIndex.sortBy { case (res, _) => - res.statistic }
121+ .zipWithIndex.sortBy { case (res, _) => - res.statistic }
104122 .take(numTopFeatures)
105- .map{ case (_, indices) => indices }
123+ .map { case (_, indices) => indices }
106124 .sorted
107125 new ChiSqSelectorModel (indices)
108126 }
0 commit comments