Fix mesh.get_local_rank when it is > 1d (#164473)

Previously, we would not take the arguments passed by get_local_rank into account. This means that we wouldn't be able to trace this call if we had a device_mesh > 1d

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164473
Approved by: https://github.com/xmfan, https://github.com/Skylion007
This commit is contained in:
Francisco Massa
2025-10-04 11:27:53 +00:00
committed by PyTorch MergeBot
parent 5103ecc5d8
commit 83d71dfb2f
3 changed files with 7 additions and 3 deletions

View File

@ -378,7 +378,7 @@ vgg16,pass,0
vision_maskrcnn,pass,20
vision_maskrcnn,pass,18

1 name accuracy graph_breaks
378
379
380
381
382
383
384

View File

@ -286,7 +286,7 @@ vgg16,pass,6
vision_maskrcnn,pass,39
vision_maskrcnn,pass,37

1 name accuracy graph_breaks
286
287
288
289
290
291
292

View File

@ -274,7 +274,11 @@ class DeviceMeshVariable(DistributedVariable):
if name == "get_rank":
return ConstantVariable.create(self.value.get_rank())
if name == "get_local_rank":
return ConstantVariable.create(self.value.get_local_rank())
const_args = [x.as_python_constant() for x in args]
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
return ConstantVariable.create(
self.value.get_local_rank(*const_args, **const_kwargs)
)
if name == "get_group":
const_args = [x.as_python_constant() for x in args]
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}