Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions sklearn/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .utils.validation import check_consistent_length
from .utils import deprecated
from .utils.random import random_choice_csc
from .utils.stats import _weighted_percentile
from .utils.multiclass import class_distribution


Expand Down Expand Up @@ -366,7 +367,7 @@ def y_mean_(self):
return self.constant_
raise AttributeError

def fit(self, X, y):
def fit(self, X, y, sample_weight=None):
"""Fit the random regressor.

Parameters
Expand All @@ -378,6 +379,9 @@ def fit(self, X, y):
y : array-like, shape = [n_samples] or [n_samples, n_outputs]
Target values.

sample_weight : array-like of shape = [n_samples], optional
Sample weights.

Returns
-------
self : object
Expand All @@ -389,25 +393,40 @@ def fit(self, X, y):
"'mean', 'median', 'quantile' or 'constant'"
% self.strategy)

y = check_array(y, accept_sparse='csr', ensure_2d=False)
y = check_array(y, ensure_2d=False)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed accept_sparse='csr' since it's not supported.

if len(y) == 0:
raise ValueError("y must not be empty.")
self.output_2d_ = (y.ndim == 2)

check_consistent_length(X, y)
self.output_2d_ = y.ndim == 2
if y.ndim == 1:
y = np.reshape(y, (-1, 1))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob question: out of curiosity, is there any difference between doing this and

`y = y[:, np.newaxis]`

I always use the latter.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The y = y[:, np.newaxis] doesn't preserve contiguity. This is a known bug and will be solve in the current or next release of numpy.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes, I remember now :)

self.n_outputs_ = y.shape[1]

check_consistent_length(X, y, sample_weight)

if self.strategy == "mean":
self.constant_ = np.mean(y, axis=0)
self.constant_ = np.average(y, axis=0, weights=sample_weight)

elif self.strategy == "median":
self.constant_ = np.median(y, axis=0)
if sample_weight is None:
self.constant_ = np.median(y, axis=0)
else:
self.constant_ = [_weighted_percentile(y[:, k], sample_weight,
percentile=50.)
for k in range(self.n_outputs_)]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.constant_ = np.asarray(.... ) ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is handled by the reshape. I will check tomorrow

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, sorry I looked at only the diff. I take back my comment.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a nitpick: why not just set it to np.empty initially and then fill it. In that way reshaping can be avoided.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reshape must be done for the other np.mean and np.median. I think that the keepdims keyword is not yet available on the lowest supported numpy version.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, about my pretentious comments then ;)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worry :-)

elif self.strategy == "quantile":
if self.quantile is None or not np.isscalar(self.quantile):
raise ValueError("Quantile must be a scalar in the range "
"[0.0, 1.0], but got %s." % self.quantile)

self.constant_ = np.percentile(y, axis=0, q=self.quantile * 100.0)
percentile = self.quantile * 100.0
if sample_weight is None:
self.constant_ = np.percentile(y, axis=0, q=percentile)
else:
self.constant_ = [_weighted_percentile(y[:, k], sample_weight,
percentile=percentile)
for k in range(self.n_outputs_)]

elif self.strategy == "constant":
if self.constant is None:
Expand All @@ -426,7 +445,6 @@ def fit(self, X, y):
self.constant_ = self.constant

self.constant_ = np.reshape(self.constant_, (1, -1))
self.n_outputs_ = np.size(self.constant_) # y.shape[1] is not safe
return self

def predict(self, X):
Expand Down
13 changes: 1 addition & 12 deletions sklearn/ensemble/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ..base import RegressorMixin
from ..utils import check_random_state, check_array, check_X_y, column_or_1d
from ..utils.extmath import logsumexp
from ..utils.stats import _weighted_percentile
from ..externals import six
from ..feature_selection.from_model import _LearntSelectorMixin

Expand All @@ -50,18 +51,6 @@
from ._gradient_boosting import _random_sample_mask


def _weighted_percentile(array, sample_weight, percentile=50):
"""Compute the weighted ``percentile`` of ``array`` with ``sample_weight``. """
sorted_idx = np.argsort(array)

# Find index of median prediction for each sample
weight_cdf = sample_weight[sorted_idx].cumsum()
percentile_or_above = weight_cdf >= (percentile / 100.0) * weight_cdf[-1]
percentile_idx = percentile_or_above.argmax()

return array[sorted_idx[percentile_idx]]


class QuantileEstimator(BaseEstimator):
"""An estimator predicting the alpha-quantile of the training targets."""
def __init__(self, alpha=0.9):
Expand Down
19 changes: 19 additions & 0 deletions sklearn/tests/test_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_true
from sklearn.utils.testing import assert_warns_message
from sklearn.utils.stats import _weighted_percentile

from sklearn.dummy import DummyClassifier, DummyRegressor

Expand Down Expand Up @@ -572,6 +573,24 @@ def test_most_frequent_strategy_sparse_target():
np.zeros((n_samples, 1))]))


def test_dummy_regressor_sample_weight(n_samples=10):
random_state = np.random.RandomState(seed=1)

X = [[0]] * n_samples
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it better to generate X randomly? Just for a sanity check.

y = random_state.rand(n_samples)
sample_weight = random_state.rand(n_samples)

est = DummyRegressor(strategy="mean").fit(X, y, sample_weight)
assert_equal(est.constant_, np.average(y, weights=sample_weight))

est = DummyRegressor(strategy="median").fit(X, y, sample_weight)
assert_equal(est.constant_, _weighted_percentile(y, sample_weight, 50.))

est = DummyRegressor(strategy="quantile", quantile=.95).fit(X, y,
sample_weight)
assert_equal(est.constant_, _weighted_percentile(y, sample_weight, 95.))


if __name__ == '__main__':
import nose
nose.runmodule()
13 changes: 13 additions & 0 deletions sklearn/utils/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,16 @@ def _rankdata(a, method="average"):

except TypeError as e:
rankdata = _rankdata


def _weighted_percentile(array, sample_weight, percentile=50):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why this is private?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it is in utils.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I was confused since there any many functions in utils which are not private.

"""Compute the weighted ``percentile`` of ``array`` with ``sample_weight``. """
sorted_idx = np.argsort(array)

# Find index of median prediction for each sample
weight_cdf = sample_weight[sorted_idx].cumsum()
percentile_or_above = weight_cdf >= (percentile / 100.0) * weight_cdf[-1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for being stupid, but I am not able to get this to work. My arguments are [3, 2, 4] and [1, 2, 3] for array and sample_weight respectively. The sorted_idx is an array and thus throwing a TypeError. I wonder what are the expected arguments here.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample_weight should be a numpy array

On Sun, Oct 19, 2014 at 12:26 PM, Saurabh Jha [email protected]
wrote:

In sklearn/utils/stats.py:

@@ -44,3 +44,16 @@ def _rankdata(a, method="average"):

except TypeError as e:
rankdata = _rankdata
+
+
+def _weighted_percentile(array, sample_weight, percentile=50):

  • """Compute the weighted percentile of array with sample_weight. """
  • sorted_idx = np.argsort(array)
  • Find index of median prediction for each sample

  • weight_cdf = sample_weight[sorted_idx].cumsum()
  • percentile_or_above = weight_cdf >= (percentile / 100.0) * weight_cdf[-1]

Sorry for being stupid, but I am not able to get this to work. My
arguments are [3, 2, 4] and [1, 2, 3] for array and sample_weight
respectively. The sorted_idx is an array and thus throwing a TypeError. I
wonder what are the expected arguments here.


Reply to this email directly or view it on GitHub
https://github.com/scikit-learn/scikit-learn/pull/3779/files#r19059282.

Godspeed,
Manoj Kumar,
Intern, Telecom ParisTech
Mech Undergrad
http://manojbits.wordpress.com

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @MechCoder !

percentile_idx = percentile_or_above.argmax()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm understanding right, these two lines can be replaced by

precentile_idx = np.searchsorted(weight_cdf, (percentile / 100.) * weight_cdf[-1])

or am I wrong?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think this could be optimized in another pr? I have just taken what @pprett has done previously and put it there to be useful to more than just gradient boosting.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, unless @pprett thinks if it is ok, to change this over here.


return array[sorted_idx[percentile_idx]]