mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-25 16:14:55 +08:00 
			
		
		
		
	Fixes #70125. Much of the work was done by #161687. This PR is additional test cleanup. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162766 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
		
			
				
	
	
		
			1014 lines
		
	
	
		
			51 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1014 lines
		
	
	
		
			51 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["module: nn"]
 | |
| 
 | |
| from itertools import chain, product
 | |
| from inspect import signature, isgenerator
 | |
| from copy import deepcopy
 | |
| import tempfile
 | |
| from operator import methodcaller
 | |
| 
 | |
| import torch
 | |
| 
 | |
| from torch._subclasses.meta_utils import assert_metadata_eq
 | |
| from torch.testing._internal.common_cuda import with_tf32_off
 | |
| from torch.testing._internal.common_device_type import (
 | |
|     instantiate_device_type_tests, onlyCPU, onlyCUDA, toleranceOverride, tol, skipMeta)
 | |
| from torch.testing._internal.common_modules import module_db, modules, ModuleErrorEnum, TrainEvalMode
 | |
| from torch.testing._internal.common_utils import (
 | |
|     TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck,
 | |
|     gradgradcheck, parametrize, wrapSwapTensorsTest, TEST_WITH_ROCM)
 | |
| from unittest.mock import patch, call
 | |
| 
 | |
| 
 | |
| if TEST_WITH_ROCM:
 | |
|     import os
 | |
|     os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
 | |
|     os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"
 | |
| 
 | |
| 
 | |
| class TestModule(TestCase):
 | |
|     _do_cuda_memory_leak_check = True
 | |
|     _do_cuda_non_default_stream = True
 | |
|     precision = 1e-5
 | |
|     rel_tol = 1e-5
 | |
| 
 | |
|     def _assert_module_parameters_and_buffer_are(self, module, device, dtype):
 | |
|         # Check device placement and dtype for created parameters and buffers.
 | |
|         # Only verify floating point dtypes since that's what the kwarg or methods
 | |
|         # such as `float()` applies to.
 | |
|         if not isinstance(device, torch.device):
 | |
|             device = torch.device(device)
 | |
| 
 | |
|         def _check_module(items, name, device=device, dtype=dtype):
 | |
|             for item_name, item in items:
 | |
|                 self.assertEqual(
 | |
|                     item.device, device,
 | |
|                     f'{name} {item_name} is on device {item.device} instead of the expected device {device}')
 | |
|                 if item.dtype.is_floating_point:
 | |
|                     self.assertEqual(
 | |
|                         item.dtype, dtype,
 | |
|                         f'{name} {item_name} is of dtype {item.dtype} instead of the expected dtype {dtype}')
 | |
|         _check_module(module.named_parameters(), "Parameter")
 | |
|         _check_module(module.named_buffers(), "Buffer")
 | |
| 
 | |
|     @modules(module_db)
 | |
|     def test_forward(self, device, dtype, module_info, training):
 | |
|         module_cls = module_info.module_cls
 | |
|         module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
 | |
|                                                        requires_grad=False, training=training)
 | |
|         dtype_to_method_caller = {
 | |
|             torch.float32: methodcaller("float"),
 | |
|             torch.float64: methodcaller("double"),
 | |
|         }
 | |
|         for module_input in module_inputs:
 | |
|             if module_input.forward_input is None:
 | |
|                 continue
 | |
| 
 | |
|             with freeze_rng_state():
 | |
|                 # === Instantiate the module. ===
 | |
|                 args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | |
|                 m = module_cls(*args, **kwargs)
 | |
|                 m.to(device).to(dtype)
 | |
|                 m.train(training)
 | |
| 
 | |
|                 # === Do forward pass. ===
 | |
|                 args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
 | |
|                 outputs = m(*args, **kwargs)
 | |
| 
 | |
|                 # === Compare outputs to a reference if one is specified. ===
 | |
|                 # TODO: Handle precision
 | |
|                 reference_fn = module_input.reference_fn
 | |
|                 if reference_fn is not None:
 | |
|                     ref_outputs = reference_fn(m, *args, **kwargs)
 | |
|                     self.assertEqual(outputs, ref_outputs)
 | |
| 
 | |
|                 # === Use the method call and verify the parameters and buffers ===
 | |
|                 if dtype in dtype_to_method_caller:
 | |
|                     dtype_to_method_caller[dtype](m)
 | |
|                     m(*args, **kwargs)
 | |
|                     self._assert_module_parameters_and_buffer_are(m, device, dtype)
 | |
| 
 | |
|     # Tests passing factory kwargs (e.g. device / dtype) during module instantiation.
 | |
|     # They should be applied to any created parameters and buffers.
 | |
|     @modules(module_db)
 | |
|     def test_factory_kwargs(self, device, dtype, module_info, training):
 | |
|         module_cls = module_info.module_cls
 | |
|         module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
 | |
|                                                        requires_grad=False, training=training)
 | |
|         for module_input in module_inputs:
 | |
|             args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | |
| 
 | |
|             # Check if this module creates parameters or registers buffers.
 | |
|             # The mock magic here passes through to the real Parameter / register_buffer
 | |
|             # logic and is only used to check call inputs.
 | |
|             module_creates_params_or_buffers = False
 | |
|             parameter_new = mock_wrapper(torch.nn.Parameter.__new__)
 | |
|             with patch.object(torch.nn.Parameter, '__new__', parameter_new):
 | |
|                 register_buffer = mock_wrapper(torch.nn.Module.register_buffer)
 | |
|                 with patch.object(torch.nn.Module, 'register_buffer', register_buffer):
 | |
|                     m = module_cls(*args, **kwargs)
 | |
|                     m.train(training)
 | |
| 
 | |
|                     # Check if a parameter or buffer was created with a tensor not passed to the constructor.
 | |
|                     constructor_tensors = get_tensors_from(args, kwargs)
 | |
|                     for mock in [parameter_new.mock, register_buffer.mock]:
 | |
|                         for call_args, call_kwargs in mock.call_args_list:
 | |
|                             call_tensors = get_tensors_from(call_args, call_kwargs)
 | |
|                             if len(call_tensors) > 0 and not constructor_tensors.intersection(call_tensors):
 | |
|                                 module_creates_params_or_buffers = True
 | |
|                                 break
 | |
| 
 | |
|             if not module_creates_params_or_buffers:
 | |
|                 continue
 | |
| 
 | |
|             # Instantiate module with the factory kwargs.
 | |
|             kwargs.update({
 | |
|                 'device': device,
 | |
|                 'dtype': dtype,
 | |
|             })
 | |
| 
 | |
|             if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin):
 | |
