Update nn.Module._apply to not gate on should_use_set_data when swap_tensors is set (#120659)

This updates the nesting of if statements in `nn.Module._apply` such that if

`torch.__future__.set_swap_module_params_on_conversion(True)`, we always try to swap regardless of whether
- `torch._has_compatible_shallow_copy_type(param, fn(param)`
- `torch.__future__.set_overwrite_module_params_on_conversion` is set

This means that `meta_module.to_empty('device')` can now use the swap_tensors path cc @awgu

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120659
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-02-26 14:20:21 -08:00
committed by PyTorch MergeBot
parent 213b3ac3f2
commit 677e67c399
4 changed files with 89 additions and 46 deletions

View File

@ -869,7 +869,6 @@ class TestModule(TestCase):
for module_input in module_inputs:
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
m = module_cls(*c_args, **c_kwargs)
@ -904,7 +903,7 @@ class TestModule(TestCase):
m.to(device=device_, dtype=dtype_)
self.assertTrue(isinstance(p, torch.nn.Parameter) for p in m.parameters())
self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters()))
self.assertTrue(all(p.device.type == device_ for p in m.parameters()))
self.assertTrue(all(p.dtype == dtype_ for p in m.parameters()))
p_ids_after = [id(p) for p in m.parameters()]
@ -932,6 +931,47 @@ class TestModule(TestCase):
self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after)))
@modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32])
@parametrize('swap', [True, False])
@wrapSwapTensorsTest()
def test_to_empty(self, device, dtype, module_info, swap, training):
module_cls = module_info.module_cls
with torch.device("meta"):
module_inputs = module_info.module_inputs_func(module_info, device=None, dtype=dtype,
requires_grad=False, training=training)
torch.__future__.set_swap_module_params_on_conversion(swap)
device_ = torch.device(device)
for module_input in module_inputs:
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
with torch.device("meta"):
m = module_cls(*c_args, **c_kwargs)
p_ids_before = [id(p) for p in m.parameters()]
p_cdatas_before = [p._cdata for p in m.parameters()]
m.to_empty(device=device_)
self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters()))
self.assertTrue(all(p.device == device_ for p in m.parameters()))
self.assertTrue(all(p.dtype == dtype for p in m.parameters()))
p_ids_after = [id(p) for p in m.parameters()]
p_cdatas_after = [p._cdata for p in m.parameters()]
if swap:
# id same, ._cdata differs --> swapped cdata of THPVariable
self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
else:
# id and ._cdata differ
# meta and device have different shallow copy types, so this will create a new
# parameter and assign it to the module
self.assertTrue(all(a != b for a, b in zip(p_ids_before, p_ids_after)))
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
instantiate_device_type_tests(TestModule, globals(), allow_mps=True)
if __name__ == '__main__':