mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
@bypass-github-export-checks GLSL Code Gen is not used, so this diff removes - GLSL parts of ShaderSource - Anything enclosed by USE_VULKAN_SHADERC_RUNTIME, as well as the flag itself - gen_vulkan_glsl script Plus some additional refactoring Differential Revision: [D41358861](https://our.internmc.facebook.com/intern/diff/D41358861/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D41358861/)! Pull Request resolved: https://github.com/pytorch/pytorch/pull/91912 Approved by: https://github.com/mcr229
199 lines
4.8 KiB
C++
199 lines
4.8 KiB
C++
#pragma once
|
|
|
|
#ifdef USE_VULKAN_API
|
|
|
|
#include <ATen/native/vulkan/api/Common.h>
|
|
#include <ATen/native/vulkan/api/Types.h>
|
|
#include <ATen/native/vulkan/api/Utils.h>
|
|
#include <c10/util/flat_hash_map.h>
|
|
#include <c10/util/hash.h>
|
|
|
|
#include <mutex>
|
|
|
|
namespace at {
|
|
namespace native {
|
|
namespace vulkan {
|
|
namespace api {
|
|
|
|
class ShaderLayout final {
|
|
public:
|
|
using Signature = c10::SmallVector<VkDescriptorType, 6u>;
|
|
|
|
explicit ShaderLayout(const VkDevice, const Signature&);
|
|
|
|
ShaderLayout(const ShaderLayout&) = delete;
|
|
ShaderLayout& operator=(const ShaderLayout&) = delete;
|
|
|
|
ShaderLayout(ShaderLayout&&) noexcept;
|
|
ShaderLayout& operator=(ShaderLayout&&) = delete;
|
|
|
|
~ShaderLayout();
|
|
|
|
private:
|
|
VkDevice device_;
|
|
VkDescriptorSetLayout handle_;
|
|
|
|
public:
|
|
VkDescriptorSetLayout handle() const {
|
|
return handle_;
|
|
}
|
|
|
|
// We need to define a custom swap function since this class
|
|
// does not allow for move assignment. The swap function will
|
|
// be used in the hash map.
|
|
friend void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept;
|
|
};
|
|
|
|
struct ShaderInfo final {
|
|
struct {
|
|
const uint32_t* bin;
|
|
uint32_t size;
|
|
} src_code;
|
|
|
|
std::string kernel_name{""};
|
|
ShaderLayout::Signature kernel_layout{};
|
|
|
|
// Shader Metadata
|
|
utils::uvec3 out_tile_size{1u, 1u, 1u};
|
|
|
|
c10::SmallVector<uint32_t, 4> tile_size;
|
|
StorageType bias_storage_type{StorageType::UNKNOWN};
|
|
StorageType weight_storage_type{StorageType::UNKNOWN};
|
|
|
|
explicit ShaderInfo();
|
|
explicit ShaderInfo(std::string, const char*);
|
|
explicit ShaderInfo(
|
|
std::string,
|
|
const uint32_t*,
|
|
const uint32_t,
|
|
const std::vector<VkDescriptorType>&);
|
|
explicit ShaderInfo(
|
|
std::string,
|
|
const uint32_t*,
|
|
const uint32_t,
|
|
const std::vector<VkDescriptorType>&,
|
|
const std::vector<uint32_t>& tile_size,
|
|
const StorageType bias_storage_type,
|
|
const StorageType weight_storage_type);
|
|
};
|
|
|
|
bool operator==(const ShaderInfo& _1, const ShaderInfo& _2);
|
|
|
|
class ShaderModule final {
|
|
public:
|
|
explicit ShaderModule(const VkDevice device, const ShaderInfo& source);
|
|
|
|
ShaderModule(const ShaderModule&) = delete;
|
|
ShaderModule& operator=(const ShaderModule&) = delete;
|
|
|
|
ShaderModule(ShaderModule&&) noexcept;
|
|
ShaderModule& operator=(ShaderModule&&) = delete;
|
|
|
|
~ShaderModule();
|
|
|
|
private:
|
|
VkDevice device_;
|
|
VkShaderModule handle_;
|
|
|
|
public:
|
|
inline VkShaderModule handle() const {
|
|
return handle_;
|
|
}
|
|
|
|
// We need to define a custom swap function since this class
|
|
// does not allow for move assignment. The swap function will
|
|
// be used in the hash map.
|
|
friend void swap(ShaderModule& lhs, ShaderModule& rhs) noexcept;
|
|
};
|
|
|
|
class ShaderLayoutCache final {
|
|
public:
|
|
explicit ShaderLayoutCache(const VkDevice device);
|
|
|
|
ShaderLayoutCache(const ShaderLayoutCache&) = delete;
|
|
ShaderLayoutCache& operator=(const ShaderLayoutCache&) = delete;
|
|
|
|
ShaderLayoutCache(ShaderLayoutCache&&) noexcept;
|
|
ShaderLayoutCache& operator=(ShaderLayoutCache&&) = delete;
|
|
|
|
~ShaderLayoutCache();
|
|
|
|
using Key = ShaderLayout::Signature;
|
|
using Value = ShaderLayout;
|
|
|
|
struct Hasher {
|
|
inline size_t operator()(const ShaderLayout::Signature& signature) const {
|
|
size_t hashed = 0u;
|
|
|
|
for (const VkDescriptorType type : signature) {
|
|
hashed = c10::hash_combine(hashed, c10::get_hash(type));
|
|
}
|
|
|
|
return hashed;
|
|
}
|
|
};
|
|
|
|
private:
|
|
// Multiple threads could potentially be adding entries into the cache, so use
|
|
// a mutex to manage access
|
|
std::mutex cache_mutex_;
|
|
|
|
VkDevice device_;
|
|
ska::flat_hash_map<Key, Value, Hasher> cache_;
|
|
|
|
public:
|
|
VkDescriptorSetLayout retrieve(const Key&);
|
|
void purge();
|
|
};
|
|
|
|
class ShaderCache final {
|
|
public:
|
|
explicit ShaderCache(const VkDevice device);
|
|
|
|
ShaderCache(const ShaderCache&) = delete;
|
|
ShaderCache& operator=(const ShaderCache&) = delete;
|
|
|
|
ShaderCache(ShaderCache&&) noexcept;
|
|
ShaderCache& operator=(ShaderCache&&) = delete;
|
|
|
|
~ShaderCache();
|
|
|
|
using Key = ShaderInfo;
|
|
using Value = ShaderModule;
|
|
|
|
struct Hasher {
|
|
inline size_t operator()(const ShaderInfo& source) const {
|
|
return c10::get_hash(source.src_code.bin, source.src_code.size);
|
|
}
|
|
};
|
|
|
|
private:
|
|
// Multiple threads could potentially be adding entries into the cache, so use
|
|
// a mutex to manage access
|
|
std::mutex cache_mutex_;
|
|
|
|
VkDevice device_;
|
|
ska::flat_hash_map<Key, Value, Hasher> cache_;
|
|
|
|
public:
|
|
VkShaderModule retrieve(const Key&);
|
|
void purge();
|
|
};
|
|
|
|
} // namespace api
|
|
} // namespace vulkan
|
|
} // namespace native
|
|
} // namespace at
|
|
|
|
inline bool operator==(
|
|
const VkDescriptorSetLayoutBinding& _1,
|
|
const VkDescriptorSetLayoutBinding& _2) {
|
|
return (
|
|
_1.binding == _2.binding && _1.descriptorType == _2.descriptorType &&
|
|
_1.descriptorCount == _2.descriptorCount &&
|
|
_1.stageFlags == _2.stageFlags &&
|
|
_1.pImmutableSamplers == _2.pImmutableSamplers);
|
|
}
|
|
|
|
#endif /* USE_VULKAN_API */
|