|                 # Ensure device and dtype are passed to all UninitializedParameters and UninitializedBuffers.
 | |
|                 uninit_param_new = mock_wrapper(torch.nn.UninitializedParameter.__new__)
 | |
|                 with patch.object(torch.nn.UninitializedParameter, '__new__', uninit_param_new):
 | |
|                     uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__)
 | |
|                     with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
 | |
|                         m = module_cls(*args, **kwargs)
 | |
|                         m.train(training)
 | |
|                         uninit_param_new.mock.assert_has_calls(
 | |
|                             [call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
 | |
|                         uninit_buffer_new.mock.assert_has_calls(
 | |
|                             [call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls])
 | |
|             else:
 | |
|                 # Check device placement and dtype for created parameters and buffers.
 | |
|                 # Only verify floating point dtypes since that's what the kwarg applies to.
 | |
|                 m = module_cls(*args, **kwargs)
 | |
|                 m.train(training)
 | |
|                 self._assert_module_parameters_and_buffer_are(m, device, dtype)
 | |
| 
 | |
|     @onlyCUDA
 | |
|     @modules(module_db)
 | |
|     def test_multiple_device_transfer(self, device, dtype, module_info, training):
 | |
|         module_cls = module_info.module_cls
 | |
|         module_inputs_device = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
 | |
|                                                               requires_grad=False, training=training)
 | |
|         module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
 | |
|                                                            requires_grad=False, training=training)
 | |
|         for module_input_device, module_input_cpu in zip(module_inputs_device, module_inputs_cpu):
 | |
|             if module_input_device.forward_input is None:
 | |
|                 continue
 | |
| 
 | |
|             with freeze_rng_state():
 | |
|                 # === Instantiate the module. ===
 | |
|                 args, kwargs = module_input_device.constructor_input.args, module_input_device.constructor_input.kwargs
 | |
|                 m = module_cls(*args, **kwargs)
 | |
|                 m.to(device).to(dtype)
 | |
|                 m.train(training)
 | |
| 
 | |
|                 # === Do forward pass on GPU ===
 | |
|                 input_device_args = module_input_device.forward_input.args
 | |
|                 input_device_kwargs = module_input_device.forward_input.kwargs
 | |
|                 m(*input_device_args, **input_device_kwargs)
 | |
|                 self._assert_module_parameters_and_buffer_are(m, device, dtype)
 | |
| 
 | |
|                 # === Move to CPU ===
 | |
|                 input_cpu_args = module_input_cpu.forward_input.args
 | |
|                 input_cpu_kwargs = module_input_cpu.forward_input.kwargs
 | |
|                 m.cpu()
 | |
|                 m(*input_cpu_args, **input_cpu_kwargs)
 | |
|                 self._assert_module_parameters_and_buffer_are(m, "cpu", dtype)
 | |
| 
 | |
|                 # === Move back to GPU and forward pass ===
 | |
|                 m.cuda()
 | |
|                 m(*input_device_args, **input_device_kwargs)
 | |
|                 self._assert_module_parameters_and_buffer_are(m, device, dtype)
 | |
| 
 | |
|                 if torch.cuda.device_count() >= 2:
 | |
|                     # === test cross-GPU transfer works
 | |
|                     def _to_device1(objs):
 | |
|                         if isinstance(objs, (tuple, list)):
 | |
|                             return type(objs)(_to_device1(item) for item in objs)
 | |
|                         elif isinstance(objs, dict):
 | |
|                             return {name: _to_device1(item) for name, item in objs.items()}
 | |
|                         elif isinstance(objs, torch.Tensor):
 | |
|                             return objs.cuda(1)
 | |
|                         else:
 | |
|                             return objs
 | |
|                     input_device_1_args = _to_device1(input_device_args)
 | |
|                     input_device_1_kwargs = _to_device1(input_device_kwargs)
 | |
| 
 | |
|                     m.cuda(1)
 | |
|                     with torch.cuda.device(1):
 | |
|                         m(*input_device_1_args, **input_device_1_kwargs)
 | |
|                     self._assert_module_parameters_and_buffer_are(m, torch.device("cuda:1"), dtype)
 | |
| 
 | |
|     @modules(module_db)
 | |
|     def test_repr(self, device, dtype, module_info, training):
 | |
|         # Test module can be represented with repr and str without errors.
 | |
|         module_cls = module_info.module_cls
 | |
|         module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
 | |
|                                                        requires_grad=False, training=training)
 | |
|         for module_input in module_inputs:
 | |
|             args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | |
|             m = module_cls(*args, **kwargs)
 | |
|             m.to(device).to(dtype)
 | |
|             m.train(training)
 | |
| 
 | |
|             # Check that these methods do not raise errors
 | |
|             m.__repr__()
 | |
|             str(m)
 | |
| 
 | |
|     @modules(module_db)
 | |
|     def test_save_load(self, device, dtype, module_info, training):
 | |
|         # Test that module can be pickled and unpickled.
 | |
|         module_cls = module_info.module_cls
 | |
|         module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
 | |
|                                                        requires_grad=False, training=training)
 | |
|         for module_input in module_inputs:
 | |
|             if module_input.forward_input is None:
 | |
|                 continue
 | |
| 
 | |
|             args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | |
| 
 | |
|             with freeze_rng_state():
 | |
|                 # === Instantiate the module. ===
 | |
|                 args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | |
|                 m = module_cls(*args, **kwargs)
 | |
|                 m.to(device).to(dtype)
 | |
|                 m.train(training)
 | |
|                 sd = m.state_dict()
 | |
| 
 | |
|                 # === Do forward pass. ===
 | |
|                 args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
 | |
|                 output = m(*args, **kwargs)
 | |
| 
 | |
|                 # === Check saved/loaded module gives the same output. ===
 | |
|                 with tempfile.TemporaryFile() as f:
 | |
|                     torch.save(m, f)
 | |
|                     f.seek(0)
 | |
|                     # weights_only=False as this is legacy code that saves the model
 | |
|                     m_copy = torch.load(f, weights_only=False)
 | |
|                     output_from_copy = m_copy(*args, **kwargs)
 | |
|                     self.assertEqual(output, output_from_copy)
 | |
| 
 | |
|                 # === Check saved/loaded state_dict are the same (including weights_only load). ===
 | |
|                 with tempfile.TemporaryFile() as f:
 | |
|                     torch.save(sd, f)
 | |
|                     f.seek(0)
 | |
|                     sd_copy = torch.load(f)
 | |
|                     self.assertEqual(sd_copy, sd)
 | |
|                     del sd_copy
 | |
|                     f.seek(0)
 | |
|                     sd_copy_wo = torch.load(f, weights_only=True)
 | |
|                     self.assertEqual(sd_copy_wo, sd)
 | |
| 
 | |
|     @skipMeta
 | |
|     @modules([module_info for module_info in module_db
 | |
|               if 'inplace' in signature(module_info.module_cls).parameters])
 | |
|     def test_check_inplace(self, device, dtype, module_info, training):
 | |
|         # Check if the inplace variant of the module gives the same result as the out of place
 | |
|         # variant.
 | |
|         module_cls = module_info.module_cls
 | |
|         module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
 | |
|                                                        requires_grad=True, training=training)
 | |
|         for module_input in module_inputs:
 | |
|             if module_input.forward_input is None:
 | |
|                 continue
 | |
| 
 | |
|             # === Instantiate the module. ===
 | |
|             args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | |
|             m_op = module_cls(*args, **kwargs, inplace=False)
 | |
|             m_op.to(device).to(dtype)
 | |
|             m_op.train(training)
 | |
|             m_inplace = module_cls(*args, **kwargs, inplace=True)
 | |
|             m_inplace.to(device).to(dtype)
 | |
|             m_inplace.train(training)
 | |
| 
 | |
|             # === Inplace modules only supports inplace operations on the first argument ===
 | |
|             input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
 | |
| 
 | |
|             # ===  Do not allow the first input to be in input_kwargs ===
 | |
|             forward_sig = signature(m_op).parameters
 | |
|             self.assertGreaterEqual(len(forward_sig), 1)
 | |
|             first_param_name = next(iter(forward_sig.items()))
 | |
|             self.assertNotIn(first_param_name, input_kwargs)
 | |
| 
 | |
|             # === Out of place operation does not write to original tensor ===
 | |
|             self.assertGreaterEqual(len(input_args), 1)
 | |
|             input_version = input_args[0]._version
 | |
|             with freeze_rng_state():
 | |
|                 output_op = m_op(*input_args, **input_kwargs)
 | |
|             self.assertEqual(input_args[0]._version, input_version)
 | |
| 
 | |
|             # === Check that the inplace operation gives the same result ===
 | |
|             input_arg_copy = deepcopy(input_args)
 | |
|             input_arg_clone = tuple(i.clone() for i in input_arg_copy)
 | |
|             input_clone_version = input_arg_clone[0]._version
 | |
|             with freeze_rng_state():
 | |
|                 output_ip = m_inplace(*input_arg_clone, **input_kwargs)
 | |
|             self.assertGreater(input_arg_clone[0]._version, input_clone_version)
 | |
|             self.assertEqual(output_op, output_ip)
 | |
| 
 | |
|             # === Check that the gradients are the same ===
 | |
|             grad = output_op.data.clone().normal_()
 | |
|             output_op.backward(grad)
 | |
|             output_ip.backward(grad)
 | |
|             self.assertEqual(input_args[0].grad, input_arg_copy[0].grad)
 | |
| 
 | |
|     def _traverse_obj(self, obj, func):
 | |
|         if isinstance(obj, (tuple, list)):
 | |
|             return type(obj)(self._traverse_obj(o, func) for o in obj)
 | |
|         elif isgenerator(obj):
 | |
|             return tuple(self._traverse_obj(o, func) for o in obj)
 | |
|         elif isinstance(obj, dict):
 | |
|             return {name: self._traverse_obj(o, func) for name, o in obj.items()}
 | |
|         elif isinstance(obj, (torch.Tensor, torch.nn.Parameter)):
 | |
|             return func(obj)
 | |
|         else:
 | |
|             return obj
 | |
| 
 | |
|     def _retain_grad(self, obj):
 | |
|         # gradients needs to be retained to check for grad. This is useful when
 | |
|         # non-leafs are present in the graph.
 | |
|         def inner_retain_grad(obj):
 | |
|             if obj.requires_grad:
 | |
|                 obj.retain_grad()
 | |
|         self._traverse_obj(obj, inner_retain_grad)
 | |
| 
 | |
|     def _get_grads(self, obj):
 | |
|         def inner_get_grad(obj):
 | |
|             if obj.requires_grad:
 | |
|                 return obj.grad
 | |
|         return self._traverse_obj(obj, inner_get_grad)
 | |
| 
 | |
|     def _zero_grad(self, obj):
 | |
|         def inner_zero_grad(obj):
 | |
|             if obj.grad is not None:
 | |
|                 obj.grad = None
 | |
|         self._traverse_obj(obj, inner_zero_grad)
 | |
| 
 | |
|     @modules(module_db)
 | |
|     def test_non_contiguous_tensors(self, device, dtype, module_info, training):
 | |
|         # Check modules work with non-contiguous tensors
 | |
| 
 | |
|         module_cls = module_info.module_cls
 | |
|         module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
 | |
|                                                        requires_grad=True, training=training)
 | |
| 
 | |
|         def _make_non_contiguous(obj):
 | |
|             def inner_make_non_contiguous(obj):
 | |
|                 # Scalar tensors can not be made non-contiguous
 | |
|                 if not isinstance(obj, torch.Tensor) or obj.dim() == 0:
 | |
|                     return obj
 | |
| 
 | |
|                 out = torch.repeat_interleave(obj, 2, dim=-1)
 | |
|                 out = out[..., ::2].detach()
 | |
|                 out.requires_grad = obj.requires_grad
 | |
|                 return out
 | |
|             return self._traverse_obj(obj, inner_make_non_contiguous)
 | |
| 
 | |
|         def _can_be_noncontiguous(obj):
 | |
|             if isinstance(obj, (tuple, list)):
 | |
|                 return any(_can_be_noncontiguous(o) for o in obj)
 | |
|             elif isinstance(obj, dict):
 | |
|                 return any(_can_be_noncontiguous(o) for o in obj.values())
 | |
|             # scalar tensors can not be non-contiguous
 | |
|             return isinstance(obj, torch.Tensor) and obj.dim() != 0
 | |
| 
 | |
|         for module_input in module_inputs:
 | |
|             if module_input.forward_input is None:
 | |
|                 continue
 | |
| 
 | |
|             input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
 | |
|             if not (_can_be_noncontiguous(input_args) or _can_be_noncontiguous(input_kwargs)):
 | |
|                 continue
 | |
| 
 | |
|             # === Instantiate the module. ===
 | |
|             args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | |
|             m = module_cls(*args, **kwargs)
 | |
|             m.to(device).to(dtype)
 | |
|             m.train(training)
 | |
| 
 | |
|             self._retain_grad((input_args, input_kwargs))
 | |
| 
 | |
|             # === Forward with default input
 | |
|             with freeze_rng_state():
 | |
|                 default_output = m(*input_args, **input_kwargs)
 | |
