Revert "Fix mesh.get_local_rank when it is > 1d (#164473)"

This reverts commit 83d71dfb2fd993a6242372b8123549acaa85ffdb.

Reverted https://github.com/pytorch/pytorch/pull/164473 on behalf of https://github.com/izaitsevfb due to appears to be causing vision_maskrcnn regression ([comment](https://github.com/pytorch/pytorch/pull/164473#issuecomment-3374738997))
This commit is contained in:
PyTorch MergeBot
2025-10-07 00:37:41 +00:00
parent e89d12bf5d
commit afee8062d5
3 changed files with 3 additions and 7 deletions

View File

@ -378,7 +378,7 @@ vgg16,pass,0
vision_maskrcnn,pass,18
vision_maskrcnn,pass,20

1 name accuracy graph_breaks
378
379
380
381
382
383
384

View File

@ -286,7 +286,7 @@ vgg16,pass,6
vision_maskrcnn,pass,37
vision_maskrcnn,pass,39

1 name accuracy graph_breaks
286
287
288
289
290
291
292

View File

@ -271,11 +271,7 @@ class DeviceMeshVariable(DistributedVariable):
if name == "get_rank":
return ConstantVariable.create(self.value.get_rank())
if name == "get_local_rank":
const_args = [x.as_python_constant() for x in args]
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
return ConstantVariable.create(
self.value.get_local_rank(*const_args, **const_kwargs)
)
return ConstantVariable.create(self.value.get_local_rank())
if name == "get_group":
const_args = [x.as_python_constant() for x in args]
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}