Skip to content

Deberta can now be exported to TorchScript #27734

Open
Szustarol wants to merge 4 commits intohuggingface:mainfrom
Szustarol:fix-export-deberta-to-torchscript
Open

Deberta can now be exported to TorchScript #27734
Szustarol wants to merge 4 commits intohuggingface:mainfrom
Szustarol:fix-export-deberta-to-torchscript

Conversation

@Szustarol
Copy link
Copy Markdown
Contributor

@Szustarol Szustarol commented Nov 27, 2023

What does this PR do?

Fixes #20815

Generally, torch.autograd.Functions cannot be traced in Torch, as per this open issue: pytorch/pytorch#32822
This issue is thus more of a PyTorch problem, but nevertheless can be resolved. 🤗 Transformers' implementation is basically the same as the original https://github.com/microsoft/DeBERTa, which was tracable with a dirty trick of using a tracing context:
https://github.com/microsoft/DeBERTa/blob/4d7fe0bd4fb3c7d4f4005a7cafabde9800372098/DeBERTa/utils/jit_tracing.py#L10C1-L17C6
Of course such a solution is not applicable here as it would conflict with the existing API and usage of the 🤗 Transformers. I have decided to explore a bit the recent development in PyTorch and it seems is_tracing is now publicly accessible through torch.jit (though it is not yet documented), which gets rid of the context problem. So I have basically implemented the original solution but with the newly available is_tracing call.

I have also added tests to check if the traced model outputs the same tensors as the model that is being traced.
This was not mentioned in the issue but I have applied the same changes to Deberta_v2 since it is obviously also affected.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker

@demq
Copy link
Copy Markdown

demq commented Nov 27, 2023

Great work @Szustarol ! I think this solution still pins the device for the traced model. One solution is to create an ONNX model instead. Perhaps you might have a better solution?

@Szustarol
Copy link
Copy Markdown
Contributor Author

Thanks for the response @demq, I would be more than happy to expand on my solution, however I am not quite sure if I get your suggestion right.
I wanted to fix the issue of the Deberta model being untracable with PyTorch torch.jit.trace api, do you think I am missing something here? Also I believe that, if the device is pinned, it is an effect of torch.jit.trace usage, but the only place where this happens is the testing code where it should not be a problem, since it is run in a single test setup and not actually saved anywhere. Unless you mean a complete reimplementation of the Deberta model to not use calls that lead to a pinned device (if this actually happens, I can check for that tomorrow)?
I might have misunderstood something since I'm just learning the ropes of 🤗 Transformers, I am sorry in advance!

@demq
Copy link
Copy Markdown

demq commented Nov 28, 2023

Thanks for the response @demq, I would be more than happy to expand on my solution, however I am not quite sure if I get your suggestion right. I wanted to fix the issue of the Deberta model being untracable with PyTorch torch.jit.trace api, do you think I am missing something here? Also I believe that, if the device is pinned, it is an effect of torch.jit.trace usage, but the only place where this happens is the testing code where it should not be a problem, since it is run in a single test setup and not actually saved anywhere. Unless you mean a complete reimplementation of the Deberta model to not use calls that lead to a pinned device (if this actually happens, I can check for that tomorrow)? I might have misunderstood something since I'm just learning the ropes of 🤗 Transformers, I am sorry in advance!

Yes - your solution here correctly addresses the open issue. You are absolutely correct that the device pinning is caused by jit.trace(), and it is a separate issue from what you have addressed here. Previously, I experienced the issue of device pinning on a traced (using the decorator trick on XSoftmax) fine-tuned Deberta model, so we had to export to ONNX instead of torch script to get around it.

My comment was inspired by seeing how quick you were to submit a solution to this issue :) I would imagine one way to ensure no device pinning is by exporting the model through jit.script() instead of jit.trace().

@Szustarol
Copy link
Copy Markdown
Contributor Author

Okay, I think now I understand your point, thanks! I will surely have a look into it and report if I can find a solution. I think one way to fix this is to check what parts of the code cause model pinning and try to get rid of them which is what I will try. Should we open an issue for this?

@Szustarol Szustarol force-pushed the fix-export-deberta-to-torchscript branch from 844a1da to 7657821 Compare November 28, 2023 21:48
@Szustarol
Copy link
Copy Markdown
Contributor Author

Szustarol commented Nov 28, 2023

Okay I have done some extensive research and testing on this subject mentioned by @demq and it appears to me that during the tracing process all devices in the forward calls are pinned, so no .to(device=...) or tensor(..., device=some_tensor.device) lines should be present in the traced code.
Since there is no review yet I decided to solve it in the same PR.
Luckily we can intertwine traced and scripted code and the solution is to move all of the device-dependent tensor creations to a separate scripted callable which is exactly what I did.
Sadly I cannot provide a test for this case since I can never be sure what devices are available on a testbench, but if someone wants to try this out, it can be done with this snippet of code, which traces the model on the GPU, but executes it on CPU:

