Skip to content

Conversation

@gpetters94
Copy link
Collaborator

This adds a few small changes that are needed for OPT support, namely:

  • Support for aten::view when the output shape is statically known

  • Folding away torch::type_as when both arguments are the same type

  • Support for torch::masked_fill when the mask is a float type

@ramiro050
Copy link
Collaborator

Can you split the changes into 3 PRs? They are all quite independent from one another. It would also make the commit titles a lot more descriptive, since each PR could get as a title the description you have in the bullet points above.

@gpetters94 gpetters94 changed the title Add cases for view, type_as, and masked_fill Add support for float mask to aten::masked_fill Sep 9, 2022
@gpetters94
Copy link
Collaborator Author

I've split the PRs, this one is for masked_fill now.

@gpetters94 gpetters94 requested a review from ramiro050 September 9, 2022 17:40
Copy link
Collaborator

@ramiro050 ramiro050 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just have a small change request, but other than that it LGTM

@gpetters94 gpetters94 force-pushed the masked_fill branch 3 times, most recently from 003853f to f73493e Compare September 14, 2022 16:15
Value input = payloadArgs[0];
Value mask = payloadArgs[1];
if (mask.getType().isa<mlir::FloatType>())
mask = b.create<arith::ConstantOp>(loc, b.getBoolAttr(false));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I missed this the first time I reviewed your changes. Why is mask being set to false here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems to be expected behavior. I was casting to Int1 at first, but further testing shows that it seems to treat all floats as false. I haven't found anything in the documentation about it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a bug upstream. I would actually expect float mask to result in a runtime error, since this is the behavior that aten.masked_select has:

https://github.com/pytorch/pytorch/blob/5b58140d1a471b144baf66cc61a45a86746f0215/aten/src/ATen/native/TensorAdvancedIndexing.cpp#L1720-L1721

Can you file a bug upstream for this?

Copy link
Collaborator Author

@gpetters94 gpetters94 Sep 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, the bug report is up here. I'm going to leave this as-is for now in case it's expected behavior but I'll add an assert if it isn't.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants