From 09cb34c1dce8fe1b880bbf3115d8ddad3401d871 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Mon, 22 Sep 2025 21:12:14 +0000 Subject: [PATCH] [RELAND] Always build USE_DISTRIBUTED (#160449) and Make distributed modules importable even when backend not built (#159889) (#162594) Summary: Original: D81957844 and D81957923 Also, https://github.com/pytorch/pytorch/pull/162142 is patched in as well #buildall Test Plan: sandcastle and oss ci Rollback Plan: Reviewed By: H-Huang Pull Request resolved: https://github.com/pytorch/pytorch/pull/162594 Approved by: https://github.com/H-Huang, https://github.com/dcci --- .ci/pytorch/macos-build.sh | 7 +- .ci/pytorch/macos-test.sh | 4 + .ci/wheel/build_wheel.sh | 3 +- BUILD.bazel | 3 +- CMakeLists.txt | 12 +- buckbuild.bzl | 4 +- c10/ovrsource_defs.bzl | 4 +- caffe2/CMakeLists.txt | 142 +++++----- cmake/Dependencies.cmake | 2 +- cmake/Summary.cmake | 12 +- docs/source/conf.py | 7 - test/cpp/dist_autograd/CMakeLists.txt | 2 +- test/distributed/tensor/test_fake.py | 41 +++ test/export/test_export.py | 10 +- test/test_numa_binding.py | 5 +- tools/build_pytorch_libs.py | 3 +- torch/CMakeLists.txt | 50 ++-- torch/_C/_distributed_c10d.pyi | 9 + torch/csrc/Exceptions.h | 2 - torch/csrc/Module.cpp | 10 - torch/csrc/autograd/functions/init.cpp | 4 - torch/csrc/distributed/c10d/HashStore.cpp | 1 - torch/csrc/distributed/c10d/Work.cpp | 2 +- torch/csrc/distributed/c10d/init.cpp | 1 + torch/csrc/inductor/aoti_torch/shim_cpu.cpp | 4 - torch/csrc/jit/python/pybind_utils.h | 6 +- .../csrc/jit/python/python_sugared_value.cpp | 3 +- torch/csrc/jit/runtime/interpreter.h | 14 +- torch/csrc/jit/serialization/pickler.h | 2 - torch/csrc/jit/serialization/unpickler.h | 2 - .../standalone/execution_trace_observer.cpp | 9 - torch/csrc/profiler/util.cpp | 6 +- torch/csrc/profiler/util.h | 2 - torch/distributed/_C_stubs.py | 150 +++++++++++ torch/distributed/__init__.py | 236 ++++++++--------- torch/distributed/_dist2.py | 2 +- torch/distributed/_distributed_c10d.py | 245 ++++++++++++++++++ torch/distributed/_functional_collectives.py | 12 +- .../_shard/sharded_tensor/reshard.py | 2 +- .../chunk_sharding_spec_ops/embedding_bag.py | 2 +- .../distributed/_symmetric_memory/__init__.py | 22 +- .../_symmetric_memory/_nvshmem_triton.py | 2 +- torch/distributed/_tools/fake_collectives.py | 4 +- .../algorithms/model_averaging/utils.py | 4 - torch/distributed/constants.py | 15 +- torch/distributed/device_mesh.py | 44 +--- torch/distributed/distributed_c10d.py | 70 +++-- torch/distributed/elastic/control_plane.py | 2 +- torch/distributed/nn/functional.py | 4 - torch/distributed/rpc/__init__.py | 2 +- torch/distributed/tensor/_collective_utils.py | 4 +- .../testing/_internal/distributed/fake_pg.py | 2 +- 52 files changed, 766 insertions(+), 446 deletions(-) create mode 100644 test/distributed/tensor/test_fake.py create mode 100644 torch/distributed/_C_stubs.py create mode 100644 torch/distributed/_distributed_c10d.py diff --git a/.ci/pytorch/macos-build.sh b/.ci/pytorch/macos-build.sh index d7447e7d4858..d41c3c08e628 100755 --- a/.ci/pytorch/macos-build.sh +++ b/.ci/pytorch/macos-build.sh @@ -35,11 +35,10 @@ fi print_cmake_info if [[ ${BUILD_ENVIRONMENT} == *"distributed"* ]]; then - # Needed for inductor benchmarks, as lots of HF networks make `torch.distribtued` calls - USE_DISTRIBUTED=1 USE_OPENMP=1 WERROR=1 python setup.py bdist_wheel + USE_OPENMP=1 WERROR=1 python setup.py bdist_wheel else - # Explicitly set USE_DISTRIBUTED=0 to align with the default build config on mac. This also serves as the sole CI config that tests - # that building with USE_DISTRIBUTED=0 works at all. See https://github.com/pytorch/pytorch/issues/86448 + # NB: we always build with distributed; USE_DISTRIBUTED turns off all + # backends (specifically the gloo backend), so test that this case works too USE_DISTRIBUTED=0 USE_OPENMP=1 MACOSX_DEPLOYMENT_TARGET=11.0 WERROR=1 BUILD_TEST=OFF USE_PYTORCH_METAL=1 python setup.py bdist_wheel --plat-name macosx_11_0_arm64 fi if which sccache > /dev/null; then diff --git a/.ci/pytorch/macos-test.sh b/.ci/pytorch/macos-test.sh index a859901191e0..79d47da43171 100755 --- a/.ci/pytorch/macos-test.sh +++ b/.ci/pytorch/macos-test.sh @@ -13,9 +13,13 @@ if [[ ! $(python -c "import torch; print(int(torch.backends.openmp.is_available( fi popd +python -mpip install -r requirements.txt + # enable debug asserts in serialization export TORCH_SERIALIZATION_DEBUG=1 +python -mpip install --no-input -r requirements.txt + setup_test_python() { # The CircleCI worker hostname doesn't resolve to an address. # This environment variable makes ProcessGroupGloo default to diff --git a/.ci/wheel/build_wheel.sh b/.ci/wheel/build_wheel.sh index 2d5f4d30b4c8..98b50c0ceeaf 100755 --- a/.ci/wheel/build_wheel.sh +++ b/.ci/wheel/build_wheel.sh @@ -177,7 +177,8 @@ source ~/${desired_python}-build/bin/activate retry pip install "${PINNED_PACKAGES[@]}" -r "${pytorch_rootdir}/requirements.txt" retry brew install libomp -# For USE_DISTRIBUTED=1 on macOS, need libuv, which is build as part of tensorpipe submodule +# For USE_DISTRIBUTED=1 on macOS, this enables gloo, which needs libuv, which +# is build as part of tensorpipe submodule export USE_DISTRIBUTED=1 export USE_MKLDNN=OFF diff --git a/BUILD.bazel b/BUILD.bazel index d4202e7a2c1e..635f39eed2ce 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -22,7 +22,6 @@ COMMON_COPTS = [ "-DHAVE_SHM_UNLINK=1", "-D_FILE_OFFSET_BITS=64", "-DUSE_FBGEMM", - "-DUSE_DISTRIBUTED", "-DAT_PER_OPERATOR_HEADERS", "-DATEN_THREADING=NATIVE", "-DNO_CUDNN_DESTROY_HANDLE", @@ -811,7 +810,7 @@ cc_library( name = "torch_python", srcs = libtorch_python_core_sources + if_cuda(libtorch_python_cuda_sources) - + if_cuda(libtorch_python_distributed_sources) + + libtorch_python_distributed_sources + GENERATED_AUTOGRAD_PYTHON, hdrs = glob([ "torch/csrc/generic/*.cpp", diff --git a/CMakeLists.txt b/CMakeLists.txt index 384dd27f9262..8323f310fec4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -181,8 +181,9 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(ppc64le)") set(CPU_POWER ON) endif() -# For non-supported platforms, turn USE_DISTRIBUTED off by default. It is not -# tested and likely won't work without additional changes. +# For non-supported platforms, turn USE_DISTRIBUTED off by default. +# NB: USE_DISTRIBUTED simply disables the backend; distributed code +# still gets built if(NOT LINUX AND NOT WIN32) set(USE_DISTRIBUTED OFF @@ -262,11 +263,11 @@ option(USE_PYTORCH_METAL "Use Metal for PyTorch iOS build" OFF) option(USE_PYTORCH_METAL_EXPORT "Export Metal models on MacOSX desktop" OFF) option(USE_NATIVE_ARCH "Use -march=native" OFF) cmake_dependent_option(USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF) -option(USE_DISTRIBUTED "Use distributed" ON) +option(USE_DISTRIBUTED "Enable default distributed backends" ON) cmake_dependent_option(USE_NCCL "Use NCCL" ON "USE_DISTRIBUTED;USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_XCCL "Use XCCL" ON - "USE_XPU;UNIX;NOT APPLE" OFF) + "USE_DISTRIBUTED;USE_XPU;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF) cmake_dependent_option(USE_RCCL "Use RCCL" ON "USE_NCCL;NOT WIN32" OFF) cmake_dependent_option(USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF) @@ -438,11 +439,10 @@ if(WIN32) PATH_SUFFIXES lib NO_DEFAULT_PATH) if(NOT libuv_tmp_LIBRARY) - set(USE_DISTRIBUTED OFF) set(USE_GLOO OFF) message( WARNING - "Libuv is not installed in current conda env. Set USE_DISTRIBUTED to OFF. " + "Libuv is not installed in current conda env. Set USE_GLOO to OFF. " "Please run command 'conda install -c conda-forge libuv=1.39' to install libuv." ) else() diff --git a/buckbuild.bzl b/buckbuild.bzl index 2e5a7611ce97..047ed71ad279 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -156,7 +156,7 @@ ROOT = "//" if IS_OSS else "//xplat/caffe2" # for targets in subfolders ROOT_PATH = "//" if IS_OSS else "//xplat/caffe2/" -C10 = "//c10:c10" if IS_OSS else "//xplat/caffe2/c10:c10" +C10 = "//c10:c10" if IS_OSS else ("//xplat/caffe2/c10:c10_ovrsource" if is_arvr_mode() else "//xplat/caffe2/c10:c10") # a dictionary maps third party library name to fbsource and oss target THIRD_PARTY_LIBS = { @@ -948,6 +948,7 @@ def define_buck_targets( [ ("torch/csrc/api/include", "torch/**/*.h"), ("", "torch/csrc/**/*.h"), + ("", "torch/csrc/**/*.hpp"), ("", "torch/nativert/**/*.h"), ("", "torch/headeronly/**/*.h"), ("", "torch/script.h"), @@ -2047,6 +2048,7 @@ def define_buck_targets( ("", "caffe2/utils/*.h"), ("", "caffe2/core/*.h"), ("", "torch/csrc/*.h"), + ("", "torch/csrc/*.hpp"), ("", "torch/csrc/api/include/torch/*.h"), ("", "torch/csrc/autograd/*.h"), ("", "torch/csrc/autograd/*/*.h"), diff --git a/c10/ovrsource_defs.bzl b/c10/ovrsource_defs.bzl index aafe5a4de8c4..532404f21bba 100644 --- a/c10/ovrsource_defs.bzl +++ b/c10/ovrsource_defs.bzl @@ -18,9 +18,9 @@ cuda_supported_platforms = [ def define_c10_ovrsource(name, is_mobile): if is_mobile: - pp_flags = ["-DC10_MOBILE=1"] + pp_flags = ["-DC10_MOBILE=1", "-DC10_USE_GLOG"] else: - pp_flags = [] + pp_flags = ["-DC10_USE_GLOG"] oxx_static_library( name = name, diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index f581e47b36fc..51e4023b0d18 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -540,11 +540,9 @@ if(NOT INTERN_BUILD_MOBILE AND NOT BUILD_LITE_INTERPRETER) ${TORCH_SRC_DIR}/csrc/utils/byte_order.cpp ) - if(USE_DISTRIBUTED) - append_filelist("libtorch_distributed_base_sources" TORCH_SRCS) - if(NOT WIN32) - append_filelist("libtorch_distributed_extra_sources" TORCH_SRCS) - endif() + append_filelist("libtorch_distributed_base_sources" TORCH_SRCS) + if(NOT WIN32) + append_filelist("libtorch_distributed_extra_sources" TORCH_SRCS) endif() endif() @@ -575,32 +573,30 @@ if(USE_CUDA) list(APPEND Caffe2_GPU_SRCS ${TORCH_SRC_DIR}/csrc/cuda/nccl.cpp) endif() - if(USE_DISTRIBUTED) - append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_GPU_SRCS) - if(NOT WIN32) - append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) - set_source_files_properties( - ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupNCCL.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp - PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" - ) - endif() + append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_GPU_SRCS) + if(NOT WIN32) + append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) + set_source_files_properties( + ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupNCCL.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp + PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" + ) + endif() - set(ASYNC_MM_FILE "${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/AsyncMM.cu") - # Disable the warning to make cutlass warp-specialized cooperative kernel build for gcc-9 - if(CMAKE_COMPILER_IS_GNUCXX) - set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-Wno-unused-but-set-variable") - endif() - if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0 AND CUDA_NVCC_FLAGS MATCHES ".*compute_90.*") - set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a") - endif() + set(ASYNC_MM_FILE "${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/AsyncMM.cu") + # Disable the warning to make cutlass warp-specialized cooperative kernel build for gcc-9 + if(CMAKE_COMPILER_IS_GNUCXX) + set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-Wno-unused-but-set-variable") + endif() + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0 AND CUDA_NVCC_FLAGS MATCHES ".*compute_90.*") + set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a") endif() set_source_files_properties( ${TORCH_ROOT}/aten/src/ATen/cuda/detail/LazyNVRTC.cpp @@ -633,11 +629,9 @@ if(USE_ROCM) list(APPEND Caffe2_HIP_SRCS ${TORCH_SRC_DIR}/csrc/cuda/nccl.cpp) endif() - if(USE_DISTRIBUTED) - append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_HIP_SRCS) - if(NOT WIN32) - append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_HIP_SRCS) - endif() + append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_HIP_SRCS) + if(NOT WIN32) + append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_HIP_SRCS) endif() # caffe2_nvrtc's stubs to driver APIs are useful for HIP. # See NOTE [ ATen NVRTC Stub and HIP ] @@ -1358,12 +1352,10 @@ if(BUILD_TEST) add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit) add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert) add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor) - if(USE_DISTRIBUTED) - add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d) - if(NOT WIN32) - add_subdirectory(${TORCH_ROOT}/test/cpp/dist_autograd ${CMAKE_BINARY_DIR}/dist_autograd) - add_subdirectory(${TORCH_ROOT}/test/cpp/rpc ${CMAKE_BINARY_DIR}/test_cpp_rpc) - endif() + add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d) + if(NOT WIN32) + add_subdirectory(${TORCH_ROOT}/test/cpp/dist_autograd ${CMAKE_BINARY_DIR}/dist_autograd) + add_subdirectory(${TORCH_ROOT}/test/cpp/rpc ${CMAKE_BINARY_DIR}/test_cpp_rpc) endif() if(NOT NO_API) add_subdirectory(${TORCH_ROOT}/test/cpp/api ${CMAKE_BINARY_DIR}/test_api) @@ -1468,46 +1460,40 @@ if(BUILD_LITE_INTERPRETER) endif() endif() - -# Pass USE_DISTRIBUTED to torch_cpu, as some codes in jit/pickler.cpp and -# jit/unpickler.cpp need to be compiled only when USE_DISTRIBUTED is set -if(USE_DISTRIBUTED) - target_compile_definitions(torch_cpu PUBLIC USE_DISTRIBUTED) - if(USE_GLOO AND USE_C10D_GLOO) - target_compile_definitions(torch_cpu PUBLIC USE_C10D_GLOO) +if(USE_GLOO AND USE_C10D_GLOO) + target_compile_definitions(torch_cpu PUBLIC USE_C10D_GLOO) +endif() +if(USE_UCC AND USE_C10D_UCC) + target_compile_definitions(torch_cpu PUBLIC USE_C10D_UCC) + if(USE_CUDA) + target_compile_definitions(torch_cuda PUBLIC USE_C10D_UCC) endif() - if(USE_UCC AND USE_C10D_UCC) - target_compile_definitions(torch_cpu PUBLIC USE_C10D_UCC) - if(USE_CUDA) - target_compile_definitions(torch_cuda PUBLIC USE_C10D_UCC) - endif() +endif() +if(USE_NCCL AND USE_C10D_NCCL) + if(USE_ROCM) + target_compile_definitions(torch_hip PUBLIC USE_C10D_NCCL) + else() + target_compile_definitions(torch_cuda PUBLIC USE_C10D_NCCL) endif() - if(USE_NCCL AND USE_C10D_NCCL) - if(USE_ROCM) - target_compile_definitions(torch_hip PUBLIC USE_C10D_NCCL) - else() - target_compile_definitions(torch_cuda PUBLIC USE_C10D_NCCL) - endif() - endif() - if(USE_MPI AND USE_C10D_MPI) - if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - set_source_files_properties( - "${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupMPI.cpp" - PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations) - endif() - target_compile_definitions(torch_cpu PUBLIC USE_C10D_MPI) - endif() - # Pass USE_RPC in order to reduce use of - # #if defined(USE_DISTRIBUTED) && !defined(_WIN32) - # need to be removed when RPC is supported - if(NOT WIN32) - target_compile_definitions(torch_cpu PUBLIC USE_RPC) - endif() - # Pass USE_TENSORPIPE to torch_cpu as some parts of rpc/utils.cpp - # can only be compiled with USE_TENSORPIPE is set. - if(USE_TENSORPIPE) - target_compile_definitions(torch_cpu PUBLIC USE_TENSORPIPE) +endif() +if(USE_MPI AND USE_C10D_MPI) + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set_source_files_properties( + "${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupMPI.cpp" + PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations) endif() + target_compile_definitions(torch_cpu PUBLIC USE_C10D_MPI) +endif() +# Pass USE_RPC in order to reduce use of +# #if defined(USE_DISTRIBUTED) && !defined(_WIN32) +# need to be removed when RPC is supported +if(NOT WIN32) + target_compile_definitions(torch_cpu PUBLIC USE_RPC) +endif() +# Pass USE_TENSORPIPE to torch_cpu as some parts of rpc/utils.cpp +# can only be compiled with USE_TENSORPIPE is set. +if(USE_TENSORPIPE) + target_compile_definitions(torch_cpu PUBLIC USE_TENSORPIPE) endif() if(NOT INTERN_BUILD_MOBILE) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 6ad56d3b9b44..08ffdaf8cf45 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1134,7 +1134,7 @@ if(USE_CUDA AND CUDA_VERSION VERSION_LESS 13.0) include_directories(SYSTEM ${CUB_INCLUDE_DIRS}) endif() -if(USE_DISTRIBUTED AND USE_TENSORPIPE) +if(USE_TENSORPIPE) if(MSVC) message(WARNING "Tensorpipe cannot be used on Windows.") else() diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 2e2fd370a994..a0bfb22bed80 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -193,13 +193,11 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_PYTORCH_QNNPACK : ${USE_PYTORCH_QNNPACK}") message(STATUS " USE_XNNPACK : ${USE_XNNPACK}") message(STATUS " USE_DISTRIBUTED : ${USE_DISTRIBUTED}") - if(${USE_DISTRIBUTED}) - message(STATUS " USE_MPI : ${USE_MPI}") - message(STATUS " USE_GLOO : ${USE_GLOO}") - message(STATUS " USE_GLOO_WITH_OPENSSL : ${USE_GLOO_WITH_OPENSSL}") - message(STATUS " USE_GLOO_IBVERBS : ${USE_GLOO_IBVERBS}") - message(STATUS " USE_TENSORPIPE : ${USE_TENSORPIPE}") - endif() + message(STATUS " USE_MPI : ${USE_MPI}") + message(STATUS " USE_GLOO : ${USE_GLOO}") + message(STATUS " USE_GLOO_WITH_OPENSSL : ${USE_GLOO_WITH_OPENSSL}") + message(STATUS " USE_GLOO_IBVERBS : ${USE_GLOO_IBVERBS}") + message(STATUS " USE_TENSORPIPE : ${USE_TENSORPIPE}") if(NOT "${SELECTED_OP_LIST}" STREQUAL "") message(STATUS " SELECTED_OP_LIST : ${SELECTED_OP_LIST}") endif() diff --git a/docs/source/conf.py b/docs/source/conf.py index e2431b886418..b6fe60286e2b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -3307,13 +3307,6 @@ def coverage_post_process(app, exception): if not isinstance(app.builder, CoverageBuilder): return - if not torch.distributed.is_available(): - raise RuntimeError( - "The coverage tool cannot run with a version " - "of PyTorch that was built with USE_DISTRIBUTED=0 " - "as this module's API changes." - ) - # These are all the modules that have "automodule" in an rst file # These modules are the ones for which coverage is checked # Here, we make sure that no module is missing from that list diff --git a/test/cpp/dist_autograd/CMakeLists.txt b/test/cpp/dist_autograd/CMakeLists.txt index 14fd7f7ae9a2..86a6c924288b 100644 --- a/test/cpp/dist_autograd/CMakeLists.txt +++ b/test/cpp/dist_autograd/CMakeLists.txt @@ -1,4 +1,4 @@ -if(USE_DISTRIBUTED AND NOT WIN32) +if(NOT WIN32) set(DIST_AUTOGRAD_TEST_DIR "${TORCH_ROOT}/test/cpp/dist_autograd") set(DIST_AUTOGRAD_TEST_SOURCES ${TORCH_ROOT}/test/cpp/common/main.cpp diff --git a/test/distributed/tensor/test_fake.py b/test/distributed/tensor/test_fake.py new file mode 100644 index 000000000000..099c6e87f5f1 --- /dev/null +++ b/test/distributed/tensor/test_fake.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import Shard +from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.distributed.fake_pg import FakeStore + + +class TestFakeDTensor(TestCase): + def test_fake_dtensor_operations(self): + # Use FakeTensorMode to handle CUDA tensors without actual CUDA + fake_mode = FakeTensorMode() + world_size = 4 + + fake_store = FakeStore() + torch.distributed.init_process_group( + "fake", store=fake_store, rank=0, world_size=world_size + ) + device_mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (2, world_size // 2), + ) + + # Create fake CUDA tensor using FakeTensorMode + with fake_mode: + x = torch.randn(1, 1, device="cuda") + x = DTensor.from_local(x, device_mesh, [Shard(0), Shard(1)]) + + # Test basic DTensor operations + self.assertIsInstance(x, DTensor) + + # Test sum operation + r = x.sum(1) + self.assertIsInstance(r, DTensor) + + +if __name__ == "__main__": + run_tests() diff --git a/test/export/test_export.py b/test/export/test_export.py index 950f7e87c9d9..4bcef61e32bb 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -61,10 +61,7 @@ from torch.export.passes import move_to_device_pass from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.testing import FileCheck -from torch.testing._internal.common_cuda import ( - PLATFORM_SUPPORTS_FLASH_ATTENTION, - xfailIfDistributedNotSupported, -) +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION from torch.testing._internal.common_utils import ( find_library_location, IS_FBCODE, @@ -15808,7 +15805,6 @@ class GraphModule(torch.nn.Module): finally: torch.distributed.destroy_process_group() - @xfailIfDistributedNotSupported def test_distributed_all_reduce(self): class Foo(torch.nn.Module): def __init__(self): @@ -15826,7 +15822,6 @@ class GraphModule(torch.nn.Module): inp = (torch.randn(4, 4),) self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp))) - @xfailIfDistributedNotSupported def test_distributed_all_gather(self): class Foo(torch.nn.Module): def forward(self, x): @@ -15842,7 +15837,6 @@ class GraphModule(torch.nn.Module): torch.allclose(a, b) for a, b in zip(ep.module()(*inp), m(*inp)) ) - @xfailIfDistributedNotSupported def test_distributed_all_gather_into_tensor(self): class Foo(torch.nn.Module): def forward(self, x): @@ -15856,7 +15850,6 @@ class GraphModule(torch.nn.Module): inp = (torch.randn(2),) self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp))) - @xfailIfDistributedNotSupported @testing.expectedFailureCppRuntime def test_distributed_all_to_all_single(self): class Foo(torch.nn.Module): @@ -15874,7 +15867,6 @@ class GraphModule(torch.nn.Module): ) self.assertEqual(len(nodes), 1) - @xfailIfDistributedNotSupported @testing.expectedFailureCppRuntime def test_distributed_reduce_scatter_tensor(self): class Foo(torch.nn.Module): diff --git a/test/test_numa_binding.py b/test/test_numa_binding.py index 764156ff9b98..d38032ba2260 100644 --- a/test/test_numa_binding.py +++ b/test/test_numa_binding.py @@ -7,7 +7,7 @@ import sys from dataclasses import dataclass from multiprocessing.context import SpawnProcess from typing import Any, Optional -from unittest import skipUnless +from unittest import skipIf, skipUnless from unittest.mock import mock_open, patch import torch @@ -22,7 +22,7 @@ from torch.numa.binding import ( AffinityMode, NumaOptions, ) -from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_utils import IS_MACOS, run_tests, TestCase @dataclass(frozen=True) @@ -680,6 +680,7 @@ class NumaBindingTest(TestCase): set(range(0, 2)), ) + @skipIf(IS_MACOS, "sched_getaffinity doesn't exist") def test_binds_to_node_0_if_node_stored_as_minus_one(self) -> None: self._add_mock_hardware( num_sockets=1, diff --git a/tools/build_pytorch_libs.py b/tools/build_pytorch_libs.py index 9d43de80f129..457b224354fb 100644 --- a/tools/build_pytorch_libs.py +++ b/tools/build_pytorch_libs.py @@ -88,8 +88,7 @@ def build_pytorch( ) -> None: my_env = _create_build_env() if ( - not check_negative_env_flag("USE_DISTRIBUTED") - and not check_negative_env_flag("USE_CUDA") + not check_negative_env_flag("USE_CUDA") and not check_negative_env_flag("USE_NCCL") and not check_env_flag("USE_SYSTEM_NCCL") ): diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 866c40ad1c12..adc9aad4a05c 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -276,32 +276,30 @@ add_custom_command( WORKING_DIRECTORY "${TORCH_ROOT}" ) -if(USE_DISTRIBUTED) - if(WIN32) - append_filelist("libtorch_python_distributed_core_sources" TORCH_PYTHON_SRCS) - else() - append_filelist("libtorch_python_distributed_sources" TORCH_PYTHON_SRCS) - endif() - # Disable certain warnings for GCC-9.X - if(CMAKE_COMPILER_IS_GNUCXX) - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/rpc/testing/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - endif() - # NCCL is a private dependency of libtorch, but libtorch_python includes - # some private headers of libtorch, which in turn include NCCL. As a hacky - # alternative to making NCCL a public dependency of libtorch, we make it - # a private dependency of libtorch_python as well. - if(USE_NCCL) - list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) - endif() - # Same for MPI. - if(USE_MPI) - list(APPEND TORCH_PYTHON_LINK_LIBRARIES MPI::MPI_CXX) - endif() - list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D) +if(WIN32) + append_filelist("libtorch_python_distributed_core_sources" TORCH_PYTHON_SRCS) +else() + append_filelist("libtorch_python_distributed_sources" TORCH_PYTHON_SRCS) endif() +# Disable certain warnings for GCC-9.X +if(CMAKE_COMPILER_IS_GNUCXX) + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/rpc/testing/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") +endif() +# NCCL is a private dependency of libtorch, but libtorch_python includes +# some private headers of libtorch, which in turn include NCCL. As a hacky +# alternative to making NCCL a public dependency of libtorch, we make it +# a private dependency of libtorch_python as well. +if(USE_NCCL) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) +endif() +# Same for MPI. +if(USE_MPI) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES MPI::MPI_CXX) +endif() +list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D) if(USE_NCCL AND NOT WIN32) list(APPEND TORCH_PYTHON_SRCS @@ -369,10 +367,6 @@ if(BUILD_LIBTORCHLESS) target_compile_definitions(torch_python PRIVATE USE_C10D_NCCL) endif() - if(USE_DISTRIBUTED) - target_compile_definitions(torch_python PRIVATE USE_DISTRIBUTED) - endif() - if(USE_MPI AND USE_C10D_MPI) target_compile_definitions(torch_python PRIVATE USE_C10D_MPI) endif() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index ad3d8e3abf24..79e437063b8c 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -851,3 +851,12 @@ class ProcessGroupXCCL(Backend): def _set_process_group(pg: ProcessGroup) -> None: ... def _current_process_group() -> ProcessGroup: ... +def _dump_nccl_trace_json( + includeCollectives: Optional[bool] = ..., + onlyActive: Optional[bool] = ..., +) -> bytes: ... +def _dump_nccl_trace( + includeCollectives: Optional[bool] = ..., + includeStackTraces: Optional[bool] = ..., + onlyActive: Optional[bool] = ..., +) -> bytes: ... diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index 60a7bb644df0..d43d2b02a23e 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -15,9 +15,7 @@ #include #include -#if defined(USE_DISTRIBUTED) #include -#endif inline void PyErr_SetString(PyObject* type, const std::string& message) { PyErr_SetString(type, message.c_str()); diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 3a04926d5c02..d040e16ba528 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -121,14 +121,10 @@ #endif #endif -#ifdef USE_DISTRIBUTED -#ifdef USE_C10D #include #include #include #include -#endif -#endif #if defined(USE_VALGRIND) #include @@ -553,11 +549,7 @@ static PyObject* THPModule_getBackcompatKeepdimWarn( } static PyObject* THPModule_hasDistributed(PyObject* _unused, PyObject* noargs) { -#ifdef USE_DISTRIBUTED Py_RETURN_TRUE; -#else - Py_RETURN_FALSE; -#endif } static PyObject* THPModule_showConfig(PyObject* module, PyObject* noargs) { @@ -2009,7 +2001,6 @@ PyObject* initModule() { #ifdef USE_XPU THPUtils_addPyMethodDefs(methods, THXPModule_methods()); #endif -#if defined(USE_DISTRIBUTED) && defined(USE_C10D) THPUtils_addPyMethodDefs( methods, torch::distributed::c10d::python_functions()); #ifndef _WIN32 @@ -2019,7 +2010,6 @@ PyObject* initModule() { methods, torch::distributed::autograd::python_functions()); THPUtils_addPyMethodDefs( methods, torch::distributed::rpc::testing::python_functions()); -#endif #endif static struct PyModuleDef torchmodule = { diff --git a/torch/csrc/autograd/functions/init.cpp b/torch/csrc/autograd/functions/init.cpp index 5e19010f9ae3..05c8901e1f60 100644 --- a/torch/csrc/autograd/functions/init.cpp +++ b/torch/csrc/autograd/functions/init.cpp @@ -8,9 +8,7 @@ #include #include #include -#ifdef USE_DISTRIBUTED #include -#endif #include #include #include @@ -150,11 +148,9 @@ void THPAutograd_initFunctions() { static PyTypeObject CopyBackwardsClass; addClass(module, CopyBackwardsClass, "CopyBackwards"); -#ifdef USE_DISTRIBUTED static PyTypeObject SendRpcBackwardClass; addClass( module, SendRpcBackwardClass, "SendRpcBackward"); -#endif static PyTypeObject CopySlicesClass; addClass(module, CopySlicesClass, "CopySlices"); diff --git a/torch/csrc/distributed/c10d/HashStore.cpp b/torch/csrc/distributed/c10d/HashStore.cpp index 15befd9ec34e..1055afc4847d 100644 --- a/torch/csrc/distributed/c10d/HashStore.cpp +++ b/torch/csrc/distributed/c10d/HashStore.cpp @@ -1,6 +1,5 @@ #include -#include #include #include diff --git a/torch/csrc/distributed/c10d/Work.cpp b/torch/csrc/distributed/c10d/Work.cpp index cdec9185ce53..2c1ee42727d8 100644 --- a/torch/csrc/distributed/c10d/Work.cpp +++ b/torch/csrc/distributed/c10d/Work.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 7e79fef8392f..128fab6593b3 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -46,6 +46,7 @@ #include #include #include + #include #include diff --git a/torch/csrc/inductor/aoti_torch/shim_cpu.cpp b/torch/csrc/inductor/aoti_torch/shim_cpu.cpp index b1c864bf3fbb..a610685fe955 100644 --- a/torch/csrc/inductor/aoti_torch/shim_cpu.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_cpu.cpp @@ -1,7 +1,5 @@ -#ifdef USE_DISTRIBUTED #include -#endif #include #include @@ -533,7 +531,6 @@ AOTITorchError aoti_torch_cpu__weight_int4pack_mm_cpu_tensor( }); } -#ifdef USE_DISTRIBUTED AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce_( AtenTensorHandle inp, const char* reduce_op, @@ -566,4 +563,3 @@ AOTITorchError aoti_torch_cpu__c10d_functional_wait_tensor( *ret0 = new_tensor_handle(std::move(tmp_result)); }); } -#endif diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 5ae84e3e0c68..2c0c1ea4b9cf 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -13,6 +13,8 @@ #include #include #include +#include +#include #include #include #include @@ -24,10 +26,6 @@ #include #include #include -#ifdef USE_DISTRIBUTED -#include -#include -#endif #include #include diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 8b16e089aa50..808fe7d3605b 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -1225,7 +1225,7 @@ std::shared_ptr toSugaredValue( } else if (obj.ptr() == py::module::import("torch").attr("_check").ptr()) { return std::make_shared(); #ifdef USE_RPC - // RPC module is only available when build flag "USE_DISTRIBUTED" is on. + // This is not defined on WINDOWS } else if ( isRpcAvailable && obj.ptr() == @@ -1238,7 +1238,6 @@ std::shared_ptr toSugaredValue( return SpecialFormValue::create(prim::rpc_sync); } else if ( isRpcAvailable && - // RPC module is only available when build flag "USE_DISTRIBUTED" is on. obj.ptr() == py::module::import("torch.distributed.rpc").attr("remote").ptr()) { return SpecialFormValue::create(prim::rpc_remote); diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index 6ae9f52a0cda..be582cfb7cdd 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -128,13 +128,8 @@ struct InterpreterContinuation { std::optional tls_state = std::nullopt) : state(std::move(state_)), stack(std::move(stack_)), - tls_state_(std::move(tls_state)) -#ifdef USE_DISTRIBUTED - , - dist_autograd_context_id_(dist_autograd_context_id) -#endif - { - } + tls_state_(std::move(tls_state)), + dist_autograd_context_id_(dist_autograd_context_id) {} void operator()(); @@ -142,9 +137,10 @@ struct InterpreterContinuation { InterpreterState state; Stack stack; std::optional tls_state_ = std::nullopt; -#ifdef USE_DISTRIBUTED - int64_t dist_autograd_context_id_; +#ifndef USE_RPC + [[maybe_unused]] #endif + int64_t dist_autograd_context_id_; }; // what is the tensors type, including state from the current execution context diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index 526c840bc10e..e3379f4de65a 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -79,9 +79,7 @@ class TORCH_API Pickler { void pushTuple(const IValue& ivalue); void pushString(const std::string& string); void pushDevice(const IValue& ivalue); -#ifdef USE_DISTRIBUTED void pushRRef(const IValue& ivalue); -#endif // unmemoized version void pushStringImpl(const std::string& string); void pushStorageOfTensor(const at::Tensor& tensor); diff --git a/torch/csrc/jit/serialization/unpickler.h b/torch/csrc/jit/serialization/unpickler.h index 702a1d8816e7..208cf554ad2b 100644 --- a/torch/csrc/jit/serialization/unpickler.h +++ b/torch/csrc/jit/serialization/unpickler.h @@ -140,9 +140,7 @@ class TORCH_API Unpickler { void rebuildParameter(); void rebuildTensorFromTypeV2(); void rebuildSparseTensor(); -#ifdef USE_DISTRIBUTED void rebuildRRef(); -#endif PickleOpCode readInstruction(); PickleOpCode readOpCode() { return static_cast(read()); diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp index 1c88e80d4021..e46c141cd3f4 100644 --- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp +++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp @@ -30,15 +30,12 @@ #include #include -#ifdef USE_DISTRIBUTED #include -#endif // USE_DISTRIBUTED using namespace at; // Collective property attributes // https://github.com/pytorch/pytorch/issues/124674 -#ifdef USE_DISTRIBUTED constexpr auto kETCommsName = "collective_name"; constexpr auto kETInMsgNelems = "in_msg_nelems"; constexpr auto kETOutMsgNelems = "out_msg_nelems"; @@ -49,7 +46,6 @@ constexpr auto kETGlobalRankStride = "global_rank_stride"; constexpr auto kETGroupSize = "pg_size"; constexpr auto kETProcessGroupName = "pg_name"; constexpr auto kETProcessGroupDesc = "pg_desc"; -#endif // USE_DISTRIBUTED namespace torch::profiler::impl { @@ -269,7 +265,6 @@ static std::ofstream openOutputFile(const std::string& name) { return stream; } -#ifdef USE_DISTRIBUTED static std::string getAttrJson( const std::string& name, const std::string& type, @@ -282,7 +277,6 @@ static std::string getAttrJson( type, value); } -#endif static void writeJsonNode( std::ofstream& out, @@ -660,7 +654,6 @@ static void handleKernelBackendInfo( inline std::string getCommsNodeAttrs(const RecordFunction& fn) { // NOLINT std::vector attrs; -#ifdef USE_DISTRIBUTED // We rely on paramcommsdebug object that is available in thread local info auto debugInfo = dynamic_cast( c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO)); @@ -704,8 +697,6 @@ inline std::string getCommsNodeAttrs(const RecordFunction& fn) { // NOLINT addAttr(kGroupSize, kETGroupSize, "uint64"); -#endif // USE_DISTRIBUTED - // XXX consider using as string stream? return attrs.empty() ? "" : fmt::format(", {}", fmt::join(attrs, ", ")); } diff --git a/torch/csrc/profiler/util.cpp b/torch/csrc/profiler/util.cpp index 0b2979e6fb7e..e97699a99fd1 100644 --- a/torch/csrc/profiler/util.cpp +++ b/torch/csrc/profiler/util.cpp @@ -11,9 +11,7 @@ #ifdef USE_KINETO #include #endif -#ifdef USE_DISTRIBUTED #include -#endif // USE_DISTRIBUTED namespace torch::profiler::impl { @@ -455,7 +453,7 @@ std::unordered_map saveNcclMeta( // @lint-ignore CLANGTIDY const SaveNcclMetaConfig& config) { std::unordered_map map; -#ifdef USE_DISTRIBUTED +#if !defined(BUILD_LITE_INTERPRETER) && !defined(C10_MOBILE) auto debugInfo = dynamic_cast( c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO)); @@ -565,7 +563,7 @@ std::unordered_map saveNcclMeta( } } } -#endif // USE_DISTRIBUTED +#endif // !defined(BUILD_LITE_INTERPRETER) && !defined(C10_MOBILE) return map; } diff --git a/torch/csrc/profiler/util.h b/torch/csrc/profiler/util.h index f2ae57fa0e59..dcb4b866a2de 100644 --- a/torch/csrc/profiler/util.h +++ b/torch/csrc/profiler/util.h @@ -185,7 +185,6 @@ struct HashCombine { } }; -#ifdef USE_DISTRIBUTED constexpr auto kCommsName = "Collective name"; constexpr auto kDtype = "dtype"; constexpr auto kInMsgNelems = "In msg nelems"; @@ -203,6 +202,5 @@ constexpr auto kP2pSrc = "Src Rank"; constexpr auto kP2pDst = "Dst Rank"; constexpr auto kInTensorsStart = "Input Tensors start"; constexpr auto kOutTensorsStart = "Output Tensors start"; -#endif // USE_DISTRIBUTED } // namespace torch::profiler::impl diff --git a/torch/distributed/_C_stubs.py b/torch/distributed/_C_stubs.py new file mode 100644 index 000000000000..b241006372b6 --- /dev/null +++ b/torch/distributed/_C_stubs.py @@ -0,0 +1,150 @@ +# mypy: allow-untyped-defs +""" +Python stubs for backend-specific distributed components. + +Since _C._distributed_c10d always exists now, this module only provides +stubs for backend-specific functionality that may not be available in all builds +(e.g., NCCL, UCC, MPI, Gloo, etc.). +""" + +from __future__ import annotations + +from typing import Optional, TYPE_CHECKING + +from torch._C._distributed_c10d import Store + + +if TYPE_CHECKING: + from datetime import timedelta + +import torch + + +# Store classes +class HashStore(Store): + """Stub HashStore for builds without this functionality.""" + + def __init__(self, *args, **kwargs): + self._data = {} + + def set(self, key: str, value: str): + self._data[key] = value + + def get(self, key: str) -> bytes: + return self._data.get(key, "").encode() + + +# Backend-specific process group stubs +class ProcessGroupMPI: + """Stub ProcessGroupMPI for non-MPI builds.""" + + def __init__(self, *args, **kwargs): + pass + + +class ProcessGroupNCCL: + """Stub ProcessGroupNCCL for non-NCCL builds.""" + + def __init__(self, *args, **kwargs): + pass + + +class ProcessGroupGloo: + """Stub ProcessGroupGloo for non-Gloo builds.""" + + def __init__(self, *args, **kwargs): + pass + + +class ProcessGroupUCC: + """Stub ProcessGroupUCC for non-UCC builds.""" + + def __init__(self, *args, **kwargs): + pass + + +class ProcessGroupXCCL: + """Stub ProcessGroupXCCL for non-XCCL builds.""" + + def __init__(self, *args, **kwargs): + pass + + +class _ProcessGroupWrapper: + """Stub _ProcessGroupWrapper for non-Gloo builds.""" + + def __init__(self, process_group, *args, **kwargs): + self._process_group = process_group + + def __getattr__(self, name): + return getattr(self._process_group, name) + + +# NCCL-specific function stubs +_DEFAULT_PG_NCCL_TIMEOUT: Optional[timedelta] = None + + +def _hash_tensors(tensors): + """Stub function to hash tensors - returns dummy hash.""" + return 0 + + +def _dump_nccl_trace_json( + includeCollectives: Optional[bool] = None, onlyActive: Optional[bool] = None +) -> bytes: + """Stub function that returns empty JSON trace.""" + return b"{}" + + +def _dump_nccl_trace( + includeCollectives: Optional[bool] = None, + includeStackTraces: Optional[bool] = None, + onlyActive: Optional[bool] = None, +) -> bytes: + """Stub function that returns empty pickle trace.""" + return b"" + + +# NVSHMEM/SymmetricMemory stubs +def _is_nvshmem_available() -> bool: + """Stub function that returns False indicating NVSHMEM is not available.""" + return False + + +def _nvshmemx_cumodule_init(module: int) -> None: + """Stub function for NVSHMEM CU module initialization.""" + + +class _SymmetricMemory: + """Stub _SymmetricMemory class for builds without this functionality.""" + + def __init__(self, *args, **kwargs): + pass + + @classmethod + def empty_strided_p2p(cls, size, stride, dtype, device, group_name=None): + """Stub that returns a regular tensor.""" + return torch.empty(size, dtype=dtype, device=device) + + @classmethod + def rendezvous(cls, tensor, group_name=None): + """Stub that returns None.""" + return None + + @classmethod + def set_group_info(cls, *args, **kwargs): + """Stub that does nothing.""" + + @classmethod + def set_backend(cls, name): + """Stub that does nothing.""" + + @classmethod + def get_backend(cls, device): + """Stub that returns None.""" + return None + + @classmethod + def has_multicast_support(cls, device_type, device_index): + """Stub that returns False.""" + return False diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index 38e2fdbee803..836b00c51c3a 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -14,16 +14,10 @@ log = logging.getLogger(__name__) def is_available() -> bool: """ - Return ``True`` if the distributed package is available. - - Otherwise, - ``torch.distributed`` does not expose any other APIs. Currently, - ``torch.distributed`` is available on Linux, MacOS and Windows. Set - ``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source. - Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows, - ``USE_DISTRIBUTED=0`` for MacOS. + Always returns ``True``. Note that even if distributed is available, + there may not necessarily be any usable backends. """ - return hasattr(torch._C, "_c10d_init") + return True if is_available() and not torch._C._c10d_init(): @@ -36,132 +30,124 @@ DistNetworkError = torch._C._DistNetworkError DistStoreError = torch._C._DistStoreError QueueEmptyError = torch._C._DistQueueEmptyError -if is_available(): - from torch._C._distributed_c10d import ( - _broadcast_coalesced, - _compute_bucket_assignment_by_size, - _ControlCollectives, - _DEFAULT_FIRST_BUCKET_BYTES, - _make_nccl_premul_sum, - _register_builtin_comm_hook, - _register_comm_hook, - _StoreCollectives, - _test_python_store, - _verify_params_across_processes, - Backend as _Backend, - BuiltinCommHookType, - DebugLevel, - FileStore, - get_debug_level, - GradBucket, - Logger, - PrefixStore, - ProcessGroup as ProcessGroup, - Reducer, - set_debug_level, - set_debug_level_from_env, - Store, - TCPStore, - Work as _Work, - ) +from torch.distributed._distributed_c10d import ( + _broadcast_coalesced, + _compute_bucket_assignment_by_size, + _ControlCollectives, + _DEFAULT_FIRST_BUCKET_BYTES, + _make_nccl_premul_sum, + _register_builtin_comm_hook, + _register_comm_hook, + _StoreCollectives, + _test_python_store, + _verify_params_across_processes, + Backend as _Backend, + BuiltinCommHookType, + DebugLevel, + FileStore, + get_debug_level, + GradBucket, + Logger, + PrefixStore, + ProcessGroup as ProcessGroup, + Reducer, + set_debug_level, + set_debug_level_from_env, + Store, + TCPStore, + Work as _Work, +) - class _DistributedPdb(pdb.Pdb): - """ - Supports using PDB from inside a multiprocessing child process. - Usage: - _DistributedPdb().set_trace() - """ +class _DistributedPdb(pdb.Pdb): + """ + Supports using PDB from inside a multiprocessing child process. - def interaction(self, *args, **kwargs): - _stdin = sys.stdin - try: - sys.stdin = open("/dev/stdin") - pdb.Pdb.interaction(self, *args, **kwargs) - finally: - sys.stdin = _stdin + Usage: + _DistributedPdb().set_trace() + """ - _breakpoint_cache: dict[int, typing.Any] = {} - - def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600): - """ - Set a breakpoint, but only on a single rank. All other ranks will wait for you to be - done with the breakpoint before continuing. - - Args: - rank (int): Which rank to break on. Default: ``0`` - skip (int): Skip the first ``skip`` calls to this breakpoint. Default: ``0``. - """ - if skip > 0: - key = hash(str(traceback.format_exc())) - counter = _breakpoint_cache.get(key, 0) + 1 - _breakpoint_cache[key] = counter - if counter <= skip: - log.warning("Skip the breakpoint, counter=%d", counter) - return - - # avoid having the default timeout (if short) interrupt your debug session - if timeout_s is not None: - for group in torch.distributed.distributed_c10d._pg_map: - torch.distributed.distributed_c10d._set_pg_timeout( - timedelta(seconds=timeout_s), group - ) - - if get_rank() == rank: - pdb = _DistributedPdb() - pdb.message( - "\n!!! ATTENTION !!!\n\n" - f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n" - ) - pdb.set_trace() - # If Meta/Python keys are in the TLS, we want to make sure that we ignore them - # and hit the (default) CPU/CUDA implementation of barrier. - meta_in_tls = torch._C._meta_in_tls_dispatch_include() - guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] - torch._C._set_meta_in_tls_dispatch_include(False) + def interaction(self, *args, **kwargs): + _stdin = sys.stdin try: - barrier() + sys.stdin = open("/dev/stdin") + pdb.Pdb.interaction(self, *args, **kwargs) finally: - torch._C._set_meta_in_tls_dispatch_include(meta_in_tls) - del guard + sys.stdin = _stdin - if sys.platform != "win32": - from torch._C._distributed_c10d import HashStore - from .device_mesh import DeviceMesh, init_device_mesh +_breakpoint_cache: dict[int, typing.Any] = {} - # Variables prefixed with underscore are not auto imported - # See the comment in `distributed_c10d.py` above `_backend` on why we expose - # this. - from .distributed_c10d import * # noqa: F403 - from .distributed_c10d import ( - _all_gather_base, - _coalescing_manager, - _CoalescingManager, - _create_process_group_wrapper, - _get_process_group_name, - _rank_not_in_group, - _reduce_scatter_base, - _time_estimator, - get_node_local_rank, - ) - from .remote_device import _remote_device - from .rendezvous import ( - _create_store_from_options, - register_rendezvous_handler, - rendezvous, - ) - set_debug_level_from_env() +def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600): + """ + Set a breakpoint, but only on a single rank. All other ranks will wait for you to be + done with the breakpoint before continuing. -else: - # This stub is sufficient to get - # python test/test_public_bindings.py -k test_correct_module_names - # working even when USE_DISTRIBUTED=0. Feel free to add more - # stubs as necessary. - # We cannot define stubs directly because they confuse pyre + Args: + rank (int): Which rank to break on. Default: ``0`` + skip (int): Skip the first ``skip`` calls to this breakpoint. Default: ``0``. + """ + if skip > 0: + key = hash(str(traceback.format_exc())) + counter = _breakpoint_cache.get(key, 0) + 1 + _breakpoint_cache[key] = counter + if counter <= skip: + log.warning("Skip the breakpoint, counter=%d", counter) + return - class _ProcessGroupStub: - pass + # avoid having the default timeout (if short) interrupt your debug session + if timeout_s is not None: + for group in torch.distributed.distributed_c10d._pg_map: + torch.distributed.distributed_c10d._set_pg_timeout( + timedelta(seconds=timeout_s), group + ) - sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined] + if get_rank() == rank: + pdb = _DistributedPdb() + pdb.message( + "\n!!! ATTENTION !!!\n\n" + f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n" + ) + pdb.set_trace() + # If Meta/Python keys are in the TLS, we want to make sure that we ignore them + # and hit the (default) CPU/CUDA implementation of barrier. + meta_in_tls = torch._C._meta_in_tls_dispatch_include() + guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] + torch._C._set_meta_in_tls_dispatch_include(False) + try: + barrier() + finally: + torch._C._set_meta_in_tls_dispatch_include(meta_in_tls) + del guard + + +if sys.platform != "win32": + from torch.distributed._distributed_c10d import HashStore + +from .device_mesh import DeviceMesh, init_device_mesh + +# Variables prefixed with underscore are not auto imported +# See the comment in `distributed_c10d.py` above `_backend` on why we expose +# this. +from .distributed_c10d import * # noqa: F403 +from .distributed_c10d import ( + _all_gather_base, + _coalescing_manager, + _CoalescingManager, + _create_process_group_wrapper, + _get_process_group_name, + _rank_not_in_group, + _reduce_scatter_base, + _time_estimator, + get_node_local_rank, +) +from .remote_device import _remote_device +from .rendezvous import ( + _create_store_from_options, + register_rendezvous_handler, + rendezvous, +) + + +set_debug_level_from_env() diff --git a/torch/distributed/_dist2.py b/torch/distributed/_dist2.py index ce5cb8d7e0cc..1c27bf55d683 100644 --- a/torch/distributed/_dist2.py +++ b/torch/distributed/_dist2.py @@ -10,7 +10,7 @@ from datetime import timedelta from typing import Protocol, Union import torch -from torch._C._distributed_c10d import ( +from torch.distributed._distributed_c10d import ( _current_process_group, _set_process_group, ProcessGroup, diff --git a/torch/distributed/_distributed_c10d.py b/torch/distributed/_distributed_c10d.py new file mode 100644 index 000000000000..beb7830edc1d --- /dev/null +++ b/torch/distributed/_distributed_c10d.py @@ -0,0 +1,245 @@ +# mypy: disable-error-code="assignment" +# noqa: F401 +""" +Centralized module for importing and re-exporting torch._C._distributed_c10d components. + +IMPORTANT PATTERN: +Never access torch._C._distributed_c10d directly in code. Always import from and use +torch.distributed._distributed_c10d which is guaranteed to have all functions available. + +Example: + # WRONG: torch._C._distributed_c10d._set_global_rank(rank) + # RIGHT: + from torch.distributed._distributed_c10d import _set_global_rank + _set_global_rank(rank) +""" + +from typing import TYPE_CHECKING + +# Import all core distributed components from the C extension +# NB: This list has to be spelled out because the _C module doesn't have __all__ +from torch._C._distributed_c10d import ( + _allow_inflight_collective_as_graph_input, + _broadcast_coalesced, + _compute_bucket_assignment_by_size, + _ControlCollectives, + _current_process_group, + _DEFAULT_FIRST_BUCKET_BYTES, + _DEFAULT_PG_TIMEOUT, + _DistributedBackendOptions, + _make_nccl_premul_sum, + _register_builtin_comm_hook, + _register_comm_hook, + _register_process_group, + _register_work, + _resolve_process_group, + _set_allow_inflight_collective_as_graph_input, + _set_global_rank, + _set_process_group, + _StoreCollectives, + _test_python_store, + _unregister_all_process_groups, + _unregister_process_group, + _verify_params_across_processes, + _WorkerServer, + AllgatherOptions, + AllreduceCoalescedOptions, + AllreduceOptions, + AllToAllOptions, + Backend, + BarrierOptions, + BroadcastOptions, + BuiltinCommHookType, + DebugLevel, + FakeProcessGroup, + FakeWork, + FileStore, + GatherOptions, + get_debug_level, + GradBucket, + Logger, + PrefixStore, + ProcessGroup, + ReduceOp, + ReduceOptions, + Reducer, + ReduceScatterOptions, + ScatterOptions, + set_debug_level, + set_debug_level_from_env, + Store, + TCPStore, + Work, +) + + +# Backend-specific components that may not be available +_MPI_AVAILABLE = False +_NCCL_AVAILABLE = False +_GLOO_AVAILABLE = False +_UCC_AVAILABLE = False +_XCCL_AVAILABLE = False + +# HashStore +try: + from torch._C._distributed_c10d import HashStore +except ImportError: + if not TYPE_CHECKING: + from torch.distributed._C_stubs import HashStore + +# NVSHMEM/SymmetricMemory components + +# There are multiple backends for SymmetricMemory, as a result, +# _SymmetricMemory should not be imported together with NVSHMEM related modules. +try: + from torch._C._distributed_c10d import _SymmetricMemory +except ImportError: + if not TYPE_CHECKING: + from torch.distributed._C_stubs import _SymmetricMemory + +try: + from torch._C._distributed_c10d import ( + _is_nvshmem_available, + _nvshmemx_cumodule_init, + ) +except ImportError: + if not TYPE_CHECKING: + from torch.distributed._C_stubs import ( + _is_nvshmem_available, + _nvshmemx_cumodule_init, + ) + +# MPI backend +try: + from torch._C._distributed_c10d import ProcessGroupMPI + + _MPI_AVAILABLE = True +except ImportError: + if not TYPE_CHECKING: + from torch.distributed._C_stubs import ProcessGroupMPI + +# NCCL backend +try: + from torch._C._distributed_c10d import ( + _DEFAULT_PG_NCCL_TIMEOUT, + _dump_nccl_trace, + _dump_nccl_trace_json, + _hash_tensors, + ProcessGroupNCCL, + ) + + _NCCL_AVAILABLE = True +except ImportError: + if not TYPE_CHECKING: + from torch.distributed._C_stubs import ( + _DEFAULT_PG_NCCL_TIMEOUT, + _dump_nccl_trace, + _dump_nccl_trace_json, + _hash_tensors, + ProcessGroupNCCL, + ) + +# Gloo backend +try: + from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo + + _GLOO_AVAILABLE = True +except ImportError: + if not TYPE_CHECKING: + from torch.distributed._C_stubs import _ProcessGroupWrapper, ProcessGroupGloo + +# UCC backend +try: + from torch._C._distributed_c10d import ProcessGroupUCC + + _UCC_AVAILABLE = True +except ImportError: + if not TYPE_CHECKING: + from torch.distributed._C_stubs import ProcessGroupUCC + +# XCCL backend +try: + from torch._C._distributed_c10d import ProcessGroupXCCL + + _XCCL_AVAILABLE = True +except ImportError: + if not TYPE_CHECKING: + from torch.distributed._C_stubs import ProcessGroupXCCL + +# Provide backwards compatibility by making all symbols available at module level +__all__ = [ + # Basic components + "_broadcast_coalesced", + "_compute_bucket_assignment_by_size", + "_ControlCollectives", + "_DEFAULT_FIRST_BUCKET_BYTES", + "_DEFAULT_PG_TIMEOUT", + "_DEFAULT_PG_NCCL_TIMEOUT", + "_make_nccl_premul_sum", + "_register_builtin_comm_hook", + "_register_comm_hook", + "_StoreCollectives", + "_test_python_store", + "_verify_params_across_processes", + "_allow_inflight_collective_as_graph_input", + "_register_work", + "_set_allow_inflight_collective_as_graph_input", + "_is_nvshmem_available", + "_nvshmemx_cumodule_init", + "_SymmetricMemory", + "_hash_tensors", + "_set_global_rank", + "_dump_nccl_trace", + "_dump_nccl_trace_json", + "Backend", + "BuiltinCommHookType", + "DebugLevel", + "FakeProcessGroup", + "FileStore", + "get_debug_level", + "GradBucket", + "HashStore", + "Logger", + "PrefixStore", + "ProcessGroup", + "Reducer", + "ReduceOp", + "set_debug_level", + "set_debug_level_from_env", + "Store", + "TCPStore", + "Work", + "FakeWork", + # Additional distributed_c10d components + "_DistributedBackendOptions", + "_register_process_group", + "_resolve_process_group", + "_unregister_all_process_groups", + "_unregister_process_group", + "_current_process_group", + "_set_process_group", + "_WorkerServer", + "AllgatherOptions", + "AllreduceCoalescedOptions", + "AllreduceOptions", + "AllToAllOptions", + "BarrierOptions", + "BroadcastOptions", + "GatherOptions", + "ReduceOptions", + "ReduceScatterOptions", + "ScatterOptions", + # Process group implementations + "ProcessGroupMPI", + "ProcessGroupNCCL", + "ProcessGroupGloo", + "ProcessGroupUCC", + "ProcessGroupXCCL", + "_ProcessGroupWrapper", + # Availability flags + "_MPI_AVAILABLE", + "_NCCL_AVAILABLE", + "_GLOO_AVAILABLE", + "_UCC_AVAILABLE", + "_XCCL_AVAILABLE", +] diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index c893794fc301..95feb6cd7971 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -7,6 +7,10 @@ from typing import Any, cast, Optional, TYPE_CHECKING, Union import torch import torch.distributed as dist import torch.distributed.distributed_c10d as c10d +from torch.distributed._distributed_c10d import ( + _allow_inflight_collective_as_graph_input, + _set_allow_inflight_collective_as_graph_input, +) from torch.distributed.device_mesh import DeviceMesh from torch.fx.experimental.proxy_tensor import get_proxy_mode @@ -858,15 +862,13 @@ def allow_inflight_collective_as_graph_input_ctx(value: bool = True): will be registered in the work registry, and the wait_tensor() in compiled region called on the output tensor of the collective will wait on the correct work object. """ - previous = torch._C._distributed_c10d._allow_inflight_collective_as_graph_input() + previous = _allow_inflight_collective_as_graph_input() try: - torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(value) + _set_allow_inflight_collective_as_graph_input(value) yield finally: - torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input( - previous - ) + _set_allow_inflight_collective_as_graph_input(previous) def _make_all_gather_out_tensor(input, group_size): diff --git a/torch/distributed/_shard/sharded_tensor/reshard.py b/torch/distributed/_shard/sharded_tensor/reshard.py index daef9c358618..2bc3d65e5c8c 100644 --- a/torch/distributed/_shard/sharded_tensor/reshard.py +++ b/torch/distributed/_shard/sharded_tensor/reshard.py @@ -4,7 +4,7 @@ import copy import torch import torch.distributed as dist import torch.distributed._shard.sharding_spec as shard_spec -from torch._C._distributed_c10d import ProcessGroup +from torch.distributed._distributed_c10d import ProcessGroup from torch.distributed._shard.metadata import ShardMetadata from torch.distributed._shard.sharding_spec._internals import ( get_chunked_dim_size, diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py index 61808d0adf62..f02563619d2f 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py @@ -4,7 +4,7 @@ from typing import cast import torch import torch.distributed as dist -from torch._C._distributed_c10d import ReduceOp +from torch.distributed._distributed_c10d import ReduceOp from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._shard.sharding_spec import ChunkShardingSpec from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index b903a7085e9f..77e05cf9b162 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -15,7 +15,12 @@ import torch import torch.distributed._functional_collectives as funcol import torch.distributed.distributed_c10d as c10d from torch._C._autograd import DeviceType -from torch._C._distributed_c10d import _SymmetricMemory, Work as _Work +from torch.distributed._distributed_c10d import ( + _register_work, + _SymmetricMemory, + ProcessGroup, + Work as _Work, +) _group_name_to_store: dict[str, c10d.Store] = {} @@ -1488,7 +1493,7 @@ def _low_contention_all_gather( src_buf = symm_mem.get_buffer(remote_rank, tensor.shape, tensor.dtype) chunks[remote_rank].copy_(src_buf) symm_mem.barrier() - torch._C._distributed_c10d._register_work(output, Work()) + _register_work(output, Work()) return output @@ -1536,7 +1541,7 @@ def _low_contention_reduce_scatter_with_symm_mem_input( ret = ret.mean(dim=0) else: raise ValueError(f"reduce_op ({reduce_op}) is not supported") - torch._C._distributed_c10d._register_work(ret, Work()) + _register_work(ret, Work()) return ret @@ -1571,7 +1576,7 @@ def _low_contention_reduce_scatter_with_workspace( ret = ret.mean(dim=0) else: raise ValueError(f"reduce_op ({reduce_op}) is not supported") - torch._C._distributed_c10d._register_work(ret, Work()) + _register_work(ret, Work()) return ret @@ -1649,7 +1654,6 @@ from typing import overload, TYPE_CHECKING, Union if TYPE_CHECKING: - from torch._C._distributed_c10d import ProcessGroup from torch.types import _device, _dtype, _int @@ -1727,8 +1731,6 @@ def rendezvous( group (Union[str, :class:`torch.distributed.ProcessGroup`]): The group identifying the participating processes. This can be either a group name or a process group object. """ - from torch._C._distributed_c10d import ProcessGroup - if isinstance(group, str): group_name = group elif isinstance(group, ProcessGroup): @@ -1746,11 +1748,7 @@ def is_nvshmem_available() -> bool: Check if NVSHMEM is available in current build and on current system. """ - try: - from torch._C._distributed_c10d import _is_nvshmem_available - except ImportError: - # Not all builds have NVSHMEM support. - return False + from torch.distributed._distributed_c10d import _is_nvshmem_available # Check if NVSHMEM is available on current system. return _is_nvshmem_available() diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index 797c611443aa..4bad8ff0ceb8 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -80,7 +80,7 @@ def enable_triton(lib_dir: Optional[str] = None) -> dict[str, str]: """ import triton - from torch._C._distributed_c10d import _nvshmemx_cumodule_init + from torch.distributed._distributed_c10d import _nvshmemx_cumodule_init if lib_dir is not None: lib_path = os.path.join(lib_dir, "libnvshmem_device.bc") diff --git a/torch/distributed/_tools/fake_collectives.py b/torch/distributed/_tools/fake_collectives.py index 3b201b395334..b89970ab3348 100644 --- a/torch/distributed/_tools/fake_collectives.py +++ b/torch/distributed/_tools/fake_collectives.py @@ -2,7 +2,9 @@ import random from typing import Any import torch -from torch._C._distributed_c10d import ( + +# Import centralized distributed components +from torch.distributed._distributed_c10d import ( _resolve_process_group, FakeWork, ProcessGroup, diff --git a/torch/distributed/algorithms/model_averaging/utils.py b/torch/distributed/algorithms/model_averaging/utils.py index fa8cc184eddc..3e3243002a9c 100644 --- a/torch/distributed/algorithms/model_averaging/utils.py +++ b/torch/distributed/algorithms/model_averaging/utils.py @@ -5,10 +5,6 @@ from typing import Union import torch import torch.distributed as dist - -# The two imports below are not always available depending on the -# USE_DISTRIBUTED compile flag. Make sure they raise import error -# if we're trying to use them. from torch.distributed import group, ProcessGroup diff --git a/torch/distributed/constants.py b/torch/distributed/constants.py index c1e604bc8675..bfa878521864 100644 --- a/torch/distributed/constants.py +++ b/torch/distributed/constants.py @@ -1,7 +1,11 @@ from datetime import timedelta from typing import Optional -from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT +# Import from centralized fallback module - no ImportError handling needed +from torch.distributed._distributed_c10d import ( + _DEFAULT_PG_NCCL_TIMEOUT, + _DEFAULT_PG_TIMEOUT, +) __all__ = ["default_pg_timeout", "default_pg_nccl_timeout"] @@ -16,11 +20,4 @@ default_pg_timeout: timedelta = _DEFAULT_PG_TIMEOUT # Later, we could consider merging them back together at the c++ layer if we can align on a same value. # (only if TORCH_NCCL_BLOCKING_WAIT or TORCH_NCCL_ASYNC_ERROR_HANDLING is set to 1). -try: - from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT - - default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT -except ImportError: - # if C++ NCCL support is not compiled, we don't have access to the default nccl value. - # if anyone is actually trying to use nccl in this state, it should error. - default_pg_nccl_timeout = None +default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 3a9363090bf7..6ee9263db8cd 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -11,35 +11,14 @@ from itertools import zip_longest from typing import Optional, TYPE_CHECKING, Union import torch -from torch.distributed import is_available from torch.utils._typing_utils import not_none __all__ = ["init_device_mesh", "DeviceMesh"] -if not is_available(): - import sys - - # We need to create the stubs when distributed is not available. - # Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```), - # since it would try to import ``torch.distributed.device_mesh`` or - # ``torch.distributed.init_device_mesh`` but cannot find them. - - class _DeviceMeshStub: - pass - - def _init_device_mesh_stub(): - pass - - sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined] - sys.modules[ - "torch.distributed.device_mesh" - ].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined] - - -else: - from torch._C._distributed_c10d import Backend as C10dBackend +if True: # just to temporarily avoid reindentation + from torch.distributed._distributed_c10d import Backend as C10dBackend from torch.distributed.distributed_c10d import ( _get_default_group, _resolve_process_group, @@ -534,15 +513,16 @@ else: # heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host # NOTE: This device selection would only work for homogeneous hardware. num_devices_per_host = device_handle.device_count() - if ( - world_size > num_devices_per_host - and world_size % num_devices_per_host != 0 - ): - raise RuntimeError( - f"DeviceMesh only support homogeneous hardware, but found " - f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!" - ) - device_handle.set_device(get_rank() % num_devices_per_host) + if num_devices_per_host: + if ( + world_size > num_devices_per_host + and world_size % num_devices_per_host != 0 + ): + raise RuntimeError( + f"DeviceMesh only support homogeneous hardware, but found " + f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!" + ) + device_handle.set_device(get_rank() % num_devices_per_host) return _get_default_group() diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index afad129ed939..75c973c4e2a6 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -19,13 +19,21 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union from typing_extensions import deprecated import torch +import torch.distributed._distributed_c10d as _c10d from torch._C import _DistStoreError as DistStoreError -from torch._C._distributed_c10d import ( +from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs +from torch.distributed._distributed_c10d import ( # Process group implementations; Availability flags _DistributedBackendOptions, + _GLOO_AVAILABLE, + _MPI_AVAILABLE, + _NCCL_AVAILABLE, + _ProcessGroupWrapper, _register_process_group, _resolve_process_group, + _UCC_AVAILABLE, _unregister_all_process_groups, _unregister_process_group, + _XCCL_AVAILABLE, AllgatherOptions, AllreduceCoalescedOptions, AllreduceOptions, @@ -37,6 +45,11 @@ from torch._C._distributed_c10d import ( get_debug_level, PrefixStore, ProcessGroup, + ProcessGroupGloo, + ProcessGroupMPI, + ProcessGroupNCCL, + ProcessGroupUCC, + ProcessGroupXCCL, ReduceOp, ReduceOptions, ReduceScatterOptions, @@ -44,7 +57,6 @@ from torch._C._distributed_c10d import ( Store, Work, ) -from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs from torch.monitor import _WaitCounter from torch.overrides import handle_torch_function, has_torch_function from torch.utils._typing_utils import not_none @@ -131,17 +143,11 @@ __all__ = [ "split_group", ] -_MPI_AVAILABLE = True -_NCCL_AVAILABLE = True -_GLOO_AVAILABLE = True -_UCC_AVAILABLE = True -_XCCL_AVAILABLE = True - _pickler = pickle.Pickler _unpickler = pickle.Unpickler -# Change __module__ of all imported types from torch._C._distributed_c10d that are public +# Change __module__ of all imported types from the distributed wrapper that are public def _export_c_types() -> None: _public_types_to_change_module = [ AllreduceCoalescedOptions, @@ -167,45 +173,26 @@ def _export_c_types() -> None: _export_c_types() -try: - from torch._C._distributed_c10d import ProcessGroupMPI - +# Add process groups to __all__ and set their module based on availability +if _MPI_AVAILABLE: ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupMPI"] -except ImportError: - _MPI_AVAILABLE = False - -try: - from torch._C._distributed_c10d import ProcessGroupNCCL +if _NCCL_AVAILABLE: ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupNCCL"] -except ImportError: - _NCCL_AVAILABLE = False - -try: - from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo +if _GLOO_AVAILABLE: ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupGloo"] -except ImportError: - _GLOO_AVAILABLE = False - -try: - from torch._C._distributed_c10d import ProcessGroupUCC +if _UCC_AVAILABLE: ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupUCC"] -except ImportError: - _UCC_AVAILABLE = False - -try: - from torch._C._distributed_c10d import ProcessGroupXCCL +if _XCCL_AVAILABLE: ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupXCCL"] -except ImportError: - _XCCL_AVAILABLE = False logger = logging.getLogger(__name__) @@ -1327,7 +1314,8 @@ def _get_default_store() -> Store: def _update_default_pg(pg) -> None: _world.default_pg = pg rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1 - torch._C._distributed_c10d._set_global_rank(rank) + + _c10d._set_global_rank(rank) def get_backend_config(group: Optional[ProcessGroup] = None) -> str: @@ -1964,7 +1952,7 @@ def _new_process_group_helper( if device_id: pg.bound_device_id = device_id - backend_class: torch._C._distributed_c10d.Backend + backend_class: _c10d.Backend for device, backend_str in backend_config.get_device_backend_map().items(): # Use the group name as prefix in the default store, such that # a single store can be reused by multiple groups. @@ -3079,7 +3067,9 @@ def _object_to_tensor(obj, device, group): if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): backend = get_backend(group) if backend == Backend.NCCL: - hash = torch._C._distributed_c10d._hash_tensors([byte_tensor]) + from torch.distributed._distributed_c10d import _hash_tensors + + hash = _hash_tensors([byte_tensor]) logger.warning( "_object_to_tensor size: %s hash value: %s", byte_tensor.numel(), @@ -3094,7 +3084,9 @@ def _tensor_to_object(tensor, tensor_size, group): if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): backend = get_backend(group) if backend == Backend.NCCL: - hash = torch._C._distributed_c10d._hash_tensors([tensor]) + from torch.distributed._distributed_c10d import _hash_tensors + + hash = _hash_tensors([tensor]) logger.warning( "_tensor_to_object size: %s hash value: %s", tensor.numel(), hash ) @@ -4971,7 +4963,7 @@ def monitored_barrier( def _create_process_group_wrapper( - wrapped_pg: torch._C._distributed_c10d.Backend, + wrapped_pg: _c10d.Backend, store_prefix: str, store: Store, rank: int, diff --git a/torch/distributed/elastic/control_plane.py b/torch/distributed/elastic/control_plane.py index 817255edd23d..63334a0ca3f6 100644 --- a/torch/distributed/elastic/control_plane.py +++ b/torch/distributed/elastic/control_plane.py @@ -14,7 +14,7 @@ TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET" @contextmanager def _worker_server(socket_path: str) -> Generator[None, None, None]: - from torch._C._distributed_c10d import _WorkerServer + from torch.distributed._distributed_c10d import _WorkerServer server = _WorkerServer(socket_path) try: diff --git a/torch/distributed/nn/functional.py b/torch/distributed/nn/functional.py index eeff877260bc..2bdf3fe2bdff 100644 --- a/torch/distributed/nn/functional.py +++ b/torch/distributed/nn/functional.py @@ -2,10 +2,6 @@ import torch import torch.distributed as dist from torch.autograd import Function - -# The two imports below are not always available depending on the -# USE_DISTRIBUTED compile flag. Make sure they raise import error -# if we're trying to use them. from torch.distributed import group, ReduceOp diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py index adf901d6b6e3..27a945a92e44 100644 --- a/torch/distributed/rpc/__init__.py +++ b/torch/distributed/rpc/__init__.py @@ -37,7 +37,6 @@ if is_available(): import numbers import torch.distributed.autograd as dist_autograd - from torch._C._distributed_c10d import Store from torch._C._distributed_rpc import ( # noqa: F401 _cleanup_python_rpc_handler, _DEFAULT_INIT_METHOD, @@ -70,6 +69,7 @@ if is_available(): RpcBackendOptions, WorkerInfo, ) + from torch.distributed._distributed_c10d import Store if _is_tensorpipe_available: from torch._C._distributed_rpc import ( # noqa: F401 diff --git a/torch/distributed/tensor/_collective_utils.py b/torch/distributed/tensor/_collective_utils.py index 4fce6fea538a..f01836c59592 100644 --- a/torch/distributed/tensor/_collective_utils.py +++ b/torch/distributed/tensor/_collective_utils.py @@ -8,8 +8,10 @@ from typing import Optional import torch import torch.distributed._functional_collectives as funcol import torch.distributed.tensor._dtensor_spec as dtensor_spec -from torch._C._distributed_c10d import _resolve_process_group from torch._logging import warning_once + +# Import from centralized fallback module - no conditional imports needed +from torch.distributed._distributed_c10d import _resolve_process_group from torch.distributed.device_mesh import _mesh_resources, DeviceMesh from torch.distributed.distributed_c10d import ( _get_group_size_by_name, diff --git a/torch/testing/_internal/distributed/fake_pg.py b/torch/testing/_internal/distributed/fake_pg.py index e160f2fe5061..a36d2da29b4a 100644 --- a/torch/testing/_internal/distributed/fake_pg.py +++ b/torch/testing/_internal/distributed/fake_pg.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import torch.distributed as dist -from torch._C._distributed_c10d import FakeProcessGroup +from torch.distributed._distributed_c10d import FakeProcessGroup class FakeStore(dist.Store):