Deberta can now be exported to TorchScript #27734
Deberta can now be exported to TorchScript #27734Szustarol wants to merge 4 commits intohuggingface:mainfrom
Conversation
|
Great work @Szustarol ! I think this solution still pins the device for the traced model. One solution is to create an ONNX model instead. Perhaps you might have a better solution? |
|
Thanks for the response @demq, I would be more than happy to expand on my solution, however I am not quite sure if I get your suggestion right. |
Yes - your solution here correctly addresses the open issue. You are absolutely correct that the device pinning is caused by jit.trace(), and it is a separate issue from what you have addressed here. Previously, I experienced the issue of device pinning on a traced (using the decorator trick on XSoftmax) fine-tuned Deberta model, so we had to export to ONNX instead of torch script to get around it. My comment was inspired by seeing how quick you were to submit a solution to this issue :) I would imagine one way to ensure no device pinning is by exporting the model through jit.script() instead of jit.trace(). |
|
Okay, I think now I understand your point, thanks! I will surely have a look into it and report if I can find a solution. I think one way to fix this is to check what parts of the code cause model pinning and try to get rid of them which is what I will try. Should we open an issue for this? |
844a1da to
7657821
Compare
|
Okay I have done some extensive research and testing on this subject mentioned by @demq and it appears to me that during the tracing process all devices in the import torch
import io
from transformers import AutoTokenizer, AutoModel
# tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v2-xlarge")
# model = AutoModel.from_pretrained("microsoft/deberta-v2-xlarge", torchscript=True).to("cuda")
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-base")
model = AutoModel.from_pretrained("microsoft/deberta-base", torchscript=True).to("cuda")
tokenized_dict = tokenizer(
["Is this working",], ["Not yet",],
return_tensors="pt"
).to('cuda')
input_tuple = (tokenized_dict['input_ids'], tokenized_dict['attention_mask'])
traced_model = torch.jit.trace(model, input_tuple)
model_bytes = io.BytesIO()
torch.jit.save(traced_model, model_bytes)
model_bytes.seek(0)
print("######### tracing and saving done")
loaded_model = torch.jit.load(model_bytes, map_location='cpu')
tokenized_dict = tokenized_dict.to('cpu')
input_tuple = (tokenized_dict['input_ids'], tokenized_dict['attention_mask'])
print(loaded_model(*input_tuple)[0].device) # outputs cpuI would love to hear some feedback on this code, as with the extension of this task by the device problem I feel like the model modifications have now become quite extensive. |
|
Thanks for the great update. I can confirm that tracing works now without device pinning in my local linux env: I have used slightly changed test script: It works fine, but I get the following warnings during the jit.trace(): |
|
Yes of course you are right, first time seeing those errors I have erroneously assumed that the tensor shape will be constant for a given config which might be true for |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Hey! Do you need a review on this? |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
7657821 to
16858cb
Compare
|
The reason it took so long is that while trying to fix the issue mentioned above, with constant size pinning, I have noticed that the HF tests use a HPFProxy object during tracing, which is not compatible with |
The latest version seems to work fine for me @Szustarol , thank you very much for your efforts! I can trace a model on a cpu without any warnings and run inference on GPUs, there is no apparent device pinning or warnings on constant size pinning. Perhaps it is worth to ask @ArthurZucker to review the PR and merge it to master? |
|
Sure let me review it! |
ArthurZucker
left a comment
There was a problem hiding this comment.
As I'm not an torch script expert, would you mind educating me a bit?
Could you walk me through motivations behind some of the changes?
(LGTM overall!)
| return self.config.hidden_size | ||
|
|
||
|
|
||
| # copied from transformers.models.deberta.modeling_deberta._traceable |
There was a problem hiding this comment.
| # copied from transformers.models.deberta.modeling_deberta._traceable | |
| # Copied from transformers.models.deberta.modeling_deberta._traceable |
There was a problem hiding this comment.
All other copied from need this capital 😉
| if tensor.dtype in [torch.float16, torch.float32, torch.float64]: | ||
| # Will not be baked in during tracing | ||
| return _get_float_min_value(tensor) | ||
| # Will be baked in during tracing | ||
| return torch.tensor(torch.finfo(tensor.dtype).min) |
There was a problem hiding this comment.
why does tracing have an issue with this? Known bug or expected behaviour?
|
|
||
| Return: | ||
| `torch.LongTensor`: A tensor with shape [1, query_size, key_size] | ||
| `torch.LongTensor`: A tensor with shape [1, query_layer.size, key-layer.size] |
There was a problem hiding this comment.
| `torch.LongTensor`: A tensor with shape [1, query_layer.size, key-layer.size] | |
| `torch.LongTensor`: A tensor with shape [1, query_length, key_length] |
| q_ids = torch.arange(query_size, dtype=torch.long, device=query_layer.device) | ||
| k_ids = torch.arange(key_size, dtype=torch.long, device=key_layer.device) | ||
| rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1) | ||
| rel_pos_ids = rel_pos_ids[:query_size, :] |
There was a problem hiding this comment.
| rel_pos_ids = rel_pos_ids[:query_size, :] | |
| rel_pos_ids = rel_pos_ids[q_ids, :] |
pretty sure when compiling this is faster. I don't know exactly for scripting, but it's slow for cpu infer. Just noting here !
| @torch.jit.script | ||
| def scaled_size_sqrt(query_layer, scale_factor: int): | ||
| return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) |
There was a problem hiding this comment.
why does this have to be define as such? (FMI) known torch script issue or expected way to do so?
|
|
||
| if token_type_ids is None: | ||
| token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) | ||
| token_type_ids = get_token_type_ids_on_device(device_discriminator, using_ids) |
There was a problem hiding this comment.
if we do this after computing the word embedding we can always use the same device discriminator no ?
|
|
||
|
|
||
| @torch.jit.script | ||
| # copied from transformers.models.deberta.modeling_deberta._get_float_min_value |
There was a problem hiding this comment.
| # copied from transformers.models.deberta.modeling_deberta._get_float_min_value | |
| # Copied from transformers.models.deberta.modeling_deberta._get_float_min_value |
|
|
||
| def prepare_config_and_inputs_for_common(self): | ||
| config_and_inputs = self.prepare_config_and_inputs() | ||
| ( |
There was a problem hiding this comment.
We have a test_torchscript = False attribute line 243 of the test file that should be set to true!
|
@Szustarol and @ArthurZucker - would it be possible to resolve the few review questions/improvement suggestion and get this PR merged in the near future? This is a very valuable contribution to the library, eagerly anticipated by many who want to use fine-tuned DeBerta outside of pure Python ecosystem. |
|
Down for this ! 🔥 If I don't get and answer I'll just dive a bit! 🤗 Sorry all for the many delays |
|
I tested the PR and It works for jit script. |
|
+1 Also waiting on this, it would help my project a ton if we can merge. Thank you for working on this! |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Sorry to reopen this! Please feel free to close if no one else wants this but this would help my project a lot and it seems just about done. Can it be merged or is there more work left? Thank you! |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
ArthurZucker
left a comment
There was a problem hiding this comment.
BTW, incorporated most of these in #22105 and will add you as co-author! Gonna close this once merge
What does this PR do?
Fixes #20815
Generally,
torch.autograd.Functionscannot be traced in Torch, as per this open issue: pytorch/pytorch#32822This issue is thus more of a PyTorch problem, but nevertheless can be resolved. 🤗 Transformers' implementation is basically the same as the original https://github.com/microsoft/DeBERTa, which was tracable with a dirty trick of using a tracing context:
https://github.com/microsoft/DeBERTa/blob/4d7fe0bd4fb3c7d4f4005a7cafabde9800372098/DeBERTa/utils/jit_tracing.py#L10C1-L17C6
Of course such a solution is not applicable here as it would conflict with the existing API and usage of the 🤗 Transformers. I have decided to explore a bit the recent development in PyTorch and it seems
is_tracingis now publicly accessible throughtorch.jit(though it is not yet documented), which gets rid of the context problem. So I have basically implemented the original solution but with the newly availableis_tracingcall.I have also added tests to check if the traced model outputs the same tensors as the model that is being traced.
This was not mentioned in the issue but I have applied the same changes to Deberta_v2 since it is obviously also affected.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker