[DeviceMesh] Add extra check in flatten result cache lookup

[ghstack-poisoned]
This commit is contained in:
fduwjj
2025-09-18 13:16:59 -07:00
parent 0178ff92a6
commit 07891e11a9
2 changed files with 22 additions and 1 deletions

View File

@ -903,6 +903,18 @@ class TestDeviceMeshGetItem(DTensorTestBase):
cp_tp_mesh._flatten("dummy")
self.assertEqual(mesh_3d["dummy"].mesh_dim_names[0], "dummy")
# Test flatten into an existing mesh_dim_name inside the mesh
with self.assertRaisesRegex(
RuntimeError,
"dp already exists for submesh of the DeviceMes",
):
mesh_3d._flatten("dp")
with self.assertRaisesRegex(
RuntimeError,
"Flatten mesh with mesh_dim_name dp_tp has been created before",
):
mesh_3d["cp", "tp"]._flatten("dp_tp")
@with_comms(eager_init=True)
def test_flatten_mesh_4d(self):
mesh_shape = (2, 2, 2, 1)

View File

@ -185,7 +185,16 @@ if True: # just to temporarily avoid reindentation
root_mesh in self.root_to_flatten_mapping
and mesh_dim_name in self.root_to_flatten_mapping[root_mesh]
):
return self.root_to_flatten_mapping[root_mesh][mesh_dim_name]
if (
tuple(flatten_dims_in_root)
== self.flatten_name_to_root_dims[root_mesh][mesh_dim_name]
):
return self.root_to_flatten_mapping[root_mesh][mesh_dim_name]
else:
raise RuntimeError(
f"Flatten mesh with mesh_dim_name {mesh_dim_name} has been created before, "
f"Please specify another valid mesh_dim_name.",
)
flattened_mesh_dim_size = math.prod(device_mesh.mesh.size())