Skip to content

Commit 7362518

Browse files
hariharans29github-actions[bot]edgchen1
authored andcommitted
Internal Dupe of #25255 - [MLAS] Optimize MlasConv using thread partition opt (#26103)
### Description This is an internal branch dupe of #25255 + some minor cosmetic changes to account for Copilot feedback ### Motivation and Context Improve performance of NCHW Conv - Both grouped convolutions and batched inputs should benefit from this change. For a detailed understanding of perf improvement, please refer to the numbers in #25255. Credit to @zoeczy and team for this improvement and code change --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Edward Chen <[email protected]>
1 parent d9b2048 commit 7362518

File tree

3 files changed

+278
-2
lines changed

3 files changed

+278
-2
lines changed

onnxruntime/core/mlas/lib/convolve.cpp

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,82 @@ Return Value:
729729
}
730730
}
731731

732+
void
733+
MlasConvExpandThenGemmSegmentedThreaded(
734+
void* Context,
735+
ptrdiff_t Index
736+
)
737+
/*++
738+
739+
Routine Description:
740+
741+
This routine is invoked from a worker thread to execute a segment of a
742+
convolution operation.
743+
744+
If using this, the entire convolution operation is parallelized on the
745+
(batch size * group count) parameter and this routine has logic to
746+
perform a specific thread's shard of the entire Convolution operation.
747+
748+
Arguments:
749+
750+
Context - Supplies the pointer to the context for the threaded operation.
751+
752+
Index - Supplies the current index of the threaded operation.
753+
754+
Return Value:
755+
756+
None.
757+
758+
--*/
759+
760+
{
761+
MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context;
762+
763+
const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters;
764+
765+
const size_t GroupCount = Parameters->GroupCount;
766+
const size_t BatchGroupCount = Parameters->BatchCount * GroupCount;
767+
768+
const size_t TargetThreadCount = WorkBlock->TargetThreadCount;
769+
770+
const size_t BatchGroupCountPerThread = BatchGroupCount / TargetThreadCount;
771+
const size_t BatchGroupCountExtra = BatchGroupCount % TargetThreadCount;
772+
773+
size_t BatchGroupStart;
774+
size_t BatchGroupEnd;
775+
776+
if (static_cast<size_t>(Index) < BatchGroupCountExtra) {
777+
BatchGroupStart = (BatchGroupCountPerThread + 1) * Index;
778+
BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread + 1;
779+
} else {
780+
BatchGroupStart = BatchGroupCountPerThread * Index + BatchGroupCountExtra;
781+
BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread;
782+
}
783+
784+
const size_t FilterCount = Parameters->FilterCount;
785+
const size_t OutputSize = Parameters->OutputSize;
786+
const size_t K = Parameters->K;
787+
788+
const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize;
789+
const size_t OutputGroupSize = FilterCount * OutputSize;
790+
const size_t FilterGroupSize = FilterCount * K;
791+
792+
for (size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++) {
793+
size_t group = bg % GroupCount;
794+
795+
const float* input = WorkBlock->Input + bg * InputGroupSize;
796+
const float* filter = WorkBlock->Filter + group * FilterGroupSize;
797+
float* output = WorkBlock->Output + bg * OutputGroupSize;
798+
const float* bias = WorkBlock->Bias;
799+
if (bias != nullptr) {
800+
bias += group * FilterCount;
801+
}
802+
float* ColumnBuffer = WorkBlock->WorkingBuffer + Index * OutputSize * K;
803+
804+
MlasConvOperation(Parameters, input, filter, bias, ColumnBuffer, output, 0, OutputSize);
805+
}
806+
}
807+
732808
inline
733809
bool
734810
MlasConvTryMultithread(
@@ -890,8 +966,8 @@ Return Value:
890966

891967
ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool);
892968

893-
if (size_t(TargetThreadCount) >= BatchGroupCount) {
894-
TargetThreadCount = ptrdiff_t(BatchGroupCount);
969+
if (static_cast<size_t>(TargetThreadCount) >= BatchGroupCount) {
970+
TargetThreadCount = static_cast<ptrdiff_t>(BatchGroupCount);
895971
}
896972

897973
MLAS_CONV_WORK_BLOCK WorkBlock;
@@ -919,6 +995,30 @@ Return Value:
919995

920996
#endif
921997

998+
if (Algorithm == MlasConvAlgorithmExpandThenGemmSegmented && ((BatchCount > 1) || (GroupCount > 1))) {
999+
const size_t BatchGroupCount = BatchCount * GroupCount;
1000+
1001+
ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool);
1002+
1003+
if (static_cast<size_t>(TargetThreadCount) >= BatchGroupCount) {
1004+
TargetThreadCount = static_cast<ptrdiff_t>(BatchGroupCount);
1005+
}
1006+
1007+
MLAS_CONV_WORK_BLOCK WorkBlock;
1008+
1009+
WorkBlock.Parameters = Parameters;
1010+
WorkBlock.Input = Input;
1011+
WorkBlock.Filter = Filter;
1012+
WorkBlock.Bias = Bias;
1013+
WorkBlock.WorkingBuffer = WorkingBuffer;
1014+
WorkBlock.Output = Output;
1015+
WorkBlock.TargetThreadCount = TargetThreadCount;
1016+
1017+
MlasExecuteThreaded(MlasConvExpandThenGemmSegmentedThreaded, &WorkBlock, TargetThreadCount, ThreadPool);
1018+
1019+
return;
1020+
}
1021+
9221022
//
9231023
// Iterate over each batch and group.
9241024
//
@@ -1308,6 +1408,18 @@ Return Value:
13081408
Parameters->u.ExpandThenGemmSegmented.ThreadStrideN = StrideN;
13091409

