mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
make TORCH_(CUDABLAS|CUSOLVER)_CHECK usable in custom extensions (#67161)
Summary: Make `TORCH_CUDABLAS_CHECK` and `TORCH_CUSOLVER_CHECK` available in custom extensions by exporting the internal functions called by the both macros. Rel: https://github.com/pytorch/pytorch/issues/67073 cc xwang233 ptrblck Pull Request resolved: https://github.com/pytorch/pytorch/pull/67161 Reviewed By: jbschlosser Differential Revision: D31984694 Pulled By: ngimel fbshipit-source-id: 0035ecd1398078cf7d3abc23aaefda57aaa31106
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ad89d994c9
commit
d4493b27ee
@ -82,6 +82,24 @@ class TestCppExtensionAOT(common.TestCase):
|
||||
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
||||
self.assertEqual(z, torch.ones_like(z))
|
||||
|
||||
@common.skipIfRocm
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
def test_cublas_extension(self):
|
||||
from torch_test_cpp_extension import cublas_extension
|
||||
|
||||
x = torch.zeros(100, device="cuda", dtype=torch.float32)
|
||||
z = cublas_extension.noop_cublas_function(x)
|
||||
self.assertEqual(z, x)
|
||||
|
||||
@common.skipIfRocm
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
def test_cusolver_extension(self):
|
||||
from torch_test_cpp_extension import cusolver_extension
|
||||
|
||||
x = torch.zeros(100, device="cuda", dtype=torch.float32)
|
||||
z = cusolver_extension.noop_cusolver_function(x)
|
||||
self.assertEqual(z, x)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Not available on Windows")
|
||||
def test_no_python_abi_suffix_sets_the_correct_library_name(self):
|
||||
# For this test, run_test.py will call `python setup.py install` in the
|
||||
|
||||
Reference in New Issue
Block a user