mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update nn.Module._apply to not gate on should_use_set_data when swap_tensors is set (#120659)
This updates the nesting of if statements in `nn.Module._apply` such that if `torch.__future__.set_swap_module_params_on_conversion(True)`, we always try to swap regardless of whether - `torch._has_compatible_shallow_copy_type(param, fn(param)` - `torch.__future__.set_overwrite_module_params_on_conversion` is set This means that `meta_module.to_empty('device')` can now use the swap_tensors path cc @awgu Pull Request resolved: https://github.com/pytorch/pytorch/pull/120659 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
213b3ac3f2
commit
677e67c399
@ -869,7 +869,6 @@ class TestModule(TestCase):
|
||||
|
||||
for module_input in module_inputs:
|
||||
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
|
||||
m = module_cls(*c_args, **c_kwargs)
|
||||
|
||||
@ -904,7 +903,7 @@ class TestModule(TestCase):
|
||||
|
||||
m.to(device=device_, dtype=dtype_)
|
||||
|
||||
self.assertTrue(isinstance(p, torch.nn.Parameter) for p in m.parameters())
|
||||
self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters()))
|
||||
self.assertTrue(all(p.device.type == device_ for p in m.parameters()))
|
||||
self.assertTrue(all(p.dtype == dtype_ for p in m.parameters()))
|
||||
p_ids_after = [id(p) for p in m.parameters()]
|
||||
@ -932,6 +931,47 @@ class TestModule(TestCase):
|
||||
self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after)))
|
||||
|
||||
|
||||
@modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32])
|
||||
@parametrize('swap', [True, False])
|
||||
@wrapSwapTensorsTest()
|
||||
def test_to_empty(self, device, dtype, module_info, swap, training):
|
||||
module_cls = module_info.module_cls
|
||||
|
||||
with torch.device("meta"):
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=None, dtype=dtype,
|
||||
requires_grad=False, training=training)
|
||||
|
||||
torch.__future__.set_swap_module_params_on_conversion(swap)
|
||||
device_ = torch.device(device)
|
||||
|
||||
for module_input in module_inputs:
|
||||
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
|
||||
with torch.device("meta"):
|
||||
m = module_cls(*c_args, **c_kwargs)
|
||||
|
||||
p_ids_before = [id(p) for p in m.parameters()]
|
||||
p_cdatas_before = [p._cdata for p in m.parameters()]
|
||||
m.to_empty(device=device_)
|
||||
|
||||
self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters()))
|
||||
self.assertTrue(all(p.device == device_ for p in m.parameters()))
|
||||
self.assertTrue(all(p.dtype == dtype for p in m.parameters()))
|
||||
p_ids_after = [id(p) for p in m.parameters()]
|
||||
p_cdatas_after = [p._cdata for p in m.parameters()]
|
||||
|
||||
if swap:
|
||||
# id same, ._cdata differs --> swapped cdata of THPVariable
|
||||
self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
|
||||
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
|
||||
else:
|
||||
# id and ._cdata differ
|
||||
# meta and device have different shallow copy types, so this will create a new
|
||||
# parameter and assign it to the module
|
||||
self.assertTrue(all(a != b for a, b in zip(p_ids_before, p_ids_after)))
|
||||
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestModule, globals(), allow_mps=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user