mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[aoti] Update cshim for all backends (#155604)
Fixes https://github.com/pytorch/pytorch/issues/155349 `python torchgen/gen.py --update-aoti-c-shim` will now update all cpu/cuda/mps/xpu shims -- I verified this using `aten._print.default`, but didn't commit the changes since I'm not sure if we actually want to add this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155604 Approved by: https://github.com/desertfire, https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
38bfd462b8
commit
938515fa75
117
torchgen/gen.py
117
torchgen/gen.py
@ -18,7 +18,6 @@ import torchgen.api.meta as meta
|
||||
import torchgen.api.native as native
|
||||
import torchgen.api.structured as structured
|
||||
import torchgen.dest as dest
|
||||
from torchgen.aoti.fallback_ops import inductor_fallback_ops
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import (
|
||||
@ -37,10 +36,8 @@ from torchgen.context import (
|
||||
with_native_function_and_indices,
|
||||
)
|
||||
from torchgen.gen_aoti_c_shim import (
|
||||
gen_aoti_c_shim,
|
||||
gen_aoti_c_shim_files,
|
||||
gen_static_dispatch_backend_call_signature,
|
||||
get_fallback_op_name,
|
||||
get_header_for_aoti,
|
||||
)
|
||||
from torchgen.gen_functionalization_type import (
|
||||
gen_functionalization_definition,
|
||||
@ -2395,103 +2392,19 @@ def gen_source_files(
|
||||
else:
|
||||
raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
|
||||
|
||||
structured_func_group_dict = {}
|
||||
for func_group in structured_native_functions:
|
||||
for func in func_group.functions():
|
||||
if func.structured_delegate is not None:
|
||||
structured_func_group_dict[func.structured_delegate] = func_group
|
||||
break
|
||||
|
||||
if dispatch_key in aoti_backends:
|
||||
fallbacks = {}
|
||||
for func in native_functions:
|
||||
op_name = get_fallback_op_name(func)
|
||||
if op_name in inductor_fallback_ops:
|
||||
fallbacks[op_name] = func
|
||||
fallback_native_functions = tuple(
|
||||
value for _, value in sorted(fallbacks.items())
|
||||
)
|
||||
|
||||
# header files were checked in for ABI-compatiblilty checking
|
||||
header_file_name = f"c_shim_{dispatch_key.lower()}.h"
|
||||
new_header = gen_aoti_c_shim(
|
||||
fallback_native_functions,
|
||||
inductor_fallback_ops,
|
||||
structured_func_group_dict,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
header=True,
|
||||
extend_aoti_c_shim=extend_aoti_c_shim,
|
||||
includes="",
|
||||
)
|
||||
if update_aoti_c_shim:
|
||||
aoti_fm.write(
|
||||
header_file_name,
|
||||
lambda: new_header,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
with open(
|
||||
os.path.join(aoti_fm.install_dir, header_file_name)
|
||||
) as old_file:
|
||||
old_header = old_file.read()
|
||||
assert old_header == new_header, """
|
||||
|
||||
WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
|
||||
indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
|
||||
Only in a limited number of situations, this is allowed:
|
||||
|
||||
1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py.
|
||||
If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to add a new entry to
|
||||
existing C shim header files.
|
||||
|
||||
2. You added a new default argument to an existing fallback op. This is clearly a BC breaking
|
||||
change in the AOTInductor land. You need to annotate the new default argument in
|
||||
torchgen/aoti/fallback_ops.py, and then run `python torchgen/gen.py --update-aoti-c-shim` to
|
||||
update the C shim header files by creating different versions of the fallback op. See
|
||||
https://github.com/pytorch/pytorch/pull/154848 as an example.
|
||||
|
||||
"""
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
|
||||
)
|
||||
|
||||
# cpp files are always generated on-the-fly
|
||||
def headers_for_aoti() -> str:
|
||||
headers = []
|
||||
for func in fallback_native_functions:
|
||||
header = get_header_for_aoti(
|
||||
func,
|
||||
structured_func_group_dict,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
extend_aoti_c_shim=extend_aoti_c_shim,
|
||||
)
|
||||
if header is not None:
|
||||
headers.append(header)
|
||||
return "\n".join(sorted(set(headers)))
|
||||
|
||||
extra_headers = (
|
||||
extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
|
||||
)
|
||||
|
||||
aoti_fm.write(
|
||||
f"c_shim_{dispatch_key.lower()}.cpp",
|
||||
lambda: gen_aoti_c_shim(
|
||||
fallback_native_functions,
|
||||
inductor_fallback_ops,
|
||||
structured_func_group_dict,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
header=False,
|
||||
extend_aoti_c_shim=extend_aoti_c_shim,
|
||||
includes=headers_for_aoti() + "\n" + extra_headers,
|
||||
),
|
||||
)
|
||||
|
||||
del fm
|
||||
|
||||
gen_aoti_c_shim_files(
|
||||
aoti_fm=aoti_fm,
|
||||
aoti_backends=aoti_backends,
|
||||
native_functions=native_functions,
|
||||
backend_indices=backend_indices,
|
||||
structured_native_functions=structured_native_functions,
|
||||
extra_cuda_headers=extra_cuda_headers,
|
||||
update_aoti_c_shim=update_aoti_c_shim,
|
||||
extend_aoti_c_shim=extend_aoti_c_shim,
|
||||
)
|
||||
|
||||
# BackendSelect is generated specially
|
||||
def gen_backend_select() -> dict[str, list[str]]:
|
||||
relevant_fns = [
|
||||
@ -2997,6 +2910,12 @@ def main() -> None:
|
||||
DispatchKey.CPU,
|
||||
DispatchKey.CUDA,
|
||||
}
|
||||
if options.update_aoti_c_shim:
|
||||
# When updating the shim we want to update all devices, but when just
|
||||
# building/checking the headers, we only want to check the devices that
|
||||
# are available.
|
||||
aoti_backends.add(DispatchKey.XPU)
|
||||
aoti_backends.add(DispatchKey.MPS)
|
||||
|
||||
if options.mps:
|
||||
functions_keys.add(DispatchKey.MPS)
|
||||
|
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torchgen.aoti.fallback_ops import inductor_fallback_ops
|
||||
from torchgen.api.types import DispatcherSignature
|
||||
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
|
||||
from torchgen.context import method_with_native_function
|
||||
@ -14,6 +16,7 @@ from torchgen.model import (
|
||||
BaseType,
|
||||
DispatchKey,
|
||||
FunctionSchema,
|
||||
is_cuda_dispatch_key,
|
||||
ListType,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
@ -21,7 +24,7 @@ from torchgen.model import (
|
||||
OptionalType,
|
||||
Type,
|
||||
)
|
||||
from torchgen.utils import mapMaybe
|
||||
from torchgen.utils import FileManager, mapMaybe
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -593,3 +596,107 @@ def gen_aoti_c_shim(
|
||||
""")
|
||||
+ body
|
||||
)
|
||||
|
||||
|
||||
def gen_aoti_c_shim_files(
|
||||
aoti_fm: FileManager,
|
||||
aoti_backends: set[DispatchKey],
|
||||
native_functions: Sequence[NativeFunction],
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
structured_native_functions: Sequence[NativeFunctionsGroup],
|
||||
extra_cuda_headers: str,
|
||||
extend_aoti_c_shim: bool,
|
||||
update_aoti_c_shim: bool,
|
||||
) -> None:
|
||||
structured_func_group_dict = {}
|
||||
for func_group in structured_native_functions:
|
||||
for func in func_group.functions():
|
||||
if func.structured_delegate is not None:
|
||||
structured_func_group_dict[func.structured_delegate] = func_group
|
||||
break
|
||||
|
||||
for dispatch_key in aoti_backends:
|
||||
fallbacks = {}
|
||||
for func in native_functions:
|
||||
op_name = get_fallback_op_name(func)
|
||||
if op_name in inductor_fallback_ops:
|
||||
fallbacks[op_name] = func
|
||||
fallback_native_functions = tuple(
|
||||
value for _, value in sorted(fallbacks.items())
|
||||
)
|
||||
|
||||
# header files were checked in for ABI-compatiblilty checking
|
||||
header_file_name = f"c_shim_{dispatch_key.lower()}.h"
|
||||
new_header = gen_aoti_c_shim(
|
||||
fallback_native_functions,
|
||||
inductor_fallback_ops,
|
||||
structured_func_group_dict,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
header=True,
|
||||
extend_aoti_c_shim=extend_aoti_c_shim,
|
||||
includes="",
|
||||
)
|
||||
if update_aoti_c_shim:
|
||||
aoti_fm.write(
|
||||
header_file_name,
|
||||
lambda: new_header,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
with open(
|
||||
os.path.join(aoti_fm.install_dir, header_file_name)
|
||||
) as old_file:
|
||||
old_header = old_file.read()
|
||||
assert old_header == new_header, """
|
||||
|
||||
WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
|
||||
indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
|
||||
Only in a limited number of situations, this is allowed:
|
||||
|
||||
1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py.
|
||||
If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to add a new entry to
|
||||
existing C shim header files.
|
||||
|
||||
2. You added a new default argument to an existing fallback op. This is clearly a BC breaking
|
||||
change in the AOTInductor land. You need to annotate the new default argument in
|
||||
torchgen/aoti/fallback_ops.py, and then run `python torchgen/gen.py --update-aoti-c-shim` to
|
||||
update the C shim header files by creating different versions of the fallback op. See
|
||||
https://github.com/pytorch/pytorch/pull/154848 as an example.
|
||||
|
||||
"""
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
|
||||
)
|
||||
|
||||
# cpp files are always generated on-the-fly
|
||||
def headers_for_aoti() -> str:
|
||||
headers = []
|
||||
for func in fallback_native_functions:
|
||||
header = get_header_for_aoti(
|
||||
func,
|
||||
structured_func_group_dict,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
extend_aoti_c_shim=extend_aoti_c_shim,
|
||||
)
|
||||
if header is not None:
|
||||
headers.append(header)
|
||||
return "\n".join(sorted(set(headers)))
|
||||
|
||||
extra_headers = extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
|
||||
|
||||
aoti_fm.write(
|
||||
f"c_shim_{dispatch_key.lower()}.cpp",
|
||||
lambda: gen_aoti_c_shim(
|
||||
fallback_native_functions,
|
||||
inductor_fallback_ops,
|
||||
structured_func_group_dict,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
header=False,
|
||||
extend_aoti_c_shim=extend_aoti_c_shim,
|
||||
includes=headers_for_aoti() + "\n" + extra_headers,
|
||||
),
|
||||
)
|
||||
|
Reference in New Issue
Block a user