[DeviceMesh] Initialized mesh tensor with CPU context (#124767)

This PR makes sure to construct the `DeviceMesh`'s `mesh` tensor on CPU device in `init_device_mesh()`. This means that we can call `init_device_mesh()` under meta-device context and still construct the correct `mesh` tensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124767
Approved by: https://github.com/wz337
ghstack dependencies: #124651, #124741
This commit is contained in:
Andrew Gu
2024-04-23 12:45:31 -07:00
committed by PyTorch MergeBot
parent 674e15ae07
commit 1db7d64af2
2 changed files with 17 additions and 1 deletions

View File

@ -487,6 +487,19 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
self.assertEqual(param.device, torch.device("meta"))
self._test_to_empty_and_reset_parameters(model, mesh, mlp_dim)
# Test that we can call `fully_shard` under meta-device context and
# that `init_device_mesh` call still works
mlp_dim = 8
with torch.device("meta"):
model = nn.Sequential(MLP(mlp_dim, with_buffer=True), MLP(mlp_dim))
for param in model.parameters():
self.assertEqual(param.device, torch.device("meta"))
for module in (model[0], model[1], model):
fully_shard(module)
for param in model.parameters():
self.assertEqual(param.device, torch.device("meta"))
self._test_to_empty_and_reset_parameters(model, mesh, mlp_dim)
@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_meta_device_2d_init(self):
assert self.world_size >= 4, f"{self.world_size}"

View File

@ -557,6 +557,9 @@ else:
f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
)
# Always initialize the mesh's tensor on CPU, regardless of what the
# external device type has been set to be (e.g. meta)
with torch.device("cpu"):
mesh = torch.arange(math.prod(mesh_shape), dtype=torch.int).view(mesh_shape)
device_mesh = DeviceMesh(
device_type=device_type,