mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move functional collectives to the right namespace (#97793)
This moves them from `torch._C._nn` to `torch._C._dist` Pull Request resolved: https://github.com/pytorch/pytorch/pull/97793 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
45acfc8574
commit
184bfbc3d7
@ -14718,29 +14718,25 @@
|
||||
|
||||
# Collectives
|
||||
- func: all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor
|
||||
# This should be changed to distributed but it requires changes all over the place to work
|
||||
python_module: nn
|
||||
python_module: dist
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: all_reduce
|
||||
variants: function
|
||||
|
||||
- func: all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor
|
||||
# This should be changed to distributed but it requires changes all over the place to work
|
||||
python_module: nn
|
||||
python_module: dist
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: all_gather_into_tensor
|
||||
variants: function
|
||||
|
||||
- func: reduce_scatter_tensor(Tensor input, str reduceOp, int scatter_dim, str tag, int[] ranks, int group_size) -> Tensor
|
||||
# This should be changed to distributed but it requires changes all over the place to work
|
||||
python_module: nn
|
||||
python_module: dist
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: reduce_scatter_tensor
|
||||
variants: function
|
||||
|
||||
- func: wait_tensor(Tensor self) -> Tensor
|
||||
# This should be changed to distributed but it requires changes all over the place to work
|
||||
python_module: nn
|
||||
python_module: dist
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: wait_tensor
|
||||
|
||||
|
@ -260,6 +260,7 @@ _GENERATED_AUTOGRAD_PYTHON_CPP = [
|
||||
"torch/csrc/autograd/generated/python_nested_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_fft_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_linalg_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_dist_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_return_types.cpp",
|
||||
"torch/csrc/autograd/generated/python_enum_tag.cpp",
|
||||
"torch/csrc/autograd/generated/python_sparse_functions.cpp",
|
||||
|
@ -930,6 +930,7 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"):
|
||||
"torch/csrc/autograd/generated/python_nn_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_fft_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_linalg_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_dist_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_enum_tag.cpp",
|
||||
"torch/csrc/autograd/generated/python_return_types.cpp",
|
||||
"torch/csrc/autograd/generated/python_sparse_functions.cpp",
|
||||
|
@ -394,6 +394,7 @@ set(GENERATED_CXX_PYTHON
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_nested_functions.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_sparse_functions.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_dist_functions.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp"
|
||||
)
|
||||
|
@ -149,6 +149,7 @@ def get_generate_code_bin_outs():
|
||||
"autograd/generated/python_return_types.cpp": ["autograd/generated/python_return_types.cpp"],
|
||||
"autograd/generated/python_sparse_functions.cpp": ["autograd/generated/python_sparse_functions.cpp"],
|
||||
"autograd/generated/python_special_functions.cpp": ["autograd/generated/python_special_functions.cpp"],
|
||||
"autograd/generated/python_dist_functions.cpp": ["autograd/generated/python_dist_functions.cpp"],
|
||||
"autograd/generated/python_torch_functions_0.cpp": ["autograd/generated/python_torch_functions_0.cpp"],
|
||||
"autograd/generated/python_torch_functions_1.cpp": ["autograd/generated/python_torch_functions_1.cpp"],
|
||||
"autograd/generated/python_torch_functions_2.cpp": ["autograd/generated/python_torch_functions_2.cpp"],
|
||||
|
@ -135,6 +135,7 @@ def define_tools_targets(
|
||||
"autograd/templates/python_return_types.cpp",
|
||||
"autograd/templates/python_sparse_functions.cpp",
|
||||
"autograd/templates/python_special_functions.cpp",
|
||||
"autograd/templates/python_dist_functions.cpp",
|
||||
"autograd/templates/python_torch_functions.cpp",
|
||||
"autograd/templates/python_variable_methods.cpp",
|
||||
"autograd/templates/variable_factories.h",
|
||||
|
@ -239,6 +239,10 @@ def is_py_special_function(f: NativeFunction) -> bool:
|
||||
return f.python_module == "special"
|
||||
|
||||
|
||||
def is_py_dist_function(f: NativeFunction) -> bool:
|
||||
return f.python_module == "dist"
|
||||
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
#
|
||||
# Main Function
|
||||
@ -345,6 +349,15 @@ def gen(
|
||||
symint=symint,
|
||||
)
|
||||
|
||||
create_python_bindings(
|
||||
fm,
|
||||
functions,
|
||||
is_py_dist_function,
|
||||
"torch.distributed.functional",
|
||||
"python_dist_functions.cpp",
|
||||
method=False,
|
||||
)
|
||||
|
||||
# Currently, we only use `functions` to generate `return_types` bindings.
|
||||
# All methods which return namedtuple have function variant at this point.
|
||||
# If any method only operator with namedtuple is added in the future,
|
||||
@ -902,6 +915,7 @@ if(check_has_torch_function(self_)) {{
|
||||
"torch.nested": "THPNestedVariableFunctionsModule",
|
||||
"torch.sparse": "THPSparseVariableFunctionsModule",
|
||||
"torch.special": "THPSpecialVariableFunctionsModule",
|
||||
"torch.distributed.functional": "THPDistVariableFunctionsModule",
|
||||
}[module]
|
||||
if module
|
||||
else "THPVariableClass"
|
||||
|
68
tools/autograd/templates/python_dist_functions.cpp
Normal file
68
tools/autograd/templates/python_dist_functions.cpp
Normal file
@ -0,0 +1,68 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
// ${generated_comment}
|
||||
|
||||
#include "torch/csrc/Device.h"
|
||||
#include "torch/csrc/DynamicTypes.h"
|
||||
#include "torch/csrc/Exceptions.h"
|
||||
#include "torch/csrc/autograd/python_dist_functions.h"
|
||||
#include "torch/csrc/autograd/python_return_types.h"
|
||||
#include "torch/csrc/autograd/python_variable.h"
|
||||
#include "torch/csrc/autograd/utils/wrap_outputs.h"
|
||||
#include "torch/csrc/autograd/utils/python_arg_parsing.h"
|
||||
#include "torch/csrc/utils/pycfunction_helpers.h"
|
||||
#include "torch/csrc/utils/python_arg_parser.h"
|
||||
#include "torch/csrc/utils/structseq.h"
|
||||
#include "torch/csrc/utils/tensor_memoryformats.h"
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#else
|
||||
$ops_headers
|
||||
#endif
|
||||
|
||||
using at::Tensor;
|
||||
using at::Scalar;
|
||||
using at::MemoryFormat;
|
||||
using at::Generator;
|
||||
using at::IntArrayRef;
|
||||
using at::ArrayRef;
|
||||
|
||||
using namespace torch::autograd::utils;
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
// generated forward declarations start here
|
||||
|
||||
${py_forwards}
|
||||
|
||||
static PyMethodDef dist_functions[] = {
|
||||
${py_method_defs}
|
||||
{NULL}
|
||||
};
|
||||
|
||||
static PyObject* THPDistVariableFunctionsModule = NULL;
|
||||
|
||||
void initDistFunctions(PyObject* module) {
|
||||
static struct PyModuleDef def = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"torch._C._dist",
|
||||
NULL,
|
||||
-1,
|
||||
dist_functions
|
||||
};
|
||||
PyObject* dist = PyModule_Create(&def);
|
||||
THPDistVariableFunctionsModule = dist;
|
||||
if (!dist) {
|
||||
throw python_error();
|
||||
}
|
||||
// steals a reference to dist
|
||||
if (PyModule_AddObject(module, "_dist", dist) != 0) {
|
||||
throw python_error();
|
||||
}
|
||||
}
|
||||
|
||||
// generated methods start here
|
||||
|
||||
${py_methods}
|
||||
|
||||
}} // namespace torch::autograd
|
@ -41,6 +41,7 @@
|
||||
#include <torch/csrc/TypeInfo.h>
|
||||
#include <torch/csrc/api/include/torch/python/init.h>
|
||||
#include <torch/csrc/autograd/python_cpp_function.h>
|
||||
#include <torch/csrc/autograd/python_dist_functions.h>
|
||||
#include <torch/csrc/autograd/python_enum_tag.h>
|
||||
#include <torch/csrc/autograd/python_fft_functions.h>
|
||||
#include <torch/csrc/autograd/python_function.h>
|
||||
@ -1328,6 +1329,7 @@ PyObject* initModule() {
|
||||
torch::autograd::initNestedFunctions(module);
|
||||
torch::autograd::initSparseFunctions(module);
|
||||
torch::autograd::initSpecialFunctions(module);
|
||||
torch::autograd::initDistFunctions(module);
|
||||
torch::autograd::init_legacy_variable(module);
|
||||
torch::profiler::initPythonBindings(module);
|
||||
torch::python::init_bindings(module);
|
||||
|
9
torch/csrc/autograd/python_dist_functions.h
Normal file
9
torch/csrc/autograd/python_dist_functions.h
Normal file
@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
|
||||
void initDistFunctions(PyObject* module);
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
@ -105,7 +105,7 @@ class AsyncCollectiveTensor(torch.Tensor):
|
||||
Use it inside functional collective pytorch wrappers like the following:
|
||||
def functional_collective(self, group, tag):
|
||||
tag, rankset, group_size = _expand_group(group, tag)
|
||||
tensor = torch._C._nn.{collective}(self, tag, rankset, group_size)
|
||||
tensor = torch._C._dist.{collective}(self, tag, rankset, group_size)
|
||||
res = AsyncCollectiveTensor(tensor)
|
||||
_register_wrapper_tensor(res, tensor)
|
||||
return res
|
||||
@ -254,7 +254,7 @@ def wait_tensor(tensor):
|
||||
|
||||
Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
|
||||
"""
|
||||
return torch._C._nn.wait_tensor(tensor) # type: ignore[attr-defined]
|
||||
return torch._C._dist.wait_tensor(tensor) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
|
||||
@ -275,7 +275,7 @@ def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str =
|
||||
that information and perform collective algebraic optimization. Use other forms of input for that.
|
||||
"""
|
||||
tag, rankset, group_size = _expand_group(group, tag)
|
||||
tensor = torch._C._nn.all_reduce(self, reduceOp, tag, rankset, group_size) # type: ignore[attr-defined]
|
||||
tensor = torch._C._dist.all_reduce(self, reduceOp, tag, rankset, group_size) # type: ignore[attr-defined]
|
||||
res = AsyncCollectiveTensor(tensor)
|
||||
_register_wrapper_tensor(res, tensor)
|
||||
return res
|
||||
@ -307,7 +307,9 @@ def reduce_scatter_tensor(
|
||||
assert (
|
||||
self.size(0) % group_size == 0
|
||||
), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
|
||||
tensor = torch._C._nn.reduce_scatter_tensor(self, reduceOp, scatter_dim, tag, rankset, group_size) # type: ignore[attr-defined]
|
||||
tensor = torch._C._dist.reduce_scatter_tensor( # type: ignore[attr-defined]
|
||||
self, reduceOp, scatter_dim, tag, rankset, group_size
|
||||
)
|
||||
res = AsyncCollectiveTensor(tensor)
|
||||
_register_wrapper_tensor(res, tensor)
|
||||
return res
|
||||
|
Reference in New Issue
Block a user