Skip to content

Commit ddfa80b

Browse files
committed
Update on "Fix NJT linear_backward() memory usage"
Fixes #141112 The formula we're using for `linear_backward()` is inefficient for higher dim input sizes, even if the input is trivially higher dim (e.g. via use of `unsqueeze()`). This PR updates the formula to match the more efficient version employed by NST. Specifically, note the leading dim collapse for `grad_output`'s values before we compute the various matmuls. https://github.com/pytorch/pytorch/blob/d5ee1d1b581da8399d604bd661ea5fe454b485d6/aten/src/ATen/native/nested/NestedTensorBackward.cpp#L37-L70 Testing for correctness is done via existing gradcheck tests (e.g. `test_backward_nn_functional_linear`). I added a memory usage test but I think it's likely there's a better way to do this. [ghstack-poisoned]
2 parents 5b66b88 + ee17028 commit ddfa80b

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

torch/nested/_internal/ops.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -544,16 +544,22 @@ def linear_backward_default(func, *args, **kwargs):
544544

545545
ds, dw, db = None, None, None
546546
check_ragged_dim_same(func, inp, "self", grad_output, "grad_output")
547-
reshaped_grad = grad_output._values.reshape(-1, weight.size(0))
548547
if output_mask[0]:
549548
ds = NestedTensor(
550-
torch.matmul(reshaped_grad, weight).view_as(inp._values),
551-
**extract_kwargs(grad_output),
549+
torch.matmul(grad_output._values, weight), **extract_kwargs(grad_output)
552550
)
553551
if output_mask[1]:
554-
dw = torch.matmul(reshaped_grad.t(), inp._values.reshape(-1, weight.size(1)))
552+
# NB: Fold dims of values for input and grad_output to treat them as 2D. This
553+
# trick avoids materializing large intermediates and immediately reducing over
554+
# them via sum(). This is equivalent to computing:
555+
# torch.matmul(grad_output._values.transpose(-2, -1), inp._values)
556+
# and then summing over the leading dimensions to get a 2D weight grad.
557+
grad_2d = grad_output._values.reshape(-1, weight.size(0))
558+
input_2d = inp._values.reshape(-1, weight.size(1))
559+
dw = torch.matmul(grad_2d.t(), input_2d)
555560
if output_mask[2]:
556-
db = reshaped_grad.sum(0)
561+
# NB: autograd engine will sum over all but the last dim to get a 1D bias grad.
562+
db = grad_output._values
557563
return (ds, dw, db)
558564

559565

0 commit comments

Comments
 (0)