Move functional collectives to the right namespace (#97793)

This moves them from `torch._C._nn` to `torch._C._dist`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97793
Approved by: https://github.com/albanD
This commit is contained in:
Rodrigo Kumpera
2023-03-30 22:18:09 +00:00
committed by PyTorch MergeBot
parent 45acfc8574
commit 184bfbc3d7
11 changed files with 108 additions and 12 deletions

View File

@ -239,6 +239,10 @@ def is_py_special_function(f: NativeFunction) -> bool:
return f.python_module == "special"
def is_py_dist_function(f: NativeFunction) -> bool:
return f.python_module == "dist"
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Main Function
@ -345,6 +349,15 @@ def gen(
symint=symint,
)
create_python_bindings(
fm,
functions,
is_py_dist_function,
"torch.distributed.functional",
"python_dist_functions.cpp",
method=False,
)
# Currently, we only use `functions` to generate `return_types` bindings.
# All methods which return namedtuple have function variant at this point.
# If any method only operator with namedtuple is added in the future,
@ -902,6 +915,7 @@ if(check_has_torch_function(self_)) {{
"torch.nested": "THPNestedVariableFunctionsModule",
"torch.sparse": "THPSparseVariableFunctionsModule",
"torch.special": "THPSpecialVariableFunctionsModule",
"torch.distributed.functional": "THPDistVariableFunctionsModule",
}[module]
if module
else "THPVariableClass"