|                 if isinstance(default_output, torch.Tensor):
 | |
|                     grad_output = default_output.clone().detach_().normal_()
 | |
|                     default_output.backward(grad_output, retain_graph=True)
 | |
|                 else:
 | |
|                     grad_output = tuple(self._traverse_obj(o, lambda o: o.clone().detach_().normal_() if o.requires_grad else None)
 | |
|                                         for o in default_output)
 | |
|                     flattened_default_output = torch.utils._pytree.tree_leaves(default_output)
 | |
|                     flattened_grad_output = torch.utils._pytree.tree_leaves(grad_output)
 | |
|                     for o, g_o in zip(flattened_default_output, flattened_grad_output):
 | |
|                         if (o.requires_grad):
 | |
|                             o.backward(g_o, retain_graph=True)
 | |
| 
 | |
|             default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs)))
 | |
|             default_param_grad = deepcopy([p.grad for p in m.parameters()])
 | |
| 
 | |
|             # === Construct non-contiguous tensors ===
 | |
|             nc_input_args, nc_input_kwargs = _make_non_contiguous((input_args, input_kwargs))
 | |
|             nc_grad_output = _make_non_contiguous(grad_output)
 | |
| 
 | |
|             # === Compare results with non-contiguous and contiguous tensors ===
 | |
|             inputs = [(input_args, input_kwargs), (nc_input_args, nc_input_kwargs)]
 | |
|             grads = [grad_output, nc_grad_output]
 | |
| 
 | |
|             for (in_args, in_kwargs), g_out in product(inputs, grads):
 | |
|                 g_out_copy = deepcopy(g_out)
 | |
|                 self._zero_grad((in_args, in_kwargs))
 | |
|                 self._zero_grad(m.parameters())
 | |
| 
 | |
|                 with freeze_rng_state():
 | |
|                     out = m(*in_args, **in_kwargs)
 | |
|                     if isinstance(out, torch.Tensor):
 | |
|                         out.backward(g_out_copy, retain_graph=True)
 | |
|                     else:
 | |
|                         flattened_out = torch.utils._pytree.tree_leaves(out)
 | |
|                         flattened_g_out_copy = torch.utils._pytree.tree_leaves(g_out_copy)
 | |
|                         for o, g_o in zip(flattened_out, flattened_g_out_copy):
 | |
|                             if o.requires_grad:
 | |
|                                 o.backward(g_o, retain_graph=True)
 | |
| 
 | |
|                 input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs))
 | |
|                 self.assertEqual(out, default_output)
 | |
|                 self.assertEqual(input_args_grad, default_input_args_grad, atol=1e-4, rtol=0)
 | |
|                 self.assertEqual(input_kwargs_grad, default_input_kwargs_grad, atol=1e-4, rtol=0)
 | |
| 
 | |
|                 param_grad = [p.grad for p in m.parameters()]
 | |
|                 self.assertEqual(param_grad, default_param_grad)
 | |
| 
 | |
|     def _test_gradients_helper(self, device, dtype, module_info, training, check):
 | |
|         # Check gradients
 | |
|         module_cls = module_info.module_cls
 | |
|         module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
 | |
|                                                        requires_grad=True, training=training)
 | |
|         # === Set nondet tol for gradcheck to user-defined value if on CUDA and cudNN is enabled
 | |
|         gradcheck_nondet_tol = 0.0
 | |
|         if (torch.device(device).type == 'cuda' and torch.backends.cudnn.enabled):
 | |
|             gradcheck_nondet_tol = module_info.gradcheck_nondet_tol
 | |
| 
 | |
|         for module_input in module_inputs:
 | |
|             if module_input.forward_input is None:
 | |
|                 continue
 | |
| 
 | |
|             # === Instantiate the module. ===
 | |
|             args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | |
|             m = module_cls(*args, **kwargs)
 | |
|             m.to(device).to(dtype)
 | |
|             m.train(training)
 | |
| 
 | |
|             params = tuple(m.parameters())
 | |
| 
 | |
|             # === Lazy modules need to see an input to initialize params before gradcheck is run. ===
 | |
|             input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
 | |
|             if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin):
 | |
|                 with torch.no_grad():
 | |
|                     m(*input_args, **input_kwargs)
 | |
| 
 | |
|             # === Perform gradient check on the input_args ===
 | |
|             other_kwargs = {}
 | |
|             kwarg_tensors = []
 | |
|             for name, obj in input_kwargs.items():
 | |
|                 if isinstance(obj, torch.Tensor):
 | |
|                     kwarg_tensors.append((name, obj))
 | |
|                 else:
 | |
|                     other_kwargs[name] = obj
 | |
| 
 | |
|             def fn_to_gradcheck(*flat_input_and_params):
 | |
|                 input_and_params = torch.utils._pytree.tree_unflatten(flat_input_and_params, flat_spec)
 | |
|                 new_input_args = input_and_params[:len(input_args)]
 | |
|                 kwarg_args = input_and_params[-len(kwarg_tensors):]
 | |
|                 new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)}
 | |
| 
 | |
|                 with freeze_rng_state():
 | |
|                     output = m(*new_input_args, **new_kwargs, **other_kwargs)
 | |
|                     output_flattened = torch.utils._pytree.tree_leaves(output)
 | |
|                     return output_flattened
 | |
| 
 | |
|             def do_check(flat_input):
 | |
|                 self.assertTrue(
 | |
|                     check(
 | |
|                         fn_to_gradcheck,
 | |
|                         flat_input,
 | |
|                         nondet_tol=gradcheck_nondet_tol,
 | |
|                         fast_mode=module_info.gradcheck_fast_mode
 | |
|                     ))
 | |
| 
 | |
|             # check total derivative
 | |
|             grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
 | |
|             flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
 | |
|             do_check(flat_input)
 | |
| 
 | |
|             # check partial derivatives
 | |
|             old_params_requires_grad = [p.requires_grad for p in params]
 | |
|             for p in params:
 | |
|                 p.requires_grad = False
 | |
| 
 | |
|             old_kwargs_requires_grad = [obj.requires_grad for (_, obj) in kwarg_tensors]
 | |
|             for (_, obj) in kwarg_tensors:
 | |
|                 obj.requires_grad = False
 | |
| 
 | |
|             for p, old in zip(params, old_params_requires_grad):
 | |
|                 p.requires_grad = old
 | |
|                 grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
 | |
|                 flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
 | |
|                 do_check(flat_input)
 | |
|                 p.requires_grad = False
 | |
| 
 | |
|             for (_, obj), old in zip(kwarg_tensors, old_kwargs_requires_grad):
 | |
