Default meta device to use swap_tensors in nn.Module._apply (.to_empty and .to('meta')) (#126819)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126819
Approved by: https://github.com/albanD
ghstack dependencies: #127313, #126814
This commit is contained in:
Mikayla Gawarecki
2024-05-29 17:14:06 -07:00
committed by PyTorch MergeBot
parent bfdec93395
commit fa426b096b
2 changed files with 19 additions and 10 deletions

View File

@ -983,16 +983,23 @@ class TestModule(TestCase):
p_ids_after = [id(p) for p in m.parameters()]
p_cdatas_after = [p._cdata for p in m.parameters()]
if swap:
# id same, ._cdata differs --> swapped cdata of THPVariable
# Technically, meta and device have different shallow copy types, so when swap=False it will create a new
# parameter and assign it to the module BUT we opt into swap_tensors when either one is on meta.
self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
else:
# id and ._cdata differ
# meta and device have different shallow copy types, so this will create a new
# parameter and assign it to the module
self.assertTrue(all(a != b for a, b in zip(p_ids_before, p_ids_after)))
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
# Test the opposite direction device --> meta
m = m.to(device="meta")
p_ids_after_meta = [id(p) for p in m.parameters()]
p_cdatas_after_meta = [p._cdata for p in m.parameters()]
# id same, ._cdata differs --> swapped cdata of THPVariable
# Technically, meta and device have different shallow copy types, so when swap=False it will create a new
# parameter and assign it to the module BUT we opt into swap_tensors when either one is on meta.
self.assertTrue(all(a == b for a, b in zip(p_ids_after, p_ids_after_meta)))
self.assertTrue(all(a != b for a, b in zip(p_cdatas_after, p_cdatas_after_meta)))
instantiate_device_type_tests(TestModule, globals(), allow_mps=True)

View File

@ -798,6 +798,8 @@ class Module:
return (should_use_swap_tensors
# subclasses may have multiple child tensors so we need to use swap_tensors
or is_traceable_wrapper_subclass(tensor_applied)
or tensor.device.type == 'meta'
or tensor_applied.device.type == 'meta'
or tensor.device.type == 'xla'
or tensor_applied.device.type == 'xla')