mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	[Vulkan] Remove GLSL Code Gen (#91912)
@bypass-github-export-checks GLSL Code Gen is not used, so this diff removes - GLSL parts of ShaderSource - Anything enclosed by USE_VULKAN_SHADERC_RUNTIME, as well as the flag itself - gen_vulkan_glsl script Plus some additional refactoring Differential Revision: [D41358861](https://our.internmc.facebook.com/intern/diff/D41358861/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D41358861/)! Pull Request resolved: https://github.com/pytorch/pytorch/pull/91912 Approved by: https://github.com/mcr229
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							28eb3c8faf
						
					
				
				
					commit
					ec94cbc66a
				
			@ -266,7 +266,6 @@ option(USE_SOURCE_DEBUG_ON_MOBILE "Enable " ON)
 | 
				
			|||||||
option(USE_LITE_INTERPRETER_PROFILER "Enable " ON)
 | 
					option(USE_LITE_INTERPRETER_PROFILER "Enable " ON)
 | 
				
			||||||
option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF)
 | 
					option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF)
 | 
				
			||||||
option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF)
 | 
					option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF)
 | 
				
			||||||
option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation as opposed to build-time (needs libshaderc)" OFF)
 | 
					 | 
				
			||||||
# option USE_XNNPACK: try to enable xnnpack by default.
 | 
					# option USE_XNNPACK: try to enable xnnpack by default.
 | 
				
			||||||
option(USE_XNNPACK "Use XNNPACK" ON)
 | 
					option(USE_XNNPACK "Use XNNPACK" ON)
 | 
				
			||||||
option(USE_ZMQ "Use ZMQ" OFF)
 | 
					option(USE_ZMQ "Use ZMQ" OFF)
 | 
				
			||||||
@ -746,9 +745,6 @@ if(USE_VULKAN)
 | 
				
			|||||||
    string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_RELAXED_PRECISION")
 | 
					    string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_RELAXED_PRECISION")
 | 
				
			||||||
  endif()
 | 
					  endif()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if(USE_VULKAN_SHADERC_RUNTIME)
 | 
					 | 
				
			||||||
    string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_SHADERC_RUNTIME")
 | 
					 | 
				
			||||||
  endif()
 | 
					 | 
				
			||||||
endif()
 | 
					endif()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if(BUILD_LITE_INTERPRETER)
 | 
					if(BUILD_LITE_INTERPRETER)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,115 +0,0 @@
 | 
				
			|||||||
#!/usr/bin/env python3
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import argparse
 | 
					 | 
				
			||||||
import glob
 | 
					 | 
				
			||||||
import sys
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
from torchgen.code_template import CodeTemplate
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
H_NAME = "glsl.h"
 | 
					 | 
				
			||||||
CPP_NAME = "glsl.cpp"
 | 
					 | 
				
			||||||
DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def findAllGlsls(path):
 | 
					 | 
				
			||||||
    vexs = glob.glob(os.path.join(path, '**', '*.glsl'), recursive=True)
 | 
					 | 
				
			||||||
    output = []
 | 
					 | 
				
			||||||
    for f in vexs:
 | 
					 | 
				
			||||||
        if len(f) > 1:
 | 
					 | 
				
			||||||
            output.append(f)
 | 
					 | 
				
			||||||
    output.sort()
 | 
					 | 
				
			||||||
    return output
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def getName(filePath):
 | 
					 | 
				
			||||||
    return os.path.basename(filePath).replace("/", "_").replace(".", "_")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def genCppH(hFilePath, cppFilePath, templateGlslPaths, tmpDirPath, env):
 | 
					 | 
				
			||||||
    print("hFilePath:{}".format(hFilePath))
 | 
					 | 
				
			||||||
    print("cppFilePath:{}".format(cppFilePath))
 | 
					 | 
				
			||||||
    h = "#pragma once\n"
 | 
					 | 
				
			||||||
    nsbegin = "\nnamespace at { namespace native { namespace vulkan { \n"
 | 
					 | 
				
			||||||
    nsend = "\n} } } //namespace at::native::vulkan\n"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    h += nsbegin
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cpp = "#include <ATen/native/vulkan/{}>".format(H_NAME)
 | 
					 | 
				
			||||||
    cpp += nsbegin
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for templateGlslPath in templateGlslPaths:
 | 
					 | 
				
			||||||
        name = getName(templateGlslPath)
 | 
					 | 
				
			||||||
        h += "extern const char* " + name + ";\n"
 | 
					 | 
				
			||||||
        cpp += "const char* " + name + " = \n"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        codeTemplate = CodeTemplate.from_file(templateGlslPath)
 | 
					 | 
				
			||||||
        srcPath = tmpDirPath + "/" + name + ".glsl"
 | 
					 | 
				
			||||||
        content = codeTemplate.substitute(env)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        lines = content.split("\n")
 | 
					 | 
				
			||||||
        for l in lines:
 | 
					 | 
				
			||||||
            if (len(l) < 1):
 | 
					 | 
				
			||||||
                continue
 | 
					 | 
				
			||||||
            cpp += "\"" + l + "\\n\"\n"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        cpp += ";\n"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cpp += nsend
 | 
					 | 
				
			||||||
    h += nsend
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    with open(hFilePath, "w") as f:
 | 
					 | 
				
			||||||
        f.write(h)
 | 
					 | 
				
			||||||
    with open(cppFilePath, "w") as f:
 | 
					 | 
				
			||||||
        f.write(cpp)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_arg_env(items):
 | 
					 | 
				
			||||||
    d = {}
 | 
					 | 
				
			||||||
    if items:
 | 
					 | 
				
			||||||
        for item in items:
 | 
					 | 
				
			||||||
            tokens = item.split("=")
 | 
					 | 
				
			||||||
            key = tokens[0].strip()
 | 
					 | 
				
			||||||
            value = tokens[1].strip()
 | 
					 | 
				
			||||||
            d[key] = value
 | 
					 | 
				
			||||||
    return d
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def main(argv):
 | 
					 | 
				
			||||||
    parser = argparse.ArgumentParser(description='Generate glsl.cpp and glsl.h containing glsl sources')
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        '-i',
 | 
					 | 
				
			||||||
        '--glsl-path',
 | 
					 | 
				
			||||||
        help='path to directory with glsl to process',
 | 
					 | 
				
			||||||
        required=True,
 | 
					 | 
				
			||||||
        default='.')
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        '-o',
 | 
					 | 
				
			||||||
        '--output-path',
 | 
					 | 
				
			||||||
        help='path to directory to generate glsl.h glsl.cpp (cpp namespace at::native::vulkan)',
 | 
					 | 
				
			||||||
        required=True)
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        '-t',
 | 
					 | 
				
			||||||
        '--tmp-dir-path',
 | 
					 | 
				
			||||||
        required=True,
 | 
					 | 
				
			||||||
        help='/tmp')
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        "--env",
 | 
					 | 
				
			||||||
        metavar="KEY=VALUE",
 | 
					 | 
				
			||||||
        nargs='*',
 | 
					 | 
				
			||||||
        help="Set a number of key-value pairs")
 | 
					 | 
				
			||||||
    options = parser.parse_args()
 | 
					 | 
				
			||||||
    if not os.path.exists(options.tmp_dir_path):
 | 
					 | 
				
			||||||
        os.makedirs(options.tmp_dir_path)
 | 
					 | 
				
			||||||
    env = DEFAULT_ENV
 | 
					 | 
				
			||||||
    for key, value in parse_arg_env(options.env).items():
 | 
					 | 
				
			||||||
        env[key] = value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if not os.path.exists(options.output_path):
 | 
					 | 
				
			||||||
        os.makedirs(options.output_path)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    glsls = findAllGlsls(options.glsl_path)
 | 
					 | 
				
			||||||
    genCppH(
 | 
					 | 
				
			||||||
        options.output_path + "/" + H_NAME, options.output_path + "/" + CPP_NAME,
 | 
					 | 
				
			||||||
        glsls,
 | 
					 | 
				
			||||||
        tmpDirPath=options.tmp_dir_path,
 | 
					 | 
				
			||||||
        env=env)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == '__main__':
 | 
					 | 
				
			||||||
    sys.exit(main(sys.argv))
 | 
					 | 
				
			||||||
@ -1,9 +1,5 @@
 | 
				
			|||||||
#include <ATen/native/vulkan/api/Shader.h>
 | 
					#include <ATen/native/vulkan/api/Shader.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#ifdef USE_VULKAN_SHADERC_RUNTIME
 | 
					 | 
				
			||||||
