mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/143804 Approved by: https://github.com/Skylion007
96 lines
2.9 KiB
C++
96 lines
2.9 KiB
C++
#pragma once
|
|
|
|
#include <ATen/Config.h>
|
|
|
|
#include <c10/core/Device.h>
|
|
#include <c10/util/flat_hash_map.h>
|
|
#include <c10/xpu/XPUFunctions.h>
|
|
#include <c10/xpu/XPUStream.h>
|
|
|
|
#include <oneapi/dnnl/dnnl.hpp>
|
|
#include <oneapi/dnnl/dnnl_sycl.hpp>
|
|
#include <vector>
|
|
|
|
namespace at::native::onednn {
|
|
|
|
TORCH_XPU_API dnnl::memory make_onednn_memory(
|
|
dnnl::memory::desc md,
|
|
dnnl::engine& engine,
|
|
void* ptr);
|
|
|
|
// Keep non-static and non-inline
|
|
bool set_onednn_verbose(int level);
|
|
|
|
// GpuEngineManager singleton
|
|
struct TORCH_XPU_API GpuEngineManager {
|
|
static GpuEngineManager& Instance(); // Singleton
|
|
|
|
dnnl::engine& get_engine(const Device& device) {
|
|
TORCH_INTERNAL_ASSERT(device.type() == kXPU);
|
|
TORCH_INTERNAL_ASSERT(device.index() < c10::xpu::device_count());
|
|
return *engine_pool[device.index()];
|
|
}
|
|
|
|
GpuEngineManager(GpuEngineManager const&) = delete;
|
|
GpuEngineManager& operator=(GpuEngineManager const&) = delete;
|
|
GpuEngineManager(GpuEngineManager&&) = default;
|
|
GpuEngineManager& operator=(GpuEngineManager&&) = default;
|
|
|
|
protected:
|
|
GpuEngineManager() {
|
|
c10::DeviceIndex device_count = c10::xpu::device_count();
|
|
TORCH_INTERNAL_ASSERT(device_count > 0);
|
|
for (const auto i : c10::irange(device_count)) {
|
|
engine_pool.push_back(
|
|
std::make_shared<dnnl::engine>(dnnl::sycl_interop::make_engine(
|
|
c10::xpu::get_raw_device(i), c10::xpu::get_device_context())));
|
|
}
|
|
}
|
|
~GpuEngineManager() = default;
|
|
|
|
private:
|
|
std::vector<std::shared_ptr<dnnl::engine>> engine_pool;
|
|
};
|
|
|
|
// GpuStreamManager singleton
|
|
struct TORCH_XPU_API GpuStreamManager {
|
|
static GpuStreamManager& Instance(); // Singleton
|
|
|
|
dnnl::stream get_stream() {
|
|
auto stream = c10::xpu::getCurrentXPUStream();
|
|
auto priority = stream.priority();
|
|
auto device_index = stream.device_index();
|
|
if (stream_pool[device_index][priority].find(stream) ==
|
|
stream_pool[device_index][priority].end()) {
|
|
stream_pool[device_index][priority][stream] =
|
|
std::make_shared<dnnl::stream>(dnnl::sycl_interop::make_stream(
|
|
GpuEngineManager::Instance().get_engine(
|
|
{c10::kXPU, device_index}),
|
|
stream.queue()));
|
|
}
|
|
return *stream_pool[device_index][priority][stream];
|
|
}
|
|
|
|
GpuStreamManager(GpuStreamManager const&) = delete;
|
|
GpuStreamManager& operator=(GpuStreamManager const&) = delete;
|
|
GpuStreamManager(GpuStreamManager&&) = default;
|
|
GpuStreamManager& operator=(GpuStreamManager&&) = default;
|
|
|
|
protected:
|
|
GpuStreamManager() {
|
|
c10::DeviceIndex device_count = c10::xpu::device_count();
|
|
TORCH_INTERNAL_ASSERT(device_count > 0);
|
|
stream_pool.resize(device_count);
|
|
}
|
|
~GpuStreamManager() = default;
|
|
|
|
private:
|
|
using stream_hash_map =
|
|
ska::flat_hash_map<c10::xpu::XPUStream, std::shared_ptr<dnnl::stream>>;
|
|
std::vector<
|
|
std::array<stream_hash_map, c10::xpu::max_compile_time_stream_priorities>>
|
|
stream_pool;
|
|
};
|
|
|
|
} // namespace at::native::onednn
|