Files
pytorch/aten/src/ATen/native/vulkan/api/Shader.h
salilsdesai ec94cbc66a [Vulkan] Remove GLSL Code Gen (#91912)
@bypass-github-export-checks

GLSL Code Gen is not used, so this diff removes
- GLSL parts of ShaderSource
- Anything enclosed by USE_VULKAN_SHADERC_RUNTIME, as well as the flag itself
- gen_vulkan_glsl script

Plus some additional refactoring

Differential Revision: [D41358861](https://our.internmc.facebook.com/intern/diff/D41358861/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D41358861/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91912
Approved by: https://github.com/mcr229
2023-01-10 20:29:47 +00:00

199 lines
4.8 KiB
C++

#pragma once
#ifdef USE_VULKAN_API
#include <ATen/native/vulkan/api/Common.h>
#include <ATen/native/vulkan/api/Types.h>
#include <ATen/native/vulkan/api/Utils.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/hash.h>
#include <mutex>
namespace at {
namespace native {
namespace vulkan {
namespace api {
class ShaderLayout final {
public:
using Signature = c10::SmallVector<VkDescriptorType, 6u>;
explicit ShaderLayout(const VkDevice, const Signature&);
ShaderLayout(const ShaderLayout&) = delete;
ShaderLayout& operator=(const ShaderLayout&) = delete;
ShaderLayout(ShaderLayout&&) noexcept;
ShaderLayout& operator=(ShaderLayout&&) = delete;
~ShaderLayout();
private:
VkDevice device_;
VkDescriptorSetLayout handle_;
public:
VkDescriptorSetLayout handle() const {
return handle_;
}
// We need to define a custom swap function since this class
// does not allow for move assignment. The swap function will
// be used in the hash map.
friend void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept;
};
struct ShaderInfo final {
struct {
const uint32_t* bin;
uint32_t size;
} src_code;
std::string kernel_name{""};
ShaderLayout::Signature kernel_layout{};
// Shader Metadata
utils::uvec3 out_tile_size{1u, 1u, 1u};
c10::SmallVector<uint32_t, 4> tile_size;
StorageType bias_storage_type{StorageType::UNKNOWN};
StorageType weight_storage_type{StorageType::UNKNOWN};
explicit ShaderInfo();
explicit ShaderInfo(std::string, const char*);
explicit ShaderInfo(
std::string,
const uint32_t*,
const uint32_t,
const std::vector<VkDescriptorType>&);
explicit ShaderInfo(
std::string,
const uint32_t*,
const uint32_t,
const std::vector<VkDescriptorType>&,
const std::vector<uint32_t>& tile_size,
const StorageType bias_storage_type,
const StorageType weight_storage_type);
};
bool operator==(const ShaderInfo& _1, const ShaderInfo& _2);
class ShaderModule final {
public:
explicit ShaderModule(const VkDevice device, const ShaderInfo& source);
ShaderModule(const ShaderModule&) = delete;
ShaderModule& operator=(const ShaderModule&) = delete;
ShaderModule(ShaderModule&&) noexcept;
ShaderModule& operator=(ShaderModule&&) = delete;
~ShaderModule();
private:
VkDevice device_;
VkShaderModule handle_;
public:
inline VkShaderModule handle() const {
return handle_;
}
// We need to define a custom swap function since this class
// does not allow for move assignment. The swap function will
// be used in the hash map.
friend void swap(ShaderModule& lhs, ShaderModule& rhs) noexcept;
};
class ShaderLayoutCache final {
public:
explicit ShaderLayoutCache(const VkDevice device);
ShaderLayoutCache(const ShaderLayoutCache&) = delete;
ShaderLayoutCache& operator=(const ShaderLayoutCache&) = delete;
ShaderLayoutCache(ShaderLayoutCache&&) noexcept;
ShaderLayoutCache& operator=(ShaderLayoutCache&&) = delete;
~ShaderLayoutCache();
using Key = ShaderLayout::Signature;
using Value = ShaderLayout;
struct Hasher {
inline size_t operator()(const ShaderLayout::Signature& signature) const {
size_t hashed = 0u;
for (const VkDescriptorType type : signature) {
hashed = c10::hash_combine(hashed, c10::get_hash(type));
}
return hashed;
}
};
private:
// Multiple threads could potentially be adding entries into the cache, so use
// a mutex to manage access
std::mutex cache_mutex_;
VkDevice device_;
ska::flat_hash_map<Key, Value, Hasher> cache_;
public:
VkDescriptorSetLayout retrieve(const Key&);
void purge();
};
class ShaderCache final {
public:
explicit ShaderCache(const VkDevice device);
ShaderCache(const ShaderCache&) = delete;
ShaderCache& operator=(const ShaderCache&) = delete;
ShaderCache(ShaderCache&&) noexcept;
ShaderCache& operator=(ShaderCache&&) = delete;
~ShaderCache();
using Key = ShaderInfo;
using Value = ShaderModule;
struct Hasher {
inline size_t operator()(const ShaderInfo& source) const {
return c10::get_hash(source.src_code.bin, source.src_code.size);
}
};
private:
// Multiple threads could potentially be adding entries into the cache, so use
// a mutex to manage access
std::mutex cache_mutex_;
VkDevice device_;
ska::flat_hash_map<Key, Value, Hasher> cache_;
public:
VkShaderModule retrieve(const Key&);
void purge();
};
} // namespace api
} // namespace vulkan
} // namespace native
} // namespace at
inline bool operator==(
const VkDescriptorSetLayoutBinding& _1,
const VkDescriptorSetLayoutBinding& _2) {
return (
_1.binding == _2.binding && _1.descriptorType == _2.descriptorType &&
_1.descriptorCount == _2.descriptorCount &&
_1.stageFlags == _2.stageFlags &&
_1.pImmutableSamplers == _2.pImmutableSamplers);
}
#endif /* USE_VULKAN_API */