mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Added unit test to test the cuda_pluggable allocator (#154041)
Added unit test to include the cuda_pluggable allocator and replicate the apex setup.py to build nccl_allocator extension This test to check if this commit https://github.com/pytorch/pytorch/pull/152179 helps to build the cuda pluggable allocator in Rocm/Apex Pull Request resolved: https://github.com/pytorch/pytorch/pull/154041 Approved by: https://github.com/atalman, https://github.com/jeffdaily Co-authored-by: Jithun Nair <jithun.nair@amd.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
5b8f422561
commit
c2660d29a5
@ -1194,6 +1194,48 @@ class TestCppExtensionJIT(common.TestCase):
|
||||
self.assertEqual(abs_t, torch.abs(t))
|
||||
self.assertEqual(floor_t, torch.floor(t))
|
||||
|
||||
@unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found")
|
||||
def test_cuda_pluggable_allocator_include(self):
|
||||
"""
|
||||
This method creates a minimal example to replicate the apex setup.py to build nccl_allocator extension
|
||||
"""
|
||||
|
||||
# the cpp source includes CUDAPluggableAllocator and has an empty exported function
|
||||
cpp_source = """
|
||||
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
|
||||
#include <torch/extension.h>
|
||||
int get_nccl_allocator() {
|
||||
return 0;
|
||||
}
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("get_nccl_allocator", []() { return get_nccl_allocator(); });
|
||||
}
|
||||
"""
|
||||
|
||||
build_dir = tempfile.mkdtemp()
|
||||
src_path = os.path.join(build_dir, "NCCLAllocator.cpp")
|
||||
|
||||
with open(src_path, mode="w") as f:
|
||||
f.write(cpp_source)
|
||||
|
||||
# initially success is false
|
||||
success = False
|
||||
try:
|
||||
# try to build the module
|
||||
torch.utils.cpp_extension.load(
|
||||
name="nccl_allocator",
|
||||
sources=src_path,
|
||||
verbose=True,
|
||||
with_cuda=True,
|
||||
)
|
||||
# set success as true if built successfully
|
||||
success = True
|
||||
except Exception as e:
|
||||
print(f"Failed to load the module: {e}")
|
||||
|
||||
# test if build was successful
|
||||
self.assertEqual(success, True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.run_tests()
|
||||
|
||||
Reference in New Issue
Block a user