Skip to content

Conversation

@MarkFischinger
Copy link
Contributor

Following up on the Hamerly K-means optimization, I've also applied OpenMP to the naive K-means implementation. Here are the results.

Benchmark Results
I ran a benchmark with the following settings:

Number of data points: 10,000
Number of dimensions: 50
Number of clusters: 10
Maximum iterations: 10

Here's what I found:

Original version: 13.753 seconds
OpenMP version: 12.8868 seconds

This represents a 6.3% decrease in execution time.
Looking forward to your thoughts :)

@conradsnicta
Copy link
Contributor

conradsnicta commented Jul 11, 2024

@rcurtin
Copy link
Member

rcurtin commented Jul 12, 2024

Have you checked whether the distance computation itself is parallelized? If it's not, simply parallelizing that instead might give significantly more speedup. (Note that mlpack doesn't require OpenMP to be available or enabled, so going beyond #pragma omp parallel for can be a little bit tricky, if you want to go the route of the paper @conradsnicta linked to, but it can definitely be done.)

@MarkFischinger
Copy link
Contributor Author

@conradsnicta @rcurtin I experimented with various OpenMP approaches. The most effective method is dividing the calculations into blocks.
On my machine, the benchmark results are:

~/gsoc/mlpack-benchmarks/src/kmeans$ ./naive_kmeans
Running benchmark for naive KMeans...
Benchmark for naive KMeans completed.
Total time: 5.10537 seconds

The problem is that blockSize needs to be adjusted for the specific system (That's why the memory tests currently fail). I could add it as another parameter, but I'm not sure if that's a valid/best practise. Is there maybe a solution where I could set that size dynamically?

@conradsnicta
Copy link
Contributor

conradsnicta commented Jul 15, 2024

@MarkFischinger I don't think hardcoding the block size is the way to go. It would be better to divide the available data into segments based on the number of available threads, which is determined at run-time.

The nominal segment size can be simply number_of_data_vectors / number_of_threads in integer form. The last segment will need to have a possibly different size than the nominal size, to take into account that the division may not be clean.

As an example, say we have 123 vectors and 4 threads, resulting in a nominal segment size of 30. The first 3 segments would each have a size of 30, thereby taking up 90 vectors in total. The last segment size would then be 123 - 90 = 33.

A bunch of corner cases would also need to be taken care of, such as ensuring a minimum number of vectors per thread, and dealing with the possibility of having more available threads than vectors.

@MarkFischinger
Copy link
Contributor Author

I fixed the error and implemented a new minVectorsPerThread variable. Here are the new benchmarks:

mark@mark:~/gsoc/mlpack-benchmarks/src/kmeans$ ./naive_kmeans
Running benchmark for naive KMeans...
Benchmark for naive KMeans completed.
Total time: 7.37996 seconds

Copy link
Member

@rcurtin rcurtin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there @MarkFischinger, the numbers are looking better. I took a look through the code and I think that there are several places where we could get additional speedup while changing less of the original underlying code. Let me know what you think or if I can clarify any of my comments. 👍


#ifdef MLPACK_USE_OPENMP
#include <omp.h>
#endif
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already done in core.hpp, so shouldn't be needed here (so long as naive_kmeans.hpp includes core.hpp.

// Pre-compute squared norms of centroids
arma::vec centroidNorms(clusters);
#ifdef MLPACK_USE_OPENMP
#pragma omp parallel for schedule(static)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to wrap just a #pragma in the #ifdef---the compiler will ignore it if OpenMP is not enabled. Only where the actual code involves OpenMP do we need to use those (like if you are calling omp_get_num_threads() or something).

closestCluster = j;
}
const arma::vec& centroid = centroids.col(j);
distances(j) = std::max(0.0, dataNorm + centroidNorms(j) - 2 * arma::dot(dataPoint, centroid));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is true that a distance can be expressed this way, but in general it is not more efficient to compute distances like this. Note that computing the data norm takes one pass over the data column, but then when you take the dot product of the data point and the centroid, it's a second pass over the data column. Plus, given that this computes the norm of every data point at every iteration of k-means, I'd strongly expect that using distance.Evaluate() would be a lot faster. What was the reason you came to this solution?