|                 obj.requires_grad = old
 | |
|                 grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
 | |
|                 flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
 | |
|                 do_check(flat_input)
 | |
|                 obj.requires_grad = False
 | |
| 
 | |
|     @modules(module_db, allowed_dtypes=[torch.double])
 | |
|     def test_grad(self, device, dtype, module_info, training):
 | |
|         self._test_gradients_helper(device, dtype, module_info, training, gradcheck)
 | |
| 
 | |
|     @modules([m for m in module_db if m.supports_gradgrad],
 | |
|              allowed_dtypes=[torch.double])
 | |
|     def test_gradgrad(self, device, dtype, module_info, training):
 | |
|         self._test_gradients_helper(device, dtype, module_info, training, gradgradcheck)
 | |
| 
 | |
|     @onlyCUDA
 | |
|     @with_tf32_off  # Turn off TF32 to compute at full precision https://github.com/pytorch/pytorch/issues/86798
 | |
|     @toleranceOverride({torch.float32: tol(5e-2, 0),
 | |
|                         torch.float64: tol(4e-4, 0)})
 | |
|     @modules(module_db)
 | |
|     def test_cpu_gpu_parity(self, device, dtype, module_info, training):
 | |
|         # TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a
 | |
|         # nicer way for eval mode only.
 | |
|         # See https://github.com/pytorch/pytorch/issues/79161
 | |
|         rnn_modules = {torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM}
 | |
|         if (module_info.module_cls in rnn_modules
 | |
|                 and not training
 | |
|                 and 'cuda' in device
 | |
|                 and torch.backends.cudnn.enabled):
 | |
|             return
 | |
| 
 | |
|         # Test cpu and gpu results are the same
 | |
|         module_cls = module_info.module_cls
 | |
|         module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
 | |
|                                                            requires_grad=True, training=training)
 | |
| 
 | |
|         def _to_device(obj):
 | |
|             if isinstance(obj, torch.Tensor):
 | |
|                 res = obj.detach().to(device=device)
 | |
|                 res.requires_grad = obj.requires_grad
 | |
|                 return res
 | |
|             elif isinstance(obj, tuple):
 | |
|                 return tuple(_to_device(o) for o in obj)
 | |
|             elif isinstance(obj, dict):
 | |
|                 return {key: _to_device(o) for key, o in obj.items()}
 | |
|             else:
 | |
|                 return deepcopy(obj)
 | |
| 
 | |
|         for module_input in module_inputs_cpu:
 | |
|             # === Move input from cpu to device ===
 | |
|             cpu_forward_args = module_input.forward_input.args
 | |
|             cpu_forward_kwargs = module_input.forward_input.kwargs
 | |
| 
 | |
|             gpu_forward_args, gpu_forward_kwargs = _to_device((cpu_forward_args, cpu_forward_kwargs))
 | |
| 
 | |
|             self._retain_grad((cpu_forward_args, cpu_forward_kwargs, gpu_forward_args, gpu_forward_kwargs))
 | |
| 
 | |
|             # === Construct module on cpu and gpu ===
 | |
|             args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | |
| 
 | |
|             cpu_module = module_cls(*args, **kwargs).to(dtype).to("cpu")
 | |
|             cpu_module.train(training)
 | |
|             gpu_module = module_cls(*args, **kwargs).to(dtype).to(device)
 | |
|             gpu_module.train(training)
 | |
| 
 | |
|             # === Lazy modules need to see an input to initialize params ===
 | |
|             if issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin):
 | |
|                 with torch.no_grad():
 | |
|                     cpu_module(*cpu_forward_args, **cpu_forward_kwargs)
 | |
|                     gpu_module(*gpu_forward_args, **gpu_forward_kwargs)
 | |
| 
 | |
|             for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
 | |
|                 gpu_p.data.copy_(cpu_p)
 | |
| 
 | |
|             # === Compare forward output between cpu and gpu ===
 | |
|             cpu_outputs = cpu_module(*cpu_forward_args, **cpu_forward_kwargs)
 | |
|             gpu_outputs = gpu_module(*gpu_forward_args, **gpu_forward_kwargs)
 | |
| 
 | |
|             self.assertEqual(cpu_outputs, gpu_outputs)
 | |
| 
 | |
|             # === Run backwards on CPU and GPU and compare results ===
 | |
|             def check_backward(cpu_output, gpu_output):
 | |
|                 cpu_grad_output = cpu_output.clone().normal_()
 | |
|                 gpu_grad_output = cpu_grad_output.type_as(gpu_output)
 | |
| 
 | |
|                 cpu_output.backward(cpu_grad_output, retain_graph=True)
 | |
|                 gpu_output.backward(gpu_grad_output, retain_graph=True)
 | |
| 
 | |
|                 cpu_grad_input = self._get_grads(cpu_forward_args)
 | |
|                 gpu_grad_input = self._get_grads(gpu_forward_args)
 | |
|                 self.assertEqual(cpu_grad_input, gpu_grad_input)
 | |
| 
 | |
|                 for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
 | |
|                     self.assertEqual(cpu_p.grad, gpu_p.grad)
 | |
| 
 | |
|                 cpu_grad_kwarg_input = self._get_grads(cpu_forward_kwargs)
 | |
|                 gpu_grad_kwarg_input = self._get_grads(gpu_forward_kwargs)
 | |
|                 self.assertEqual(cpu_grad_kwarg_input, gpu_grad_kwarg_input)
 | |
| 
 | |
|             for _ in range(5):
 | |
|                 if isinstance(cpu_outputs, torch.Tensor):
 | |
|                     check_backward(cpu_outputs, gpu_outputs)
 | |
|                 else:
 | |
|                     flatten_cpu_outputs = torch.utils._pytree.tree_leaves(cpu_outputs)
 | |
|                     flatten_gpu_outputs = torch.utils._pytree.tree_leaves(gpu_outputs)
 | |
|                     for cpu_output, gpu_output in zip(flatten_cpu_outputs, flatten_gpu_outputs):
 | |
|                         if cpu_output.requires_grad:
 | |
|                             check_backward(cpu_output, gpu_output)
 | |
| 
 | |
|     @with_tf32_off
 | |
|     @modules(module_db)
 | |
|     def test_memory_format(self, device, dtype, module_info, training):
 | |
|         is_sm86or80 = device.startswith("cuda") and (torch.cuda.get_device_capability(0) == (8, 6)
 | |
|                                                      or torch.cuda.get_device_capability(0) == (8, 0))
 | |
|         # TODO tighten it to a specific module
 | |
|         atol, rtol = (3e-3, 7e-3) if is_sm86or80 else (None, None)
 | |
