[aoti] Update cshim for all backends (#155604)

Fixes https://github.com/pytorch/pytorch/issues/155349
`python torchgen/gen.py --update-aoti-c-shim` will now update all cpu/cuda/mps/xpu shims -- I verified this using `aten._print.default`, but didn't commit the changes since I'm not sure if we actually want to add this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155604
Approved by: https://github.com/desertfire, https://github.com/janeyx99
This commit is contained in:
angelayi
2025-06-12 22:10:53 +00:00
committed by PyTorch MergeBot
parent 38bfd462b8
commit 938515fa75
2 changed files with 126 additions and 100 deletions

View File

@ -18,7 +18,6 @@ import torchgen.api.meta as meta
import torchgen.api.native as native
import torchgen.api.structured as structured
import torchgen.dest as dest
from torchgen.aoti.fallback_ops import inductor_fallback_ops
from torchgen.api import cpp
from torchgen.api.translate import translate
from torchgen.api.types import (
@ -37,10 +36,8 @@ from torchgen.context import (
with_native_function_and_indices,
)
from torchgen.gen_aoti_c_shim import (
gen_aoti_c_shim,
gen_aoti_c_shim_files,
gen_static_dispatch_backend_call_signature,
get_fallback_op_name,
get_header_for_aoti,
)
from torchgen.gen_functionalization_type import (
gen_functionalization_definition,
@ -2395,103 +2392,19 @@ def gen_source_files(
else:
raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
structured_func_group_dict = {}
for func_group in structured_native_functions:
for func in func_group.functions():
if func.structured_delegate is not None:
structured_func_group_dict[func.structured_delegate] = func_group
break
if dispatch_key in aoti_backends:
fallbacks = {}
for func in native_functions:
op_name = get_fallback_op_name(func)
if op_name in inductor_fallback_ops:
fallbacks[op_name] = func
fallback_native_functions = tuple(
value for _, value in sorted(fallbacks.items())
)
# header files were checked in for ABI-compatiblilty checking
header_file_name = f"c_shim_{dispatch_key.lower()}.h"
new_header = gen_aoti_c_shim(
fallback_native_functions,
inductor_fallback_ops,
structured_func_group_dict,
dispatch_key,
backend_indices,
header=True,
extend_aoti_c_shim=extend_aoti_c_shim,
includes="",
)
if update_aoti_c_shim:
aoti_fm.write(
header_file_name,
lambda: new_header,
)
else:
try:
with open(
os.path.join(aoti_fm.install_dir, header_file_name)
) as old_file:
old_header = old_file.read()
assert old_header == new_header, """
WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
Only in a limited number of situations, this is allowed:
1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py.
If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to add a new entry to
existing C shim header files.
2. You added a new default argument to an existing fallback op. This is clearly a BC breaking
change in the AOTInductor land. You need to annotate the new default argument in
torchgen/aoti/fallback_ops.py, and then run `python torchgen/gen.py --update-aoti-c-shim` to
update the C shim header files by creating different versions of the fallback op. See
https://github.com/pytorch/pytorch/pull/154848 as an example.
"""
except FileNotFoundError:
print(
f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
)
# cpp files are always generated on-the-fly
def headers_for_aoti() -> str:
headers = []
for func in fallback_native_functions:
header = get_header_for_aoti(
func,
structured_func_group_dict,
dispatch_key,
backend_indices,
extend_aoti_c_shim=extend_aoti_c_shim,
)
if header is not None:
headers.append(header)
return "\n".join(sorted(set(headers)))
extra_headers = (
extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
)
aoti_fm.write(
f"c_shim_{dispatch_key.lower()}.cpp",
lambda: gen_aoti_c_shim(
fallback_native_functions,
inductor_fallback_ops,
structured_func_group_dict,
dispatch_key,
backend_indices,
header=False,
extend_aoti_c_shim=extend_aoti_c_shim,
includes=headers_for_aoti() + "\n" + extra_headers,
),
)
del fm
gen_aoti_c_shim_files(
aoti_fm=aoti_fm,
aoti_backends=aoti_backends,
native_functions=native_functions,
backend_indices=backend_indices,
structured_native_functions=structured_native_functions,
extra_cuda_headers=extra_cuda_headers,
update_aoti_c_shim=update_aoti_c_shim,
extend_aoti_c_shim=extend_aoti_c_shim,
)
# BackendSelect is generated specially
def gen_backend_select() -> dict[str, list[str]]:
relevant_fns = [
@ -2997,6 +2910,12 @@ def main() -> None:
DispatchKey.CPU,
DispatchKey.CUDA,
}
if options.update_aoti_c_shim:
# When updating the shim we want to update all devices, but when just
# building/checking the headers, we only want to check the devices that
# are available.
aoti_backends.add(DispatchKey.XPU)
aoti_backends.add(DispatchKey.MPS)
if options.mps:
functions_keys.add(DispatchKey.MPS)

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import os
import textwrap
from dataclasses import dataclass
from typing import TYPE_CHECKING
from torchgen.aoti.fallback_ops import inductor_fallback_ops
from torchgen.api.types import DispatcherSignature
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
from torchgen.context import method_with_native_function
@ -14,6 +16,7 @@ from torchgen.model import (
BaseType,
DispatchKey,
FunctionSchema,
is_cuda_dispatch_key,
ListType,
NativeFunction,
NativeFunctionsGroup,
@ -21,7 +24,7 @@ from torchgen.model import (
OptionalType,
Type,
)
from torchgen.utils import mapMaybe
from torchgen.utils import FileManager, mapMaybe
if TYPE_CHECKING:
@ -593,3 +596,107 @@ def gen_aoti_c_shim(
""")
+ body
)
def gen_aoti_c_shim_files(
aoti_fm: FileManager,
aoti_backends: set[DispatchKey],
native_functions: Sequence[NativeFunction],
backend_indices: dict[DispatchKey, BackendIndex],
structured_native_functions: Sequence[NativeFunctionsGroup],
extra_cuda_headers: str,
extend_aoti_c_shim: bool,
update_aoti_c_shim: bool,
) -> None:
structured_func_group_dict = {}
for func_group in structured_native_functions:
for func in func_group.functions():
if func.structured_delegate is not None:
structured_func_group_dict[func.structured_delegate] = func_group
break
for dispatch_key in aoti_backends:
fallbacks = {}
for func in native_functions:
op_name = get_fallback_op_name(func)
if op_name in inductor_fallback_ops:
fallbacks[op_name] = func
fallback_native_functions = tuple(
value for _, value in sorted(fallbacks.items())
)
# header files were checked in for ABI-compatiblilty checking
header_file_name = f"c_shim_{dispatch_key.lower()}.h"
new_header = gen_aoti_c_shim(
fallback_native_functions,
inductor_fallback_ops,
structured_func_group_dict,
dispatch_key,
backend_indices,
header=True,
extend_aoti_c_shim=extend_aoti_c_shim,
includes="",
)
if update_aoti_c_shim:
aoti_fm.write(
header_file_name,
lambda: new_header,
)
else:
try:
with open(
os.path.join(aoti_fm.install_dir, header_file_name)
) as old_file:
old_header = old_file.read()
assert old_header == new_header, """
WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
Only in a limited number of situations, this is allowed:
1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py.
If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to add a new entry to
existing C shim header files.
2. You added a new default argument to an existing fallback op. This is clearly a BC breaking
change in the AOTInductor land. You need to annotate the new default argument in
torchgen/aoti/fallback_ops.py, and then run `python torchgen/gen.py --update-aoti-c-shim` to
update the C shim header files by creating different versions of the fallback op. See
https://github.com/pytorch/pytorch/pull/154848 as an example.
"""
except FileNotFoundError:
print(
f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
)
# cpp files are always generated on-the-fly
def headers_for_aoti() -> str:
headers = []
for func in fallback_native_functions:
header = get_header_for_aoti(
func,
structured_func_group_dict,
dispatch_key,
backend_indices,
extend_aoti_c_shim=extend_aoti_c_shim,
)
if header is not None:
headers.append(header)
return "\n".join(sorted(set(headers)))
extra_headers = extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
aoti_fm.write(
f"c_shim_{dispatch_key.lower()}.cpp",
lambda: gen_aoti_c_shim(
fallback_native_functions,
inductor_fallback_ops,
structured_func_group_dict,
dispatch_key,
backend_indices,
header=False,
extend_aoti_c_shim=extend_aoti_c_shim,
includes=headers_for_aoti() + "\n" + extra_headers,
),
)