Skip to content

Commit 85dadbd

Browse files
committed
Update on "Move bf16_gemv_trans to ReducedPrecisionFloatGemvFastPathKernel"
Following the previous move of fp16_gemv_trans. Testing: Checked for performance regression with llm_benchmarks' `python benchmarks/benchmark_torch_mm.py llm`, didn't find one Differential Revision: [D64930872](https://our.internmc.facebook.com/intern/diff/D64930872/) cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
2 parents 13b88aa + 6b7adba commit 85dadbd

File tree

354 files changed

+2652
-2016
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

354 files changed

+2652
-2016
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,6 @@
131131
path = third_party/composable_kernel
132132
url = https://github.com/ROCm/composable_kernel.git
133133
branch = develop
134+
[submodule "third_party/x86-simd-sort"]
135+
path = third_party/x86-simd-sort
136+
url = https://github.com/intel/x86-simd-sort.git

.lintrunner.toml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ include_patterns = [
5858
'aten/src/ATen/mps/**/*.mm',
5959
'aten/src/ATen/xpu/**/*.h',
6060
'aten/src/ATen/xpu/**/*.cpp',
61+
'aten/src/ATen/native/mps/**/*.metal',
6162
'aten/src/ATen/native/mps/**/*.mm',
6263
'aten/src/ATen/native/vulkan/**/*.h',
6364
'aten/src/ATen/native/vulkan/**/*.cpp',
@@ -69,8 +70,6 @@ include_patterns = [
6970
'aten/src/ATen/native/cudnn/*.cpp',
7071
'c10/**/*.h',
7172
'c10/**/*.cpp',
72-
'distributed/c10d/*DMAConnectivity.*',
73-
'distributed/c10d/*SymmetricMemory.*',
7473
'torch/csrc/**/*.h',
7574
'torch/csrc/**/*.hpp',
7675
'torch/csrc/**/*.cpp',
@@ -79,6 +78,7 @@ include_patterns = [
7978
]
8079
exclude_patterns = [
8180
'aten/src/ATen/native/vulkan/api/vk_mem_alloc.h',
81+
'aten/src/ATen/native/mps/kernels/Quantized.metal',
8282
'c10/util/strong_type.h',
8383
'**/fb/**',
8484
'torch/csrc/inductor/aoti_torch/generated/**',
@@ -224,6 +224,9 @@ exclude_patterns = [
224224
'**/fb/**',
225225
'**/generated/**',
226226
'**/*pb.h',
227+
'**/*inl.h',
228+
'aten/src/ATen/CPUFixedAllocator.h',
229+
'aten/src/ATen/Parallel*.h',
227230
'c10/xpu/**/*.h',
228231
'c10/xpu/**/*.cpp',
229232
'c10/benchmark/intrusive_ptr_benchmark.cpp',
@@ -236,15 +239,12 @@ exclude_patterns = [
236239
'c10/util/strong_type.h',
237240
'c10/util/SmallVector.h',
238241
'c10/util/win32-headers.h',
239-
'c10/util/*inl.h',
240242
'c10/test/**/*.h',
241243
'third_party/**/*',
242244
'torch/csrc/api/include/torch/nn/modules/common.h',
243245
'torch/csrc/api/include/torch/linalg.h',
244-
'torch/csrc/api/include/torch/nn/pimpl-inl.h',
245246
'torch/csrc/autograd/generated/**',
246247
'torch/csrc/distributed/**/*.cu',
247-
'torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h',
248248
'torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp',
249249
'torch/csrc/distributed/c10d/WinSockUtils.hpp',
250250
'torch/csrc/distributed/c10d/quantization/quantization_gpu.h',
@@ -253,7 +253,6 @@ exclude_patterns = [
253253
'torch/csrc/jit/**/*',
254254
'torch/csrc/jit/serialization/mobile_bytecode_generated.h',
255255
'torch/csrc/utils/pythoncapi_compat.h',
256-
'torch/csrc/utils/throughput_benchmark-inl.h',
257256
]
258257
init_command = [
259258
'python3',

CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ else()
262262
cmake_dependent_option(USE_CUFILE "Use cuFile" OFF "USE_CUDA AND NOT WIN32" OFF)
263263
endif()
264264
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
265+
option(USE_X86_SIMD_SORT "Use x86-simd-sort to accelerate sorting and topk for AVX2/AVX512" ON)
265266
option(USE_KINETO "Use Kineto profiling library" ON)
266267
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)
267268
option(USE_FAKELOWP "Use FakeLowp operators" OFF)
@@ -903,6 +904,13 @@ if(USE_FBGEMM)
903904
string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM")
904905
endif()
905906

907+
if(USE_X86_SIMD_SORT)
908+
string(APPEND CMAKE_CXX_FLAGS " -DUSE_X86_SIMD_SORT")
909+
if(USE_XSS_OPENMP)
910+
string(APPEND CMAKE_CXX_FLAGS " -DXSS_USE_OPENMP")
911+
endif()
912+
endif()
913+
906914
if(USE_PYTORCH_QNNPACK)
907915
string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_QNNPACK")
908916
endif()

NOTICE

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,3 +454,37 @@ and reference the following license:
454454
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
455455
OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
456456
PERFORMANCE OF THIS SOFTWARE.
457+
458+
=======================================================================
459+
x86-simd-sort BSD 3-Clause License
460+
=======================================================================
461+
462+
Code derived from implementations in x86-simd-sort should mention its
463+
derivation and reference the following license:
464+
465+
Copyright (c) 2022, Intel. All rights reserved.
466+
467+
Redistribution and use in source and binary forms, with or without
468+
modification, are permitted provided that the following conditions are met:
469+
470+
1. Redistributions of source code must retain the above copyright notice, this
471+
list of conditions and the following disclaimer.
472+
473+
2. Redistributions in binary form must reproduce the above copyright notice,
474+
this list of conditions and the following disclaimer in the documentation
475+
and/or other materials provided with the distribution.
476+
477+
3. Neither the name of the copyright holder nor the names of its
478+
contributors may be used to endorse or promote products derived from
479+
this software without specific prior written permission.
480+
481+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
482+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
483+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
484+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
485+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
486+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
487+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
488+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
489+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
490+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

aten/src/ATen/Context.cpp

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
#include <c10/core/CPUAllocator.h>
66

77
#include <algorithm>
8+
#include <array>
89
#include <cctype>
9-
#include <string>
1010
#include <stdexcept>
11+
#include <string>
1112

1213
#include <ATen/cpu/FlushDenormal.h>
1314

@@ -72,7 +73,7 @@ bool Context::deterministicAlgorithmsWarnOnly() const {
7273
return _deterministic_algorithms_warn_only;
7374
}
7475

75-
void Context::setDeterministicAlgorithms(bool b, bool warn_only=false) {
76+
void Context::setDeterministicAlgorithms(bool b, bool warn_only = false) {
7677
_deterministic_algorithms = b;
7778
_deterministic_algorithms_warn_only = warn_only;
7879
}
@@ -169,27 +170,21 @@ bool Context::userEnabledOverrideableSDP() const {
169170
return enabled_overrideable;
170171
}
171172

172-
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
173-
static const char cublas_config_var_name[] = "CUBLAS_WORKSPACE_CONFIG";
174-
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
175-
static const char* const cublas_deterministic_configs[] = { ":4096:8", ":16:8" };
173+
static constexpr const auto cublas_config_var_name = "CUBLAS_WORKSPACE_CONFIG";
174+
static constexpr const std::array<const char*, 2> cublas_deterministic_configs = {":4096:8", ":16:8"};
176175

177176
bool Context::checkCuBLASConfigDeterministic() {
178-
bool cublas_config_deterministic = true;
179177
// If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config
180178
// is set to deterministic setting
181-
if (hasCUDART() && (versionCUDART() >= 10020)) {
182-
char* workspace_config = std::getenv(cublas_config_var_name);
183-
cublas_config_deterministic = (workspace_config != nullptr) && (
184-
(strcmp(workspace_config, cublas_deterministic_configs[0]) == 0)
185-
|| (strcmp(workspace_config, cublas_deterministic_configs[1]) == 0)
186-
);
179+
if (hasCUDART()) {
180+
const auto workspace_config = c10::utils::get_env(cublas_config_var_name);
181+
return (workspace_config == cublas_deterministic_configs[0] || workspace_config == cublas_deterministic_configs[1]);
187182
}
188-
return cublas_config_deterministic;
183+
return true;
189184
}
190185

191186
void Context::alertCuBLASConfigNotDeterministic() const {
192-
static bool cublas_config_deterministic = checkCuBLASConfigDeterministic();
187+
static const bool cublas_config_deterministic = checkCuBLASConfigDeterministic();
193188
if (C10_LIKELY(!deterministicAlgorithms() || cublas_config_deterministic)) {
194189
return;
195190
}

aten/src/ATen/ParallelThreadPoolNative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ std::shared_ptr<TaskThreadPoolBase> create_c10_threadpool(
4545

4646
} // namespace
4747

48-
C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool);
48+
C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool)
4949

5050
void set_num_interop_threads(int nthreads) {
5151
TORCH_CHECK(nthreads > 0, "Expected positive number of threads");

aten/src/ATen/Version.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ std::string get_mkl_version() {
2424
{
2525
// Magic buffer number is from MKL documentation
2626
// https://software.intel.com/en-us/mkl-developer-reference-c-mkl-get-version-string
27-
char buf[198];
28-
mkl_get_version_string(buf, 198);
29-
version = buf;
27+
version.resize(198,'\0');
28+
mkl_get_version_string(version.data(), 198);
29+
version.resize(strlen(version.c_str()));
3030
}
3131
#else
3232
version = "MKL not found";

aten/src/ATen/WrapDimUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ inline int64_t maybe_wrap_dim(
3535
// if necessary
3636
return dim;
3737
}
38-
return maybe_wrap_dim(dim, tensor_sizes[0].size());
38+
return maybe_wrap_dim(dim, static_cast<int64_t>(tensor_sizes[0].size()));
3939
}
4040

4141
// Given an array of dimensions `dims` of length `ndims`, this function "Wraps"

aten/src/ATen/core/dispatch/OperatorEntry.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ void OperatorEntry::reportSignatureError(const CppSignature& call_signature, con
510510
"This likely happened in a call to OperatorHandle::typed<Return (Args...)>(). ",
511511
"Please make sure that the function signature matches the signature in the operator registration call."
512512
);
513-
};
513+
}
514514

515515
#ifndef STRIP_ERROR_MESSAGES
516516
static std::string post_process_dispatch_key_str(std::string dispatch_key) {

aten/src/ATen/core/ivalue.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,6 +1360,7 @@ struct TORCH_API IValue final {
13601360
Payload(Payload&&) = delete;
13611361
Payload& operator=(const Payload&) = delete;
13621362
Payload& operator=(Payload&&) = delete;
1363+
// NOLINTNEXTLINE(modernize-use-equals-default)
13631364
~Payload() {}
13641365
};
13651366

0 commit comments

Comments
 (0)