|         module_cls = module_info.module_cls
 | |
|         module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
 | |
|                                                        requires_grad=True, training=training)
 | |
|         module_memformat_affects_out = module_info.module_memformat_affects_out
 | |
| 
 | |
|         def _get_mem_formats(channels_last=False, channels_last_3d=False):
 | |
|             if channels_last:
 | |
|                 return ([torch.contiguous_format, torch.channels_last],
 | |
|                         [torch.preserve_format, torch.contiguous_format, torch.channels_last])
 | |
|             elif channels_last_3d:
 | |
|                 return ([torch.contiguous_format, torch.channels_last_3d],
 | |
|                         [torch.preserve_format, torch.contiguous_format, torch.channels_last_3d])
 | |
|             else:
 | |
|                 return ([torch.contiguous_format],
 | |
|                         [torch.preserve_format, torch.contiguous_format])
 | |
| 
 | |
|         # Check that at least one Tensor input has dim == n
 | |
|         def _check_dims(obj, n):
 | |
|             if isinstance(obj, torch.Tensor):
 | |
|                 return obj.dim() == n
 | |
|             elif isinstance(obj, (tuple, list)):
 | |
|                 return any(_check_dims(o, n) for o in obj)
 | |
|             else:
 | |
|                 return False
 | |
| 
 | |
|         # Called after _check_dims, when we know that >= 1 tensor can be converted to mem_format
 | |
|         def _to_mem_format(mem_format, obj):
 | |
|             def inner_to_mem_format(obj):
 | |
|                 d = obj.dim()
 | |
|                 if ((mem_format == torch.channels_last and d != 4)
 | |
|                    or (mem_format == torch.channels_last_3d and d != 5)):
 | |
|                     return obj.detach().clone().requires_grad_(obj.requires_grad)
 | |
|                 return obj.clone().to(memory_format=mem_format).detach().requires_grad_(obj.requires_grad)
 | |
| 
 | |
|             return self._traverse_obj(obj, inner_to_mem_format)
 | |
| 
 | |
|         def _check_out_mem_format(output, input_mem_format, module_mem_format):
 | |
|             def inner_check_out_mem_format(output):
 | |
|                 d = output.dim()
 | |
|                 if (d == 4 and ((input_mem_format == torch.channels_last)
 | |
|                                 or (module_mem_format == torch.channels_last and module_memformat_affects_out))):
 | |
|                     self.assertTrue(output.numel() == 0 or output.is_contiguous(memory_format=torch.channels_last))
 | |
|                 elif (d == 5 and ((input_mem_format == torch.channels_last_3d)
 | |
|                                   or (module_mem_format == torch.channels_last_3d and module_memformat_affects_out))):
 | |
|                     self.assertTrue(output.numel() == 0 or output.is_contiguous(memory_format=torch.channels_last_3d))
 | |
|                 else:
 | |
|                     self.assertTrue(output.is_contiguous())
 | |
|             return self._traverse_obj(output, inner_check_out_mem_format)
 | |
| 
 | |
|         def _req_grad(t):
 | |
|             return isinstance(t, torch.Tensor) and t.requires_grad
 | |
| 
 | |
|         for module_input in module_inputs:
 | |
|             if module_input.forward_input is None:
 | |
|                 continue
 | |
| 
 | |
|             supports_channels_last = _check_dims(module_input.forward_input.args, 4)
 | |
|             supports_channels_last_3d = _check_dims(module_input.forward_input.args, 5)
 | |
|             input_mem_formats, module_mem_formats = _get_mem_formats(supports_channels_last, supports_channels_last_3d)
 | |
| 
 | |
|             with freeze_rng_state():
 | |
|                 # === Instantiate the module. ===
 | |
|                 args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | |
| 
 | |
|                 m = module_cls(*args, **kwargs)
 | |
|                 m.to(device).to(dtype)
 | |
|                 m.train(training)
 | |
| 
 | |
|                 # === Get output in (contiguous, contiguous) configuration. ===
 | |
|                 args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
 | |
|                 desired_outputs = m(*args, **kwargs)
 | |
|                 # === Do backward pass. ===
 | |
|                 ref_diff_outputs = tuple(t for t in torch.utils._pytree.tree_leaves(desired_outputs) if _req_grad(t))
 | |
|                 if training and len(ref_diff_outputs) > 0:
 | |
|                     params = tuple(p for p in m.parameters())
 | |
|                     ref_diff_inputs = tuple(
 | |
|                         t
 | |
|                         for t in torch.utils._pytree.tree_leaves((args, kwargs, params))
 | |
|                         if _req_grad(t)
 | |
|                     )
 | |
|                     ref_grad_outputs = tuple(
 | |
|                         torch.rand_like(t)
 | |
|                         for t in ref_diff_outputs
 | |
|                     )
 | |
|                     ref_grad_inputs = torch.autograd.grad(
 | |
|                         ref_diff_outputs,
 | |
|                         ref_diff_inputs,
 | |
|                         grad_outputs=ref_grad_outputs,
 | |
|                     )
 | |
| 
 | |
|                 for input_mem_format in input_mem_formats:
 | |
|                     # === Change memformat of input. ===
 | |
|                     d_args = _to_mem_format(input_mem_format, module_input.forward_input.args)
 | |
|                     d_kwargs = _to_mem_format(input_mem_format, module_input.forward_input.kwargs)
 | |
| 
 | |
|                     # See https://github.com/pytorch/pytorch/issues/107861
 | |
|                     # When inductor tests are turned on, the setting of requires_grad will be lost
 | |
|                     for t1, t2 in zip(
 | |
|                         torch.utils._pytree.tree_leaves(d_args),
 | |
|                         torch.utils._pytree.tree_leaves(module_input.forward_input.args),
 | |
|                     ):
 | |
|                         t1.requires_grad_(t2.requires_grad)
 | |
|                     for t1, t2 in zip(
 | |
|                         torch.utils._pytree.tree_leaves(d_kwargs),
 | |
|                         torch.utils._pytree.tree_leaves(module_input.forward_input.kwargs),
 | |
|                     ):
 | |
|                         t1.requires_grad_(t2.requires_grad)
 | |
| 
 | |
|                     module_input.forward_input.args = d_args
 | |
|                     module_input.forward_input.kwargs = d_kwargs
 | |
| 
 | |
|                     for module_mem_format in module_mem_formats:
 | |
|                         # === Change memformat of module ===
 | |
|                         m.to(memory_format=module_mem_format)
 | |
| 
 | |
|                         # === Do forward pass. ===
 | |
|                         args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
 | |
|                         outputs = m(*args, **kwargs)
 | |
| 
 | |
