Skip to content

Commit a07ba10

Browse files
committed
Fix some typos and calculation of initial weights
1 parent 74eefe7 commit a07ba10

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

python/pyspark/mllib/_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,9 @@ def _get_initial_weights(initial_weights, data):
329329
if initial_weights.ndim != 1:
330330
raise TypeError("At least one data element has "
331331
+ initial_weights.ndim + " dimensions, which is not 1")
332-
initial_weights = numpy.ones([initial_weights.shape[0] - 1])
332+
initial_weights = numpy.ones([initial_weights.shape[0]])
333333
elif type(initial_weights) == SparseVector:
334-
initial_weights = numpy.ones([initial_weights.size - 1])
334+
initial_weights = numpy.ones([initial_weights.size])
335335
return initial_weights
336336

337337

python/pyspark/mllib/classification.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
_serialize_double_matrix, _deserialize_double_matrix, \
2525
_serialize_double_vector, _deserialize_double_vector, \
2626
_get_initial_weights, _serialize_rating, _regression_train_wrapper, \
27-
LinearModel, _linear_predictor_typecheck
27+
LinearModel, _linear_predictor_typecheck, _get_unmangled_labeled_point_rdd
2828
from pyspark.mllib.linalg import SparseVector
2929
from pyspark.mllib.regression import LabeledPoint
3030
from math import exp, log
@@ -135,9 +135,9 @@ class NaiveBayesModel(object):
135135
>>> model.predict(array([1.0, 0.0]))
136136
1.0
137137
>>> sparse_data = [
138-
... LabeledPoint(0.0, SparseVector(2, {1: 0.0}),
139-
... LabeledPoint(0.0, SparseVector(2, {1: 1.0}),
140-
... LabeledPoint(1.0, SparseVector(2, {0: 1.0})
138+
... LabeledPoint(0.0, SparseVector(2, {1: 0.0})),
139+
... LabeledPoint(0.0, SparseVector(2, {1: 1.0})),
140+
... LabeledPoint(1.0, SparseVector(2, {0: 1.0}))
141141
... ]
142142
>>> model = NaiveBayes.train(sc.parallelize(sparse_data))
143143
>>> model.predict(SparseVector(2, {1: 1.0}))
@@ -173,7 +173,7 @@ def train(cls, data, lambda_=1.0):
173173
@param lambda_: The smoothing parameter
174174
"""
175175
sc = data.context
176-
dataBytes = _get_unmangled_double_vector_rdd(data)
176+
dataBytes = _get_unmangled_labeled_point_rdd(data)
177177
ans = sc._jvm.PythonMLLibAPI().trainNaiveBayes(dataBytes._jrdd, lambda_)
178178
return NaiveBayesModel(
179179
_deserialize_double_vector(ans[0]),

0 commit comments

Comments
 (0)