Skip to content

Commit beda3f4

Browse files
author
Fabian Pedregosa
committed
FIX: inheritance in DenseBaseSVM
Sequel for d079dde, including a test
1 parent d079dde commit beda3f4

File tree

2 files changed

+33
-19
lines changed

2 files changed

+33
-19
lines changed

sklearn/svm/base.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -323,18 +323,20 @@ def _dense_predict(self, X):
323323
"the number of features at training time" %
324324
(n_features, self.shape_fit_[1]))
325325

326-
params = self.get_params()
327-
if 'scale_C' in params:
328-
del params['scale_C']
329-
if "sparse" in params:
330-
del params["sparse"]
326+
epsilon = self.epsilon
327+
if epsilon == None:
328+
epsilon = 0.1
331329

332330
svm_type = LIBSVM_IMPL.index(self.impl)
333331
return libsvm.predict(
334332
X, self.support_, self.support_vectors_, self.n_support_,
335333
self.dual_coef_, self.intercept_,
336334
self.label_, self.probA_, self.probB_,
337-
svm_type=svm_type, **params)
335+
svm_type=svm_type,
336+
kernel=self.kernel, C=self.C, nu=self.nu,
337+
probability=self.probability, degree=self.degree,
338+
shrinking=self.shrinking, tol=self.tol, cache_size=self.cache_size,
339+
coef0=self.coef0, gamma=self.gamma, epsilon=epsilon)
338340

339341
def _sparse_predict(self, X):
340342
X = sp.csr_matrix(X, dtype=np.float64)
@@ -393,18 +395,19 @@ def predict_proba(self, X):
393395
def _dense_predict_proba(self, X):
394396
X = self._compute_kernel(X)
395397

396-
params = self.get_params()
397-
if 'scale_C' in params:
398-
del params['scale_C']
399-
if "sparse" in params:
400-
del params["sparse"]
398+
epsilon = self.epsilon
399+
if epsilon == None:
400+
epsilon = 0.1
401401

402402
svm_type = LIBSVM_IMPL.index(self.impl)
403403
pprob = libsvm.predict_proba(
404404
X, self.support_, self.support_vectors_, self.n_support_,
405405
self.dual_coef_, self.intercept_, self.label_,
406406
self.probA_, self.probB_,
407-
svm_type=svm_type, **params)
407+
svm_type=svm_type, kernel=self.kernel, C=self.C, nu=self.nu,
408+
probability=self.probability, degree=self.degree,
409+
shrinking=self.shrinking, tol=self.tol, cache_size=self.cache_size,
410+
coef0=self.coef0, gamma=self.gamma, epsilon=epsilon)
408411

409412
return pprob
410413

@@ -478,18 +481,18 @@ def decision_function(self, X):
478481

479482
X = array2d(X, dtype=np.float64, order="C")
480483

481-
params = self.get_params()
482-
if 'scale_C' in params:
483-
del params['scale_C']
484-
if "sparse" in params:
485-
del params["sparse"]
486-
484+
epsilon = self.epsilon
485+
if epsilon == None:
486+
epsilon = 0.1
487487
dec_func = libsvm.decision_function(
488488
X, self.support_, self.support_vectors_, self.n_support_,
489489
self.dual_coef_, self.intercept_, self.label_,
490490
self.probA_, self.probB_,
491491
svm_type=LIBSVM_IMPL.index(self.impl),
492-
**params)
492+
kernel=self.kernel, C=self.C, nu=self.nu,
493+
probability=self.probability, degree=self.degree,
494+
shrinking=self.shrinking, tol=self.tol, cache_size=self.cache_size,
495+
coef0=self.coef0, gamma=self.gamma, epsilon=epsilon)
493496

494497
return dec_func
495498

sklearn/svm/tests/test_svm.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,17 @@ def test_immutable_coef_property():
561561
assert_raises(AttributeError, clf.__setattr__, 'coef_', np.arange(3))
562562
assert_raises(RuntimeError, clf.coef_.__setitem__, (0, 0), 0)
563563

564+
def test_inheritance():
565+
# check that SVC classes can do inheritance
566+
class ChildSVC(svm.SVC):
567+
def __init__(self, foo=0):
568+
self.foo = foo
569+
svm.SVC.__init__(self)
570+
571+
clf = ChildSVC()
572+
clf.fit(iris.data, iris.target)
573+
clf.predict(iris.data[-1])
574+
clf.decision_function(iris.data[-1])
564575

565576
if __name__ == '__main__':
566577
import nose

0 commit comments

Comments
 (0)