[dynamo][device_mesh] Support mesh_dim_names (#164200)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164200
Approved by: https://github.com/SherlockNoMad, https://github.com/jansel
This commit is contained in:
Animesh Jain
2025-09-29 19:33:27 -07:00
committed by PyTorch MergeBot
parent 7afcb030d8
commit 5274753873
2 changed files with 7 additions and 0 deletions

View File

@ -127,6 +127,8 @@ class GraphModule(torch.nn.Module):
def fn(x):
local_rank = device_mesh.get_local_rank()
global_rank = device_mesh.get_rank()
if "dp" not in device_mesh.mesh_dim_names:
x = x * 2
return x + local_rank + global_rank
x = torch.ones(10)

View File

@ -251,6 +251,11 @@ class DeviceMeshVariable(DistributedVariable):
return ConstantVariable.create(self.value.ndim)
if name == "device_type":
return ConstantVariable.create(self.value.device_type)
if name == "mesh_dim_names":
source = self.source
if source:
source = AttrSource(base=source, member="mesh_dim_names")
return VariableTracker.build(tx, self.value.mesh_dim_names, source)
return super().var_getattr(tx, name)
def call_method(