Skip to content

Using vectorized loads (float2) for fp16 to improve performance#11390

Merged
hariharans29 merged 6 commits intomicrosoft:masterfrom
ROCm:hubertlu/fastgelu
May 5, 2022
Merged

Using vectorized loads (float2) for fp16 to improve performance#11390
hariharans29 merged 6 commits intomicrosoft:masterfrom
ROCm:hubertlu/fastgelu

Conversation

@hubertlu-tw
Copy link
Contributor

Description: Describe your changes.
Optimized LaunchFastGeluKernel for fp16 on AMD GPUs.

Motivation and Context
The performance gap is observed in FastGelu kernels with the microbenchmark on MI200 vs. A100. With the optimized kernels, we are able to see the significant performance improvement compared to A100.

@tianleiwu
Copy link
Contributor

There are a few warnings from cpplint. Please format the code (see .clang-format in the root directory). One quick way is to rename fast_gelu_impl.cu to fast_gelu_impl.cc, then use Visual Studio to format the file, then undo the rename.

const float2* bias_cast = reinterpret_cast<const float2*>(bias);
float2* output_cast = reinterpret_cast<float2*>(output);

const half2 two2 = __floats2half2_rn(two, two);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious: Why not just [__float2half2_rn](https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____HALF__MISC.html#group__CUDA__MATH____HALF__MISC_1ge40813c17ab4b0779764e2e5e3014019) since both halves are going to be populated with the same value ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks for pointing this out.

template <unsigned TPB>
__global__ void FastGeluKernel4Bias(const half2 a, const half2 b, const half2 c, int input_length, int bias_length,
const half* input, const half* bias, half* output) {
const half2 two2 = __float2half2_rn(two);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hari mentioned that the following code could be moved inside the if block:

  const half2 two2 = __float2half2_rn(two);
  const half2 one2 = __float2half2_rn(one);
  const float2* input_cast = reinterpret_cast<const float2*>(input);
  const float2* bias_cast = reinterpret_cast<const float2*>(bias);
  float2* output_cast = reinterpret_cast<float2*>(output);

That could save computation in some cases. Similar change can be made in other functions.

@hubertlu-tw
Copy link
Contributor Author

hubertlu-tw commented May 4, 2022

The perf numbers in the following table were collected on MI200.

batch-size seq-len inter-dim (Original) total time (us) over 110 runs (Oprimized) total time (us) over 110 runs Performance improvement (%) (Oprimized_1) total time (us) over 110 runs Performance improvement (%) after moving some computations to LaunchFastGeluKernel
1 128 3072 967 934 3.412616 897 7.238883144
1 128 4096 994 959 3.521127 956 3.822937626
1 384 3072 1134 1146 -1.0582 1110 2.116402116
1 384 4096 1320 1176 10.90909 1177 10.83333333
1 512 3072 1307 1174 10.17598 1171 10.4055088
1 512 4096 1559 1349 13.47017 1351 13.34188582
4 128 3072 1315 1220 7.224335 1172 10.87452471
4 128 4096 1557 1339 14.00128 1335 14.25818882
4 384 3072 2762 2268 17.88559 2253 18.42867487
4 384 4096 3487 2844 18.43992 2831 18.81273301
4 512 3072 3478 2852 17.99885 2819 18.94767108
4 512 4096 4476 3579 20.04021 3555 20.57640751
8 128 3072 2032 1700 16.33858 1699 16.38779528
8 128 4096 2519 2074 17.66574 2059 18.26121477
8 384 3072 4952 3933 20.57754 3918 20.88045234
8 384 4096 6403 5120 20.03748 5112 20.16242386
8 512 3072 6427 5141 20.00934 5101 20.63170997
8 512 4096 8382 6638 20.80649 6603 21.22405154
32 128 3072 6391 5127 19.77781 5097 20.24722266
32 128 4096 8388 6632 20.93467 6603 21.28040057
32 384 3072 18114 14264 21.25428 14167 21.78977586
32 384 4096 23999 18811 21.61757 18700 22.08008667
32 512 3072 23983 18825 21.5069 18717 21.9572197
32 512 4096 31790 24876 21.74898 24728 22.21453287
64 128 3072 12265 9703 20.88871 9625 21.52466368
64 128 4096 16172 12736 21.2466 12661 21.71036359
64 384 3072 35649 28052 21.31056 27914 21.69766333
64 384 4096 47332 37037 21.75061 36811 22.22809093
64 512 3072 47322 37036 21.73619 36829 22.17361904
64 512 4096 62939 49192 21.84178 48858 22.37245587
128 128 3072 23958 18820 21.44586 18709 21.90917439
128 128 4096 31775 24890 21.66798 24735 22.15578285
128 384 3072 70674 55384 21.63455 55196 21.90055749
128 384 4096 94078 73592 21.77555 73116 22.28151109
128 512 3072 94023 73681 21.63513 73165 22.18393372
128 512 4096 125250 97950 21.79641 97554 22.11257485

@hariharans29
Copy link
Member

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline, MacOS CI Pipeline, ONNX Runtime Web CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline

@hariharans29
Copy link
Member

/azp run Windows GPU TensorRT CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline, onnxruntime-python-checks-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed

@hariharans29
Copy link
Member

@tianleiwu @weixingzhang : Any more comments ?

@azure-pipelines
Copy link

Azure Pipelines successfully started running 6 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 10 pipeline(s).

@tianleiwu
Copy link
Contributor

@tianleiwu @weixingzhang : Any more comments ?
LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants