[pytorch][codegen] migrate gen_variable_type to new data model (#49735)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49735

This is the final wave of autograd codegen data model migration.

After this PR:
- autograd codegen no longer depends on Declarations.yaml;
- autograd codegen sources are fully type annotated and pass mypy-strict check;

To avoid potential merge conflicts with other pending PRs, some structural
changes are intentionally avoided, e.g. didn't move inner methods out, didn't
change all inner methods to avoid reading outer function's variables, and etc.

Confirmed byte-for-byte compatible with the old codegen:
```
Run it before and after this PR:
  .jenkins/pytorch/codegen-test.sh <baseline_output_dir>
  .jenkins/pytorch/codegen-test.sh <test_output_dir>

Then run diff to compare the generated files:
  diff -Naur <baseline_output_dir> <test_output_dir>
```

Confirmed clean mypy-strict run:
```
mypy --config mypy-strict.ini
```

Test Plan: Imported from OSS

Reviewed By: ezyang, bhosmer

Differential Revision: D25678879

Pulled By: ljk53

fbshipit-source-id: ba6e2eb6b9fb744208f7f79a922d933fcc3bde9f
This commit is contained in:
Jiakai Liu
2021-01-05 14:00:02 -08:00
committed by Facebook GitHub Bot
parent a272a7eeab
commit e71a13e8a3
11 changed files with 497 additions and 454 deletions

View File

@ -31,9 +31,11 @@ strict_equality = True
files = tools/codegen/gen.py,
tools/autograd/gen_annotated_fn_args.py,
tools/autograd/gen_autograd.py,
tools/autograd/gen_python_functions.py,
tools/autograd/gen_trace_type.py,
tools/autograd/gen_variable_factories.py,
tools/autograd/gen_variable_type.py,
tools/autograd/load_derivatives.py,
torch/utils/benchmark/utils/common.py,
torch/utils/benchmark/utils/timer.py,

View File

@ -23,9 +23,6 @@ torch/csrc/autograd/generated/
import argparse
import os
import yaml
import re
from .utils import YamlLoader, op_name_with_overload
from tools.codegen.selective_build.selector import SelectiveBuilder
# See NOTE [ Autograd View Variables ] in variable.h for details.
@ -89,84 +86,14 @@ RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union({
'tensor_split', 'swapdims', 'swapaxes'
})
def format_return_type(returns):
if len(returns) == 0:
return 'void'
elif len(returns) == 1:
return returns[0]['type']
else:
return_types = [r['type'] for r in returns]
return 'std::tuple<{}>'.format(','.join(return_types))
def get_simple_type(arg):
simple_type = arg['type']
simple_type = simple_type.replace(' &', '').replace('const ', '')
simple_type = simple_type.replace('Generator *', 'Generator')
opt_match = re.match(r'c10::optional<(.+)>', simple_type)
if opt_match:
simple_type = '{}?'.format(opt_match.group(1))
return simple_type
def has_tensoroptions_argument(declaration):
for argument in declaration['arguments']:
if 'TensorOptions' == argument['dynamic_type']:
return True
return False
def load_aten_declarations(path):
with open(path, 'r') as f:
declarations = yaml.load(f, Loader=YamlLoader)
# enrich declarations with additional information
selected_declarations = []
for declaration in declarations:
if declaration.get('deprecated'):
continue
for arg in declaration['arguments']:
arg['simple_type'] = get_simple_type(arg)
for arg in declaration['schema_order_arguments']:
arg['simple_type'] = get_simple_type(arg)
for ret in declaration['returns']:
ret['simple_type'] = get_simple_type(ret)
declaration['formals'] = [arg['type'] + ' ' + arg['name']
for arg in declaration['arguments']]
declaration['schema_order_formals'] = [arg['type'] + ' ' + arg['name']
for arg in declaration['schema_order_arguments']]
declaration['args'] = [arg['name'] for arg in declaration['arguments']]
declaration['schema_order_args'] = [arg['name'] for arg in declaration['schema_order_arguments']]
declaration['api_name'] = declaration['name']
if declaration.get('overload_name'):
declaration['type_wrapper_name'] = "{}_{}".format(
declaration['name'], declaration['overload_name'])
else:
declaration['type_wrapper_name'] = declaration['name']
declaration['operator_name_with_overload'] = declaration['schema_string'].split('(')[0]
declaration['unqual_operator_name_with_overload'] = declaration['operator_name_with_overload'].split('::')[1]
declaration['return_type'] = format_return_type(declaration['returns'])
declaration['base_name'] = declaration['name']
selected_declarations.append(declaration)
return selected_declarations
def gen_autograd(aten_path, native_functions_path, out, autograd_dir, operator_selector: SelectiveBuilder, disable_autograd=False):
full_aten_decls = load_aten_declarations(aten_path)
def filter_decls(aten_decls, operator_selector):
def is_operator_selected_for_training(decl):
op_name = op_name_with_overload(decl)
return operator_selector.is_operator_selected_for_training(op_name)
return [decl for decl in aten_decls if is_operator_selected_for_training(decl)]
aten_decls = filter_decls(full_aten_decls, operator_selector)
def gen_autograd(
aten_path: str,
native_functions_path: str,
out: str,
autograd_dir: str,
operator_selector: SelectiveBuilder,
disable_autograd: bool = False,
) -> None:
# Parse and load derivatives.yaml
from .load_derivatives import load_derivatives
differentiability_infos = load_derivatives(
@ -175,13 +102,13 @@ def gen_autograd(aten_path, native_functions_path, out, autograd_dir, operator_s
template_path = os.path.join(autograd_dir, 'templates')
# Generate VariableType.h/cpp
from .gen_trace_type import gen_trace_type
from .gen_variable_type import gen_variable_type
if not disable_autograd:
from .gen_variable_type import gen_variable_type
gen_variable_type(out, aten_decls, differentiability_infos, template_path)
gen_variable_type(out, native_functions_path, differentiability_infos, template_path, operator_selector)
from . import gen_trace_type
# operator filter not applied as tracing sources are excluded in selective build
gen_trace_type.gen_trace_type(out, native_functions_path, template_path)
gen_trace_type(out, native_functions_path, template_path)
# Generate Functions.h/cpp
from .gen_autograd_functions import gen_autograd_functions_lib
@ -193,7 +120,12 @@ def gen_autograd(aten_path, native_functions_path, out, autograd_dir, operator_s
gen_variable_factories(out, native_functions_path, template_path)
def gen_autograd_python(aten_path, native_functions_path, out, autograd_dir):
def gen_autograd_python(
aten_path: str,
native_functions_path: str,
out: str,
autograd_dir: str,
) -> None:
from .load_derivatives import load_derivatives
differentiability_infos = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)
@ -212,7 +144,7 @@ def gen_autograd_python(aten_path, native_functions_path, out, autograd_dir):
out, native_functions_path, deprecated_path, template_path)
def main():
def main() -> None:
parser = argparse.ArgumentParser(
description='Generate autograd C++ files script')
parser.add_argument('declarations', metavar='DECL',

View File

@ -422,7 +422,7 @@ def gen_trace_type_shard(
fm: FileManager, native_functions: Sequence[NativeFunction], suffix: str
) -> None:
fm.write_with_template('TraceType%s.cpp' % suffix, 'TraceType.cpp', lambda: {
'generated_comment': f'@generated from {fm.template_dir}/TraceType.cpp',
'generated_comment': '@' + f'generated from {fm.template_dir}/TraceType.cpp',
'trace_method_definitions': list(mapMaybe(method_definition, native_functions)),
'trace_wrapper_registrations': list(mapMaybe(method_registration, native_functions)),
})

View File

@ -22,20 +22,24 @@
# which will in turn dispatch back to VariableType for its
# differentiable subcomponents.
#
from dataclasses import dataclass
from .utils import CodeTemplate, nested_dict, write, make_out_api_name_faithful
from .gen_autograd import VIEW_FUNCTIONS, VIEW_FUNCTIONS_WITH_METADATA_CHANGE, \
MULTI_OUTPUT_SAFE_FUNCTIONS, RETURNS_VIEWS_OF_INPUT
from .gen_autograd_functions import uses_single_grad
from .gen_trace_type import MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER, MANUAL_AUTOGRAD
from .gen_trace_type import (
MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER, MANUAL_AUTOGRAD,
declare_returned_variables, tie_return_values, get_return_value, type_wrapper_name,
)
from tools.codegen.api.types import *
from tools.codegen.api.autograd import *
import tools.codegen.api.cpp as cpp
import tools.codegen.api.python as python
from tools.codegen.gen import with_native_function
from tools.codegen.code_template import CodeTemplate
from tools.codegen.gen import with_native_function, parse_native_yaml, FileManager, mapMaybe
from tools.codegen.model import *
from typing import Dict, Optional, List, Sequence, Any, Callable
from tools.codegen.selective_build.selector import SelectiveBuilder
from typing import Callable, List, Optional, Sequence, Tuple, Union
# We don't set or modify grad_fn on these methods. Generally, they return
# tensors that have requires_grad=False. In-place functions listed here will
@ -209,9 +213,6 @@ m.impl("${unqual_operator_name_with_overload}",
UNPACK_TENSOR = CodeTemplate("""\
auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});""")
LEGACY_WRAP_OPTIONS = CodeTemplate("""\
auto ${arg_name}_ = TensorOptions(${arg_name});""")
DECLARE_GRAD_FN = CodeTemplate("""\
std::shared_ptr<${op}> grad_fn;
""")
@ -304,49 +305,18 @@ ${statements}
#endif
""")
# Methods shared by TraceType and VariableType to handle return variable declaration, tie and tuple.
def format_return_variables(declaration):
name = declaration['name']
arguments = declaration['arguments']
inplace = declaration['inplace']
is_out_fn = name.endswith('_out')
modifies_arguments = inplace or is_out_fn
@dataclass(frozen=True)
class NativeFunctionWithDifferentiabilityInfo:
func: NativeFunction
info: Optional[DifferentiabilityInfo]
def declare_returned_variables():
if modifies_arguments:
return ''
if len(declaration['returns']) == 1:
return ''
# TODO: this will be ugly
names = [ret['type'] + ' ' + ret['name'] + ';' for ret in declaration['returns']]
return '\n'.join(names)
def tie_return_values():
if len(declaration['returns']) == 1:
return 'auto {}'.format(declaration['returns'][0]['name'])
names = [ret['name'] for ret in declaration['returns']]
return 'std::tie({})'.format(', '.join(names))
def get_return_value():
if inplace:
return 'self'
if is_out_fn:
return_names = [arg['name'] for arg in arguments
if arg.get('output', False)]
if len(return_names) == 1:
return return_names[0]
return 'std::forward_as_tuple({})'.format(', '.join(return_names))
returns = declaration['returns']
if len(returns) == 1:
return returns[0]['name']
moved = ['std::move({})'.format(r['name']) for r in returns]
return 'std::make_tuple({})'.format(', '.join(moved))
return (declare_returned_variables(), tie_return_values(), get_return_value())
def gen_variable_type(out, aten_declarations, differentiability_infos, template_path):
def gen_variable_type(
out: str,
native_yaml_path: str,
differentiability_infos: Sequence[DifferentiabilityInfo],
template_path: str,
operator_selector: SelectiveBuilder,
) -> None:
"""VariableType.h and VariableType.cpp body
@ -354,154 +324,202 @@ def gen_variable_type(out, aten_declarations, differentiability_infos, template_
implementation of each function dispatches to the base tensor type to
compute the output. The grad_fn is attached to differentiable functions.
"""
fns = list(sorted(filter(
operator_selector.is_native_function_selected_for_training,
parse_native_yaml(native_yaml_path)), key=lambda f: cpp.name(f.func)))
fns_with_infos = match_differentiability_info(fns, differentiability_infos)
aten_declarations = list(sorted(aten_declarations, key=lambda decl: decl['name']))
match_declarations_with_differentiability_info(aten_declarations, differentiability_infos)
gen_variable_type_shard(out, aten_declarations, template_path, None, True)
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
gen_variable_type_shard(fm, fns_with_infos, 'VariableType.h', 'VariableType.h')
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp
# template regarding sharding of the generated files.
num_shards = 5
shards = [[] for _ in range(num_shards)]
shards: List[List[NativeFunctionWithDifferentiabilityInfo]] = [[] for _ in range(num_shards)]
# functions are assigned arbitrarily but stably to a file based on hash
for decl in aten_declarations:
x = sum(ord(c) for c in decl['name']) % num_shards
shards[x].append(decl)
for fn in fns_with_infos:
x = sum(ord(c) for c in cpp.name(fn.func.func)) % num_shards
shards[x].append(fn)
for i, shard in enumerate(shards):
gen_variable_type_shard(out, shard, template_path, '_%d' % i, False)
gen_variable_type_shard(out, aten_declarations, template_path, 'Everything', False)
gen_variable_type_shard(fm, shard, 'VariableType.cpp', f'VariableType_{i}.cpp')
gen_variable_type_shard(fm, fns_with_infos, 'VariableType.cpp', 'VariableTypeEverything.cpp')
def gen_variable_type_shard(out, aten_declarations, template_path, suffix, header):
VARIABLE_TYPE_H = CodeTemplate.from_file(template_path + '/VariableType.h')
VARIABLE_TYPE_CPP = CodeTemplate.from_file(template_path + '/VariableType.cpp')
@with_native_function
def gen_formals(f: NativeFunction) -> str:
if f.use_c10_dispatcher.dispatcher_uses_new_style():
formals = ', '.join(
f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}'
for a in f.func.schema_order_arguments()
)
else:
sig_group = CppSignatureGroup.from_native_function(f, method=False)
formals = ', '.join(f'{a.type} {a.name}' for a in sig_group.signature.arguments())
return formals
type_declarations = []
type_definitions = []
wrapper_registrations = []
@with_native_function
def gen_wrapper_registration(f: NativeFunction) -> str:
if f.use_c10_dispatcher.dispatcher_uses_new_style():
return WRAPPER_REGISTRATION.substitute(
unqual_operator_name_with_overload=f.func.name,
type_wrapper_name=type_wrapper_name(f),
class_type='VariableType',
)
else:
return UNBOXEDONLY_WRAPPER_REGISTRATION.substitute(
unqual_operator_name_with_overload=f.func.name,
type_wrapper_name=type_wrapper_name(f),
class_type='VariableType',
)
for declaration in aten_declarations:
if declaration['use_c10_dispatcher'] in ['full', 'hacky_wrapper_for_legacy_signatures']:
formals = declaration['schema_order_formals']
else:
assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper'
formals = declaration['formals']
type_declarations.append(METHOD_DECLARATION.substitute(declaration, formals=formals))
strategy = dispatch_strategy(declaration)
if declaration['name'] not in MANUAL_AUTOGRAD and strategy == 'use_derived':
body = emit_body(declaration)
def gen_variable_type_shard(
fm: FileManager,
fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
template_name: str,
output_name: str,
) -> None:
type_declarations: List[str] = []
type_definitions: List[str] = []
wrapper_registrations: List[str] = []
for fn in fns_with_infos:
f = fn.func
name = cpp.name(f.func)
formals = gen_formals(f)
type_declarations.append(METHOD_DECLARATION.substitute(
return_type=cpp.returns_type(f.func.returns),
type_wrapper_name=type_wrapper_name(f),
formals=formals,
))
if name not in MANUAL_AUTOGRAD and dispatch_strategy(fn) == 'use_derived':
type_definitions.append(METHOD_DEFINITION.substitute(
declaration, type_definition_body=body, formals=formals))
if declaration['use_c10_dispatcher'] in ['full', 'hacky_wrapper_for_legacy_signatures']:
wrapper_registrations.append(WRAPPER_REGISTRATION.substitute(
declaration, class_type='VariableType'))
else:
assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper'
wrapper_registrations.append(UNBOXEDONLY_WRAPPER_REGISTRATION.substitute(
declaration, class_type='VariableType'))
return_type=cpp.returns_type(f.func.returns),
type_wrapper_name=type_wrapper_name(f),
type_definition_body=emit_body(fn),
formals=formals,
))
wrapper_registrations.append(gen_wrapper_registration(f))
# See Note [Manual Backend kernels]
assert (declaration['name'] in MANUAL_BACKEND) == declaration['manual_kernel_registration']
assert (name in MANUAL_BACKEND) == f.manual_kernel_registration
# If you want to register a kernel to Autograd, you must make the op abstract.
# In other words, this op must have dispatch section in native_functions.yaml.
if declaration['name'] in MANUAL_AUTOGRAD_AND_TRACER or declaration['derivative']:
msg = (f'There\'s a formula for {declaration["name"]}(or its functional variant) in derivatives.yaml. '
if name in MANUAL_AUTOGRAD_AND_TRACER or (fn.info and fn.info.has_derivatives):
msg = (f'There\'s a formula for {name}(or its functional variant) in derivatives.yaml. '
f'It\'s required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA '
f'or DefaultBackend in native_functions.yaml. Please see '
f'https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword '
f'for instructions to choose the right dispatch keyword.')
assert declaration['abstract'], msg
assert f.is_abstract, msg
env = {
fm.write_with_template(output_name, template_name, lambda: {
'generated_comment': '@' + f'generated from {fm.template_dir}/{template_name}',
'type_derived_method_declarations': type_declarations,
'type_derived_method_definitions': type_definitions,
'wrapper_registrations': wrapper_registrations,
}
if header:
write(out, 'VariableType.h', VARIABLE_TYPE_H, env)
else:
write(out, 'VariableType%s.cpp' % suffix, VARIABLE_TYPE_CPP, env)
})
def emit_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
assert dispatch_strategy(fn) == 'use_derived'
f = fn.func
info = fn.info
def emit_body(declaration):
assert dispatch_strategy(declaration) == 'use_derived'
arguments = declaration['arguments']
returns = declaration['returns']
func = declaration['derivative']
name = declaration['name']
inplace = declaration['inplace']
is_out_fn = name.endswith('_out')
modifies_arguments = inplace or is_out_fn
returns_void = len(returns) == 0
base_name = name[:-1] if inplace else name[:-4] if is_out_fn else name
name = cpp.name(f.func)
inplace = f.func.kind() == SchemaKind.inplace
is_out_fn = f.func.kind() == SchemaKind.out
returns_void = len(f.func.returns) == 0
base_name = f.func.name.name.base # TODO: should be str(f.func.name.name)?
view_info = VIEW_FUNCTIONS.get(base_name, None)
if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT:
view_info = "self"
def is_differentiable(arg):
if 'TensorOptions' in arg['type']:
return False
if 'Tensor' not in arg['type']:
return False
if arg['name'] in declaration.get('non_differentiable_arg_names', []):
return False
return True
def is_differentiable(name: str, type: Type) -> bool:
return type.is_tensor_like() and (info is None or name not in info.non_differentiable_arg_names)
def find_args_with_derivatives(differentiable_inputs):
def gen_differentiable_input(
arg: Union[Argument, SelfArgument, TensorOptionsArguments]
) -> Optional[DifferentiableInput]:
if isinstance(arg, TensorOptionsArguments):
return None
a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg
# TODO: `cpp_type` is only to keep it byte-for-byte compatible with the old codegen, should remove.
# NB: This is not a clone of cpp.argument() - TensorOptionsArguments / faithful / binds are
# not handled properly as they are irrelevant for this codegen.
cpp_type = cpp.argument_type(a, binds=a.name).cpp_type()
if not is_differentiable(a.name, a.type):
return None
return DifferentiableInput(
name=a.name,
type=a.type,
cpp_type=cpp_type,
)
@with_native_function
def gen_differentiable_inputs(f: NativeFunction) -> List[DifferentiableInput]:
return list(mapMaybe(gen_differentiable_input, f.func.arguments.non_out))
def find_args_with_derivatives(differentiable_inputs: List[DifferentiableInput]) -> List[DifferentiableInput]:
"""Find arguments that have derivative definitions"""
if func is None:
if info is None or not info.has_derivatives:
return differentiable_inputs
names = set(name for d in func.derivatives for name in d.var_names)
differentiable = [arg for arg in differentiable_inputs if arg['name'] in names]
names = set(name for d in info.derivatives for name in d.var_names)
differentiable = [arg for arg in differentiable_inputs if arg.name in names]
if len(differentiable) != len(names):
missing = names - set(arg['name'] for arg in differentiable)
raise RuntimeError(f'Missing arguments for derivatives: {missing} in {func.name}')
missing = names - set(arg.name for arg in differentiable)
raise RuntimeError(f'Missing arguments for derivatives: {missing} in {info.name}')
return differentiable
inputs = [arg for arg in arguments if not arg.get('output', False)]
differentiable_inputs = list(filter(is_differentiable, inputs))
args_with_derivatives = find_args_with_derivatives(differentiable_inputs)
non_differentiable_arg_names = declaration.get('non_differentiable_arg_names', [])
candidate_differentiable_outputs = list(filter(is_differentiable, returns))
def gen_differentiable_outputs(f: NativeFunction) -> List[DifferentiableOutput]:
outputs: List[DifferentiableOutput] = [
DifferentiableOutput(name=name, type=ret.type, cpp_type=cpp.return_type(ret))
for name, ret in zip(cpp.return_names(f), f.func.returns)]
if declaration['output_differentiability'] is not None:
differentiable_outputs = []
output_differentiability = declaration['output_differentiability']
if False in output_differentiability and inplace:
raise RuntimeError("output_differentiability=False for inplace operation (version_counter won't get updated)")
for differentiable, output in zip(output_differentiability, returns):
if differentiable:
differentiable_outputs.append(output)
elif uses_single_grad(func):
differentiable_outputs = candidate_differentiable_outputs[:1]
else:
differentiable_outputs = candidate_differentiable_outputs
output_differentiability = info.output_differentiability if info else None
if output_differentiability is not None:
differentiable_outputs: List[DifferentiableOutput] = []
if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
raise RuntimeError("output_differentiability=False for inplace operation (version_counter won't get updated)")
for differentiable, output in zip(output_differentiability, outputs):
if differentiable:
differentiable_outputs.append(output)
return differentiable_outputs
candidate_differentiable_outputs = list(filter(lambda r: is_differentiable(r.name, r.type), outputs))
if uses_single_grad(info):
return candidate_differentiable_outputs[:1]
else:
return candidate_differentiable_outputs
differentiable_inputs = gen_differentiable_inputs(f)
args_with_derivatives = find_args_with_derivatives(differentiable_inputs)
differentiable_outputs = gen_differentiable_outputs(f)
requires_derivative = (
base_name not in DONT_REQUIRE_DERIVATIVE and name not in DONT_REQUIRE_DERIVATIVE and
len(differentiable_inputs) > 0 and len(differentiable_outputs) > 0)
if func is not None and not requires_derivative:
raise RuntimeError('ERROR: derivative ignored for {} -- specified an autograd function without derivative'
.format(name))
if info is not None and info.has_derivatives and not requires_derivative:
raise RuntimeError(f'ERROR: derivative ignored for {name} -- specified an autograd function without derivative')
def emit_save_inputs():
setup = []
if func is None:
def emit_save_inputs() -> List[str]:
setup: List[str] = []
if info is None or not info.has_derivatives:
return setup
has_tensorlist_arg = \
any(arg.type in ['TensorList', 'const c10::List<c10::optional<Tensor>> &'] for arg in func.args_with_derivatives)
has_tensorlist_arg = any(is_tensor_list_type(arg.type) for arg in args_with_derivatives)
# We don't want to save tensors if we know that they will never be used
# when computing the derivative, so we add guards to those statements
def guard_for(arg: SavedAttribute) -> Optional[str]:
assert info is not None
# It's hard to determine the edge offset if we have TensorLists
if has_tensorlist_arg:
return None
@ -512,12 +530,12 @@ def emit_body(declaration):
# require_grad if the backward function even gets executed. I don't
# have any good ideas for detecting those cases, so I simply disabled the
# checks.
if 'backward' in func.name:
if 'backward' in info.name:
return None
# If there's a single derivative we could compute, we already have
# a requires_grad check that is sufficient
if len(func.args_with_derivatives) <= 1:
if len(args_with_derivatives) <= 1:
return None
# We really only care about trimming down the amount of tensors we save
@ -526,7 +544,7 @@ def emit_body(declaration):
# We want to emit simple guards, so we only allow that if checking one
# input is enough to determine whether we need that value
used_in = [d for d in func.derivatives if arg in d.saved_inputs]
used_in = [d for d in info.derivatives if arg in d.saved_inputs]
assert len(used_in) > 0
if len(used_in) != 1:
return None
@ -536,75 +554,76 @@ def emit_body(declaration):
derivative_var_name = derivative.var_names[0]
# Figure out the offset of the edge that uses this variable
for edge_off, arg in enumerate(func.args_with_derivatives):
if arg.name == derivative_var_name:
for edge_off, a in enumerate(args_with_derivatives):
if a.name == derivative_var_name:
break
else:
raise AssertionError()
return f'grad_fn->should_compute_output({edge_off})'
setup.extend(save_variables(func.all_saved_inputs, False, guard_for))
for arg in func.args_with_derivatives:
if arg.type in ['TensorList', 'const c10::List<c10::optional<Tensor>> &']:
setup.extend(save_variables(info.all_saved_inputs, False, guard_for))
for arg in args_with_derivatives:
if is_tensor_list_type(arg.type):
setup.append(f'grad_fn->{arg.name}_size_ = {arg.name}.size();')
return setup
def setup_derivative(differentiable_inputs):
env = {}
env['args_with_derivatives'] = [arg['name'] for arg in args_with_derivatives]
env['op'] = func.op if func is not None else 'NotImplemented'
env['op_ctor'] = '' if func is not None else '"{}"'.format(declaration['api_name'])
def setup_derivative(differentiable_inputs: List[DifferentiableInput]) -> List[str]:
body: List[str] = []
if is_out_fn:
# For out functions, ensure that no input or output requires grad
body = []
body.append(DECLARE_GRAD_FN.substitute(op='Node'))
body.append(SETUP_NONE_REQUIRES_GRAD.substitute(
base_name=base_name,
args_to_check=[arg['name'] for arg in differentiable_inputs]))
args_to_check=[arg.name for arg in differentiable_inputs]))
body.append(SETUP_NONE_REQUIRES_GRAD.substitute(
base_name=base_name,
args_to_check=[arg['name'] for arg in differentiable_outputs]))
args_to_check=[arg.name for arg in differentiable_outputs]))
return body
op = info.op if info is not None and info.has_derivatives else 'NotImplemented'
setup = []
setup.extend(ASSIGN_GRAD_FN.substitute(env).split('\n'))
setup.extend(ASSIGN_GRAD_FN.substitute(
op=op,
op_ctor='' if info is not None and info.has_derivatives else f'"{cpp.name(f.func)}"',
args_with_derivatives=[arg.name for arg in args_with_derivatives],
).split('\n'))
setup.extend(emit_save_inputs())
body = []
body.extend(emit_check_no_requires_grad(differentiable_inputs, args_with_derivatives))
body.append(DECLARE_GRAD_FN.substitute(env))
body.append(DECLARE_GRAD_FN.substitute(op=op))
body.append(SETUP_DERIVATIVE.substitute(setup=setup))
return body
def emit_check_if_in_complex_autograd_allowlist():
body = []
def emit_check_if_in_complex_autograd_allowlist() -> List[str]:
body: List[str] = []
if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX:
return body
for arg in differentiable_outputs:
name = arg['name']
if arg['type'] in ['Tensor', 'TensorList', 'const c10::List<c10::optional<Tensor>> &']:
body.append('throw_error_for_complex_autograd({}, "{}");'.format(name, base_name))
name = arg.name
# TODO: should be `arg.type.is_tensor_like()`?
if arg.cpp_type in ['Tensor', 'TensorList', 'const c10::List<c10::optional<Tensor>> &']:
body.append(f'throw_error_for_complex_autograd({name}, "{base_name}");')
return body
def emit_check_no_requires_grad(tensor_args, args_with_derivatives):
def emit_check_no_requires_grad(
tensor_args: List[DifferentiableInput],
args_with_derivatives: List[DifferentiableInput],
) -> List[str]:
"""Checks that arguments without derivatives don't require grad"""
body = []
body: List[str] = []
for arg in tensor_args:
if arg in args_with_derivatives:
continue
name = arg['name']
if name in non_differentiable_arg_names:
name = arg.name
if info and name in info.non_differentiable_arg_names:
continue
if name == 'output':
# Double-backwards definitions sometimes take in 'input' and
# 'output', but only define the derivative for input.
continue
if arg['dynamic_type'] in {'IndexTensor', 'ByteTensor', 'BoolTensor'}:
continue
body.append('check_no_requires_grad({}, "{}");'.format(name, name))
body.append(f'check_no_requires_grad({name}, "{name}");')
return body
def save_variables(
@ -644,42 +663,40 @@ def emit_body(declaration):
stmts.append('}')
return stmts
def emit_dispatch_call(api_name, input_base, unpacked_args):
def emit_dispatch_call(f: NativeFunction, input_base: str, unpacked_args: Sequence[str]) -> str:
""" Dispatch call via function in a namespace or method on Tensor."""
if 'namespace' in declaration['method_of']:
if declaration['use_c10_dispatcher'] in ['hacky_wrapper_for_legacy_signatures', 'full']:
dispatcher_api_name = make_out_api_name_faithful(api_name)
else:
assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper'
dispatcher_api_name = api_name
if Variant.function in f.variants:
call = CALL_DISPATCH_VIA_NAMESPACE.substitute(
api_name=dispatcher_api_name,
api_name=cpp.name(
f.func,
faithful_name_for_out_overloads=f.use_c10_dispatcher.dispatcher_uses_new_style(),
),
unpacked_args=unpacked_args)
else:
call = CALL_DISPATCH_VIA_METHOD.substitute(
api_name=api_name,
api_name=cpp.name(f.func),
var=input_base,
unpacked_method_args=unpacked_args[1:])
return call
def emit_view_lambda():
def emit_view_lambda(unpacked_bindings: List[Binding]) -> str:
""" Generate an additional lambda function to recover views in backward when as_strided is not supported.
See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details."""
input_base = 'input_base'
replay_view_func = ''
updated_unpacked_args = []
combined = nested_dict(env, declaration)
known_view_arg_simple_types = ['int64_t', 'int64_t?', 'bool', 'IntArrayRef']
for arg in combined['unpacked_args']:
updated_unpacked_args: List[str] = []
known_view_arg_simple_types: List[str] = ['int64_t', 'c10::optional<int64_t>', 'bool', 'IntArrayRef']
for unpacked_binding in unpacked_bindings:
arg, arg_type = unpacked_binding.name, unpacked_binding.type
if arg == 'self_':
updated_unpacked_args.append(input_base)
continue
arg_type = combined['unpacked_args_simple_type'][arg]
if arg_type not in known_view_arg_simple_types:
raise TypeError('You are adding an {} {} argument to op {} in addition to known types: {}. '
'Please update the list or materialize it so that it can be closed over by value, '
'also add a test in pytorch/xla/test/test_operations.py where this code is exercised.'
.format(arg_type, arg, declaration['name'], ', '.join(known_view_arg_simple_types)))
known_types_str = ', '.join(known_view_arg_simple_types)
raise TypeError(f'You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: '
f'{known_types_str}. Please update the list or materialize it so that it can be closed '
'over by value, also add a test in pytorch/xla/test/test_operations.py where this code '
'is exercised.')
if arg_type == 'IntArrayRef':
# It's not safe to close over IntArrayRef by value, since this is a
@ -687,7 +704,7 @@ def emit_body(declaration):
arg_vec = arg + '_vec'
replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg, vec=arg_vec)
updated_unpacked_args.append(arg_vec)
elif arg_type == 'int64_t?':
elif arg_type == 'c10::optional<int64_t>':
# Materialize int64_t? to int64_t
arg_value = arg + '_val'
replay_view_func += OPTIONAL_TO_VAL.substitute(arg=arg, val=arg_value, default='0')
@ -695,7 +712,7 @@ def emit_body(declaration):
else:
updated_unpacked_args.append(arg)
replay_view_call = emit_dispatch_call(combined['api_name'], input_base, updated_unpacked_args)
replay_view_call = emit_dispatch_call(f, input_base, updated_unpacked_args)
replay_view_func += REPLAY_VIEW_LAMBDA_FUNC.substitute(
input_base=input_base,
replay_view_call=replay_view_call)
@ -706,17 +723,17 @@ def emit_body(declaration):
is_view_with_metadata_change=is_view_with_metadata_change,
replay_view_func=replay_view_func)
def wrap_output(return_values, var):
def wrap_output(f: NativeFunction, unpacked_bindings: List[Binding], var: str) -> str:
call = ''
rhs_value = None
if 'Tensor' not in declaration['return_type']:
rhs_value: Optional[str] = None
if not any(r.type.is_tensor_like() for r in f.func.returns):
rhs_value = var
elif view_info is not None:
# See NOTE [ Autograd View Variables ] in variable.h for details.
differentiable_output_vars = {r['name'] for r in differentiable_outputs}
differentiable_output_vars = {r.name for r in differentiable_outputs}
if not isinstance(view_info, str):
raise TypeError("The view info should be a string for {}, but it is: {}".format(base_name, view_info))
raise TypeError(f'The view info should be a string for {base_name}, but it is: {view_info}')
if len(differentiable_output_vars) == 0:
# no output is differentiable (.indices() for SparseTensors for example)
@ -725,54 +742,55 @@ def emit_body(declaration):
# Single differentiable output (Tensor or Tensor[])
return_info = differentiable_outputs[0]
# We only support simple Tensor or a TensorList for functions that return views
if not return_info['dynamic_type'] in ['Tensor', 'TensorList']:
raise RuntimeError("{} that return differentiable views can only return Tensor or Tensor[]".format(base_name))
if not is_tensor_type(return_info.type) and not is_tensor_list_type(return_info.type):
raise RuntimeError(f'{base_name} that return differentiable views can only return Tensor or Tensor[]')
# Only allow rebasing of the history if we return a single Tensor
# If we are in a no grad block, raise a warning
# See NOTE [ View + Inplace detection ] for more details about this logic
if return_info['dynamic_type'] in ['TensorList', 'const c10::List<c10::optional<Tensor>> &']:
if is_tensor_list_type(return_info.type):
if base_name in MULTI_OUTPUT_SAFE_FUNCTIONS:
creation_meta = "CreationMeta::MULTI_OUTPUT_SAFE"
creation_meta = 'CreationMeta::MULTI_OUTPUT_SAFE'
else:
creation_meta = "CreationMeta::MULTI_OUTPUT_NODE"
call += ("as_view(/* base */ {}, /* output */ {}, /* is_bw_differentiable */ true, "
"/* is_fw_differentiable */ true, "
"/* creation_meta */ {});").format(view_info, var, creation_meta)
rhs_value = 'std::move({})'.format(var)
creation_meta = 'CreationMeta::MULTI_OUTPUT_NODE'
call += (f'as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, '
'/* is_fw_differentiable */ true, '
f'/* creation_meta */ {creation_meta});')
rhs_value = f'std::move({var})'
else:
call += emit_view_lambda()
creation_meta = "GradMode::is_enabled() ? CreationMeta::DEFAULT: CreationMeta::NO_GRAD_MODE"
rhs_value = ("as_view(/* base */ {}, /* output */ {}, /* is_bw_differentiable */ true, "
"/* is_fw_differentiable */ true, "
"/* view_func */ func, /* creation_meta */ {})").format(view_info, var, creation_meta)
call += emit_view_lambda(unpacked_bindings)
creation_meta = 'GradMode::is_enabled() ? CreationMeta::DEFAULT: CreationMeta::NO_GRAD_MODE'
rhs_value = (f'as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, '
'/* is_fw_differentiable */ true, '
f'/* view_func */ func, /* creation_meta */ {creation_meta})')
else:
# This could be supported but we don't need it at the moment, so keeping things simple.
raise RuntimeError("Function that return multiple differentiable output "
"when at least one of them is view is not supported.")
raise RuntimeError('Function that return multiple differentiable output '
'when at least one of them is view is not supported.')
else:
rhs_value = 'std::move({})'.format(var)
rhs_value = f'std::move({var})'
assert rhs_value is not None
call += ASSIGN_RETURN_VALUE.substitute(return_values=return_values,
call += ASSIGN_RETURN_VALUE.substitute(return_values=tie_return_values(f),
rhs_value=rhs_value)
return call
def enforce_same_tensorimpl_and_storage(env, call):
save_ptrs_stmts = []
enforce_same_ptrs_stmts = []
if declaration['name'] not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
for arg in env.get('unpacked_args', []):
simple_type = env['unpacked_args_simple_type'][arg]
if simple_type == 'TensorList':
def enforce_same_tensorimpl_and_storage(call: str, unpacked_bindings: List[Binding]) -> str:
save_ptrs_stmts: List[str] = []
enforce_same_ptrs_stmts: List[str] = []
if cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
for unpacked_binding in unpacked_bindings:
arg = unpacked_binding.name
noref_cpp_type = unpacked_binding.ctype.cpp_type(strip_ref=True)
if noref_cpp_type == 'TensorList':
save_ptrs_stmts += [SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
elif simple_type == 'c10::List<c10::optional<Tensor>>':
elif noref_cpp_type == 'c10::List<c10::optional<Tensor>>':
save_ptrs_stmts += [SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg),
SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)]
enforce_same_ptrs_stmts += [ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg),
ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)]
elif simple_type == 'Tensor':
elif noref_cpp_type == 'Tensor':
save_ptrs_stmts += [SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
SAVE_TENSOR_IMPL.substitute(tensor_name=arg)]
enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSOR_STORAGE.substitute(tensor_name=arg),
@ -784,74 +802,69 @@ def emit_body(declaration):
RUN_ONLY_IN_DEBUG_MODE.substitute(statements=enforce_same_ptrs_stmts)
return call
def emit_call(env, tie_return_values):
combined = nested_dict(env, declaration)
def emit_call(f: NativeFunction, unpacked_bindings: List[Binding]) -> str:
# We only care about adding `at::AutoNonVariableTypeMode` guard for non-variable dispatch
# (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
# the baseType operations still dispatch to non-Variable type, even if the arguments passed
# in are now Variables.
# See NOTE [ Treating Variables as non-Variables in type dispatch ] for details.
base_type_call = emit_dispatch_call(combined['api_name'], 'self_', combined['unpacked_args'])
if not modifies_arguments and not returns_void:
unpacked_args = [b.name for b in unpacked_bindings]
base_type_call = emit_dispatch_call(f, 'self_', unpacked_args)
if not modifies_arguments(f) and not returns_void:
call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute(
base_type_call=base_type_call)
call += wrap_output(tie_return_values, 'tmp')
call += wrap_output(f, unpacked_bindings, 'tmp')
else:
call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute(
base_type_call=base_type_call)
call = enforce_same_tensorimpl_and_storage(env, call)
call = enforce_same_tensorimpl_and_storage(call, unpacked_bindings)
return call
def emit_history():
fn = 'rebase' if modifies_arguments and view_info is None else 'set'
output_names = [r['name'] for r in differentiable_outputs]
def emit_history() -> str:
fn = 'rebase' if modifies_arguments(f) and view_info is None else 'set'
output_names = [r.name for r in differentiable_outputs]
# TODO: flatten allocates a std::vector, which could be expensive
outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute(outs=output_names)
return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs)
def emit_save_outputs():
def emit_save_outputs() -> str:
if is_out_fn:
# out functions don't currently support differentiation
return ''
func = declaration['derivative']
if func is not None:
stmts = save_variables(func.all_saved_outputs, True)
if info is not None and info.has_derivatives:
stmts = save_variables(info.all_saved_outputs, True)
if len(stmts) == 0:
return ''
return CONDITIONAL.substitute(cond='grad_fn', statements=stmts)
return ''
def emit_any_requires_grad():
def emit_any_requires_grad() -> List[str]:
return [SETUP_ANY_REQUIRES_GRAD.substitute(
args_with_derivatives=[arg['name'] for arg in args_with_derivatives]), ]
args_with_derivatives=[arg.name for arg in args_with_derivatives]), ]
def emit_check_inplace():
def emit_check_inplace() -> List[str]:
if not inplace:
return []
return ['check_inplace({}, _any_requires_grad);'.format(arg['name']) for arg in differentiable_outputs]
return [f'check_inplace({arg.name}, _any_requires_grad);' for arg in differentiable_outputs]
def emit_increment_version():
if not modifies_arguments:
def emit_increment_version(f: NativeFunction) -> List[str]:
if not modifies_arguments(f):
return []
return ['increment_version({});'.format(arg['name']) for arg in returns]
return [f'increment_version({r});' for r in cpp.return_names(f)]
env = {}
combined = nested_dict(env, declaration)
body: List[str] = []
unpack_args_stats, unpacked_bindings = unpack_args(f)
body = []
declare_returned_variables, tie_return_values, get_return_value = format_return_variables(declaration)
body.extend(unpack_args(env, declaration))
body.extend(unpack_args_stats)
if requires_derivative:
body.extend(emit_any_requires_grad())
body.extend(emit_check_inplace())
body.extend(setup_derivative(differentiable_inputs))
body.append(declare_returned_variables)
body.append(declare_returned_variables(f))
body.append(emit_call(env, tie_return_values))
body.extend(emit_increment_version())
body.append(emit_call(f, unpacked_bindings))
body.extend(emit_increment_version(f))
if requires_derivative:
# set_flags has to appear after version_counter, because rebase_history
# requires that the counter is incremented before it is called
@ -866,56 +879,54 @@ def emit_body(declaration):
assert inplace
body.append('reset_grad_accumulator(self);')
if not returns_void:
body.append('return {};'.format(get_return_value))
body.append(f'return {get_return_value(f)};')
return body
@with_native_function
def unpack_args(f: NativeFunction) -> Tuple[List[str], List[Binding]]:
body: List[str] = []
unpacked_bindings: List[Binding] = []
def unpack_args(env, declaration):
def requires_unpack(arg):
return 'Tensor' in arg['dynamic_type'] and 'c10::optional' not in arg['type']
body = []
unpacked_args = []
unpacked_args_simple_type = {}
if declaration['use_c10_dispatcher'] in ['full', 'hacky_wrapper_for_legacy_signatures']:
arguments = declaration['schema_order_arguments']
if f.use_c10_dispatcher.dispatcher_uses_new_style():
bindings = [r for a in f.func.schema_order_arguments()
for r in cpp.argument(a,
method=False,
cpp_no_default_args=set(),
faithful=False,
has_tensor_options=False)]
else:
assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper'
arguments = declaration['arguments']
for i, arg in enumerate(arguments):
if not requires_unpack(arg):
unpacked_args.append(arg['name'])
unpacked_args_simple_type[arg['name']] = arg['simple_type']
sig_group = CppSignatureGroup.from_native_function(f, method=False)
bindings = list(sig_group.signature.arguments())
for i, binding in enumerate(bindings):
assert not isinstance(binding.argument, SelfArgument)
if isinstance(binding.argument, TensorOptionsArguments):
raise RuntimeError("VariableKernel shouldn't take TensorOptions")
is_nullable = binding.argument.type.is_nullable()
if not binding.argument.type.is_tensor_like() or is_nullable:
unpacked_bindings.append(binding)
continue
dynamic_type = arg['dynamic_type']
if 'TensorOptions' not in dynamic_type:
is_nullable = arg.get('is_nullable', False)
ref = (not is_nullable) and dynamic_type != 'TensorList'
suffix = '_opt' if is_nullable and dynamic_type != 'TensorList' else ''
body.append(UNPACK_TENSOR.substitute(
arg_name=arg['name'],
arg_pos=i,
suffix=suffix,
ref='&' if ref else '',
))
else:
# Okay, we are abusing the definition of 'unpack' here a bit,
# although it's still getting the non-variable from the variable
# (in this case via TensorOptions rather than Variable/Tensor).
assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper', \
"VariableKernel shouldn't take TensorOptions if the op is c10-full"
body.append(LEGACY_WRAP_OPTIONS.substitute(arg_name=arg['name']))
is_tensor_list = is_tensor_list_type(binding.argument.type)
ref = (not is_nullable) and not is_tensor_list
suffix = '_opt' if is_nullable and not is_tensor_list else ''
body.append(UNPACK_TENSOR.substitute(
arg_name=binding.name,
arg_pos=i,
suffix=suffix,
ref='&' if ref else '',
))
unpacked_bindings.append(Binding(
name=binding.name + '_',
ctype=binding.ctype,
argument=binding.argument,
default=binding.default,
))
unpacked_args.append(arg['name'] + '_')
unpacked_args_simple_type[arg['name'] + '_'] = arg['simple_type']
return body, unpacked_bindings
env['unpacked_args'] = unpacked_args
env['unpacked_args_simple_type'] = unpacked_args_simple_type
return body
def dispatch_strategy(declaration):
def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str:
"""How are we going to call the underlying implementation of a
declaration? There are two strategies:
@ -935,7 +946,7 @@ def dispatch_strategy(declaration):
get dispatched back to VariableType (which will ensure that they
are differentiable.)
"""
if declaration['abstract'] or declaration['derivative'] is not None:
if fn.func.is_abstract or (fn.info is not None and fn.info.has_derivatives):
# If the function is abstract (not implemented on at::Type), we must
# call the implementation on the derived type with unpacked tensors.
@ -959,62 +970,47 @@ def dispatch_strategy(declaration):
# assumption might not hold, but then you'll see gradcheck fail.)
return 'use_type'
def get_decl_signature(declaration: Dict[Any, Any], use_base_variant: bool = False) -> str:
name = declaration['name']
arguments = declaration['arguments']
if use_base_variant:
if declaration['inplace']:
assert name.endswith('_')
name = name[:-1]
elif name.endswith('_out'):
name = name[:-4]
arguments = [arg for arg in arguments if not arg.get('output', False)]
simple_types = ', '.join(arg['simple_type'] for arg in arguments)
return f'{name}({simple_types})'
def is_tensor_type(t: Type) -> bool:
# TODO: Should handle optional here?
return t.is_tensor_like() and t.is_list_like() is None
@with_native_function
def get_func_signature(f: NativeFunction) -> str:
args = CppSignatureGroup.from_native_function(f, method=False).signature.arguments()
types = ', '.join(python.argument_type_str(a.argument.type, simple_type=True)
if isinstance(a.argument, Argument) else 'TensorOptions'
for a in args)
return f'{cpp.name(f.func)}({types})'
def is_tensor_list_type(t: Type) -> bool:
# TODO: Should handle optional here?
return t.is_tensor_like() and t.is_list_like() is not None
def match_declarations_with_differentiability_info(
declarations: Dict[Any, Any],
def modifies_arguments(f: NativeFunction) -> bool:
return f.func.kind() in [SchemaKind.inplace, SchemaKind.out]
def match_differentiability_info(
native_functions: List[NativeFunction],
differentiability_infos: Sequence[DifferentiabilityInfo],
) -> None:
) -> List[NativeFunctionWithDifferentiabilityInfo]:
"""Sets the "derivative" key on declarations to matching autograd function
In-place functions will use the out-of-place derivative definition if there
is no in-place specific derivative.
"""
info_by_signature = {get_func_signature(info.func): info for info in differentiability_infos}
info_by_schema = {info.func.func: info for info in differentiability_infos}
functional_info_by_signature = {
info.func.func.signature(strip_default=True): info
for info in differentiability_infos
if info.func.func.kind() == SchemaKind.functional}
def find_info(declaration: Dict[Any, Any]) -> Optional[DifferentiabilityInfo]:
signature = get_decl_signature(declaration)
if signature in info_by_signature:
return info_by_signature[signature]
def find_info(f: NativeFunction) -> Tuple[Optional[DifferentiabilityInfo], bool]:
if f.func in info_by_schema:
return info_by_schema[f.func], True
# if there is no exact match look for the out-of-place signature.
# i.e mul() for mul_() or mul_out()
signature = get_decl_signature(declaration, use_base_variant=True)
return info_by_signature.get(signature)
return functional_info_by_signature.get(f.func.signature(strip_default=True)), False
for declaration in declarations:
info = find_info(declaration)
declaration['derivative'] = info if info and info.args_with_derivatives else None
result: List[NativeFunctionWithDifferentiabilityInfo] = []
for f in native_functions:
info, is_exact_match = find_info(f)
result.append(NativeFunctionWithDifferentiabilityInfo(
func=f,
info=info,
))
# Currently, the '.strides()' to 'strides_or_error' replacement does not support
# 'self' derivatives of an inplace function, so we must check for this case.
if declaration['inplace'] and (info is not None):
for derivative in info.derivatives:
if 'self' in derivative.var_names:
for saved_input in derivative.saved_inputs:
assert 'strides_or_error' not in saved_input.expr, (
"Calling '.strides()' in the 'self' derivative formula of an "
f"in-place function is not supported: {declaration['name']}")
declaration['non_differentiable_arg_names'] = info.non_differentiable_arg_names if info else []
declaration['output_differentiability'] = info.output_differentiability if info else None
return result

