Skip to content

Commit 20f646e

Browse files
akroppakropp
authored andcommitted
Adapt multihead_attention
1 parent 202e7f5 commit 20f646e

File tree

2 files changed

+86
-89
lines changed

2 files changed

+86
-89
lines changed

src/mlpack/methods/ann/layer/not_adapted/multihead_attention.hpp

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,15 @@ namespace mlpack {
4848
* of shape `(embedDim * tgtSeqLen, batchSize)`. The embeddings are stored
4949
* consequently.
5050
*
51-
* @tparam InputType Type of the input data (arma::colvec, arma::mat,
52-
* arma::sp_mat or arma::cube).
53-
* @tparam OutputType Type of the output data (arma::colvec, arma::mat,
51+
* @tparam MatType Type of the input/output data (arma::colvec, arma::mat,
5452
* arma::sp_mat or arma::cube).
5553
* @tparam RegularizerType Type of the regularizer to be used.
5654
*/
5755
template <
58-
typename InputType = arma::mat,
59-
typename OutputType = arma::mat,
56+
typename MatType = arma::mat,
6057
typename RegularizerType = NoRegularizer
6158
>
62-
class MultiheadAttentionType : public Layer<InputType, OutputType>
59+
class MultiheadAttentionType : public Layer<MatType>
6360
{
6461
public:
6562
/**
@@ -82,20 +79,20 @@ class MultiheadAttentionType : public Layer<InputType, OutputType>
8279
const size_t srcSeqLen,
8380
const size_t embedDim,
8481
const size_t numHeads,
85-
const InputType& attnmask = InputType(),
86-
const InputType& keyPaddingMask = InputType());
82+
const MatType& attnmask = MatType(),
83+
const MatType& keyPaddingMask = MatType());
8784

8885
//! Clone the MultiheadAttentionType object. This handles polymorphism
8986
//! correctly.
90-
MultiheadAttentionType* Clone() const
87+
MultiheadAttentionType* Clone() const override
9188
{
9289
return new MultiheadAttentionType(*this);
9390
}
9491

9592
/**
9693
* Reset the layer parameters.
9794
*/
98-
void SetWeights(typename OutputType::elem_type* weightsPtr);
95+
void SetWeights(typename MatType::elem_type* weightsPtr);
9996

10097
/**
10198
* Ordinary feed forward pass of a neural network, evaluating the function
@@ -104,7 +101,7 @@ class MultiheadAttentionType : public Layer<InputType, OutputType>
104101
* @param input The query matrix.
105102
* @param output Resulting output activation.
106103
*/
107-
void Forward(const InputType& input, OutputType& output);
104+
void Forward(const MatType& input, MatType& output) override;
108105

109106
/**
110107
* Ordinary feed backward pass of a neural network, calculating the function
@@ -114,9 +111,9 @@ class MultiheadAttentionType : public Layer<InputType, OutputType>
114111
* @param gy The backpropagated error.
115112
* @param g The calculated gradient.
116113
*/
117-
void Backward(const InputType& /* input */,
118-
const OutputType& gy,
119-
OutputType& g);
114+
void Backward(const MatType& /* input */,
115+
const MatType& gy,
116+
MatType& g) override;
120117

121118
/**
122119
* Calculate the gradient using the output delta and the input activation.
@@ -125,12 +122,12 @@ class MultiheadAttentionType : public Layer<InputType, OutputType>
125122
* @param error The calculated error.
126123
* @param gradient The calculated gradient.
127124
*/
128-
void Gradient(const InputType& input,
129-
const OutputType& error,
130-
OutputType& gradient);
125+
void Gradient(const MatType& input,
126+
const MatType& error,
127+
MatType& gradient) override;
131128

132129
//! Get the size of the weights.
133-
size_t WeightSize() const { return 4 * (embedDim + 1) * embedDim; }
130+
size_t WeightSize() const override { return 4 * (embedDim + 1) * embedDim; }
134131

135132
/**
136133
* Serialize the layer.
@@ -159,22 +156,20 @@ class MultiheadAttentionType : public Layer<InputType, OutputType>
159156
size_t& NumHeads() { return numHeads; }
160157

161158
//! Get the two dimensional Attention Mask.
162-
OutputType const& AttentionMask() const { return attnMask; }
159+
MatType const& AttentionMask() const { return attnMask; }
163160
//! Modify the two dimensional Attention Mask.
164-
OutputType& AttentionMask() { return attnMask; }
161+
MatType& AttentionMask() { return attnMask; }
165162

166163
//! Get Key Padding Mask.
167-
OutputType const& KeyPaddingMask() const { return keyPaddingMask; }
164+
MatType const& KeyPaddingMask() const { return keyPaddingMask; }
168165
//! Modify the Key Padding Mask.
169-
OutputType& KeyPaddingMask() { return keyPaddingMask; }
170-
171-
const size_t WeightSize() const { return (4 * embedDim + 4) * embedDim; }
166+
MatType& KeyPaddingMask() { return keyPaddingMask; }
172167

173168
const std::vector<size_t> OutputDimensions() const
174169
{
175170
// This returns the output as a 2-dimensional (embedDim * tgtSeqLen)
176171
// matrix.
177-
std::vector<size_t> outputDimensions(inputDimensions.size(), 1);
172+
std::vector<size_t> outputDimensions(this->inputDimensions.size(), 1);
178173
outputDimensions[0] = embedDim;
179174
outputDimensions[1] = tgtSeqLen;
180175

@@ -188,7 +183,7 @@ class MultiheadAttentionType : public Layer<InputType, OutputType>
188183

189184
private:
190185
//! Element Type of the output.
191-
typedef typename OutputType::elem_type ElemType;
186+
typedef typename MatType::elem_type ElemType;
192187

193188
//! Target sequence length.
194189
size_t tgtSeqLen;
@@ -206,37 +201,37 @@ class MultiheadAttentionType : public Layer<InputType, OutputType>
206201
size_t headDim;
207202

208203
//! Two dimensional Attention Mask of shape (tgtSeqLen, srcSeqLen).
209-
OutputType attnMask;
204+
MatType attnMask;
210205

211206
//! Key Padding Mask.
212-
OutputType keyPaddingMask;
207+
MatType keyPaddingMask;
213208

214209
//! Locally-stored weight matrix associated with query.
215-
OutputType queryWt;
210+
MatType queryWt;
216211

217212
//! Locally-stored weight matrix associated with key.
218-
OutputType keyWt;
213+
MatType keyWt;
219214

220215
//! Locally-stored weight matrix associated with value.
221-
OutputType valueWt;
216+
MatType valueWt;
222217

223218
//! Locally-stored weight matrix associated with attnWt.
224-
OutputType outWt;
219+
MatType outWt;
225220

226221
//! Locally-stored bias associated with query.
227-
OutputType qBias;
222+
MatType qBias;
228223

229224
//! Locally-stored bias associated with key.
230-
OutputType kBias;
225+
MatType kBias;
231226

232227
//! Locall-stored bias associated with value.
233-
OutputType vBias;
228+
MatType vBias;
234229

235230
//! Locally-stored bias associated with attnWt.
236-
OutputType outBias;
231+
MatType outBias;
237232

238233
//! Locally-stored weights parameter.
239-
OutputType weights;
234+
MatType weights;
240235

241236
//! Locally-stored projected query matrix over linear layer.
242237
arma::Cube<ElemType> qProj;
@@ -254,15 +249,17 @@ class MultiheadAttentionType : public Layer<InputType, OutputType>
254249
arma::Cube<ElemType> attnOut;
255250

256251
//! Softmax layer to represent the probabilities of next sequence.
257-
Softmax softmax;
252+
SoftmaxType<MatType> softmax;
253+
254+
// temporary storage for softmax output
255+
MatType softmaxOutput;
258256

259257
//! Locally-stored regularizer object.
260258
RegularizerType regularizer;
261259
}; // class MultiheadAttention
262260

263261
// Standard MultiheadAttention layer using no regularization.
264-
typedef MultiheadAttentionType<arma::mat, arma::mat, NoRegularizer>
265-
MultiheadAttention;
262+
typedef MultiheadAttentionType<arma::mat, NoRegularizer> MultiheadAttention;
266263

267264
} // namespace mlpack
268265

src/mlpack/methods/ann/layer/not_adapted/multihead_attention_impl.hpp

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,29 @@
2020

2121
namespace mlpack {
2222

23-
template <typename InputType, typename OutputType, typename RegularizerType>
24-
MultiheadAttentionType<InputType, OutputType, RegularizerType>::
23+
template <typename MatType, typename RegularizerType>
24+
MultiheadAttentionType<MatType, RegularizerType>::
2525
MultiheadAttentionType() :
2626
tgtSeqLen(0),
2727
srcSeqLen(0),
2828
embedDim(0),
2929
numHeads(0),
3030
headDim(0),
31-
attnMask(InputType()),
32-
keyPaddingMask(InputType())
31+
attnMask(MatType()),
32+
keyPaddingMask(MatType())
3333
{
3434
// Nothing to do here.
3535
}
3636

37-
template <typename InputType, typename OutputType, typename RegularizerType>
38-
MultiheadAttentionType<InputType, OutputType, RegularizerType>::
37+
template <typename MatType, typename RegularizerType>
38+
MultiheadAttentionType<MatType, RegularizerType>::
3939
MultiheadAttentionType(
4040
const size_t tgtSeqLen,
4141
const size_t srcSeqLen,
4242
const size_t embedDim,
4343
const size_t numHeads,
44-
const InputType& attnMask,
45-
const InputType& keyPaddingMask) :
44+
const MatType& attnMask,
45+
const MatType& keyPaddingMask) :
4646
tgtSeqLen(tgtSeqLen),
4747
srcSeqLen(srcSeqLen),
4848
embedDim(embedDim),
@@ -59,36 +59,36 @@ MultiheadAttentionType(
5959
headDim = embedDim / numHeads;
6060
}
6161

62-
template <typename InputType, typename OutputType, typename RegularizerType>
63-
void MultiheadAttentionType<InputType, OutputType, RegularizerType>::SetWeights(
64-
typename OutputType::elem_type* weightsPtr)
62+
template <typename MatType, typename RegularizerType>
63+
void MultiheadAttentionType<MatType, RegularizerType>::SetWeights(
64+
typename MatType::elem_type* weightsPtr)
6565
{
66-
weights = OutputType(weightsPtr, 1, (4 * embedDim + 4) * embedDim, false,
66+
weights = MatType(weightsPtr, 1, (4 * embedDim + 4) * embedDim, false,
6767
true);
6868

69-
queryWt = OutputType(weightsPtr, embedDim, embedDim, false, true);
70-
keyWt = OutputType(weightsPtr + embedDim * embedDim, embedDim, embedDim,
69+
queryWt = MatType(weightsPtr, embedDim, embedDim, false, true);
70+
keyWt = MatType(weightsPtr + embedDim * embedDim, embedDim, embedDim,
7171
false, true);
72-
valueWt = OutputType(weightsPtr + 2 * embedDim * embedDim, embedDim, embedDim,
72+
valueWt = MatType(weightsPtr + 2 * embedDim * embedDim, embedDim, embedDim,
7373
false, true);
74-
outWt = OutputType(weightsPtr + 3 * embedDim * embedDim, embedDim, embedDim,
74+
outWt = MatType(weightsPtr + 3 * embedDim * embedDim, embedDim, embedDim,
7575
false, true);
7676

77-
qBias = OutputType(weightsPtr + 4 * embedDim * embedDim, embedDim, 1, false,
77+
qBias = MatType(weightsPtr + 4 * embedDim * embedDim, embedDim, 1, false,
7878
true);
79-
kBias = OutputType(weightsPtr + (4 * embedDim + 1) * embedDim, embedDim, 1,
79+
kBias = MatType(weightsPtr + (4 * embedDim + 1) * embedDim, embedDim, 1,
8080
false, true);
81-
vBias = OutputType(weightsPtr + (4 * embedDim + 2) * embedDim, embedDim, 1,
81+
vBias = MatType(weightsPtr + (4 * embedDim + 2) * embedDim, embedDim, 1,
8282
false, true);
83-
outBias = OutputType(weightsPtr + (4 * embedDim + 3) * embedDim, 1, embedDim,
83+
outBias = MatType(weightsPtr + (4 * embedDim + 3) * embedDim, 1, embedDim,
8484
false, true);
8585
}
8686

87-
template <typename InputType, typename OutputType, typename RegularizerType>
88-
void MultiheadAttentionType<InputType, OutputType, RegularizerType>::
89-
Forward(const InputType& input, OutputType& output)
87+
template <typename MatType, typename RegularizerType>
88+
void MultiheadAttentionType<MatType, RegularizerType>::
89+
Forward(const MatType& input, MatType& output)
9090
{
91-
typedef typename arma::Cube<typename InputType::elem_type> CubeType;
91+
typedef typename arma::Cube<typename MatType::elem_type> CubeType;
9292

9393
if (input.n_rows != embedDim * (tgtSeqLen + 2 * srcSeqLen))
9494
{
@@ -104,12 +104,12 @@ Forward(const InputType& input, OutputType& output)
104104
// The shape of q : (embedDim, tgtSeqLen, batchSize).
105105
// The shape of k : (embedDim, srcSeqLen, batchSize).
106106
// The shape of v : (embedDim, srcSeqLen, batchSize).
107-
const CubeType q(const_cast<InputType&>(input).memptr(),
107+
const CubeType q(const_cast<MatType&>(input).memptr(),
108108
embedDim, tgtSeqLen, batchSize, false, false);
109-
const CubeType k(const_cast<InputType&>(input).memptr() +
109+
const CubeType k(const_cast<MatType&>(input).memptr() +
110110
embedDim * tgtSeqLen * batchSize,
111111
embedDim, srcSeqLen, batchSize, false, false);
112-
const CubeType v(const_cast<InputType&>(input).memptr() +
112+
const CubeType v(const_cast<MatType&>(input).memptr() +
113113
embedDim * (tgtSeqLen + srcSeqLen) * batchSize,
114114
embedDim, srcSeqLen, batchSize, false, false);
115115

@@ -167,8 +167,8 @@ Forward(const InputType& input, OutputType& output)
167167

168168
for (size_t i = 0; i < numHeads * batchSize; ++i)
169169
{
170-
softmax.Forward(scores.slice(i), softmax.OutputParameter());
171-
scores.slice(i) = softmax.OutputParameter();
170+
softmax.Forward(scores.slice(i), softmaxOutput);
171+
scores.slice(i) = softmaxOutput;
172172
}
173173

174174
// Calculate the attention output i.e. matrix multiplication of softmax
@@ -188,13 +188,13 @@ Forward(const InputType& input, OutputType& output)
188188
}
189189
}
190190

191-
template <typename InputType, typename OutputType, typename RegularizerType>
192-
void MultiheadAttentionType<InputType, OutputType, RegularizerType>::
193-
Backward(const InputType& /* input */,
194-
const OutputType& gy,
195-
OutputType& g)
191+
template <typename MatType, typename RegularizerType>
192+
void MultiheadAttentionType<MatType, RegularizerType>::
193+
Backward(const MatType& /* input */,
194+
const MatType& gy,
195+
MatType& g)
196196
{
197-
typedef typename arma::Cube<typename OutputType::elem_type> CubeType;
197+
typedef typename arma::Cube<typename MatType::elem_type> CubeType;
198198

199199
if (gy.n_rows != tgtSeqLen * embedDim)
200200
{
@@ -208,7 +208,7 @@ Backward(const InputType& /* input */,
208208
// The shape of gyTemp : (tgtSeqLen, embedDim, batchSize).
209209
// We need not split it into n heads now because this is the part when
210210
// output were concatenated from n heads.
211-
CubeType gyTemp(const_cast<OutputType&>(gy).memptr(), embedDim,
211+
CubeType gyTemp(const_cast<MatType&>(gy).memptr(), embedDim,
212212
tgtSeqLen, batchSize, true, false);
213213

214214
// The shape of gyTemp : (embedDim, tgtSeqLen, batchSize).
@@ -278,13 +278,13 @@ Backward(const InputType& /* input */,
278278
}
279279
}
280280

281-
template <typename InputType, typename OutputType, typename RegularizerType>
282-
void MultiheadAttentionType<InputType, OutputType, RegularizerType>::
283-
Gradient(const InputType& input,
284-
const OutputType& error,
285-
OutputType& gradient)
281+
template <typename MatType, typename RegularizerType>
282+
void MultiheadAttentionType<MatType, RegularizerType>::
283+
Gradient(const MatType& input,
284+
const MatType& error,
285+
MatType& gradient)
286286
{
287-
typedef typename arma::Cube<typename InputType::elem_type> CubeType;
287+
typedef typename arma::Cube<typename MatType::elem_type> CubeType;
288288

289289
if (input.n_rows != embedDim * (tgtSeqLen + 2 * srcSeqLen))
290290
{
@@ -302,16 +302,16 @@ Gradient(const InputType& input,
302302
// The shape of gradient : (4 * embedDim * embedDim + 4 * embedDim, 1).
303303
gradient.set_size(arma::size(weights));
304304

305-
const CubeType q(const_cast<InputType&>(input).memptr(),
305+
const CubeType q(const_cast<MatType&>(input).memptr(),
306306
embedDim, tgtSeqLen, batchSize, false, false);
307-
const CubeType k(const_cast<InputType&>(input).memptr() + q.n_elem,
307+
const CubeType k(const_cast<MatType&>(input).memptr() + q.n_elem,
308308
embedDim, srcSeqLen, batchSize, false, false);
309-
const CubeType v(const_cast<InputType&>(input).memptr() + q.n_elem + k.n_elem,
309+
const CubeType v(const_cast<MatType&>(input).memptr() + q.n_elem + k.n_elem,
310310
embedDim, srcSeqLen, batchSize, false, false);
311311

312312
// Reshape the propagated error into a cube.
313313
// The shape of errorTemp : (embedDim, tgtSeqLen, batchSize).
314-
CubeType errorTemp(const_cast<OutputType&>(error).memptr(), embedDim,
314+
CubeType errorTemp(const_cast<MatType&>(error).memptr(), embedDim,
315315
tgtSeqLen, batchSize, true, false);
316316

317317
// Gradient wrt. outBias, i.e. dL/d(outBias).
@@ -425,12 +425,12 @@ Gradient(const InputType& input,
425425
regularizer.Evaluate(weights, gradient);
426426
}
427427

428-
template <typename InputType, typename OutputType, typename RegularizerType>
428+
template <typename MatType, typename RegularizerType>
429429
template <typename Archive>
430-
void MultiheadAttentionType<InputType, OutputType, RegularizerType>::
430+
void MultiheadAttentionType<MatType, RegularizerType>::
431431
serialize(Archive& ar, const uint32_t /* version */)
432432
{
433-
ar(cereal::base_class<Layer<InputType, OutputType>>(this));
433+
ar(cereal::base_class<Layer<MatType>>(this));
434434

435435
ar(CEREAL_NVP(tgtSeqLen));
436436
ar(CEREAL_NVP(srcSeqLen));

0 commit comments

Comments
 (0)