mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use presence of _symint in kernel name to generate symint sig or not (#84579)
Something people found confusing was that whether or not a native:: signature would get SymInt or not in its type was based on the dispatch key. This changes it so that SymInt or not in type is based on whether or not you have _symint in the name of the kernel or not. This means that even when we make operators support SymInt, you no longer have to go and update all the preexisting definitions; instead, you now selectively write _symint to opt individual kernels into SymInt support. I then go and update a bunch of kernels that don't have proper SymInt support to make use of this convention. There is some hacking around for view generation code. I also add support for external backends to specify 'symint' operators, for which we generate SymInt signatures instead of regular signatures. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: [D39310060](https://our.internmc.facebook.com/intern/diff/D39310060) Pull Request resolved: https://github.com/pytorch/pytorch/pull/84579 Approved by: https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
18a31cc044
commit
93aef3a010
@ -304,6 +304,8 @@ def generate_function(
|
||||
if func.kind() == SchemaKind.out
|
||||
else cpp.name(func)
|
||||
)
|
||||
if f.func.has_symint():
|
||||
kernel_name += "_symint"
|
||||
backend_metadata = {
|
||||
DispatchKey.CompositeExplicitAutograd: {
|
||||
func.name: BackendMetadata(
|
||||
@ -555,7 +557,7 @@ def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]:
|
||||
|
||||
clone_mutable_inputs_str = "\n".join(clone_mutable_inputs)
|
||||
return f"""
|
||||
{sig.defn()} {{
|
||||
{sig.defn(name=sig.name() + ("_symint" if g.out.func.has_symint() else ""))} {{
|
||||
{clone_mutable_inputs_str}
|
||||
{maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs});
|
||||
{ret_str}
|
||||
@ -615,7 +617,7 @@ def gen_composite_out_kernel(g: NativeFunctionsGroup) -> Optional[str]:
|
||||
|
||||
# Kernel name needs to follow the naming convention defined in `generate_function()`
|
||||
return f"""
|
||||
{sig.defn(name=g.out.func.name.unambiguous_name())} {{
|
||||
{sig.defn(name=g.out.func.name.unambiguous_name() + ("_symint" if g.out.func.has_symint() else ""))} {{
|
||||
auto {out_name} = at::_ops::{g.functional.func.name.unambiguous_name()}::call({exprs});
|
||||
{copy_outs_str}
|
||||
{return_str(g.out.func.returns, rets)}
|
||||
|
Reference in New Issue
Block a user