|                         # === Compare outputs to (contiguous, contiguous) output. ===
 | |
|                         if input_mem_format != torch.contiguous_format or module_mem_format != torch.contiguous_format:
 | |
|                             self.assertEqual(outputs, desired_outputs, rtol=rtol, atol=atol)
 | |
| 
 | |
|                         # === Check mem format of output. ===
 | |
|                         _check_out_mem_format(outputs, input_mem_format, module_mem_format)
 | |
| 
 | |
|                         # === Do backward pass. ===
 | |
|                         diff_outputs = tuple(t for t in torch.utils._pytree.tree_leaves(outputs) if _req_grad(t))
 | |
|                         if training and len(diff_outputs) > 0:
 | |
|                             params = tuple(p for p in m.parameters())
 | |
|                             diff_inputs = tuple(
 | |
|                                 t
 | |
|                                 for t in torch.utils._pytree.tree_leaves((args, kwargs, params))
 | |
|                                 if _req_grad(t)
 | |
|                             )
 | |
|                             grad_outputs = tuple(
 | |
|                                 torch.empty_like(t1).copy_(t2)
 | |
|                                 for (t1, t2) in zip(diff_outputs, ref_grad_outputs)
 | |
|                             )
 | |
| 
 | |
|                             grad_inputs = torch.autograd.grad(
 | |
|                                 diff_outputs,
 | |
|                                 diff_inputs,
 | |
|                                 grad_outputs=grad_outputs,
 | |
|                             )
 | |
| 
 | |
|                             if (
 | |
|                                 input_mem_format != torch.contiguous_format
 | |
|                                 or module_mem_format != torch.contiguous_format
 | |
|                             ):
 | |
|                                 self.assertEqual(
 | |
|                                     grad_inputs, ref_grad_inputs, rtol=rtol, atol=atol
 | |
|                                 )
 | |
| 
 | |
|                             # === Check mem format of grad_inputs. ===
 | |
|                             _check_out_mem_format(grad_inputs, input_mem_format, module_mem_format)
 | |
| 
 | |
|     # Test whether train and eval modes differ for each module. Use to verify
 | |
|     # that the ModuleInfo entry flag is correct.
 | |
|     @modules(module_db, train_eval_mode=TrainEvalMode.train_only)
 | |
|     def test_if_train_and_eval_modes_differ(self, device, dtype, module_info, training):
 | |
|         module_cls = module_info.module_cls
 | |
|         module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
 | |
|                                                        requires_grad=False, training=training)
 | |
| 
 | |
|         # Run forward inputs through to see if the training flag is accessed during forward.
 | |
|         for module_input in module_inputs:
 | |
|             if module_input.forward_input is None:
 | |
|                 continue
 | |
| 
 | |
|             # === Instantiate the module. ===
 | |
|             args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | |
|             m = module_cls(*args, **kwargs)
 | |
|             m.to(device).to(dtype)
 | |
|             m.train(training)
 | |
| 
 | |
|             # Remove training attribute and see if forward still works.
 | |
|             delattr(m, 'training')
 | |
| 
 | |
|             # === Do forward pass. ===
 | |
|             try:
 | |
|                 args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
 | |
|                 m(*args, **kwargs)
 | |
|             except AttributeError as e:
 | |
|                 if "'training'" in str(e):
 | |
|                     self.assertTrue(module_info.train_and_eval_differ,
 | |
|                                     f"The ModuleInfo entry for {module_info.name} has "
 | |
|                                     "train_and_eval_differ=False, but the training mode was found to "
 | |
|                                     "affect the forward pass. Consider setting train_and_eval_differ=True "
 | |
|                                     "for this ModuleInfo entry.")
 | |
|                 else:
 | |
|                     raise e
 | |
| 
 | |
| 
 | |
|     @onlyCPU
 | |
|     @modules(module_db)
 | |
|     def test_device_ctx_init(self, device, dtype, module_info, training):
 | |
|         module_cls = module_info.module_cls
 | |
|         module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
 | |
|                                                        requires_grad=False, training=training)
 | |
|         with torch.device('meta'):
 | |
|             module_inputs_meta = module_info.module_inputs_func(module_info, device=None, dtype=dtype,
 | |
|                                                                 requires_grad=False, training=training)
 | |
| 
 | |
|         for module_input, module_input_meta in zip(module_inputs, module_inputs_meta):
 | |
|             c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | |
| 
 | |
|             c_args_meta, c_kwargs_meta = module_input_meta.constructor_input.args, module_input_meta.constructor_input.kwargs
 | |
| 
 | |
|             m_cpu = module_cls(*c_args, **c_kwargs)
 | |
| 
 | |
|             with torch.device('meta'):
 | |
|                 m = module_cls(*c_args_meta, **c_kwargs_meta)
 | |
| 
 | |
|             for (p_meta, p_cpu) in chain(zip(m.parameters(), m_cpu.parameters()),
 | |
|                                          zip(m.buffers(), m_cpu.buffers())):
 | |
|                 if torch.nn.parameter.is_lazy(p_meta):
 | |
|                     continue
 | |
|                 self.assertTrue(p_meta.is_meta)
 | |
|                 assert_metadata_eq(self.assertEqual, p_meta, p_cpu)
 | |
| 
 | |
| 
 | |
|     @modules([module for module in module_db if module.module_error_inputs_func is not None])
 | |
|     def test_errors(self, device, dtype, module_info, training):
 | |
|         module_cls = module_info.module_cls
 | |
|         error_inputs = module_info.module_error_inputs_func(module_info, device=device, dtype=dtype,
 | |
|                                                             requires_grad=False, training=training)
 | |
|         for error_input in error_inputs:
 | |
|             module_input = error_input.module_error_input
 | |
|             c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | |
|             if error_input.error_on == ModuleErrorEnum.CONSTRUCTION_ERROR:
 | |
|                 with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
 | |
|                     m = module_cls(*c_args, **c_kwargs)
 | |
|             elif error_input.error_on == ModuleErrorEnum.FORWARD_ERROR:
 | |
|                 m = module_cls(*c_args, **c_kwargs)
 | |
|                 fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
 | |
|                 with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
 | |
|                     m(*fw_args, **fw_kwargs)
 | |
|             else:
 | |
|                 raise NotImplementedError(f"Unknown error type {error_input.error_on}")
 | |
| 
 | |
|     # Only run this test for float32 because the test loops over all the dtypes
 | |
|     @modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32])
 | |
|     @parametrize('swap', [True, False])
 | |
|     @parametrize('set_grad', [True, False])
 | |
|     @wrapSwapTensorsTest()
 | |
|     def test_to(self, device, dtype, module_info, training, swap, set_grad):
 | |
