[RELAND] Always build USE_DISTRIBUTED (#160449) and Make distributed modules importable even when backend not built (#159889) (#162594)

Summary:
Original: D81957844 and D81957923

Also, https://github.com/pytorch/pytorch/pull/162142 is patched in as well

#buildall

Test Plan:
sandcastle and oss ci

Rollback Plan:

Reviewed By: H-Huang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162594
Approved by: https://github.com/H-Huang, https://github.com/dcci
This commit is contained in:
Edward Yang
2025-09-22 21:12:14 +00:00
committed by PyTorch MergeBot
parent 4027e97791
commit 09cb34c1dc
52 changed files with 766 additions and 446 deletions

View File

@ -35,11 +35,10 @@ fi
print_cmake_info
if [[ ${BUILD_ENVIRONMENT} == *"distributed"* ]]; then
# Needed for inductor benchmarks, as lots of HF networks make `torch.distribtued` calls
USE_DISTRIBUTED=1 USE_OPENMP=1 WERROR=1 python setup.py bdist_wheel
USE_OPENMP=1 WERROR=1 python setup.py bdist_wheel
else
# Explicitly set USE_DISTRIBUTED=0 to align with the default build config on mac. This also serves as the sole CI config that tests
# that building with USE_DISTRIBUTED=0 works at all. See https://github.com/pytorch/pytorch/issues/86448
# NB: we always build with distributed; USE_DISTRIBUTED turns off all
# backends (specifically the gloo backend), so test that this case works too
USE_DISTRIBUTED=0 USE_OPENMP=1 MACOSX_DEPLOYMENT_TARGET=11.0 WERROR=1 BUILD_TEST=OFF USE_PYTORCH_METAL=1 python setup.py bdist_wheel --plat-name macosx_11_0_arm64
fi
if which sccache > /dev/null; then

View File

@ -13,9 +13,13 @@ if [[ ! $(python -c "import torch; print(int(torch.backends.openmp.is_available(
fi
popd
python -mpip install -r requirements.txt
# enable debug asserts in serialization
export TORCH_SERIALIZATION_DEBUG=1
python -mpip install --no-input -r requirements.txt
setup_test_python() {
# The CircleCI worker hostname doesn't resolve to an address.
# This environment variable makes ProcessGroupGloo default to

View File

@ -177,7 +177,8 @@ source ~/${desired_python}-build/bin/activate
retry pip install "${PINNED_PACKAGES[@]}" -r "${pytorch_rootdir}/requirements.txt"
retry brew install libomp
# For USE_DISTRIBUTED=1 on macOS, need libuv, which is build as part of tensorpipe submodule
# For USE_DISTRIBUTED=1 on macOS, this enables gloo, which needs libuv, which
# is build as part of tensorpipe submodule
export USE_DISTRIBUTED=1
export USE_MKLDNN=OFF

View File

@ -22,7 +22,6 @@ COMMON_COPTS = [
"-DHAVE_SHM_UNLINK=1",
"-D_FILE_OFFSET_BITS=64",
"-DUSE_FBGEMM",
"-DUSE_DISTRIBUTED",
"-DAT_PER_OPERATOR_HEADERS",
"-DATEN_THREADING=NATIVE",
"-DNO_CUDNN_DESTROY_HANDLE",
@ -811,7 +810,7 @@ cc_library(
name = "torch_python",
srcs = libtorch_python_core_sources
+ if_cuda(libtorch_python_cuda_sources)
+ if_cuda(libtorch_python_distributed_sources)
+ libtorch_python_distributed_sources
+ GENERATED_AUTOGRAD_PYTHON,
hdrs = glob([
"torch/csrc/generic/*.cpp",

View File

@ -181,8 +181,9 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(ppc64le)")
set(CPU_POWER ON)
endif()
# For non-supported platforms, turn USE_DISTRIBUTED off by default. It is not
# tested and likely won't work without additional changes.
# For non-supported platforms, turn USE_DISTRIBUTED off by default.
# NB: USE_DISTRIBUTED simply disables the backend; distributed code
# still gets built
if(NOT LINUX AND NOT WIN32)
set(USE_DISTRIBUTED
OFF
@ -262,11 +263,11 @@ option(USE_PYTORCH_METAL "Use Metal for PyTorch iOS build" OFF)
option(USE_PYTORCH_METAL_EXPORT "Export Metal models on MacOSX desktop" OFF)
option(USE_NATIVE_ARCH "Use -march=native" OFF)
cmake_dependent_option(USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF)
option(USE_DISTRIBUTED "Use distributed" ON)
option(USE_DISTRIBUTED "Enable default distributed backends" ON)
cmake_dependent_option(USE_NCCL "Use NCCL" ON
"USE_DISTRIBUTED;USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF)
cmake_dependent_option(USE_XCCL "Use XCCL" ON
"USE_XPU;UNIX;NOT APPLE" OFF)
"USE_DISTRIBUTED;USE_XPU;UNIX;NOT APPLE" OFF)
cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF)
cmake_dependent_option(USE_RCCL "Use RCCL" ON "USE_NCCL;NOT WIN32" OFF)
cmake_dependent_option(USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF)
@ -438,11 +439,10 @@ if(WIN32)
PATH_SUFFIXES lib
NO_DEFAULT_PATH)
if(NOT libuv_tmp_LIBRARY)
set(USE_DISTRIBUTED OFF)
set(USE_GLOO OFF)
message(
WARNING
"Libuv is not installed in current conda env. Set USE_DISTRIBUTED to OFF. "
"Libuv is not installed in current conda env. Set USE_GLOO to OFF. "
"Please run command 'conda install -c conda-forge libuv=1.39' to install libuv."
)
else()

View File

@ -156,7 +156,7 @@ ROOT = "//" if IS_OSS else "//xplat/caffe2"
# for targets in subfolders
ROOT_PATH = "//" if IS_OSS else "//xplat/caffe2/"
C10 = "//c10:c10" if IS_OSS else "//xplat/caffe2/c10:c10"
C10 = "//c10:c10" if IS_OSS else ("//xplat/caffe2/c10:c10_ovrsource" if is_arvr_mode() else "//xplat/caffe2/c10:c10")
# a dictionary maps third party library name to fbsource and oss target
THIRD_PARTY_LIBS = {
@ -948,6 +948,7 @@ def define_buck_targets(
[
("torch/csrc/api/include", "torch/**/*.h"),
("", "torch/csrc/**/*.h"),
("", "torch/csrc/**/*.hpp"),
("", "torch/nativert/**/*.h"),
("", "torch/headeronly/**/*.h"),
("", "torch/script.h"),
@ -2047,6 +2048,7 @@ def define_buck_targets(
("", "caffe2/utils/*.h"),
("", "caffe2/core/*.h"),
("", "torch/csrc/*.h"),
("", "torch/csrc/*.hpp"),
("", "torch/csrc/api/include/torch/*.h"),
("", "torch/csrc/autograd/*.h"),
("", "torch/csrc/autograd/*/*.h"),

View File

@ -18,9 +18,9 @@ cuda_supported_platforms = [
def define_c10_ovrsource(name, is_mobile):
if is_mobile:
pp_flags = ["-DC10_MOBILE=1"]
pp_flags = ["-DC10_MOBILE=1", "-DC10_USE_GLOG"]
else:
pp_flags = []
pp_flags = ["-DC10_USE_GLOG"]
oxx_static_library(
name = name,

View File

@ -540,13 +540,11 @@ if(NOT INTERN_BUILD_MOBILE AND NOT BUILD_LITE_INTERPRETER)
${TORCH_SRC_DIR}/csrc/utils/byte_order.cpp
)
if(USE_DISTRIBUTED)
append_filelist("libtorch_distributed_base_sources" TORCH_SRCS)
if(NOT WIN32)
append_filelist("libtorch_distributed_extra_sources" TORCH_SRCS)
endif()
endif()
endif()
if(USE_CUDA OR USE_ROCM)
append_filelist("libtorch_cuda_core_sources" Caffe2_GPU_HIP_JIT_FUSERS_SRCS)
@ -575,7 +573,6 @@ if(USE_CUDA)
list(APPEND Caffe2_GPU_SRCS
${TORCH_SRC_DIR}/csrc/cuda/nccl.cpp)
endif()
if(USE_DISTRIBUTED)
append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_GPU_SRCS)
if(NOT WIN32)
append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS)
@ -601,7 +598,6 @@ if(USE_CUDA)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0 AND CUDA_NVCC_FLAGS MATCHES ".*compute_90.*")
set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a")
endif()
endif()
set_source_files_properties(
${TORCH_ROOT}/aten/src/ATen/cuda/detail/LazyNVRTC.cpp
PROPERTIES COMPILE_DEFINITIONS "NVRTC_SHORTHASH=${CUDA_NVRTC_SHORTHASH}"
@ -633,12 +629,10 @@ if(USE_ROCM)
list(APPEND Caffe2_HIP_SRCS
${TORCH_SRC_DIR}/csrc/cuda/nccl.cpp)
endif()
if(USE_DISTRIBUTED)
append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_HIP_SRCS)
if(NOT WIN32)
append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_HIP_SRCS)
endif()
endif()
# caffe2_nvrtc's stubs to driver APIs are useful for HIP.
# See NOTE [ ATen NVRTC Stub and HIP ]
add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS})
@ -1358,13 +1352,11 @@ if(BUILD_TEST)
add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit)
add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert)
add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor)
if(USE_DISTRIBUTED)
add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d)
if(NOT WIN32)
add_subdirectory(${TORCH_ROOT}/test/cpp/dist_autograd ${CMAKE_BINARY_DIR}/dist_autograd)
add_subdirectory(${TORCH_ROOT}/test/cpp/rpc ${CMAKE_BINARY_DIR}/test_cpp_rpc)
endif()
endif()
if(NOT NO_API)
add_subdirectory(${TORCH_ROOT}/test/cpp/api ${CMAKE_BINARY_DIR}/test_api)
endif()
@ -1468,11 +1460,6 @@ if(BUILD_LITE_INTERPRETER)
endif()
endif()
# Pass USE_DISTRIBUTED to torch_cpu, as some codes in jit/pickler.cpp and
# jit/unpickler.cpp need to be compiled only when USE_DISTRIBUTED is set
if(USE_DISTRIBUTED)
target_compile_definitions(torch_cpu PUBLIC USE_DISTRIBUTED)
if(USE_GLOO AND USE_C10D_GLOO)
target_compile_definitions(torch_cpu PUBLIC USE_C10D_GLOO)
endif()
@ -1508,7 +1495,6 @@ if(USE_DISTRIBUTED)
if(USE_TENSORPIPE)
target_compile_definitions(torch_cpu PUBLIC USE_TENSORPIPE)
endif()
endif()
if(NOT INTERN_BUILD_MOBILE)
if(${CAFFE2_LINK_LOCAL_PROTOBUF})

View File

@ -1134,7 +1134,7 @@ if(USE_CUDA AND CUDA_VERSION VERSION_LESS 13.0)
include_directories(SYSTEM ${CUB_INCLUDE_DIRS})
endif()
if(USE_DISTRIBUTED AND USE_TENSORPIPE)
if(USE_TENSORPIPE)
if(MSVC)
message(WARNING "Tensorpipe cannot be used on Windows.")
else()

View File

@ -193,13 +193,11 @@ function(caffe2_print_configuration_summary)
message(STATUS " USE_PYTORCH_QNNPACK : ${USE_PYTORCH_QNNPACK}")
message(STATUS " USE_XNNPACK : ${USE_XNNPACK}")
message(STATUS " USE_DISTRIBUTED : ${USE_DISTRIBUTED}")
if(${USE_DISTRIBUTED})
message(STATUS " USE_MPI : ${USE_MPI}")
message(STATUS " USE_GLOO : ${USE_GLOO}")
message(STATUS " USE_GLOO_WITH_OPENSSL : ${USE_GLOO_WITH_OPENSSL}")
message(STATUS " USE_GLOO_IBVERBS : ${USE_GLOO_IBVERBS}")
message(STATUS " USE_TENSORPIPE : ${USE_TENSORPIPE}")
endif()
if(NOT "${SELECTED_OP_LIST}" STREQUAL "")
message(STATUS " SELECTED_OP_LIST : ${SELECTED_OP_LIST}")
endif()

View File

@ -3307,13 +3307,6 @@ def coverage_post_process(app, exception):
if not isinstance(app.builder, CoverageBuilder):
return
if not torch.distributed.is_available():
raise RuntimeError(
"The coverage tool cannot run with a version "
"of PyTorch that was built with USE_DISTRIBUTED=0 "
"as this module's API changes."
)
# These are all the modules that have "automodule" in an rst file
# These modules are the ones for which coverage is checked
# Here, we make sure that no module is missing from that list

View File

@ -1,4 +1,4 @@
if(USE_DISTRIBUTED AND NOT WIN32)
if(NOT WIN32)
set(DIST_AUTOGRAD_TEST_DIR "${TORCH_ROOT}/test/cpp/dist_autograd")
set(DIST_AUTOGRAD_TEST_SOURCES
${TORCH_ROOT}/test/cpp/common/main.cpp

View File

@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import torch
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Shard
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed.fake_pg import FakeStore
class TestFakeDTensor(TestCase):
def test_fake_dtensor_operations(self):
# Use FakeTensorMode to handle CUDA tensors without actual CUDA
fake_mode = FakeTensorMode()
world_size = 4
fake_store = FakeStore()
torch.distributed.init_process_group(
"fake", store=fake_store, rank=0, world_size=world_size
)
device_mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda",
(2, world_size // 2),
)
# Create fake CUDA tensor using FakeTensorMode
with fake_mode:
x = torch.randn(1, 1, device="cuda")
x = DTensor.from_local(x, device_mesh, [Shard(0), Shard(1)])
# Test basic DTensor operations
self.assertIsInstance(x, DTensor)
# Test sum operation
r = x.sum(1)
self.assertIsInstance(r, DTensor)
if __name__ == "__main__":
run_tests()

View File

@ -61,10 +61,7 @@ from torch.export.passes import move_to_device_pass
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
xfailIfDistributedNotSupported,
)
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_utils import (
find_library_location,
IS_FBCODE,
@ -15808,7 +15805,6 @@ class GraphModule(torch.nn.Module):
finally:
torch.distributed.destroy_process_group()
@xfailIfDistributedNotSupported
def test_distributed_all_reduce(self):
class Foo(torch.nn.Module):
def __init__(self):
@ -15826,7 +15822,6 @@ class GraphModule(torch.nn.Module):
inp = (torch.randn(4, 4),)
self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp)))
@xfailIfDistributedNotSupported
def test_distributed_all_gather(self):
class Foo(torch.nn.Module):
def forward(self, x):
@ -15842,7 +15837,6 @@ class GraphModule(torch.nn.Module):
torch.allclose(a, b) for a, b in zip(ep.module()(*inp), m(*inp))
)
@xfailIfDistributedNotSupported
def test_distributed_all_gather_into_tensor(self):
class Foo(torch.nn.Module):
def forward(self, x):
@ -15856,7 +15850,6 @@ class GraphModule(torch.nn.Module):
inp = (torch.randn(2),)
self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp)))
@xfailIfDistributedNotSupported
@testing.expectedFailureCppRuntime
def test_distributed_all_to_all_single(self):
class Foo(torch.nn.Module):
@ -15874,7 +15867,6 @@ class GraphModule(torch.nn.Module):
)
self.assertEqual(len(nodes), 1)
@xfailIfDistributedNotSupported
@testing.expectedFailureCppRuntime
def test_distributed_reduce_scatter_tensor(self):
class Foo(torch.nn.Module):

View File

@ -7,7 +7,7 @@ import sys
from dataclasses import dataclass
from multiprocessing.context import SpawnProcess
from typing import Any, Optional
from unittest import skipUnless
from unittest import skipIf, skipUnless
from unittest.mock import mock_open, patch
import torch
@ -22,7 +22,7 @@ from torch.numa.binding import (
AffinityMode,
NumaOptions,
)
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_utils import IS_MACOS, run_tests, TestCase
@dataclass(frozen=True)
@ -680,6 +680,7 @@ class NumaBindingTest(TestCase):
set(range(0, 2)),
)
@skipIf(IS_MACOS, "sched_getaffinity doesn't exist")
def test_binds_to_node_0_if_node_stored_as_minus_one(self) -> None:
self._add_mock_hardware(
num_sockets=1,

View File

@ -88,8 +88,7 @@ def build_pytorch(
) -> None:
my_env = _create_build_env()
if (
not check_negative_env_flag("USE_DISTRIBUTED")
and not check_negative_env_flag("USE_CUDA")
not check_negative_env_flag("USE_CUDA")
and not check_negative_env_flag("USE_NCCL")
and not check_env_flag("USE_SYSTEM_NCCL")
):

View File

@ -276,7 +276,7 @@ add_custom_command(
WORKING_DIRECTORY
"${TORCH_ROOT}"
)
if(USE_DISTRIBUTED)
if(WIN32)
append_filelist("libtorch_python_distributed_core_sources" TORCH_PYTHON_SRCS)
else()
@ -301,8 +301,6 @@ if(USE_DISTRIBUTED)
endif()
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D)
endif()
if(USE_NCCL AND NOT WIN32)
list(APPEND TORCH_PYTHON_SRCS
${TORCH_SRC_DIR}/csrc/cuda/python_nccl.cpp)
@ -369,10 +367,6 @@ if(BUILD_LIBTORCHLESS)
target_compile_definitions(torch_python PRIVATE USE_C10D_NCCL)
endif()
if(USE_DISTRIBUTED)
target_compile_definitions(torch_python PRIVATE USE_DISTRIBUTED)
endif()
if(USE_MPI AND USE_C10D_MPI)
target_compile_definitions(torch_python PRIVATE USE_C10D_MPI)
endif()

