mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes ##159399 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160222 Approved by: https://github.com/janeyx99
101 lines
3.0 KiB
Python
101 lines
3.0 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import ctypes
|
|
|
|
import torch
|
|
from torch._inductor.async_compile import AsyncCompile
|
|
from torch._inductor.codecache import CUDACodeCache
|
|
from torch._inductor.codegen.cuda.cuda_env import nvcc_exist
|
|
from torch._inductor.exc import CUDACompileError
|
|
from torch._inductor.test_case import TestCase as InductorTestCase
|
|
from torch._inductor.utils import fresh_cache
|
|
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
|
|
|
|
|
_SOURCE_CODE = r"""
|
|
|
|
#include <stdio.h>
|
|
|
|
__global__
|
|
void saxpy_device(int n, float a, float *x, float *y)
|
|
{
|
|
int i = blockIdx.x*blockDim.x + threadIdx.x;
|
|
if (i < n) y[i] = a*x[i] + y[i];
|
|
}
|
|
|
|
extern "C" {
|
|
|
|
__attribute__((__visibility__("default")))
|
|
int saxpy(int n, float a, float *x, float *y) {
|
|
// Perform SAXPY
|
|
saxpy_device<<<(n+255)/256, 256>>>(n, a, x, y);
|
|
return 0;
|
|
}
|
|
|
|
}
|
|
"""
|
|
|
|
|
|
class TestCUDACodeCache(InductorTestCase):
|
|
@requires_cuda_and_triton
|
|
def test_cuda_load(self):
|
|
with fresh_cache():
|
|
# Test both .o and .so compilation.
|
|
(
|
|
object_file_path,
|
|
object_hash_key,
|
|
source_code_path0,
|
|
) = CUDACodeCache.compile(_SOURCE_CODE, "o")
|
|
dll_wrapper, so_hash_key, source_code_path1 = CUDACodeCache.load(
|
|
_SOURCE_CODE, "so"
|
|
)
|
|
self.assertEqual(source_code_path0, source_code_path1)
|
|
self.assertEqual(object_hash_key, so_hash_key)
|
|
|
|
# Test load and call functions in .so.
|
|
x = torch.rand(10).float().cuda()
|
|
y = torch.rand(10).float().cuda()
|
|
a = 5.0
|
|
expected_y = a * x + y
|
|
dll_wrapper.saxpy(
|
|
ctypes.c_int(10),
|
|
ctypes.c_float(a),
|
|
ctypes.c_void_p(x.data_ptr()),
|
|
ctypes.c_void_p(y.data_ptr()),
|
|
)
|
|
torch.testing.assert_close(y, expected_y)
|
|
|
|
@requires_cuda_and_triton
|
|
def test_compilation_error(self):
|
|
with fresh_cache():
|
|
error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1)
|
|
with self.assertRaises(CUDACompileError):
|
|
CUDACodeCache.compile(error_source_code, "o")
|
|
|
|
@requires_cuda_and_triton
|
|
def test_async_compile(self):
|
|
with fresh_cache():
|
|
async_compile = AsyncCompile()
|
|
compiled_res = async_compile.cuda(_SOURCE_CODE, "so")
|
|
async_compile.wait(globals())
|
|
|
|
# Test load and call functions in .so.
|
|
x = torch.rand(5).float().cuda()
|
|
y = torch.rand(5).float().cuda()
|
|
a = 2.0
|
|
expected_y = a * x + y
|
|
compiled_res.result().saxpy(
|
|
ctypes.c_int(5),
|
|
ctypes.c_float(a),
|
|
ctypes.c_void_p(x.data_ptr()),
|
|
ctypes.c_void_p(y.data_ptr()),
|
|
)
|
|
torch.testing.assert_close(y, expected_y)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
if nvcc_exist():
|
|
run_tests("cuda")
|