Skip to content

Conversation

@pbelevich
Copy link
Contributor

@pbelevich pbelevich commented Sep 20, 2019

Stack from ghstack:

Differential Revision: D17516856

@pytorchbot pytorchbot added module: autograd Related to torch.autograd, and the autograd engine in general module: cpp Related to C++ API module: internals Related to internal abstractions in c10 and ATen labels Sep 20, 2019
pbelevich added a commit that referenced this pull request Sep 20, 2019
ghstack-source-id: 68b6629
Pull Request resolved: #26568
pbelevich added a commit that referenced this pull request Sep 21, 2019
ghstack-source-id: aa75548
Pull Request resolved: #26568
pbelevich added a commit that referenced this pull request Sep 22, 2019
ghstack-source-id: 036936a
Pull Request resolved: #26568
pbelevich added a commit that referenced this pull request Sep 23, 2019
ghstack-source-id: 417671a
Pull Request resolved: #26568
pbelevich added a commit that referenced this pull request Sep 26, 2019
ghstack-source-id: 96c7671
Pull Request resolved: #26568
pbelevich added a commit that referenced this pull request Oct 3, 2019
ghstack-source-id: c4b4e61
Pull Request resolved: #26568
pbelevich added a commit that referenced this pull request Oct 3, 2019
ghstack-source-id: b5b0f31
Pull Request resolved: #26568
pbelevich added a commit that referenced this pull request Oct 9, 2019
ghstack-source-id: cdede33
Pull Request resolved: #26568
pbelevich added a commit that referenced this pull request Oct 23, 2019
ghstack-source-id: c715ea1
Pull Request resolved: #26568
pbelevich added a commit that referenced this pull request Oct 24, 2019
ghstack-source-id: 3b37813
Pull Request resolved: #26568
pbelevich added a commit that referenced this pull request Oct 24, 2019
ghstack-source-id: 88b6e0c
Pull Request resolved: #26568
@ssbotelh
Copy link

Hi there, do you guys have any kind of timeline for this task? I'm just really eager to use this feature on our C++ code base, so I'm curious as to whether there's any completion target. Thanks!

@yf225
Copy link
Contributor

yf225 commented Feb 13, 2020

@ssbotelh Thanks a lot for your interest! I will work on this feature and have it ready by our 1.5 release.

@yf225
Copy link
Contributor

yf225 commented Feb 13, 2020

Update: at::Tensor::register_hook is already added by #28287, with usage examples in

TEST(CustomAutogradTest, Hooks) {
Variable x = torch::ones({5,5}, torch::requires_grad());
Variable y = torch::ones({5,5})*4;
y.set_requires_grad(true);
int counter = 0;
std::function<void(int, Variable)> bw_hook([&counter](int inc, Variable grad){
counter += inc;
});
Variable z = x * x + x * 2 + x * y + y;
x.register_hook([&bw_hook](Variable grad){
bw_hook(0, grad);
});
auto hook_1 = z.register_hook([&bw_hook](Variable grad){
bw_hook(1, grad);
});
z.backward(torch::ones({5,5}), true, true);
ASSERT_EQ(counter, 1);
auto hook_2 = z.register_hook([&bw_hook](Variable grad){
bw_hook(2, grad);
});
z.backward(torch::ones({5,5}), true, true);
ASSERT_EQ(counter, 4);
z.remove_hook(hook_2);
z.backward(torch::ones({5,5}), true, true);
ASSERT_EQ(counter, 5);
std::function<Variable(Variable)> bw_hook_modify([](Variable grad){
return grad.mul(2);
});
z.remove_hook(hook_1);
z.register_hook(bw_hook_modify);
y.grad().zero_();
z.backward(torch::ones({5,5}), true, false);
ASSERT_VARIABLE_EQ(y.grad(), (x+1)*2);
y.register_hook(bw_hook_modify);
y.grad().zero_();
z.backward(torch::ones({5,5}), false, false);
ASSERT_VARIABLE_EQ(y.grad(), (x+1)*4);
ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index");
}
TEST(CustomAutogradTest, HookNone) {
struct NoneGradientFunction : public Function<NoneGradientFunction> {
static variable_list forward(AutogradContext *ctx, Variable x, Variable y) {
return {x,y};
}
static variable_list backward(AutogradContext *ctx, variable_list grad) {
return {grad[0], Variable()};
}
};
bool was_called = false;
auto hook = ([&was_called](Variable grad){
ASSERT_TRUE(grad.defined());
was_called = true;
});
auto x = torch::randn({5,5}, torch::requires_grad());
auto y = torch::randn({5,5});
auto out = NoneGradientFunction::apply(x,y);
Variable rx = x[0], ry = x[1];
rx.register_hook(hook);
ry.register_hook(hook);
(rx+ry).sum().backward();
ASSERT_TRUE(was_called);
}

@yf225 yf225 closed this Feb 13, 2020
@facebook-github-bot facebook-github-bot deleted the gh/pbelevich/10/head branch March 15, 2020 14:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: autograd Related to torch.autograd, and the autograd engine in general module: cpp Related to C++ API module: internals Related to internal abstractions in c10 and ATen

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants