mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-12 06:44:55 +08:00
Enabling StaticCudaLauncher for ROCm
This commit is contained in:
committed by
PyTorch MergeBot
parent
341e924981
commit
2691b25b6a
@ -12,7 +12,6 @@ from torch._inductor.runtime.static_cuda_launcher import StaticallyLaunchedCudaK
|
||||
from torch._inductor.runtime.triton_compat import CompiledKernel, tl, triton
|
||||
from torch._inductor.runtime.triton_helpers import libdevice
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.common_utils import skipIfRocm
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
@ -39,8 +38,9 @@ class TestStaticCudaLauncher(TestCase):
|
||||
# Just used by tests for now.
|
||||
# TODO: derive cubin_path from wherever triton stores the cubin file on disk.
|
||||
tmp_file = tempfile.NamedTemporaryFile(mode="wb", delete=False)
|
||||
binary_key = "hsaco" if torch.version.hip else "cubin"
|
||||
with tmp_file:
|
||||
tmp_file.write(kernel.asm["cubin"])
|
||||
tmp_file.write(kernel.asm[binary_key])
|
||||
self.tmp_files.append(tmp_file)
|
||||
return tmp_file.name
|
||||
|
||||
@ -64,7 +64,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
result.load_kernel(device_interface.current_device())
|
||||
return result
|
||||
|
||||
@skipIfRocm
|
||||
def test_basic(self):
|
||||
@triton.jit
|
||||
def simple_kernel(arg0, arg1):
|
||||
@ -91,7 +90,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
# 2. triton relies on inspect.get_source to get the type annotations
|
||||
# so I can't even use exec() to generate the test cases.
|
||||
# So we'll just make a few kernels by hand
|
||||
@skipIfRocm
|
||||
def test_unsigned_integers(self):
|
||||
@triton.jit
|
||||
def unsigned_integers(
|
||||
@ -115,7 +113,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
@skipIfRocm
|
||||
def test_signed_integers(self):
|
||||
@triton.jit
|
||||
def signed_integers(
|
||||
@ -139,7 +136,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
@skipIfRocm
|
||||
def test_basic_1arg(self):
|
||||
@triton.jit
|
||||
def simple_kernel_1_arg(arg0):
|
||||
@ -164,7 +160,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
@skipIfRocm
|
||||
def test_constexpr(self):
|
||||
# Constexprs are compiled directly into the cubin file,
|
||||
# so we never need to pass it to StaticCudaLauncher.
|
||||
@ -193,7 +188,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
@skipIfRocm
|
||||
def test_implied_constant(self):
|
||||
"""xnumel is unused in this kernel, but isn't explicitly marked as a constexpr"""
|
||||
|
||||
@ -246,7 +240,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
launcher.run(1, 1, 1, stream, arg0, arg2, 128)
|
||||
self.assertEqual(arg1, arg2)
|
||||
|
||||
@skipIfRocm
|
||||
def test_kernel_no_args(self):
|
||||
# Just an easy way to test incompatible number of arguments
|
||||
@triton.jit
|
||||
@ -259,7 +252,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
stream = device_interface.get_raw_stream(device_interface.current_device())
|
||||
launcher.run(1, 1, 1, stream)
|
||||
|
||||
@skipIfRocm
|
||||
def test_high_shared_mem(self):
|
||||
@triton.jit
|
||||
def simple_kernel(arg0, arg1):
|
||||
@ -283,7 +275,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
launcher.run(1, 1, 1, stream, new_arg0, arg1)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
@skipIfRocm
|
||||
def test_too_high_shared_mem(self):
|
||||
@triton.jit
|
||||
def simple_kernel(arg0, arg1):
|
||||
@ -303,7 +294,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
lambda: self._make_launcher(compiled_kernel),
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
def test_kernel_empty_tensor(self):
|
||||
# Triton kernel generated by torch.compile of the following:
|
||||
# @torch.compile()
|
||||
@ -364,7 +354,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
launcher.run(1, 1, 1, stream, arg1, arg2, buf1, arg0, xnumel)
|
||||
self.assertEqual(buf0, buf1)
|
||||
|
||||
@skipIfRocm
|
||||
def test_kernel_many_args(self):
|
||||
N = 200
|
||||
# Make 200 arguments
|
||||
@ -405,7 +394,6 @@ class TestStaticTritonCompileResult(TestCase):
|
||||
Tests static cuda launcher with torch.compile()
|
||||
"""
|
||||
|
||||
@skipIfRocm
|
||||
def test_basic_compile(self):
|
||||
@torch.compile
|
||||
def foo(x, y):
|
||||
@ -415,7 +403,6 @@ class TestStaticTritonCompileResult(TestCase):
|
||||
y = torch.randn(10, device="cuda")
|
||||
self.assertEqual(foo(x, y), x + y)
|
||||
|
||||
@skipIfRocm
|
||||
# The error gets raised on a worker, so we want to not use a separate process
|
||||
@torch._inductor.config.patch("compile_threads", 1)
|
||||
def test_incompatible_code(self):
|
||||
@ -438,7 +425,6 @@ class TestStaticTritonCompileResult(TestCase):
|
||||
lambda: foo(x),
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
# The error gets raised on a worker, so we want to not use a separate process
|
||||
@torch._inductor.config.patch(
|
||||
{"compile_threads": 1, "static_launch_user_defined_triton_kernels": True}
|
||||
@ -460,7 +446,6 @@ class TestStaticTritonCompileResult(TestCase):
|
||||
x2 = x.clone().detach_()
|
||||
self.assertEqual(foo(x), x2 + 5)
|
||||
|
||||
@skipIfRocm
|
||||
def test_empty_tensor(self):
|
||||
@torch.compile()
|
||||
def foo(x, y):
|
||||
@ -472,7 +457,6 @@ class TestStaticTritonCompileResult(TestCase):
|
||||
result = foo(x, y)
|
||||
self.assertEqual(result, torch.cat(((x * 4), y + 10)))
|
||||
|
||||
@skipIfRocm
|
||||
def test_any(self):
|
||||
def fn(x):
|
||||
return (
|
||||
@ -492,7 +476,6 @@ class TestStaticTritonCompileResult(TestCase):
|
||||
compiled_result = compiled_fn(arg)
|
||||
self.assertEqual(eager_result, compiled_result)
|
||||
|
||||
@skipIfRocm
|
||||
def test_disable_static_cuda_launcher(self):
|
||||
@torch.compile
|
||||
def fn(x, y):
|
||||
|
||||
@ -37,9 +37,17 @@ class StaticallyLaunchedCudaKernel:
|
||||
def __init__(self, kernel: CompiledKernel) -> None:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.name = kernel.src.fn.__name__
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.cubin_raw = kernel.asm.get("cubin", None)
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if "hsaco" in kernel.asm:
|
||||
self.cubin_raw = kernel.asm["hsaco"]
|
||||
self.is_rocm = True
|
||||
elif "cubin" in kernel.asm:
|
||||
self.cubin_raw = kernel.asm["cubin"]
|
||||
self.is_rocm = False
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Expected either 'hsaco' (ROCm) or 'cubin' (CUDA) in kernel.asm"
|
||||
)
|
||||
|
||||
self.cubin_path = kernel._cubin_path
|
||||
|
||||
# Used by torch.compile to filter constants in older triton versions
|
||||
@ -245,13 +253,42 @@ class StaticallyLaunchedCudaKernel:
|
||||
# thing, it should always match.
|
||||
# Get rid of constants before passing to cubin launcher
|
||||
|
||||
# Add a None if triton wants extra parameters for scratch spaces
|
||||
arg_tys = self.arg_tys
|
||||
for has_scratch in [self.has_global_scratch, self.has_profile_scratch]:
|
||||
if has_scratch:
|
||||
arg_tys = arg_tys + "O"
|
||||
args = (*args, None)
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
|
||||
if self.is_rocm:
|
||||
# ROCm/HIP kernel ABI: The Triton HIP backend ALWAYS includes both
|
||||
# global_scratch and profile_scratch parameters in the kernel signature,
|
||||
# even when the kernel doesn't use them (i.e., when has_*_scratch is False).
|
||||
#
|
||||
# This differs fundamentally from CUDA, where these parameters are only
|
||||
# present in the signature if the corresponding has_*_scratch flag is True.
|
||||
#
|
||||
# The flags indicate whether memory will be allocated/used:
|
||||
# - has_global_scratch: Whether global scratch workspace is needed
|
||||
# - has_profile_scratch: Whether profiling instrumentation is enabled
|
||||
#
|
||||
# However, regardless of flag values, we MUST always pass both parameters
|
||||
# to match the HIP kernel ABI. Passing None is safe:
|
||||
#
|
||||
# - If scratch is not needed (has_*_scratch=False or scratch_size=0):
|
||||
# The None becomes nullptr, which the kernel never dereferences
|
||||
#
|
||||
# - If scratch is needed (has_*_scratch=True and scratch_size>0):
|
||||
# The None becomes nullptr initially, but the HIP runtime intercepts
|
||||
# the kernel launch, allocates the required scratch memory based on
|
||||
# kernel metadata, and replaces the nullptr with a valid pointer before
|
||||
# the kernel actually executes
|
||||
#
|
||||
# Not passing both parameters causes segmentation faults because the kernel
|
||||
# expects them at specific positions in the argument array.
|
||||
arg_tys = arg_tys + "OO"
|
||||
args = (*args, None, None)
|
||||
|
||||
else:
|
||||
for has_scratch in [self.has_global_scratch, self.has_profile_scratch]:
|
||||
if has_scratch:
|
||||
arg_tys = arg_tys + "O"
|
||||
args = (*args, None)
|
||||
assert len(args) == len(arg_tys)
|
||||
|
||||
# TODO: can handle grid functions here or in C++, so
|
||||
|
||||
@ -1593,9 +1593,8 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
|
||||
return None
|
||||
|
||||
def check_can_launch() -> StaticallyLaunchedCudaKernel:
|
||||
if triton_meta.get("device_type") != "cuda":
|
||||
# Only cuda kernels
|
||||
raise CannotStaticallyLaunchKernel("Non-cuda device")
|
||||
if triton_meta.get("device_type") not in ("cuda", "hip"):
|
||||
raise CannotStaticallyLaunchKernel("Non-cuda/ROCm device")
|
||||
|
||||
if torch._inductor.config.cpp_wrapper:
|
||||
# If we're running with cpp wrapper, it doesn't
|
||||
@ -1621,10 +1620,11 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
|
||||
"static launch does not support launch attributes"
|
||||
)
|
||||
|
||||
binary_ext = "hsaco" if triton_meta.get("device_type") == "hip" else "cubin"
|
||||
cubin_location = os.path.join(
|
||||
triton_cache_dir(triton_meta.get("device", 0)),
|
||||
triton_hash_to_path_key(kernel.hash),
|
||||
f"{kernel.src.fn.__name__}.cubin",
|
||||
f"{kernel.src.fn.__name__}.{binary_ext}",
|
||||
)
|
||||
|
||||
if not os.path.exists(cubin_location):
|
||||
|
||||
@ -2159,7 +2159,7 @@ PyObject* initModule() {
|
||||
#ifdef USE_CUDA
|
||||
torch::cuda::initModule(module);
|
||||
#endif
|
||||
#if defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
#if defined(USE_CUDA)
|
||||
ASSERT_TRUE(StaticCudaLauncher_init(module));
|
||||
#endif
|
||||
#ifdef USE_MPS
|
||||
|
||||
@ -1,7 +1,4 @@
|
||||
#if defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
// We disable this file from being hipified because there are CUDA drivers hip
|
||||
// has not implemented yet. Also, we're passing in a cubin file directly, so it
|
||||
// would take more work to support ROCM anyway.
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
#include <torch/csrc/utils/pythoncapi_compat.h>
|
||||
|
||||
#include <ATen/Context.h>
|
||||
@ -16,6 +13,11 @@
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
#include <filesystem>
|
||||
#include <optional>
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#endif
|
||||
|
||||
/**
|
||||
Implements a static launcher for triton compiled CUDA kernels.
|
||||
Given a path to a cubin file, a function name, and some metadata,
|
||||
@ -56,8 +58,15 @@ const at::cuda::NVRTC& nvrtc() {
|
||||
|
||||
CUdeviceptr getPointer(PyObject* obj) {
|
||||
CUdeviceptr data_ptr = 0;
|
||||
|
||||
if (THPUtils_checkLong(obj)) {
|
||||
data_ptr = THPUtils_unpackUInt64(obj);
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
data_ptr = reinterpret_cast<hipDeviceptr_t>(THPUtils_unpackUInt64(obj));
|
||||
#else
|
||||
data_ptr = THPUtils_unpackUInt64(obj);
|
||||
#endif
|
||||
|
||||
return data_ptr;
|
||||
}
|
||||
if (obj == Py_None) {
|
||||
@ -73,13 +82,25 @@ CUdeviceptr getPointer(PyObject* obj) {
|
||||
TORCH_CHECK(
|
||||
THPUtils_checkLong(ret),
|
||||
"data_ptr method of Pointer object must return 64-bit int");
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
data_ptr = reinterpret_cast<hipDeviceptr_t>(THPUtils_unpackUInt64(ret));
|
||||
#else
|
||||
data_ptr = THPUtils_unpackUInt64(ret);
|
||||
#endif
|
||||
|
||||
if (!data_ptr)
|
||||
return data_ptr;
|
||||
|
||||
CUdeviceptr dev_ptr = 0;
|
||||
#if defined(USE_ROCM)
|
||||
AT_CUDA_DRIVER_CHECK(hipPointerGetAttribute(
|
||||
&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr));
|
||||
#else
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuPointerGetAttribute(
|
||||
&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr));
|
||||
#endif
|
||||
|
||||
return dev_ptr;
|
||||
}
|
||||
|
||||
@ -98,6 +119,18 @@ CUfunction loadKernel(
|
||||
}
|
||||
CUmodule mod = nullptr;
|
||||
CUfunction func = nullptr;
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
AT_CUDA_DRIVER_CHECK(hipModuleLoad(&mod, filePath.c_str()));
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
hipModuleGetFunction(&func, mod, funcName.c_str()));
|
||||
int shared_optin = 0;
|
||||
AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute(
|
||||
&shared_optin,
|
||||
hipDeviceAttributeSharedMemPerBlockOptin,
|
||||
device));
|
||||
|
||||
#else
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoad(&mod, filePath.c_str()));
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
nvrtc().cuModuleGetFunction(&func, mod, funcName.c_str()));
|
||||
@ -106,6 +139,9 @@ CUfunction loadKernel(
|
||||
&shared_optin,
|
||||
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
|
||||
#endif
|
||||
|
||||
// Shared memory logic from triton/third-party/nvidia/backend/driver.c
|
||||
// If we're using more than 48 KB of shared memory, and we have
|
||||
// access to more than 48 KB of shared memory on the device,
|
||||
@ -124,6 +160,23 @@ CUfunction loadKernel(
|
||||
" Reducing block sizes or `num_stages` may help.");
|
||||
if (sharedMemBytes > SHARED_MEM_STATIC_MAX &&
|
||||
shared_optin > SHARED_MEM_STATIC_MAX) {
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
hipFuncSetCacheConfig(func, hipFuncCachePreferShared));
|
||||
int shared_total = 0, shared_static = 0;
|
||||
AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute(
|
||||
&shared_total,
|
||||
hipDeviceAttributeMaxSharedMemoryPerMultiprocessor,
|
||||
device));
|
||||
AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute(
|
||||
&shared_static, HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func));
|
||||
AT_CUDA_DRIVER_CHECK(hipFuncSetAttribute(
|
||||
func,
|
||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static));
|
||||
|
||||
#else
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
nvrtc().cuFuncSetCacheConfig(func, CU_FUNC_CACHE_PREFER_SHARED));
|
||||
int shared_total = 0, shared_static = 0;
|
||||
@ -136,7 +189,8 @@ CUfunction loadKernel(
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncSetAttribute(
|
||||
func,
|
||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static));
|
||||
shared_optin - shared_static));
|
||||
#endif
|
||||
}
|
||||
return func;
|
||||
}
|
||||
@ -152,6 +206,27 @@ inline void launchKernel(
|
||||
cudaStream_t stream) {
|
||||
// cta_args is always 1 for inductor generated triton kernels,
|
||||
// so we don't need to figure out grid dimension here
|
||||
#if defined(USE_ROCM)
|
||||
int device = 0;
|
||||
AT_CUDA_DRIVER_CHECK(hipGetDevice(&device));
|
||||
int warp_size = 0;
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, device));
|
||||
|
||||
AT_CUDA_DRIVER_CHECK(hipModuleLaunchKernel(
|
||||
func,
|
||||
gridX,
|
||||
gridY,
|
||||
gridZ,
|
||||
warp_size * numWarps, // blockDim.x
|
||||
1, // blockDim.y
|
||||
1, // blockDim.z
|
||||
sharedMemBytes,
|
||||
stream,
|
||||
args,
|
||||
nullptr));
|
||||
|
||||
#else
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
|
||||
func,
|
||||
gridX,
|
||||
@ -164,6 +239,7 @@ inline void launchKernel(
|
||||
stream,
|
||||
args,
|
||||
nullptr));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename FINAL, typename F>
|
||||
@ -269,11 +345,21 @@ PyObject* load_kernel(PyObject* self, PyObject* args) {
|
||||
CUdevice device = static_cast<CUdevice>(device_ptr); // NOLINT
|
||||
CUfunction func = nullptr;
|
||||
func = loadKernel(filePath, funcName, sharedMemBytes, device);
|
||||
// Taken from triton/nvidia/backend/driver.c
|
||||
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, func));
|
||||
AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute(
|
||||
&n_spills, HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func));
|
||||
|
||||
#else
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
nvrtc().cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, func));
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncGetAttribute(
|
||||
&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func));
|
||||
|
||||
#endif
|
||||
n_spills /= 4;
|
||||
// Return a tuple of CUFunction, n_regs, n_spills
|
||||
return Py_BuildValue(
|
||||
@ -299,7 +385,6 @@ PyObject* launch_kernel_inner(
|
||||
std::array<uint64_t, MAX_ARGS> argStorage = {};
|
||||
std::array<void*, MAX_ARGS> kernelArgs = {};
|
||||
parseKernelArgs(varArgs, argTypes, argStorage.data(), kernelArgs.data());
|
||||
|
||||
launchKernel(
|
||||
func,
|
||||
gridX,
|
||||
@ -386,13 +471,26 @@ PyObject* launch_kernel(PyObject* self, PyObject* args) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
CUcontext pctx = nullptr;
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
|
||||
#if defined(USE_ROCM)
|
||||
AT_CUDA_DRIVER_CHECK(hipCtxGetCurrent(&pctx));
|
||||
#else
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
|
||||
#endif
|
||||
|
||||
if (!pctx) {
|
||||
// Ensure device context exists
|
||||
CUdevice device = 0;
|
||||
#if defined(USE_ROCM)
|
||||
AT_CUDA_DRIVER_CHECK(hipDeviceGet(&device, 0));
|
||||
AT_CUDA_DRIVER_CHECK(hipDevicePrimaryCtxRetain(&pctx, device));
|
||||
AT_CUDA_DRIVER_CHECK(hipCtxSetCurrent(pctx));
|
||||
#else
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuDeviceGet(&device, 0));
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuDevicePrimaryCtxRetain(&pctx, device));
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxSetCurrent(pctx));
|
||||
|
||||
#endif
|
||||
|
||||
}
|
||||
CUfunction func = reinterpret_cast<CUfunction>(func_ptr); // NOLINT
|
||||
cudaStream_t cudaStream = reinterpret_cast<cudaStream_t>(stream); // NOLINT
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#if defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
#if defined(USE_CUDA)
|
||||
#include <torch/csrc/inductor/cpp_wrapper/device_internal/cuda.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user