Fix torchgen update-aoti-shim (#156323)

will remove the fill changes before landing and let Jane merge her changes!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156323
Approved by: https://github.com/janeyx99
This commit is contained in:
angelayi
2025-06-20 05:23:06 +00:00
committed by PyTorch MergeBot
parent f7a5ad6c29
commit c37ddcaefb
2 changed files with 50 additions and 44 deletions

View File

@ -2824,15 +2824,40 @@ def main() -> None:
from torchgen.model import dispatch_keys
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
# for them; this is the set
functions_keys = {
DispatchKey.CPU,
DispatchKey.CUDA,
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeImplicitAutogradNestedTensor,
DispatchKey.CompositeExplicitAutograd,
DispatchKey.CompositeExplicitAutogradNonFunctional,
DispatchKey.Meta,
DispatchKey.MTIA,
}
aoti_backends = {
DispatchKey.CPU,
DispatchKey.CUDA,
}
# TODO: stop generating CUDA kernels for non-CUDA builds
ignore_keys = set()
if not options.mps:
if options.mps or options.update_aoti_c_shim:
functions_keys.add(DispatchKey.MPS)
aoti_backends.add(DispatchKey.MPS)
else:
ignore_keys.add(DispatchKey.MPS)
if DispatchKey.MPS in dispatch_keys:
del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
if not options.xpu:
if options.xpu or options.update_aoti_c_shim:
functions_keys.add(DispatchKey.XPU)
aoti_backends.add(DispatchKey.XPU)
else:
ignore_keys.add(DispatchKey.XPU)
if DispatchKey.XPU in dispatch_keys:
@ -2844,6 +2869,13 @@ def main() -> None:
if DispatchKey.MTIA in dispatch_keys:
del dispatch_keys[dispatch_keys.index(DispatchKey.MTIA)]
if options.backend_whitelist:
dispatch_keys = [
k
for k in dispatch_keys
if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
]
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
native_functions, backend_indices = (
@ -2893,45 +2925,6 @@ def main() -> None:
if options.xpu:
device_fms["xpu"] = make_file_manager(options=options)
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
# for them; this is the set
functions_keys = {
DispatchKey.CPU,
DispatchKey.CUDA,
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeImplicitAutogradNestedTensor,
DispatchKey.CompositeExplicitAutograd,
DispatchKey.CompositeExplicitAutogradNonFunctional,
DispatchKey.Meta,
DispatchKey.MTIA,
}
aoti_backends = {
DispatchKey.CPU,
DispatchKey.CUDA,
}
if options.update_aoti_c_shim:
# When updating the shim we want to update all devices, but when just
# building/checking the headers, we only want to check the devices that
# are available.
aoti_backends.add(DispatchKey.XPU)
aoti_backends.add(DispatchKey.MPS)
if options.mps:
functions_keys.add(DispatchKey.MPS)
aoti_backends.add(DispatchKey.MPS)
if options.xpu:
functions_keys.add(DispatchKey.XPU)
aoti_backends.add(DispatchKey.XPU)
if options.backend_whitelist:
dispatch_keys = [
k
for k in dispatch_keys
if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
]
static_dispatch_idx: list[BackendIndex] = []
if options.static_dispatch_backend:
static_dispatch_idx = [

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import difflib
import os
import textwrap
from dataclasses import dataclass
@ -648,9 +649,20 @@ def gen_aoti_c_shim_files(
os.path.join(aoti_fm.install_dir, header_file_name)
) as old_file:
old_header = old_file.read()
assert old_header == new_header, """
WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
if old_header != new_header:
diff = "\n".join(
difflib.unified_diff(
old_header.splitlines(),
new_header.splitlines(),
fromfile="expected",
tofile="actual",
lineterm="",
)
)
raise RuntimeError(f"""
The generated AOTInductor C shim header files have unexpectedly changed. This
indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
Only in a limited number of situations, this is allowed:
@ -664,7 +676,8 @@ torchgen/aoti/fallback_ops.py, and then run `python torchgen/gen.py --update-aot
update the C shim header files by creating different versions of the fallback op. See
https://github.com/pytorch/pytorch/pull/154848 as an example.
"""
{diff}
""")
except FileNotFoundError:
print(
f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"