View File

@ -87,3 +87,36 @@ class DifferentiabilityInfo:
# Raw data read from derivatives.yaml.
output_differentiability: Optional[List[bool]]
@property
def has_derivatives(self) -> bool:
return len(self.args_with_derivatives) > 0
# Represents a differentiable `Argument`.
# How is it different from the `Argument` type?
# - It's processed Arguments which are differentiable and only used in the
# context of the autograd codegen;
# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
@dataclass(frozen=True)
class DifferentiableInput:
name: str
type: Type
# TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
cpp_type: str
# Represents a differentiable `Return`.
# How it it different from the `Return` type?
# - The name in `Return` is optional. Here it is always populated using the same
# `cpp.return_names()` method.
# TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
# - It's processed Returns which are differentiable, in compliance with the
# `output_differentiability` field defined in derivatives.yaml (if specified),
# and are only used in the context of the autograd codegen;
@dataclass(frozen=True)
class DifferentiableOutput:
name: str
type: Type
# TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
cpp_type: str

View File

@ -106,7 +106,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
return BaseCType("DimnameList", binds)
elif str(t.elem) == 'Tensor?':
if local.use_c10_dispatcher().dispatcher_uses_new_style():
return BaseCType("const c10::List<c10::optional<Tensor>> &", binds)
return ConstRefCType(BaseCType("c10::List<c10::optional<Tensor>>", binds))
else:
return BaseCType("TensorList", binds)
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)

