@@ -54,8 +54,6 @@ namespace mlpack {
5454 *
5555 * @tparam MatType Type of the input data (arma::colvec, arma::mat,
5656 * arma::sp_mat or arma::cube).
57- * @tparam MatType Type of the output data (arma::colvec, arma::mat,
58- * arma::sp_mat or arma::cube).
5957 */
6058template <
6159 typename MatType = arma::mat
@@ -71,17 +69,12 @@ class LayerNormType : public Layer<MatType>
7169 *
7270 * @param eps The epsilon added to variance to ensure numerical stability.
7371 */
74- LayerNormType (const double eps = 1e-8 );
72+ LayerNormType (const double eps);
7573
7674 // ! Clone the LayerNormType object. This handles polymorphism correctly.
7775 LayerNormType* Clone () const { return new LayerNormType (*this ); }
7876
79- /* *
80- * Reset the layer parameters.
81- */
82- void Reset ();
83-
84- /* *
77+ /* *
8578 * Forward pass of Layer Normalization. Transforms the input data
8679 * into zero mean and unit variance, scales the data by a factor gamma and
8780 * shifts it by beta.
@@ -138,9 +131,11 @@ class LayerNormType : public Layer<MatType>
138131 // as the input.
139132 this ->outputDimensions = this ->inputDimensions ;
140133 size = this ->inputDimensions [0 ];
141- for (size_t i=1 ; i<this ->inputDimensions .size (); i++) size + = this ->inputDimensions [i];
134+ for (size_t i=1 ; i<this ->inputDimensions .size (); i++) size * = this ->inputDimensions [i];
142135 }
143136
137+ void SetWeights (typename MatType::elem_type* /* weightsPtr */ ) override ;
138+
144139 void CustomInitialize (
145140 MatType& /* W */ ,
146141 const size_t /* elements */ ) override ;
0 commit comments