mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix torchgen update-aoti-shim (#156323)
will remove the fill changes before landing and let Jane merge her changes! Pull Request resolved: https://github.com/pytorch/pytorch/pull/156323 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
f7a5ad6c29
commit
c37ddcaefb
@ -2824,15 +2824,40 @@ def main() -> None:
|
||||
|
||||
from torchgen.model import dispatch_keys
|
||||
|
||||
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
|
||||
# for them; this is the set
|
||||
functions_keys = {
|
||||
DispatchKey.CPU,
|
||||
DispatchKey.CUDA,
|
||||
DispatchKey.CompositeImplicitAutograd,
|
||||
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||||
DispatchKey.CompositeExplicitAutograd,
|
||||
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
||||
DispatchKey.Meta,
|
||||
DispatchKey.MTIA,
|
||||
}
|
||||
|
||||
aoti_backends = {
|
||||
DispatchKey.CPU,
|
||||
DispatchKey.CUDA,
|
||||
}
|
||||
|
||||
# TODO: stop generating CUDA kernels for non-CUDA builds
|
||||
ignore_keys = set()
|
||||
if not options.mps:
|
||||
|
||||
if options.mps or options.update_aoti_c_shim:
|
||||
functions_keys.add(DispatchKey.MPS)
|
||||
aoti_backends.add(DispatchKey.MPS)
|
||||
else:
|
||||
ignore_keys.add(DispatchKey.MPS)
|
||||
|
||||
if DispatchKey.MPS in dispatch_keys:
|
||||
del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
|
||||
|
||||
if not options.xpu:
|
||||
if options.xpu or options.update_aoti_c_shim:
|
||||
functions_keys.add(DispatchKey.XPU)
|
||||
aoti_backends.add(DispatchKey.XPU)
|
||||
else:
|
||||
ignore_keys.add(DispatchKey.XPU)
|
||||
|
||||
if DispatchKey.XPU in dispatch_keys:
|
||||
@ -2844,6 +2869,13 @@ def main() -> None:
|
||||
if DispatchKey.MTIA in dispatch_keys:
|
||||
del dispatch_keys[dispatch_keys.index(DispatchKey.MTIA)]
|
||||
|
||||
if options.backend_whitelist:
|
||||
dispatch_keys = [
|
||||
k
|
||||
for k in dispatch_keys
|
||||
if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
|
||||
]
|
||||
|
||||
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
|
||||
valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
|
||||
native_functions, backend_indices = (
|
||||
@ -2893,45 +2925,6 @@ def main() -> None:
|
||||
if options.xpu:
|
||||
device_fms["xpu"] = make_file_manager(options=options)
|
||||
|
||||
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
|
||||
# for them; this is the set
|
||||
functions_keys = {
|
||||
DispatchKey.CPU,
|
||||
DispatchKey.CUDA,
|
||||
DispatchKey.CompositeImplicitAutograd,
|
||||
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||||
DispatchKey.CompositeExplicitAutograd,
|
||||
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
||||
DispatchKey.Meta,
|
||||
DispatchKey.MTIA,
|
||||
}
|
||||
|
||||
aoti_backends = {
|
||||
DispatchKey.CPU,
|
||||
DispatchKey.CUDA,
|
||||
}
|
||||
if options.update_aoti_c_shim:
|
||||
# When updating the shim we want to update all devices, but when just
|
||||
# building/checking the headers, we only want to check the devices that
|
||||
# are available.
|
||||
aoti_backends.add(DispatchKey.XPU)
|
||||
aoti_backends.add(DispatchKey.MPS)
|
||||
|
||||
if options.mps:
|
||||
functions_keys.add(DispatchKey.MPS)
|
||||
aoti_backends.add(DispatchKey.MPS)
|
||||
|
||||
if options.xpu:
|
||||
functions_keys.add(DispatchKey.XPU)
|
||||
aoti_backends.add(DispatchKey.XPU)
|
||||
|
||||
if options.backend_whitelist:
|
||||
dispatch_keys = [
|
||||
k
|
||||
for k in dispatch_keys
|
||||
if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
|
||||
]
|
||||
|
||||
static_dispatch_idx: list[BackendIndex] = []
|
||||
if options.static_dispatch_backend:
|
||||
static_dispatch_idx = [
|
||||
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
import os
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
@ -648,9 +649,20 @@ def gen_aoti_c_shim_files(
|
||||
os.path.join(aoti_fm.install_dir, header_file_name)
|
||||
) as old_file:
|
||||
old_header = old_file.read()
|
||||
assert old_header == new_header, """
|
||||
|
||||
WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
|
||||
if old_header != new_header:
|
||||
diff = "\n".join(
|
||||
difflib.unified_diff(
|
||||
old_header.splitlines(),
|
||||
new_header.splitlines(),
|
||||
fromfile="expected",
|
||||
tofile="actual",
|
||||
lineterm="",
|
||||
)
|
||||
)
|
||||
|
||||
raise RuntimeError(f"""
|
||||
The generated AOTInductor C shim header files have unexpectedly changed. This
|
||||
indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
|
||||
Only in a limited number of situations, this is allowed:
|
||||
|
||||
@ -664,7 +676,8 @@ torchgen/aoti/fallback_ops.py, and then run `python torchgen/gen.py --update-aot
|
||||
update the C shim header files by creating different versions of the fallback op. See
|
||||
https://github.com/pytorch/pytorch/pull/154848 as an example.
|
||||
|
||||
"""
|
||||
{diff}
|
||||
""")
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
|
||||
|
Reference in New Issue
Block a user