-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] allow dict to be mixed between tracing and scripting #39601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 6065f8a (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 3 failures confirmed as flaky and can be ignored:
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 23 times. |
[ghstack-poisoned]
[ghstack-poisoned]
|
Thanks for working on this! Are the test failures related? Also is it possible to support Dict/List recursively, e.g. Dict[str, List[Tensor]]? (sorry for asking too much :( ) |
The test failures seems un-related, let me try to run them again. For recursive support, I think the current solution should work, if you look at the test case I added, it is |
|
Thanks a lot @wanchaol! This will unblock PyPer Feed preproc model tracing, and we urgently need this feature now. cc. @alyssawangqq |
jamesr66a
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, looks fine. But I'd be interested to know:
- Is an embedded ScriptObject call the only place we can see this used? I don't recall any ops that take dict, for example, but it would be nice to double check
- What failure modes might look like, e.g. tracing a program that passes a dynamic dict to this thing, but the trace encode static values
|
With this patch applied on top of current master, I notice that if I traced the module with import torch
from typing import Dict
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
def forward(self, dict_input):
@torch.jit.script
def script_function(dict_input: Dict[str, torch.Tensor]):
return dict_input['x']
return script_function(dict_input)
input_1 = {'x': torch.tensor(1)}
input_2 = {'x': torch.tensor(2), 'y': torch.tensor(3)}
m = TestModule()
m_traced = torch.jit.trace(m, (input_1, ))
print(m_traced(input_1))
print(m_traced(input_2))Output: |
I think it's true that no other real ops is taking dict. this only applies to mix tracing and scripting, so another possibility is to trace a function that embeds a script function (which takes dict as input).
This seems the failure that @yf225 just posted, that we actually encode the static dict inputs here pytorch/torch/csrc/jit/frontend/tracer.cpp Line 331 in caaf49d
|
|
So as this is a separate issue from what we try to do in this PR, record this in #40529 and move the discussion there. |
|
Not supporting dynamic dict should be ok for my use case, as I expect the number of features will remain the same when we trace and actually using it :) |
A combination of pytorch#39601 and pytorch#40424 both are approved and merged in master
# Support unpacking python dictionary in **torch.jit.trace()**
## Problem statement & Motivation
### Problem 1(usability):
Say, if you have a model and its forward method defined as follows:
**`def forward(self, key1=value1, key2=value2, key3=value3)`**
And you have a dataset and each data point in the dataset is a python dict as follows:
**`data = {key1:value1, key3:value3, key2:value2}`**
The problem is that if you want to trace the model using the dict data by the giving dataset, you need unpack the dictionary and reorder its value manually and make up a tuple as **`data_tuple = (value1, value2, value3)`** as the **`example_inputs`** parameter of **`torch.jit.trace()`**. This marshalling process is not user friendly.
### Problem 2 (feasibility):
Say, if you have a model and its forward method defined as follows:
**`def forward(self, key1=None, key2=None, key3=None)`** -> The default value is **None**
And you have a dataset and each data point in the dataset is a python dict as follows:
**`data = {key1:value1, key3:value3}`** -> Only **part of** the required value by forward was given, the rest use the default value.
The problem is that if you want to trace the model using the dict data by the giving dataset, it's not feasible at all. Cause neither you can pass a tuple like **`T1 = (value1, value3)`** nor **`T2 = (value1, None, value3)`**. T1 will mismatch value3 with key2 and T2 include **None** type which will be blocked by tracer's type checking. (Of course you can pass **`T3 = (value1,)`** to make the trace function finish without exception, but the traced model you get probably is not what you expect cause the different input may result in different traced result.).
These problems come from the HuggingFace's PT model, especially in text-classification tasks with datasets such as [MRPC,](https://paperswithcode.com/dataset/mrpc) [MNLI](https://paperswithcode.com/dataset/multinli) etc.
## Solution
To address these two issues, we propose to support a new type, that is, python dict as example_inputs parameter for torch.jit.trace(). We can base on the runtime type information of the example_inputs object to determine if we fall back to the original tuple path or go into the new dictionary path. Both problem 1 and problem 2 can be solved by utilizing the "**`**`**"
operator.
## Limitation & Mitigation
1. If we use dict as example_inputs to trace the model, then we have to pass a dictionary to the traced model too. (Cause probably we will change the order of debug name of the input parameter in torchscript IR, thus we can't assume the traced model's input parameters order are the same with the original model.). We need highlight this too in the document to mitigate this problem.
For example:
```
# fetch a data from dataloader, and the data is a dictionary
# and the example_inputs_dict is like: {key1:value1, key3:value3, key2:value2}
# the forward() is like: def forward(self, key1=value1, key2=value2, key3=value3)
example_inputs_dict = next(iter(dataloader))
jit_model = model.eval()
# use the dictionary to trace the model
jit_model = torch.jit.trace(jit_model, example_inputs_dict, strict=False) # Now the IR will be graph(%self : __torch__.module.___torch_mangle_n.Mymodule, %key1 : type1, %key3 : type3, %key2 : type2)
jit_model = torch.jit.freeze(jit_model)
# It's OK to use dict as the parameter for traced model
jit_model(**example_inputs_dict)
example_inputs_tuple = (value1, value3, value2)
# It's wrong to rely on the original args order.
jit_model(*example_inputs_tuple)
```
## Note
1. This PR will make some UT introduced in [39601](#39601) fail, which I think should be classified as unpacking a tuple containing a single dictionary element in our solution.
4. I think there is ambiguity since currently we only specify passing a tuple or a single Tensor as our example_inputs parameter in **torch.jit.trace()**'s documentation, but it seems we can still passing a dictionary.
Pull Request resolved: #81623
Approved by: https://github.com/davidberard98
…) (pytorch#99) # Support unpacking python dictionary in **torch.jit.trace()** ## Problem statement & Motivation ### Problem 1(usability): Say, if you have a model and its forward method defined as follows: **`def forward(self, key1=value1, key2=value2, key3=value3)`** And you have a dataset and each data point in the dataset is a python dict as follows: **`data = {key1:value1, key3:value3, key2:value2}`** The problem is that if you want to trace the model using the dict data by the giving dataset, you need unpack the dictionary and reorder its value manually and make up a tuple as **`data_tuple = (value1, value2, value3)`** as the **`example_inputs`** parameter of **`torch.jit.trace()`**. This marshalling process is not user friendly. ### Problem 2 (feasibility): Say, if you have a model and its forward method defined as follows: **`def forward(self, key1=None, key2=None, key3=None)`** -> The default value is **None** And you have a dataset and each data point in the dataset is a python dict as follows: **`data = {key1:value1, key3:value3}`** -> Only **part of** the required value by forward was given, the rest use the default value. The problem is that if you want to trace the model using the dict data by the giving dataset, it's not feasible at all. Cause neither you can pass a tuple like **`T1 = (value1, value3)`** nor **`T2 = (value1, None, value3)`**. T1 will mismatch value3 with key2 and T2 include **None** type which will be blocked by tracer's type checking. (Of course you can pass **`T3 = (value1,)`** to make the trace function finish without exception, but the traced model you get probably is not what you expect cause the different input may result in different traced result.). These problems come from the HuggingFace's PT model, especially in text-classification tasks with datasets such as [MRPC,](https://paperswithcode.com/dataset/mrpc) [MNLI](https://paperswithcode.com/dataset/multinli) etc. ## Solution To address these two issues, we propose to support a new type, that is, python dict as example_inputs parameter for torch.jit.trace(). We can base on the runtime type information of the example_inputs object to determine if we fall back to the original tuple path or go into the new dictionary path. Both problem 1 and problem 2 can be solved by utilizing the "**`**`**" operator. ## Limitation & Mitigation 1. If we use dict as example_inputs to trace the model, then we have to pass a dictionary to the traced model too. (Cause probably we will change the order of debug name of the input parameter in torchscript IR, thus we can't assume the traced model's input parameters order are the same with the original model.). We need highlight this too in the document to mitigate this problem. For example: ``` # fetch a data from dataloader, and the data is a dictionary # and the example_inputs_dict is like: {key1:value1, key3:value3, key2:value2} # the forward() is like: def forward(self, key1=value1, key2=value2, key3=value3) example_inputs_dict = next(iter(dataloader)) jit_model = model.eval() # use the dictionary to trace the model jit_model = torch.jit.trace(jit_model, example_inputs_dict, strict=False) # Now the IR will be graph(%self : __torch__.module.___torch_mangle_n.Mymodule, %key1 : type1, %key3 : type3, %key2 : type2) jit_model = torch.jit.freeze(jit_model) # It's OK to use dict as the parameter for traced model jit_model(**example_inputs_dict) example_inputs_tuple = (value1, value3, value2) # It's wrong to rely on the original args order. jit_model(*example_inputs_tuple) ``` ## Note 1. This PR will make some UT introduced in [39601](pytorch#39601) fail, which I think should be classified as unpacking a tuple containing a single dictionary element in our solution. 4. I think there is ambiguity since currently we only specify passing a tuple or a single Tensor as our example_inputs parameter in **torch.jit.trace()**'s documentation, but it seems we can still passing a dictionary. Pull Request resolved: pytorch#81623 Approved by: https://github.com/davidberard98 Co-authored-by: tangleintel <[email protected]>
Stack from ghstack:
Differential Revision: D22202689