Generate op variants for core CIA ops (#141797)

There are four core ATen ops with Composite Implicit Autograd (CIA) dispatch: upsample_bilinear2d.vec, upsample_nearest2d.vec, avg_pool1d, and adaptive_avg_pool1d. Op variant auto-generation is currently skipped for CIA ops. In preparation to disable the decompositions for upsample ops by default in export, we need to generate out variants for these ops.

This change enables autogen for core-tagged CIA ops, which enables generation of upsample_bilinear2d.vec_out and upsample_nearest2d.vec_out.

Test Plan:
Added a new test test_functional_variant_autogen_out_variant_core to cover this case in test_codegen.py.
Confirmed that upsample_bilinear2d.vec_out and upsample_nearest2d.vec_out op overloads are registered (they were previously not available).

Differential Revision: D66590257

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141797
Approved by: https://github.com/larryliu0820
This commit is contained in:
Gregory Comer
2024-12-03 22:57:44 +00:00
committed by PyTorch MergeBot
parent f0b33658f8
commit da5b281f23
3 changed files with 31 additions and 2 deletions

View File

@ -649,6 +649,7 @@ aten::_weight_int4pack_mm_for_cpu
aten::_weight_int8pack_mm
aten::_weight_norm_interface_backward
aten::_weight_norm_interface_backward.out
aten::adaptive_avg_pool1d.out
aten::adaptive_avg_pool2d.out
aten::adaptive_avg_pool3d.out
aten::adaptive_avg_pool3d_backward.grad_input
@ -672,6 +673,7 @@ aten::argmin
aten::argmin.out
aten::as_strided
aten::as_strided_
aten::avg_pool1d.out
aten::avg_pool2d
aten::avg_pool2d.out
aten::avg_pool2d_backward
@ -1312,12 +1314,14 @@ aten::unsafe_split_with_sizes.out
aten::unsqueeze_
aten::upsample_bicubic2d_backward
aten::upsample_bicubic2d_backward.grad_input
aten::upsample_bilinear2d.vec_out
aten::upsample_bilinear2d_backward
aten::upsample_bilinear2d_backward.grad_input
aten::upsample_linear1d_backward
aten::upsample_linear1d_backward.grad_input
aten::upsample_nearest1d_backward
aten::upsample_nearest1d_backward.grad_input
aten::upsample_nearest2d.vec_out
aten::upsample_nearest2d_backward
aten::upsample_nearest2d_backward.grad_input
aten::upsample_nearest3d_backward

View File

@ -410,6 +410,17 @@ class TestNativeFunctionGeneratrion(unittest.TestCase):
)
BackendIndex.grow_index(self.backend_indices, two_returns_backend_index)
self.core_func, core_func_index = NativeFunction.from_yaml(
{
"func": "op_3.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor",
"autogen": "op_3.vec_out",
"tags": ["core"],
},
loc=Location(__file__, 1),
valid_tags={"core"},
)
BackendIndex.grow_index(self.backend_indices, core_func_index)
def test_functional_variant_autogen_out_variant(self) -> None:
native_functions = [self.one_return_func]
add_generated_native_functions(native_functions, self.backend_indices)
@ -438,6 +449,19 @@ class TestNativeFunctionGeneratrion(unittest.TestCase):
]
self.assertEqual(backend_metadata.kernel, "op_2_out")
def test_functional_variant_autogen_out_variant_core(self) -> None:
"""
Tests autogen of out variants for core-tageed ops that are CompositeImplicitAutograd.
"""
native_functions = [self.core_func]
add_generated_native_functions(native_functions, self.backend_indices)
print(native_functions)
self.assertEqual(len(native_functions), 2)
self.assertEqual(
str(native_functions[1].func),
"op_3.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)",
)
# Test for static_dispatch
class TestStaticDispatchGeneratrion(unittest.TestCase):

View File

@ -393,6 +393,7 @@ def add_generated_native_functions(
has_inplace = SchemaKind.inplace in d
has_mutable = SchemaKind.mutable in d
has_out = SchemaKind.out in d
is_core = any("core" in variant.tags for variant in d.values())
# We automatically generate a few native functions that don't exist in the yaml, for a few reasons:
# (1) If an operator has an inplace/out= variant but no functional variant, we can generate
@ -409,14 +410,14 @@ def add_generated_native_functions(
has_view_ops = any(
f.is_view_op and str(f.func.name.name) != "set_" for f in d.values()
)
# Don't generate the other variants for CompositeImplicitAutograd operators.
# Don't generate the other variants for non-core CompositeImplicitAutograd operators.
# We could probably do this, but the main benefit of generating the function triplets
# is for transforms that need them, and transforms don't need to act directly
# on CompositeImplicitAutograd operators (since we let them decompose).
are_composite_implicit = all(
f.has_composite_implicit_autograd_kernel for f in d.values()
)
if are_manual or has_view_ops or are_composite_implicit:
if are_manual or has_view_ops or are_composite_implicit and not is_core:
continue
if has_out and len(d.values()) == 1:
# Note: [Out ops with functional variants that don't get grouped properly]