mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Default meta device to use swap_tensors in nn.Module._apply (.to_empty and .to('meta')) (#126819)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126819 Approved by: https://github.com/albanD ghstack dependencies: #127313, #126814
This commit is contained in:
committed by
PyTorch MergeBot
parent
bfdec93395
commit
fa426b096b
@ -983,16 +983,23 @@ class TestModule(TestCase):
|
||||
p_ids_after = [id(p) for p in m.parameters()]
|
||||
p_cdatas_after = [p._cdata for p in m.parameters()]
|
||||
|
||||
if swap:
|
||||
# id same, ._cdata differs --> swapped cdata of THPVariable
|
||||
self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
|
||||
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
|
||||
else:
|
||||
# id and ._cdata differ
|
||||
# meta and device have different shallow copy types, so this will create a new
|
||||
# parameter and assign it to the module
|
||||
self.assertTrue(all(a != b for a, b in zip(p_ids_before, p_ids_after)))
|
||||
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
|
||||
# id same, ._cdata differs --> swapped cdata of THPVariable
|
||||
# Technically, meta and device have different shallow copy types, so when swap=False it will create a new
|
||||
# parameter and assign it to the module BUT we opt into swap_tensors when either one is on meta.
|
||||
self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
|
||||
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
|
||||
|
||||
# Test the opposite direction device --> meta
|
||||
m = m.to(device="meta")
|
||||
|
||||
p_ids_after_meta = [id(p) for p in m.parameters()]
|
||||
p_cdatas_after_meta = [p._cdata for p in m.parameters()]
|
||||
|
||||
# id same, ._cdata differs --> swapped cdata of THPVariable
|
||||
# Technically, meta and device have different shallow copy types, so when swap=False it will create a new
|
||||
# parameter and assign it to the module BUT we opt into swap_tensors when either one is on meta.
|
||||
self.assertTrue(all(a == b for a, b in zip(p_ids_after, p_ids_after_meta)))
|
||||
self.assertTrue(all(a != b for a, b in zip(p_cdatas_after, p_cdatas_after_meta)))
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestModule, globals(), allow_mps=True)
|
||||
|
Reference in New Issue
Block a user