mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Revert "Move functional collectives to the right namespace (#97793)"
This reverts commit 184bfbc3d7b37e8f202f4938f6ea9ba557c93b1e. Reverted https://github.com/pytorch/pytorch/pull/97793 on behalf of https://github.com/atalman due to breaks internal builds
This commit is contained in:
		@ -14718,25 +14718,29 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# Collectives
 | 
					# Collectives
 | 
				
			||||||
- func: all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor
 | 
					- func: all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor
 | 
				
			||||||
  python_module: dist
 | 
					  # This should be changed to distributed but it requires changes all over the place to work
 | 
				
			||||||
 | 
					  python_module: nn
 | 
				
			||||||
  dispatch:
 | 
					  dispatch:
 | 
				
			||||||
    CompositeExplicitAutograd: all_reduce
 | 
					    CompositeExplicitAutograd: all_reduce
 | 
				
			||||||
  variants: function
 | 
					  variants: function
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- func: all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor
 | 
					- func: all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor
 | 
				
			||||||
  python_module: dist
 | 
					  # This should be changed to distributed but it requires changes all over the place to work
 | 
				
			||||||
 | 
					  python_module: nn
 | 
				
			||||||
  dispatch:
 | 
					  dispatch:
 | 
				
			||||||
    CompositeExplicitAutograd: all_gather_into_tensor
 | 
					    CompositeExplicitAutograd: all_gather_into_tensor
 | 
				
			||||||
  variants: function
 | 
					  variants: function
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- func: reduce_scatter_tensor(Tensor input, str reduceOp, int scatter_dim, str tag, int[] ranks, int group_size) -> Tensor
 | 
					- func: reduce_scatter_tensor(Tensor input, str reduceOp, int scatter_dim, str tag, int[] ranks, int group_size) -> Tensor
 | 
				
			||||||
  python_module: dist
 | 
					  # This should be changed to distributed but it requires changes all over the place to work
 | 
				
			||||||
 | 
					  python_module: nn
 | 
				
			||||||
  dispatch:
 | 
					  dispatch:
 | 
				
			||||||
    CompositeExplicitAutograd: reduce_scatter_tensor
 | 
					    CompositeExplicitAutograd: reduce_scatter_tensor
 | 
				
			||||||
  variants: function
 | 
					  variants: function
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- func: wait_tensor(Tensor self) -> Tensor
 | 
					- func: wait_tensor(Tensor self) -> Tensor
 | 
				
			||||||
  python_module: dist
 | 
					  # This should be changed to distributed but it requires changes all over the place to work
 | 
				
			||||||
 | 
					  python_module: nn
 | 
				
			||||||
  dispatch:
 | 
					  dispatch:
 | 
				
			||||||
    CompositeExplicitAutograd: wait_tensor
 | 
					    CompositeExplicitAutograd: wait_tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -260,7 +260,6 @@ _GENERATED_AUTOGRAD_PYTHON_CPP = [
 | 
				
			|||||||
    "torch/csrc/autograd/generated/python_nested_functions.cpp",
 | 
					    "torch/csrc/autograd/generated/python_nested_functions.cpp",
 | 
				
			||||||
    "torch/csrc/autograd/generated/python_fft_functions.cpp",
 | 
					    "torch/csrc/autograd/generated/python_fft_functions.cpp",
 | 
				
			||||||
    "torch/csrc/autograd/generated/python_linalg_functions.cpp",
 | 
					    "torch/csrc/autograd/generated/python_linalg_functions.cpp",
 | 
				
			||||||
    "torch/csrc/autograd/generated/python_dist_functions.cpp",
 | 
					 | 
				
			||||||
    "torch/csrc/autograd/generated/python_return_types.cpp",
 | 
					    "torch/csrc/autograd/generated/python_return_types.cpp",
 | 
				
			||||||
    "torch/csrc/autograd/generated/python_enum_tag.cpp",
 | 
					    "torch/csrc/autograd/generated/python_enum_tag.cpp",
 | 
				
			||||||
    "torch/csrc/autograd/generated/python_sparse_functions.cpp",
 | 
					    "torch/csrc/autograd/generated/python_sparse_functions.cpp",
 | 
				
			||||||
 | 
				
			|||||||
@ -930,7 +930,6 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"):
 | 
				
			|||||||
        "torch/csrc/autograd/generated/python_nn_functions.cpp",
 | 
					        "torch/csrc/autograd/generated/python_nn_functions.cpp",
 | 
				
			||||||
        "torch/csrc/autograd/generated/python_fft_functions.cpp",
 | 
					        "torch/csrc/autograd/generated/python_fft_functions.cpp",
 | 
				
			||||||
        "torch/csrc/autograd/generated/python_linalg_functions.cpp",
 | 
					        "torch/csrc/autograd/generated/python_linalg_functions.cpp",
 | 
				
			||||||
        "torch/csrc/autograd/generated/python_dist_functions.cpp",
 | 
					 | 
				
			||||||
        "torch/csrc/autograd/generated/python_enum_tag.cpp",
 | 
					        "torch/csrc/autograd/generated/python_enum_tag.cpp",
 | 
				
			||||||
        "torch/csrc/autograd/generated/python_return_types.cpp",
 | 
					        "torch/csrc/autograd/generated/python_return_types.cpp",
 | 
				
			||||||
        "torch/csrc/autograd/generated/python_sparse_functions.cpp",
 | 
					        "torch/csrc/autograd/generated/python_sparse_functions.cpp",
 | 
				
			||||||
 | 
				
			|||||||
@ -394,7 +394,6 @@ set(GENERATED_CXX_PYTHON
 | 
				
			|||||||
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_nested_functions.cpp"
 | 
					  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_nested_functions.cpp"
 | 
				
			||||||
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_sparse_functions.cpp"
 | 
					  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_sparse_functions.cpp"
 | 
				
			||||||
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
 | 
					  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
 | 
				
			||||||
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_dist_functions.cpp"
 | 
					 | 
				
			||||||
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp"
 | 
					  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp"
 | 
				
			||||||
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp"
 | 
					  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp"
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
 | 
				
			|||||||
@ -149,7 +149,6 @@ def get_generate_code_bin_outs():
 | 
				
			|||||||
            "autograd/generated/python_return_types.cpp": ["autograd/generated/python_return_types.cpp"],
 | 
					            "autograd/generated/python_return_types.cpp": ["autograd/generated/python_return_types.cpp"],
 | 
				
			||||||
            "autograd/generated/python_sparse_functions.cpp": ["autograd/generated/python_sparse_functions.cpp"],
 | 
					            "autograd/generated/python_sparse_functions.cpp": ["autograd/generated/python_sparse_functions.cpp"],
 | 
				
			||||||
            "autograd/generated/python_special_functions.cpp": ["autograd/generated/python_special_functions.cpp"],
 | 
					            "autograd/generated/python_special_functions.cpp": ["autograd/generated/python_special_functions.cpp"],
 | 
				
			||||||
            "autograd/generated/python_dist_functions.cpp": ["autograd/generated/python_dist_functions.cpp"],
 | 
					 | 
				
			||||||
            "autograd/generated/python_torch_functions_0.cpp": ["autograd/generated/python_torch_functions_0.cpp"],
 | 
					            "autograd/generated/python_torch_functions_0.cpp": ["autograd/generated/python_torch_functions_0.cpp"],
 | 
				
			||||||
            "autograd/generated/python_torch_functions_1.cpp": ["autograd/generated/python_torch_functions_1.cpp"],
 | 
					            "autograd/generated/python_torch_functions_1.cpp": ["autograd/generated/python_torch_functions_1.cpp"],
 | 
				
			||||||
            "autograd/generated/python_torch_functions_2.cpp": ["autograd/generated/python_torch_functions_2.cpp"],
 | 
					            "autograd/generated/python_torch_functions_2.cpp": ["autograd/generated/python_torch_functions_2.cpp"],
 | 
				
			||||||
 | 
				
			|||||||
@ -135,7 +135,6 @@ def define_tools_targets(
 | 
				
			|||||||
            "autograd/templates/python_return_types.cpp",
 | 
					            "autograd/templates/python_return_types.cpp",
 | 
				
			||||||
            "autograd/templates/python_sparse_functions.cpp",
 | 
					            "autograd/templates/python_sparse_functions.cpp",
 | 
				
			||||||
            "autograd/templates/python_special_functions.cpp",
 | 
					            "autograd/templates/python_special_functions.cpp",
 | 
				
			||||||
            "autograd/templates/python_dist_functions.cpp",
 | 
					 | 
				
			||||||
            "autograd/templates/python_torch_functions.cpp",
 | 
					            "autograd/templates/python_torch_functions.cpp",
 | 
				
			||||||
            "autograd/templates/python_variable_methods.cpp",
 | 
					            "autograd/templates/python_variable_methods.cpp",
 | 
				
			||||||
            "autograd/templates/variable_factories.h",
 | 
					            "autograd/templates/variable_factories.h",
 | 
				
			||||||
 | 
				
			|||||||
@ -239,10 +239,6 @@ def is_py_special_function(f: NativeFunction) -> bool:
 | 
				
			|||||||
    return f.python_module == "special"
 | 
					    return f.python_module == "special"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def is_py_dist_function(f: NativeFunction) -> bool:
 | 
					 | 
				
			||||||
    return f.python_module == "dist"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 | 
					# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
#                            Main Function
 | 
					#                            Main Function
 | 
				
			||||||
@ -349,15 +345,6 @@ def gen(
 | 
				
			|||||||
        symint=symint,
 | 
					        symint=symint,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    create_python_bindings(
 | 
					 | 
				
			||||||
        fm,
 | 
					 | 
				
			||||||
        functions,
 | 
					 | 
				
			||||||
        is_py_dist_function,
 | 
					 | 
				
			||||||
        "torch.distributed.functional",
 | 
					 | 
				
			||||||
        "python_dist_functions.cpp",
 | 
					 | 
				
			||||||
        method=False,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Currently, we only use `functions` to generate `return_types` bindings.
 | 
					    # Currently, we only use `functions` to generate `return_types` bindings.
 | 
				
			||||||
    # All methods which return namedtuple have function variant at this point.
 | 
					    # All methods which return namedtuple have function variant at this point.
 | 
				
			||||||
    # If any method only operator with namedtuple is added in the future,
 | 
					    # If any method only operator with namedtuple is added in the future,
 | 
				
			||||||
@ -915,7 +902,6 @@ if(check_has_torch_function(self_)) {{
 | 
				
			|||||||
            "torch.nested": "THPNestedVariableFunctionsModule",
 | 
					            "torch.nested": "THPNestedVariableFunctionsModule",
 | 
				
			||||||
            "torch.sparse": "THPSparseVariableFunctionsModule",
 | 
					            "torch.sparse": "THPSparseVariableFunctionsModule",
 | 
				
			||||||
            "torch.special": "THPSpecialVariableFunctionsModule",
 | 
					            "torch.special": "THPSpecialVariableFunctionsModule",
 | 
				
			||||||
            "torch.distributed.functional": "THPDistVariableFunctionsModule",
 | 
					 | 
				
			||||||
        }[module]
 | 
					        }[module]
 | 
				
			||||||
        if module
 | 
					        if module
 | 
				
			||||||
        else "THPVariableClass"
 | 
					        else "THPVariableClass"
 | 
				
			||||||
 | 
				
			|||||||
@ -1,68 +0,0 @@
 | 
				
			|||||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 | 
					 | 
				
			||||||
// ${generated_comment}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#include "torch/csrc/Device.h"
 | 
					 | 
				
			||||||
#include "torch/csrc/DynamicTypes.h"
 | 
					 | 
				
			||||||
#include "torch/csrc/Exceptions.h"
 | 
					 | 
				
			||||||
#include "torch/csrc/autograd/python_dist_functions.h"
 | 
					 | 
				
			||||||
#include "torch/csrc/autograd/python_return_types.h"
 | 
					 | 
				
			||||||
#include "torch/csrc/autograd/python_variable.h"
 | 
					 | 
				
			||||||
#include "torch/csrc/autograd/utils/wrap_outputs.h"
 | 
					 | 
				
			||||||
#include "torch/csrc/autograd/utils/python_arg_parsing.h"
 | 
					 | 
				
			||||||
#include "torch/csrc/utils/pycfunction_helpers.h"
 | 
					 | 
				
			||||||
#include "torch/csrc/utils/python_arg_parser.h"
 | 
					 | 
				
			||||||
#include "torch/csrc/utils/structseq.h"
 | 
					 | 
				
			||||||
#include "torch/csrc/utils/tensor_memoryformats.h"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
					 | 
				
			||||||
#include <ATen/Functions.h>
 | 
					 | 
				
			||||||
#else
 | 
					 | 
				
			||||||
$ops_headers
 | 
					 | 
				
			||||||
#endif
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
using at::Tensor;
 | 
					 | 
				
			||||||
using at::Scalar;
 | 
					 | 
				
			||||||
using at::MemoryFormat;
 | 
					 | 
				
			||||||
using at::Generator;
 | 
					 | 
				
			||||||
using at::IntArrayRef;
 | 
					 | 
				
			||||||
using at::ArrayRef;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
using namespace torch::autograd::utils;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
namespace torch { namespace autograd {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// generated forward declarations start here
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
${py_forwards}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
static PyMethodDef dist_functions[] = {
 | 
					 | 
				
			||||||
  ${py_method_defs}
 | 
					 | 
				
			||||||
  {NULL}
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
static PyObject* THPDistVariableFunctionsModule = NULL;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
void initDistFunctions(PyObject* module) {
 | 
					 | 
				
			||||||
  static struct PyModuleDef def = {
 | 
					 | 
				
			||||||
     PyModuleDef_HEAD_INIT,
 | 
					 | 
				
			||||||
     "torch._C._dist",
 | 
					 | 
				
			||||||
     NULL,
 | 
					 | 
				
			||||||
     -1,
 | 
					 | 
				
			||||||
     dist_functions
 | 
					 | 
				
			||||||
  };
 | 
					 | 
				
			||||||
  PyObject* dist = PyModule_Create(&def);
 | 
					 | 
				
			||||||
  THPDistVariableFunctionsModule = dist;
 | 
					 | 
				
			||||||
  if (!dist) {
 | 
					 | 
				
			||||||
    throw python_error();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  // steals a reference to dist
 | 
					 | 
				
			||||||
  if (PyModule_AddObject(module, "_dist", dist) != 0) {
 | 
					 | 
				
			||||||
    throw python_error();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// generated methods start here
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
${py_methods}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}} // namespace torch::autograd
 | 
					 | 
				
			||||||
@ -41,7 +41,6 @@
 | 
				
			|||||||
#include <torch/csrc/TypeInfo.h>
 | 
					#include <torch/csrc/TypeInfo.h>
 | 
				
			||||||
#include <torch/csrc/api/include/torch/python/init.h>
 | 
					#include <torch/csrc/api/include/torch/python/init.h>
 | 
				
			||||||
#include <torch/csrc/autograd/python_cpp_function.h>
 | 
					#include <torch/csrc/autograd/python_cpp_function.h>
 | 
				
			||||||
#include <torch/csrc/autograd/python_dist_functions.h>
 | 
					 | 
				
			||||||
#include <torch/csrc/autograd/python_enum_tag.h>
 | 
					#include <torch/csrc/autograd/python_enum_tag.h>
 | 
				
			||||||
#include <torch/csrc/autograd/python_fft_functions.h>
 | 
					#include <torch/csrc/autograd/python_fft_functions.h>
 | 
				
			||||||
#include <torch/csrc/autograd/python_function.h>
 | 
					#include <torch/csrc/autograd/python_function.h>
 | 
				
			||||||
@ -1329,7 +1328,6 @@ PyObject* initModule() {
 | 
				
			|||||||
  torch::autograd::initNestedFunctions(module);
 | 
					  torch::autograd::initNestedFunctions(module);
 | 
				
			||||||
  torch::autograd::initSparseFunctions(module);
 | 
					  torch::autograd::initSparseFunctions(module);
 | 
				
			||||||
  torch::autograd::initSpecialFunctions(module);
 | 
					  torch::autograd::initSpecialFunctions(module);
 | 
				
			||||||
  torch::autograd::initDistFunctions(module);
 | 
					 | 
				
			||||||
  torch::autograd::init_legacy_variable(module);
 | 
					  torch::autograd::init_legacy_variable(module);
 | 
				
			||||||
  torch::profiler::initPythonBindings(module);
 | 
					  torch::profiler::initPythonBindings(module);
 | 
				
			||||||
  torch::python::init_bindings(module);
 | 
					  torch::python::init_bindings(module);
 | 
				
			||||||
 | 
				
			|||||||
@ -1,9 +0,0 @@
 | 
				
			|||||||
#pragma once
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
namespace torch {
 | 
					 | 
				
			||||||
namespace autograd {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
void initDistFunctions(PyObject* module);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
} // namespace autograd
 | 
					 | 
				
			||||||
} // namespace torch
 | 
					 | 
				
			||||||
@ -105,7 +105,7 @@ class AsyncCollectiveTensor(torch.Tensor):
 | 
				
			|||||||
    Use it inside functional collective pytorch wrappers like the following:
 | 
					    Use it inside functional collective pytorch wrappers like the following:
 | 
				
			||||||
    def functional_collective(self, group, tag):
 | 
					    def functional_collective(self, group, tag):
 | 
				
			||||||
        tag, rankset, group_size = _expand_group(group, tag)
 | 
					        tag, rankset, group_size = _expand_group(group, tag)
 | 
				
			||||||
        tensor = torch._C._dist.{collective}(self, tag, rankset, group_size)
 | 
					        tensor = torch._C._nn.{collective}(self, tag, rankset, group_size)
 | 
				
			||||||
        res = AsyncCollectiveTensor(tensor)
 | 
					        res = AsyncCollectiveTensor(tensor)
 | 
				
			||||||
        _register_wrapper_tensor(res, tensor)
 | 
					        _register_wrapper_tensor(res, tensor)
 | 
				
			||||||
        return res
 | 
					        return res
 | 
				
			||||||
@ -254,7 +254,7 @@ def wait_tensor(tensor):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
 | 
					    Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    return torch._C._dist.wait_tensor(tensor)  # type: ignore[attr-defined]
 | 
					    return torch._C._nn.wait_tensor(tensor)  # type: ignore[attr-defined]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
 | 
					def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
 | 
				
			||||||
@ -275,7 +275,7 @@ def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str =
 | 
				
			|||||||
    that information and perform collective algebraic optimization. Use other forms of input for that.
 | 
					    that information and perform collective algebraic optimization. Use other forms of input for that.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    tag, rankset, group_size = _expand_group(group, tag)
 | 
					    tag, rankset, group_size = _expand_group(group, tag)
 | 
				
			||||||
    tensor = torch._C._dist.all_reduce(self, reduceOp, tag, rankset, group_size)  # type: ignore[attr-defined]
 | 
					    tensor = torch._C._nn.all_reduce(self, reduceOp, tag, rankset, group_size)  # type: ignore[attr-defined]
 | 
				
			||||||
    res = AsyncCollectiveTensor(tensor)
 | 
					    res = AsyncCollectiveTensor(tensor)
 | 
				
			||||||
    _register_wrapper_tensor(res, tensor)
 | 
					    _register_wrapper_tensor(res, tensor)
 | 
				
			||||||
    return res
 | 
					    return res
 | 
				
			||||||
@ -307,9 +307,7 @@ def reduce_scatter_tensor(
 | 
				
			|||||||
    assert (
 | 
					    assert (
 | 
				
			||||||
        self.size(0) % group_size == 0
 | 
					        self.size(0) % group_size == 0
 | 
				
			||||||
    ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
 | 
					    ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
 | 
				
			||||||
    tensor = torch._C._dist.reduce_scatter_tensor(  # type: ignore[attr-defined]
 | 
					    tensor = torch._C._nn.reduce_scatter_tensor(self, reduceOp, scatter_dim, tag, rankset, group_size)  # type: ignore[attr-defined]
 | 
				
			||||||
        self, reduceOp, scatter_dim, tag, rankset, group_size
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    res = AsyncCollectiveTensor(tensor)
 | 
					    res = AsyncCollectiveTensor(tensor)
 | 
				
			||||||
    _register_wrapper_tensor(res, tensor)
 | 
					    _register_wrapper_tensor(res, tensor)
 | 
				
			||||||
    return res
 | 
					    return res
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user