mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Codegen: ADInplaceOrViewType only include operators registered (#68692)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68692 ADInplaceOrViewType is a sharded file, so by only including specific operator headers, we ensure that changing one (non-method) operator only needs one shard to be re-compiled. This also ports the generated code over to the `at::_ops` interface, and the code generator itself to using `write_sharded` instead of re-implementing its own version of sharding. Test Plan: Imported from OSS Reviewed By: bdhirsh Differential Revision: D33217916 Pulled By: albanD fbshipit-source-id: 90f1868f72644f1b5aa023cefd6a102bbbec95af
This commit is contained in:
committed by
Facebook GitHub Bot
parent
cc55da8a9b
commit
ad803936d1
@ -9,16 +9,16 @@ from tools.codegen.api.autograd import (
|
||||
NativeFunctionWithDifferentiabilityInfo, gen_differentiable_outputs,
|
||||
dispatch_strategy,
|
||||
)
|
||||
from tools.codegen.api.types import (Binding, DispatcherSignature, CppSignatureGroup, CType,
|
||||
BaseCType, OptionalCType, longT, boolT, intArrayRefT)
|
||||
from tools.codegen.api.types import (Binding, DispatcherSignature, CType, BaseCType,
|
||||
OptionalCType, longT, boolT, intArrayRefT)
|
||||
from tools.codegen.code_template import CodeTemplate
|
||||
from tools.codegen.context import with_native_function
|
||||
from tools.codegen.model import (
|
||||
Type, NativeFunction, SelfArgument, TensorOptionsArguments, Variant,
|
||||
SchemaKind, is_foreach_op,
|
||||
Type, NativeFunction, SelfArgument, TensorOptionsArguments, SchemaKind,
|
||||
is_foreach_op,
|
||||
)
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
from tools.codegen.utils import mapMaybe, FileManager
|
||||
from typing import List, Optional, Sequence, Tuple, Dict
|
||||
from tools.codegen.utils import FileManager
|
||||
from .context import with_native_function_with_differentiability_info
|
||||
from .gen_trace_type import (
|
||||
MANUAL_AUTOGRAD, type_wrapper_name, tie_return_values, get_return_value
|
||||
@ -104,11 +104,8 @@ OPTIONAL_TO_VAL = CodeTemplate("""\
|
||||
auto ${val} = ${arg}.value_or(${default});
|
||||
""")
|
||||
|
||||
CALL_DISPATCH_VIA_NAMESPACE = CodeTemplate("""\
|
||||
at::${api_name}(${unpacked_args})""")
|
||||
|
||||
CALL_DISPATCH_VIA_METHOD = CodeTemplate("""\
|
||||
${var}.${api_name}(${unpacked_method_args})""")
|
||||
CALL_DISPATCH = CodeTemplate("""\
|
||||
at::_ops::${unambiguous_name}::call(${unpacked_args})""")
|
||||
|
||||
SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate("""\
|
||||
std::function<at::Tensor(const at::Tensor&)> func=nullptr;
|
||||
@ -142,7 +139,7 @@ m.impl("${unqual_operator_name_with_overload}", torch::autograd::autogradNotImpl
|
||||
INPLACE_REDISPATCH = CodeTemplate("""\
|
||||
{
|
||||
at::AutoDispatchBelowADInplaceOrView guard;
|
||||
at::redispatch::${api_name}(${unpacked_args});
|
||||
at::_ops::${unambiguous_name}::redispatch(${unpacked_args});
|
||||
}
|
||||
""")
|
||||
|
||||
@ -153,7 +150,7 @@ ${return_values} = ${rhs_value};
|
||||
VIEW_REDISPATCH = CodeTemplate("""\
|
||||
${assign_return_values} ([&]() {
|
||||
at::AutoDispatchBelowADInplaceOrView guard;
|
||||
return at::redispatch::${api_name}(${unpacked_args});
|
||||
return at::_ops::${unambiguous_name}::redispatch(${unpacked_args});
|
||||
})();
|
||||
""")
|
||||
|
||||
@ -231,19 +228,9 @@ def get_view_info(f: NativeFunction) -> Optional[str]:
|
||||
# - The view replay call also is not part of the hot path.
|
||||
def emit_view_call(f: NativeFunction, input_base: str, unpacked_args: Sequence[str]) -> str:
|
||||
# View replay functions use the standard Dispatcher::call API.
|
||||
if Variant.function in f.variants:
|
||||
call = CALL_DISPATCH_VIA_NAMESPACE.substitute(
|
||||
api_name=cpp.name(
|
||||
f.func,
|
||||
faithful_name_for_out_overloads=True,
|
||||
),
|
||||
unpacked_args=unpacked_args)
|
||||
else:
|
||||
call = CALL_DISPATCH_VIA_METHOD.substitute(
|
||||
api_name=cpp.name(f.func),
|
||||
var=input_base,
|
||||
unpacked_method_args=unpacked_args[1:])
|
||||
return call
|
||||
return CALL_DISPATCH.substitute(
|
||||
unambiguous_name=f.func.name.unambiguous_name(),
|
||||
unpacked_args=unpacked_args)
|
||||
|
||||
def emit_view_lambda(f: NativeFunction, unpacked_bindings: List[Binding]) -> str:
|
||||
""" Generate an additional lambda function to recover views in backward when as_strided is not supported.
|
||||
@ -359,14 +346,9 @@ def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> Li
|
||||
|
||||
# Note that this calls the slow, dispatching variants of manual_cpp_binding ops.
|
||||
# We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal.
|
||||
sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=f.manual_cpp_binding)
|
||||
if sig_group.faithful_signature is not None:
|
||||
api_name = sig_group.faithful_signature.name()
|
||||
else:
|
||||
api_name = sig_group.signature.name()
|
||||
if modifies_arguments(f): # inplace op
|
||||
inplace_view_body.append(INPLACE_REDISPATCH.substitute(
|
||||
api_name=api_name,
|
||||
unambiguous_name=f.func.name.unambiguous_name(),
|
||||
unpacked_args=redispatch_args,
|
||||
))
|
||||
for r in cpp.return_names(f):
|
||||
@ -375,7 +357,7 @@ def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> Li
|
||||
assert(get_view_info(f) is not None)
|
||||
inplace_view_body.append(VIEW_REDISPATCH.substitute(
|
||||
assign_return_values='auto ' + TMP_VAR + ' = ',
|
||||
api_name=api_name,
|
||||
unambiguous_name=f.func.name.unambiguous_name(),
|
||||
unpacked_args=redispatch_args,
|
||||
))
|
||||
call, rhs_value = emit_view_body(fn, TMP_VAR)
|
||||
@ -425,17 +407,16 @@ def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool:
|
||||
name = cpp.name(f.func)
|
||||
return name not in MANUAL_AUTOGRAD and dispatch_strategy(fn) == 'use_derived'
|
||||
|
||||
def gen_inplace_or_view_type_shard(
|
||||
fm: FileManager, fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], suffix: str
|
||||
) -> None:
|
||||
def gen_inplace_or_view_type_env(fn: NativeFunctionWithDifferentiabilityInfo) -> Dict[str, List[str]]:
|
||||
definition = inplace_or_view_method_definition(fn)
|
||||
registration = inplace_or_view_method_registration(fn)
|
||||
|
||||
filtered_fns_with_infos = list(filter(use_derived, fns_with_infos))
|
||||
|
||||
fm.write_with_template('ADInplaceOrViewType%s.cpp' % suffix, 'ADInplaceOrViewType.cpp', lambda: {
|
||||
'generated_comment': f'@generated from {fm.template_dir}/ADInplaceOrViewType.cpp',
|
||||
'inplace_or_view_method_definitions': list(mapMaybe(inplace_or_view_method_definition, filtered_fns_with_infos)),
|
||||
'inplace_or_view_wrapper_registrations': list(mapMaybe(inplace_or_view_method_registration, filtered_fns_with_infos)),
|
||||
})
|
||||
return {
|
||||
'ops_headers': ([f'#include <ATen/ops/{fn.func.root_name}_ops.h>']
|
||||
if definition is not None else []),
|
||||
'inplace_or_view_method_definitions': [definition] if definition is not None else [],
|
||||
'inplace_or_view_wrapper_registrations': [registration] if registration is not None else [],
|
||||
}
|
||||
|
||||
def gen_inplace_or_view_type(
|
||||
out: str,
|
||||
@ -446,14 +427,18 @@ def gen_inplace_or_view_type(
|
||||
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp
|
||||
# template regarding sharding of the generated files.
|
||||
num_shards = 2
|
||||
shards: List[List[NativeFunctionWithDifferentiabilityInfo]] = [[] for _ in range(num_shards)]
|
||||
|
||||
# functions are assigned arbitrarily but stably to a file based on hash
|
||||
for fn in fns_with_infos:
|
||||
x = sum(ord(c) for c in cpp.name(fn.func.func)) % num_shards
|
||||
shards[x].append(fn)
|
||||
|
||||
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
||||
for i, shard in enumerate(shards):
|
||||
gen_inplace_or_view_type_shard(fm, shard, f'_{i}')
|
||||
gen_inplace_or_view_type_shard(fm, fns_with_infos, 'Everything')
|
||||
fm.write_sharded(
|
||||
'ADInplaceOrViewType.cpp',
|
||||
[fn for fn in fns_with_infos if use_derived(fn)],
|
||||
key_fn=lambda fn: fn.func.root_name,
|
||||
base_env={
|
||||
'generated_comment':
|
||||
f'@generated from {template_path}/ADInplaceOrViewType.cpp',
|
||||
},
|
||||
env_callable=gen_inplace_or_view_type_env,
|
||||
num_shards=2,
|
||||
sharded_keys={'ops_headers', 'inplace_or_view_method_definitions',
|
||||
'inplace_or_view_wrapper_registrations'}
|
||||
)
|
||||
|
@ -1,13 +1,15 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include "torch/csrc/autograd/VariableTypeUtils.h"
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/RedispatchFunctions.h>
|
||||
|
||||
// ${generated_comment}
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Operators.h>
|
||||
#else
|
||||
$ops_headers
|
||||
#endif
|
||||
|
||||
using namespace at;
|
||||
using torch::autograd::CreationMeta;
|
||||
|
Reference in New Issue
Block a user