Add wrappers for synchronous GPUDirect Storage APIs (#130633)

Based in part on https://github.com/NVIDIA/apex/pull/1774

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130633
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-07-19 13:36:03 -07:00
committed by PyTorch MergeBot
parent 5c78581fc9
commit 5b5e0698a5
20 changed files with 388 additions and 1 deletions

View File

@ -413,6 +413,7 @@ cc_library(
"@cuda//:nvrtc",
"@cudnn",
"@cudnn_frontend",
"@cuda//:cufile",
],
alwayslink = True,
)

View File

@ -251,6 +251,15 @@ cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
"USE_CUDNN" OFF)
cmake_dependent_option(USE_CUSPARSELT "Use cuSPARSELt" ON "USE_CUDA" OFF)
# Binary builds will fail for cufile due to https://github.com/pytorch/builder/issues/1924
# Using TH_BINARY_BUILD to check whether is binary build.
# USE_ROCM is guarded against in Dependencies.cmake because USE_ROCM is not properly defined here
if(DEFINED ENV{TH_BINARY_BUILD})
cmake_dependent_option(USE_CUFILE "Use cuFile" ON
"USE_CUDA AND NOT $ENV{TH_BINARY_BUILD} AND NOT WIN32" OFF)
else()
cmake_dependent_option(USE_CUFILE "Use cuFile" ON "USE_CUDA AND NOT WIN32" OFF)
endif()
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
option(USE_KINETO "Use Kineto profiling library" ON)
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)

View File

@ -773,6 +773,7 @@ libtorch_python_cuda_core_sources = [
"torch/csrc/cuda/shared/cudart.cpp",
"torch/csrc/cuda/shared/nvtx.cpp",
"torch/csrc/cuda/utils.cpp",
"torch/csrc/cuda/GdsFile.cpp",
]
libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [

View File

@ -928,6 +928,10 @@ elseif(USE_CUDA)
torch_compile_options(torch_cuda) # see cmake/public/utils.cmake
target_compile_definitions(torch_cuda PRIVATE USE_CUDA)
if(USE_CUFILE)
target_link_libraries(torch_cuda PRIVATE torch::cufile)
target_compile_definitions(torch_cuda PRIVATE USE_CUFILE)
endif()
if(USE_CUSPARSELT)
target_link_libraries(torch_cuda PRIVATE torch::cusparselt)
target_compile_definitions(torch_cuda PRIVATE USE_CUSPARSELT)

View File

@ -39,6 +39,7 @@ if(USE_CUDA)
set(CAFFE2_USE_CUDA ${USE_CUDA})
set(CAFFE2_USE_CUDNN ${USE_CUDNN})
set(CAFFE2_USE_CUSPARSELT ${USE_CUSPARSELT})
set(CAFFE2_USE_CUFILE ${USE_CUFILE})
set(CAFFE2_USE_NVRTC ${USE_NVRTC})
include(${CMAKE_CURRENT_LIST_DIR}/public/cuda.cmake)
if(CAFFE2_USE_CUDA)
@ -60,6 +61,9 @@ if(USE_CUDA)
else()
caffe2_update_option(USE_CUSPARSELT OFF)
endif()
if(CAFFE2_USE_CUFILE)
list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS torch::cufile)
endif()
find_program(SCCACHE_EXECUTABLE sccache)
if(SCCACHE_EXECUTABLE)
# Using RSP/--options-file renders output noncacheable by sccache
@ -79,6 +83,7 @@ if(USE_CUDA)
set(CAFFE2_USE_CUDA OFF)
set(CAFFE2_USE_CUDNN OFF)
set(CAFFE2_USE_CUSPARSELT OFF)
set(CAFFE2_USE_CUFILE OFF)
set(CAFFE2_USE_NVRTC OFF)
endif()
endif()
@ -1035,7 +1040,6 @@ if(USE_ROCM)
caffe2_update_option(USE_SYSTEM_NCCL ON)
endif()
list(APPEND HIP_CXX_FLAGS -fPIC)
list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_AMD__=1)
list(APPEND HIP_CXX_FLAGS -DCUDA_HAS_FP16=1)

View File

