mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[AOTI XPU] Enable Cpp wraper for Intel GPU. (#135318)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135318 Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/guangyey, https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
c418a9ac75
commit
4742080ed9
@ -263,6 +263,7 @@ exclude_patterns = [
|
||||
'torch/csrc/jit/**/*',
|
||||
'torch/csrc/jit/serialization/mobile_bytecode_generated.h',
|
||||
'torch/csrc/utils/pythoncapi_compat.h',
|
||||
'torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h',
|
||||
]
|
||||
init_command = [
|
||||
'python3',
|
||||
|
@ -792,6 +792,7 @@ libtorch_python_xpu_sources = [
|
||||
"torch/csrc/xpu/Event.cpp",
|
||||
"torch/csrc/xpu/Module.cpp",
|
||||
"torch/csrc/xpu/Stream.cpp",
|
||||
"torch/csrc/inductor/aoti_torch/shim_xpu.cpp",
|
||||
]
|
||||
|
||||
libtorch_python_core_sources = [
|
||||
|
@ -1116,6 +1116,10 @@ if(USE_XPU)
|
||||
|
||||
# Set cached ${ATen_XPU_INCLUDE_DIRS} to torch
|
||||
include_directories(SYSTEM ${ATen_XPU_INCLUDE_DIRS})
|
||||
message(INFO "Install ${TORCH_XPU_OPS_DIR}/src/ATen/xpu to ${TORCH_INSTALL_INCLUDE_DIR}/ATen/xpu")
|
||||
install(DIRECTORY "${TORCH_XPU_OPS_DIR}/src/ATen/xpu"
|
||||
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/ATen/
|
||||
FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp")
|
||||
|
||||
endif()
|
||||
endif()
|
||||
|
1
setup.py
1
setup.py
@ -1289,6 +1289,7 @@ def main():
|
||||
"include/torch/csrc/inductor/aoti_torch/*.h",
|
||||
"include/torch/csrc/inductor/aoti_torch/c/*.h",
|
||||
"include/torch/csrc/inductor/aoti_torch/generated/*.h",
|
||||
"include/torch/csrc/inductor/aoti_torch/generated/extend/*.h",
|
||||
"include/torch/csrc/jit/*.h",
|
||||
"include/torch/csrc/jit/backends/*.h",
|
||||
"include/torch/csrc/jit/generated/*.h",
|
||||
|
@ -7,11 +7,12 @@ from typing import NamedTuple
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.utils import is_gpu
|
||||
from torch.testing._internal.common_device_type import (
|
||||
get_desired_device_type_test_bases,
|
||||
)
|
||||
from torch.testing._internal.common_utils import slowTest, TEST_WITH_ASAN
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
|
||||
|
||||
try:
|
||||
@ -38,29 +39,40 @@ except unittest.SkipTest:
|
||||
raise
|
||||
|
||||
|
||||
_desired_test_bases = get_desired_device_type_test_bases()
|
||||
RUN_CUDA = (
|
||||
HAS_CUDA
|
||||
and any(getattr(x, "device_type", "") == "cuda" for x in _desired_test_bases)
|
||||
_desired_test_bases = get_desired_device_type_test_bases(allow_xpu=True)
|
||||
RUN_GPU = (
|
||||
HAS_GPU
|
||||
and any(is_gpu(getattr(x, "device_type", "")) for x in _desired_test_bases)
|
||||
and not TEST_WITH_ASAN
|
||||
)
|
||||
|
||||
|
||||
class CudaWrapperTemplate:
|
||||
class GpuWrapperTemplate:
|
||||
pass
|
||||
|
||||
|
||||
class TestCudaWrapper(InductorTestCase):
|
||||
device = "cuda"
|
||||
class TestGpuWrapper(InductorTestCase):
|
||||
device = GPU_TYPE
|
||||
|
||||
|
||||
class DynamicShapesCudaWrapperCudaTests(InductorTestCase):
|
||||
device = "cuda"
|
||||
class DynamicShapesGpuWrapperGpuTests(InductorTestCase):
|
||||
device = GPU_TYPE
|
||||
|
||||
|
||||
test_failures_cuda_wrapper = {
|
||||
test_failures_gpu_wrapper = {
|
||||
"test_mm_plus_mm2_cuda_dynamic_shapes": test_torchinductor.TestFailure(
|
||||
("cuda_wrapper",), is_skip=True
|
||||
("gpu_wrapper",), is_skip=True
|
||||
),
|
||||
"test_randint_xpu": test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=False),
|
||||
"test_randint_xpu_dynamic_shapes": test_torchinductor.TestFailure(
|
||||
("gpu_wrapper",), is_skip=False
|
||||
),
|
||||
# ATen ops: scaled_dot_product_efficient_attention not implemented on XPU.
|
||||
"test_scaled_dot_product_efficient_attention_xpu": test_torchinductor.TestFailure(
|
||||
("gpu_wrapper",), is_skip=False
|
||||
),
|
||||
"test_scaled_dot_product_efficient_attention_xpu_dynamic_shapes": test_torchinductor.TestFailure(
|
||||
("gpu_wrapper",), is_skip=False
|
||||
),
|
||||
}
|
||||
|
||||
@ -114,20 +126,34 @@ def make_test_case(
|
||||
fn.__dict__ = copy.deepcopy(func.__dict__)
|
||||
if condition:
|
||||
setattr(
|
||||
CudaWrapperTemplate,
|
||||
GpuWrapperTemplate,
|
||||
test_name,
|
||||
fn,
|
||||
)
|
||||
|
||||
|
||||
if RUN_CUDA:
|
||||
if RUN_GPU:
|
||||
|
||||
class BaseTest(NamedTuple):
|
||||
name: str
|
||||
device: str = "cuda"
|
||||
device: str = GPU_TYPE
|
||||
tests: InductorTestCase = test_torchinductor.GPUTests()
|
||||
check_code: bool = True
|
||||
|
||||
# XPU Not implemented yet
|
||||
XPU_BASE_TEST_SKIP = [
|
||||
"test_foreach_cpp_wrapper",
|
||||
"test_enable_dynamic_shapes_cpp_wrapper",
|
||||
"test_dynamic_shapes_persistent_reduction_mixed_x_dim",
|
||||
"test_cat_slice_cat",
|
||||
"test_mm_plus_mm2",
|
||||
"test_mm_plus_mm3",
|
||||
"test_addmm",
|
||||
"test_linear_relu",
|
||||
"test_fft_real_input",
|
||||
"test_fft_real_input_real_output",
|
||||
]
|
||||
|
||||
# Maintain two separate test lists for cuda and cpp for now
|
||||
for item in [
|
||||
BaseTest("test_add_complex"),
|
||||
@ -236,40 +262,41 @@ if RUN_CUDA:
|
||||
tests=test_select_algorithm.TestSelectAlgorithm(),
|
||||
),
|
||||
]:
|
||||
if item.device == "xpu" and item.name in XPU_BASE_TEST_SKIP:
|
||||
continue
|
||||
make_test_case(item.name, item.device, item.tests, check_code=item.check_code)
|
||||
|
||||
from torch._inductor.utils import is_big_gpu
|
||||
|
||||
if is_big_gpu(0):
|
||||
if GPU_TYPE == "cuda" and is_big_gpu(0):
|
||||
skip_list = ["test_addmm", "test_linear_relu"]
|
||||
# need to skip instead of omit, otherwise fbcode ci can be flaky
|
||||
for test_name in skip_list:
|
||||
test_failures_cuda_wrapper[
|
||||
test_failures_gpu_wrapper[
|
||||
f"{test_name}_cuda"
|
||||
] = test_torchinductor.TestFailure(("cuda_wrapper",), is_skip=True)
|
||||
test_failures_cuda_wrapper[
|
||||
f"{test_name}_cuda_dynamic_shapes"
|
||||
] = test_torchinductor.TestFailure(("cuda_wrapper",), is_skip=True)
|
||||
] = test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=True)
|
||||
test_failures_gpu_wrapper[
|
||||
f"{test_name}_gpu_dynamic_shapes"
|
||||
] = test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=True)
|
||||
|
||||
test_torchinductor.copy_tests(
|
||||
CudaWrapperTemplate, TestCudaWrapper, "cuda_wrapper", test_failures_cuda_wrapper
|
||||
GpuWrapperTemplate, TestGpuWrapper, "gpu_wrapper", test_failures_gpu_wrapper
|
||||
)
|
||||
|
||||
DynamicShapesCudaWrapperTemplate = (
|
||||
test_torchinductor_dynamic_shapes.make_dynamic_cls(CudaWrapperTemplate)
|
||||
DynamicShapesGpuWrapperTemplate = (
|
||||
test_torchinductor_dynamic_shapes.make_dynamic_cls(GpuWrapperTemplate)
|
||||
)
|
||||
|
||||
test_torchinductor.copy_tests(
|
||||
DynamicShapesCudaWrapperTemplate,
|
||||
DynamicShapesCudaWrapperCudaTests,
|
||||
"cuda_wrapper",
|
||||
test_failures_cuda_wrapper,
|
||||
DynamicShapesGpuWrapperTemplate,
|
||||
DynamicShapesGpuWrapperGpuTests,
|
||||
"gpu_wrapper",
|
||||
test_failures_gpu_wrapper,
|
||||
xfail_prop="_expected_failure_dynamic_wrapper",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
print(f"FS: run_cuda {RUN_CUDA}")
|
||||
if RUN_CUDA:
|
||||
if RUN_GPU:
|
||||
run_tests(needs="filelock")
|
||||
|
@ -3,7 +3,6 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from torch.testing._internal.common_device_type import expectedFailureXPU
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_CI,
|
||||
IS_WINDOWS,
|
||||
@ -71,7 +70,6 @@ class TestMemoryPlanning(TestCase):
|
||||
)
|
||||
self.assertTrue(same(f(*args), result))
|
||||
|
||||
@expectedFailureXPU
|
||||
def test_cpp_wrapper(self):
|
||||
f, args = self._generate(device=GPU_TYPE)
|
||||
compiled = torch.compile(f, dynamic=True)
|
||||
|
@ -3265,7 +3265,6 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
|
||||
gm = make_fx(f, tracing_mode=tracing_mode)(x, x)
|
||||
self.assertEqual(gm(x, x), x + x)
|
||||
|
||||
@skipIfXpu
|
||||
@requires_gpu
|
||||
@patch.object(torch._inductor.config, "cpp_wrapper", True)
|
||||
@patch.object(torch._inductor.config, "triton.autotune_at_compile_time", True)
|
||||
|
@ -382,6 +382,7 @@ def init_backend_registration():
|
||||
"xpu",
|
||||
TritonScheduling,
|
||||
PythonWrapperCodegen,
|
||||
CppWrapperGpu,
|
||||
)
|
||||
|
||||
private_backend = torch._C._get_privateuse1_backend_name()
|
||||
|
@ -81,6 +81,7 @@ DTYPE_TO_ATEN = {
|
||||
DEVICE_TO_ATEN = {
|
||||
"cpu": "at::kCPU",
|
||||
"cuda": "at::kCUDA",
|
||||
"xpu": "at::kXPU",
|
||||
}
|
||||
|
||||
LAYOUT_TO_ATEN = {
|
||||
|
@ -198,11 +198,16 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
}}
|
||||
"""
|
||||
)
|
||||
extend_aoti_path = (
|
||||
extend_aoti_c_shim_include = (
|
||||
f"torch/csrc/inductor/aoti_torch/generated/extend/c_shim_{self.device}.h"
|
||||
)
|
||||
if os.path.exists(extend_aoti_path):
|
||||
self.header.splice(f"#include <{extend_aoti_path}>")
|
||||
extend_aoti_c_shim_path = os.path.join(
|
||||
os.path.dirname(torch.__file__),
|
||||
"include",
|
||||
extend_aoti_c_shim_include,
|
||||
)
|
||||
if os.path.exists(extend_aoti_c_shim_path):
|
||||
self.header.splice(f"#include <{extend_aoti_c_shim_include}>")
|
||||
|
||||
enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
|
||||
"linux",
|
||||
|
@ -782,6 +782,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
async_compile = AsyncCompile()
|
||||
generate_example_value = AlgorithmSelectorCache.generate_example_value
|
||||
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
||||
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
||||
"""
|
||||
)
|
||||
|
||||
|
@ -16,7 +16,7 @@ class XPUDeviceOpOverrides(DeviceOpOverrides):
|
||||
return f"torch.xpu._DeviceGuard({device_idx})"
|
||||
|
||||
def cpp_device_guard(self):
|
||||
return "at::xpu::XPUGuard"
|
||||
return "at::DeviceGuard"
|
||||
|
||||
def cpp_aoti_device_guard(self):
|
||||
return "AOTIXpuGuard"
|
||||
@ -30,5 +30,50 @@ class XPUDeviceOpOverrides(DeviceOpOverrides):
|
||||
def cpp_getStreamFromExternal(self):
|
||||
return "at::xpu::getStreamFromExternal"
|
||||
|
||||
def kernel_header(self):
|
||||
source_codes = """
|
||||
#include <torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h>
|
||||
"""
|
||||
return source_codes
|
||||
|
||||
def kernel_driver(self):
|
||||
source_codes = """
|
||||
namespace {
|
||||
|
||||
struct Grid {
|
||||
Grid(uint32_t x, uint32_t y, uint32_t z)
|
||||
: grid_x(x), grid_y(y), grid_z(z) {}
|
||||
uint32_t grid_x;
|
||||
uint32_t grid_y;
|
||||
uint32_t grid_z;
|
||||
|
||||
bool is_non_zero() {
|
||||
return grid_x > 0 && grid_y > 0 && grid_z > 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
"""
|
||||
return source_codes
|
||||
|
||||
def abi_compatible_header(self):
|
||||
return """
|
||||
#include <torch/csrc/inductor/aoti_runtime/utils_xpu.h>
|
||||
#include <torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h>
|
||||
"""
|
||||
|
||||
def cpp_stream_type(self):
|
||||
return "sycl::queue*"
|
||||
|
||||
def aoti_get_stream(self):
|
||||
return "aoti_torch_get_current_xpu_stream"
|
||||
|
||||
def cpp_kernel_type(self):
|
||||
return "std::unique_ptr<sycl::kernel>"
|
||||
|
||||
def cpp_device_ptr(self):
|
||||
return "void *"
|
||||
|
||||
|
||||
register_device_op_overrides("xpu", XPUDeviceOpOverrides())
|
||||
|
@ -405,8 +405,8 @@ class BuildOptionsBase:
|
||||
self._passthough_args = _remove_duplication_in_list(self._passthough_args)
|
||||
|
||||
def _finalize_options(self) -> None:
|
||||
self._process_compile_only_options
|
||||
self._remove_duplicate_options
|
||||
self._process_compile_only_options()
|
||||
self._remove_duplicate_options()
|
||||
|
||||
def get_compiler(self) -> str:
|
||||
return self._compiler
|
||||
@ -530,7 +530,7 @@ def _get_ffast_math_flags() -> List[str]:
|
||||
return flags
|
||||
|
||||
|
||||
def _get_optimization_cflags() -> List[str]:
|
||||
def _get_optimization_cflags(cpp_compiler: str) -> List[str]:
|
||||
if _IS_WINDOWS:
|
||||
return ["O2"]
|
||||
else:
|
||||
@ -545,7 +545,7 @@ def _get_optimization_cflags() -> List[str]:
|
||||
|
||||
if sys.platform != "darwin":
|
||||
# on macos, unknown argument: '-fno-tree-loop-vectorize'
|
||||
if is_gcc():
|
||||
if _is_gcc(cpp_compiler):
|
||||
cflags.append("fno-tree-loop-vectorize")
|
||||
# https://stackoverflow.com/questions/65966969/why-does-march-native-not-work-on-apple-m1
|
||||
# `-march=native` is unrecognized option on M1
|
||||
@ -593,7 +593,7 @@ def get_cpp_options(
|
||||
|
||||
cflags = (
|
||||
_get_shared_cflag(compile_only)
|
||||
+ _get_optimization_cflags()
|
||||
+ _get_optimization_cflags(cpp_compiler)
|
||||
+ _get_warning_all_cflag(warning_all)
|
||||
+ _get_cpp_std_cflag()
|
||||
+ _get_os_related_cpp_cflags(cpp_compiler)
|
||||
@ -629,9 +629,10 @@ class CppOptions(BuildOptionsBase):
|
||||
warning_all: bool = True,
|
||||
extra_flags: Sequence[str] = (),
|
||||
use_absolute_path: bool = False,
|
||||
compiler: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._compiler = get_cpp_compiler()
|
||||
self._compiler = compiler if compiler else get_cpp_compiler()
|
||||
self._use_absolute_path = use_absolute_path
|
||||
self._compile_only = compile_only
|
||||
|
||||
@ -1116,12 +1117,14 @@ class CppTorchOptions(CppOptions):
|
||||
use_mmap_weights: bool = False,
|
||||
shared: bool = True,
|
||||
extra_flags: Sequence[str] = (),
|
||||
compiler: str = "",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
compile_only=compile_only,
|
||||
warning_all=warning_all,
|
||||
extra_flags=extra_flags,
|
||||
use_absolute_path=use_absolute_path,
|
||||
compiler=compiler,
|
||||
)
|
||||
|
||||
self._aot_mode = aot_mode
|
||||
@ -1205,7 +1208,6 @@ def get_cpp_torch_device_options(
|
||||
|
||||
include_dirs = cpp_extension.include_paths(device_type)
|
||||
libraries_dirs = cpp_extension.library_paths(device_type)
|
||||
|
||||
if device_type == "cuda":
|
||||
definations.append(" USE_ROCM" if torch.version.hip else " USE_CUDA")
|
||||
|
||||
@ -1223,7 +1225,12 @@ def get_cpp_torch_device_options(
|
||||
|
||||
if device_type == "xpu":
|
||||
definations.append(" USE_XPU")
|
||||
cflags += ["fsycl"]
|
||||
# Add "-Wno-unsupported-floating-point-opt" here to
|
||||
# suppress compiler warning:
|
||||
# "warning: overriding currently unsupported use of floating point
|
||||
# exceptions on this target [-Wunsupported-floating-point-opt]".
|
||||
# Since the compiler has not support some features.
|
||||
cflags += ["fsycl", "Wno-unsupported-floating-point-opt"]
|
||||
libraries += ["c10_xpu", "sycl", "ze_loader", "torch_xpu"]
|
||||
|
||||
if aot_mode:
|
||||
@ -1275,6 +1282,12 @@ class CppTorchDeviceOptions(CppTorchOptions):
|
||||
shared: bool = True,
|
||||
extra_flags: Sequence[str] = (),
|
||||
) -> None:
|
||||
if device_type == "xpu":
|
||||
from torch.utils.cpp_extension import _join_sycl_home
|
||||
|
||||
compiler = _join_sycl_home("bin", "icpx")
|
||||
else:
|
||||
compiler = ""
|
||||
super().__init__(
|
||||
vec_isa=vec_isa,
|
||||
include_pytorch=include_pytorch,
|
||||
@ -1283,11 +1296,8 @@ class CppTorchDeviceOptions(CppTorchOptions):
|
||||
use_absolute_path=use_absolute_path,
|
||||
use_mmap_weights=use_mmap_weights,
|
||||
extra_flags=extra_flags,
|
||||
compiler=compiler,
|
||||
)
|
||||
if device_type == "xpu":
|
||||
from torch.utils.cpp_extension import _join_sycl_home
|
||||
|
||||
self._compiler = _join_sycl_home("bin", "icpx")
|
||||
|
||||
device_definations: List[str] = []
|
||||
device_include_dirs: List[str] = []
|
||||
|
178
torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h
Normal file
178
torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h
Normal file
@ -0,0 +1,178 @@
|
||||
// NOLINT
|
||||
#pragma once
|
||||
#ifdef USE_XPU
|
||||
#include <c10/xpu/XPUFunctions.h>
|
||||
#include <level_zero/ze_api.h>
|
||||
#include <sycl/sycl.hpp>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#define ZE_CHECK(status) \
|
||||
{ \
|
||||
if (status != ZE_RESULT_SUCCESS) { \
|
||||
std::stringstream ss; \
|
||||
ss << "L0 runtime error: " << std::hex << std::uppercase << status; \
|
||||
throw std::runtime_error(ss.str()); \
|
||||
} \
|
||||
}
|
||||
|
||||
static ze_module_handle_t create_module(
|
||||
ze_context_handle_t context,
|
||||
ze_device_handle_t device,
|
||||
const uint8_t* binary_ptr,
|
||||
size_t binary_size) {
|
||||
const char* build_flags = "";
|
||||
const ze_module_format_t format = ZE_MODULE_FORMAT_IL_SPIRV;
|
||||
ze_module_desc_t module_description = {};
|
||||
module_description.stype = ZE_STRUCTURE_TYPE_MODULE_DESC;
|
||||
module_description.format = format;
|
||||
module_description.inputSize = binary_size;
|
||||
module_description.pInputModule = (uint8_t*)binary_ptr;
|
||||
module_description.pBuildFlags = build_flags;
|
||||
ze_module_build_log_handle_t buildlog = nullptr;
|
||||
ze_module_handle_t module = nullptr;
|
||||
auto context_initial = context;
|
||||
auto device_initial = device;
|
||||
auto error_no = ZE_RESULT_SUCCESS;
|
||||
error_no =
|
||||
zeModuleCreate(context, device, &module_description, &module, &buildlog);
|
||||
|
||||
if (error_no != ZE_RESULT_SUCCESS) {
|
||||
size_t szLog = 0;
|
||||
ZE_CHECK(zeModuleBuildLogGetString(buildlog, &szLog, nullptr));
|
||||
char* strLog = (char*)malloc(szLog);
|
||||
ZE_CHECK(zeModuleBuildLogGetString(buildlog, &szLog, strLog));
|
||||
std::cerr << "L0 build module failed. Log: " << strLog << std::endl;
|
||||
free(strLog);
|
||||
}
|
||||
if (buildlog) {
|
||||
ZE_CHECK(zeModuleBuildLogDestroy(buildlog));
|
||||
}
|
||||
ZE_CHECK(error_no);
|
||||
return module;
|
||||
}
|
||||
|
||||
ze_kernel_handle_t create_function(
|
||||
ze_module_handle_t module,
|
||||
ze_kernel_flags_t flag,
|
||||
const std::string& func_name) {
|
||||
ze_kernel_handle_t kernel = nullptr;
|
||||
ze_kernel_desc_t kernel_description = {};
|
||||
kernel_description.stype = ZE_STRUCTURE_TYPE_KERNEL_DESC;
|
||||
kernel_description.pNext = nullptr;
|
||||
kernel_description.flags = flag;
|
||||
kernel_description.pKernelName = func_name.c_str();
|
||||
assert(module);
|
||||
ZE_CHECK(zeKernelCreate(module, &kernel_description, &kernel));
|
||||
return kernel;
|
||||
}
|
||||
|
||||
static ze_module_handle_t loadModule(std::string& spv_path) {
|
||||
sycl::device& sycl_device =
|
||||
c10::xpu::get_raw_device(c10::xpu::current_device());
|
||||
auto sycl_context =
|
||||
sycl_device.get_platform().ext_oneapi_get_default_context();
|
||||
auto l0_device =
|
||||
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
|
||||
auto l0_context =
|
||||
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
|
||||
|
||||
std::ifstream IFS(spv_path.c_str(), std::ios::binary);
|
||||
std::ostringstream OSS;
|
||||
OSS << IFS.rdbuf();
|
||||
std::string data(OSS.str());
|
||||
|
||||
return create_module(
|
||||
l0_context,
|
||||
l0_device,
|
||||
reinterpret_cast<const uint8_t*>(data.c_str()),
|
||||
data.size());
|
||||
}
|
||||
|
||||
static std::unique_ptr<sycl::kernel> getKernel(
|
||||
ze_module_handle_t l0_module,
|
||||
const char* kernel_name) {
|
||||
assert(l0_module);
|
||||
assert(kernel_name);
|
||||
auto l0_kernel =
|
||||
create_function(l0_module, ZE_KERNEL_FLAG_FORCE_RESIDENCY, kernel_name);
|
||||
|
||||
sycl::device& sycl_device =
|
||||
c10::xpu::get_raw_device(c10::xpu::current_device());
|
||||
auto sycl_context =
|
||||
sycl_device.get_platform().ext_oneapi_get_default_context();
|
||||
|
||||
auto mod = sycl::make_kernel_bundle<
|
||||
sycl::backend::ext_oneapi_level_zero,
|
||||
sycl::bundle_state::executable>(
|
||||
{l0_module, sycl::ext::oneapi::level_zero::ownership::transfer},
|
||||
sycl_context);
|
||||
|
||||
auto fun = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
|
||||
{mod, l0_kernel, sycl::ext::oneapi::level_zero::ownership::transfer},
|
||||
sycl_context);
|
||||
return std::make_unique<sycl::kernel>(fun);
|
||||
}
|
||||
|
||||
[[maybe_unused]] static std::unique_ptr<sycl::kernel> loadKernel(
|
||||
std::string filePath,
|
||||
const std::string& funcName,
|
||||
uint32_t sharedMemBytes,
|
||||
const std::optional<std::string>& binDir = std::nullopt) {
|
||||
if (binDir) {
|
||||
std::filesystem::path p1{*binDir};
|
||||
std::filesystem::path p2{filePath};
|
||||
filePath = (p1 / p2.filename()).string();
|
||||
}
|
||||
auto mod = loadModule(filePath);
|
||||
return getKernel(mod, funcName.c_str());
|
||||
}
|
||||
|
||||
[[maybe_unused]] static void launchKernel(
|
||||
std::unique_ptr<sycl::kernel>& kernel_ptr,
|
||||
uint32_t grid_x,
|
||||
uint32_t grid_y,
|
||||
uint32_t grid_z,
|
||||
uint32_t num_warps,
|
||||
uint32_t shared_memory,
|
||||
void** params,
|
||||
sycl::queue* queue_ptr) {
|
||||
std::string kernel_name =
|
||||
kernel_ptr->get_info<sycl::info::kernel::function_name>();
|
||||
// Currently threads_per_warp is hard code to 32 from torch.compile to triton
|
||||
// stack.
|
||||
int threads_per_warp = 32;
|
||||
uint32_t num_params = kernel_ptr->get_info<sycl::info::kernel::num_args>();
|
||||
size_t global_range_x = grid_x * threads_per_warp * num_warps;
|
||||
size_t global_range_y = grid_y;
|
||||
size_t global_range_z = grid_z;
|
||||
size_t local_range_x = num_warps * threads_per_warp;
|
||||
size_t local_range_y = 1;
|
||||
size_t local_range_z = 1;
|
||||
sycl::range<3> global_range(global_range_z, global_range_y, global_range_x);
|
||||
sycl::range<3> local_range(local_range_z, local_range_y, local_range_x);
|
||||
sycl::nd_range<3> parallel_work_size(global_range, local_range);
|
||||
if (shared_memory) {
|
||||
// num_params from sycl info = user provided args + shared_memroy_buffer
|
||||
num_params -= 1;
|
||||
}
|
||||
// Submit the imported kernel.
|
||||
auto cgf = [&](sycl::handler& cgh) {
|
||||
for (uint32_t i = 0; i < num_params; ++i) {
|
||||
cgh.set_arg(i, *(static_cast<void**>(params[i])));
|
||||
}
|
||||
|
||||
if (shared_memory > 0) {
|
||||
constexpr int dimensions = 1;
|
||||
using share_mem_t = sycl::local_accessor<int8_t, dimensions>;
|
||||
share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
|
||||
cgh.set_arg(num_params, local_buffer);
|
||||
cgh.parallel_for(parallel_work_size, *kernel_ptr);
|
||||
} else {
|
||||
cgh.parallel_for(parallel_work_size, *kernel_ptr);
|
||||
}
|
||||
};
|
||||
auto event = queue_ptr->submit(cgf);
|
||||
}
|
||||
#endif
|
36
torch/csrc/inductor/aoti_runtime/utils_xpu.h
Normal file
36
torch/csrc/inductor/aoti_runtime/utils_xpu.h
Normal file
@ -0,0 +1,36 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef USE_XPU
|
||||
// WARNING: Be careful when adding new includes here. This header will be used
|
||||
// in model.so, and should not refer to any aten/c10 headers except the stable
|
||||
// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
|
||||
// applies to other files under torch/csrc/inductor/aoti_runtime/.
|
||||
#include <torch/csrc/inductor/aoti_runtime/utils.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim_xpu.h>
|
||||
|
||||
namespace torch::aot_inductor {
|
||||
|
||||
inline void delete_xpu_guard(void* ptr) {
|
||||
AOTI_TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_delete_xpu_guard(reinterpret_cast<XPUGuardHandle>(ptr)));
|
||||
}
|
||||
|
||||
class AOTIXpuGuard {
|
||||
public:
|
||||
AOTIXpuGuard(int32_t device_index) : guard_(nullptr, delete_xpu_guard) {
|
||||
XPUGuardHandle ptr = nullptr;
|
||||
AOTI_TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_create_xpu_guard(device_index, &ptr));
|
||||
guard_.reset(ptr);
|
||||
}
|
||||
|
||||
void set_index(int32_t device_index) {
|
||||
AOTI_TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_xpu_guard_set_index(guard_.get(), device_index));
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<XPUGuardOpaque, DeleterFnPtr> guard_;
|
||||
};
|
||||
} // namespace torch::aot_inductor
|
||||
#endif // USE_XPU
|
@ -96,6 +96,7 @@ using AOTITorchError = int32_t;
|
||||
// desired for perf reasons.)
|
||||
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cpu();
|
||||
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cuda();
|
||||
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_xpu();
|
||||
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_privateuse1();
|
||||
|
||||
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2();
|
||||
|
36
torch/csrc/inductor/aoti_torch/c/shim_xpu.h
Normal file
36
torch/csrc/inductor/aoti_torch/c/shim_xpu.h
Normal file
@ -0,0 +1,36 @@
|
||||
#ifndef AOTI_TORCH_SHIM_XPU
|
||||
#define AOTI_TORCH_SHIM_XPU
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
|
||||
#ifdef USE_XPU
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct XPUGuardOpaque;
|
||||
using XPUGuardHandle = XPUGuardOpaque*;
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_xpu_guard(
|
||||
int32_t device_index,
|
||||
XPUGuardHandle* ret_guard // returns new reference
|
||||
);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_delete_xpu_guard(XPUGuardHandle guard);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_xpu_guard_set_index(XPUGuardHandle guard, int32_t device_index);
|
||||
|
||||
struct XPUStreamGuardOpaque;
|
||||
using XPUStreamGuardHandle = XPUStreamGuardOpaque*;
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_get_current_xpu_stream(int32_t device_index, void** ret_stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
#endif // USE_XPU
|
||||
#endif // AOTI_TORCH_SHIM_XPU
|
@ -119,6 +119,7 @@ const int AOTI_TORCH_MAX_NUMEL_TO_PRINT = 64;
|
||||
|
||||
AOTI_TORCH_DEVICE_TYPE_IMPL(cpu, CPU)
|
||||
AOTI_TORCH_DEVICE_TYPE_IMPL(cuda, CUDA)
|
||||
AOTI_TORCH_DEVICE_TYPE_IMPL(xpu, XPU)
|
||||
AOTI_TORCH_DEVICE_TYPE_IMPL(privateuse1, PrivateUse1)
|
||||
#undef AOTI_TORCH_DEVICE_TYPE_IMPL
|
||||
|
||||
|
37
torch/csrc/inductor/aoti_torch/shim_xpu.cpp
Normal file
37
torch/csrc/inductor/aoti_torch/shim_xpu.cpp
Normal file
@ -0,0 +1,37 @@
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim_xpu.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
|
||||
#include <c10/core/DeviceGuard.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/StreamGuard.h>
|
||||
#include <c10/xpu/XPUStream.h>
|
||||
|
||||
AOTITorchError aoti_torch_create_xpu_guard(
|
||||
int32_t device_index,
|
||||
XPUGuardHandle* ret_guard // returns new reference
|
||||
) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
at::DeviceGuard* guard =
|
||||
new at::DeviceGuard(at::Device(at::DeviceType::XPU, device_index));
|
||||
*ret_guard = reinterpret_cast<XPUGuardHandle>(guard);
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_delete_xpu_guard(XPUGuardHandle guard) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
{ delete reinterpret_cast<at::DeviceGuard*>(guard); });
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_xpu_guard_set_index(
|
||||
XPUGuardHandle guard,
|
||||
int32_t device_index) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
{ reinterpret_cast<at::DeviceGuard*>(guard)->set_index(device_index); });
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_get_current_xpu_stream(int32_t device_index, void** ret_stream) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
{ *ret_stream = &(at::xpu::getCurrentXPUStream(device_index).queue()); });
|
||||
}
|
Reference in New Issue
Block a user