mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI][XPU] Refactor AOTInductor runtime API for Intel GPU. (#153929)
Simplify and improve code format for sycl_runtime_wrappers.h Pull Request resolved: https://github.com/pytorch/pytorch/pull/153929 Approved by: https://github.com/desertfire ghstack dependencies: #153924
This commit is contained in:
committed by
PyTorch MergeBot
parent
531d8f5fb6
commit
dcb3edd30d
@ -17,87 +17,74 @@
|
||||
} \
|
||||
}
|
||||
|
||||
static ze_module_handle_t create_module(
|
||||
const uint8_t* binary_ptr,
|
||||
size_t binary_size) {
|
||||
sycl::device& sycl_device =
|
||||
static ze_module_handle_t _createModule(
|
||||
const uint8_t* binaryPtr,
|
||||
size_t binarySize) {
|
||||
sycl::device& syclDevice =
|
||||
c10::xpu::get_raw_device(c10::xpu::current_device());
|
||||
auto sycl_context =
|
||||
sycl_device.get_platform().ext_oneapi_get_default_context();
|
||||
auto l0_device =
|
||||
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
|
||||
auto l0_context =
|
||||
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
|
||||
auto syclContext = syclDevice.get_platform().ext_oneapi_get_default_context();
|
||||
auto device =
|
||||
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(syclDevice);
|
||||
auto context =
|
||||
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(syclContext);
|
||||
|
||||
const char* build_flags = "";
|
||||
const char* buildFlags = "";
|
||||
const ze_module_format_t format = ZE_MODULE_FORMAT_IL_SPIRV;
|
||||
ze_module_desc_t module_description = {};
|
||||
module_description.stype = ZE_STRUCTURE_TYPE_MODULE_DESC;
|
||||
module_description.format = format;
|
||||
module_description.inputSize = binary_size;
|
||||
module_description.pInputModule = (uint8_t*)binary_ptr;
|
||||
module_description.pBuildFlags = build_flags;
|
||||
ze_module_build_log_handle_t buildlog = nullptr;
|
||||
ze_module_desc_t moduleDescription = {};
|
||||
moduleDescription.stype = ZE_STRUCTURE_TYPE_MODULE_DESC;
|
||||
moduleDescription.format = format;
|
||||
moduleDescription.inputSize = binarySize;
|
||||
moduleDescription.pInputModule = (uint8_t*)binaryPtr;
|
||||
moduleDescription.pBuildFlags = buildFlags;
|
||||
ze_module_build_log_handle_t buildLog = nullptr;
|
||||
ze_module_handle_t module = nullptr;
|
||||
auto error_no = ZE_RESULT_SUCCESS;
|
||||
error_no = zeModuleCreate(
|
||||
l0_context, l0_device, &module_description, &module, &buildlog);
|
||||
error_no =
|
||||
zeModuleCreate(context, device, &moduleDescription, &module, &buildLog);
|
||||
|
||||
if (error_no != ZE_RESULT_SUCCESS) {
|
||||
size_t szLog = 0;
|
||||
ZE_CHECK(zeModuleBuildLogGetString(buildlog, &szLog, nullptr));
|
||||
ZE_CHECK(zeModuleBuildLogGetString(buildLog, &szLog, nullptr));
|
||||
char* strLog = (char*)malloc(szLog);
|
||||
ZE_CHECK(zeModuleBuildLogGetString(buildlog, &szLog, strLog));
|
||||
ZE_CHECK(zeModuleBuildLogGetString(buildLog, &szLog, strLog));
|
||||
std::cerr << "L0 build module failed. Log: " << strLog << std::endl;
|
||||
free(strLog);
|
||||
}
|
||||
if (buildlog) {
|
||||
ZE_CHECK(zeModuleBuildLogDestroy(buildlog));
|
||||
if (buildLog) {
|
||||
ZE_CHECK(zeModuleBuildLogDestroy(buildLog));
|
||||
}
|
||||
ZE_CHECK(error_no);
|
||||
return module;
|
||||
}
|
||||
|
||||
ze_kernel_handle_t create_function(
|
||||
static std::unique_ptr<sycl::kernel> _createKernel(
|
||||
ze_module_handle_t module,
|
||||
ze_kernel_flags_t flag,
|
||||
const std::string& func_name) {
|
||||
ze_kernel_handle_t kernel = nullptr;
|
||||
ze_kernel_desc_t kernel_description = {};
|
||||
kernel_description.stype = ZE_STRUCTURE_TYPE_KERNEL_DESC;
|
||||
kernel_description.pNext = nullptr;
|
||||
kernel_description.flags = flag;
|
||||
kernel_description.pKernelName = func_name.c_str();
|
||||
const char* kernelName) {
|
||||
assert(module);
|
||||
ZE_CHECK(zeKernelCreate(module, &kernel_description, &kernel));
|
||||
return kernel;
|
||||
}
|
||||
assert(kernelName);
|
||||
ze_kernel_handle_t kernel = nullptr;
|
||||
ze_kernel_desc_t kernelDescription = {};
|
||||
kernelDescription.stype = ZE_STRUCTURE_TYPE_KERNEL_DESC;
|
||||
kernelDescription.pNext = nullptr;
|
||||
kernelDescription.flags = ZE_KERNEL_FLAG_FORCE_RESIDENCY;
|
||||
kernelDescription.pKernelName = kernelName;
|
||||
ZE_CHECK(zeKernelCreate(module, &kernelDescription, &kernel));
|
||||
|
||||
static std::unique_ptr<sycl::kernel> getKernel(
|
||||
ze_module_handle_t l0_module,
|
||||
const char* kernel_name) {
|
||||
assert(l0_module);
|
||||
assert(kernel_name);
|
||||
auto l0_kernel =
|
||||
create_function(l0_module, ZE_KERNEL_FLAG_FORCE_RESIDENCY, kernel_name);
|
||||
|
||||
sycl::device& sycl_device =
|
||||
sycl::device& syclDevice =
|
||||
c10::xpu::get_raw_device(c10::xpu::current_device());
|
||||
auto sycl_context =
|
||||
sycl_device.get_platform().ext_oneapi_get_default_context();
|
||||
|
||||
auto syclContext = syclDevice.get_platform().ext_oneapi_get_default_context();
|
||||
auto mod = sycl::make_kernel_bundle<
|
||||
sycl::backend::ext_oneapi_level_zero,
|
||||
sycl::bundle_state::executable>(
|
||||
{l0_module, sycl::ext::oneapi::level_zero::ownership::transfer},
|
||||
sycl_context);
|
||||
|
||||
{module, sycl::ext::oneapi::level_zero::ownership::transfer},
|
||||
syclContext);
|
||||
auto fun = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
|
||||
{mod, l0_kernel, sycl::ext::oneapi::level_zero::ownership::transfer},
|
||||
sycl_context);
|
||||
{mod, kernel, sycl::ext::oneapi::level_zero::ownership::transfer},
|
||||
syclContext);
|
||||
return std::make_unique<sycl::kernel>(fun);
|
||||
}
|
||||
|
||||
// GPU Cpp Wrapper API
|
||||
[[maybe_unused]] static std::unique_ptr<sycl::kernel> loadKernel(
|
||||
std::string filePath,
|
||||
const std::string& funcName,
|
||||
@ -114,12 +101,13 @@ static std::unique_ptr<sycl::kernel> getKernel(
|
||||
OSS << IFS.rdbuf();
|
||||
std::string data(OSS.str());
|
||||
|
||||
auto mod = create_module(
|
||||
auto mod = _createModule(
|
||||
reinterpret_cast<const uint8_t*>(data.c_str()), data.size());
|
||||
|
||||
return getKernel(mod, funcName.c_str());
|
||||
return _createKernel(mod, funcName.c_str());
|
||||
}
|
||||
|
||||
// GPU Cpp Wrapper API
|
||||
[[maybe_unused]] static std::unique_ptr<sycl::kernel> loadKernel(
|
||||
const void* start,
|
||||
const void* end,
|
||||
@ -128,55 +116,56 @@ static std::unique_ptr<sycl::kernel> getKernel(
|
||||
size_t size = reinterpret_cast<const uint8_t*>(end) -
|
||||
reinterpret_cast<const uint8_t*>(start);
|
||||
|
||||
auto mod = create_module(reinterpret_cast<const uint8_t*>(start), size);
|
||||
auto mod = _createModule(reinterpret_cast<const uint8_t*>(start), size);
|
||||
|
||||
return getKernel(mod, funcName.c_str());
|
||||
return _createKernel(mod, funcName.c_str());
|
||||
}
|
||||
|
||||
// GPU Cpp Wrapper API
|
||||
[[maybe_unused]] static void launchKernel(
|
||||
std::unique_ptr<sycl::kernel>& kernel_ptr,
|
||||
uint32_t grid_x,
|
||||
uint32_t grid_y,
|
||||
uint32_t grid_z,
|
||||
uint32_t num_warps,
|
||||
uint32_t shared_memory,
|
||||
std::unique_ptr<sycl::kernel>& kernelPtr,
|
||||
uint32_t gridX,
|
||||
uint32_t gridY,
|
||||
uint32_t gridZ,
|
||||
uint32_t numWarps,
|
||||
uint32_t sharedMemory,
|
||||
void** params,
|
||||
sycl::queue* queue_ptr) {
|
||||
std::string kernel_name =
|
||||
kernel_ptr->get_info<sycl::info::kernel::function_name>();
|
||||
// Currently threads_per_warp is hard code to 32 from torch.compile to triton
|
||||
sycl::queue* queuePtr) {
|
||||
std::string kernelName =
|
||||
kernelPtr->get_info<sycl::info::kernel::function_name>();
|
||||
// Currently threadsPerWarp is hard code to 32 from torch.compile to triton
|
||||
// stack.
|
||||
int threads_per_warp = 32;
|
||||
uint32_t num_params = kernel_ptr->get_info<sycl::info::kernel::num_args>();
|
||||
size_t global_range_x = grid_x * threads_per_warp * num_warps;
|
||||
size_t global_range_y = grid_y;
|
||||
size_t global_range_z = grid_z;
|
||||
size_t local_range_x = num_warps * threads_per_warp;
|
||||
size_t local_range_y = 1;
|
||||
size_t local_range_z = 1;
|
||||
sycl::range<3> global_range(global_range_z, global_range_y, global_range_x);
|
||||
sycl::range<3> local_range(local_range_z, local_range_y, local_range_x);
|
||||
sycl::nd_range<3> parallel_work_size(global_range, local_range);
|
||||
if (shared_memory) {
|
||||
// num_params from sycl info = user provided args + shared_memroy_buffer
|
||||
num_params -= 1;
|
||||
int threadsPerWarp = 32;
|
||||
uint32_t numParams = kernelPtr->get_info<sycl::info::kernel::num_args>();
|
||||
size_t globalRangeX = gridX * threadsPerWarp * numWarps;
|
||||
size_t globalRangeY = gridY;
|
||||
size_t globalRangeZ = gridZ;
|
||||
size_t localRangeX = numWarps * threadsPerWarp;
|
||||
size_t localRangeY = 1;
|
||||
size_t localRangeZ = 1;
|
||||
sycl::range<3> globalRange(globalRangeZ, globalRangeY, globalRangeX);
|
||||
sycl::range<3> localRange(localRangeZ, localRangeY, localRangeX);
|
||||
sycl::nd_range<3> parallelWorkSize(globalRange, localRange);
|
||||
if (sharedMemory) {
|
||||
// numParams from sycl info = user provided args + sharedMemroyBuffer
|
||||
numParams -= 1;
|
||||
}
|
||||
// Submit the imported kernel.
|
||||
auto cgf = [&](sycl::handler& cgh) {
|
||||
for (uint32_t i = 0; i < num_params; ++i) {
|
||||
for (uint32_t i = 0; i < numParams; ++i) {
|
||||
cgh.set_arg(i, *(static_cast<void**>(params[i])));
|
||||
}
|
||||
|
||||
if (shared_memory > 0) {
|
||||
if (sharedMemory > 0) {
|
||||
constexpr int dimensions = 1;
|
||||
using share_mem_t = sycl::local_accessor<int8_t, dimensions>;
|
||||
share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
|
||||
cgh.set_arg(num_params, local_buffer);
|
||||
cgh.parallel_for(parallel_work_size, *kernel_ptr);
|
||||
share_mem_t localBuffer = share_mem_t(sharedMemory, cgh);
|
||||
cgh.set_arg(numParams, localBuffer);
|
||||
cgh.parallel_for(parallelWorkSize, *kernelPtr);
|
||||
} else {
|
||||
cgh.parallel_for(parallel_work_size, *kernel_ptr);
|
||||
cgh.parallel_for(parallelWorkSize, *kernelPtr);
|
||||
}
|
||||
};
|
||||
auto event = queue_ptr->submit(cgf);
|
||||
auto event = queuePtr->submit(cgf);
|
||||
}
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user