@ -978,6 +978,14 @@ if(CUDAToolkit_FOUND)
_CUDAToolkit_find_and_add_import_lib(cublas_static DEPS culibos)
endif()
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 11.4)
_CUDAToolkit_find_and_add_import_lib(cuFile ALT cufile DEPS culibos)
_CUDAToolkit_find_and_add_import_lib(cuFile_static ALT cufile_static DEPS culibos)
_CUDAToolkit_find_and_add_import_lib(cuFile_rdma ALT cufile_rdma DEPS cuFile culibos)
_CUDAToolkit_find_and_add_import_lib(cuFile_rdma_static ALT cufile_rdma_static DEPS cuFile_static culibos)
endif()
# cuFFTW depends on cuFFT
_CUDAToolkit_find_and_add_import_lib(cufftw DEPS cufft)
_CUDAToolkit_find_and_add_import_lib(cufftw_static DEPS cufft_static)

View File

@ -74,6 +74,7 @@ function(caffe2_print_configuration_summary)
message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}")
message(STATUS " USE_CUDNN : ${USE_CUDNN}")
message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}")
message(STATUS " USE_CUFILE : ${USE_CUFILE}")
message(STATUS " CUDA version : ${CUDA_VERSION}")
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}")
@ -83,6 +84,9 @@ function(caffe2_print_configuration_summary)
if(${USE_CUSPARSELT})
message(STATUS " cuSPARSELt version : ${CUSPARSELT_VERSION}")
endif()
if(${USE_CUFILE})
message(STATUS " cufile library : ${CUDA_cuFile_LIBRARY}")
endif()
message(STATUS " CUDA root directory : ${CUDA_TOOLKIT_ROOT_DIR}")
message(STATUS " CUDA library : ${CUDA_cuda_driver_LIBRARY}")
message(STATUS " cudart library : ${CUDA_cudart_LIBRARY}")

View File

@ -244,6 +244,22 @@ else()
message(STATUS "USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support")
endif()
# cufile
if(CAFFE2_USE_CUFILE)
add_library(torch::cufile INTERFACE IMPORTED)
if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
set_property(
TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES
CUDA::cuFile_static)
else()
set_property(
TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES
CUDA::cuFile)
endif()
else()
message(STATUS "USE_CUFILE is set to 0. Compiling without cuFile support")
endif()
# curand
add_library(caffe2::curand INTERFACE IMPORTED)
if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)

View File

@ -181,6 +181,7 @@ See the :doc:`documentation <cuda._sanitizer>` for information on how to use it.
.. for tracking purposes
.. py:module:: torch.cuda.comm
.. py:module:: torch.cuda.error
.. py:module:: torch.cuda.gds
.. py:module:: torch.cuda.graphs
.. py:module:: torch.cuda.jiterator
.. py:module:: torch.cuda.memory

View File

@ -38,6 +38,9 @@
# USE_CUSPARSELT=0
# disables the cuSPARSELt build
#
# USE_CUFILE=0
# disables the cuFile build
#
# USE_FBGEMM=0
# disables the FBGEMM build
#

View File

@ -17,6 +17,8 @@ from copy import deepcopy
from itertools import product
from random import randint
import psutil
import torch
import torch.cuda
from torch import inf, nan
@ -62,6 +64,7 @@ from torch.testing._internal.common_utils import (
skipIfRocm,
slowTest,
subtest,
TemporaryFileName,
TEST_CUDA,
TEST_CUDA_GRAPH,
TEST_NUMPY,
@ -4022,6 +4025,15 @@ print(f"{{r1}}, {{r2}}")
x = torch.cuda.device_count()
self.assertEqual(f"{x}, 1", r)
def test_gds_fails_in_ci(self):
if IS_WINDOWS or TEST_WITH_ROCM:
error_msg = "is not supported on this platform"
else:
error_msg = "cuFileHandleRegister failed"
with TemporaryFileName() as f:
with self.assertRaisesRegex(RuntimeError, error_msg):
file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCudaMallocAsync(TestCase):
@ -5169,6 +5181,40 @@ class TestCudaOptims(TestCase):
self.assertEqual(scaler._growth_tracker, growth_tracker)
class TestGDS(TestCase):
def _get_tmp_dir_fs_type(self):
my_path = os.path.realpath("/tmp")
root_type = ""
for part in psutil.disk_partitions():
if part.mountpoint == "/":
root_type = part.fstype
continue
if part.mountpoint == my_path:
return part.fstype
return root_type
@unittest.skipIf(IS_WINDOWS or TEST_WITH_ROCM, "Not supported on Windows or ROCm")
def test_gds_read_write_tensors(self):
if self._get_tmp_dir_fs_type() not in ("ext4", "xfs"):
self.skipTest("GPUDirect Storage requires ext4/xfs for local filesystem")
src1 = torch.randn(1024, device="cuda")
src2 = torch.randn(2, 1024, device="cuda")
torch.cuda.gds._gds_register_buffer(src1.untyped_storage())
torch.cuda.gds._gds_register_buffer(src2.untyped_storage())
dest1 = torch.empty(1024, device="cuda")
dest2 = torch.empty(2, 1024, device="cuda")
with TemporaryFileName() as f:
file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
file.save_storage(src1.untyped_storage(), offset=0)
file.save_storage(src2.untyped_storage(), offset=src1.nbytes)
file.load_storage(dest1.untyped_storage(), offset=0)
file.load_storage(dest2.untyped_storage(), offset=src1.nbytes)
self.assertEqual(src1, dest1)
self.assertEqual(src2, dest2)
torch.cuda.gds._gds_deregister_buffer(src1.untyped_storage())
torch.cuda.gds._gds_deregister_buffer(src2.untyped_storage())
instantiate_parametrized_tests(TestCuda)
instantiate_parametrized_tests(TestCudaMallocAsync)
instantiate_device_type_tests(TestCudaOptims, globals())

View File

@ -60,6 +60,12 @@ cc_library(
visibility = ["//visibility:public"],
)
cc_library(
name = "cufile",
srcs = ["targets/x86_64-linux/lib/libcufile.so"],
visibility = ["//visibility:public"],
)
cc_library(
name = "nvrtc",
srcs = [

View File

@ -312,6 +312,10 @@ if(USE_NUMPY)
target_compile_definitions(torch_python PRIVATE USE_NUMPY)
endif()
if(USE_CUFILE AND NOT USE_ROCM)
target_compile_definitions(torch_python PRIVATE USE_CUFILE)
endif()
if(HAVE_SOVERSION)
set_target_properties(torch_python PROPERTIES
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})

View File

@ -1974,6 +1974,14 @@ def _can_use_cudnn_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
# Defined in torch/csrc/cuda/GdsFile.cpp
def _gds_register_buffer(t: Storage) -> None: ...
def _gds_deregister_buffer(t: Storage) -> None: ...
def _gds_register_handle(fd: _int) -> _int: ...
def _gds_deregister_handle(handle: _int) -> None: ...
def _gds_load_storage(handle: _int, s: Storage, offset: _int) -> None: ...
def _gds_save_storage(handle: _int, s: Storage, offset: _int) -> None: ...
# Defined in torch/csrc/cuda/python_comm.cpp
def _broadcast(tensor: Tensor, devices: List[_int]) -> List[Tensor]: ...
def _broadcast_out(tensor: Tensor, out_tensors: List[Tensor]) -> List[Tensor]: ...

View File

@ -306,6 +306,7 @@ def _load_global_deps() -> None:
"cuda_runtime": "libcudart.so.*[0-9]",
"cuda_cupti": "libcupti.so.*[0-9]",
"cufft": "libcufft.so.*[0-9]",
"cufile": "libcufile.so.*[0-9]",
"curand": "libcurand.so.*[0-9]",
"cusolver": "libcusolver.so.*[0-9]",
"cusparse": "libcusparse.so.*[0-9]",

134
torch/csrc/cuda/GdsFile.cpp Normal file
View File

@ -0,0 +1,134 @@
#include <pybind11/pybind11.h>
#include <torch/csrc/utils/pybind.h>
#if defined(USE_CUFILE)
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cufile.h>
namespace {
// To get error message for cuFileRead/Write APIs that return ssize_t (-1 for
// filesystem error and a negative CUfileOpError enum value otherwise).
template <
class T,
typename std::enable_if<std::is_integral<T>::value, std::nullptr_t>::type =
nullptr>
std::string cuGDSFileGetErrorString(T status) {
status = std::abs(status);
return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status))
: std::string(std::strerror(errno));
}
// To get error message for Buf/Handle registeration APIs that return
// CUfileError_t
template <
class T,
typename std::enable_if<!std::is_integral<T>::value, std::nullptr_t>::type =
nullptr>
std::string cuGDSFileGetErrorString(T status) {
std::string errStr = cuGDSFileGetErrorString(static_cast<int>(status.err));
if (IS_CUDA_ERR(status))
errStr.append(".").append(
cudaGetErrorString(static_cast<cudaError_t>(status.cu_err)));
return errStr;
}
} // namespace
void gds_load_storage(
int64_t handle,
const at::Storage& storage,
off_t offset) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
c10::cuda::CUDAGuard gpuGuard(storage.device());
void* dataPtr = storage.mutable_data();
const size_t nbytes = storage.nbytes();
// Read the binary file
ssize_t ret = cuFileRead(cf_handle, (void*)dataPtr, nbytes, offset, 0);
TORCH_CHECK(ret >= 0, "cuFileRead failed: ", cuGDSFileGetErrorString(ret));
}
void gds_save_storage(
int64_t handle,
const at::Storage& storage,
off_t offset) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
c10::cuda::CUDAGuard gpuGuard(storage.device());
void* dataPtr = storage.mutable_data();
const size_t nbytes = storage.nbytes();
// Write device memory contents to the file
ssize_t ret = cuFileWrite(cf_handle, dataPtr, nbytes, offset, 0);
TORCH_CHECK(ret >= 0, "cuFileWrite failed: ", cuGDSFileGetErrorString(ret));
}
void gds_register_buffer(const at::Storage& storage) {
void* dataPtr = storage.mutable_data();
const size_t nbytes = storage.nbytes();
CUfileError_t status = cuFileBufRegister(dataPtr, nbytes, 0);
TORCH_CHECK(
status.err == CU_FILE_SUCCESS,
"cuFileBufRegister failed: ",
cuGDSFileGetErrorString(status));
return;
}
void gds_deregister_buffer(const at::Storage& storage) {
void* dataPtr = storage.mutable_data();
CUfileError_t status = cuFileBufDeregister(dataPtr);
TORCH_CHECK(
status.err == CU_FILE_SUCCESS,
"cuFileBufDeregister failed: ",
cuGDSFileGetErrorString(status));
return;
}
int64_t gds_register_handle(int fd) {
CUfileDescr_t cf_descr;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
CUfileHandle_t cf_handle;
memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t));
cf_descr.handle.fd = fd;
cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr);
if (status.err != CU_FILE_SUCCESS) {
TORCH_CHECK(
false,
"cuFileHandleRegister failed: ",
cuGDSFileGetErrorString(status));
}
// Returning cuFileHandle_t as int64_t
return reinterpret_cast<int64_t>(cf_handle);
}
void gds_deregister_handle(int64_t handle) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
cuFileHandleDeregister(cf_handle);
}
#endif
namespace torch::cuda::shared {
void initGdsBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
#if defined(USE_CUFILE)
m.def("_gds_register_handle", &gds_register_handle);
m.def("_gds_deregister_handle", &gds_deregister_handle);
m.def("_gds_register_buffer", &gds_register_buffer);
m.def("_gds_deregister_buffer", &gds_deregister_buffer);
m.def("_gds_load_storage", &gds_load_storage);
m.def("_gds_save_storage", &gds_save_storage);
#endif
}
} // namespace torch::cuda::shared

View File

@ -0,0 +1,7 @@
#ifndef THCP_GDSFILE_INC
#define THCP_GDSFILE_INC
#include <torch/csrc/python_headers.h>
void initGdsBindings(PyObject* module);
#endif // THCP_GDSFILE_INC

View File

@ -34,6 +34,7 @@
#include <torch/csrc/CudaIPCTypes.h>
#include <torch/csrc/Generator.h>
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
#include <torch/csrc/cuda/GdsFile.h>
#include <torch/csrc/cuda/THCP.h>
#include <torch/csrc/cuda/memory_snapshot.h>
#include <torch/csrc/cuda/python_comm.h>
@ -1953,6 +1954,7 @@ namespace shared {
void initCudartBindings(PyObject* module);
void initNvtxBindings(PyObject* module);
void initGdsBindings(PyObject* module);
#if defined(USE_CUDNN) || defined(USE_ROCM)
void initCudnnBindings(PyObject* module);
#endif
@ -1968,6 +1970,7 @@ void initModule(PyObject* module) {
#if defined(USE_CUDNN) || defined(USE_ROCM)
shared::initCudnnBindings(module);
#endif
shared::initGdsBindings(module);
registerCudaDeviceProperties(module);
registerCudaPluggableAllocator(module);
}

View File

@ -27,6 +27,7 @@ import torch._C
from torch.types import Device
from .. import device as _device
from .._utils import _dummy_type, _LazySeedTracker, classproperty
from . import gds
from ._utils import _get_device_index
from .graphs import (
CUDAGraph,

126
torch/cuda/gds.py Normal file
View File

@ -0,0 +1,126 @@
import os
import sys
from typing import Callable, Optional
import torch
from torch.types import Storage
def _dummy_fn(name: str) -> Callable:
def fn(*args, **kwargs): # type: ignore[no-untyped-def]
raise RuntimeError(f"torch._C.{name} is not supported on this platform")
return fn
if not hasattr(torch._C, "_gds_register_buffer"):
assert not hasattr(torch._C, "_gds_deregister_buffer")
assert not hasattr(torch._C, "_gds_register_handle")
assert not hasattr(torch._C, "_gds_deregister_handle")
assert not hasattr(torch._C, "_gds_load_storage")
assert not hasattr(torch._C, "_gds_save_storage")
# Define functions
torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer")
torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer")
torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle")
torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle")
torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage")
torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage")
def _gds_register_buffer(s: Storage) -> None:
"""Registers a buffer.
Args:
s (Storage): Buffer to register.
"""
torch._C._gds_register_buffer(s)
def _gds_deregister_buffer(s: Storage) -> None:
"""Registers a buffer.
Args:
s (Storage): Buffer to register.
"""
torch._C._gds_deregister_buffer(s)
class _GdsFile:
r"""Wrapper around cuFile.
cuFile is a file-like interface to the GPUDirect Storage (GDS) API.
Args:
filename (str): Name of the file to open.
flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will
be added automatically.
.. _CUDA GPUDirect Storage Documentation:
https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api
"""
def __init__(self, filename: str, flags: int):
if sys.platform == "win32":
raise RuntimeError("GdsFile is not supported on this platform.")
self.filename = filename
self.flags = flags
self.fd = os.open(filename, flags | os.O_DIRECT)
self.handle: Optional[int] = None
self.register_handle()
def __del__(self) -> None:
if self.handle is not None:
self.deregister_handle()
os.close(self.fd)
def register_handle(self) -> None:
"""Registers file descriptor to cuFile Driver.
This is a wrapper around ``cuFileHandleRegister``.
"""
assert (
self.handle is None
), "Cannot register a handle that is already registered."
self.handle = torch._C._gds_register_handle(self.fd)
def deregister_handle(self) -> None:
"""Deregisters file descriptor from cuFile Driver.
This is a wrapper around ``cuFileHandleDeregister``.
"""
assert (
self.handle is not None
), "Cannot deregister a handle that is not registered."
torch._C._gds_deregister_handle(self.handle)
self.handle = None
def load_storage(self, storage: Storage, offset: int = 0) -> None:
"""Loads data from the file into the storage.
This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data
will be loaded from the file at ``offset`` into the storage.
Args:
storage (Storage): Storage to load data into.
offset (int, optional): Offset into the file to start loading from. (Default: 0)
"""
assert (
self.handle is not None
), "Cannot load data from a file that is not registered."
torch._C._gds_load_storage(self.handle, storage, offset)
def save_storage(self, storage: Storage, offset: int = 0) -> None:
"""Saves data from the storage into the file.
This is a wrapper around ``cuFileWrite``. All bytes of the storage
will be written to the file at ``offset``.
Args:
storage (Storage): Storage to save data from.
offset (int, optional): Offset into the file to start saving to. (Default: 0)
"""
assert (
self.handle is not None
), "Cannot save data to a file that is not registered."
torch._C._gds_save_storage(self.handle, storage, offset)