diff --git a/CMakeLists.txt b/CMakeLists.txt index 359687f9b0ee..53aff333bfdd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -266,7 +266,6 @@ option(USE_SOURCE_DEBUG_ON_MOBILE "Enable " ON) option(USE_LITE_INTERPRETER_PROFILER "Enable " ON) option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF) option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF) -option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation as opposed to build-time (needs libshaderc)" OFF) # option USE_XNNPACK: try to enable xnnpack by default. option(USE_XNNPACK "Use XNNPACK" ON) option(USE_ZMQ "Use ZMQ" OFF) @@ -746,9 +745,6 @@ if(USE_VULKAN) string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_RELAXED_PRECISION") endif() - if(USE_VULKAN_SHADERC_RUNTIME) - string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_SHADERC_RUNTIME") - endif() endif() if(BUILD_LITE_INTERPRETER) diff --git a/aten/src/ATen/gen_vulkan_glsl.py b/aten/src/ATen/gen_vulkan_glsl.py deleted file mode 100644 index b43dcb6cfeff..000000000000 --- a/aten/src/ATen/gen_vulkan_glsl.py +++ /dev/null @@ -1,115 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import glob -import sys -import os -from torchgen.code_template import CodeTemplate - -H_NAME = "glsl.h" -CPP_NAME = "glsl.cpp" -DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"} - -def findAllGlsls(path): - vexs = glob.glob(os.path.join(path, '**', '*.glsl'), recursive=True) - output = [] - for f in vexs: - if len(f) > 1: - output.append(f) - output.sort() - return output - -def getName(filePath): - return os.path.basename(filePath).replace("/", "_").replace(".", "_") - -def genCppH(hFilePath, cppFilePath, templateGlslPaths, tmpDirPath, env): - print("hFilePath:{}".format(hFilePath)) - print("cppFilePath:{}".format(cppFilePath)) - h = "#pragma once\n" - nsbegin = "\nnamespace at { namespace native { namespace vulkan { \n" - nsend = "\n} } } //namespace at::native::vulkan\n" - - h += nsbegin - - cpp = "#include ".format(H_NAME) - cpp += nsbegin - - for templateGlslPath in templateGlslPaths: - name = getName(templateGlslPath) - h += "extern const char* " + name + ";\n" - cpp += "const char* " + name + " = \n" - - codeTemplate = CodeTemplate.from_file(templateGlslPath) - srcPath = tmpDirPath + "/" + name + ".glsl" - content = codeTemplate.substitute(env) - - lines = content.split("\n") - for l in lines: - if (len(l) < 1): - continue - cpp += "\"" + l + "\\n\"\n" - - cpp += ";\n" - - cpp += nsend - h += nsend - - with open(hFilePath, "w") as f: - f.write(h) - with open(cppFilePath, "w") as f: - f.write(cpp) - - -def parse_arg_env(items): - d = {} - if items: - for item in items: - tokens = item.split("=") - key = tokens[0].strip() - value = tokens[1].strip() - d[key] = value - return d - - -def main(argv): - parser = argparse.ArgumentParser(description='Generate glsl.cpp and glsl.h containing glsl sources') - parser.add_argument( - '-i', - '--glsl-path', - help='path to directory with glsl to process', - required=True, - default='.') - parser.add_argument( - '-o', - '--output-path', - help='path to directory to generate glsl.h glsl.cpp (cpp namespace at::native::vulkan)', - required=True) - parser.add_argument( - '-t', - '--tmp-dir-path', - required=True, - help='/tmp') - parser.add_argument( - "--env", - metavar="KEY=VALUE", - nargs='*', - help="Set a number of key-value pairs") - options = parser.parse_args() - if not os.path.exists(options.tmp_dir_path): - os.makedirs(options.tmp_dir_path) - env = DEFAULT_ENV - for key, value in parse_arg_env(options.env).items(): - env[key] = value - - if not os.path.exists(options.output_path): - os.makedirs(options.output_path) - - glsls = findAllGlsls(options.glsl_path) - genCppH( - options.output_path + "/" + H_NAME, options.output_path + "/" + CPP_NAME, - glsls, - tmpDirPath=options.tmp_dir_path, - env=env) - -if __name__ == '__main__': - sys.exit(main(sys.argv)) diff --git a/aten/src/ATen/native/vulkan/api/Shader.cpp b/aten/src/ATen/native/vulkan/api/Shader.cpp index 7b48e18d7dde..302483abd0a5 100644 --- a/aten/src/ATen/native/vulkan/api/Shader.cpp +++ b/aten/src/ATen/native/vulkan/api/Shader.cpp @@ -1,9 +1,5 @@ #include -#ifdef USE_VULKAN_SHADERC_RUNTIME -#include -#endif /* USE_VULKAN_SHADERC_RUNTIME */ - namespace at { namespace native { namespace vulkan { @@ -14,38 +10,19 @@ namespace api { // ShaderInfo::ShaderInfo() - : type(ShaderInfo::Type::SPIRV), - src_code{ - .spirv = - { - nullptr, - 0u, - }, + : src_code{ + nullptr, + 0u, } {} -ShaderInfo::ShaderInfo(std::string name, const char* const glsl_src) - : type(ShaderInfo::Type::GLSL), - src_code{ - .glsl = - { - glsl_src, - 0u, - }, - }, - kernel_name{std::move(name)} {} - ShaderInfo::ShaderInfo( std::string name, const uint32_t* const spirv_bin, const uint32_t size, const std::vector& layout) - : type(Type::SPIRV), - src_code{ - .spirv = - { - spirv_bin, - size, - }, + : src_code{ + spirv_bin, + size, }, kernel_name{std::move(name)}, kernel_layout{layout} {} @@ -58,13 +35,9 @@ ShaderInfo::ShaderInfo( const std::vector& tile_size, const StorageType bias_storage_type, const StorageType weight_storage_type) - : type(Type::SPIRV), - src_code{ - .spirv = - { - spirv_bin, - size, - }, + : src_code{ + spirv_bin, + size, }, kernel_name{std::move(name)}, kernel_layout{layout}, @@ -77,17 +50,9 @@ ShaderInfo::ShaderInfo( } bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) { - if (_1.type != _2.type) { - return false; - } - - if (_1.type == ShaderInfo::Type::SPIRV) { - return ( - _1.src_code.spirv.bin == _2.src_code.spirv.bin && - _1.src_code.spirv.size == _2.src_code.spirv.size); - } else { - return (_1.src_code.glsl.src == _2.src_code.glsl.src); - } + return ( + _1.src_code.bin == _2.src_code.bin && + _1.src_code.size == _2.src_code.size); } // @@ -153,8 +118,8 @@ void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept { ShaderModule::ShaderModule(const VkDevice device, const ShaderInfo& source) : device_(device), handle_{VK_NULL_HANDLE} { - const uint32_t* code = source.src_code.spirv.bin; - uint32_t size = source.src_code.spirv.size; + const uint32_t* code = source.src_code.bin; + uint32_t size = source.src_code.size; const VkShaderModuleCreateInfo shader_module_create_info{ VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType diff --git a/aten/src/ATen/native/vulkan/api/Shader.h b/aten/src/ATen/native/vulkan/api/Shader.h index 1a196785c1cc..03d3d7c11d1e 100644 --- a/aten/src/ATen/native/vulkan/api/Shader.h +++ b/aten/src/ATen/native/vulkan/api/Shader.h @@ -45,17 +45,9 @@ class ShaderLayout final { }; struct ShaderInfo final { - enum class Type { GLSL, SPIRV } type; - - union { - struct { - const char* src; // Null-terminated - uint32_t unused; // padding - } glsl; - struct { - const uint32_t* bin; - uint32_t size; - } spirv; + struct { + const uint32_t* bin; + uint32_t size; } src_code; std::string kernel_name{""}; @@ -171,8 +163,7 @@ class ShaderCache final { struct Hasher { inline size_t operator()(const ShaderInfo& source) const { - return c10::get_hash( - source.type, source.src_code.spirv.bin, source.src_code.spirv.size); + return c10::get_hash(source.src_code.bin, source.src_code.size); } }; diff --git a/aten/src/ATen/native/vulkan/ops/Common.h b/aten/src/ATen/native/vulkan/ops/Common.h index 79a0f1edaaba..944977c635aa 100644 --- a/aten/src/ATen/native/vulkan/ops/Common.h +++ b/aten/src/ATen/native/vulkan/ops/Common.h @@ -6,18 +6,9 @@ #include #include #include - -#define CONCAT_LITERALS(a, b) #a #b -#ifdef USE_VULKAN_SHADERC_RUNTIME -#include -#define VK_KERNEL(name) \ - ::at::native::vulkan::api::ShaderInfo { \ - CONCAT_LITERALS(vulkan., name), name##_glsl, \ - } -#else #include + #define VK_KERNEL(name) ::at::native::vulkan::name##_spv -#endif /* USE_VULKAN_SHADERC_RUNTIME */ namespace at { namespace native { diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 279d72a41e66..23c9cd8eeb77 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -175,7 +175,6 @@ function(caffe2_print_configuration_summary) if(${USE_VULKAN}) message(STATUS " USE_VULKAN_FP16_INFERENCE : ${USE_VULKAN_FP16_INFERENCE}") message(STATUS " USE_VULKAN_RELAXED_PRECISION : ${USE_VULKAN_RELAXED_PRECISION}") - message(STATUS " USE_VULKAN_SHADERC_RUNTIME : ${USE_VULKAN_SHADERC_RUNTIME}") endif() message(STATUS " USE_PROF : ${USE_PROF}") message(STATUS " USE_QNNPACK : ${USE_QNNPACK}") diff --git a/cmake/VulkanCodegen.cmake b/cmake/VulkanCodegen.cmake index 52279be82ae2..43231ea586ac 100644 --- a/cmake/VulkanCodegen.cmake +++ b/cmake/VulkanCodegen.cmake @@ -13,71 +13,45 @@ if(USE_VULKAN_FP16_INFERENCE) list(APPEND VULKAN_GEN_ARG_ENV "format=rgba16f") endif() -if(USE_VULKAN_SHADERC_RUNTIME) - set(PYTHONPATH "$ENV{PYTHONPATH}") - set(NEW_PYTHONPATH ${PYTHONPATH}) - list(APPEND NEW_PYTHONPATH "${CMAKE_CURRENT_LIST_DIR}/..") - set(ENV{PYTHONPATH} ${NEW_PYTHONPATH}) - execute_process( - COMMAND - "${PYTHON_EXECUTABLE}" - ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/gen_vulkan_glsl.py - --glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl - --output-path ${VULKAN_GEN_OUTPUT_PATH} - --tmp-dir-path=${CMAKE_BINARY_DIR}/vulkan/glsl - --env ${VULKAN_GEN_ARG_ENV} - RESULT_VARIABLE error_code) - set(ENV{PYTHONPATH} ${PYTHONPATH}) +# Precompiling shaders +if(ANDROID) + if(NOT ANDROID_NDK) + message(FATAL_ERROR "ANDROID_NDK not set") + endif() + + set(GLSLC_PATH "${ANDROID_NDK}/shader-tools/${ANDROID_NDK_HOST_SYSTEM_NAME}/glslc") +else() + find_program( + GLSLC_PATH glslc + PATHS + ENV VULKAN_SDK + PATHS "$ENV{VULKAN_SDK}/${CMAKE_HOST_SYSTEM_PROCESSOR}/bin" + PATHS "$ENV{VULKAN_SDK}/bin" + ) + + if(NOT GLSLC_PATH) + message(FATAL_ERROR "USE_VULKAN glslc not found") + endif(NOT GLSLC_PATH) +endif() + +set(PYTHONPATH "$ENV{PYTHONPATH}") +set(NEW_PYTHONPATH ${PYTHONPATH}) +list(APPEND NEW_PYTHONPATH "${CMAKE_CURRENT_LIST_DIR}/..") +set(ENV{PYTHONPATH} ${NEW_PYTHONPATH}) +execute_process( + COMMAND + "${PYTHON_EXECUTABLE}" + ${CMAKE_CURRENT_LIST_DIR}/../tools/gen_vulkan_spv.py + --glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl + --output-path ${VULKAN_GEN_OUTPUT_PATH} + --glslc-path=${GLSLC_PATH} + --tmp-dir-path=${CMAKE_BINARY_DIR}/vulkan/spv + --env ${VULKAN_GEN_ARG_ENV} + RESULT_VARIABLE error_code) +set(ENV{PYTHONPATH} ${PYTHONPATH}) if(error_code) - message(FATAL_ERROR "Failed to gen glsl.h and glsl.cpp with shaders sources for Vulkan backend") + message(FATAL_ERROR "Failed to gen spv.h and spv.cpp with precompiled shaders for Vulkan backend") endif() - set(vulkan_generated_cpp ${VULKAN_GEN_OUTPUT_PATH}/glsl.cpp) - return() -endif() - -if(NOT USE_VULKAN_SHADERC_RUNTIME) - # Precompiling shaders - if(ANDROID) - if(NOT ANDROID_NDK) - message(FATAL_ERROR "ANDROID_NDK not set") - endif() - - set(GLSLC_PATH "${ANDROID_NDK}/shader-tools/${ANDROID_NDK_HOST_SYSTEM_NAME}/glslc") - else() - find_program( - GLSLC_PATH glslc - PATHS - ENV VULKAN_SDK - PATHS "$ENV{VULKAN_SDK}/${CMAKE_HOST_SYSTEM_PROCESSOR}/bin" - PATHS "$ENV{VULKAN_SDK}/bin" - ) - - if(NOT GLSLC_PATH) - message(FATAL_ERROR "USE_VULKAN glslc not found") - endif(NOT GLSLC_PATH) - endif() - - set(PYTHONPATH "$ENV{PYTHONPATH}") - set(NEW_PYTHONPATH ${PYTHONPATH}) - list(APPEND NEW_PYTHONPATH "${CMAKE_CURRENT_LIST_DIR}/..") - set(ENV{PYTHONPATH} ${NEW_PYTHONPATH}) - execute_process( - COMMAND - "${PYTHON_EXECUTABLE}" - ${CMAKE_CURRENT_LIST_DIR}/../tools/gen_vulkan_spv.py - --glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl - --output-path ${VULKAN_GEN_OUTPUT_PATH} - --glslc-path=${GLSLC_PATH} - --tmp-dir-path=${CMAKE_BINARY_DIR}/vulkan/spv - --env ${VULKAN_GEN_ARG_ENV} - RESULT_VARIABLE error_code) - set(ENV{PYTHONPATH} ${PYTHONPATH}) - - if(error_code) - message(FATAL_ERROR "Failed to gen spv.h and spv.cpp with precompiled shaders for Vulkan backend") - endif() - - set(vulkan_generated_cpp ${VULKAN_GEN_OUTPUT_PATH}/spv.cpp) -endif() +set(vulkan_generated_cpp ${VULKAN_GEN_OUTPUT_PATH}/spv.cpp) diff --git a/cmake/VulkanDependencies.cmake b/cmake/VulkanDependencies.cmake index 0569a7e705c1..52de6a4286e7 100644 --- a/cmake/VulkanDependencies.cmake +++ b/cmake/VulkanDependencies.cmake @@ -29,60 +29,6 @@ if(ANDROID) list(APPEND Vulkan_INCLUDES ${VULKAN_WRAPPER_DIR}) list(APPEND Vulkan_LIBS VulkanWrapper) - # Shaderc - if(USE_VULKAN_SHADERC_RUNTIME) - # Shaderc from ANDROID_NDK - set(Shaderc_ANDROID_NDK_INCLUDE_DIR "${ANDROID_NDK}/sources/third_party/shaderc/include") - message(STATUS "Shaderc_ANDROID_NDK_INCLUDE_DIR:${Shaderc_ANDROID_NDK_INCLUDE_DIR}") - - find_path( - GOOGLE_SHADERC_INCLUDE_DIRS - NAMES shaderc/shaderc.hpp - PATHS "${Shaderc_ANDROID_NDK_INCLUDE_DIR}") - - set(Shaderc_ANDROID_NDK_LIB_DIR "${ANDROID_NDK}/sources/third_party/shaderc/libs/${ANDROID_STL}/${ANDROID_ABI}") - message(STATUS "Shaderc_ANDROID_NDK_LIB_DIR:${Shaderc_ANDROID_NDK_LIB_DIR}") - - find_library( - GOOGLE_SHADERC_LIBRARIES - NAMES shaderc - PATHS "${Shaderc_ANDROID_NDK_LIB_DIR}") - - # Shaderc in NDK is not prebuilt - if(NOT GOOGLE_SHADERC_LIBRARIES) - set(NDK_SHADERC_DIR "${ANDROID_NDK}/sources/third_party/shaderc") - set(NDK_BUILD_CMD "${ANDROID_NDK}/ndk-build") - - execute_process( - COMMAND ${NDK_BUILD_CMD} - NDK_PROJECT_PATH=${NDK_SHADERC_DIR} - APP_BUILD_SCRIPT=${NDK_SHADERC_DIR}/Android.mk - APP_PLATFORM=${ANDROID_PLATFORM} - APP_STL=${ANDROID_STL} - APP_ABI=${ANDROID_ABI} - libshaderc_combined -j8 - WORKING_DIRECTORY "${NDK_SHADERC_DIR}" - RESULT_VARIABLE error_code) - - if(error_code) - message(FATAL_ERROR "Failed to build ANDROID_NDK Shaderc error_code:${error_code}") - else() - unset(GOOGLE_SHADERC_LIBRARIES CACHE) - find_library( - GOOGLE_SHADERC_LIBRARIES - NAMES shaderc - HINTS "${Shaderc_ANDROID_NDK_LIB_DIR}") - endif() - endif(NOT GOOGLE_SHADERC_LIBRARIES) - - if(GOOGLE_SHADERC_INCLUDE_DIRS AND GOOGLE_SHADERC_LIBRARIES) - message(STATUS "Shaderc FOUND include:${GOOGLE_SHADERC_INCLUDE_DIRS}") - message(STATUS "Shaderc FOUND libs:${GOOGLE_SHADERC_LIBRARIES}") - endif() - - list(APPEND Vulkan_INCLUDES ${GOOGLE_SHADERC_INCLUDE_DIRS}) - list(APPEND Vulkan_LIBS ${GOOGLE_SHADERC_LIBRARIES}) - endif(USE_VULKAN_SHADERC_RUNTIME) else() find_package(Vulkan) @@ -95,32 +41,4 @@ else() set(GOOGLE_SHADERC_INCLUDE_SEARCH_PATH ${Vulkan_INCLUDE_DIR}) set(GOOGLE_SHADERC_LIBRARY_SEARCH_PATH ${Vulkan_LIBRARY}) - - if(USE_VULKAN_SHADERC_RUNTIME) - find_path( - GOOGLE_SHADERC_INCLUDE_DIRS - NAMES shaderc/shaderc.hpp - PATHS ${GOOGLE_SHADERC_INCLUDE_SEARCH_PATH}) - - find_library( - GOOGLE_SHADERC_LIBRARIES - NAMES shaderc_combined - PATHS ${GOOGLE_SHADERC_LIBRARY_SEARCH_PATH}) - - find_package_handle_standard_args( - Shaderc - DEFAULT_MSG - GOOGLE_SHADERC_INCLUDE_DIRS - GOOGLE_SHADERC_LIBRARIES) - - if(NOT Shaderc_FOUND) - message(FATAL_ERROR "USE_VULKAN: Shaderc not found in VULKAN_SDK") - else() - message(STATUS "shaderc FOUND include:${GOOGLE_SHADERC_INCLUDE_DIRS}") - message(STATUS "shaderc FOUND libs:${GOOGLE_SHADERC_LIBRARIES}") - endif() - - list(APPEND Vulkan_INCLUDES ${GOOGLE_SHADERC_INCLUDE_DIRS}) - list(APPEND Vulkan_LIBS ${GOOGLE_SHADERC_LIBRARIES}) - endif(USE_VULKAN_SHADERC_RUNTIME) endif() diff --git a/scripts/build_android.sh b/scripts/build_android.sh index e2be6c88e989..be018593331d 100755 --- a/scripts/build_android.sh +++ b/scripts/build_android.sh @@ -157,9 +157,6 @@ if [ -n "${USE_VULKAN}" ]; then if [ -n "${USE_VULKAN_RELAXED_PRECISION}" ]; then CMAKE_ARGS+=("-DUSE_VULKAN_RELAXED_PRECISION=ON") fi - if [ -n "${USE_VULKAN_SHADERC_RUNTIME}" ]; then - CMAKE_ARGS+=("-DUSE_VULKAN_SHADERC_RUNTIME=ON") - fi fi # Use-specified CMake arguments go last to allow overridding defaults diff --git a/tools/BUCK.bzl b/tools/BUCK.bzl index 58a49fded0ee..3d685f4bab15 100644 --- a/tools/BUCK.bzl +++ b/tools/BUCK.bzl @@ -211,18 +211,6 @@ def define_tools_targets( srcs = [ "gen_vulkan_spv.py", ], - base_module = "", - deps = [ - torchgen_deps, - ":gen_aten_vulkan_glsl_lib", - ], - ) - - python_library( - name = "gen_aten_vulkan_glsl_lib", - srcs = [ - "gen_vulkan_glsl.py", - ], base_module = "tools", deps = [ torchgen_deps, @@ -231,12 +219,11 @@ def define_tools_targets( python_binary( name = "gen_aten_vulkan_spv_bin", - main_module = "gen_vulkan_spv", + main_module = "tools.gen_vulkan_spv", visibility = [ "PUBLIC", ], deps = [ - ":gen_aten_vulkan_glsl_lib", ":gen_aten_vulkan_spv_lib", ], ) @@ -249,7 +236,6 @@ def define_tools_targets( contacts = contacts, visibility = ["PUBLIC"], deps = [ - ":gen_aten_vulkan_glsl_lib", ":gen_aten_vulkan_spv_lib", ], ) diff --git a/tools/gen_vulkan_glsl.py b/tools/gen_vulkan_glsl.py deleted file mode 100644 index 6d89da0c743c..000000000000 --- a/tools/gen_vulkan_glsl.py +++ /dev/null @@ -1,111 +0,0 @@ -import copy -import os - -from collections import OrderedDict - -import yaml -from torchgen.code_template import CodeTemplate -from yaml.constructor import ConstructorError -from yaml.nodes import MappingNode - -try: - from yaml import CLoader as Loader -except ImportError: - from yaml import Loader # type: ignore[misc] - -# https://gist.github.com/pypt/94d747fe5180851196eb -class UniqueKeyLoader(Loader): - def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def] - if not isinstance(node, MappingNode): - raise ConstructorError( - None, - None, - "expected a mapping node, but found %s" % node.id, - node.start_mark, - ) - mapping = {} - for key_node, value_node in node.value: - key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call] - try: - hash(key) - except TypeError as e: - raise ConstructorError( - "while constructing a mapping", - node.start_mark, - "found unacceptable key ", - key_node.start_mark, - ) from e - # check for duplicate keys - if key in mapping: - raise ConstructorError( - "while constructing a mapping", - node.start_mark, - "found duplicate key", - key_node.start_mark, - ) - value = self.construct_object(value_node, deep=deep) # type: ignore[no-untyped-call] - mapping[key] = value - return mapping - - -class GLSLGenerator(object): - standard_header = """ -#version 450 core -#define PRECISION $precision -#define FORMAT $format - -""" - - def __init__(self): # type: ignore[no-untyped-def] - self.ops_template_params = {} - - def add_params_yaml(self, parameters_yaml_file): # type: ignore[no-untyped-def] - all_template_params = OrderedDict() - with open(parameters_yaml_file, "r") as f: - contents = yaml.load(f, Loader=UniqueKeyLoader) - for key in contents: - all_template_params[key] = contents[key] - self.validate_and_construct_op_params(all_template_params) # type: ignore[no-untyped-call] - - def validate_and_construct_op_params(self, all_template_params): # type: ignore[no-untyped-def] - for op in all_template_params: - if op in self.ops_template_params: - raise KeyError(f"{op} params file has already been parsed") - op_params_default_vals = all_template_params[op][ - "parameter_names_with_default_values" - ] - template_params_set = set(op_params_default_vals.keys()) - self.ops_template_params[op] = [] - self.ops_template_params[op].append(op_params_default_vals) - op_template_params_values = all_template_params[op]["parameter_values"] - for param_vals in op_template_params_values: - param_vals_set = set(param_vals.keys()) - invalid_keys = param_vals_set - template_params_set - if (len(invalid_keys)) > 0: - raise KeyError(f"Invalid keys {invalid_keys} are found") - param_vals_copy = copy.deepcopy(op_params_default_vals) - for key in param_vals: - param_vals_copy[key] = param_vals[key] - self.ops_template_params[op].append(param_vals_copy) - - def generate(self, glsl_template_in, out_dir): # type: ignore[no-untyped-def] - glsl_template_name = os.path.basename(glsl_template_in) - op_name, extension_name = glsl_template_name.split(".") - if extension_name != "glslt": - raise TypeError(f"invalid file type for glsl template {extension_name}") - if op_name not in self.ops_template_params: - raise KeyError(f"{op_name} params have not been populated") - code_template = CodeTemplate.from_file(glsl_template_in) - for template_params in self.ops_template_params[op_name]: - content = GLSLGenerator.standard_header - param_vals_string = "x".join([str(i) for i in template_params.values()]) - output_file_name = op_name + "_" + param_vals_string + ".glsl" - content += code_template.substitute(template_params) - output_file = os.path.join(out_dir, output_file_name) - with open(output_file, "w") as f: - f.write(content) - - -# Remove this -if __name__ == "__main__": - pass diff --git a/tools/gen_vulkan_spv.py b/tools/gen_vulkan_spv.py index db4ed965cf84..92f6a5a9f093 100644 --- a/tools/gen_vulkan_spv.py +++ b/tools/gen_vulkan_spv.py @@ -2,21 +2,122 @@ import argparse import array +import copy import glob import os import re import sys import subprocess +import yaml +from collections import OrderedDict from torchgen.code_template import CodeTemplate from dataclasses import dataclass -from typing import List +from typing import Any, Dict, List +from yaml.constructor import ConstructorError +from yaml.nodes import MappingNode -from tools.gen_vulkan_glsl import GLSLGenerator +try: + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader # type: ignore[misc] H_NAME = "spv.h" CPP_NAME = "spv.cpp" DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"} +# https://gist.github.com/pypt/94d747fe5180851196eb +class UniqueKeyLoader(Loader): + def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def] + if not isinstance(node, MappingNode): + raise ConstructorError( + None, + None, + "expected a mapping node, but found %s" % node.id, + node.start_mark, + ) + mapping = {} + for key_node, value_node in node.value: + key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call] + try: + hash(key) + except TypeError as e: + raise ConstructorError( + "while constructing a mapping", + node.start_mark, + "found unacceptable key ", + key_node.start_mark, + ) from e + # check for duplicate keys + if key in mapping: + raise ConstructorError( + "while constructing a mapping", + node.start_mark, + "found duplicate key", + key_node.start_mark, + ) + value = self.construct_object(value_node, deep=deep) # type: ignore[no-untyped-call] + mapping[key] = value + return mapping + + +class VulkanShaderGenerator(object): + standard_header = """ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +""" + + def __init__(self: "VulkanShaderGenerator") -> None: + self.ops_template_params: Dict[Any, Any] = {} + + def add_params_yaml(self, parameters_yaml_file): # type: ignore[no-untyped-def] + all_template_params = OrderedDict() + with open(parameters_yaml_file, "r") as f: + contents = yaml.load(f, Loader=UniqueKeyLoader) + for key in contents: + all_template_params[key] = contents[key] + self.validate_and_construct_op_params(all_template_params) # type: ignore[no-untyped-call] + + def validate_and_construct_op_params(self, all_template_params): # type: ignore[no-untyped-def] + for op in all_template_params: + if op in self.ops_template_params: + raise KeyError(f"{op} params file has already been parsed") + op_params_default_vals = all_template_params[op][ + "parameter_names_with_default_values" + ] + template_params_set = set(op_params_default_vals.keys()) + self.ops_template_params[op] = [] + self.ops_template_params[op].append(op_params_default_vals) + op_template_params_values = all_template_params[op]["parameter_values"] + for param_vals in op_template_params_values: + param_vals_set = set(param_vals.keys()) + missing_keys = template_params_set - param_vals_set + invalid_keys = param_vals_set - template_params_set + if (len(invalid_keys)) > 0: + raise KeyError(f"Invalid keys {invalid_keys} are found") + param_vals_copy = copy.deepcopy(op_params_default_vals) + for key in param_vals: + param_vals_copy[key] = param_vals[key] + self.ops_template_params[op].append(param_vals_copy) + + def generate(self, glsl_template_in, out_dir): # type: ignore[no-untyped-def] + glsl_template_name = os.path.basename(glsl_template_in) + op_name, extension_name = glsl_template_name.split(".") + if extension_name != "glslt": + raise TypeError(f"invalid file type for glsl template {extension_name}") + if op_name not in self.ops_template_params: + raise KeyError(f"{op_name} params have not been populated") + code_template = CodeTemplate.from_file(glsl_template_in) + for template_params in self.ops_template_params[op_name]: + content = VulkanShaderGenerator.standard_header + param_vals_string = "x".join([str(i) for i in template_params.values()]) + output_file_name = op_name + "_" + param_vals_string + ".glsl" + content += code_template.substitute(template_params) + output_file = os.path.join(out_dir, output_file_name) + with open(output_file, "w") as f: + f.write(content) + @dataclass class ShaderInfo: @@ -25,38 +126,44 @@ class ShaderInfo: weight_storage_type: str = "" bias_storage_type: str = "" -def getName(filePath): +def getName(filePath: str) -> str: return os.path.basename(filePath).replace("/", "_").replace(".", "_") -def isDescriptorLine(lineStr): +def isDescriptorLine(lineStr: str) -> bool: descriptorLineId = r"^layout\(set" - return re.search(descriptorLineId, lineStr) + return re.search(descriptorLineId, lineStr) is not None -def isTileSizeLine(lineStr): +def isTileSizeLine(lineStr: str) -> bool: tile_size_id = r"^ \* TILE_SIZE = \(" - return re.search(tile_size_id, lineStr) + return re.search(tile_size_id, lineStr) is not None -def findTileSizes(lineStr): +def findTileSizes(lineStr: str) -> List[int]: tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)" matches = re.search(tile_size_id, lineStr) + if matches is None: + raise AssertionError("matches is None in findTileSizes") return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))] -def isWeightStorageTypeLine(lineStr): +def isWeightStorageTypeLine(lineStr: str) -> bool: weight_storage_id = r"^ \* WEIGHT_STORAGE = " - return re.search(weight_storage_id, lineStr) + return re.search(weight_storage_id, lineStr) is not None -def getWeightStorageType(lineStr): +def getWeightStorageType(lineStr: str) -> str: weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)" matches = re.search(weight_storage_id, lineStr) + if matches is None: + raise AssertionError("matches is None in getWeightStorageType") return matches.group(1) -def isBiasStorageTypeLine(lineStr): +def isBiasStorageTypeLine(lineStr: str) -> bool: weight_storage_id = r"^ \* BIAS_STORAGE = " - return re.search(weight_storage_id, lineStr) + return re.search(weight_storage_id, lineStr) is not None -def getBiasStorageType(lineStr): +def getBiasStorageType(lineStr: str) -> str: weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)" matches = re.search(weight_storage_id, lineStr) + if matches is None: + raise AssertionError("matches is None in getBiasStorageType") return matches.group(1) typeIdMapping = { @@ -73,12 +180,15 @@ storageTypeToEnum = { "": "api::StorageType::UNKNOWN", } -def determineDescriptorType(lineStr): +def determineDescriptorType(lineStr: str) -> str: for identifier, typeNum in typeIdMapping.items(): if re.search(identifier, lineStr): return typeNum + raise AssertionError( + "No matching descriptor type for " + lineStr + " in determineDescriptorType" + ) -def getShaderInfo(srcFilePath): +def getShaderInfo(srcFilePath: str) -> ShaderInfo: shader_info = ShaderInfo([], [], "") with open(srcFilePath, 'r') as srcFile: for line in srcFile: @@ -93,14 +203,14 @@ def getShaderInfo(srcFilePath): return shader_info -def genGLSLFromGLSLT(src_dir_path, tmp_dir_path): +def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None: template_dir_path = os.path.join(src_dir_path, "templates") vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True) parameter_yaml_files = [] for f in vexs: if len(f) > 1: parameter_yaml_files.append(f) - generator = GLSLGenerator() + generator = VulkanShaderGenerator() for params_yaml in parameter_yaml_files: generator.add_params_yaml(params_yaml) # type: ignore[no-untyped-call] @@ -113,9 +223,20 @@ def genGLSLFromGLSLT(src_dir_path, tmp_dir_path): for glslt in templateSrcPaths: generator.generate(glslt, tmp_dir_path) # type: ignore[no-untyped-call] -def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env): - print("hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".format( - hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath)) + +def genCppH( + hFilePath: str, + cppFilePath: str, + srcDirPath: str, + glslcPath: str, + tmpDirPath: str, + env: Dict[Any, Any], +) -> None: + print( + "hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".format( + hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath + ) + ) vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True) templateSrcPaths = [] @@ -142,8 +263,8 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env): codeTemplate = CodeTemplate.from_file(templateSrcPath) srcPath = tmpDirPath + "/" + name + ".glsl" content = codeTemplate.substitute(env) - with open(srcPath, 'w') as f: - f.write(content) + with open(srcPath, 'w') as fw: + fw.write(content) spvPath = tmpDirPath + "/" + name + ".spv" print("spvPath {}".format(spvPath)) @@ -188,8 +309,8 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env): name = getName(spvPath) print("spvPath:{}".format(spvPath)) - with open(spvPath, 'rb') as f: - next_bin = array.array('I', f.read()) + with open(spvPath, 'rb') as fr: + next_bin = array.array('I', fr.read()) sizeBytes = 4 * len(next_bin) shader_info_bin_code.append( "const uint32_t {}_bin[] = {{\n {}\n}};".format( @@ -231,13 +352,13 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env): cpp += nsend h += nsend - with open(hFilePath, "w") as f: - f.write(h) - with open(cppFilePath, "w") as f: - f.write(cpp) + with open(hFilePath, "w") as fw: + fw.write(h) + with open(cppFilePath, "w") as fw: + fw.write(cpp) -def parse_arg_env(items): +def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]: d = {} if items: for item in items: @@ -248,7 +369,7 @@ def parse_arg_env(items): return d -def main(argv): +def main(argv: List[str]) -> int: parser = argparse.ArgumentParser(description='') parser.add_argument( '-i', @@ -294,5 +415,7 @@ def main(argv): tmpDirPath=options.tmp_dir_path, env=env) + return 0 + if __name__ == '__main__': sys.exit(main(sys.argv)) diff --git a/tools/test/test_vulkan_codegen.py b/tools/test/test_vulkan_codegen.py index 8b0b4b3a13cd..ae87c27e7aeb 100644 --- a/tools/test/test_vulkan_codegen.py +++ b/tools/test/test_vulkan_codegen.py @@ -2,11 +2,11 @@ import os import tempfile import unittest -from tools.gen_vulkan_glsl import GLSLGenerator +from tools.gen_vulkan_spv import VulkanShaderGenerator from yaml.constructor import ConstructorError -class TestGLSLCodegen(unittest.TestCase): +class TestVulkanShaderCodegen(unittest.TestCase): def test_assert_on_duplicate_key_yaml(self) -> None: yaml_with_duplicate_keys = """ conv2d_pw: @@ -37,7 +37,7 @@ conv2d_pw: TILE_SIZE_Y: 4 """ - generator = GLSLGenerator() # type: ignore[no-untyped-call] + generator = VulkanShaderGenerator() # type: ignore[no-untyped-call] with tempfile.NamedTemporaryFile(mode="w") as fp: fp.write(yaml_with_duplicate_keys) fp.flush() @@ -57,7 +57,7 @@ conv2d_pw: TILE_SIZE_Z: 2 """ - generator = GLSLGenerator() # type: ignore[no-untyped-call] + generator = VulkanShaderGenerator() # type: ignore[no-untyped-call] with tempfile.NamedTemporaryFile(mode="w") as fp: fp.write(yaml_with_key_mismatch) fp.flush() @@ -77,7 +77,7 @@ conv2d_pw: x = $TILE_SIZE_X + $TILE_SIZE_Y """ - generator = GLSLGenerator() # type: ignore[no-untyped-call] + generator = VulkanShaderGenerator() # type: ignore[no-untyped-call] with tempfile.NamedTemporaryFile(mode="w") as fp: fp.write(yaml_with_key_mismatch) fp.flush()