Skip to content

Commit 0d3bf97

Browse files
thomasjpfanfacebook-github-bot
authored andcommitted
TST Adds test for non-contiguous tensors (#64954)
Summary: Follow up to #61935 This PR: 1. Adds test for non-contiguous tensors 2. Fixes bug in `NLLLoss` that was catch by the test. The reason this was not catch in `common_nn` is because `CriterionTest` overrides `test_cuda` but does not call `test_nonconfig`. cc albanD mruberry jbschlosser walterddr Pull Request resolved: #64954 Reviewed By: zou3519 Differential Revision: D31174149 Pulled By: jbschlosser fbshipit-source-id: a16073e59b40ccc01c82ede016b63a8db2e810f5
1 parent a839cec commit 0d3bf97

File tree

3 files changed

+132
-15
lines changed

3 files changed

+132
-15
lines changed

aten/src/ATen/native/cuda/Loss.cu

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,14 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_2d(
276276
void nll_loss_forward_out_cuda_template(
277277
const Tensor& output,
278278
const Tensor& total_weight,
279-
const Tensor& input,
280-
const Tensor& target,
279+
const Tensor& input_,
280+
const Tensor& target_,
281281
const Tensor& weight,
282282
int64_t reduction,
283283
int64_t ignore_index) {
284+
auto input = *input_.expect_contiguous();
285+
auto target = *target_.expect_contiguous();
286+
284287
int64_t n_classes = input.size(-1);
285288
int64_t n_dims = input.dim();
286289
int64_t batch_size = n_dims == 1 ? 1 : input.size(0);
@@ -327,9 +330,6 @@ void nll_loss_forward_out_cuda_template(
327330
output.resize_({});
328331
total_weight.resize_({});
329332
330-
auto input_ = input.contiguous();
331-
auto target_ = target.contiguous();
332-
333333
if (n_dims == 1) {
334334
AT_DISPATCH_FLOATING_TYPES_AND2(
335335
at::ScalarType::Half,
@@ -345,8 +345,8 @@ void nll_loss_forward_out_cuda_template(
345345
<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
346346
output.data_ptr<scalar_t>(),
347347
total_weight.data_ptr<scalar_t>(),
348-
input_.data_ptr<scalar_t>(),
349-
target_.data_ptr<index_t>(),
348+
input.data_ptr<scalar_t>(),
349+
target.data_ptr<index_t>(),
350350
weight_.defined() ? weight_.data_ptr<scalar_t>()
351351
: nullptr,
352352
reduction == at::Reduction::Mean,
@@ -374,8 +374,8 @@ void nll_loss_forward_out_cuda_template(
374374
at::cuda::getCurrentCUDAStream()>>>(
375375
output.data_ptr<scalar_t>(),
376376
total_weight.data_ptr<scalar_t>(),
377-
input_.data_ptr<scalar_t>(),
378-
target_.data_ptr<index_t>(),
377+
input.data_ptr<scalar_t>(),
378+
target.data_ptr<index_t>(),
379379
weight_.defined() ? weight_.data_ptr<scalar_t>()
380380
: nullptr,
381381
reduction == at::Reduction::Mean,
@@ -459,14 +459,19 @@ __global__ void nll_loss_backward_reduce_cuda_kernel_2d(
459459
};
460460
461461
void nll_loss_backward_out_cuda_template(
462-
const Tensor& grad_input,
463-
const Tensor& grad_output,
464-
const Tensor& input,
465-
const Tensor& target,
462+
const Tensor& grad_input_,
463+
const Tensor& grad_output_,
464+
const Tensor& input_,
465+
const Tensor& target_,
466466
const Tensor& total_weight,
467467
const Tensor& weight,
468468
int64_t reduction,
469469
int64_t ignore_index) {
470+
auto target = *target_.expect_contiguous();
471+
auto input = *input_.expect_contiguous();
472+
auto grad_input = *grad_input_.expect_contiguous();
473+
auto grad_output = *grad_output_.expect_contiguous();
474+
470475
int64_t n_dims = input.dim();
471476
int64_t n_classes = input.size(-1);
472477
int64_t batch_size = n_dims == 1 ? 1 : input.size(0);
@@ -508,7 +513,6 @@ void nll_loss_backward_out_cuda_template(
508513
return;
509514
}
510515
511-
auto target_ = target.contiguous();
512516
TORCH_CHECK(grad_output.numel() == 1);
513517
514518
if (n_dims == 1) {

test/test_modules.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from inspect import signature
1+
from itertools import product
2+
from inspect import signature, isgenerator
23
from copy import deepcopy
34
import tempfile
45

@@ -205,6 +206,116 @@ def test_check_inplace(self, device, dtype, module_info):
205206
output_ip.backward(grad)
206207
self.assertEqual(input_args[0].grad, input_arg_copy[0].grad)
207208

209+
def _traverse_obj(self, obj, func):
210+
if isinstance(obj, (tuple, list)):
211+
return type(obj)(self._traverse_obj(o, func) for o in obj)
212+
elif isgenerator(obj):
213+
return tuple(self._traverse_obj(o, func) for o in obj)
214+
elif isinstance(obj, dict):
215+
return {name: self._traverse_obj(o, func) for name, o in obj.items()}
216+
elif isinstance(obj, (torch.Tensor, torch.nn.Parameter)):
217+
return func(obj)
218+
219+
def _retain_grad(self, obj):
220+
# gradients needs to be retained to check for grad. This is useful when
221+
# non-leafs are present in the graph.
222+
def inner_retain_grad(obj):
223+
if obj.requires_grad:
224+
obj.retain_grad()
225+
self._traverse_obj(obj, inner_retain_grad)
226+
227+
def _get_grads(self, obj):
228+
def inner_get_grad(obj):
229+
if obj.requires_grad:
230+
return obj.grad
231+
return self._traverse_obj(obj, inner_get_grad)
232+
233+
def _zero_grad(self, obj):
234+
def inner_zero_grad(obj):
235+
if obj.grad is not None:
236+
obj.grad = None
237+
self._traverse_obj(obj, inner_zero_grad)
238+
239+
@modules(module_db)
240+
def test_non_contiguous_tensors(self, device, dtype, module_info):
241+
# Check modules work with non-contiguous tensors
242+
243+
module_cls = module_info.module_cls
244+
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
245+
requires_grad=True)
246+
247+
def _make_non_contiguous(obj):
248+
def inner_make_non_contiguous(obj):
249+
# Scalar tensors can not be made non-contiguous
250+
if not isinstance(obj, torch.Tensor) or obj.dim() == 0:
251+
return obj
252+
253+
out = torch.repeat_interleave(obj, 2, dim=-1)
254+
out = out[..., ::2].detach()
255+
out.requires_grad = obj.requires_grad
256+
return out
257+
return self._traverse_obj(obj, inner_make_non_contiguous)
258+
259+
def _can_be_noncontiguous(obj):
260+
if isinstance(obj, (tuple, list)):
261+
return any(_can_be_noncontiguous(o) for o in obj)
262+
elif isinstance(obj, dict):
263+
return any(_can_be_noncontiguous(o) for o in obj.values())
264+
# scalar tensors can not be non-contiguous
265+
if not isinstance(obj, torch.Tensor) or obj.dim() == 0:
266+
return False
267+
return True
268+
269+
270+
for module_input in module_inputs:
271+
if module_input.forward_input is None:
272+
continue
273+
274+
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
275+
if not (_can_be_noncontiguous(input_args) or _can_be_noncontiguous(input_kwargs)):
276+
continue
277+
278+
# === Instantiate the module. ===
279+
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
280+
m = module_cls(*args, **kwargs)
281+
m.to(device).to(dtype)
282+
283+
self._retain_grad((input_args, input_kwargs))
284+
285+
# === Forward with default input
286+
with freeze_rng_state():
287+
default_output = m(*input_args, **input_kwargs)
288+
grad_output = default_output.clone().detach_().normal_()
289+
default_output.backward(grad_output, retain_graph=True)
290+
291+
default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs)))
292+
default_param_grad = deepcopy([p.grad for p in m.parameters()])
293+
294+
# === Construct non-contiguous tensors ===
295+
nc_input_args, nc_input_kwargs = _make_non_contiguous((input_args, input_kwargs))
296+
nc_grad_output = _make_non_contiguous(grad_output)
297+
298+
# === Compare results with non-contiguous and contiguous tensors ===
299+
inputs = [(input_args, input_kwargs), (nc_input_args, nc_input_kwargs)]
300+
grads = [grad_output, nc_grad_output]
301+
302+
for (in_args, in_kwargs), g_out in product(inputs, grads):
303+
g_out_copy = deepcopy(g_out)
304+
self._zero_grad((in_args, in_kwargs))
305+
self._zero_grad(m.parameters())
306+
307+
with freeze_rng_state():
308+
out = m(*in_args, **in_kwargs)
309+
out.backward(g_out_copy, retain_graph=True)
310+
311+
input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs))
312+
self.assertEqual(out, default_output)
313+
self.assertEqual(input_args_grad, default_input_args_grad, atol=1e-4, rtol=0)
314+
self.assertEqual(input_kwargs_grad, default_input_kwargs_grad, atol=1e-4, rtol=0)
315+
316+
param_grad = [p.grad for p in m.parameters()]
317+
self.assertEqual(param_grad, default_param_grad)
318+
208319

209320
def _test_gradients_helper(self, device, dtype, module_info, check):
210321
# Check gradients

torch/testing/_internal/common_modules.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, **
195195

196196
cases: List[Tuple[str, dict]] = [
197197
('', {}),
198+
('reduction_sum', {'reduction': 'sum'}),
199+
('reduction_none', {'reduction': 'none'}),
198200
('ignore_index', {'ignore_index': 2}),
199201
('weights', {'weight': make_weight(10).abs()}),
200202
('weights_ignore_index', {'weight': make_weight(10).abs(), 'ignore_index': 2}),

0 commit comments

Comments
 (0)