Skip to content

Optimize FastGelu with float2 and float4 vectorized kernels on ROCm#11491

Merged
zhangyaobit merged 21 commits intomicrosoft:masterfrom
ROCm:hubertlu/fastgelu
Jun 24, 2022
Merged

Optimize FastGelu with float2 and float4 vectorized kernels on ROCm#11491
zhangyaobit merged 21 commits intomicrosoft:masterfrom
ROCm:hubertlu/fastgelu

Conversation

@hubertlu-tw
Copy link
Contributor

Description: Describe your changes.
Optimized FastGeluKernel on ROCm.

It is relevant to the earlier PR: #11390.

Motivation and Context

batch-size seq-len inter-dim input_length (Original) total time (us) over 110 runs (float2) total time (us) over 110 runs Performance improvement (%) after moving some computations to LaunchFastGeluKernel (float_4) total time (us) over 110 runs Performance improvement (%) Peformance improvement (float2 vs. float4) total time (us) over 110 runs (conditionally swap float_2 and float_4 kernels) Performance improvement (%) Peformance improvement (float2 vs. conditionally-swapping kernels)
1 128 3072 393216 967 897 7.238883144 903 6.618407446 -0.620475698 915 5.37745605 -1.86143
1 128 4096 524288 994 956 3.822937626 955 3.923541247 0.100603621 944 5.030181087 1.207243
1 384 3072 1179648 1134 1110 2.116402116 1110 2.116402116 4.02116E-10 1090 3.880070547 1.763668
1 384 4096 1572864 1320 1177 10.83333333 1247 5.53030303 -5.3030303 1156 12.42424242 1.590909
1 512 3072 1572864 1307 1171 10.4055088 1210 7.421576129 -2.983932671 1157 11.47666412 1.071155
1 512 4096 2097152 1559 1351 13.34188582 1535 1.539448364 -11.80243746 1368 12.25144323 -1.09044
4 128 3072 1572864 1315 1172 10.87452471 1242 5.551330798 -5.323193912 1191 9.429657795 -1.44487
4 128 4096 2097152 1557 1335 14.25818882 1527 1.926782274 -12.33140655 1342 13.80860629 -0.44958
4 384 3072 4718592 2762 2253 18.42867487 2133 22.77335264 4.344677773 2142 22.44750181 4.018827
4 384 4096 6291456 3487 2831 18.81273301 2722 21.93862919 3.125896184 2720 21.99598509 3.183252
4 512 3072 6291456 3478 2819 18.94767108 2743 21.13283496 2.185163883 2719 21.82288672 2.875216
4 512 4096 8388608 4476 3555 20.57640751 3430 23.36907954 2.792672025 3305 26.16175156 5.585344
8 128 3072 3145728 2032 1699 16.38779528 1622 20.17716535 3.789370074 1621 20.22637795 3.838583
8 128 4096 4194304 2519 2059 18.26121477 1963 22.07225089 3.811036123 1968 21.87375943 3.612545
8 384 3072 9437184 4952 3918 20.88045234 3656 26.17124394 5.290791602 3640 26.49434572 5.613893
8 384 4096 12582912 6403 5112 20.16242386 4660 27.22161487 7.059191008 4647 27.4246447 7.262221
8 512 3072 12582912 6427 5101 20.63170997 4642 27.77345573 7.141745764 4645 27.72677766 7.095068
8 512 4096 16777216 8382 6603 21.22405154 6003 28.38224767 7.158196134 5988 28.56120258 7.337151
32 128 3072 12582912 6391 5097 20.24722266 4629 27.57002034 7.322797681 4633 27.50743233 7.26021
32 128 4096 16777216 8388 6603 21.28040057 6003 28.43347639 7.153075825 6004 28.4215546 7.141154
32 384 3072 37748736 18114 14167 21.78977586 13001 28.22678591 6.437010051 13010 28.17710059 6.387325
32 384 4096 50331648 23999 18700 22.08008667 17324 27.8136589 5.733572232 17245 28.14283928 6.062753
32 512 3072 50331648 23983 18717 21.9572197 17253 28.06154359 6.104323893 17219 28.20331068 6.246091
32 512 4096 67108864 31790 24728 22.21453287 22958 27.78232148 5.567788615 23022 27.58100031 5.366467
64 128 3072 25165824 12265 9625 21.52466368 9020 26.4573991 4.932735423 8924 27.24011415 5.71545
64 128 4096 33554432 16172 12661 21.71036359 11620 28.14741529 6.437051696 11639 28.02992827 6.319565
64 384 3072 75497472 35649 27914 21.69766333 25535 28.3710623 6.673398972 25541 28.35423154 6.656568
64 384 4096 100663296 47332 36811 22.22809093 33809 28.57052311 6.342432183 33820 28.54728302 6.319192
64 512 3072 100663296 47322 36829 22.17361904 33847 28.47512785 6.301508808 33857 28.45399603 6.280377
64 512 4096 134217728 62939 48858 22.37245587 44991 28.5165001 6.144044233 44975 28.54192154 6.169466
128 128 3072 50331648 23958 18709 21.90917439 17197 28.22021872 6.311044326 17192 28.24108857 6.331914
128 128 4096 67108864 31775 24735 22.15578285 22957 27.75137687 5.595594019 22990 27.64752164 5.491739
128 384 3072 150994944 70674 55196 21.90055749 50813 28.10227241 6.201714916 50937 27.92681892 6.026261
128 384 4096 201326592 94078 73116 22.28151109 67133 28.64112757 6.359616485 67120 28.6549459 6.373435
128 512 3072 201326592 94023 73165 22.18393372 67210 28.5174904 6.333556681 67222 28.50472757 6.320794
128 512 4096 268435456 125250 97554 22.11257485 89731 28.35848303 6.245908184 89809 28.29620758 6.183633

