Skip to content

Commit c15405c

Browse files
author
Travis Galoppo
committed
SPARK-4156
1 parent c7ad085 commit c15405c

File tree

3 files changed

+325
-0
lines changed

3 files changed

+325
-0
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.mllib
19+
20+
import org.apache.spark.{SparkConf, SparkContext}
21+
import org.apache.spark.mllib.clustering.GaussianMixtureModel
22+
import org.apache.spark.mllib.clustering.GMMExpectationMaximization
23+
import org.apache.spark.mllib.linalg.Vectors
24+
25+
object DenseGmmEM {
26+
def main(args: Array[String]): Unit = {
27+
if( args.length != 3 ) {
28+
println("usage: DenseGmmEM <input file> <k> <delta>")
29+
} else {
30+
run(args(0), args(1).toInt, args(2).toDouble)
31+
}
32+
}
33+
34+
def run(inputFile: String, k: Int, tol: Double) {
35+
val conf = new SparkConf().setAppName("Spark EM Sample")
36+
val ctx = new SparkContext(conf)
37+
38+
val data = ctx.textFile(inputFile).map(line =>
39+
Vectors.dense(line.trim.split(' ').map(_.toDouble))).cache()
40+
41+
val clusters = GMMExpectationMaximization.train(data, k)
42+
43+
for(i <- 0 until clusters.k) {
44+
println("w=%f mu=%s sigma=\n%s\n" format (clusters.w(i), clusters.mu(i), clusters.sigma(i)))
45+
}
46+
}
47+
}
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.clustering
19+
20+
import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix}
21+
import breeze.linalg.{Transpose, det, inv}
22+
import org.apache.spark.rdd.RDD
23+
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors}
24+
import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext}
25+
import org.apache.spark.SparkContext.DoubleAccumulatorParam
26+
27+
/**
28+
* Expectation-Maximization for multivariate Gaussian Mixture Models.
29+
*
30+
*/
31+
object GMMExpectationMaximization {
32+
/**
33+
* Trains a GMM using the given parameters
34+
*
35+
* @param data training points stores as RDD[Vector]
36+
* @param k the number of Gaussians in the mixture
37+
* @param maxIterations the maximum number of iterations to perform
38+
* @param delta change in log-likelihood at which convergence is considered achieved
39+
*/
40+
def train(data: RDD[Vector], k: Int, maxIterations: Int, delta: Double): GaussianMixtureModel = {
41+
new GMMExpectationMaximization().setK(k)
42+
.setMaxIterations(maxIterations)
43+
.setDelta(delta)
44+
.run(data)
45+
}
46+
47+
/**
48+
* Trains a GMM using the given parameters
49+
*
50+
* @param data training points stores as RDD[Vector]
51+
* @param k the number of Gaussians in the mixture
52+
* @param maxIterations the maximum number of iterations to perform
53+
*/
54+
def train(data: RDD[Vector], k: Int, maxIterations: Int): GaussianMixtureModel = {
55+
new GMMExpectationMaximization().setK(k).setMaxIterations(maxIterations).run(data)
56+
}
57+
58+
/**
59+
* Trains a GMM using the given parameters
60+
*
61+
* @param data training points stores as RDD[Vector]
62+
* @param k the number of Gaussians in the mixture
63+
*/
64+
def train(data: RDD[Vector], k: Int): GaussianMixtureModel = {
65+
new GMMExpectationMaximization().setK(k).run(data)
66+
}
67+
}
68+
69+
/**
70+
* This class performs multivariate Gaussian expectation maximization. It will
71+
* maximize the log-likelihood for a mixture of k Gaussians, iterating until
72+
* the log-likelihood changes by less than delta, or until it has reached
73+
* the max number of iterations.
74+
*/
75+
class GMMExpectationMaximization private (
76+
private var k: Int,
77+
private var delta: Double,
78+
private var maxIterations: Int) extends Serializable {
79+
80+
// Type aliases for convenience
81+
private type DenseDoubleVector = BreezeVector[Double]
82+
private type DenseDoubleMatrix = BreezeMatrix[Double]
83+
84+
// A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold
85+
def this() = this(2, 0.01, 100)
86+
87+
/** Set the number of Gaussians in the mixture model. Default: 2 */
88+
def setK(k: Int): this.type = {
89+
this.k = k
90+
this
91+
}
92+
93+
/** Set the maximum number of iterations to run. Default: 100 */
94+
def setMaxIterations(maxIterations: Int): this.type = {
95+
this.maxIterations = maxIterations
96+
this
97+
}
98+
99+
/**
100+
* Set the largest change in log-likelihood at which convergence is
101+
* considered to have occurred.
102+
*/
103+
def setDelta(delta: Double): this.type = {
104+
this.delta = delta
105+
this
106+
}
107+
108+
/** Machine precision value used to ensure matrix conditioning */
109+
private val eps = math.pow(2.0, -52)
110+
111+
/** Perform expectation maximization */
112+
def run(data: RDD[Vector]): GaussianMixtureModel = {
113+
val ctx = data.sparkContext
114+
115+
// we will operate on the data as breeze data
116+
val breezeData = data.map{ u => u.toBreeze.toDenseVector }.cache()
117+
118+
// Get length of the input vectors
119+
val d = breezeData.first.length
120+
121+
// For each Gaussian, we will initialize the mean as some random
122+
// point from the data. (This could be improved)
123+
val samples = breezeData.takeSample(true, k, scala.util.Random.nextInt)
124+
125+
// C will be array of (weight, mean, covariance) tuples
126+
// we start with uniform weights, a random mean from the data, and
127+
// identity matrices for covariance
128+
var C = (0 until k).map(i => (1.0/k,
129+
samples(i),
130+
BreezeMatrix.eye[Double](d))).toArray
131+
132+
val acc_w = new Array[Accumulator[Double]](k)
133+
val acc_mu = new Array[Accumulator[DenseDoubleVector]](k)
134+
val acc_sigma = new Array[Accumulator[DenseDoubleMatrix]](k)
135+
136+
var llh = Double.MinValue // current log-likelihood
137+
var llhp = 0.0 // previous log-likelihood
138+
139+
var i, iter = 0
140+
do {
141+
// reset accumulators
142+
for(i <- 0 until k){
143+
acc_w(i) = ctx.accumulator(0.0)
144+
acc_mu(i) = ctx.accumulator(
145+
BreezeVector.zeros[Double](d))(DenseDoubleVectorAccumulatorParam)
146+
acc_sigma(i) = ctx.accumulator(
147+
BreezeMatrix.zeros[Double](d,d))(DenseDoubleMatrixAccumulatorParam)
148+
}
149+
150+
val log_likelihood = ctx.accumulator(0.0)
151+
152+
// broadcast the current weights and distributions to all nodes
153+
val dists = ctx.broadcast((0 until k).map(i =>
154+
new MultivariateGaussian(C(i)._2, C(i)._3)).toArray)
155+
val weights = ctx.broadcast((0 until k).map(i => C(i)._1).toArray)
156+
157+
// calculate partial assignments for each sample in the data
158+
// (often referred to as the "E" step in literature)
159+
breezeData.foreach(x => {
160+
val p = (0 until k).map(i =>
161+
eps + weights.value(i) * dists.value(i).pdf(x)).toArray
162+
val norm = sum(p)
163+
164+
log_likelihood += math.log(norm)
165+
166+
// accumulate weighted sums
167+
for(i <- 0 until k){
168+
p(i) /= norm
169+
acc_w(i) += p(i)
170+
acc_mu(i) += x * p(i)
171+
acc_sigma(i) += x * new Transpose(x) * p(i)
172+
}
173+
})
174+
175+
// Collect the computed sums
176+
val W = (0 until k).map(i => acc_w(i).value).toArray
177+
val MU = (0 until k).map(i => acc_mu(i).value).toArray
178+
val SIGMA = (0 until k).map(i => acc_sigma(i).value).toArray
179+
180+
// Create new distributions based on the partial assignments
181+
// (often referred to as the "M" step in literature)
182+
C = (0 until k).map(i => {
183+
val weight = W(i) / sum(W)
184+
val mu = MU(i) / W(i)
185+
val sigma = SIGMA(i) / W(i) - mu * new Transpose(mu)
186+
(weight, mu, sigma)
187+
}).toArray
188+
189+
llhp = llh; // current becomes previous
190+
llh = log_likelihood.value // this is the freshly computed log-likelihood
191+
iter += 1
192+
} while(iter < maxIterations && Math.abs(llh-llhp) > delta)
193+
194+
// Need to convert the breeze matrices to MLlib matrices
195+
val weights = (0 until k).map(i => C(i)._1).toArray
196+
val means = (0 until k).map(i => Vectors.fromBreeze(C(i)._2)).toArray
197+
val sigmas = (0 until k).map(i => Matrices.fromBreeze(C(i)._3)).toArray
198+
new GaussianMixtureModel(weights, means, sigmas)
199+
}
200+
201+
/** Sum the values in array of doubles */
202+
private def sum(x : Array[Double]) : Double = {
203+
var s : Double = 0.0
204+
x.foreach(u => s += u)
205+
s
206+
}
207+
208+
/** AccumulatorParam for Dense Breeze Vectors */
209+
private object DenseDoubleVectorAccumulatorParam extends AccumulatorParam[DenseDoubleVector] {
210+
def zero(initialVector : DenseDoubleVector) : DenseDoubleVector = {
211+
BreezeVector.zeros[Double](initialVector.length)
212+
}
213+
214+
def addInPlace(a : DenseDoubleVector, b : DenseDoubleVector) : DenseDoubleVector = {
215+
a += b
216+
}
217+
}
218+
219+
/** AccumulatorParam for Dense Breeze Matrices */
220+
private object DenseDoubleMatrixAccumulatorParam extends AccumulatorParam[DenseDoubleMatrix] {
221+
def zero(initialVector : DenseDoubleMatrix) : DenseDoubleMatrix = {
222+
BreezeMatrix.zeros[Double](initialVector.rows, initialVector.cols)
223+
}
224+
225+
def addInPlace(a : DenseDoubleMatrix, b : DenseDoubleMatrix) : DenseDoubleMatrix = {
226+
a += b
227+
}
228+
}
229+
230+
/**
231+
* Utility class to implement the density function for multivariate Gaussian distribution.
232+
* Breeze provides this functionality, but it requires the Apache Commons Math library,
233+
* so this class is here so-as to not introduce a new dependency in Spark.
234+
*/
235+
private class MultivariateGaussian(val mu : DenseDoubleVector, val sigma : DenseDoubleMatrix)
236+
extends Serializable {
237+
private val sigma_inv_2 = inv(sigma) * -0.5
238+
private val U = math.pow(2.0*math.Pi, -mu.length/2.0) * math.pow(det(sigma), -0.5)
239+
240+
def pdf(x : DenseDoubleVector) : Double = {
241+
val delta = x - mu
242+
val delta_t = new Transpose(delta)
243+
U * math.exp(delta_t * sigma_inv_2 * delta)
244+
}
245+
}
246+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.clustering
19+
20+
import org.apache.spark.mllib.linalg.Matrix
21+
import org.apache.spark.mllib.linalg.Vector
22+
23+
/**
24+
* Multivariate Gaussian mixture model consisting of k Gaussians, where points are drawn
25+
* from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are the respective
26+
* mean and covariance for each Gaussian distribution i=1..k.
27+
*/
28+
class GaussianMixtureModel(val w: Array[Double], val mu: Array[Vector], val sigma: Array[Matrix]) {
29+
30+
/** Number of gaussians in mixture */
31+
def k: Int = w.length;
32+
}

0 commit comments

Comments
 (0)