Skip to content

Commit 74eefe7

Browse files
committed
Added LabeledPoint class in Python
1 parent 889dde8 commit 74eefe7

File tree

5 files changed

+261
-127
lines changed

5 files changed

+261
-127
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,24 @@ class PythonMLLibAPI extends Serializable {
3939
private val DENSE_VECTOR_MAGIC: Byte = 1
4040
private val SPARSE_VECTOR_MAGIC: Byte = 2
4141
private val DENSE_MATRIX_MAGIC: Byte = 3
42+
private val LABELED_POINT_MAGIC: Byte = 4
4243

43-
private def deserializeDoubleVector(bytes: Array[Byte]): Vector = {
44-
require(bytes.length >= 5, "Byte array too short")
45-
val magic = bytes(0)
44+
private def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = {
45+
require(bytes.length - offset >= 5, "Byte array too short")
46+
val magic = bytes(offset)
4647
if (magic == DENSE_VECTOR_MAGIC) {
47-
deserializeDenseVector(bytes)
48+
deserializeDenseVector(bytes, offset)
4849
} else if (magic == SPARSE_VECTOR_MAGIC) {
49-
deserializeSparseVector(bytes)
50+
deserializeSparseVector(bytes, offset)
5051
} else {
5152
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
5253
}
5354
}
5455

55-
private def deserializeDenseVector(bytes: Array[Byte]): Vector = {
56-
val packetLength = bytes.length
56+
private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
57+
val packetLength = bytes.length - offset
5758
require(packetLength >= 5, "Byte array too short")
58-
val bb = ByteBuffer.wrap(bytes)
59+
val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
5960
bb.order(ByteOrder.nativeOrder())
6061
val magic = bb.get()
6162
require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic)
@@ -67,10 +68,10 @@ class PythonMLLibAPI extends Serializable {
6768
Vectors.dense(ans)
6869
}
6970

70-
private def deserializeSparseVector(bytes: Array[Byte]): Vector = {
71-
val packetLength = bytes.length
71+
private def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
72+
val packetLength = bytes.length - offset
7273
require(packetLength >= 9, "Byte array too short")
73-
val bb = ByteBuffer.wrap(bytes)
74+
val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
7475
bb.order(ByteOrder.nativeOrder())
7576
val magic = bb.get()
7677
require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic)
@@ -166,14 +167,23 @@ class PythonMLLibAPI extends Serializable {
166167
bytes
167168
}
168169

170+
private def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = {
171+
require(bytes.length >= 9, "Byte array too short")
172+
val magic = bytes(0)
173+
if (magic != LABELED_POINT_MAGIC) {
174+
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
175+
}
176+
val labelBytes = ByteBuffer.wrap(bytes, 1, 8)
177+
labelBytes.order(ByteOrder.nativeOrder())
178+
val label = labelBytes.asDoubleBuffer().get(0)
179+
LabeledPoint(label, deserializeDoubleVector(bytes, 9))
180+
}
181+
169182
private def trainRegressionModel(
170183
trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel,
171184
dataBytesJRDD: JavaRDD[Array[Byte]],
172185
initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
173-
val data = dataBytesJRDD.rdd.map(xBytes => {
174-
val x = deserializeDoubleVector(xBytes)
175-
LabeledPoint(x(0), x.slice(1, x.size))
176-
})
186+
val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
177187
val initialWeights = deserializeDoubleVector(initialWeightsBA)
178188
val model = trainFunc(data, initialWeights)
179189
val ret = new java.util.LinkedList[java.lang.Object]()
@@ -299,10 +309,7 @@ class PythonMLLibAPI extends Serializable {
299309
def trainNaiveBayes(
300310
dataBytesJRDD: JavaRDD[Array[Byte]],
301311
lambda: Double): java.util.List[java.lang.Object] = {
302-
val data = dataBytesJRDD.rdd.map(xBytes => {
303-
val x = deserializeDoubleVector(xBytes)
304-
LabeledPoint(x(0), x.slice(1, x.size))
305-
})
312+
val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
306313
val model = NaiveBayes.train(data, lambda)
307314
val ret = new java.util.LinkedList[java.lang.Object]()
308315
ret.add(serializeDoubleVector(Vectors.dense(model.labels)))
@@ -320,7 +327,7 @@ class PythonMLLibAPI extends Serializable {
320327
maxIterations: Int,
321328
runs: Int,
322329
initializationMode: String): java.util.List[java.lang.Object] = {
323-
val data = dataBytesJRDD.rdd.map(deserializeDoubleVector)
330+
val data = dataBytesJRDD.rdd.map(bytes => deserializeDoubleVector(bytes))
324331
val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
325332
val ret = new java.util.LinkedList[java.lang.Object]()
326333
ret.add(serializeDoubleMatrix(model.clusterCenters.map(_.toArray)))

0 commit comments

Comments
 (0)