Skip to content

Commit 65aa16f

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
1 parent f994099 commit 65aa16f

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

test/test_nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8184,9 +8184,9 @@ def test_batchnorm_large_batch(self, device, dtype):
81848184
@dtypes(torch.float, torch.double, torch.bfloat16, torch.complex128)
81858185
def test_conv_empty_input(self, device, dtype):
81868186
def help(input, conv, memory_format):
8187-
ref_out = conv(input).detach()
8187+
ref_out = conv(input)
81888188
conv_cl = conv.to(memory_format=memory_format)
8189-
out_cl = conv_cl(input).detach()
8189+
out_cl = conv_cl(input)
81908190
self.assertEqual(ref_out, out_cl)
81918191
input_cl = input.to(memory_format=memory_format)
81928192
out_cl2 = conv(input_cl)

torch/nn/modules/module.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -794,13 +794,6 @@ def compute_should_use_set_data(tensor, tensor_applied):
794794

795795
should_use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion()
796796

797-
def compute_should_use_swap_tensors(tensor, tensor_applied):
798-
return (should_use_swap_tensors
799-
# subclasses may have multiple child tensors so we need to use swap_tensors
800-
or is_traceable_wrapper_subclass(tensor_applied)
801-
or tensor.device.type == 'xla'
802-
or tensor_applied.device.type == 'xla')
803-
804797
for key, param in self._parameters.items():
805798
if param is None:
806799
continue
@@ -811,7 +804,8 @@ def compute_should_use_swap_tensors(tensor, tensor_applied):
811804
param_applied = fn(param)
812805
p_should_use_set_data = compute_should_use_set_data(param, param_applied)
813806

814-
p_should_use_swap_tensors = compute_should_use_swap_tensors(param, param_applied)
807+
# subclasses may have multiple child tensors so we need to use swap_tensors
808+
p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied)
815809

816810
param_grad = param.grad
817811
if p_should_use_swap_tensors:

0 commit comments

Comments
 (0)