View File

@ -31,14 +31,16 @@ class BaseCType:
type: str
name: ArgName
def cpp_type(self) -> str:
def cpp_type(self, *, strip_ref: bool = False) -> str:
return self.type
@dataclass(frozen=True)
class ConstRefCType:
elem: 'CType'
def cpp_type(self) -> str:
def cpp_type(self, *, strip_ref: bool = False) -> str:
if strip_ref:
return self.elem.cpp_type(strip_ref=strip_ref)
return f'const {self.elem.cpp_type()} &'
@property
@ -49,7 +51,9 @@ class ConstRefCType:
class MutRefCType:
elem: 'CType'
def cpp_type(self) -> str:
def cpp_type(self, *, strip_ref: bool = False) -> str:
if strip_ref:
return self.elem.cpp_type(strip_ref=strip_ref)
return f'{self.elem.cpp_type()} &'
@property
@ -60,7 +64,8 @@ class MutRefCType:
class OptionalCType:
elem: 'CType'
def cpp_type(self) -> str:
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f'c10::optional<{self.elem.cpp_type()}>'
@property

View File

@ -203,8 +203,7 @@ class RegisterSchema:
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
op_name = f"aten::{f.func.name}"
if not self.selector.is_operator_selected(op_name):
if not self.selector.is_native_function_selected(f):
return None
return f'm.def({cpp_string(str(f.func))});\n'
@ -399,8 +398,7 @@ struct {class_name} final : public {parent_class} {{
e.expr for e in translate(functional_sig.arguments(), dispatcher.arguments(functional_func), method=False)
)
op_name = f"aten::{f.func.name}"
if self.target is Target.REGISTRATION and not self.selector.is_operator_selected(op_name):
if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f):
return None
k = f.func.kind()
@ -480,8 +478,7 @@ struct {class_name} final : public {parent_class} {{
if f.manual_kernel_registration:
return None
op_name = f"aten::{f.func.name}"
if self.target is Target.REGISTRATION and not self.selector.is_operator_selected(op_name):
if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f):
return None
name = native.name(f.func)

View File

@ -567,7 +567,7 @@ class FunctionSchema:
else:
return SchemaKind.functional
def signature(self) -> 'FunctionSchema':
def signature(self, *, strip_default: bool = False) -> 'FunctionSchema':
"""
Certain schemas are 'related', in that they are simply
inplace/out/functional versions of the same function. This method
@ -582,11 +582,13 @@ class FunctionSchema:
- Out arguments are stripped
- Mutability annotations are stripped (this is sound
because you cannot overload on mutability annotation)
- Return names are stripped since they are not overloadable and
some variants have return names but some not
"""
def strip_ret_annotation(r: Return) -> Return:
return Return(
name=r.name,
name=None,
type=r.type,
annotation=None,
)
@ -600,7 +602,7 @@ class FunctionSchema:
),
overload_name="", # stripped
),
arguments=self.arguments.signature(),
arguments=self.arguments.signature(strip_default=strip_default),
returns=tuple(map(strip_ret_annotation, self.returns)),
)
@ -983,14 +985,14 @@ class Arguments:
ret.extend(self.post_tensor_options_kwarg_only)
return ret
def signature(self) -> 'Arguments':
def signature(self, *, strip_default: bool = False) -> 'Arguments':
# dataclasses.replace could be used here, but it is less
# type safe so for now I've opted to type everything out
def strip_arg_annotation(a: Argument) -> Argument:
return Argument(
name=a.name,
type=a.type,
default=a.default, # hmmm
default=a.default if not strip_default else None,
annotation=None,
)

