[AOTI XPU] Enable Cpp wraper for Intel GPU. (#135318)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135318
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/guangyey, https://github.com/desertfire
This commit is contained in:
xinan.lin
2024-11-25 21:41:28 -08:00
committed by PyTorch MergeBot
parent c418a9ac75
commit 4742080ed9
19 changed files with 432 additions and 49 deletions

View File

@ -263,6 +263,7 @@ exclude_patterns = [
'torch/csrc/jit/**/*',
'torch/csrc/jit/serialization/mobile_bytecode_generated.h',
'torch/csrc/utils/pythoncapi_compat.h',
'torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h',
]
init_command = [
'python3',

View File

@ -792,6 +792,7 @@ libtorch_python_xpu_sources = [
"torch/csrc/xpu/Event.cpp",
"torch/csrc/xpu/Module.cpp",
"torch/csrc/xpu/Stream.cpp",
"torch/csrc/inductor/aoti_torch/shim_xpu.cpp",
]
libtorch_python_core_sources = [

View File

@ -1116,6 +1116,10 @@ if(USE_XPU)
# Set cached ${ATen_XPU_INCLUDE_DIRS} to torch
include_directories(SYSTEM ${ATen_XPU_INCLUDE_DIRS})
message(INFO "Install ${TORCH_XPU_OPS_DIR}/src/ATen/xpu to ${TORCH_INSTALL_INCLUDE_DIR}/ATen/xpu")
install(DIRECTORY "${TORCH_XPU_OPS_DIR}/src/ATen/xpu"
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/ATen/
FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp")
endif()
endif()

View File

@ -1289,6 +1289,7 @@ def main():
"include/torch/csrc/inductor/aoti_torch/*.h",
"include/torch/csrc/inductor/aoti_torch/c/*.h",
"include/torch/csrc/inductor/aoti_torch/generated/*.h",
"include/torch/csrc/inductor/aoti_torch/generated/extend/*.h",
"include/torch/csrc/jit/*.h",
"include/torch/csrc/jit/backends/*.h",
"include/torch/csrc/jit/generated/*.h",

View File

@ -7,11 +7,12 @@ from typing import NamedTuple
import torch
from torch._inductor import config
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import is_gpu
from torch.testing._internal.common_device_type import (
get_desired_device_type_test_bases,
)
from torch.testing._internal.common_utils import slowTest, TEST_WITH_ASAN
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
try:
@ -38,29 +39,40 @@ except unittest.SkipTest:
raise
_desired_test_bases = get_desired_device_type_test_bases()
RUN_CUDA = (
HAS_CUDA
and any(getattr(x, "device_type", "") == "cuda" for x in _desired_test_bases)
_desired_test_bases = get_desired_device_type_test_bases(allow_xpu=True)
RUN_GPU = (
HAS_GPU
and any(is_gpu(getattr(x, "device_type", "")) for x in _desired_test_bases)
and not TEST_WITH_ASAN
)
class CudaWrapperTemplate:
class GpuWrapperTemplate:
pass
class TestCudaWrapper(InductorTestCase):
device = "cuda"
class TestGpuWrapper(InductorTestCase):
device = GPU_TYPE
class DynamicShapesCudaWrapperCudaTests(InductorTestCase):
device = "cuda"
class DynamicShapesGpuWrapperGpuTests(InductorTestCase):
device = GPU_TYPE
test_failures_cuda_wrapper = {
test_failures_gpu_wrapper = {
"test_mm_plus_mm2_cuda_dynamic_shapes": test_torchinductor.TestFailure(
("cuda_wrapper",), is_skip=True
("gpu_wrapper",), is_skip=True
),
"test_randint_xpu": test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=False),
"test_randint_xpu_dynamic_shapes": test_torchinductor.TestFailure(
("gpu_wrapper",), is_skip=False
),
# ATen ops: scaled_dot_product_efficient_attention not implemented on XPU.
"test_scaled_dot_product_efficient_attention_xpu": test_torchinductor.TestFailure(
("gpu_wrapper",), is_skip=False
),
"test_scaled_dot_product_efficient_attention_xpu_dynamic_shapes": test_torchinductor.TestFailure(
("gpu_wrapper",), is_skip=False
),
}
@ -114,20 +126,34 @@ def make_test_case(
fn.__dict__ = copy.deepcopy(func.__dict__)
if condition:
setattr(
CudaWrapperTemplate,
GpuWrapperTemplate,
test_name,
fn,
)
if RUN_CUDA:
if RUN_GPU:
class BaseTest(NamedTuple):
name: str
device: str = "cuda"
device: str = GPU_TYPE
tests: InductorTestCase = test_torchinductor.GPUTests()
check_code: bool = True
# XPU Not implemented yet
XPU_BASE_TEST_SKIP = [
"test_foreach_cpp_wrapper",
"test_enable_dynamic_shapes_cpp_wrapper",
"test_dynamic_shapes_persistent_reduction_mixed_x_dim",
"test_cat_slice_cat",
"test_mm_plus_mm2",
"test_mm_plus_mm3",
"test_addmm",
"test_linear_relu",
"test_fft_real_input",
"test_fft_real_input_real_output",
]
# Maintain two separate test lists for cuda and cpp for now
for item in [
BaseTest("test_add_complex"),
@ -236,40 +262,41 @@ if RUN_CUDA:
tests=test_select_algorithm.TestSelectAlgorithm(),
),
]:
if item.device == "xpu" and item.name in XPU_BASE_TEST_SKIP:
continue
make_test_case(item.name, item.device, item.tests, check_code=item.check_code)
from torch._inductor.utils import is_big_gpu
if is_big_gpu(0):
if GPU_TYPE == "cuda" and is_big_gpu(0):
skip_list = ["test_addmm", "test_linear_relu"]
# need to skip instead of omit, otherwise fbcode ci can be flaky
for test_name in skip_list:
test_failures_cuda_wrapper[
test_failures_gpu_wrapper[
f"{test_name}_cuda"
] = test_torchinductor.TestFailure(("cuda_wrapper",), is_skip=True)
test_failures_cuda_wrapper[
f"{test_name}_cuda_dynamic_shapes"
] = test_torchinductor.TestFailure(("cuda_wrapper",), is_skip=True)
] = test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=True)
test_failures_gpu_wrapper[
f"{test_name}_gpu_dynamic_shapes"
] = test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=True)
test_torchinductor.copy_tests(
CudaWrapperTemplate, TestCudaWrapper, "cuda_wrapper", test_failures_cuda_wrapper
GpuWrapperTemplate, TestGpuWrapper, "gpu_wrapper", test_failures_gpu_wrapper
)
DynamicShapesCudaWrapperTemplate = (
test_torchinductor_dynamic_shapes.make_dynamic_cls(CudaWrapperTemplate)
DynamicShapesGpuWrapperTemplate = (
test_torchinductor_dynamic_shapes.make_dynamic_cls(GpuWrapperTemplate)
)
test_torchinductor.copy_tests(
DynamicShapesCudaWrapperTemplate,
DynamicShapesCudaWrapperCudaTests,
"cuda_wrapper",
test_failures_cuda_wrapper,
DynamicShapesGpuWrapperTemplate,
DynamicShapesGpuWrapperGpuTests,
"gpu_wrapper",
test_failures_gpu_wrapper,
xfail_prop="_expected_failure_dynamic_wrapper",
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
print(f"FS: run_cuda {RUN_CUDA}")
if RUN_CUDA:
if RUN_GPU:
run_tests(needs="filelock")

View File

@ -3,7 +3,6 @@
import sys
import unittest
from torch.testing._internal.common_device_type import expectedFailureXPU
from torch.testing._internal.common_utils import (
IS_CI,
IS_WINDOWS,
@ -71,7 +70,6 @@ class TestMemoryPlanning(TestCase):
)
self.assertTrue(same(f(*args), result))
@expectedFailureXPU
def test_cpp_wrapper(self):
f, args = self._generate(device=GPU_TYPE)
compiled = torch.compile(f, dynamic=True)

View File

@ -3265,7 +3265,6 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
gm = make_fx(f, tracing_mode=tracing_mode)(x, x)
self.assertEqual(gm(x, x), x + x)
@skipIfXpu
@requires_gpu
@patch.object(torch._inductor.config, "cpp_wrapper", True)
@patch.object(torch._inductor.config, "triton.autotune_at_compile_time", True)

View File

@ -382,6 +382,7 @@ def init_backend_registration():
"xpu",
TritonScheduling,
PythonWrapperCodegen,
CppWrapperGpu,
)
private_backend = torch._C._get_privateuse1_backend_name()

View File

@ -81,6 +81,7 @@ DTYPE_TO_ATEN = {
DEVICE_TO_ATEN = {
"cpu": "at::kCPU",
"cuda": "at::kCUDA",
"xpu": "at::kXPU",
}
LAYOUT_TO_ATEN = {

View File

@ -198,11 +198,16 @@ class CppWrapperCpu(PythonWrapperCodegen):
}}
"""
)
extend_aoti_path = (
extend_aoti_c_shim_include = (
f"torch/csrc/inductor/aoti_torch/generated/extend/c_shim_{self.device}.h"
)
if os.path.exists(extend_aoti_path):
self.header.splice(f"#include <{extend_aoti_path}>")
extend_aoti_c_shim_path = os.path.join(
os.path.dirname(torch.__file__),
"include",
extend_aoti_c_shim_include,
)
if os.path.exists(extend_aoti_c_shim_path):
self.header.splice(f"#include <{extend_aoti_c_shim_include}>")
enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
"linux",

View File

@ -782,6 +782,7 @@ class PythonWrapperCodegen(CodeGen):
async_compile = AsyncCompile()
generate_example_value = AlgorithmSelectorCache.generate_example_value
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
"""
)

View File

@ -16,7 +16,7 @@ class XPUDeviceOpOverrides(DeviceOpOverrides):
return f"torch.xpu._DeviceGuard({device_idx})"
def cpp_device_guard(self):
return "at::xpu::XPUGuard"
return "at::DeviceGuard"
def cpp_aoti_device_guard(self):
return "AOTIXpuGuard"
@ -30,5 +30,50 @@ class XPUDeviceOpOverrides(DeviceOpOverrides):
def cpp_getStreamFromExternal(self):
return "at::xpu::getStreamFromExternal"
def kernel_header(self):
source_codes = """
#include <torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h>
"""
return source_codes
def kernel_driver(self):
source_codes = """
namespace {
struct Grid {
Grid(uint32_t x, uint32_t y, uint32_t z)
: grid_x(x), grid_y(y), grid_z(z) {}
uint32_t grid_x;
uint32_t grid_y;
uint32_t grid_z;
bool is_non_zero() {
return grid_x > 0 && grid_y > 0 && grid_z > 0;
}
};
} // anonymous namespace
"""
return source_codes
def abi_compatible_header(self):
return """
#include <torch/csrc/inductor/aoti_runtime/utils_xpu.h>
#include <torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h>
"""
def cpp_stream_type(self):
return "sycl::queue*"
def aoti_get_stream(self):
return "aoti_torch_get_current_xpu_stream"
def cpp_kernel_type(self):
return "std::unique_ptr<sycl::kernel>"
def cpp_device_ptr(self):
return "void *"
register_device_op_overrides("xpu", XPUDeviceOpOverrides())

View File

@ -405,8 +405,8 @@ class BuildOptionsBase:
self._passthough_args = _remove_duplication_in_list(self._passthough_args)
def _finalize_options(self) -> None:
self._process_compile_only_options
self._remove_duplicate_options
self._process_compile_only_options()
self._remove_duplicate_options()
def get_compiler(self) -> str:
return self._compiler
@ -530,7 +530,7 @@ def _get_ffast_math_flags() -> List[str]:
return flags
def _get_optimization_cflags() -> List[str]:
def _get_optimization_cflags(cpp_compiler: str) -> List[str]:
if _IS_WINDOWS:
return ["O2"]
else:
@ -545,7 +545,7 @@ def _get_optimization_cflags() -> List[str]:
if sys.platform != "darwin":
# on macos, unknown argument: '-fno-tree-loop-vectorize'
if is_gcc():
if _is_gcc(cpp_compiler):
cflags.append("fno-tree-loop-vectorize")
# https://stackoverflow.com/questions/65966969/why-does-march-native-not-work-on-apple-m1
# `-march=native` is unrecognized option on M1
@ -593,7 +593,7 @@ def get_cpp_options(
cflags = (
_get_shared_cflag(compile_only)
+ _get_optimization_cflags()
+ _get_optimization_cflags(cpp_compiler)
+ _get_warning_all_cflag(warning_all)
+ _get_cpp_std_cflag()
+ _get_os_related_cpp_cflags(cpp_compiler)
@ -629,9 +629,10 @@ class CppOptions(BuildOptionsBase):
warning_all: bool = True,
extra_flags: Sequence[str] = (),
use_absolute_path: bool = False,
compiler: str = "",
) -> None:
super().__init__()
self._compiler = get_cpp_compiler()
self._compiler = compiler if compiler else get_cpp_compiler()
self._use_absolute_path = use_absolute_path
self._compile_only = compile_only
@ -1116,12 +1117,14 @@ class CppTorchOptions(CppOptions):
use_mmap_weights: bool = False,
shared: bool = True,
extra_flags: Sequence[str] = (),
compiler: str = "",
) -> None:
super().__init__(
compile_only=compile_only,
warning_all=warning_all,
extra_flags=extra_flags,
use_absolute_path=use_absolute_path,
compiler=compiler,
)
self._aot_mode = aot_mode
@ -1205,7 +1208,6 @@ def get_cpp_torch_device_options(
include_dirs = cpp_extension.include_paths(device_type)
libraries_dirs = cpp_extension.library_paths(device_type)
if device_type == "cuda":
definations.append(" USE_ROCM" if torch.version.hip else " USE_CUDA")
@ -1223,7 +1225,12 @@ def get_cpp_torch_device_options(
if device_type == "xpu":
definations.append(" USE_XPU")
cflags += ["fsycl"]
# Add "-Wno-unsupported-floating-point-opt" here to
# suppress compiler warning:
# "warning: overriding currently unsupported use of floating point
# exceptions on this target [-Wunsupported-floating-point-opt]".
# Since the compiler has not support some features.
cflags += ["fsycl", "Wno-unsupported-floating-point-opt"]
libraries += ["c10_xpu", "sycl", "ze_loader", "torch_xpu"]
if aot_mode:
@ -1275,6 +1282,12 @@ class CppTorchDeviceOptions(CppTorchOptions):
shared: bool = True,
extra_flags: Sequence[str] = (),
) -> None:
if device_type == "xpu":
from torch.utils.cpp_extension import _join_sycl_home
compiler = _join_sycl_home("bin", "icpx")
else:
compiler = ""
super().__init__(
vec_isa=vec_isa,
include_pytorch=include_pytorch,
@ -1283,11 +1296,8 @@ class CppTorchDeviceOptions(CppTorchOptions):
use_absolute_path=use_absolute_path,
use_mmap_weights=use_mmap_weights,
extra_flags=extra_flags,
compiler=compiler,
)
if device_type == "xpu":
from torch.utils.cpp_extension import _join_sycl_home
self._compiler = _join_sycl_home("bin", "icpx")
device_definations: List[str] = []
device_include_dirs: List[str] = []

View File

@ -0,0 +1,178 @@
// NOLINT
#pragma once
#ifdef USE_XPU
#include <c10/xpu/XPUFunctions.h>
#include <level_zero/ze_api.h>
#include <sycl/sycl.hpp>
#include <fstream>
#include <iostream>
#include <string>
#define ZE_CHECK(status) \
{ \
if (status != ZE_RESULT_SUCCESS) { \
std::stringstream ss; \
ss << "L0 runtime error: " << std::hex << std::uppercase << status; \
throw std::runtime_error(ss.str()); \
} \
}
static ze_module_handle_t create_module(
ze_context_handle_t context,
ze_device_handle_t device,
const uint8_t* binary_ptr,
size_t binary_size) {
const char* build_flags = "";
const ze_module_format_t format = ZE_MODULE_FORMAT_IL_SPIRV;
ze_module_desc_t module_description = {};
module_description.stype = ZE_STRUCTURE_TYPE_MODULE_DESC;
module_description.format = format;
module_description.inputSize = binary_size;
module_description.pInputModule = (uint8_t*)binary_ptr;
module_description.pBuildFlags = build_flags;
ze_module_build_log_handle_t buildlog = nullptr;
ze_module_handle_t module = nullptr;
auto context_initial = context;
auto device_initial = device;
auto error_no = ZE_RESULT_SUCCESS;
error_no =
zeModuleCreate(context, device, &module_description, &module, &buildlog);
if (error_no != ZE_RESULT_SUCCESS) {
size_t szLog = 0;
ZE_CHECK(zeModuleBuildLogGetString(buildlog, &szLog, nullptr));
char* strLog = (char*)malloc(szLog);
ZE_CHECK(zeModuleBuildLogGetString(buildlog, &szLog, strLog));
std::cerr << "L0 build module failed. Log: " << strLog << std::endl;
free(strLog);
}
if (buildlog) {
ZE_CHECK(zeModuleBuildLogDestroy(buildlog));
}
ZE_CHECK(error_no);
return module;
}
ze_kernel_handle_t create_function(
ze_module_handle_t module,
ze_kernel_flags_t flag,
const std::string& func_name) {
ze_kernel_handle_t kernel = nullptr;
ze_kernel_desc_t kernel_description = {};
kernel_description.stype = ZE_STRUCTURE_TYPE_KERNEL_DESC;
kernel_description.pNext = nullptr;
kernel_description.flags = flag;
kernel_description.pKernelName = func_name.c_str();
assert(module);
ZE_CHECK(zeKernelCreate(module, &kernel_description, &kernel));
return kernel;
}
static ze_module_handle_t loadModule(std::string& spv_path) {
sycl::device& sycl_device =
c10::xpu::get_raw_device(c10::xpu::current_device());
auto sycl_context =
sycl_device.get_platform().ext_oneapi_get_default_context();
auto l0_device =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
auto l0_context =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
std::ifstream IFS(spv_path.c_str(), std::ios::binary);
std::ostringstream OSS;
OSS << IFS.rdbuf();
std::string data(OSS.str());
return create_module(
l0_context,
l0_device,
reinterpret_cast<const uint8_t*>(data.c_str()),
data.size());
}
static std::unique_ptr<sycl::kernel> getKernel(
ze_module_handle_t l0_module,
const char* kernel_name) {
assert(l0_module);
assert(kernel_name);
auto l0_kernel =
create_function(l0_module, ZE_KERNEL_FLAG_FORCE_RESIDENCY, kernel_name);
sycl::device& sycl_device =
c10::xpu::get_raw_device(c10::xpu::current_device());
auto sycl_context =
sycl_device.get_platform().ext_oneapi_get_default_context();
auto mod = sycl::make_kernel_bundle<
sycl::backend::ext_oneapi_level_zero,
sycl::bundle_state::executable>(
{l0_module, sycl::ext::oneapi::level_zero::ownership::transfer},
sycl_context);
auto fun = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
{mod, l0_kernel, sycl::ext::oneapi::level_zero::ownership::transfer},
sycl_context);
return std::make_unique<sycl::kernel>(fun);
}
[[maybe_unused]] static std::unique_ptr<sycl::kernel> loadKernel(
std::string filePath,
const std::string& funcName,
uint32_t sharedMemBytes,
const std::optional<std::string>& binDir = std::nullopt) {
if (binDir) {
std::filesystem::path p1{*binDir};
std::filesystem::path p2{filePath};
filePath = (p1 / p2.filename()).string();
}
auto mod = loadModule(filePath);
return getKernel(mod, funcName.c_str());
}
[[maybe_unused]] static void launchKernel(
std::unique_ptr<sycl::kernel>& kernel_ptr,
uint32_t grid_x,
uint32_t grid_y,
uint32_t grid_z,
uint32_t num_warps,
uint32_t shared_memory,
void** params,
sycl::queue* queue_ptr) {
std::string kernel_name =
kernel_ptr->get_info<sycl::info::kernel::function_name>();
// Currently threads_per_warp is hard code to 32 from torch.compile to triton
// stack.
int threads_per_warp = 32;
uint32_t num_params = kernel_ptr->get_info<sycl::info::kernel::num_args>();
size_t global_range_x = grid_x * threads_per_warp * num_warps;
size_t global_range_y = grid_y;
size_t global_range_z = grid_z;
size_t local_range_x = num_warps * threads_per_warp;
size_t local_range_y = 1;
size_t local_range_z = 1;
sycl::range<3> global_range(global_range_z, global_range_y, global_range_x);
sycl::range<3> local_range(local_range_z, local_range_y, local_range_x);
sycl::nd_range<3> parallel_work_size(global_range, local_range);
if (shared_memory) {
// num_params from sycl info = user provided args + shared_memroy_buffer
num_params -= 1;
}
// Submit the imported kernel.
auto cgf = [&](sycl::handler& cgh) {
for (uint32_t i = 0; i < num_params; ++i) {
cgh.set_arg(i, *(static_cast<void**>(params[i])));
}
if (shared_memory > 0) {
constexpr int dimensions = 1;
using share_mem_t = sycl::local_accessor<int8_t, dimensions>;
share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
cgh.set_arg(num_params, local_buffer);
cgh.parallel_for(parallel_work_size, *kernel_ptr);
} else {
cgh.parallel_for(parallel_work_size, *kernel_ptr);
}
};
auto event = queue_ptr->submit(cgf);
}
#endif

View File

@ -0,0 +1,36 @@
#pragma once
#ifdef USE_XPU
// WARNING: Be careful when adding new includes here. This header will be used
// in model.so, and should not refer to any aten/c10 headers except the stable
// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
// applies to other files under torch/csrc/inductor/aoti_runtime/.
#include <torch/csrc/inductor/aoti_runtime/utils.h>
#include <torch/csrc/inductor/aoti_torch/c/shim_xpu.h>
namespace torch::aot_inductor {
inline void delete_xpu_guard(void* ptr) {
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_delete_xpu_guard(reinterpret_cast<XPUGuardHandle>(ptr)));
}
class AOTIXpuGuard {
public:
AOTIXpuGuard(int32_t device_index) : guard_(nullptr, delete_xpu_guard) {
XPUGuardHandle ptr = nullptr;
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_create_xpu_guard(device_index, &ptr));
guard_.reset(ptr);
}
void set_index(int32_t device_index) {
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_xpu_guard_set_index(guard_.get(), device_index));
}
private:
std::unique_ptr<XPUGuardOpaque, DeleterFnPtr> guard_;
};
} // namespace torch::aot_inductor
#endif // USE_XPU

