mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI] Extend torchgen to generate C shim with version number (#147745)
Summary: While it is ok to add a new arg with defaul value to a fallback op in Python, it will be BC-breaking for the C shim. This PR adds an automatic approach to update C shim files when specifying a version number with a list of new args for the modified op. See https://github.com/pytorch/pytorch/pull/154848 as an example on how to do that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147745 Approved by: https://github.com/yushangdi
This commit is contained in:
committed by
PyTorch MergeBot
parent
1d67849e43
commit
197080337b
@ -199,11 +199,16 @@ def zip_type_and_name(types: list[str], names: list[str]) -> list[str]:
|
||||
|
||||
|
||||
# Generate argument declarations and callsite expressions
|
||||
def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[str]]:
|
||||
types = []
|
||||
new_names = []
|
||||
callsite_exprs = []
|
||||
def gen_arguments(
|
||||
flat_arguments: Sequence[Argument], skipped_args: set[str]
|
||||
) -> tuple[list[str], list[str]]:
|
||||
types: list[str] = []
|
||||
new_names: list[str] = []
|
||||
callsite_exprs: list[str] = []
|
||||
for arg in flat_arguments:
|
||||
if arg.name in skipped_args:
|
||||
callsite_exprs.append("std::nullopt")
|
||||
continue
|
||||
new_types, names, _, new_callsite_exprs = convert_arg_type_and_name(
|
||||
arg.type, arg.name, arg.is_write
|
||||
)
|
||||
@ -230,7 +235,7 @@ def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
|
||||
|
||||
def convert_return(typ: BaseType, val: str) -> str:
|
||||
if typ.name == BaseTy.Tensor:
|
||||
return f"new_tensor_handle(std::move({val}));"
|
||||
return f"new_tensor_handle(std::move({val}))"
|
||||
elif typ.name == BaseTy.SymInt:
|
||||
return f"{val}.expect_int()"
|
||||
elif typ.name == BaseTy.Scalar:
|
||||
@ -269,47 +274,93 @@ declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {}
|
||||
|
||||
|
||||
def gen_declaration_and_definition(
|
||||
schema: FunctionSchema, device: str, backend_call: str
|
||||
schema: FunctionSchema,
|
||||
device: str,
|
||||
backend_call: str,
|
||||
version_info: dict[str, list[str]],
|
||||
) -> tuple[str, str]:
|
||||
func_name = schema.name.unambiguous_name()
|
||||
base_name = schema.name.unambiguous_name()
|
||||
|
||||
global declaration_definition_cache
|
||||
if (func_name, device, backend_call) in declaration_definition_cache:
|
||||
return declaration_definition_cache[(func_name, device, backend_call)]
|
||||
if (base_name, device, backend_call) in declaration_definition_cache:
|
||||
return declaration_definition_cache[(base_name, device, backend_call)]
|
||||
|
||||
if schema.is_out_fn():
|
||||
# out_variant has out arguments in the front, and it's ok to ignore return values
|
||||
# because C shim functions only return AOTITorchError
|
||||
args, callsite_exprs = gen_arguments(
|
||||
[*schema.arguments.out, *schema.arguments.flat_non_out]
|
||||
# Check the validity of version_info. The format should look like
|
||||
# {"v2" : ["new_arg1"], "v3": ["new_arg2, new_arg3"]}.
|
||||
indexed_version_info: dict[int, list[str]] = {1: []}
|
||||
for ver_str, new_args in sorted(version_info.items()):
|
||||
assert ver_str.startswith("v"), (
|
||||
f"Version number for {base_name} is {ver_str}, not starting with 'v'"
|
||||
)
|
||||
ret_assignments: list[str] = []
|
||||
else:
|
||||
args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
|
||||
# ignore return values for inplace ops
|
||||
ret_declarations, ret_assignments = (
|
||||
([], []) if schema.name.name.inplace else gen_returns(schema)
|
||||
try:
|
||||
ver_id = int(ver_str[1:])
|
||||
except ValueError as e:
|
||||
raise AssertionError(
|
||||
f"Version number for {base_name} is {ver_str}, not a valid integer after 'v'"
|
||||
) from e
|
||||
assert ver_id not in indexed_version_info, (
|
||||
f"{ver_str} for {base_name} has already been defined"
|
||||
)
|
||||
args.extend(ret_declarations)
|
||||
indexed_version_info[ver_id] = new_args
|
||||
|
||||
declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
|
||||
declarations: list[str] = []
|
||||
definitions: list[str] = []
|
||||
skipped_args: set[str] = set()
|
||||
|
||||
tmp_result = "auto tmp_result = " if ret_assignments else ""
|
||||
ret_assignments_str = "\n" + "\n".join(ret_assignments) if ret_assignments else ""
|
||||
definition = f"""
|
||||
{declaration} {{
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{
|
||||
{tmp_result}{backend_call}(
|
||||
{textwrap.indent(", ".join(callsite_exprs), " ")}
|
||||
);{textwrap.indent(ret_assignments_str, " ")}
|
||||
}});
|
||||
}}
|
||||
"""
|
||||
declaration_definition_cache[(func_name, device, backend_call)] = (
|
||||
declaration,
|
||||
definition,
|
||||
for ver_id, new_args in sorted(indexed_version_info.items(), reverse=True):
|
||||
# Iterate in the reverse order, so the latest version of an op will get generated first
|
||||
# with all the arguments included, while a set of to-be-trimmed args is carried down
|
||||
# to generate earlier version of the op.
|
||||
func_name = base_name if ver_id == 1 else f"{base_name}_v{ver_id}"
|
||||
if schema.is_out_fn():
|
||||
# out_variant has out arguments in the front, and it's ok to ignore return values
|
||||
# because C shim functions only return AOTITorchError
|
||||
args, callsite_exprs = gen_arguments(
|
||||
[*schema.arguments.out, *schema.arguments.flat_non_out], skipped_args
|
||||
)
|
||||
ret_assignments: list[str] = []
|
||||
else:
|
||||
args, callsite_exprs = gen_arguments(
|
||||
schema.arguments.flat_all, skipped_args
|
||||
)
|
||||
# ignore return values for inplace ops
|
||||
ret_declarations, ret_assignments = (
|
||||
([], []) if schema.name.name.inplace else gen_returns(schema)
|
||||
)
|
||||
args.extend(ret_declarations)
|
||||
|
||||
declaration = textwrap.dedent(
|
||||
f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
|
||||
)
|
||||
|
||||
tmp_result = "auto tmp_result = " if ret_assignments else ""
|
||||
indent = "\t\t"
|
||||
ret_assignments_str = (
|
||||
"\n".join(indent + r for r in ret_assignments) if ret_assignments else ""
|
||||
)
|
||||
definition = (
|
||||
textwrap.dedent(f"""
|
||||
{declaration} {{
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{
|
||||
{tmp_result}{backend_call}(
|
||||
{", ".join(callsite_exprs)}
|
||||
);
|
||||
""")
|
||||
+ ret_assignments_str
|
||||
+ textwrap.dedent("""
|
||||
});
|
||||
}
|
||||
""")
|
||||
)
|
||||
skipped_args.update(new_args)
|
||||
declarations.append(f"AOTI_TORCH_EXPORT {declaration};")
|
||||
definitions.append(definition)
|
||||
|
||||
declaration_definition_cache[(base_name, device, backend_call)] = (
|
||||
"\n".join(declarations),
|
||||
"\n".join(definitions),
|
||||
)
|
||||
return declaration, definition
|
||||
return declaration_definition_cache[(base_name, device, backend_call)]
|
||||
|
||||
|
||||
def gen_static_dispatch_backend_call_signature(
|
||||
@ -402,6 +453,7 @@ def get_fallback_op_name(func: NativeFunction) -> str:
|
||||
|
||||
def gen_c_shim(
|
||||
func: NativeFunction,
|
||||
version_info: dict[str, list[str]],
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
@ -424,11 +476,13 @@ def gen_c_shim(
|
||||
try:
|
||||
if header:
|
||||
declaration, _ = gen_declaration_and_definition(
|
||||
schema, device, backend_call
|
||||
schema, device, backend_call, version_info
|
||||
)
|
||||
return f"AOTI_TORCH_EXPORT {declaration};"
|
||||
return declaration
|
||||
else:
|
||||
_, definition = gen_declaration_and_definition(schema, device, backend_call)
|
||||
_, definition = gen_declaration_and_definition(
|
||||
schema, device, backend_call, version_info
|
||||
)
|
||||
return definition
|
||||
|
||||
except NotImplementedError:
|
||||
@ -437,6 +491,7 @@ def gen_c_shim(
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShimGenerator:
|
||||
inductor_fallback_ops: dict[str, dict[str, list[str]]]
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
|
||||
dispatch_key: DispatchKey
|
||||
backend_indices: dict[DispatchKey, BackendIndex]
|
||||
@ -448,8 +503,10 @@ class ShimGenerator:
|
||||
self,
|
||||
func: NativeFunction,
|
||||
) -> str | None:
|
||||
version_info = self.inductor_fallback_ops[get_fallback_op_name(func)]
|
||||
result = gen_c_shim(
|
||||
func,
|
||||
version_info,
|
||||
self.func_group_mapping,
|
||||
self.dispatch_key,
|
||||
self.backend_indices,
|
||||
@ -461,6 +518,7 @@ class ShimGenerator:
|
||||
|
||||
def gen_aoti_c_shim(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
inductor_fallback_ops: dict[str, dict[str, list[str]]],
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
@ -472,6 +530,7 @@ def gen_aoti_c_shim(
|
||||
list(
|
||||
mapMaybe(
|
||||
ShimGenerator(
|
||||
inductor_fallback_ops,
|
||||
func_group_mapping,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
@ -484,44 +543,53 @@ def gen_aoti_c_shim(
|
||||
)
|
||||
device = dispatch_key.lower()
|
||||
warning = """
|
||||
|
||||
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
|
||||
// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details"""
|
||||
|
||||
if header:
|
||||
return f"""
|
||||
{warning}
|
||||
return (
|
||||
warning
|
||||
+ textwrap.dedent("""
|
||||
|
||||
#pragma once
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {{
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
{body}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}} // extern "C"
|
||||
#endif
|
||||
"""
|
||||
""")
|
||||
+ body
|
||||
+ textwrap.dedent("""
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
""")
|
||||
)
|
||||
else:
|
||||
return f"""
|
||||
{warning}
|
||||
return (
|
||||
warning
|
||||
+ textwrap.dedent(f"""
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/generated/{"extend/" if extend_aoti_c_shim else ""}c_shim_{device}.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/generated/{"extend/" if extend_aoti_c_shim else ""}c_shim_{device}.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/{str(dispatch_key)}Functions.h>
|
||||
#include <ATen/CompositeExplicitAutogradFunctions.h>
|
||||
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
|
||||
#include <ATen/CompositeImplicitAutogradFunctions.h>
|
||||
#else
|
||||
{includes}
|
||||
#endif
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/{str(dispatch_key)}Functions.h>
|
||||
#include <ATen/CompositeExplicitAutogradFunctions.h>
|
||||
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
|
||||
#include <ATen/CompositeImplicitAutogradFunctions.h>
|
||||
#else
|
||||
""")
|
||||
+ includes
|
||||
+ textwrap.dedent("""
|
||||
#endif // AT_PER_OPERATOR_HEADERS
|
||||
|
||||
using namespace torch::aot_inductor;
|
||||
using namespace torch::aot_inductor;
|
||||
|
||||
{body}"""
|
||||
""")
|
||||
+ body
|
||||
)
|
||||
|
Reference in New Issue
Block a user