@@ -181,12 +181,6 @@ def _check_fitted_model(km):
181181 % km .n_clusters , km .fit , [[0. , 1. ]])
182182
183183
184- def test_k_means_plus_plus_init ():
185- km = KMeans (init = "k-means++" , n_clusters = n_clusters ,
186- random_state = 42 ).fit (X )
187- _check_fitted_model (km )
188-
189-
190184def test_k_means_new_centers ():
191185 # Explore the part of the code where a new center is reassigned
192186 X = np .array ([[0 , 0 , 1 , 1 ],
@@ -229,24 +223,6 @@ def test_k_means_precompute_distances_flag():
229223 assert_raises (ValueError , km .fit , X )
230224
231225
232- def test_k_means_plus_plus_init_sparse ():
233- km = KMeans (init = "k-means++" , n_clusters = n_clusters , random_state = 42 )
234- km .fit (X_csr )
235- _check_fitted_model (km )
236-
237-
238- def test_k_means_random_init ():
239- km = KMeans (init = "random" , n_clusters = n_clusters , random_state = 42 )
240- km .fit (X )
241- _check_fitted_model (km )
242-
243-
244- def test_k_means_random_init_sparse ():
245- km = KMeans (init = "random" , n_clusters = n_clusters , random_state = 42 )
246- km .fit (X_csr )
247- _check_fitted_model (km )
248-
249-
250226def test_k_means_plus_plus_init_not_precomputed ():
251227 km = KMeans (init = "k-means++" , n_clusters = n_clusters , random_state = 42 ,
252228 precompute_distances = False ).fit (X )
@@ -259,10 +235,11 @@ def test_k_means_random_init_not_precomputed():
259235 _check_fitted_model (km )
260236
261237
262- def test_k_means_perfect_init ():
263- km = KMeans (init = centers .copy (), n_clusters = n_clusters , random_state = 42 ,
264- n_init = 1 )
265- km .fit (X )
238+ @pytest .mark .parametrize ('data' , [X , X_csr ], ids = ['dense' , 'sparse' ])
239+ @pytest .mark .parametrize ('init' , ['random' , 'k-means++' , centers .copy ()])
240+ def test_k_means_init (data , init ):
241+ km = KMeans (init = init , n_clusters = n_clusters , random_state = 42 , n_init = 1 )
242+ km .fit (data )
266243 _check_fitted_model (km )
267244
268245
@@ -315,13 +292,6 @@ def test_k_means_fortran_aligned_data():
315292 assert_array_equal (km .labels_ , labels )
316293
317294
318- def test_mb_k_means_plus_plus_init_dense_array ():
319- mb_k_means = MiniBatchKMeans (init = "k-means++" , n_clusters = n_clusters ,
320- random_state = 42 )
321- mb_k_means .fit (X )
322- _check_fitted_model (mb_k_means )
323-
324-
325295def test_mb_kmeans_verbose ():
326296 mb_k_means = MiniBatchKMeans (init = "k-means++" , n_clusters = n_clusters ,
327297 random_state = 42 , verbose = 1 )
@@ -333,49 +303,25 @@ def test_mb_kmeans_verbose():
333303 sys .stdout = old_stdout
334304
335305
336- def test_mb_k_means_plus_plus_init_sparse_matrix ():
337- mb_k_means = MiniBatchKMeans (init = "k-means++" , n_clusters = n_clusters ,
338- random_state = 42 )
339- mb_k_means .fit (X_csr )
340- _check_fitted_model (mb_k_means )
341-
342-
343306def test_minibatch_init_with_large_k ():
344307 mb_k_means = MiniBatchKMeans (init = 'k-means++' , init_size = 10 , n_clusters = 20 )
345308 # Check that a warning is raised, as the number clusters is larger
346309 # than the init_size
347310 assert_warns (RuntimeWarning , mb_k_means .fit , X )
348311
349312
350- def test_minibatch_k_means_random_init_dense_array ():
351- # increase n_init to make random init stable enough
352- mb_k_means = MiniBatchKMeans (init = "random" , n_clusters = n_clusters ,
353- random_state = 42 , n_init = 10 ).fit (X )
354- _check_fitted_model (mb_k_means )
355-
356-
357- def test_minibatch_k_means_random_init_sparse_csr ():
358- # increase n_init to make random init stable enough
359- mb_k_means = MiniBatchKMeans (init = "random" , n_clusters = n_clusters ,
360- random_state = 42 , n_init = 10 ).fit (X_csr )
361- _check_fitted_model (mb_k_means )
362-
363-
364- def test_minibatch_k_means_perfect_init_dense_array ():
365- mb_k_means = MiniBatchKMeans (init = centers .copy (), n_clusters = n_clusters ,
366- random_state = 42 , n_init = 1 ).fit (X )
367- _check_fitted_model (mb_k_means )
368-
369-
370313def test_minibatch_k_means_init_multiple_runs_with_explicit_centers ():
371314 mb_k_means = MiniBatchKMeans (init = centers .copy (), n_clusters = n_clusters ,
372315 random_state = 42 , n_init = 10 )
373316 assert_warns (RuntimeWarning , mb_k_means .fit , X )
374317
375318
376- def test_minibatch_k_means_perfect_init_sparse_csr ():
377- mb_k_means = MiniBatchKMeans (init = centers .copy (), n_clusters = n_clusters ,
378- random_state = 42 , n_init = 1 ).fit (X_csr )
319+ @pytest .mark .parametrize ('data' , [X , X_csr ], ids = ['dense' , 'sparse' ])
320+ @pytest .mark .parametrize ('init' , ["random" , 'k-means++' , centers .copy ()])
321+ def test_minibatch_k_means_init (data , init ):
322+ mb_k_means = MiniBatchKMeans (init = init , n_clusters = n_clusters ,
323+ random_state = 42 , n_init = 10 )
324+ mb_k_means .fit (data )
379325 _check_fitted_model (mb_k_means )
380326
381327
@@ -585,64 +531,39 @@ def test_predict():
585531 assert_array_equal (pred , km .labels_ )
586532
587533
588- def test_score ():
589-
590- km1 = KMeans (n_clusters = n_clusters , max_iter = 1 , random_state = 42 , n_init = 1 )
591- s1 = km1 .fit (X ).score (X )
592- km2 = KMeans (n_clusters = n_clusters , max_iter = 10 , random_state = 42 , n_init = 1 )
593- s2 = km2 .fit (X ).score (X )
594- assert_greater (s2 , s1 )
595-
534+ @pytest .mark .parametrize ('algo' , ['full' , 'elkan' ])
535+ def test_score (algo ):
536+ # Check that fitting k-means with multiple inits gives better score
596537 km1 = KMeans (n_clusters = n_clusters , max_iter = 1 , random_state = 42 , n_init = 1 ,
597- algorithm = 'elkan' )
538+ algorithm = algo )
598539 s1 = km1 .fit (X ).score (X )
599540 km2 = KMeans (n_clusters = n_clusters , max_iter = 10 , random_state = 42 , n_init = 1 ,
600- algorithm = 'elkan' )
541+ algorithm = algo )
601542 s2 = km2 .fit (X ).score (X )
602543 assert_greater (s2 , s1 )
603544
604545
605- def test_predict_minibatch_dense_input ():
606- mb_k_means = MiniBatchKMeans (n_clusters = n_clusters , random_state = 40 ).fit (X )
607-
608- # sanity check: predict centroid labels
609- pred = mb_k_means .predict (mb_k_means .cluster_centers_ )
610- assert_array_equal (pred , np .arange (n_clusters ))
611-
612- # sanity check: re-predict labeling for training set samples
613- pred = mb_k_means .predict (X )
614- assert_array_equal (mb_k_means .predict (X ), mb_k_means .labels_ )
615-
616-
617- def test_predict_minibatch_kmeanspp_init_sparse_input ():
618- mb_k_means = MiniBatchKMeans (n_clusters = n_clusters , init = 'k-means++' ,
619- n_init = 10 ).fit (X_csr )
546+ @pytest .mark .parametrize ('data' , [X , X_csr ], ids = ['dense' , 'sparse' ])
547+ @pytest .mark .parametrize ('init' , ['random' , 'k-means++' , centers .copy ()])
548+ def test_predict_minibatch (data , init ):
549+ mb_k_means = MiniBatchKMeans (n_clusters = n_clusters , init = init ,
550+ n_init = 10 , random_state = 0 ).fit (data )
620551
621552 # sanity check: re-predict labeling for training set samples
622- assert_array_equal (mb_k_means .predict (X_csr ), mb_k_means .labels_ )
553+ assert_array_equal (mb_k_means .predict (data ), mb_k_means .labels_ )
623554
624555 # sanity check: predict centroid labels
625556 pred = mb_k_means .predict (mb_k_means .cluster_centers_ )
626557 assert_array_equal (pred , np .arange (n_clusters ))
627558
628- # check that models trained on sparse input also works for dense input at
629- # predict time
630- assert_array_equal (mb_k_means .predict (X ), mb_k_means .labels_ )
631-
632-
633- def test_predict_minibatch_random_init_sparse_input ():
634- mb_k_means = MiniBatchKMeans (n_clusters = n_clusters , init = 'random' ,
635- n_init = 10 ).fit (X_csr )
636-
637- # sanity check: re-predict labeling for training set samples
638- assert_array_equal (mb_k_means .predict (X_csr ), mb_k_means .labels_ )
639-
640- # sanity check: predict centroid labels
641- pred = mb_k_means .predict (mb_k_means .cluster_centers_ )
642- assert_array_equal (pred , np .arange (n_clusters ))
643559
560+ @pytest .mark .parametrize ('init' , ['random' , 'k-means++' , centers .copy ()])
561+ def test_predict_minibatch_dense_sparse (init ):
644562 # check that models trained on sparse input also works for dense input at
645563 # predict time
564+ mb_k_means = MiniBatchKMeans (n_clusters = n_clusters , init = init ,
565+ n_init = 10 , random_state = 0 ).fit (X_csr )
566+
646567 assert_array_equal (mb_k_means .predict (X ), mb_k_means .labels_ )
647568
648569
@@ -694,27 +615,19 @@ def test_fit_transform():
694615 assert_array_almost_equal (X1 , X2 )
695616
696617
697- def test_predict_equal_labels ():
698- km = KMeans (random_state = 13 , n_jobs = 1 , n_init = 1 , max_iter = 1 ,
699- algorithm = 'full' )
700- km .fit (X )
701- assert_array_equal (km .predict (X ), km .labels_ )
702-
618+ @pytest .mark .parametrize ('algo' , ['full' , 'elkan' ])
619+ def test_predict_equal_labels (algo ):
703620 km = KMeans (random_state = 13 , n_jobs = 1 , n_init = 1 , max_iter = 1 ,
704- algorithm = 'elkan' )
621+ algorithm = algo )
705622 km .fit (X )
706623 assert_array_equal (km .predict (X ), km .labels_ )
707624
708625
709626def test_full_vs_elkan ():
627+ km1 = KMeans (algorithm = 'full' , random_state = 13 ).fit (X )
628+ km2 = KMeans (algorithm = 'elkan' , random_state = 13 ).fit (X )
710629
711- km1 = KMeans (algorithm = 'full' , random_state = 13 )
712- km2 = KMeans (algorithm = 'elkan' , random_state = 13 )
713-
714- km1 .fit (X )
715- km2 .fit (X )
716-
717- homogeneity_score (km1 .predict (X ), km2 .predict (X )) == 1.0
630+ assert homogeneity_score (km1 .predict (X ), km2 .predict (X )) == 1.0
718631
719632
720633def test_n_init ():
0 commit comments