mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: This PR greatly simplifies `mypy-strict.ini` by strictly typing everything in `.github` and `tools`, rather than picking and choosing only specific files in those two dirs. It also removes `warn_unused_ignores` from `mypy-strict.ini`, for reasons described in https://github.com/pytorch/pytorch/pull/56402#issuecomment-822743795: basically, that setting makes life more difficult depending on what libraries you have installed locally vs in CI (e.g. `ruamel`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/59117 Test Plan: ``` flake8 mypy --config mypy-strict.ini ``` Reviewed By: malfet Differential Revision: D28765386 Pulled By: samestep fbshipit-source-id: 3e744e301c7a464f8a2a2428fcdbad534e231f2e
138 lines
3.5 KiB
Python
138 lines
3.5 KiB
Python
"""
|
|
This util is used to parse op_deps_pass output (in yaml) and convert it into
|
|
other formats for downstream use cases. It is not used by OSS cmake build.
|
|
|
|
To run this file by hand from the root of the PyTorch repository, run:
|
|
|
|
python -m tools.code_analyzer.op_deps_processor \
|
|
--op-dependency build_code_analyzer/work/torch_result.yaml \
|
|
--output pt_deps.bzl
|
|
"""
|
|
|
|
import argparse
|
|
import yaml
|
|
from typing import Any, List
|
|
|
|
from tools.codegen.code_template import CodeTemplate
|
|
|
|
BAZEL_OUTPUT = CodeTemplate("""\
|
|
TORCH_DEPS = {
|
|
${ops}
|
|
}
|
|
""")
|
|
|
|
BAZEL_OP = CodeTemplate("""\
|
|
"${op_name}": [
|
|
${op_deps}
|
|
],
|
|
""")
|
|
|
|
BAZEL_OP_DEP = CodeTemplate("""\
|
|
"${dep_name}",
|
|
""")
|
|
|
|
DOT_OUTPUT = CodeTemplate("""\
|
|
digraph {
|
|
layout="circo";
|
|
${ops}
|
|
}
|
|
""")
|
|
|
|
DOT_OP = CodeTemplate("""\
|
|
${op_deps}
|
|
""")
|
|
|
|
DOT_OP_DEP = CodeTemplate("""\
|
|
"${op_name}" -> "${dep_name}";
|
|
""")
|
|
|
|
|
|
def load_op_deps(fname: str) -> Any:
|
|
with open(fname, 'r') as stream:
|
|
return yaml.safe_load(stream)
|
|
|
|
|
|
def process_base_ops(graph: Any, base_ops: List[str]) -> None:
|
|
# remove base ops from all `depends` lists to compress the output graph
|
|
for op in graph:
|
|
op['depends'] = [
|
|
dep for dep in op.get('depends', []) if dep['name'] not in base_ops
|
|
]
|
|
|
|
# add base ops section at the beginning
|
|
graph.insert(0, {
|
|
'name': '__BASE__',
|
|
'depends': [{'name': name} for name in base_ops]})
|
|
|
|
|
|
def convert(
|
|
fname: str,
|
|
graph: Any,
|
|
output_template: CodeTemplate,
|
|
op_template: CodeTemplate,
|
|
op_dep_template: CodeTemplate,
|
|
) -> None:
|
|
ops = []
|
|
for op in graph:
|
|
op_name = op['name']
|
|
op_deps = []
|
|
|
|
for dep in op.get('depends', []):
|
|
dep_name = dep['name']
|
|
if dep_name == op_name:
|
|
# skip itself reference
|
|
continue
|
|
op_deps.append(
|
|
op_dep_template.substitute(
|
|
op_name=op_name,
|
|
dep_name=dep_name))
|
|
|
|
if not op_deps:
|
|
# skip ops without any fanout
|
|
continue
|
|
|
|
ops.append(
|
|
op_template.substitute(
|
|
op_name=op_name,
|
|
op_deps=op_deps))
|
|
|
|
with open(fname, 'w') as out:
|
|
out.write(output_template.substitute(ops=ops))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description='Util to parse & convert op_deps_pass output')
|
|
parser.add_argument(
|
|
'--op_dependency',
|
|
required=True,
|
|
help='input yaml file of op dependency graph produced by op_deps_pass')
|
|
parser.add_argument(
|
|
'--format',
|
|
default='bazel',
|
|
help='output file format [bazel, dot]')
|
|
parser.add_argument(
|
|
'--base_ops',
|
|
nargs='*',
|
|
help='optional list of `base` ops that should always be kept in '
|
|
'custom build, to make the output stable from trivial changes; '
|
|
'each item is `namespace`::`operator name` without overload; '
|
|
'e.g.: aten::empty aten::size ...')
|
|
parser.add_argument(
|
|
'--output',
|
|
required=True,
|
|
help='output file')
|
|
args = parser.parse_args()
|
|
|
|
deps = load_op_deps(args.op_dependency)
|
|
|
|
if args.base_ops:
|
|
process_base_ops(deps, args.base_ops)
|
|
|
|
if args.format == 'bazel':
|
|
convert(args.output, deps, BAZEL_OUTPUT, BAZEL_OP, BAZEL_OP_DEP)
|
|
elif args.format == 'dot':
|
|
convert(args.output, deps, DOT_OUTPUT, DOT_OP, DOT_OP_DEP)
|
|
else:
|
|
raise Exception("Unknown output format: " + args.format)
|