Skip to content

Bfloat16 support for MatMulBnb4, Training support bitsandbytes>=0.41.2#18484

Merged
jambayk merged 5 commits intomainfrom
jambayk/matmulbnb4-bf16
Nov 20, 2023
Merged

Bfloat16 support for MatMulBnb4, Training support bitsandbytes>=0.41.2#18484
jambayk merged 5 commits intomainfrom
jambayk/matmulbnb4-bf16

Conversation

@jambayk
Copy link
Contributor

@jambayk jambayk commented Nov 17, 2023

Description

Add bfloat16 support for MatMulBnb4 contrib op. This is useful for QLoRA fine-tuning.

  • On GPUs with SM80+ (A100, etc), it uses the native cuda bfloat16 dtype, nv_bfloat16. On other GPUs, it uses the onnxruntime BFloat16 type which uses float for compute.
  • I have validated the op in a llama2-7b training scenario. The losses match pytorch training and the training throughput is better.
  • Cannot add a bfloat16 case in the op unit test since casting BFloat16 to and from float multiple times during the test causes the required tolerances to be unachievable.

The custom autograd function exporter in onnxruntime-training is updated to support the latest version of bitsandbytes. They changed how the quant_state is stored.

Motivation and Context

Enable QLoRA fine-tuning with bfloat16.

@askhade
Copy link
Contributor

askhade commented Nov 17, 2023

"Cannot add a bfloat16 case in the op unit test since casting BFloat16 to and from float multiple times during the test causes the required tolerances to be unachievable." - Can you add a test which will only run if the hardware supports the data type? We plan to move the CI pipelines on A10, once this is done we will be able to run the test.

@jambayk
Copy link
Contributor Author

jambayk commented Nov 17, 2023

"Cannot add a bfloat16 case in the op unit test since casting BFloat16 to and from float multiple times during the test causes the required tolerances to be unachievable." - Can you add a test which will only run if the hardware supports the data type? We plan to move the CI pipelines on A10, once this is done we will be able to run the test.

@askhade I tried this too but because of how the test is setup, the tolerance issue still remains.

  • The reference value for matrix multiplication output is calculated in float. And it is then cast to bfloat16.
  • The actual value is computed by
    • cast input to bfloat16
    • compute in bfloat16
    • accumulate in float -> cast to bfloat16

Because of the difference in where and how the cast to bfloat16 happens, the outputs vary significantly in some places. This was not an issue for float16 since casting from float to float16 does not change the value that much.

Edit: There are also differences from the dequantization step too.

pengwa
pengwa previously approved these changes Nov 17, 2023
Copy link
Contributor

@pengwa pengwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for fixing the exporter issue BTW.

@jambayk jambayk requested a review from pengwa November 17, 2023 21:18
@jambayk jambayk merged commit 1af0681 into main Nov 20, 2023
@jambayk jambayk deleted the jambayk/matmulbnb4-bf16 branch November 20, 2023 17:52
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
microsoft#18484)

### Description
<!-- Describe your changes. -->
Add bfloat16 support for `MatMulBnb4` contrib op. This is useful for
QLoRA fine-tuning.
- On GPUs with SM80+ (A100, etc), it uses the native cuda bfloat16
dtype, `nv_bfloat16`. On other GPUs, it uses the onnxruntime `BFloat16`
type which uses float for compute.
- I have validated the op in a llama2-7b training scenario. The losses
match pytorch training and the training throughput is better.
- Cannot add a bfloat16 case in the op unit test since casting BFloat16
to and from float multiple times during the test causes the required
tolerances to be unachievable.

The custom autograd function exporter in onnxruntime-training is updated
to support the latest version of bitsandbytes. They changed how the
`quant_state` is stored.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Enable QLoRA fine-tuning with bfloat16.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants