mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[DeviceMesh] Initialized mesh tensor with CPU context (#124767)
This PR makes sure to construct the `DeviceMesh`'s `mesh` tensor on CPU device in `init_device_mesh()`. This means that we can call `init_device_mesh()` under meta-device context and still construct the correct `mesh` tensor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124767 Approved by: https://github.com/wz337 ghstack dependencies: #124651, #124741
This commit is contained in:
committed by
PyTorch MergeBot
parent
674e15ae07
commit
1db7d64af2
@ -487,6 +487,19 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
|
||||
self.assertEqual(param.device, torch.device("meta"))
|
||||
self._test_to_empty_and_reset_parameters(model, mesh, mlp_dim)
|
||||
|
||||
# Test that we can call `fully_shard` under meta-device context and
|
||||
# that `init_device_mesh` call still works
|
||||
mlp_dim = 8
|
||||
with torch.device("meta"):
|
||||
model = nn.Sequential(MLP(mlp_dim, with_buffer=True), MLP(mlp_dim))
|
||||
for param in model.parameters():
|
||||
self.assertEqual(param.device, torch.device("meta"))
|
||||
for module in (model[0], model[1], model):
|
||||
fully_shard(module)
|
||||
for param in model.parameters():
|
||||
self.assertEqual(param.device, torch.device("meta"))
|
||||
self._test_to_empty_and_reset_parameters(model, mesh, mlp_dim)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
def test_meta_device_2d_init(self):
|
||||
assert self.world_size >= 4, f"{self.world_size}"
|
||||
|
@ -557,7 +557,10 @@ else:
|
||||
f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
|
||||
)
|
||||
|
||||
mesh = torch.arange(math.prod(mesh_shape), dtype=torch.int).view(mesh_shape)
|
||||
# Always initialize the mesh's tensor on CPU, regardless of what the
|
||||
# external device type has been set to be (e.g. meta)
|
||||
with torch.device("cpu"):
|
||||
mesh = torch.arange(math.prod(mesh_shape), dtype=torch.int).view(mesh_shape)
|
||||
device_mesh = DeviceMesh(
|
||||
device_type=device_type,
|
||||
mesh=mesh,
|
||||
|
Reference in New Issue
Block a user