mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI] Add int return type support for custom op in proxy executor (#155465)
Summary: When a custom op has int return type in its schema. The returned value will be specialized and such behaviour is different from a symint return type. This diff **only added support for int return type**. As the returned int will be specialized and fused into downstream kernels (if being used), we can simply skip the int return type in the proxy executor. Note that in the eager run, the returned int will be specialized to the value defined in the real impl of the custom op. In exported program or in AOTI, the returned int will be specialized to the value defined in the fake impl of the custom op. So the definitions of the return value should be consistent across real and fake impl of the custom op. Otherwise the eager run and AOTI run will have different results. Test Plan: ``` buck2 run mode/dev-nosan caffe2/test/inductor:test_aot_inductor_custom_ops -- -r test_fn_with_int_output ``` Rollback Plan: Differential Revision: D76159406 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155465 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
da50835bde
commit
1851f50866
@ -41,7 +41,23 @@ std::tuple<Tensor, std::optional<Tensor>, std::optional<Tensor>> fn_with_optiona
|
||||
return {t3, t4, t5};
|
||||
}
|
||||
|
||||
std::tuple<Tensor, std::optional<Tensor>, std::optional<Tensor>, int64_t, int64_t> fn_with_int_output_impl(Tensor t1, Tensor t2, int64_t i1) {
|
||||
Tensor t3 = t1 + t2;
|
||||
Tensor t4 = t1 - t2;
|
||||
Tensor t5;
|
||||
int64_t i2 = 0;
|
||||
int64_t i3 = 0;
|
||||
return {t3, t4, t5, i2, i3};
|
||||
}
|
||||
|
||||
std::tuple<Tensor, std::optional<Tensor>, std::optional<Tensor>, int64_t, int64_t> fn_with_int_output_meta(Tensor t1, Tensor t2, int64_t i1) {
|
||||
Tensor t3 = t1.clone();
|
||||
Tensor t4 = t1.clone();
|
||||
Tensor t5;
|
||||
int64_t i2 = 0;
|
||||
int64_t i3 = 0;
|
||||
return {t3, t4, t5, i2, i3};
|
||||
}
|
||||
|
||||
Tensor fn_with_all_inputs_impl(
|
||||
const Tensor& tensor,
|
||||
@ -381,6 +397,7 @@ TORCH_LIBRARY(aoti_custom_ops, m) {
|
||||
m.def("custom_add(Tensor t1, Tensor t2) -> Tensor");
|
||||
m.def("fn_with_optional_tensor_output(Tensor t1, Tensor t2) -> (Tensor, Tensor?, Tensor?)");
|
||||
m.def("fn_with_optional_tensor_output_2(Tensor t1, Tensor t2) -> (Tensor, Tensor?, Tensor?)");
|
||||
m.def("fn_with_int_output(Tensor t1, Tensor t2, int i) -> (Tensor, Tensor?, Tensor?, int, int)");
|
||||
m.def(
|
||||
"fn_with_all_inputs(Tensor tensor, "
|
||||
"Tensor[] tensors, "
|
||||
@ -428,6 +445,7 @@ TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
|
||||
m.impl("custom_add", at::custom_add_impl);
|
||||
m.impl("fn_with_optional_tensor_output", at::fn_with_optional_tensor_output_impl);
|
||||
m.impl("fn_with_optional_tensor_output_2", at::fn_with_optional_tensor_output_2_impl);
|
||||
m.impl("fn_with_int_output", at::fn_with_int_output_impl);
|
||||
m.impl("fn_with_all_inputs", at::fn_with_all_inputs_impl);
|
||||
m.impl("fn_with_default_input", at::fn_with_default_input_impl);
|
||||
m.impl("fn_with_tuple_output", at::fn_with_tuple_output_impl);
|
||||
@ -441,6 +459,7 @@ TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
|
||||
TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
|
||||
m.impl("fn_with_optional_tensor_output", at::fn_with_optional_tensor_output_meta);
|
||||
m.impl("fn_with_optional_tensor_output_2", at::fn_with_optional_tensor_output_2_meta);
|
||||
m.impl("fn_with_int_output", at::fn_with_int_output_meta);
|
||||
m.impl("fn_with_all_inputs", at::fn_with_all_inputs_meta);
|
||||
m.impl("fn_with_default_input", at::fn_with_default_input_meta);
|
||||
m.impl("fn_with_tuple_output", at::fn_with_tuple_output_meta);
|
||||
|
||||
@ -161,6 +161,20 @@ class AOTInductorTestsTemplate:
|
||||
)
|
||||
self.check_model(m, args)
|
||||
|
||||
def test_fn_with_int_output(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
i = x.shape[0]
|
||||
z, _, _, i1, i2 = torch.ops.aoti_custom_ops.fn_with_int_output(x, y, i)
|
||||
return z, z * (i1 + i2 + i)
|
||||
|
||||
m = M().to(device=self.device)
|
||||
args = (
|
||||
torch.randn(3, 3, device=self.device),
|
||||
torch.randn(3, 3, device=self.device),
|
||||
)
|
||||
self.check_model(m, args)
|
||||
|
||||
def test_custom_op_all_inputs(self) -> None:
|
||||
class MyModel(torch.nn.Module):
|
||||
# pyre-fixme[3]: Return type must be annotated.
|
||||
|
||||
@ -1885,6 +1885,11 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
output_args: Optional[list[str]] = None,
|
||||
raw_outputs: Optional[list[ir.Buffer]] = None,
|
||||
):
|
||||
"""
|
||||
Generates declarations for external kernel arguments if needed, based on the provided
|
||||
operator and its arguments. It processes both input and output arguments, categorizing
|
||||
them into tensor and integer arguments for further code generation.
|
||||
"""
|
||||
schema = None
|
||||
if isinstance(op_overload, torch._higher_order_ops.torchbind.CallTorchBind):
|
||||
obj = raw_args[0]
|
||||
@ -2006,7 +2011,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
|
||||
# TODO: Only support None and tensor(s) returns for now, SymInt is not implemented yet
|
||||
for return_type in return_types:
|
||||
if isinstance(return_type, (torch.TensorType, torch.NoneType)):
|
||||
if isinstance(
|
||||
return_type, (torch.TensorType, torch.NoneType, torch.IntType)
|
||||
):
|
||||
pass
|
||||
elif isinstance(return_type, torch.OptionalType):
|
||||
assert isinstance(return_type.getElementType(), torch.TensorType)
|
||||
@ -2021,6 +2028,8 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
# None output is supported, but Optional return types are not yet supported
|
||||
if output_arg is None:
|
||||
continue
|
||||
elif isinstance(raw_output_arg, int):
|
||||
new_int_args.append(str(raw_output_arg))
|
||||
elif isinstance(output_arg, (list, tuple)):
|
||||
for out in output_arg:
|
||||
fill_output_arg(
|
||||
@ -2060,6 +2069,8 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
return mutated_buf_names[0]
|
||||
elif isinstance(out, (list, tuple)):
|
||||
return type(out)(extract_output_name(o) for o in out)
|
||||
elif isinstance(out, int):
|
||||
return str(out)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected output: {type(out)}")
|
||||
|
||||
|
||||
@ -6753,6 +6753,12 @@ class ExternKernelNode:
|
||||
|
||||
|
||||
class FallbackKernel(ExternKernelAlloc):
|
||||
"""
|
||||
A class that represents a fallback kernel for handling operators that are not
|
||||
directly support by inductor. It currently supports functional ops, view ops,
|
||||
implace aten ops, and mutating ops that are auto-functionalizable.
|
||||
"""
|
||||
|
||||
def __init__( # type: ignore[no-untyped-def]
|
||||
self,
|
||||
layout,
|
||||
@ -7023,6 +7029,8 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(return_type, torch.IntType):
|
||||
return export_schema.Argument.create(as_int=output)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported return type {type(return_type)}")
|
||||
|
||||
|
||||
@ -536,6 +536,17 @@ void OSSProxyExecutor::get_output_info_from_serialized(
|
||||
}
|
||||
break;
|
||||
}
|
||||
case c10::TypeKind::IntType: {
|
||||
TORCH_CHECK(
|
||||
serialized_output_type == "as_int",
|
||||
"Expected extern kernel ",
|
||||
serialized_node["target"],
|
||||
" to have serialized output type as_int, ",
|
||||
" but got ",
|
||||
serialized_output_type);
|
||||
outputs.emplace_back(output_index, DynamicArgType::IntType, 1);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
@ -800,12 +811,14 @@ void OSSProxyExecutor::call_function(
|
||||
tensor_id,
|
||||
", expected num = ",
|
||||
num_tensors - num_output_tensors);
|
||||
|
||||
int num_output_ints = op_kernel->num_output_ints();
|
||||
TORCH_CHECK(
|
||||
int_id == num_ints,
|
||||
int_id == num_ints - num_output_ints,
|
||||
"Mismatch between ints consumed and num_ints, got int_id = ",
|
||||
int_id,
|
||||
", num_ints = ",
|
||||
num_ints);
|
||||
num_ints - num_output_ints);
|
||||
|
||||
// Call the op with the prepared stack.
|
||||
op_kernel->run(stack);
|
||||
@ -851,6 +864,18 @@ void OSSProxyExecutor::call_function(
|
||||
} else {
|
||||
index++;
|
||||
}
|
||||
} else if (schema_return.real_type()->kind() == c10::TypeKind::IntType) {
|
||||
// need to use real_type() to differentiate between IntType and SymIntType
|
||||
// for int type, it is already specialized in downstream kernels. So we
|
||||
// don't need to do anything here.
|
||||
auto returned_int_value = stack[index++].toInt();
|
||||
auto serialized_int_value = flatten_int_args[int_id++];
|
||||
TORCH_CHECK(
|
||||
returned_int_value == serialized_int_value,
|
||||
"Expect returned int value to match the serialized int value, but got retured int value: ",
|
||||
returned_int_value,
|
||||
" and serialized int value: ",
|
||||
serialized_int_value);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
@ -865,6 +890,13 @@ void OSSProxyExecutor::call_function(
|
||||
tensor_id,
|
||||
", expected num = ",
|
||||
num_tensors);
|
||||
|
||||
TORCH_CHECK(
|
||||
int_id == num_ints,
|
||||
"Mismatch between tensors consumed and num_ints, got tensor_id = ",
|
||||
int_id,
|
||||
", expected num = ",
|
||||
num_ints);
|
||||
}
|
||||
|
||||
} // namespace torch::aot_inductor
|
||||
|
||||
@ -82,6 +82,16 @@ struct OSSOpKernel {
|
||||
return num_output_tensors;
|
||||
}
|
||||
|
||||
int num_output_ints() const {
|
||||
int num_output_ints = 0;
|
||||
for (const auto& output : outputs_) {
|
||||
if (output.arg_type == DynamicArgType::IntType) {
|
||||
num_output_ints += output.length;
|
||||
}
|
||||
}
|
||||
return num_output_ints;
|
||||
}
|
||||
|
||||
virtual void run(std::vector<c10::IValue>& stack) = 0;
|
||||
virtual c10::FunctionSchema schema() const = 0;
|
||||
virtual ~OSSOpKernel() = default;
|
||||
|
||||
Reference in New Issue
Block a user