mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Vulkan] Remove GLSL Code Gen (#91912)
@bypass-github-export-checks GLSL Code Gen is not used, so this diff removes - GLSL parts of ShaderSource - Anything enclosed by USE_VULKAN_SHADERC_RUNTIME, as well as the flag itself - gen_vulkan_glsl script Plus some additional refactoring Differential Revision: [D41358861](https://our.internmc.facebook.com/intern/diff/D41358861/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D41358861/)! Pull Request resolved: https://github.com/pytorch/pytorch/pull/91912 Approved by: https://github.com/mcr229
This commit is contained in:
committed by
PyTorch MergeBot
parent
28eb3c8faf
commit
ec94cbc66a
@ -266,7 +266,6 @@ option(USE_SOURCE_DEBUG_ON_MOBILE "Enable " ON)
|
|||||||
option(USE_LITE_INTERPRETER_PROFILER "Enable " ON)
|
option(USE_LITE_INTERPRETER_PROFILER "Enable " ON)
|
||||||
option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF)
|
option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF)
|
||||||
option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF)
|
option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF)
|
||||||
option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation as opposed to build-time (needs libshaderc)" OFF)
|
|
||||||
# option USE_XNNPACK: try to enable xnnpack by default.
|
# option USE_XNNPACK: try to enable xnnpack by default.
|
||||||
option(USE_XNNPACK "Use XNNPACK" ON)
|
option(USE_XNNPACK "Use XNNPACK" ON)
|
||||||
option(USE_ZMQ "Use ZMQ" OFF)
|
option(USE_ZMQ "Use ZMQ" OFF)
|
||||||
@ -746,9 +745,6 @@ if(USE_VULKAN)
|
|||||||
string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_RELAXED_PRECISION")
|
string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_RELAXED_PRECISION")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(USE_VULKAN_SHADERC_RUNTIME)
|
|
||||||
string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_SHADERC_RUNTIME")
|
|
||||||
endif()
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(BUILD_LITE_INTERPRETER)
|
if(BUILD_LITE_INTERPRETER)
|
||||||
|
@ -1,115 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import glob
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from torchgen.code_template import CodeTemplate
|
|
||||||
|
|
||||||
H_NAME = "glsl.h"
|
|
||||||
CPP_NAME = "glsl.cpp"
|
|
||||||
DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"}
|
|
||||||
|
|
||||||
def findAllGlsls(path):
|
|
||||||
vexs = glob.glob(os.path.join(path, '**', '*.glsl'), recursive=True)
|
|
||||||
output = []
|
|
||||||
for f in vexs:
|
|
||||||
if len(f) > 1:
|
|
||||||
output.append(f)
|
|
||||||
output.sort()
|
|
||||||
return output
|
|
||||||
|
|
||||||
def getName(filePath):
|
|
||||||
return os.path.basename(filePath).replace("/", "_").replace(".", "_")
|
|
||||||
|
|
||||||
def genCppH(hFilePath, cppFilePath, templateGlslPaths, tmpDirPath, env):
|
|
||||||
print("hFilePath:{}".format(hFilePath))
|
|
||||||
print("cppFilePath:{}".format(cppFilePath))
|
|
||||||
h = "#pragma once\n"
|
|
||||||
nsbegin = "\nnamespace at { namespace native { namespace vulkan { \n"
|
|
||||||
nsend = "\n} } } //namespace at::native::vulkan\n"
|
|
||||||
|
|
||||||
h += nsbegin
|
|
||||||
|
|
||||||
cpp = "#include <ATen/native/vulkan/{}>".format(H_NAME)
|
|
||||||
cpp += nsbegin
|
|
||||||
|
|
||||||
for templateGlslPath in templateGlslPaths:
|
|
||||||
name = getName(templateGlslPath)
|
|
||||||
h += "extern const char* " + name + ";\n"
|
|
||||||
cpp += "const char* " + name + " = \n"
|
|
||||||
|
|
||||||
codeTemplate = CodeTemplate.from_file(templateGlslPath)
|
|
||||||
srcPath = tmpDirPath + "/" + name + ".glsl"
|
|
||||||
content = codeTemplate.substitute(env)
|
|
||||||
|
|
||||||
lines = content.split("\n")
|
|
||||||
for l in lines:
|
|
||||||
if (len(l) < 1):
|
|
||||||
continue
|
|
||||||
cpp += "\"" + l + "\\n\"\n"
|
|
||||||
|
|
||||||
cpp += ";\n"
|
|
||||||
|
|
||||||
cpp += nsend
|
|
||||||
h += nsend
|
|
||||||
|
|
||||||
with open(hFilePath, "w") as f:
|
|
||||||
f.write(h)
|
|
||||||
with open(cppFilePath, "w") as f:
|
|
||||||
f.write(cpp)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_arg_env(items):
|
|
||||||
d = {}
|
|
||||||
if items:
|
|
||||||
for item in items:
|
|
||||||
tokens = item.split("=")
|
|
||||||
key = tokens[0].strip()
|
|
||||||
value = tokens[1].strip()
|
|
||||||
d[key] = value
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
|
||||||
parser = argparse.ArgumentParser(description='Generate glsl.cpp and glsl.h containing glsl sources')
|
|
||||||
parser.add_argument(
|
|
||||||
'-i',
|
|
||||||
'--glsl-path',
|
|
||||||
help='path to directory with glsl to process',
|
|
||||||
required=True,
|
|
||||||
default='.')
|
|
||||||
parser.add_argument(
|
|
||||||
'-o',
|
|
||||||
'--output-path',
|
|
||||||
help='path to directory to generate glsl.h glsl.cpp (cpp namespace at::native::vulkan)',
|
|
||||||
required=True)
|
|
||||||
parser.add_argument(
|
|
||||||
'-t',
|
|
||||||
'--tmp-dir-path',
|
|
||||||
required=True,
|
|
||||||
help='/tmp')
|
|
||||||
parser.add_argument(
|
|
||||||
"--env",
|
|
||||||
metavar="KEY=VALUE",
|
|
||||||
nargs='*',
|
|
||||||
help="Set a number of key-value pairs")
|
|
||||||
options = parser.parse_args()
|
|
||||||
if not os.path.exists(options.tmp_dir_path):
|
|
||||||
os.makedirs(options.tmp_dir_path)
|
|
||||||
env = DEFAULT_ENV
|
|
||||||
for key, value in parse_arg_env(options.env).items():
|
|
||||||
env[key] = value
|
|
||||||
|
|
||||||
if not os.path.exists(options.output_path):
|
|
||||||
os.makedirs(options.output_path)
|
|
||||||
|
|
||||||
glsls = findAllGlsls(options.glsl_path)
|
|
||||||
genCppH(
|
|
||||||
options.output_path + "/" + H_NAME, options.output_path + "/" + CPP_NAME,
|
|
||||||
glsls,
|
|
||||||
tmpDirPath=options.tmp_dir_path,
|
|
||||||
env=env)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
sys.exit(main(sys.argv))
|
|
@ -1,9 +1,5 @@
|
|||||||
#include <ATen/native/vulkan/api/Shader.h>
|
#include <ATen/native/vulkan/api/Shader.h>
|
||||||
|
|
||||||
#ifdef USE_VULKAN_SHADERC_RUNTIME
|
|
||||||
#include <shaderc/shaderc.hpp>
|
|
||||||
#endif /* USE_VULKAN_SHADERC_RUNTIME */
|
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
namespace vulkan {
|
namespace vulkan {
|
||||||
@ -14,38 +10,19 @@ namespace api {
|
|||||||
//
|
//
|
||||||
|
|
||||||
ShaderInfo::ShaderInfo()
|
ShaderInfo::ShaderInfo()
|
||||||
: type(ShaderInfo::Type::SPIRV),
|
: src_code{
|
||||||
src_code{
|
nullptr,
|
||||||
.spirv =
|
0u,
|
||||||
{
|
|
||||||
nullptr,
|
|
||||||
0u,
|
|
||||||
},
|
|
||||||
} {}
|
} {}
|
||||||
|
|
||||||
ShaderInfo::ShaderInfo(std::string name, const char* const glsl_src)
|
|
||||||
: type(ShaderInfo::Type::GLSL),
|
|
||||||
src_code{
|
|
||||||
.glsl =
|
|
||||||
{
|
|
||||||
glsl_src,
|
|
||||||
0u,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
kernel_name{std::move(name)} {}
|
|
||||||
|
|
||||||
ShaderInfo::ShaderInfo(
|
ShaderInfo::ShaderInfo(
|
||||||
std::string name,
|
std::string name,
|
||||||
const uint32_t* const spirv_bin,
|
const uint32_t* const spirv_bin,
|
||||||
const uint32_t size,
|
const uint32_t size,
|
||||||
const std::vector<VkDescriptorType>& layout)
|
const std::vector<VkDescriptorType>& layout)
|
||||||
: type(Type::SPIRV),
|
: src_code{
|
||||||
src_code{
|
spirv_bin,
|
||||||
.spirv =
|
size,
|
||||||
{
|
|
||||||
spirv_bin,
|
|
||||||
size,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
kernel_name{std::move(name)},
|
kernel_name{std::move(name)},
|
||||||
kernel_layout{layout} {}
|
kernel_layout{layout} {}
|
||||||
@ -58,13 +35,9 @@ ShaderInfo::ShaderInfo(
|
|||||||
const std::vector<uint32_t>& tile_size,
|
const std::vector<uint32_t>& tile_size,
|
||||||
const StorageType bias_storage_type,
|
const StorageType bias_storage_type,
|
||||||
const StorageType weight_storage_type)
|
const StorageType weight_storage_type)
|
||||||
: type(Type::SPIRV),
|
: src_code{
|
||||||
src_code{
|
spirv_bin,
|
||||||
.spirv =
|
size,
|
||||||
{
|
|
||||||
spirv_bin,
|
|
||||||
size,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
kernel_name{std::move(name)},
|
kernel_name{std::move(name)},
|
||||||
kernel_layout{layout},
|
kernel_layout{layout},
|
||||||
@ -77,17 +50,9 @@ ShaderInfo::ShaderInfo(
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) {
|
bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) {
|
||||||
if (_1.type != _2.type) {
|
return (
|
||||||
return false;
|
_1.src_code.bin == _2.src_code.bin &&
|
||||||
}
|
_1.src_code.size == _2.src_code.size);
|
||||||
|
|
||||||
if (_1.type == ShaderInfo::Type::SPIRV) {
|
|
||||||
return (
|
|
||||||
_1.src_code.spirv.bin == _2.src_code.spirv.bin &&
|
|
||||||
_1.src_code.spirv.size == _2.src_code.spirv.size);
|
|
||||||
} else {
|
|
||||||
return (_1.src_code.glsl.src == _2.src_code.glsl.src);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -153,8 +118,8 @@ void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept {
|
|||||||
|
|
||||||
ShaderModule::ShaderModule(const VkDevice device, const ShaderInfo& source)
|
ShaderModule::ShaderModule(const VkDevice device, const ShaderInfo& source)
|
||||||
: device_(device), handle_{VK_NULL_HANDLE} {
|
: device_(device), handle_{VK_NULL_HANDLE} {
|
||||||
const uint32_t* code = source.src_code.spirv.bin;
|
const uint32_t* code = source.src_code.bin;
|
||||||
uint32_t size = source.src_code.spirv.size;
|
uint32_t size = source.src_code.size;
|
||||||
|
|
||||||
const VkShaderModuleCreateInfo shader_module_create_info{
|
const VkShaderModuleCreateInfo shader_module_create_info{
|
||||||
VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType
|
VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType
|
||||||
|
@ -45,17 +45,9 @@ class ShaderLayout final {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct ShaderInfo final {
|
struct ShaderInfo final {
|
||||||
enum class Type { GLSL, SPIRV } type;
|
struct {
|
||||||
|
const uint32_t* bin;
|
||||||
union {
|
uint32_t size;
|
||||||
struct {
|
|
||||||
const char* src; // Null-terminated
|
|
||||||
uint32_t unused; // padding
|
|
||||||
} glsl;
|
|
||||||
struct {
|
|
||||||
const uint32_t* bin;
|
|
||||||
uint32_t size;
|
|
||||||
} spirv;
|
|
||||||
} src_code;
|
} src_code;
|
||||||
|
|
||||||
std::string kernel_name{""};
|
std::string kernel_name{""};
|
||||||
@ -171,8 +163,7 @@ class ShaderCache final {
|
|||||||
|
|
||||||
struct Hasher {
|
struct Hasher {
|
||||||
inline size_t operator()(const ShaderInfo& source) const {
|
inline size_t operator()(const ShaderInfo& source) const {
|
||||||
return c10::get_hash(
|
return c10::get_hash(source.src_code.bin, source.src_code.size);
|
||||||
source.type, source.src_code.spirv.bin, source.src_code.spirv.size);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -6,18 +6,9 @@
|
|||||||
#include <ATen/core/Tensor.h>
|
#include <ATen/core/Tensor.h>
|
||||||
#include <ATen/native/vulkan/api/api.h>
|
#include <ATen/native/vulkan/api/api.h>
|
||||||
#include <ATen/native/vulkan/ops/Convert.h>
|
#include <ATen/native/vulkan/ops/Convert.h>
|
||||||
|
|
||||||
#define CONCAT_LITERALS(a, b) #a #b
|
|
||||||
#ifdef USE_VULKAN_SHADERC_RUNTIME
|
|
||||||
#include <ATen/native/vulkan/glsl.h>
|
|
||||||
#define VK_KERNEL(name) \
|
|
||||||
::at::native::vulkan::api::ShaderInfo { \
|
|
||||||
CONCAT_LITERALS(vulkan., name), name##_glsl, \
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
#include <ATen/native/vulkan/spv.h>
|
#include <ATen/native/vulkan/spv.h>
|
||||||
|
|
||||||
#define VK_KERNEL(name) ::at::native::vulkan::name##_spv
|
#define VK_KERNEL(name) ::at::native::vulkan::name##_spv
|
||||||
#endif /* USE_VULKAN_SHADERC_RUNTIME */
|
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
|
@ -175,7 +175,6 @@ function(caffe2_print_configuration_summary)
|
|||||||
if(${USE_VULKAN})
|
if(${USE_VULKAN})
|
||||||
message(STATUS " USE_VULKAN_FP16_INFERENCE : ${USE_VULKAN_FP16_INFERENCE}")
|
message(STATUS " USE_VULKAN_FP16_INFERENCE : ${USE_VULKAN_FP16_INFERENCE}")
|
||||||
message(STATUS " USE_VULKAN_RELAXED_PRECISION : ${USE_VULKAN_RELAXED_PRECISION}")
|
message(STATUS " USE_VULKAN_RELAXED_PRECISION : ${USE_VULKAN_RELAXED_PRECISION}")
|
||||||
message(STATUS " USE_VULKAN_SHADERC_RUNTIME : ${USE_VULKAN_SHADERC_RUNTIME}")
|
|
||||||
endif()
|
endif()
|
||||||
message(STATUS " USE_PROF : ${USE_PROF}")
|
message(STATUS " USE_PROF : ${USE_PROF}")
|
||||||
message(STATUS " USE_QNNPACK : ${USE_QNNPACK}")
|
message(STATUS " USE_QNNPACK : ${USE_QNNPACK}")
|
||||||
|
@ -13,71 +13,45 @@ if(USE_VULKAN_FP16_INFERENCE)
|
|||||||
list(APPEND VULKAN_GEN_ARG_ENV "format=rgba16f")
|
list(APPEND VULKAN_GEN_ARG_ENV "format=rgba16f")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(USE_VULKAN_SHADERC_RUNTIME)
|
# Precompiling shaders
|
||||||
set(PYTHONPATH "$ENV{PYTHONPATH}")
|
if(ANDROID)
|
||||||
set(NEW_PYTHONPATH ${PYTHONPATH})
|
if(NOT ANDROID_NDK)
|
||||||
list(APPEND NEW_PYTHONPATH "${CMAKE_CURRENT_LIST_DIR}/..")
|
message(FATAL_ERROR "ANDROID_NDK not set")
|
||||||
set(ENV{PYTHONPATH} ${NEW_PYTHONPATH})
|
endif()
|
||||||
execute_process(
|
|
||||||
COMMAND
|
set(GLSLC_PATH "${ANDROID_NDK}/shader-tools/${ANDROID_NDK_HOST_SYSTEM_NAME}/glslc")
|
||||||
"${PYTHON_EXECUTABLE}"
|
else()
|
||||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/gen_vulkan_glsl.py
|
find_program(
|
||||||
--glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl
|
GLSLC_PATH glslc
|
||||||
--output-path ${VULKAN_GEN_OUTPUT_PATH}
|
PATHS
|
||||||
--tmp-dir-path=${CMAKE_BINARY_DIR}/vulkan/glsl
|
ENV VULKAN_SDK
|
||||||
--env ${VULKAN_GEN_ARG_ENV}
|
PATHS "$ENV{VULKAN_SDK}/${CMAKE_HOST_SYSTEM_PROCESSOR}/bin"
|
||||||
RESULT_VARIABLE error_code)
|
PATHS "$ENV{VULKAN_SDK}/bin"
|
||||||
set(ENV{PYTHONPATH} ${PYTHONPATH})
|
)
|
||||||
|
|
||||||
|
if(NOT GLSLC_PATH)
|
||||||
|
message(FATAL_ERROR "USE_VULKAN glslc not found")
|
||||||
|
endif(NOT GLSLC_PATH)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(PYTHONPATH "$ENV{PYTHONPATH}")
|
||||||
|
set(NEW_PYTHONPATH ${PYTHONPATH})
|
||||||
|
list(APPEND NEW_PYTHONPATH "${CMAKE_CURRENT_LIST_DIR}/..")
|
||||||
|
set(ENV{PYTHONPATH} ${NEW_PYTHONPATH})
|
||||||
|
execute_process(
|
||||||
|
COMMAND
|
||||||
|
"${PYTHON_EXECUTABLE}"
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/../tools/gen_vulkan_spv.py
|
||||||
|
--glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl
|
||||||
|
--output-path ${VULKAN_GEN_OUTPUT_PATH}
|
||||||
|
--glslc-path=${GLSLC_PATH}
|
||||||
|
--tmp-dir-path=${CMAKE_BINARY_DIR}/vulkan/spv
|
||||||
|
--env ${VULKAN_GEN_ARG_ENV}
|
||||||
|
RESULT_VARIABLE error_code)
|
||||||
|
set(ENV{PYTHONPATH} ${PYTHONPATH})
|
||||||
|
|
||||||
if(error_code)
|
if(error_code)
|
||||||
message(FATAL_ERROR "Failed to gen glsl.h and glsl.cpp with shaders sources for Vulkan backend")
|
message(FATAL_ERROR "Failed to gen spv.h and spv.cpp with precompiled shaders for Vulkan backend")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(vulkan_generated_cpp ${VULKAN_GEN_OUTPUT_PATH}/glsl.cpp)
|
set(vulkan_generated_cpp ${VULKAN_GEN_OUTPUT_PATH}/spv.cpp)
|
||||||
return()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(NOT USE_VULKAN_SHADERC_RUNTIME)
|
|
||||||
# Precompiling shaders
|
|
||||||
if(ANDROID)
|
|
||||||
if(NOT ANDROID_NDK)
|
|
||||||
message(FATAL_ERROR "ANDROID_NDK not set")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(GLSLC_PATH "${ANDROID_NDK}/shader-tools/${ANDROID_NDK_HOST_SYSTEM_NAME}/glslc")
|
|
||||||
else()
|
|
||||||
find_program(
|
|
||||||
GLSLC_PATH glslc
|
|
||||||
PATHS
|
|
||||||
ENV VULKAN_SDK
|
|
||||||
PATHS "$ENV{VULKAN_SDK}/${CMAKE_HOST_SYSTEM_PROCESSOR}/bin"
|
|
||||||
PATHS "$ENV{VULKAN_SDK}/bin"
|
|
||||||
)
|
|
||||||
|
|
||||||
if(NOT GLSLC_PATH)
|
|
||||||
message(FATAL_ERROR "USE_VULKAN glslc not found")
|
|
||||||
endif(NOT GLSLC_PATH)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(PYTHONPATH "$ENV{PYTHONPATH}")
|
|
||||||
set(NEW_PYTHONPATH ${PYTHONPATH})
|
|
||||||
list(APPEND NEW_PYTHONPATH "${CMAKE_CURRENT_LIST_DIR}/..")
|
|
||||||
set(ENV{PYTHONPATH} ${NEW_PYTHONPATH})
|
|
||||||
execute_process(
|
|
||||||
COMMAND
|
|
||||||
"${PYTHON_EXECUTABLE}"
|
|
||||||
${CMAKE_CURRENT_LIST_DIR}/../tools/gen_vulkan_spv.py
|
|
||||||
--glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl
|
|
||||||
--output-path ${VULKAN_GEN_OUTPUT_PATH}
|
|
||||||
--glslc-path=${GLSLC_PATH}
|
|
||||||
--tmp-dir-path=${CMAKE_BINARY_DIR}/vulkan/spv
|
|
||||||
--env ${VULKAN_GEN_ARG_ENV}
|
|
||||||
RESULT_VARIABLE error_code)
|
|
||||||
set(ENV{PYTHONPATH} ${PYTHONPATH})
|
|
||||||
|
|
||||||
if(error_code)
|
|
||||||
message(FATAL_ERROR "Failed to gen spv.h and spv.cpp with precompiled shaders for Vulkan backend")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(vulkan_generated_cpp ${VULKAN_GEN_OUTPUT_PATH}/spv.cpp)
|
|
||||||
endif()
|
|
||||||
|
@ -29,60 +29,6 @@ if(ANDROID)
|
|||||||
list(APPEND Vulkan_INCLUDES ${VULKAN_WRAPPER_DIR})
|
list(APPEND Vulkan_INCLUDES ${VULKAN_WRAPPER_DIR})
|
||||||
list(APPEND Vulkan_LIBS VulkanWrapper)
|
list(APPEND Vulkan_LIBS VulkanWrapper)
|
||||||
|
|
||||||
# Shaderc
|
|
||||||
if(USE_VULKAN_SHADERC_RUNTIME)
|
|
||||||
# Shaderc from ANDROID_NDK
|
|
||||||
set(Shaderc_ANDROID_NDK_INCLUDE_DIR "${ANDROID_NDK}/sources/third_party/shaderc/include")
|
|
||||||
message(STATUS "Shaderc_ANDROID_NDK_INCLUDE_DIR:${Shaderc_ANDROID_NDK_INCLUDE_DIR}")
|
|
||||||
|
|
||||||
find_path(
|
|
||||||
GOOGLE_SHADERC_INCLUDE_DIRS
|
|
||||||
NAMES shaderc/shaderc.hpp
|
|
||||||
PATHS "${Shaderc_ANDROID_NDK_INCLUDE_DIR}")
|
|
||||||
|
|
||||||
set(Shaderc_ANDROID_NDK_LIB_DIR "${ANDROID_NDK}/sources/third_party/shaderc/libs/${ANDROID_STL}/${ANDROID_ABI}")
|
|
||||||
message(STATUS "Shaderc_ANDROID_NDK_LIB_DIR:${Shaderc_ANDROID_NDK_LIB_DIR}")
|
|
||||||
|
|
||||||
find_library(
|
|
||||||
GOOGLE_SHADERC_LIBRARIES
|
|
||||||
NAMES shaderc
|
|
||||||
PATHS "${Shaderc_ANDROID_NDK_LIB_DIR}")
|
|
||||||
|
|
||||||
# Shaderc in NDK is not prebuilt
|
|
||||||
if(NOT GOOGLE_SHADERC_LIBRARIES)
|
|
||||||
set(NDK_SHADERC_DIR "${ANDROID_NDK}/sources/third_party/shaderc")
|
|
||||||
set(NDK_BUILD_CMD "${ANDROID_NDK}/ndk-build")
|
|
||||||
|
|
||||||
execute_process(
|
|
||||||
COMMAND ${NDK_BUILD_CMD}
|
|
||||||
NDK_PROJECT_PATH=${NDK_SHADERC_DIR}
|
|
||||||
APP_BUILD_SCRIPT=${NDK_SHADERC_DIR}/Android.mk
|
|
||||||
APP_PLATFORM=${ANDROID_PLATFORM}
|
|
||||||
APP_STL=${ANDROID_STL}
|
|
||||||
APP_ABI=${ANDROID_ABI}
|
|
||||||
libshaderc_combined -j8
|
|
||||||
WORKING_DIRECTORY "${NDK_SHADERC_DIR}"
|
|
||||||
RESULT_VARIABLE error_code)
|
|
||||||
|
|
||||||
if(error_code)
|
|
||||||
message(FATAL_ERROR "Failed to build ANDROID_NDK Shaderc error_code:${error_code}")
|
|
||||||
else()
|
|
||||||
unset(GOOGLE_SHADERC_LIBRARIES CACHE)
|
|
||||||
find_library(
|
|
||||||
GOOGLE_SHADERC_LIBRARIES
|
|
||||||
NAMES shaderc
|
|
||||||
HINTS "${Shaderc_ANDROID_NDK_LIB_DIR}")
|
|
||||||
endif()
|
|
||||||
endif(NOT GOOGLE_SHADERC_LIBRARIES)
|
|
||||||
|
|
||||||
if(GOOGLE_SHADERC_INCLUDE_DIRS AND GOOGLE_SHADERC_LIBRARIES)
|
|
||||||
message(STATUS "Shaderc FOUND include:${GOOGLE_SHADERC_INCLUDE_DIRS}")
|
|
||||||
message(STATUS "Shaderc FOUND libs:${GOOGLE_SHADERC_LIBRARIES}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
list(APPEND Vulkan_INCLUDES ${GOOGLE_SHADERC_INCLUDE_DIRS})
|
|
||||||
list(APPEND Vulkan_LIBS ${GOOGLE_SHADERC_LIBRARIES})
|
|
||||||
endif(USE_VULKAN_SHADERC_RUNTIME)
|
|
||||||
else()
|
else()
|
||||||
find_package(Vulkan)
|
find_package(Vulkan)
|
||||||
|
|
||||||
@ -95,32 +41,4 @@ else()
|
|||||||
|
|
||||||
set(GOOGLE_SHADERC_INCLUDE_SEARCH_PATH ${Vulkan_INCLUDE_DIR})
|
set(GOOGLE_SHADERC_INCLUDE_SEARCH_PATH ${Vulkan_INCLUDE_DIR})
|
||||||
set(GOOGLE_SHADERC_LIBRARY_SEARCH_PATH ${Vulkan_LIBRARY})
|
set(GOOGLE_SHADERC_LIBRARY_SEARCH_PATH ${Vulkan_LIBRARY})
|
||||||
|
|
||||||
if(USE_VULKAN_SHADERC_RUNTIME)
|
|
||||||
find_path(
|
|
||||||
GOOGLE_SHADERC_INCLUDE_DIRS
|
|
||||||
NAMES shaderc/shaderc.hpp
|
|
||||||
PATHS ${GOOGLE_SHADERC_INCLUDE_SEARCH_PATH})
|
|
||||||
|
|
||||||
find_library(
|
|
||||||
GOOGLE_SHADERC_LIBRARIES
|
|
||||||
NAMES shaderc_combined
|
|
||||||
PATHS ${GOOGLE_SHADERC_LIBRARY_SEARCH_PATH})
|
|
||||||
|
|
||||||
find_package_handle_standard_args(
|
|
||||||
Shaderc
|
|
||||||
DEFAULT_MSG
|
|
||||||
GOOGLE_SHADERC_INCLUDE_DIRS
|
|
||||||
GOOGLE_SHADERC_LIBRARIES)
|
|
||||||
|
|
||||||
if(NOT Shaderc_FOUND)
|
|
||||||
message(FATAL_ERROR "USE_VULKAN: Shaderc not found in VULKAN_SDK")
|
|
||||||
else()
|
|
||||||
message(STATUS "shaderc FOUND include:${GOOGLE_SHADERC_INCLUDE_DIRS}")
|
|
||||||
message(STATUS "shaderc FOUND libs:${GOOGLE_SHADERC_LIBRARIES}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
list(APPEND Vulkan_INCLUDES ${GOOGLE_SHADERC_INCLUDE_DIRS})
|
|
||||||
list(APPEND Vulkan_LIBS ${GOOGLE_SHADERC_LIBRARIES})
|
|
||||||
endif(USE_VULKAN_SHADERC_RUNTIME)
|
|
||||||
endif()
|
endif()
|
||||||
|
@ -157,9 +157,6 @@ if [ -n "${USE_VULKAN}" ]; then
|
|||||||
if [ -n "${USE_VULKAN_RELAXED_PRECISION}" ]; then
|
if [ -n "${USE_VULKAN_RELAXED_PRECISION}" ]; then
|
||||||
CMAKE_ARGS+=("-DUSE_VULKAN_RELAXED_PRECISION=ON")
|
CMAKE_ARGS+=("-DUSE_VULKAN_RELAXED_PRECISION=ON")
|
||||||
fi
|
fi
|
||||||
if [ -n "${USE_VULKAN_SHADERC_RUNTIME}" ]; then
|
|
||||||
CMAKE_ARGS+=("-DUSE_VULKAN_SHADERC_RUNTIME=ON")
|
|
||||||
fi
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Use-specified CMake arguments go last to allow overridding defaults
|
# Use-specified CMake arguments go last to allow overridding defaults
|
||||||
|
@ -211,18 +211,6 @@ def define_tools_targets(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"gen_vulkan_spv.py",
|
"gen_vulkan_spv.py",
|
||||||
],
|
],
|
||||||
base_module = "",
|
|
||||||
deps = [
|
|
||||||
torchgen_deps,
|
|
||||||
":gen_aten_vulkan_glsl_lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python_library(
|
|
||||||
name = "gen_aten_vulkan_glsl_lib",
|
|
||||||
srcs = [
|
|
||||||
"gen_vulkan_glsl.py",
|
|
||||||
],
|
|
||||||
base_module = "tools",
|
base_module = "tools",
|
||||||
deps = [
|
deps = [
|
||||||
torchgen_deps,
|
torchgen_deps,
|
||||||
@ -231,12 +219,11 @@ def define_tools_targets(
|
|||||||
|
|
||||||
python_binary(
|
python_binary(
|
||||||
name = "gen_aten_vulkan_spv_bin",
|
name = "gen_aten_vulkan_spv_bin",
|
||||||
main_module = "gen_vulkan_spv",
|
main_module = "tools.gen_vulkan_spv",
|
||||||
visibility = [
|
visibility = [
|
||||||
"PUBLIC",
|
"PUBLIC",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":gen_aten_vulkan_glsl_lib",
|
|
||||||
":gen_aten_vulkan_spv_lib",
|
":gen_aten_vulkan_spv_lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -249,7 +236,6 @@ def define_tools_targets(
|
|||||||
contacts = contacts,
|
contacts = contacts,
|
||||||
visibility = ["PUBLIC"],
|
visibility = ["PUBLIC"],
|
||||||
deps = [
|
deps = [
|
||||||
":gen_aten_vulkan_glsl_lib",
|
|
||||||
":gen_aten_vulkan_spv_lib",
|
":gen_aten_vulkan_spv_lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -1,111 +0,0 @@
|
|||||||
import copy
|
|
||||||
import os
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
from torchgen.code_template import CodeTemplate
|
|
||||||
from yaml.constructor import ConstructorError
|
|
||||||
from yaml.nodes import MappingNode
|
|
||||||
|
|
||||||
try:
|
|
||||||
from yaml import CLoader as Loader
|
|
||||||
except ImportError:
|
|
||||||
from yaml import Loader # type: ignore[misc]
|
|
||||||
|
|
||||||
# https://gist.github.com/pypt/94d747fe5180851196eb
|
|
||||||
class UniqueKeyLoader(Loader):
|
|
||||||
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
|
|
||||||
if not isinstance(node, MappingNode):
|
|
||||||
raise ConstructorError(
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
"expected a mapping node, but found %s" % node.id,
|
|
||||||
node.start_mark,
|
|
||||||
)
|
|
||||||
mapping = {}
|
|
||||||
for key_node, value_node in node.value:
|
|
||||||
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
|
|
||||||
try:
|
|
||||||
hash(key)
|
|
||||||
except TypeError as e:
|
|
||||||
raise ConstructorError(
|
|
||||||
"while constructing a mapping",
|
|
||||||
node.start_mark,
|
|
||||||
"found unacceptable key ",
|
|
||||||
key_node.start_mark,
|
|
||||||
) from e
|
|
||||||
# check for duplicate keys
|
|
||||||
if key in mapping:
|
|
||||||
raise ConstructorError(
|
|
||||||
"while constructing a mapping",
|
|
||||||
node.start_mark,
|
|
||||||
"found duplicate key",
|
|
||||||
key_node.start_mark,
|
|
||||||
)
|
|
||||||
value = self.construct_object(value_node, deep=deep) # type: ignore[no-untyped-call]
|
|
||||||
mapping[key] = value
|
|
||||||
return mapping
|
|
||||||
|
|
||||||
|
|
||||||
class GLSLGenerator(object):
|
|
||||||
standard_header = """
|
|
||||||
#version 450 core
|
|
||||||
#define PRECISION $precision
|
|
||||||
#define FORMAT $format
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self): # type: ignore[no-untyped-def]
|
|
||||||
self.ops_template_params = {}
|
|
||||||
|
|
||||||
def add_params_yaml(self, parameters_yaml_file): # type: ignore[no-untyped-def]
|
|
||||||
all_template_params = OrderedDict()
|
|
||||||
with open(parameters_yaml_file, "r") as f:
|
|
||||||
contents = yaml.load(f, Loader=UniqueKeyLoader)
|
|
||||||
for key in contents:
|
|
||||||
all_template_params[key] = contents[key]
|
|
||||||
self.validate_and_construct_op_params(all_template_params) # type: ignore[no-untyped-call]
|
|
||||||
|
|
||||||
def validate_and_construct_op_params(self, all_template_params): # type: ignore[no-untyped-def]
|
|
||||||
for op in all_template_params:
|
|
||||||
if op in self.ops_template_params:
|
|
||||||
raise KeyError(f"{op} params file has already been parsed")
|
|
||||||
op_params_default_vals = all_template_params[op][
|
|
||||||
"parameter_names_with_default_values"
|
|
||||||
]
|
|
||||||
template_params_set = set(op_params_default_vals.keys())
|
|
||||||
self.ops_template_params[op] = []
|
|
||||||
self.ops_template_params[op].append(op_params_default_vals)
|
|
||||||
op_template_params_values = all_template_params[op]["parameter_values"]
|
|
||||||
for param_vals in op_template_params_values:
|
|
||||||
param_vals_set = set(param_vals.keys())
|
|
||||||
invalid_keys = param_vals_set - template_params_set
|
|
||||||
if (len(invalid_keys)) > 0:
|
|
||||||
raise KeyError(f"Invalid keys {invalid_keys} are found")
|
|
||||||
param_vals_copy = copy.deepcopy(op_params_default_vals)
|
|
||||||
for key in param_vals:
|
|
||||||
param_vals_copy[key] = param_vals[key]
|
|
||||||
self.ops_template_params[op].append(param_vals_copy)
|
|
||||||
|
|
||||||
def generate(self, glsl_template_in, out_dir): # type: ignore[no-untyped-def]
|
|
||||||
glsl_template_name = os.path.basename(glsl_template_in)
|
|
||||||
op_name, extension_name = glsl_template_name.split(".")
|
|
||||||
if extension_name != "glslt":
|
|
||||||
raise TypeError(f"invalid file type for glsl template {extension_name}")
|
|
||||||
if op_name not in self.ops_template_params:
|
|
||||||
raise KeyError(f"{op_name} params have not been populated")
|
|
||||||
code_template = CodeTemplate.from_file(glsl_template_in)
|
|
||||||
for template_params in self.ops_template_params[op_name]:
|
|
||||||
content = GLSLGenerator.standard_header
|
|
||||||
param_vals_string = "x".join([str(i) for i in template_params.values()])
|
|
||||||
output_file_name = op_name + "_" + param_vals_string + ".glsl"
|
|
||||||
content += code_template.substitute(template_params)
|
|
||||||
output_file = os.path.join(out_dir, output_file_name)
|
|
||||||
with open(output_file, "w") as f:
|
|
||||||
f.write(content)
|
|
||||||
|
|
||||||
|
|
||||||
# Remove this
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pass
|
|
@ -2,21 +2,122 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import array
|
import array
|
||||||
|
import copy
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import yaml
|
||||||
|
from collections import OrderedDict
|
||||||
from torchgen.code_template import CodeTemplate
|
from torchgen.code_template import CodeTemplate
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List
|
from typing import Any, Dict, List
|
||||||
|
from yaml.constructor import ConstructorError
|
||||||
|
from yaml.nodes import MappingNode
|
||||||
|
|
||||||
from tools.gen_vulkan_glsl import GLSLGenerator
|
try:
|
||||||
|
from yaml import CLoader as Loader
|
||||||
|
except ImportError:
|
||||||
|
from yaml import Loader # type: ignore[misc]
|
||||||
|
|
||||||
H_NAME = "spv.h"
|
H_NAME = "spv.h"
|
||||||
CPP_NAME = "spv.cpp"
|
CPP_NAME = "spv.cpp"
|
||||||
DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"}
|
DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"}
|
||||||
|
|
||||||
|
# https://gist.github.com/pypt/94d747fe5180851196eb
|
||||||
|
class UniqueKeyLoader(Loader):
|
||||||
|
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
|
||||||
|
if not isinstance(node, MappingNode):
|
||||||
|
raise ConstructorError(
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
"expected a mapping node, but found %s" % node.id,
|
||||||
|
node.start_mark,
|
||||||
|
)
|
||||||
|
mapping = {}
|
||||||
|
for key_node, value_node in node.value:
|
||||||
|
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
|
||||||
|
try:
|
||||||
|
hash(key)
|
||||||
|
except TypeError as e:
|
||||||
|
raise ConstructorError(
|
||||||
|
"while constructing a mapping",
|
||||||
|
node.start_mark,
|
||||||
|
"found unacceptable key ",
|
||||||
|
key_node.start_mark,
|
||||||
|
) from e
|
||||||
|
# check for duplicate keys
|
||||||
|
if key in mapping:
|
||||||
|
raise ConstructorError(
|
||||||
|
"while constructing a mapping",
|
||||||
|
node.start_mark,
|
||||||
|
"found duplicate key",
|
||||||
|
key_node.start_mark,
|
||||||
|
)
|
||||||
|
value = self.construct_object(value_node, deep=deep) # type: ignore[no-untyped-call]
|
||||||
|
mapping[key] = value
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
class VulkanShaderGenerator(object):
|
||||||
|
standard_header = """
|
||||||
|
#version 450 core
|
||||||
|
#define PRECISION $precision
|
||||||
|
#define FORMAT $format
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self: "VulkanShaderGenerator") -> None:
|
||||||
|
self.ops_template_params: Dict[Any, Any] = {}
|
||||||
|
|
||||||
|
def add_params_yaml(self, parameters_yaml_file): # type: ignore[no-untyped-def]
|
||||||
|
all_template_params = OrderedDict()
|
||||||
|
with open(parameters_yaml_file, "r") as f:
|
||||||
|
contents = yaml.load(f, Loader=UniqueKeyLoader)
|
||||||
|
for key in contents:
|
||||||
|
all_template_params[key] = contents[key]
|
||||||
|
self.validate_and_construct_op_params(all_template_params) # type: ignore[no-untyped-call]
|
||||||
|
|
||||||
|
def validate_and_construct_op_params(self, all_template_params): # type: ignore[no-untyped-def]
|
||||||
|
for op in all_template_params:
|
||||||
|
if op in self.ops_template_params:
|
||||||
|
raise KeyError(f"{op} params file has already been parsed")
|
||||||
|
op_params_default_vals = all_template_params[op][
|
||||||
|
"parameter_names_with_default_values"
|
||||||
|
]
|
||||||
|
template_params_set = set(op_params_default_vals.keys())
|
||||||
|
self.ops_template_params[op] = []
|
||||||
|
self.ops_template_params[op].append(op_params_default_vals)
|
||||||
|
op_template_params_values = all_template_params[op]["parameter_values"]
|
||||||
|
for param_vals in op_template_params_values:
|
||||||
|
param_vals_set = set(param_vals.keys())
|
||||||
|
missing_keys = template_params_set - param_vals_set
|
||||||
|
invalid_keys = param_vals_set - template_params_set
|
||||||
|
if (len(invalid_keys)) > 0:
|
||||||
|
raise KeyError(f"Invalid keys {invalid_keys} are found")
|
||||||
|
param_vals_copy = copy.deepcopy(op_params_default_vals)
|
||||||
|
for key in param_vals:
|
||||||
|
param_vals_copy[key] = param_vals[key]
|
||||||
|
self.ops_template_params[op].append(param_vals_copy)
|
||||||
|
|
||||||
|
def generate(self, glsl_template_in, out_dir): # type: ignore[no-untyped-def]
|
||||||
|
glsl_template_name = os.path.basename(glsl_template_in)
|
||||||
|
op_name, extension_name = glsl_template_name.split(".")
|
||||||
|
if extension_name != "glslt":
|
||||||
|
raise TypeError(f"invalid file type for glsl template {extension_name}")
|
||||||
|
if op_name not in self.ops_template_params:
|
||||||
|
raise KeyError(f"{op_name} params have not been populated")
|
||||||
|
code_template = CodeTemplate.from_file(glsl_template_in)
|
||||||
|
for template_params in self.ops_template_params[op_name]:
|
||||||
|
content = VulkanShaderGenerator.standard_header
|
||||||
|
param_vals_string = "x".join([str(i) for i in template_params.values()])
|
||||||
|
output_file_name = op_name + "_" + param_vals_string + ".glsl"
|
||||||
|
content += code_template.substitute(template_params)
|
||||||
|
output_file = os.path.join(out_dir, output_file_name)
|
||||||
|
with open(output_file, "w") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ShaderInfo:
|
class ShaderInfo:
|
||||||
@ -25,38 +126,44 @@ class ShaderInfo:
|
|||||||
weight_storage_type: str = ""
|
weight_storage_type: str = ""
|
||||||
bias_storage_type: str = ""
|
bias_storage_type: str = ""
|
||||||
|
|
||||||
def getName(filePath):
|
def getName(filePath: str) -> str:
|
||||||
return os.path.basename(filePath).replace("/", "_").replace(".", "_")
|
return os.path.basename(filePath).replace("/", "_").replace(".", "_")
|
||||||
|
|
||||||
def isDescriptorLine(lineStr):
|
def isDescriptorLine(lineStr: str) -> bool:
|
||||||
descriptorLineId = r"^layout\(set"
|
descriptorLineId = r"^layout\(set"
|
||||||
return re.search(descriptorLineId, lineStr)
|
return re.search(descriptorLineId, lineStr) is not None
|
||||||
|
|
||||||
def isTileSizeLine(lineStr):
|
def isTileSizeLine(lineStr: str) -> bool:
|
||||||
tile_size_id = r"^ \* TILE_SIZE = \("
|
tile_size_id = r"^ \* TILE_SIZE = \("
|
||||||
return re.search(tile_size_id, lineStr)
|
return re.search(tile_size_id, lineStr) is not None
|
||||||
|
|
||||||
def findTileSizes(lineStr):
|
def findTileSizes(lineStr: str) -> List[int]:
|
||||||
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
|
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
|
||||||
matches = re.search(tile_size_id, lineStr)
|
matches = re.search(tile_size_id, lineStr)
|
||||||
|
if matches is None:
|
||||||
|
raise AssertionError("matches is None in findTileSizes")
|
||||||
return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
|
return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
|
||||||
|
|
||||||
def isWeightStorageTypeLine(lineStr):
|
def isWeightStorageTypeLine(lineStr: str) -> bool:
|
||||||
weight_storage_id = r"^ \* WEIGHT_STORAGE = "
|
weight_storage_id = r"^ \* WEIGHT_STORAGE = "
|
||||||
return re.search(weight_storage_id, lineStr)
|
return re.search(weight_storage_id, lineStr) is not None
|
||||||
|
|
||||||
def getWeightStorageType(lineStr):
|
def getWeightStorageType(lineStr: str) -> str:
|
||||||
weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
|
weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
|
||||||
matches = re.search(weight_storage_id, lineStr)
|
matches = re.search(weight_storage_id, lineStr)
|
||||||
|
if matches is None:
|
||||||
|
raise AssertionError("matches is None in getWeightStorageType")
|
||||||
return matches.group(1)
|
return matches.group(1)
|
||||||
|
|
||||||
def isBiasStorageTypeLine(lineStr):
|
def isBiasStorageTypeLine(lineStr: str) -> bool:
|
||||||
weight_storage_id = r"^ \* BIAS_STORAGE = "
|
weight_storage_id = r"^ \* BIAS_STORAGE = "
|
||||||
return re.search(weight_storage_id, lineStr)
|
return re.search(weight_storage_id, lineStr) is not None
|
||||||
|
|
||||||
def getBiasStorageType(lineStr):
|
def getBiasStorageType(lineStr: str) -> str:
|
||||||
weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
|
weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
|
||||||
matches = re.search(weight_storage_id, lineStr)
|
matches = re.search(weight_storage_id, lineStr)
|
||||||
|
if matches is None:
|
||||||
|
raise AssertionError("matches is None in getBiasStorageType")
|
||||||
return matches.group(1)
|
return matches.group(1)
|
||||||
|
|
||||||
typeIdMapping = {
|
typeIdMapping = {
|
||||||
@ -73,12 +180,15 @@ storageTypeToEnum = {
|
|||||||
"": "api::StorageType::UNKNOWN",
|
"": "api::StorageType::UNKNOWN",
|
||||||
}
|
}
|
||||||
|
|
||||||
def determineDescriptorType(lineStr):
|
def determineDescriptorType(lineStr: str) -> str:
|
||||||
for identifier, typeNum in typeIdMapping.items():
|
for identifier, typeNum in typeIdMapping.items():
|
||||||
if re.search(identifier, lineStr):
|
if re.search(identifier, lineStr):
|
||||||
return typeNum
|
return typeNum
|
||||||
|
raise AssertionError(
|
||||||
|
"No matching descriptor type for " + lineStr + " in determineDescriptorType"
|
||||||
|
)
|
||||||
|
|
||||||
def getShaderInfo(srcFilePath):
|
def getShaderInfo(srcFilePath: str) -> ShaderInfo:
|
||||||
shader_info = ShaderInfo([], [], "")
|
shader_info = ShaderInfo([], [], "")
|
||||||
with open(srcFilePath, 'r') as srcFile:
|
with open(srcFilePath, 'r') as srcFile:
|
||||||
for line in srcFile:
|
for line in srcFile:
|
||||||
@ -93,14 +203,14 @@ def getShaderInfo(srcFilePath):
|
|||||||
|
|
||||||
return shader_info
|
return shader_info
|
||||||
|
|
||||||
def genGLSLFromGLSLT(src_dir_path, tmp_dir_path):
|
def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
|
||||||
template_dir_path = os.path.join(src_dir_path, "templates")
|
template_dir_path = os.path.join(src_dir_path, "templates")
|
||||||
vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True)
|
vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True)
|
||||||
parameter_yaml_files = []
|
parameter_yaml_files = []
|
||||||
for f in vexs:
|
for f in vexs:
|
||||||
if len(f) > 1:
|
if len(f) > 1:
|
||||||
parameter_yaml_files.append(f)
|
parameter_yaml_files.append(f)
|
||||||
generator = GLSLGenerator()
|
generator = VulkanShaderGenerator()
|
||||||
for params_yaml in parameter_yaml_files:
|
for params_yaml in parameter_yaml_files:
|
||||||
generator.add_params_yaml(params_yaml) # type: ignore[no-untyped-call]
|
generator.add_params_yaml(params_yaml) # type: ignore[no-untyped-call]
|
||||||
|
|
||||||
@ -113,9 +223,20 @@ def genGLSLFromGLSLT(src_dir_path, tmp_dir_path):
|
|||||||
for glslt in templateSrcPaths:
|
for glslt in templateSrcPaths:
|
||||||
generator.generate(glslt, tmp_dir_path) # type: ignore[no-untyped-call]
|
generator.generate(glslt, tmp_dir_path) # type: ignore[no-untyped-call]
|
||||||
|
|
||||||
def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
|
|
||||||
print("hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".format(
|
def genCppH(
|
||||||
hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath))
|
hFilePath: str,
|
||||||
|
cppFilePath: str,
|
||||||
|
srcDirPath: str,
|
||||||
|
glslcPath: str,
|
||||||
|
tmpDirPath: str,
|
||||||
|
env: Dict[Any, Any],
|
||||||
|
) -> None:
|
||||||
|
print(
|
||||||
|
"hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".format(
|
||||||
|
hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True)
|
vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True)
|
||||||
templateSrcPaths = []
|
templateSrcPaths = []
|
||||||
@ -142,8 +263,8 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
|
|||||||
codeTemplate = CodeTemplate.from_file(templateSrcPath)
|
codeTemplate = CodeTemplate.from_file(templateSrcPath)
|
||||||
srcPath = tmpDirPath + "/" + name + ".glsl"
|
srcPath = tmpDirPath + "/" + name + ".glsl"
|
||||||
content = codeTemplate.substitute(env)
|
content = codeTemplate.substitute(env)
|
||||||
with open(srcPath, 'w') as f:
|
with open(srcPath, 'w') as fw:
|
||||||
f.write(content)
|
fw.write(content)
|
||||||
|
|
||||||
spvPath = tmpDirPath + "/" + name + ".spv"
|
spvPath = tmpDirPath + "/" + name + ".spv"
|
||||||
print("spvPath {}".format(spvPath))
|
print("spvPath {}".format(spvPath))
|
||||||
@ -188,8 +309,8 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
|
|||||||
name = getName(spvPath)
|
name = getName(spvPath)
|
||||||
|
|
||||||
print("spvPath:{}".format(spvPath))
|
print("spvPath:{}".format(spvPath))
|
||||||
with open(spvPath, 'rb') as f:
|
with open(spvPath, 'rb') as fr:
|
||||||
next_bin = array.array('I', f.read())
|
next_bin = array.array('I', fr.read())
|
||||||
sizeBytes = 4 * len(next_bin)
|
sizeBytes = 4 * len(next_bin)
|
||||||
shader_info_bin_code.append(
|
shader_info_bin_code.append(
|
||||||
"const uint32_t {}_bin[] = {{\n {}\n}};".format(
|
"const uint32_t {}_bin[] = {{\n {}\n}};".format(
|
||||||
@ -231,13 +352,13 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
|
|||||||
cpp += nsend
|
cpp += nsend
|
||||||
h += nsend
|
h += nsend
|
||||||
|
|
||||||
with open(hFilePath, "w") as f:
|
with open(hFilePath, "w") as fw:
|
||||||
f.write(h)
|
fw.write(h)
|
||||||
with open(cppFilePath, "w") as f:
|
with open(cppFilePath, "w") as fw:
|
||||||
f.write(cpp)
|
fw.write(cpp)
|
||||||
|
|
||||||
|
|
||||||
def parse_arg_env(items):
|
def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||||
d = {}
|
d = {}
|
||||||
if items:
|
if items:
|
||||||
for item in items:
|
for item in items:
|
||||||
@ -248,7 +369,7 @@ def parse_arg_env(items):
|
|||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv: List[str]) -> int:
|
||||||
parser = argparse.ArgumentParser(description='')
|
parser = argparse.ArgumentParser(description='')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-i',
|
'-i',
|
||||||
@ -294,5 +415,7 @@ def main(argv):
|
|||||||
tmpDirPath=options.tmp_dir_path,
|
tmpDirPath=options.tmp_dir_path,
|
||||||
env=env)
|
env=env)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
sys.exit(main(sys.argv))
|
sys.exit(main(sys.argv))
|
||||||
|
@ -2,11 +2,11 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from tools.gen_vulkan_glsl import GLSLGenerator
|
from tools.gen_vulkan_spv import VulkanShaderGenerator
|
||||||
from yaml.constructor import ConstructorError
|
from yaml.constructor import ConstructorError
|
||||||
|
|
||||||
|
|
||||||
class TestGLSLCodegen(unittest.TestCase):
|
class TestVulkanShaderCodegen(unittest.TestCase):
|
||||||
def test_assert_on_duplicate_key_yaml(self) -> None:
|
def test_assert_on_duplicate_key_yaml(self) -> None:
|
||||||
yaml_with_duplicate_keys = """
|
yaml_with_duplicate_keys = """
|
||||||
conv2d_pw:
|
conv2d_pw:
|
||||||
@ -37,7 +37,7 @@ conv2d_pw:
|
|||||||
TILE_SIZE_Y: 4
|
TILE_SIZE_Y: 4
|
||||||
"""
|
"""
|
||||||
|
|
||||||
generator = GLSLGenerator() # type: ignore[no-untyped-call]
|
generator = VulkanShaderGenerator() # type: ignore[no-untyped-call]
|
||||||
with tempfile.NamedTemporaryFile(mode="w") as fp:
|
with tempfile.NamedTemporaryFile(mode="w") as fp:
|
||||||
fp.write(yaml_with_duplicate_keys)
|
fp.write(yaml_with_duplicate_keys)
|
||||||
fp.flush()
|
fp.flush()
|
||||||
@ -57,7 +57,7 @@ conv2d_pw:
|
|||||||
TILE_SIZE_Z: 2
|
TILE_SIZE_Z: 2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
generator = GLSLGenerator() # type: ignore[no-untyped-call]
|
generator = VulkanShaderGenerator() # type: ignore[no-untyped-call]
|
||||||
with tempfile.NamedTemporaryFile(mode="w") as fp:
|
with tempfile.NamedTemporaryFile(mode="w") as fp:
|
||||||
fp.write(yaml_with_key_mismatch)
|
fp.write(yaml_with_key_mismatch)
|
||||||
fp.flush()
|
fp.flush()
|
||||||
@ -77,7 +77,7 @@ conv2d_pw:
|
|||||||
x = $TILE_SIZE_X + $TILE_SIZE_Y
|
x = $TILE_SIZE_X + $TILE_SIZE_Y
|
||||||
"""
|
"""
|
||||||
|
|
||||||
generator = GLSLGenerator() # type: ignore[no-untyped-call]
|
generator = VulkanShaderGenerator() # type: ignore[no-untyped-call]
|
||||||
with tempfile.NamedTemporaryFile(mode="w") as fp:
|
with tempfile.NamedTemporaryFile(mode="w") as fp:
|
||||||
fp.write(yaml_with_key_mismatch)
|
fp.write(yaml_with_key_mismatch)
|
||||||
fp.flush()
|
fp.flush()
|
||||||
|
Reference in New Issue
Block a user