Enable generating generic c_shim that doesn't bypass dispatcher (#158974)

Adds `c_shim_aten.{h/cpp}` and use this for `fill_`

This is the generated `c_shim_aten.cpp` for reference

```cpp

// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
// See 7e86a7c015/torchgen/gen.py (L2424-L2436) for details

// This file corresponds to the aten_shimified_ops list in torchgen/aoti/fallback_ops.py

#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
#include <ATen/CompositeImplicitAutogradFunctions.h>
#else
#include <ATen/ops/fill.h>

#endif // AT_PER_OPERATOR_HEADERS

using namespace torch::aot_inductor;

AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value) {
    AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
        at::fill_(
            *tensor_handle_to_tensor_pointer(self), value
        );
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158974
Approved by: https://github.com/albanD, https://github.com/janeyx99
This commit is contained in:
Mikayla Gawarecki
2025-07-25 08:38:52 -07:00
committed by PyTorch MergeBot
parent bfe6765d6b
commit e65ab9a868
9 changed files with 143 additions and 24 deletions

View File

@ -175,3 +175,12 @@ inductor_fallback_ops: dict[str, dict[str, list[str]]] = {
"aten.view.dtype": {},
"aten._weight_int4pack_mm_with_scales_and_zeros.default": {},
}
# `python torchgen/gen.py --update-aoti-c-shim` will automatically generate
# c_shim_aten.{h/cpp} based on the list below.
# Operators in this list are intended to be used in torch/csrc/stable/ops.h
# Unlike other c_shims, operators in this file do not bypass the dispatcher.
# The same BC rules apply as inductor_fallback_ops.
aten_shimified_ops: dict[str, dict[str, list[str]]] = {
"aten.fill_.Scalar": {},
}

View File

@ -95,6 +95,7 @@ from torchgen.yaml_utils import YamlDumper, YamlLoader
if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Optional
T = TypeVar("T")
@ -2215,7 +2216,7 @@ def gen_source_files(
per_operator_headers: bool,
skip_dispatcher_op_registration: bool,
update_aoti_c_shim: bool,
aoti_backends: set[DispatchKey],
aoti_backends: set[Optional[DispatchKey]],
extend_aoti_c_shim: bool,
) -> None:
extra_cuda_headers = """\
@ -2840,6 +2841,9 @@ def main() -> None:
aoti_backends = {
DispatchKey.CPU,
DispatchKey.CUDA,
# None will generate the aten shim based on aten_shimified_ops
# which does not bypass the dispatcher
None,
}
# TODO: stop generating CUDA kernels for non-CUDA builds

View File

@ -6,7 +6,7 @@ import textwrap
from dataclasses import dataclass
from typing import TYPE_CHECKING
from torchgen.aoti.fallback_ops import inductor_fallback_ops
from torchgen.aoti.fallback_ops import aten_shimified_ops, inductor_fallback_ops
from torchgen.api.types import DispatcherSignature
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
from torchgen.context import method_with_native_function
@ -30,6 +30,7 @@ from torchgen.utils import FileManager, mapMaybe
if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Optional
base_type_to_c_type = {
@ -391,21 +392,28 @@ def gen_static_dispatch_backend_call_signature(
def gen_static_dispatch_backend_call(
f: NativeFunction,
backend_index: BackendIndex,
backend_index: Optional[BackendIndex] = None,
) -> str:
sig = DispatcherSignature.from_schema(f.func)
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
if backend_index is None:
return f"at::{cpp_sig.name()}"
else:
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
def get_backend_index_for_aoti(
func: NativeFunction,
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey,
dispatch_key: Optional[DispatchKey],
backend_indices: dict[DispatchKey, BackendIndex],
extend_aoti_c_shim: bool,
) -> BackendIndex | None:
backend_index = None
if dispatch_key is None:
return backend_index
if backend_indices[dispatch_key].has_kernel(func) or (
func.structured_delegate is not None
and func.structured_delegate in func_group_mapping
@ -439,18 +447,19 @@ def get_backend_index_for_aoti(
def get_header_for_aoti(
func: NativeFunction,
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey,
dispatch_key: Optional[DispatchKey],
backend_indices: dict[DispatchKey, BackendIndex],
extend_aoti_c_shim: bool,
) -> str | None:
backend_index = get_backend_index_for_aoti(
func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim
)
return (
None
if backend_index is None
else f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
)
if backend_index is None:
if dispatch_key is None:
return f"#include <ATen/ops/{func.root_name}.h>"
return None
return f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
def get_fallback_op_name(func: NativeFunction) -> str:
@ -465,7 +474,7 @@ def gen_c_shim(
func: NativeFunction,
version_info: dict[str, list[str]],
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey,
dispatch_key: Optional[DispatchKey],
backend_indices: dict[DispatchKey, BackendIndex],
header: bool,
extend_aoti_c_shim: bool,
@ -473,11 +482,11 @@ def gen_c_shim(
backend_index = get_backend_index_for_aoti(
func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim
)
if backend_index is None:
if backend_index is None and dispatch_key is not None:
return None
schema = func.func
device = dispatch_key.lower()
device = "aten" if dispatch_key is None else dispatch_key.lower()
backend_call = gen_static_dispatch_backend_call(
func,
backend_index,
@ -503,7 +512,7 @@ def gen_c_shim(
class ShimGenerator:
inductor_fallback_ops: dict[str, dict[str, list[str]]]
func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
dispatch_key: DispatchKey
dispatch_key: Optional[DispatchKey]
backend_indices: dict[DispatchKey, BackendIndex]
header: bool # True to generate .h and False to generate .cpp
extend_aoti_c_shim: bool
@ -530,7 +539,7 @@ def gen_aoti_c_shim(
native_functions: Sequence[NativeFunction],
inductor_fallback_ops: dict[str, dict[str, list[str]]],
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey,
dispatch_key: Optional[DispatchKey],
backend_indices: dict[DispatchKey, BackendIndex],
header: bool,
extend_aoti_c_shim: bool,
@ -551,7 +560,19 @@ def gen_aoti_c_shim(
)
)
)
device = dispatch_key.lower()
device = "aten" if dispatch_key is None else dispatch_key.lower()
include_device_functions = (
"#include <ATen/Functions.h>"
if dispatch_key is None
else f"#include <ATen/{str(dispatch_key)}Functions.h>"
)
aten_warning = (
(
"\n\n// This file corresponds to the aten_shimified_ops list in torchgen/aoti/fallback_ops.py\n"
)
if dispatch_key is None
else ""
)
warning = """
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
@ -560,6 +581,7 @@ def gen_aoti_c_shim(
if header:
return (
warning
+ aten_warning
+ textwrap.dedent("""
#pragma once
@ -582,13 +604,14 @@ def gen_aoti_c_shim(
else:
return (
warning
+ aten_warning
+ textwrap.dedent(f"""
#include <torch/csrc/inductor/aoti_torch/generated/{"extend/" if extend_aoti_c_shim else ""}c_shim_{device}.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/{str(dispatch_key)}Functions.h>
{include_device_functions}
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
#include <ATen/CompositeImplicitAutogradFunctions.h>
@ -607,7 +630,7 @@ def gen_aoti_c_shim(
def gen_aoti_c_shim_files(
aoti_fm: FileManager,
aoti_backends: set[DispatchKey],
aoti_backends: set[Optional[DispatchKey]],
native_functions: Sequence[NativeFunction],
backend_indices: dict[DispatchKey, BackendIndex],
structured_native_functions: Sequence[NativeFunctionsGroup],
@ -623,20 +646,27 @@ def gen_aoti_c_shim_files(
break
for dispatch_key in aoti_backends:
# Use aten_shimified_ops for the aten backend, inductor_fallback_ops for others
fallback_ops_dict = (
aten_shimified_ops if dispatch_key is None else inductor_fallback_ops
)
fallbacks = {}
for func in native_functions:
op_name = get_fallback_op_name(func)
if op_name in inductor_fallback_ops:
if op_name in fallback_ops_dict:
fallbacks[op_name] = func
fallback_native_functions = tuple(
value for _, value in sorted(fallbacks.items())
)
# Use "aten" as the device name when dispatch_key is Generic
device_name = "aten" if dispatch_key is None else dispatch_key.lower()
# header files were checked in for ABI-compatiblilty checking
header_file_name = f"c_shim_{dispatch_key.lower()}.h"
header_file_name = f"c_shim_{device_name}.h"
new_header = gen_aoti_c_shim(
fallback_native_functions,
inductor_fallback_ops,
fallback_ops_dict,
structured_func_group_dict,
dispatch_key,
backend_indices,
@ -704,10 +734,14 @@ https://github.com/pytorch/pytorch/pull/154848 as an example.
headers.append(header)
return "\n".join(sorted(set(headers)))
extra_headers = extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
extra_headers = (
extra_cuda_headers
if dispatch_key is not None and is_cuda_dispatch_key(dispatch_key)
else ""
)
aoti_fm.write(
f"c_shim_{dispatch_key.lower()}.cpp",
f"c_shim_{device_name}.cpp",
lambda: gen_aoti_c_shim(
fallback_native_functions,
inductor_fallback_ops,