Bfloat16 support for MatMulBnb4, Training support bitsandbytes>=0.41.2#18484
Bfloat16 support for MatMulBnb4, Training support bitsandbytes>=0.41.2#18484
Conversation
reset unit test alphabetical order
|
"Cannot add a bfloat16 case in the op unit test since casting BFloat16 to and from float multiple times during the test causes the required tolerances to be unachievable." - Can you add a test which will only run if the hardware supports the data type? We plan to move the CI pipelines on A10, once this is done we will be able to run the test. |
@askhade I tried this too but because of how the test is setup, the tolerance issue still remains.
Because of the difference in where and how the cast to bfloat16 happens, the outputs vary significantly in some places. This was not an issue for float16 since casting from float to float16 does not change the value that much. Edit: There are also differences from the dequantization step too. |
pengwa
left a comment
There was a problem hiding this comment.
LGTM. Thanks for fixing the exporter issue BTW.
This reverts commit 21ddce3.
microsoft#18484) ### Description <!-- Describe your changes. --> Add bfloat16 support for `MatMulBnb4` contrib op. This is useful for QLoRA fine-tuning. - On GPUs with SM80+ (A100, etc), it uses the native cuda bfloat16 dtype, `nv_bfloat16`. On other GPUs, it uses the onnxruntime `BFloat16` type which uses float for compute. - I have validated the op in a llama2-7b training scenario. The losses match pytorch training and the training throughput is better. - Cannot add a bfloat16 case in the op unit test since casting BFloat16 to and from float multiple times during the test causes the required tolerances to be unachievable. The custom autograd function exporter in onnxruntime-training is updated to support the latest version of bitsandbytes. They changed how the `quant_state` is stored. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Enable QLoRA fine-tuning with bfloat16.
Description
Add bfloat16 support for
MatMulBnb4contrib op. This is useful for QLoRA fine-tuning.nv_bfloat16. On other GPUs, it uses the onnxruntimeBFloat16type which uses float for compute.The custom autograd function exporter in onnxruntime-training is updated to support the latest version of bitsandbytes. They changed how the
quant_stateis stored.Motivation and Context
Enable QLoRA fine-tuning with bfloat16.