Skip to content

Conversation

@wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Jun 5, 2020

Stack from ghstack:

Differential Revision: D22202689

@wanchaol wanchaol requested a review from apaszke as a code owner June 5, 2020 21:00
wanchaol added a commit that referenced this pull request Jun 5, 2020
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jun 5, 2020
@dr-ci
Copy link

dr-ci bot commented Jun 5, 2020

💊 CI failures summary and remediations

As of commit 6065f8a (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


3 failures confirmed as flaky and can be ignored:

  • pytorch_xla_linux_bionic_py3_6_clang9_build
  • pytorch_macos_10_13_py3_test
  • pytorch_linux_xenial_py3_6_gcc5_4_build

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 23 times.

wanchaol added a commit that referenced this pull request Jun 9, 2020
wanchaol added a commit that referenced this pull request Jun 23, 2020
@wanchaol wanchaol requested review from houseroad and jamesr66a June 23, 2020 17:04
@xw285cornell
Copy link
Contributor

Thanks for working on this! Are the test failures related?

Also is it possible to support Dict/List recursively, e.g. Dict[str, List[Tensor]]? (sorry for asking too much :( )

@wanchaol
Copy link
Collaborator Author

Thanks for working on this! Are the test failures related?

Also is it possible to support Dict/List recursively, e.g. Dict[str, List[Tensor]]? (sorry for asking too much :( )

The test failures seems un-related, let me try to run them again.

For recursive support, I think the current solution should work, if you look at the test case I added, it is Dict[str, List[Tensor]

@yf225
Copy link
Contributor

yf225 commented Jun 24, 2020

Thanks a lot @wanchaol! This will unblock PyPer Feed preproc model tracing, and we urgently need this feature now. cc. @alyssawangqq

Copy link
Collaborator

@jamesr66a jamesr66a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, looks fine. But I'd be interested to know:

  1. Is an embedded ScriptObject call the only place we can see this used? I don't recall any ops that take dict, for example, but it would be nice to double check
  2. What failure modes might look like, e.g. tracing a program that passes a dynamic dict to this thing, but the trace encode static values

@yf225
Copy link
Contributor

yf225 commented Jun 24, 2020

With this patch applied on top of current master, I notice that if I traced the module with {'x': torch.tensor(1)} (i.e. dict with 1 key), I won't be able to run the traced module with {'x': torch.tensor(1), 'y': torch.tensor(2)} (i.e. dict with 2 keys). Curious is this expected? Thanks!

import torch

from typing import Dict

class TestModule(torch.nn.Module):
    def __init__(self):
        super(TestModule, self).__init__()

    def forward(self, dict_input):
        @torch.jit.script
        def script_function(dict_input: Dict[str, torch.Tensor]):
            return dict_input['x']
        return script_function(dict_input)

input_1 = {'x': torch.tensor(1)}
input_2 = {'x': torch.tensor(2), 'y': torch.tensor(3)}

m = TestModule()
m_traced = torch.jit.trace(m, (input_1, ))
print(m_traced(input_1))
print(m_traced(input_2))

Output:

tensor(1)
Traceback (most recent call last):
  File "test_yf225.py", line 21, in <module>
    print(m_traced(input_2))
  File "/data/miniconda3/envs/working3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Expected 1 elements in a list but found 2

@wanchaol
Copy link
Collaborator Author

@jamesr66a

  1. Is an embedded ScriptObject call the only place we can see this used? I don't recall any ops that take dict, for example, but it would be nice to double check

I think it's true that no other real ops is taking dict. this only applies to mix tracing and scripting, so another possibility is to trace a function that embeds a script function (which takes dict as input).

  1. What failure modes might look like, e.g. tracing a program that passes a dynamic dict to this thing, but the trace encode static values

This seems the failure that @yf225 just posted, that we actually encode the static dict inputs here

state->graph->createListUnpack(unpack_to_list, dict_size);
when add the tracer inputs. IIRC this is because we need to connect the dict value tensors to the symbolic values generated, so that we could use it later on. This is something weird behavior that we impose to the user though

@wanchaol
Copy link
Collaborator Author

So as this is a separate issue from what we try to do in this PR, record this in #40529 and move the discussion there.

@xw285cornell
Copy link
Contributor

Not supporting dynamic dict should be ok for my use case, as I expect the number of features will remain the same when we trace and actually using it :)

@facebook-github-bot
Copy link
Contributor

@wanchaol merged this pull request in 3e09268.

@facebook-github-bot facebook-github-bot deleted the gh/wanchaol/109/head branch June 28, 2020 14:17
wanchaol added a commit to wanchaol/pytorch that referenced this pull request Jun 30, 2020
A combination of pytorch#39601 and
pytorch#40424 both are approved and
merged in master
malfet pushed a commit that referenced this pull request Jul 1, 2020
A combination of #39601 and
#40424 both are approved and
merged in master
pytorchmergebot pushed a commit that referenced this pull request Oct 15, 2022
# Support unpacking python dictionary in **torch.jit.trace()**

## Problem statement & Motivation
### Problem 1(usability):
Say, if you have a model and its forward method defined as follows:
**`def forward(self, key1=value1, key2=value2, key3=value3)`**
And you have a dataset and each data point in the dataset is a python dict as follows:
**`data = {key1:value1, key3:value3, key2:value2}`**

The problem is that if you want to trace the model using the dict data by the giving dataset, you need unpack the dictionary and reorder its value manually and make up a tuple as **`data_tuple = (value1, value2, value3)`** as the **`example_inputs`** parameter of **`torch.jit.trace()`**. This marshalling process is not user friendly.

### Problem 2 (feasibility):
Say, if you have a model and its forward method defined as follows:
**`def forward(self, key1=None, key2=None, key3=None)`** -> The default value is **None**
And you have a dataset and each data point in the dataset is a python dict as follows:
**`data = {key1:value1, key3:value3}`** -> Only **part of** the required value by forward was given, the rest use the default value.

The problem is that if you want to trace the model using the dict data by the giving dataset, it's not feasible at all. Cause neither you can pass a tuple like **`T1 = (value1, value3)`**  nor **`T2 = (value1, None, value3)`**. T1 will mismatch value3 with key2 and T2 include **None** type which will be blocked by tracer's type checking. (Of course you can pass **`T3 = (value1,)`**  to make the trace function finish without exception, but the traced model you get probably is not what you expect cause the different input may result in different traced result.).

These problems come from the HuggingFace's PT model, especially in text-classification tasks with datasets such as [MRPC,](https://paperswithcode.com/dataset/mrpc)  [MNLI](https://paperswithcode.com/dataset/multinli) etc.

## Solution
To address these two issues, we propose to support a new type, that is, python dict as example_inputs parameter for torch.jit.trace(). We can base on the runtime type information of the example_inputs object to determine if we fall back to the original tuple path or go into the new dictionary path. Both problem 1 and  problem 2 can be solved by utilizing the "**`**`**"
operator.

## Limitation & Mitigation

1. If we use dict as example_inputs to trace the model, then we have to pass a dictionary to the traced model too. (Cause probably we will change the order of debug name of the input parameter in torchscript IR, thus we can't assume the traced model's input parameters order are the same with the original model.). We need highlight this too in the document to mitigate this problem.

    For example:
```
# fetch a data from dataloader, and the data is a dictionary
# and the example_inputs_dict is like: {key1:value1, key3:value3, key2:value2}
# the forward() is like: def forward(self, key1=value1, key2=value2, key3=value3)
example_inputs_dict = next(iter(dataloader))
jit_model = model.eval()
# use the dictionary to trace the model
jit_model = torch.jit.trace(jit_model, example_inputs_dict, strict=False)  # Now the IR will be graph(%self : __torch__.module.___torch_mangle_n.Mymodule, %key1 : type1, %key3 : type3, %key2 : type2)
jit_model = torch.jit.freeze(jit_model)

# It's OK to use dict as the parameter for traced model
jit_model(**example_inputs_dict)

example_inputs_tuple = (value1, value3, value2)
# It's wrong to rely on the original args order.
jit_model(*example_inputs_tuple)

```
## Note
1. This PR will make some UT introduced in [39601](#39601) fail, which I think should be classified as unpacking a tuple containing a single dictionary element in our solution.
4. I think there is ambiguity since currently we only specify passing a tuple or a single Tensor as our example_inputs parameter in **torch.jit.trace()**'s documentation, but it seems we can still passing a dictionary.

Pull Request resolved: #81623
Approved by: https://github.com/davidberard98
neggles pushed a commit to neggles/pytorch that referenced this pull request Mar 9, 2023
…) (pytorch#99)

# Support unpacking python dictionary in **torch.jit.trace()**

## Problem statement & Motivation
### Problem 1(usability):
Say, if you have a model and its forward method defined as follows:
**`def forward(self, key1=value1, key2=value2, key3=value3)`**
And you have a dataset and each data point in the dataset is a python dict as follows:
**`data = {key1:value1, key3:value3, key2:value2}`**

The problem is that if you want to trace the model using the dict data by the giving dataset, you need unpack the dictionary and reorder its value manually and make up a tuple as **`data_tuple = (value1, value2, value3)`** as the **`example_inputs`** parameter of **`torch.jit.trace()`**. This marshalling process is not user friendly.

### Problem 2 (feasibility):
Say, if you have a model and its forward method defined as follows:
**`def forward(self, key1=None, key2=None, key3=None)`** -> The default value is **None**
And you have a dataset and each data point in the dataset is a python dict as follows:
**`data = {key1:value1, key3:value3}`** -> Only **part of** the required value by forward was given, the rest use the default value.

The problem is that if you want to trace the model using the dict data by the giving dataset, it's not feasible at all. Cause neither you can pass a tuple like **`T1 = (value1, value3)`**  nor **`T2 = (value1, None, value3)`**. T1 will mismatch value3 with key2 and T2 include **None** type which will be blocked by tracer's type checking. (Of course you can pass **`T3 = (value1,)`**  to make the trace function finish without exception, but the traced model you get probably is not what you expect cause the different input may result in different traced result.).

These problems come from the HuggingFace's PT model, especially in text-classification tasks with datasets such as [MRPC,](https://paperswithcode.com/dataset/mrpc)  [MNLI](https://paperswithcode.com/dataset/multinli) etc.

## Solution
To address these two issues, we propose to support a new type, that is, python dict as example_inputs parameter for torch.jit.trace(). We can base on the runtime type information of the example_inputs object to determine if we fall back to the original tuple path or go into the new dictionary path. Both problem 1 and  problem 2 can be solved by utilizing the "**`**`**"
operator.

## Limitation & Mitigation

1. If we use dict as example_inputs to trace the model, then we have to pass a dictionary to the traced model too. (Cause probably we will change the order of debug name of the input parameter in torchscript IR, thus we can't assume the traced model's input parameters order are the same with the original model.). We need highlight this too in the document to mitigate this problem.

    For example:
```
# fetch a data from dataloader, and the data is a dictionary
# and the example_inputs_dict is like: {key1:value1, key3:value3, key2:value2}
# the forward() is like: def forward(self, key1=value1, key2=value2, key3=value3)
example_inputs_dict = next(iter(dataloader))
jit_model = model.eval()
# use the dictionary to trace the model
jit_model = torch.jit.trace(jit_model, example_inputs_dict, strict=False)  # Now the IR will be graph(%self : __torch__.module.___torch_mangle_n.Mymodule, %key1 : type1, %key3 : type3, %key2 : type2)
jit_model = torch.jit.freeze(jit_model)

# It's OK to use dict as the parameter for traced model
jit_model(**example_inputs_dict)

example_inputs_tuple = (value1, value3, value2)
# It's wrong to rely on the original args order.
jit_model(*example_inputs_tuple)

```
## Note
1. This PR will make some UT introduced in [39601](pytorch#39601) fail, which I think should be classified as unpacking a tuple containing a single dictionary element in our solution.
4. I think there is ambiguity since currently we only specify passing a tuple or a single Tensor as our example_inputs parameter in **torch.jit.trace()**'s documentation, but it seems we can still passing a dictionary.

Pull Request resolved: pytorch#81623
Approved by: https://github.com/davidberard98

Co-authored-by: tangleintel <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants