From c2660d29a5185cf5f24aa280ab3edbf29b960431 Mon Sep 17 00:00:00 2001 From: skishore Date: Thu, 22 May 2025 18:22:15 +0000 Subject: [PATCH] [ROCm] Added unit test to test the cuda_pluggable allocator (#154041) Added unit test to include the cuda_pluggable allocator and replicate the apex setup.py to build nccl_allocator extension This test to check if this commit https://github.com/pytorch/pytorch/pull/152179 helps to build the cuda pluggable allocator in Rocm/Apex Pull Request resolved: https://github.com/pytorch/pytorch/pull/154041 Approved by: https://github.com/atalman, https://github.com/jeffdaily Co-authored-by: Jithun Nair --- test/test_cpp_extensions_jit.py | 42 +++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index ef4e5b5cb1e3..dc7269c865f9 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -1194,6 +1194,48 @@ class TestCppExtensionJIT(common.TestCase): self.assertEqual(abs_t, torch.abs(t)) self.assertEqual(floor_t, torch.floor(t)) + @unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found") + def test_cuda_pluggable_allocator_include(self): + """ + This method creates a minimal example to replicate the apex setup.py to build nccl_allocator extension + """ + + # the cpp source includes CUDAPluggableAllocator and has an empty exported function + cpp_source = """ + #include + #include + int get_nccl_allocator() { + return 0; + } + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("get_nccl_allocator", []() { return get_nccl_allocator(); }); + } + """ + + build_dir = tempfile.mkdtemp() + src_path = os.path.join(build_dir, "NCCLAllocator.cpp") + + with open(src_path, mode="w") as f: + f.write(cpp_source) + + # initially success is false + success = False + try: + # try to build the module + torch.utils.cpp_extension.load( + name="nccl_allocator", + sources=src_path, + verbose=True, + with_cuda=True, + ) + # set success as true if built successfully + success = True + except Exception as e: + print(f"Failed to load the module: {e}") + + # test if build was successful + self.assertEqual(success, True) + if __name__ == "__main__": common.run_tests()