[AOTI][XPU] Refactor AOTInductor runtime API for Intel GPU. (#153929)

Simplify and improve code format for sycl_runtime_wrappers.h

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153929
Approved by: https://github.com/desertfire
ghstack dependencies: #153924
This commit is contained in:
xinan.lin
2025-05-20 00:13:47 -07:00
committed by PyTorch MergeBot
parent 531d8f5fb6
commit dcb3edd30d

View File

@ -17,87 +17,74 @@
} \
}
static ze_module_handle_t create_module(
const uint8_t* binary_ptr,
size_t binary_size) {
sycl::device& sycl_device =
static ze_module_handle_t _createModule(
const uint8_t* binaryPtr,
size_t binarySize) {
sycl::device& syclDevice =
c10::xpu::get_raw_device(c10::xpu::current_device());
auto sycl_context =
sycl_device.get_platform().ext_oneapi_get_default_context();
auto l0_device =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
auto l0_context =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
auto syclContext = syclDevice.get_platform().ext_oneapi_get_default_context();
auto device =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(syclDevice);
auto context =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(syclContext);
const char* build_flags = "";
const char* buildFlags = "";
const ze_module_format_t format = ZE_MODULE_FORMAT_IL_SPIRV;
ze_module_desc_t module_description = {};
module_description.stype = ZE_STRUCTURE_TYPE_MODULE_DESC;
module_description.format = format;
module_description.inputSize = binary_size;
module_description.pInputModule = (uint8_t*)binary_ptr;
module_description.pBuildFlags = build_flags;
ze_module_build_log_handle_t buildlog = nullptr;
ze_module_desc_t moduleDescription = {};
moduleDescription.stype = ZE_STRUCTURE_TYPE_MODULE_DESC;
moduleDescription.format = format;
moduleDescription.inputSize = binarySize;
moduleDescription.pInputModule = (uint8_t*)binaryPtr;
moduleDescription.pBuildFlags = buildFlags;
ze_module_build_log_handle_t buildLog = nullptr;
ze_module_handle_t module = nullptr;
auto error_no = ZE_RESULT_SUCCESS;
error_no = zeModuleCreate(
l0_context, l0_device, &module_description, &module, &buildlog);
error_no =
zeModuleCreate(context, device, &moduleDescription, &module, &buildLog);
if (error_no != ZE_RESULT_SUCCESS) {
size_t szLog = 0;
ZE_CHECK(zeModuleBuildLogGetString(buildlog, &szLog, nullptr));
ZE_CHECK(zeModuleBuildLogGetString(buildLog, &szLog, nullptr));
char* strLog = (char*)malloc(szLog);
ZE_CHECK(zeModuleBuildLogGetString(buildlog, &szLog, strLog));
ZE_CHECK(zeModuleBuildLogGetString(buildLog, &szLog, strLog));
std::cerr << "L0 build module failed. Log: " << strLog << std::endl;
free(strLog);
}
if (buildlog) {
ZE_CHECK(zeModuleBuildLogDestroy(buildlog));
if (buildLog) {
ZE_CHECK(zeModuleBuildLogDestroy(buildLog));
}
ZE_CHECK(error_no);
return module;
}
ze_kernel_handle_t create_function(
static std::unique_ptr<sycl::kernel> _createKernel(
ze_module_handle_t module,
ze_kernel_flags_t flag,
const std::string& func_name) {
ze_kernel_handle_t kernel = nullptr;
ze_kernel_desc_t kernel_description = {};
kernel_description.stype = ZE_STRUCTURE_TYPE_KERNEL_DESC;
kernel_description.pNext = nullptr;
kernel_description.flags = flag;
kernel_description.pKernelName = func_name.c_str();
const char* kernelName) {
assert(module);
ZE_CHECK(zeKernelCreate(module, &kernel_description, &kernel));
return kernel;
}
assert(kernelName);
ze_kernel_handle_t kernel = nullptr;
ze_kernel_desc_t kernelDescription = {};
kernelDescription.stype = ZE_STRUCTURE_TYPE_KERNEL_DESC;
kernelDescription.pNext = nullptr;
kernelDescription.flags = ZE_KERNEL_FLAG_FORCE_RESIDENCY;
kernelDescription.pKernelName = kernelName;
ZE_CHECK(zeKernelCreate(module, &kernelDescription, &kernel));
static std::unique_ptr<sycl::kernel> getKernel(
ze_module_handle_t l0_module,
const char* kernel_name) {
assert(l0_module);
assert(kernel_name);
auto l0_kernel =
create_function(l0_module, ZE_KERNEL_FLAG_FORCE_RESIDENCY, kernel_name);
sycl::device& sycl_device =
sycl::device& syclDevice =
c10::xpu::get_raw_device(c10::xpu::current_device());
auto sycl_context =
sycl_device.get_platform().ext_oneapi_get_default_context();
auto syclContext = syclDevice.get_platform().ext_oneapi_get_default_context();
auto mod = sycl::make_kernel_bundle<
sycl::backend::ext_oneapi_level_zero,
sycl::bundle_state::executable>(
{l0_module, sycl::ext::oneapi::level_zero::ownership::transfer},
sycl_context);
{module, sycl::ext::oneapi::level_zero::ownership::transfer},
syclContext);
auto fun = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
{mod, l0_kernel, sycl::ext::oneapi::level_zero::ownership::transfer},
sycl_context);
{mod, kernel, sycl::ext::oneapi::level_zero::ownership::transfer},
syclContext);
return std::make_unique<sycl::kernel>(fun);
}
// GPU Cpp Wrapper API
[[maybe_unused]] static std::unique_ptr<sycl::kernel> loadKernel(
std::string filePath,
const std::string& funcName,
@ -114,12 +101,13 @@ static std::unique_ptr<sycl::kernel> getKernel(
OSS << IFS.rdbuf();
std::string data(OSS.str());
auto mod = create_module(
auto mod = _createModule(
reinterpret_cast<const uint8_t*>(data.c_str()), data.size());
return getKernel(mod, funcName.c_str());
return _createKernel(mod, funcName.c_str());
}
// GPU Cpp Wrapper API
[[maybe_unused]] static std::unique_ptr<sycl::kernel> loadKernel(
const void* start,
const void* end,
@ -128,55 +116,56 @@ static std::unique_ptr<sycl::kernel> getKernel(
size_t size = reinterpret_cast<const uint8_t*>(end) -
reinterpret_cast<const uint8_t*>(start);
auto mod = create_module(reinterpret_cast<const uint8_t*>(start), size);
auto mod = _createModule(reinterpret_cast<const uint8_t*>(start), size);
return getKernel(mod, funcName.c_str());
return _createKernel(mod, funcName.c_str());
}
// GPU Cpp Wrapper API
[[maybe_unused]] static void launchKernel(
std::unique_ptr<sycl::kernel>& kernel_ptr,
uint32_t grid_x,
uint32_t grid_y,
uint32_t grid_z,
uint32_t num_warps,
uint32_t shared_memory,
std::unique_ptr<sycl::kernel>& kernelPtr,
uint32_t gridX,
uint32_t gridY,
uint32_t gridZ,
uint32_t numWarps,
uint32_t sharedMemory,
void** params,
sycl::queue* queue_ptr) {
std::string kernel_name =
kernel_ptr->get_info<sycl::info::kernel::function_name>();
// Currently threads_per_warp is hard code to 32 from torch.compile to triton
sycl::queue* queuePtr) {
std::string kernelName =
kernelPtr->get_info<sycl::info::kernel::function_name>();
// Currently threadsPerWarp is hard code to 32 from torch.compile to triton
// stack.
int threads_per_warp = 32;
uint32_t num_params = kernel_ptr->get_info<sycl::info::kernel::num_args>();
size_t global_range_x = grid_x * threads_per_warp * num_warps;
size_t global_range_y = grid_y;
size_t global_range_z = grid_z;
size_t local_range_x = num_warps * threads_per_warp;
size_t local_range_y = 1;
size_t local_range_z = 1;
sycl::range<3> global_range(global_range_z, global_range_y, global_range_x);
sycl::range<3> local_range(local_range_z, local_range_y, local_range_x);
sycl::nd_range<3> parallel_work_size(global_range, local_range);
if (shared_memory) {
// num_params from sycl info = user provided args + shared_memroy_buffer
num_params -= 1;
int threadsPerWarp = 32;
uint32_t numParams = kernelPtr->get_info<sycl::info::kernel::num_args>();
size_t globalRangeX = gridX * threadsPerWarp * numWarps;
size_t globalRangeY = gridY;
size_t globalRangeZ = gridZ;
size_t localRangeX = numWarps * threadsPerWarp;
size_t localRangeY = 1;
size_t localRangeZ = 1;
sycl::range<3> globalRange(globalRangeZ, globalRangeY, globalRangeX);
sycl::range<3> localRange(localRangeZ, localRangeY, localRangeX);
sycl::nd_range<3> parallelWorkSize(globalRange, localRange);
if (sharedMemory) {
// numParams from sycl info = user provided args + sharedMemroyBuffer
numParams -= 1;
}
// Submit the imported kernel.
auto cgf = [&](sycl::handler& cgh) {
for (uint32_t i = 0; i < num_params; ++i) {
for (uint32_t i = 0; i < numParams; ++i) {
cgh.set_arg(i, *(static_cast<void**>(params[i])));
}
if (shared_memory > 0) {
if (sharedMemory > 0) {
constexpr int dimensions = 1;
using share_mem_t = sycl::local_accessor<int8_t, dimensions>;
share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
cgh.set_arg(num_params, local_buffer);
cgh.parallel_for(parallel_work_size, *kernel_ptr);
share_mem_t localBuffer = share_mem_t(sharedMemory, cgh);
cgh.set_arg(numParams, localBuffer);
cgh.parallel_for(parallelWorkSize, *kernelPtr);
} else {
cgh.parallel_for(parallel_work_size, *kernel_ptr);
cgh.parallel_for(parallelWorkSize, *kernelPtr);
}
};
auto event = queue_ptr->submit(cgf);
auto event = queuePtr->submit(cgf);
}
#endif