@hubertlu-tw hubertlu-tw changed the title Hubertlu/fastgelu Optimize FastGelu with float2 and float4 vectorized kernels on ROCm May 11, 2022
@hubertlu-tw
Copy link
Contributor Author

We might be able to use aligned_vector to replace manual vectorized memory access. The example code can be found here:

using LoadT = aligned_vector<T, UNROLL>;
using MaskLoadT = aligned_vector<bool, UNROLL>;
for (CUDA_LONG id = idx * UNROLL; id < N; id += step_size) {
rand = curand_uniform4(&state);
// vectorized load into storage
T src[UNROLL];
LoadT *value = reinterpret_cast<LoadT*>(&src);
*value = *reinterpret_cast<const LoadT*>(&X_data[id]);
T r[UNROLL];
bool mask[UNROLL];
// actual computation
#pragma unroll
for (int ii = 0; ii < UNROLL; ii++) {
mask[ii] = (&rand.x)[ii] < p;
r[ii] = T(float(src[ii]) * mask[ii] * scale);
}
// Vectorized writes for mask_data & Y_data
*(reinterpret_cast<LoadT*>(&Y_data[id])) = *reinterpret_cast<LoadT*>(&r[0]);
*(reinterpret_cast<MaskLoadT*>(&mask_data[id])) = *reinterpret_cast<MaskLoadT*>(&mask[0]);
__syncthreads();

@hubertlu-tw
Copy link
Contributor Author

@tianleiwu could you please help me review this PR? Thanks.

@zhangyaobit zhangyaobit requested review from PeixuanZuo and pengwa May 27, 2022 04:13
@zhangyaobit
Copy link
Contributor

We might be able to use aligned_vector to replace manual vectorized memory access. The example code can be found here:

using LoadT = aligned_vector<T, UNROLL>;
using MaskLoadT = aligned_vector<bool, UNROLL>;
for (CUDA_LONG id = idx * UNROLL; id < N; id += step_size) {
rand = curand_uniform4(&state);
// vectorized load into storage
T src[UNROLL];
LoadT *value = reinterpret_cast<LoadT*>(&src);
*value = *reinterpret_cast<const LoadT*>(&X_data[id]);
T r[UNROLL];
bool mask[UNROLL];
// actual computation
#pragma unroll
for (int ii = 0; ii < UNROLL; ii++) {
mask[ii] = (&rand.x)[ii] < p;
r[ii] = T(float(src[ii]) * mask[ii] * scale);
}
// Vectorized writes for mask_data & Y_data
*(reinterpret_cast<LoadT*>(&Y_data[id])) = *reinterpret_cast<LoadT*>(&r[0]);
*(reinterpret_cast<MaskLoadT*>(&mask_data[id])) = *reinterpret_cast<MaskLoadT*>(&mask[0]);
__syncthreads();

Are there any problems of using aligned_vector here?

It looks like we should use it, as Peixuan previously recommended. I believe with aligned_vector, we don't need maintain two copies of mostly similar code respectively for float2 and float4, we just need a single template code which could be instantiated twice.

Example code could be found here: onnxruntime/core/providers/cuda/math/softmax_blockwise_impl.cuh:270
onnxruntime/core/providers/cuda/math/softmax_impl.cu:120

zhangyaobit
zhangyaobit previously approved these changes Jun 20, 2022
@zhangyaobit zhangyaobit reopened this Jun 21, 2022
@zhangyaobit
Copy link
Contributor

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline, MacOS CI Pipeline, ONNX Runtime Web CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 10 pipeline(s).

@zhangyaobit
Copy link
Contributor

/azp run Windows GPU TensorRT CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline, onnxruntime-python-checks-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 6 pipeline(s).

@ytaous
Copy link
Contributor

ytaous commented Jun 23, 2022

The failed tests seem to be related opset 17 stuff. Can u pls merge your code with master again and see if the failed tests go away?

@azure-pipelines
Copy link

Commenter does not have sufficient privileges for PR 11491 in repo microsoft/onnxruntime

@zhangyaobit
Copy link
Contributor

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline, MacOS CI Pipeline, ONNX Runtime Web CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline

@zhangyaobit
Copy link
Contributor

/azp run Windows GPU TensorRT CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline, onnxruntime-python-checks-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 10 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 6 pipeline(s).

tianleiwu
tianleiwu previously approved these changes Jun 23, 2022
@tianleiwu
Copy link
Contributor

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline, MacOS CI Pipeline, ONNX Runtime Web CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run Windows GPU TensorRT CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline, onnxruntime-python-checks-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 6 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 10 pipeline(s).

@zhangyaobit zhangyaobit merged commit f4ba199 into microsoft:master Jun 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants