mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Codemod][python/main_function] caffe2: (#113357)
Differential Revision: D51149464 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113357 Approved by: https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
87aeb248c9
commit
9b736c707c
@ -7,14 +7,16 @@ import glob
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
import subprocess
|
||||
import textwrap
|
||||
import yaml
|
||||
from collections import OrderedDict
|
||||
from torchgen.code_template import CodeTemplate
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import yaml
|
||||
from torchgen.code_template import CodeTemplate
|
||||
from yaml.constructor import ConstructorError
|
||||
from yaml.nodes import MappingNode
|
||||
|
||||
@ -128,17 +130,21 @@ class ShaderInfo:
|
||||
bias_storage_type: str = ""
|
||||
register_for: Optional[Tuple[str, List[str]]] = None
|
||||
|
||||
|
||||
def getName(filePath: str) -> str:
|
||||
return os.path.basename(filePath).replace("/", "_").replace(".", "_")
|
||||
|
||||
|
||||
def isDescriptorLine(lineStr: str) -> bool:
|
||||
descriptorLineId = r"^layout\(set"
|
||||
return re.search(descriptorLineId, lineStr) is not None
|
||||
|
||||
|
||||
def isTileSizeLine(lineStr: str) -> bool:
|
||||
tile_size_id = r"^ \* TILE_SIZE = \("
|
||||
return re.search(tile_size_id, lineStr) is not None
|
||||
|
||||
|
||||
def findTileSizes(lineStr: str) -> List[int]:
|
||||
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
|
||||
matches = re.search(tile_size_id, lineStr)
|
||||
@ -146,10 +152,12 @@ def findTileSizes(lineStr: str) -> List[int]:
|
||||
raise AssertionError("matches is None in findTileSizes")
|
||||
return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
|
||||
|
||||
|
||||
def isWeightStorageTypeLine(lineStr: str) -> bool:
|
||||
weight_storage_id = r"^ \* WEIGHT_STORAGE = "
|
||||
return re.search(weight_storage_id, lineStr) is not None
|
||||
|
||||
|
||||
def getWeightStorageType(lineStr: str) -> str:
|
||||
weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
|
||||
matches = re.search(weight_storage_id, lineStr)
|
||||
@ -157,10 +165,12 @@ def getWeightStorageType(lineStr: str) -> str:
|
||||
raise AssertionError("matches is None in getWeightStorageType")
|
||||
return matches.group(1)
|
||||
|
||||
|
||||
def isBiasStorageTypeLine(lineStr: str) -> bool:
|
||||
weight_storage_id = r"^ \* BIAS_STORAGE = "
|
||||
return re.search(weight_storage_id, lineStr) is not None
|
||||
|
||||
|
||||
def getBiasStorageType(lineStr: str) -> str:
|
||||
weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
|
||||
matches = re.search(weight_storage_id, lineStr)
|
||||
@ -168,11 +178,15 @@ def getBiasStorageType(lineStr: str) -> str:
|
||||
raise AssertionError("matches is None in getBiasStorageType")
|
||||
return matches.group(1)
|
||||
|
||||
|
||||
def isRegisterForLine(lineStr: str) -> bool:
|
||||
# Check for Shader Name and a list of at least one Registry Key
|
||||
register_for_id = r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
|
||||
register_for_id = (
|
||||
r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
|
||||
)
|
||||
return re.search(register_for_id, lineStr) is not None
|
||||
|
||||
|
||||
def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
|
||||
register_for_pattern = r"'([A-Za-z0-9_]+)'"
|
||||
matches = re.findall(register_for_pattern, lineStr)
|
||||
@ -181,6 +195,7 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
|
||||
matches_list = list(matches)
|
||||
return (matches_list[0], matches_list[1:])
|
||||
|
||||
|
||||
typeIdMapping = {
|
||||
r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE",
|
||||
r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER",
|
||||
@ -189,12 +204,13 @@ typeIdMapping = {
|
||||
}
|
||||
|
||||
storageTypeToEnum = {
|
||||
"TEXTURE_2D" : "api::StorageType::TEXTURE_2D",
|
||||
"TEXTURE_3D" : "api::StorageType::TEXTURE_3D",
|
||||
"BUFFER" : "api::StorageType::BUFFER",
|
||||
"TEXTURE_2D": "api::StorageType::TEXTURE_2D",
|
||||
"TEXTURE_3D": "api::StorageType::TEXTURE_3D",
|
||||
"BUFFER": "api::StorageType::BUFFER",
|
||||
"": "api::StorageType::UNKNOWN",
|
||||
}
|
||||
|
||||
|
||||
def determineDescriptorType(lineStr: str) -> str:
|
||||
for identifier, typeNum in typeIdMapping.items():
|
||||
if re.search(identifier, lineStr):
|
||||
@ -203,6 +219,7 @@ def determineDescriptorType(lineStr: str) -> str:
|
||||
"No matching descriptor type for " + lineStr + " in determineDescriptorType"
|
||||
)
|
||||
|
||||
|
||||
def getShaderInfo(srcFilePath: str) -> ShaderInfo:
|
||||
shader_info = ShaderInfo([], [], "")
|
||||
with open(srcFilePath) as srcFile:
|
||||
@ -220,9 +237,10 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
|
||||
|
||||
return shader_info
|
||||
|
||||
|
||||
def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
|
||||
template_dir_path = os.path.join(src_dir_path, "templates")
|
||||
vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True)
|
||||
vexs = glob.glob(os.path.join(template_dir_path, "**", "*.yaml"), recursive=True)
|
||||
parameter_yaml_files = []
|
||||
for f in vexs:
|
||||
if len(f) > 1:
|
||||
@ -231,7 +249,7 @@ def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
|
||||
for params_yaml in parameter_yaml_files:
|
||||
generator.add_params_yaml(params_yaml) # type: ignore[no-untyped-call]
|
||||
|
||||
vexs = glob.glob(os.path.join(src_dir_path, '**', '*.glslt'), recursive=True)
|
||||
vexs = glob.glob(os.path.join(src_dir_path, "**", "*.glslt"), recursive=True)
|
||||
templateSrcPaths = []
|
||||
for f in vexs:
|
||||
if len(f) > 1:
|
||||
@ -258,7 +276,7 @@ def genCppH(
|
||||
templateSrcPaths = []
|
||||
|
||||
for srcDirPath in srcDirPaths:
|
||||
vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True)
|
||||
vexs = glob.glob(os.path.join(srcDirPath, "**", "*.glsl"), recursive=True)
|
||||
for f in vexs:
|
||||
if len(f) > 1:
|
||||
templateSrcPaths.append(f)
|
||||
@ -267,7 +285,7 @@ def genCppH(
|
||||
# Now add glsl files that are generated from templates
|
||||
genGLSLFromGLSLT(srcDirPath, tmpDirPath)
|
||||
|
||||
vexs = glob.glob(os.path.join(tmpDirPath, '**', '*.glsl'), recursive=True)
|
||||
vexs = glob.glob(os.path.join(tmpDirPath, "**", "*.glsl"), recursive=True)
|
||||
for f in vexs:
|
||||
if len(f) > 1:
|
||||
templateSrcPaths.append(f)
|
||||
@ -283,17 +301,20 @@ def genCppH(
|
||||
codeTemplate = CodeTemplate.from_file(templateSrcPath)
|
||||
srcPath = tmpDirPath + "/" + name + ".glsl"
|
||||
content = codeTemplate.substitute(env)
|
||||
with open(srcPath, 'w') as fw:
|
||||
with open(srcPath, "w") as fw:
|
||||
fw.write(content)
|
||||
|
||||
spvPath = tmpDirPath + "/" + name + ".spv"
|
||||
print(f"spvPath {spvPath}")
|
||||
|
||||
cmd = [
|
||||
glslcPath, "-fshader-stage=compute",
|
||||
srcPath, "-o", spvPath,
|
||||
glslcPath,
|
||||
"-fshader-stage=compute",
|
||||
srcPath,
|
||||
"-o",
|
||||
spvPath,
|
||||
"--target-env=vulkan1.0",
|
||||
"-Werror"
|
||||
"-Werror",
|
||||
] + [arg for srcDirPath in srcDirPaths for arg in ["-I", srcDirPath]]
|
||||
|
||||
print("\nglslc cmd:", cmd)
|
||||
@ -323,7 +344,9 @@ def genCppH(
|
||||
h += "extern const ShaderListing shader_infos;\n"
|
||||
h += "extern ShaderRegistry shader_registry;\n"
|
||||
h += "inline const ShaderListing& get_shader_infos() {\n return shader_infos;\n}\n"
|
||||
h += "inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n}\n"
|
||||
h += (
|
||||
"inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n}\n"
|
||||
)
|
||||
|
||||
h += nsend
|
||||
|
||||
@ -341,8 +364,8 @@ def genCppH(
|
||||
name = getName(spvPath).replace("_spv", "")
|
||||
|
||||
print(f"spvPath:{spvPath}")
|
||||
with open(spvPath, 'rb') as fr:
|
||||
next_bin = array.array('I', fr.read())
|
||||
with open(spvPath, "rb") as fr:
|
||||
next_bin = array.array("I", fr.read())
|
||||
sizeBytes = 4 * len(next_bin)
|
||||
shader_info_bin_code.append(
|
||||
"const uint32_t {}_bin[] = {{\n{}\n}};".format(
|
||||
@ -362,7 +385,7 @@ def genCppH(
|
||||
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
|
||||
|
||||
shader_info_args = [
|
||||
f"\"vulkan.{name}\"",
|
||||
f'"vulkan.{name}"',
|
||||
f"{name}_bin",
|
||||
str(sizeBytes),
|
||||
shader_info_layouts,
|
||||
@ -373,7 +396,7 @@ def genCppH(
|
||||
|
||||
shader_info_cpp_code.append(
|
||||
textwrap.indent(
|
||||
"{{\"{}\",\n api::ShaderInfo(\n{})}}".format(
|
||||
'{{"{}",\n api::ShaderInfo(\n{})}}'.format(
|
||||
name,
|
||||
textwrap.indent(",\n".join(shader_info_args), " "),
|
||||
),
|
||||
@ -386,7 +409,7 @@ def genCppH(
|
||||
for registry_key in registry_keys:
|
||||
shader_info_registry_code.append(
|
||||
textwrap.indent(
|
||||
f"{{\"{op_name}\", {{{{\"{registry_key}\", \"{name}\"}}}}}}",
|
||||
f'{{"{op_name}", {{{{"{registry_key}", "{name}"}}}}}}',
|
||||
" ",
|
||||
),
|
||||
)
|
||||
@ -421,34 +444,20 @@ def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||
|
||||
|
||||
def main(argv: List[str]) -> int:
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser = argparse.ArgumentParser(description="")
|
||||
parser.add_argument(
|
||||
'-i',
|
||||
'--glsl-paths',
|
||||
nargs='+',
|
||||
"-i",
|
||||
"--glsl-paths",
|
||||
nargs="+",
|
||||
help='List of paths to look for GLSL source files, separated by spaces. Ex: --glsl-paths "path1 path2 path3"',
|
||||
default=['.'],
|
||||
default=["."],
|
||||
)
|
||||
parser.add_argument("-c", "--glslc-path", required=True, help="")
|
||||
parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp")
|
||||
parser.add_argument("-o", "--output-path", required=True, help="")
|
||||
parser.add_argument(
|
||||
'-c',
|
||||
'--glslc-path',
|
||||
required=True,
|
||||
help='')
|
||||
parser.add_argument(
|
||||
'-t',
|
||||
'--tmp-dir-path',
|
||||
required=True,
|
||||
help='/tmp')
|
||||
parser.add_argument(
|
||||
'-o',
|
||||
'--output-path',
|
||||
required=True,
|
||||
help='')
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
metavar="KEY=VALUE",
|
||||
nargs='*',
|
||||
help="Set a number of key-value pairs")
|
||||
"--env", metavar="KEY=VALUE", nargs="*", help="Set a number of key-value pairs"
|
||||
)
|
||||
options = parser.parse_args()
|
||||
env = DEFAULT_ENV
|
||||
for key, value in parse_arg_env(options.env).items():
|
||||
@ -466,9 +475,15 @@ def main(argv: List[str]) -> int:
|
||||
srcDirPaths=options.glsl_paths,
|
||||
glslcPath=options.glslc_path,
|
||||
tmpDirPath=options.tmp_dir_path,
|
||||
env=env)
|
||||
env=env,
|
||||
)
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def invoke_main() -> None:
|
||||
sys.exit(main(sys.argv))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_main() # pragma: no cover
|
||||
|
@ -3,7 +3,7 @@ import os
|
||||
import os.path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input-file")
|
||||
parser.add_argument("--output-file")
|
||||
@ -22,3 +22,7 @@ if __name__ == "__main__":
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
f.write(contents)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main() # pragma: no cover
|
||||
|
@ -26,10 +26,10 @@ import functools
|
||||
import itertools
|
||||
import marshal
|
||||
import os
|
||||
import types
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
import types
|
||||
|
||||
|
||||
PATH_MARKER = "<Generated by torch::deploy>"
|
||||
@ -121,10 +121,10 @@ class Freezer:
|
||||
|
||||
Shared frozen modules evenly across the files.
|
||||
"""
|
||||
bytecode_file_names = [
|
||||
f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)
|
||||
bytecode_file_names = [f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)]
|
||||
bytecode_files = [
|
||||
open(os.path.join(install_root, name), "w") for name in bytecode_file_names
|
||||
]
|
||||
bytecode_files = [open(os.path.join(install_root, name), "w") for name in bytecode_file_names]
|
||||
it = itertools.cycle(bytecode_files)
|
||||
for m in self.frozen_modules:
|
||||
self.write_frozen(m, next(it))
|
||||
@ -202,7 +202,6 @@ class Freezer:
|
||||
module_parent = normalized_path.parent.parts
|
||||
return list(module_parent) + [module_basename]
|
||||
|
||||
|
||||
def compile_string(self, file_content: str) -> types.CodeType:
|
||||
# instead of passing in the real build time path to 'compile', we
|
||||
# pass in a marker instead. This prevents the build time path being
|
||||
@ -239,19 +238,26 @@ class Freezer:
|
||||
|
||||
bytecode = marshal.dumps(co)
|
||||
size = len(bytecode)
|
||||
if path.name == '__init__.py':
|
||||
if path.name == "__init__.py":
|
||||
# Python packages are signified by negative size.
|
||||
size = -size
|
||||
self.frozen_modules.append(
|
||||
FrozenModule(".".join(module_qualname), c_name, size, bytecode)
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Compile py source")
|
||||
parser.add_argument("paths", nargs="*", help="Paths to freeze.")
|
||||
parser.add_argument("--verbose", action="store_true", help="Print debug logs")
|
||||
parser.add_argument("--install-dir", "--install_dir", help="Root directory for all output files")
|
||||
parser.add_argument("--oss", action="store_true", help="If it's OSS build, add a fake _PyImport_FrozenModules")
|
||||
parser.add_argument(
|
||||
"--install-dir", "--install_dir", help="Root directory for all output files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--oss",
|
||||
action="store_true",
|
||||
help="If it's OSS build, add a fake _PyImport_FrozenModules",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--symbol-name",
|
||||
"--symbol_name",
|
||||
@ -265,7 +271,7 @@ if __name__ == "__main__":
|
||||
|
||||
for p in args.paths:
|
||||
path = Path(p)
|
||||
if path.is_dir() and not Path.exists(path / '__init__.py'):
|
||||
if path.is_dir() and not Path.exists(path / "__init__.py"):
|
||||
# this 'top level path p' is a standard directory containing modules,
|
||||
# not a module itself
|
||||
# each 'mod' could be a dir containing __init__.py or .py file
|
||||
@ -277,3 +283,7 @@ if __name__ == "__main__":
|
||||
|
||||
f.write_bytecode(args.install_dir)
|
||||
f.write_main(args.install_dir, args.oss, args.symbol_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main() # pragma: no cover
|
||||
|
@ -25,11 +25,15 @@ DENY_LIST = [
|
||||
"_bootstrap_external.py",
|
||||
]
|
||||
|
||||
strip_file_dir = ""
|
||||
|
||||
|
||||
def remove_prefix(text, prefix):
|
||||
if text.startswith(prefix):
|
||||
return text[len(prefix):]
|
||||
return text[len(prefix) :]
|
||||
return text
|
||||
|
||||
|
||||
def write_to_zip(file_path, strip_file_path, zf, prepend_str=""):
|
||||
stripped_file_path = prepend_str + remove_prefix(file_path, strip_file_dir + "/")
|
||||
path = Path(stripped_file_path)
|
||||
@ -37,28 +41,45 @@ def write_to_zip(file_path, strip_file_path, zf, prepend_str=""):
|
||||
return
|
||||
zf.write(file_path, stripped_file_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main() -> None:
|
||||
global strip_file_dir
|
||||
parser = argparse.ArgumentParser(description="Zip py source")
|
||||
parser.add_argument("paths", nargs="*", help="Paths to zip.")
|
||||
parser.add_argument("--install-dir", "--install_dir", help="Root directory for all output files")
|
||||
parser.add_argument("--strip-dir", "--strip_dir", help="The absolute directory we want to remove from zip")
|
||||
parser.add_argument(
|
||||
"--prepend-str", "--prepend_str", help="A string to prepend onto all paths of a file in the zip", default=""
|
||||
"--install-dir", "--install_dir", help="Root directory for all output files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strip-dir",
|
||||
"--strip_dir",
|
||||
help="The absolute directory we want to remove from zip",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prepend-str",
|
||||
"--prepend_str",
|
||||
help="A string to prepend onto all paths of a file in the zip",
|
||||
default="",
|
||||
)
|
||||
parser.add_argument("--zip-name", "--zip_name", help="Output zip name")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
zip_file_name = args.install_dir + '/' + args.zip_name
|
||||
zip_file_name = args.install_dir + "/" + args.zip_name
|
||||
strip_file_dir = args.strip_dir
|
||||
prepend_str = args.prepend_str
|
||||
zf = ZipFile(zip_file_name, mode='w')
|
||||
zf = ZipFile(zip_file_name, mode="w")
|
||||
|
||||
for p in sorted(args.paths):
|
||||
if os.path.isdir(p):
|
||||
files = glob.glob(p + "/**/*.py", recursive=True)
|
||||
for file_path in sorted(files):
|
||||
# strip the absolute path
|
||||
write_to_zip(file_path, strip_file_dir + "/", zf, prepend_str=prepend_str)
|
||||
write_to_zip(
|
||||
file_path, strip_file_dir + "/", zf, prepend_str=prepend_str
|
||||
)
|
||||
else:
|
||||
write_to_zip(p, strip_file_dir + "/", zf, prepend_str=prepend_str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main() # pragma: no cover
|
||||
|
Reference in New Issue
Block a user