1515#include " utils.h"
1616#include " vector.h"
1717
18+ #if defined(__AVX512F__) || defined(__AVX__) || defined(__SSE__)
19+ #include < immintrin.h>
20+ #endif
21+
1822namespace fasttext {
1923
2024DenseMatrix::DenseMatrix () : DenseMatrix(0 , 0 ) {}
@@ -146,6 +150,92 @@ void DenseMatrix::addRowToVector(Vector& x, int32_t i, real a) const {
146150 }
147151}
148152
153+ /* Abstract over AVX512F, AVX, and SSE intrinsics, using the one available on this machine. */
154+ #if defined(__AVX512F__)
155+ using Register = __m512;
156+ inline Register Add (Register first, Register second) { return _mm512_add_ps (first, second); }
157+ inline Register Set1 (float to) { return _mm512_set1_ps (to); }
158+ inline Register Multiply (Register first, Register second) { return _mm512_mul_ps (first, second); }
159+ #elif defined(__AVX__)
160+ using Register = __m256;
161+ inline Register Add (Register first, Register second) { return _mm256_add_ps (first, second); }
162+ inline Register Set1 (float to) { return _mm256_set1_ps (to); }
163+ inline Register Multiply (Register first, Register second) { return _mm256_mul_ps (first, second); }
164+ #elif defined(__SSE__)
165+ using Register = __m128;
166+ inline Register Add (Register first, Register second) { return _mm_add_ps (first, second); }
167+ inline Register Set1 (float to) { return _mm_set1_ps (to); }
168+ inline Register Multiply (Register first, Register second) { return _mm_mul_ps (first, second); }
169+ #endif
170+
171+ /* Faster routine for averaging rows of a matrix on x86.
172+ * The idea here is to keep the accumulators in registers if possible. */
173+ #if defined(__AVX512F__) || defined(__AVX__) || defined(__SSE__)
174+ template <unsigned Cols> void averageRowsFast (Vector& x, const std::vector<int32_t >& rows, const DenseMatrix &matrix) {
175+ // Columns must be a multiple of how many floats fit in a register.
176+ static_assert (Cols % (sizeof (Register) / 4 ) == 0 );
177+ constexpr unsigned RegisterCount = Cols / (sizeof (Register) / 4 );
178+ // These should be aligned by aligned.h
179+ assert (reinterpret_cast <uintptr_t >(x.data ()) % sizeof (Register) == 0 );
180+ assert (reinterpret_cast <uintptr_t >(matrix.data ()) % sizeof (Register) == 0 );
181+
182+ // Guard against empty list of rows with default NaN behavior.
183+ if (rows.empty ()) {
184+ x.zero ();
185+ x.mul (1.0 / rows.size ());
186+ return ;
187+ }
188+
189+ // Copy the first row to accumulation registers.
190+ Register accum[RegisterCount];
191+ auto row = rows.cbegin ();
192+ const Register *base = reinterpret_cast <const Register*>(matrix.data () + matrix.cols () * *row);
193+ for (unsigned i = 0 ; i < RegisterCount; ++i) {
194+ accum[i] = base[i];
195+ }
196+ // Add the rows after the first.
197+ for (++row; row != rows.cend (); ++row) {
198+ base = reinterpret_cast <const Register*>(matrix.data () + matrix.cols () * *row);
199+ for (unsigned i = 0 ; i < RegisterCount; ++i) {
200+ accum[i] = Add (accum[i], base[i]);
201+ }
202+ }
203+ // Multiply by (1.0 / rows.size()) and write to x.
204+ Register mul = Set1 (1.0 / rows.size ());
205+ for (unsigned i = 0 ; i < RegisterCount; ++i) {
206+ reinterpret_cast <Register*>(x.data ())[i] = Multiply (accum[i], mul);
207+ }
208+ }
209+ #endif
210+
211+ void DenseMatrix::averageRowsToVector (Vector& x, const std::vector<int32_t >& rows) const {
212+ #if defined(__AVX512F__) || defined(__AVX__) || defined(__SSE__)
213+ switch (cols ()) {
214+ case 512 :
215+ // Maximum number that can fit all in registers on AVX512F.
216+ averageRowsFast<512 >(x, rows, *this );
217+ return ;
218+ case 256 :
219+ averageRowsFast<256 >(x, rows, *this );
220+ return ;
221+ case 64 :
222+ averageRowsFast<64 >(x, rows, *this );
223+ return ;
224+ case 32 :
225+ averageRowsFast<32 >(x, rows, *this );
226+ return ;
227+ case 16 :
228+ averageRowsFast<16 >(x, rows, *this );
229+ return ;
230+ }
231+ #endif
232+ x.zero ();
233+ for (auto it = rows.cbegin (); it != rows.cend (); ++it) {
234+ addRowToVector (x, *it);
235+ }
236+ x.mul (1.0 / rows.size ());
237+ }
238+
149239void DenseMatrix::save (std::ostream& out) const {
150240 out.write ((char *)&m_, sizeof (int64_t ));
151241 out.write ((char *)&n_, sizeof (int64_t ));
@@ -155,7 +245,7 @@ void DenseMatrix::save(std::ostream& out) const {
155245void DenseMatrix::load (std::istream& in) {
156246 in.read ((char *)&m_, sizeof (int64_t ));
157247 in.read ((char *)&n_, sizeof (int64_t ));
158- data_ = std::vector <real>(m_ * n_);
248+ data_ = intgemm::AlignedVector <real>(m_ * n_);
159249 in.read ((char *)data_.data (), m_ * n_ * sizeof (real));
160250}
161251
0 commit comments