From b7034e9c924412bfbe8ee25a22d7e95239b5ca65 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Mon, 1 Sep 2025 14:59:40 -0400 Subject: [PATCH] Always build USE_DISTRIBUTED. (#160449) Signed-off-by: Edward Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/160449 Approved by: https://github.com/wconstab, https://github.com/albanD, https://github.com/dcci --- .ci/pytorch/macos-build.sh | 7 +- .ci/pytorch/macos-test.sh | 2 + .ci/wheel/build_wheel.sh | 3 +- BUILD.bazel | 1 - CMakeLists.txt | 12 +- caffe2/CMakeLists.txt | 142 ++++++++---------- cmake/Dependencies.cmake | 2 +- cmake/Summary.cmake | 12 +- docs/source/conf.py | 7 - test/cpp/dist_autograd/CMakeLists.txt | 2 +- test/export/test_export.py | 10 +- tools/build_pytorch_libs.py | 3 +- torch/CMakeLists.txt | 50 +++--- torch/csrc/Exceptions.h | 2 - torch/csrc/Module.cpp | 8 +- torch/csrc/autograd/functions/init.cpp | 4 - torch/csrc/inductor/aoti_torch/shim_cpu.cpp | 4 - torch/csrc/jit/python/pybind_utils.h | 6 +- .../csrc/jit/python/python_sugared_value.cpp | 3 +- torch/csrc/jit/runtime/interpreter.h | 14 +- torch/csrc/jit/serialization/pickler.h | 2 - torch/csrc/jit/serialization/unpickler.h | 2 - .../standalone/execution_trace_observer.cpp | 9 -- torch/csrc/profiler/util.cpp | 4 - torch/csrc/profiler/util.h | 2 - torch/distributed/__init__.py | 12 +- .../algorithms/model_averaging/utils.py | 4 - torch/distributed/nn/functional.py | 4 - 28 files changed, 120 insertions(+), 213 deletions(-) diff --git a/.ci/pytorch/macos-build.sh b/.ci/pytorch/macos-build.sh index d7447e7d4858..d41c3c08e628 100755 --- a/.ci/pytorch/macos-build.sh +++ b/.ci/pytorch/macos-build.sh @@ -35,11 +35,10 @@ fi print_cmake_info if [[ ${BUILD_ENVIRONMENT} == *"distributed"* ]]; then - # Needed for inductor benchmarks, as lots of HF networks make `torch.distribtued` calls - USE_DISTRIBUTED=1 USE_OPENMP=1 WERROR=1 python setup.py bdist_wheel + USE_OPENMP=1 WERROR=1 python setup.py bdist_wheel else - # Explicitly set USE_DISTRIBUTED=0 to align with the default build config on mac. This also serves as the sole CI config that tests - # that building with USE_DISTRIBUTED=0 works at all. See https://github.com/pytorch/pytorch/issues/86448 + # NB: we always build with distributed; USE_DISTRIBUTED turns off all + # backends (specifically the gloo backend), so test that this case works too USE_DISTRIBUTED=0 USE_OPENMP=1 MACOSX_DEPLOYMENT_TARGET=11.0 WERROR=1 BUILD_TEST=OFF USE_PYTORCH_METAL=1 python setup.py bdist_wheel --plat-name macosx_11_0_arm64 fi if which sccache > /dev/null; then diff --git a/.ci/pytorch/macos-test.sh b/.ci/pytorch/macos-test.sh index f7a7f950e453..401749cc94f7 100755 --- a/.ci/pytorch/macos-test.sh +++ b/.ci/pytorch/macos-test.sh @@ -16,6 +16,8 @@ popd # enable debug asserts in serialization export TORCH_SERIALIZATION_DEBUG=1 +python -mpip install --no-input -r requirements.txt + setup_test_python() { # The CircleCI worker hostname doesn't resolve to an address. # This environment variable makes ProcessGroupGloo default to diff --git a/.ci/wheel/build_wheel.sh b/.ci/wheel/build_wheel.sh index b9b6448ae208..9ce81a883126 100755 --- a/.ci/wheel/build_wheel.sh +++ b/.ci/wheel/build_wheel.sh @@ -213,7 +213,8 @@ pip install requests ninja typing-extensions retry pip install -r "${pytorch_rootdir}/requirements.txt" || true retry brew install libomp -# For USE_DISTRIBUTED=1 on macOS, need libuv, which is build as part of tensorpipe submodule +# For USE_DISTRIBUTED=1 on macOS, this enables gloo, which needs libuv, which +# is build as part of tensorpipe submodule export USE_DISTRIBUTED=1 export USE_MKLDNN=OFF diff --git a/BUILD.bazel b/BUILD.bazel index d4202e7a2c1e..2cbd36f06761 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -22,7 +22,6 @@ COMMON_COPTS = [ "-DHAVE_SHM_UNLINK=1", "-D_FILE_OFFSET_BITS=64", "-DUSE_FBGEMM", - "-DUSE_DISTRIBUTED", "-DAT_PER_OPERATOR_HEADERS", "-DATEN_THREADING=NATIVE", "-DNO_CUDNN_DESTROY_HANDLE", diff --git a/CMakeLists.txt b/CMakeLists.txt index 6b6d6be45941..3825cc494ab6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -181,8 +181,9 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(ppc64le)") set(CPU_POWER ON) endif() -# For non-supported platforms, turn USE_DISTRIBUTED off by default. It is not -# tested and likely won't work without additional changes. +# For non-supported platforms, turn USE_DISTRIBUTED off by default. +# NB: USE_DISTRIBUTED simply disables the backend; distributed code +# still gets built if(NOT LINUX AND NOT WIN32) set(USE_DISTRIBUTED OFF @@ -261,11 +262,11 @@ option(USE_PYTORCH_METAL "Use Metal for PyTorch iOS build" OFF) option(USE_PYTORCH_METAL_EXPORT "Export Metal models on MacOSX desktop" OFF) option(USE_NATIVE_ARCH "Use -march=native" OFF) cmake_dependent_option(USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF) -option(USE_DISTRIBUTED "Use distributed" ON) +option(USE_DISTRIBUTED "Enable default distributed backends" ON) cmake_dependent_option(USE_NCCL "Use NCCL" ON "USE_DISTRIBUTED;USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_XCCL "Use XCCL" ON - "USE_XPU;UNIX;NOT APPLE" OFF) + "USE_DISTRIBUTED;USE_XPU;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF) cmake_dependent_option(USE_RCCL "Use RCCL" ON "USE_NCCL;NOT WIN32" OFF) cmake_dependent_option(USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF) @@ -430,11 +431,10 @@ if(WIN32) PATH_SUFFIXES lib NO_DEFAULT_PATH) if(NOT libuv_tmp_LIBRARY) - set(USE_DISTRIBUTED OFF) set(USE_GLOO OFF) message( WARNING - "Libuv is not installed in current conda env. Set USE_DISTRIBUTED to OFF. " + "Libuv is not installed in current conda env. Set USE_GLOO to OFF. " "Please run command 'conda install -c conda-forge libuv=1.39' to install libuv." ) else() diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 86a57264d253..378cb73a225e 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -540,11 +540,9 @@ if(NOT INTERN_BUILD_MOBILE AND NOT BUILD_LITE_INTERPRETER) ${TORCH_SRC_DIR}/csrc/utils/byte_order.cpp ) - if(USE_DISTRIBUTED) - append_filelist("libtorch_distributed_base_sources" TORCH_SRCS) - if(NOT WIN32) - append_filelist("libtorch_distributed_extra_sources" TORCH_SRCS) - endif() + append_filelist("libtorch_distributed_base_sources" TORCH_SRCS) + if(NOT WIN32) + append_filelist("libtorch_distributed_extra_sources" TORCH_SRCS) endif() endif() @@ -568,32 +566,30 @@ if(USE_CUDA) list(APPEND Caffe2_GPU_SRCS ${TORCH_SRC_DIR}/csrc/cuda/nccl.cpp) endif() - if(USE_DISTRIBUTED) - append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_GPU_SRCS) - if(NOT WIN32) - append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) - set_source_files_properties( - ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupNCCL.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp - PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" - ) - endif() + append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_GPU_SRCS) + if(NOT WIN32) + append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) + set_source_files_properties( + ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupNCCL.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp + PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" + ) + endif() - set(ASYNC_MM_FILE "${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/AsyncMM.cu") - # Disable the warning to make cutlass warp-specialized cooperative kernel build for gcc-9 - if(CMAKE_COMPILER_IS_GNUCXX) - set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-Wno-unused-but-set-variable") - endif() - if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0 AND CUDA_NVCC_FLAGS MATCHES ".*compute_90.*") - set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a") - endif() + set(ASYNC_MM_FILE "${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/AsyncMM.cu") + # Disable the warning to make cutlass warp-specialized cooperative kernel build for gcc-9 + if(CMAKE_COMPILER_IS_GNUCXX) + set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-Wno-unused-but-set-variable") + endif() + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0 AND CUDA_NVCC_FLAGS MATCHES ".*compute_90.*") + set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a") endif() set_source_files_properties( ${TORCH_ROOT}/aten/src/ATen/cuda/detail/LazyNVRTC.cpp @@ -626,11 +622,9 @@ if(USE_ROCM) list(APPEND Caffe2_HIP_SRCS ${TORCH_SRC_DIR}/csrc/cuda/nccl.cpp) endif() - if(USE_DISTRIBUTED) - append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_HIP_SRCS) - if(NOT WIN32) - append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_HIP_SRCS) - endif() + append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_HIP_SRCS) + if(NOT WIN32) + append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_HIP_SRCS) endif() # caffe2_nvrtc's stubs to driver APIs are useful for HIP. # See NOTE [ ATen NVRTC Stub and HIP ] @@ -1351,12 +1345,10 @@ if(BUILD_TEST) add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit) add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert) add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor) - if(USE_DISTRIBUTED) - add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d) - if(NOT WIN32) - add_subdirectory(${TORCH_ROOT}/test/cpp/dist_autograd ${CMAKE_BINARY_DIR}/dist_autograd) - add_subdirectory(${TORCH_ROOT}/test/cpp/rpc ${CMAKE_BINARY_DIR}/test_cpp_rpc) - endif() + add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d) + if(NOT WIN32) + add_subdirectory(${TORCH_ROOT}/test/cpp/dist_autograd ${CMAKE_BINARY_DIR}/dist_autograd) + add_subdirectory(${TORCH_ROOT}/test/cpp/rpc ${CMAKE_BINARY_DIR}/test_cpp_rpc) endif() if(NOT NO_API) add_subdirectory(${TORCH_ROOT}/test/cpp/api ${CMAKE_BINARY_DIR}/test_api) @@ -1461,46 +1453,40 @@ if(BUILD_LITE_INTERPRETER) endif() endif() - -# Pass USE_DISTRIBUTED to torch_cpu, as some codes in jit/pickler.cpp and -# jit/unpickler.cpp need to be compiled only when USE_DISTRIBUTED is set -if(USE_DISTRIBUTED) - target_compile_definitions(torch_cpu PUBLIC USE_DISTRIBUTED) - if(USE_GLOO AND USE_C10D_GLOO) - target_compile_definitions(torch_cpu PUBLIC USE_C10D_GLOO) +if(USE_GLOO AND USE_C10D_GLOO) + target_compile_definitions(torch_cpu PUBLIC USE_C10D_GLOO) +endif() +if(USE_UCC AND USE_C10D_UCC) + target_compile_definitions(torch_cpu PUBLIC USE_C10D_UCC) + if(USE_CUDA) + target_compile_definitions(torch_cuda PUBLIC USE_C10D_UCC) endif() - if(USE_UCC AND USE_C10D_UCC) - target_compile_definitions(torch_cpu PUBLIC USE_C10D_UCC) - if(USE_CUDA) - target_compile_definitions(torch_cuda PUBLIC USE_C10D_UCC) - endif() +endif() +if(USE_NCCL AND USE_C10D_NCCL) + if(USE_ROCM) + target_compile_definitions(torch_hip PUBLIC USE_C10D_NCCL) + else() + target_compile_definitions(torch_cuda PUBLIC USE_C10D_NCCL) endif() - if(USE_NCCL AND USE_C10D_NCCL) - if(USE_ROCM) - target_compile_definitions(torch_hip PUBLIC USE_C10D_NCCL) - else() - target_compile_definitions(torch_cuda PUBLIC USE_C10D_NCCL) - endif() - endif() - if(USE_MPI AND USE_C10D_MPI) - if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - set_source_files_properties( - "${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupMPI.cpp" - PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations) - endif() - target_compile_definitions(torch_cpu PUBLIC USE_C10D_MPI) - endif() - # Pass USE_RPC in order to reduce use of - # #if defined(USE_DISTRIBUTED) && !defined(_WIN32) - # need to be removed when RPC is supported - if(NOT WIN32) - target_compile_definitions(torch_cpu PUBLIC USE_RPC) - endif() - # Pass USE_TENSORPIPE to torch_cpu as some parts of rpc/utils.cpp - # can only be compiled with USE_TENSORPIPE is set. - if(USE_TENSORPIPE) - target_compile_definitions(torch_cpu PUBLIC USE_TENSORPIPE) +endif() +if(USE_MPI AND USE_C10D_MPI) + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set_source_files_properties( + "${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupMPI.cpp" + PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations) endif() + target_compile_definitions(torch_cpu PUBLIC USE_C10D_MPI) +endif() +# Pass USE_RPC in order to reduce use of +# #if defined(USE_DISTRIBUTED) && !defined(_WIN32) +# need to be removed when RPC is supported +if(NOT WIN32) + target_compile_definitions(torch_cpu PUBLIC USE_RPC) +endif() +# Pass USE_TENSORPIPE to torch_cpu as some parts of rpc/utils.cpp +# can only be compiled with USE_TENSORPIPE is set. +if(USE_TENSORPIPE) + target_compile_definitions(torch_cpu PUBLIC USE_TENSORPIPE) endif() if(NOT INTERN_BUILD_MOBILE) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 944c7821f667..3354c18dd3af 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1126,7 +1126,7 @@ if(USE_CUDA AND CUDA_VERSION VERSION_LESS 13.0) include_directories(SYSTEM ${CUB_INCLUDE_DIRS}) endif() -if(USE_DISTRIBUTED AND USE_TENSORPIPE) +if(USE_TENSORPIPE) if(MSVC) message(WARNING "Tensorpipe cannot be used on Windows.") else() diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 745d9ea05868..3d388fea772c 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -191,13 +191,11 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_PYTORCH_QNNPACK : ${USE_PYTORCH_QNNPACK}") message(STATUS " USE_XNNPACK : ${USE_XNNPACK}") message(STATUS " USE_DISTRIBUTED : ${USE_DISTRIBUTED}") - if(${USE_DISTRIBUTED}) - message(STATUS " USE_MPI : ${USE_MPI}") - message(STATUS " USE_GLOO : ${USE_GLOO}") - message(STATUS " USE_GLOO_WITH_OPENSSL : ${USE_GLOO_WITH_OPENSSL}") - message(STATUS " USE_GLOO_IBVERBS : ${USE_GLOO_IBVERBS}") - message(STATUS " USE_TENSORPIPE : ${USE_TENSORPIPE}") - endif() + message(STATUS " USE_MPI : ${USE_MPI}") + message(STATUS " USE_GLOO : ${USE_GLOO}") + message(STATUS " USE_GLOO_WITH_OPENSSL : ${USE_GLOO_WITH_OPENSSL}") + message(STATUS " USE_GLOO_IBVERBS : ${USE_GLOO_IBVERBS}") + message(STATUS " USE_TENSORPIPE : ${USE_TENSORPIPE}") if(NOT "${SELECTED_OP_LIST}" STREQUAL "") message(STATUS " SELECTED_OP_LIST : ${SELECTED_OP_LIST}") endif() diff --git a/docs/source/conf.py b/docs/source/conf.py index 4f47652e88d2..fd923a7c4da3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -3331,13 +3331,6 @@ def coverage_post_process(app, exception): if not isinstance(app.builder, CoverageBuilder): return - if not torch.distributed.is_available(): - raise RuntimeError( - "The coverage tool cannot run with a version " - "of PyTorch that was built with USE_DISTRIBUTED=0 " - "as this module's API changes." - ) - # These are all the modules that have "automodule" in an rst file # These modules are the ones for which coverage is checked # Here, we make sure that no module is missing from that list diff --git a/test/cpp/dist_autograd/CMakeLists.txt b/test/cpp/dist_autograd/CMakeLists.txt index 14fd7f7ae9a2..86a6c924288b 100644 --- a/test/cpp/dist_autograd/CMakeLists.txt +++ b/test/cpp/dist_autograd/CMakeLists.txt @@ -1,4 +1,4 @@ -if(USE_DISTRIBUTED AND NOT WIN32) +if(NOT WIN32) set(DIST_AUTOGRAD_TEST_DIR "${TORCH_ROOT}/test/cpp/dist_autograd") set(DIST_AUTOGRAD_TEST_SOURCES ${TORCH_ROOT}/test/cpp/common/main.cpp diff --git a/test/export/test_export.py b/test/export/test_export.py index 6fb39cfdbb65..5c87afb5551b 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -63,10 +63,7 @@ from torch.export.passes import move_to_device_pass from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.testing import FileCheck -from torch.testing._internal.common_cuda import ( - PLATFORM_SUPPORTS_FLASH_ATTENTION, - xfailIfDistributedNotSupported, -) +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION from torch.testing._internal.common_utils import ( find_library_location, IS_FBCODE, @@ -15360,7 +15357,6 @@ class GraphModule(torch.nn.Module): finally: torch.distributed.destroy_process_group() - @xfailIfDistributedNotSupported def test_distributed_all_reduce(self): class Foo(torch.nn.Module): def __init__(self): @@ -15378,7 +15374,6 @@ class GraphModule(torch.nn.Module): inp = (torch.randn(4, 4),) self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp))) - @xfailIfDistributedNotSupported def test_distributed_all_gather(self): class Foo(torch.nn.Module): def forward(self, x): @@ -15394,7 +15389,6 @@ class GraphModule(torch.nn.Module): torch.allclose(a, b) for a, b in zip(ep.module()(*inp), m(*inp)) ) - @xfailIfDistributedNotSupported def test_distributed_all_gather_into_tensor(self): class Foo(torch.nn.Module): def forward(self, x): @@ -15408,7 +15402,6 @@ class GraphModule(torch.nn.Module): inp = (torch.randn(2),) self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp))) - @xfailIfDistributedNotSupported @testing.expectedFailureCppRuntime def test_distributed_all_to_all_single(self): class Foo(torch.nn.Module): @@ -15426,7 +15419,6 @@ class GraphModule(torch.nn.Module): ) self.assertEqual(len(nodes), 1) - @xfailIfDistributedNotSupported @testing.expectedFailureCppRuntime def test_distributed_reduce_scatter_tensor(self): class Foo(torch.nn.Module): diff --git a/tools/build_pytorch_libs.py b/tools/build_pytorch_libs.py index 9d43de80f129..457b224354fb 100644 --- a/tools/build_pytorch_libs.py +++ b/tools/build_pytorch_libs.py @@ -88,8 +88,7 @@ def build_pytorch( ) -> None: my_env = _create_build_env() if ( - not check_negative_env_flag("USE_DISTRIBUTED") - and not check_negative_env_flag("USE_CUDA") + not check_negative_env_flag("USE_CUDA") and not check_negative_env_flag("USE_NCCL") and not check_env_flag("USE_SYSTEM_NCCL") ): diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 1632147f0220..fc51329bbac6 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -273,32 +273,30 @@ add_custom_command( WORKING_DIRECTORY "${TORCH_ROOT}" ) -if(USE_DISTRIBUTED) - if(WIN32) - append_filelist("libtorch_python_distributed_core_sources" TORCH_PYTHON_SRCS) - else() - append_filelist("libtorch_python_distributed_sources" TORCH_PYTHON_SRCS) - endif() - # Disable certain warnings for GCC-9.X - if(CMAKE_COMPILER_IS_GNUCXX) - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/rpc/testing/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - endif() - # NCCL is a private dependency of libtorch, but libtorch_python includes - # some private headers of libtorch, which in turn include NCCL. As a hacky - # alternative to making NCCL a public dependency of libtorch, we make it - # a private dependency of libtorch_python as well. - if(USE_NCCL) - list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) - endif() - # Same for MPI. - if(USE_MPI) - list(APPEND TORCH_PYTHON_LINK_LIBRARIES MPI::MPI_CXX) - endif() - list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D) +if(WIN32) + append_filelist("libtorch_python_distributed_core_sources" TORCH_PYTHON_SRCS) +else() + append_filelist("libtorch_python_distributed_sources" TORCH_PYTHON_SRCS) endif() +# Disable certain warnings for GCC-9.X +if(CMAKE_COMPILER_IS_GNUCXX) + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/rpc/testing/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") +endif() +# NCCL is a private dependency of libtorch, but libtorch_python includes +# some private headers of libtorch, which in turn include NCCL. As a hacky +# alternative to making NCCL a public dependency of libtorch, we make it +# a private dependency of libtorch_python as well. +if(USE_NCCL) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) +endif() +# Same for MPI. +if(USE_MPI) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES MPI::MPI_CXX) +endif() +list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D) if(USE_NCCL AND NOT WIN32) list(APPEND TORCH_PYTHON_SRCS @@ -366,10 +364,6 @@ if(BUILD_LIBTORCHLESS) target_compile_definitions(torch_python PRIVATE USE_C10D_NCCL) endif() - if(USE_DISTRIBUTED) - target_compile_definitions(torch_python PRIVATE USE_DISTRIBUTED) - endif() - if(USE_MPI AND USE_C10D_MPI) target_compile_definitions(torch_python PRIVATE USE_C10D_MPI) endif() diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index 60a7bb644df0..d43d2b02a23e 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -15,9 +15,7 @@ #include #include -#if defined(USE_DISTRIBUTED) #include -#endif inline void PyErr_SetString(PyObject* type, const std::string& message) { PyErr_SetString(type, message.c_str()); diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 675a4c431005..6f052b0331ed 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -120,14 +120,12 @@ #endif #endif -#ifdef USE_DISTRIBUTED #ifdef USE_C10D #include #include #include #include #endif -#endif #if defined(USE_VALGRIND) #include @@ -552,11 +550,7 @@ static PyObject* THPModule_getBackcompatKeepdimWarn( } static PyObject* THPModule_hasDistributed(PyObject* _unused, PyObject* noargs) { -#ifdef USE_DISTRIBUTED Py_RETURN_TRUE; -#else - Py_RETURN_FALSE; -#endif } static PyObject* THPModule_showConfig(PyObject* module, PyObject* noargs) { @@ -1993,7 +1987,7 @@ PyObject* initModule() { #ifdef USE_XPU THPUtils_addPyMethodDefs(methods, THXPModule_methods()); #endif -#if defined(USE_DISTRIBUTED) && defined(USE_C10D) +#ifdef USE_C10D THPUtils_addPyMethodDefs( methods, torch::distributed::c10d::python_functions()); #ifndef _WIN32 diff --git a/torch/csrc/autograd/functions/init.cpp b/torch/csrc/autograd/functions/init.cpp index 5e19010f9ae3..05c8901e1f60 100644 --- a/torch/csrc/autograd/functions/init.cpp +++ b/torch/csrc/autograd/functions/init.cpp @@ -8,9 +8,7 @@ #include #include #include -#ifdef USE_DISTRIBUTED #include -#endif #include #include #include @@ -150,11 +148,9 @@ void THPAutograd_initFunctions() { static PyTypeObject CopyBackwardsClass; addClass(module, CopyBackwardsClass, "CopyBackwards"); -#ifdef USE_DISTRIBUTED static PyTypeObject SendRpcBackwardClass; addClass( module, SendRpcBackwardClass, "SendRpcBackward"); -#endif static PyTypeObject CopySlicesClass; addClass(module, CopySlicesClass, "CopySlices"); diff --git a/torch/csrc/inductor/aoti_torch/shim_cpu.cpp b/torch/csrc/inductor/aoti_torch/shim_cpu.cpp index b1c864bf3fbb..a610685fe955 100644 --- a/torch/csrc/inductor/aoti_torch/shim_cpu.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_cpu.cpp @@ -1,7 +1,5 @@ -#ifdef USE_DISTRIBUTED #include -#endif #include #include @@ -533,7 +531,6 @@ AOTITorchError aoti_torch_cpu__weight_int4pack_mm_cpu_tensor( }); } -#ifdef USE_DISTRIBUTED AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce_( AtenTensorHandle inp, const char* reduce_op, @@ -566,4 +563,3 @@ AOTITorchError aoti_torch_cpu__c10d_functional_wait_tensor( *ret0 = new_tensor_handle(std::move(tmp_result)); }); } -#endif diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index f80ae1b9481c..605e98a2a106 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -13,6 +13,8 @@ #include #include #include +#include +#include #include #include #include @@ -24,10 +26,6 @@ #include #include #include -#ifdef USE_DISTRIBUTED -#include -#include -#endif #include #include diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 8b16e089aa50..808fe7d3605b 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -1225,7 +1225,7 @@ std::shared_ptr toSugaredValue( } else if (obj.ptr() == py::module::import("torch").attr("_check").ptr()) { return std::make_shared(); #ifdef USE_RPC - // RPC module is only available when build flag "USE_DISTRIBUTED" is on. + // This is not defined on WINDOWS } else if ( isRpcAvailable && obj.ptr() == @@ -1238,7 +1238,6 @@ std::shared_ptr toSugaredValue( return SpecialFormValue::create(prim::rpc_sync); } else if ( isRpcAvailable && - // RPC module is only available when build flag "USE_DISTRIBUTED" is on. obj.ptr() == py::module::import("torch.distributed.rpc").attr("remote").ptr()) { return SpecialFormValue::create(prim::rpc_remote); diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index 6ae9f52a0cda..be582cfb7cdd 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -128,13 +128,8 @@ struct InterpreterContinuation { std::optional tls_state = std::nullopt) : state(std::move(state_)), stack(std::move(stack_)), - tls_state_(std::move(tls_state)) -#ifdef USE_DISTRIBUTED - , - dist_autograd_context_id_(dist_autograd_context_id) -#endif - { - } + tls_state_(std::move(tls_state)), + dist_autograd_context_id_(dist_autograd_context_id) {} void operator()(); @@ -142,9 +137,10 @@ struct InterpreterContinuation { InterpreterState state; Stack stack; std::optional tls_state_ = std::nullopt; -#ifdef USE_DISTRIBUTED - int64_t dist_autograd_context_id_; +#ifndef USE_RPC + [[maybe_unused]] #endif + int64_t dist_autograd_context_id_; }; // what is the tensors type, including state from the current execution context diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index 526c840bc10e..e3379f4de65a 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -79,9 +79,7 @@ class TORCH_API Pickler { void pushTuple(const IValue& ivalue); void pushString(const std::string& string); void pushDevice(const IValue& ivalue); -#ifdef USE_DISTRIBUTED void pushRRef(const IValue& ivalue); -#endif // unmemoized version void pushStringImpl(const std::string& string); void pushStorageOfTensor(const at::Tensor& tensor); diff --git a/torch/csrc/jit/serialization/unpickler.h b/torch/csrc/jit/serialization/unpickler.h index 702a1d8816e7..208cf554ad2b 100644 --- a/torch/csrc/jit/serialization/unpickler.h +++ b/torch/csrc/jit/serialization/unpickler.h @@ -140,9 +140,7 @@ class TORCH_API Unpickler { void rebuildParameter(); void rebuildTensorFromTypeV2(); void rebuildSparseTensor(); -#ifdef USE_DISTRIBUTED void rebuildRRef(); -#endif PickleOpCode readInstruction(); PickleOpCode readOpCode() { return static_cast(read()); diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp index 1c88e80d4021..e46c141cd3f4 100644 --- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp +++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp @@ -30,15 +30,12 @@ #include #include -#ifdef USE_DISTRIBUTED #include -#endif // USE_DISTRIBUTED using namespace at; // Collective property attributes // https://github.com/pytorch/pytorch/issues/124674 -#ifdef USE_DISTRIBUTED constexpr auto kETCommsName = "collective_name"; constexpr auto kETInMsgNelems = "in_msg_nelems"; constexpr auto kETOutMsgNelems = "out_msg_nelems"; @@ -49,7 +46,6 @@ constexpr auto kETGlobalRankStride = "global_rank_stride"; constexpr auto kETGroupSize = "pg_size"; constexpr auto kETProcessGroupName = "pg_name"; constexpr auto kETProcessGroupDesc = "pg_desc"; -#endif // USE_DISTRIBUTED namespace torch::profiler::impl { @@ -269,7 +265,6 @@ static std::ofstream openOutputFile(const std::string& name) { return stream; } -#ifdef USE_DISTRIBUTED static std::string getAttrJson( const std::string& name, const std::string& type, @@ -282,7 +277,6 @@ static std::string getAttrJson( type, value); } -#endif static void writeJsonNode( std::ofstream& out, @@ -660,7 +654,6 @@ static void handleKernelBackendInfo( inline std::string getCommsNodeAttrs(const RecordFunction& fn) { // NOLINT std::vector attrs; -#ifdef USE_DISTRIBUTED // We rely on paramcommsdebug object that is available in thread local info auto debugInfo = dynamic_cast( c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO)); @@ -704,8 +697,6 @@ inline std::string getCommsNodeAttrs(const RecordFunction& fn) { // NOLINT addAttr(kGroupSize, kETGroupSize, "uint64"); -#endif // USE_DISTRIBUTED - // XXX consider using as string stream? return attrs.empty() ? "" : fmt::format(", {}", fmt::join(attrs, ", ")); } diff --git a/torch/csrc/profiler/util.cpp b/torch/csrc/profiler/util.cpp index 0b2979e6fb7e..4ed0ac45b04d 100644 --- a/torch/csrc/profiler/util.cpp +++ b/torch/csrc/profiler/util.cpp @@ -11,9 +11,7 @@ #ifdef USE_KINETO #include #endif -#ifdef USE_DISTRIBUTED #include -#endif // USE_DISTRIBUTED namespace torch::profiler::impl { @@ -455,7 +453,6 @@ std::unordered_map saveNcclMeta( // @lint-ignore CLANGTIDY const SaveNcclMetaConfig& config) { std::unordered_map map; -#ifdef USE_DISTRIBUTED auto debugInfo = dynamic_cast( c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO)); @@ -565,7 +562,6 @@ std::unordered_map saveNcclMeta( } } } -#endif // USE_DISTRIBUTED return map; } diff --git a/torch/csrc/profiler/util.h b/torch/csrc/profiler/util.h index f2ae57fa0e59..dcb4b866a2de 100644 --- a/torch/csrc/profiler/util.h +++ b/torch/csrc/profiler/util.h @@ -185,7 +185,6 @@ struct HashCombine { } }; -#ifdef USE_DISTRIBUTED constexpr auto kCommsName = "Collective name"; constexpr auto kDtype = "dtype"; constexpr auto kInMsgNelems = "In msg nelems"; @@ -203,6 +202,5 @@ constexpr auto kP2pSrc = "Src Rank"; constexpr auto kP2pDst = "Dst Rank"; constexpr auto kInTensorsStart = "Input Tensors start"; constexpr auto kOutTensorsStart = "Output Tensors start"; -#endif // USE_DISTRIBUTED } // namespace torch::profiler::impl diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index 38e2fdbee803..bfb4175d61e0 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -14,16 +14,10 @@ log = logging.getLogger(__name__) def is_available() -> bool: """ - Return ``True`` if the distributed package is available. - - Otherwise, - ``torch.distributed`` does not expose any other APIs. Currently, - ``torch.distributed`` is available on Linux, MacOS and Windows. Set - ``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source. - Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows, - ``USE_DISTRIBUTED=0`` for MacOS. + Always returns ``True``. Note that even if distributed is available, + there may not necessarily be any usable backends. """ - return hasattr(torch._C, "_c10d_init") + return True if is_available() and not torch._C._c10d_init(): diff --git a/torch/distributed/algorithms/model_averaging/utils.py b/torch/distributed/algorithms/model_averaging/utils.py index fa8cc184eddc..3e3243002a9c 100644 --- a/torch/distributed/algorithms/model_averaging/utils.py +++ b/torch/distributed/algorithms/model_averaging/utils.py @@ -5,10 +5,6 @@ from typing import Union import torch import torch.distributed as dist - -# The two imports below are not always available depending on the -# USE_DISTRIBUTED compile flag. Make sure they raise import error -# if we're trying to use them. from torch.distributed import group, ProcessGroup diff --git a/torch/distributed/nn/functional.py b/torch/distributed/nn/functional.py index eeff877260bc..2bdf3fe2bdff 100644 --- a/torch/distributed/nn/functional.py +++ b/torch/distributed/nn/functional.py @@ -2,10 +2,6 @@ import torch import torch.distributed as dist from torch.autograd import Function - -# The two imports below are not always available depending on the -# USE_DISTRIBUTED compile flag. Make sure they raise import error -# if we're trying to use them. from torch.distributed import group, ReduceOp