2020
2121namespace mlpack {
2222
23- template <typename InputType, typename OutputType , typename RegularizerType>
24- MultiheadAttentionType<InputType, OutputType , RegularizerType>::
23+ template <typename MatType , typename RegularizerType>
24+ MultiheadAttentionType<MatType , RegularizerType>::
2525MultiheadAttentionType () :
2626 tgtSeqLen (0 ),
2727 srcSeqLen (0 ),
2828 embedDim (0 ),
2929 numHeads (0 ),
3030 headDim (0 ),
31- attnMask (InputType ()),
32- keyPaddingMask (InputType ())
31+ attnMask (MatType ()),
32+ keyPaddingMask (MatType ())
3333{
3434 // Nothing to do here.
3535}
3636
37- template <typename InputType, typename OutputType , typename RegularizerType>
38- MultiheadAttentionType<InputType, OutputType , RegularizerType>::
37+ template <typename MatType , typename RegularizerType>
38+ MultiheadAttentionType<MatType , RegularizerType>::
3939MultiheadAttentionType (
4040 const size_t tgtSeqLen,
4141 const size_t srcSeqLen,
4242 const size_t embedDim,
4343 const size_t numHeads,
44- const InputType & attnMask,
45- const InputType & keyPaddingMask) :
44+ const MatType & attnMask,
45+ const MatType & keyPaddingMask) :
4646 tgtSeqLen (tgtSeqLen),
4747 srcSeqLen (srcSeqLen),
4848 embedDim (embedDim),
@@ -59,36 +59,36 @@ MultiheadAttentionType(
5959 headDim = embedDim / numHeads;
6060}
6161
62- template <typename InputType, typename OutputType , typename RegularizerType>
63- void MultiheadAttentionType<InputType, OutputType , RegularizerType>::SetWeights(
64- typename OutputType ::elem_type* weightsPtr)
62+ template <typename MatType , typename RegularizerType>
63+ void MultiheadAttentionType<MatType , RegularizerType>::SetWeights(
64+ typename MatType ::elem_type* weightsPtr)
6565{
66- weights = OutputType (weightsPtr, 1 , (4 * embedDim + 4 ) * embedDim, false ,
66+ weights = MatType (weightsPtr, 1 , (4 * embedDim + 4 ) * embedDim, false ,
6767 true );
6868
69- queryWt = OutputType (weightsPtr, embedDim, embedDim, false , true );
70- keyWt = OutputType (weightsPtr + embedDim * embedDim, embedDim, embedDim,
69+ queryWt = MatType (weightsPtr, embedDim, embedDim, false , true );
70+ keyWt = MatType (weightsPtr + embedDim * embedDim, embedDim, embedDim,
7171 false , true );
72- valueWt = OutputType (weightsPtr + 2 * embedDim * embedDim, embedDim, embedDim,
72+ valueWt = MatType (weightsPtr + 2 * embedDim * embedDim, embedDim, embedDim,
7373 false , true );
74- outWt = OutputType (weightsPtr + 3 * embedDim * embedDim, embedDim, embedDim,
74+ outWt = MatType (weightsPtr + 3 * embedDim * embedDim, embedDim, embedDim,
7575 false , true );
7676
77- qBias = OutputType (weightsPtr + 4 * embedDim * embedDim, embedDim, 1 , false ,
77+ qBias = MatType (weightsPtr + 4 * embedDim * embedDim, embedDim, 1 , false ,
7878 true );
79- kBias = OutputType (weightsPtr + (4 * embedDim + 1 ) * embedDim, embedDim, 1 ,
79+ kBias = MatType (weightsPtr + (4 * embedDim + 1 ) * embedDim, embedDim, 1 ,
8080 false , true );
81- vBias = OutputType (weightsPtr + (4 * embedDim + 2 ) * embedDim, embedDim, 1 ,
81+ vBias = MatType (weightsPtr + (4 * embedDim + 2 ) * embedDim, embedDim, 1 ,
8282 false , true );
83- outBias = OutputType (weightsPtr + (4 * embedDim + 3 ) * embedDim, 1 , embedDim,
83+ outBias = MatType (weightsPtr + (4 * embedDim + 3 ) * embedDim, 1 , embedDim,
8484 false , true );
8585}
8686
87- template <typename InputType, typename OutputType , typename RegularizerType>
88- void MultiheadAttentionType<InputType, OutputType , RegularizerType>::
89- Forward (const InputType & input, OutputType & output)
87+ template <typename MatType , typename RegularizerType>
88+ void MultiheadAttentionType<MatType , RegularizerType>::
89+ Forward (const MatType & input, MatType & output)
9090{
91- typedef typename arma::Cube<typename InputType ::elem_type> CubeType;
91+ typedef typename arma::Cube<typename MatType ::elem_type> CubeType;
9292
9393 if (input.n_rows != embedDim * (tgtSeqLen + 2 * srcSeqLen))
9494 {
@@ -104,12 +104,12 @@ Forward(const InputType& input, OutputType& output)
104104 // The shape of q : (embedDim, tgtSeqLen, batchSize).
105105 // The shape of k : (embedDim, srcSeqLen, batchSize).
106106 // The shape of v : (embedDim, srcSeqLen, batchSize).
107- const CubeType q (const_cast <InputType &>(input).memptr (),
107+ const CubeType q (const_cast <MatType &>(input).memptr (),
108108 embedDim, tgtSeqLen, batchSize, false , false );
109- const CubeType k (const_cast <InputType &>(input).memptr () +
109+ const CubeType k (const_cast <MatType &>(input).memptr () +
110110 embedDim * tgtSeqLen * batchSize,
111111 embedDim, srcSeqLen, batchSize, false , false );
112- const CubeType v (const_cast <InputType &>(input).memptr () +
112+ const CubeType v (const_cast <MatType &>(input).memptr () +
113113 embedDim * (tgtSeqLen + srcSeqLen) * batchSize,
114114 embedDim, srcSeqLen, batchSize, false , false );
115115
@@ -167,8 +167,8 @@ Forward(const InputType& input, OutputType& output)
167167
168168 for (size_t i = 0 ; i < numHeads * batchSize; ++i)
169169 {
170- softmax.Forward (scores.slice (i), softmax. OutputParameter () );
171- scores.slice (i) = softmax. OutputParameter () ;
170+ softmax.Forward (scores.slice (i), softmaxOutput );
171+ scores.slice (i) = softmaxOutput ;
172172 }
173173
174174 // Calculate the attention output i.e. matrix multiplication of softmax
@@ -188,13 +188,13 @@ Forward(const InputType& input, OutputType& output)
188188 }
189189}
190190
191- template <typename InputType, typename OutputType , typename RegularizerType>
192- void MultiheadAttentionType<InputType, OutputType , RegularizerType>::
193- Backward (const InputType & /* input */ ,
194- const OutputType & gy,
195- OutputType & g)
191+ template <typename MatType , typename RegularizerType>
192+ void MultiheadAttentionType<MatType , RegularizerType>::
193+ Backward (const MatType & /* input */ ,
194+ const MatType & gy,
195+ MatType & g)
196196{
197- typedef typename arma::Cube<typename OutputType ::elem_type> CubeType;
197+ typedef typename arma::Cube<typename MatType ::elem_type> CubeType;
198198
199199 if (gy.n_rows != tgtSeqLen * embedDim)
200200 {
@@ -208,7 +208,7 @@ Backward(const InputType& /* input */,
208208 // The shape of gyTemp : (tgtSeqLen, embedDim, batchSize).
209209 // We need not split it into n heads now because this is the part when
210210 // output were concatenated from n heads.
211- CubeType gyTemp (const_cast <OutputType &>(gy).memptr (), embedDim,
211+ CubeType gyTemp (const_cast <MatType &>(gy).memptr (), embedDim,
212212 tgtSeqLen, batchSize, true , false );
213213
214214 // The shape of gyTemp : (embedDim, tgtSeqLen, batchSize).
@@ -278,13 +278,13 @@ Backward(const InputType& /* input */,
278278 }
279279}
280280
281- template <typename InputType, typename OutputType , typename RegularizerType>
282- void MultiheadAttentionType<InputType, OutputType , RegularizerType>::
283- Gradient (const InputType & input,
284- const OutputType & error,
285- OutputType & gradient)
281+ template <typename MatType , typename RegularizerType>
282+ void MultiheadAttentionType<MatType , RegularizerType>::
283+ Gradient (const MatType & input,
284+ const MatType & error,
285+ MatType & gradient)
286286{
287- typedef typename arma::Cube<typename InputType ::elem_type> CubeType;
287+ typedef typename arma::Cube<typename MatType ::elem_type> CubeType;
288288
289289 if (input.n_rows != embedDim * (tgtSeqLen + 2 * srcSeqLen))
290290 {
@@ -302,16 +302,16 @@ Gradient(const InputType& input,
302302 // The shape of gradient : (4 * embedDim * embedDim + 4 * embedDim, 1).
303303 gradient.set_size (arma::size (weights));
304304
305- const CubeType q (const_cast <InputType &>(input).memptr (),
305+ const CubeType q (const_cast <MatType &>(input).memptr (),
306306 embedDim, tgtSeqLen, batchSize, false , false );
307- const CubeType k (const_cast <InputType &>(input).memptr () + q.n_elem ,
307+ const CubeType k (const_cast <MatType &>(input).memptr () + q.n_elem ,
308308 embedDim, srcSeqLen, batchSize, false , false );
309- const CubeType v (const_cast <InputType &>(input).memptr () + q.n_elem + k.n_elem ,
309+ const CubeType v (const_cast <MatType &>(input).memptr () + q.n_elem + k.n_elem ,
310310 embedDim, srcSeqLen, batchSize, false , false );
311311
312312 // Reshape the propagated error into a cube.
313313 // The shape of errorTemp : (embedDim, tgtSeqLen, batchSize).
314- CubeType errorTemp (const_cast <OutputType &>(error).memptr (), embedDim,
314+ CubeType errorTemp (const_cast <MatType &>(error).memptr (), embedDim,
315315 tgtSeqLen, batchSize, true , false );
316316
317317 // Gradient wrt. outBias, i.e. dL/d(outBias).
@@ -425,12 +425,12 @@ Gradient(const InputType& input,
425425 regularizer.Evaluate (weights, gradient);
426426}
427427
428- template <typename InputType, typename OutputType , typename RegularizerType>
428+ template <typename MatType , typename RegularizerType>
429429template <typename Archive>
430- void MultiheadAttentionType<InputType, OutputType , RegularizerType>::
430+ void MultiheadAttentionType<MatType , RegularizerType>::
431431serialize (Archive& ar, const uint32_t /* version */ )
432432{
433- ar (cereal::base_class<Layer<InputType, OutputType >>(this ));
433+ ar (cereal::base_class<Layer<MatType >>(this ));
434434
435435 ar (CEREAL_NVP (tgtSeqLen));
436436 ar (CEREAL_NVP (srcSeqLen));
0 commit comments