@@ -39,23 +39,24 @@ class PythonMLLibAPI extends Serializable {
3939 private val DENSE_VECTOR_MAGIC : Byte = 1
4040 private val SPARSE_VECTOR_MAGIC : Byte = 2
4141 private val DENSE_MATRIX_MAGIC : Byte = 3
42+ private val LABELED_POINT_MAGIC : Byte = 4
4243
43- private def deserializeDoubleVector (bytes : Array [Byte ]): Vector = {
44- require(bytes.length >= 5 , " Byte array too short" )
45- val magic = bytes(0 )
44+ private def deserializeDoubleVector (bytes : Array [Byte ], offset : Int = 0 ): Vector = {
45+ require(bytes.length - offset >= 5 , " Byte array too short" )
46+ val magic = bytes(offset )
4647 if (magic == DENSE_VECTOR_MAGIC ) {
47- deserializeDenseVector(bytes)
48+ deserializeDenseVector(bytes, offset )
4849 } else if (magic == SPARSE_VECTOR_MAGIC ) {
49- deserializeSparseVector(bytes)
50+ deserializeSparseVector(bytes, offset )
5051 } else {
5152 throw new IllegalArgumentException (" Magic " + magic + " is wrong." )
5253 }
5354 }
5455
55- private def deserializeDenseVector (bytes : Array [Byte ]): Vector = {
56- val packetLength = bytes.length
56+ private def deserializeDenseVector (bytes : Array [Byte ], offset : Int = 0 ): Vector = {
57+ val packetLength = bytes.length - offset
5758 require(packetLength >= 5 , " Byte array too short" )
58- val bb = ByteBuffer .wrap(bytes)
59+ val bb = ByteBuffer .wrap(bytes, offset, bytes.length - offset )
5960 bb.order(ByteOrder .nativeOrder())
6061 val magic = bb.get()
6162 require(magic == DENSE_VECTOR_MAGIC , " Invalid magic: " + magic)
@@ -67,10 +68,10 @@ class PythonMLLibAPI extends Serializable {
6768 Vectors .dense(ans)
6869 }
6970
70- private def deserializeSparseVector (bytes : Array [Byte ]): Vector = {
71- val packetLength = bytes.length
71+ private def deserializeSparseVector (bytes : Array [Byte ], offset : Int = 0 ): Vector = {
72+ val packetLength = bytes.length - offset
7273 require(packetLength >= 9 , " Byte array too short" )
73- val bb = ByteBuffer .wrap(bytes)
74+ val bb = ByteBuffer .wrap(bytes, offset, bytes.length - offset )
7475 bb.order(ByteOrder .nativeOrder())
7576 val magic = bb.get()
7677 require(magic == SPARSE_VECTOR_MAGIC , " Invalid magic: " + magic)
@@ -166,14 +167,23 @@ class PythonMLLibAPI extends Serializable {
166167 bytes
167168 }
168169
170+ private def deserializeLabeledPoint (bytes : Array [Byte ]): LabeledPoint = {
171+ require(bytes.length >= 9 , " Byte array too short" )
172+ val magic = bytes(0 )
173+ if (magic != LABELED_POINT_MAGIC ) {
174+ throw new IllegalArgumentException (" Magic " + magic + " is wrong." )
175+ }
176+ val labelBytes = ByteBuffer .wrap(bytes, 1 , 8 )
177+ labelBytes.order(ByteOrder .nativeOrder())
178+ val label = labelBytes.asDoubleBuffer().get(0 )
179+ LabeledPoint (label, deserializeDoubleVector(bytes, 9 ))
180+ }
181+
169182 private def trainRegressionModel (
170183 trainFunc : (RDD [LabeledPoint ], Vector ) => GeneralizedLinearModel ,
171184 dataBytesJRDD : JavaRDD [Array [Byte ]],
172185 initialWeightsBA : Array [Byte ]): java.util.LinkedList [java.lang.Object ] = {
173- val data = dataBytesJRDD.rdd.map(xBytes => {
174- val x = deserializeDoubleVector(xBytes)
175- LabeledPoint (x(0 ), x.slice(1 , x.size))
176- })
186+ val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
177187 val initialWeights = deserializeDoubleVector(initialWeightsBA)
178188 val model = trainFunc(data, initialWeights)
179189 val ret = new java.util.LinkedList [java.lang.Object ]()
@@ -299,10 +309,7 @@ class PythonMLLibAPI extends Serializable {
299309 def trainNaiveBayes (
300310 dataBytesJRDD : JavaRDD [Array [Byte ]],
301311 lambda : Double ): java.util.List [java.lang.Object ] = {
302- val data = dataBytesJRDD.rdd.map(xBytes => {
303- val x = deserializeDoubleVector(xBytes)
304- LabeledPoint (x(0 ), x.slice(1 , x.size))
305- })
312+ val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
306313 val model = NaiveBayes .train(data, lambda)
307314 val ret = new java.util.LinkedList [java.lang.Object ]()
308315 ret.add(serializeDoubleVector(Vectors .dense(model.labels)))
@@ -320,7 +327,7 @@ class PythonMLLibAPI extends Serializable {
320327 maxIterations : Int ,
321328 runs : Int ,
322329 initializationMode : String ): java.util.List [java.lang.Object ] = {
323- val data = dataBytesJRDD.rdd.map(deserializeDoubleVector)
330+ val data = dataBytesJRDD.rdd.map(bytes => deserializeDoubleVector(bytes) )
324331 val model = KMeans .train(data, k, maxIterations, runs, initializationMode)
325332 val ret = new java.util.LinkedList [java.lang.Object ]()
326333 ret.add(serializeDoubleMatrix(model.clusterCenters.map(_.toArray)))
0 commit comments