-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add suport for tensor targets in for-in #19380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Fixes #19314 |
9ce9f3f to
0c9ba75
Compare
0c9ba75 to
9d42740
Compare
torch/csrc/jit/script/compiler.cpp
Outdated
| /*required=*/true); | ||
|
|
||
| auto outermost_dim_index = | ||
| graph->insertConstant(0, nullptr, range); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the result_type null? Shouldn't it be IntType::get() or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
|
Note that this doesn't work for a list of user defined types as well. Not sure how easy it is to handle after refactor |
dc7a28e to
935e4eb
Compare
|
@eellison that's a very good point! it indeed breaks down if we pass a list of user defined types into a compiled method. It kinda works if the said list is defined inside the compiled function: the example below seems to correctly print 3 "a"s. Argument specialization seems to degrade a def test_for_in_list_of_user_types(self):
@torch.jit.script
class FooTest(object):
def __init__(self, x):
self.foo = x
def getFooTest(self):
print ("a")
@torch.jit.script
def sum_list():
t = torch.zeros(2, 3)
my_list = [FooTest(t), FooTest(t), FooTest(t)]
sum = 0
for i in my_list:
i.getFooTest() |
|
repro w/ an index op def test_for_in_list_of_user_types_argument(self):
@torch.jit.script
class FooTest(object):
def __init__(self):
pass
def getFooTest(self):
print ("a")
my_list = [FooTest()]
@torch.jit.script
def bar(my_list):
t = my_list[0]
t.getFooTest()
bar(my_list) |
torch/csrc/jit/register_prim_ops.cpp
Outdated
| }; | ||
| }), | ||
| Operator( | ||
| "aten::select0(int[] self) -> int", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of creating this new op, could you instead compose it of our existing ops?
if self.dim() == 0:
raise TypeError('iteration over a 0-d tensor')
select_index = self.size(0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, I don't know. I think this will make IR more convoluted and less readable. Unless we have optimizations that can automatically infer this type of constraints, propagate and benefit from it, I am not sure if all this extra control flow is worth it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We specialize on dimensions - so we do have these optimizations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should be optimizing for IR readability in its unoptimized state, and I think in the long run adding more and more ad-hoc Ops will make our IR less readable. Also, if you look at autodiff.cpp / graph_fuser.cpp, we have existing optimizations for aten::size that this would prevent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We specialize on dimensions - so we do have these optimizations
Could you please point me to the code that records the dimensionality of list types or any more general constraints on it?
I'd like to understand this better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, look at peephole.cpp the line: node->matches("aten::dim(Tensor self) -> int. Then in constant propagation we remove if statements which have a False conditional. So the error checking here will be optimized away.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking into it we already have a aten::select that does the right thing from ATen, its signature is aten::select(Tensor(a) self, int dim, int index), so there's no need for a new op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switched to aten::dim. Thanks to @eellison for the suggestion. I didn't realize it was defined on a tensor rather than list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@driazati please take another look!
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Krovatkin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@Krovatkin merged this pull request in 725ef26. |
Summary: Fixes pytorch#19314 Pull Request resolved: pytorch#19380 Differential Revision: D15167858 Pulled By: Krovatkin fbshipit-source-id: e87261bbf3e6f8df0601df80280eb3dba42798cd
Fixes #19314