Skip to content

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented Dec 28, 2018

Add support for type inference for optional type refinement.

If a conditional is of the form "x is None" or "x is not None", or is a boolean expression containing multiple none checks, the proper type refinements are inserted in each branch.

For example:
if optional_tensor is not None and len(optional_tensor) < 2:
# optional_tensor is a Tensor

if optional_tensor1 is not None and optional_tensor2 is not None:
# both optional_tensor1 and optional_tensor2 are Tensors

TODO:

  • not run an op for unchecked unwrap optional in the interpreter

  • potentially refine types to prim::None (omitted for now to simply things & because it's not an actual use cause).

@eellison eellison changed the title Add implicit optional unwrapping Add implicit optional unwrapping [WIP] Dec 28, 2018
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Dec 28, 2018
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(maybe in a followup?) We could have the Python Printer spit out prim::__is__ nodes as Python is statements and then skip emitting these altogether so they'd get re-inserted when the code is read in again

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this constructor isn't used

test/test_jit.py Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test for nested optionals (i.e. Optional[Optional[int]])?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. This isn't schematizable, so it would need special logic. I don't think this is an actual use case - I grep'd around for "Optional[Optional[" and didn't find anything. I think erroring and waiting to see if it ever comes up would be sufficient.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is another SugaredValue necessary for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There needs to be special casing of the duplicate unwrapped optional calls, and this is the most self-contained way to achieve it. Additionally, we can't just not print out the unchecked_optionals because then we lose the boolean expressions, since those get desugared to if statements after export/import.
e.g.:
if x is None and len(x) < 2:
print("hi")

On the second compilation, the compiler won't know to combine " x is None and len(x)" since they will no longer be in a boolean expression. But the inserted unwrap optional from the first compilation will still exist.

@eellison eellison force-pushed the insert_unchecked_optional branch 2 times, most recently from cf50154 to 111be98 Compare January 3, 2019 23:34
@eellison eellison changed the title Add implicit optional unwrapping [WIP] Add implicit optional unwrapping Jan 3, 2019

@torch.jit.script
def test_ternary(x):
# type: (Optional[int]) -> int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this doesn't do any refinement (both branches return a literal)

Copy link
Contributor Author

@eellison eellison Jan 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch (test works once updated)

test/test_jit.py Outdated

@torch.jit.script
def test_not_none(x):
# type: (Optional[int])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The None return should be explicit, there were some mypy issues at some point over this same thing

if (type != NoneType::get()) {
auto output = g->insert(prim::_unchecked_unwrap_optional, {v});
// set name using "a" and not "a.1"
setVar(fakeRange(), v->uniqueNameBase(), output);
Copy link
Contributor

@driazati driazati Jan 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason this can't use v->node()->range()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Range isn't a method of node

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple significant problems here:

  1. It is an error to rely on the uniqueName for anything other than debugging. ir.h is allowed to rename anything at will and downstream code needs to work. If this needs to remember the variable name, then keep the real variable name in the active_refinements list along with the value.

  2. Using a fakeRange() is an error here. It will cause a completely unintelligible error message when something goes wrong the set. You should be able to keep the SourceRange of the variable in the active_refinements lit.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty close. I think there is one restructure that we should do to make this easier to maintain and remove some of the gotcha issues that can come up with using Value*'s uniqueNames, and saving Value*'s in maps (details are in comments below for why these are bad things to do).

  1. Refinements should be a map from variables to refined type, not Value* to refined type. This reflects how it works in practice, and avoids the unsafe call to uniqueBaseName. We are, after all, refining the type of variables in the environment and not values.
  2. Refinements should be gathered by direct examination of the AST of the condition, not the IR. This will remove the "action-at-a-distance" of the Value->Refinement map. It will also prevent cases. Instead you can simply have an independent function BoolInfo findRefinements(Expr cond) that works side-by-side with emit expression. That function can be relatively short (it just has to handle is/is not/and/or) and it is independent of all the normal emit logic. This removes the need for changes to the environment. Refinements can be applied directly when beginning an if statement with a function like insertRefinements(Refinements ref) without putting the logic into the environment.

_(aten, index_put_) \
_(aten, device) \
_(aten, len) \
_(prim, _unchecked_unwrap_optional)\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no need for an underscore, it is in the prim namespace which is not externally accessible.

return a + b
''')

def test_optional_refinement(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of these tests are actually run, can you add some tests that actually make sure unwrap optional is inserted correctly?

if (type != NoneType::get()) {
auto output = g->insert(prim::_unchecked_unwrap_optional, {v});
// set name using "a" and not "a.1"
setVar(fakeRange(), v->uniqueNameBase(), output);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple significant problems here:

  1. It is an error to rely on the uniqueName for anything other than debugging. ir.h is allowed to rename anything at will and downstream code needs to work. If this needs to remember the variable name, then keep the real variable name in the active_refinements list along with the value.

  2. Using a fakeRange() is an error here. It will cause a completely unintelligible error message when something goes wrong the set. You should be able to keep the SourceRange of the variable in the active_refinements lit.

}

TORCH_API void LintGraph(std::shared_ptr<Graph>& graph);
inline const SourceRange& fakeRange() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given a comment below, it is not necessary to expose this, and we do not want to encourage its use.

return 0;
};
}),
// TODO removed in preprocessing before being run in the interpreter,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

// This op can be removed in preprocessing before being run in the interpreter(but is currently not removed),
// even when it is removed it needs to remain a registered op so that constant prop can run.

auto maybe_t_2 = b.find(v_1);
if (maybe_t_2 != b.end()) {
TypePtr t_2 = maybe_t_2->second;
auto maybe_unified_type = unifyTypes(t_1, t_2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems wrong. Lets say you have two refinements known to be true:

a -> A
a -> B
with A <: B

Then you known 'a -> A' but unify(A, B) == B. It is possible you cannot cause this to fail yet because we never introduce a refinement with this structure yet, but it shouldn't be left incorrect. I want to avoid adding a typeMeet operator because it is a burden to maintain, but you can choose t_1 if t_1 <: t_2, t_2 if t_2 <: t_1, and drop the refinement if their is true.

return emitShortCircuitIf(
tree->range(), inputs[0], inputs[1], tree->kind() == TK_OR);
}
case TK_IS:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like a lot of duplication just to add the setIsBoolInfo flag at the end. Refactor?


// ordered set, because we want deterministic graph output
std::set<std::string> mutated_variables;
auto true_vars = save_true->definedVariables();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This introduces unnecessary copies.

}
}

void insertRefinements() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a weird API. It is only called in one place. Why not have it take active_refinements as an argument?

// has not changed and we omit adding an if node output for it.
// this also prevents jitter when importing an if expression like:
// x = x if x is not None else 1
bool valueNotWrittenTo(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is best handled in dead code elimination rather than here, where it makes the compiler logic more complicated. Dead code elimination already deletes unneeded if statement/while outputs. It is pretty natural to add a rule there to make it understand that an if output can be eliminated if it is simply combining an unchecked_unwrap(x) with an x or another unchecked_unwrap(x).

Also, I believe this might be buggy. Just because an output of an is is an _uncheck_unwrap_optional does not me it is the same variable as it originally was. There doesn't appear to be any check that true_v's input is the original value for v. For instance, what happens if someone swaps two unwrapped optionals in an if statement.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty close. I think there is one restructure that we should do to make this easier to maintain and remove some of the gotcha issues that can come up with using Value*'s uniqueNames, and saving Value*'s in maps (details are in comments below for why these are bad things to do).

  1. Refinements should be a map from variables to refined type, not Value* to refined type. This reflects how it works in practice, and avoids the unsafe call to uniqueBaseName. We are, after all, refining the type of variables in the environment and not values.
  2. Refinements should be gathered by direct examination of the AST of the condition, not the IR. This will remove the "action-at-a-distance" of the Value->Refinement map. It will also prevent cases. Instead you can simply have an independent function BoolInfo findRefinements(Expr cond) that works side-by-side with emit expression. That function can be relatively short (it just has to handle is/is not/and/or) and it is independent of all the normal emit logic. This removes the need for changes to the environment. Refinements can be applied directly when beginning an if statement with a function like insertRefinements(Refinements ref) without putting the logic into the environment.

@eellison eellison force-pushed the insert_unchecked_optional branch 2 times, most recently from d8e159f to a7c7eb3 Compare January 10, 2019 22:27
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wary of special-casing this code too much. It seems like the case this is trying to handle (assignment in the else-block of an if not None check) is unlikely to appear in user code. I think we should just leave the unwrap_optional's in for those cases—if it turns out to be a common pattern somehow, we can revisit these additions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For context, leaving them in has issues for jitter in export, if I remember correctly.

Copy link
Contributor Author

@eellison eellison Jan 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were jitter issues at one point but unrelated to this. This is just trying to remove unneeded if node outputs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it would be simpler to phrase these as in-place non-static methods? e.g. a.intersect(b) instead of Refinements::interesctRefinements(a, b)

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool -- this looks pretty good. The Refinements data structure needs to be cleaned up a bit, and I believe there is a bug in the scope of refinements generating in lazy if statements that needs to be resolved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you refactor the unchecked_unwrap_optional part into a meaningfully named function?

if (removePossiblyUnneededUnwrap(...))
  continue;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use a small struct here? There are about 10 lines below where types and ranges get repeated. It will get worse if someone needs to add something later. It also stores two copies of every variable name, two entire red-black trees, and doubles the number of traversals!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also make more senses as a method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use a struct. The weird data structure is escaping the class it is in!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refinements here are not scoped appropriately. They will stay life after the end of the condition and through the end of the block the conditional is in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These expressions are emitted when each corresponding block is emitted so I don't think they escape the condition. I will add a test to make sure this does not regress.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see -- I didn't see the call below to emitIfExpr. That should work then.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not a method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs access to the environment_stack

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, why not false_info.addRefinement(name, type)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is none, shouldn't it also remap name to None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since None does not currently subtype optionals, we can't refine the type to None without erroring on type unification. Inserting prim::None with an optional type could maybe help constant prop in the case that the none variable is used again. It also makes it hard to tell when a user sets a variable to None vs when the None is automatically inserted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It causes jitter for a reason i haven't looked into. I am going to leave it out on this PR and revisit when None subtypes optional

 -     opt_5, x_3 = annotate(Optional[Tensor], None), x_1
?                           ---------      -
+     opt_5, x_3 = annotate(Tensor, None), x_1

Copy link
Member

@suo suo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach looks fine to me, but there was some cleanliness stuff I commented about inline. Also, please clang-format!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment describing the purpose of this data structure

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem like you need to return something here or in unionRefinements—the return value is never used.

Also, intersect and union seem fine, the Refinements seems redundant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Union is a reserved keyword so that's why I added unionRefinements

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment seems out of date.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const here and for maybe_unified_type

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a comment describing the purpose of this function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't take this by value

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const here and below

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe at() for these accesses just in case

@suo
Copy link
Member

suo commented Jan 17, 2019

Also please look into the rocm failure before landing

@eellison eellison force-pushed the insert_unchecked_optional branch from c75fbd3 to 2162a61 Compare January 18, 2019 07:28
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Jan 18, 2019
Summary:
Add support for type inference for optional type refinement.

If a conditional is of the form "x is None" or "x is not None", or is a boolean expression containing multiple none checks, the proper type refinements are inserted in each branch.

For example:
if optional_tensor is not None and len(optional_tensor) < 2:
	# optional_tensor is a Tensor

if optional_tensor1 is not None and optional_tensor2 is not None:
	# both optional_tensor1 and optional_tensor2 are Tensors

TODO:

- not run an op for unchecked unwrap optional in the interpreter

- potentially refine types to prim::None (omitted for now to simply things & because it's not an actual use cause).
Pull Request resolved: pytorch/pytorch#15587

Differential Revision: D13733810

Pulled By: eellison

fbshipit-source-id: 57c32be9f5a09ab5542ba0144a6059b96de23d7a
facebook-github-bot pushed a commit that referenced this pull request Jan 25, 2019
Summary:
Now that #15587 has landed, updating docs.

Will close #15278
Pull Request resolved: #16380

Differential Revision: D13825221

Pulled By: eellison

fbshipit-source-id: c5a7a7fbb40ba7be46a80760862468f2c9967169
@ezyang ezyang added the merged label Jun 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants