mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
harden fabric checks for symmetric memory (#160790)
Now we check only that fabric allocation succeeded, but sometimes we fail during export or import afterwards, with no recourse. Check the full cycle before attempting to allocate memory with the fabric. TODO: move it to c10/cuda so that it can be used from CUDACachingAllocator too Pull Request resolved: https://github.com/pytorch/pytorch/pull/160790 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
b439675ae2
commit
0254646654
@ -4,6 +4,9 @@
|
||||
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#endif
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
@ -12,6 +15,7 @@
|
||||
namespace at::cuda {
|
||||
|
||||
static std::vector<int8_t> p2pAccessEnabled_;
|
||||
static std::vector<int8_t> fabricAccessEnabled_;
|
||||
static int64_t num_devices_ = -1;
|
||||
|
||||
namespace detail {
|
||||
@ -29,6 +33,8 @@ void init_p2p_access_cache(int64_t num_devices) {
|
||||
for (const auto i : c10::irange(num_devices)) {
|
||||
p2pAccessEnabled_[i * num_devices + i] = 1;
|
||||
}
|
||||
fabricAccessEnabled_.clear();
|
||||
fabricAccessEnabled_.resize(num_devices, -1);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
@ -36,13 +42,14 @@ void init_p2p_access_cache(int64_t num_devices) {
|
||||
bool get_p2p_access(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
|
||||
TORCH_CHECK(dev >= 0 || dev < num_devices_,
|
||||
dev, " is not a device");
|
||||
TORCH_CHECK(dev_to_access >= 0 || dev_to_access < num_devices_,
|
||||
dev_to_access, " is not a device");
|
||||
TORCH_CHECK(dev >= 0 || dev < num_devices_, dev, " is not a device");
|
||||
TORCH_CHECK(
|
||||
dev_to_access >= 0 || dev_to_access < num_devices_,
|
||||
dev_to_access,
|
||||
" is not a device");
|
||||
TORCH_INTERNAL_ASSERT(num_devices_ >= 0, "p2p access cache not initialized");
|
||||
|
||||
auto &cache = p2pAccessEnabled_[dev * num_devices_ + dev_to_access];
|
||||
auto& cache = p2pAccessEnabled_[dev * num_devices_ + dev_to_access];
|
||||
|
||||
if (cache != -1) {
|
||||
return cache;
|
||||
@ -58,4 +65,118 @@ bool get_p2p_access(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
|
||||
return cache;
|
||||
}
|
||||
|
||||
} // namespace at::cuda::detail
|
||||
namespace {
|
||||
#if !defined USE_ROCM && defined CUDA_VERSION && CUDA_VERSION >= 12040 && defined PYTORCH_C10_DRIVER_API_SUPPORTED
|
||||
|
||||
nvmlDevice_t get_nvml_device(c10::DeviceIndex dev) {
|
||||
static bool nvml_init [[maybe_unused]] = []() {
|
||||
TORCH_INTERNAL_ASSERT(NVML_SUCCESS == DriverAPI::get()->nvmlInit_v2_());
|
||||
return true;
|
||||
}();
|
||||
|
||||
auto prop = at::cuda::getDeviceProperties(dev);
|
||||
char pci_id // NOLINT(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
[NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE];
|
||||
snprintf(
|
||||
pci_id,
|
||||
sizeof(pci_id),
|
||||
NVML_DEVICE_PCI_BUS_ID_FMT,
|
||||
prop->pciDomainID,
|
||||
prop->pciBusID,
|
||||
prop->pciDeviceID);
|
||||
|
||||
nvmlDevice_t nvml_device = nullptr;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
NVML_SUCCESS ==
|
||||
DriverAPI::get()->nvmlDeviceGetHandleByPciBusId_v2_(
|
||||
pci_id, &nvml_device));
|
||||
return nvml_device;
|
||||
}
|
||||
|
||||
bool isFabricSupported() {
|
||||
// 1. try allocating memory
|
||||
CUmemGenericAllocationHandle handle = 0;
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
|
||||
size_t granularity{};
|
||||
const auto driver_api = c10::cuda::DriverAPI::get();
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemGetAllocationGranularity_(
|
||||
&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
|
||||
|
||||
auto status = driver_api->cuMemCreate_(&handle, granularity, &prop, 0);
|
||||
if (status != CUDA_SUCCESS) {
|
||||
LOG(INFO)
|
||||
<< "status " << status
|
||||
<< " Could not allocate memory with FABRIC handle, falling back to fd handle exchange\n";
|
||||
return false;
|
||||
}
|
||||
// 2. check export
|
||||
CUmemFabricHandle sharedHandle;
|
||||
status = driver_api->cuMemExportToShareableHandle_(
|
||||
&sharedHandle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0);
|
||||
if (status != CUDA_SUCCESS) {
|
||||
LOG(INFO)
|
||||
<< "status " << status
|
||||
<< " Could not export FABRIC handle, falling back to fd handle exchange\n";
|
||||
driver_api->cuMemRelease_(handle);
|
||||
return false;
|
||||
}
|
||||
// 3. check import
|
||||
CUmemGenericAllocationHandle import_handle = 0;
|
||||
status = driver_api->cuMemImportFromShareableHandle_(
|
||||
&import_handle, &sharedHandle, CU_MEM_HANDLE_TYPE_FABRIC);
|
||||
if (status != CUDA_SUCCESS) {
|
||||
LOG(INFO)
|
||||
<< "status " << status
|
||||
<< " Could not import FABRIC handle, falling back to fd handle exchange\n";
|
||||
driver_api->cuMemRelease_(handle);
|
||||
return false;
|
||||
}
|
||||
driver_api->cuMemRelease_(import_handle);
|
||||
driver_api->cuMemRelease_(handle);
|
||||
LOG(INFO) << "using fabric to exchange memory handles\n";
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
bool get_fabric_access(c10::DeviceIndex dev) {
|
||||
#if !defined USE_ROCM && defined CUDA_VERSION && CUDA_VERSION >= 12040 && defined PYTORCH_C10_DRIVER_API_SUPPORTED
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
|
||||
TORCH_CHECK(dev >= 0 || dev < num_devices_, dev, " is not a device");
|
||||
auto& cache = fabricAccessEnabled_[dev];
|
||||
if (cache != -1) {
|
||||
return cache;
|
||||
}
|
||||
auto nvml_device = get_nvml_device(dev);
|
||||
if (nvml_device != nullptr) {
|
||||
nvmlGpuFabricInfoV_t fabricInfo;
|
||||
fabricInfo.state = NVML_GPU_FABRIC_STATE_NOT_SUPPORTED;
|
||||
fabricInfo.version = nvmlGpuFabricInfo_v2;
|
||||
if (DriverAPI::get()->nvmlDeviceGetGpuFabricInfoV_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
NVML_SUCCESS ==
|
||||
DriverAPI::get()->nvmlDeviceGetGpuFabricInfoV_(
|
||||
nvml_device, &fabricInfo));
|
||||
auto state = fabricInfo.state != NVML_GPU_FABRIC_STATE_NOT_SUPPORTED;
|
||||
if (state) {
|
||||
// now perform the full cycle of allocating - exporting - importing memory
|
||||
state = isFabricSupported();
|
||||
}
|
||||
cache = state ? 1 : 0;
|
||||
return cache;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace at::cuda
|
||||
|
@ -8,5 +8,6 @@ void init_p2p_access_cache(int64_t num_devices);
|
||||
}
|
||||
|
||||
TORCH_CUDA_CPP_API bool get_p2p_access(c10::DeviceIndex source_dev, c10::DeviceIndex dest_dev);
|
||||
TORCH_CUDA_CPP_API bool get_fabric_access(c10::DeviceIndex device);
|
||||
|
||||
} // namespace at::cuda
|
||||
|
@ -38,6 +38,13 @@ DriverAPI create_driver_api() {
|
||||
C10_NVML_DRIVER_API(LOOKUP_NVML_ENTRY)
|
||||
#undef LOOKUP_NVML_ENTRY
|
||||
}
|
||||
|
||||
if (handle_1) {
|
||||
#define LOOKUP_NVML_ENTRY_OPTIONAL(name) \
|
||||
r.name##_ = ((decltype(&name))dlsym(handle_1, #name));
|
||||
C10_NVML_DRIVER_API_OPTIONAL(LOOKUP_NVML_ENTRY_OPTIONAL)
|
||||
#undef LOOKUP_NVML_ENTRY_OPTIONAL
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
|
@ -67,6 +67,8 @@
|
||||
_(nvmlDeviceGetComputeRunningProcesses) \
|
||||
_(nvmlSystemGetCudaDriverVersion_v2)
|
||||
|
||||
#define C10_NVML_DRIVER_API_OPTIONAL(_) _(nvmlDeviceGetGpuFabricInfoV)
|
||||
|
||||
namespace c10::cuda {
|
||||
|
||||
struct DriverAPI {
|
||||
@ -75,6 +77,7 @@ struct DriverAPI {
|
||||
C10_LIBCUDA_DRIVER_API_REQUIRED(CREATE_MEMBER_VERSIONED)
|
||||
C10_LIBCUDA_DRIVER_API_OPTIONAL(CREATE_MEMBER_VERSIONED)
|
||||
C10_NVML_DRIVER_API(CREATE_MEMBER)
|
||||
C10_NVML_DRIVER_API_OPTIONAL(CREATE_MEMBER)
|
||||
#undef CREATE_MEMBER_VERSIONED
|
||||
#undef CREATE_MEMBER
|
||||
|
||||
|
@ -1122,6 +1122,11 @@ elseif(USE_CUDA)
|
||||
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
|
||||
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
|
||||
endif()
|
||||
# Set driver api defined for PeerToPeerAccess
|
||||
if(NOT WIN32)
|
||||
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/cuda/PeerToPeerAccess.cpp PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1")
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
if(USE_XPU)
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/PeerToPeerAccess.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/util/error.h>
|
||||
@ -420,23 +421,11 @@ void* CUDASymmetricMemoryAllocator::alloc(
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
|
||||
prop.location.id = device_idx;
|
||||
const auto driver_api = c10::cuda::DriverAPI::get();
|
||||
|
||||
bool has_fabric_support = at::cuda::get_fabric_access(device_idx);
|
||||
LOG(INFO) << "CUDASymmetricMemoryAllocator::alloc: has_fabric_support " << has_fabric_support;
|
||||
if (handle_type_ == Expandable_Segments_Handle_Type::UNSPECIFIED) {
|
||||
// Initialize NVML
|
||||
if (driver_api->nvmlInit_v2_() == NVML_SUCCESS) {
|
||||
// Get the driver version
|
||||
int version = -1;
|
||||
const auto res = driver_api->nvmlSystemGetCudaDriverVersion_v2_(&version);
|
||||
if (res == NVML_SUCCESS) {
|
||||
// Check if driver is sufficiently new
|
||||
if (version < 12040) {
|
||||
handle_type_ = Expandable_Segments_Handle_Type::POSIX_FD;
|
||||
handle_type_ = has_fabric_support ? Expandable_Segments_Handle_Type::FABRIC_HANDLE : Expandable_Segments_Handle_Type::POSIX_FD;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (handle_type_ == Expandable_Segments_Handle_Type::POSIX_FD) {
|
||||
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
|
||||
} else {
|
||||
@ -444,22 +433,13 @@ void* CUDASymmetricMemoryAllocator::alloc(
|
||||
}
|
||||
|
||||
size_t granularity;
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemGetAllocationGranularity_(
|
||||
&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
|
||||
block_size = at::round_up(block_size, granularity);
|
||||
|
||||
HandleType handle;
|
||||
auto status = driver_api->cuMemCreate_(&handle, block_size, &prop, 0);
|
||||
if (handle_type_ == Expandable_Segments_Handle_Type::UNSPECIFIED) {
|
||||
if (status != CUDA_SUCCESS) {
|
||||
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
|
||||
handle_type_ = Expandable_Segments_Handle_Type::POSIX_FD;
|
||||
status = driver_api->cuMemCreate_(&handle, block_size, &prop, 0);
|
||||
} else {
|
||||
handle_type_ = Expandable_Segments_Handle_Type::FABRIC_HANDLE;
|
||||
}
|
||||
}
|
||||
C10_CUDA_DRIVER_CHECK(status);
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemCreate_(&handle, block_size, &prop, 0));
|
||||
|
||||
#elif defined(USE_ROCM)
|
||||
hipMemAllocationProp prop = {};
|
||||
|
Reference in New Issue
Block a user