|         module_cls = module_info.module_cls
 | |
|         devices = ['cpu']
 | |
|         if torch.cuda.is_available():
 | |
|             devices += ['cuda']
 | |
|         dtypes = module_info.dtypes
 | |
|         module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
 | |
|                                                        requires_grad=False, training=training)
 | |
|         torch.__future__.set_swap_module_params_on_conversion(swap)
 | |
| 
 | |
|         for module_input in module_inputs:
 | |
|             c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | |
|             args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
 | |
| 
 | |
|             m = module_cls(*c_args, **c_kwargs)
 | |
| 
 | |
|             # Avoid using `module.to()` when constructing module since that is the method we are testing
 | |
|             def _to(m, set_grad=False):
 | |
|                 for c in m.children():
 | |
|                     _to(c, set_grad=set_grad)
 | |
|                 for n, p in m.named_parameters(recurse=False):
 | |
|                     new_p = torch.nn.Parameter(p.detach().clone().to(device, dtype))
 | |
|                     setattr(m, n, new_p)
 | |
|                     if set_grad:
 | |
|                         new_p.grad = torch.randn_like(new_p)
 | |
|                 for n, b in m.named_buffers(recurse=False):
 | |
|                     new_b = b.detach().clone().to(device, dtype)
 | |
|                     setattr(m, n, new_b)
 | |
|             _to(m, set_grad=set_grad)
 | |
| 
 | |
|             # Check .to() can be run after forward and backward with swap
 | |
|             has_params = len(list(m.parameters())) > 0
 | |
|             if swap and not set_grad and has_params:
 | |
|                 out = m(*args, **kwargs)
 | |
|                 if isinstance(out, tuple):
 | |
|                     out = out[0]
 | |
|                 out.sum().backward()
 | |
|                 m.to(dtype=torch.half)
 | |
|                 # reset
 | |
|                 m.to(dtype=torch.float32)
 | |
| 
 | |
|             prev_device, prev_dtype = device, dtype
 | |
|             for device_, dtype_ in product(devices, dtypes):
 | |
|                 # if device/dtype do not change, grad.to(device, dtype) is a no-op so
 | |
|                 # swapping will not change ._cdata
 | |
|                 # parameters will be wrapped in an nn.Parameter before swapping
 | |
|                 # which will cause the ._cdata to change
 | |
|                 g_no_swap = device_ == prev_device and dtype_ == prev_dtype
 | |
|                 prev_device, prev_dtype = device_, dtype_
 | |
| 
 | |
|                 p_ids_before = [id(p) for p in m.parameters()]
 | |
|                 p_cdatas_before = [p._cdata for p in m.parameters()]
 | |
|                 if set_grad:
 | |
|                     g_ids_before = [id(p.grad) for p in m.parameters()]
 | |
|                     g_cdatas_before = [p.grad._cdata for p in m.parameters()]
 | |
| 
 | |
|                 m.to(device=device_, dtype=dtype_)
 | |
| 
 | |
|                 self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters()))
 | |
|                 self.assertTrue(all(p.device.type == device_ for p in m.parameters()))
 | |
|                 self.assertTrue(all(p.dtype == dtype_ for p in m.parameters()))
 | |
|                 p_ids_after = [id(p) for p in m.parameters()]
 | |
|                 p_cdatas_after = [p._cdata for p in m.parameters()]
 | |
| 
 | |
|                 if set_grad:
 | |
|                     self.assertTrue(all(p.grad.device.type == device_ for p in m.parameters()))
 | |
|                     self.assertTrue(all(p.grad.dtype == dtype_ for p in m.parameters()))
 | |
|                     g_ids_after = [id(p.grad) for p in m.parameters()]
 | |
|                     g_cdatas_after = [p.grad._cdata for p in m.parameters()]
 | |
| 
 | |
|                 if swap:
 | |
|                     # id same, ._cdata differs --> swapped cdata of THPVariable
 | |
|                     self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
 | |
|                     self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
 | |
|                     if set_grad:
 | |
|                         self.assertTrue(
 | |
|                             all(a == b if g_no_swap else a != b for a, b in zip(g_cdatas_before, g_cdatas_after)))
 | |
|                 else:
 | |
|                     # id and _cdata remain the same --> .data setting
 | |
|                     self.assertTrue(all(a == b for a, b in zip(p_cdatas_before, p_cdatas_after)))
 | |
|                     self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
 | |
|                     if set_grad:
 | |
|                         self.assertTrue(all(a == b for a, b in zip(g_cdatas_before, g_cdatas_after)))
 | |
|                         self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after)))
 | |
| 
 | |
|     @modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32])
 | |
|     @parametrize('swap', [True, False])
 | |
|     @wrapSwapTensorsTest()
 | |
|     def test_to_empty(self, device, dtype, module_info, swap, training):
 | |
|         module_cls = module_info.module_cls
 | |
| 
 | |
|         with torch.device("meta"):
 | |
|             module_inputs = module_info.module_inputs_func(module_info, device=None, dtype=dtype,
 | |
|                                                            requires_grad=False, training=training)
 | |
| 
 | |
|         torch.__future__.set_swap_module_params_on_conversion(swap)
 | |
|         device_ = torch.device(device)
 | |
| 
 | |
|         for module_input in module_inputs:
 | |
|             c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | |
| 
 | |
|             with torch.device("meta"):
 | |
|                 m = module_cls(*c_args, **c_kwargs)
 | |
| 
 | |
|             p_ids_before = [id(p) for p in m.parameters()]
 | |
|             p_cdatas_before = [p._cdata for p in m.parameters()]
 | |
|             m.to_empty(device=device_)
 | |
| 
 | |
|             self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters()))
 | |
|             self.assertTrue(all(p.device == device_ for p in m.parameters()))
 | |
|             self.assertTrue(all(p.dtype == dtype for p in m.parameters()))
 | |
|             p_ids_after = [id(p) for p in m.parameters()]
 | |
|             p_cdatas_after = [p._cdata for p in m.parameters()]
 | |
| 
 | |
|             if swap:
 | |
|                 # id same, ._cdata differs --> swapped cdata of THPVariable
 | |
|                 self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
 | |
|                 self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
 | |
|             else:
 | |
|                 # id and ._cdata differ
 | |
|                 # meta and device have different shallow copy types, so this will create a new
 | |
|                 # parameter and assign it to the module
 | |
|                 self.assertTrue(all(a != b for a, b in zip(p_ids_before, p_ids_after)))
 | |
|                 self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
 | |
| 
 | |
| 
 | |
| instantiate_device_type_tests(TestModule, globals(), allow_mps=True)
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     run_tests()
 |