Codegen: ADInplaceOrViewType only include operators registered (#68692)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68692

ADInplaceOrViewType is a sharded file, so by only including specific
operator headers, we ensure that changing one (non-method) operator
only needs one shard to be re-compiled.

This also ports the generated code over to the `at::_ops` interface,
and the code generator itself to using `write_sharded` instead of
re-implementing its own version of sharding.

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Differential Revision: D33217916

Pulled By: albanD

fbshipit-source-id: 90f1868f72644f1b5aa023cefd6a102bbbec95af
This commit is contained in:
Peter Bell
2022-01-12 15:33:23 -08:00
committed by Facebook GitHub Bot
parent cc55da8a9b
commit ad803936d1
2 changed files with 43 additions and 56 deletions

View File

@ -9,16 +9,16 @@ from tools.codegen.api.autograd import (
NativeFunctionWithDifferentiabilityInfo, gen_differentiable_outputs,
dispatch_strategy,
)
from tools.codegen.api.types import (Binding, DispatcherSignature, CppSignatureGroup, CType,
BaseCType, OptionalCType, longT, boolT, intArrayRefT)
from tools.codegen.api.types import (Binding, DispatcherSignature, CType, BaseCType,
OptionalCType, longT, boolT, intArrayRefT)
from tools.codegen.code_template import CodeTemplate
from tools.codegen.context import with_native_function
from tools.codegen.model import (
Type, NativeFunction, SelfArgument, TensorOptionsArguments, Variant,
SchemaKind, is_foreach_op,
Type, NativeFunction, SelfArgument, TensorOptionsArguments, SchemaKind,
is_foreach_op,
)
from typing import List, Optional, Sequence, Tuple
from tools.codegen.utils import mapMaybe, FileManager
from typing import List, Optional, Sequence, Tuple, Dict
from tools.codegen.utils import FileManager
from .context import with_native_function_with_differentiability_info
from .gen_trace_type import (
MANUAL_AUTOGRAD, type_wrapper_name, tie_return_values, get_return_value
@ -104,11 +104,8 @@ OPTIONAL_TO_VAL = CodeTemplate("""\
auto ${val} = ${arg}.value_or(${default});
""")
CALL_DISPATCH_VIA_NAMESPACE = CodeTemplate("""\
at::${api_name}(${unpacked_args})""")
CALL_DISPATCH_VIA_METHOD = CodeTemplate("""\
${var}.${api_name}(${unpacked_method_args})""")
CALL_DISPATCH = CodeTemplate("""\
at::_ops::${unambiguous_name}::call(${unpacked_args})""")
SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate("""\
std::function<at::Tensor(const at::Tensor&)> func=nullptr;
@ -142,7 +139,7 @@ m.impl("${unqual_operator_name_with_overload}", torch::autograd::autogradNotImpl
INPLACE_REDISPATCH = CodeTemplate("""\
{
at::AutoDispatchBelowADInplaceOrView guard;
at::redispatch::${api_name}(${unpacked_args});
at::_ops::${unambiguous_name}::redispatch(${unpacked_args});
}
""")
@ -153,7 +150,7 @@ ${return_values} = ${rhs_value};
VIEW_REDISPATCH = CodeTemplate("""\
${assign_return_values} ([&]() {
at::AutoDispatchBelowADInplaceOrView guard;
return at::redispatch::${api_name}(${unpacked_args});
return at::_ops::${unambiguous_name}::redispatch(${unpacked_args});
})();
""")
@ -231,19 +228,9 @@ def get_view_info(f: NativeFunction) -> Optional[str]:
# - The view replay call also is not part of the hot path.
def emit_view_call(f: NativeFunction, input_base: str, unpacked_args: Sequence[str]) -> str:
# View replay functions use the standard Dispatcher::call API.
if Variant.function in f.variants:
call = CALL_DISPATCH_VIA_NAMESPACE.substitute(
api_name=cpp.name(
f.func,
faithful_name_for_out_overloads=True,
),
unpacked_args=unpacked_args)
else:
call = CALL_DISPATCH_VIA_METHOD.substitute(
api_name=cpp.name(f.func),
var=input_base,
unpacked_method_args=unpacked_args[1:])
return call
return CALL_DISPATCH.substitute(
unambiguous_name=f.func.name.unambiguous_name(),
unpacked_args=unpacked_args)
def emit_view_lambda(f: NativeFunction, unpacked_bindings: List[Binding]) -> str:
""" Generate an additional lambda function to recover views in backward when as_strided is not supported.
@ -359,14 +346,9 @@ def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> Li
# Note that this calls the slow, dispatching variants of manual_cpp_binding ops.
# We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal.
sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=f.manual_cpp_binding)
if sig_group.faithful_signature is not None:
api_name = sig_group.faithful_signature.name()
else:
api_name = sig_group.signature.name()
if modifies_arguments(f): # inplace op
inplace_view_body.append(INPLACE_REDISPATCH.substitute(
api_name=api_name,
unambiguous_name=f.func.name.unambiguous_name(),
unpacked_args=redispatch_args,
))
for r in cpp.return_names(f):
@ -375,7 +357,7 @@ def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> Li
assert(get_view_info(f) is not None)
inplace_view_body.append(VIEW_REDISPATCH.substitute(
assign_return_values='auto ' + TMP_VAR + ' = ',
api_name=api_name,
unambiguous_name=f.func.name.unambiguous_name(),
unpacked_args=redispatch_args,
))
call, rhs_value = emit_view_body(fn, TMP_VAR)
@ -425,17 +407,16 @@ def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool:
name = cpp.name(f.func)
return name not in MANUAL_AUTOGRAD and dispatch_strategy(fn) == 'use_derived'
def gen_inplace_or_view_type_shard(
fm: FileManager, fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], suffix: str
) -> None:
def gen_inplace_or_view_type_env(fn: NativeFunctionWithDifferentiabilityInfo) -> Dict[str, List[str]]:
definition = inplace_or_view_method_definition(fn)
registration = inplace_or_view_method_registration(fn)
filtered_fns_with_infos = list(filter(use_derived, fns_with_infos))
fm.write_with_template('ADInplaceOrViewType%s.cpp' % suffix, 'ADInplaceOrViewType.cpp', lambda: {
'generated_comment': f'@generated from {fm.template_dir}/ADInplaceOrViewType.cpp',
'inplace_or_view_method_definitions': list(mapMaybe(inplace_or_view_method_definition, filtered_fns_with_infos)),
'inplace_or_view_wrapper_registrations': list(mapMaybe(inplace_or_view_method_registration, filtered_fns_with_infos)),
})
return {
'ops_headers': ([f'#include <ATen/ops/{fn.func.root_name}_ops.h>']
if definition is not None else []),
'inplace_or_view_method_definitions': [definition] if definition is not None else [],
'inplace_or_view_wrapper_registrations': [registration] if registration is not None else [],
}
def gen_inplace_or_view_type(
out: str,
@ -446,14 +427,18 @@ def gen_inplace_or_view_type(
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp
# template regarding sharding of the generated files.
num_shards = 2
shards: List[List[NativeFunctionWithDifferentiabilityInfo]] = [[] for _ in range(num_shards)]
# functions are assigned arbitrarily but stably to a file based on hash
for fn in fns_with_infos:
x = sum(ord(c) for c in cpp.name(fn.func.func)) % num_shards
shards[x].append(fn)
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
for i, shard in enumerate(shards):
gen_inplace_or_view_type_shard(fm, shard, f'_{i}')
gen_inplace_or_view_type_shard(fm, fns_with_infos, 'Everything')
fm.write_sharded(
'ADInplaceOrViewType.cpp',
[fn for fn in fns_with_infos if use_derived(fn)],
key_fn=lambda fn: fn.func.root_name,
base_env={
'generated_comment':
f'@generated from {template_path}/ADInplaceOrViewType.cpp',
},
env_callable=gen_inplace_or_view_type_env,
num_shards=2,
sharded_keys={'ops_headers', 'inplace_or_view_method_definitions',
'inplace_or_view_wrapper_registrations'}
)

View File

@ -1,13 +1,15 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include "torch/csrc/autograd/VariableTypeUtils.h"
#include <torch/library.h>
#include <ATen/Functions.h>
#include <ATen/RedispatchFunctions.h>
// ${generated_comment}
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Operators.h>
#else
$ops_headers
#endif
using namespace at;
using torch::autograd::CreationMeta;