mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Codemod][python/main_function] caffe2: (#113357)
Differential Revision: D51149464 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113357 Approved by: https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
87aeb248c9
commit
9b736c707c
@ -7,14 +7,16 @@ import glob
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
import subprocess
|
import subprocess
|
||||||
import textwrap
|
import textwrap
|
||||||
import yaml
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from torchgen.code_template import CodeTemplate
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Tuple, Optional
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from torchgen.code_template import CodeTemplate
|
||||||
from yaml.constructor import ConstructorError
|
from yaml.constructor import ConstructorError
|
||||||
from yaml.nodes import MappingNode
|
from yaml.nodes import MappingNode
|
||||||
|
|
||||||
@ -128,17 +130,21 @@ class ShaderInfo:
|
|||||||
bias_storage_type: str = ""
|
bias_storage_type: str = ""
|
||||||
register_for: Optional[Tuple[str, List[str]]] = None
|
register_for: Optional[Tuple[str, List[str]]] = None
|
||||||
|
|
||||||
|
|
||||||
def getName(filePath: str) -> str:
|
def getName(filePath: str) -> str:
|
||||||
return os.path.basename(filePath).replace("/", "_").replace(".", "_")
|
return os.path.basename(filePath).replace("/", "_").replace(".", "_")
|
||||||
|
|
||||||
|
|
||||||
def isDescriptorLine(lineStr: str) -> bool:
|
def isDescriptorLine(lineStr: str) -> bool:
|
||||||
descriptorLineId = r"^layout\(set"
|
descriptorLineId = r"^layout\(set"
|
||||||
return re.search(descriptorLineId, lineStr) is not None
|
return re.search(descriptorLineId, lineStr) is not None
|
||||||
|
|
||||||
|
|
||||||
def isTileSizeLine(lineStr: str) -> bool:
|
def isTileSizeLine(lineStr: str) -> bool:
|
||||||
tile_size_id = r"^ \* TILE_SIZE = \("
|
tile_size_id = r"^ \* TILE_SIZE = \("
|
||||||
return re.search(tile_size_id, lineStr) is not None
|
return re.search(tile_size_id, lineStr) is not None
|
||||||
|
|
||||||
|
|
||||||
def findTileSizes(lineStr: str) -> List[int]:
|
def findTileSizes(lineStr: str) -> List[int]:
|
||||||
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
|
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
|
||||||
matches = re.search(tile_size_id, lineStr)
|
matches = re.search(tile_size_id, lineStr)
|
||||||
@ -146,10 +152,12 @@ def findTileSizes(lineStr: str) -> List[int]:
|
|||||||
raise AssertionError("matches is None in findTileSizes")
|
raise AssertionError("matches is None in findTileSizes")
|
||||||
return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
|
return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
|
||||||
|
|
||||||
|
|
||||||
def isWeightStorageTypeLine(lineStr: str) -> bool:
|
def isWeightStorageTypeLine(lineStr: str) -> bool:
|
||||||
weight_storage_id = r"^ \* WEIGHT_STORAGE = "
|
weight_storage_id = r"^ \* WEIGHT_STORAGE = "
|
||||||
return re.search(weight_storage_id, lineStr) is not None
|
return re.search(weight_storage_id, lineStr) is not None
|
||||||
|
|
||||||
|
|
||||||
def getWeightStorageType(lineStr: str) -> str:
|
def getWeightStorageType(lineStr: str) -> str:
|
||||||
weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
|
weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
|
||||||
matches = re.search(weight_storage_id, lineStr)
|
matches = re.search(weight_storage_id, lineStr)
|
||||||
@ -157,10 +165,12 @@ def getWeightStorageType(lineStr: str) -> str:
|
|||||||
raise AssertionError("matches is None in getWeightStorageType")
|
raise AssertionError("matches is None in getWeightStorageType")
|
||||||
return matches.group(1)
|
return matches.group(1)
|
||||||
|
|
||||||
|
|
||||||
def isBiasStorageTypeLine(lineStr: str) -> bool:
|
def isBiasStorageTypeLine(lineStr: str) -> bool:
|
||||||
weight_storage_id = r"^ \* BIAS_STORAGE = "
|
weight_storage_id = r"^ \* BIAS_STORAGE = "
|
||||||
return re.search(weight_storage_id, lineStr) is not None
|
return re.search(weight_storage_id, lineStr) is not None
|
||||||
|
|
||||||
|
|
||||||
def getBiasStorageType(lineStr: str) -> str:
|
def getBiasStorageType(lineStr: str) -> str:
|
||||||
weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
|
weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
|
||||||
matches = re.search(weight_storage_id, lineStr)
|
matches = re.search(weight_storage_id, lineStr)
|
||||||
@ -168,11 +178,15 @@ def getBiasStorageType(lineStr: str) -> str:
|
|||||||
raise AssertionError("matches is None in getBiasStorageType")
|
raise AssertionError("matches is None in getBiasStorageType")
|
||||||
return matches.group(1)
|
return matches.group(1)
|
||||||
|
|
||||||
|
|
||||||
def isRegisterForLine(lineStr: str) -> bool:
|
def isRegisterForLine(lineStr: str) -> bool:
|
||||||
# Check for Shader Name and a list of at least one Registry Key
|
# Check for Shader Name and a list of at least one Registry Key
|
||||||
register_for_id = r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
|
register_for_id = (
|
||||||
|
r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
|
||||||
|
)
|
||||||
return re.search(register_for_id, lineStr) is not None
|
return re.search(register_for_id, lineStr) is not None
|
||||||
|
|
||||||
|
|
||||||
def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
|
def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
|
||||||
register_for_pattern = r"'([A-Za-z0-9_]+)'"
|
register_for_pattern = r"'([A-Za-z0-9_]+)'"
|
||||||
matches = re.findall(register_for_pattern, lineStr)
|
matches = re.findall(register_for_pattern, lineStr)
|
||||||
@ -181,6 +195,7 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
|
|||||||
matches_list = list(matches)
|
matches_list = list(matches)
|
||||||
return (matches_list[0], matches_list[1:])
|
return (matches_list[0], matches_list[1:])
|
||||||
|
|
||||||
|
|
||||||
typeIdMapping = {
|
typeIdMapping = {
|
||||||
r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE",
|
r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE",
|
||||||
r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER",
|
r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER",
|
||||||
@ -189,12 +204,13 @@ typeIdMapping = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
storageTypeToEnum = {
|
storageTypeToEnum = {
|
||||||
"TEXTURE_2D" : "api::StorageType::TEXTURE_2D",
|
"TEXTURE_2D": "api::StorageType::TEXTURE_2D",
|
||||||
"TEXTURE_3D" : "api::StorageType::TEXTURE_3D",
|
"TEXTURE_3D": "api::StorageType::TEXTURE_3D",
|
||||||
"BUFFER" : "api::StorageType::BUFFER",
|
"BUFFER": "api::StorageType::BUFFER",
|
||||||
"": "api::StorageType::UNKNOWN",
|
"": "api::StorageType::UNKNOWN",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def determineDescriptorType(lineStr: str) -> str:
|
def determineDescriptorType(lineStr: str) -> str:
|
||||||
for identifier, typeNum in typeIdMapping.items():
|
for identifier, typeNum in typeIdMapping.items():
|
||||||
if re.search(identifier, lineStr):
|
if re.search(identifier, lineStr):
|
||||||
@ -203,6 +219,7 @@ def determineDescriptorType(lineStr: str) -> str:
|
|||||||
"No matching descriptor type for " + lineStr + " in determineDescriptorType"
|
"No matching descriptor type for " + lineStr + " in determineDescriptorType"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def getShaderInfo(srcFilePath: str) -> ShaderInfo:
|
def getShaderInfo(srcFilePath: str) -> ShaderInfo:
|
||||||
shader_info = ShaderInfo([], [], "")
|
shader_info = ShaderInfo([], [], "")
|
||||||
with open(srcFilePath) as srcFile:
|
with open(srcFilePath) as srcFile:
|
||||||
@ -220,9 +237,10 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
|
|||||||
|
|
||||||
return shader_info
|
return shader_info
|
||||||
|
|
||||||
|
|
||||||
def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
|
def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
|
||||||
template_dir_path = os.path.join(src_dir_path, "templates")
|
template_dir_path = os.path.join(src_dir_path, "templates")
|
||||||
vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True)
|
vexs = glob.glob(os.path.join(template_dir_path, "**", "*.yaml"), recursive=True)
|
||||||
parameter_yaml_files = []
|
parameter_yaml_files = []
|
||||||
for f in vexs:
|
for f in vexs:
|
||||||
if len(f) > 1:
|
if len(f) > 1:
|
||||||
@ -231,7 +249,7 @@ def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
|
|||||||
for params_yaml in parameter_yaml_files:
|
for params_yaml in parameter_yaml_files:
|
||||||
generator.add_params_yaml(params_yaml) # type: ignore[no-untyped-call]
|
generator.add_params_yaml(params_yaml) # type: ignore[no-untyped-call]
|
||||||
|
|
||||||
vexs = glob.glob(os.path.join(src_dir_path, '**', '*.glslt'), recursive=True)
|
vexs = glob.glob(os.path.join(src_dir_path, "**", "*.glslt"), recursive=True)
|
||||||
templateSrcPaths = []
|
templateSrcPaths = []
|
||||||
for f in vexs:
|
for f in vexs:
|
||||||
if len(f) > 1:
|
if len(f) > 1:
|
||||||
@ -258,7 +276,7 @@ def genCppH(
|
|||||||
templateSrcPaths = []
|
templateSrcPaths = []
|
||||||
|
|
||||||
for srcDirPath in srcDirPaths:
|
for srcDirPath in srcDirPaths:
|
||||||
vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True)
|
vexs = glob.glob(os.path.join(srcDirPath, "**", "*.glsl"), recursive=True)
|
||||||
for f in vexs:
|
for f in vexs:
|
||||||
if len(f) > 1:
|
if len(f) > 1:
|
||||||
templateSrcPaths.append(f)
|
templateSrcPaths.append(f)
|
||||||
@ -267,7 +285,7 @@ def genCppH(
|
|||||||
# Now add glsl files that are generated from templates
|
# Now add glsl files that are generated from templates
|
||||||
genGLSLFromGLSLT(srcDirPath, tmpDirPath)
|
genGLSLFromGLSLT(srcDirPath, tmpDirPath)
|
||||||
|
|
||||||
vexs = glob.glob(os.path.join(tmpDirPath, '**', '*.glsl'), recursive=True)
|
vexs = glob.glob(os.path.join(tmpDirPath, "**", "*.glsl"), recursive=True)
|
||||||
for f in vexs:
|
for f in vexs:
|
||||||
if len(f) > 1:
|
if len(f) > 1:
|
||||||
templateSrcPaths.append(f)
|
templateSrcPaths.append(f)
|
||||||
@ -283,17 +301,20 @@ def genCppH(
|
|||||||
codeTemplate = CodeTemplate.from_file(templateSrcPath)
|
codeTemplate = CodeTemplate.from_file(templateSrcPath)
|
||||||
srcPath = tmpDirPath + "/" + name + ".glsl"
|
srcPath = tmpDirPath + "/" + name + ".glsl"
|
||||||
content = codeTemplate.substitute(env)
|
content = codeTemplate.substitute(env)
|
||||||
with open(srcPath, 'w') as fw:
|
with open(srcPath, "w") as fw:
|
||||||
fw.write(content)
|
fw.write(content)
|
||||||
|
|
||||||
spvPath = tmpDirPath + "/" + name + ".spv"
|
spvPath = tmpDirPath + "/" + name + ".spv"
|
||||||
print(f"spvPath {spvPath}")
|
print(f"spvPath {spvPath}")
|
||||||
|
|
||||||
cmd = [
|
cmd = [
|
||||||
glslcPath, "-fshader-stage=compute",
|
glslcPath,
|
||||||
srcPath, "-o", spvPath,
|
"-fshader-stage=compute",
|
||||||
|
srcPath,
|
||||||
|
"-o",
|
||||||
|
spvPath,
|
||||||
"--target-env=vulkan1.0",
|
"--target-env=vulkan1.0",
|
||||||
"-Werror"
|
"-Werror",
|
||||||
] + [arg for srcDirPath in srcDirPaths for arg in ["-I", srcDirPath]]
|
] + [arg for srcDirPath in srcDirPaths for arg in ["-I", srcDirPath]]
|
||||||
|
|
||||||
print("\nglslc cmd:", cmd)
|
print("\nglslc cmd:", cmd)
|
||||||
@ -323,7 +344,9 @@ def genCppH(
|
|||||||
h += "extern const ShaderListing shader_infos;\n"
|
h += "extern const ShaderListing shader_infos;\n"
|
||||||
h += "extern ShaderRegistry shader_registry;\n"
|
h += "extern ShaderRegistry shader_registry;\n"
|
||||||
h += "inline const ShaderListing& get_shader_infos() {\n return shader_infos;\n}\n"
|
h += "inline const ShaderListing& get_shader_infos() {\n return shader_infos;\n}\n"
|
||||||
h += "inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n}\n"
|
h += (
|
||||||
|
"inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n}\n"
|
||||||
|
)
|
||||||
|
|
||||||
h += nsend
|
h += nsend
|
||||||
|
|
||||||
@ -341,8 +364,8 @@ def genCppH(
|
|||||||
name = getName(spvPath).replace("_spv", "")
|
name = getName(spvPath).replace("_spv", "")
|
||||||
|
|
||||||
print(f"spvPath:{spvPath}")
|
print(f"spvPath:{spvPath}")
|
||||||
with open(spvPath, 'rb') as fr:
|
with open(spvPath, "rb") as fr:
|
||||||
next_bin = array.array('I', fr.read())
|
next_bin = array.array("I", fr.read())
|
||||||
sizeBytes = 4 * len(next_bin)
|
sizeBytes = 4 * len(next_bin)
|
||||||
shader_info_bin_code.append(
|
shader_info_bin_code.append(
|
||||||
"const uint32_t {}_bin[] = {{\n{}\n}};".format(
|
"const uint32_t {}_bin[] = {{\n{}\n}};".format(
|
||||||
@ -362,7 +385,7 @@ def genCppH(
|
|||||||
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
|
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
|
||||||
|
|
||||||
shader_info_args = [
|
shader_info_args = [
|
||||||
f"\"vulkan.{name}\"",
|
f'"vulkan.{name}"',
|
||||||
f"{name}_bin",
|
f"{name}_bin",
|
||||||
str(sizeBytes),
|
str(sizeBytes),
|
||||||
shader_info_layouts,
|
shader_info_layouts,
|
||||||
@ -373,7 +396,7 @@ def genCppH(
|
|||||||
|
|
||||||
shader_info_cpp_code.append(
|
shader_info_cpp_code.append(
|
||||||
textwrap.indent(
|
textwrap.indent(
|
||||||
"{{\"{}\",\n api::ShaderInfo(\n{})}}".format(
|
'{{"{}",\n api::ShaderInfo(\n{})}}'.format(
|
||||||
name,
|
name,
|
||||||
textwrap.indent(",\n".join(shader_info_args), " "),
|
textwrap.indent(",\n".join(shader_info_args), " "),
|
||||||
),
|
),
|
||||||
@ -386,7 +409,7 @@ def genCppH(
|
|||||||
for registry_key in registry_keys:
|
for registry_key in registry_keys:
|
||||||
shader_info_registry_code.append(
|
shader_info_registry_code.append(
|
||||||
textwrap.indent(
|
textwrap.indent(
|
||||||
f"{{\"{op_name}\", {{{{\"{registry_key}\", \"{name}\"}}}}}}",
|
f'{{"{op_name}", {{{{"{registry_key}", "{name}"}}}}}}',
|
||||||
" ",
|
" ",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -421,34 +444,20 @@ def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
|
|||||||
|
|
||||||
|
|
||||||
def main(argv: List[str]) -> int:
|
def main(argv: List[str]) -> int:
|
||||||
parser = argparse.ArgumentParser(description='')
|
parser = argparse.ArgumentParser(description="")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-i',
|
"-i",
|
||||||
'--glsl-paths',
|
"--glsl-paths",
|
||||||
nargs='+',
|
nargs="+",
|
||||||
help='List of paths to look for GLSL source files, separated by spaces. Ex: --glsl-paths "path1 path2 path3"',
|
help='List of paths to look for GLSL source files, separated by spaces. Ex: --glsl-paths "path1 path2 path3"',
|
||||||
default=['.'],
|
default=["."],
|
||||||
)
|
)
|
||||||
|
parser.add_argument("-c", "--glslc-path", required=True, help="")
|
||||||
|
parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp")
|
||||||
|
parser.add_argument("-o", "--output-path", required=True, help="")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-c',
|
"--env", metavar="KEY=VALUE", nargs="*", help="Set a number of key-value pairs"
|
||||||
'--glslc-path',
|
)
|
||||||
required=True,
|
|
||||||
help='')
|
|
||||||
parser.add_argument(
|
|
||||||
'-t',
|
|
||||||
'--tmp-dir-path',
|
|
||||||
required=True,
|
|
||||||
help='/tmp')
|
|
||||||
parser.add_argument(
|
|
||||||
'-o',
|
|
||||||
'--output-path',
|
|
||||||
required=True,
|
|
||||||
help='')
|
|
||||||
parser.add_argument(
|
|
||||||
"--env",
|
|
||||||
metavar="KEY=VALUE",
|
|
||||||
nargs='*',
|
|
||||||
help="Set a number of key-value pairs")
|
|
||||||
options = parser.parse_args()
|
options = parser.parse_args()
|
||||||
env = DEFAULT_ENV
|
env = DEFAULT_ENV
|
||||||
for key, value in parse_arg_env(options.env).items():
|
for key, value in parse_arg_env(options.env).items():
|
||||||
@ -466,9 +475,15 @@ def main(argv: List[str]) -> int:
|
|||||||
srcDirPaths=options.glsl_paths,
|
srcDirPaths=options.glsl_paths,
|
||||||
glslcPath=options.glslc_path,
|
glslcPath=options.glslc_path,
|
||||||
tmpDirPath=options.tmp_dir_path,
|
tmpDirPath=options.tmp_dir_path,
|
||||||
env=env)
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
def invoke_main() -> None:
|
||||||
sys.exit(main(sys.argv))
|
sys.exit(main(sys.argv))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
invoke_main() # pragma: no cover
|
||||||
|
@ -3,7 +3,7 @@ import os
|
|||||||
import os.path
|
import os.path
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--input-file")
|
parser.add_argument("--input-file")
|
||||||
parser.add_argument("--output-file")
|
parser.add_argument("--output-file")
|
||||||
@ -22,3 +22,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
with open(output_file, "w") as f:
|
with open(output_file, "w") as f:
|
||||||
f.write(contents)
|
f.write(contents)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main() # pragma: no cover
|
||||||
|
@ -26,10 +26,10 @@ import functools
|
|||||||
import itertools
|
import itertools
|
||||||
import marshal
|
import marshal
|
||||||
import os
|
import os
|
||||||
|
import types
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
import types
|
|
||||||
|
|
||||||
|
|
||||||
PATH_MARKER = "<Generated by torch::deploy>"
|
PATH_MARKER = "<Generated by torch::deploy>"
|
||||||
@ -121,10 +121,10 @@ class Freezer:
|
|||||||
|
|
||||||
Shared frozen modules evenly across the files.
|
Shared frozen modules evenly across the files.
|
||||||
"""
|
"""
|
||||||
bytecode_file_names = [
|
bytecode_file_names = [f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)]
|
||||||
f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)
|
bytecode_files = [
|
||||||
|
open(os.path.join(install_root, name), "w") for name in bytecode_file_names
|
||||||
]
|
]
|
||||||
bytecode_files = [open(os.path.join(install_root, name), "w") for name in bytecode_file_names]
|
|
||||||
it = itertools.cycle(bytecode_files)
|
it = itertools.cycle(bytecode_files)
|
||||||
for m in self.frozen_modules:
|
for m in self.frozen_modules:
|
||||||
self.write_frozen(m, next(it))
|
self.write_frozen(m, next(it))
|
||||||
@ -202,7 +202,6 @@ class Freezer:
|
|||||||
module_parent = normalized_path.parent.parts
|
module_parent = normalized_path.parent.parts
|
||||||
return list(module_parent) + [module_basename]
|
return list(module_parent) + [module_basename]
|
||||||
|
|
||||||
|
|
||||||
def compile_string(self, file_content: str) -> types.CodeType:
|
def compile_string(self, file_content: str) -> types.CodeType:
|
||||||
# instead of passing in the real build time path to 'compile', we
|
# instead of passing in the real build time path to 'compile', we
|
||||||
# pass in a marker instead. This prevents the build time path being
|
# pass in a marker instead. This prevents the build time path being
|
||||||
@ -239,19 +238,26 @@ class Freezer:
|
|||||||
|
|
||||||
bytecode = marshal.dumps(co)
|
bytecode = marshal.dumps(co)
|
||||||
size = len(bytecode)
|
size = len(bytecode)
|
||||||
if path.name == '__init__.py':
|
if path.name == "__init__.py":
|
||||||
# Python packages are signified by negative size.
|
# Python packages are signified by negative size.
|
||||||
size = -size
|
size = -size
|
||||||
self.frozen_modules.append(
|
self.frozen_modules.append(
|
||||||
FrozenModule(".".join(module_qualname), c_name, size, bytecode)
|
FrozenModule(".".join(module_qualname), c_name, size, bytecode)
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser(description="Compile py source")
|
parser = argparse.ArgumentParser(description="Compile py source")
|
||||||
parser.add_argument("paths", nargs="*", help="Paths to freeze.")
|
parser.add_argument("paths", nargs="*", help="Paths to freeze.")
|
||||||
parser.add_argument("--verbose", action="store_true", help="Print debug logs")
|
parser.add_argument("--verbose", action="store_true", help="Print debug logs")
|
||||||
parser.add_argument("--install-dir", "--install_dir", help="Root directory for all output files")
|
parser.add_argument(
|
||||||
parser.add_argument("--oss", action="store_true", help="If it's OSS build, add a fake _PyImport_FrozenModules")
|
"--install-dir", "--install_dir", help="Root directory for all output files"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--oss",
|
||||||
|
action="store_true",
|
||||||
|
help="If it's OSS build, add a fake _PyImport_FrozenModules",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--symbol-name",
|
"--symbol-name",
|
||||||
"--symbol_name",
|
"--symbol_name",
|
||||||
@ -265,7 +271,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
for p in args.paths:
|
for p in args.paths:
|
||||||
path = Path(p)
|
path = Path(p)
|
||||||
if path.is_dir() and not Path.exists(path / '__init__.py'):
|
if path.is_dir() and not Path.exists(path / "__init__.py"):
|
||||||
# this 'top level path p' is a standard directory containing modules,
|
# this 'top level path p' is a standard directory containing modules,
|
||||||
# not a module itself
|
# not a module itself
|
||||||
# each 'mod' could be a dir containing __init__.py or .py file
|
# each 'mod' could be a dir containing __init__.py or .py file
|
||||||
@ -277,3 +283,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
f.write_bytecode(args.install_dir)
|
f.write_bytecode(args.install_dir)
|
||||||
f.write_main(args.install_dir, args.oss, args.symbol_name)
|
f.write_main(args.install_dir, args.oss, args.symbol_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main() # pragma: no cover
|
||||||
|
@ -25,11 +25,15 @@ DENY_LIST = [
|
|||||||
"_bootstrap_external.py",
|
"_bootstrap_external.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
strip_file_dir = ""
|
||||||
|
|
||||||
|
|
||||||
def remove_prefix(text, prefix):
|
def remove_prefix(text, prefix):
|
||||||
if text.startswith(prefix):
|
if text.startswith(prefix):
|
||||||
return text[len(prefix):]
|
return text[len(prefix) :]
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def write_to_zip(file_path, strip_file_path, zf, prepend_str=""):
|
def write_to_zip(file_path, strip_file_path, zf, prepend_str=""):
|
||||||
stripped_file_path = prepend_str + remove_prefix(file_path, strip_file_dir + "/")
|
stripped_file_path = prepend_str + remove_prefix(file_path, strip_file_dir + "/")
|
||||||
path = Path(stripped_file_path)
|
path = Path(stripped_file_path)
|
||||||
@ -37,28 +41,45 @@ def write_to_zip(file_path, strip_file_path, zf, prepend_str=""):
|
|||||||
return
|
return
|
||||||
zf.write(file_path, stripped_file_path)
|
zf.write(file_path, stripped_file_path)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
def main() -> None:
|
||||||
|
global strip_file_dir
|
||||||
parser = argparse.ArgumentParser(description="Zip py source")
|
parser = argparse.ArgumentParser(description="Zip py source")
|
||||||
parser.add_argument("paths", nargs="*", help="Paths to zip.")
|
parser.add_argument("paths", nargs="*", help="Paths to zip.")
|
||||||
parser.add_argument("--install-dir", "--install_dir", help="Root directory for all output files")
|
|
||||||
parser.add_argument("--strip-dir", "--strip_dir", help="The absolute directory we want to remove from zip")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prepend-str", "--prepend_str", help="A string to prepend onto all paths of a file in the zip", default=""
|
"--install-dir", "--install_dir", help="Root directory for all output files"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--strip-dir",
|
||||||
|
"--strip_dir",
|
||||||
|
help="The absolute directory we want to remove from zip",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prepend-str",
|
||||||
|
"--prepend_str",
|
||||||
|
help="A string to prepend onto all paths of a file in the zip",
|
||||||
|
default="",
|
||||||
)
|
)
|
||||||
parser.add_argument("--zip-name", "--zip_name", help="Output zip name")
|
parser.add_argument("--zip-name", "--zip_name", help="Output zip name")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
zip_file_name = args.install_dir + '/' + args.zip_name
|
zip_file_name = args.install_dir + "/" + args.zip_name
|
||||||
strip_file_dir = args.strip_dir
|
strip_file_dir = args.strip_dir
|
||||||
prepend_str = args.prepend_str
|
prepend_str = args.prepend_str
|
||||||
zf = ZipFile(zip_file_name, mode='w')
|
zf = ZipFile(zip_file_name, mode="w")
|
||||||
|
|
||||||
for p in sorted(args.paths):
|
for p in sorted(args.paths):
|
||||||
if os.path.isdir(p):
|
if os.path.isdir(p):
|
||||||
files = glob.glob(p + "/**/*.py", recursive=True)
|
files = glob.glob(p + "/**/*.py", recursive=True)
|
||||||
for file_path in sorted(files):
|
for file_path in sorted(files):
|
||||||
# strip the absolute path
|
# strip the absolute path
|
||||||
write_to_zip(file_path, strip_file_dir + "/", zf, prepend_str=prepend_str)
|
write_to_zip(
|
||||||
|
file_path, strip_file_dir + "/", zf, prepend_str=prepend_str
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
write_to_zip(p, strip_file_dir + "/", zf, prepend_str=prepend_str)
|
write_to_zip(p, strip_file_dir + "/", zf, prepend_str=prepend_str)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main() # pragma: no cover
|
||||||
|
Reference in New Issue
Block a user