mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76275 In preparation for addressing https://github.com/pytorch/pytorch/issues/73212 Diff was generated with: ``` git mv tools/codegen torchgen git grep -l 'tools.codegen' | xargs sed -i 's/tools.codegen/torchgen/g' sed -i "s/\${TOOLS_PATH}\/codegen/\${TORCH_ROOT}\/torchgen/g" caffe2/CMakeLists.txt ``` and a manual edits to: * tools/test/test_gen_backend_stubs.py * torchgen/build.bzl * torchgen/gen_backend_stubs.py aka this diff: ``` diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py index 3dc26c6d2d..104054575e 100644 --- a/tools/test/test_gen_backend_stubs.py +++ b/tools/test/test_gen_backend_stubs.py @@ -9,7 +9,7 @@ from torchgen.gen_backend_stubs import run from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE # noqa: F401 path = os.path.dirname(os.path.realpath(__file__)) -gen_backend_stubs_path = os.path.join(path, '../torchgen/gen_backend_stubs.py') +gen_backend_stubs_path = os.path.join(path, '../../torchgen/gen_backend_stubs.py') # gen_backend_stubs.py is an integration point that is called directly by external backends. # The tests here are to confirm that badly formed inputs result in reasonable error messages. diff --git a/torchgen/build.bzl b/torchgen/build.bzl index ed04e35a43..d00078a3cf 100644 --- a/torchgen/build.bzl +++ b/torchgen/build.bzl @@ -1,6 +1,6 @@ def define_targets(rules): rules.py_library( - name = "codegen", + name = "torchgen", srcs = rules.glob(["**/*.py"]), deps = [ rules.requirement("PyYAML"), @@ -11,6 +11,6 @@ def define_targets(rules): rules.py_binary( name = "gen", - srcs = [":codegen"], + srcs = [":torchgen"], visibility = ["//visibility:public"], ) diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index c1a672a655..beee7a15e0 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -474,7 +474,7 @@ def run( ) -> None: # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py - pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute() + pytorch_root = pathlib.Path(__file__).parent.parent.absolute() template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates") def make_file_manager(install_dir: str) -> FileManager: ``` run_all_fbandroid_tests Test Plan: sandcastle Reviewed By: albanD, ngimel Differential Revision: D35770317 fbshipit-source-id: 153ac4a7fef15b1e750812a90bfafdbc8f1ebcdf (cherry picked from commit c6d485d1d4648fa1c8a4c14c5bf3d8e899b9b4dd)
326 lines
12 KiB
Python
Executable File
326 lines
12 KiB
Python
Executable File
#!/bin/env python3
|
|
|
|
# Copyright (c) 2016-present, Facebook, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
##############################################################################
|
|
|
|
import sys
|
|
import yaml
|
|
import argparse
|
|
import os
|
|
from copy import deepcopy
|
|
from typing import Dict, List, Set
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--template_dir", default=".", help="where template.h is")
|
|
parser.add_argument("--yaml_dir", default="aten/src/ATen/ATen",
|
|
help="where ATen yaml files are")
|
|
parser.add_argument("--output_prefix", default="", help="")
|
|
parser.add_argument(
|
|
"--install_dir", default=".", help="where to put generated file")
|
|
parser.add_argument("--aten_root", default="", help="root directory of aten")
|
|
args, _ = parser.parse_known_args()
|
|
|
|
if args.aten_root:
|
|
if not os.path.exists(args.aten_root):
|
|
raise ValueError('aten_root ({}) does not exist'.format(
|
|
args.aten_root))
|
|
sys.path.insert(0, os.path.join(args.aten_root, '..'))
|
|
from torchgen.code_template import CodeTemplate as CT
|
|
else:
|
|
from torchgen.code_template import CodeTemplate as CT
|
|
|
|
OP_TEMPLATE = CT.from_file(
|
|
os.path.join(args.template_dir, 'aten_op_template.h'))
|
|
|
|
|
|
try:
|
|
# use faster C loader if available
|
|
from yaml import CSafeLoader as Loader
|
|
except ImportError:
|
|
from yaml import SafeLoader as Loader # type: ignore[misc]
|
|
|
|
|
|
def write(filename, s):
|
|
with open(filename, "w") as f:
|
|
f.write(s)
|
|
|
|
|
|
def read(filename):
|
|
with open(filename, "r") as f:
|
|
return f.read()
|
|
|
|
|
|
def value_has_tensors(v):
|
|
# Sparse shouldn't appear in public API, seems to be temporary bug
|
|
return "Tensor" in v['dynamic_type'] and "Sparse" not in v['dynamic_type']
|
|
|
|
|
|
def value_is_tensor_type(v):
|
|
return value_has_tensors(v) and v['dynamic_type'] not in ['at::TensorList', 'const c10::List<c10::optional<at::Tensor>> &']
|
|
|
|
|
|
# for each aten type, how do we handle a return value of that type?
|
|
RETURN_MAP = {
|
|
'at::Tensor': 'assignTo(Output(${offset}),${output});',
|
|
'at::Scalar': 'assignTo(Output(${offset}),${output}.type(), ${output});',
|
|
'bool': 'assignToValue<int64_t>(Output(${offset}),${output});',
|
|
'int64_t': 'assignToValue<int64_t>(Output(${offset}),${output});',
|
|
'::std::vector<at::Tensor>': 'assignListStartingAt(${offset}, ${output});',
|
|
}
|
|
|
|
# for each non-Tensor aten argument, how to we read it from caffe2's
|
|
# attribute list. Most of these call runtime functions defined in the
|
|
# template class.
|
|
ARGUMENT_MAP = {
|
|
'const at::Scalar &': 'at::Scalar ${arg} = readScalarAttribute("${arg}");',
|
|
'bool': 'bool ${arg} = readAttribute<int64_t>("${arg}");',
|
|
'int': 'int ${arg} = readAttribute<int64_t>("${arg}");',
|
|
'double': 'double ${arg} = readAttribute<float>("${arg}");',
|
|
'int64_t': 'int64_t ${arg} = readAttribute<int64_t>("${arg}");',
|
|
'at::IntArrayRef': 'auto ${arg} = readIntArrayRef("${arg}");',
|
|
'::std::array<bool,2>': 'auto ${arg} = readBoolMask<2>("${arg}");',
|
|
'::std::array<bool,3>': 'auto ${arg} = readBoolMask<3>("${arg}");',
|
|
}
|
|
|
|
# for BC reasons we want to route some of the functions to different
|
|
# implementations
|
|
SPECIAL_IMPLEMENTATIONS = {
|
|
'index': 'internal::index_with_uint8_handling',
|
|
}
|
|
|
|
def expand(o):
|
|
num_defaults = sum(1 if 'default' in arg else 0 for arg in o['arguments'])
|
|
results = [o]
|
|
for i in range(0, num_defaults):
|
|
# last num_default values should be default
|
|
assert('default' in o['arguments'][-(i + 1)])
|
|
v = deepcopy(o)
|
|
v['arguments'] = v['arguments'][:-(i + 1)]
|
|
results.append(v)
|
|
return results
|
|
|
|
|
|
# filter the list of declarations removing things we cannot support
|
|
def supports(o, factory_methods):
|
|
# Ignore all families (!) of functions that have TensorOptions (i.e. tensor factory methods).
|
|
if o['name'] in factory_methods:
|
|
if factory_methods[o['name']] == 0:
|
|
print("Skipping {} because it is a factory method".format(o['name']))
|
|
factory_methods[o['name']] += 1
|
|
return False
|
|
|
|
# skip all in-place operators for now since aten cannot Resize
|
|
# caffe2 memory inside an operator
|
|
if o['inplace']:
|
|
return False
|
|
|
|
# _out variants also work in-place on arguments taken as destinations
|
|
# we also cannot handle these because aten cannot resize caffe2 Tensors
|
|
if "_out" in o['name']:
|
|
return False
|
|
|
|
# skip if no return, previously it is 'void'
|
|
if len(o['returns']) == 0:
|
|
return False
|
|
|
|
# skip return types we cannot handle
|
|
for ret in o['returns']:
|
|
if not value_has_tensors(ret) and ret['type'] not in RETURN_MAP:
|
|
print("Skipping {} Because of Ret: {} ({})".format(
|
|
o['name'], ret['type'], ret['dynamic_type']))
|
|
return False
|
|
|
|
# skip arguments we cannot handle
|
|
for arg in o['arguments']:
|
|
if not value_has_tensors(arg) and arg['type'] not in ARGUMENT_MAP:
|
|
print("Skipping {} Because of Arg: {} ({}) ".format(
|
|
o['name'], arg['type'], arg['dynamic_type']))
|
|
return False
|
|
return True
|
|
|
|
|
|
# template for each potential operator.
|
|
# each operator has an integer 'key' associated with it, and
|
|
# a lambda that defines the operator
|
|
# non-tensor attributes are created in ${initialization}
|
|
# and then saved as arguments to the lambda
|
|
# Inputs/Outputs are read inside the lambda
|
|
#
|
|
# each implementation is defined in a separate method annotated with
|
|
# C10_NOINLINE to avoid inlining into the ATenOp constructor, which would
|
|
# trigger pathological compile times.
|
|
IMPLEMENTATION_TEMPLATE = CT("""\
|
|
C10_NOINLINE void implementation_${key}() { // ${name}
|
|
${initialization}
|
|
run_op = [=] {
|
|
at::AutoDispatchBelowAutograd guard;
|
|
${statements}
|
|
auto the_result = ${invocation};
|
|
${assignments}
|
|
return true;
|
|
};
|
|
}
|
|
""")
|
|
|
|
CASE_TEMPLATE = CT("""\
|
|
case ${key}: // ${name}
|
|
implementation_${key}();
|
|
break;
|
|
""")
|
|
|
|
ASSIGN_CHECK_SIZE_TEMPLATE = CT("""\
|
|
if(OutputSize() > ${offset}) {${assignment}}
|
|
""")
|
|
|
|
|
|
def get_output(o, i):
|
|
if len(o['returns']) == 1:
|
|
return 'the_result'
|
|
else:
|
|
return '::std::get<{}>(the_result)'.format(i)
|
|
|
|
|
|
def attribute_names(o):
|
|
return sorted([a['name'] for a in o['arguments'] if not value_has_tensors(a)])
|
|
|
|
|
|
def required_attribute_names(o):
|
|
return sorted([a['name'] for a in o['arguments'] if not value_has_tensors(a) and 'default' not in a])
|
|
|
|
|
|
def self_as_first_argument(arguments):
|
|
return ([a for a in arguments if a['name'] == 'self'] +
|
|
[a for a in arguments if a['name'] != 'self'])
|
|
|
|
|
|
def get_num_inputs(o):
|
|
args = 0
|
|
for a in o['arguments']:
|
|
if a['type'] in ['at::TensorList', 'const c10::List<c10::optional<at::Tensor>> &']:
|
|
return '*'
|
|
elif value_has_tensors(a):
|
|
args += 1
|
|
return str(args)
|
|
|
|
|
|
def find_factory_methods(decls):
|
|
factory_methods = {}
|
|
for o in decls:
|
|
if any(arg['dynamic_type'] == 'at::TensorOptions' for arg in o['arguments']):
|
|
factory_methods[o['name']] = 0
|
|
return factory_methods
|
|
|
|
|
|
def emit_assignments(o, env):
|
|
for i, r in enumerate(o['returns']):
|
|
t = RETURN_MAP[r['type'] if not value_is_tensor_type(r) else 'at::Tensor']
|
|
assignment = CT(t).substitute(env, offset=i, output=get_output(o, i))
|
|
check_size_assignment = ASSIGN_CHECK_SIZE_TEMPLATE.substitute(env, offset=i, assignment=assignment)
|
|
|
|
env['assignments'].append(check_size_assignment)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
decls = yaml.load(read(os.path.join(args.yaml_dir, 'Declarations.yaml')), Loader=Loader)
|
|
factory_methods = find_factory_methods(decls)
|
|
filtered = [expanded for o in decls for expanded in expand(o) if supports(expanded, factory_methods)]
|
|
top_env: Dict[str, List] = {
|
|
'mappings': [],
|
|
'implementations': [],
|
|
'cases': [],
|
|
}
|
|
seen: Set[str] = set()
|
|
key = 0
|
|
for o in filtered:
|
|
# [DESCRIPTORS]
|
|
# each option is associated with a descriptor string that is used
|
|
# to figure out which version of an op is being used:
|
|
# The format is:
|
|
# opname-num_inputs-attribute_1-attribute2
|
|
# Example:
|
|
# lerp-2-weight
|
|
# the operator lerp takes 2 arguments and has the attribute weight
|
|
attr_names = attribute_names(o)
|
|
num_inputs = get_num_inputs(o)
|
|
descriptor = '-'.join([o['name']] + attr_names + [num_inputs])
|
|
if descriptor in seen:
|
|
continue
|
|
seen.add(descriptor)
|
|
|
|
# map from descriptor string to the integer key in the switch statements
|
|
# that initializes the operators
|
|
top_env['mappings'].append('{{ "{}", {} }},'.format(descriptor, key))
|
|
env = {
|
|
'name': o['name'],
|
|
'statements': [],
|
|
'arguments': [],
|
|
'assignments': [],
|
|
'initialization': [],
|
|
'key': str(key),
|
|
}
|
|
|
|
if 'namespace' not in o['method_of'] and 'Tensor' not in o['method_of']:
|
|
# methods on type like 'ones' or 'zeros' always take a
|
|
# string attribute that is translated into the at::Type object
|
|
# e.g. "Float" is at::kFloat
|
|
assert('Type' in o['method_of'])
|
|
|
|
static_tensor_inputs = sum(arg['type'] not in ['at::TensorList', 'const c10::List<c10::optional<at::Tensor>> &'] and value_is_tensor_type(arg) for arg in o['arguments'])
|
|
has_tensorlist = any(arg['type'] in ['at::TensorList', 'const c10::List<c10::optional<at::Tensor>> &'] for arg in o['arguments'])
|
|
if has_tensorlist:
|
|
tensorlist_idx = [i for i, arg in enumerate(o['arguments']) if arg['type'] in ['at::TensorList', 'const c10::List<c10::optional<at::Tensor>> &']][0]
|
|
|
|
real_inputs = 0
|
|
for i, arg in enumerate(o['arguments']):
|
|
env['arguments'].append(arg['name'])
|
|
# Pretend the flat argument list is a stack where the end is the top.
|
|
view_length = 'InputSize()' if has_tensorlist and i < tensorlist_idx else static_tensor_inputs
|
|
if arg['type'] == 'at::TensorList':
|
|
# NOTE: do not advance real_inputs here. After this we will
|
|
# switch to indexing the "stack" from the end
|
|
env['statements'].append(
|
|
'auto {} = peekSlice({}, InputSize() - {}, InputSize());'
|
|
.format(arg['name'], real_inputs, static_tensor_inputs))
|
|
elif arg['type'] == 'const c10::List<c10::optional<at::Tensor>> &':
|
|
# NOTE: do not advance real_inputs here. After this we will
|
|
# switch to indexing the "stack" from the end
|
|
env['statements'].append(
|
|
'auto {} = peekSliceOptionals({}, InputSize() - {}, InputSize());'
|
|
.format(arg['name'], real_inputs, static_tensor_inputs))
|
|
elif value_is_tensor_type(arg):
|
|
# load tensor inputs from Caffe2
|
|
env['statements'].append(
|
|
'auto {} = peek({}, {});'.format(arg['name'], real_inputs, view_length))
|
|
real_inputs += 1
|
|
else:
|
|
init = CT(ARGUMENT_MAP[arg['type']]).substitute(env, arg=arg['name'])
|
|
env['initialization'].append(init)
|
|
|
|
emit_assignments(o, env)
|
|
|
|
if o['name'] in SPECIAL_IMPLEMENTATIONS:
|
|
env['invocation'] = "{}({})".format(SPECIAL_IMPLEMENTATIONS[o['name']], ','.join(env['arguments']))
|
|
elif 'namespace' in o['method_of']:
|
|
env['invocation'] = CT("at::${name}(${arguments})").substitute(env)
|
|
else:
|
|
assert('Tensor' in o['method_of'])
|
|
env['invocation'] = "self.{}({})".format(
|
|
o['name'], ', '.join(env['arguments'][1:]))
|
|
|
|
top_env['implementations'].append(IMPLEMENTATION_TEMPLATE.substitute(env))
|
|
top_env['cases'].append(CASE_TEMPLATE.substitute(env))
|
|
key += 1
|
|
write(os.path.join(args.install_dir, args.output_prefix + "aten_op.h"), OP_TEMPLATE.substitute(top_env))
|