Skip to content

Conversation

@malvika2147
Copy link

@malvika2147 malvika2147 commented Jul 1, 2019

Stack from ghstack:

Summary: The ready queue prioritizes the most nested reentrant tasks. Run the reentrant tasks recursively until the recursion depth is close to max_recursion_depth, which depends on the python recursion limit. Once the limit is reached, further reentrant backwards tasks will be run in a different thread.

Test Plan: Added test for reentrant backwards with checkpoint and a test for a recursive backwards function (which should fail if we run all the reentrant tasks recursively in the same thread) and for testing priority of reentrant tasks.
Will add a test for priority of reentrant tasks in future pr.

Differential Revision: D16131955

…o limit

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@pytorchbot pytorchbot added module: autograd Related to torch.autograd, and the autograd engine in general module: pybind Related to our Python bindings / interactions with other Python libraries labels Jul 1, 2019
@malvika2147 malvika2147 requested review from apaszke and ezyang July 1, 2019 15:28
namespace torch { namespace autograd { namespace python {

PythonEngine::PythonEngine () {
max_recursion_depth_ = 0.1*Py_GetRecursionLimit();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little note explaining how we got the magic number 0.1 would be appreciated here :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I guess, now the constructor for PythonEngine must be GIL protected. Have we checked this invariant is respected in all the use sites?

@ezyang
Copy link
Contributor

ezyang commented Jul 1, 2019

Test Plan

A formally written out test plan would be very useful here, since there are no test suite changes (and indeed, it's a bit hard to think of how you would actually go about writing unit tests; probably only manual tests are possible here.)

mal added 2 commits July 1, 2019 21:07
…til close to limit"

Prioritize reentrant tasks and execute them recursively until close to limit

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

gh-metadata: pytorch pytorch 22397 gh/mal2147/10/head
…til close to limit"

Prioritize reentrant tasks and execute them recursively until close to limit

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

gh-metadata: pytorch pytorch 22397 gh/mal2147/10/head
@ezyang
Copy link
Contributor

ezyang commented Jul 2, 2019

Don't forget to rerequest review when you need it!

@ezyang ezyang requested a review from VitalyFedyunin July 2, 2019 21:00
@ezyang ezyang requested a review from mrshenli July 2, 2019 21:17
mal added 4 commits July 2, 2019 18:14
…til close to limit"

Prioritize reentrant tasks and execute them recursively until close to limit

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

gh-metadata: pytorch pytorch 22397 gh/mal2147/10/head
…til close to limit"

Prioritize reentrant tasks and execute them recursively until close to limit

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

gh-metadata: pytorch pytorch 22397 gh/mal2147/10/head
…til close to limit"

Prioritize reentrant tasks and execute them recursively until close to limit

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

gh-metadata: pytorch pytorch 22397 gh/mal2147/10/head
…til close to limit"

Prioritize reentrant tasks and execute them recursively until close to limit

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

gh-metadata: pytorch pytorch 22397 gh/mal2147/10/head
if(current_depth >= max_recursion_depth_){
// See Note [Reentrant backwards]
// If reached the max depth, switch to a different thread
add_thread_pool_task(&graph_task);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen if we exceed thread_pool_shared_->graphtasks_queue_.size() here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::queue should resize dynamically, or am I misunderstanding your question?

@VitalyFedyunin
Copy link
Contributor

I suggest to have clearly defined constant how deep we can go with recursion.

@malvika2147
Copy link
Author

I suggest to have clearly defined constant how deep we can go with recursion.

But in that case, we will always create new threads even if there is enough stack space to run on a single thread.

…til close to limit"

Prioritize reentrant tasks and execute them recursively until close to limit

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

gh-metadata: pytorch pytorch 22397 gh/mal2147/10/head
@malvika2147 malvika2147 requested a review from ezyang July 3, 2019 20:19
with torch.enable_grad():
ctx.x = Variable(x.data, requires_grad=True)
ctx.x = ctx.x - 1
return ctx.x.detach()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty funky forward function. (It's funky because we don't normally do autograd operations inside of a forward function.) It seems like it's doing two things: you want to return x (forward is just identity), but you also want to create a leaf variable on context with some non-trivial autograd history. Is there a reason x has to be used in both cases? I'll keep reading and see if I can figure out why you create leaf variables in forward ;)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, you're also using ctx.x to keep track about how many times you recurse.

…til close to limit"

Prioritize reentrant tasks and execute them recursively until close to limit

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

gh-metadata: pytorch pytorch 22397 gh/mal2147/10/head
}

Engine::Engine() = default;
// This limit is based on the default python recursion limit which is 1000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments do not match the code 1000 vs 100

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to set it lower than the actual python limit to take into account the function calls within python

…til close to limit"

Prioritize reentrant tasks and execute them recursively until close to limit

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

gh-metadata: pytorch pytorch 22397 gh/mal2147/10/head
@zou3519 zou3519 deleted the gh/mal2147/10/head branch July 5, 2019 15:53
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 0140a75.

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 0140a75.

xzhu1900 pushed a commit to xzhu1900/pytorch that referenced this pull request Jul 5, 2019
…o limit

Summary: Pull Request resolved: pytorch#22397

Test Plan:
Added test for reentrant backwards with checkpoint and a test for a recursive backwards function (which should fail if we run all the reentrant tasks recursively in the same thread) and for testing priority of reentrant tasks.
~~Will add a test for priority of reentrant tasks in future pr.~~

Imported from OSS

Differential Revision: D16131955

fbshipit-source-id: 18301d45c1ec9fbeb566b1016dbaf7a84a09c7ac
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: autograd Related to torch.autograd, and the autograd engine in general module: pybind Related to our Python bindings / interactions with other Python libraries

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants