mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add wrappers for synchronous GPUDirect Storage APIs (#130633)
Based in part on https://github.com/NVIDIA/apex/pull/1774 Pull Request resolved: https://github.com/pytorch/pytorch/pull/130633 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
5c78581fc9
commit
5b5e0698a5
@ -413,6 +413,7 @@ cc_library(
|
||||
"@cuda//:nvrtc",
|
||||
"@cudnn",
|
||||
"@cudnn_frontend",
|
||||
"@cuda//:cufile",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
@ -251,6 +251,15 @@ cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
|
||||
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
|
||||
"USE_CUDNN" OFF)
|
||||
cmake_dependent_option(USE_CUSPARSELT "Use cuSPARSELt" ON "USE_CUDA" OFF)
|
||||
# Binary builds will fail for cufile due to https://github.com/pytorch/builder/issues/1924
|
||||
# Using TH_BINARY_BUILD to check whether is binary build.
|
||||
# USE_ROCM is guarded against in Dependencies.cmake because USE_ROCM is not properly defined here
|
||||
if(DEFINED ENV{TH_BINARY_BUILD})
|
||||
cmake_dependent_option(USE_CUFILE "Use cuFile" ON
|
||||
"USE_CUDA AND NOT $ENV{TH_BINARY_BUILD} AND NOT WIN32" OFF)
|
||||
else()
|
||||
cmake_dependent_option(USE_CUFILE "Use cuFile" ON "USE_CUDA AND NOT WIN32" OFF)
|
||||
endif()
|
||||
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
|
||||
option(USE_KINETO "Use Kineto profiling library" ON)
|
||||
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)
|
||||
|
@ -773,6 +773,7 @@ libtorch_python_cuda_core_sources = [
|
||||
"torch/csrc/cuda/shared/cudart.cpp",
|
||||
"torch/csrc/cuda/shared/nvtx.cpp",
|
||||
"torch/csrc/cuda/utils.cpp",
|
||||
"torch/csrc/cuda/GdsFile.cpp",
|
||||
]
|
||||
|
||||
libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [
|
||||
|
@ -928,6 +928,10 @@ elseif(USE_CUDA)
|
||||
torch_compile_options(torch_cuda) # see cmake/public/utils.cmake
|
||||
target_compile_definitions(torch_cuda PRIVATE USE_CUDA)
|
||||
|
||||
if(USE_CUFILE)
|
||||
target_link_libraries(torch_cuda PRIVATE torch::cufile)
|
||||
target_compile_definitions(torch_cuda PRIVATE USE_CUFILE)
|
||||
endif()
|
||||
if(USE_CUSPARSELT)
|
||||
target_link_libraries(torch_cuda PRIVATE torch::cusparselt)
|
||||
target_compile_definitions(torch_cuda PRIVATE USE_CUSPARSELT)
|
||||
|
@ -39,6 +39,7 @@ if(USE_CUDA)
|
||||
set(CAFFE2_USE_CUDA ${USE_CUDA})
|
||||
set(CAFFE2_USE_CUDNN ${USE_CUDNN})
|
||||
set(CAFFE2_USE_CUSPARSELT ${USE_CUSPARSELT})
|
||||
set(CAFFE2_USE_CUFILE ${USE_CUFILE})
|
||||
set(CAFFE2_USE_NVRTC ${USE_NVRTC})
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/public/cuda.cmake)
|
||||
if(CAFFE2_USE_CUDA)
|
||||
@ -60,6 +61,9 @@ if(USE_CUDA)
|
||||
else()
|
||||
caffe2_update_option(USE_CUSPARSELT OFF)
|
||||
endif()
|
||||
if(CAFFE2_USE_CUFILE)
|
||||
list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS torch::cufile)
|
||||
endif()
|
||||
find_program(SCCACHE_EXECUTABLE sccache)
|
||||
if(SCCACHE_EXECUTABLE)
|
||||
# Using RSP/--options-file renders output noncacheable by sccache
|
||||
@ -79,6 +83,7 @@ if(USE_CUDA)
|
||||
set(CAFFE2_USE_CUDA OFF)
|
||||
set(CAFFE2_USE_CUDNN OFF)
|
||||
set(CAFFE2_USE_CUSPARSELT OFF)
|
||||
set(CAFFE2_USE_CUFILE OFF)
|
||||
set(CAFFE2_USE_NVRTC OFF)
|
||||
endif()
|
||||
endif()
|
||||
@ -1035,7 +1040,6 @@ if(USE_ROCM)
|
||||
caffe2_update_option(USE_SYSTEM_NCCL ON)
|
||||
endif()
|
||||
|
||||
|
||||
list(APPEND HIP_CXX_FLAGS -fPIC)
|
||||
list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_AMD__=1)
|
||||
list(APPEND HIP_CXX_FLAGS -DCUDA_HAS_FP16=1)
|
||||
|
@ -978,6 +978,14 @@ if(CUDAToolkit_FOUND)
|
||||
_CUDAToolkit_find_and_add_import_lib(cublas_static DEPS culibos)
|
||||
endif()
|
||||
|
||||
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 11.4)
|
||||
_CUDAToolkit_find_and_add_import_lib(cuFile ALT cufile DEPS culibos)
|
||||
_CUDAToolkit_find_and_add_import_lib(cuFile_static ALT cufile_static DEPS culibos)
|
||||
|
||||
_CUDAToolkit_find_and_add_import_lib(cuFile_rdma ALT cufile_rdma DEPS cuFile culibos)
|
||||
_CUDAToolkit_find_and_add_import_lib(cuFile_rdma_static ALT cufile_rdma_static DEPS cuFile_static culibos)
|
||||
endif()
|
||||
|
||||
# cuFFTW depends on cuFFT
|
||||
_CUDAToolkit_find_and_add_import_lib(cufftw DEPS cufft)
|
||||
_CUDAToolkit_find_and_add_import_lib(cufftw_static DEPS cufft_static)
|
||||
|
@ -74,6 +74,7 @@ function(caffe2_print_configuration_summary)
|
||||
message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}")
|
||||
message(STATUS " USE_CUDNN : ${USE_CUDNN}")
|
||||
message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}")
|
||||
message(STATUS " USE_CUFILE : ${USE_CUFILE}")
|
||||
message(STATUS " CUDA version : ${CUDA_VERSION}")
|
||||
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
|
||||
message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}")
|
||||
@ -83,6 +84,9 @@ function(caffe2_print_configuration_summary)
|
||||
if(${USE_CUSPARSELT})
|
||||
message(STATUS " cuSPARSELt version : ${CUSPARSELT_VERSION}")
|
||||
endif()
|
||||
if(${USE_CUFILE})
|
||||
message(STATUS " cufile library : ${CUDA_cuFile_LIBRARY}")
|
||||
endif()
|
||||
message(STATUS " CUDA root directory : ${CUDA_TOOLKIT_ROOT_DIR}")
|
||||
message(STATUS " CUDA library : ${CUDA_cuda_driver_LIBRARY}")
|
||||
message(STATUS " cudart library : ${CUDA_cudart_LIBRARY}")
|
||||
|
@ -244,6 +244,22 @@ else()
|
||||
message(STATUS "USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support")
|
||||
endif()
|
||||
|
||||
# cufile
|
||||
if(CAFFE2_USE_CUFILE)
|
||||
add_library(torch::cufile INTERFACE IMPORTED)
|
||||
if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
|
||||
set_property(
|
||||
TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES
|
||||
CUDA::cuFile_static)
|
||||
else()
|
||||
set_property(
|
||||
TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES
|
||||
CUDA::cuFile)
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "USE_CUFILE is set to 0. Compiling without cuFile support")
|
||||
endif()
|
||||
|
||||
# curand
|
||||
add_library(caffe2::curand INTERFACE IMPORTED)
|
||||
if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
|
||||
|
@ -181,6 +181,7 @@ See the :doc:`documentation <cuda._sanitizer>` for information on how to use it.
|
||||
.. for tracking purposes
|
||||
.. py:module:: torch.cuda.comm
|
||||
.. py:module:: torch.cuda.error
|
||||
.. py:module:: torch.cuda.gds
|
||||
.. py:module:: torch.cuda.graphs
|
||||
.. py:module:: torch.cuda.jiterator
|
||||
.. py:module:: torch.cuda.memory
|
||||
|
3
setup.py
3
setup.py
@ -38,6 +38,9 @@
|
||||
# USE_CUSPARSELT=0
|
||||
# disables the cuSPARSELt build
|
||||
#
|
||||
# USE_CUFILE=0
|
||||
# disables the cuFile build
|
||||
#
|
||||
# USE_FBGEMM=0
|
||||
# disables the FBGEMM build
|
||||
#
|
||||
|
@ -17,6 +17,8 @@ from copy import deepcopy
|
||||
from itertools import product
|
||||
from random import randint
|
||||
|
||||
import psutil
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
from torch import inf, nan
|
||||
@ -62,6 +64,7 @@ from torch.testing._internal.common_utils import (
|
||||
skipIfRocm,
|
||||
slowTest,
|
||||
subtest,
|
||||
TemporaryFileName,
|
||||
TEST_CUDA,
|
||||
TEST_CUDA_GRAPH,
|
||||
TEST_NUMPY,
|
||||
@ -4022,6 +4025,15 @@ print(f"{{r1}}, {{r2}}")
|
||||
x = torch.cuda.device_count()
|
||||
self.assertEqual(f"{x}, 1", r)
|
||||
|
||||
def test_gds_fails_in_ci(self):
|
||||
if IS_WINDOWS or TEST_WITH_ROCM:
|
||||
error_msg = "is not supported on this platform"
|
||||
else:
|
||||
error_msg = "cuFileHandleRegister failed"
|
||||
with TemporaryFileName() as f:
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
|
||||
|
||||
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestCudaMallocAsync(TestCase):
|
||||
@ -5169,6 +5181,40 @@ class TestCudaOptims(TestCase):
|
||||
self.assertEqual(scaler._growth_tracker, growth_tracker)
|
||||
|
||||
|
||||
class TestGDS(TestCase):
|
||||
def _get_tmp_dir_fs_type(self):
|
||||
my_path = os.path.realpath("/tmp")
|
||||
root_type = ""
|
||||
for part in psutil.disk_partitions():
|
||||
if part.mountpoint == "/":
|
||||
root_type = part.fstype
|
||||
continue
|
||||
if part.mountpoint == my_path:
|
||||
return part.fstype
|
||||
return root_type
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS or TEST_WITH_ROCM, "Not supported on Windows or ROCm")
|
||||
def test_gds_read_write_tensors(self):
|
||||
if self._get_tmp_dir_fs_type() not in ("ext4", "xfs"):
|
||||
self.skipTest("GPUDirect Storage requires ext4/xfs for local filesystem")
|
||||
src1 = torch.randn(1024, device="cuda")
|
||||
src2 = torch.randn(2, 1024, device="cuda")
|
||||
torch.cuda.gds._gds_register_buffer(src1.untyped_storage())
|
||||
torch.cuda.gds._gds_register_buffer(src2.untyped_storage())
|
||||
dest1 = torch.empty(1024, device="cuda")
|
||||
dest2 = torch.empty(2, 1024, device="cuda")
|
||||
with TemporaryFileName() as f:
|
||||
file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
|
||||
file.save_storage(src1.untyped_storage(), offset=0)
|
||||
file.save_storage(src2.untyped_storage(), offset=src1.nbytes)
|
||||
file.load_storage(dest1.untyped_storage(), offset=0)
|
||||
file.load_storage(dest2.untyped_storage(), offset=src1.nbytes)
|
||||
self.assertEqual(src1, dest1)
|
||||
self.assertEqual(src2, dest2)
|
||||
torch.cuda.gds._gds_deregister_buffer(src1.untyped_storage())
|
||||
torch.cuda.gds._gds_deregister_buffer(src2.untyped_storage())
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestCuda)
|
||||
instantiate_parametrized_tests(TestCudaMallocAsync)
|
||||
instantiate_device_type_tests(TestCudaOptims, globals())
|
||||
|
6
third_party/cuda.BUILD
vendored
6
third_party/cuda.BUILD
vendored
@ -60,6 +60,12 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cufile",
|
||||
srcs = ["targets/x86_64-linux/lib/libcufile.so"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nvrtc",
|
||||
srcs = [
|
||||
|
@ -312,6 +312,10 @@ if(USE_NUMPY)
|
||||
target_compile_definitions(torch_python PRIVATE USE_NUMPY)
|
||||
endif()
|
||||
|
||||
if(USE_CUFILE AND NOT USE_ROCM)
|
||||
target_compile_definitions(torch_python PRIVATE USE_CUFILE)
|
||||
endif()
|
||||
|
||||
if(HAVE_SOVERSION)
|
||||
set_target_properties(torch_python PROPERTIES
|
||||
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})
|
||||
|
@ -1974,6 +1974,14 @@ def _can_use_cudnn_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
|
||||
def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
|
||||
def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
|
||||
|
||||
# Defined in torch/csrc/cuda/GdsFile.cpp
|
||||
def _gds_register_buffer(t: Storage) -> None: ...
|
||||
def _gds_deregister_buffer(t: Storage) -> None: ...
|
||||
def _gds_register_handle(fd: _int) -> _int: ...
|
||||
def _gds_deregister_handle(handle: _int) -> None: ...
|
||||
def _gds_load_storage(handle: _int, s: Storage, offset: _int) -> None: ...
|
||||
def _gds_save_storage(handle: _int, s: Storage, offset: _int) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/cuda/python_comm.cpp
|
||||
def _broadcast(tensor: Tensor, devices: List[_int]) -> List[Tensor]: ...
|
||||
def _broadcast_out(tensor: Tensor, out_tensors: List[Tensor]) -> List[Tensor]: ...
|
||||
|
@ -306,6 +306,7 @@ def _load_global_deps() -> None:
|
||||
"cuda_runtime": "libcudart.so.*[0-9]",
|
||||
"cuda_cupti": "libcupti.so.*[0-9]",
|
||||
"cufft": "libcufft.so.*[0-9]",
|
||||
"cufile": "libcufile.so.*[0-9]",
|
||||
"curand": "libcurand.so.*[0-9]",
|
||||
"cusolver": "libcusolver.so.*[0-9]",
|
||||
"cusparse": "libcusparse.so.*[0-9]",
|
||||
|
134
torch/csrc/cuda/GdsFile.cpp
Normal file
134
torch/csrc/cuda/GdsFile.cpp
Normal file
@ -0,0 +1,134 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#if defined(USE_CUFILE)
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cufile.h>
|
||||
|
||||
namespace {
|
||||
// To get error message for cuFileRead/Write APIs that return ssize_t (-1 for
|
||||
// filesystem error and a negative CUfileOpError enum value otherwise).
|
||||
template <
|
||||
class T,
|
||||
typename std::enable_if<std::is_integral<T>::value, std::nullptr_t>::type =
|
||||
nullptr>
|
||||
std::string cuGDSFileGetErrorString(T status) {
|
||||
status = std::abs(status);
|
||||
return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status))
|
||||
: std::string(std::strerror(errno));
|
||||
}
|
||||
|
||||
// To get error message for Buf/Handle registeration APIs that return
|
||||
// CUfileError_t
|
||||
template <
|
||||
class T,
|
||||
typename std::enable_if<!std::is_integral<T>::value, std::nullptr_t>::type =
|
||||
nullptr>
|
||||
std::string cuGDSFileGetErrorString(T status) {
|
||||
std::string errStr = cuGDSFileGetErrorString(static_cast<int>(status.err));
|
||||
if (IS_CUDA_ERR(status))
|
||||
errStr.append(".").append(
|
||||
cudaGetErrorString(static_cast<cudaError_t>(status.cu_err)));
|
||||
return errStr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void gds_load_storage(
|
||||
int64_t handle,
|
||||
const at::Storage& storage,
|
||||
off_t offset) {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
|
||||
c10::cuda::CUDAGuard gpuGuard(storage.device());
|
||||
|
||||
void* dataPtr = storage.mutable_data();
|
||||
const size_t nbytes = storage.nbytes();
|
||||
|
||||
// Read the binary file
|
||||
ssize_t ret = cuFileRead(cf_handle, (void*)dataPtr, nbytes, offset, 0);
|
||||
TORCH_CHECK(ret >= 0, "cuFileRead failed: ", cuGDSFileGetErrorString(ret));
|
||||
}
|
||||
|
||||
void gds_save_storage(
|
||||
int64_t handle,
|
||||
const at::Storage& storage,
|
||||
off_t offset) {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
|
||||
c10::cuda::CUDAGuard gpuGuard(storage.device());
|
||||
|
||||
void* dataPtr = storage.mutable_data();
|
||||
const size_t nbytes = storage.nbytes();
|
||||
|
||||
// Write device memory contents to the file
|
||||
ssize_t ret = cuFileWrite(cf_handle, dataPtr, nbytes, offset, 0);
|
||||
TORCH_CHECK(ret >= 0, "cuFileWrite failed: ", cuGDSFileGetErrorString(ret));
|
||||
}
|
||||
|
||||
void gds_register_buffer(const at::Storage& storage) {
|
||||
void* dataPtr = storage.mutable_data();
|
||||
const size_t nbytes = storage.nbytes();
|
||||
|
||||
CUfileError_t status = cuFileBufRegister(dataPtr, nbytes, 0);
|
||||
TORCH_CHECK(
|
||||
status.err == CU_FILE_SUCCESS,
|
||||
"cuFileBufRegister failed: ",
|
||||
cuGDSFileGetErrorString(status));
|
||||
return;
|
||||
}
|
||||
|
||||
void gds_deregister_buffer(const at::Storage& storage) {
|
||||
void* dataPtr = storage.mutable_data();
|
||||
CUfileError_t status = cuFileBufDeregister(dataPtr);
|
||||
TORCH_CHECK(
|
||||
status.err == CU_FILE_SUCCESS,
|
||||
"cuFileBufDeregister failed: ",
|
||||
cuGDSFileGetErrorString(status));
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t gds_register_handle(int fd) {
|
||||
CUfileDescr_t cf_descr;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
CUfileHandle_t cf_handle;
|
||||
memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t));
|
||||
cf_descr.handle.fd = fd;
|
||||
cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
|
||||
CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr);
|
||||
if (status.err != CU_FILE_SUCCESS) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"cuFileHandleRegister failed: ",
|
||||
cuGDSFileGetErrorString(status));
|
||||
}
|
||||
|
||||
// Returning cuFileHandle_t as int64_t
|
||||
return reinterpret_cast<int64_t>(cf_handle);
|
||||
}
|
||||
|
||||
void gds_deregister_handle(int64_t handle) {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
|
||||
cuFileHandleDeregister(cf_handle);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
namespace torch::cuda::shared {
|
||||
|
||||
void initGdsBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
|
||||
#if defined(USE_CUFILE)
|
||||
m.def("_gds_register_handle", &gds_register_handle);
|
||||
m.def("_gds_deregister_handle", &gds_deregister_handle);
|
||||
m.def("_gds_register_buffer", &gds_register_buffer);
|
||||
m.def("_gds_deregister_buffer", &gds_deregister_buffer);
|
||||
m.def("_gds_load_storage", &gds_load_storage);
|
||||
m.def("_gds_save_storage", &gds_save_storage);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace torch::cuda::shared
|
7
torch/csrc/cuda/GdsFile.h
Normal file
7
torch/csrc/cuda/GdsFile.h
Normal file
@ -0,0 +1,7 @@
|
||||
#ifndef THCP_GDSFILE_INC
|
||||
#define THCP_GDSFILE_INC
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
void initGdsBindings(PyObject* module);
|
||||
#endif // THCP_GDSFILE_INC
|
@ -34,6 +34,7 @@
|
||||
#include <torch/csrc/CudaIPCTypes.h>
|
||||
#include <torch/csrc/Generator.h>
|
||||
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
|
||||
#include <torch/csrc/cuda/GdsFile.h>
|
||||
#include <torch/csrc/cuda/THCP.h>
|
||||
#include <torch/csrc/cuda/memory_snapshot.h>
|
||||
#include <torch/csrc/cuda/python_comm.h>
|
||||
@ -1953,6 +1954,7 @@ namespace shared {
|
||||
|
||||
void initCudartBindings(PyObject* module);
|
||||
void initNvtxBindings(PyObject* module);
|
||||
void initGdsBindings(PyObject* module);
|
||||
#if defined(USE_CUDNN) || defined(USE_ROCM)
|
||||
void initCudnnBindings(PyObject* module);
|
||||
#endif
|
||||
@ -1968,6 +1970,7 @@ void initModule(PyObject* module) {
|
||||
#if defined(USE_CUDNN) || defined(USE_ROCM)
|
||||
shared::initCudnnBindings(module);
|
||||
#endif
|
||||
shared::initGdsBindings(module);
|
||||
registerCudaDeviceProperties(module);
|
||||
registerCudaPluggableAllocator(module);
|
||||
}
|
||||
|
@ -27,6 +27,7 @@ import torch._C
|
||||
from torch.types import Device
|
||||
from .. import device as _device
|
||||
from .._utils import _dummy_type, _LazySeedTracker, classproperty
|
||||
from . import gds
|
||||
from ._utils import _get_device_index
|
||||
from .graphs import (
|
||||
CUDAGraph,
|
||||
|
126
torch/cuda/gds.py
Normal file
126
torch/cuda/gds.py
Normal file
@ -0,0 +1,126 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.types import Storage
|
||||
|
||||
|
||||
def _dummy_fn(name: str) -> Callable:
|
||||
def fn(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||
raise RuntimeError(f"torch._C.{name} is not supported on this platform")
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
if not hasattr(torch._C, "_gds_register_buffer"):
|
||||
assert not hasattr(torch._C, "_gds_deregister_buffer")
|
||||
assert not hasattr(torch._C, "_gds_register_handle")
|
||||
assert not hasattr(torch._C, "_gds_deregister_handle")
|
||||
assert not hasattr(torch._C, "_gds_load_storage")
|
||||
assert not hasattr(torch._C, "_gds_save_storage")
|
||||
# Define functions
|
||||
torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer")
|
||||
torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer")
|
||||
torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle")
|
||||
torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle")
|
||||
torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage")
|
||||
torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage")
|
||||
|
||||
|
||||
def _gds_register_buffer(s: Storage) -> None:
|
||||
"""Registers a buffer.
|
||||
|
||||
Args:
|
||||
s (Storage): Buffer to register.
|
||||
"""
|
||||
torch._C._gds_register_buffer(s)
|
||||
|
||||
|
||||
def _gds_deregister_buffer(s: Storage) -> None:
|
||||
"""Registers a buffer.
|
||||
|
||||
Args:
|
||||
s (Storage): Buffer to register.
|
||||
"""
|
||||
torch._C._gds_deregister_buffer(s)
|
||||
|
||||
|
||||
class _GdsFile:
|
||||
r"""Wrapper around cuFile.
|
||||
|
||||
cuFile is a file-like interface to the GPUDirect Storage (GDS) API.
|
||||
|
||||
Args:
|
||||
filename (str): Name of the file to open.
|
||||
flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will
|
||||
be added automatically.
|
||||
|
||||
.. _CUDA GPUDirect Storage Documentation:
|
||||
https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api
|
||||
"""
|
||||
|
||||
def __init__(self, filename: str, flags: int):
|
||||
if sys.platform == "win32":
|
||||
raise RuntimeError("GdsFile is not supported on this platform.")
|
||||
self.filename = filename
|
||||
self.flags = flags
|
||||
self.fd = os.open(filename, flags | os.O_DIRECT)
|
||||
self.handle: Optional[int] = None
|
||||
self.register_handle()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.handle is not None:
|
||||
self.deregister_handle()
|
||||
os.close(self.fd)
|
||||
|
||||
def register_handle(self) -> None:
|
||||
"""Registers file descriptor to cuFile Driver.
|
||||
|
||||
This is a wrapper around ``cuFileHandleRegister``.
|
||||
"""
|
||||
assert (
|
||||
self.handle is None
|
||||
), "Cannot register a handle that is already registered."
|
||||
self.handle = torch._C._gds_register_handle(self.fd)
|
||||
|
||||
def deregister_handle(self) -> None:
|
||||
"""Deregisters file descriptor from cuFile Driver.
|
||||
|
||||
This is a wrapper around ``cuFileHandleDeregister``.
|
||||
"""
|
||||
assert (
|
||||
self.handle is not None
|
||||
), "Cannot deregister a handle that is not registered."
|
||||
torch._C._gds_deregister_handle(self.handle)
|
||||
self.handle = None
|
||||
|
||||
def load_storage(self, storage: Storage, offset: int = 0) -> None:
|
||||
"""Loads data from the file into the storage.
|
||||
|
||||
This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data
|
||||
will be loaded from the file at ``offset`` into the storage.
|
||||
|
||||
Args:
|
||||
storage (Storage): Storage to load data into.
|
||||
offset (int, optional): Offset into the file to start loading from. (Default: 0)
|
||||
"""
|
||||
assert (
|
||||
self.handle is not None
|
||||
), "Cannot load data from a file that is not registered."
|
||||
torch._C._gds_load_storage(self.handle, storage, offset)
|
||||
|
||||
def save_storage(self, storage: Storage, offset: int = 0) -> None:
|
||||
"""Saves data from the storage into the file.
|
||||
|
||||
This is a wrapper around ``cuFileWrite``. All bytes of the storage
|
||||
will be written to the file at ``offset``.
|
||||
|
||||
Args:
|
||||
storage (Storage): Storage to save data from.
|
||||
offset (int, optional): Offset into the file to start saving to. (Default: 0)
|
||||
"""
|
||||
assert (
|
||||
self.handle is not None
|
||||
), "Cannot save data to a file that is not registered."
|
||||
torch._C._gds_save_storage(self.handle, storage, offset)
|
Reference in New Issue
Block a user