mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR allows user to author a CUDA kernel in python. ``` from torch.cuda.jiterator import create_jit_fn code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x * y + x - y + alpha; }" jitted_fn = create_jit_fn(code_string, alpha=0) a = torch.rand(3, device='cuda') b = torch.rand(3, device='cuda') result = jitted_fn(a, b, alpha=1.0) ``` Limitations: - Only supports elementwise kernel - 1~8 tensor inputs (empty input, e.g. factory methods, is not supported) - inputs tensors must live in cuda device - cpu Scalar is not supported - kwargs must be pre-declared when calling create_jit_fn - kwargs must be convertible to at::Scalar, one of float64, int64_t, bool. (complex not support for now) TODOs: - [x] consolidate union and c10::variant implementation - [x] plug into existing op testing framework - [ ] rename files, place files in the right folder - [ ] place util functions in the right file - [x] enforce assumptions in python interface e.g <8 inputs, kwargs types - [x] Add user-facing documentation Pull Request resolved: https://github.com/pytorch/pytorch/pull/76394 Approved by: https://github.com/mruberry
133 lines
5.0 KiB
Python
133 lines
5.0 KiB
Python
# Owner(s): ["module: cuda"]
|
|
|
|
import torch
|
|
from torch.cuda.jiterator import _create_jit_fn as create_jit_fn
|
|
import sys
|
|
from itertools import product
|
|
from torch.testing._internal.common_utils import TestCase, parametrize, run_tests, TEST_CUDA
|
|
from torch.testing._internal.common_dtype import all_types_and_complex_and
|
|
from torch.testing._internal.common_device_type import (
|
|
skipCUDAIfRocm, skipCUDAIf, instantiate_device_type_tests, dtypes, toleranceOverride, tol)
|
|
from torch.testing._internal.common_cuda import _get_torch_cuda_version
|
|
|
|
if not TEST_CUDA:
|
|
print('CUDA not available, skipping tests', file=sys.stderr)
|
|
TestCase = object # noqa: F811
|
|
|
|
|
|
code_string = "template <typename T> T my_fused_kernel(T x, T y, T alpha, T beta) { return alpha * x + beta * y; }"
|
|
jitted_fn = create_jit_fn(code_string, alpha=1, beta=1)
|
|
|
|
def ref_fn(x, y, alpha=1, beta=1):
|
|
return alpha * x + beta * y
|
|
|
|
class TestPythonJiterator(TestCase):
|
|
@skipCUDAIfRocm
|
|
@parametrize("shape_strides", [
|
|
(([3, 3], [3, 1]), ([3, 3], [3, 1])), # contiguous
|
|
])
|
|
@dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
|
|
all_types_and_complex_and(torch.half, torch.bfloat16)))
|
|
def test_all_dtype_contiguous(self, device, dtypes, shape_strides):
|
|
a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
|
|
b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])
|
|
|
|
a = a_buffer.as_strided(*shape_strides[0])
|
|
b = b_buffer.as_strided(*shape_strides[1])
|
|
|
|
expected = ref_fn(a, b)
|
|
result = jitted_fn(a, b)
|
|
|
|
self.assertEqual(expected, result)
|
|
|
|
@skipCUDAIfRocm
|
|
# See https://github.com/pytorch/pytorch/pull/76394#issuecomment-1118018287 for details
|
|
@skipCUDAIf(_get_torch_cuda_version() < (11, 6), "On cuda 11.3, nvrtcCompileProgram is taking too long to "
|
|
"compile jiterator generated kernels for non-contiguous input that requires dynamic-casting.")
|
|
@parametrize("shape_strides", [
|
|
(([3, 3], [1, 3]), ([3, 1], [1, 3])), # non-contiguous
|
|
])
|
|
@dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
|
|
all_types_and_complex_and(torch.half, torch.bfloat16)))
|
|
def test_all_dtype_noncontiguous(self, device, dtypes, shape_strides):
|
|
a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
|
|
b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])
|
|
|
|
a = a_buffer.as_strided(*shape_strides[0])
|
|
b = b_buffer.as_strided(*shape_strides[1])
|
|
|
|
expected = ref_fn(a, b)
|
|
result = jitted_fn(a, b)
|
|
|
|
self.assertEqual(expected, result)
|
|
|
|
@skipCUDAIfRocm
|
|
@dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
|
|
@parametrize("alpha", [-1, 2.0, None])
|
|
@parametrize("beta", [3, -4.2, None])
|
|
@toleranceOverride({torch.float16 : tol(atol=1e-2, rtol=1e-3)})
|
|
def test_extra_args(self, device, dtype, alpha, beta):
|
|
a = torch.rand(3, device=device).mul(10).type(dtype)
|
|
b = torch.rand(3, device=device).mul(10).type(dtype)
|
|
|
|
extra_args = {}
|
|
if alpha is not None:
|
|
extra_args["alpha"] = alpha
|
|
if beta is not None:
|
|
extra_args["beta"] = beta
|
|
|
|
expected = ref_fn(a, b, **extra_args)
|
|
result = jitted_fn(a, b, **extra_args)
|
|
|
|
self.assertEqual(expected, result)
|
|
|
|
@skipCUDAIfRocm
|
|
def test_bool_extra_args(self, device):
|
|
code_string = "template <typename T> T conditional(T x, T mask, bool is_train) { return is_train ? x * mask : x; }"
|
|
jitted_fn = create_jit_fn(code_string, is_train=False)
|
|
|
|
def ref_fn(x, mask, is_train):
|
|
return x * mask if is_train else x
|
|
|
|
a = torch.rand(3, device=device)
|
|
b = torch.rand(3, device=device)
|
|
|
|
expected = ref_fn(a, b, is_train=True)
|
|
result = jitted_fn(a, b, is_train=True)
|
|
self.assertEqual(expected, result)
|
|
|
|
@skipCUDAIfRocm
|
|
@parametrize("num_inputs", list(range(1, 9)))
|
|
def test_various_num_inputs(self, num_inputs):
|
|
inputs = []
|
|
for i in range(num_inputs):
|
|
inputs.append(torch.rand(3, device='cuda').mul(10))
|
|
|
|
input_string = ",".join([f"T i{i}" for i in range(num_inputs)])
|
|
function_body = "+".join([f"i{i}" for i in range(num_inputs)])
|
|
code_string = f"template <typename T> T my_kernel({input_string}) {{ return {function_body}; }}"
|
|
jitted_fn = create_jit_fn(code_string)
|
|
|
|
def ref_fn(*inputs):
|
|
return torch.sum(torch.stack(inputs), dim=0)
|
|
|
|
expected = ref_fn(*inputs)
|
|
result = jitted_fn(*inputs)
|
|
|
|
self.assertEqual(expected, result)
|
|
|
|
@skipCUDAIfRocm
|
|
@parametrize("code_string", [
|
|
"template <typename T> T my _kernel(T x) { return x; }",
|
|
"template <typename T> Tmy_kernel(T x) { return x; }",
|
|
])
|
|
def test_invalid_function_name(self, code_string):
|
|
with self.assertRaises(Exception):
|
|
jitted_fn = create_jit_fn(code_string)
|
|
|
|
|
|
instantiate_device_type_tests(TestPythonJiterator, globals(), only_for="cuda")
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|