Skip to content
This repository was archived by the owner on Mar 19, 2024. It is now read-only.

Commit b733943

Browse files
kpufacebook-github-bot
authored andcommitted
Predict 1.9-4.2x faster (#1341)
Summary: I made prediction 1.9x to 4.2x faster than before. # Motivation I want to use https://tinyurl.com/nllblid218e and similarly parametrized models to run language classification on petabytes of web data. # Methodology The costliest operation is summing the rows for each model input. I've optimized this in three ways: 1. `addRowToVector` was a virtual function call for each row. I've replaced this with one virtual function call per prediction by adding `averageRowsToVector` to `Matrix` calls. 2. `Vector` and `DenseMatrix` were not 64-byte aligned so the CPU was doing a lot of unaligned memory access. I've brought in my own `vector` replacement that does 64-byte alignment. 3. Write the `averageRowsToVector` in intrinsics for common vector sizes. This works on SSE, AVX, and AVX512F. See the commit history for a breakdown of speed improvement from each change. # Experiments Test set [docs1000.txt.gz](https://github.com/facebookresearch/fastText/files/11832996/docs1000.txt.gz) which is a bunch of random documents https://data.statmt.org/heafield/classified-fasttext/ CPU: AMD Ryzen 9 7950X 16-Core Model https://tinyurl.com/nllblid218e with 256-dimensional vectors Before real 0m8.757s user 0m8.434s sys 0m0.327s After real 0m2.046s user 0m1.717s sys 0m0.334s Model https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin with 16-dimensional vectors Before real 0m0.926s user 0m0.889s sys 0m0.037s After real 0m0.477s user 0m0.436s sys 0m0.040s Pull Request resolved: #1341 Reviewed By: graemenail Differential Revision: D52134736 Pulled By: kpuatfb fbshipit-source-id: 42067161f4c968c34612934b48a562399a267f3b
1 parent 6c2204b commit b733943

File tree

14 files changed

+283
-32
lines changed

14 files changed

+283
-32
lines changed

CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
cmake_minimum_required(VERSION 2.8.9)
1010
project(fasttext)
1111

12+
set(CMAKE_CXX_STANDARD 17)
13+
1214
# The version number.
1315
set (fasttext_VERSION_MAJOR 0)
1416
set (fasttext_VERSION_MINOR 1)
1517

1618
include_directories(fasttext)
1719

18-
set(CMAKE_CXX_FLAGS " -pthread -std=c++11 -funroll-loops -O3 -march=native")
20+
set(CMAKE_CXX_FLAGS " -pthread -std=c++17 -funroll-loops -O3 -march=native")
1921

2022
set(HEADER_FILES
2123
src/args.h

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#
88

99
CXX = c++
10-
CXXFLAGS = -pthread -std=c++11 -march=native
10+
CXXFLAGS = -pthread -std=c++17 -march=native
1111
OBJS = args.o autotune.o matrix.o dictionary.o loss.o productquantizer.o densematrix.o quantmatrix.o vector.o model.o utils.o meter.o fasttext.o
1212
INCLUDES = -I.
1313

setup.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,14 @@ def has_flag(compiler, flags):
9898

9999

100100
def cpp_flag(compiler):
101-
"""Return the -std=c++[11/14] compiler flag.
102-
The c++14 is preferred over c++11 (when it is available).
101+
"""Return the -std=c++17 compiler flag.
103102
"""
104-
standards = ['-std=c++11']
103+
standards = ['-std=c++17']
105104
for standard in standards:
106105
if has_flag(compiler, [standard]):
107106
return standard
108107
raise RuntimeError(
109-
'Unsupported compiler -- at least C++11 support '
108+
'Unsupported compiler -- at least C++17 support '
110109
'is needed!'
111110
)
112111

src/aligned.h

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#pragma once
2+
#include <cstdlib>
3+
#include <new>
4+
#ifdef _MSC_VER
5+
// Ensure _HAS_EXCEPTIONS is defined
6+
#include <vcruntime.h>
7+
#include <malloc.h>
8+
#endif
9+
10+
#if !((defined(_MSC_VER) && !defined(__clang__)) ? (_HAS_EXCEPTIONS) : (__EXCEPTIONS))
11+
#include <cstdlib>
12+
#endif
13+
14+
// Aligned simple vector.
15+
16+
namespace intgemm {
17+
18+
template <class T> class AlignedVector {
19+
public:
20+
AlignedVector() : mem_(nullptr), size_(0) {}
21+
22+
explicit AlignedVector(std::size_t size, std::size_t alignment = 64 /* CPU cares about this */)
23+
: size_(size) {
24+
#ifdef _MSC_VER
25+
mem_ = static_cast<T*>(_aligned_malloc(size * sizeof(T), alignment));
26+
if (!mem_) {
27+
# if (defined(_MSC_VER) && !defined(__clang__)) ? (_HAS_EXCEPTIONS) : (__EXCEPTIONS)
28+
throw std::bad_alloc();
29+
# else
30+
std::abort();
31+
# endif
32+
}
33+
#else
34+
if (posix_memalign(reinterpret_cast<void **>(&mem_), alignment, size * sizeof(T))) {
35+
# if (defined(_MSC_VER) && !defined(__clang__)) ? (_HAS_EXCEPTIONS) : (__EXCEPTIONS)
36+
throw std::bad_alloc();
37+
# else
38+
std::abort();
39+
# endif
40+
}
41+
#endif
42+
}
43+
44+
template <class InputIt> AlignedVector(InputIt first, InputIt last)
45+
: AlignedVector(last - first) {
46+
std::copy(first, last, begin());
47+
}
48+
49+
AlignedVector(AlignedVector &&from) noexcept : mem_(from.mem_), size_(from.size_) {
50+
from.mem_ = nullptr;
51+
from.size_ = 0;
52+
}
53+
54+
AlignedVector &operator=(AlignedVector &&from) {
55+
if (this == &from) return *this;
56+
release();
57+
mem_ = from.mem_;
58+
size_ = from.size_;
59+
from.mem_ = nullptr;
60+
from.size_ = 0;
61+
return *this;
62+
}
63+
64+
AlignedVector(const AlignedVector&) = delete;
65+
AlignedVector& operator=(const AlignedVector&) = delete;
66+
67+
~AlignedVector() { release(); }
68+
69+
std::size_t size() const { return size_; }
70+
71+
T &operator[](std::size_t offset) { return mem_[offset]; }
72+
const T &operator[](std::size_t offset) const { return mem_[offset]; }
73+
74+
T *begin() { return mem_; }
75+
const T *begin() const { return mem_; }
76+
T *end() { return mem_ + size_; }
77+
const T *end() const { return mem_ + size_; }
78+
79+
T *data() { return mem_; }
80+
const T *data() const { return mem_; }
81+
82+
template <typename ReturnType>
83+
ReturnType *as() { return reinterpret_cast<ReturnType*>(mem_); }
84+
85+
private:
86+
T *mem_;
87+
std::size_t size_;
88+
89+
void release() {
90+
#ifdef _MSC_VER
91+
_aligned_free(mem_);
92+
#else
93+
std::free(mem_);
94+
#endif
95+
}
96+
};
97+
98+
} // namespace intgemm

src/densematrix.cc

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
#include "utils.h"
1616
#include "vector.h"
1717

18+
#if defined(__AVX512F__) || defined(__AVX__) || defined(__SSE__)
19+
#include <immintrin.h>
20+
#endif
21+
1822
namespace fasttext {
1923

2024
DenseMatrix::DenseMatrix() : DenseMatrix(0, 0) {}
@@ -146,6 +150,92 @@ void DenseMatrix::addRowToVector(Vector& x, int32_t i, real a) const {
146150
}
147151
}
148152

153+
/* Abstract over AVX512F, AVX, and SSE intrinsics, using the one available on this machine. */
154+
#if defined(__AVX512F__)
155+
using Register = __m512;
156+
inline Register Add(Register first, Register second) { return _mm512_add_ps(first, second); }
157+
inline Register Set1(float to) { return _mm512_set1_ps(to); }
158+
inline Register Multiply(Register first, Register second) { return _mm512_mul_ps(first, second); }
159+
#elif defined(__AVX__)
160+
using Register = __m256;
161+
inline Register Add(Register first, Register second) { return _mm256_add_ps(first, second); }
162+
inline Register Set1(float to) { return _mm256_set1_ps(to); }
163+
inline Register Multiply(Register first, Register second) { return _mm256_mul_ps(first, second); }
164+
#elif defined(__SSE__)
165+
using Register = __m128;
166+
inline Register Add(Register first, Register second) { return _mm_add_ps(first, second); }
167+
inline Register Set1(float to) { return _mm_set1_ps(to); }
168+
inline Register Multiply(Register first, Register second) { return _mm_mul_ps(first, second); }
169+
#endif
170+
171+
/* Faster routine for averaging rows of a matrix on x86.
172+
* The idea here is to keep the accumulators in registers if possible. */
173+
#if defined(__AVX512F__) || defined(__AVX__) || defined(__SSE__)
174+
template <unsigned Cols> void averageRowsFast(Vector& x, const std::vector<int32_t>& rows, const DenseMatrix &matrix) {
175+
// Columns must be a multiple of how many floats fit in a register.
176+
static_assert(Cols % (sizeof(Register) / 4) == 0);
177+
constexpr unsigned RegisterCount = Cols / (sizeof(Register) / 4);
178+
// These should be aligned by aligned.h
179+
assert(reinterpret_cast<uintptr_t>(x.data()) % sizeof(Register) == 0);
180+
assert(reinterpret_cast<uintptr_t>(matrix.data()) % sizeof(Register) == 0);
181+
182+
// Guard against empty list of rows with default NaN behavior.
183+
if (rows.empty()) {
184+
x.zero();
185+
x.mul(1.0 / rows.size());
186+
return;
187+
}
188+
189+
// Copy the first row to accumulation registers.
190+
Register accum[RegisterCount];
191+
auto row = rows.cbegin();
192+
const Register *base = reinterpret_cast<const Register*>(matrix.data() + matrix.cols() * *row);
193+
for (unsigned i = 0; i < RegisterCount; ++i) {
194+
accum[i] = base[i];
195+
}
196+
// Add the rows after the first.
197+
for (++row; row != rows.cend(); ++row) {
198+
base = reinterpret_cast<const Register*>(matrix.data() + matrix.cols() * *row);
199+
for (unsigned i = 0; i < RegisterCount; ++i) {
200+
accum[i] = Add(accum[i], base[i]);
201+
}
202+
}
203+
// Multiply by (1.0 / rows.size()) and write to x.
204+
Register mul = Set1(1.0 / rows.size());
205+
for (unsigned i = 0; i < RegisterCount; ++i) {
206+
reinterpret_cast<Register*>(x.data())[i] = Multiply(accum[i], mul);
207+
}
208+
}
209+
#endif
210+
211+
void DenseMatrix::averageRowsToVector(Vector& x, const std::vector<int32_t>& rows) const {
212+
#if defined(__AVX512F__) || defined(__AVX__) || defined(__SSE__)
213+
switch (cols()) {
214+
case 512:
215+
// Maximum number that can fit all in registers on AVX512F.
216+
averageRowsFast<512>(x, rows, *this);
217+
return;
218+
case 256:
219+
averageRowsFast<256>(x, rows, *this);
220+
return;
221+
case 64:
222+
averageRowsFast<64>(x, rows, *this);
223+
return;
224+
case 32:
225+
averageRowsFast<32>(x, rows, *this);
226+
return;
227+
case 16:
228+
averageRowsFast<16>(x, rows, *this);
229+
return;
230+
}
231+
#endif
232+
x.zero();
233+
for (auto it = rows.cbegin(); it != rows.cend(); ++it) {
234+
addRowToVector(x, *it);
235+
}
236+
x.mul(1.0 / rows.size());
237+
}
238+
149239
void DenseMatrix::save(std::ostream& out) const {
150240
out.write((char*)&m_, sizeof(int64_t));
151241
out.write((char*)&n_, sizeof(int64_t));
@@ -155,7 +245,7 @@ void DenseMatrix::save(std::ostream& out) const {
155245
void DenseMatrix::load(std::istream& in) {
156246
in.read((char*)&m_, sizeof(int64_t));
157247
in.read((char*)&n_, sizeof(int64_t));
158-
data_ = std::vector<real>(m_ * n_);
248+
data_ = intgemm::AlignedVector<real>(m_ * n_);
159249
in.read((char*)data_.data(), m_ * n_ * sizeof(real));
160250
}
161251

src/densematrix.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <stdexcept>
1616
#include <vector>
1717

18+
#include "aligned.h"
1819
#include "matrix.h"
1920
#include "real.h"
2021

@@ -24,7 +25,7 @@ class Vector;
2425

2526
class DenseMatrix : public Matrix {
2627
protected:
27-
std::vector<real> data_;
28+
intgemm::AlignedVector<real> data_;
2829
void uniformThread(real, int, int32_t);
2930

3031
public:
@@ -71,6 +72,7 @@ class DenseMatrix : public Matrix {
7172
void addVectorToRow(const Vector&, int64_t, real) override;
7273
void addRowToVector(Vector& x, int32_t i) const override;
7374
void addRowToVector(Vector& x, int32_t i, real a) const override;
75+
void averageRowsToVector(Vector& x, const std::vector<int32_t>& rows) const override;
7476
void save(std::ostream&) const override;
7577
void load(std::istream&) override;
7678
void dump(std::ostream&) const override;

src/dictionary.cc

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ Dictionary::Dictionary(std::shared_ptr<Args> args, std::istream& in)
4242
load(in);
4343
}
4444

45-
int32_t Dictionary::find(const std::string& w) const {
45+
int32_t Dictionary::find(const std::string_view w) const {
4646
return find(w, hash(w));
4747
}
4848

49-
int32_t Dictionary::find(const std::string& w, uint32_t h) const {
49+
int32_t Dictionary::find(const std::string_view w, uint32_t h) const {
5050
int32_t word2intsize = word2int_.size();
5151
int32_t id = h % word2intsize;
5252
while (word2int_[id] != -1 && words_[word2int_[id]].word != w) {
@@ -126,12 +126,12 @@ bool Dictionary::discard(int32_t id, real rand) const {
126126
return rand > pdiscard_[id];
127127
}
128128

129-
int32_t Dictionary::getId(const std::string& w, uint32_t h) const {
129+
int32_t Dictionary::getId(const std::string_view w, uint32_t h) const {
130130
int32_t id = find(w, h);
131131
return word2int_[id];
132132
}
133133

134-
int32_t Dictionary::getId(const std::string& w) const {
134+
int32_t Dictionary::getId(const std::string_view w) const {
135135
int32_t h = find(w);
136136
return word2int_[h];
137137
}
@@ -142,7 +142,7 @@ entry_type Dictionary::getType(int32_t id) const {
142142
return words_[id].type;
143143
}
144144

145-
entry_type Dictionary::getType(const std::string& w) const {
145+
entry_type Dictionary::getType(const std::string_view w) const {
146146
return (w.find(args_->label) == 0) ? entry_type::label : entry_type::word;
147147
}
148148

@@ -160,7 +160,7 @@ std::string Dictionary::getWord(int32_t id) const {
160160
// Since all fasttext models that were already released were trained
161161
// using signed char, we fixed the hash function to make models
162162
// compatible whatever compiler is used.
163-
uint32_t Dictionary::hash(const std::string& str) const {
163+
uint32_t Dictionary::hash(const std::string_view str) const {
164164
uint32_t h = 2166136261;
165165
for (size_t i = 0; i < str.size(); i++) {
166166
h = h ^ uint32_t(int8_t(str[i]));
@@ -324,11 +324,16 @@ void Dictionary::addWordNgrams(
324324

325325
void Dictionary::addSubwords(
326326
std::vector<int32_t>& line,
327-
const std::string& token,
327+
const std::string_view token,
328328
int32_t wid) const {
329329
if (wid < 0) { // out of vocab
330330
if (token != EOS) {
331-
computeSubwords(BOW + token + EOW, line);
331+
std::string concat;
332+
concat.reserve(BOW.size() + token.size() + EOW.size());
333+
concat += BOW;
334+
concat.append(token.data(), token.size());
335+
concat += EOW;
336+
computeSubwords(concat, line);
332337
}
333338
} else {
334339
if (args_->maxn <= 0) { // in vocab w/o subwords
@@ -406,6 +411,51 @@ int32_t Dictionary::getLine(
406411
return ntokens;
407412
}
408413

414+
namespace {
415+
bool readWordNoNewline(std::string_view& in, std::string_view& word) {
416+
const std::string_view spaces(" \n\r\t\v\f\0");
417+
std::string_view::size_type begin = in.find_first_not_of(spaces);
418+
if (begin == std::string_view::npos) {
419+
in.remove_prefix(in.size());
420+
return false;
421+
}
422+
in.remove_prefix(begin);
423+
word = in.substr(0, in.find_first_of(spaces));
424+
in.remove_prefix(word.size());
425+
return true;
426+
}
427+
} // namespace
428+
429+
int32_t Dictionary::getStringNoNewline(
430+
std::string_view in,
431+
std::vector<int32_t>& words,
432+
std::vector<int32_t>& labels) const {
433+
std::vector<int32_t> word_hashes;
434+
std::string_view token;
435+
int32_t ntokens = 0;
436+
437+
words.clear();
438+
labels.clear();
439+
while (readWordNoNewline(in, token)) {
440+
uint32_t h = hash(token);
441+
int32_t wid = getId(token, h);
442+
entry_type type = wid < 0 ? getType(token) : getType(wid);
443+
444+
ntokens++;
445+
if (type == entry_type::word) {
446+
addSubwords(words, token, wid);
447+
word_hashes.push_back(h);
448+
} else if (type == entry_type::label && wid >= 0) {
449+
labels.push_back(wid - nwords_);
450+
}
451+
if (token == EOS) {
452+
break;
453+
}
454+
}
455+
addWordNgrams(words, word_hashes, args_->wordNgrams);
456+
return ntokens;
457+
}
458+
409459
void Dictionary::pushHash(std::vector<int32_t>& hashes, int32_t id) const {
410460
if (pruneidx_size_ == 0 || id < 0) {
411461
return;

0 commit comments

Comments
 (0)