#include <shaderc/shaderc.hpp>
 | 
					 | 
				
			||||||
#endif /* USE_VULKAN_SHADERC_RUNTIME */
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
namespace at {
 | 
					namespace at {
 | 
				
			||||||
namespace native {
 | 
					namespace native {
 | 
				
			||||||
namespace vulkan {
 | 
					namespace vulkan {
 | 
				
			||||||
@ -14,39 +10,20 @@ namespace api {
 | 
				
			|||||||
//
 | 
					//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ShaderInfo::ShaderInfo()
 | 
					ShaderInfo::ShaderInfo()
 | 
				
			||||||
    : type(ShaderInfo::Type::SPIRV),
 | 
					    : src_code{
 | 
				
			||||||
      src_code{
 | 
					 | 
				
			||||||
          .spirv =
 | 
					 | 
				
			||||||
              {
 | 
					 | 
				
			||||||
          nullptr,
 | 
					          nullptr,
 | 
				
			||||||
          0u,
 | 
					          0u,
 | 
				
			||||||
              },
 | 
					 | 
				
			||||||
      } {}
 | 
					      } {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ShaderInfo::ShaderInfo(std::string name, const char* const glsl_src)
 | 
					 | 
				
			||||||
    : type(ShaderInfo::Type::GLSL),
 | 
					 | 
				
			||||||
      src_code{
 | 
					 | 
				
			||||||
          .glsl =
 | 
					 | 
				
			||||||
              {
 | 
					 | 
				
			||||||
                  glsl_src,
 | 
					 | 
				
			||||||
                  0u,
 | 
					 | 
				
			||||||
              },
 | 
					 | 
				
			||||||
      },
 | 
					 | 
				
			||||||
      kernel_name{std::move(name)} {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
ShaderInfo::ShaderInfo(
 | 
					ShaderInfo::ShaderInfo(
 | 
				
			||||||
    std::string name,
 | 
					    std::string name,
 | 
				
			||||||
    const uint32_t* const spirv_bin,
 | 
					    const uint32_t* const spirv_bin,
 | 
				
			||||||
    const uint32_t size,
 | 
					    const uint32_t size,
 | 
				
			||||||
    const std::vector<VkDescriptorType>& layout)
 | 
					    const std::vector<VkDescriptorType>& layout)
 | 
				
			||||||
    : type(Type::SPIRV),
 | 
					    : src_code{
 | 
				
			||||||
      src_code{
 | 
					 | 
				
			||||||
          .spirv =
 | 
					 | 
				
			||||||
              {
 | 
					 | 
				
			||||||
          spirv_bin,
 | 
					          spirv_bin,
 | 
				
			||||||
          size,
 | 
					          size,
 | 
				
			||||||
      },
 | 
					      },
 | 
				
			||||||
      },
 | 
					 | 
				
			||||||
      kernel_name{std::move(name)},
 | 
					      kernel_name{std::move(name)},
 | 
				
			||||||
      kernel_layout{layout} {}
 | 
					      kernel_layout{layout} {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -58,14 +35,10 @@ ShaderInfo::ShaderInfo(
 | 
				
			|||||||
    const std::vector<uint32_t>& tile_size,
 | 
					    const std::vector<uint32_t>& tile_size,
 | 
				
			||||||
    const StorageType bias_storage_type,
 | 
					    const StorageType bias_storage_type,
 | 
				
			||||||
    const StorageType weight_storage_type)
 | 
					    const StorageType weight_storage_type)
 | 
				
			||||||
    : type(Type::SPIRV),
 | 
					    : src_code{
 | 
				
			||||||
      src_code{
 | 
					 | 
				
			||||||
          .spirv =
 | 
					 | 
				
			||||||
              {
 | 
					 | 
				
			||||||
          spirv_bin,
 | 
					          spirv_bin,
 | 
				
			||||||
          size,
 | 
					          size,
 | 
				
			||||||
      },
 | 
					      },
 | 
				
			||||||
      },
 | 
					 | 
				
			||||||
      kernel_name{std::move(name)},
 | 
					      kernel_name{std::move(name)},
 | 
				
			||||||
      kernel_layout{layout},
 | 
					      kernel_layout{layout},
 | 
				
			||||||
      tile_size(tile_size),
 | 
					      tile_size(tile_size),
 | 
				
			||||||
@ -77,17 +50,9 @@ ShaderInfo::ShaderInfo(
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) {
 | 
					bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) {
 | 
				
			||||||
  if (_1.type != _2.type) {
 | 
					 | 
				
			||||||
    return false;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  if (_1.type == ShaderInfo::Type::SPIRV) {
 | 
					 | 
				
			||||||
  return (
 | 
					  return (
 | 
				
			||||||
        _1.src_code.spirv.bin == _2.src_code.spirv.bin &&
 | 
					      _1.src_code.bin == _2.src_code.bin &&
 | 
				
			||||||
        _1.src_code.spirv.size == _2.src_code.spirv.size);
 | 
					      _1.src_code.size == _2.src_code.size);
 | 
				
			||||||
  } else {
 | 
					 | 
				
			||||||
    return (_1.src_code.glsl.src == _2.src_code.glsl.src);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
@ -153,8 +118,8 @@ void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
ShaderModule::ShaderModule(const VkDevice device, const ShaderInfo& source)
 | 
					ShaderModule::ShaderModule(const VkDevice device, const ShaderInfo& source)
 | 
				
			||||||
    : device_(device), handle_{VK_NULL_HANDLE} {
 | 
					    : device_(device), handle_{VK_NULL_HANDLE} {
 | 
				
			||||||
  const uint32_t* code = source.src_code.spirv.bin;
 | 
					  const uint32_t* code = source.src_code.bin;
 | 
				
			||||||
  uint32_t size = source.src_code.spirv.size;
 | 
					  uint32_t size = source.src_code.size;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const VkShaderModuleCreateInfo shader_module_create_info{
 | 
					  const VkShaderModuleCreateInfo shader_module_create_info{
 | 
				
			||||||
      VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType
 | 
					      VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType
 | 
				
			||||||
 | 
				
			|||||||
@ -45,17 +45,9 @@ class ShaderLayout final {
 | 
				
			|||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct ShaderInfo final {
 | 
					struct ShaderInfo final {
 | 
				
			||||||
  enum class Type { GLSL, SPIRV } type;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  union {
 | 
					 | 
				
			||||||
    struct {
 | 
					 | 
				
			||||||
      const char* src; // Null-terminated
 | 
					 | 
				
			||||||
      uint32_t unused; // padding
 | 
					 | 
				
			||||||
    } glsl;
 | 
					 | 
				
			||||||
  struct {
 | 
					  struct {
 | 
				
			||||||
    const uint32_t* bin;
 | 
					    const uint32_t* bin;
 | 
				
			||||||
    uint32_t size;
 | 
					    uint32_t size;
 | 
				
			||||||
    } spirv;
 | 
					 | 
				
			||||||
  } src_code;
 | 
					  } src_code;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  std::string kernel_name{""};
 | 
					  std::string kernel_name{""};
 | 
				
			||||||
@ -171,8 +163,7 @@ class ShaderCache final {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  struct Hasher {
 | 
					  struct Hasher {
 | 
				
			||||||
    inline size_t operator()(const ShaderInfo& source) const {
 | 
					    inline size_t operator()(const ShaderInfo& source) const {
 | 
				
			||||||
      return c10::get_hash(
 | 
					      return c10::get_hash(source.src_code.bin, source.src_code.size);
 | 
				
			||||||
          source.type, source.src_code.spirv.bin, source.src_code.spirv.size);
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -6,18 +6,9 @@
 | 
				
			|||||||
#include <ATen/core/Tensor.h>
 | 
					#include <ATen/core/Tensor.h>
 | 
				
			||||||
#include <ATen/native/vulkan/api/api.h>
 | 
					#include <ATen/native/vulkan/api/api.h>
 | 
				
			||||||
#include <ATen/native/vulkan/ops/Convert.h>
 | 
					#include <ATen/native/vulkan/ops/Convert.h>
 | 
				
			||||||
 | 
					 | 
				
			||||||
#define CONCAT_LITERALS(a, b) #a #b
 | 
					 | 
				
			||||||
#ifdef USE_VULKAN_SHADERC_RUNTIME
 | 
					 | 
				
			||||||
#include <ATen/native/vulkan/glsl.h>
 | 
					 | 
				
			||||||
#define VK_KERNEL(name)                          \
 | 
					 | 
				
			||||||
  ::at::native::vulkan::api::ShaderInfo {        \
 | 
					 | 
				
			||||||
    CONCAT_LITERALS(vulkan., name), name##_glsl, \
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
#else
 | 
					 | 
				
			||||||
#include <ATen/native/vulkan/spv.h>
 | 
					#include <ATen/native/vulkan/spv.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#define VK_KERNEL(name) ::at::native::vulkan::name##_spv
 | 
					#define VK_KERNEL(name) ::at::native::vulkan::name##_spv
 | 
				
			||||||
#endif /* USE_VULKAN_SHADERC_RUNTIME */
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace at {
 | 
					namespace at {
 | 
				
			||||||
namespace native {
 | 
					namespace native {
 | 
				
			||||||
 | 
				
			|||||||
@ -175,7 +175,6 @@ function(caffe2_print_configuration_summary)
 | 
				
			|||||||
  if(${USE_VULKAN})
 | 
					  if(${USE_VULKAN})
 | 
				
			||||||
    message(STATUS "    USE_VULKAN_FP16_INFERENCE    : ${USE_VULKAN_FP16_INFERENCE}")
 | 
					    message(STATUS "    USE_VULKAN_FP16_INFERENCE    : ${USE_VULKAN_FP16_INFERENCE}")
 | 
				
			||||||
    message(STATUS "    USE_VULKAN_RELAXED_PRECISION : ${USE_VULKAN_RELAXED_PRECISION}")
 | 
					    message(STATUS "    USE_VULKAN_RELAXED_PRECISION : ${USE_VULKAN_RELAXED_PRECISION}")
 | 
				
			||||||
    message(STATUS "    USE_VULKAN_SHADERC_RUNTIME   : ${USE_VULKAN_SHADERC_RUNTIME}")
 | 
					 | 
				
			||||||
  endif()
 | 
					  endif()
 | 
				
			||||||
  message(STATUS "  USE_PROF              : ${USE_PROF}")
 | 
					  message(STATUS "  USE_PROF              : ${USE_PROF}")
 | 
				
			||||||
  message(STATUS "  USE_QNNPACK           : ${USE_QNNPACK}")
 | 
					  message(STATUS "  USE_QNNPACK           : ${USE_QNNPACK}")
 | 
				
			||||||
 | 
				
			|||||||
@ -13,39 +13,14 @@ if(USE_VULKAN_FP16_INFERENCE)
 | 
				
			|||||||
  list(APPEND VULKAN_GEN_ARG_ENV "format=rgba16f")
 | 
					  list(APPEND VULKAN_GEN_ARG_ENV "format=rgba16f")
 | 
				
			||||||
endif()
 | 
					endif()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if(USE_VULKAN_SHADERC_RUNTIME)
 | 
					# Precompiling shaders
 | 
				
			||||||
  set(PYTHONPATH "$ENV{PYTHONPATH}")
 | 
					if(ANDROID)
 | 
				
			||||||
  set(NEW_PYTHONPATH ${PYTHONPATH})
 | 
					 | 
				
			||||||
  list(APPEND NEW_PYTHONPATH "${CMAKE_CURRENT_LIST_DIR}/..")
 | 
					 | 
				
			||||||
  set(ENV{PYTHONPATH} ${NEW_PYTHONPATH})
 | 
					 | 
				
			||||||
  execute_process(
 | 
					 | 
				
			||||||
    COMMAND
 | 
					 | 
				
			||||||
    "${PYTHON_EXECUTABLE}"
 | 
					 | 
				
			||||||
    ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/gen_vulkan_glsl.py
 | 
					 | 
				
			||||||
    --glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl
 | 
					 | 
				
			||||||
    --output-path ${VULKAN_GEN_OUTPUT_PATH}
 | 
					 | 
				
			||||||
    --tmp-dir-path=${CMAKE_BINARY_DIR}/vulkan/glsl
 | 
					 | 
				
			||||||
    --env ${VULKAN_GEN_ARG_ENV}
 | 
					 | 
				
			||||||
    RESULT_VARIABLE error_code)
 | 
					 | 
				
			||||||
  set(ENV{PYTHONPATH} ${PYTHONPATH})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  if(error_code)
 | 
					 | 
				
			||||||
    message(FATAL_ERROR "Failed to gen glsl.h and glsl.cpp with shaders sources for Vulkan backend")
 | 
					 | 
				
			||||||
  endif()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  set(vulkan_generated_cpp ${VULKAN_GEN_OUTPUT_PATH}/glsl.cpp)
 | 
					 | 
				
			||||||
  return()
 | 
					 | 
				
			||||||
endif()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if(NOT USE_VULKAN_SHADERC_RUNTIME)
 | 
					 | 
				
			||||||
  # Precompiling shaders
 | 
					 | 
				
			||||||
  if(ANDROID)
 | 
					 | 
				
			||||||
  if(NOT ANDROID_NDK)
 | 
					  if(NOT ANDROID_NDK)
 | 
				
			||||||
    message(FATAL_ERROR "ANDROID_NDK not set")
 | 
					    message(FATAL_ERROR "ANDROID_NDK not set")
 | 
				
			||||||
  endif()
 | 
					  endif()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  set(GLSLC_PATH "${ANDROID_NDK}/shader-tools/${ANDROID_NDK_HOST_SYSTEM_NAME}/glslc")
 | 
					  set(GLSLC_PATH "${ANDROID_NDK}/shader-tools/${ANDROID_NDK_HOST_SYSTEM_NAME}/glslc")
 | 
				
			||||||
  else()
 | 
					else()
 | 
				
			||||||
  find_program(
 | 
					  find_program(
 | 
				
			||||||
    GLSLC_PATH glslc
 | 
					    GLSLC_PATH glslc
 | 
				
			||||||
    PATHS
 | 
					    PATHS
 | 
				
			||||||
@ -57,13 +32,13 @@ if(NOT USE_VULKAN_SHADERC_RUNTIME)
 | 
				
			|||||||
  if(NOT GLSLC_PATH)
 | 
					  if(NOT GLSLC_PATH)
 | 
				
			||||||
    message(FATAL_ERROR "USE_VULKAN glslc not found")
 | 
					    message(FATAL_ERROR "USE_VULKAN glslc not found")
 | 
				
			||||||
  endif(NOT GLSLC_PATH)
 | 
					  endif(NOT GLSLC_PATH)
 | 
				
			||||||
  endif()
 | 
					endif()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  set(PYTHONPATH "$ENV{PYTHONPATH}")
 | 
					set(PYTHONPATH "$ENV{PYTHONPATH}")
 | 
				
			||||||
  set(NEW_PYTHONPATH ${PYTHONPATH})
 | 
					set(NEW_PYTHONPATH ${PYTHONPATH})
 | 
				
			||||||
  list(APPEND NEW_PYTHONPATH "${CMAKE_CURRENT_LIST_DIR}/..")
 | 
					list(APPEND NEW_PYTHONPATH "${CMAKE_CURRENT_LIST_DIR}/..")
 | 
				
			||||||
  set(ENV{PYTHONPATH} ${NEW_PYTHONPATH})
 | 
					set(ENV{PYTHONPATH} ${NEW_PYTHONPATH})
 | 
				
			||||||
  execute_process(
 | 
					execute_process(
 | 
				
			||||||
  COMMAND
 | 
					  COMMAND
 | 
				
			||||||
  "${PYTHON_EXECUTABLE}"
 | 
					  "${PYTHON_EXECUTABLE}"
 | 
				
			||||||
  ${CMAKE_CURRENT_LIST_DIR}/../tools/gen_vulkan_spv.py
 | 
					  ${CMAKE_CURRENT_LIST_DIR}/../tools/gen_vulkan_spv.py
 | 
				
			||||||
@ -73,11 +48,10 @@ if(NOT USE_VULKAN_SHADERC_RUNTIME)
 | 
				
			|||||||
  --tmp-dir-path=${CMAKE_BINARY_DIR}/vulkan/spv
 | 
					  --tmp-dir-path=${CMAKE_BINARY_DIR}/vulkan/spv
 | 
				
			||||||
  --env ${VULKAN_GEN_ARG_ENV}
 | 
					  --env ${VULKAN_GEN_ARG_ENV}
 | 
				
			||||||
  RESULT_VARIABLE error_code)
 | 
					  RESULT_VARIABLE error_code)
 | 
				
			||||||
  set(ENV{PYTHONPATH} ${PYTHONPATH})
 | 
					set(ENV{PYTHONPATH} ${PYTHONPATH})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if(error_code)
 | 
					  if(error_code)
 | 
				
			||||||
    message(FATAL_ERROR "Failed to gen spv.h and spv.cpp with precompiled shaders for Vulkan backend")
 | 
					    message(FATAL_ERROR "Failed to gen spv.h and spv.cpp with precompiled shaders for Vulkan backend")
 | 
				
			||||||
  endif()
 | 
					  endif()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  set(vulkan_generated_cpp ${VULKAN_GEN_OUTPUT_PATH}/spv.cpp)
 | 
					set(vulkan_generated_cpp ${VULKAN_GEN_OUTPUT_PATH}/spv.cpp)
 | 
				
			||||||
endif()
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -29,60 +29,6 @@ if(ANDROID)
 | 
				
			|||||||
  list(APPEND Vulkan_INCLUDES ${VULKAN_WRAPPER_DIR})
 | 
					  list(APPEND Vulkan_INCLUDES ${VULKAN_WRAPPER_DIR})
 | 
				
			||||||
  list(APPEND Vulkan_LIBS VulkanWrapper)
 | 
					  list(APPEND Vulkan_LIBS VulkanWrapper)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # Shaderc
 | 
					 | 
				
			||||||
  if(USE_VULKAN_SHADERC_RUNTIME)
 | 
					 | 
				
			||||||
    # Shaderc from ANDROID_NDK
 | 
					 | 
				
			||||||
    set(Shaderc_ANDROID_NDK_INCLUDE_DIR "${ANDROID_NDK}/sources/third_party/shaderc/include")
 | 
					 | 
				
			||||||
    message(STATUS "Shaderc_ANDROID_NDK_INCLUDE_DIR:${Shaderc_ANDROID_NDK_INCLUDE_DIR}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    find_path(
 | 
					 | 
				
			||||||
      GOOGLE_SHADERC_INCLUDE_DIRS
 | 
					 | 
				
			||||||
      NAMES shaderc/shaderc.hpp
 | 
					 | 
				
			||||||
      PATHS "${Shaderc_ANDROID_NDK_INCLUDE_DIR}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    set(Shaderc_ANDROID_NDK_LIB_DIR "${ANDROID_NDK}/sources/third_party/shaderc/libs/${ANDROID_STL}/${ANDROID_ABI}")
 | 
					 | 
				
			||||||
    message(STATUS "Shaderc_ANDROID_NDK_LIB_DIR:${Shaderc_ANDROID_NDK_LIB_DIR}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    find_library(
 | 
					 | 
				
			||||||
      GOOGLE_SHADERC_LIBRARIES
 | 
					 | 
				
			||||||
      NAMES shaderc
 | 
					 | 
				
			||||||
      PATHS "${Shaderc_ANDROID_NDK_LIB_DIR}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Shaderc in NDK is not prebuilt
 | 
					 | 
				
			||||||
    if(NOT GOOGLE_SHADERC_LIBRARIES)
 | 
					 | 
				
			||||||
      set(NDK_SHADERC_DIR "${ANDROID_NDK}/sources/third_party/shaderc")
 | 
					 | 
				
			||||||
      set(NDK_BUILD_CMD "${ANDROID_NDK}/ndk-build")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      execute_process(
 | 
					 | 
				
			||||||
        COMMAND ${NDK_BUILD_CMD}
 | 
					 | 
				
			||||||
        NDK_PROJECT_PATH=${NDK_SHADERC_DIR}
 | 
					 | 
				
			||||||
        APP_BUILD_SCRIPT=${NDK_SHADERC_DIR}/Android.mk
 | 
					 | 
				
			||||||
        APP_PLATFORM=${ANDROID_PLATFORM}
 | 
					 | 
				
			||||||
        APP_STL=${ANDROID_STL}
 | 
					 | 
				
			||||||
        APP_ABI=${ANDROID_ABI}
 | 
					 | 
				
			||||||
        libshaderc_combined -j8
 | 
					 | 
				
			||||||
        WORKING_DIRECTORY "${NDK_SHADERC_DIR}"
 | 
					 | 
				
			||||||
        RESULT_VARIABLE error_code)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      if(error_code)
 | 
					 | 
				
			||||||
        message(FATAL_ERROR "Failed to build ANDROID_NDK Shaderc error_code:${error_code}")
 | 
					 | 
				
			||||||
      else()
 | 
					 | 
				
			||||||
        unset(GOOGLE_SHADERC_LIBRARIES CACHE)
 | 
					 | 
				
			||||||
        find_library(
 | 
					 | 
				
			||||||
          GOOGLE_SHADERC_LIBRARIES
 | 
					 | 
				
			||||||
          NAMES shaderc
 | 
					 | 
				
			||||||
          HINTS "${Shaderc_ANDROID_NDK_LIB_DIR}")
 | 
					 | 
				
			||||||
      endif()
 | 
					 | 
				
			||||||
    endif(NOT GOOGLE_SHADERC_LIBRARIES)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if(GOOGLE_SHADERC_INCLUDE_DIRS AND GOOGLE_SHADERC_LIBRARIES)
 | 
					 | 
				
			||||||
      message(STATUS "Shaderc FOUND include:${GOOGLE_SHADERC_INCLUDE_DIRS}")
 | 
					 | 
				
			||||||
      message(STATUS "Shaderc FOUND libs:${GOOGLE_SHADERC_LIBRARIES}")
 | 
					 | 
				
			||||||
    endif()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    list(APPEND Vulkan_INCLUDES ${GOOGLE_SHADERC_INCLUDE_DIRS})
 | 
					 | 
				
			||||||
    list(APPEND Vulkan_LIBS ${GOOGLE_SHADERC_LIBRARIES})
 | 
					 | 
				
			||||||
  endif(USE_VULKAN_SHADERC_RUNTIME)
 | 
					 | 
				
			||||||
else()
 | 
					else()
 | 
				
			||||||
  find_package(Vulkan)
 | 
					  find_package(Vulkan)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -95,32 +41,4 @@ else()
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  set(GOOGLE_SHADERC_INCLUDE_SEARCH_PATH ${Vulkan_INCLUDE_DIR})
 | 
					  set(GOOGLE_SHADERC_INCLUDE_SEARCH_PATH ${Vulkan_INCLUDE_DIR})
 | 
				
			||||||
  set(GOOGLE_SHADERC_LIBRARY_SEARCH_PATH ${Vulkan_LIBRARY})
 | 
					  set(GOOGLE_SHADERC_LIBRARY_SEARCH_PATH ${Vulkan_LIBRARY})
 | 
				
			||||||
 | 
					 | 
				
			||||||
  if(USE_VULKAN_SHADERC_RUNTIME)
 | 
					 | 
				
			||||||
    find_path(
 | 
					 | 
				
			||||||
        GOOGLE_SHADERC_INCLUDE_DIRS
 | 
					 | 
				
			||||||
        NAMES shaderc/shaderc.hpp
 | 
					 | 
				
			||||||
        PATHS ${GOOGLE_SHADERC_INCLUDE_SEARCH_PATH})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    find_library(
 | 
					 | 
				
			||||||
        GOOGLE_SHADERC_LIBRARIES
 | 
					 | 
				
			||||||
        NAMES shaderc_combined
 | 
					 | 
				
			||||||
        PATHS ${GOOGLE_SHADERC_LIBRARY_SEARCH_PATH})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    find_package_handle_standard_args(
 | 
					 | 
				
			||||||
        Shaderc
 | 
					 | 
				
			||||||
        DEFAULT_MSG
 | 
					 | 
				
			||||||
        GOOGLE_SHADERC_INCLUDE_DIRS
 | 
					 | 
				
			||||||
        GOOGLE_SHADERC_LIBRARIES)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if(NOT Shaderc_FOUND)
 | 
					 | 
				
			||||||
      message(FATAL_ERROR "USE_VULKAN: Shaderc not found in VULKAN_SDK")
 | 
					 | 
				
			||||||
    else()
 | 
					 | 
				
			||||||
      message(STATUS "shaderc FOUND include:${GOOGLE_SHADERC_INCLUDE_DIRS}")
 | 
					 | 
				
			||||||
      message(STATUS "shaderc FOUND libs:${GOOGLE_SHADERC_LIBRARIES}")
 | 
					 | 
				
			||||||
    endif()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    list(APPEND Vulkan_INCLUDES ${GOOGLE_SHADERC_INCLUDE_DIRS})
 | 
					 | 
				
			||||||
    list(APPEND Vulkan_LIBS ${GOOGLE_SHADERC_LIBRARIES})
 | 
					 | 
				
			||||||
  endif(USE_VULKAN_SHADERC_RUNTIME)
 | 
					 | 
				
			||||||
endif()
 | 
					endif()
 | 
				
			||||||
 | 
				
			|||||||
@ -157,9 +157,6 @@ if [ -n "${USE_VULKAN}" ]; then
 | 
				
			|||||||
  if [ -n "${USE_VULKAN_RELAXED_PRECISION}" ]; then
 | 
					  if [ -n "${USE_VULKAN_RELAXED_PRECISION}" ]; then
 | 
				
			||||||
    CMAKE_ARGS+=("-DUSE_VULKAN_RELAXED_PRECISION=ON")
 | 
					    CMAKE_ARGS+=("-DUSE_VULKAN_RELAXED_PRECISION=ON")
 | 
				
			||||||
  fi
 | 
					  fi
 | 
				
			||||||
  if [ -n "${USE_VULKAN_SHADERC_RUNTIME}" ]; then
 | 
					 | 
				
			||||||
    CMAKE_ARGS+=("-DUSE_VULKAN_SHADERC_RUNTIME=ON")
 | 
					 | 
				
			||||||
  fi
 | 
					 | 
				
			||||||
fi
 | 
					fi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Use-specified CMake arguments go last to allow overridding defaults
 | 
					# Use-specified CMake arguments go last to allow overridding defaults
 | 
				
			||||||
 | 
				
			|||||||
@ -211,18 +211,6 @@ def define_tools_targets(
 | 
				
			|||||||
        srcs = [
 | 
					        srcs = [
 | 
				
			||||||
            "gen_vulkan_spv.py",
 | 
					            "gen_vulkan_spv.py",
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
        base_module = "",
 | 
					 | 
				
			||||||
        deps = [
 | 
					 | 
				
			||||||
            torchgen_deps,
 | 
					 | 
				
			||||||
            ":gen_aten_vulkan_glsl_lib",
 | 
					 | 
				
			||||||
        ],
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    python_library(
 | 
					 | 
				
			||||||
        name = "gen_aten_vulkan_glsl_lib",
 | 
					 | 
				
			||||||
        srcs = [
 | 
					 | 
				
			||||||
            "gen_vulkan_glsl.py",
 | 
					 | 
				
			||||||
        ],
 | 
					 | 
				
			||||||
        base_module = "tools",
 | 
					        base_module = "tools",
 | 
				
			||||||
        deps = [
 | 
					        deps = [
 | 
				
			||||||
            torchgen_deps,
 | 
					            torchgen_deps,
 | 
				
			||||||
@ -231,12 +219,11 @@ def define_tools_targets(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    python_binary(
 | 
					    python_binary(
 | 
				
			||||||
        name = "gen_aten_vulkan_spv_bin",
 | 
					        name = "gen_aten_vulkan_spv_bin",
 | 
				
			||||||
        main_module = "gen_vulkan_spv",
 | 
					        main_module = "tools.gen_vulkan_spv",
 | 
				
			||||||
        visibility = [
 | 
					        visibility = [
 | 
				
			||||||
            "PUBLIC",
 | 
					            "PUBLIC",
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
        deps = [
 | 
					        deps = [
 | 
				
			||||||
            ":gen_aten_vulkan_glsl_lib",
 | 
					 | 
				
			||||||
            ":gen_aten_vulkan_spv_lib",
 | 
					            ":gen_aten_vulkan_spv_lib",
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
@ -249,7 +236,6 @@ def define_tools_targets(
 | 
				
			|||||||
        contacts = contacts,
 | 
					        contacts = contacts,
 | 
				
			||||||
        visibility = ["PUBLIC"],
 | 
					        visibility = ["PUBLIC"],
 | 
				
			||||||
        deps = [
 | 
					        deps = [
 | 
				
			||||||
            ":gen_aten_vulkan_glsl_lib",
 | 
					 | 
				
			||||||
            ":gen_aten_vulkan_spv_lib",
 | 
					            ":gen_aten_vulkan_spv_lib",
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
				
			|||||||
@ -1,111 +0,0 @@
 | 
				
			|||||||
import copy
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from collections import OrderedDict
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import yaml
 | 
					 | 
				
			||||||
from torchgen.code_template import CodeTemplate
 | 
					 | 
				
			||||||
from yaml.constructor import ConstructorError
 | 
					 | 
				
			||||||
from yaml.nodes import MappingNode
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
try:
 | 
					 | 
				
			||||||
    from yaml import CLoader as Loader
 | 
					 | 
				
			||||||
except ImportError:
 | 
					 | 
				
			||||||
    from yaml import Loader  # type: ignore[misc]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# https://gist.github.com/pypt/94d747fe5180851196eb
 | 
					 | 
				
			||||||
class UniqueKeyLoader(Loader):
 | 
					 | 
				
			||||||
    def construct_mapping(self, node, deep=False):  # type: ignore[no-untyped-def]
 | 
					 | 
				
			||||||
        if not isinstance(node, MappingNode):
 | 
					 | 
				
			||||||
            raise ConstructorError(
 | 
					 | 
				
			||||||
                None,
 | 
					 | 
				
			||||||
                None,
 | 
					 | 
				
			||||||
                "expected a mapping node, but found %s" % node.id,
 | 
					 | 
				
			||||||
                node.start_mark,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        mapping = {}
 | 
					 | 
				
			||||||
        for key_node, value_node in node.value:
 | 
					 | 
				
			||||||
            key = self.construct_object(key_node, deep=deep)  # type: ignore[no-untyped-call]
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                hash(key)
 | 
					 | 
				
			||||||
            except TypeError as e:
 | 
					 | 
				
			||||||
                raise ConstructorError(
 | 
					 | 
				
			||||||
                    "while constructing a mapping",
 | 
					 | 
				
			||||||
                    node.start_mark,
 | 
					 | 
				
			||||||
                    "found unacceptable key ",
 | 
					 | 
				
			||||||
                    key_node.start_mark,
 | 
					 | 
				
			||||||
                ) from e
 | 
					 | 
				
			||||||
            # check for duplicate keys
 | 
					 | 
				
			||||||
            if key in mapping:
 | 
					 | 
				
			||||||
                raise ConstructorError(
 | 
					 | 
				
			||||||
                    "while constructing a mapping",
 | 
					 | 
				
			||||||
                    node.start_mark,
 | 
					 | 
				
			||||||
                    "found duplicate key",
 | 
					 | 
				
			||||||
                    key_node.start_mark,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
            value = self.construct_object(value_node, deep=deep)  # type: ignore[no-untyped-call]
 | 
					 | 
				
			||||||
            mapping[key] = value
 | 
					 | 
				
			||||||
        return mapping
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class GLSLGenerator(object):
 | 
					 | 
				
			||||||
    standard_header = """
 | 
					 | 
				
			||||||
#version 450 core
 | 
					 | 
				
			||||||
#define PRECISION $precision
 | 
					 | 
				
			||||||
#define FORMAT $format
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self):  # type: ignore[no-untyped-def]
 | 
					 | 
				
			||||||
        self.ops_template_params = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def add_params_yaml(self, parameters_yaml_file):  # type: ignore[no-untyped-def]
 | 
					 | 
				
			||||||
        all_template_params = OrderedDict()
 | 
					 | 
				
			||||||
        with open(parameters_yaml_file, "r") as f:
 | 
					 | 
				
			||||||
            contents = yaml.load(f, Loader=UniqueKeyLoader)
 | 
					 | 
				
			||||||
            for key in contents:
 | 
					 | 
				
			||||||
                all_template_params[key] = contents[key]
 | 
					 | 
				
			||||||
        self.validate_and_construct_op_params(all_template_params)  # type: ignore[no-untyped-call]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def validate_and_construct_op_params(self, all_template_params):  # type: ignore[no-untyped-def]
 | 
					 | 
				
			||||||
        for op in all_template_params:
 | 
					 | 
				
			||||||
            if op in self.ops_template_params:
 | 
					 | 
				
			||||||
                raise KeyError(f"{op} params file has already been parsed")
 | 
					 | 
				
			||||||
            op_params_default_vals = all_template_params[op][
 | 
					 | 
				
			||||||
                "parameter_names_with_default_values"
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
            template_params_set = set(op_params_default_vals.keys())
 | 
					 | 
				
			||||||
            self.ops_template_params[op] = []
 | 
					 | 
				
			||||||
            self.ops_template_params[op].append(op_params_default_vals)
 | 
					 | 
				
			||||||
            op_template_params_values = all_template_params[op]["parameter_values"]
 | 
					 | 
				
			||||||
            for param_vals in op_template_params_values:
 | 
					 | 
				
			||||||
                param_vals_set = set(param_vals.keys())
 | 
					 | 
				
			||||||
                invalid_keys = param_vals_set - template_params_set
 | 
					 | 
				
			||||||
                if (len(invalid_keys)) > 0:
 | 
					 | 
				
			||||||
                    raise KeyError(f"Invalid keys {invalid_keys} are found")
 | 
					 | 
				
			||||||
                param_vals_copy = copy.deepcopy(op_params_default_vals)
 | 
					 | 
				
			||||||
                for key in param_vals:
 | 
					 | 
				
			||||||
                    param_vals_copy[key] = param_vals[key]
 | 
					 | 
				
			||||||
                self.ops_template_params[op].append(param_vals_copy)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(self, glsl_template_in, out_dir):  # type: ignore[no-untyped-def]
 | 
					 | 
				
			||||||
        glsl_template_name = os.path.basename(glsl_template_in)
 | 
					 | 
				
			||||||
        op_name, extension_name = glsl_template_name.split(".")
 | 
					 | 
				
			||||||
        if extension_name != "glslt":
 | 
					 | 
				
			||||||
            raise TypeError(f"invalid file type for glsl template {extension_name}")
 | 
					 | 
				
			||||||
        if op_name not in self.ops_template_params:
 | 
					 | 
				
			||||||
            raise KeyError(f"{op_name} params have not been populated")
 | 
					 | 
				
			||||||
        code_template = CodeTemplate.from_file(glsl_template_in)
 | 
					 | 
				
			||||||
        for template_params in self.ops_template_params[op_name]:
 | 
					 | 
				
			||||||
            content = GLSLGenerator.standard_header
 | 
					 | 
				
			||||||
            param_vals_string = "x".join([str(i) for i in template_params.values()])
 | 
					 | 
				
			||||||
            output_file_name = op_name + "_" + param_vals_string + ".glsl"
 | 
					 | 
				
			||||||
            content += code_template.substitute(template_params)
 | 
					 | 
				
			||||||
            output_file = os.path.join(out_dir, output_file_name)
 | 
					 | 
				
			||||||
            with open(output_file, "w") as f:
 | 
					 | 
				
			||||||
                f.write(content)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Remove this
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					 | 
				
			||||||
    pass
 | 
					 | 
				
			||||||
@ -2,21 +2,122 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
import array
 | 
					import array
 | 
				
			||||||
 | 
					import copy
 | 
				
			||||||
import glob
 | 
					import glob
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
import subprocess
 | 
					import subprocess
 | 
				
			||||||
 | 
					import yaml
 | 
				
			||||||
 | 
					from collections import OrderedDict
 | 
				
			||||||
from torchgen.code_template import CodeTemplate
 | 
					from torchgen.code_template import CodeTemplate
 | 
				
			||||||
from dataclasses import dataclass
 | 
					from dataclasses import dataclass
 | 
				
			||||||
from typing import List
 | 
					from typing import Any, Dict, List
 | 
				
			||||||
 | 
					from yaml.constructor import ConstructorError
 | 
				
			||||||
 | 
					from yaml.nodes import MappingNode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from tools.gen_vulkan_glsl import GLSLGenerator
 | 
					try:
 | 
				
			||||||
 | 
					    from yaml import CLoader as Loader
 | 
				
			||||||
 | 
					except ImportError:
 | 
				
			||||||
 | 
					    from yaml import Loader  # type: ignore[misc]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
H_NAME = "spv.h"
 | 
					H_NAME = "spv.h"
 | 
				
			||||||
CPP_NAME = "spv.cpp"
 | 
					CPP_NAME = "spv.cpp"
 | 
				
			||||||
DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"}
 | 
					DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# https://gist.github.com/pypt/94d747fe5180851196eb
 | 
				
			||||||
 | 
					class UniqueKeyLoader(Loader):
 | 
				
			||||||
 | 
					    def construct_mapping(self, node, deep=False):  # type: ignore[no-untyped-def]
 | 
				
			||||||
 | 
					        if not isinstance(node, MappingNode):
 | 
				
			||||||
 | 
					            raise ConstructorError(
 | 
				
			||||||
 | 
					                None,
 | 
				
			||||||
 | 
					                None,
 | 
				
			||||||
 | 
					                "expected a mapping node, but found %s" % node.id,
 | 
				
			||||||
 | 
					                node.start_mark,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        mapping = {}
 | 
				
			||||||
 | 
					        for key_node, value_node in node.value:
 | 
				
			||||||
 | 
					            key = self.construct_object(key_node, deep=deep)  # type: ignore[no-untyped-call]
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                hash(key)
 | 
				
			||||||
 | 
					            except TypeError as e:
 | 
				
			||||||
 | 
					                raise ConstructorError(
 | 
				
			||||||
 | 
					                    "while constructing a mapping",
 | 
				
			||||||
 | 
					                    node.start_mark,
 | 
				
			||||||
 | 
					                    "found unacceptable key ",
 | 
				
			||||||
 | 
					                    key_node.start_mark,
 | 
				
			||||||
 | 
					                ) from e
 | 
				
			||||||
 | 
					            # check for duplicate keys
 | 
				
			||||||
 | 
					            if key in mapping:
 | 
				
			||||||
 | 
					                raise ConstructorError(
 | 
				
			||||||
 | 
					                    "while constructing a mapping",
 | 
				
			||||||
 | 
					                    node.start_mark,
 | 
				
			||||||
 | 
					                    "found duplicate key",
 | 
				
			||||||
 | 
					                    key_node.start_mark,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            value = self.construct_object(value_node, deep=deep)  # type: ignore[no-untyped-call]
 | 
				
			||||||
 | 
					            mapping[key] = value
 | 
				
			||||||
 | 
					        return mapping
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VulkanShaderGenerator(object):
 | 
				
			||||||
 | 
					    standard_header = """
 | 
				
			||||||
 | 
					#version 450 core
 | 
				
			||||||
 | 
					#define PRECISION $precision
 | 
				
			||||||
 | 
					#define FORMAT $format
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self: "VulkanShaderGenerator") -> None:
 | 
				
			||||||
 | 
					        self.ops_template_params: Dict[Any, Any] = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_params_yaml(self, parameters_yaml_file):  # type: ignore[no-untyped-def]
 | 
				
			||||||
 | 
					        all_template_params = OrderedDict()
 | 
				
			||||||
 | 
					        with open(parameters_yaml_file, "r") as f:
 | 
				
			||||||
 | 
					            contents = yaml.load(f, Loader=UniqueKeyLoader)
 | 
				
			||||||
 | 
					            for key in contents:
 | 
				
			||||||
 | 
					                all_template_params[key] = contents[key]
 | 
				
			||||||
 | 
					        self.validate_and_construct_op_params(all_template_params)  # type: ignore[no-untyped-call]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def validate_and_construct_op_params(self, all_template_params):  # type: ignore[no-untyped-def]
 | 
				
			||||||
 | 
					        for op in all_template_params:
 | 
				
			||||||
 | 
					            if op in self.ops_template_params:
 | 
				
			||||||
 | 
					                raise KeyError(f"{op} params file has already been parsed")
 | 
				
			||||||
 | 
					            op_params_default_vals = all_template_params[op][
 | 
				
			||||||
 | 
					                "parameter_names_with_default_values"
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					            template_params_set = set(op_params_default_vals.keys())
 | 
				
			||||||
 | 
					            self.ops_template_params[op] = []
 | 
				
			||||||
 | 
					            self.ops_template_params[op].append(op_params_default_vals)
 | 
				
			||||||
 | 
					            op_template_params_values = all_template_params[op]["parameter_values"]
 | 
				
			||||||
 | 
					            for param_vals in op_template_params_values:
 | 
				
			||||||
 | 
					                param_vals_set = set(param_vals.keys())
 | 
				
			||||||
 | 
					                missing_keys = template_params_set - param_vals_set
 | 
				
			||||||
 | 
					                invalid_keys = param_vals_set - template_params_set
 | 
				
			||||||
 | 
					                if (len(invalid_keys)) > 0:
 | 
				
			||||||
 | 
					                    raise KeyError(f"Invalid keys {invalid_keys} are found")
 | 
				
			||||||
 | 
					                param_vals_copy = copy.deepcopy(op_params_default_vals)
 | 
				
			||||||
 | 
					                for key in param_vals:
 | 
				
			||||||
 | 
					                    param_vals_copy[key] = param_vals[key]
 | 
				
			||||||
 | 
					                self.ops_template_params[op].append(param_vals_copy)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, glsl_template_in, out_dir):  # type: ignore[no-untyped-def]
 | 
				
			||||||
 | 
					        glsl_template_name = os.path.basename(glsl_template_in)
 | 
				
			||||||
 | 
					        op_name, extension_name = glsl_template_name.split(".")
 | 
				
			||||||
 | 
					        if extension_name != "glslt":
 | 
				
			||||||
 | 
					            raise TypeError(f"invalid file type for glsl template {extension_name}")
 | 
				
			||||||
 | 
					        if op_name not in self.ops_template_params:
 | 
				
			||||||
 | 
					            raise KeyError(f"{op_name} params have not been populated")
 | 
				
			||||||
 | 
					        code_template = CodeTemplate.from_file(glsl_template_in)
 | 
				
			||||||
 | 
					        for template_params in self.ops_template_params[op_name]:
 | 
				
			||||||
 | 
					            content = VulkanShaderGenerator.standard_header
 | 
				
			||||||
 | 
					            param_vals_string = "x".join([str(i) for i in template_params.values()])
 | 
				
			||||||
 | 
					            output_file_name = op_name + "_" + param_vals_string + ".glsl"
 | 
				
			||||||
 | 
					            content += code_template.substitute(template_params)
 | 
				
			||||||
 | 
					            output_file = os.path.join(out_dir, output_file_name)
 | 
				
			||||||
 | 
					            with open(output_file, "w") as f:
 | 
				
			||||||
 | 
					                f.write(content)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
class ShaderInfo:
 | 
					class ShaderInfo:
 | 
				
			||||||
@ -25,38 +126,44 @@ class ShaderInfo:
 | 
				
			|||||||
    weight_storage_type: str = ""
 | 
					    weight_storage_type: str = ""
 | 
				
			||||||
    bias_storage_type: str = ""
 | 
					    bias_storage_type: str = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def getName(filePath):
 | 
					def getName(filePath: str) -> str:
 | 
				
			||||||
    return os.path.basename(filePath).replace("/", "_").replace(".", "_")
 | 
					    return os.path.basename(filePath).replace("/", "_").replace(".", "_")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def isDescriptorLine(lineStr):
 | 
					def isDescriptorLine(lineStr: str) -> bool:
 | 
				
			||||||
    descriptorLineId = r"^layout\(set"
 | 
					    descriptorLineId = r"^layout\(set"
 | 
				
			||||||
    return re.search(descriptorLineId, lineStr)
 | 
					    return re.search(descriptorLineId, lineStr) is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def isTileSizeLine(lineStr):
 | 
					def isTileSizeLine(lineStr: str) -> bool:
 | 
				
			||||||
    tile_size_id = r"^ \* TILE_SIZE = \("
 | 
					    tile_size_id = r"^ \* TILE_SIZE = \("
 | 
				
			||||||
    return re.search(tile_size_id, lineStr)
 | 
					    return re.search(tile_size_id, lineStr) is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def findTileSizes(lineStr):
 | 
					def findTileSizes(lineStr: str) -> List[int]:
 | 
				
			||||||
    tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
 | 
					    tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
 | 
				
			||||||
    matches = re.search(tile_size_id, lineStr)
 | 
					    matches = re.search(tile_size_id, lineStr)
 | 
				
			||||||
 | 
					    if matches is None:
 | 
				
			||||||
 | 
					        raise AssertionError("matches is None in findTileSizes")
 | 
				
			||||||
    return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
 | 
					    return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def isWeightStorageTypeLine(lineStr):
 | 
					def isWeightStorageTypeLine(lineStr: str) -> bool:
 | 
				
			||||||
    weight_storage_id = r"^ \* WEIGHT_STORAGE = "
 | 
					    weight_storage_id = r"^ \* WEIGHT_STORAGE = "
 | 
				
			||||||
    return re.search(weight_storage_id, lineStr)
 | 
					    return re.search(weight_storage_id, lineStr) is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def getWeightStorageType(lineStr):
 | 
					def getWeightStorageType(lineStr: str) -> str:
 | 
				
			||||||
    weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
 | 
					    weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
 | 
				
			||||||
    matches = re.search(weight_storage_id, lineStr)
 | 
					    matches = re.search(weight_storage_id, lineStr)
 | 
				
			||||||
 | 
					    if matches is None:
 | 
				
			||||||
 | 
					        raise AssertionError("matches is None in getWeightStorageType")
 | 
				
			||||||
    return matches.group(1)
 | 
					    return matches.group(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def isBiasStorageTypeLine(lineStr):
 | 
					def isBiasStorageTypeLine(lineStr: str) -> bool:
 | 
				
			||||||
    weight_storage_id = r"^ \* BIAS_STORAGE = "
 | 
					    weight_storage_id = r"^ \* BIAS_STORAGE = "
 | 
				
			||||||
    return re.search(weight_storage_id, lineStr)
 | 
					    return re.search(weight_storage_id, lineStr) is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def getBiasStorageType(lineStr):
 | 
					def getBiasStorageType(lineStr: str) -> str:
 | 
				
			||||||
    weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
 | 
					    weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
 | 
				
			||||||
    matches = re.search(weight_storage_id, lineStr)
 | 
					    matches = re.search(weight_storage_id, lineStr)
 | 
				
			||||||
 | 
					    if matches is None:
 | 
				
			||||||
 | 
					        raise AssertionError("matches is None in getBiasStorageType")
 | 
				
			||||||
    return matches.group(1)
 | 
					    return matches.group(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
typeIdMapping = {
 | 
					typeIdMapping = {
 | 
				
			||||||
@ -73,12 +180,15 @@ storageTypeToEnum = {
 | 
				
			|||||||
    "": "api::StorageType::UNKNOWN",
 | 
					    "": "api::StorageType::UNKNOWN",
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def determineDescriptorType(lineStr):
 | 
					def determineDescriptorType(lineStr: str) -> str:
 | 
				
			||||||
    for identifier, typeNum in typeIdMapping.items():
 | 
					    for identifier, typeNum in typeIdMapping.items():
 | 
				
			||||||
        if re.search(identifier, lineStr):
 | 
					        if re.search(identifier, lineStr):
 | 
				
			||||||
            return typeNum
 | 
					            return typeNum
 | 
				
			||||||
 | 
					    raise AssertionError(
 | 
				
			||||||
 | 
					        "No matching descriptor type for " + lineStr + " in determineDescriptorType"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def getShaderInfo(srcFilePath):
 | 
					def getShaderInfo(srcFilePath: str) -> ShaderInfo:
 | 
				
			||||||
    shader_info = ShaderInfo([], [], "")
 | 
					    shader_info = ShaderInfo([], [], "")
 | 
				
			||||||
    with open(srcFilePath, 'r') as srcFile:
 | 
					    with open(srcFilePath, 'r') as srcFile:
 | 
				
			||||||
        for line in srcFile:
 | 
					        for line in srcFile:
 | 
				
			||||||
@ -93,14 +203,14 @@ def getShaderInfo(srcFilePath):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    return shader_info
 | 
					    return shader_info
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def genGLSLFromGLSLT(src_dir_path, tmp_dir_path):
 | 
					def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
 | 
				
			||||||
    template_dir_path = os.path.join(src_dir_path, "templates")
 | 
					    template_dir_path = os.path.join(src_dir_path, "templates")
 | 
				
			||||||
    vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True)
 | 
					    vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True)
 | 
				
			||||||
    parameter_yaml_files = []
 | 
					    parameter_yaml_files = []
 | 
				
			||||||
    for f in vexs:
 | 
					    for f in vexs:
 | 
				
			||||||
        if len(f) > 1:
 | 
					        if len(f) > 1:
 | 
				
			||||||
            parameter_yaml_files.append(f)
 | 
					            parameter_yaml_files.append(f)
 | 
				
			||||||
    generator = GLSLGenerator()
 | 
					    generator = VulkanShaderGenerator()
 | 
				
			||||||
    for params_yaml in parameter_yaml_files:
 | 
					    for params_yaml in parameter_yaml_files:
 | 
				
			||||||
        generator.add_params_yaml(params_yaml)  # type: ignore[no-untyped-call]
 | 
					        generator.add_params_yaml(params_yaml)  # type: ignore[no-untyped-call]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -113,9 +223,20 @@ def genGLSLFromGLSLT(src_dir_path, tmp_dir_path):
 | 
				
			|||||||
    for glslt in templateSrcPaths:
 | 
					    for glslt in templateSrcPaths:
 | 
				
			||||||
        generator.generate(glslt, tmp_dir_path)  # type: ignore[no-untyped-call]
 | 
					        generator.generate(glslt, tmp_dir_path)  # type: ignore[no-untyped-call]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
 | 
					
 | 
				
			||||||
    print("hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".format(
 | 
					def genCppH(
 | 
				
			||||||
        hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath))
 | 
					    hFilePath: str,
 | 
				
			||||||
 | 
					    cppFilePath: str,
 | 
				
			||||||
 | 
					    srcDirPath: str,
 | 
				
			||||||
 | 
					    glslcPath: str,
 | 
				
			||||||
 | 
					    tmpDirPath: str,
 | 
				
			||||||
 | 
					    env: Dict[Any, Any],
 | 
				
			||||||
 | 
					) -> None:
 | 
				
			||||||
 | 
					    print(
 | 
				
			||||||
 | 
					        "hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".format(
 | 
				
			||||||
 | 
					            hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True)
 | 
					    vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True)
 | 
				
			||||||
    templateSrcPaths = []
 | 
					    templateSrcPaths = []
 | 
				
			||||||
@ -142,8 +263,8 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
 | 
				
			|||||||
        codeTemplate = CodeTemplate.from_file(templateSrcPath)
 | 
					        codeTemplate = CodeTemplate.from_file(templateSrcPath)
 | 
				
			||||||
        srcPath = tmpDirPath + "/" + name + ".glsl"
 | 
					        srcPath = tmpDirPath + "/" + name + ".glsl"
 | 
				
			||||||
        content = codeTemplate.substitute(env)
 | 
					        content = codeTemplate.substitute(env)
 | 
				
			||||||
        with open(srcPath, 'w') as f:
 | 
					        with open(srcPath, 'w') as fw:
 | 
				
			||||||
            f.write(content)
 | 
					            fw.write(content)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        spvPath = tmpDirPath + "/" + name + ".spv"
 | 
					        spvPath = tmpDirPath + "/" + name + ".spv"
 | 
				
			||||||
        print("spvPath {}".format(spvPath))
 | 
					        print("spvPath {}".format(spvPath))
 | 
				
			||||||
@ -188,8 +309,8 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
 | 
				
			|||||||
        name = getName(spvPath)
 | 
					        name = getName(spvPath)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        print("spvPath:{}".format(spvPath))
 | 
					        print("spvPath:{}".format(spvPath))
 | 
				
			||||||
        with open(spvPath, 'rb') as f:
 | 
					        with open(spvPath, 'rb') as fr:
 | 
				
			||||||
            next_bin = array.array('I', f.read())
 | 
					            next_bin = array.array('I', fr.read())
 | 
				
			||||||
            sizeBytes = 4 * len(next_bin)
 | 
					            sizeBytes = 4 * len(next_bin)
 | 
				
			||||||
            shader_info_bin_code.append(
 | 
					            shader_info_bin_code.append(
 | 
				
			||||||
                "const uint32_t {}_bin[] = {{\n  {}\n}};".format(
 | 
					                "const uint32_t {}_bin[] = {{\n  {}\n}};".format(
 | 
				
			||||||
@ -231,13 +352,13 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
 | 
				
			|||||||
    cpp += nsend
 | 
					    cpp += nsend
 | 
				
			||||||
    h += nsend
 | 
					    h += nsend
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with open(hFilePath, "w") as f:
 | 
					    with open(hFilePath, "w") as fw:
 | 
				
			||||||
        f.write(h)
 | 
					        fw.write(h)
 | 
				
			||||||
    with open(cppFilePath, "w") as f:
 | 
					    with open(cppFilePath, "w") as fw:
 | 
				
			||||||
        f.write(cpp)
 | 
					        fw.write(cpp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def parse_arg_env(items):
 | 
					def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
 | 
				
			||||||
    d = {}
 | 
					    d = {}
 | 
				
			||||||
    if items:
 | 
					    if items:
 | 
				
			||||||
        for item in items:
 | 
					        for item in items:
 | 
				
			||||||
@ -248,7 +369,7 @@ def parse_arg_env(items):
 | 
				
			|||||||
    return d
 | 
					    return d
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main(argv):
 | 
					def main(argv: List[str]) -> int:
 | 
				
			||||||
    parser = argparse.ArgumentParser(description='')
 | 
					    parser = argparse.ArgumentParser(description='')
 | 
				
			||||||
    parser.add_argument(
 | 
					    parser.add_argument(
 | 
				
			||||||
        '-i',
 | 
					        '-i',
 | 
				
			||||||
@ -294,5 +415,7 @@ def main(argv):
 | 
				
			|||||||
        tmpDirPath=options.tmp_dir_path,
 | 
					        tmpDirPath=options.tmp_dir_path,
 | 
				
			||||||
        env=env)
 | 
					        env=env)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    sys.exit(main(sys.argv))
 | 
					    sys.exit(main(sys.argv))
 | 
				
			||||||
 | 
				
			|||||||
@ -2,11 +2,11 @@ import os
 | 
				
			|||||||
import tempfile
 | 
					import tempfile
 | 
				
			||||||
import unittest
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from tools.gen_vulkan_glsl import GLSLGenerator
 | 
					from tools.gen_vulkan_spv import VulkanShaderGenerator
 | 
				
			||||||
from yaml.constructor import ConstructorError
 | 
					from yaml.constructor import ConstructorError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestGLSLCodegen(unittest.TestCase):
 | 
					class TestVulkanShaderCodegen(unittest.TestCase):
 | 
				
			||||||
    def test_assert_on_duplicate_key_yaml(self) -> None:
 | 
					    def test_assert_on_duplicate_key_yaml(self) -> None:
 | 
				
			||||||
        yaml_with_duplicate_keys = """
 | 
					        yaml_with_duplicate_keys = """
 | 
				
			||||||
conv2d_pw:
 | 
					conv2d_pw:
 | 
				
			||||||
@ -37,7 +37,7 @@ conv2d_pw:
 | 
				
			|||||||
      TILE_SIZE_Y: 4
 | 
					      TILE_SIZE_Y: 4
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        generator = GLSLGenerator()  # type: ignore[no-untyped-call]
 | 
					        generator = VulkanShaderGenerator()  # type: ignore[no-untyped-call]
 | 
				
			||||||
        with tempfile.NamedTemporaryFile(mode="w") as fp:
 | 
					        with tempfile.NamedTemporaryFile(mode="w") as fp:
 | 
				
			||||||
            fp.write(yaml_with_duplicate_keys)
 | 
					            fp.write(yaml_with_duplicate_keys)
 | 
				
			||||||
            fp.flush()
 | 
					            fp.flush()
 | 
				
			||||||
@ -57,7 +57,7 @@ conv2d_pw:
 | 
				
			|||||||
      TILE_SIZE_Z: 2
 | 
					      TILE_SIZE_Z: 2
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        generator = GLSLGenerator()  # type: ignore[no-untyped-call]
 | 
					        generator = VulkanShaderGenerator()  # type: ignore[no-untyped-call]
 | 
				
			||||||
        with tempfile.NamedTemporaryFile(mode="w") as fp:
 | 
					        with tempfile.NamedTemporaryFile(mode="w") as fp:
 | 
				
			||||||
            fp.write(yaml_with_key_mismatch)
 | 
					            fp.write(yaml_with_key_mismatch)
 | 
				
			||||||
            fp.flush()
 | 
					            fp.flush()
 | 
				
			||||||
@ -77,7 +77,7 @@ conv2d_pw:
 | 
				
			|||||||
x = $TILE_SIZE_X + $TILE_SIZE_Y
 | 
					x = $TILE_SIZE_X + $TILE_SIZE_Y
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        generator = GLSLGenerator()  # type: ignore[no-untyped-call]
 | 
					        generator = VulkanShaderGenerator()  # type: ignore[no-untyped-call]
 | 
				
			||||||
        with tempfile.NamedTemporaryFile(mode="w") as fp:
 | 
					        with tempfile.NamedTemporaryFile(mode="w") as fp:
 | 
				
			||||||
            fp.write(yaml_with_key_mismatch)
 | 
					            fp.write(yaml_with_key_mismatch)
 | 
				
			||||||
            fp.flush()
 | 
					            fp.flush()
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user