View File

@ -96,6 +96,7 @@ using AOTITorchError = int32_t;
// desired for perf reasons.)
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cpu();
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cuda();
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_xpu();
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_privateuse1();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2();

View File

@ -0,0 +1,36 @@
#ifndef AOTI_TORCH_SHIM_XPU
#define AOTI_TORCH_SHIM_XPU
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#ifdef USE_XPU
#ifdef __cplusplus
extern "C" {
#endif
struct XPUGuardOpaque;
using XPUGuardHandle = XPUGuardOpaque*;
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_xpu_guard(
int32_t device_index,
XPUGuardHandle* ret_guard // returns new reference
);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_delete_xpu_guard(XPUGuardHandle guard);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_xpu_guard_set_index(XPUGuardHandle guard, int32_t device_index);
struct XPUStreamGuardOpaque;
using XPUStreamGuardHandle = XPUStreamGuardOpaque*;
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_current_xpu_stream(int32_t device_index, void** ret_stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // USE_XPU
#endif // AOTI_TORCH_SHIM_XPU

View File

@ -119,6 +119,7 @@ const int AOTI_TORCH_MAX_NUMEL_TO_PRINT = 64;
AOTI_TORCH_DEVICE_TYPE_IMPL(cpu, CPU)
AOTI_TORCH_DEVICE_TYPE_IMPL(cuda, CUDA)
AOTI_TORCH_DEVICE_TYPE_IMPL(xpu, XPU)
AOTI_TORCH_DEVICE_TYPE_IMPL(privateuse1, PrivateUse1)
#undef AOTI_TORCH_DEVICE_TYPE_IMPL

View File

@ -0,0 +1,37 @@
#include <torch/csrc/inductor/aoti_torch/c/shim_xpu.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <c10/core/DeviceGuard.h>
#include <c10/core/DeviceType.h>
#include <c10/core/StreamGuard.h>
#include <c10/xpu/XPUStream.h>
AOTITorchError aoti_torch_create_xpu_guard(
int32_t device_index,
XPUGuardHandle* ret_guard // returns new reference
) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::DeviceGuard* guard =
new at::DeviceGuard(at::Device(at::DeviceType::XPU, device_index));
*ret_guard = reinterpret_cast<XPUGuardHandle>(guard);
});
}
AOTITorchError aoti_torch_delete_xpu_guard(XPUGuardHandle guard) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ delete reinterpret_cast<at::DeviceGuard*>(guard); });
}
AOTITorchError aoti_torch_xpu_guard_set_index(
XPUGuardHandle guard,
int32_t device_index) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ reinterpret_cast<at::DeviceGuard*>(guard)->set_index(device_index); });
}
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_current_xpu_stream(int32_t device_index, void** ret_stream) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *ret_stream = &(at::xpu::getCurrentXPUStream(device_index).queue()); });
}