@@ -40,8 +40,7 @@ RepeatType<MatType>::RepeatType(const RepeatType& other) :
4040 multiples (other.multiples),
4141 interleave(other.interleave),
4242 outIdxs(other.outIdxs),
43- sizeMult(other.sizeMult),
44- backIdxs(other.backIdxs)
43+ sizeMult(other.sizeMult)
4544{
4645 // Nothing else to do.
4746}
@@ -52,8 +51,7 @@ RepeatType<MatType>::RepeatType(RepeatType&& other) noexcept :
5251 multiples (other.multiples),
5352 interleave(other.interleave),
5453 outIdxs(other.outIdxs),
55- sizeMult(other.sizeMult),
56- backIdxs(other.backIdxs)
54+ sizeMult(other.sizeMult)
5755{
5856 // Nothing else to do.
5957}
@@ -68,7 +66,6 @@ RepeatType<MatType>& RepeatType<MatType>::operator=(const RepeatType& other)
6866 interleave = other.interleave ;
6967 outIdxs = other.outIdxs ;
7068 sizeMult = other.sizeMult ;
71- backIdxs = other.backIdxs ;
7269 }
7370
7471 return *this ;
@@ -84,7 +81,6 @@ RepeatType<MatType>& RepeatType<MatType>::operator=(RepeatType&& other) noexcept
8481 interleave = std::move (other.interleave );
8582 outIdxs = std::move (other.outIdxs );
8683 sizeMult = other.sizeMult ;
87- backIdxs = other.backIdxs ;
8884 }
8985
9086 return *this ;
@@ -106,7 +102,13 @@ void RepeatType<MatType>::ComputeOutputDimensions()
106102 {
107103 inputSize *= this ->inputDimensions [i];
108104 }
109- arma::umat idxs = arma::regspace<arma::uvec>(0 , inputSize - 1 );
105+ MatType idxs (inputSize, 1 );
106+ for (size_t i=0 ; i<inputSize; i++)
107+ {
108+ idxs.at (i) = i;
109+ }
110+ // want to do this, but can't leave it without a namespace (e.g. arma::)
111+ // MatType idxs = linspace<MatType>(0, inputSize - 1, inputSize);
110112
111113 // Here, we are going to pre-compute the source index for each output
112114 // for a single tensor. Since the tensors are flattened into 1-d
@@ -150,24 +152,24 @@ void RepeatType<MatType>::ComputeOutputDimensions()
150152 }
151153 outSize *= this ->outputDimensions [i];
152154 }
153- outIdxs = idxs.as_col ();
154-
155- // Now, we are going to pre-compute the contribution of each output
156- // element to the input elements. This will be used in the backward
157- // pass with a simple matrix multiplication.
158- backIdxs.set_size (inputSize, sizeMult);
159- arma::uvec counts (inputSize, arma::fill::zeros);
160- for (size_t i = 0 ; i < outIdxs.n_elem ; i++)
155+ outIdxs.resize (idxs.n_elem );
156+ for (size_t i=0 ; i<idxs.n_elem ; i++)
161157 {
162- auto r = outIdxs.at (i);
163- backIdxs.at (r, counts.at (r)++) = i;
158+ outIdxs[i] = idxs.at (i);
164159 }
165160}
166161
167162template <typename MatType>
168163void RepeatType<MatType>::Forward(const MatType& input, MatType& output)
169164{
170- output = input.rows (outIdxs);
165+ #pragma omp parallel for
166+ for (size_t j=0 ; j<input.n_cols ; j++)
167+ {
168+ for (size_t i = 0 ; i < outIdxs.size (); i++)
169+ {
170+ output.at (i, j) = input.at (outIdxs.at (i), j);
171+ }
172+ }
171173}
172174
173175template <typename MatType>
@@ -177,12 +179,15 @@ void RepeatType<MatType>::Backward(
177179 const MatType& gy,
178180 MatType& g)
179181{
180- g = gy. rows (backIdxs. col ( 0 ));
181- for (size_t c = 1 ; c < sizeMult; c ++)
182+ # pragma omp parallel for
183+ for (size_t j= 0 ; j<gy. n_cols ; j ++)
182184 {
183- g += gy.rows (backIdxs.col (c));
185+ g.col (j).zeros ();
186+ for (size_t i=0 ; i<outIdxs.size (); i++) {
187+ g.at (outIdxs.at (i), j) += gy.at (i, j);
188+ }
189+ g.col (j) /= sizeMult;
184190 }
185- g /= sizeMult;
186191}
187192
188193template <typename MatType>
@@ -196,7 +201,6 @@ void RepeatType<MatType>::serialize(
196201 ar (CEREAL_NVP (interleave));
197202 ar (CEREAL_NVP (outIdxs));
198203 ar (CEREAL_NVP (sizeMult));
199- ar (CEREAL_NVP (backIdxs));
200204}
201205
202206} // namespace mlpack
0 commit comments