mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix mesh.get_local_rank when it is > 1d (#164473)
Previously, we would not take the arguments passed by get_local_rank into account. This means that we wouldn't be able to trace this call if we had a device_mesh > 1d Pull Request resolved: https://github.com/pytorch/pytorch/pull/164473 Approved by: https://github.com/xmfan, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
5103ecc5d8
commit
83d71dfb2f
@ -378,7 +378,7 @@ vgg16,pass,0
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,20
|
||||
vision_maskrcnn,pass,18
|
||||
|
||||
|
||||
|
||||
|
|
@ -286,7 +286,7 @@ vgg16,pass,6
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,pass,39
|
||||
vision_maskrcnn,pass,37
|
||||
|
||||
|
||||
|
||||
|
|
@ -274,7 +274,11 @@ class DeviceMeshVariable(DistributedVariable):
|
||||
if name == "get_rank":
|
||||
return ConstantVariable.create(self.value.get_rank())
|
||||
if name == "get_local_rank":
|
||||
return ConstantVariable.create(self.value.get_local_rank())
|
||||
const_args = [x.as_python_constant() for x in args]
|
||||
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
||||
return ConstantVariable.create(
|
||||
self.value.get_local_rank(*const_args, **const_kwargs)
|
||||
)
|
||||
if name == "get_group":
|
||||
const_args = [x.as_python_constant() for x in args]
|
||||
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
||||
|
Reference in New Issue
Block a user