Skip to content

Conversation

@Krovatkin
Copy link
Contributor

@Krovatkin Krovatkin commented Apr 29, 2019

This PR adds a new trace API trace_module that will allow us to trace multiple methods as a part of a single ScriptModule

See the example below.

        class Net(nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.conv = nn.Conv2d(1, 1, 3)

            def forward(self, x):
                return self.conv(x)

            def weighted_kernel_sum(self, weight):
                return weight * self.conv.weight

        example_weight = torch.rand(1, 1, 3, 3)
        example_forward_input = torch.rand(1, 1, 3, 3)
        n = Net()
        inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
        module = torch.jit.trace_module(n, inputs)

@Krovatkin Krovatkin requested a review from zdevito April 29, 2019 05:26
@pytorchbot pytorchbot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Apr 29, 2019
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the implementation is correct, but the API is weird at the moment. Not sure users will be able to find trace_dict and it doesn't feel simple. I also think the way def trace is organized can be improved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not a fan of this API. I think the problem is that you can't distinguish a dict example input from a dict with string_key->example input. Instead, we can have users pass in the method rather than the field name:

        module = torch.jit.trace(n, {n.forward: example_forward_input, n.weight_sum_kernel: example_weight})

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is going on here? The method_inputs.items() loop is only sensible if the thing is a module. I feel like this could be refactored to be a lot more readable.

@Krovatkin Krovatkin requested a review from zdevito April 30, 2019 05:50
Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't like the dict-overloaded API too much. What's the context for this change? It's usually helpful to include some details in the PR description.

@Krovatkin
Copy link
Contributor Author

@apaszke added the description and a simple example.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Ready to go once check_inputs has a test case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this line is covered by the tests. check_inputs would need to be different for each method but it looks like we are using the same inputs here.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Krovatkin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@Krovatkin merged this pull request in 7ddd5d0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants