mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI] Refine the C shim autogen mechanism (#125589)
Summary: Based on the discussions in https://github.com/pytorch/pytorch/pull/120513. Instead of auto-generate C shim fallback ops for thousands of ops, we maintain a list of fallback ops based on torch/_inductor/lowering.py, and only generate C shim functions for those ops. At the torchgen time, we will re-generate C shim files and compare the header file contents against the existing C shim headers. If there is any change, the compilation will fail with prompt on how to proceed. This makes sure the ABI-compatible C shim layer is small enough to maintain in the long run. Differential Revision: [D57004046](https://our.internmc.facebook.com/intern/diff/D57004046) Pull Request resolved: https://github.com/pytorch/pytorch/pull/125589 Approved by: https://github.com/frank-wei, https://github.com/chenyang78, https://github.com/albanD, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
0bde9c08ef
commit
ed48ea9997
105
torchgen/gen.py
105
torchgen/gen.py
@ -3,6 +3,7 @@ import functools
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
from collections import defaultdict, namedtuple, OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
@ -27,6 +28,7 @@ import torchgen.api.native as native
|
||||
import torchgen.api.structured as structured
|
||||
import torchgen.dest as dest
|
||||
|
||||
from torchgen.aoti.fallback_ops import inductor_fallback_ops
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import (
|
||||
@ -48,6 +50,7 @@ from torchgen.gen_aoti_c_shim import (
|
||||
gen_aoti_c_shim,
|
||||
gen_static_dispatch_backend_call_signature,
|
||||
get_backend_index_for_aoti,
|
||||
get_fallback_op_name,
|
||||
)
|
||||
from torchgen.gen_functionalization_type import (
|
||||
gen_functionalization_definition,
|
||||
@ -2190,6 +2193,7 @@ def gen_source_files(
|
||||
force_schema_registration: bool,
|
||||
per_operator_headers: bool,
|
||||
skip_dispatcher_op_registration: bool,
|
||||
update_aoti_c_shim: bool,
|
||||
) -> None:
|
||||
extra_cuda_headers = """\
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -2349,53 +2353,99 @@ def gen_source_files(
|
||||
else:
|
||||
raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
|
||||
|
||||
structured_func_group_dict = {
|
||||
f"{func_group.functional.namespace}.{func_group.functional.func.name}": func_group
|
||||
for func_group in structured_native_functions
|
||||
}
|
||||
if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA):
|
||||
fallbacks = dict()
|
||||
for func in native_functions:
|
||||
op_name = get_fallback_op_name(func)
|
||||
if op_name in inductor_fallback_ops:
|
||||
fallbacks[op_name] = (
|
||||
func,
|
||||
structured_func_group_dict.get(
|
||||
f"{func.namespace}.{func.func.name.name}", None
|
||||
),
|
||||
)
|
||||
fallback_native_functions = tuple(
|
||||
value for _, value in sorted(fallbacks.items())
|
||||
)
|
||||
|
||||
def get_header(
|
||||
f: NativeFunction,
|
||||
func: NativeFunction,
|
||||
func_group: Optional[NativeFunctionsGroup],
|
||||
) -> Optional[str]:
|
||||
backend_index = get_backend_index_for_aoti(
|
||||
f, dispatch_key, backend_indices
|
||||
func, func_group, dispatch_key, backend_indices
|
||||
)
|
||||
return (
|
||||
None
|
||||
if backend_index is None
|
||||
else f"#include <ATen/ops/{f.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
|
||||
else f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
|
||||
)
|
||||
|
||||
def headers_for_aoti() -> str:
|
||||
headers = []
|
||||
for g in grouped_native_functions:
|
||||
if isinstance(g, NativeFunctionsGroup):
|
||||
for f in g.functions():
|
||||
# some variants are registered in the backend, but some are registered as CompositeExplicitAutograd
|
||||
header = get_header(f)
|
||||
if header is not None:
|
||||
headers.append(header)
|
||||
else:
|
||||
header = get_header(g)
|
||||
if header is not None:
|
||||
headers.append(header)
|
||||
for func, func_group in fallback_native_functions:
|
||||
header = get_header(func, func_group)
|
||||
if header is not None:
|
||||
headers.append(header)
|
||||
return "\n".join(sorted(set(headers)))
|
||||
|
||||
extra_headers = (
|
||||
extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
|
||||
)
|
||||
|
||||
aoti_fm.write(
|
||||
f"c_shim_{dispatch_key.lower()}.h",
|
||||
lambda: gen_aoti_c_shim(
|
||||
native_functions,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
header=True,
|
||||
includes="",
|
||||
),
|
||||
# header files were checked in for ABI-compatiblilty checking
|
||||
header_file_name = f"c_shim_{dispatch_key.lower()}.h"
|
||||
new_header = gen_aoti_c_shim(
|
||||
fallback_native_functions,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
header=True,
|
||||
includes="",
|
||||
)
|
||||
if update_aoti_c_shim:
|
||||
aoti_fm.write(
|
||||
header_file_name,
|
||||
lambda: new_header,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
with open(
|
||||
os.path.join(aoti_fm.install_dir, header_file_name)
|
||||
) as old_file:
|
||||
old_header = old_file.read()
|
||||
assert (
|
||||
old_header == new_header
|
||||
), """
|
||||
|
||||
WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
|
||||
indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
|
||||
Only in a limited number of situations, this is allowed:
|
||||
|
||||
1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py.
|
||||
If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to update the existing
|
||||
C shim header files.
|
||||
|
||||
2. You added a new default argument to an existing fallback op. This is clearly a BC breaking
|
||||
change in the AOTInductor land. In this case, you need to keep a manual copy of that existing
|
||||
fallback op in a file, e.g. torch/csrc/inductor/aoti_torch/c/shim.h, bump up the version
|
||||
number of that fallback op in the newly generated C shim files, and update the cpp wrapper
|
||||
codegen to generate the correct cpp call for this op. Contact AOTInductor team for assistance.
|
||||
|
||||
"""
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
|
||||
)
|
||||
|
||||
# cpp files are always generated on-the-fly
|
||||
aoti_fm.write(
|
||||
f"c_shim_{dispatch_key.lower()}.cpp",
|
||||
lambda: gen_aoti_c_shim(
|
||||
native_functions,
|
||||
fallback_native_functions,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
header=False,
|
||||
@ -2780,6 +2830,12 @@ def main() -> None:
|
||||
default=["headers", "sources", "declarations_yaml"],
|
||||
help="Generate only a subset of files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-aoti-c-shim",
|
||||
action="store_true",
|
||||
help="Update AOTInductor C shim after changing torchgen/aoti/fallback_ops.py. "
|
||||
"WARNING: Do not use this unless you are sure what you are doing!!!",
|
||||
)
|
||||
|
||||
options = parser.parse_args()
|
||||
|
||||
@ -2898,6 +2954,7 @@ def main() -> None:
|
||||
force_schema_registration=options.force_schema_registration,
|
||||
per_operator_headers=options.per_operator_headers,
|
||||
skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
|
||||
update_aoti_c_shim=options.update_aoti_c_shim,
|
||||
)
|
||||
|
||||
if "headers" in options.generate:
|
||||
|
||||
Reference in New Issue
Block a user