diff --git a/tools/gen_vulkan_spv.py b/tools/gen_vulkan_spv.py index 1c09980652e1..1d0f8b2c61fd 100644 --- a/tools/gen_vulkan_spv.py +++ b/tools/gen_vulkan_spv.py @@ -7,14 +7,16 @@ import glob import os import re import sys + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) import subprocess import textwrap -import yaml from collections import OrderedDict -from torchgen.code_template import CodeTemplate from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, List, Optional, Tuple + +import yaml +from torchgen.code_template import CodeTemplate from yaml.constructor import ConstructorError from yaml.nodes import MappingNode @@ -128,17 +130,21 @@ class ShaderInfo: bias_storage_type: str = "" register_for: Optional[Tuple[str, List[str]]] = None + def getName(filePath: str) -> str: return os.path.basename(filePath).replace("/", "_").replace(".", "_") + def isDescriptorLine(lineStr: str) -> bool: descriptorLineId = r"^layout\(set" return re.search(descriptorLineId, lineStr) is not None + def isTileSizeLine(lineStr: str) -> bool: tile_size_id = r"^ \* TILE_SIZE = \(" return re.search(tile_size_id, lineStr) is not None + def findTileSizes(lineStr: str) -> List[int]: tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)" matches = re.search(tile_size_id, lineStr) @@ -146,10 +152,12 @@ def findTileSizes(lineStr: str) -> List[int]: raise AssertionError("matches is None in findTileSizes") return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))] + def isWeightStorageTypeLine(lineStr: str) -> bool: weight_storage_id = r"^ \* WEIGHT_STORAGE = " return re.search(weight_storage_id, lineStr) is not None + def getWeightStorageType(lineStr: str) -> str: weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)" matches = re.search(weight_storage_id, lineStr) @@ -157,10 +165,12 @@ def getWeightStorageType(lineStr: str) -> str: raise AssertionError("matches is None in getWeightStorageType") return matches.group(1) + def isBiasStorageTypeLine(lineStr: str) -> bool: weight_storage_id = r"^ \* BIAS_STORAGE = " return re.search(weight_storage_id, lineStr) is not None + def getBiasStorageType(lineStr: str) -> str: weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)" matches = re.search(weight_storage_id, lineStr) @@ -168,11 +178,15 @@ def getBiasStorageType(lineStr: str) -> str: raise AssertionError("matches is None in getBiasStorageType") return matches.group(1) + def isRegisterForLine(lineStr: str) -> bool: # Check for Shader Name and a list of at least one Registry Key - register_for_id = r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)" + register_for_id = ( + r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)" + ) return re.search(register_for_id, lineStr) is not None + def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]: register_for_pattern = r"'([A-Za-z0-9_]+)'" matches = re.findall(register_for_pattern, lineStr) @@ -181,6 +195,7 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]: matches_list = list(matches) return (matches_list[0], matches_list[1:]) + typeIdMapping = { r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE", r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER", @@ -189,12 +204,13 @@ typeIdMapping = { } storageTypeToEnum = { - "TEXTURE_2D" : "api::StorageType::TEXTURE_2D", - "TEXTURE_3D" : "api::StorageType::TEXTURE_3D", - "BUFFER" : "api::StorageType::BUFFER", + "TEXTURE_2D": "api::StorageType::TEXTURE_2D", + "TEXTURE_3D": "api::StorageType::TEXTURE_3D", + "BUFFER": "api::StorageType::BUFFER", "": "api::StorageType::UNKNOWN", } + def determineDescriptorType(lineStr: str) -> str: for identifier, typeNum in typeIdMapping.items(): if re.search(identifier, lineStr): @@ -203,6 +219,7 @@ def determineDescriptorType(lineStr: str) -> str: "No matching descriptor type for " + lineStr + " in determineDescriptorType" ) + def getShaderInfo(srcFilePath: str) -> ShaderInfo: shader_info = ShaderInfo([], [], "") with open(srcFilePath) as srcFile: @@ -220,9 +237,10 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo: return shader_info + def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None: template_dir_path = os.path.join(src_dir_path, "templates") - vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True) + vexs = glob.glob(os.path.join(template_dir_path, "**", "*.yaml"), recursive=True) parameter_yaml_files = [] for f in vexs: if len(f) > 1: @@ -231,7 +249,7 @@ def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None: for params_yaml in parameter_yaml_files: generator.add_params_yaml(params_yaml) # type: ignore[no-untyped-call] - vexs = glob.glob(os.path.join(src_dir_path, '**', '*.glslt'), recursive=True) + vexs = glob.glob(os.path.join(src_dir_path, "**", "*.glslt"), recursive=True) templateSrcPaths = [] for f in vexs: if len(f) > 1: @@ -258,7 +276,7 @@ def genCppH( templateSrcPaths = [] for srcDirPath in srcDirPaths: - vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True) + vexs = glob.glob(os.path.join(srcDirPath, "**", "*.glsl"), recursive=True) for f in vexs: if len(f) > 1: templateSrcPaths.append(f) @@ -267,7 +285,7 @@ def genCppH( # Now add glsl files that are generated from templates genGLSLFromGLSLT(srcDirPath, tmpDirPath) - vexs = glob.glob(os.path.join(tmpDirPath, '**', '*.glsl'), recursive=True) + vexs = glob.glob(os.path.join(tmpDirPath, "**", "*.glsl"), recursive=True) for f in vexs: if len(f) > 1: templateSrcPaths.append(f) @@ -283,17 +301,20 @@ def genCppH( codeTemplate = CodeTemplate.from_file(templateSrcPath) srcPath = tmpDirPath + "/" + name + ".glsl" content = codeTemplate.substitute(env) - with open(srcPath, 'w') as fw: + with open(srcPath, "w") as fw: fw.write(content) spvPath = tmpDirPath + "/" + name + ".spv" print(f"spvPath {spvPath}") cmd = [ - glslcPath, "-fshader-stage=compute", - srcPath, "-o", spvPath, + glslcPath, + "-fshader-stage=compute", + srcPath, + "-o", + spvPath, "--target-env=vulkan1.0", - "-Werror" + "-Werror", ] + [arg for srcDirPath in srcDirPaths for arg in ["-I", srcDirPath]] print("\nglslc cmd:", cmd) @@ -323,7 +344,9 @@ def genCppH( h += "extern const ShaderListing shader_infos;\n" h += "extern ShaderRegistry shader_registry;\n" h += "inline const ShaderListing& get_shader_infos() {\n return shader_infos;\n}\n" - h += "inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n}\n" + h += ( + "inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n}\n" + ) h += nsend @@ -341,8 +364,8 @@ def genCppH( name = getName(spvPath).replace("_spv", "") print(f"spvPath:{spvPath}") - with open(spvPath, 'rb') as fr: - next_bin = array.array('I', fr.read()) + with open(spvPath, "rb") as fr: + next_bin = array.array("I", fr.read()) sizeBytes = 4 * len(next_bin) shader_info_bin_code.append( "const uint32_t {}_bin[] = {{\n{}\n}};".format( @@ -362,7 +385,7 @@ def genCppH( shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts)) shader_info_args = [ - f"\"vulkan.{name}\"", + f'"vulkan.{name}"', f"{name}_bin", str(sizeBytes), shader_info_layouts, @@ -373,7 +396,7 @@ def genCppH( shader_info_cpp_code.append( textwrap.indent( - "{{\"{}\",\n api::ShaderInfo(\n{})}}".format( + '{{"{}",\n api::ShaderInfo(\n{})}}'.format( name, textwrap.indent(",\n".join(shader_info_args), " "), ), @@ -386,7 +409,7 @@ def genCppH( for registry_key in registry_keys: shader_info_registry_code.append( textwrap.indent( - f"{{\"{op_name}\", {{{{\"{registry_key}\", \"{name}\"}}}}}}", + f'{{"{op_name}", {{{{"{registry_key}", "{name}"}}}}}}', " ", ), ) @@ -421,34 +444,20 @@ def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]: def main(argv: List[str]) -> int: - parser = argparse.ArgumentParser(description='') + parser = argparse.ArgumentParser(description="") parser.add_argument( - '-i', - '--glsl-paths', - nargs='+', + "-i", + "--glsl-paths", + nargs="+", help='List of paths to look for GLSL source files, separated by spaces. Ex: --glsl-paths "path1 path2 path3"', - default=['.'], + default=["."], ) + parser.add_argument("-c", "--glslc-path", required=True, help="") + parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp") + parser.add_argument("-o", "--output-path", required=True, help="") parser.add_argument( - '-c', - '--glslc-path', - required=True, - help='') - parser.add_argument( - '-t', - '--tmp-dir-path', - required=True, - help='/tmp') - parser.add_argument( - '-o', - '--output-path', - required=True, - help='') - parser.add_argument( - "--env", - metavar="KEY=VALUE", - nargs='*', - help="Set a number of key-value pairs") + "--env", metavar="KEY=VALUE", nargs="*", help="Set a number of key-value pairs" + ) options = parser.parse_args() env = DEFAULT_ENV for key, value in parse_arg_env(options.env).items(): @@ -466,9 +475,15 @@ def main(argv: List[str]) -> int: srcDirPaths=options.glsl_paths, glslcPath=options.glslc_path, tmpDirPath=options.tmp_dir_path, - env=env) + env=env, + ) return 0 -if __name__ == '__main__': + +def invoke_main() -> None: sys.exit(main(sys.argv)) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/tools/substitute.py b/tools/substitute.py index e9c05990c75f..c988e78f97c3 100644 --- a/tools/substitute.py +++ b/tools/substitute.py @@ -3,7 +3,7 @@ import os import os.path -if __name__ == "__main__": +def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--input-file") parser.add_argument("--output-file") @@ -22,3 +22,7 @@ if __name__ == "__main__": with open(output_file, "w") as f: f.write(contents) + + +if __name__ == "__main__": + main() # pragma: no cover diff --git a/torch/utils/_freeze.py b/torch/utils/_freeze.py index 6acb7a075d1a..c7be90a4baee 100644 --- a/torch/utils/_freeze.py +++ b/torch/utils/_freeze.py @@ -26,10 +26,10 @@ import functools import itertools import marshal import os +import types from dataclasses import dataclass from pathlib import Path from typing import List -import types PATH_MARKER = "" @@ -121,10 +121,10 @@ class Freezer: Shared frozen modules evenly across the files. """ - bytecode_file_names = [ - f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES) + bytecode_file_names = [f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)] + bytecode_files = [ + open(os.path.join(install_root, name), "w") for name in bytecode_file_names ] - bytecode_files = [open(os.path.join(install_root, name), "w") for name in bytecode_file_names] it = itertools.cycle(bytecode_files) for m in self.frozen_modules: self.write_frozen(m, next(it)) @@ -202,7 +202,6 @@ class Freezer: module_parent = normalized_path.parent.parts return list(module_parent) + [module_basename] - def compile_string(self, file_content: str) -> types.CodeType: # instead of passing in the real build time path to 'compile', we # pass in a marker instead. This prevents the build time path being @@ -239,19 +238,26 @@ class Freezer: bytecode = marshal.dumps(co) size = len(bytecode) - if path.name == '__init__.py': + if path.name == "__init__.py": # Python packages are signified by negative size. size = -size self.frozen_modules.append( FrozenModule(".".join(module_qualname), c_name, size, bytecode) ) -if __name__ == "__main__": + +def main() -> None: parser = argparse.ArgumentParser(description="Compile py source") parser.add_argument("paths", nargs="*", help="Paths to freeze.") parser.add_argument("--verbose", action="store_true", help="Print debug logs") - parser.add_argument("--install-dir", "--install_dir", help="Root directory for all output files") - parser.add_argument("--oss", action="store_true", help="If it's OSS build, add a fake _PyImport_FrozenModules") + parser.add_argument( + "--install-dir", "--install_dir", help="Root directory for all output files" + ) + parser.add_argument( + "--oss", + action="store_true", + help="If it's OSS build, add a fake _PyImport_FrozenModules", + ) parser.add_argument( "--symbol-name", "--symbol_name", @@ -265,7 +271,7 @@ if __name__ == "__main__": for p in args.paths: path = Path(p) - if path.is_dir() and not Path.exists(path / '__init__.py'): + if path.is_dir() and not Path.exists(path / "__init__.py"): # this 'top level path p' is a standard directory containing modules, # not a module itself # each 'mod' could be a dir containing __init__.py or .py file @@ -277,3 +283,7 @@ if __name__ == "__main__": f.write_bytecode(args.install_dir) f.write_main(args.install_dir, args.oss, args.symbol_name) + + +if __name__ == "__main__": + main() # pragma: no cover diff --git a/torch/utils/_zip.py b/torch/utils/_zip.py index 9514bb342b7e..f37ddb449878 100644 --- a/torch/utils/_zip.py +++ b/torch/utils/_zip.py @@ -25,11 +25,15 @@ DENY_LIST = [ "_bootstrap_external.py", ] +strip_file_dir = "" + + def remove_prefix(text, prefix): if text.startswith(prefix): - return text[len(prefix):] + return text[len(prefix) :] return text + def write_to_zip(file_path, strip_file_path, zf, prepend_str=""): stripped_file_path = prepend_str + remove_prefix(file_path, strip_file_dir + "/") path = Path(stripped_file_path) @@ -37,28 +41,45 @@ def write_to_zip(file_path, strip_file_path, zf, prepend_str=""): return zf.write(file_path, stripped_file_path) -if __name__ == "__main__": + +def main() -> None: + global strip_file_dir parser = argparse.ArgumentParser(description="Zip py source") parser.add_argument("paths", nargs="*", help="Paths to zip.") - parser.add_argument("--install-dir", "--install_dir", help="Root directory for all output files") - parser.add_argument("--strip-dir", "--strip_dir", help="The absolute directory we want to remove from zip") parser.add_argument( - "--prepend-str", "--prepend_str", help="A string to prepend onto all paths of a file in the zip", default="" + "--install-dir", "--install_dir", help="Root directory for all output files" + ) + parser.add_argument( + "--strip-dir", + "--strip_dir", + help="The absolute directory we want to remove from zip", + ) + parser.add_argument( + "--prepend-str", + "--prepend_str", + help="A string to prepend onto all paths of a file in the zip", + default="", ) parser.add_argument("--zip-name", "--zip_name", help="Output zip name") args = parser.parse_args() - zip_file_name = args.install_dir + '/' + args.zip_name + zip_file_name = args.install_dir + "/" + args.zip_name strip_file_dir = args.strip_dir prepend_str = args.prepend_str - zf = ZipFile(zip_file_name, mode='w') + zf = ZipFile(zip_file_name, mode="w") for p in sorted(args.paths): if os.path.isdir(p): files = glob.glob(p + "/**/*.py", recursive=True) for file_path in sorted(files): # strip the absolute path - write_to_zip(file_path, strip_file_dir + "/", zf, prepend_str=prepend_str) + write_to_zip( + file_path, strip_file_dir + "/", zf, prepend_str=prepend_str + ) else: write_to_zip(p, strip_file_dir + "/", zf, prepend_str=prepend_str) + + +if __name__ == "__main__": + main() # pragma: no cover