mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo][device_mesh] Support mesh_dim_names (#164200)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164200 Approved by: https://github.com/SherlockNoMad, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
7afcb030d8
commit
5274753873
@ -127,6 +127,8 @@ class GraphModule(torch.nn.Module):
|
||||
def fn(x):
|
||||
local_rank = device_mesh.get_local_rank()
|
||||
global_rank = device_mesh.get_rank()
|
||||
if "dp" not in device_mesh.mesh_dim_names:
|
||||
x = x * 2
|
||||
return x + local_rank + global_rank
|
||||
|
||||
x = torch.ones(10)
|
||||
|
@ -251,6 +251,11 @@ class DeviceMeshVariable(DistributedVariable):
|
||||
return ConstantVariable.create(self.value.ndim)
|
||||
if name == "device_type":
|
||||
return ConstantVariable.create(self.value.device_type)
|
||||
if name == "mesh_dim_names":
|
||||
source = self.source
|
||||
if source:
|
||||
source = AttrSource(base=source, member="mesh_dim_names")
|
||||
return VariableTracker.build(tx, self.value.mesh_dim_names, source)
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
def call_method(
|
||||
|
Reference in New Issue
Block a user