Skip to content

Conversation

@rcurtin
Copy link
Member

@rcurtin rcurtin commented Dec 15, 2023

This is the last of mlpack's simple regression techniques (at least I think), and before I document the rest of mlpack, I'm going to turn my next efforts towards putting together the other parts of the test framework. So, this'll be the last algorithm documented for at least a little while.

Rendered documentation can be seen at https://www.ratml.org/misc/mlpack-markdown-doc/bayesian_linear_regression.html

@mercierc if you have time to take a look at what I've done to the code, please let me know if I made any errors or if you have any suggestions 😄

Here are a summary of the changes I made:

  • Added a version of the constructor that can also train.
  • Added single-point Predict() overloads.
  • Added support for passing hyperparameters to Train(). A bunch of overloads are necessary, because until we require C++17 support, we can't use std::optional.
  • Added a template parameter ModelMatType that specifies the type of the model to be stored. This required a number of internal changes, including adding template types to all Train() and Predict() overloads to allow more flexibility in what users pass in.

At the moment BayesianLinearRegression will not work if you pass in sparse data (arma::sp_mat), but that is something that can be improved in a future day. Actually, for that, Armadillo is missing sparse stddev() support, which I'll do at some point in the future.

Copy link
Member

@zoq zoq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work.

Copy link
Member

@shrit shrit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good 👍 Thanks for the hardwork

Comment on lines 114 to 127
// Get the row vector type corresponding to a given MatType.

template<typename MatType>
struct GetRowType
{
typedef arma::Row<typename MatType::elem_type> type;
};

template<typename eT>
struct GetRowType<arma::Mat<eT>>
{
typedef arma::Row<eT> type;
};

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have already added these right ? I think it will be resolved while merging

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I will resolve this during the merge. It should be the same in several branches. 👍

Comment on lines 167 to +173
*/
double Train(const arma::mat& data,
const arma::rowvec& responses);
// Many overloads necessary here until std::optional is available with C++17.
// The first overload is also necessary to avoid confusing the hyperparameter
// tuner, so that this can be correctly detected as a regression algorithm.
template<typename MatType>
ElemType Train(const MatType& data,
const arma::rowvec& responses);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to have the comment, I was wondering why and then I read it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It took me hours to debug it the first time I encountered it... the algorithm was being detected as a classifier and it turns out the way to fix it was this extra overload. I wonder if it might be better to have each classifier/regressor come with some traits that describe some things about it, but that can be a project for another time.

@rcurtin rcurtin merged commit f743369 into mlpack:master Dec 22, 2023
@rcurtin rcurtin deleted the bayesian_linear_regression_doc branch December 22, 2023 14:45
@rcurtin rcurtin mentioned this pull request May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants