diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 584f0a47e4a9..168a9104dcdb 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1027,6 +1027,7 @@ elseif(USE_CUDA) nvshmem_device ) target_compile_definitions(torch_cuda PUBLIC USE_NVSHMEM) + target_compile_definitions(nvshmem_extension PUBLIC USE_NVSHMEM) target_link_libraries(torch_cuda PRIVATE nvshmem_extension) install(TARGETS nvshmem_extension EXPORT Caffe2Targets DESTINATION lib) endif() diff --git a/test/distributed/test_nvshmem.py b/test/distributed/test_nvshmem.py index cdd68591e7fe..7679eba8a4e3 100644 --- a/test/distributed/test_nvshmem.py +++ b/test/distributed/test_nvshmem.py @@ -2,11 +2,7 @@ # To run: # TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -# OR -# TORCH_SYMMMEM=NVSHMEM torchrun --nproc-per-node 4 test/distributed/test_nvshmem.py -import os -import sys import torch import torch.distributed as dist @@ -24,21 +20,11 @@ from torch.testing._internal.common_utils import ( from torch.testing._internal.inductor_utils import requires_triton -symm_mem_backend = os.getenv("TORCH_SYMMMEM") - -if symm_mem_backend != "NVSHMEM": - print( - "test_nvshmem requires setting `TORCH_SYMMMEM=NVSHMEM`, skipping tests", - file=sys.stderr, - ) - sys.exit(0) - - # Decorator def requires_nvshmem(): return skip_but_pass_in_sandcastle_if( - symm_mem_backend != "NVSHMEM", - "test_nvshmem requires setting `TORCH_SYMMMEM=NVSHMEM`", + not symm_mem.is_nvshmem_available(), + "test_nvshmem requires NVSHMEM, skipping tests", ) diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 99da27087bbe..f5e5c666b58f 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -705,6 +705,9 @@ def _unregister_process_group(group_name: str) -> None: ... # Python. At C++ interface, it is converted to a uintptr_t. def _nvshmemx_cumodule_init(module: int) -> None: ... +# Check if NVSHMEM is available on current system. +def _is_nvshmem_available() -> bool: ... + class _SymmetricMemory: @staticmethod def set_group_info( diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index b478a93df458..fddc374cd637 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1005,13 +1005,17 @@ This class does not support ``__members__`` property.)"); return ::c10d::unregister_all_process_groups(); }); +#ifdef USE_NVSHMEM // Intializes the device state in CUmodule so that it’s able to perform // NVSHMEM operations. -#ifdef USE_NVSHMEM module.def( "_nvshmemx_cumodule_init", ::c10d::nvshmem_extension::nvshmemx_cumodule_init, py::arg("module")); + + // Check if NVSHMEM is available on current system. + module.def( + "_is_nvshmem_available", ::c10d::nvshmem_extension::is_nvshmem_available); #endif py::class_<::c10d::BroadcastOptions>(module, "BroadcastOptions") diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp index dee189d58aa4..f13941ba5a27 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp @@ -32,15 +32,23 @@ bool allow_overlapping_devices() { // Query environment variable to get the backend used for CUDA Symmetric Memory. std::string getSymmMemBackendCUDA() { + // TORCH_SYMMMEM environment variable can be used to indicate the preferred + // backend. static auto val = c10::utils::get_env("TORCH_SYMMMEM"); - if (!val.has_value()) { - // In-house implementation: `CUDASymmetricMemory` - return "CUDA"; - } else { - // Other backends: - // - "NVSHMEM": `NVSHMEMSymmetricMemory` + if (val.has_value()) { + TORCH_CHECK( + val.value() == "CUDA" || val.value() == "NVSHMEM" || + val.value() == "NCCL", + "TORCH_SYMMMEM environment variable must be one of 'CUDA', 'NVSHMEM', 'NCCL'.") return val.value(); } + // If TORCH_SYMMMEM is not set, check if NVSHMEM is available (for broader + // support). + // TODO: uncomment this once all single-node tests work with NVSHMEM + // if (is_nvshmem_available()) { + // return "NVSHMEM"; + // } + return "CUDA"; } IpcChannel::IpcChannel() diff --git a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu index 4e4aa81ac926..e6f4f7972d27 100644 --- a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu +++ b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu @@ -1,3 +1,4 @@ +#include #include #include @@ -20,6 +21,28 @@ static StoreExchange storeExchange = StoreExchange("nvshmem_ext"); constexpr int MiB = 1024 * 1024; +// Check if NVSHMEM is available +bool is_nvshmem_available() { + // Runtime check + static std::mutex mutex; + static int is_available = -2; + std::lock_guard lock(mutex); + if (is_available == -2) { + void* handle{}; + // Open the shared library, RTLD_LAZY defers symbol resolution until needed + handle = dlopen("libnvshmem_host.so.3", RTLD_LAZY); + if (!handle) { + std::cerr << dlerror() << "\n"; + is_available = 0; + } else { + is_available = 1; + // Close the shared library + dlclose(handle); + } + } + return is_available == 1; +} + // Bootstrap based on user's setting for NCCL // Long term, this may be a bit unclean; short term, it improves UX void maybe_initialize_env_vars() { @@ -71,6 +94,11 @@ void initialize_nvshmem_with_store( "nvshmemx_init_attr failed"); is_initialized = true; + + // Print version + int major, minor; + ::nvshmem_info_get_version(&major, &minor); + LOG(INFO) << "NVSHMEM is available, version: " << major << "." << minor; } // Intializes the device state in CUmodule so that it’s able to perform NVSHMEM diff --git a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh index 9b537caaa623..fd51ded49dbe 100644 --- a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh +++ b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh @@ -11,6 +11,9 @@ void initialize_nvshmem_with_store( int rank, int world_size); +// Check if NVSHMEM is available +TORCH_API bool is_nvshmem_available(); + // Intializes the device state in CUmodule so that it’s able to perform NVSHMEM // operations. TORCH_API void nvshmemx_cumodule_init(uintptr_t module); diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index dcc4a4119490..ccaf955e216d 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -1704,4 +1704,20 @@ def rendezvous( return _SymmetricMemory.rendezvous(tensor, group_name) -__all__ = ["empty", "rendezvous"] +def is_nvshmem_available() -> bool: + r""" + is_nvshmem_available() -> bool + + Check if NVSHMEM is available in current build and on current system. + """ + try: + from torch._C._distributed_c10d import _is_nvshmem_available + except ImportError: + # Not all builds have NVSHMEM support. + return False + + # Check if NVSHMEM is available on current system. + return _is_nvshmem_available() + + +__all__ = ["empty", "rendezvous", "is_nvshmem_available"]