Use presence of _symint in kernel name to generate symint sig or not (#84579)

Something people found confusing was that whether or not a native::
signature would get SymInt or not in its type was based on the dispatch
key.  This changes it so that SymInt or not in type is based on whether
or not you have _symint in the name of the kernel or not.  This means
that even when we make operators support SymInt, you no longer have to
go and update all the preexisting definitions; instead, you now
selectively write _symint to opt individual kernels into SymInt support.

I then go and update a bunch of kernels that don't have proper SymInt
support to make use of this convention.  There is some hacking around
for view generation code.

I also add support for external backends to specify 'symint' operators, for which we generate SymInt signatures instead of regular signatures.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D39310060](https://our.internmc.facebook.com/intern/diff/D39310060)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84579
Approved by: https://github.com/wconstab
This commit is contained in:
Eli Uriegas
2022-09-09 11:29:07 -07:00
committed by PyTorch MergeBot
parent 18a31cc044
commit 93aef3a010
22 changed files with 173 additions and 130 deletions

View File

@ -304,6 +304,8 @@ def generate_function(
if func.kind() == SchemaKind.out
else cpp.name(func)
)
if f.func.has_symint():
kernel_name += "_symint"
backend_metadata = {
DispatchKey.CompositeExplicitAutograd: {
func.name: BackendMetadata(
@ -555,7 +557,7 @@ def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]:
clone_mutable_inputs_str = "\n".join(clone_mutable_inputs)
return f"""
{sig.defn()} {{
{sig.defn(name=sig.name() + ("_symint" if g.out.func.has_symint() else ""))} {{
{clone_mutable_inputs_str}
{maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs});
{ret_str}
@ -615,7 +617,7 @@ def gen_composite_out_kernel(g: NativeFunctionsGroup) -> Optional[str]:
# Kernel name needs to follow the naming convention defined in `generate_function()`
return f"""
{sig.defn(name=g.out.func.name.unambiguous_name())} {{
{sig.defn(name=g.out.func.name.unambiguous_name() + ("_symint" if g.out.func.has_symint() else ""))} {{
auto {out_name} = at::_ops::{g.functional.func.name.unambiguous_name()}::call({exprs});
{copy_outs_str}
{return_str(g.out.func.returns, rets)}