mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytorch] clean up unused util srcs under tools/autograd (#50611)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50611 Removed the unused old-style code to prevent it from being used. Added all autograd/gen_pyi sources to mypy-strict.ini config. Confirmed byte-for-byte compatible with the old codegen: ``` Run it before and after this PR: .jenkins/pytorch/codegen-test.sh <baseline_output_dir> .jenkins/pytorch/codegen-test.sh <test_output_dir> Then run diff to compare the generated files: diff -Naur <baseline_output_dir> <test_output_dir> ``` Confirmed clean mypy-strict run: ``` mypy --config mypy-strict.ini ``` Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D25929730 Pulled By: ljk53 fbshipit-source-id: 1fc94436fd4a6b9b368ee0736e99bfb3c01d38ef
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b75cdceb44
commit
5252e9857a
@ -430,8 +430,6 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
"${TOOLS_PATH}/autograd/gen_variable_factories.py"
|
||||
"${TOOLS_PATH}/autograd/gen_variable_type.py"
|
||||
"${TOOLS_PATH}/autograd/load_derivatives.py"
|
||||
"${TOOLS_PATH}/autograd/nested_dict.py"
|
||||
"${TOOLS_PATH}/autograd/utils.py"
|
||||
WORKING_DIRECTORY "${TORCH_ROOT}")
|
||||
|
||||
|
||||
|
@ -30,13 +30,8 @@ implicit_reexport = False
|
||||
strict_equality = True
|
||||
|
||||
files = tools/codegen/gen.py,
|
||||
tools/autograd/gen_annotated_fn_args.py,
|
||||
tools/autograd/gen_autograd.py,
|
||||
tools/autograd/gen_python_functions.py,
|
||||
tools/autograd/gen_trace_type.py,
|
||||
tools/autograd/gen_variable_factories.py,
|
||||
tools/autograd/gen_variable_type.py,
|
||||
tools/autograd/load_derivatives.py,
|
||||
tools/autograd/*.py,
|
||||
tools/pyi/*.py,
|
||||
torch/utils/benchmark/utils/common.py,
|
||||
torch/utils/benchmark/utils/timer.py,
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/*.py,
|
||||
|
@ -1,19 +0,0 @@
|
||||
# TODO: refactor nested_dict into common library with ATen
|
||||
class nested_dict(object):
|
||||
"""
|
||||
A nested dict is a dictionary with a parent. If key lookup fails,
|
||||
it recursively continues into the parent. Writes always happen to
|
||||
the top level dict.
|
||||
"""
|
||||
|
||||
def __init__(self, base, parent):
|
||||
self.base, self.parent = base, parent
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self.base or item in self.parent
|
||||
|
||||
def __getitem__(self, x):
|
||||
r = self.base.get(x)
|
||||
if r is not None:
|
||||
return r
|
||||
return self.parent[x]
|
@ -1,114 +0,0 @@
|
||||
import re
|
||||
import os
|
||||
import yaml
|
||||
from .nested_dict import nested_dict
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
__all__ = [
|
||||
'CodeTemplate', 'IDENT_REGEX', 'YamlLoader', 'nested_dict',
|
||||
'split_name_params', 'write',
|
||||
]
|
||||
|
||||
from tools.codegen.code_template import CodeTemplate
|
||||
|
||||
# You should use these lines, rather than doing it manually.
|
||||
# Especially if you see this error!
|
||||
#
|
||||
# File "/usr/local/lib/python2.7/dist-packages/yaml/__init__.py", line 69, in load
|
||||
# loader = Loader(stream)
|
||||
# TypeError: 'module' object is not callable
|
||||
try:
|
||||
# use faster C loader if available
|
||||
from yaml import CLoader as YamlLoader
|
||||
except ImportError:
|
||||
from yaml import Loader as YamlLoader
|
||||
|
||||
GENERATED_COMMENT = CodeTemplate(
|
||||
"@" + "generated from ${filename}")
|
||||
|
||||
# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
|
||||
# occurrence of a parameter in the derivative formula
|
||||
IDENT_REGEX = r'(^|\W){}($|\W)'
|
||||
|
||||
|
||||
# TODO: Use a real parser here; this will get bamboozled
|
||||
# by signatures that contain things like std::array<bool, 2> (note the space)
|
||||
def split_name_params(prototype):
|
||||
name, overload_name, params = re.match(r'(\w+)(\.\w+)?\((.*)\)', prototype).groups()
|
||||
return name, params.split(', ')
|
||||
|
||||
|
||||
# When tracing, we record inplace operations as out-of-place operations,
|
||||
# because we don't have a story for side effects in the IR yet.
|
||||
#
|
||||
# Doing this un-inplacing is a little delicate however; __and__ is NOT inplace!
|
||||
# TODO: Do something more robust
|
||||
def uninplace_api_name(api_name):
|
||||
if api_name.endswith('_') and not api_name.endswith('__'):
|
||||
api_name = api_name[:-1]
|
||||
return unout_api_name(api_name)
|
||||
|
||||
def make_out_api_name_faithful(api_name):
|
||||
# Variable kernel needs to call the _outf overload instead of the _out overload
|
||||
# because the _outf overload matches the argument order as it's passed into
|
||||
# the variable kernel
|
||||
if api_name.endswith('_out'):
|
||||
api_name = api_name + 'f'
|
||||
return api_name
|
||||
|
||||
|
||||
def write(dirname: str, name: str, template: CodeTemplate, env: Dict[str, List[str]]) -> None:
|
||||
env['generated_comment'] = GENERATED_COMMENT.substitute(filename=template.filename)
|
||||
path = os.path.join(dirname, name)
|
||||
# See Note [Unchanging results for ninja]
|
||||
try:
|
||||
with open(path, 'r') as f:
|
||||
old_val = f.read()
|
||||
except IOError:
|
||||
old_val = None
|
||||
new_val = template.substitute(env)
|
||||
if old_val != new_val:
|
||||
with open(path, 'w') as f:
|
||||
print("Writing {}".format(path))
|
||||
f.write(new_val)
|
||||
else:
|
||||
print("Skipped writing {}".format(path))
|
||||
|
||||
def is_out_variant(decl):
|
||||
return decl['name'].endswith('_out')
|
||||
|
||||
def op_name_with_overload(decl):
|
||||
return decl['operator_name_with_overload']
|
||||
|
||||
def load_op_list_and_strip_overload(op_list, op_list_path):
|
||||
if op_list is None and op_list_path is None:
|
||||
return None
|
||||
if op_list is None:
|
||||
op_list = []
|
||||
if op_list_path is not None:
|
||||
with open(op_list_path, 'r') as f:
|
||||
op_list += yaml.load(f, Loader=YamlLoader)
|
||||
# strip out the overload part
|
||||
return {opname.split('.', 1)[0] for opname in op_list}
|
||||
|
||||
def is_output(arg):
|
||||
return arg.get('output', False)
|
||||
|
||||
def has_outputs(declaration):
|
||||
return any([is_output(arg) for arg in declaration['arguments']])
|
||||
|
||||
def op_name(declaration):
|
||||
name = declaration['name']
|
||||
if has_outputs(declaration):
|
||||
if not name.endswith("_out"):
|
||||
raise RuntimeError(
|
||||
'{} has output params, expecting name ending with \'_out\''.
|
||||
format(declaration['name']))
|
||||
return name[:-4]
|
||||
else:
|
||||
if name.endswith("_out"):
|
||||
raise RuntimeError(
|
||||
'{}: name ends with \'_out\', expecting output params'.
|
||||
format(declaration['name']))
|
||||
return name
|
@ -12,7 +12,7 @@ python -m tools.code_analyzer.op_deps_processor \
|
||||
import argparse
|
||||
import yaml
|
||||
|
||||
from ..autograd.utils import CodeTemplate
|
||||
from tools.codegen.code_template import CodeTemplate
|
||||
|
||||
BAZEL_OUTPUT = CodeTemplate("""\
|
||||
TORCH_DEPS = {
|
||||
|
@ -1,5 +1,3 @@
|
||||
|
||||
import os
|
||||
import collections
|
||||
from pprint import pformat
|
||||
|
||||
@ -7,9 +5,9 @@ import argparse
|
||||
|
||||
from tools.codegen.model import *
|
||||
from tools.codegen.api.python import *
|
||||
from tools.codegen.gen import FileManager
|
||||
from typing import Sequence, List, Dict
|
||||
|
||||
from ..autograd.utils import CodeTemplate, write
|
||||
from ..autograd.gen_python_functions import should_generate_py_binding, load_signatures, group_overloads
|
||||
|
||||
"""
|
||||
@ -166,7 +164,7 @@ def sig_for_ops(opname: str) -> List[str]:
|
||||
raise Exception("unknown op", opname)
|
||||
|
||||
def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]:
|
||||
type_hints = []
|
||||
type_hints: List[str] = []
|
||||
|
||||
# Some deprecated ops that are on the blocklist are still included in pyi
|
||||
if sig_group.signature.name in blocklist and not sig_group.signature.deprecated:
|
||||
@ -193,7 +191,7 @@ def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]:
|
||||
|
||||
return type_hints
|
||||
|
||||
def gen_nn_functional(out: str) -> None:
|
||||
def gen_nn_functional(fm: FileManager) -> None:
|
||||
# Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
|
||||
# through an `_add_docstr` call
|
||||
imports = [
|
||||
@ -241,28 +239,22 @@ def gen_nn_functional(out: str) -> None:
|
||||
import_code = ["from .. import {0} as {0}".format(_) for _ in imports]
|
||||
# TODO make these types more precise
|
||||
dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)]
|
||||
stubs = CodeTemplate.from_file(os.path.join('torch', 'nn', 'functional.pyi.in'))
|
||||
env = {
|
||||
fm.write_with_template('torch/nn/functional.pyi', 'torch/nn/functional.pyi.in', lambda: {
|
||||
'imported_hints': import_code,
|
||||
'dispatched_hints': dispatch_code
|
||||
}
|
||||
write(out, 'torch/nn/functional.pyi', stubs, env)
|
||||
'dispatched_hints': dispatch_code,
|
||||
})
|
||||
|
||||
# functional.pyi already contains the definitions for those functions
|
||||
# so, we don't export then to it
|
||||
from_c.extend(['hardtanh', 'leaky_relu', 'hardsigmoid'])
|
||||
dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)]
|
||||
env = {
|
||||
fm.write_with_template('torch/_C/_nn.pyi', 'torch/_C/_nn.pyi.in', lambda: {
|
||||
'imported_hints': import_code,
|
||||
'dispatched_hints': dispatch_code
|
||||
}
|
||||
stubs = CodeTemplate.from_file(os.path.join('torch', '_C', '_nn.pyi.in'))
|
||||
write(out, 'torch/_C/_nn.pyi', stubs, env)
|
||||
'dispatched_hints': dispatch_code,
|
||||
})
|
||||
|
||||
def gen_nn_pyi(out: str) -> None:
|
||||
gen_nn_functional(out)
|
||||
|
||||
def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, out: str) -> None:
|
||||
def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -> None:
|
||||
"""gen_pyi()
|
||||
|
||||
This function generates a pyi file for torch.
|
||||
@ -550,14 +542,19 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, out: str) -> None:
|
||||
'dtype_class_hints': dtype_class_hints,
|
||||
'all_directive': all_directive
|
||||
}
|
||||
TORCH_C_TYPE_STUBS = CodeTemplate.from_file(os.path.join('torch', '_C', '__init__.pyi.in'))
|
||||
TORCH_C_VARIABLE_FUNCTIONS_TYPE_STUBS = \
|
||||
CodeTemplate.from_file(os.path.join('torch', '_C', '_VariableFunctions.pyi.in'))
|
||||
|
||||
write(out, 'torch/_C/__init__.pyi', TORCH_C_TYPE_STUBS, env)
|
||||
write(out, 'torch/_C/_VariableFunctions.pyi', TORCH_C_VARIABLE_FUNCTIONS_TYPE_STUBS, env)
|
||||
write(out, 'torch/_VF.pyi', TORCH_C_VARIABLE_FUNCTIONS_TYPE_STUBS, env)
|
||||
gen_nn_pyi(out)
|
||||
fm.write_with_template('torch/_C/__init__.pyi', 'torch/_C/__init__.pyi.in', lambda: {
|
||||
'generated_comment': '@' + 'generated from torch/_C/__init__.pyi.in',
|
||||
**env,
|
||||
})
|
||||
fm.write_with_template('torch/_C/_VariableFunctions.pyi', 'torch/_C/_VariableFunctions.pyi.in', lambda: {
|
||||
'generated_comment': '@' + 'generated from torch/_C/_VariableFunctions.pyi.in',
|
||||
**env,
|
||||
})
|
||||
fm.write_with_template('torch/_VF.pyi', 'torch/_C/_VariableFunctions.pyi.in', lambda: {
|
||||
'generated_comment': '@' + 'generated from torch/_C/_VariableFunctions.pyi.in',
|
||||
**env,
|
||||
})
|
||||
gen_nn_functional(fm)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@ -573,7 +570,8 @@ def main() -> None:
|
||||
default='.',
|
||||
help='path to output directory')
|
||||
args = parser.parse_args()
|
||||
gen_pyi(args.native_functions_path, args.deprecated_functions_path, args.out)
|
||||
fm = FileManager(install_dir=args.out, template_dir='.', dry_run=False)
|
||||
gen_pyi(args.native_functions_path, args.deprecated_functions_path, fm)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,6 +1,13 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
try:
|
||||
# use faster C loader if available
|
||||
from yaml import CLoader as YamlLoader
|
||||
except ImportError:
|
||||
from yaml import Loader as YamlLoader
|
||||
|
||||
source_files = {'.py', '.cpp', '.h'}
|
||||
|
||||
@ -76,15 +83,16 @@ def generate_code(ninja_global=None,
|
||||
python_install_dir,
|
||||
autograd_dir)
|
||||
|
||||
|
||||
def get_selector_from_legacy_operator_selection_list(
|
||||
selected_op_list_path: str,
|
||||
):
|
||||
from tools.autograd.utils import load_op_list_and_strip_overload
|
||||
|
||||
selected_op_list = load_op_list_and_strip_overload(
|
||||
None,
|
||||
selected_op_list_path,
|
||||
)
|
||||
with open(selected_op_list_path, 'r') as f:
|
||||
# strip out the overload part
|
||||
# It's only for legacy config - do NOT copy this code!
|
||||
selected_op_list = {
|
||||
opname.split('.', 1)[0] for opname in yaml.load(f, Loader=YamlLoader)
|
||||
}
|
||||
|
||||
# Internal build doesn't use this flag any more. Only used by OSS
|
||||
# build now. Every operator should be considered a root operator
|
||||
@ -96,14 +104,11 @@ def get_selector_from_legacy_operator_selection_list(
|
||||
is_used_for_training = True
|
||||
|
||||
from tools.codegen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
selector: SelectiveBuilder = SelectiveBuilder.get_nop_selector()
|
||||
if selected_op_list is not None:
|
||||
selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
||||
selected_op_list,
|
||||
is_root_operator,
|
||||
is_used_for_training,
|
||||
)
|
||||
selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
||||
selected_op_list,
|
||||
is_root_operator,
|
||||
is_used_for_training,
|
||||
)
|
||||
|
||||
return selector
|
||||
|
||||
|
Reference in New Issue
Block a user