mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	[DeviceMesh] Add extra check in flatten result cache lookup
[ghstack-poisoned]
This commit is contained in:
		@ -903,6 +903,18 @@ class TestDeviceMeshGetItem(DTensorTestBase):
 | 
			
		||||
        cp_tp_mesh._flatten("dummy")
 | 
			
		||||
        self.assertEqual(mesh_3d["dummy"].mesh_dim_names[0], "dummy")
 | 
			
		||||
 | 
			
		||||
        # Test flatten into an existing mesh_dim_name inside the mesh
 | 
			
		||||
        with self.assertRaisesRegex(
 | 
			
		||||
            RuntimeError,
 | 
			
		||||
            "dp already exists for submesh of the DeviceMes",
 | 
			
		||||
        ):
 | 
			
		||||
            mesh_3d._flatten("dp")
 | 
			
		||||
        with self.assertRaisesRegex(
 | 
			
		||||
            RuntimeError,
 | 
			
		||||
            "Flatten mesh with mesh_dim_name dp_tp has been created before",
 | 
			
		||||
        ):
 | 
			
		||||
            mesh_3d["cp", "tp"]._flatten("dp_tp")
 | 
			
		||||
 | 
			
		||||
    @with_comms(eager_init=True)
 | 
			
		||||
    def test_flatten_mesh_4d(self):
 | 
			
		||||
        mesh_shape = (2, 2, 2, 1)
 | 
			
		||||
 | 
			
		||||
@ -185,7 +185,16 @@ if True:  # just to temporarily avoid reindentation
 | 
			
		||||
                root_mesh in self.root_to_flatten_mapping
 | 
			
		||||
                and mesh_dim_name in self.root_to_flatten_mapping[root_mesh]
 | 
			
		||||
            ):
 | 
			
		||||
                return self.root_to_flatten_mapping[root_mesh][mesh_dim_name]
 | 
			
		||||
                if (
 | 
			
		||||
                    tuple(flatten_dims_in_root)
 | 
			
		||||
                    == self.flatten_name_to_root_dims[root_mesh][mesh_dim_name]
 | 
			
		||||
                ):
 | 
			
		||||
                    return self.root_to_flatten_mapping[root_mesh][mesh_dim_name]
 | 
			
		||||
                else:
 | 
			
		||||
                    raise RuntimeError(
 | 
			
		||||
                        f"Flatten mesh with mesh_dim_name {mesh_dim_name} has been created before, "
 | 
			
		||||
                        f"Please specify another valid mesh_dim_name.",
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
            flattened_mesh_dim_size = math.prod(device_mesh.mesh.size())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user