-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add implicit optional unwrapping #15587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(maybe in a followup?) We could have the Python Printer spit out prim::__is__ nodes as Python is statements and then skip emitting these altogether so they'd get re-inserted when the code is read in again
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this constructor isn't used
test/test_jit.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a test for nested optionals (i.e. Optional[Optional[int]])?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. This isn't schematizable, so it would need special logic. I don't think this is an actual use case - I grep'd around for "Optional[Optional[" and didn't find anything. I think erroring and waiting to see if it ever comes up would be sufficient.
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is another SugaredValue necessary for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There needs to be special casing of the duplicate unwrapped optional calls, and this is the most self-contained way to achieve it. Additionally, we can't just not print out the unchecked_optionals because then we lose the boolean expressions, since those get desugared to if statements after export/import.
e.g.:
if x is None and len(x) < 2:
print("hi")
On the second compilation, the compiler won't know to combine " x is None and len(x)" since they will no longer be in a boolean expression. But the inserted unwrap optional from the first compilation will still exist.
cf50154 to
111be98
Compare
|
|
||
| @torch.jit.script | ||
| def test_ternary(x): | ||
| # type: (Optional[int]) -> int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this doesn't do any refinement (both branches return a literal)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch (test works once updated)
test/test_jit.py
Outdated
|
|
||
| @torch.jit.script | ||
| def test_not_none(x): | ||
| # type: (Optional[int]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The None return should be explicit, there were some mypy issues at some point over this same thing
torch/csrc/jit/script/compiler.cpp
Outdated
| if (type != NoneType::get()) { | ||
| auto output = g->insert(prim::_unchecked_unwrap_optional, {v}); | ||
| // set name using "a" and not "a.1" | ||
| setVar(fakeRange(), v->uniqueNameBase(), output); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason this can't use v->node()->range()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Range isn't a method of node
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple significant problems here:
-
It is an error to rely on the uniqueName for anything other than debugging. ir.h is allowed to rename anything at will and downstream code needs to work. If this needs to remember the variable name, then keep the real variable name in the active_refinements list along with the value.
-
Using a fakeRange() is an error here. It will cause a completely unintelligible error message when something goes wrong the set. You should be able to keep the SourceRange of the variable in the active_refinements lit.
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty close. I think there is one restructure that we should do to make this easier to maintain and remove some of the gotcha issues that can come up with using Value*'s uniqueNames, and saving Value*'s in maps (details are in comments below for why these are bad things to do).
- Refinements should be a map from variables to refined type, not Value* to refined type. This reflects how it works in practice, and avoids the unsafe call to uniqueBaseName. We are, after all, refining the type of variables in the environment and not values.
- Refinements should be gathered by direct examination of the AST of the condition, not the IR. This will remove the "action-at-a-distance" of the Value->Refinement map. It will also prevent cases. Instead you can simply have an independent function
BoolInfo findRefinements(Expr cond)that works side-by-side with emit expression. That function can be relatively short (it just has to handle is/is not/and/or) and it is independent of all the normal emit logic. This removes the need for changes to the environment. Refinements can be applied directly when beginning an if statement with a function likeinsertRefinements(Refinements ref)without putting the logic into the environment.
| _(aten, index_put_) \ | ||
| _(aten, device) \ | ||
| _(aten, len) \ | ||
| _(prim, _unchecked_unwrap_optional)\ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: no need for an underscore, it is in the prim namespace which is not externally accessible.
| return a + b | ||
| ''') | ||
|
|
||
| def test_optional_refinement(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None of these tests are actually run, can you add some tests that actually make sure unwrap optional is inserted correctly?
torch/csrc/jit/script/compiler.cpp
Outdated
| if (type != NoneType::get()) { | ||
| auto output = g->insert(prim::_unchecked_unwrap_optional, {v}); | ||
| // set name using "a" and not "a.1" | ||
| setVar(fakeRange(), v->uniqueNameBase(), output); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple significant problems here:
-
It is an error to rely on the uniqueName for anything other than debugging. ir.h is allowed to rename anything at will and downstream code needs to work. If this needs to remember the variable name, then keep the real variable name in the active_refinements list along with the value.
-
Using a fakeRange() is an error here. It will cause a completely unintelligible error message when something goes wrong the set. You should be able to keep the SourceRange of the variable in the active_refinements lit.
torch/csrc/jit/ir.h
Outdated
| } | ||
|
|
||
| TORCH_API void LintGraph(std::shared_ptr<Graph>& graph); | ||
| inline const SourceRange& fakeRange() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given a comment below, it is not necessary to expose this, and we do not want to encourage its use.
torch/csrc/jit/register_prim_ops.cpp
Outdated
| return 0; | ||
| }; | ||
| }), | ||
| // TODO removed in preprocessing before being run in the interpreter, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
// This op can be removed in preprocessing before being run in the interpreter(but is currently not removed),
// even when it is removed it needs to remain a registered op so that constant prop can run.
torch/csrc/jit/script/compiler.cpp
Outdated
| auto maybe_t_2 = b.find(v_1); | ||
| if (maybe_t_2 != b.end()) { | ||
| TypePtr t_2 = maybe_t_2->second; | ||
| auto maybe_unified_type = unifyTypes(t_1, t_2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems wrong. Lets say you have two refinements known to be true:
a -> A
a -> B
with A <: B
Then you known 'a -> A' but unify(A, B) == B. It is possible you cannot cause this to fail yet because we never introduce a refinement with this structure yet, but it shouldn't be left incorrect. I want to avoid adding a typeMeet operator because it is a burden to maintain, but you can choose t_1 if t_1 <: t_2, t_2 if t_2 <: t_1, and drop the refinement if their is true.
torch/csrc/jit/script/compiler.cpp
Outdated
| return emitShortCircuitIf( | ||
| tree->range(), inputs[0], inputs[1], tree->kind() == TK_OR); | ||
| } | ||
| case TK_IS: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels like a lot of duplication just to add the setIsBoolInfo flag at the end. Refactor?
torch/csrc/jit/script/compiler.cpp
Outdated
|
|
||
| // ordered set, because we want deterministic graph output | ||
| std::set<std::string> mutated_variables; | ||
| auto true_vars = save_true->definedVariables(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This introduces unnecessary copies.
torch/csrc/jit/script/compiler.cpp
Outdated
| } | ||
| } | ||
|
|
||
| void insertRefinements() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a weird API. It is only called in one place. Why not have it take active_refinements as an argument?
torch/csrc/jit/script/compiler.cpp
Outdated
| // has not changed and we omit adding an if node output for it. | ||
| // this also prevents jitter when importing an if expression like: | ||
| // x = x if x is not None else 1 | ||
| bool valueNotWrittenTo( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is best handled in dead code elimination rather than here, where it makes the compiler logic more complicated. Dead code elimination already deletes unneeded if statement/while outputs. It is pretty natural to add a rule there to make it understand that an if output can be eliminated if it is simply combining an unchecked_unwrap(x) with an x or another unchecked_unwrap(x).
Also, I believe this might be buggy. Just because an output of an is is an _uncheck_unwrap_optional does not me it is the same variable as it originally was. There doesn't appear to be any check that true_v's input is the original value for v. For instance, what happens if someone swaps two unwrapped optionals in an if statement.
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty close. I think there is one restructure that we should do to make this easier to maintain and remove some of the gotcha issues that can come up with using Value*'s uniqueNames, and saving Value*'s in maps (details are in comments below for why these are bad things to do).
- Refinements should be a map from variables to refined type, not Value* to refined type. This reflects how it works in practice, and avoids the unsafe call to uniqueBaseName. We are, after all, refining the type of variables in the environment and not values.
- Refinements should be gathered by direct examination of the AST of the condition, not the IR. This will remove the "action-at-a-distance" of the Value->Refinement map. It will also prevent cases. Instead you can simply have an independent function
BoolInfo findRefinements(Expr cond)that works side-by-side with emit expression. That function can be relatively short (it just has to handle is/is not/and/or) and it is independent of all the normal emit logic. This removes the need for changes to the environment. Refinements can be applied directly when beginning an if statement with a function likeinsertRefinements(Refinements ref)without putting the logic into the environment.
d8e159f to
a7c7eb3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wary of special-casing this code too much. It seems like the case this is trying to handle (assignment in the else-block of an if not None check) is unlikely to appear in user code. I think we should just leave the unwrap_optional's in for those cases—if it turns out to be a common pattern somehow, we can revisit these additions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For context, leaving them in has issues for jitter in export, if I remember correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There were jitter issues at one point but unrelated to this. This is just trying to remove unneeded if node outputs.
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it would be simpler to phrase these as in-place non-static methods? e.g. a.intersect(b) instead of Refinements::interesctRefinements(a, b)
a7e4e67 to
b7a8d73
Compare
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool -- this looks pretty good. The Refinements data structure needs to be cleaned up a bit, and I believe there is a bug in the scope of refinements generating in lazy if statements that needs to be resolved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you refactor the unchecked_unwrap_optional part into a meaningfully named function?
if (removePossiblyUnneededUnwrap(...))
continue;
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use a small struct here? There are about 10 lines below where types and ranges get repeated. It will get worse if someone needs to add something later. It also stores two copies of every variable name, two entire red-black trees, and doubles the number of traversals!
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also make more senses as a method.
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use a struct. The weird data structure is escaping the class it is in!
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refinements here are not scoped appropriately. They will stay life after the end of the condition and through the end of the block the conditional is in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These expressions are emitted when each corresponding block is emitted so I don't think they escape the condition. I will add a test to make sure this does not regress.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see -- I didn't see the call below to emitIfExpr. That should work then.
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this not a method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It needs access to the environment_stack
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean, why not false_info.addRefinement(name, type)?
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it is none, shouldn't it also remap name to None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since None does not currently subtype optionals, we can't refine the type to None without erroring on type unification. Inserting prim::None with an optional type could maybe help constant prop in the case that the none variable is used again. It also makes it hard to tell when a user sets a variable to None vs when the None is automatically inserted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It causes jitter for a reason i haven't looked into. I am going to leave it out on this PR and revisit when None subtypes optional
- opt_5, x_3 = annotate(Optional[Tensor], None), x_1
? --------- -
+ opt_5, x_3 = annotate(Tensor, None), x_1
suo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The approach looks fine to me, but there was some cleanliness stuff I commented about inline. Also, please clang-format!
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment describing the purpose of this data structure
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't seem like you need to return something here or in unionRefinements—the return value is never used.
Also, intersect and union seem fine, the Refinements seems redundant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Union is a reserved keyword so that's why I added unionRefinements
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment seems out of date.
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const here and for maybe_unified_type
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a comment describing the purpose of this function
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't take this by value
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto here
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const here and below
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe at() for these accesses just in case
|
Also please look into the rocm failure before landing |
…ning types to None.
c75fbd3 to
2162a61
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eellison is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Add support for type inference for optional type refinement. If a conditional is of the form "x is None" or "x is not None", or is a boolean expression containing multiple none checks, the proper type refinements are inserted in each branch. For example: if optional_tensor is not None and len(optional_tensor) < 2: # optional_tensor is a Tensor if optional_tensor1 is not None and optional_tensor2 is not None: # both optional_tensor1 and optional_tensor2 are Tensors TODO: - not run an op for unchecked unwrap optional in the interpreter - potentially refine types to prim::None (omitted for now to simply things & because it's not an actual use cause). Pull Request resolved: pytorch/pytorch#15587 Differential Revision: D13733810 Pulled By: eellison fbshipit-source-id: 57c32be9f5a09ab5542ba0144a6059b96de23d7a
Add support for type inference for optional type refinement.
If a conditional is of the form "x is None" or "x is not None", or is a boolean expression containing multiple none checks, the proper type refinements are inserted in each branch.
For example:
if optional_tensor is not None and len(optional_tensor) < 2:
# optional_tensor is a Tensor
if optional_tensor1 is not None and optional_tensor2 is not None:
# both optional_tensor1 and optional_tensor2 are Tensors
TODO:
not run an op for unchecked unwrap optional in the interpreter
potentially refine types to prim::None (omitted for now to simply things & because it's not an actual use cause).