[Codemod][python/main_function] caffe2: (#113357)

Differential Revision: D51149464

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113357
Approved by: https://github.com/huydhn
This commit is contained in:
Zsolt Dollenstein
2023-11-15 22:17:28 +00:00
committed by PyTorch MergeBot
parent 87aeb248c9
commit 9b736c707c
4 changed files with 116 additions and 66 deletions

View File

@ -7,14 +7,16 @@ import glob
import os
import re
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import subprocess
import textwrap
import yaml
from collections import OrderedDict
from torchgen.code_template import CodeTemplate
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Optional
from typing import Any, Dict, List, Optional, Tuple
import yaml
from torchgen.code_template import CodeTemplate
from yaml.constructor import ConstructorError
from yaml.nodes import MappingNode
@ -128,17 +130,21 @@ class ShaderInfo:
bias_storage_type: str = ""
register_for: Optional[Tuple[str, List[str]]] = None
def getName(filePath: str) -> str:
return os.path.basename(filePath).replace("/", "_").replace(".", "_")
def isDescriptorLine(lineStr: str) -> bool:
descriptorLineId = r"^layout\(set"
return re.search(descriptorLineId, lineStr) is not None
def isTileSizeLine(lineStr: str) -> bool:
tile_size_id = r"^ \* TILE_SIZE = \("
return re.search(tile_size_id, lineStr) is not None
def findTileSizes(lineStr: str) -> List[int]:
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
matches = re.search(tile_size_id, lineStr)
@ -146,10 +152,12 @@ def findTileSizes(lineStr: str) -> List[int]:
raise AssertionError("matches is None in findTileSizes")
return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
def isWeightStorageTypeLine(lineStr: str) -> bool:
weight_storage_id = r"^ \* WEIGHT_STORAGE = "
return re.search(weight_storage_id, lineStr) is not None
def getWeightStorageType(lineStr: str) -> str:
weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
matches = re.search(weight_storage_id, lineStr)
@ -157,10 +165,12 @@ def getWeightStorageType(lineStr: str) -> str:
raise AssertionError("matches is None in getWeightStorageType")
return matches.group(1)
def isBiasStorageTypeLine(lineStr: str) -> bool:
weight_storage_id = r"^ \* BIAS_STORAGE = "
return re.search(weight_storage_id, lineStr) is not None
def getBiasStorageType(lineStr: str) -> str:
weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
matches = re.search(weight_storage_id, lineStr)
@ -168,11 +178,15 @@ def getBiasStorageType(lineStr: str) -> str:
raise AssertionError("matches is None in getBiasStorageType")
return matches.group(1)
def isRegisterForLine(lineStr: str) -> bool:
# Check for Shader Name and a list of at least one Registry Key
register_for_id = r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
register_for_id = (
r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
)
return re.search(register_for_id, lineStr) is not None
def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
register_for_pattern = r"'([A-Za-z0-9_]+)'"
matches = re.findall(register_for_pattern, lineStr)
@ -181,6 +195,7 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
matches_list = list(matches)
return (matches_list[0], matches_list[1:])
typeIdMapping = {
r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE",
r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER",
@ -189,12 +204,13 @@ typeIdMapping = {
}
storageTypeToEnum = {
"TEXTURE_2D" : "api::StorageType::TEXTURE_2D",
"TEXTURE_3D" : "api::StorageType::TEXTURE_3D",
"BUFFER" : "api::StorageType::BUFFER",
"TEXTURE_2D": "api::StorageType::TEXTURE_2D",
"TEXTURE_3D": "api::StorageType::TEXTURE_3D",
"BUFFER": "api::StorageType::BUFFER",
"": "api::StorageType::UNKNOWN",
}
def determineDescriptorType(lineStr: str) -> str:
for identifier, typeNum in typeIdMapping.items():
if re.search(identifier, lineStr):
@ -203,6 +219,7 @@ def determineDescriptorType(lineStr: str) -> str:
"No matching descriptor type for " + lineStr + " in determineDescriptorType"
)
def getShaderInfo(srcFilePath: str) -> ShaderInfo:
shader_info = ShaderInfo([], [], "")
with open(srcFilePath) as srcFile:
@ -220,9 +237,10 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
return shader_info
def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
template_dir_path = os.path.join(src_dir_path, "templates")
vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True)
vexs = glob.glob(os.path.join(template_dir_path, "**", "*.yaml"), recursive=True)
parameter_yaml_files = []
for f in vexs:
if len(f) > 1:
@ -231,7 +249,7 @@ def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
for params_yaml in parameter_yaml_files:
generator.add_params_yaml(params_yaml) # type: ignore[no-untyped-call]
vexs = glob.glob(os.path.join(src_dir_path, '**', '*.glslt'), recursive=True)
vexs = glob.glob(os.path.join(src_dir_path, "**", "*.glslt"), recursive=True)
templateSrcPaths = []
for f in vexs:
if len(f) > 1:
@ -258,7 +276,7 @@ def genCppH(
templateSrcPaths = []
for srcDirPath in srcDirPaths:
vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True)
vexs = glob.glob(os.path.join(srcDirPath, "**", "*.glsl"), recursive=True)
for f in vexs:
if len(f) > 1:
templateSrcPaths.append(f)
@ -267,7 +285,7 @@ def genCppH(
# Now add glsl files that are generated from templates
genGLSLFromGLSLT(srcDirPath, tmpDirPath)
vexs = glob.glob(os.path.join(tmpDirPath, '**', '*.glsl'), recursive=True)
vexs = glob.glob(os.path.join(tmpDirPath, "**", "*.glsl"), recursive=True)
for f in vexs:
if len(f) > 1:
templateSrcPaths.append(f)
@ -283,17 +301,20 @@ def genCppH(
codeTemplate = CodeTemplate.from_file(templateSrcPath)
srcPath = tmpDirPath + "/" + name + ".glsl"
content = codeTemplate.substitute(env)
with open(srcPath, 'w') as fw:
with open(srcPath, "w") as fw:
fw.write(content)
spvPath = tmpDirPath + "/" + name + ".spv"
print(f"spvPath {spvPath}")
cmd = [
glslcPath, "-fshader-stage=compute",
srcPath, "-o", spvPath,
glslcPath,
"-fshader-stage=compute",
srcPath,
"-o",
spvPath,
"--target-env=vulkan1.0",
"-Werror"
"-Werror",
] + [arg for srcDirPath in srcDirPaths for arg in ["-I", srcDirPath]]
print("\nglslc cmd:", cmd)
@ -323,7 +344,9 @@ def genCppH(
h += "extern const ShaderListing shader_infos;\n"
h += "extern ShaderRegistry shader_registry;\n"
h += "inline const ShaderListing& get_shader_infos() {\n return shader_infos;\n}\n"
h += "inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n}\n"
h += (
"inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n}\n"
)
h += nsend
@ -341,8 +364,8 @@ def genCppH(
name = getName(spvPath).replace("_spv", "")
print(f"spvPath:{spvPath}")
with open(spvPath, 'rb') as fr:
next_bin = array.array('I', fr.read())
with open(spvPath, "rb") as fr:
next_bin = array.array("I", fr.read())
sizeBytes = 4 * len(next_bin)
shader_info_bin_code.append(
"const uint32_t {}_bin[] = {{\n{}\n}};".format(
@ -362,7 +385,7 @@ def genCppH(
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
shader_info_args = [
f"\"vulkan.{name}\"",
f'"vulkan.{name}"',
f"{name}_bin",
str(sizeBytes),
shader_info_layouts,
@ -373,7 +396,7 @@ def genCppH(
shader_info_cpp_code.append(
textwrap.indent(
"{{\"{}\",\n api::ShaderInfo(\n{})}}".format(
'{{"{}",\n api::ShaderInfo(\n{})}}'.format(
name,
textwrap.indent(",\n".join(shader_info_args), " "),
),
@ -386,7 +409,7 @@ def genCppH(
for registry_key in registry_keys:
shader_info_registry_code.append(
textwrap.indent(
f"{{\"{op_name}\", {{{{\"{registry_key}\", \"{name}\"}}}}}}",
f'{{"{op_name}", {{{{"{registry_key}", "{name}"}}}}}}',
" ",
),
)
@ -421,34 +444,20 @@ def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
def main(argv: List[str]) -> int:
parser = argparse.ArgumentParser(description='')
parser = argparse.ArgumentParser(description="")
parser.add_argument(
'-i',
'--glsl-paths',
nargs='+',
"-i",
"--glsl-paths",
nargs="+",
help='List of paths to look for GLSL source files, separated by spaces. Ex: --glsl-paths "path1 path2 path3"',
default=['.'],
default=["."],
)
parser.add_argument("-c", "--glslc-path", required=True, help="")
parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp")
parser.add_argument("-o", "--output-path", required=True, help="")
parser.add_argument(
'-c',
'--glslc-path',
required=True,
help='')
parser.add_argument(
'-t',
'--tmp-dir-path',
required=True,
help='/tmp')
parser.add_argument(
'-o',
'--output-path',
required=True,
help='')
parser.add_argument(
"--env",
metavar="KEY=VALUE",
nargs='*',
help="Set a number of key-value pairs")
"--env", metavar="KEY=VALUE", nargs="*", help="Set a number of key-value pairs"
)
options = parser.parse_args()
env = DEFAULT_ENV
for key, value in parse_arg_env(options.env).items():
@ -466,9 +475,15 @@ def main(argv: List[str]) -> int:
srcDirPaths=options.glsl_paths,
glslcPath=options.glslc_path,
tmpDirPath=options.tmp_dir_path,
env=env)
env=env,
)
return 0
if __name__ == '__main__':
def invoke_main() -> None:
sys.exit(main(sys.argv))
if __name__ == "__main__":
invoke_main() # pragma: no cover

View File

@ -3,7 +3,7 @@ import os
import os.path
if __name__ == "__main__":
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--input-file")
parser.add_argument("--output-file")
@ -22,3 +22,7 @@ if __name__ == "__main__":
with open(output_file, "w") as f:
f.write(contents)
if __name__ == "__main__":
main() # pragma: no cover

View File

@ -26,10 +26,10 @@ import functools
import itertools
import marshal
import os
import types
from dataclasses import dataclass
from pathlib import Path
from typing import List
import types
PATH_MARKER = "<Generated by torch::deploy>"
@ -121,10 +121,10 @@ class Freezer:
Shared frozen modules evenly across the files.
"""
bytecode_file_names = [
f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)
bytecode_file_names = [f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)]
bytecode_files = [
open(os.path.join(install_root, name), "w") for name in bytecode_file_names
]
bytecode_files = [open(os.path.join(install_root, name), "w") for name in bytecode_file_names]
it = itertools.cycle(bytecode_files)
for m in self.frozen_modules:
self.write_frozen(m, next(it))
@ -202,7 +202,6 @@ class Freezer:
module_parent = normalized_path.parent.parts
return list(module_parent) + [module_basename]
def compile_string(self, file_content: str) -> types.CodeType:
# instead of passing in the real build time path to 'compile', we
# pass in a marker instead. This prevents the build time path being
@ -239,19 +238,26 @@ class Freezer:
bytecode = marshal.dumps(co)
size = len(bytecode)
if path.name == '__init__.py':
if path.name == "__init__.py":
# Python packages are signified by negative size.
size = -size
self.frozen_modules.append(
FrozenModule(".".join(module_qualname), c_name, size, bytecode)
)
if __name__ == "__main__":
def main() -> None:
parser = argparse.ArgumentParser(description="Compile py source")
parser.add_argument("paths", nargs="*", help="Paths to freeze.")
parser.add_argument("--verbose", action="store_true", help="Print debug logs")
parser.add_argument("--install-dir", "--install_dir", help="Root directory for all output files")
parser.add_argument("--oss", action="store_true", help="If it's OSS build, add a fake _PyImport_FrozenModules")
parser.add_argument(
"--install-dir", "--install_dir", help="Root directory for all output files"
)
parser.add_argument(
"--oss",
action="store_true",
help="If it's OSS build, add a fake _PyImport_FrozenModules",
)
parser.add_argument(
"--symbol-name",
"--symbol_name",
@ -265,7 +271,7 @@ if __name__ == "__main__":
for p in args.paths:
path = Path(p)
if path.is_dir() and not Path.exists(path / '__init__.py'):
if path.is_dir() and not Path.exists(path / "__init__.py"):
# this 'top level path p' is a standard directory containing modules,
# not a module itself
# each 'mod' could be a dir containing __init__.py or .py file
@ -277,3 +283,7 @@ if __name__ == "__main__":
f.write_bytecode(args.install_dir)
f.write_main(args.install_dir, args.oss, args.symbol_name)
if __name__ == "__main__":
main() # pragma: no cover

View File

@ -25,11 +25,15 @@ DENY_LIST = [
"_bootstrap_external.py",
]
strip_file_dir = ""
def remove_prefix(text, prefix):
if text.startswith(prefix):
return text[len(prefix):]
return text[len(prefix) :]
return text
def write_to_zip(file_path, strip_file_path, zf, prepend_str=""):
stripped_file_path = prepend_str + remove_prefix(file_path, strip_file_dir + "/")
path = Path(stripped_file_path)
@ -37,28 +41,45 @@ def write_to_zip(file_path, strip_file_path, zf, prepend_str=""):
return
zf.write(file_path, stripped_file_path)
if __name__ == "__main__":
def main() -> None:
global strip_file_dir
parser = argparse.ArgumentParser(description="Zip py source")
parser.add_argument("paths", nargs="*", help="Paths to zip.")
parser.add_argument("--install-dir", "--install_dir", help="Root directory for all output files")
parser.add_argument("--strip-dir", "--strip_dir", help="The absolute directory we want to remove from zip")
parser.add_argument(
"--prepend-str", "--prepend_str", help="A string to prepend onto all paths of a file in the zip", default=""
"--install-dir", "--install_dir", help="Root directory for all output files"
)
parser.add_argument(
"--strip-dir",
"--strip_dir",
help="The absolute directory we want to remove from zip",
)
parser.add_argument(
"--prepend-str",
"--prepend_str",
help="A string to prepend onto all paths of a file in the zip",
default="",
)
parser.add_argument("--zip-name", "--zip_name", help="Output zip name")
args = parser.parse_args()
zip_file_name = args.install_dir + '/' + args.zip_name
zip_file_name = args.install_dir + "/" + args.zip_name
strip_file_dir = args.strip_dir
prepend_str = args.prepend_str
zf = ZipFile(zip_file_name, mode='w')
zf = ZipFile(zip_file_name, mode="w")
for p in sorted(args.paths):
if os.path.isdir(p):
files = glob.glob(p + "/**/*.py", recursive=True)
for file_path in sorted(files):
# strip the absolute path
write_to_zip(file_path, strip_file_dir + "/", zf, prepend_str=prepend_str)
write_to_zip(
file_path, strip_file_dir + "/", zf, prepend_str=prepend_str
)
else:
write_to_zip(p, strip_file_dir + "/", zf, prepend_str=prepend_str)
if __name__ == "__main__":
main() # pragma: no cover