mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable generating generic c_shim that doesn't bypass dispatcher (#158974)
Adds `c_shim_aten.{h/cpp}` and use this for `fill_`
This is the generated `c_shim_aten.cpp` for reference
```cpp
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
// See 7e86a7c015/torchgen/gen.py (L2424-L2436)
for details
// This file corresponds to the aten_shimified_ops list in torchgen/aoti/fallback_ops.py
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
#include <ATen/CompositeImplicitAutogradFunctions.h>
#else
#include <ATen/ops/fill.h>
#endif // AT_PER_OPERATOR_HEADERS
using namespace torch::aot_inductor;
AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::fill_(
*tensor_handle_to_tensor_pointer(self), value
);
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158974
Approved by: https://github.com/albanD, https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
bfe6765d6b
commit
e65ab9a868
@ -175,3 +175,12 @@ inductor_fallback_ops: dict[str, dict[str, list[str]]] = {
|
||||
"aten.view.dtype": {},
|
||||
"aten._weight_int4pack_mm_with_scales_and_zeros.default": {},
|
||||
}
|
||||
|
||||
# `python torchgen/gen.py --update-aoti-c-shim` will automatically generate
|
||||
# c_shim_aten.{h/cpp} based on the list below.
|
||||
# Operators in this list are intended to be used in torch/csrc/stable/ops.h
|
||||
# Unlike other c_shims, operators in this file do not bypass the dispatcher.
|
||||
# The same BC rules apply as inductor_fallback_ops.
|
||||
aten_shimified_ops: dict[str, dict[str, list[str]]] = {
|
||||
"aten.fill_.Scalar": {},
|
||||
}
|
||||
|
@ -95,6 +95,7 @@ from torchgen.yaml_utils import YamlDumper, YamlLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -2215,7 +2216,7 @@ def gen_source_files(
|
||||
per_operator_headers: bool,
|
||||
skip_dispatcher_op_registration: bool,
|
||||
update_aoti_c_shim: bool,
|
||||
aoti_backends: set[DispatchKey],
|
||||
aoti_backends: set[Optional[DispatchKey]],
|
||||
extend_aoti_c_shim: bool,
|
||||
) -> None:
|
||||
extra_cuda_headers = """\
|
||||
@ -2840,6 +2841,9 @@ def main() -> None:
|
||||
aoti_backends = {
|
||||
DispatchKey.CPU,
|
||||
DispatchKey.CUDA,
|
||||
# None will generate the aten shim based on aten_shimified_ops
|
||||
# which does not bypass the dispatcher
|
||||
None,
|
||||
}
|
||||
|
||||
# TODO: stop generating CUDA kernels for non-CUDA builds
|
||||
|
@ -6,7 +6,7 @@ import textwrap
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torchgen.aoti.fallback_ops import inductor_fallback_ops
|
||||
from torchgen.aoti.fallback_ops import aten_shimified_ops, inductor_fallback_ops
|
||||
from torchgen.api.types import DispatcherSignature
|
||||
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
|
||||
from torchgen.context import method_with_native_function
|
||||
@ -30,6 +30,7 @@ from torchgen.utils import FileManager, mapMaybe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
|
||||
base_type_to_c_type = {
|
||||
@ -391,21 +392,28 @@ def gen_static_dispatch_backend_call_signature(
|
||||
|
||||
def gen_static_dispatch_backend_call(
|
||||
f: NativeFunction,
|
||||
backend_index: BackendIndex,
|
||||
backend_index: Optional[BackendIndex] = None,
|
||||
) -> str:
|
||||
sig = DispatcherSignature.from_schema(f.func)
|
||||
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
|
||||
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
|
||||
if backend_index is None:
|
||||
return f"at::{cpp_sig.name()}"
|
||||
else:
|
||||
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
|
||||
|
||||
|
||||
def get_backend_index_for_aoti(
|
||||
func: NativeFunction,
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
dispatch_key: Optional[DispatchKey],
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
extend_aoti_c_shim: bool,
|
||||
) -> BackendIndex | None:
|
||||
backend_index = None
|
||||
|
||||
if dispatch_key is None:
|
||||
return backend_index
|
||||
|
||||
if backend_indices[dispatch_key].has_kernel(func) or (
|
||||
func.structured_delegate is not None
|
||||
and func.structured_delegate in func_group_mapping
|
||||
@ -439,18 +447,19 @@ def get_backend_index_for_aoti(
|
||||
def get_header_for_aoti(
|
||||
func: NativeFunction,
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
dispatch_key: Optional[DispatchKey],
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
extend_aoti_c_shim: bool,
|
||||
) -> str | None:
|
||||
backend_index = get_backend_index_for_aoti(
|
||||
func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim
|
||||
)
|
||||
return (
|
||||
None
|
||||
if backend_index is None
|
||||
else f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
|
||||
)
|
||||
if backend_index is None:
|
||||
if dispatch_key is None:
|
||||
return f"#include <ATen/ops/{func.root_name}.h>"
|
||||
return None
|
||||
|
||||
return f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
|
||||
|
||||
|
||||
def get_fallback_op_name(func: NativeFunction) -> str:
|
||||
@ -465,7 +474,7 @@ def gen_c_shim(
|
||||
func: NativeFunction,
|
||||
version_info: dict[str, list[str]],
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
dispatch_key: Optional[DispatchKey],
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
header: bool,
|
||||
extend_aoti_c_shim: bool,
|
||||
@ -473,11 +482,11 @@ def gen_c_shim(
|
||||
backend_index = get_backend_index_for_aoti(
|
||||
func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim
|
||||
)
|
||||
if backend_index is None:
|
||||
if backend_index is None and dispatch_key is not None:
|
||||
return None
|
||||
|
||||
schema = func.func
|
||||
device = dispatch_key.lower()
|
||||
device = "aten" if dispatch_key is None else dispatch_key.lower()
|
||||
backend_call = gen_static_dispatch_backend_call(
|
||||
func,
|
||||
backend_index,
|
||||
@ -503,7 +512,7 @@ def gen_c_shim(
|
||||
class ShimGenerator:
|
||||
inductor_fallback_ops: dict[str, dict[str, list[str]]]
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
|
||||
dispatch_key: DispatchKey
|
||||
dispatch_key: Optional[DispatchKey]
|
||||
backend_indices: dict[DispatchKey, BackendIndex]
|
||||
header: bool # True to generate .h and False to generate .cpp
|
||||
extend_aoti_c_shim: bool
|
||||
@ -530,7 +539,7 @@ def gen_aoti_c_shim(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
inductor_fallback_ops: dict[str, dict[str, list[str]]],
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
dispatch_key: Optional[DispatchKey],
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
header: bool,
|
||||
extend_aoti_c_shim: bool,
|
||||
@ -551,7 +560,19 @@ def gen_aoti_c_shim(
|
||||
)
|
||||
)
|
||||
)
|
||||
device = dispatch_key.lower()
|
||||
device = "aten" if dispatch_key is None else dispatch_key.lower()
|
||||
include_device_functions = (
|
||||
"#include <ATen/Functions.h>"
|
||||
if dispatch_key is None
|
||||
else f"#include <ATen/{str(dispatch_key)}Functions.h>"
|
||||
)
|
||||
aten_warning = (
|
||||
(
|
||||
"\n\n// This file corresponds to the aten_shimified_ops list in torchgen/aoti/fallback_ops.py\n"
|
||||
)
|
||||
if dispatch_key is None
|
||||
else ""
|
||||
)
|
||||
warning = """
|
||||
|
||||
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
|
||||
@ -560,6 +581,7 @@ def gen_aoti_c_shim(
|
||||
if header:
|
||||
return (
|
||||
warning
|
||||
+ aten_warning
|
||||
+ textwrap.dedent("""
|
||||
|
||||
#pragma once
|
||||
@ -582,13 +604,14 @@ def gen_aoti_c_shim(
|
||||
else:
|
||||
return (
|
||||
warning
|
||||
+ aten_warning
|
||||
+ textwrap.dedent(f"""
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/generated/{"extend/" if extend_aoti_c_shim else ""}c_shim_{device}.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/{str(dispatch_key)}Functions.h>
|
||||
{include_device_functions}
|
||||
#include <ATen/CompositeExplicitAutogradFunctions.h>
|
||||
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
|
||||
#include <ATen/CompositeImplicitAutogradFunctions.h>
|
||||
@ -607,7 +630,7 @@ def gen_aoti_c_shim(
|
||||
|
||||
def gen_aoti_c_shim_files(
|
||||
aoti_fm: FileManager,
|
||||
aoti_backends: set[DispatchKey],
|
||||
aoti_backends: set[Optional[DispatchKey]],
|
||||
native_functions: Sequence[NativeFunction],
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
structured_native_functions: Sequence[NativeFunctionsGroup],
|
||||
@ -623,20 +646,27 @@ def gen_aoti_c_shim_files(
|
||||
break
|
||||
|
||||
for dispatch_key in aoti_backends:
|
||||
# Use aten_shimified_ops for the aten backend, inductor_fallback_ops for others
|
||||
fallback_ops_dict = (
|
||||
aten_shimified_ops if dispatch_key is None else inductor_fallback_ops
|
||||
)
|
||||
fallbacks = {}
|
||||
for func in native_functions:
|
||||
op_name = get_fallback_op_name(func)
|
||||
if op_name in inductor_fallback_ops:
|
||||
if op_name in fallback_ops_dict:
|
||||
fallbacks[op_name] = func
|
||||
fallback_native_functions = tuple(
|
||||
value for _, value in sorted(fallbacks.items())
|
||||
)
|
||||
|
||||
# Use "aten" as the device name when dispatch_key is Generic
|
||||
device_name = "aten" if dispatch_key is None else dispatch_key.lower()
|
||||
|
||||
# header files were checked in for ABI-compatiblilty checking
|
||||
header_file_name = f"c_shim_{dispatch_key.lower()}.h"
|
||||
header_file_name = f"c_shim_{device_name}.h"
|
||||
new_header = gen_aoti_c_shim(
|
||||
fallback_native_functions,
|
||||
inductor_fallback_ops,
|
||||
fallback_ops_dict,
|
||||
structured_func_group_dict,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
@ -704,10 +734,14 @@ https://github.com/pytorch/pytorch/pull/154848 as an example.
|
||||
headers.append(header)
|
||||
return "\n".join(sorted(set(headers)))
|
||||
|
||||
extra_headers = extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
|
||||
extra_headers = (
|
||||
extra_cuda_headers
|
||||
if dispatch_key is not None and is_cuda_dispatch_key(dispatch_key)
|
||||
else ""
|
||||
)
|
||||
|
||||
aoti_fm.write(
|
||||
f"c_shim_{dispatch_key.lower()}.cpp",
|
||||
f"c_shim_{device_name}.cpp",
|
||||
lambda: gen_aoti_c_shim(
|
||||
fallback_native_functions,
|
||||
inductor_fallback_ops,
|
||||
|
Reference in New Issue
Block a user