View File

@ -3,6 +3,7 @@ import yaml
from dataclasses import dataclass
from tools.codegen.model import NativeFunction
from tools.codegen.selective_build.operator import *
# A SelectiveBuilder holds information extracted from the selective build
@ -96,6 +97,10 @@ class SelectiveBuilder:
name = strip_operator_overload_name(name)
return name in self.operators and self.operators[name].include_all_overloads
def is_native_function_selected(self, func: NativeFunction) -> bool:
op_name = op_name_from_native_function(func)
return self.is_operator_selected(op_name)
def is_operator_selected_for_training(self, name: str) -> bool:
if not self.is_operator_selected(name):
return False
@ -123,6 +128,10 @@ class SelectiveBuilder:
(base_op.include_all_overloads and base_op.is_used_for_training)
)
def is_native_function_selected_for_training(self, func: NativeFunction) -> bool:
op_name = op_name_from_native_function(func)
return self.is_operator_selected_for_training(op_name)
def is_root_operator(self, name: str) -> bool:
if not self.is_operator_selected(name):
return False
@ -158,3 +167,9 @@ def combine_selective_builders(lhs: SelectiveBuilder, rhs: SelectiveBuilder) ->
debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info)
operators = merge_operator_dicts(lhs.operators, rhs.operators)
return SelectiveBuilder(include_all_operators, debug_info, operators)
def op_name_from_native_function(f: NativeFunction) -> str:
# This was originally read from the 'operator_name_with_overload' field in the
# declaration dict, which was the part before the first '(' in 'schema_string'.
return f'aten::{f.func.name}'

