@@ -26,6 +26,8 @@ import org.apache.spark.mllib.stat.impl.MultivariateGaussian
2626import org .apache .spark .{Accumulator , AccumulatorParam , SparkContext }
2727import org .apache .spark .SparkContext .DoubleAccumulatorParam
2828
29+ import scala .collection .mutable .IndexedSeqView
30+
2931/**
3032 * This class performs expectation maximization for multivariate Gaussian
3133 * Mixture Models (GMMs). A GMM represents a composite distribution of
@@ -51,6 +53,7 @@ class GaussianMixtureModelEM private (
5153 // Type aliases for convenience
5254 private type DenseDoubleVector = BreezeVector [Double ]
5355 private type DenseDoubleMatrix = BreezeMatrix [Double ]
56+ private type VectorArrayView = IndexedSeqView [DenseDoubleVector , Array [DenseDoubleVector ]]
5457
5558 private type ExpectationSum = (
5659 Array [Double ], // log-likelihood in index 0
@@ -80,10 +83,12 @@ class GaussianMixtureModelEM private (
8083
8184 // compute cluster contributions for each input point
8285 // (U, T) => U for aggregation
83- private def computeExpectation (weights : Array [Double ], dists : Array [MultivariateGaussian ])
86+ private def computeExpectation (
87+ weights : Array [Double ],
88+ dists : Array [MultivariateGaussian ])
8489 (model : ExpectationSum , x : DenseDoubleVector ): ExpectationSum = {
8590 val k = model._2.length
86- val p = ( 0 until k ).map(i => eps + weights(i) * dists(i) .pdf(x)).toArray
91+ val p = weights.zip(dists ).map { case (weight, dist) => eps + weight * dist .pdf(x) }
8792 val pSum = p.sum
8893 model._1(0 ) += math.log(pSum)
8994 val xxt = x * new Transpose (x)
@@ -106,7 +111,10 @@ class GaussianMixtureModelEM private (
106111 /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
107112 def this () = this (2 , 0.01 , 100 )
108113
109- /** Set the initial GMM starting point, bypassing the random initialization */
114+ /** Set the initial GMM starting point, bypassing the random initialization.
115+ * You must call setK() prior to calling this method, and the condition
116+ * (gmm.k == this.k) must be met; failure will result in an IllegalArgumentException
117+ */
110118 def setInitialGmm (gmm : GaussianMixtureModel ): this .type = {
111119 if (gmm.k == k) {
112120 initialGmm = Some (gmm)
@@ -156,34 +164,30 @@ class GaussianMixtureModelEM private (
156164
157165 /** Perform expectation maximization */
158166 def run (data : RDD [Vector ]): GaussianMixtureModel = {
159- val ctx = data.sparkContext
167+ val sc = data.sparkContext
160168
161169 // we will operate on the data as breeze data
162170 val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
163171
164172 // Get length of the input vectors
165173 val d = breezeData.first.length
166174
167- // gaussians will be array of (weight, mean, covariance) tuples .
175+ // Determine initial weights and corresponding Gaussians .
168176 // If the user supplied an initial GMM, we use those values, otherwise
169177 // we start with uniform weights, a random mean from the data, and
170178 // diagonal covariance matrices using component variances
171- // derived from the samples
172- var gaussians = initialGmm match {
173- case Some (gmm) => (0 until k ).map{ i =>
174- (gmm.weight(i), gmm.mu(i) .toBreeze.toDenseVector, gmm. sigma(i) .toBreeze.toDenseMatrix)
175- }.toArray
179+ // derived from the samples
180+ val (weights, gaussians) = initialGmm match {
181+ case Some (gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma ).map{ case (mu, sigma) =>
182+ new MultivariateGaussian (mu .toBreeze.toDenseVector, sigma.toBreeze.toDenseMatrix)
183+ }.toArray)
176184
177185 case None => {
178- // For each Gaussian, we will initialize the mean as the average
179- // of some random samples from the data
180186 val samples = breezeData.takeSample(true , k * nSamples, scala.util.Random .nextInt)
181-
182- (0 until k).map{ i =>
183- (1.0 / k,
184- vectorMean(samples.slice(i * nSamples, (i + 1 ) * nSamples)),
185- initCovariance(samples.slice(i * nSamples, (i + 1 ) * nSamples)))
186- }.toArray
187+ ((0 until k).map(_ => 1.0 / k).toArray, (0 until k).map{ i =>
188+ val slice = samples.view(i * nSamples, (i + 1 ) * nSamples)
189+ new MultivariateGaussian (vectorMean(slice), initCovariance(slice))
190+ }.toArray)
187191 }
188192 }
189193
@@ -192,47 +196,36 @@ class GaussianMixtureModelEM private (
192196
193197 var iter = 0
194198 do {
195- // pivot gaussians into weight and distribution arrays
196- val weights = (0 until k).map(i => gaussians(i)._1).toArray
197- val dists = (0 until k).map{ i =>
198- new MultivariateGaussian (gaussians(i)._2, gaussians(i)._3)
199- }.toArray
200-
201199 // create and broadcast curried cluster contribution function
202- val compute = ctx .broadcast(computeExpectation(weights, dists )_)
200+ val compute = sc .broadcast(computeExpectation(weights, gaussians )_)
203201
204202 // aggregate the cluster contribution for all sample points
205- val sums = breezeData.aggregate(zeroExpectationSum(k, d))(compute.value, addExpectationSums)
206-
207- // Assignments to make the code more readable
208- val logLikelihood = sums._1(0 )
209- val W = sums._2
210- val MU = sums._3
211- val SIGMA = sums._4
203+ val (logLikelihood, wSums, muSums, sigmaSums) =
204+ breezeData.aggregate(zeroExpectationSum(k, d))(compute.value, addExpectationSums)
212205
213206 // Create new distributions based on the partial assignments
214207 // (often referred to as the "M" step in literature)
215- gaussians = (0 until k).map{ i =>
216- val weight = W (i) / W .sum
217- val mu = MU (i) / W (i)
218- val sigma = SIGMA (i) / W (i) - mu * new Transpose (mu)
219- (weight, mu, sigma)
220- }.toArray
221-
208+ val sumWeights = wSums.sum
209+ for (i <- 0 until k) {
210+ val mu = muSums(i) / wSums(i)
211+ val sigma = sigmaSums(i) / wSums(i) - mu * new Transpose (mu)
212+ weights(i) = wSums(i) / sumWeights
213+ gaussians(i) = new MultivariateGaussian (mu, sigma)
214+ }
215+
222216 llhp = llh // current becomes previous
223- llh = logLikelihood // this is the freshly computed log-likelihood
217+ llh = logLikelihood( 0 ) // this is the freshly computed log-likelihood
224218 iter += 1
225219 } while (iter < maxIterations && Math .abs(llh- llhp) > convergenceTol)
226220
227221 // Need to convert the breeze matrices to MLlib matrices
228- val weights = (0 until k).map(i => gaussians(i)._1).toArray
229- val means = (0 until k).map(i => Vectors .fromBreeze(gaussians(i)._2)).toArray
230- val sigmas = (0 until k).map(i => Matrices .fromBreeze(gaussians(i)._3)).toArray
222+ val means = (0 until k).map(i => Vectors .fromBreeze(gaussians(i).mu)).toArray
223+ val sigmas = (0 until k).map(i => Matrices .fromBreeze(gaussians(i).sigma)).toArray
231224 new GaussianMixtureModel (weights, means, sigmas)
232225 }
233226
234227 /** Average of dense breeze vectors */
235- private def vectorMean (x : Array [ DenseDoubleVector ] ): DenseDoubleVector = {
228+ private def vectorMean (x : VectorArrayView ): DenseDoubleVector = {
236229 val v = BreezeVector .zeros[Double ](x(0 ).length)
237230 x.foreach(xi => v += xi)
238231 v / x.length.asInstanceOf [Double ]
@@ -242,7 +235,7 @@ class GaussianMixtureModelEM private (
242235 * Construct matrix where diagonal entries are element-wise
243236 * variance of input vectors (computes biased variance)
244237 */
245- private def initCovariance (x : Array [ DenseDoubleVector ] ): DenseDoubleMatrix = {
238+ private def initCovariance (x : VectorArrayView ): DenseDoubleMatrix = {
246239 val mu = vectorMean(x)
247240 val ss = BreezeVector .zeros[Double ](x(0 ).length)
248241 val cov = BreezeMatrix .eye[Double ](ss.length)
@@ -255,15 +248,18 @@ class GaussianMixtureModelEM private (
255248 * Given the input vectors, return the membership value of each vector
256249 * to all mixture components.
257250 */
258- def predictClusters (points : RDD [Vector ], mu : Array [Vector ], sigma : Array [Matrix ],
251+ def predictClusters (
252+ points : RDD [Vector ],
253+ mu : Array [Vector ],
254+ sigma : Array [Matrix ],
259255 weight : Array [Double ], k : Int ): RDD [Array [Double ]] = {
260- val ctx = points.sparkContext
261- val dists = ctx .broadcast{
256+ val sc = points.sparkContext
257+ val dists = sc .broadcast{
262258 (0 until k).map{ i =>
263259 new MultivariateGaussian (mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix)
264260 }.toArray
265261 }
266- val weights = ctx .broadcast((0 until k).map(i => weight(i)).toArray)
262+ val weights = sc .broadcast((0 until k).map(i => weight(i)).toArray)
267263 points.map{ x =>
268264 computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k)
269265 }
@@ -272,11 +268,14 @@ class GaussianMixtureModelEM private (
272268 /**
273269 * Compute the partial assignments for each vector
274270 */
275- def computeSoftAssignments (pt : DenseDoubleVector , dists : Array [MultivariateGaussian ],
276- weights : Array [Double ], k : Int ): Array [Double ] = {
277- val p = (0 until k).map(i => eps + weights(i) * dists(i).pdf(pt)).toArray
271+ private def computeSoftAssignments (
272+ pt : DenseDoubleVector ,
273+ dists : Array [MultivariateGaussian ],
274+ weights : Array [Double ],
275+ k : Int ): Array [Double ] = {
276+ val p = weights.zip(dists).map { case (weight, dist) => eps + weight * dist.pdf(pt) }
278277 val pSum = p.sum
279- for (i <- 0 until k){
278+ for (i <- 0 until k){
280279 p(i) /= pSum
281280 }
282281 p
0 commit comments