diff --git a/.ci/pytorch/macos-build.sh b/.ci/pytorch/macos-build.sh index d41c3c08e628..d7447e7d4858 100755 --- a/.ci/pytorch/macos-build.sh +++ b/.ci/pytorch/macos-build.sh @@ -35,10 +35,11 @@ fi print_cmake_info if [[ ${BUILD_ENVIRONMENT} == *"distributed"* ]]; then - USE_OPENMP=1 WERROR=1 python setup.py bdist_wheel + # Needed for inductor benchmarks, as lots of HF networks make `torch.distribtued` calls + USE_DISTRIBUTED=1 USE_OPENMP=1 WERROR=1 python setup.py bdist_wheel else - # NB: we always build with distributed; USE_DISTRIBUTED turns off all - # backends (specifically the gloo backend), so test that this case works too + # Explicitly set USE_DISTRIBUTED=0 to align with the default build config on mac. This also serves as the sole CI config that tests + # that building with USE_DISTRIBUTED=0 works at all. See https://github.com/pytorch/pytorch/issues/86448 USE_DISTRIBUTED=0 USE_OPENMP=1 MACOSX_DEPLOYMENT_TARGET=11.0 WERROR=1 BUILD_TEST=OFF USE_PYTORCH_METAL=1 python setup.py bdist_wheel --plat-name macosx_11_0_arm64 fi if which sccache > /dev/null; then diff --git a/.ci/pytorch/macos-test.sh b/.ci/pytorch/macos-test.sh index 64ea8a1c2554..a859901191e0 100755 --- a/.ci/pytorch/macos-test.sh +++ b/.ci/pytorch/macos-test.sh @@ -16,8 +16,6 @@ popd # enable debug asserts in serialization export TORCH_SERIALIZATION_DEBUG=1 -python -mpip install --no-input -r requirements.txt - setup_test_python() { # The CircleCI worker hostname doesn't resolve to an address. # This environment variable makes ProcessGroupGloo default to diff --git a/.ci/wheel/build_wheel.sh b/.ci/wheel/build_wheel.sh index 763fce4b73e1..e63a68e4f193 100755 --- a/.ci/wheel/build_wheel.sh +++ b/.ci/wheel/build_wheel.sh @@ -189,8 +189,7 @@ pip install requests ninja typing-extensions retry pip install -r "${pytorch_rootdir}/requirements.txt" || true retry brew install libomp -# For USE_DISTRIBUTED=1 on macOS, this enables gloo, which needs libuv, which -# is build as part of tensorpipe submodule +# For USE_DISTRIBUTED=1 on macOS, need libuv, which is build as part of tensorpipe submodule export USE_DISTRIBUTED=1 export USE_MKLDNN=OFF diff --git a/BUILD.bazel b/BUILD.bazel index 2cbd36f06761..d4202e7a2c1e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -22,6 +22,7 @@ COMMON_COPTS = [ "-DHAVE_SHM_UNLINK=1", "-D_FILE_OFFSET_BITS=64", "-DUSE_FBGEMM", + "-DUSE_DISTRIBUTED", "-DAT_PER_OPERATOR_HEADERS", "-DATEN_THREADING=NATIVE", "-DNO_CUDNN_DESTROY_HANDLE", diff --git a/CMakeLists.txt b/CMakeLists.txt index 4120e621bdd0..ce7890f002d3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -181,9 +181,8 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(ppc64le)") set(CPU_POWER ON) endif() -# For non-supported platforms, turn USE_DISTRIBUTED off by default. -# NB: USE_DISTRIBUTED simply disables the backend; distributed code -# still gets built +# For non-supported platforms, turn USE_DISTRIBUTED off by default. It is not +# tested and likely won't work without additional changes. if(NOT LINUX AND NOT WIN32) set(USE_DISTRIBUTED OFF @@ -262,11 +261,11 @@ option(USE_PYTORCH_METAL "Use Metal for PyTorch iOS build" OFF) option(USE_PYTORCH_METAL_EXPORT "Export Metal models on MacOSX desktop" OFF) option(USE_NATIVE_ARCH "Use -march=native" OFF) cmake_dependent_option(USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF) -option(USE_DISTRIBUTED "Enable default distributed backends" ON) +option(USE_DISTRIBUTED "Use distributed" ON) cmake_dependent_option(USE_NCCL "Use NCCL" ON "USE_DISTRIBUTED;USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_XCCL "Use XCCL" ON - "USE_DISTRIBUTED;USE_XPU;UNIX;NOT APPLE" OFF) + "USE_XPU;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF) cmake_dependent_option(USE_RCCL "Use RCCL" ON "USE_NCCL;NOT WIN32" OFF) cmake_dependent_option(USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF) @@ -431,10 +430,11 @@ if(WIN32) PATH_SUFFIXES lib NO_DEFAULT_PATH) if(NOT libuv_tmp_LIBRARY) + set(USE_DISTRIBUTED OFF) set(USE_GLOO OFF) message( WARNING - "Libuv is not installed in current conda env. Set USE_GLOO to OFF. " + "Libuv is not installed in current conda env. Set USE_DISTRIBUTED to OFF. " "Please run command 'conda install -c conda-forge libuv=1.39' to install libuv." ) else() diff --git a/buckbuild.bzl b/buckbuild.bzl index 218fd747301f..e079d9839544 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -948,7 +948,6 @@ def define_buck_targets( [ ("torch/csrc/api/include", "torch/**/*.h"), ("", "torch/csrc/**/*.h"), - ("", "torch/csrc/**/*.hpp"), ("", "torch/nativert/**/*.h"), ("", "torch/headeronly/**/*.h"), ("", "torch/script.h"), @@ -2034,7 +2033,6 @@ def define_buck_targets( ("", "caffe2/utils/*.h"), ("", "caffe2/core/*.h"), ("", "torch/csrc/*.h"), - ("", "torch/csrc/*.hpp"), ("", "torch/csrc/api/include/torch/*.h"), ("", "torch/csrc/autograd/*.h"), ("", "torch/csrc/autograd/*/*.h"), diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 504dbf5a4fad..b4a94fb9fe76 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -540,9 +540,11 @@ if(NOT INTERN_BUILD_MOBILE AND NOT BUILD_LITE_INTERPRETER) ${TORCH_SRC_DIR}/csrc/utils/byte_order.cpp ) - append_filelist("libtorch_distributed_base_sources" TORCH_SRCS) - if(NOT WIN32) - append_filelist("libtorch_distributed_extra_sources" TORCH_SRCS) + if(USE_DISTRIBUTED) + append_filelist("libtorch_distributed_base_sources" TORCH_SRCS) + if(NOT WIN32) + append_filelist("libtorch_distributed_extra_sources" TORCH_SRCS) + endif() endif() endif() @@ -566,30 +568,32 @@ if(USE_CUDA) list(APPEND Caffe2_GPU_SRCS ${TORCH_SRC_DIR}/csrc/cuda/nccl.cpp) endif() - append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_GPU_SRCS) - if(NOT WIN32) - append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) - set_source_files_properties( - ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupNCCL.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu - ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp - PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" - ) - endif() + if(USE_DISTRIBUTED) + append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_GPU_SRCS) + if(NOT WIN32) + append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) + set_source_files_properties( + ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupNCCL.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp + PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" + ) + endif() - set(ASYNC_MM_FILE "${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/AsyncMM.cu") - # Disable the warning to make cutlass warp-specialized cooperative kernel build for gcc-9 - if(CMAKE_COMPILER_IS_GNUCXX) - set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-Wno-unused-but-set-variable") - endif() - if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0 AND CUDA_NVCC_FLAGS MATCHES ".*compute_90.*") - set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a") + set(ASYNC_MM_FILE "${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/AsyncMM.cu") + # Disable the warning to make cutlass warp-specialized cooperative kernel build for gcc-9 + if(CMAKE_COMPILER_IS_GNUCXX) + set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-Wno-unused-but-set-variable") + endif() + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0 AND CUDA_NVCC_FLAGS MATCHES ".*compute_90.*") + set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a") + endif() endif() set_source_files_properties( ${TORCH_ROOT}/aten/src/ATen/cuda/detail/LazyNVRTC.cpp @@ -622,9 +626,11 @@ if(USE_ROCM) list(APPEND Caffe2_HIP_SRCS ${TORCH_SRC_DIR}/csrc/cuda/nccl.cpp) endif() - append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_HIP_SRCS) - if(NOT WIN32) - append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_HIP_SRCS) + if(USE_DISTRIBUTED) + append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_HIP_SRCS) + if(NOT WIN32) + append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_HIP_SRCS) + endif() endif() # caffe2_nvrtc's stubs to driver APIs are useful for HIP. # See NOTE [ ATen NVRTC Stub and HIP ] @@ -1345,10 +1351,12 @@ if(BUILD_TEST) add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit) add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert) add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor) - add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d) - if(NOT WIN32) - add_subdirectory(${TORCH_ROOT}/test/cpp/dist_autograd ${CMAKE_BINARY_DIR}/dist_autograd) - add_subdirectory(${TORCH_ROOT}/test/cpp/rpc ${CMAKE_BINARY_DIR}/test_cpp_rpc) + if(USE_DISTRIBUTED) + add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d) + if(NOT WIN32) + add_subdirectory(${TORCH_ROOT}/test/cpp/dist_autograd ${CMAKE_BINARY_DIR}/dist_autograd) + add_subdirectory(${TORCH_ROOT}/test/cpp/rpc ${CMAKE_BINARY_DIR}/test_cpp_rpc) + endif() endif() if(NOT NO_API) add_subdirectory(${TORCH_ROOT}/test/cpp/api ${CMAKE_BINARY_DIR}/test_api) @@ -1453,40 +1461,46 @@ if(BUILD_LITE_INTERPRETER) endif() endif() -if(USE_GLOO AND USE_C10D_GLOO) - target_compile_definitions(torch_cpu PUBLIC USE_C10D_GLOO) -endif() -if(USE_UCC AND USE_C10D_UCC) - target_compile_definitions(torch_cpu PUBLIC USE_C10D_UCC) - if(USE_CUDA) - target_compile_definitions(torch_cuda PUBLIC USE_C10D_UCC) + +# Pass USE_DISTRIBUTED to torch_cpu, as some codes in jit/pickler.cpp and +# jit/unpickler.cpp need to be compiled only when USE_DISTRIBUTED is set +if(USE_DISTRIBUTED) + target_compile_definitions(torch_cpu PUBLIC USE_DISTRIBUTED) + if(USE_GLOO AND USE_C10D_GLOO) + target_compile_definitions(torch_cpu PUBLIC USE_C10D_GLOO) endif() -endif() -if(USE_NCCL AND USE_C10D_NCCL) - if(USE_ROCM) - target_compile_definitions(torch_hip PUBLIC USE_C10D_NCCL) - else() - target_compile_definitions(torch_cuda PUBLIC USE_C10D_NCCL) + if(USE_UCC AND USE_C10D_UCC) + target_compile_definitions(torch_cpu PUBLIC USE_C10D_UCC) + if(USE_CUDA) + target_compile_definitions(torch_cuda PUBLIC USE_C10D_UCC) + endif() endif() -endif() -if(USE_MPI AND USE_C10D_MPI) - if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - set_source_files_properties( - "${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupMPI.cpp" - PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations) + if(USE_NCCL AND USE_C10D_NCCL) + if(USE_ROCM) + target_compile_definitions(torch_hip PUBLIC USE_C10D_NCCL) + else() + target_compile_definitions(torch_cuda PUBLIC USE_C10D_NCCL) + endif() + endif() + if(USE_MPI AND USE_C10D_MPI) + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set_source_files_properties( + "${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupMPI.cpp" + PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations) + endif() + target_compile_definitions(torch_cpu PUBLIC USE_C10D_MPI) + endif() + # Pass USE_RPC in order to reduce use of + # #if defined(USE_DISTRIBUTED) && !defined(_WIN32) + # need to be removed when RPC is supported + if(NOT WIN32) + target_compile_definitions(torch_cpu PUBLIC USE_RPC) + endif() + # Pass USE_TENSORPIPE to torch_cpu as some parts of rpc/utils.cpp + # can only be compiled with USE_TENSORPIPE is set. + if(USE_TENSORPIPE) + target_compile_definitions(torch_cpu PUBLIC USE_TENSORPIPE) endif() - target_compile_definitions(torch_cpu PUBLIC USE_C10D_MPI) -endif() -# Pass USE_RPC in order to reduce use of -# #if defined(USE_DISTRIBUTED) && !defined(_WIN32) -# need to be removed when RPC is supported -if(NOT WIN32) - target_compile_definitions(torch_cpu PUBLIC USE_RPC) -endif() -# Pass USE_TENSORPIPE to torch_cpu as some parts of rpc/utils.cpp -# can only be compiled with USE_TENSORPIPE is set. -if(USE_TENSORPIPE) - target_compile_definitions(torch_cpu PUBLIC USE_TENSORPIPE) endif() if(NOT INTERN_BUILD_MOBILE) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index e4e82b16f410..ef5c2fd4e97d 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1126,7 +1126,7 @@ if(USE_CUDA AND CUDA_VERSION VERSION_LESS 13.0) include_directories(SYSTEM ${CUB_INCLUDE_DIRS}) endif() -if(USE_TENSORPIPE) +if(USE_DISTRIBUTED AND USE_TENSORPIPE) if(MSVC) message(WARNING "Tensorpipe cannot be used on Windows.") else() diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 3d388fea772c..745d9ea05868 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -191,11 +191,13 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_PYTORCH_QNNPACK : ${USE_PYTORCH_QNNPACK}") message(STATUS " USE_XNNPACK : ${USE_XNNPACK}") message(STATUS " USE_DISTRIBUTED : ${USE_DISTRIBUTED}") - message(STATUS " USE_MPI : ${USE_MPI}") - message(STATUS " USE_GLOO : ${USE_GLOO}") - message(STATUS " USE_GLOO_WITH_OPENSSL : ${USE_GLOO_WITH_OPENSSL}") - message(STATUS " USE_GLOO_IBVERBS : ${USE_GLOO_IBVERBS}") - message(STATUS " USE_TENSORPIPE : ${USE_TENSORPIPE}") + if(${USE_DISTRIBUTED}) + message(STATUS " USE_MPI : ${USE_MPI}") + message(STATUS " USE_GLOO : ${USE_GLOO}") + message(STATUS " USE_GLOO_WITH_OPENSSL : ${USE_GLOO_WITH_OPENSSL}") + message(STATUS " USE_GLOO_IBVERBS : ${USE_GLOO_IBVERBS}") + message(STATUS " USE_TENSORPIPE : ${USE_TENSORPIPE}") + endif() if(NOT "${SELECTED_OP_LIST}" STREQUAL "") message(STATUS " SELECTED_OP_LIST : ${SELECTED_OP_LIST}") endif() diff --git a/docs/source/conf.py b/docs/source/conf.py index d1504757f9c5..44ad4de8115f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -3333,6 +3333,13 @@ def coverage_post_process(app, exception): if not isinstance(app.builder, CoverageBuilder): return + if not torch.distributed.is_available(): + raise RuntimeError( + "The coverage tool cannot run with a version " + "of PyTorch that was built with USE_DISTRIBUTED=0 " + "as this module's API changes." + ) + # These are all the modules that have "automodule" in an rst file # These modules are the ones for which coverage is checked # Here, we make sure that no module is missing from that list diff --git a/test/cpp/dist_autograd/CMakeLists.txt b/test/cpp/dist_autograd/CMakeLists.txt index 86a6c924288b..14fd7f7ae9a2 100644 --- a/test/cpp/dist_autograd/CMakeLists.txt +++ b/test/cpp/dist_autograd/CMakeLists.txt @@ -1,4 +1,4 @@ -if(NOT WIN32) +if(USE_DISTRIBUTED AND NOT WIN32) set(DIST_AUTOGRAD_TEST_DIR "${TORCH_ROOT}/test/cpp/dist_autograd") set(DIST_AUTOGRAD_TEST_SOURCES ${TORCH_ROOT}/test/cpp/common/main.cpp diff --git a/test/export/test_export.py b/test/export/test_export.py index 62b4e4d09242..f22c016dba3a 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -65,7 +65,10 @@ from torch.export.passes import move_to_device_pass from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.testing import FileCheck -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_FLASH_ATTENTION, + xfailIfDistributedNotSupported, +) from torch.testing._internal.common_utils import ( find_library_location, IS_FBCODE, @@ -15552,6 +15555,7 @@ class GraphModule(torch.nn.Module): finally: torch.distributed.destroy_process_group() + @xfailIfDistributedNotSupported def test_distributed_all_reduce(self): class Foo(torch.nn.Module): def __init__(self): @@ -15569,6 +15573,7 @@ class GraphModule(torch.nn.Module): inp = (torch.randn(4, 4),) self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp))) + @xfailIfDistributedNotSupported def test_distributed_all_gather(self): class Foo(torch.nn.Module): def forward(self, x): @@ -15584,6 +15589,7 @@ class GraphModule(torch.nn.Module): torch.allclose(a, b) for a, b in zip(ep.module()(*inp), m(*inp)) ) + @xfailIfDistributedNotSupported def test_distributed_all_gather_into_tensor(self): class Foo(torch.nn.Module): def forward(self, x): @@ -15597,6 +15603,7 @@ class GraphModule(torch.nn.Module): inp = (torch.randn(2),) self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp))) + @xfailIfDistributedNotSupported @testing.expectedFailureCppRuntime def test_distributed_all_to_all_single(self): class Foo(torch.nn.Module): @@ -15614,6 +15621,7 @@ class GraphModule(torch.nn.Module): ) self.assertEqual(len(nodes), 1) + @xfailIfDistributedNotSupported @testing.expectedFailureCppRuntime def test_distributed_reduce_scatter_tensor(self): class Foo(torch.nn.Module): diff --git a/tools/build_pytorch_libs.py b/tools/build_pytorch_libs.py index 457b224354fb..9d43de80f129 100644 --- a/tools/build_pytorch_libs.py +++ b/tools/build_pytorch_libs.py @@ -88,7 +88,8 @@ def build_pytorch( ) -> None: my_env = _create_build_env() if ( - not check_negative_env_flag("USE_CUDA") + not check_negative_env_flag("USE_DISTRIBUTED") + and not check_negative_env_flag("USE_CUDA") and not check_negative_env_flag("USE_NCCL") and not check_env_flag("USE_SYSTEM_NCCL") ): diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index fc51329bbac6..1632147f0220 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -273,30 +273,32 @@ add_custom_command( WORKING_DIRECTORY "${TORCH_ROOT}" ) +if(USE_DISTRIBUTED) + if(WIN32) + append_filelist("libtorch_python_distributed_core_sources" TORCH_PYTHON_SRCS) + else() + append_filelist("libtorch_python_distributed_sources" TORCH_PYTHON_SRCS) + endif() + # Disable certain warnings for GCC-9.X + if(CMAKE_COMPILER_IS_GNUCXX) + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/rpc/testing/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + endif() + # NCCL is a private dependency of libtorch, but libtorch_python includes + # some private headers of libtorch, which in turn include NCCL. As a hacky + # alternative to making NCCL a public dependency of libtorch, we make it + # a private dependency of libtorch_python as well. + if(USE_NCCL) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) + endif() + # Same for MPI. + if(USE_MPI) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES MPI::MPI_CXX) + endif() + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D) -if(WIN32) - append_filelist("libtorch_python_distributed_core_sources" TORCH_PYTHON_SRCS) -else() - append_filelist("libtorch_python_distributed_sources" TORCH_PYTHON_SRCS) endif() -# Disable certain warnings for GCC-9.X -if(CMAKE_COMPILER_IS_GNUCXX) - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/rpc/testing/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") -endif() -# NCCL is a private dependency of libtorch, but libtorch_python includes -# some private headers of libtorch, which in turn include NCCL. As a hacky -# alternative to making NCCL a public dependency of libtorch, we make it -# a private dependency of libtorch_python as well. -if(USE_NCCL) - list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) -endif() -# Same for MPI. -if(USE_MPI) - list(APPEND TORCH_PYTHON_LINK_LIBRARIES MPI::MPI_CXX) -endif() -list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D) if(USE_NCCL AND NOT WIN32) list(APPEND TORCH_PYTHON_SRCS @@ -364,6 +366,10 @@ if(BUILD_LIBTORCHLESS) target_compile_definitions(torch_python PRIVATE USE_C10D_NCCL) endif() + if(USE_DISTRIBUTED) + target_compile_definitions(torch_python PRIVATE USE_DISTRIBUTED) + endif() + if(USE_MPI AND USE_C10D_MPI) target_compile_definitions(torch_python PRIVATE USE_C10D_MPI) endif() diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index d43d2b02a23e..60a7bb644df0 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -15,7 +15,9 @@ #include #include +#if defined(USE_DISTRIBUTED) #include +#endif inline void PyErr_SetString(PyObject* type, const std::string& message) { PyErr_SetString(type, message.c_str()); diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 6f052b0331ed..675a4c431005 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -120,12 +120,14 @@ #endif #endif +#ifdef USE_DISTRIBUTED #ifdef USE_C10D #include #include #include #include #endif +#endif #if defined(USE_VALGRIND) #include @@ -550,7 +552,11 @@ static PyObject* THPModule_getBackcompatKeepdimWarn( } static PyObject* THPModule_hasDistributed(PyObject* _unused, PyObject* noargs) { +#ifdef USE_DISTRIBUTED Py_RETURN_TRUE; +#else + Py_RETURN_FALSE; +#endif } static PyObject* THPModule_showConfig(PyObject* module, PyObject* noargs) { @@ -1987,7 +1993,7 @@ PyObject* initModule() { #ifdef USE_XPU THPUtils_addPyMethodDefs(methods, THXPModule_methods()); #endif -#ifdef USE_C10D +#if defined(USE_DISTRIBUTED) && defined(USE_C10D) THPUtils_addPyMethodDefs( methods, torch::distributed::c10d::python_functions()); #ifndef _WIN32 diff --git a/torch/csrc/autograd/functions/init.cpp b/torch/csrc/autograd/functions/init.cpp index 05c8901e1f60..5e19010f9ae3 100644 --- a/torch/csrc/autograd/functions/init.cpp +++ b/torch/csrc/autograd/functions/init.cpp @@ -8,7 +8,9 @@ #include #include #include +#ifdef USE_DISTRIBUTED #include +#endif #include #include #include @@ -148,9 +150,11 @@ void THPAutograd_initFunctions() { static PyTypeObject CopyBackwardsClass; addClass(module, CopyBackwardsClass, "CopyBackwards"); +#ifdef USE_DISTRIBUTED static PyTypeObject SendRpcBackwardClass; addClass( module, SendRpcBackwardClass, "SendRpcBackward"); +#endif static PyTypeObject CopySlicesClass; addClass(module, CopySlicesClass, "CopySlices"); diff --git a/torch/csrc/inductor/aoti_torch/shim_cpu.cpp b/torch/csrc/inductor/aoti_torch/shim_cpu.cpp index a610685fe955..b1c864bf3fbb 100644 --- a/torch/csrc/inductor/aoti_torch/shim_cpu.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_cpu.cpp @@ -1,5 +1,7 @@ +#ifdef USE_DISTRIBUTED #include +#endif #include #include @@ -531,6 +533,7 @@ AOTITorchError aoti_torch_cpu__weight_int4pack_mm_cpu_tensor( }); } +#ifdef USE_DISTRIBUTED AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce_( AtenTensorHandle inp, const char* reduce_op, @@ -563,3 +566,4 @@ AOTITorchError aoti_torch_cpu__c10d_functional_wait_tensor( *ret0 = new_tensor_handle(std::move(tmp_result)); }); } +#endif diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 605e98a2a106..f80ae1b9481c 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -13,8 +13,6 @@ #include #include #include -#include -#include #include #include #include @@ -26,6 +24,10 @@ #include #include #include +#ifdef USE_DISTRIBUTED +#include +#include +#endif #include #include diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 808fe7d3605b..8b16e089aa50 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -1225,7 +1225,7 @@ std::shared_ptr toSugaredValue( } else if (obj.ptr() == py::module::import("torch").attr("_check").ptr()) { return std::make_shared(); #ifdef USE_RPC - // This is not defined on WINDOWS + // RPC module is only available when build flag "USE_DISTRIBUTED" is on. } else if ( isRpcAvailable && obj.ptr() == @@ -1238,6 +1238,7 @@ std::shared_ptr toSugaredValue( return SpecialFormValue::create(prim::rpc_sync); } else if ( isRpcAvailable && + // RPC module is only available when build flag "USE_DISTRIBUTED" is on. obj.ptr() == py::module::import("torch.distributed.rpc").attr("remote").ptr()) { return SpecialFormValue::create(prim::rpc_remote); diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index be582cfb7cdd..6ae9f52a0cda 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -128,8 +128,13 @@ struct InterpreterContinuation { std::optional tls_state = std::nullopt) : state(std::move(state_)), stack(std::move(stack_)), - tls_state_(std::move(tls_state)), - dist_autograd_context_id_(dist_autograd_context_id) {} + tls_state_(std::move(tls_state)) +#ifdef USE_DISTRIBUTED + , + dist_autograd_context_id_(dist_autograd_context_id) +#endif + { + } void operator()(); @@ -137,10 +142,9 @@ struct InterpreterContinuation { InterpreterState state; Stack stack; std::optional tls_state_ = std::nullopt; -#ifndef USE_RPC - [[maybe_unused]] -#endif +#ifdef USE_DISTRIBUTED int64_t dist_autograd_context_id_; +#endif }; // what is the tensors type, including state from the current execution context diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index e3379f4de65a..526c840bc10e 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -79,7 +79,9 @@ class TORCH_API Pickler { void pushTuple(const IValue& ivalue); void pushString(const std::string& string); void pushDevice(const IValue& ivalue); +#ifdef USE_DISTRIBUTED void pushRRef(const IValue& ivalue); +#endif // unmemoized version void pushStringImpl(const std::string& string); void pushStorageOfTensor(const at::Tensor& tensor); diff --git a/torch/csrc/jit/serialization/unpickler.h b/torch/csrc/jit/serialization/unpickler.h index 208cf554ad2b..702a1d8816e7 100644 --- a/torch/csrc/jit/serialization/unpickler.h +++ b/torch/csrc/jit/serialization/unpickler.h @@ -140,7 +140,9 @@ class TORCH_API Unpickler { void rebuildParameter(); void rebuildTensorFromTypeV2(); void rebuildSparseTensor(); +#ifdef USE_DISTRIBUTED void rebuildRRef(); +#endif PickleOpCode readInstruction(); PickleOpCode readOpCode() { return static_cast(read()); diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp index e46c141cd3f4..1c88e80d4021 100644 --- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp +++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp @@ -30,12 +30,15 @@ #include #include +#ifdef USE_DISTRIBUTED #include +#endif // USE_DISTRIBUTED using namespace at; // Collective property attributes // https://github.com/pytorch/pytorch/issues/124674 +#ifdef USE_DISTRIBUTED constexpr auto kETCommsName = "collective_name"; constexpr auto kETInMsgNelems = "in_msg_nelems"; constexpr auto kETOutMsgNelems = "out_msg_nelems"; @@ -46,6 +49,7 @@ constexpr auto kETGlobalRankStride = "global_rank_stride"; constexpr auto kETGroupSize = "pg_size"; constexpr auto kETProcessGroupName = "pg_name"; constexpr auto kETProcessGroupDesc = "pg_desc"; +#endif // USE_DISTRIBUTED namespace torch::profiler::impl { @@ -265,6 +269,7 @@ static std::ofstream openOutputFile(const std::string& name) { return stream; } +#ifdef USE_DISTRIBUTED static std::string getAttrJson( const std::string& name, const std::string& type, @@ -277,6 +282,7 @@ static std::string getAttrJson( type, value); } +#endif static void writeJsonNode( std::ofstream& out, @@ -654,6 +660,7 @@ static void handleKernelBackendInfo( inline std::string getCommsNodeAttrs(const RecordFunction& fn) { // NOLINT std::vector attrs; +#ifdef USE_DISTRIBUTED // We rely on paramcommsdebug object that is available in thread local info auto debugInfo = dynamic_cast( c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO)); @@ -697,6 +704,8 @@ inline std::string getCommsNodeAttrs(const RecordFunction& fn) { // NOLINT addAttr(kGroupSize, kETGroupSize, "uint64"); +#endif // USE_DISTRIBUTED + // XXX consider using as string stream? return attrs.empty() ? "" : fmt::format(", {}", fmt::join(attrs, ", ")); } diff --git a/torch/csrc/profiler/util.cpp b/torch/csrc/profiler/util.cpp index e97699a99fd1..0b2979e6fb7e 100644 --- a/torch/csrc/profiler/util.cpp +++ b/torch/csrc/profiler/util.cpp @@ -11,7 +11,9 @@ #ifdef USE_KINETO #include #endif +#ifdef USE_DISTRIBUTED #include +#endif // USE_DISTRIBUTED namespace torch::profiler::impl { @@ -453,7 +455,7 @@ std::unordered_map saveNcclMeta( // @lint-ignore CLANGTIDY const SaveNcclMetaConfig& config) { std::unordered_map map; -#if !defined(BUILD_LITE_INTERPRETER) && !defined(C10_MOBILE) +#ifdef USE_DISTRIBUTED auto debugInfo = dynamic_cast( c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO)); @@ -563,7 +565,7 @@ std::unordered_map saveNcclMeta( } } } -#endif // !defined(BUILD_LITE_INTERPRETER) && !defined(C10_MOBILE) +#endif // USE_DISTRIBUTED return map; } diff --git a/torch/csrc/profiler/util.h b/torch/csrc/profiler/util.h index dcb4b866a2de..f2ae57fa0e59 100644 --- a/torch/csrc/profiler/util.h +++ b/torch/csrc/profiler/util.h @@ -185,6 +185,7 @@ struct HashCombine { } }; +#ifdef USE_DISTRIBUTED constexpr auto kCommsName = "Collective name"; constexpr auto kDtype = "dtype"; constexpr auto kInMsgNelems = "In msg nelems"; @@ -202,5 +203,6 @@ constexpr auto kP2pSrc = "Src Rank"; constexpr auto kP2pDst = "Dst Rank"; constexpr auto kInTensorsStart = "Input Tensors start"; constexpr auto kOutTensorsStart = "Output Tensors start"; +#endif // USE_DISTRIBUTED } // namespace torch::profiler::impl diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index bfb4175d61e0..38e2fdbee803 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -14,10 +14,16 @@ log = logging.getLogger(__name__) def is_available() -> bool: """ - Always returns ``True``. Note that even if distributed is available, - there may not necessarily be any usable backends. + Return ``True`` if the distributed package is available. + + Otherwise, + ``torch.distributed`` does not expose any other APIs. Currently, + ``torch.distributed`` is available on Linux, MacOS and Windows. Set + ``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source. + Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows, + ``USE_DISTRIBUTED=0`` for MacOS. """ - return True + return hasattr(torch._C, "_c10d_init") if is_available() and not torch._C._c10d_init(): diff --git a/torch/distributed/algorithms/model_averaging/utils.py b/torch/distributed/algorithms/model_averaging/utils.py index 3e3243002a9c..fa8cc184eddc 100644 --- a/torch/distributed/algorithms/model_averaging/utils.py +++ b/torch/distributed/algorithms/model_averaging/utils.py @@ -5,6 +5,10 @@ from typing import Union import torch import torch.distributed as dist + +# The two imports below are not always available depending on the +# USE_DISTRIBUTED compile flag. Make sure they raise import error +# if we're trying to use them. from torch.distributed import group, ProcessGroup diff --git a/torch/distributed/nn/functional.py b/torch/distributed/nn/functional.py index 2bdf3fe2bdff..eeff877260bc 100644 --- a/torch/distributed/nn/functional.py +++ b/torch/distributed/nn/functional.py @@ -2,6 +2,10 @@ import torch import torch.distributed as dist from torch.autograd import Function + +# The two imports below are not always available depending on the +# USE_DISTRIBUTED compile flag. Make sure they raise import error +# if we're trying to use them. from torch.distributed import group, ReduceOp