View File

@ -22,9 +22,10 @@ import argparse
import re
from itertools import groupby
from functools import reduce
from ..autograd.gen_autograd import load_aten_declarations
import yaml
from ..autograd.gen_autograd import RETURNS_VIEWS_OF_INPUT
from ..autograd.utils import CodeTemplate, write, is_out_variant, op_name_with_overload
from ..autograd.utils import CodeTemplate, YamlLoader, write, is_out_variant, op_name_with_overload
from tools.codegen.selective_build.selector import SelectiveBuilder
# JIT has a type system of
@ -279,6 +280,66 @@ def argument_order(decl):
return decl.get('jit_argument_order') or list(range(len(decl['arguments'])))
def format_return_type(returns):
if len(returns) == 0:
return 'void'
elif len(returns) == 1:
return returns[0]['type']
else:
return_types = [r['type'] for r in returns]
return 'std::tuple<{}>'.format(','.join(return_types))
def get_simple_type(arg):
simple_type = arg['type']
simple_type = simple_type.replace(' &', '').replace('const ', '')
simple_type = simple_type.replace('Generator *', 'Generator')
opt_match = re.match(r'c10::optional<(.+)>', simple_type)
if opt_match:
simple_type = '{}?'.format(opt_match.group(1))
return simple_type
def load_aten_declarations(path):
with open(path, 'r') as f:
declarations = yaml.load(f, Loader=YamlLoader)
# enrich declarations with additional information
selected_declarations = []
for declaration in declarations:
if declaration.get('deprecated'):
continue
for arg in declaration['arguments']:
arg['simple_type'] = get_simple_type(arg)
for arg in declaration['schema_order_arguments']:
arg['simple_type'] = get_simple_type(arg)
for ret in declaration['returns']:
ret['simple_type'] = get_simple_type(ret)
declaration['formals'] = [arg['type'] + ' ' + arg['name']
for arg in declaration['arguments']]
declaration['schema_order_formals'] = [arg['type'] + ' ' + arg['name']
for arg in declaration['schema_order_arguments']]
declaration['args'] = [arg['name'] for arg in declaration['arguments']]
declaration['schema_order_args'] = [arg['name'] for arg in declaration['schema_order_arguments']]
declaration['api_name'] = declaration['name']
if declaration.get('overload_name'):
declaration['type_wrapper_name'] = "{}_{}".format(
declaration['name'], declaration['overload_name'])
else:
declaration['type_wrapper_name'] = declaration['name']
declaration['operator_name_with_overload'] = declaration['schema_string'].split('(')[0]
declaration['unqual_operator_name_with_overload'] = declaration['operator_name_with_overload'].split('::')[1]
declaration['return_type'] = format_return_type(declaration['returns'])
declaration['base_name'] = declaration['name']
selected_declarations.append(declaration)
return selected_declarations
def gen_unboxing_wrappers(
declarations,
out,