View File

@ -851,3 +851,12 @@ class ProcessGroupXCCL(Backend):
def _set_process_group(pg: ProcessGroup) -> None: ...
def _current_process_group() -> ProcessGroup: ...
def _dump_nccl_trace_json(
includeCollectives: Optional[bool] = ...,
onlyActive: Optional[bool] = ...,
) -> bytes: ...
def _dump_nccl_trace(
includeCollectives: Optional[bool] = ...,
includeStackTraces: Optional[bool] = ...,
onlyActive: Optional[bool] = ...,
) -> bytes: ...

View File

@ -15,9 +15,7 @@
#include <torch/csrc/utils/cpp_stacktraces.h>
#include <torch/csrc/utils/pybind.h>
#if defined(USE_DISTRIBUTED)
#include <torch/csrc/distributed/c10d/exception.h>
#endif
inline void PyErr_SetString(PyObject* type, const std::string& message) {
PyErr_SetString(type, message.c_str());

View File

@ -121,14 +121,10 @@
#endif
#endif
#ifdef USE_DISTRIBUTED
#ifdef USE_C10D
#include <torch/csrc/distributed/autograd/python_autograd.h>
#include <torch/csrc/distributed/c10d/c10d.h>
#include <torch/csrc/distributed/rpc/rpc.h>
#include <torch/csrc/distributed/rpc/testing/testing.h>
#endif
#endif
#if defined(USE_VALGRIND)
#include <callgrind.h>
@ -553,11 +549,7 @@ static PyObject* THPModule_getBackcompatKeepdimWarn(
}
static PyObject* THPModule_hasDistributed(PyObject* _unused, PyObject* noargs) {
#ifdef USE_DISTRIBUTED
Py_RETURN_TRUE;
#else
Py_RETURN_FALSE;
#endif
}
static PyObject* THPModule_showConfig(PyObject* module, PyObject* noargs) {
@ -2009,7 +2001,6 @@ PyObject* initModule() {
#ifdef USE_XPU
THPUtils_addPyMethodDefs(methods, THXPModule_methods());
#endif
#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
THPUtils_addPyMethodDefs(
methods, torch::distributed::c10d::python_functions());
#ifndef _WIN32
@ -2019,7 +2010,6 @@ PyObject* initModule() {
methods, torch::distributed::autograd::python_functions());
THPUtils_addPyMethodDefs(
methods, torch::distributed::rpc::testing::python_functions());
#endif
#endif
static struct PyModuleDef torchmodule = {

View File

@ -8,9 +8,7 @@
#include <torch/csrc/autograd/python_autograd.h>
#include <torch/csrc/autograd/python_cpp_function.h>
#include <torch/csrc/autograd/python_variable.h>
#ifdef USE_DISTRIBUTED
#include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
#endif
#include <torch/csrc/jit/python/python_tracer.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_numbers.h>
@ -150,11 +148,9 @@ void THPAutograd_initFunctions() {
static PyTypeObject CopyBackwardsClass;
addClass<CopyBackwards, NoCtor>(module, CopyBackwardsClass, "CopyBackwards");
#ifdef USE_DISTRIBUTED
static PyTypeObject SendRpcBackwardClass;
addClass<torch::distributed::autograd::SendRpcBackward, NoCtor>(
module, SendRpcBackwardClass, "SendRpcBackward");
#endif
static PyTypeObject CopySlicesClass;
addClass<CopySlices, NoCtor>(module, CopySlicesClass, "CopySlices");

View File

@ -1,6 +1,5 @@
#include <torch/csrc/distributed/c10d/HashStore.hpp>
#include <unistd.h>
#include <cstdint>
#include <chrono>

View File

@ -1,5 +1,5 @@
#include <ATen/ThreadLocalState.h>
#include <distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>

View File

@ -46,6 +46,7 @@
#include <fmt/format.h>
#include <pybind11/chrono.h>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>

View File

@ -1,7 +1,5 @@
#ifdef USE_DISTRIBUTED
#include <torch/csrc/distributed/c10d/Functional.hpp>
#endif
#include <torch/csrc/inductor/aoti_torch/c/shim_cpu.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
@ -533,7 +531,6 @@ AOTITorchError aoti_torch_cpu__weight_int4pack_mm_cpu_tensor(
});
}
#ifdef USE_DISTRIBUTED
AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce_(
AtenTensorHandle inp,
const char* reduce_op,
@ -566,4 +563,3 @@ AOTITorchError aoti_torch_cpu__c10d_functional_wait_tensor(
*ret0 = new_tensor_handle(std::move(tmp_result));
});
}
#endif

View File

@ -13,6 +13,8 @@
#include <torch/csrc/Layout.h>
#include <torch/csrc/QScheme.h>
#include <torch/csrc/Stream.h>
#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/distributed/rpc/rref_impl.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <torch/csrc/jit/frontend/tracer.h>
@ -24,10 +26,6 @@
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/six.h>
#ifdef USE_DISTRIBUTED
#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/distributed/rpc/rref_impl.h>
#endif
#include <ATen/core/function_schema.h>
#include <c10/core/Stream.h>

View File

@ -1225,7 +1225,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
} else if (obj.ptr() == py::module::import("torch").attr("_check").ptr()) {
return std::make_shared<TorchCheckValue>();
#ifdef USE_RPC
// RPC module is only available when build flag "USE_DISTRIBUTED" is on.
// This is not defined on WINDOWS
} else if (
isRpcAvailable &&
obj.ptr() ==
@ -1238,7 +1238,6 @@ std::shared_ptr<SugaredValue> toSugaredValue(
return SpecialFormValue::create(prim::rpc_sync);
} else if (
isRpcAvailable &&
// RPC module is only available when build flag "USE_DISTRIBUTED" is on.
obj.ptr() ==
py::module::import("torch.distributed.rpc").attr("remote").ptr()) {
return SpecialFormValue::create(prim::rpc_remote);

View File

@ -128,13 +128,8 @@ struct InterpreterContinuation {
std::optional<at::ThreadLocalState> tls_state = std::nullopt)
: state(std::move(state_)),
stack(std::move(stack_)),
tls_state_(std::move(tls_state))
#ifdef USE_DISTRIBUTED
,
dist_autograd_context_id_(dist_autograd_context_id)
#endif
{
}
tls_state_(std::move(tls_state)),
dist_autograd_context_id_(dist_autograd_context_id) {}
void operator()();
@ -142,9 +137,10 @@ struct InterpreterContinuation {
InterpreterState state;
Stack stack;
std::optional<at::ThreadLocalState> tls_state_ = std::nullopt;
#ifdef USE_DISTRIBUTED
int64_t dist_autograd_context_id_;
#ifndef USE_RPC
[[maybe_unused]]
#endif
int64_t dist_autograd_context_id_;
};
// what is the tensors type, including state from the current execution context

View File

@ -79,9 +79,7 @@ class TORCH_API Pickler {
void pushTuple(const IValue& ivalue);
void pushString(const std::string& string);
void pushDevice(const IValue& ivalue);
#ifdef USE_DISTRIBUTED
void pushRRef(const IValue& ivalue);
#endif
// unmemoized version
void pushStringImpl(const std::string& string);
void pushStorageOfTensor(const at::Tensor& tensor);

View File

@ -140,9 +140,7 @@ class TORCH_API Unpickler {
void rebuildParameter();
void rebuildTensorFromTypeV2();
void rebuildSparseTensor();
#ifdef USE_DISTRIBUTED
void rebuildRRef();
#endif
PickleOpCode readInstruction();
PickleOpCode readOpCode() {
return static_cast<PickleOpCode>(read<uint8_t>());

View File

@ -30,15 +30,12 @@
#include <torch/csrc/profiler/standalone/execution_trace_observer.h>
#include <torch/csrc/profiler/util.h>
#ifdef USE_DISTRIBUTED
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
#endif // USE_DISTRIBUTED
using namespace at;
// Collective property attributes
// https://github.com/pytorch/pytorch/issues/124674
#ifdef USE_DISTRIBUTED
constexpr auto kETCommsName = "collective_name";
constexpr auto kETInMsgNelems = "in_msg_nelems";
constexpr auto kETOutMsgNelems = "out_msg_nelems";
@ -49,7 +46,6 @@ constexpr auto kETGlobalRankStride = "global_rank_stride";
constexpr auto kETGroupSize = "pg_size";
constexpr auto kETProcessGroupName = "pg_name";
constexpr auto kETProcessGroupDesc = "pg_desc";
#endif // USE_DISTRIBUTED
namespace torch::profiler::impl {
@ -269,7 +265,6 @@ static std::ofstream openOutputFile(const std::string& name) {
return stream;
}
#ifdef USE_DISTRIBUTED
static std::string getAttrJson(
const std::string& name,
const std::string& type,
@ -282,7 +277,6 @@ static std::string getAttrJson(
type,
value);
}
#endif
static void writeJsonNode(
std::ofstream& out,
@ -660,7 +654,6 @@ static void handleKernelBackendInfo(
inline std::string getCommsNodeAttrs(const RecordFunction& fn) { // NOLINT
std::vector<std::string> attrs;
#ifdef USE_DISTRIBUTED
// We rely on paramcommsdebug object that is available in thread local info
auto debugInfo = dynamic_cast<ParamCommsDebugInfo*>(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO));
@ -704,8 +697,6 @@ inline std::string getCommsNodeAttrs(const RecordFunction& fn) { // NOLINT
addAttr(kGroupSize, kETGroupSize, "uint64");
#endif // USE_DISTRIBUTED
// XXX consider using as string stream?
return attrs.empty() ? "" : fmt::format(", {}", fmt::join(attrs, ", "));
}

View File

@ -11,9 +11,7 @@
#ifdef USE_KINETO
#include <libkineto.h>
#endif
#ifdef USE_DISTRIBUTED
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
#endif // USE_DISTRIBUTED
namespace torch::profiler::impl {
@ -455,7 +453,7 @@ std::unordered_map<std::string, std::string> saveNcclMeta(
// @lint-ignore CLANGTIDY
const SaveNcclMetaConfig& config) {
std::unordered_map<std::string, std::string> map;
#ifdef USE_DISTRIBUTED
#if !defined(BUILD_LITE_INTERPRETER) && !defined(C10_MOBILE)
auto debugInfo = dynamic_cast<ParamCommsDebugInfo*>(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO));
@ -565,7 +563,7 @@ std::unordered_map<std::string, std::string> saveNcclMeta(
}
}
}
#endif // USE_DISTRIBUTED
#endif // !defined(BUILD_LITE_INTERPRETER) && !defined(C10_MOBILE)
return map;
}

View File

@ -185,7 +185,6 @@ struct HashCombine {
}
};
#ifdef USE_DISTRIBUTED
constexpr auto kCommsName = "Collective name";
constexpr auto kDtype = "dtype";
constexpr auto kInMsgNelems = "In msg nelems";
@ -203,6 +202,5 @@ constexpr auto kP2pSrc = "Src Rank";
constexpr auto kP2pDst = "Dst Rank";
constexpr auto kInTensorsStart = "Input Tensors start";
constexpr auto kOutTensorsStart = "Output Tensors start";
#endif // USE_DISTRIBUTED
} // namespace torch::profiler::impl

View File

@ -0,0 +1,150 @@
# mypy: allow-untyped-defs
"""
Python stubs for backend-specific distributed components.
Since _C._distributed_c10d always exists now, this module only provides
stubs for backend-specific functionality that may not be available in all builds
(e.g., NCCL, UCC, MPI, Gloo, etc.).
"""
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
from torch._C._distributed_c10d import Store
if TYPE_CHECKING:
from datetime import timedelta
import torch
# Store classes
class HashStore(Store):
"""Stub HashStore for builds without this functionality."""
def __init__(self, *args, **kwargs):
self._data = {}
def set(self, key: str, value: str):
self._data[key] = value
def get(self, key: str) -> bytes:
return self._data.get(key, "").encode()
# Backend-specific process group stubs
class ProcessGroupMPI:
"""Stub ProcessGroupMPI for non-MPI builds."""
def __init__(self, *args, **kwargs):
pass
class ProcessGroupNCCL:
"""Stub ProcessGroupNCCL for non-NCCL builds."""
def __init__(self, *args, **kwargs):
pass
class ProcessGroupGloo:
"""Stub ProcessGroupGloo for non-Gloo builds."""
def __init__(self, *args, **kwargs):
pass
class ProcessGroupUCC:
"""Stub ProcessGroupUCC for non-UCC builds."""
def __init__(self, *args, **kwargs):
pass
class ProcessGroupXCCL:
"""Stub ProcessGroupXCCL for non-XCCL builds."""
def __init__(self, *args, **kwargs):
pass
class _ProcessGroupWrapper:
"""Stub _ProcessGroupWrapper for non-Gloo builds."""
def __init__(self, process_group, *args, **kwargs):
self._process_group = process_group
def __getattr__(self, name):
return getattr(self._process_group, name)
# NCCL-specific function stubs
_DEFAULT_PG_NCCL_TIMEOUT: Optional[timedelta] = None
def _hash_tensors(tensors):
"""Stub function to hash tensors - returns dummy hash."""
return 0
def _dump_nccl_trace_json(
includeCollectives: Optional[bool] = None, onlyActive: Optional[bool] = None
) -> bytes:
"""Stub function that returns empty JSON trace."""
return b"{}"
def _dump_nccl_trace(
includeCollectives: Optional[bool] = None,
includeStackTraces: Optional[bool] = None,
onlyActive: Optional[bool] = None,
) -> bytes:
"""Stub function that returns empty pickle trace."""
return b""
# NVSHMEM/SymmetricMemory stubs
def _is_nvshmem_available() -> bool:
"""Stub function that returns False indicating NVSHMEM is not available."""
return False
def _nvshmemx_cumodule_init(module: int) -> None:
"""Stub function for NVSHMEM CU module initialization."""
class _SymmetricMemory:
"""Stub _SymmetricMemory class for builds without this functionality."""
def __init__(self, *args, **kwargs):
pass
@classmethod
def empty_strided_p2p(cls, size, stride, dtype, device, group_name=None):
"""Stub that returns a regular tensor."""
return torch.empty(size, dtype=dtype, device=device)
@classmethod
def rendezvous(cls, tensor, group_name=None):
"""Stub that returns None."""
return None
@classmethod
def set_group_info(cls, *args, **kwargs):
"""Stub that does nothing."""
@classmethod
def set_backend(cls, name):
"""Stub that does nothing."""
@classmethod
def get_backend(cls, device):
"""Stub that returns None."""
return None
@classmethod
def has_multicast_support(cls, device_type, device_index):
"""Stub that returns False."""
return False

View File

@ -14,16 +14,10 @@ log = logging.getLogger(__name__)
def is_available() -> bool:
"""
Return ``True`` if the distributed package is available.
Otherwise,
``torch.distributed`` does not expose any other APIs. Currently,
``torch.distributed`` is available on Linux, MacOS and Windows. Set
``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source.
Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows,
``USE_DISTRIBUTED=0`` for MacOS.
Always returns ``True``. Note that even if distributed is available,
there may not necessarily be any usable backends.
"""
return hasattr(torch._C, "_c10d_init")
return True
if is_available() and not torch._C._c10d_init():
@ -36,8 +30,7 @@ DistNetworkError = torch._C._DistNetworkError
DistStoreError = torch._C._DistStoreError
QueueEmptyError = torch._C._DistQueueEmptyError
if is_available():
from torch._C._distributed_c10d import (
from torch.distributed._distributed_c10d import (
_broadcast_coalesced,
_compute_bucket_assignment_by_size,
_ControlCollectives,
@ -65,6 +58,7 @@ if is_available():
Work as _Work,
)
class _DistributedPdb(pdb.Pdb):
"""
Supports using PDB from inside a multiprocessing child process.
@ -81,8 +75,10 @@ if is_available():
finally:
sys.stdin = _stdin
_breakpoint_cache: dict[int, typing.Any] = {}
def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600):
"""
Set a breakpoint, but only on a single rank. All other ranks will wait for you to be
@ -125,8 +121,9 @@ if is_available():
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
del guard
if sys.platform != "win32":
from torch._C._distributed_c10d import HashStore
from torch.distributed._distributed_c10d import HashStore
from .device_mesh import DeviceMesh, init_device_mesh
@ -152,16 +149,5 @@ if is_available():
rendezvous,
)
set_debug_level_from_env()
else:
# This stub is sufficient to get
# python test/test_public_bindings.py -k test_correct_module_names
# working even when USE_DISTRIBUTED=0. Feel free to add more
# stubs as necessary.
# We cannot define stubs directly because they confuse pyre
class _ProcessGroupStub:
pass
sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined]

View File

@ -10,7 +10,7 @@ from datetime import timedelta
from typing import Protocol, Union
import torch
from torch._C._distributed_c10d import (
from torch.distributed._distributed_c10d import (
_current_process_group,
_set_process_group,
ProcessGroup,

View File

@ -0,0 +1,245 @@
# mypy: disable-error-code="assignment"
# noqa: F401
"""
Centralized module for importing and re-exporting torch._C._distributed_c10d components.
IMPORTANT PATTERN:
Never access torch._C._distributed_c10d directly in code. Always import from and use
torch.distributed._distributed_c10d which is guaranteed to have all functions available.
Example:
# WRONG: torch._C._distributed_c10d._set_global_rank(rank)
# RIGHT:
from torch.distributed._distributed_c10d import _set_global_rank
_set_global_rank(rank)
"""
from typing import TYPE_CHECKING
# Import all core distributed components from the C extension
# NB: This list has to be spelled out because the _C module doesn't have __all__
from torch._C._distributed_c10d import (
_allow_inflight_collective_as_graph_input,
_broadcast_coalesced,
_compute_bucket_assignment_by_size,
_ControlCollectives,
_current_process_group,
_DEFAULT_FIRST_BUCKET_BYTES,
_DEFAULT_PG_TIMEOUT,
_DistributedBackendOptions,
_make_nccl_premul_sum,
_register_builtin_comm_hook,
_register_comm_hook,
_register_process_group,
_register_work,
_resolve_process_group,
_set_allow_inflight_collective_as_graph_input,
_set_global_rank,
_set_process_group,
_StoreCollectives,
_test_python_store,
_unregister_all_process_groups,
_unregister_process_group,
_verify_params_across_processes,
_WorkerServer,
AllgatherOptions,
AllreduceCoalescedOptions,
AllreduceOptions,
AllToAllOptions,
Backend,
BarrierOptions,
BroadcastOptions,
BuiltinCommHookType,
DebugLevel,
FakeProcessGroup,
FakeWork,
FileStore,
GatherOptions,
get_debug_level,
GradBucket,
Logger,
PrefixStore,
ProcessGroup,
ReduceOp,
ReduceOptions,
Reducer,
ReduceScatterOptions,
ScatterOptions,
set_debug_level,
set_debug_level_from_env,
Store,
TCPStore,
Work,
)
# Backend-specific components that may not be available
_MPI_AVAILABLE = False
_NCCL_AVAILABLE = False
_GLOO_AVAILABLE = False
_UCC_AVAILABLE = False
_XCCL_AVAILABLE = False
# HashStore
try:
from torch._C._distributed_c10d import HashStore
except ImportError:
if not TYPE_CHECKING:
from torch.distributed._C_stubs import HashStore
# NVSHMEM/SymmetricMemory components
# There are multiple backends for SymmetricMemory, as a result,
# _SymmetricMemory should not be imported together with NVSHMEM related modules.
try:
from torch._C._distributed_c10d import _SymmetricMemory
except ImportError:
if not TYPE_CHECKING:
from torch.distributed._C_stubs import _SymmetricMemory
try:
from torch._C._distributed_c10d import (
_is_nvshmem_available,
_nvshmemx_cumodule_init,
)
except ImportError:
if not TYPE_CHECKING:
from torch.distributed._C_stubs import (
_is_nvshmem_available,
_nvshmemx_cumodule_init,
)
# MPI backend
try:
from torch._C._distributed_c10d import ProcessGroupMPI
_MPI_AVAILABLE = True
except ImportError:
if not TYPE_CHECKING:
from torch.distributed._C_stubs import ProcessGroupMPI
# NCCL backend
try:
from torch._C._distributed_c10d import (
_DEFAULT_PG_NCCL_TIMEOUT,
_dump_nccl_trace,
_dump_nccl_trace_json,
_hash_tensors,
ProcessGroupNCCL,
)
_NCCL_AVAILABLE = True
except ImportError:
if not TYPE_CHECKING:
from torch.distributed._C_stubs import (
_DEFAULT_PG_NCCL_TIMEOUT,
_dump_nccl_trace,
_dump_nccl_trace_json,
_hash_tensors,
ProcessGroupNCCL,
)
# Gloo backend
try:
from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo
_GLOO_AVAILABLE = True
except ImportError:
if not TYPE_CHECKING:
from torch.distributed._C_stubs import _ProcessGroupWrapper, ProcessGroupGloo
# UCC backend
try:
from torch._C._distributed_c10d import ProcessGroupUCC
_UCC_AVAILABLE = True
except ImportError:
if not TYPE_CHECKING:
from torch.distributed._C_stubs import ProcessGroupUCC
# XCCL backend
try:
from torch._C._distributed_c10d import ProcessGroupXCCL
_XCCL_AVAILABLE = True
except ImportError:
if not TYPE_CHECKING:
from torch.distributed._C_stubs import ProcessGroupXCCL
# Provide backwards compatibility by making all symbols available at module level
__all__ = [
# Basic components
"_broadcast_coalesced",
"_compute_bucket_assignment_by_size",
"_ControlCollectives",
"_DEFAULT_FIRST_BUCKET_BYTES",
"_DEFAULT_PG_TIMEOUT",
"_DEFAULT_PG_NCCL_TIMEOUT",
"_make_nccl_premul_sum",
"_register_builtin_comm_hook",
"_register_comm_hook",
"_StoreCollectives",
"_test_python_store",
"_verify_params_across_processes",
"_allow_inflight_collective_as_graph_input",
"_register_work",
"_set_allow_inflight_collective_as_graph_input",
"_is_nvshmem_available",
"_nvshmemx_cumodule_init",
"_SymmetricMemory",
"_hash_tensors",
"_set_global_rank",
"_dump_nccl_trace",
"_dump_nccl_trace_json",
"Backend",
"BuiltinCommHookType",
"DebugLevel",
"FakeProcessGroup",
"FileStore",
"get_debug_level",
"GradBucket",
"HashStore",
"Logger",
"PrefixStore",
"ProcessGroup",
"Reducer",
"ReduceOp",
"set_debug_level",
"set_debug_level_from_env",
"Store",
"TCPStore",
"Work",
"FakeWork",
# Additional distributed_c10d components
"_DistributedBackendOptions",
"_register_process_group",
"_resolve_process_group",
"_unregister_all_process_groups",
"_unregister_process_group",
"_current_process_group",
"_set_process_group",
"_WorkerServer",
"AllgatherOptions",
"AllreduceCoalescedOptions",
"AllreduceOptions",
"AllToAllOptions",
"BarrierOptions",
"BroadcastOptions",
"GatherOptions",
"ReduceOptions",
"ReduceScatterOptions",
"ScatterOptions",
# Process group implementations
"ProcessGroupMPI",
"ProcessGroupNCCL",
"ProcessGroupGloo",
"ProcessGroupUCC",
"ProcessGroupXCCL",
"_ProcessGroupWrapper",
# Availability flags
"_MPI_AVAILABLE",
"_NCCL_AVAILABLE",
"_GLOO_AVAILABLE",
"_UCC_AVAILABLE",
"_XCCL_AVAILABLE",
]

View File

@ -7,6 +7,10 @@ from typing import Any, cast, Optional, TYPE_CHECKING, Union
import torch
import torch.distributed as dist
import torch.distributed.distributed_c10d as c10d
from torch.distributed._distributed_c10d import (
_allow_inflight_collective_as_graph_input,
_set_allow_inflight_collective_as_graph_input,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.fx.experimental.proxy_tensor import get_proxy_mode
@ -858,15 +862,13 @@ def allow_inflight_collective_as_graph_input_ctx(value: bool = True):
will be registered in the work registry, and the wait_tensor() in compiled region called on
the output tensor of the collective will wait on the correct work object.
"""
previous = torch._C._distributed_c10d._allow_inflight_collective_as_graph_input()
previous = _allow_inflight_collective_as_graph_input()
try:
torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(value)
_set_allow_inflight_collective_as_graph_input(value)
yield
finally:
torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(
previous
)
_set_allow_inflight_collective_as_graph_input(previous)
def _make_all_gather_out_tensor(input, group_size):

View File

@ -4,7 +4,7 @@ import copy
import torch
import torch.distributed as dist
import torch.distributed._shard.sharding_spec as shard_spec
from torch._C._distributed_c10d import ProcessGroup
from torch.distributed._distributed_c10d import ProcessGroup
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharding_spec._internals import (
get_chunked_dim_size,

View File

@ -4,7 +4,7 @@ from typing import cast
import torch
import torch.distributed as dist
from torch._C._distributed_c10d import ReduceOp
from torch.distributed._distributed_c10d import ReduceOp
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op

View File

@ -15,7 +15,12 @@ import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d
from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _SymmetricMemory, Work as _Work
from torch.distributed._distributed_c10d import (
_register_work,
_SymmetricMemory,
ProcessGroup,
Work as _Work,
)
_group_name_to_store: dict[str, c10d.Store] = {}
@ -1488,7 +1493,7 @@ def _low_contention_all_gather(
src_buf = symm_mem.get_buffer(remote_rank, tensor.shape, tensor.dtype)
chunks[remote_rank].copy_(src_buf)
symm_mem.barrier()
torch._C._distributed_c10d._register_work(output, Work())
_register_work(output, Work())
return output
@ -1536,7 +1541,7 @@ def _low_contention_reduce_scatter_with_symm_mem_input(
ret = ret.mean(dim=0)
else:
raise ValueError(f"reduce_op ({reduce_op}) is not supported")
torch._C._distributed_c10d._register_work(ret, Work())
_register_work(ret, Work())
return ret
@ -1571,7 +1576,7 @@ def _low_contention_reduce_scatter_with_workspace(
ret = ret.mean(dim=0)
else:
raise ValueError(f"reduce_op ({reduce_op}) is not supported")
torch._C._distributed_c10d._register_work(ret, Work())
_register_work(ret, Work())
return ret
@ -1649,7 +1654,6 @@ from typing import overload, TYPE_CHECKING, Union
if TYPE_CHECKING:
from torch._C._distributed_c10d import ProcessGroup
from torch.types import _device, _dtype, _int
@ -1727,8 +1731,6 @@ def rendezvous(
group (Union[str, :class:`torch.distributed.ProcessGroup`]): The group identifying the
participating processes. This can be either a group name or a process group object.
"""
from torch._C._distributed_c10d import ProcessGroup
if isinstance(group, str):
group_name = group
elif isinstance(group, ProcessGroup):
@ -1746,11 +1748,7 @@ def is_nvshmem_available() -> bool:
Check if NVSHMEM is available in current build and on current system.
"""
try:
from torch._C._distributed_c10d import _is_nvshmem_available
except ImportError:
# Not all builds have NVSHMEM support.
return False
from torch.distributed._distributed_c10d import _is_nvshmem_available
# Check if NVSHMEM is available on current system.
return _is_nvshmem_available()

View File

@ -80,7 +80,7 @@ def enable_triton(lib_dir: Optional[str] = None) -> dict[str, str]:
"""
import triton
from torch._C._distributed_c10d import _nvshmemx_cumodule_init
from torch.distributed._distributed_c10d import _nvshmemx_cumodule_init
if lib_dir is not None:
lib_path = os.path.join(lib_dir, "libnvshmem_device.bc")

View File

@ -2,7 +2,9 @@ import random
from typing import Any
import torch
from torch._C._distributed_c10d import (
# Import centralized distributed components
from torch.distributed._distributed_c10d import (
_resolve_process_group,
FakeWork,
ProcessGroup,

View File

@ -5,10 +5,6 @@ from typing import Union
import torch
import torch.distributed as dist
# The two imports below are not always available depending on the
# USE_DISTRIBUTED compile flag. Make sure they raise import error
# if we're trying to use them.
from torch.distributed import group, ProcessGroup

View File

@ -1,7 +1,11 @@
from datetime import timedelta
from typing import Optional
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
# Import from centralized fallback module - no ImportError handling needed
from torch.distributed._distributed_c10d import (
_DEFAULT_PG_NCCL_TIMEOUT,
_DEFAULT_PG_TIMEOUT,
)
__all__ = ["default_pg_timeout", "default_pg_nccl_timeout"]
@ -16,11 +20,4 @@ default_pg_timeout: timedelta = _DEFAULT_PG_TIMEOUT
# Later, we could consider merging them back together at the c++ layer if we can align on a same value.
# (only if TORCH_NCCL_BLOCKING_WAIT or TORCH_NCCL_ASYNC_ERROR_HANDLING is set to 1).
try:
from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT
default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT
except ImportError:
# if C++ NCCL support is not compiled, we don't have access to the default nccl value.
# if anyone is actually trying to use nccl in this state, it should error.
default_pg_nccl_timeout = None

View File

@ -11,35 +11,14 @@ from itertools import zip_longest
from typing import Optional, TYPE_CHECKING, Union
import torch
from torch.distributed import is_available
from torch.utils._typing_utils import not_none
__all__ = ["init_device_mesh", "DeviceMesh"]
if not is_available():
import sys
# We need to create the stubs when distributed is not available.
# Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```),
# since it would try to import ``torch.distributed.device_mesh`` or
# ``torch.distributed.init_device_mesh`` but cannot find them.
class _DeviceMeshStub:
pass
def _init_device_mesh_stub():
pass
sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined]
sys.modules[
"torch.distributed.device_mesh"
].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined]
else:
from torch._C._distributed_c10d import Backend as C10dBackend
if True: # just to temporarily avoid reindentation
from torch.distributed._distributed_c10d import Backend as C10dBackend
from torch.distributed.distributed_c10d import (
_get_default_group,
_resolve_process_group,
@ -534,6 +513,7 @@ else:
# heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host
# NOTE: This device selection would only work for homogeneous hardware.
num_devices_per_host = device_handle.device_count()
if num_devices_per_host:
if (
world_size > num_devices_per_host
and world_size % num_devices_per_host != 0

View File

@ -19,13 +19,21 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing_extensions import deprecated
import torch
import torch.distributed._distributed_c10d as _c10d
from torch._C import _DistStoreError as DistStoreError
from torch._C._distributed_c10d import (
from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs
from torch.distributed._distributed_c10d import ( # Process group implementations; Availability flags
_DistributedBackendOptions,
_GLOO_AVAILABLE,
_MPI_AVAILABLE,
_NCCL_AVAILABLE,
_ProcessGroupWrapper,
_register_process_group,
_resolve_process_group,
_UCC_AVAILABLE,
_unregister_all_process_groups,
_unregister_process_group,
_XCCL_AVAILABLE,
AllgatherOptions,
AllreduceCoalescedOptions,
AllreduceOptions,
@ -37,6 +45,11 @@ from torch._C._distributed_c10d import (
get_debug_level,
PrefixStore,
ProcessGroup,
ProcessGroupGloo,
ProcessGroupMPI,
ProcessGroupNCCL,
ProcessGroupUCC,
ProcessGroupXCCL,
ReduceOp,
ReduceOptions,
ReduceScatterOptions,
@ -44,7 +57,6 @@ from torch._C._distributed_c10d import (
Store,
Work,
)
from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs
from torch.monitor import _WaitCounter
from torch.overrides import handle_torch_function, has_torch_function
from torch.utils._typing_utils import not_none
@ -131,17 +143,11 @@ __all__ = [
"split_group",
]
_MPI_AVAILABLE = True
_NCCL_AVAILABLE = True
_GLOO_AVAILABLE = True
_UCC_AVAILABLE = True
_XCCL_AVAILABLE = True
_pickler = pickle.Pickler
_unpickler = pickle.Unpickler
# Change __module__ of all imported types from torch._C._distributed_c10d that are public
# Change __module__ of all imported types from the distributed wrapper that are public
def _export_c_types() -> None:
_public_types_to_change_module = [
AllreduceCoalescedOptions,
@ -167,45 +173,26 @@ def _export_c_types() -> None:
_export_c_types()
try:
from torch._C._distributed_c10d import ProcessGroupMPI
# Add process groups to __all__ and set their module based on availability
if _MPI_AVAILABLE:
ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupMPI"]
except ImportError:
_MPI_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupNCCL
if _NCCL_AVAILABLE:
ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupNCCL"]
except ImportError:
_NCCL_AVAILABLE = False
try:
from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo
if _GLOO_AVAILABLE:
ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupGloo"]
except ImportError:
_GLOO_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupUCC
if _UCC_AVAILABLE:
ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupUCC"]
except ImportError:
_UCC_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupXCCL
if _XCCL_AVAILABLE:
ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupXCCL"]
except ImportError:
_XCCL_AVAILABLE = False
logger = logging.getLogger(__name__)
@ -1327,7 +1314,8 @@ def _get_default_store() -> Store:
def _update_default_pg(pg) -> None:
_world.default_pg = pg
rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1
torch._C._distributed_c10d._set_global_rank(rank)
_c10d._set_global_rank(rank)
def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
@ -1964,7 +1952,7 @@ def _new_process_group_helper(
if device_id:
pg.bound_device_id = device_id
backend_class: torch._C._distributed_c10d.Backend
backend_class: _c10d.Backend
for device, backend_str in backend_config.get_device_backend_map().items():
# Use the group name as prefix in the default store, such that
# a single store can be reused by multiple groups.
@ -3079,7 +3067,9 @@ def _object_to_tensor(obj, device, group):
if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
backend = get_backend(group)
if backend == Backend.NCCL:
hash = torch._C._distributed_c10d._hash_tensors([byte_tensor])
from torch.distributed._distributed_c10d import _hash_tensors
hash = _hash_tensors([byte_tensor])
logger.warning(
"_object_to_tensor size: %s hash value: %s",
byte_tensor.numel(),
@ -3094,7 +3084,9 @@ def _tensor_to_object(tensor, tensor_size, group):
if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
backend = get_backend(group)
if backend == Backend.NCCL:
hash = torch._C._distributed_c10d._hash_tensors([tensor])
from torch.distributed._distributed_c10d import _hash_tensors
hash = _hash_tensors([tensor])
logger.warning(
"_tensor_to_object size: %s hash value: %s", tensor.numel(), hash
)
@ -4971,7 +4963,7 @@ def monitored_barrier(
def _create_process_group_wrapper(
wrapped_pg: torch._C._distributed_c10d.Backend,
wrapped_pg: _c10d.Backend,
store_prefix: str,
store: Store,
rank: int,

View File

@ -14,7 +14,7 @@ TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET"
@contextmanager
def _worker_server(socket_path: str) -> Generator[None, None, None]:
from torch._C._distributed_c10d import _WorkerServer
from torch.distributed._distributed_c10d import _WorkerServer
server = _WorkerServer(socket_path)
try:

View File

@ -2,10 +2,6 @@
import torch
import torch.distributed as dist
from torch.autograd import Function
# The two imports below are not always available depending on the
# USE_DISTRIBUTED compile flag. Make sure they raise import error
# if we're trying to use them.
from torch.distributed import group, ReduceOp

View File

@ -37,7 +37,6 @@ if is_available():
import numbers
import torch.distributed.autograd as dist_autograd
from torch._C._distributed_c10d import Store
from torch._C._distributed_rpc import ( # noqa: F401
_cleanup_python_rpc_handler,
_DEFAULT_INIT_METHOD,
@ -70,6 +69,7 @@ if is_available():
RpcBackendOptions,
WorkerInfo,
)
from torch.distributed._distributed_c10d import Store
if _is_tensorpipe_available:
from torch._C._distributed_rpc import ( # noqa: F401

View File

@ -8,8 +8,10 @@ from typing import Optional
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._dtensor_spec as dtensor_spec
from torch._C._distributed_c10d import _resolve_process_group
from torch._logging import warning_once
# Import from centralized fallback module - no conditional imports needed
from torch.distributed._distributed_c10d import _resolve_process_group
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.distributed_c10d import (
_get_group_size_by_name,

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import torch.distributed as dist
from torch._C._distributed_c10d import FakeProcessGroup
from torch.distributed._distributed_c10d import FakeProcessGroup
class FakeStore(dist.Store):