Skip to content

K-Means clustering performance improvements #10744

@FrancoisFayard

Description

@FrancoisFayard

Hi,

I am new to the Scikit-Learn world, but I've talked a lot with Alexandre Gramfort about some performance enhancement that can be done to Scikit-Learn. I knew for a long time that the KMeans clustering code was suboptimal, but my version was a C++ code and it seems that it is not a language that is used in Scikit-Learn. Wether or not it was possible to make the KMeans clustering implementation of Scikit-Learn faster with the constraints of the community was an open question to me as I am not a Python programmer.

I am going to focus on Intel architectures. The benchmarks I've done are in between the vanilla Scikit-Learn (the latest one on Anaconda), Intel Scikit-Learn (the latest one from the intel channel which is linked to the DAAL) and a code I've written using Cython and compiled with the Intel compiler.

Here are the difference of speed to classify 270 000 points in 1024 clusters in dimension 3. This example comes from the color quantization documentation of Scikit-Learn which has an image of 270 000 pixels. The reference speed is the vanilla Scikit-Learn on a MacBook Pro with 4 cores. Any higher number represents a speedup over this configuration:

MacBookPro: 4 cores, Haswell (AVX2)
Scikit-Learn - Vanilla: 1
Scikit-Learn - Intel: x 52
InsideLoop - Icc 18: x 234 (About 129 GFlops/s)

Dual Xeon: 36 cores, Skylake (AVX512)
Scikit-Learn - Vanilla: x 2
Scikit-Learn - Intel: x 85
InsideLoop - Gcc 7.3: x 276
InsideLoop - Icc 18: x 1382 (About 754 GFlops/s)

This shows that some huge improvements can be done, even above the Intel version. There are a few things where my code has an advantage:

  • It knows at compile time the dimension, which is 3 here
  • It turns out that kmeans clustering with low dimension is really bad for both the vanilla and Intel version of Scikit-Learn
  • It has been compiled with the Intel compiler for my platform. Therefore, I have asked the compiler to generate AVX2 instructions on the MacBook Pro, and AVX512 instructions on the Xeon Skylake

Here is the code if you are interested. A few things are missing such as the handling of empty clusters but I have never seen that being a hotspot for k-means clustering. Also, bare in mind that I am more of a C/C++ programmer. This is my first Cython code and there might be things that are not really "appropriate". To get the speedup, I leverage both parallelisation (using OpenMP) and vectorization (thanks to memoryview and the Intel compiler).

import time
import numpy as np
import sklearn.cluster
from sklearn.cluster import KMeans
from sklearn.datasets import load_sample_image
from cython.parallel import prange
cimport cython


cdef int width = 427
cdef int height = 640
cdef int nb_points = width * height
cdef int nb_clusters = 1024
cdef int nb_iterations_insideloop = 100
cdef int nb_iterations_scikitlearn = 3
cdef int nb_iterations_intel = 20

pixel = np.random.rand(width * height, 3).astype('float32')

pr = np.array(pixel[:, 0])
pg = np.array(pixel[:, 1])
pb = np.array(pixel[:, 2])
cluster = np.arange(0, nb_points, dtype = 'int32')
cluster = cluster % nb_clusters
cr = np.zeros(nb_clusters, dtype = 'float32')
cg = np.zeros(nb_clusters, dtype = 'float32')
cb = np.zeros(nb_clusters, dtype = 'float32')
pop = np.zeros(nb_clusters, dtype = 'int32')

cdef float[::1] pr_view = pr
cdef float[::1] pg_view = pg
cdef float[::1] pb_view = pb
cdef int[::1] cluster_view = cluster
cdef float[::1] cr_view = cr
cdef float[::1] cg_view = cg
cdef float[::1] cb_view = cb
cdef int[::1] pop_view = pop

t0 = time.time()
kmeans(pr_view, pg_view, pb_view,
       cluster_view,
       cr_view, cg_view, cb_view,
       pop_view, nb_iterations_insideloop)
t1 = time.time()
gflops = 1.0e-9 * width * height * nb_iterations_insideloop * nb_clusters * 8 / (t1 - t0)
print('  Time for KMeans clustering, InsideLoop: {} s'.format((t1 - t0) / nb_iterations_insideloop))
print('                             Performance: {} Glops/s'.format(gflops))

t0 = time.time()
res = sklearn.cluster.k_means(pixel, nb_clusters, init = 'random',
    n_init = 1, tol = 1.0e-16, max_iter = nb_iterations_scikitlearn, return_n_iter = True)
t1 = time.time()
gflops = 1.0e-9 * width * height * nb_iterations_scikitlearn * nb_clusters * 8 / (t1 - t0)
print('Time for KMeans clustering, Scikit-Learn: {} s'.format((t1 - t0) / nb_iterations_scikitlearn))
print('                             Performance: {} Glops/s'.format(gflops))

t0 = time.time()
res = KMeans(init = 'random', n_init = 1, tol = 1.0e-16, n_clusters = nb_clusters,
    random_state = 0, max_iter = nb_iterations_intel).fit(pixel)
t1 = time.time()
gflops = 1.0e-9 * width * height * nb_iterations_intel * nb_clusters * 8 / (t1 - t0)
print('       Time for KMeans clustering, Intel: {} s'.format((t1 - t0) / nb_iterations_scikitlearn))
print('                             Performance: {} Glops/s'.format(gflops))


@cython.boundscheck(False)
@cython.wraparound(False)
cpdef void kmeans(float[::1] pr, float[::1] pg, float[::1] pb,
                  int[::1] cluster,
                  float[::1] cr, float[::1] cg, float[::1] cb,
                  int[::1] pop, int nb_iterations) nogil:
  cdef int m = len(cluster) # m is the number of points
  cdef int n = len(pop)     # n is the number of clusters
  cdef int i, j, k, best_j
  cdef float alpha, distance, best_distance
  cdef float x, y, z

  for k in range(nb_iterations):
    for j in range(n):
      cr[j] = 0.0
      cg[j] = 0.0
      cb[j] = 0.0
      pop[j] = 0
    for i in range(m):
      j = cluster[i]
      cr[j] += pr[i]
      cg[j] += pg[i]
      cb[j] += pb[i]
      pop[j] += 1
    for j in range(n):
      if pop[j] > 0:
        alpha = 1.0 / pop[j]
        cr[j] *= alpha
        cg[j] *= alpha
        cb[j] *= alpha

    for i in prange(m):
      best_distance = 3.0 + 1.0
      best_j = 0
      for j in range(n):
        x = pr[i] - cr[j]
        y = pg[i] - cg[j]
        z = pb[i] - cb[j]
        distance = x * x + y * y + z * z
        if distance < best_distance:
          best_distance = distance
          best_j = j
      cluster[i] = best_j

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions