Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7db8ce6
Enable Intel®-AMX/oneDNN to accelerate IndexFlat search
guangzegu Feb 27, 2024
b35a0f2
formatted distances.cpp and onednn_utils.h
guangzegu Mar 11, 2024
781f178
Add descriptions of Intel®-AMX/oneDNN optimization to INSTALL.md
guangzegu Mar 19, 2024
2f3fdf9
Add oneDNN/AMX optimization for distance calculation using Blas for I…
guangzegu Mar 19, 2024
a15c5cc
Merge branch 'facebookresearch:main' into main
guangzegu Jun 19, 2024
116fc01
Restructure the AMX integration with faiss
guangzegu Sep 24, 2024
e3ea518
Merge remote-tracking branch 'upstream/main'
guangzegu Oct 14, 2024
9e34323
Refactor and optimize the code structure to support AMX/OneDNN comput…
guangzegu Oct 21, 2024
f556407
Format distances_dnnl.h
guangzegu Oct 22, 2024
ed7b184
Merge branch 'main' into main
guangzegu Oct 24, 2024
78857d9
Merge remote-tracking branch 'upstream/main'
guangzegu Feb 14, 2025
fc447da
Merge branch 'facebookresearch:main' into main
guangzegu Feb 21, 2025
f43b0fa
Merge branch 'facebookresearch:main' into main
guangzegu Mar 13, 2025
e14be69
Merge branch 'facebookresearch:main' into main
guangzegu Mar 18, 2025
eda99c0
Add DNNL compilation flags to support low-precision testing
guangzegu Mar 19, 2025
53fa4ad
Add unit tests for low-precision IndexFlatIP
guangzegu Mar 19, 2025
6dd3a8a
Skip certain high precision tests using DNNL compile option
guangzegu Mar 20, 2025
fd45d23
Merge branch 'facebookresearch:main' into main
guangzegu Mar 20, 2025
3933a75
Merge branch 'main' into main
mnorris11 Oct 30, 2025
36ffdd0
Merge branch 'main' into main
mnorris11 Nov 3, 2025
1fd9826
reformat AMX-accelerated code to match Faiss coding style
guangzegu Nov 4, 2025
2b59c3f
Update onednn_utils.h
guangzegu Nov 4, 2025
394fdeb
Update distances_dnnl.h
guangzegu Nov 4, 2025
033ae6f
Format comments in AMX-accelerated code to comply with Faiss style
guangzegu Nov 4, 2025
413a22b
Reformat AMX-accelerated code comments to match Faiss coding style
guangzegu Nov 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")