13101410
*WorkingBufferSize = TargetThreadCount * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD;
1411+
1412+
if (Parameters->BatchCount > 1 || Parameters->GroupCount > 1) {
1413+
1414+
size_t WorkingBufferSizePerThread = std::max({Parameters->OutputSize * Parameters->K,
1415+
Parameters->FilterCount * Parameters->OutputSize,
1416+
static_cast<size_t>(MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD)});
1417+
TargetThreadCount = MaximumThreadCount;
1418+
if (static_cast<size_t>(TargetThreadCount) >= Parameters->BatchCount * Parameters->GroupCount) {
1419+
TargetThreadCount = static_cast<ptrdiff_t>(Parameters->BatchCount * Parameters->GroupCount);
1420+
}
1421+
*WorkingBufferSize = TargetThreadCount * WorkingBufferSizePerThread;
1422+
}
13111423
}
13121424
}
13131425
#if defined(_MSC_VER) && !defined(__clang__)

onnxruntime/test/mlas/bench/bench_sconv.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "mlas.h"
55
#include "bench_util.h"
6+
#include "core/util/thread_utils.h"
67

78
#include <stdexcept>
89
#include <numeric>
@@ -138,6 +139,113 @@ void SCONV_NCHW(benchmark::State& state, const char* /*dummy*/) {
138139
}
139140
}
140141

