@@ -85,21 +85,22 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
8585 the number of samples. The sklearn.cluster.estimate_bandwidth function
8686 can be used to do this more efficiently.
8787
88- seeds : array-like, shape=[n_seeds, n_features]
89- Point used as initial kernel locations.
88+ seeds : array-like, shape=[n_seeds, n_features] or None
89+ Point used as initial kernel locations. If None and bin_seeding=False,
90+ each data point is used as a seed. If None and bin_seeding=True,
91+ see bin_seeding.
9092
91- bin_seeding : boolean
93+ bin_seeding : boolean, default=False
9294 If true, initial kernel locations are not locations of all
9395 points, but rather the location of the discretized version of
9496 points, where points are binned onto a grid whose coarseness
9597 corresponds to the bandwidth. Setting this option to True will speed
9698 up the algorithm because fewer seeds will be initialized.
97- default value: False
9899 Ignored if seeds argument is not None.
99100
100- min_bin_freq : int, optional
101+ min_bin_freq : int, default=1
101102 To speed up the algorithm, accept only those bins with at least
102- min_bin_freq points as seeds. If not defined, set to 1.
103+ min_bin_freq points as seeds.
103104
104105 cluster_all : boolean, default True
105106 If true, then all points are clustered, even those orphans that are
@@ -133,6 +134,9 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
133134
134135 if bandwidth is None :
135136 bandwidth = estimate_bandwidth (X )
137+ elif bandwidth <= 0 :
138+ raise ValueError ("bandwidth needs to be greater than zero or None, got %f" %
139+ bandwidth )
136140 if seeds is None :
137141 if bin_seeding :
138142 seeds = get_bin_seeds (X , bandwidth , min_bin_freq )
@@ -155,13 +159,19 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
155159 break # Depending on seeding strategy this condition may occur
156160 my_old_mean = my_mean # save the old mean
157161 my_mean = np .mean (points_within , axis = 0 )
158- # If converged or at max_iter, addS the cluster
162+ # If converged or at max_iter, adds the cluster
159163 if (extmath .norm (my_mean - my_old_mean ) < stop_thresh or
160164 completed_iterations == max_iter ):
161165 center_intensity_dict [tuple (my_mean )] = len (points_within )
162166 break
163167 completed_iterations += 1
164168
169+ if not center_intensity_dict :
170+ # nothing near seeds
171+ raise ValueError ("No point was within bandwidth=%f of any seed."
172+ " Try a different seeding strategy or increase the bandwidth."
173+ % bandwidth )
174+
165175 # POST PROCESSING: remove near duplicate points
166176 # If the distance between two kernels is less than the bandwidth,
167177 # then we have to remove one because it is a duplicate. Remove the
@@ -225,12 +235,16 @@ def get_bin_seeds(X, bin_size, min_bin_freq=1):
225235 # Bin points
226236 bin_sizes = defaultdict (int )
227237 for point in X :
228- binned_point = np .cast [ np . int32 ] (point / bin_size )
238+ binned_point = np .round (point / bin_size )
229239 bin_sizes [tuple (binned_point )] += 1
230240
231241 # Select only those bins as seeds which have enough members
232242 bin_seeds = np .array ([point for point , freq in six .iteritems (bin_sizes ) if
233243 freq >= min_bin_freq ], dtype = np .float32 )
244+ if len (bin_seeds ) == len (X ):
245+ warnings .warn ("Binning data failed with provided bin_size=%f, using data"
246+ " points as seeds." % bin_size )
247+ return X
234248 bin_seeds = bin_seeds * bin_size
235249 return bin_seeds
236250
0 commit comments