diff --git a/test/test_modules.py b/test/test_modules.py index e854eec8add7..601cf5cefdf9 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -983,16 +983,23 @@ class TestModule(TestCase): p_ids_after = [id(p) for p in m.parameters()] p_cdatas_after = [p._cdata for p in m.parameters()] - if swap: - # id same, ._cdata differs --> swapped cdata of THPVariable - self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after))) - self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after))) - else: - # id and ._cdata differ - # meta and device have different shallow copy types, so this will create a new - # parameter and assign it to the module - self.assertTrue(all(a != b for a, b in zip(p_ids_before, p_ids_after))) - self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after))) + # id same, ._cdata differs --> swapped cdata of THPVariable + # Technically, meta and device have different shallow copy types, so when swap=False it will create a new + # parameter and assign it to the module BUT we opt into swap_tensors when either one is on meta. + self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after))) + self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after))) + + # Test the opposite direction device --> meta + m = m.to(device="meta") + + p_ids_after_meta = [id(p) for p in m.parameters()] + p_cdatas_after_meta = [p._cdata for p in m.parameters()] + + # id same, ._cdata differs --> swapped cdata of THPVariable + # Technically, meta and device have different shallow copy types, so when swap=False it will create a new + # parameter and assign it to the module BUT we opt into swap_tensors when either one is on meta. + self.assertTrue(all(a == b for a, b in zip(p_ids_after, p_ids_after_meta))) + self.assertTrue(all(a != b for a, b in zip(p_cdatas_after, p_cdatas_after_meta))) instantiate_device_type_tests(TestModule, globals(), allow_mps=True) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 80f7876f28fd..2e65bb97c659 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -798,6 +798,8 @@ class Module: return (should_use_swap_tensors # subclasses may have multiple child tensors so we need to use swap_tensors or is_traceable_wrapper_subclass(tensor_applied) + or tensor.device.type == 'meta' + or tensor_applied.device.type == 'meta' or tensor.device.type == 'xla' or tensor_applied.device.type == 'xla')