Skip to content

Commit 7259b71

Browse files
akroppakropp
authored andcommitted
Fix LayerNorm
1 parent e64b454 commit 7259b71

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

src/mlpack/methods/ann/layer/layer_norm.hpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ namespace mlpack {
5454
*
5555
* @tparam MatType Type of the input data (arma::colvec, arma::mat,
5656
* arma::sp_mat or arma::cube).
57-
* @tparam MatType Type of the output data (arma::colvec, arma::mat,
58-
* arma::sp_mat or arma::cube).
5957
*/
6058
template <
6159
typename MatType = arma::mat
@@ -71,17 +69,12 @@ class LayerNormType : public Layer<MatType>
7169
*
7270
* @param eps The epsilon added to variance to ensure numerical stability.
7371
*/
74-
LayerNormType(const double eps = 1e-8);
72+
LayerNormType(const double eps);
7573

7674
//! Clone the LayerNormType object. This handles polymorphism correctly.
7775
LayerNormType* Clone() const { return new LayerNormType(*this); }
7876

79-
/**
80-
* Reset the layer parameters.
81-
*/
82-
void Reset();
83-
84-
/**
77+
/**
8578
* Forward pass of Layer Normalization. Transforms the input data
8679
* into zero mean and unit variance, scales the data by a factor gamma and
8780
* shifts it by beta.
@@ -138,9 +131,11 @@ class LayerNormType : public Layer<MatType>
138131
// as the input.
139132
this->outputDimensions = this->inputDimensions;
140133
size = this->inputDimensions[0];
141-
for (size_t i=1; i<this->inputDimensions.size(); i++) size += this->inputDimensions[i];
134+
for (size_t i=1; i<this->inputDimensions.size(); i++) size *= this->inputDimensions[i];
142135
}
143136

137+
void SetWeights(typename MatType::elem_type* /* weightsPtr */) override;
138+
144139
void CustomInitialize(
145140
MatType& /* W */,
146141
const size_t /* elements */) override;

0 commit comments

Comments
 (0)