[Vulkan] Remove GLSL Code Gen (#91912)

@bypass-github-export-checks

GLSL Code Gen is not used, so this diff removes
- GLSL parts of ShaderSource
- Anything enclosed by USE_VULKAN_SHADERC_RUNTIME, as well as the flag itself
- gen_vulkan_glsl script

Plus some additional refactoring

Differential Revision: [D41358861](https://our.internmc.facebook.com/intern/diff/D41358861/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D41358861/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91912
Approved by: https://github.com/mcr229
This commit is contained in:
salilsdesai
2023-01-09 18:08:01 -08:00
committed by PyTorch MergeBot
parent 28eb3c8faf
commit ec94cbc66a
13 changed files with 217 additions and 503 deletions

View File

@ -266,7 +266,6 @@ option(USE_SOURCE_DEBUG_ON_MOBILE "Enable " ON)
option(USE_LITE_INTERPRETER_PROFILER "Enable " ON) option(USE_LITE_INTERPRETER_PROFILER "Enable " ON)
option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF) option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF)
option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF) option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF)
option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation as opposed to build-time (needs libshaderc)" OFF)
# option USE_XNNPACK: try to enable xnnpack by default. # option USE_XNNPACK: try to enable xnnpack by default.
option(USE_XNNPACK "Use XNNPACK" ON) option(USE_XNNPACK "Use XNNPACK" ON)
option(USE_ZMQ "Use ZMQ" OFF) option(USE_ZMQ "Use ZMQ" OFF)
@ -746,9 +745,6 @@ if(USE_VULKAN)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_RELAXED_PRECISION") string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_RELAXED_PRECISION")
endif() endif()
if(USE_VULKAN_SHADERC_RUNTIME)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_SHADERC_RUNTIME")
endif()
endif() endif()
if(BUILD_LITE_INTERPRETER) if(BUILD_LITE_INTERPRETER)

View File

@ -1,115 +0,0 @@
#!/usr/bin/env python3
import argparse
import glob
import sys
import os
from torchgen.code_template import CodeTemplate
H_NAME = "glsl.h"
CPP_NAME = "glsl.cpp"
DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"}
def findAllGlsls(path):
vexs = glob.glob(os.path.join(path, '**', '*.glsl'), recursive=True)
output = []
for f in vexs:
if len(f) > 1:
output.append(f)
output.sort()
return output
def getName(filePath):
return os.path.basename(filePath).replace("/", "_").replace(".", "_")
def genCppH(hFilePath, cppFilePath, templateGlslPaths, tmpDirPath, env):
print("hFilePath:{}".format(hFilePath))
print("cppFilePath:{}".format(cppFilePath))
h = "#pragma once\n"
nsbegin = "\nnamespace at { namespace native { namespace vulkan { \n"
nsend = "\n} } } //namespace at::native::vulkan\n"
h += nsbegin
cpp = "#include <ATen/native/vulkan/{}>".format(H_NAME)
cpp += nsbegin
for templateGlslPath in templateGlslPaths:
name = getName(templateGlslPath)
h += "extern const char* " + name + ";\n"
cpp += "const char* " + name + " = \n"
codeTemplate = CodeTemplate.from_file(templateGlslPath)
srcPath = tmpDirPath + "/" + name + ".glsl"
content = codeTemplate.substitute(env)
lines = content.split("\n")
for l in lines:
if (len(l) < 1):
continue
cpp += "\"" + l + "\\n\"\n"
cpp += ";\n"
cpp += nsend
h += nsend
with open(hFilePath, "w") as f:
f.write(h)
with open(cppFilePath, "w") as f:
f.write(cpp)
def parse_arg_env(items):
d = {}
if items:
for item in items:
tokens = item.split("=")
key = tokens[0].strip()
value = tokens[1].strip()
d[key] = value
return d
def main(argv):
parser = argparse.ArgumentParser(description='Generate glsl.cpp and glsl.h containing glsl sources')
parser.add_argument(
'-i',
'--glsl-path',
help='path to directory with glsl to process',
required=True,
default='.')
parser.add_argument(
'-o',
'--output-path',
help='path to directory to generate glsl.h glsl.cpp (cpp namespace at::native::vulkan)',
required=True)
parser.add_argument(
'-t',
'--tmp-dir-path',
required=True,
help='/tmp')
parser.add_argument(
"--env",
metavar="KEY=VALUE",
nargs='*',
help="Set a number of key-value pairs")
options = parser.parse_args()
if not os.path.exists(options.tmp_dir_path):
os.makedirs(options.tmp_dir_path)
env = DEFAULT_ENV
for key, value in parse_arg_env(options.env).items():
env[key] = value
if not os.path.exists(options.output_path):
os.makedirs(options.output_path)
glsls = findAllGlsls(options.glsl_path)
genCppH(
options.output_path + "/" + H_NAME, options.output_path + "/" + CPP_NAME,
glsls,
tmpDirPath=options.tmp_dir_path,
env=env)
if __name__ == '__main__':
sys.exit(main(sys.argv))

View File

@ -1,9 +1,5 @@
#include <ATen/native/vulkan/api/Shader.h> #include <ATen/native/vulkan/api/Shader.h>
#ifdef USE_VULKAN_SHADERC_RUNTIME
#include <shaderc/shaderc.hpp>
#endif /* USE_VULKAN_SHADERC_RUNTIME */
namespace at { namespace at {
namespace native { namespace native {
namespace vulkan { namespace vulkan {
@ -14,38 +10,19 @@ namespace api {
// //
ShaderInfo::ShaderInfo() ShaderInfo::ShaderInfo()
: type(ShaderInfo::Type::SPIRV), : src_code{
src_code{ nullptr,
.spirv = 0u,
{
nullptr,
0u,
},
} {} } {}
ShaderInfo::ShaderInfo(std::string name, const char* const glsl_src)
: type(ShaderInfo::Type::GLSL),
src_code{
.glsl =
{
glsl_src,
0u,
},
},
kernel_name{std::move(name)} {}
ShaderInfo::ShaderInfo( ShaderInfo::ShaderInfo(
std::string name, std::string name,
const uint32_t* const spirv_bin, const uint32_t* const spirv_bin,
const uint32_t size, const uint32_t size,
const std::vector<VkDescriptorType>& layout) const std::vector<VkDescriptorType>& layout)
: type(Type::SPIRV), : src_code{
src_code{ spirv_bin,
.spirv = size,
{
spirv_bin,
size,
},
}, },
kernel_name{std::move(name)}, kernel_name{std::move(name)},
kernel_layout{layout} {} kernel_layout{layout} {}
@ -58,13 +35,9 @@ ShaderInfo::ShaderInfo(
const std::vector<uint32_t>& tile_size, const std::vector<uint32_t>& tile_size,
const StorageType bias_storage_type, const StorageType bias_storage_type,
const StorageType weight_storage_type) const StorageType weight_storage_type)
: type(Type::SPIRV), : src_code{
src_code{ spirv_bin,
.spirv = size,
{
spirv_bin,
size,
},
}, },
kernel_name{std::move(name)}, kernel_name{std::move(name)},
kernel_layout{layout}, kernel_layout{layout},
@ -77,17 +50,9 @@ ShaderInfo::ShaderInfo(
} }
bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) { bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) {
if (_1.type != _2.type) { return (
return false; _1.src_code.bin == _2.src_code.bin &&
} _1.src_code.size == _2.src_code.size);
if (_1.type == ShaderInfo::Type::SPIRV) {
return (
_1.src_code.spirv.bin == _2.src_code.spirv.bin &&
_1.src_code.spirv.size == _2.src_code.spirv.size);
} else {
return (_1.src_code.glsl.src == _2.src_code.glsl.src);
}
} }
// //
@ -153,8 +118,8 @@ void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept {
ShaderModule::ShaderModule(const VkDevice device, const ShaderInfo& source) ShaderModule::ShaderModule(const VkDevice device, const ShaderInfo& source)
: device_(device), handle_{VK_NULL_HANDLE} { : device_(device), handle_{VK_NULL_HANDLE} {
const uint32_t* code = source.src_code.spirv.bin; const uint32_t* code = source.src_code.bin;
uint32_t size = source.src_code.spirv.size; uint32_t size = source.src_code.size;
const VkShaderModuleCreateInfo shader_module_create_info{ const VkShaderModuleCreateInfo shader_module_create_info{
VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType

View File

@ -45,17 +45,9 @@ class ShaderLayout final {
}; };
struct ShaderInfo final { struct ShaderInfo final {
enum class Type { GLSL, SPIRV } type; struct {
const uint32_t* bin;
union { uint32_t size;
struct {
const char* src; // Null-terminated
uint32_t unused; // padding
} glsl;
struct {
const uint32_t* bin;
uint32_t size;
} spirv;
} src_code; } src_code;
std::string kernel_name{""}; std::string kernel_name{""};
@ -171,8 +163,7 @@ class ShaderCache final {
struct Hasher { struct Hasher {
inline size_t operator()(const ShaderInfo& source) const { inline size_t operator()(const ShaderInfo& source) const {
return c10::get_hash( return c10::get_hash(source.src_code.bin, source.src_code.size);
source.type, source.src_code.spirv.bin, source.src_code.spirv.size);
} }
}; };

View File

@ -6,18 +6,9 @@
#include <ATen/core/Tensor.h> #include <ATen/core/Tensor.h>
#include <ATen/native/vulkan/api/api.h> #include <ATen/native/vulkan/api/api.h>
#include <ATen/native/vulkan/ops/Convert.h> #include <ATen/native/vulkan/ops/Convert.h>
#define CONCAT_LITERALS(a, b) #a #b
#ifdef USE_VULKAN_SHADERC_RUNTIME
#include <ATen/native/vulkan/glsl.h>
#define VK_KERNEL(name) \
::at::native::vulkan::api::ShaderInfo { \
CONCAT_LITERALS(vulkan., name), name##_glsl, \
}
#else
#include <ATen/native/vulkan/spv.h> #include <ATen/native/vulkan/spv.h>
#define VK_KERNEL(name) ::at::native::vulkan::name##_spv #define VK_KERNEL(name) ::at::native::vulkan::name##_spv
#endif /* USE_VULKAN_SHADERC_RUNTIME */
namespace at { namespace at {
namespace native { namespace native {

View File

@ -175,7 +175,6 @@ function(caffe2_print_configuration_summary)
if(${USE_VULKAN}) if(${USE_VULKAN})
message(STATUS " USE_VULKAN_FP16_INFERENCE : ${USE_VULKAN_FP16_INFERENCE}") message(STATUS " USE_VULKAN_FP16_INFERENCE : ${USE_VULKAN_FP16_INFERENCE}")
message(STATUS " USE_VULKAN_RELAXED_PRECISION : ${USE_VULKAN_RELAXED_PRECISION}") message(STATUS " USE_VULKAN_RELAXED_PRECISION : ${USE_VULKAN_RELAXED_PRECISION}")
message(STATUS " USE_VULKAN_SHADERC_RUNTIME : ${USE_VULKAN_SHADERC_RUNTIME}")
endif() endif()
message(STATUS " USE_PROF : ${USE_PROF}") message(STATUS " USE_PROF : ${USE_PROF}")
message(STATUS " USE_QNNPACK : ${USE_QNNPACK}") message(STATUS " USE_QNNPACK : ${USE_QNNPACK}")

View File

@ -13,71 +13,45 @@ if(USE_VULKAN_FP16_INFERENCE)
list(APPEND VULKAN_GEN_ARG_ENV "format=rgba16f") list(APPEND VULKAN_GEN_ARG_ENV "format=rgba16f")
endif() endif()
if(USE_VULKAN_SHADERC_RUNTIME) # Precompiling shaders
set(PYTHONPATH "$ENV{PYTHONPATH}") if(ANDROID)
set(NEW_PYTHONPATH ${PYTHONPATH}) if(NOT ANDROID_NDK)
list(APPEND NEW_PYTHONPATH "${CMAKE_CURRENT_LIST_DIR}/..") message(FATAL_ERROR "ANDROID_NDK not set")
set(ENV{PYTHONPATH} ${NEW_PYTHONPATH}) endif()
execute_process(
COMMAND set(GLSLC_PATH "${ANDROID_NDK}/shader-tools/${ANDROID_NDK_HOST_SYSTEM_NAME}/glslc")
"${PYTHON_EXECUTABLE}" else()
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/gen_vulkan_glsl.py find_program(
--glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl GLSLC_PATH glslc
--output-path ${VULKAN_GEN_OUTPUT_PATH} PATHS
--tmp-dir-path=${CMAKE_BINARY_DIR}/vulkan/glsl ENV VULKAN_SDK
--env ${VULKAN_GEN_ARG_ENV} PATHS "$ENV{VULKAN_SDK}/${CMAKE_HOST_SYSTEM_PROCESSOR}/bin"
RESULT_VARIABLE error_code) PATHS "$ENV{VULKAN_SDK}/bin"
set(ENV{PYTHONPATH} ${PYTHONPATH}) )
if(NOT GLSLC_PATH)
message(FATAL_ERROR "USE_VULKAN glslc not found")
endif(NOT GLSLC_PATH)
endif()
set(PYTHONPATH "$ENV{PYTHONPATH}")
set(NEW_PYTHONPATH ${PYTHONPATH})
list(APPEND NEW_PYTHONPATH "${CMAKE_CURRENT_LIST_DIR}/..")
set(ENV{PYTHONPATH} ${NEW_PYTHONPATH})
execute_process(
COMMAND
"${PYTHON_EXECUTABLE}"
${CMAKE_CURRENT_LIST_DIR}/../tools/gen_vulkan_spv.py
--glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl
--output-path ${VULKAN_GEN_OUTPUT_PATH}
--glslc-path=${GLSLC_PATH}
--tmp-dir-path=${CMAKE_BINARY_DIR}/vulkan/spv
--env ${VULKAN_GEN_ARG_ENV}
RESULT_VARIABLE error_code)
set(ENV{PYTHONPATH} ${PYTHONPATH})
if(error_code) if(error_code)
message(FATAL_ERROR "Failed to gen glsl.h and glsl.cpp with shaders sources for Vulkan backend") message(FATAL_ERROR "Failed to gen spv.h and spv.cpp with precompiled shaders for Vulkan backend")
endif() endif()
set(vulkan_generated_cpp ${VULKAN_GEN_OUTPUT_PATH}/glsl.cpp) set(vulkan_generated_cpp ${VULKAN_GEN_OUTPUT_PATH}/spv.cpp)
return()
endif()
if(NOT USE_VULKAN_SHADERC_RUNTIME)
# Precompiling shaders
if(ANDROID)
if(NOT ANDROID_NDK)
message(FATAL_ERROR "ANDROID_NDK not set")
endif()
set(GLSLC_PATH "${ANDROID_NDK}/shader-tools/${ANDROID_NDK_HOST_SYSTEM_NAME}/glslc")
else()
find_program(
GLSLC_PATH glslc
PATHS
ENV VULKAN_SDK
PATHS "$ENV{VULKAN_SDK}/${CMAKE_HOST_SYSTEM_PROCESSOR}/bin"
PATHS "$ENV{VULKAN_SDK}/bin"
)
if(NOT GLSLC_PATH)
message(FATAL_ERROR "USE_VULKAN glslc not found")
endif(NOT GLSLC_PATH)
endif()
set(PYTHONPATH "$ENV{PYTHONPATH}")
set(NEW_PYTHONPATH ${PYTHONPATH})
list(APPEND NEW_PYTHONPATH "${CMAKE_CURRENT_LIST_DIR}/..")
set(ENV{PYTHONPATH} ${NEW_PYTHONPATH})
execute_process(
COMMAND
"${PYTHON_EXECUTABLE}"
${CMAKE_CURRENT_LIST_DIR}/../tools/gen_vulkan_spv.py
--glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl
--output-path ${VULKAN_GEN_OUTPUT_PATH}
--glslc-path=${GLSLC_PATH}
--tmp-dir-path=${CMAKE_BINARY_DIR}/vulkan/spv
--env ${VULKAN_GEN_ARG_ENV}
RESULT_VARIABLE error_code)
set(ENV{PYTHONPATH} ${PYTHONPATH})
if(error_code)
message(FATAL_ERROR "Failed to gen spv.h and spv.cpp with precompiled shaders for Vulkan backend")
endif()
set(vulkan_generated_cpp ${VULKAN_GEN_OUTPUT_PATH}/spv.cpp)
endif()

View File

@ -29,60 +29,6 @@ if(ANDROID)
list(APPEND Vulkan_INCLUDES ${VULKAN_WRAPPER_DIR}) list(APPEND Vulkan_INCLUDES ${VULKAN_WRAPPER_DIR})
list(APPEND Vulkan_LIBS VulkanWrapper) list(APPEND Vulkan_LIBS VulkanWrapper)
# Shaderc
if(USE_VULKAN_SHADERC_RUNTIME)
# Shaderc from ANDROID_NDK
set(Shaderc_ANDROID_NDK_INCLUDE_DIR "${ANDROID_NDK}/sources/third_party/shaderc/include")
message(STATUS "Shaderc_ANDROID_NDK_INCLUDE_DIR:${Shaderc_ANDROID_NDK_INCLUDE_DIR}")
find_path(
GOOGLE_SHADERC_INCLUDE_DIRS
NAMES shaderc/shaderc.hpp
PATHS "${Shaderc_ANDROID_NDK_INCLUDE_DIR}")
set(Shaderc_ANDROID_NDK_LIB_DIR "${ANDROID_NDK}/sources/third_party/shaderc/libs/${ANDROID_STL}/${ANDROID_ABI}")
message(STATUS "Shaderc_ANDROID_NDK_LIB_DIR:${Shaderc_ANDROID_NDK_LIB_DIR}")
find_library(
GOOGLE_SHADERC_LIBRARIES
NAMES shaderc
PATHS "${Shaderc_ANDROID_NDK_LIB_DIR}")
# Shaderc in NDK is not prebuilt
if(NOT GOOGLE_SHADERC_LIBRARIES)
set(NDK_SHADERC_DIR "${ANDROID_NDK}/sources/third_party/shaderc")
set(NDK_BUILD_CMD "${ANDROID_NDK}/ndk-build")
execute_process(
COMMAND ${NDK_BUILD_CMD}
NDK_PROJECT_PATH=${NDK_SHADERC_DIR}
APP_BUILD_SCRIPT=${NDK_SHADERC_DIR}/Android.mk
APP_PLATFORM=${ANDROID_PLATFORM}
APP_STL=${ANDROID_STL}
APP_ABI=${ANDROID_ABI}
libshaderc_combined -j8
WORKING_DIRECTORY "${NDK_SHADERC_DIR}"
RESULT_VARIABLE error_code)
if(error_code)
message(FATAL_ERROR "Failed to build ANDROID_NDK Shaderc error_code:${error_code}")
else()
unset(GOOGLE_SHADERC_LIBRARIES CACHE)
find_library(
GOOGLE_SHADERC_LIBRARIES
NAMES shaderc
HINTS "${Shaderc_ANDROID_NDK_LIB_DIR}")
endif()
endif(NOT GOOGLE_SHADERC_LIBRARIES)
if(GOOGLE_SHADERC_INCLUDE_DIRS AND GOOGLE_SHADERC_LIBRARIES)
message(STATUS "Shaderc FOUND include:${GOOGLE_SHADERC_INCLUDE_DIRS}")
message(STATUS "Shaderc FOUND libs:${GOOGLE_SHADERC_LIBRARIES}")
endif()
list(APPEND Vulkan_INCLUDES ${GOOGLE_SHADERC_INCLUDE_DIRS})
list(APPEND Vulkan_LIBS ${GOOGLE_SHADERC_LIBRARIES})
endif(USE_VULKAN_SHADERC_RUNTIME)
else() else()
find_package(Vulkan) find_package(Vulkan)
@ -95,32 +41,4 @@ else()
set(GOOGLE_SHADERC_INCLUDE_SEARCH_PATH ${Vulkan_INCLUDE_DIR}) set(GOOGLE_SHADERC_INCLUDE_SEARCH_PATH ${Vulkan_INCLUDE_DIR})
set(GOOGLE_SHADERC_LIBRARY_SEARCH_PATH ${Vulkan_LIBRARY}) set(GOOGLE_SHADERC_LIBRARY_SEARCH_PATH ${Vulkan_LIBRARY})
if(USE_VULKAN_SHADERC_RUNTIME)
find_path(
GOOGLE_SHADERC_INCLUDE_DIRS
NAMES shaderc/shaderc.hpp
PATHS ${GOOGLE_SHADERC_INCLUDE_SEARCH_PATH})
find_library(
GOOGLE_SHADERC_LIBRARIES
NAMES shaderc_combined
PATHS ${GOOGLE_SHADERC_LIBRARY_SEARCH_PATH})
find_package_handle_standard_args(
Shaderc
DEFAULT_MSG
GOOGLE_SHADERC_INCLUDE_DIRS
GOOGLE_SHADERC_LIBRARIES)
if(NOT Shaderc_FOUND)
message(FATAL_ERROR "USE_VULKAN: Shaderc not found in VULKAN_SDK")
else()
message(STATUS "shaderc FOUND include:${GOOGLE_SHADERC_INCLUDE_DIRS}")
message(STATUS "shaderc FOUND libs:${GOOGLE_SHADERC_LIBRARIES}")
endif()
list(APPEND Vulkan_INCLUDES ${GOOGLE_SHADERC_INCLUDE_DIRS})
list(APPEND Vulkan_LIBS ${GOOGLE_SHADERC_LIBRARIES})
endif(USE_VULKAN_SHADERC_RUNTIME)
endif() endif()

View File

@ -157,9 +157,6 @@ if [ -n "${USE_VULKAN}" ]; then
if [ -n "${USE_VULKAN_RELAXED_PRECISION}" ]; then if [ -n "${USE_VULKAN_RELAXED_PRECISION}" ]; then
CMAKE_ARGS+=("-DUSE_VULKAN_RELAXED_PRECISION=ON") CMAKE_ARGS+=("-DUSE_VULKAN_RELAXED_PRECISION=ON")
fi fi
if [ -n "${USE_VULKAN_SHADERC_RUNTIME}" ]; then
CMAKE_ARGS+=("-DUSE_VULKAN_SHADERC_RUNTIME=ON")
fi
fi fi
# Use-specified CMake arguments go last to allow overridding defaults # Use-specified CMake arguments go last to allow overridding defaults

View File

@ -211,18 +211,6 @@ def define_tools_targets(
srcs = [ srcs = [
"gen_vulkan_spv.py", "gen_vulkan_spv.py",
], ],
base_module = "",
deps = [
torchgen_deps,
":gen_aten_vulkan_glsl_lib",
],
)
python_library(
name = "gen_aten_vulkan_glsl_lib",
srcs = [
"gen_vulkan_glsl.py",
],
base_module = "tools", base_module = "tools",
deps = [ deps = [
torchgen_deps, torchgen_deps,
@ -231,12 +219,11 @@ def define_tools_targets(
python_binary( python_binary(
name = "gen_aten_vulkan_spv_bin", name = "gen_aten_vulkan_spv_bin",
main_module = "gen_vulkan_spv", main_module = "tools.gen_vulkan_spv",
visibility = [ visibility = [
"PUBLIC", "PUBLIC",
], ],
deps = [ deps = [
":gen_aten_vulkan_glsl_lib",
":gen_aten_vulkan_spv_lib", ":gen_aten_vulkan_spv_lib",
], ],
) )
@ -249,7 +236,6 @@ def define_tools_targets(
contacts = contacts, contacts = contacts,
visibility = ["PUBLIC"], visibility = ["PUBLIC"],
deps = [ deps = [
":gen_aten_vulkan_glsl_lib",
":gen_aten_vulkan_spv_lib", ":gen_aten_vulkan_spv_lib",
], ],
) )

View File

@ -1,111 +0,0 @@
import copy
import os
from collections import OrderedDict
import yaml
from torchgen.code_template import CodeTemplate
from yaml.constructor import ConstructorError
from yaml.nodes import MappingNode
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader # type: ignore[misc]
# https://gist.github.com/pypt/94d747fe5180851196eb
class UniqueKeyLoader(Loader):
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
if not isinstance(node, MappingNode):
raise ConstructorError(
None,
None,
"expected a mapping node, but found %s" % node.id,
node.start_mark,
)
mapping = {}
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
try:
hash(key)
except TypeError as e:
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"found unacceptable key ",
key_node.start_mark,
) from e
# check for duplicate keys
if key in mapping:
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"found duplicate key",
key_node.start_mark,
)
value = self.construct_object(value_node, deep=deep) # type: ignore[no-untyped-call]
mapping[key] = value
return mapping
class GLSLGenerator(object):
standard_header = """
#version 450 core
#define PRECISION $precision
#define FORMAT $format
"""
def __init__(self): # type: ignore[no-untyped-def]
self.ops_template_params = {}
def add_params_yaml(self, parameters_yaml_file): # type: ignore[no-untyped-def]
all_template_params = OrderedDict()
with open(parameters_yaml_file, "r") as f:
contents = yaml.load(f, Loader=UniqueKeyLoader)
for key in contents:
all_template_params[key] = contents[key]
self.validate_and_construct_op_params(all_template_params) # type: ignore[no-untyped-call]
def validate_and_construct_op_params(self, all_template_params): # type: ignore[no-untyped-def]
for op in all_template_params:
if op in self.ops_template_params:
raise KeyError(f"{op} params file has already been parsed")
op_params_default_vals = all_template_params[op][
"parameter_names_with_default_values"
]
template_params_set = set(op_params_default_vals.keys())
self.ops_template_params[op] = []
self.ops_template_params[op].append(op_params_default_vals)
op_template_params_values = all_template_params[op]["parameter_values"]
for param_vals in op_template_params_values:
param_vals_set = set(param_vals.keys())
invalid_keys = param_vals_set - template_params_set
if (len(invalid_keys)) > 0:
raise KeyError(f"Invalid keys {invalid_keys} are found")
param_vals_copy = copy.deepcopy(op_params_default_vals)
for key in param_vals:
param_vals_copy[key] = param_vals[key]
self.ops_template_params[op].append(param_vals_copy)
def generate(self, glsl_template_in, out_dir): # type: ignore[no-untyped-def]
glsl_template_name = os.path.basename(glsl_template_in)
op_name, extension_name = glsl_template_name.split(".")
if extension_name != "glslt":
raise TypeError(f"invalid file type for glsl template {extension_name}")
if op_name not in self.ops_template_params:
raise KeyError(f"{op_name} params have not been populated")
code_template = CodeTemplate.from_file(glsl_template_in)
for template_params in self.ops_template_params[op_name]:
content = GLSLGenerator.standard_header
param_vals_string = "x".join([str(i) for i in template_params.values()])
output_file_name = op_name + "_" + param_vals_string + ".glsl"
content += code_template.substitute(template_params)
output_file = os.path.join(out_dir, output_file_name)
with open(output_file, "w") as f:
f.write(content)
# Remove this
if __name__ == "__main__":
pass

View File

@ -2,21 +2,122 @@
import argparse import argparse
import array import array
import copy
import glob import glob
import os import os
import re import re
import sys import sys
import subprocess import subprocess
import yaml
from collections import OrderedDict
from torchgen.code_template import CodeTemplate from torchgen.code_template import CodeTemplate
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import Any, Dict, List
from yaml.constructor import ConstructorError
from yaml.nodes import MappingNode
from tools.gen_vulkan_glsl import GLSLGenerator try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader # type: ignore[misc]
H_NAME = "spv.h" H_NAME = "spv.h"
CPP_NAME = "spv.cpp" CPP_NAME = "spv.cpp"
DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"} DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"}
# https://gist.github.com/pypt/94d747fe5180851196eb
class UniqueKeyLoader(Loader):
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
if not isinstance(node, MappingNode):
raise ConstructorError(
None,
None,
"expected a mapping node, but found %s" % node.id,
node.start_mark,
)
mapping = {}
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
try:
hash(key)
except TypeError as e:
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"found unacceptable key ",
key_node.start_mark,
) from e
# check for duplicate keys
if key in mapping:
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"found duplicate key",
key_node.start_mark,
)
value = self.construct_object(value_node, deep=deep) # type: ignore[no-untyped-call]
mapping[key] = value
return mapping
class VulkanShaderGenerator(object):
standard_header = """
#version 450 core
#define PRECISION $precision
#define FORMAT $format
"""
def __init__(self: "VulkanShaderGenerator") -> None:
self.ops_template_params: Dict[Any, Any] = {}
def add_params_yaml(self, parameters_yaml_file): # type: ignore[no-untyped-def]
all_template_params = OrderedDict()
with open(parameters_yaml_file, "r") as f:
contents = yaml.load(f, Loader=UniqueKeyLoader)
for key in contents:
all_template_params[key] = contents[key]
self.validate_and_construct_op_params(all_template_params) # type: ignore[no-untyped-call]
def validate_and_construct_op_params(self, all_template_params): # type: ignore[no-untyped-def]
for op in all_template_params:
if op in self.ops_template_params:
raise KeyError(f"{op} params file has already been parsed")
op_params_default_vals = all_template_params[op][
"parameter_names_with_default_values"
]
template_params_set = set(op_params_default_vals.keys())
self.ops_template_params[op] = []
self.ops_template_params[op].append(op_params_default_vals)
op_template_params_values = all_template_params[op]["parameter_values"]
for param_vals in op_template_params_values:
param_vals_set = set(param_vals.keys())
missing_keys = template_params_set - param_vals_set
invalid_keys = param_vals_set - template_params_set
if (len(invalid_keys)) > 0:
raise KeyError(f"Invalid keys {invalid_keys} are found")
param_vals_copy = copy.deepcopy(op_params_default_vals)
for key in param_vals:
param_vals_copy[key] = param_vals[key]
self.ops_template_params[op].append(param_vals_copy)
def generate(self, glsl_template_in, out_dir): # type: ignore[no-untyped-def]
glsl_template_name = os.path.basename(glsl_template_in)
op_name, extension_name = glsl_template_name.split(".")
if extension_name != "glslt":
raise TypeError(f"invalid file type for glsl template {extension_name}")
if op_name not in self.ops_template_params:
raise KeyError(f"{op_name} params have not been populated")
code_template = CodeTemplate.from_file(glsl_template_in)
for template_params in self.ops_template_params[op_name]:
content = VulkanShaderGenerator.standard_header
param_vals_string = "x".join([str(i) for i in template_params.values()])
output_file_name = op_name + "_" + param_vals_string + ".glsl"
content += code_template.substitute(template_params)
output_file = os.path.join(out_dir, output_file_name)
with open(output_file, "w") as f:
f.write(content)
@dataclass @dataclass
class ShaderInfo: class ShaderInfo:
@ -25,38 +126,44 @@ class ShaderInfo:
weight_storage_type: str = "" weight_storage_type: str = ""
bias_storage_type: str = "" bias_storage_type: str = ""
def getName(filePath): def getName(filePath: str) -> str:
return os.path.basename(filePath).replace("/", "_").replace(".", "_") return os.path.basename(filePath).replace("/", "_").replace(".", "_")
def isDescriptorLine(lineStr): def isDescriptorLine(lineStr: str) -> bool:
descriptorLineId = r"^layout\(set" descriptorLineId = r"^layout\(set"
return re.search(descriptorLineId, lineStr) return re.search(descriptorLineId, lineStr) is not None
def isTileSizeLine(lineStr): def isTileSizeLine(lineStr: str) -> bool:
tile_size_id = r"^ \* TILE_SIZE = \(" tile_size_id = r"^ \* TILE_SIZE = \("
return re.search(tile_size_id, lineStr) return re.search(tile_size_id, lineStr) is not None
def findTileSizes(lineStr): def findTileSizes(lineStr: str) -> List[int]:
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)" tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
matches = re.search(tile_size_id, lineStr) matches = re.search(tile_size_id, lineStr)
if matches is None:
raise AssertionError("matches is None in findTileSizes")
return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))] return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
def isWeightStorageTypeLine(lineStr): def isWeightStorageTypeLine(lineStr: str) -> bool:
weight_storage_id = r"^ \* WEIGHT_STORAGE = " weight_storage_id = r"^ \* WEIGHT_STORAGE = "
return re.search(weight_storage_id, lineStr) return re.search(weight_storage_id, lineStr) is not None
def getWeightStorageType(lineStr): def getWeightStorageType(lineStr: str) -> str:
weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)" weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
matches = re.search(weight_storage_id, lineStr) matches = re.search(weight_storage_id, lineStr)
if matches is None:
raise AssertionError("matches is None in getWeightStorageType")
return matches.group(1) return matches.group(1)
def isBiasStorageTypeLine(lineStr): def isBiasStorageTypeLine(lineStr: str) -> bool:
weight_storage_id = r"^ \* BIAS_STORAGE = " weight_storage_id = r"^ \* BIAS_STORAGE = "
return re.search(weight_storage_id, lineStr) return re.search(weight_storage_id, lineStr) is not None
def getBiasStorageType(lineStr): def getBiasStorageType(lineStr: str) -> str:
weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)" weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
matches = re.search(weight_storage_id, lineStr) matches = re.search(weight_storage_id, lineStr)
if matches is None:
raise AssertionError("matches is None in getBiasStorageType")
return matches.group(1) return matches.group(1)
typeIdMapping = { typeIdMapping = {
@ -73,12 +180,15 @@ storageTypeToEnum = {
"": "api::StorageType::UNKNOWN", "": "api::StorageType::UNKNOWN",
} }
def determineDescriptorType(lineStr): def determineDescriptorType(lineStr: str) -> str:
for identifier, typeNum in typeIdMapping.items(): for identifier, typeNum in typeIdMapping.items():
if re.search(identifier, lineStr): if re.search(identifier, lineStr):
return typeNum return typeNum
raise AssertionError(
"No matching descriptor type for " + lineStr + " in determineDescriptorType"
)
def getShaderInfo(srcFilePath): def getShaderInfo(srcFilePath: str) -> ShaderInfo:
shader_info = ShaderInfo([], [], "") shader_info = ShaderInfo([], [], "")
with open(srcFilePath, 'r') as srcFile: with open(srcFilePath, 'r') as srcFile:
for line in srcFile: for line in srcFile:
@ -93,14 +203,14 @@ def getShaderInfo(srcFilePath):
return shader_info return shader_info
def genGLSLFromGLSLT(src_dir_path, tmp_dir_path): def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
template_dir_path = os.path.join(src_dir_path, "templates") template_dir_path = os.path.join(src_dir_path, "templates")
vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True) vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True)
parameter_yaml_files = [] parameter_yaml_files = []
for f in vexs: for f in vexs:
if len(f) > 1: if len(f) > 1:
parameter_yaml_files.append(f) parameter_yaml_files.append(f)
generator = GLSLGenerator() generator = VulkanShaderGenerator()
for params_yaml in parameter_yaml_files: for params_yaml in parameter_yaml_files:
generator.add_params_yaml(params_yaml) # type: ignore[no-untyped-call] generator.add_params_yaml(params_yaml) # type: ignore[no-untyped-call]
@ -113,9 +223,20 @@ def genGLSLFromGLSLT(src_dir_path, tmp_dir_path):
for glslt in templateSrcPaths: for glslt in templateSrcPaths:
generator.generate(glslt, tmp_dir_path) # type: ignore[no-untyped-call] generator.generate(glslt, tmp_dir_path) # type: ignore[no-untyped-call]
def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
print("hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".format( def genCppH(
hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath)) hFilePath: str,
cppFilePath: str,
srcDirPath: str,
glslcPath: str,
tmpDirPath: str,
env: Dict[Any, Any],
) -> None:
print(
"hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".format(
hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath
)
)
vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True) vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True)
templateSrcPaths = [] templateSrcPaths = []
@ -142,8 +263,8 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
codeTemplate = CodeTemplate.from_file(templateSrcPath) codeTemplate = CodeTemplate.from_file(templateSrcPath)
srcPath = tmpDirPath + "/" + name + ".glsl" srcPath = tmpDirPath + "/" + name + ".glsl"
content = codeTemplate.substitute(env) content = codeTemplate.substitute(env)
with open(srcPath, 'w') as f: with open(srcPath, 'w') as fw:
f.write(content) fw.write(content)
spvPath = tmpDirPath + "/" + name + ".spv" spvPath = tmpDirPath + "/" + name + ".spv"
print("spvPath {}".format(spvPath)) print("spvPath {}".format(spvPath))
@ -188,8 +309,8 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
name = getName(spvPath) name = getName(spvPath)
print("spvPath:{}".format(spvPath)) print("spvPath:{}".format(spvPath))
with open(spvPath, 'rb') as f: with open(spvPath, 'rb') as fr:
next_bin = array.array('I', f.read()) next_bin = array.array('I', fr.read())
sizeBytes = 4 * len(next_bin) sizeBytes = 4 * len(next_bin)
shader_info_bin_code.append( shader_info_bin_code.append(
"const uint32_t {}_bin[] = {{\n {}\n}};".format( "const uint32_t {}_bin[] = {{\n {}\n}};".format(
@ -231,13 +352,13 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
cpp += nsend cpp += nsend
h += nsend h += nsend
with open(hFilePath, "w") as f: with open(hFilePath, "w") as fw:
f.write(h) fw.write(h)
with open(cppFilePath, "w") as f: with open(cppFilePath, "w") as fw:
f.write(cpp) fw.write(cpp)
def parse_arg_env(items): def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
d = {} d = {}
if items: if items:
for item in items: for item in items:
@ -248,7 +369,7 @@ def parse_arg_env(items):
return d return d
def main(argv): def main(argv: List[str]) -> int:
parser = argparse.ArgumentParser(description='') parser = argparse.ArgumentParser(description='')
parser.add_argument( parser.add_argument(
'-i', '-i',
@ -294,5 +415,7 @@ def main(argv):
tmpDirPath=options.tmp_dir_path, tmpDirPath=options.tmp_dir_path,
env=env) env=env)
return 0
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(main(sys.argv)) sys.exit(main(sys.argv))

View File

@ -2,11 +2,11 @@ import os
import tempfile import tempfile
import unittest import unittest
from tools.gen_vulkan_glsl import GLSLGenerator from tools.gen_vulkan_spv import VulkanShaderGenerator
from yaml.constructor import ConstructorError from yaml.constructor import ConstructorError
class TestGLSLCodegen(unittest.TestCase): class TestVulkanShaderCodegen(unittest.TestCase):
def test_assert_on_duplicate_key_yaml(self) -> None: def test_assert_on_duplicate_key_yaml(self) -> None:
yaml_with_duplicate_keys = """ yaml_with_duplicate_keys = """
conv2d_pw: conv2d_pw:
@ -37,7 +37,7 @@ conv2d_pw:
TILE_SIZE_Y: 4 TILE_SIZE_Y: 4
""" """
generator = GLSLGenerator() # type: ignore[no-untyped-call] generator = VulkanShaderGenerator() # type: ignore[no-untyped-call]
with tempfile.NamedTemporaryFile(mode="w") as fp: with tempfile.NamedTemporaryFile(mode="w") as fp:
fp.write(yaml_with_duplicate_keys) fp.write(yaml_with_duplicate_keys)
fp.flush() fp.flush()
@ -57,7 +57,7 @@ conv2d_pw:
TILE_SIZE_Z: 2 TILE_SIZE_Z: 2
""" """
generator = GLSLGenerator() # type: ignore[no-untyped-call] generator = VulkanShaderGenerator() # type: ignore[no-untyped-call]
with tempfile.NamedTemporaryFile(mode="w") as fp: with tempfile.NamedTemporaryFile(mode="w") as fp:
fp.write(yaml_with_key_mismatch) fp.write(yaml_with_key_mismatch)
fp.flush() fp.flush()
@ -77,7 +77,7 @@ conv2d_pw:
x = $TILE_SIZE_X + $TILE_SIZE_Y x = $TILE_SIZE_X + $TILE_SIZE_Y
""" """
generator = GLSLGenerator() # type: ignore[no-untyped-call] generator = VulkanShaderGenerator() # type: ignore[no-untyped-call]
with tempfile.NamedTemporaryFile(mode="w") as fp: with tempfile.NamedTemporaryFile(mode="w") as fp:
fp.write(yaml_with_key_mismatch) fp.write(yaml_with_key_mismatch)
fp.flush() fp.flush()