-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Document BayesianLinearRegression
#3578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Document BayesianLinearRegression
#3578
Conversation
zoq
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work.
shrit
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good 👍 Thanks for the hardwork
| // Get the row vector type corresponding to a given MatType. | ||
|
|
||
| template<typename MatType> | ||
| struct GetRowType | ||
| { | ||
| typedef arma::Row<typename MatType::elem_type> type; | ||
| }; | ||
|
|
||
| template<typename eT> | ||
| struct GetRowType<arma::Mat<eT>> | ||
| { | ||
| typedef arma::Row<eT> type; | ||
| }; | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have already added these right ? I think it will be resolved while merging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I will resolve this during the merge. It should be the same in several branches. 👍
| */ | ||
| double Train(const arma::mat& data, | ||
| const arma::rowvec& responses); | ||
| // Many overloads necessary here until std::optional is available with C++17. | ||
| // The first overload is also necessary to avoid confusing the hyperparameter | ||
| // tuner, so that this can be correctly detected as a regression algorithm. | ||
| template<typename MatType> | ||
| ElemType Train(const MatType& data, | ||
| const arma::rowvec& responses); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to have the comment, I was wondering why and then I read it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It took me hours to debug it the first time I encountered it... the algorithm was being detected as a classifier and it turns out the way to fix it was this extra overload. I wonder if it might be better to have each classifier/regressor come with some traits that describe some things about it, but that can be a project for another time.
This is the last of mlpack's simple regression techniques (at least I think), and before I document the rest of mlpack, I'm going to turn my next efforts towards putting together the other parts of the test framework. So, this'll be the last algorithm documented for at least a little while.
Rendered documentation can be seen at https://www.ratml.org/misc/mlpack-markdown-doc/bayesian_linear_regression.html
@mercierc if you have time to take a look at what I've done to the code, please let me know if I made any errors or if you have any suggestions 😄
Here are a summary of the changes I made:
Predict()overloads.Train(). A bunch of overloads are necessary, because until we require C++17 support, we can't usestd::optional.ModelMatTypethat specifies the type of the model to be stored. This required a number of internal changes, including adding template types to allTrain()andPredict()overloads to allow more flexibility in what users pass in.At the moment
BayesianLinearRegressionwill not work if you pass in sparse data (arma::sp_mat), but that is something that can be improved in a future day. Actually, for that, Armadillo is missing sparsestddev()support, which I'll do at some point in the future.