# Valid values are "generic", "avx2", "avx512", "avx512_spr", "sve".
option(FAISS_OPT_LEVEL "" "generic")
option(FAISS_ENABLE_DNNL "Enable support for onednn to accelerate indexflat search." OFF)
option(FAISS_ENABLE_GPU "Enable support for GPU indexes." ON)
option(FAISS_GPU_STATIC "Link GPU libraries statically." OFF)
option(FAISS_ENABLE_CUVS "Enable cuVS for GPU indexes." OFF)
Expand Down
5 changes: 5 additions & 0 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ The optional requirements are:
- for GPU indices:
- nvcc,
- the CUDA toolkit,
- for Intel®-AMX/oneDNN acceleration:
- oneDNN,
- 4th+ Gen Intel® Xeon® Scalable processor.
- for AMD GPUs:
- AMD ROCm,
- for using NVIDIA cuVS implementations:
Expand Down Expand Up @@ -127,6 +130,8 @@ Several options can be passed to CMake, among which:
- general options:
- `-DFAISS_ENABLE_GPU=OFF` in order to disable building GPU indices (possible
values are `ON` and `OFF`),
- `-DFAISS_ENABLE_DNNL=OFF` in order to support for Intel®-AMX/oneDNN to accelerate indexflat(inner_product) search (possible
values are `ON` and `OFF`, before invoking CMake and setting this option to `ON`, you can refer to this [link](https://oneapi-src.github.io/oneDNN/dev_guide_build.html) for installing oneDNN),
- `-DFAISS_ENABLE_PYTHON=OFF` in order to disable building python bindings
(possible values are `ON` and `OFF`),
- `-DFAISS_ENABLE_CUVS=ON` in order to use the NVIDIA cuVS implementations
Expand Down
4 changes: 4 additions & 0 deletions c_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ endif()
add_executable(example_c EXCLUDE_FROM_ALL example_c.c)
target_link_libraries(example_c PRIVATE faiss_c)

if(FAISS_ENABLE_DNNL)
add_compile_definitions(ENABLE_DNNL)
endif()

if(FAISS_ENABLE_GPU)
add_subdirectory(gpu)
endif()
18 changes: 18 additions & 0 deletions c_api/utils/distances_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,21 @@ void faiss_set_distance_compute_min_k_reservoir(int value) {
int faiss_get_distance_compute_min_k_reservoir() {
return faiss::distance_compute_min_k_reservoir;
}

#ifdef ENABLE_DNNL
void faiss_set_distance_compute_dnnl_query_bs(int value) {
faiss::distance_compute_dnnl_query_bs = value;
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for you to move these to cpi/cppcontrib/amx/distances_dnnl_c.h and if not feasible, gate it behind a compilation flag?

int faiss_get_distance_compute_dnnl_query_bs() {
return faiss::distance_compute_dnnl_query_bs;
}

void faiss_set_distance_compute_dnnl_database_bs(int value) {
faiss::distance_compute_dnnl_database_bs = value;
}

int faiss_get_distance_compute_dnnl_database_bs() {
return faiss::distance_compute_dnnl_database_bs;
}
#endif
14 changes: 14 additions & 0 deletions c_api/utils/distances_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,20 @@ void faiss_set_distance_compute_min_k_reservoir(int value);
/// rather than a heap
int faiss_get_distance_compute_min_k_reservoir();

#ifdef ENABLE_DNNL
/// Setter of block sizes value for oneDNN/AMX distance computations
void faiss_set_distance_compute_dnnl_query_bs(int value);

/// Getter of block sizes value for oneDNN/AMX distance computations
int faiss_get_distance_compute_dnnl_query_bs();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for you to move these to cpi/cppcontrib/amx/distances_dnnl_c.h and if not feasible, gate it behind a compilation flag?

/// Setter of block sizes value for oneDNN/AMX distance computations
void faiss_set_distance_compute_dnnl_database_bs(int value);

/// Getter of block sizes value for oneDNN/AMX distance computations
int faiss_get_distance_compute_dnnl_database_bs();
#endif

#ifdef __cplusplus
}
#endif
Expand Down
18 changes: 18 additions & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,16 @@ if(NOT WIN32)
list(APPEND FAISS_HEADERS invlists/OnDiskInvertedLists.h)
endif()

if(FAISS_ENABLE_DNNL)
list(APPEND FAISS_HEADERS cppcontrib/amx/onednn_utils.h)
list(APPEND FAISS_HEADERS cppcontrib/amx/distances_dnnl.h)
endif()

if(FAISS_ENABLE_DNNL)
add_compile_definitions(ENABLE_DNNL)
endif()


# Export FAISS_HEADERS variable to parent scope.
set(FAISS_HEADERS ${FAISS_HEADERS} PARENT_SCOPE)

Expand Down Expand Up @@ -388,6 +398,14 @@ if(FAISS_USE_LTO)
endif()
endif()

if(FAISS_ENABLE_DNNL)
find_library(RT_LIB rt)
find_library(DNNL_LIB dnnl)
target_link_libraries(faiss PRIVATE ${RT_LIB} ${DNNL_LIB})
target_link_libraries(faiss_avx2 PRIVATE ${RT_LIB} ${DNNL_LIB})
target_link_libraries(faiss_avx512 PRIVATE ${RT_LIB} ${DNNL_LIB})
endif()

find_package(OpenMP REQUIRED)
target_link_libraries(faiss PRIVATE OpenMP::OpenMP_CXX)
target_link_libraries(faiss_avx2 PRIVATE OpenMP::OpenMP_CXX)
Expand Down
116 changes: 116 additions & 0 deletions faiss/cppcontrib/amx/distances_dnnl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

/* All distance functions for L2 and IP distances.
* The actual functions are implemented in distances.cpp and distances_simd.cpp
*/

#include <faiss/cppcontrib/amx/onednn_utils.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/ResultHandler.h>
#include <faiss/impl/platform_macros.h>
#include <omp.h>

#ifndef FINTEGER
#define FINTEGER long
#endif

namespace faiss {

// block sizes for oneDNN/AMX distance computations
FAISS_API int distance_compute_dnnl_query_bs = 10240;
FAISS_API int distance_compute_dnnl_database_bs = 10240;

/**
* Find the nearest neighbors for nx queries in a set of ny vectors,
* accelerated via oneDNN/AMX.
*/
template <class BlockResultHandler>
void exhaustive_inner_product_seq_dnnl(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
BlockResultHandler& res) {
using SingleResultHandler =
typename BlockResultHandler::SingleResultHandler;
[[maybe_unused]] int nt = std::min(int(nx), omp_get_max_threads());

std::unique_ptr<float[]> res_arr(new float[nx * ny]);

comput_f32bf16f32_inner_product(
nx,
d,
ny,
d,
const_cast<float*>(x),
const_cast<float*>(y),
res_arr.get());

#pragma omp parallel num_threads(nt)
{
SingleResultHandler resi(res);
#pragma omp for
for (size_t i = 0; i < nx; i++) {
resi.begin(i);
for (size_t j = 0; j < ny; j++) {
float ip = res_arr[i * ny + j];
resi.add_result(ip, j);
}
resi.end();
}
}
}

/**
* Find the nearest neighbors for nx queries in a set of ny vectors,
* accelerated via oneDNN/AMX.
*/
template <class BlockResultHandler>
void exhaustive_inner_product_blas_dnnl(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
BlockResultHandler& res) {
/* block sizes */
const size_t bs_x = distance_compute_dnnl_query_bs;
const size_t bs_y = distance_compute_dnnl_database_bs;
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);

for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
size_t i1 = i0 + bs_x;
if (i1 > nx)
i1 = nx;

res.begin_multiple(i0, i1);

for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
size_t j1 = j0 + bs_y;
if (j1 > ny)
j1 = ny;
/* compute the actual dot products */
FINTEGER nyi = j1 - j0, nxi = i1 - i0;
comput_f32bf16f32_inner_product(
nxi,
d,
nyi,
d,
const_cast<float*>(x + i0 * d),
const_cast<float*>(y + j0 * d),
ip_block.get());

res.add_results(j0, j1, ip_block.get());
}
res.end_multiple();
InterruptCallback::check();
}
}

} // namespace faiss
142 changes: 142 additions & 0 deletions faiss/cppcontrib/amx/onednn_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

/* All distance functions for L2 and IP distances.
* The actual functions are implemented in distances.cpp and distances_simd.cpp
*/

#pragma once
#include <stdlib.h>
#include <mutex>
#include <shared_mutex>
#include "oneapi/dnnl/dnnl.hpp"

namespace faiss {

static dnnl::engine cpu_engine;
static dnnl::stream engine_stream;
static bool is_onednn_init = false;
static std::mutex init_mutex;

static bool is_amxbf16_supported() {
unsigned int eax, ebx, ecx, edx;
__asm__ __volatile__("cpuid"
: "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx)
: "a"(7), "c"(0));
return edx & (1 << 22);
}

static void init_onednn() {
std::unique_lock<std::mutex> lock(init_mutex);

if (is_onednn_init) {
return;
}

// init dnnl engine
cpu_engine = dnnl::engine(dnnl::engine::kind::cpu, 0);
engine_stream = dnnl::stream(cpu_engine);

is_onednn_init = true;
}

__attribute__((constructor)) static void library_load() {
// this functionn will be automatically called when the library is loaded
// printf("Library loaded.\n");
init_onednn();
}

/**
* @brief Compute float32 matrix inner product with bf16 intermediate results to
* accelerate
* @details The main idea is:
* 1. Define float32 memory layout for input and output
* 2. Create low precision bf16 memory descriptors as inner product input
* 3. Generate inner product primitive descriptor
* 4. Execute float32 => (reorder) => bf16 => (inner product) => float32
* chain operation, isolate different precision data, accelerate inner
* product
* 5. Pipeline execution via streams for asynchronous scheduling
*
* @param xrow Row number of input matrix X
* @param xcol Column number of input matrix X
* @param yrow Row number of weight matrix Y
* @param ycol Column number of weight matrix Y
* @param in_f32_1 Input matrix pointer in float32 type
* @param in_f32_2 Weight matrix pointer in float32 type
* @param out_f32 Output matrix pointer for result in float32 type
* @return None
*/
static void comput_f32bf16f32_inner_product(
uint32_t xrow,
uint32_t xcol,
uint32_t yrow,
uint32_t ycol,
float* in_f32_1,
float* in_f32_2,
float* out_f32) {
dnnl::memory::desc f32_md1 = dnnl::memory::desc(
{xrow, xcol},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::ab);
dnnl::memory::desc f32_md2 = dnnl::memory::desc(
{yrow, ycol},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::ab);
dnnl::memory::desc f32_dst_md2 = dnnl::memory::desc(
{xrow, yrow},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::ab);

dnnl::memory f32_mem1 = dnnl::memory(f32_md1, cpu_engine, in_f32_1);
dnnl::memory f32_mem2 = dnnl::memory(f32_md2, cpu_engine, in_f32_2);
dnnl::memory f32_dst_mem = dnnl::memory(f32_dst_md2, cpu_engine, out_f32);

// inner memory bf16
dnnl::memory::desc bf16_md1 = dnnl::memory::desc(
{xrow, xcol},
dnnl::memory::data_type::bf16,
dnnl::memory::format_tag::any);
dnnl::memory::desc bf16_md2 = dnnl::memory::desc(
{yrow, ycol},
dnnl::memory::data_type::bf16,
dnnl::memory::format_tag::any);

dnnl::inner_product_forward::primitive_desc inner_product_pd =
dnnl::inner_product_forward::primitive_desc(
cpu_engine,
dnnl::prop_kind::forward_training,
bf16_md1,
bf16_md2,
f32_dst_md2);

dnnl::inner_product_forward inner_product_prim =
dnnl::inner_product_forward(inner_product_pd);

dnnl::memory bf16_mem1 =
dnnl::memory(inner_product_pd.src_desc(), cpu_engine);
dnnl::reorder(f32_mem1, bf16_mem1)
.execute(engine_stream, f32_mem1, bf16_mem1);

dnnl::memory bf16_mem2 =
dnnl::memory(inner_product_pd.weights_desc(), cpu_engine);
dnnl::reorder(f32_mem2, bf16_mem2)
.execute(engine_stream, f32_mem2, bf16_mem2);

inner_product_prim.execute(
engine_stream,
{{DNNL_ARG_SRC, bf16_mem1},
{DNNL_ARG_WEIGHTS, bf16_mem2},
{DNNL_ARG_DST, f32_dst_mem}});

// Wait for the computation to finalize.
engine_stream.wait();

// printf("comput_f32bf16f32_inner_product finished#######>\n");
}

} // namespace faiss
Loading
Loading