Commit ddfa80b
committed
Update on "Fix NJT linear_backward() memory usage"
Fixes #141112
The formula we're using for `linear_backward()` is inefficient for higher dim input sizes, even if the input is trivially higher dim (e.g. via use of `unsqueeze()`). This PR updates the formula to match the more efficient version employed by NST. Specifically, note the leading dim collapse for `grad_output`'s values before we compute the various matmuls.
https://github.com/pytorch/pytorch/blob/d5ee1d1b581da8399d604bd661ea5fe454b485d6/aten/src/ATen/native/nested/NestedTensorBackward.cpp#L37-L70
Testing for correctness is done via existing gradcheck tests (e.g. `test_backward_nn_functional_linear`). I added a memory usage test but I think it's likely there's a better way to do this.
[ghstack-poisoned]1 file changed
+11
-5
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
544 | 544 | | |
545 | 545 | | |
546 | 546 | | |
547 | | - | |
548 | 547 | | |
549 | 548 | | |
550 | | - | |
551 | | - | |
| 549 | + | |
552 | 550 | | |
553 | 551 | | |
554 | | - | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
555 | 560 | | |
556 | | - | |
| 561 | + | |
| 562 | + | |
557 | 563 | | |
558 | 564 | | |
559 | 565 | | |
| |||
0 commit comments