142+
static MLAS_THREADPOOL* GetMlasThreadPoolForConvBenchmark(void) {
143+
static auto threadpool = std::make_unique<onnxruntime::concurrency::ThreadPool>(
144+
&onnxruntime::Env::Default(), onnxruntime::ThreadOptions(), nullptr, 4, true);
145+
return threadpool.get();
146+
}
147+
148+
void SCONV_NCHW_THREADED(benchmark::State& state, const char* /*dummy*/) {
149+
MLAS_THREADPOOL* tp = GetMlasThreadPoolForConvBenchmark();
150+
151+
const int64_t rank = state.range(0); // Rank
152+
const int64_t batch_size = state.range(1); // N
153+
const int64_t groups = state.range(2); // G
154+
const int64_t input_channels_per_group = state.range(3); // Cpg
155+
const int64_t output_channels_per_group = state.range(4); // Fpg
156+
157+
if (rank <= 0) throw std::invalid_argument("Kernel rank must greater than 0!");
158+
if (batch_size <= 0) throw std::invalid_argument("Batch size must greater than 0!");
159+
if (groups <= 0) throw std::invalid_argument("Group count must greater than 0!");
160+
if (input_channels_per_group <= 0) throw std::invalid_argument("input_channels_per_group must greater than 0!");
161+
if (output_channels_per_group <= 0) throw std::invalid_argument("output_channels_per_group must greater than 0!");
162+
163+
size_t arg_position = 5;
164+
const auto input_shape = BenchArgsVector(state, arg_position, rank);
165+
const auto kernel_shape = BenchArgsVector(state, arg_position, rank);
166+
const auto paddings = BenchArgsVector(state, arg_position, rank * 2);
167+
const auto strides = BenchArgsVector(state, arg_position, rank);
168+
const auto dilations = BenchArgsVector(state, arg_position, rank);
169+
170+
// do not check the size of each vector as they are forced from args.
171+
if (std::any_of(input_shape.begin(), input_shape.end(), [](const int64_t& dim) { return dim <= 0; })) {
172+
throw std::invalid_argument("all input image dim must > 0");
173+
}
174+
175+
if (std::any_of(kernel_shape.begin(), kernel_shape.end(), [](const int64_t& dim) { return dim <= 0; })) {
176+
throw std::invalid_argument("all kernel dim must > 0");
177+
}
178+
179+
if (std::any_of(strides.begin(), strides.end(), [](const int64_t& dim) { return dim <= 0; })) {
180+
throw std::invalid_argument("all strides dim must > 0");
181+
}
182+
183+
if (std::any_of(dilations.begin(), dilations.end(), [](const int64_t& dim) { return dim <= 0; })) {
184+
throw std::invalid_argument("all dilations dim must > 0");
185+
}
186+
187+
const int64_t GC = groups * input_channels_per_group;
188+
const int64_t GF = groups * output_channels_per_group;
189+
std::vector<int64_t> x_shape = {batch_size, GC};
190+
x_shape.insert(x_shape.end(), input_shape.begin(), input_shape.end());
191+
std::vector<int64_t> f_shape = {GF, input_channels_per_group};
192+
f_shape.insert(f_shape.end(), kernel_shape.begin(), kernel_shape.end());
193+
194+
std::vector<int64_t> output_shape((size_t)rank);
195+
for (int64_t i = 0; i < rank; ++i) {
196+
auto km = 1 + dilations[i] * (kernel_shape[i] - 1);
197+
output_shape[i] = (paddings[i] + paddings[i + rank] + input_shape[i] - km) / strides[i] + 1;
198+
}
199+
std::vector<int64_t> y_shape = {batch_size, GF};
200+
y_shape.insert(y_shape.end(), output_shape.begin(), output_shape.end());
201+
202+
MLAS_ACTIVATION activation;
203+
activation.ActivationKind = MlasIdentityActivation;
204+
MLAS_CONV_PARAMETERS Parameters;
205+
size_t WorkingBufferSize = 0;
206+
MlasConvPrepare(&Parameters,
207+
static_cast<size_t>(rank),
208+
static_cast<size_t>(batch_size),
209+
static_cast<size_t>(groups),
210+
static_cast<size_t>(input_channels_per_group),
211+
input_shape.data(),
212+
kernel_shape.data(),
213+
dilations.data(),
214+
paddings.data(),
215+
strides.data(),
216+
output_shape.data(),
217+
static_cast<size_t>(output_channels_per_group),
218+
&activation,
219+
&WorkingBufferSize,
220+
0.0f,
221+
tp);
222+
223+
auto X = RandomVectorUniform(x_shape, -2.0, 2.0);
224+
auto F = RandomVectorUniform(f_shape, -1.0, 1.0);
225+
int64_t y_size = std::accumulate(y_shape.begin(), y_shape.end(), 1LL, std::multiplies<int64_t>());
226+
std::vector<float> Y(static_cast<size_t>(y_size));
227+
std::vector<float> working_buffer(WorkingBufferSize);
228+
229+
// warm up first round.
230+
MlasConv(&Parameters,
231+
X.data(),
232+
F.data(),
233+
nullptr,
234+
working_buffer.data(),
235+
Y.data(),
236+
tp);
237+
238+
for (auto _ : state) {
239+
MlasConv(&Parameters,
240+
X.data(),
241+
F.data(),
242+
nullptr,
243+
working_buffer.data(),
244+
Y.data(),
245+
tp);
246+
}
247+
}
248+
141249
static void ResNet50(benchmark::internal::Benchmark* b) {
142250
b->ArgNames(ArgNamesForConv(2));
143251

@@ -221,6 +329,7 @@ static void TeamsModel(benchmark::internal::Benchmark* b) {
221329
}
222330

223331
BENCHMARK_CAPTURE(SCONV_NCHW, TeamsModel, "")->Apply(TeamsModel)->UseRealTime();
332+
BENCHMARK_CAPTURE(SCONV_NCHW_THREADED, TeamsModel, "")->Apply(TeamsModel)->UseRealTime();
224333

225334
static void General_Conv2d(benchmark::internal::Benchmark* b) {
226335
b->ArgNames(ArgNamesForConv(2));

onnxruntime/test/providers/cpu/nn/conv_op_test.cc

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,61 @@ TEST(ConvTest, Conv2D_2) {
339339
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true);
340340
}
341341

342+
TEST(ConvTest, Conv2D_3) {
343+
ConvOpAndTestAttributes attrs = {
344+
"", // auto_pad
345+
vector<int64_t>{1, 1}, // dilations
346+
2, // group
347+
vector<int64_t>{2, 2}, // kernel_shape
348+
vector<int64_t>{0, 0, 0, 0}, // pads
349+
vector<int64_t>{1, 1}, // strides
350+
{} // excluded EPs
351+
};
352+
353+
vector<int64_t> X_shape = {2, 2, 3, 3};
354+
vector<float> X = {1.f, 2.f, 3.f,
355+
4.f, 5.f, 6.f,
356+
7.f, 8.f, 9.f,
357+
358+
10.f, 11.f, 12.f,
359+
13.f, 14.f, 15.f,
360+
16.f, 17.f, 18.f,
361+
362+
1.f, 2.f, 3.f,
363+
7.f, 8.f, 9.f,
364+
4.f, 5.f, 6.f,
365+
366+
13.f, 14.f, 15.f,
367+
10.f, 11.f, 12.f,
368+
16.f, 17.f, 18.f};
369+
370+
vector<int64_t> W_shape = {2, 1, 2, 2};
371+
vector<float> W = {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f};
372+
373+
vector<int64_t> Y_shape = {2, 2, 2, 2};
374+
auto Y = {
375+
37.f,
376+
47.f,
377+
67.f,
378+
77.f,
379+
254.f,
380+
274.f,
381+
314.f,
382+
334.f,
383+
58.f,
384+
68.f,
385+
55.f,
386+
65.f,
387+
230.f,
388+
250.f,
389+
296.f,
390+
316.f,
391+
};
392+
393+
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, Y, Y_shape);
394+
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, Y, Y_shape, true);
395+
}
396+
342397
TEST(ConvTest, Conv2D_Bias_1) {
343398
ConvOpAndTestAttributes attrs = {
344399
"", // auto_pad

0 commit comments

Comments
 (0)