import torch
import io
from transformers import AutoTokenizer, AutoModel


# tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v2-xlarge")
# model = AutoModel.from_pretrained("microsoft/deberta-v2-xlarge", torchscript=True).to("cuda")

tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-base")
model = AutoModel.from_pretrained("microsoft/deberta-base", torchscript=True).to("cuda")

tokenized_dict = tokenizer(
    ["Is this working",], ["Not yet",], 
    return_tensors="pt"
).to('cuda')
input_tuple = (tokenized_dict['input_ids'], tokenized_dict['attention_mask'])


traced_model = torch.jit.trace(model, input_tuple)

model_bytes = io.BytesIO()
torch.jit.save(traced_model, model_bytes)
model_bytes.seek(0)

print("######### tracing and saving done")

loaded_model = torch.jit.load(model_bytes, map_location='cpu')
tokenized_dict = tokenized_dict.to('cpu')
input_tuple = (tokenized_dict['input_ids'], tokenized_dict['attention_mask'])
print(loaded_model(*input_tuple)[0].device) # outputs cpu

I would love to hear some feedback on this code, as with the extension of this task by the device problem I feel like the model modifications have now become quite extensive.

@demq
Copy link
Copy Markdown

demq commented Nov 29, 2023

Thanks for the great update. I can confirm that tracing works now without device pinning in my local linux env: I have used slightly changed test script:

import torch
import io
from transformers import AutoTokenizer, AutoModel

# tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v2-xlarge")
# model = AutoModel.from_pretrained("microsoft/deberta-v2-xlarge", torchscript=True).to("cuda")

tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-base")
model = AutoModel.from_pretrained("microsoft/deberta-base", torchscript=True).to("cuda")

encodings_cuda = tokenizer(
    ["The DeBerta tracing works without device pinning!"],
    return_token_type_ids=False,
    return_tensors="pt"
).to("cuda")

traced_model = torch.jit.trace(model, list(encodings_cuda.values()))
print(f"{traced_model(*encodings_cuda.values())[0].device=}")

model_bytes = io.BytesIO()
torch.jit.save(traced_model, model_bytes)
model_bytes.seek(0)

print("######### tracing and saving done")

loaded_model = torch.jit.load(model_bytes, map_location='cpu')
encodings_cpu = encodings_cuda.copy().to('cpu')
print(f"{loaded_model(*encodings_cpu.values())[0].device=}") # outputs cpu

It works fine, but I get the following warnings during the jit.trace():

>>> traced_model = torch.jit.trace(model, list(encodings_cuda.values()))
./transformers/src/transformers/models/deberta/modeling_deberta.py:694: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
./transformers/src/transformers/models/deberta/modeling_deberta.py:694: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
./transformers/src/transformers/models/deberta/modeling_deberta.py:733: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  att_span = min(max(query_layer.size(-2), key_layer.size(-2)), self.max_relative_positions)
./transformers/src/transformers/models/deberta/modeling_deberta.py:754: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  pos_query_layer /= torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
./transformers/src/transformers/models/deberta/modeling_deberta.py:754: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  pos_query_layer /= torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
./transformers/src/transformers/models/deberta/modeling_deberta.py:755: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if query_layer.size(-2) != key_layer.size(-2):
./transformers/src/transformers/models/deberta/modeling_deberta.py:765: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if query_layer.size(-2) != key_layer.size(-2):
./transformers/src/transformers/models/deberta/modeling_deberta.py:140: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))

@Szustarol
Copy link
Copy Markdown
Contributor Author

Yes of course you are right, first time seeing those errors I have erroneously assumed that the tensor shape will be constant for a given config which might be true for size(-1) but is certainly not true for size(-2). Those parts will have to be exported as a script as well. I will take care of it as soon as possible.
Thank you for bringing this to my attention!

@github-actions
Copy link
Copy Markdown
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Hey! Do you need a review on this?

@github-actions
Copy link
Copy Markdown
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@Szustarol Szustarol force-pushed the fix-export-deberta-to-torchscript branch from 7657821 to 16858cb Compare January 27, 2024 09:16
@Szustarol
Copy link
Copy Markdown
Contributor Author

Szustarol commented Jan 27, 2024

The reason it took so long is that while trying to fix the issue mentioned above, with constant size pinning, I have noticed that the HF tests use a HPFProxy object during tracing, which is not compatible with torch.jit.script - at the same time the usage of script is required to get of the constant sizes being pinned during model tracing. As I have found no way to reliably resolve this issue I ask here for counsel.
Please see the test run below to see what test I am referring to.

@demq
Copy link
Copy Markdown

demq commented Jan 30, 2024

The reason it took so long is that while trying to fix the issue mentioned above, with constant size pinning, I have noticed that the HF tests use a HPFProxy object during tracing, which is not compatible with torch.jit.script - at the same time the usage of script is required to get of the constant sizes being pinned during model tracing. As I have found no way to reliably resolve this issue I ask here for counsel. Please see the test run below to see what test I am referring to.

The latest version seems to work fine for me @Szustarol , thank you very much for your efforts! I can trace a model on a cpu without any warnings and run inference on GPUs, there is no apparent device pinning or warnings on constant size pinning.

Perhaps it is worth to ask @ArthurZucker to review the PR and merge it to master?

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Sure let me review it!

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I'm not an torch script expert, would you mind educating me a bit?
Could you walk me through motivations behind some of the changes?
(LGTM overall!)

return self.config.hidden_size


# copied from transformers.models.deberta.modeling_deberta._traceable
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# copied from transformers.models.deberta.modeling_deberta._traceable
# Copied from transformers.models.deberta.modeling_deberta._traceable

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All other copied from need this capital 😉

Comment on lines +133 to +137
if tensor.dtype in [torch.float16, torch.float32, torch.float64]:
# Will not be baked in during tracing
return _get_float_min_value(tensor)
# Will be baked in during tracing
return torch.tensor(torch.finfo(tensor.dtype).min)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does tracing have an issue with this? Known bug or expected behaviour?


Return:
`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
`torch.LongTensor`: A tensor with shape [1, query_layer.size, key-layer.size]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`torch.LongTensor`: A tensor with shape [1, query_layer.size, key-layer.size]
`torch.LongTensor`: A tensor with shape [1, query_length, key_length]

q_ids = torch.arange(query_size, dtype=torch.long, device=query_layer.device)
k_ids = torch.arange(key_size, dtype=torch.long, device=key_layer.device)
rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1)
rel_pos_ids = rel_pos_ids[:query_size, :]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rel_pos_ids = rel_pos_ids[:query_size, :]
rel_pos_ids = rel_pos_ids[q_ids, :]

pretty sure when compiling this is faster. I don't know exactly for scripting, but it's slow for cpu infer. Just noting here !

Comment on lines +614 to +616
@torch.jit.script
def scaled_size_sqrt(query_layer, scale_factor: int):
return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this have to be define as such? (FMI) known torch script issue or expected way to do so?


if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
token_type_ids = get_token_type_ids_on_device(device_discriminator, using_ids)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we do this after computing the word embedding we can always use the same device discriminator no ?



@torch.jit.script
# copied from transformers.models.deberta.modeling_deberta._get_float_min_value
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# copied from transformers.models.deberta.modeling_deberta._get_float_min_value
# Copied from transformers.models.deberta.modeling_deberta._get_float_min_value


def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a test_torchscript = False attribute line 243 of the test file that should be set to true!

@demq
Copy link
Copy Markdown

demq commented Mar 12, 2024

@Szustarol and @ArthurZucker - would it be possible to resolve the few review questions/improvement suggestion and get this PR merged in the near future?

This is a very valuable contribution to the library, eagerly anticipated by many who want to use fine-tuned DeBerta outside of pure Python ecosystem.

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Down for this ! 🔥 If I don't get and answer I'll just dive a bit! 🤗 Sorry all for the many delays

@YongWookHa
Copy link
Copy Markdown

I tested the PR and It works for jit script.
But it's still not convertable to onnx with dynamic batch. (no problem with static fixed batching)

@Muks14x
Copy link
Copy Markdown

Muks14x commented Mar 27, 2024

+1 Also waiting on this, it would help my project a ton if we can merge. Thank you for working on this!

@github-actions
Copy link
Copy Markdown
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@dustinaxman
Copy link
Copy Markdown

Sorry to reopen this! Please feel free to close if no one else wants this but this would help my project a lot and it seems just about done. Can it be merged or is there more work left? Thank you!

@ArthurZucker ArthurZucker reopened this Apr 30, 2024
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions
Copy link
Copy Markdown
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Jun 2, 2024
@ArthurZucker ArthurZucker reopened this Jun 5, 2024
@github-actions github-actions bot closed this Jun 13, 2024
@ArthurZucker ArthurZucker reopened this Sep 27, 2024
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, incorporated most of these in #22105 and will add you as co-author! Gonna close this once merge

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Cannot export Deberta to TorchScript

7 participants