mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Apply parts of pyupgrade to torch (starting with the safest changes). This PR only does two things: removes the need to inherit from object and removes unused future imports. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308 Approved by: https://github.com/ezyang, https://github.com/albanD
479 lines
17 KiB
Python
479 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import array
|
|
import copy
|
|
import glob
|
|
import os
|
|
import re
|
|
import sys
|
|
import subprocess
|
|
import textwrap
|
|
import yaml
|
|
from collections import OrderedDict
|
|
from torchgen.code_template import CodeTemplate
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Tuple, Optional
|
|
from yaml.constructor import ConstructorError
|
|
from yaml.nodes import MappingNode
|
|
|
|
try:
|
|
from yaml import CLoader as Loader
|
|
except ImportError:
|
|
from yaml import Loader # type: ignore[misc]
|
|
|
|
H_NAME = "spv.h"
|
|
CPP_NAME = "spv.cpp"
|
|
DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"}
|
|
|
|
# https://gist.github.com/pypt/94d747fe5180851196eb
|
|
class UniqueKeyLoader(Loader):
|
|
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
|
|
if not isinstance(node, MappingNode):
|
|
raise ConstructorError(
|
|
None,
|
|
None,
|
|
"expected a mapping node, but found %s" % node.id,
|
|
node.start_mark,
|
|
)
|
|
mapping = {}
|
|
for key_node, value_node in node.value:
|
|
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
|
|
try:
|
|
hash(key)
|
|
except TypeError as e:
|
|
raise ConstructorError(
|
|
"while constructing a mapping",
|
|
node.start_mark,
|
|
"found unacceptable key ",
|
|
key_node.start_mark,
|
|
) from e
|
|
# check for duplicate keys
|
|
if key in mapping:
|
|
raise ConstructorError(
|
|
"while constructing a mapping",
|
|
node.start_mark,
|
|
"found duplicate key",
|
|
key_node.start_mark,
|
|
)
|
|
value = self.construct_object(value_node, deep=deep) # type: ignore[no-untyped-call]
|
|
mapping[key] = value
|
|
return mapping
|
|
|
|
|
|
class VulkanShaderGenerator:
|
|
standard_header = """
|
|
#version 450 core
|
|
#define PRECISION $precision
|
|
#define FORMAT $format
|
|
|
|
"""
|
|
|
|
def __init__(self: "VulkanShaderGenerator") -> None:
|
|
self.ops_template_params: Dict[Any, Any] = {}
|
|
|
|
def add_params_yaml(self, parameters_yaml_file): # type: ignore[no-untyped-def]
|
|
all_template_params = OrderedDict()
|
|
with open(parameters_yaml_file, "r") as f:
|
|
contents = yaml.load(f, Loader=UniqueKeyLoader)
|
|
for key in contents:
|
|
all_template_params[key] = contents[key]
|
|
self.validate_and_construct_op_params(all_template_params) # type: ignore[no-untyped-call]
|
|
|
|
def validate_and_construct_op_params(self, all_template_params): # type: ignore[no-untyped-def]
|
|
for op in all_template_params:
|
|
if op in self.ops_template_params:
|
|
raise KeyError(f"{op} params file has already been parsed")
|
|
op_params_default_vals = all_template_params[op][
|
|
"parameter_names_with_default_values"
|
|
]
|
|
template_params_set = set(op_params_default_vals.keys())
|
|
self.ops_template_params[op] = []
|
|
self.ops_template_params[op].append(op_params_default_vals)
|
|
op_template_params_values = all_template_params[op]["parameter_values"]
|
|
for param_vals in op_template_params_values:
|
|
param_vals_set = set(param_vals.keys())
|
|
missing_keys = template_params_set - param_vals_set
|
|
invalid_keys = param_vals_set - template_params_set
|
|
if (len(invalid_keys)) > 0:
|
|
raise KeyError(f"Invalid keys {invalid_keys} are found")
|
|
param_vals_copy = copy.deepcopy(op_params_default_vals)
|
|
for key in param_vals:
|
|
param_vals_copy[key] = param_vals[key]
|
|
self.ops_template_params[op].append(param_vals_copy)
|
|
|
|
def generate(self, glsl_template_in, out_dir): # type: ignore[no-untyped-def]
|
|
glsl_template_name = os.path.basename(glsl_template_in)
|
|
op_name, extension_name = glsl_template_name.split(".")
|
|
if extension_name != "glslt":
|
|
raise TypeError(f"invalid file type for glsl template {extension_name}")
|
|
if op_name not in self.ops_template_params:
|
|
raise KeyError(f"{op_name} params have not been populated")
|
|
code_template = CodeTemplate.from_file(glsl_template_in)
|
|
for template_params in self.ops_template_params[op_name]:
|
|
content = VulkanShaderGenerator.standard_header
|
|
param_vals_string = "x".join([str(i) for (k, i) in template_params.items() if k != "REGISTER_FOR"])
|
|
output_file_name = op_name + "_" + param_vals_string + ".glsl"
|
|
content += code_template.substitute(template_params)
|
|
output_file = os.path.join(out_dir, output_file_name)
|
|
with open(output_file, "w") as f:
|
|
f.write(content)
|
|
|
|
|
|
@dataclass
|
|
class ShaderInfo:
|
|
tile_size: List[int]
|
|
layouts: List[str]
|
|
weight_storage_type: str = ""
|
|
bias_storage_type: str = ""
|
|
register_for: Optional[Tuple[str, List[str]]] = None
|
|
|
|
def getName(filePath: str) -> str:
|
|
return os.path.basename(filePath).replace("/", "_").replace(".", "_")
|
|
|
|
def isDescriptorLine(lineStr: str) -> bool:
|
|
descriptorLineId = r"^layout\(set"
|
|
return re.search(descriptorLineId, lineStr) is not None
|
|
|
|
def isTileSizeLine(lineStr: str) -> bool:
|
|
tile_size_id = r"^ \* TILE_SIZE = \("
|
|
return re.search(tile_size_id, lineStr) is not None
|
|
|
|
def findTileSizes(lineStr: str) -> List[int]:
|
|
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
|
|
matches = re.search(tile_size_id, lineStr)
|
|
if matches is None:
|
|
raise AssertionError("matches is None in findTileSizes")
|
|
return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
|
|
|
|
def isWeightStorageTypeLine(lineStr: str) -> bool:
|
|
weight_storage_id = r"^ \* WEIGHT_STORAGE = "
|
|
return re.search(weight_storage_id, lineStr) is not None
|
|
|
|
def getWeightStorageType(lineStr: str) -> str:
|
|
weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
|
|
matches = re.search(weight_storage_id, lineStr)
|
|
if matches is None:
|
|
raise AssertionError("matches is None in getWeightStorageType")
|
|
return matches.group(1)
|
|
|
|
def isBiasStorageTypeLine(lineStr: str) -> bool:
|
|
weight_storage_id = r"^ \* BIAS_STORAGE = "
|
|
return re.search(weight_storage_id, lineStr) is not None
|
|
|
|
def getBiasStorageType(lineStr: str) -> str:
|
|
weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
|
|
matches = re.search(weight_storage_id, lineStr)
|
|
if matches is None:
|
|
raise AssertionError("matches is None in getBiasStorageType")
|
|
return matches.group(1)
|
|
|
|
def isRegisterForLine(lineStr: str) -> bool:
|
|
# Check for Shader Name and a list of at least one Registry Key
|
|
register_for_id = r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
|
|
return re.search(register_for_id, lineStr) is not None
|
|
|
|
def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
|
|
register_for_pattern = r"'([A-Za-z0-9_]+)'"
|
|
matches = re.findall(register_for_pattern, lineStr)
|
|
if matches is None:
|
|
raise AssertionError("matches is None in getBiasStorageType")
|
|
matches_list = list(matches)
|
|
return (matches_list[0], matches_list[1:])
|
|
|
|
typeIdMapping = {
|
|
r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE",
|
|
r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER",
|
|
r"\bbuffer\b": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER",
|
|
r"\buniform\b": "VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER",
|
|
}
|
|
|
|
storageTypeToEnum = {
|
|
"TEXTURE_2D" : "api::StorageType::TEXTURE_2D",
|
|
"TEXTURE_3D" : "api::StorageType::TEXTURE_3D",
|
|
"BUFFER" : "api::StorageType::BUFFER",
|
|
"": "api::StorageType::UNKNOWN",
|
|
}
|
|
|
|
def determineDescriptorType(lineStr: str) -> str:
|
|
for identifier, typeNum in typeIdMapping.items():
|
|
if re.search(identifier, lineStr):
|
|
return typeNum
|
|
raise AssertionError(
|
|
"No matching descriptor type for " + lineStr + " in determineDescriptorType"
|
|
)
|
|
|
|
def getShaderInfo(srcFilePath: str) -> ShaderInfo:
|
|
shader_info = ShaderInfo([], [], "")
|
|
with open(srcFilePath, 'r') as srcFile:
|
|
for line in srcFile:
|
|
if isDescriptorLine(line):
|
|
shader_info.layouts.append(determineDescriptorType(line))
|
|
if isTileSizeLine(line):
|
|
shader_info.tile_size = findTileSizes(line)
|
|
if isWeightStorageTypeLine(line):
|
|
shader_info.weight_storage_type = getWeightStorageType(line)
|
|
if isBiasStorageTypeLine(line):
|
|
shader_info.bias_storage_type = getBiasStorageType(line)
|
|
if isRegisterForLine(line):
|
|
shader_info.register_for = findRegisterFor(line)
|
|
|
|
return shader_info
|
|
|
|
def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
|
|
template_dir_path = os.path.join(src_dir_path, "templates")
|
|
vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True)
|
|
parameter_yaml_files = []
|
|
for f in vexs:
|
|
if len(f) > 1:
|
|
parameter_yaml_files.append(f)
|
|
generator = VulkanShaderGenerator()
|
|
for params_yaml in parameter_yaml_files:
|
|
generator.add_params_yaml(params_yaml) # type: ignore[no-untyped-call]
|
|
|
|
vexs = glob.glob(os.path.join(src_dir_path, '**', '*.glslt'), recursive=True)
|
|
templateSrcPaths = []
|
|
for f in vexs:
|
|
if len(f) > 1:
|
|
templateSrcPaths.append(f)
|
|
templateSrcPaths.sort()
|
|
for glslt in templateSrcPaths:
|
|
generator.generate(glslt, tmp_dir_path) # type: ignore[no-untyped-call]
|
|
|
|
|
|
def genCppH(
|
|
hFilePath: str,
|
|
cppFilePath: str,
|
|
srcDirPaths: str,
|
|
glslcPath: str,
|
|
tmpDirPath: str,
|
|
env: Dict[Any, Any],
|
|
) -> None:
|
|
print(
|
|
"hFilePath:{} cppFilePath:{} srcDirPaths:{} glslcPath:{} tmpDirPath:{}".format(
|
|
hFilePath, cppFilePath, srcDirPaths, glslcPath, tmpDirPath
|
|
)
|
|
)
|
|
|
|
templateSrcPaths = []
|
|
|
|
for srcDirPath in srcDirPaths:
|
|
vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True)
|
|
for f in vexs:
|
|
if len(f) > 1:
|
|
templateSrcPaths.append(f)
|
|
templateSrcPaths.sort()
|
|
|
|
# Now add glsl files that are generated from templates
|
|
genGLSLFromGLSLT(srcDirPath, tmpDirPath)
|
|
|
|
vexs = glob.glob(os.path.join(tmpDirPath, '**', '*.glsl'), recursive=True)
|
|
for f in vexs:
|
|
if len(f) > 1:
|
|
templateSrcPaths.append(f)
|
|
templateSrcPaths.sort()
|
|
print("templateSrcPaths:{}".format(templateSrcPaths))
|
|
|
|
spvPaths = {}
|
|
for templateSrcPath in templateSrcPaths:
|
|
print("templateSrcPath {}".format(templateSrcPath))
|
|
name = getName(templateSrcPath).replace("_glsl", "")
|
|
print("name {}".format(name))
|
|
|
|
codeTemplate = CodeTemplate.from_file(templateSrcPath)
|
|
srcPath = tmpDirPath + "/" + name + ".glsl"
|
|
content = codeTemplate.substitute(env)
|
|
with open(srcPath, 'w') as fw:
|
|
fw.write(content)
|
|
|
|
spvPath = tmpDirPath + "/" + name + ".spv"
|
|
print("spvPath {}".format(spvPath))
|
|
|
|
cmd = [
|
|
glslcPath, "-fshader-stage=compute",
|
|
srcPath, "-o", spvPath,
|
|
"--target-env=vulkan1.0",
|
|
"-Werror"
|
|
] + [arg for srcDirPath in srcDirPaths for arg in ["-I", srcDirPath]]
|
|
|
|
print("\nglslc cmd:", cmd)
|
|
|
|
subprocess.check_call(cmd)
|
|
spvPaths[spvPath] = templateSrcPath
|
|
|
|
h = "#pragma once\n"
|
|
h += "#include <ATen/native/vulkan/api/Types.h>\n"
|
|
h += "#include <ATen/native/vulkan/api/vk_api.h>\n"
|
|
h += "#include <c10/util/flat_hash_map.h>\n"
|
|
h += "#include <string>\n"
|
|
|
|
nsbegin = "namespace at {\nnamespace native {\nnamespace vulkan {\n"
|
|
nsend = "} // namespace vulkan\n} // namespace native\n} // namespace at\n"
|
|
|
|
anon_ns_begin = "namespace {\n"
|
|
anon_ns_end = "} // namespace\n"
|
|
|
|
h += nsbegin
|
|
|
|
# Forward declaration of ShaderInfo
|
|
h += "namespace api {\nstruct ShaderInfo;\n} // namespace api\n"
|
|
h += "typedef ska::flat_hash_map<std::string, api::ShaderInfo> ShaderListing;\n"
|
|
h += "typedef ska::flat_hash_map<std::string, std::string> RegistryKeyMap;\n"
|
|
h += "typedef ska::flat_hash_map<std::string, RegistryKeyMap> ShaderRegistry;\n"
|
|
h += "extern const ShaderListing shader_infos;\n"
|
|
h += "extern ShaderRegistry shader_registry;\n"
|
|
h += "inline const ShaderListing& get_shader_infos() {\n return shader_infos;\n}\n"
|
|
h += "inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n}\n"
|
|
|
|
h += nsend
|
|
|
|
cpp = "#include <ATen/native/vulkan/api/Shader.h>\n"
|
|
cpp += "#include <ATen/native/vulkan/{}>\n".format(H_NAME)
|
|
cpp += "#include <stdint.h>\n"
|
|
cpp += "#include <vector>\n"
|
|
cpp += nsbegin
|
|
|
|
shader_info_bin_code = []
|
|
shader_info_cpp_code = []
|
|
shader_info_registry_code = []
|
|
|
|
for spvPath, srcPath in spvPaths.items():
|
|
name = getName(spvPath).replace("_spv", "")
|
|
|
|
print("spvPath:{}".format(spvPath))
|
|
with open(spvPath, 'rb') as fr:
|
|
next_bin = array.array('I', fr.read())
|
|
sizeBytes = 4 * len(next_bin)
|
|
shader_info_bin_code.append(
|
|
"const uint32_t {}_bin[] = {{\n{}\n}};".format(
|
|
name,
|
|
textwrap.indent(",\n".join(str(x) for x in next_bin), " "),
|
|
),
|
|
)
|
|
|
|
shader_info = getShaderInfo(srcPath)
|
|
|
|
tile_size = (
|
|
"{{{}}}".format(", ".join(str(x) for x in shader_info.tile_size))
|
|
if (len(shader_info.tile_size) > 0)
|
|
else "std::vector<uint32_t>()"
|
|
)
|
|
|
|
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
|
|
|
|
shader_info_args = [
|
|
"\"vulkan.{}\"".format(name),
|
|
"{}_bin".format(name),
|
|
str(sizeBytes),
|
|
shader_info_layouts,
|
|
tile_size,
|
|
storageTypeToEnum[shader_info.weight_storage_type],
|
|
storageTypeToEnum[shader_info.bias_storage_type],
|
|
]
|
|
|
|
shader_info_cpp_code.append(
|
|
textwrap.indent(
|
|
"{{\"{}\",\n api::ShaderInfo(\n{})}}".format(
|
|
name,
|
|
textwrap.indent(",\n".join(shader_info_args), " "),
|
|
),
|
|
" ",
|
|
),
|
|
)
|
|
|
|
if shader_info.register_for is not None:
|
|
(op_name, registry_keys) = shader_info.register_for
|
|
for registry_key in registry_keys:
|
|
shader_info_registry_code.append(
|
|
textwrap.indent(
|
|
"{{\"{}\", {{{{\"{}\", \"{}\"}}}}}}".format(
|
|
op_name,
|
|
registry_key,
|
|
name,
|
|
),
|
|
" ",
|
|
),
|
|
)
|
|
|
|
cpp += anon_ns_begin
|
|
cpp += "\n".join(shader_info_bin_code) + "\n"
|
|
cpp += anon_ns_end
|
|
|
|
cpp += "const ShaderListing shader_infos = {{\n{}}};\n".format(
|
|
",\n".join(shader_info_cpp_code),
|
|
)
|
|
cpp += "ShaderRegistry shader_registry = {{\n{}}};\n".format(
|
|
",\n".join(shader_info_registry_code),
|
|
)
|
|
cpp += nsend
|
|
|
|
with open(hFilePath, "w") as fw:
|
|
fw.write(h)
|
|
with open(cppFilePath, "w") as fw:
|
|
fw.write(cpp)
|
|
|
|
|
|
def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
|
|
d = {}
|
|
if items:
|
|
for item in items:
|
|
tokens = item.split("=")
|
|
key = tokens[0].strip()
|
|
value = tokens[1].strip()
|
|
d[key] = value
|
|
return d
|
|
|
|
|
|
def main(argv: List[str]) -> int:
|
|
parser = argparse.ArgumentParser(description='')
|
|
parser.add_argument(
|
|
'-i',
|
|
'--glsl-paths',
|
|
nargs='+',
|
|
help='List of paths to look for GLSL source files, separated by spaces. Ex: --glsl-paths "path1 path2 path3"',
|
|
default=['.'],
|
|
)
|
|
parser.add_argument(
|
|
'-c',
|
|
'--glslc-path',
|
|
required=True,
|
|
help='')
|
|
parser.add_argument(
|
|
'-t',
|
|
'--tmp-dir-path',
|
|
required=True,
|
|
help='/tmp')
|
|
parser.add_argument(
|
|
'-o',
|
|
'--output-path',
|
|
required=True,
|
|
help='')
|
|
parser.add_argument(
|
|
"--env",
|
|
metavar="KEY=VALUE",
|
|
nargs='*',
|
|
help="Set a number of key-value pairs")
|
|
options = parser.parse_args()
|
|
env = DEFAULT_ENV
|
|
for key, value in parse_arg_env(options.env).items():
|
|
env[key] = value
|
|
|
|
if not os.path.exists(options.output_path):
|
|
os.makedirs(options.output_path)
|
|
|
|
if not os.path.exists(options.tmp_dir_path):
|
|
os.makedirs(options.tmp_dir_path)
|
|
|
|
genCppH(
|
|
hFilePath=options.output_path + "/spv.h",
|
|
cppFilePath=options.output_path + "/spv.cpp",
|
|
srcDirPaths=options.glsl_paths,
|
|
glslcPath=options.glslc_path,
|
|
tmpDirPath=options.tmp_dir_path,
|
|
env=env)
|
|
|
|
return 0
|
|
|
|
if __name__ == '__main__':
|
|
sys.exit(main(sys.argv))
|