// Determine the number of threads and calculate segment size
size_t effectiveThreads = 1;
#ifdef MLPACK_USE_OPENMP
const size_t numThreads = static_cast<size_t>(std::max(1, omp_get_max_threads()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it even allowed that omp_get_max_threads() could return 0? I read the documentation and it doesn't explicitly disallow it, but it says it returns an 'upper bound'... which 0 could never be. Personally I think it would be safe to use omp_get_max_threads() directly.

size_t effectiveThreads = 1;
#ifdef MLPACK_USE_OPENMP
const size_t numThreads = static_cast<size_t>(std::max(1, omp_get_max_threads()));
const size_t minVectorsPerThread = 100;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly I would just take numThreads = 1 when OpenMP isn't available and numThreads = omp_get_max_threads() otherwise. If the user has so few points that each OpenMP thread can't get enough work, then k-means is going to run so fast anyway that the difference won't really matter.

threadId = omp_get_thread_num();
#endif
const size_t segmentStart = threadId * nominalSegmentSize;
const size_t segmentEnd = (threadId == effectiveThreads - 1) ? points : (threadId + 1) * nominalSegmentSize;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't you just use an OpenMP for loop over the points with a static schedule, just making sure to update the correct element in threadCentroids and threadCounts? I think that could simplify this code a good bit (or at least make it much closer to the original code)

arma::mat& localCentroids = threadCentroids[threadId];
arma::Col<size_t>& localCounts = threadCounts[threadId];

arma::vec distances(clusters);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a slowdown to cache all the distances like this, and then later iterate over them all with .index_min() to find the minimum cluster. It should be quicker to use the previous if (dist < minDistance) approach.

{
cNorm += std::pow(distance.Evaluate(centroids.col(i), newCentroids.col(i)),
2.0);
cNorm += arma::norm(centroids.col(j) - newCentroids.col(j), 2);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually makes this computation incorrect if distance is not EuclideanDistance.

return std::sqrt(cNorm);
distanceCalculations += clusters * points;

return cNorm;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow why you have made these changes here?

// Normalize the centroids
for (size_t j = 0; j < clusters; ++j)
{
if (counts(j) > eps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But counts(j) is a size_t and so you would want to compare it with 0. (The previous version of this loop was just fine, as far as I can tell.)

// Find the closest centroid to this point.
double minDistance = std::numeric_limits<double>::infinity();
size_t closestCluster = centroids.n_cols; // Invalid value.
const auto dataPoint = dataset.col(i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarkFischinger Do not use auto with Armadillo expressions. As part of its meta-programming framework, Armadillo uses short-lived temporaries that are not properly handled by auto.

{
const arma::vec& centroid = centroids.col(j);
distances(j) = std::max(0.0, dataNorm + centroidNorms(j) - 2 * arma::dot(dataPoint, centroid));
// Optimized Euclidean distance calculation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarkFischinger I'm not convinced this is optimized; can you clarify what led you to the solution of expressing the Euclidean distance d(x, y) as <x, x> + <y, y> - 2 <x, y>? There are a number of disadvantages to this approach including extra memory usage, multiple passes over data points, and so forth.

Copy link
Member

@rcurtin rcurtin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates @MarkFischinger. Can you be sure to address all the comments I leave please? If you can also provide new benchmarking numbers when you update the code it would be really helpful. 👍

arma::mat& localCentroids = threadCentroids[threadId];
arma::Col<size_t>& localCounts = threadCounts[threadId];

for (size_t i = segmentStart; i < segmentEnd; ++i)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does a simple omp parallel for not work here? I don't think you need the segmentStart and segmentEnd variables---I think all you need to do is get the thread id and then add to the local centroids.

{
const double dist = distance.Evaluate(dataset.col(i),
centroids.unsafe_col(j));
const double dist = std::max(0.0, dataNorm + centroidNorms(j) - 2 * arma::dot(dataPoint, centroids.col(j)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still unclear on why you are doing distance computations like this. Can you clarify your thinking and any advantage to doing it like this? I can see several disadvantages and I don't think it will be faster.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rcurtin The idea was to optimize performance for high-dimensional data by precomputing squared norms for the centroids centroidNorms and data points dataNorm. This lets us reduce the number of dot products, which can be expensive, and potentially speed things up. The formula I used:

const double dist = std::max(0.0, dataNorm + centroidNorms(j) - 2 * arma::dot(dataPoint, centroids.col(j)));

aims to minimize the calculations and help numerical stability.

However, I see your point. While this method might be faster in certain cases, it adds some complexity and overhead, especially for smaller datasets. Your suggestion to just compute the distance directly

const double dist = distance.Evaluate(dataPoint, centroids.col(j));

is definitely simpler and easier to maintain. I’ll run some benchmarks to compare the two. If my approach doesn’t show a significant speedup, switching to the simpler method makes sense :)

@MarkFischinger
Copy link
Contributor Author

@rcurtin I updated the code following the feedback and optimized the distance calculation and OpenMP parallelization. Below are the additional benchmarks comparing the performance before and after these changes:

Before Optimization:

Running benchmark for naive KMeans...
Benchmark for naive KMeans completed.
Total time: 11.494 seconds
Running test for KMeans...
Assignments:         1        1        0        0

Centroids:
   5.5000   1.5000
   5.5000   1.5000

Test passed successfully!

After Optimization:

Running benchmark for naive KMeans...
Benchmark for naive KMeans completed.
Total time: 6.35895 seconds
Running test for KMeans...
Assignments:         1        1        0        0

Centroids:
   5.5000   1.5000
   5.5000   1.5000

Test passed successfully!

I went back to the optimized Euclidean distance formula, precomputing squared norms for both centroids and data points. This cuts down on redundant calculations, particularly when dealing with high-dimensional data. I also improved the OpenMP parallelization by making the parallel loops more fine-grained. Using a static schedule helps minimize synchronization overhead, leading to better thread utilization and faster execution. Finally, I focused on memory and cache efficiency by reusing precomputed values and that the data access is more cache friendly.

@rcurtin
Copy link
Member

rcurtin commented Aug 22, 2024

The reason I am hammering so hard (now for I believe the fifth time) about the distance calculation is because it is not always the right thing to simply observe that something is faster and then do it. The real question, when observing something like this, is understanding why the other thing is slower. (This is going to be a long comment...)

Starting from the basics, we are talking about the following two approaches:

  • Sum (a[i] - b[i])^2 for each dimension i
  • Precompute (a[i] * a[i]) and (b[i] * b[i]), then compute (a_norm[i] + b_norm[i] - 2 * a[i] * b[i]).

From a standpoint of what a computer should do, the first approach should be faster for several reasons:

  • Only one pass over each a and b (nicer on memory)
  • One subtract and multiply per dimension, vs. three multiplies and two add/sub operations
  • No need for auxiliary memory to be allocated

My intuition is that memory latency will be the biggest issue here.

So the question, from my perspective, is not "which is faster", but instead "why is the first one not faster", if that's what you are observing. One key observation is that the dot-based approach will use OpenBLAS (which is very heavily optimized, and also uses OpenMP for parallelization, but, I think that should not come into play here). I would imagine that Armadillo has the same optimizations going on, but perhaps the compiler is not emitting the code we want. I therefore wrote up the following code:

#include <armadillo>

double compute_ed(const arma::mat& set1, const arma::mat& set2, const size_t points)
{
  double sum = 0.0;
  for (size_t i = 0; i < points; ++i)
  {
    for (size_t j = 0; j < points; ++j)
    {
      sum += arma::norm(set1.col(i) - set2.col(j), 2);
    }
  }
  return sum;
}

double compute_ed2(const arma::mat& set1, const arma::mat& set2, const size_t points)
{
  double sum = 0.0;
  const size_t n_rows = set1.n_rows;
  for (size_t i = 0; i < points; ++i)
  {
    for (size_t j = 0; j < points; ++j)
    {
      double point_sum = 0.0;
      for (size_t k = 0; k < set1.n_rows; k += 4)
      {
        // take 4 elements at a time to try and "trick" the vectorizer into actually working
        const double a1 = set1.mem[k + i * n_rows];
        const double a2 = set1.mem[k + 1 + i * n_rows];
        const double a3 = set1.mem[k + 2 + i * n_rows];
        const double a4 = set1.mem[k + 3 + i * n_rows];

        const double b1 = set2.mem[k + j * n_rows];
        const double b2 = set2.mem[k + 1 + j * n_rows];
        const double b3 = set2.mem[k + 2 + j * n_rows];
        const double b4 = set2.mem[k + 3 + j * n_rows];

        const double ps1 = (a1 - b1) * (a1 - b1);
        const double ps2 = (a2 - b2) * (a2 - b2);
        const double ps3 = (a3 - b3) * (a3 - b3);
        const double ps4 = (a4 - b4) * (a4 - b4);

        point_sum += ps1 + ps2 + ps3 + ps4;
      }
      sum += std::sqrt(point_sum);
    }
  }
  return sum;
}

double compute_ip(const arma::mat& set1, const arma::mat& set2, const arma::vec& set1_norms, const arma::vec& set2_norms, const size_t points)
{
  double sum = 0.0;
  for (size_t i = 0; i < points; ++i)
  {
    for (size_t j = 0; j < points; ++j)
    {
      sum += std::sqrt(set1_norms[i] + set2_norms[j] - 2 * arma::dot(set1.col(i), set2.col(j)));
    }
  }
  return sum;
}

int main()
{
  const size_t dims = 100;
  const size_t points = 1000;

  arma::mat set1(dims, points, arma::fill::randu);
  arma::mat set2(dims, points, arma::fill::randu);

  arma::wall_clock c;

  // Compute distances the EuclideanDistance way.
  c.tic();
  double sum = compute_ed(set1, set2, points);
  const double euclidean_time = c.toc();
  std::cout << "norm()-based calculation: " << euclidean_time << "s (result "
      << sum << ")." << std::endl;

  c.tic();
  sum = compute_ed2(set1, set2, points);
  const double euclidean_2_time = c.toc();
  std::cout << "norm()-based manual calculation: " << euclidean_2_time << "s (result "
      << sum << ")." << std::endl;

  // Now compute the norm-based way.
  c.tic();
  arma::vec set1_norms(points, arma::fill::none);
  arma::vec set2_norms(points, arma::fill::none);

  for (size_t i = 0; i < points; ++i)
  {
    set1_norms[i] = arma::dot(set1.col(i), set1.col(i));
  }
  for (size_t i = 0; i < points; ++i)
  {
    set2_norms[i] = arma::dot(set2.col(i), set2.col(i));
  }
  const double selfnorm_time = c.toc();
  std::cout << "dot()-based norm preparation time: " << selfnorm_time << "s."
      << std::endl;

  c.tic();
  sum = compute_ip(set1, set2, set1_norms, set2_norms, points);
  const double norm_time = c.toc();
  std::cout << "dot()-based calculation: " << norm_time << "s (result " << sum
      << ")." << std::endl;
}

Here I implemented compute_ed2() because I noticed that the Armadillo code in compute_ed() was not being vectorized properly to use SIMD instructions. When I compile with g++ -std=c++17 -march=native -DARMA_NO_DEBUG -O3 -o dot_test dot_test.cpp -larmadillo and run, I see:

$ ./dot_test 
norm()-based calculation: 0.0977767s (result 4.07441e+06).
norm()-based manual calculation: 0.0344042s (result 4.07441e+06).
dot()-based norm preparation time: 0.000114631s.
dot()-based calculation: 0.0290754s (result 4.07441e+06).

So you can see that I am getting compute_ed2() close-ish (like within 20% roughly). The big difference between compute_ed1() (which uses arma::norm() directly) and compute_ed2(), I believe, is that the vectorizer fails to properly vectorize the arma::norm() loop. Glancing into the code, I am wondering whether a temporary Armadillo object is being made for the expression set1.col(i) - set2.col(j), but I have not dug deep enough to be sure either way.

Digging deeper would involve compiling with -fopt-info-vec-missed and other similar things to see what exactly is going on. But where I am going with this is that the original distance computation strategy should be at least equivalently fast, and if it is not, we need to put in time to understand why instead of switching it out.

CC: @conradsnicta, would be interested to see what you think. I suppose we could be using norm() wrong or there is something we should use instead. I don't mind putting in some time to debug and optimize this (since distance computations are a huge part of mlpack!) but I think it would be useful to get your quick thoughts on the problem first.

@conradsnicta
Copy link
Contributor

@rcurtin @MarkFischinger Interim hazy perspective below. I'm currently recovering from surgery, so brain is working at reduced capacity.

For simple expressions where the memory is directly accessible in a linear column-major order (so excluding subrows and any compound expressions), the functions arma::dot() and arma::norm() generally just call the corresponding BLAS functions for common cases. This may result in using accelerated OpenBLAS implementations. For arma::norm() there is also an extra check for underflows and overflows, with a fallback to a slow but robust algorithm.

For compound expressions and subrows, arma::dot() and arma::norm() will use built-in algorithms to avoid making copies of data. In some cases these algorithms can be vectorised by the compiler, though the speed difference is around 20% in my (limited) observations on an ancient laptop. The speedup may be higher with CPUs that handle AVX-512.

There is a bunch of caveats here:

  • Various compilers will differ in what they're willing to vectorise, in order to stick with the IEEE floating point standard. For GCC, more vectorisation can occur if the -ffinite-math-only flag is used, but this can come with the possible side-effect of not treating NaN and Inf values properly. Haven't fully checked how picky clang is.

  • AVX instructions generally don't get used with GCC (and possibly clang) unless the -march=native flag is used.

  • For GCC, need to use the -O3 optimisation level for vectorisation to be fully enabled. I think clang enables vectorisation at the -O2 level, though its vectoriser may not be as advanced as GCC's vectoriser at -O3.

Other observations:

@MarkFischinger
Copy link
Contributor Author

@conradsnicta and @rcurtin, thanks for your feedback. I’ve been digging deeper into the performance of the sum of squared differences versus the dot-product based approach. Despite initial expectations that the sum of squared differences would be faster, my benchmarks show that both methods perform almost identically, prompting further investigation.

Here are the initial benchmark results:

Sum of squared differences time: 1.02083 seconds, result: 1.65892e+08
Dot-product based time: 1.037 seconds, result: 1.65892e+08

After applying the -fopt-info-vec flag, both methods were optimized with SIMD instructions using 16-byte vectors from the AVX instruction set. The updated benchmarks were:

Sum of squared differences time: 1.04866 seconds, result: 1.66398e+08
Dot-product based time: 1.03961 seconds, result: 1.66398e+08

Enabling the -ffast-math flag led to a significant reduction in execution times:

Sum of squared differences time: 0.272451 seconds, result: 1.67351e+08
Dot-product based time: 0.253304 seconds, result: 1.67351e+08

These results show how aggressive compiler optimizations impact both methods, bringing their performance closer together.

I am wondering whether a temporary Armadillo object is being made for the expression set1.col(i) - set2.col(j)

I also considered whether temporary Armadillo objects during operations like set1.col(i) - set2.col(j) could be affecting performance. However, due to the use of Proxy<T1> P(X);, which efficiently manages references, this was ruled out as a factor.

For further comparison, I benchmarked our implementation against arma::kmeans().

Benchmark for mlpack KMeans completed.
Total time: 1.09256 seconds
Benchmark for Armadillo KMeans completed.
Total time: 0.00981316 seconds

I'm a bit unsure why the performance diffenerence is so huge. Armadillo is more optimized, but I'm looking into it.

When reviewing the assembly code, I observed that both the sum_of_squared_differences and dot_product_distance functions are using SIMD instructions, particularly with AVX (Advanced Vector Extensions) and FMA operations like vfmadd231pd.
However, the performance gap between the mlpack implementation and arma::kmeans() suggests there’s considerable room for improvement.

Currently, I’m focusing on improving the arma::norm function to see if we can apply similar optimizations.

@rcurtin
Copy link
Member

rcurtin commented Aug 27, 2024

It's not possible to directly compare mlpack's and Armadillo's k-means implementations without further digging. They will run for different numbers of iterations, to different tolerance measures, etc., etc... the way to compare them with each other would be to time how long a single iteration takes, or how long a specific number of iterations takes. (That's not to say that the iteration or termination strategies in mlpack couldn't be improved. But here we are considering only the distance computations, so ensuring they are equal across both implementations is needed.)

@shrit
Copy link
Member

shrit commented Aug 28, 2024

@MarkFischinger, looking at this PR, it seems to me that it is getting further from the objective. For the sake of GSoC, please could you do the following:

  1. Revert the modifications related to the distance calculation
  2. keep the original update regarding the OpenMP
  3. Do not compare mlpack and armadillo implementations since they are not relevant.

Please do all of the above requests by the end of today, if you are interested in optimising the distance calculation let us do that in a different PR and merge this one with OpenMP modification only

Thank you very much.

@rcurtin
Copy link
Member

rcurtin commented Aug 28, 2024

I do think that we need to get to the bottom of why arma::norm() isn't giving good performance in our case, but for the sake of getting something merged, I think just OpenMP-izing the distance calculations (and leaving the optimization of the individual distance calculations for later) is sufficient. We can return to the arma::norm() issue elsewhere---and working that out will have significant benefits throughout the library because of how widely distance calculations are used. 👍

@rcurtin
Copy link
Member

rcurtin commented Aug 29, 2024

There are a lot of changes here that aren't actually functional changes---do you mind going through and reducing the changes to only the actual functionality changes? Thanks 👍

Copy link
Member

@rcurtin rcurtin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recognize that most of my comments will seem trivial, but the underlying idea is to only change code when there is a good reason to do so. If you go through the 'files' tab of the PR that shows the diff, you will see what I mean---there are lots of changes, but they aren't substantial; they are just naming changes or comment changes (or comment deletions, which is really generally not a good idea unless you have a good reason why). If we can reduce this to the minimum possible set of changes, then that will look good to me and we should merge it. 👍

(I still haven't had a chance to look deeper into the Armadillo norm() issues. If you want to open another issue for that feel free, otherwise I will do it when I have a chance.)

arma::Col<size_t> localCounts(centroids.n_cols, arma::fill::zeros);
// Thread-local storage for partial sums
arma::mat threadCentroids(centroids.n_rows, centroids.n_cols, arma::fill::zeros);
arma::Col<size_t> threadCounts(centroids.n_cols, arma::fill::zeros);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes don't actually change anything. Can you revert any changes where there is no actual functionality change? The only change here is a change to the name of the variables, and I don't see any reason to do that (plus it makes the diffs across history unnecessarily larger).

counts.zeros(centroids.n_cols);

// Find the closest centroid to each point and update the new centroids.
// Computed in parallel over the complete dataset
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any reason to remove this comment.

#pragma omp for
for (size_t i = 0; i < (size_t) dataset.n_cols; ++i)
#pragma omp for schedule(static) nowait
for (size_t i = 0; i < dataset.n_cols; ++i)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The (size_t) cast fixed a compilation warning for some compilers and we should therefore keep it.

{
// Find the closest centroid to this point.
double minDistance = std::numeric_limits<double>::infinity();
double minDistance = std::numeric_limits<double>::max();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why make this change? infinity() was fine. I mean max() works too, but again, better to avoid making changes when possible.

{
const double dist = distance.Evaluate(dataset.col(i),
centroids.unsafe_col(j));
const double dist = distance.Evaluate(dataset.col(i), centroids.col(j));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe now this line is too long. Can you revert this to the previous version?

localCounts(closestCluster)++;
// Update thread-local centroids and counts
threadCentroids.col(closestCluster) += dataset.col(i);
threadCounts(closestCluster)++;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a reason to remove the assertion or change the comments. I think the switch from .unsafe_col() to .col() is a good change, though. (I believe once upon a time that did not compile correctly, many many years ago.)

{
cNorm += std::pow(distance.Evaluate(centroids.col(i), newCentroids.col(i)),
2.0);
cNorm += std::pow(distance.Evaluate(centroids.col(i), newCentroids.col(i)), 2.0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that now this line is too long too.

} // namespace mlpack

#endif
#endif
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update your editor setting to leave newlines at the end of files?

@MarkFischinger
Copy link
Contributor Author

@rcurtin @shrit I added all the important comments back in, replaced .unsafe_col() with .col() like you suggested, fixed the line issue at the end and removed the arma::zeros, because of this comment.

Running the benchmark again using this benchmark script and this visualization script, I generated this graphic:

benchmark_comparison

Copy link
Member

@rcurtin rcurtin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning this up---I know that it's not the original level of optimization you were hoping for, but it is still a nice (albeit more modest) speedup, and we uncovered an issue with arma::norm() that is good to know about and indicates that there is a lot of speedup available for various mlpack algorithms if we can get arma::norm() to properly compile into SIMD instructions. We can investigate that later, but for now, let's get this merged. 🚀

Co-authored-by: Ryan Curtin <[email protected]>
@rcurtin
Copy link
Member

rcurtin commented Aug 31, 2024

Thanks for the hard work @MarkFischinger! I opened #3789 to follow up on the EuclideanDistance efficiency issue. We can continue that there.

I think I need to fix the documentation build so that it issues warnings and not errors---in this case, looks like the CS department's website for Boston College has gone offline, causing that build to fail. Oh well.

@rcurtin rcurtin merged commit 40f8414 into mlpack:master Aug 31, 2024
This was referenced Sep 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants