[AOTI] Add int return type support for custom op in proxy executor (#155465)

Summary:
When a custom op has int return type in its schema. The returned value will be specialized and such behaviour is different from a symint return type. This diff **only added support for int return type**.

As the returned int will be specialized and fused into downstream kernels (if being used), we can simply skip the int return type in the proxy executor.

Note that in the eager run, the returned int will be specialized to the value defined in the real impl of the custom op. In exported program or in AOTI, the returned int will be specialized to the value defined in the fake impl of the custom op. So the definitions of the return value should be consistent across real and fake impl of the custom op. Otherwise the eager run and AOTI run will have different results.

Test Plan:
```
buck2 run mode/dev-nosan caffe2/test/inductor:test_aot_inductor_custom_ops -- -r test_fn_with_int_output
```

Rollback Plan:

Differential Revision: D76159406

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155465
Approved by: https://github.com/angelayi
This commit is contained in:
Yiming Zhou
2025-06-10 01:07:15 +00:00
committed by PyTorch MergeBot
parent da50835bde
commit 1851f50866
6 changed files with 97 additions and 3 deletions

View File

@ -41,7 +41,23 @@ std::tuple<Tensor, std::optional<Tensor>, std::optional<Tensor>> fn_with_optiona
return {t3, t4, t5};
}
std::tuple<Tensor, std::optional<Tensor>, std::optional<Tensor>, int64_t, int64_t> fn_with_int_output_impl(Tensor t1, Tensor t2, int64_t i1) {
Tensor t3 = t1 + t2;
Tensor t4 = t1 - t2;
Tensor t5;
int64_t i2 = 0;
int64_t i3 = 0;
return {t3, t4, t5, i2, i3};
}
std::tuple<Tensor, std::optional<Tensor>, std::optional<Tensor>, int64_t, int64_t> fn_with_int_output_meta(Tensor t1, Tensor t2, int64_t i1) {
Tensor t3 = t1.clone();
Tensor t4 = t1.clone();
Tensor t5;
int64_t i2 = 0;
int64_t i3 = 0;
return {t3, t4, t5, i2, i3};
}
Tensor fn_with_all_inputs_impl(
const Tensor& tensor,
@ -381,6 +397,7 @@ TORCH_LIBRARY(aoti_custom_ops, m) {
m.def("custom_add(Tensor t1, Tensor t2) -> Tensor");
m.def("fn_with_optional_tensor_output(Tensor t1, Tensor t2) -> (Tensor, Tensor?, Tensor?)");
m.def("fn_with_optional_tensor_output_2(Tensor t1, Tensor t2) -> (Tensor, Tensor?, Tensor?)");
m.def("fn_with_int_output(Tensor t1, Tensor t2, int i) -> (Tensor, Tensor?, Tensor?, int, int)");
m.def(
"fn_with_all_inputs(Tensor tensor, "
"Tensor[] tensors, "
@ -428,6 +445,7 @@ TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
m.impl("custom_add", at::custom_add_impl);
m.impl("fn_with_optional_tensor_output", at::fn_with_optional_tensor_output_impl);
m.impl("fn_with_optional_tensor_output_2", at::fn_with_optional_tensor_output_2_impl);
m.impl("fn_with_int_output", at::fn_with_int_output_impl);
m.impl("fn_with_all_inputs", at::fn_with_all_inputs_impl);
m.impl("fn_with_default_input", at::fn_with_default_input_impl);
m.impl("fn_with_tuple_output", at::fn_with_tuple_output_impl);
@ -441,6 +459,7 @@ TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
m.impl("fn_with_optional_tensor_output", at::fn_with_optional_tensor_output_meta);
m.impl("fn_with_optional_tensor_output_2", at::fn_with_optional_tensor_output_2_meta);
m.impl("fn_with_int_output", at::fn_with_int_output_meta);
m.impl("fn_with_all_inputs", at::fn_with_all_inputs_meta);
m.impl("fn_with_default_input", at::fn_with_default_input_meta);
m.impl("fn_with_tuple_output", at::fn_with_tuple_output_meta);

View File

@ -161,6 +161,20 @@ class AOTInductorTestsTemplate:
)
self.check_model(m, args)
def test_fn_with_int_output(self) -> None:
class M(torch.nn.Module):
def forward(self, x, y):
i = x.shape[0]
z, _, _, i1, i2 = torch.ops.aoti_custom_ops.fn_with_int_output(x, y, i)
return z, z * (i1 + i2 + i)
m = M().to(device=self.device)
args = (
torch.randn(3, 3, device=self.device),
torch.randn(3, 3, device=self.device),
)
self.check_model(m, args)
def test_custom_op_all_inputs(self) -> None:
class MyModel(torch.nn.Module):
# pyre-fixme[3]: Return type must be annotated.

View File

@ -1885,6 +1885,11 @@ class CppWrapperCpu(PythonWrapperCodegen):
output_args: Optional[list[str]] = None,
raw_outputs: Optional[list[ir.Buffer]] = None,
):
"""
Generates declarations for external kernel arguments if needed, based on the provided
operator and its arguments. It processes both input and output arguments, categorizing
them into tensor and integer arguments for further code generation.
"""
schema = None
if isinstance(op_overload, torch._higher_order_ops.torchbind.CallTorchBind):
obj = raw_args[0]
@ -2006,7 +2011,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
# TODO: Only support None and tensor(s) returns for now, SymInt is not implemented yet
for return_type in return_types:
if isinstance(return_type, (torch.TensorType, torch.NoneType)):
if isinstance(
return_type, (torch.TensorType, torch.NoneType, torch.IntType)
):
pass
elif isinstance(return_type, torch.OptionalType):
assert isinstance(return_type.getElementType(), torch.TensorType)
@ -2021,6 +2028,8 @@ class CppWrapperCpu(PythonWrapperCodegen):
# None output is supported, but Optional return types are not yet supported
if output_arg is None:
continue
elif isinstance(raw_output_arg, int):
new_int_args.append(str(raw_output_arg))
elif isinstance(output_arg, (list, tuple)):
for out in output_arg:
fill_output_arg(
@ -2060,6 +2069,8 @@ class CppWrapperCpu(PythonWrapperCodegen):
return mutated_buf_names[0]
elif isinstance(out, (list, tuple)):
return type(out)(extract_output_name(o) for o in out)
elif isinstance(out, int):
return str(out)
else:
raise AssertionError(f"Unexpected output: {type(out)}")

View File

@ -6753,6 +6753,12 @@ class ExternKernelNode:
class FallbackKernel(ExternKernelAlloc):
"""
A class that represents a fallback kernel for handling operators that are not
directly support by inductor. It currently supports functional ops, view ops,
implace aten ops, and mutating ops that are auto-functionalizable.
"""
def __init__( # type: ignore[no-untyped-def]
self,
layout,
@ -7023,6 +7029,8 @@ class FallbackKernel(ExternKernelAlloc):
)
)
)
elif isinstance(return_type, torch.IntType):
return export_schema.Argument.create(as_int=output)
else:
raise RuntimeError(f"Unsupported return type {type(return_type)}")

View File

@ -536,6 +536,17 @@ void OSSProxyExecutor::get_output_info_from_serialized(
}
break;
}
case c10::TypeKind::IntType: {
TORCH_CHECK(
serialized_output_type == "as_int",
"Expected extern kernel ",
serialized_node["target"],
" to have serialized output type as_int, ",
" but got ",
serialized_output_type);
outputs.emplace_back(output_index, DynamicArgType::IntType, 1);
break;
}
default: {
TORCH_CHECK(
false,
@ -800,12 +811,14 @@ void OSSProxyExecutor::call_function(
tensor_id,
", expected num = ",
num_tensors - num_output_tensors);
int num_output_ints = op_kernel->num_output_ints();
TORCH_CHECK(
int_id == num_ints,
int_id == num_ints - num_output_ints,
"Mismatch between ints consumed and num_ints, got int_id = ",
int_id,
", num_ints = ",
num_ints);
num_ints - num_output_ints);
// Call the op with the prepared stack.
op_kernel->run(stack);
@ -851,6 +864,18 @@ void OSSProxyExecutor::call_function(
} else {
index++;
}
} else if (schema_return.real_type()->kind() == c10::TypeKind::IntType) {
// need to use real_type() to differentiate between IntType and SymIntType
// for int type, it is already specialized in downstream kernels. So we
// don't need to do anything here.
auto returned_int_value = stack[index++].toInt();
auto serialized_int_value = flatten_int_args[int_id++];
TORCH_CHECK(
returned_int_value == serialized_int_value,
"Expect returned int value to match the serialized int value, but got retured int value: ",
returned_int_value,
" and serialized int value: ",
serialized_int_value);
} else {
TORCH_CHECK(
false,
@ -865,6 +890,13 @@ void OSSProxyExecutor::call_function(
tensor_id,
", expected num = ",
num_tensors);
TORCH_CHECK(
int_id == num_ints,
"Mismatch between tensors consumed and num_ints, got tensor_id = ",
int_id,
", expected num = ",
num_ints);
}
} // namespace torch::aot_inductor

View File

@ -82,6 +82,16 @@ struct OSSOpKernel {
return num_output_tensors;
}
int num_output_ints() const {
int num_output_ints = 0;
for (const auto& output : outputs_) {
if (output.arg_type == DynamicArgType::IntType) {
num_output_ints += output.length;
}
}
return num_output_ints;
}
virtual void run(std::vector<c10::IValue>& stack) = 0;
virtual c10::FunctionSchema schema() const = 0;
virtual ~OSSOpKernel() = default;