Skip to content

Conversation

@lara-hdr
Copy link
Contributor

@lara-hdr lara-hdr commented Jun 5, 2019

No description provided.

@pytorchbot pytorchbot added module: nn Related to torch.nn module: onnx Related to torch.onnx labels Jun 5, 2019
@lara-hdr
Copy link
Contributor Author

lara-hdr commented Jun 5, 2019

waiting on #20533 to be merged (tests will fail until merged with the referenced PR)

@ezyang ezyang requested a review from houseroad June 6, 2019 16:46
@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 6, 2019
return symbolic_fn


upsample_nearest1d = _interpolate('upsample_nearest1d', 3, "nearest")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fuse upsamples together is great!

maybe we should have some convention on the prefix of the helper functions, such as _interpolate, slice_op, they are not real symbolics, just some helper functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed slice to follow the same convention as _interpolate and other ops (_max_pool, _avg_pool).

Copy link
Member

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebase to the master?

return [(torch.floor(input.size(i + 2) * torch.tensor(float(scale_factors[i])))) for i in range(dim)]
return [(torch.floor(float(input.size(i + 2)) * torch.tensor(float(scale_factors[i])))) for i in range(dim)]
else:
return [int(math.floor(int(input.size(i + 2)) * scale_factors[i])) for i in range(dim)]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cast was making the input of the operator Floor set as "integer", which is not supported (Floor only accepts float and double as input types), and fails in onnxruntime.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused here, seems the tracer won't go with the else branch, why will it affect onnx's export? do you test the script module? But it's equivalent change, so i am okay with it.

upsample_nearest1d = _interpolate('upsample_nearest1d', 3, "nearest")
upsample_nearest2d = _interpolate('upsample_nearest2d', 4, "nearest")
upsample_nearest3d = _interpolate('upsample_nearest3d', 5, "nearest")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed support for linear mode in both opsets 9 and 10.
Linear mode behaves differently in onnx and pytorch (with both align_coner enabled and disabled).
The exported model was not matching numbers with onnxruntime.
(the difference is explained here : microsoft/onnxruntime#1179).

Will try to add support to linear mode in a different PR.

@lara-hdr
Copy link
Contributor Author

@pytorchbot rebase this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@lara-hdr
Copy link
Contributor Author

@houseroad feedback on this?

@houseroad
Copy link
Member

I think the failing tests are related. Could you take a look?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Member

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Please address the questions before merge.

return [(torch.floor(input.size(i + 2) * torch.tensor(float(scale_factors[i])))) for i in range(dim)]
return [(torch.floor(float(input.size(i + 2)) * torch.tensor(float(scale_factors[i])))) for i in range(dim)]
else:
return [int(math.floor(int(input.size(i + 2)) * scale_factors[i])) for i in range(dim)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused here, seems the tracer won't go with the else branch, why will it affect onnx's export? do you test the script module? But it's equivalent change, so i am okay with it.

# make scale_factor a tensor in tracing so constant doesn't get baked in
if torch._C._get_tracing_state():
return [(torch.floor(input.size(i + 2) * torch.tensor(float(scale_factors[i])))) for i in range(dim)]
return [(torch.floor((input.size(i + 2) * torch.tensor(float(scale_factors[i]))).float())) for i in range(dim)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems the old code should generate the float tensor as well, is .float() conversion necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this, the generated ONNX graph fails in ORT and complains that the input of the "Floor" node is an Int.
I could submit an example of the ONNX file and the error if required.

For the following line, I did the same for consistency and out of precaution.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, i am wondering how the onnx model looke like? could you attach the exported model?

@lara-hdr
Copy link
Contributor Author

lara-hdr commented Jun 19, 2019

Resize.zip

@houseroad I attached the model generated without the changes in functional.py, for the code below.
I also added a screenshot of the error I mentioned above.
(for opset9 I also get a similar model (with Upsample node instead of Resize) and error)

def Test_Upsample_op():
class Test_Upsample(nn.Module):
def forward(self, input):
return nn.functional.interpolate(input, mode="nearest", scale_factor=2)
model = Test_Upsample()
input = torch.tensor([[[[1., 2],[3, 4],]]])
output = model(input)
onnx_helper.Save('upsample', 'upsample', model, input, output, 10)

@facebook-github-bot
Copy link
Contributor

@houseroad merged this pull request in 34aee93.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: nn Related to torch.nn module: onnx Related to torch.onnx open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants