|
1 | | -from inspect import signature |
| 1 | +from itertools import product |
| 2 | +from inspect import signature, isgenerator |
2 | 3 | from copy import deepcopy |
3 | 4 | import tempfile |
4 | 5 |
|
@@ -205,6 +206,116 @@ def test_check_inplace(self, device, dtype, module_info): |
205 | 206 | output_ip.backward(grad) |
206 | 207 | self.assertEqual(input_args[0].grad, input_arg_copy[0].grad) |
207 | 208 |
|
| 209 | + def _traverse_obj(self, obj, func): |
| 210 | + if isinstance(obj, (tuple, list)): |
| 211 | + return type(obj)(self._traverse_obj(o, func) for o in obj) |
| 212 | + elif isgenerator(obj): |
| 213 | + return tuple(self._traverse_obj(o, func) for o in obj) |
| 214 | + elif isinstance(obj, dict): |
| 215 | + return {name: self._traverse_obj(o, func) for name, o in obj.items()} |
| 216 | + elif isinstance(obj, (torch.Tensor, torch.nn.Parameter)): |
| 217 | + return func(obj) |
| 218 | + |
| 219 | + def _retain_grad(self, obj): |
| 220 | + # gradients needs to be retained to check for grad. This is useful when |
| 221 | + # non-leafs are present in the graph. |
| 222 | + def inner_retain_grad(obj): |
| 223 | + if obj.requires_grad: |
| 224 | + obj.retain_grad() |
| 225 | + self._traverse_obj(obj, inner_retain_grad) |
| 226 | + |
| 227 | + def _get_grads(self, obj): |
| 228 | + def inner_get_grad(obj): |
| 229 | + if obj.requires_grad: |
| 230 | + return obj.grad |
| 231 | + return self._traverse_obj(obj, inner_get_grad) |
| 232 | + |
| 233 | + def _zero_grad(self, obj): |
| 234 | + def inner_zero_grad(obj): |
| 235 | + if obj.grad is not None: |
| 236 | + obj.grad = None |
| 237 | + self._traverse_obj(obj, inner_zero_grad) |
| 238 | + |
| 239 | + @modules(module_db) |
| 240 | + def test_non_contiguous_tensors(self, device, dtype, module_info): |
| 241 | + # Check modules work with non-contiguous tensors |
| 242 | + |
| 243 | + module_cls = module_info.module_cls |
| 244 | + module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, |
| 245 | + requires_grad=True) |
| 246 | + |
| 247 | + def _make_non_contiguous(obj): |
| 248 | + def inner_make_non_contiguous(obj): |
| 249 | + # Scalar tensors can not be made non-contiguous |
| 250 | + if not isinstance(obj, torch.Tensor) or obj.dim() == 0: |
| 251 | + return obj |
| 252 | + |
| 253 | + out = torch.repeat_interleave(obj, 2, dim=-1) |
| 254 | + out = out[..., ::2].detach() |
| 255 | + out.requires_grad = obj.requires_grad |
| 256 | + return out |
| 257 | + return self._traverse_obj(obj, inner_make_non_contiguous) |
| 258 | + |
| 259 | + def _can_be_noncontiguous(obj): |
| 260 | + if isinstance(obj, (tuple, list)): |
| 261 | + return any(_can_be_noncontiguous(o) for o in obj) |
| 262 | + elif isinstance(obj, dict): |
| 263 | + return any(_can_be_noncontiguous(o) for o in obj.values()) |
| 264 | + # scalar tensors can not be non-contiguous |
| 265 | + if not isinstance(obj, torch.Tensor) or obj.dim() == 0: |
| 266 | + return False |
| 267 | + return True |
| 268 | + |
| 269 | + |
| 270 | + for module_input in module_inputs: |
| 271 | + if module_input.forward_input is None: |
| 272 | + continue |
| 273 | + |
| 274 | + input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs |
| 275 | + if not (_can_be_noncontiguous(input_args) or _can_be_noncontiguous(input_kwargs)): |
| 276 | + continue |
| 277 | + |
| 278 | + # === Instantiate the module. === |
| 279 | + args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| 280 | + m = module_cls(*args, **kwargs) |
| 281 | + m.to(device).to(dtype) |
| 282 | + |
| 283 | + self._retain_grad((input_args, input_kwargs)) |
| 284 | + |
| 285 | + # === Forward with default input |
| 286 | + with freeze_rng_state(): |
| 287 | + default_output = m(*input_args, **input_kwargs) |
| 288 | + grad_output = default_output.clone().detach_().normal_() |
| 289 | + default_output.backward(grad_output, retain_graph=True) |
| 290 | + |
| 291 | + default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs))) |
| 292 | + default_param_grad = deepcopy([p.grad for p in m.parameters()]) |
| 293 | + |
| 294 | + # === Construct non-contiguous tensors === |
| 295 | + nc_input_args, nc_input_kwargs = _make_non_contiguous((input_args, input_kwargs)) |
| 296 | + nc_grad_output = _make_non_contiguous(grad_output) |
| 297 | + |
| 298 | + # === Compare results with non-contiguous and contiguous tensors === |
| 299 | + inputs = [(input_args, input_kwargs), (nc_input_args, nc_input_kwargs)] |
| 300 | + grads = [grad_output, nc_grad_output] |
| 301 | + |
| 302 | + for (in_args, in_kwargs), g_out in product(inputs, grads): |
| 303 | + g_out_copy = deepcopy(g_out) |
| 304 | + self._zero_grad((in_args, in_kwargs)) |
| 305 | + self._zero_grad(m.parameters()) |
| 306 | + |
| 307 | + with freeze_rng_state(): |
| 308 | + out = m(*in_args, **in_kwargs) |
| 309 | + out.backward(g_out_copy, retain_graph=True) |
| 310 | + |
| 311 | + input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs)) |
| 312 | + self.assertEqual(out, default_output) |
| 313 | + self.assertEqual(input_args_grad, default_input_args_grad, atol=1e-4, rtol=0) |
| 314 | + self.assertEqual(input_kwargs_grad, default_input_kwargs_grad, atol=1e-4, rtol=0) |
| 315 | + |
| 316 | + param_grad = [p.grad for p in m.parameters()] |
| 317 | + self.assertEqual(param_grad, default_param_grad) |
| 318 | + |
208 | 319 |
|
209 | 320 | def _test_gradients_helper(self, device, dtype, module_info, check): |
210 | 321 | # Check gradients |
|
0 commit comments