Skip to content

Commit 0253a86

Browse files
committed
Remove last references to arma::
1 parent 41e93cf commit 0253a86

File tree

2 files changed

+28
-25
lines changed

2 files changed

+28
-25
lines changed

src/mlpack/methods/ann/layer/repeat.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,11 @@ class RepeatType : public Layer<MatType>
135135

136136
// Cache the target indices for a single tensor for use
137137
// in the forward pass.
138-
arma::uvec outIdxs;
138+
std::vector<size_t> outIdxs;
139139

140140
// Cache the contributions of each output element to the
141141
// input elements for use in the backward pass.
142142
size_t sizeMult;
143-
arma::umat backIdxs;
144143
}; // class RepeatType.
145144

146145
// Standard Repeat layer.

src/mlpack/methods/ann/layer/repeat_impl.hpp

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ RepeatType<MatType>::RepeatType(const RepeatType& other) :
4040
multiples(other.multiples),
4141
interleave(other.interleave),
4242
outIdxs(other.outIdxs),
43-
sizeMult(other.sizeMult),
44-
backIdxs(other.backIdxs)
43+
sizeMult(other.sizeMult)
4544
{
4645
// Nothing else to do.
4746
}
@@ -52,8 +51,7 @@ RepeatType<MatType>::RepeatType(RepeatType&& other) noexcept :
5251
multiples(other.multiples),
5352
interleave(other.interleave),
5453
outIdxs(other.outIdxs),
55-
sizeMult(other.sizeMult),
56-
backIdxs(other.backIdxs)
54+
sizeMult(other.sizeMult)
5755
{
5856
// Nothing else to do.
5957
}
@@ -68,7 +66,6 @@ RepeatType<MatType>& RepeatType<MatType>::operator=(const RepeatType& other)
6866
interleave = other.interleave;
6967
outIdxs = other.outIdxs;
7068
sizeMult = other.sizeMult;
71-
backIdxs = other.backIdxs;
7269
}
7370

7471
return *this;
@@ -84,7 +81,6 @@ RepeatType<MatType>& RepeatType<MatType>::operator=(RepeatType&& other) noexcept
8481
interleave = std::move(other.interleave);
8582
outIdxs = std::move(other.outIdxs);
8683
sizeMult = other.sizeMult;
87-
backIdxs = other.backIdxs;
8884
}
8985

9086
return *this;
@@ -106,7 +102,13 @@ void RepeatType<MatType>::ComputeOutputDimensions()
106102
{
107103
inputSize *= this->inputDimensions[i];
108104
}
109-
arma::umat idxs = arma::regspace<arma::uvec>(0, inputSize - 1);
105+
MatType idxs(inputSize, 1);
106+
for (size_t i=0; i<inputSize; i++)
107+
{
108+
idxs.at(i) = i;
109+
}
110+
// want to do this, but can't leave it without a namespace (e.g. arma::)
111+
// MatType idxs = linspace<MatType>(0, inputSize - 1, inputSize);
110112

111113
// Here, we are going to pre-compute the source index for each output
112114
// for a single tensor. Since the tensors are flattened into 1-d
@@ -150,24 +152,24 @@ void RepeatType<MatType>::ComputeOutputDimensions()
150152
}
151153
outSize *= this->outputDimensions[i];
152154
}
153-
outIdxs = idxs.as_col();
154-
155-
// Now, we are going to pre-compute the contribution of each output
156-
// element to the input elements. This will be used in the backward
157-
// pass with a simple matrix multiplication.
158-
backIdxs.set_size(inputSize, sizeMult);
159-
arma::uvec counts(inputSize, arma::fill::zeros);
160-
for (size_t i = 0; i < outIdxs.n_elem; i++)
155+
outIdxs.resize(idxs.n_elem);
156+
for (size_t i=0; i<idxs.n_elem; i++)
161157
{
162-
auto r = outIdxs.at(i);
163-
backIdxs.at(r, counts.at(r)++) = i;
158+
outIdxs[i] = idxs.at(i);
164159
}
165160
}
166161

167162
template<typename MatType>
168163
void RepeatType<MatType>::Forward(const MatType& input, MatType& output)
169164
{
170-
output = input.rows(outIdxs);
165+
#pragma omp parallel for
166+
for (size_t j=0; j<input.n_cols; j++)
167+
{
168+
for (size_t i = 0; i < outIdxs.size(); i++)
169+
{
170+
output.at(i, j) = input.at(outIdxs.at(i), j);
171+
}
172+
}
171173
}
172174

173175
template<typename MatType>
@@ -177,12 +179,15 @@ void RepeatType<MatType>::Backward(
177179
const MatType& gy,
178180
MatType& g)
179181
{
180-
g = gy.rows(backIdxs.col(0));
181-
for (size_t c = 1; c < sizeMult; c++)
182+
#pragma omp parallel for
183+
for (size_t j=0; j<gy.n_cols; j++)
182184
{
183-
g += gy.rows(backIdxs.col(c));
185+
g.col(j).zeros();
186+
for (size_t i=0; i<outIdxs.size(); i++) {
187+
g.at(outIdxs.at(i), j) += gy.at(i, j);
188+
}
189+
g.col(j) /= sizeMult;
184190
}
185-
g /= sizeMult;
186191
}
187192

188193
template<typename MatType>
@@ -196,7 +201,6 @@ void RepeatType<MatType>::serialize(
196201
ar(CEREAL_NVP(interleave));
197202
ar(CEREAL_NVP(outIdxs));
198203
ar(CEREAL_NVP(sizeMult));
199-
ar(CEREAL_NVP(backIdxs));
200204
}
201205

202206
} // namespace mlpack

0 commit comments

Comments
 (0)