Skip to content

torch.embedding: Trying to convert BFloat16 to the MPS backend but it does not have support for that dtype. #104191

@Willian-Zhang

Description

@Willian-Zhang

🐛 Describe the bug

Code to reproduce

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

path = "gpt2" # any LM would result the same
tokenizer = AutoTokenizer.from_pretrained(path) 
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map={"":"mps"})

t = tokenizer("anything", return_attention_mask=False, return_tensors='pt')
with torch.inference_mode():
    model(**t)

results in

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[6], line 2
      1 with torch.inference_mode():
----> 2     model(**t)

File /opt/homebrew/Caskroom/miniconda/base/envs/torch/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/homebrew/Caskroom/miniconda/base/envs/torch/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File /opt/homebrew/Caskroom/miniconda/base/envs/torch/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py:1080, in GPT2LMHeadModel.forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1072 r"""
   1073 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
   1074     Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
   1075     `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
   1076     are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
   1077 """
   1078 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-> 1080 transformer_outputs = self.transformer(
   1081     input_ids,
   1082     past_key_values=past_key_values,
   1083     attention_mask=attention_mask,
   1084     token_type_ids=token_type_ids,
   1085     position_ids=position_ids,
   1086     head_mask=head_mask,
   1087     inputs_embeds=inputs_embeds,
   1088     encoder_hidden_states=encoder_hidden_states,
   1089     encoder_attention_mask=encoder_attention_mask,
   1090     use_cache=use_cache,
   1091     output_attentions=output_attentions,
   1092     output_hidden_states=output_hidden_states,
   1093     return_dict=return_dict,
   1094 )
   1095 hidden_states = transformer_outputs[0]
   1097 # Set device for model parallelism

File /opt/homebrew/Caskroom/miniconda/base/envs/torch/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/homebrew/Caskroom/miniconda/base/envs/torch/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File /opt/homebrew/Caskroom/miniconda/base/envs/torch/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py:846, in GPT2Model.forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)
    843 head_mask = self.get_head_mask(head_mask, self.config.n_layer)
    845 if inputs_embeds is None:
--> 846     inputs_embeds = self.wte(input_ids)
    847 position_embeds = self.wpe(position_ids)
    848 hidden_states = inputs_embeds + position_embeds

File /opt/homebrew/Caskroom/miniconda/base/envs/torch/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/homebrew/Caskroom/miniconda/base/envs/torch/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File /opt/homebrew/Caskroom/miniconda/base/envs/torch/lib/python3.10/site-packages/torch/nn/modules/sparse.py:162, in Embedding.forward(self, input)
    161 def forward(self, input: Tensor) -> Tensor:
--> 162     return F.embedding(
    163         input, self.weight, self.padding_idx, self.max_norm,
    164         self.norm_type, self.scale_grad_by_freq, self.sparse)

File /opt/homebrew/Caskroom/miniconda/base/envs/torch/lib/python3.10/site-packages/torch/nn/functional.py:2210, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2204     # Note [embedding_renorm set_grad_enabled]
   2205     # XXX: equivalent to
   2206     # with torch.no_grad():
   2207     #   torch.embedding_renorm_
   2208     # remove once script supports set_grad_enabled
   2209     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2210 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

TypeError: Trying to convert BFloat16 to the MPS backend but it does not have support for that dtype.

Probably here:

case ScalarType::Short:
return MPSDataTypeInt16;
case ScalarType::Char:
return MPSDataTypeInt8;
case ScalarType::Byte:
return MPSDataTypeUInt8;
case ScalarType::Bool:
return MPSDataTypeBool;
case ScalarType::Double:
TORCH_CHECK_TYPE(false,
"Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. "
"Please use float32 instead.")
default:
TORCH_CHECK_TYPE(
false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.")

I wasn't able to test this on nightly, because apparently it's been blocked currently:

TORCH_CHECK_TYPE(dtype != ScalarType::BFloat16, "BFloat16 is not supported on MPS");

BF16 support is added to the OS version (macOS Sonoma) I use recently, referred from here (with timestamp):
https://developer.apple.com/wwdc23/10050?time=590

Starting with macOS Sonoma, MPSGraph adds support for a new data type, bfloat16.

https://developer.apple.com/wwdc23/10050?time=659

Adding Automatic Mixed Precision support to your network is a very easy process. First, add autocast. Both float16 and bfloat16 are supported.

Versions

Collecting environment information...
PyTorch version: 2.0.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.0 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.0.28.1.1)
CMake version: Could not collect
Libc version: N/A

Python version: 3.10.11 (main, May 17 2023, 14:30:36) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.0-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M2

Versions of relevant libraries:
[pip3] numpy==1.25.0
[pip3] torch==2.0.1
[pip3] torchaudio==2.0.2
[pip3] torchvision==0.15.2
[conda] numpy                     1.25.0                   pypi_0    pypi
[conda] torch                     2.0.1                    pypi_0    pypi
[conda] torchaudio                2.0.2                    pypi_0    pypi
[conda] torchvision               0.15.2                   pypi_0    pypi

cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: bfloat16module: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions