Skip to content

Conversation

@MarkFischinger
Copy link
Contributor

Following our discussion in issue #3662, I've implemented the pade approximant in the log softmax layer. Due to time constraints, I haven't run the tests yet, but I plan to do so shortly and update you with the results.

@MarkFischinger MarkFischinger changed the title Implementation - Pade Approximant in Log Softmax Layer Implementation - Padé Approximant in Log Softmax Layer Apr 10, 2024
Copy link
Member

@shrit shrit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you use mlpack style for variables ?

@MarkFischinger
Copy link
Contributor Author

@shrit, thank you for pointing that out. The new commit includes the fix :)

@shrit
Copy link
Member

shrit commented Apr 11, 2024

I approved this one too quickly, I did not see the that the tests were not passing.
@MarkFischinger could you try to run the tests locally ?
I would be nice to compare the matrices generated by the original fast method and Padé because I think there are a good amount of difference, otherwise the tests would not have failed ?

Copy link
Member

@rcurtin rcurtin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just preventing mlpack-bot from auto-approving until we get the fixes worked out. I'm guessing that the level of approximation might be too high, and things are not converging? (Maybe a threshold like x < 13 is needed?)

};

output.transform([padeApproximant](double x) {
return padeApproximant(x);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be a bit cleaner to just inline the whole approximant into the lambda, but, up to you.

@MarkFischinger
Copy link
Contributor Author

@shrit @rcurtin Sorry for the delay with the benchmarks -I needed to run a more detailed analysis to find an effective solution.
Here’s what I found:

Mnist Simple
Old Implementation:
Validation Loss: 567.577
Duration: 79709ms
Loss: 0.0331713
Accuracy: Train = 98.463%, Valid = 97.0707%

New Implementation (Scale 4, x < 13.0):
Validation Loss: 563.98
Duration: 80412ms
Loss: 0.049433
Accuracy: Train = 98.3571%, Valid = 96.9517%

The initial idea of adding only x < 13.0 proved too broad, leading to uncontrolled error spikes due to the large $X$ values I had been concerned about. In the discussion issue example, it featured only small X values, which worked perfectly with the Padé approximation, but large values (above 8.0) do not work quite well. But by scaling $X$ by $4$, I reduced the error, now notably smaller than in the old version, as you can see in this graph:

errors_and_time_4_4_fair

Despite the graph showing a seemingly doubled duration in runtime, the actual difference in the cnn run is minor. This improvement could be a viable option for implementation? What do you think?

auto scaledPadeApproxExpMinusX = [](double x) {
    if (x < 13.0) {
      double s = 4.0;
  
      double xs = x / s;
  
      double numerator = 24 - 12*xs + 4*xs*xs - xs*xs*xs;
      double denominator = 24 + 12*xs + 4*xs*xs + xs*xs*xs;
  
      double pade = numerator / denominator;
      return std::pow(pade, s);
    }
  
    return 0.0;
  };

  output.transform([scaledPadeApproxExpMinusX](double x) {
    return scaledPadeApproxExpMinusX(x);
  });

I think I will also test the algorithm on mnist_cnn soon.

@shrit
Copy link
Member

shrit commented Apr 16, 2024

@MarkFischinger give it a try, what I find weird is that, when we tested this separately the time was way faster, while here it looks much slower than the original one.
This worth investigating.

@MarkFischinger
Copy link
Contributor Author

Hey @shrit, I think the trouble we're seeing comes from the higher $X$ values in our MNIST examples. Originally, we only saw $X$ values above $4$ about $2.275$% of the time, based on our normal distribution setup with arma::mat output = arma::randn(1000, 1000, arma::distr_param(0, 2)). But the MNIST data showed much higher $X$ values frequently, which caused those spikes and ultimately broke the code. That's why I had to scale them down, which did slow our runtime a bit.

Here are the $X$ values for the mnist example (output):

x = 8.78431
x = 23.1975
x = 22.0821
x = 16.2784
x = 18.5682

I'm thinking, since lower $X$ values are more common and they handle better, maybe we should try the original Padé approximation for values up to say, $4$, and keep our current method as a backup for anything higher. This way, we can handle typical cases fast and still catch any outliers without any problems. What do you think? Should I run some benchmarks on this mixed approach?

@rcurtin
Copy link
Member

rcurtin commented Apr 16, 2024

Yeah, a switch to the existing implementation at about x > 4 would probably do the trick for convergence too. I would be interested to see if it would be faster, too---although, to check that, you'd need to ensure that the number of epochs used for training are constant (or, just time a single epoch, that's fine too).

The scaling trick is definitely a good one for convergence, but I suspect the std::pow is painful and what causes it to be slower.

@MarkFischinger
Copy link
Contributor Author

@rcurtin I did some backtesting, and the results showed unfortunately no/only minor improvements in runtime. Statistically, combining those two algorithms should reduce the error, but I'm still looking for faster implementations because I'm hopeful that I can find a better solution. I'll update you as soon as possible, though my available time will be limited for the next few days due to the exams :/

@rcurtin
Copy link
Member

rcurtin commented May 6, 2024

I also wonder if it would be possible to get additional speedup by writing a loop that the compiler can autovectorize. I doubt that .transform() on arbitrary lambda functions could make use of it. In Armadillo this is often done with a for loop that computes values for 2 or 4 items at once, and is written such that each operation for each element has no data dependency on previous iterations of the loop. I can probably find an example if you like, just let me know if you want to try that out. 👍

@mlpack-bot
Copy link

mlpack-bot bot commented Jun 5, 2024

This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍

@MarkFischinger
Copy link
Contributor Author

@shrit @rcurtin I'm adding targeted OpenMP parallelization to the Padé approximant computation. I’ve applied parallel processing to improve the performance of this approximant without the fast approximation function, which doesn’t gain from parallel execution here.

Here are the benchmarks of the commits:
grafik

They have an almost identical error, so that should not be an error.
grafik

I'm also working/testing one more optimization :)

@rcurtin
Copy link
Member

rcurtin commented Jun 24, 2024

Nice, the numbers look good. 👍 Two quick checks:

  • Does the mnist_cnn example still converge to the same accuracy? That is just a quick sanity check to ensure that the changed error rates are not a problem.

  • Was the original implementation OpenMP-ized? If not, it could be worth a quick test to add OpenMP there too, just to make sure that the speedup you are seeing is not entirely just due to OpenMP.

@MarkFischinger
Copy link
Contributor Author

@rcurtin You're right, the default fast approx with OpenMP is quicker. I've gone ahead and committed the OpenMP-ized fast approx version :)

