1515# limitations under the License.
1616#
1717
18+ import struct
19+ import numpy
1820from numpy import ndarray , float64 , int64 , int32 , ones , array_equal , array , dot , shape , complex , issubdtype
1921from pyspark import SparkContext , RDD
20- import numpy as np
21-
22+ from pyspark .mllib .linalg import SparseVector
2223from pyspark .serializers import Serializer
23- import struct
2424
25- # Double vector format:
25+ # Dense double vector format:
2626#
2727# [8-byte 1] [8-byte length] [length*8 bytes of data]
2828#
29+ # Sparse double vector format:
30+ #
31+ # [8-byte 2] [8-byte size] [8-byte entries] [entries*4 bytes of indices] [entries*8 bytes of values]
32+ #
2933# Double matrix format:
3034#
31- # [8-byte 2 ] [8-byte rows] [8-byte cols] [rows*cols*8 bytes of data]
35+ # [8-byte 3 ] [8-byte rows] [8-byte cols] [rows*cols*8 bytes of data]
3236#
3337# This is all in machine-endian. That means that the Java interpreter and the
3438# Python interpreter must agree on what endian the machine is.
@@ -43,8 +47,7 @@ def _deserialize_byte_array(shape, ba, offset):
4347 >>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0))
4448 True
4549 """
46- ar = ndarray (shape = shape , buffer = ba , offset = offset , dtype = "float64" ,
47- order = 'C' )
50+ ar = ndarray (shape = shape , buffer = ba , offset = offset , dtype = "float64" , order = 'C' )
4851 return ar .copy ()
4952
5053def _serialize_double_vector (v ):
@@ -58,21 +61,20 @@ def _serialize_double_vector(v):
5861 if type (v ) != ndarray :
5962 raise TypeError ("_serialize_double_vector called on a %s; "
6063 "wanted ndarray" % type (v ))
61- """complex is only datatype that can't be converted to float64"""
62- if issubdtype (v .dtype , complex ):
63- raise TypeError ("_serialize_double_vector called on a %s; "
64- "wanted ndarray" % type (v ))
65- if v .dtype != float64 :
66- v = v .astype (float64 )
6764 if v .ndim != 1 :
6865 raise TypeError ("_serialize_double_vector called on a %ddarray; "
6966 "wanted a 1darray" % v .ndim )
67+ if v .dtype != float64 :
68+ if numpy .issubdtype (v .dtype , numpy .complex ):
69+ raise TypeError ("_serialize_double_vector called on an ndarray of %s; "
70+ "wanted ndarray of float64" % v .dtype )
71+ v = v .astype ('float64' )
7072 length = v .shape [0 ]
71- ba = bytearray (16 + 8 * length )
72- header = ndarray ( shape = [ 2 ], buffer = ba , dtype = "int64" )
73- header [ 0 ] = 1
74- header [ 1 ] = length
75- arr_mid = ndarray (shape = [length ], buffer = ba , offset = 16 , dtype = "float64" )
73+ ba = bytearray (5 + 8 * length )
74+ ba [ 0 ] = 1
75+ length_bytes = ndarray ( shape = [ 1 ], buffer = ba , offset = 1 , dtype = "int32" )
76+ length_bytes [ 0 ] = length
77+ arr_mid = ndarray (shape = [length ], buffer = ba , offset = 5 , dtype = "float64" )
7678 arr_mid [...] = v
7779 return ba
7880
@@ -86,34 +88,34 @@ def _deserialize_double_vector(ba):
8688 if type (ba ) != bytearray :
8789 raise TypeError ("_deserialize_double_vector called on a %s; "
8890 "wanted bytearray" % type (ba ))
89- if len (ba ) < 16 :
91+ if len (ba ) < 5 :
9092 raise TypeError ("_deserialize_double_vector called on a %d-byte array, "
9193 "which is too short" % len (ba ))
92- if (len (ba ) & 7 ) != 0 :
93- raise TypeError ("_deserialize_double_vector called on a %d-byte array, "
94- "which is not a multiple of 8" % len (ba ))
95- header = ndarray (shape = [2 ], buffer = ba , dtype = "int64" )
96- if header [0 ] != 1 :
94+ if ba [0 ] != 1 :
9795 raise TypeError ("_deserialize_double_vector called on bytearray "
9896 "with wrong magic" )
99- length = header [ 1 ]
100- if len (ba ) != 8 * length + 16 :
97+ length = ndarray ( shape = [ 1 ], buffer = ba , offset = 1 , dtype = "int32" )[ 0 ]
98+ if len (ba ) != 8 * length + 5 :
10199 raise TypeError ("_deserialize_double_vector called on bytearray "
102100 "with wrong length" )
103- return _deserialize_byte_array ([length ], ba , 16 )
101+ return _deserialize_byte_array ([length ], ba , 5 )
104102
105103def _serialize_double_matrix (m ):
106104 """Serialize a double matrix into a mutually understood format."""
107- if (type (m ) == ndarray and m .dtype == float64 and m .ndim == 2 ):
105+ if (type (m ) == ndarray and m .ndim == 2 ):
106+ if m .dtype != float64 :
107+ if numpy .issubdtype (m .dtype , numpy .complex ):
108+ raise TypeError ("_serialize_double_matrix called on an ndarray of %s; "
109+ "wanted ndarray of float64" % m .dtype )
110+ m = m .astype ('float64' )
108111 rows = m .shape [0 ]
109112 cols = m .shape [1 ]
110- ba = bytearray (24 + 8 * rows * cols )
111- header = ndarray (shape = [3 ], buffer = ba , dtype = "int64" )
112- header [0 ] = 2
113- header [1 ] = rows
114- header [2 ] = cols
115- arr_mid = ndarray (shape = [rows , cols ], buffer = ba , offset = 24 ,
116- dtype = "float64" , order = 'C' )
113+ ba = bytearray (9 + 8 * rows * cols )
114+ ba [0 ] = 2
115+ lengths = ndarray (shape = [3 ], buffer = ba , offset = 1 , dtype = "int32" )
116+ lengths [0 ] = rows
117+ lengths [1 ] = cols
118+ arr_mid = ndarray (shape = [rows , cols ], buffer = ba , offset = 9 , dtype = "float64" , order = 'C' )
117119 arr_mid [...] = m
118120 return ba
119121 else :
@@ -125,22 +127,19 @@ def _deserialize_double_matrix(ba):
125127 if type (ba ) != bytearray :
126128 raise TypeError ("_deserialize_double_matrix called on a %s; "
127129 "wanted bytearray" % type (ba ))
128- if len (ba ) < 24 :
130+ if len (ba ) < 9 :
129131 raise TypeError ("_deserialize_double_matrix called on a %d-byte array, "
130132 "which is too short" % len (ba ))
131- if (len (ba ) & 7 ) != 0 :
132- raise TypeError ("_deserialize_double_matrix called on a %d-byte array, "
133- "which is not a multiple of 8" % len (ba ))
134- header = ndarray (shape = [3 ], buffer = ba , dtype = "int64" )
135- if (header [0 ] != 2 ):
133+ if ba [0 ] != 2 :
136134 raise TypeError ("_deserialize_double_matrix called on bytearray "
137135 "with wrong magic" )
138- rows = header [1 ]
139- cols = header [2 ]
140- if (len (ba ) != 8 * rows * cols + 24 ):
136+ lengths = ndarray (shape = [2 ], buffer = ba , offset = 1 , dtype = "int32" )
137+ rows = lengths [0 ]
138+ cols = lengths [1 ]
139+ if (len (ba ) != 8 * rows * cols + 9 ):
141140 raise TypeError ("_deserialize_double_matrix called on bytearray "
142141 "with wrong length" )
143- return _deserialize_byte_array ([rows , cols ], ba , 24 )
142+ return _deserialize_byte_array ([rows , cols ], ba , 9 )
144143
145144def _linear_predictor_typecheck (x , coeffs ):
146145 """Check that x is a one-dimensional vector of the right shape.
@@ -151,7 +150,7 @@ def _linear_predictor_typecheck(x, coeffs):
151150 pass
152151 else :
153152 raise RuntimeError ("Got array of %d elements; wanted %d"
154- % (shape (x )[0 ], shape (coeffs )[0 ]))
153+ % (numpy . shape (x )[0 ], numpy . shape (coeffs )[0 ]))
155154 else :
156155 raise RuntimeError ("Bulk predict not yet supported." )
157156 elif (type (x ) == RDD ):
@@ -187,7 +186,7 @@ def predict(self, x):
187186 """Predict the value of the dependent variable given a vector x"""
188187 """containing values for the independent variables."""
189188 _linear_predictor_typecheck (x , self ._coeff )
190- return dot (self ._coeff , x ) + self ._intercept
189+ return numpy . dot (self ._coeff , x ) + self ._intercept
191190
192191# If we weren't given initial weights, take a zero vector of the appropriate
193192# length.
@@ -200,7 +199,7 @@ def _get_initial_weights(initial_weights, data):
200199 if initial_weights .ndim != 1 :
201200 raise TypeError ("At least one data element has "
202201 + initial_weights .ndim + " dimensions, which is not 1" )
203- initial_weights = ones ([initial_weights .shape [0 ] - 1 ])
202+ initial_weights = numpy . ones ([initial_weights .shape [0 ] - 1 ])
204203 return initial_weights
205204
206205# train_func should take two parameters, namely data and initial_weights, and
0 commit comments