mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Fix mesh.get_local_rank when it is > 1d (#164473)"
This reverts commit 83d71dfb2fd993a6242372b8123549acaa85ffdb. Reverted https://github.com/pytorch/pytorch/pull/164473 on behalf of https://github.com/izaitsevfb due to appears to be causing vision_maskrcnn regression ([comment](https://github.com/pytorch/pytorch/pull/164473#issuecomment-3374738997))
This commit is contained in:
@ -378,7 +378,7 @@ vgg16,pass,0
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,18
|
||||
vision_maskrcnn,pass,20
|
||||
|
||||
|
||||
|
||||
|
|
@ -286,7 +286,7 @@ vgg16,pass,6
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,37
|
||||
vision_maskrcnn,pass,39
|
||||
|
||||
|
||||
|
||||
|
|
@ -271,11 +271,7 @@ class DeviceMeshVariable(DistributedVariable):
|
||||
if name == "get_rank":
|
||||
return ConstantVariable.create(self.value.get_rank())
|
||||
if name == "get_local_rank":
|
||||
const_args = [x.as_python_constant() for x in args]
|
||||
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
||||
return ConstantVariable.create(
|
||||
self.value.get_local_rank(*const_args, **const_kwargs)
|
||||
)
|
||||
return ConstantVariable.create(self.value.get_local_rank())
|
||||
if name == "get_group":
|
||||
const_args = [x.as_python_constant() for x in args]
|
||||
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
||||
|
Reference in New Issue
Block a user