Skip to content

Commit 4629366

Browse files
committed
Merge pull request #4176 from amueller/mean_shift_no_centers
[MRG + 1] Better error messages in MeanShift, slightly more robust to bad binning.
2 parents 95681ee + a9b3965 commit 4629366

File tree

3 files changed

+45
-12
lines changed

3 files changed

+45
-12
lines changed

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ Enhancements
179179
- Allow the fitting and scoring of all clustering algorithms in
180180
:class:`pipeline.Pipeline`. By `Andreas Müller`_.
181181

182+
- More robust seeding and improved error messages in :class:`cluster.MeanShift`
183+
by `Andreas Müller`_.
184+
182185
Documentation improvements
183186
..........................
184187

sklearn/cluster/mean_shift_.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,22 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
8585
the number of samples. The sklearn.cluster.estimate_bandwidth function
8686
can be used to do this more efficiently.
8787
88-
seeds : array-like, shape=[n_seeds, n_features]
89-
Point used as initial kernel locations.
88+
seeds : array-like, shape=[n_seeds, n_features] or None
89+
Point used as initial kernel locations. If None and bin_seeding=False,
90+
each data point is used as a seed. If None and bin_seeding=True,
91+
see bin_seeding.
9092
91-
bin_seeding : boolean
93+
bin_seeding : boolean, default=False
9294
If true, initial kernel locations are not locations of all
9395
points, but rather the location of the discretized version of
9496
points, where points are binned onto a grid whose coarseness
9597
corresponds to the bandwidth. Setting this option to True will speed
9698
up the algorithm because fewer seeds will be initialized.
97-
default value: False
9899
Ignored if seeds argument is not None.
99100
100-
min_bin_freq : int, optional
101+
min_bin_freq : int, default=1
101102
To speed up the algorithm, accept only those bins with at least
102-
min_bin_freq points as seeds. If not defined, set to 1.
103+
min_bin_freq points as seeds.
103104
104105
cluster_all : boolean, default True
105106
If true, then all points are clustered, even those orphans that are
@@ -133,6 +134,9 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
133134

134135
if bandwidth is None:
135136
bandwidth = estimate_bandwidth(X)
137+
elif bandwidth <= 0:
138+
raise ValueError("bandwidth needs to be greater than zero or None, got %f" %
139+
bandwidth)
136140
if seeds is None:
137141
if bin_seeding:
138142
seeds = get_bin_seeds(X, bandwidth, min_bin_freq)
@@ -155,13 +159,19 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
155159
break # Depending on seeding strategy this condition may occur
156160
my_old_mean = my_mean # save the old mean
157161
my_mean = np.mean(points_within, axis=0)
158-
# If converged or at max_iter, addS the cluster
162+
# If converged or at max_iter, adds the cluster
159163
if (extmath.norm(my_mean - my_old_mean) < stop_thresh or
160164
completed_iterations == max_iter):
161165
center_intensity_dict[tuple(my_mean)] = len(points_within)
162166
break
163167
completed_iterations += 1
164168

169+
if not center_intensity_dict:
170+
# nothing near seeds
171+
raise ValueError("No point was within bandwidth=%f of any seed."
172+
" Try a different seeding strategy or increase the bandwidth."
173+
% bandwidth)
174+
165175
# POST PROCESSING: remove near duplicate points
166176
# If the distance between two kernels is less than the bandwidth,
167177
# then we have to remove one because it is a duplicate. Remove the
@@ -225,12 +235,16 @@ def get_bin_seeds(X, bin_size, min_bin_freq=1):
225235
# Bin points
226236
bin_sizes = defaultdict(int)
227237
for point in X:
228-
binned_point = np.cast[np.int32](point / bin_size)
238+
binned_point = np.round(point / bin_size)
229239
bin_sizes[tuple(binned_point)] += 1
230240

231241
# Select only those bins as seeds which have enough members
232242
bin_seeds = np.array([point for point, freq in six.iteritems(bin_sizes) if
233243
freq >= min_bin_freq], dtype=np.float32)
244+
if len(bin_seeds) == len(X):
245+
warnings.warn("Binning data failed with provided bin_size=%f, using data"
246+
" points as seeds." % bin_size)
247+
return X
234248
bin_seeds = bin_seeds * bin_size
235249
return bin_seeds
236250

sklearn/cluster/tests/test_mean_shift.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
"""
55

66
import numpy as np
7+
import warnings
78

89
from sklearn.utils.testing import assert_equal
910
from sklearn.utils.testing import assert_false
1011
from sklearn.utils.testing import assert_true
1112
from sklearn.utils.testing import assert_array_equal
13+
from sklearn.utils.testing import assert_raise_message
1214

1315
from sklearn.cluster import MeanShift
1416
from sklearn.cluster import mean_shift
@@ -52,6 +54,13 @@ def test_meanshift_predict():
5254
assert_array_equal(labels, labels2)
5355

5456

57+
def test_meanshift_all_orphans():
58+
# init away from the data, crash with a sensible warning
59+
ms = MeanShift(bandwidth=0.1, seeds=[[-9, -9], [-10, -10]])
60+
msg = "No point was within bandwidth=0.1"
61+
assert_raise_message(ValueError, msg, ms.fit, X,)
62+
63+
5564
def test_unfitted():
5665
"""Non-regression: before fit, there should be not fitted attributes."""
5766
ms = MeanShift()
@@ -65,7 +74,7 @@ def test_bin_seeds():
6574
algorithm
6675
"""
6776
# Data is just 6 points in the plane
68-
X = np.array([[1., 1.], [1.5, 1.5], [1.8, 1.2],
77+
X = np.array([[1., 1.], [1.4, 1.4], [1.8, 1.2],
6978
[2., 1.], [2.1, 1.1], [0., 0.]])
7079

7180
# With a bin coarseness of 1.0 and min_bin_freq of 1, 3 bins should be
@@ -83,6 +92,13 @@ def test_bin_seeds():
8392
assert_true(len(ground_truth.symmetric_difference(test_result)) == 0)
8493

8594
# With a bin size of 0.01 and min_bin_freq of 1, 6 bins should be found
86-
test_bins = get_bin_seeds(X, 0.01, 1)
87-
test_result = set([tuple(p) for p in test_bins])
88-
assert_true(len(test_result) == 6)
95+
# we bail and use the whole data here.
96+
with warnings.catch_warnings(record=True):
97+
test_bins = get_bin_seeds(X, 0.01, 1)
98+
assert_array_equal(test_bins, X)
99+
100+
# tight clusters around [0, 0] and [1, 1], only get two bins
101+
X, _ = make_blobs(n_samples=100, n_features=2, centers=[[0, 0], [1, 1]],
102+
cluster_std=0.1, random_state=0)
103+
test_bins = get_bin_seeds(X, 1)
104+
assert_array_equal(test_bins, [[0, 0], [1, 1]])

0 commit comments

Comments
 (0)