Skip to content

Commit c1c4014

Browse files
Will Fengfacebook-github-bot
authored andcommitted
Add warning for legacy autograd function (#22922)
Summary: When working on #22762, we discovered that we haven't actually deprecated legacy autograd function. This PR puts up the deprecation warning for 1.2, with the goal to remove legacy function support completely in the near future. Pull Request resolved: #22922 Differential Revision: D16363916 Pulled By: yf225 fbshipit-source-id: 4b554010a3d1f87a3fa45cc1aa29d019c8f1033c
1 parent a2b3403 commit c1c4014

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

test/test_autograd.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,26 @@ def backward(self, grad_output):
167167
MyFunction()(y).sum().backward()
168168
self.assertEqual(v.grad.data, torch.zeros(shape))
169169

170+
def test_legacy_function_deprecation_warning(self):
171+
with warnings.catch_warnings(record=True) as w:
172+
# Ensure warnings are being shown
173+
warnings.simplefilter("always")
174+
175+
# Trigger Warning
176+
class MyFunction(Function):
177+
def forward(self, x):
178+
return x
179+
180+
def backward(self, grad_output):
181+
return grad_output
182+
183+
MyFunction()(torch.randn(3, 4))
184+
185+
# Check warning occurs
186+
self.assertIn(
187+
'Legacy autograd function with non-static forward method is deprecated',
188+
str(w[0]))
189+
170190
def test_invalid_gradients(self):
171191
class MyFunction(Function):
172192
@staticmethod

torch/csrc/autograd/python_function.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,10 @@ PyObject *THPFunction_do_forward(THPFunction *self, PyObject *_inputs)
631631
std::vector<c10::IValue>(),
632632
autograd::Function::peek_at_next_sequence_nr());
633633

634+
TORCH_WARN("Legacy autograd function with non-static forward method is deprecated and will be removed in 1.3. ",
635+
"Please use new-style autograd function with static forward method. ",
636+
"(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)");
637+
634638
auto info_pair = unpack_input<true>(_inputs);
635639
auto& unpacked_input = info_pair.first;
636640
auto& input_info = info_pair.second;

0 commit comments

Comments
 (0)