[AOTI] Extend torchgen to generate C shim with version number (#147745)

Summary: While it is ok to add a new arg with defaul value to a fallback op in Python, it will be BC-breaking for the C shim. This PR adds an automatic approach to update C shim files when specifying a version number with a list of new args for the modified op. See https://github.com/pytorch/pytorch/pull/154848 as an example on how to do that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147745
Approved by: https://github.com/yushangdi
This commit is contained in:
Bin Bao
2025-06-02 08:13:40 -07:00
committed by PyTorch MergeBot
parent 1d67849e43
commit 197080337b
3 changed files with 304 additions and 226 deletions

View File

@ -199,11 +199,16 @@ def zip_type_and_name(types: list[str], names: list[str]) -> list[str]:
# Generate argument declarations and callsite expressions
def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[str]]:
types = []
new_names = []
callsite_exprs = []
def gen_arguments(
flat_arguments: Sequence[Argument], skipped_args: set[str]
) -> tuple[list[str], list[str]]:
types: list[str] = []
new_names: list[str] = []
callsite_exprs: list[str] = []
for arg in flat_arguments:
if arg.name in skipped_args:
callsite_exprs.append("std::nullopt")
continue
new_types, names, _, new_callsite_exprs = convert_arg_type_and_name(
arg.type, arg.name, arg.is_write
)
@ -230,7 +235,7 @@ def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
def convert_return(typ: BaseType, val: str) -> str:
if typ.name == BaseTy.Tensor:
return f"new_tensor_handle(std::move({val}));"
return f"new_tensor_handle(std::move({val}))"
elif typ.name == BaseTy.SymInt:
return f"{val}.expect_int()"
elif typ.name == BaseTy.Scalar:
@ -269,47 +274,93 @@ declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {}
def gen_declaration_and_definition(
schema: FunctionSchema, device: str, backend_call: str
schema: FunctionSchema,
device: str,
backend_call: str,
version_info: dict[str, list[str]],
) -> tuple[str, str]:
func_name = schema.name.unambiguous_name()
base_name = schema.name.unambiguous_name()
global declaration_definition_cache
if (func_name, device, backend_call) in declaration_definition_cache:
return declaration_definition_cache[(func_name, device, backend_call)]
if (base_name, device, backend_call) in declaration_definition_cache:
return declaration_definition_cache[(base_name, device, backend_call)]
if schema.is_out_fn():
# out_variant has out arguments in the front, and it's ok to ignore return values
# because C shim functions only return AOTITorchError
args, callsite_exprs = gen_arguments(
[*schema.arguments.out, *schema.arguments.flat_non_out]
# Check the validity of version_info. The format should look like
# {"v2" : ["new_arg1"], "v3": ["new_arg2, new_arg3"]}.
indexed_version_info: dict[int, list[str]] = {1: []}
for ver_str, new_args in sorted(version_info.items()):
assert ver_str.startswith("v"), (
f"Version number for {base_name} is {ver_str}, not starting with 'v'"
)
ret_assignments: list[str] = []
else:
args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
# ignore return values for inplace ops
ret_declarations, ret_assignments = (
([], []) if schema.name.name.inplace else gen_returns(schema)
try:
ver_id = int(ver_str[1:])
except ValueError as e:
raise AssertionError(
f"Version number for {base_name} is {ver_str}, not a valid integer after 'v'"
) from e
assert ver_id not in indexed_version_info, (
f"{ver_str} for {base_name} has already been defined"
)
args.extend(ret_declarations)
indexed_version_info[ver_id] = new_args
declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
declarations: list[str] = []
definitions: list[str] = []
skipped_args: set[str] = set()
tmp_result = "auto tmp_result = " if ret_assignments else ""
ret_assignments_str = "\n" + "\n".join(ret_assignments) if ret_assignments else ""
definition = f"""
{declaration} {{
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{
{tmp_result}{backend_call}(
{textwrap.indent(", ".join(callsite_exprs), " ")}
);{textwrap.indent(ret_assignments_str, " ")}
}});
}}
"""
declaration_definition_cache[(func_name, device, backend_call)] = (
declaration,
definition,
for ver_id, new_args in sorted(indexed_version_info.items(), reverse=True):
# Iterate in the reverse order, so the latest version of an op will get generated first
# with all the arguments included, while a set of to-be-trimmed args is carried down
# to generate earlier version of the op.
func_name = base_name if ver_id == 1 else f"{base_name}_v{ver_id}"
if schema.is_out_fn():
# out_variant has out arguments in the front, and it's ok to ignore return values
# because C shim functions only return AOTITorchError
args, callsite_exprs = gen_arguments(
[*schema.arguments.out, *schema.arguments.flat_non_out], skipped_args
)
ret_assignments: list[str] = []
else:
args, callsite_exprs = gen_arguments(
schema.arguments.flat_all, skipped_args
)
# ignore return values for inplace ops
ret_declarations, ret_assignments = (
([], []) if schema.name.name.inplace else gen_returns(schema)
)
args.extend(ret_declarations)
declaration = textwrap.dedent(
f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
)
tmp_result = "auto tmp_result = " if ret_assignments else ""
indent = "\t\t"
ret_assignments_str = (
"\n".join(indent + r for r in ret_assignments) if ret_assignments else ""
)
definition = (
textwrap.dedent(f"""
{declaration} {{
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{
{tmp_result}{backend_call}(
{", ".join(callsite_exprs)}
);
""")
+ ret_assignments_str
+ textwrap.dedent("""
});
}
""")
)
skipped_args.update(new_args)
declarations.append(f"AOTI_TORCH_EXPORT {declaration};")
definitions.append(definition)
declaration_definition_cache[(base_name, device, backend_call)] = (
"\n".join(declarations),
"\n".join(definitions),
)
return declaration, definition
return declaration_definition_cache[(base_name, device, backend_call)]
def gen_static_dispatch_backend_call_signature(
@ -402,6 +453,7 @@ def get_fallback_op_name(func: NativeFunction) -> str:
def gen_c_shim(
func: NativeFunction,
version_info: dict[str, list[str]],
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey,
backend_indices: dict[DispatchKey, BackendIndex],
@ -424,11 +476,13 @@ def gen_c_shim(
try:
if header:
declaration, _ = gen_declaration_and_definition(
schema, device, backend_call
schema, device, backend_call, version_info
)
return f"AOTI_TORCH_EXPORT {declaration};"
return declaration
else:
_, definition = gen_declaration_and_definition(schema, device, backend_call)
_, definition = gen_declaration_and_definition(
schema, device, backend_call, version_info
)
return definition
except NotImplementedError:
@ -437,6 +491,7 @@ def gen_c_shim(
@dataclass(frozen=True)
class ShimGenerator:
inductor_fallback_ops: dict[str, dict[str, list[str]]]
func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
dispatch_key: DispatchKey
backend_indices: dict[DispatchKey, BackendIndex]
@ -448,8 +503,10 @@ class ShimGenerator:
self,
func: NativeFunction,
) -> str | None:
version_info = self.inductor_fallback_ops[get_fallback_op_name(func)]
result = gen_c_shim(
func,
version_info,
self.func_group_mapping,
self.dispatch_key,
self.backend_indices,
@ -461,6 +518,7 @@ class ShimGenerator:
def gen_aoti_c_shim(
native_functions: Sequence[NativeFunction],
inductor_fallback_ops: dict[str, dict[str, list[str]]],
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey,
backend_indices: dict[DispatchKey, BackendIndex],
@ -472,6 +530,7 @@ def gen_aoti_c_shim(
list(
mapMaybe(
ShimGenerator(
inductor_fallback_ops,
func_group_mapping,
dispatch_key,
backend_indices,
@ -484,44 +543,53 @@ def gen_aoti_c_shim(
)
device = dispatch_key.lower()
warning = """
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details"""
if header:
return f"""
{warning}
return (
warning
+ textwrap.dedent("""
#pragma once
#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#ifdef __cplusplus
extern "C" {{
#endif
#ifdef __cplusplus
extern "C" {
#endif
{body}
#ifdef __cplusplus
}} // extern "C"
#endif
"""
""")
+ body
+ textwrap.dedent("""
#ifdef __cplusplus
} // extern "C"
#endif
""")
)
else:
return f"""
{warning}
return (
warning
+ textwrap.dedent(f"""
#include <torch/csrc/inductor/aoti_torch/generated/{"extend/" if extend_aoti_c_shim else ""}c_shim_{device}.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <torch/csrc/inductor/aoti_torch/generated/{"extend/" if extend_aoti_c_shim else ""}c_shim_{device}.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/{str(dispatch_key)}Functions.h>
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
#include <ATen/CompositeImplicitAutogradFunctions.h>
#else
{includes}
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/{str(dispatch_key)}Functions.h>
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
#include <ATen/CompositeImplicitAutogradFunctions.h>
#else
""")
+ includes
+ textwrap.dedent("""
#endif // AT_PER_OPERATOR_HEADERS
using namespace torch::aot_inductor;
using namespace torch::aot_inductor;
{body}"""
""")
+ body
)