Skip to content

Commit 881fef7

Browse files
committed
Added a sparse vector in Python and made Java-Python format more compact
1 parent 07d72fe commit 881fef7

File tree

3 files changed

+176
-64
lines changed

3 files changed

+176
-64
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,24 @@ import org.apache.spark.rdd.RDD
3131
/**
3232
* :: DeveloperApi ::
3333
* The Java stubs necessary for the Python mllib bindings.
34+
*
35+
* See mllib/python/pyspark._common.py for the mutually agreed upon data format.
3436
*/
3537
@DeveloperApi
3638
class PythonMLLibAPI extends Serializable {
3739
private def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = {
3840
val packetLength = bytes.length
39-
if (packetLength < 16) {
41+
if (packetLength < 5) {
4042
throw new IllegalArgumentException("Byte array too short.")
4143
}
4244
val bb = ByteBuffer.wrap(bytes)
4345
bb.order(ByteOrder.nativeOrder())
44-
val magic = bb.getLong()
46+
val magic = bb.get()
4547
if (magic != 1) {
4648
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
4749
}
48-
val length = bb.getLong()
49-
if (packetLength != 16 + 8 * length) {
50+
val length = bb.getInt()
51+
if (packetLength != 5 + 8 * length) {
5052
throw new IllegalArgumentException("Length " + length + " is wrong.")
5153
}
5254
val db = bb.asDoubleBuffer()
@@ -57,30 +59,30 @@ class PythonMLLibAPI extends Serializable {
5759

5860
private def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = {
5961
val len = doubles.length
60-
val bytes = new Array[Byte](16 + 8 * len)
62+
val bytes = new Array[Byte](5 + 8 * len)
6163
val bb = ByteBuffer.wrap(bytes)
6264
bb.order(ByteOrder.nativeOrder())
63-
bb.putLong(1)
64-
bb.putLong(len)
65+
bb.put(1: Byte)
66+
bb.putInt(len)
6567
val db = bb.asDoubleBuffer()
6668
db.put(doubles)
6769
bytes
6870
}
6971

7072
private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = {
7173
val packetLength = bytes.length
72-
if (packetLength < 24) {
74+
if (packetLength < 9) {
7375
throw new IllegalArgumentException("Byte array too short.")
7476
}
7577
val bb = ByteBuffer.wrap(bytes)
7678
bb.order(ByteOrder.nativeOrder())
77-
val magic = bb.getLong()
79+
val magic = bb.get()
7880
if (magic != 2) {
7981
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
8082
}
81-
val rows = bb.getLong()
82-
val cols = bb.getLong()
83-
if (packetLength != 24 + 8 * rows * cols) {
83+
val rows = bb.getInt()
84+
val cols = bb.getInt()
85+
if (packetLength != 9 + 8 * rows * cols) {
8486
throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.")
8587
}
8688
val db = bb.asDoubleBuffer()
@@ -98,12 +100,12 @@ class PythonMLLibAPI extends Serializable {
98100
if (rows > 0) {
99101
cols = doubles(0).length
100102
}
101-
val bytes = new Array[Byte](24 + 8 * rows * cols)
103+
val bytes = new Array[Byte](9 + 8 * rows * cols)
102104
val bb = ByteBuffer.wrap(bytes)
103105
bb.order(ByteOrder.nativeOrder())
104-
bb.putLong(2)
105-
bb.putLong(rows)
106-
bb.putLong(cols)
106+
bb.put(2: Byte)
107+
bb.putInt(rows)
108+
bb.putInt(cols)
107109
val db = bb.asDoubleBuffer()
108110
for (i <- 0 until rows) {
109111
db.put(doubles(i))

python/pyspark/mllib/_common.py

Lines changed: 47 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,24 @@
1515
# limitations under the License.
1616
#
1717

18+
import struct
19+
import numpy
1820
from numpy import ndarray, float64, int64, int32, ones, array_equal, array, dot, shape, complex, issubdtype
1921
from pyspark import SparkContext, RDD
20-
import numpy as np
21-
22+
from pyspark.mllib.linalg import SparseVector
2223
from pyspark.serializers import Serializer
23-
import struct
2424

25-
# Double vector format:
25+
# Dense double vector format:
2626
#
2727
# [8-byte 1] [8-byte length] [length*8 bytes of data]
2828
#
29+
# Sparse double vector format:
30+
#
31+
# [8-byte 2] [8-byte size] [8-byte entries] [entries*4 bytes of indices] [entries*8 bytes of values]
32+
#
2933
# Double matrix format:
3034
#
31-
# [8-byte 2] [8-byte rows] [8-byte cols] [rows*cols*8 bytes of data]
35+
# [8-byte 3] [8-byte rows] [8-byte cols] [rows*cols*8 bytes of data]
3236
#
3337
# This is all in machine-endian. That means that the Java interpreter and the
3438
# Python interpreter must agree on what endian the machine is.
@@ -43,8 +47,7 @@ def _deserialize_byte_array(shape, ba, offset):
4347
>>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0))
4448
True
4549
"""
46-
ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64",
47-
order='C')
50+
ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64", order='C')
4851
return ar.copy()
4952

5053
def _serialize_double_vector(v):
@@ -58,21 +61,20 @@ def _serialize_double_vector(v):
5861
if type(v) != ndarray:
5962
raise TypeError("_serialize_double_vector called on a %s; "
6063
"wanted ndarray" % type(v))
61-
"""complex is only datatype that can't be converted to float64"""
62-
if issubdtype(v.dtype, complex):
63-
raise TypeError("_serialize_double_vector called on a %s; "
64-
"wanted ndarray" % type(v))
65-
if v.dtype != float64:
66-
v = v.astype(float64)
6764
if v.ndim != 1:
6865
raise TypeError("_serialize_double_vector called on a %ddarray; "
6966
"wanted a 1darray" % v.ndim)
67+
if v.dtype != float64:
68+
if numpy.issubdtype(v.dtype, numpy.complex):
69+
raise TypeError("_serialize_double_vector called on an ndarray of %s; "
70+
"wanted ndarray of float64" % v.dtype)
71+
v = v.astype('float64')
7072
length = v.shape[0]
71-
ba = bytearray(16 + 8*length)
72-
header = ndarray(shape=[2], buffer=ba, dtype="int64")
73-
header[0] = 1
74-
header[1] = length
75-
arr_mid = ndarray(shape=[length], buffer=ba, offset=16, dtype="float64")
73+
ba = bytearray(5 + 8 * length)
74+
ba[0] = 1
75+
length_bytes = ndarray(shape=[1], buffer=ba, offset=1, dtype="int32")
76+
length_bytes[0] = length
77+
arr_mid = ndarray(shape=[length], buffer=ba, offset=5, dtype="float64")
7678
arr_mid[...] = v
7779
return ba
7880

@@ -86,34 +88,34 @@ def _deserialize_double_vector(ba):
8688
if type(ba) != bytearray:
8789
raise TypeError("_deserialize_double_vector called on a %s; "
8890
"wanted bytearray" % type(ba))
89-
if len(ba) < 16:
91+
if len(ba) < 5:
9092
raise TypeError("_deserialize_double_vector called on a %d-byte array, "
9193
"which is too short" % len(ba))
92-
if (len(ba) & 7) != 0:
93-
raise TypeError("_deserialize_double_vector called on a %d-byte array, "
94-
"which is not a multiple of 8" % len(ba))
95-
header = ndarray(shape=[2], buffer=ba, dtype="int64")
96-
if header[0] != 1:
94+
if ba[0] != 1:
9795
raise TypeError("_deserialize_double_vector called on bytearray "
9896
"with wrong magic")
99-
length = header[1]
100-
if len(ba) != 8*length + 16:
97+
length = ndarray(shape=[1], buffer=ba, offset=1, dtype="int32")[0]
98+
if len(ba) != 8*length + 5:
10199
raise TypeError("_deserialize_double_vector called on bytearray "
102100
"with wrong length")
103-
return _deserialize_byte_array([length], ba, 16)
101+
return _deserialize_byte_array([length], ba, 5)
104102

105103
def _serialize_double_matrix(m):
106104
"""Serialize a double matrix into a mutually understood format."""
107-
if (type(m) == ndarray and m.dtype == float64 and m.ndim == 2):
105+
if (type(m) == ndarray and m.ndim == 2):
106+
if m.dtype != float64:
107+
if numpy.issubdtype(m.dtype, numpy.complex):
108+
raise TypeError("_serialize_double_matrix called on an ndarray of %s; "
109+
"wanted ndarray of float64" % m.dtype)
110+
m = m.astype('float64')
108111
rows = m.shape[0]
109112
cols = m.shape[1]
110-
ba = bytearray(24 + 8 * rows * cols)
111-
header = ndarray(shape=[3], buffer=ba, dtype="int64")
112-
header[0] = 2
113-
header[1] = rows
114-
header[2] = cols
115-
arr_mid = ndarray(shape=[rows, cols], buffer=ba, offset=24,
116-
dtype="float64", order='C')
113+
ba = bytearray(9 + 8 * rows * cols)
114+
ba[0] = 2
115+
lengths = ndarray(shape=[3], buffer=ba, offset=1, dtype="int32")
116+
lengths[0] = rows
117+
lengths[1] = cols
118+
arr_mid = ndarray(shape=[rows, cols], buffer=ba, offset=9, dtype="float64", order='C')
117119
arr_mid[...] = m
118120
return ba
119121
else:
@@ -125,22 +127,19 @@ def _deserialize_double_matrix(ba):
125127
if type(ba) != bytearray:
126128
raise TypeError("_deserialize_double_matrix called on a %s; "
127129
"wanted bytearray" % type(ba))
128-
if len(ba) < 24:
130+
if len(ba) < 9:
129131
raise TypeError("_deserialize_double_matrix called on a %d-byte array, "
130132
"which is too short" % len(ba))
131-
if (len(ba) & 7) != 0:
132-
raise TypeError("_deserialize_double_matrix called on a %d-byte array, "
133-
"which is not a multiple of 8" % len(ba))
134-
header = ndarray(shape=[3], buffer=ba, dtype="int64")
135-
if (header[0] != 2):
133+
if ba[0] != 2:
136134
raise TypeError("_deserialize_double_matrix called on bytearray "
137135
"with wrong magic")
138-
rows = header[1]
139-
cols = header[2]
140-
if (len(ba) != 8*rows*cols + 24):
136+
lengths = ndarray(shape=[2], buffer=ba, offset=1, dtype="int32")
137+
rows = lengths[0]
138+
cols = lengths[1]
139+
if (len(ba) != 8 * rows * cols + 9):
141140
raise TypeError("_deserialize_double_matrix called on bytearray "
142141
"with wrong length")
143-
return _deserialize_byte_array([rows, cols], ba, 24)
142+
return _deserialize_byte_array([rows, cols], ba, 9)
144143

145144
def _linear_predictor_typecheck(x, coeffs):
146145
"""Check that x is a one-dimensional vector of the right shape.
@@ -151,7 +150,7 @@ def _linear_predictor_typecheck(x, coeffs):
151150
pass
152151
else:
153152
raise RuntimeError("Got array of %d elements; wanted %d"
154-
% (shape(x)[0], shape(coeffs)[0]))
153+
% (numpy.shape(x)[0], numpy.shape(coeffs)[0]))
155154
else:
156155
raise RuntimeError("Bulk predict not yet supported.")
157156
elif (type(x) == RDD):
@@ -187,7 +186,7 @@ def predict(self, x):
187186
"""Predict the value of the dependent variable given a vector x"""
188187
"""containing values for the independent variables."""
189188
_linear_predictor_typecheck(x, self._coeff)
190-
return dot(self._coeff, x) + self._intercept
189+
return numpy.dot(self._coeff, x) + self._intercept
191190

192191
# If we weren't given initial weights, take a zero vector of the appropriate
193192
# length.
@@ -200,7 +199,7 @@ def _get_initial_weights(initial_weights, data):
200199
if initial_weights.ndim != 1:
201200
raise TypeError("At least one data element has "
202201
+ initial_weights.ndim + " dimensions, which is not 1")
203-
initial_weights = ones([initial_weights.shape[0] - 1])
202+
initial_weights = numpy.ones([initial_weights.shape[0] - 1])
204203
return initial_weights
205204

206205
# train_func should take two parameters, namely data and initial_weights, and

python/pyspark/mllib/linalg.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""
19+
MLlib utilities for working with vectors. For dense vectors, MLlib
20+
uses the NumPy C{array} type, so you can simply pass NumPy arrays
21+
around. For sparse vectors, users can construct a L{SparseVector}
22+
object from MLlib or pass SciPy C{scipy.sparse} column vectors if
23+
SciPy is available in their environment.
24+
"""
25+
26+
from numpy import array
27+
28+
29+
class SparseVector(object):
30+
"""
31+
A simple sparse vector class for passing data to MLlib. Users may
32+
alternatively pass use SciPy's {scipy.sparse} data types.
33+
"""
34+
35+
def __init__(self, size, *args):
36+
"""
37+
Create a sparse vector, using either an array of (index, value) pairs
38+
or two separate arrays of indices and values.
39+
40+
>>> print SparseVector(4, [(1, 1.0), (3, 5.5)])
41+
[1: 1.0, 3: 5.5]
42+
>>> print SparseVector(4, [1, 3], [1.0, 5.5])
43+
[1: 1.0, 3: 5.5]
44+
"""
45+
self.size = size
46+
assert 1 <= len(args) <= 2, "must pass either 1 or 2 arguments"
47+
if len(args) == 1:
48+
pairs = args[0]
49+
self.indices = array([p[0] for p in pairs], dtype='int32')
50+
self.values = array([p[1] for p in pairs], dtype='float64')
51+
else:
52+
assert len(args[0]) == len(args[1]), "index and value arrays not same length"
53+
self.indices = array(args[0], dtype='int32')
54+
self.values = array(args[1], dtype='float64')
55+
56+
def __str__(self):
57+
inds = self.indices
58+
vals = self.values
59+
entries = ", ".join(["{0}: {1}".format(inds[i], vals[i]) for i in xrange(len(inds))])
60+
return "[" + entries + "]"
61+
62+
def __repr__(self):
63+
inds = self.indices
64+
vals = self.values
65+
entries = ", ".join(["({0}, {1})".format(inds[i], vals[i]) for i in xrange(len(inds))])
66+
return "SparseVector({0}, [{1}])".format(self.size, entries)
67+
68+
69+
70+
class Vectors(object):
71+
"""
72+
Factory methods to create MLlib vectors. Note that dense vectors
73+
are simply represented as NumPy array objects, so there is no need
74+
to covert them for use in MLlib. For sparse vectors, the factory
75+
methods in this class create an MLlib-compatible type, or users
76+
can pass in SciPy's C{scipy.sparse} column vectors.
77+
"""
78+
79+
@staticmethod
80+
def sparse(size, *args):
81+
"""
82+
Create a sparse vector, using either an array of (index, value) pairs
83+
or two separate arrays of indices and values.
84+
85+
>>> print Vectors.sparse(4, [(1, 1.0), (3, 5.5)])
86+
[1: 1.0, 3: 5.5]
87+
>>> print Vectors.sparse(4, [1, 3], [1.0, 5.5])
88+
[1: 1.0, 3: 5.5]
89+
"""
90+
return SparseVector(size, *args)
91+
92+
@staticmethod
93+
def dense(elements):
94+
"""
95+
Create a dense vector of 64-bit floats from a Python list. Always
96+
returns a NumPy array.
97+
98+
>>> Vectors.dense([1, 2, 3])
99+
array([ 1., 2., 3.])
100+
"""
101+
return array(elements, dtype='float64')
102+
103+
104+
def _test():
105+
import doctest
106+
(failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS)
107+
if failure_count:
108+
exit(-1)
109+
110+
if __name__ == "__main__":
111+
_test()

0 commit comments

Comments
 (0)