Copy link
Member

@rcurtin rcurtin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unfortunate the Pade approximant didn't work out, but sometimes that's how it goes---I definitely know the feeling of spending many weeks investigating something only to find at the end of the day that the clever idea didn't work out. Do you think you want to call it case closed on the Pade approximant, or do you still have some ideas or thoughts that might help out?

As a side note, I personally would focus not on microbenchmarks (although they are definitely useful!) but instead on the actual real-world measured time for training in e.g. the mnist_cnn example. Personally I will use microbenchmarks to try and guide my development, but the "ground truth" I use to decide whether or not anything really is faster is an actual fully-working example, since this is closer to what a user will see in practice. 👍

maxInput.each_row() += log(sum(output));
#pragma omp parallel for collapse(2)
for (size_t i = 0; i < output.n_rows; i++) {
for (size_t j = 0; j < output.n_cols; j++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could just iterate over output.n_elem, the code might be a little simpler. Also do you think you can fix the style? e.g.

if (condition)
{
  stuff
}
else
{
  other stuff
}

instead of

if (condition) {
  stuff
} else {
  other stuff
}

return *this;
}


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

No need for an extra line 👍


template<typename MatType>
void LogSoftMaxType<MatType>::Forward(const MatType& input, MatType& output)
void LogSoftMaxType<MatType>::ForwardImpl(const MatType& input,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does ForwardImpl() get called? I think we also need a generic implementation that will work with Bandicoot (although I know that will be slower).

@rcurtin
Copy link
Member

rcurtin commented Jun 25, 2024

Also it would be great if you could provide some updated numbers on how much OpenMP helps, etc., just to give an idea---I haven't run the updated code myself, but I'm sure that you have in your experiments. 👍

@MarkFischinger MarkFischinger force-pushed the feat/pade_approximant_implementation branch from f0646a6 to 822c499 Compare June 26, 2024 10:18
@MarkFischinger
Copy link
Contributor Author

@rcurtin I'll return to this if I get another idea, but the current commit should be merged, because the openMP fast approximation shows a noticeable speed boost. Here are the MNIST dataset benchmarks :)

Fast Approximation without OpenMP

756/756 [====================================================================================================] 100% - 79.897s/epoch; 105ms/step; loss: 9.42397

Validation loss: 10191.9.
Time elapsed: 82497ms
Accuracy: train = 70.4048%,      valid = 70.1429%

Fast Approximation with OpenMP

756/756 [====================================================================================================] 100% - 73.3507s/epoch; 97ms/step; loss: 9.42397
Validation loss: 10191.9.
Time elapsed: 75912ms
Accuracy: train = 70.4048%,      valid = 70.1429%

Padé approximant with OpenMP

756/756 [====================================================================================================] 100% - 74.6628s/epoch; 98ms/step; loss: 19.4617
Validation loss: 76287.9.
Time elapsed: 77906ms

Copy link
Member

@rcurtin rcurtin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarkFischinger nice work, great to see we did end up getting an improvement here (even if it's not the one you were originally looking for :)).

Do you want to add a small bullet point to HISTORY.md indicating the speedup?

@conradsnicta just an FYI: we got some speedup in this PR over using .transform() and .each_col(), but the primary method of speedup is via OpenMP parallelization. I notice that neither .transform() nor .each_col() are OpenMP-ized---wanted to call it to your attention in case you wanted to incorporate that into the next Armadillo release.

Copy link
Member

@shrit shrit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, we started with Padé and ended up with the original implementation parallelized and removed the transform function

@shrit shrit merged commit 0cf9a82 into mlpack:master Jul 3, 2024
This was referenced Sep 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants