[AOTI] Refine the C shim autogen mechanism (#125589)

Summary: Based on the discussions in https://github.com/pytorch/pytorch/pull/120513. Instead of auto-generate C shim fallback ops for thousands of ops, we maintain a list of fallback ops based on torch/_inductor/lowering.py, and only generate C shim functions for those ops. At the torchgen time, we will re-generate C shim files and compare the header file contents against the existing C shim headers. If there is any change, the compilation will fail with prompt on how to proceed. This makes sure the ABI-compatible C shim layer is small enough to maintain in the long run.

Differential Revision: [D57004046](https://our.internmc.facebook.com/intern/diff/D57004046)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125589
Approved by: https://github.com/frank-wei, https://github.com/chenyang78, https://github.com/albanD, https://github.com/ezyang
This commit is contained in:
Bin Bao
2024-05-08 13:38:28 -07:00
committed by PyTorch MergeBot
parent 0bde9c08ef
commit ed48ea9997
10 changed files with 501 additions and 60 deletions

View File

@ -3,6 +3,7 @@ import functools
import json
import os
import pathlib
from collections import defaultdict, namedtuple, OrderedDict
from dataclasses import dataclass, field
from typing import (
@ -27,6 +28,7 @@ import torchgen.api.native as native
import torchgen.api.structured as structured
import torchgen.dest as dest
from torchgen.aoti.fallback_ops import inductor_fallback_ops
from torchgen.api import cpp
from torchgen.api.translate import translate
from torchgen.api.types import (
@ -48,6 +50,7 @@ from torchgen.gen_aoti_c_shim import (
gen_aoti_c_shim,
gen_static_dispatch_backend_call_signature,
get_backend_index_for_aoti,
get_fallback_op_name,
)
from torchgen.gen_functionalization_type import (
gen_functionalization_definition,
@ -2190,6 +2193,7 @@ def gen_source_files(
force_schema_registration: bool,
per_operator_headers: bool,
skip_dispatcher_op_registration: bool,
update_aoti_c_shim: bool,
) -> None:
extra_cuda_headers = """\
#include <c10/cuda/CUDAGuard.h>
@ -2349,53 +2353,99 @@ def gen_source_files(
else:
raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
structured_func_group_dict = {
f"{func_group.functional.namespace}.{func_group.functional.func.name}": func_group
for func_group in structured_native_functions
}
if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA):
fallbacks = dict()
for func in native_functions:
op_name = get_fallback_op_name(func)
if op_name in inductor_fallback_ops:
fallbacks[op_name] = (
func,
structured_func_group_dict.get(
f"{func.namespace}.{func.func.name.name}", None
),
)
fallback_native_functions = tuple(
value for _, value in sorted(fallbacks.items())
)
def get_header(
f: NativeFunction,
func: NativeFunction,
func_group: Optional[NativeFunctionsGroup],
) -> Optional[str]:
backend_index = get_backend_index_for_aoti(
f, dispatch_key, backend_indices
func, func_group, dispatch_key, backend_indices
)
return (
None
if backend_index is None
else f"#include <ATen/ops/{f.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
else f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
)
def headers_for_aoti() -> str:
headers = []
for g in grouped_native_functions:
if isinstance(g, NativeFunctionsGroup):
for f in g.functions():
# some variants are registered in the backend, but some are registered as CompositeExplicitAutograd
header = get_header(f)
if header is not None:
headers.append(header)
else:
header = get_header(g)
if header is not None:
headers.append(header)
for func, func_group in fallback_native_functions:
header = get_header(func, func_group)
if header is not None:
headers.append(header)
return "\n".join(sorted(set(headers)))
extra_headers = (
extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
)
aoti_fm.write(
f"c_shim_{dispatch_key.lower()}.h",
lambda: gen_aoti_c_shim(
native_functions,
dispatch_key,
backend_indices,
header=True,
includes="",
),
# header files were checked in for ABI-compatiblilty checking
header_file_name = f"c_shim_{dispatch_key.lower()}.h"
new_header = gen_aoti_c_shim(
fallback_native_functions,
dispatch_key,
backend_indices,
header=True,
includes="",
)
if update_aoti_c_shim:
aoti_fm.write(
header_file_name,
lambda: new_header,
)
else:
try:
with open(
os.path.join(aoti_fm.install_dir, header_file_name)
) as old_file:
old_header = old_file.read()
assert (
old_header == new_header
), """
WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
Only in a limited number of situations, this is allowed:
1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py.
If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to update the existing
C shim header files.
2. You added a new default argument to an existing fallback op. This is clearly a BC breaking
change in the AOTInductor land. In this case, you need to keep a manual copy of that existing
fallback op in a file, e.g. torch/csrc/inductor/aoti_torch/c/shim.h, bump up the version
number of that fallback op in the newly generated C shim files, and update the cpp wrapper
codegen to generate the correct cpp call for this op. Contact AOTInductor team for assistance.
"""
except FileNotFoundError:
print(
f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
)
# cpp files are always generated on-the-fly
aoti_fm.write(
f"c_shim_{dispatch_key.lower()}.cpp",
lambda: gen_aoti_c_shim(
native_functions,
fallback_native_functions,
dispatch_key,
backend_indices,
header=False,
@ -2780,6 +2830,12 @@ def main() -> None:
default=["headers", "sources", "declarations_yaml"],
help="Generate only a subset of files",
)
parser.add_argument(
"--update-aoti-c-shim",
action="store_true",
help="Update AOTInductor C shim after changing torchgen/aoti/fallback_ops.py. "
"WARNING: Do not use this unless you are sure what you are doing!!!",
)
options = parser.parse_args()
@ -2898,6 +2954,7 @@ def main() -> None:
force_schema_registration=options.force_schema_registration,
per_operator_headers=options.per_operator_headers,
skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
update_aoti_c_shim=options.update_aoti_c_shim,
)
if "headers" in options.generate: