Skip to content

Commit 308c8ad

Browse files
committed
Numerous changes to improve code
1 parent cff73e0 commit 308c8ad

File tree

3 files changed

+64
-56
lines changed

3 files changed

+64
-56
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ import org.apache.spark.{SparkConf, SparkContext}
2121
import org.apache.spark.mllib.clustering.GaussianMixtureModelEM
2222
import org.apache.spark.mllib.linalg.Vectors
2323

24+
/**
25+
* An example Gaussian Mixture Model EM app. Run with
26+
* {{{
27+
* ./bin/run-example org.apache.spark.examples.mllib.DenseGmmEM <input> <k> <covergenceTol>
28+
* }}}
29+
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
30+
*/
2431
object DenseGmmEM {
2532
def main(args: Array[String]): Unit = {
2633
if (args.length != 3) {
@@ -44,13 +51,15 @@ object DenseGmmEM {
4451
.run(data)
4552

4653
for (i <- 0 until clusters.k) {
47-
println("weight=%f mu=%s sigma=\n%s\n" format
54+
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
4855
(clusters.weight(i), clusters.mu(i), clusters.sigma(i)))
4956
}
5057

58+
println("Cluster labels:")
5159
val (responsibilityMatrix, clusterLabels) = clusters.predict(data)
5260
for (x <- clusterLabels.collect) {
5361
print(" " + x)
5462
}
63+
println
5564
}
5665
}

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala

Lines changed: 52 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ import org.apache.spark.mllib.stat.impl.MultivariateGaussian
2626
import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext}
2727
import org.apache.spark.SparkContext.DoubleAccumulatorParam
2828

29+
import scala.collection.mutable.IndexedSeqView
30+
2931
/**
3032
* This class performs expectation maximization for multivariate Gaussian
3133
* Mixture Models (GMMs). A GMM represents a composite distribution of
@@ -51,6 +53,7 @@ class GaussianMixtureModelEM private (
5153
// Type aliases for convenience
5254
private type DenseDoubleVector = BreezeVector[Double]
5355
private type DenseDoubleMatrix = BreezeMatrix[Double]
56+
private type VectorArrayView = IndexedSeqView[DenseDoubleVector, Array[DenseDoubleVector]]
5457

5558
private type ExpectationSum = (
5659
Array[Double], // log-likelihood in index 0
@@ -80,10 +83,12 @@ class GaussianMixtureModelEM private (
8083

8184
// compute cluster contributions for each input point
8285
// (U, T) => U for aggregation
83-
private def computeExpectation(weights: Array[Double], dists: Array[MultivariateGaussian])
86+
private def computeExpectation(
87+
weights: Array[Double],
88+
dists: Array[MultivariateGaussian])
8489
(model: ExpectationSum, x: DenseDoubleVector): ExpectationSum = {
8590
val k = model._2.length
86-
val p = (0 until k).map(i => eps + weights(i) * dists(i).pdf(x)).toArray
91+
val p = weights.zip(dists).map { case (weight, dist) => eps + weight * dist.pdf(x) }
8792
val pSum = p.sum
8893
model._1(0) += math.log(pSum)
8994
val xxt = x * new Transpose(x)
@@ -106,7 +111,10 @@ class GaussianMixtureModelEM private (
106111
/** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
107112
def this() = this(2, 0.01, 100)
108113

109-
/** Set the initial GMM starting point, bypassing the random initialization */
114+
/** Set the initial GMM starting point, bypassing the random initialization.
115+
* You must call setK() prior to calling this method, and the condition
116+
* (gmm.k == this.k) must be met; failure will result in an IllegalArgumentException
117+
*/
110118
def setInitialGmm(gmm: GaussianMixtureModel): this.type = {
111119
if (gmm.k == k) {
112120
initialGmm = Some(gmm)
@@ -156,34 +164,30 @@ class GaussianMixtureModelEM private (
156164

157165
/** Perform expectation maximization */
158166
def run(data: RDD[Vector]): GaussianMixtureModel = {
159-
val ctx = data.sparkContext
167+
val sc = data.sparkContext
160168

161169
// we will operate on the data as breeze data
162170
val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
163171

164172
// Get length of the input vectors
165173
val d = breezeData.first.length
166174

167-
// gaussians will be array of (weight, mean, covariance) tuples.
175+
// Determine initial weights and corresponding Gaussians.
168176
// If the user supplied an initial GMM, we use those values, otherwise
169177
// we start with uniform weights, a random mean from the data, and
170178
// diagonal covariance matrices using component variances
171-
// derived from the samples
172-
var gaussians = initialGmm match {
173-
case Some(gmm) => (0 until k).map{ i =>
174-
(gmm.weight(i), gmm.mu(i).toBreeze.toDenseVector, gmm.sigma(i).toBreeze.toDenseMatrix)
175-
}.toArray
179+
// derived from the samples
180+
val (weights, gaussians) = initialGmm match {
181+
case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map{ case(mu, sigma) =>
182+
new MultivariateGaussian(mu.toBreeze.toDenseVector, sigma.toBreeze.toDenseMatrix)
183+
}.toArray)
176184

177185
case None => {
178-
// For each Gaussian, we will initialize the mean as the average
179-
// of some random samples from the data
180186
val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
181-
182-
(0 until k).map{ i =>
183-
(1.0 / k,
184-
vectorMean(samples.slice(i * nSamples, (i + 1) * nSamples)),
185-
initCovariance(samples.slice(i * nSamples, (i + 1) * nSamples)))
186-
}.toArray
187+
((0 until k).map(_ => 1.0 / k).toArray, (0 until k).map{ i =>
188+
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
189+
new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
190+
}.toArray)
187191
}
188192
}
189193

@@ -192,47 +196,36 @@ class GaussianMixtureModelEM private (
192196

193197
var iter = 0
194198
do {
195-
// pivot gaussians into weight and distribution arrays
196-
val weights = (0 until k).map(i => gaussians(i)._1).toArray
197-
val dists = (0 until k).map{ i =>
198-
new MultivariateGaussian(gaussians(i)._2, gaussians(i)._3)
199-
}.toArray
200-
201199
// create and broadcast curried cluster contribution function
202-
val compute = ctx.broadcast(computeExpectation(weights, dists)_)
200+
val compute = sc.broadcast(computeExpectation(weights, gaussians)_)
203201

204202
// aggregate the cluster contribution for all sample points
205-
val sums = breezeData.aggregate(zeroExpectationSum(k, d))(compute.value, addExpectationSums)
206-
207-
// Assignments to make the code more readable
208-
val logLikelihood = sums._1(0)
209-
val W = sums._2
210-
val MU = sums._3
211-
val SIGMA = sums._4
203+
val (logLikelihood, wSums, muSums, sigmaSums) =
204+
breezeData.aggregate(zeroExpectationSum(k, d))(compute.value, addExpectationSums)
212205

213206
// Create new distributions based on the partial assignments
214207
// (often referred to as the "M" step in literature)
215-
gaussians = (0 until k).map{ i =>
216-
val weight = W(i) / W.sum
217-
val mu = MU(i) / W(i)
218-
val sigma = SIGMA(i) / W(i) - mu * new Transpose(mu)
219-
(weight, mu, sigma)
220-
}.toArray
221-
208+
val sumWeights = wSums.sum
209+
for (i <- 0 until k) {
210+
val mu = muSums(i) / wSums(i)
211+
val sigma = sigmaSums(i) / wSums(i) - mu * new Transpose(mu)
212+
weights(i) = wSums(i) / sumWeights
213+
gaussians(i) = new MultivariateGaussian(mu, sigma)
214+
}
215+
222216
llhp = llh // current becomes previous
223-
llh = logLikelihood // this is the freshly computed log-likelihood
217+
llh = logLikelihood(0) // this is the freshly computed log-likelihood
224218
iter += 1
225219
} while(iter < maxIterations && Math.abs(llh-llhp) > convergenceTol)
226220

227221
// Need to convert the breeze matrices to MLlib matrices
228-
val weights = (0 until k).map(i => gaussians(i)._1).toArray
229-
val means = (0 until k).map(i => Vectors.fromBreeze(gaussians(i)._2)).toArray
230-
val sigmas = (0 until k).map(i => Matrices.fromBreeze(gaussians(i)._3)).toArray
222+
val means = (0 until k).map(i => Vectors.fromBreeze(gaussians(i).mu)).toArray
223+
val sigmas = (0 until k).map(i => Matrices.fromBreeze(gaussians(i).sigma)).toArray
231224
new GaussianMixtureModel(weights, means, sigmas)
232225
}
233226

234227
/** Average of dense breeze vectors */
235-
private def vectorMean(x: Array[DenseDoubleVector]): DenseDoubleVector = {
228+
private def vectorMean(x: VectorArrayView): DenseDoubleVector = {
236229
val v = BreezeVector.zeros[Double](x(0).length)
237230
x.foreach(xi => v += xi)
238231
v / x.length.asInstanceOf[Double]
@@ -242,7 +235,7 @@ class GaussianMixtureModelEM private (
242235
* Construct matrix where diagonal entries are element-wise
243236
* variance of input vectors (computes biased variance)
244237
*/
245-
private def initCovariance(x: Array[DenseDoubleVector]): DenseDoubleMatrix = {
238+
private def initCovariance(x: VectorArrayView): DenseDoubleMatrix = {
246239
val mu = vectorMean(x)
247240
val ss = BreezeVector.zeros[Double](x(0).length)
248241
val cov = BreezeMatrix.eye[Double](ss.length)
@@ -255,15 +248,18 @@ class GaussianMixtureModelEM private (
255248
* Given the input vectors, return the membership value of each vector
256249
* to all mixture components.
257250
*/
258-
def predictClusters(points: RDD[Vector], mu: Array[Vector], sigma: Array[Matrix],
251+
def predictClusters(
252+
points: RDD[Vector],
253+
mu: Array[Vector],
254+
sigma: Array[Matrix],
259255
weight: Array[Double], k: Int): RDD[Array[Double]] = {
260-
val ctx = points.sparkContext
261-
val dists = ctx.broadcast{
256+
val sc = points.sparkContext
257+
val dists = sc.broadcast{
262258
(0 until k).map{ i =>
263259
new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix)
264260
}.toArray
265261
}
266-
val weights = ctx.broadcast((0 until k).map(i => weight(i)).toArray)
262+
val weights = sc.broadcast((0 until k).map(i => weight(i)).toArray)
267263
points.map{ x =>
268264
computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k)
269265
}
@@ -272,11 +268,14 @@ class GaussianMixtureModelEM private (
272268
/**
273269
* Compute the partial assignments for each vector
274270
*/
275-
def computeSoftAssignments(pt: DenseDoubleVector, dists: Array[MultivariateGaussian],
276-
weights: Array[Double], k: Int): Array[Double] = {
277-
val p = (0 until k).map(i => eps + weights(i) * dists(i).pdf(pt)).toArray
271+
private def computeSoftAssignments(
272+
pt: DenseDoubleVector,
273+
dists: Array[MultivariateGaussian],
274+
weights: Array[Double],
275+
k: Int): Array[Double] = {
276+
val p = weights.zip(dists).map { case (weight, dist) => eps + weight * dist.pdf(pt) }
278277
val pSum = p.sum
279-
for(i<- 0 until k){
278+
for (i <- 0 until k){
280279
p(i) /= pSum
281280
}
282281
p

mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.mllib.stat.impl
1919

2020
import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix}
21-
import breeze.linalg.{Transpose, det, inv}
21+
import breeze.linalg.{Transpose, det, pinv}
2222

2323
/**
2424
* Utility class to implement the density function for multivariate Gaussian distribution.
@@ -28,7 +28,7 @@ import breeze.linalg.{Transpose, det, inv}
2828
private[mllib] class MultivariateGaussian(
2929
val mu: BreezeVector[Double],
3030
val sigma: BreezeMatrix[Double]) extends Serializable {
31-
private val sigmaInv2 = inv(sigma) * -0.5
31+
private val sigmaInv2 = pinv(sigma) * -0.5
3232
private val U = math.pow(2.0 * math.Pi, -mu.length / 2.0) * math.pow(det(sigma), -0.5)
3333

3434
def pdf(x: BreezeVector[Double]): Double = {

0 commit comments

Comments
 (0)