mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable generating generic c_shim that doesn't bypass dispatcher (#158974)
Adds `c_shim_aten.{h/cpp}` and use this for `fill_`
This is the generated `c_shim_aten.cpp` for reference
```cpp
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
// See 7e86a7c015/torchgen/gen.py (L2424-L2436)
for details
// This file corresponds to the aten_shimified_ops list in torchgen/aoti/fallback_ops.py
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
#include <ATen/CompositeImplicitAutogradFunctions.h>
#else
#include <ATen/ops/fill.h>
#endif // AT_PER_OPERATOR_HEADERS
using namespace torch::aot_inductor;
AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::fill_(
*tensor_handle_to_tensor_pointer(self), value
);
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158974
Approved by: https://github.com/albanD, https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
bfe6765d6b
commit
e65ab9a868
@ -267,6 +267,7 @@ if(NOT INTERN_DISABLE_AUTOGRAD AND NOT BUILD_LITE_INTERPRETER)
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/ADInplaceOrViewType_0.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/ADInplaceOrViewType_1.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_cpu.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_aten.cpp"
|
||||
)
|
||||
if(BUILD_LAZY_TS_BACKEND)
|
||||
list(APPEND GENERATED_CXX_TORCH
|
||||
|
@ -278,14 +278,29 @@ void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outpu
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor fill_infinity(Tensor t) {
|
||||
auto value = std::numeric_limits<float>::infinity();
|
||||
return fill_(t, value);
|
||||
}
|
||||
|
||||
void boxed_fill_infinity(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = fill_infinity(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
|
||||
m.def("my_empty_like(Tensor t) -> Tensor");
|
||||
m.def("fill_infinity(Tensor(a!) t) -> Tensor(a!)");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_transpose", &boxed_my_transpose);
|
||||
m.impl("my_empty_like", &boxed_empty_like);
|
||||
m.impl("fill_infinity", &boxed_fill_infinity);
|
||||
}
|
||||
|
||||
|
||||
|
@ -152,3 +152,15 @@ def my_zero_(t) -> Tensor:
|
||||
Returns: my_zero_(t)
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_zero_.default(t)
|
||||
|
||||
|
||||
def fill_infinity(t) -> Tensor:
|
||||
"""
|
||||
Fills the tensor with inf.
|
||||
|
||||
Args:
|
||||
t: Tensor to fill
|
||||
|
||||
Returns: The modified tensor (same as input)
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.fill_infinity.default(t)
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: cpp"]
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@ -207,6 +208,16 @@ if not IS_WINDOWS:
|
||||
self.assertEqual(id(out), id(t))
|
||||
self.assertEqual(out, torch.zeros_like(t))
|
||||
|
||||
def test_fill_infinity(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.rand(3, 4, device=device)
|
||||
out = libtorch_agnostic.ops.fill_infinity(t)
|
||||
|
||||
self.assertEqual(id(out), id(t))
|
||||
expected = torch.full_like(t, math.inf)
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
21
torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h
Normal file
21
torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h
Normal file
@ -0,0 +1,21 @@
|
||||
|
||||
|
||||
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
|
||||
// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details
|
||||
|
||||
// This file corresponds to the aten_shimified_ops list in torchgen/aoti/fallback_ops.py
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
@ -5,6 +5,8 @@
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
|
||||
|
||||
using torch::stable::Tensor;
|
||||
|
||||
// We expect this to be the stable version of the empty_like op that takes in
|
||||
@ -24,6 +26,16 @@ inline Tensor empty_like(const Tensor& self) {
|
||||
return to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
// We expect this to be the stable version of the fill_.Scalar op
|
||||
// with identical semantics to the existing fill_.Scalar op.
|
||||
// A subtle nuance is that `value` is typed as a double, but it is
|
||||
// actually a Scalar. This is because Scalar.h is currently not
|
||||
// header-only.
|
||||
inline Tensor fill_(const Tensor& self, double value) {
|
||||
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_aten_fill__Scalar(self.get(), value));
|
||||
return self;
|
||||
}
|
||||
|
||||
// We expect this to be the stable version of the transpose op with identical
|
||||
// semantics to the existing transpose.int op.
|
||||
inline Tensor transpose(const Tensor& self, int64_t dim0, int64_t dim1) {
|
||||
|
@ -175,3 +175,12 @@ inductor_fallback_ops: dict[str, dict[str, list[str]]] = {
|
||||
"aten.view.dtype": {},
|
||||
"aten._weight_int4pack_mm_with_scales_and_zeros.default": {},
|
||||
}
|
||||
|
||||
# `python torchgen/gen.py --update-aoti-c-shim` will automatically generate
|
||||
# c_shim_aten.{h/cpp} based on the list below.
|
||||
# Operators in this list are intended to be used in torch/csrc/stable/ops.h
|
||||
# Unlike other c_shims, operators in this file do not bypass the dispatcher.
|
||||
# The same BC rules apply as inductor_fallback_ops.
|
||||
aten_shimified_ops: dict[str, dict[str, list[str]]] = {
|
||||
"aten.fill_.Scalar": {},
|
||||
}
|
||||
|
@ -95,6 +95,7 @@ from torchgen.yaml_utils import YamlDumper, YamlLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -2215,7 +2216,7 @@ def gen_source_files(
|
||||
per_operator_headers: bool,
|
||||
skip_dispatcher_op_registration: bool,
|
||||
update_aoti_c_shim: bool,
|
||||
aoti_backends: set[DispatchKey],
|
||||
aoti_backends: set[Optional[DispatchKey]],
|
||||
extend_aoti_c_shim: bool,
|
||||
) -> None:
|
||||
extra_cuda_headers = """\
|
||||
@ -2840,6 +2841,9 @@ def main() -> None:
|
||||
aoti_backends = {
|
||||
DispatchKey.CPU,
|
||||
DispatchKey.CUDA,
|
||||
# None will generate the aten shim based on aten_shimified_ops
|
||||
# which does not bypass the dispatcher
|
||||
None,
|
||||
}
|
||||
|
||||
# TODO: stop generating CUDA kernels for non-CUDA builds
|
||||
|
@ -6,7 +6,7 @@ import textwrap
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torchgen.aoti.fallback_ops import inductor_fallback_ops
|
||||
from torchgen.aoti.fallback_ops import aten_shimified_ops, inductor_fallback_ops
|
||||
from torchgen.api.types import DispatcherSignature
|
||||
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
|
||||
from torchgen.context import method_with_native_function
|
||||
@ -30,6 +30,7 @@ from torchgen.utils import FileManager, mapMaybe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
|
||||
base_type_to_c_type = {
|
||||
@ -391,21 +392,28 @@ def gen_static_dispatch_backend_call_signature(
|
||||
|
||||
def gen_static_dispatch_backend_call(
|
||||
f: NativeFunction,
|
||||
backend_index: BackendIndex,
|
||||
backend_index: Optional[BackendIndex] = None,
|
||||
) -> str:
|
||||
sig = DispatcherSignature.from_schema(f.func)
|
||||
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
|
||||
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
|
||||
if backend_index is None:
|
||||
return f"at::{cpp_sig.name()}"
|
||||
else:
|
||||
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
|
||||
|
||||
|
||||
def get_backend_index_for_aoti(
|
||||
func: NativeFunction,
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
dispatch_key: Optional[DispatchKey],
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
extend_aoti_c_shim: bool,
|
||||
) -> BackendIndex | None:
|
||||
backend_index = None
|
||||
|
||||
if dispatch_key is None:
|
||||
return backend_index
|
||||
|
||||
if backend_indices[dispatch_key].has_kernel(func) or (
|
||||
func.structured_delegate is not None
|
||||
and func.structured_delegate in func_group_mapping
|
||||
@ -439,18 +447,19 @@ def get_backend_index_for_aoti(
|
||||
def get_header_for_aoti(
|
||||
func: NativeFunction,
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
dispatch_key: Optional[DispatchKey],
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
extend_aoti_c_shim: bool,
|
||||
) -> str | None:
|
||||
backend_index = get_backend_index_for_aoti(
|
||||
func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim
|
||||
)
|
||||
return (
|
||||
None
|
||||
if backend_index is None
|
||||
else f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
|
||||
)
|
||||
if backend_index is None:
|
||||
if dispatch_key is None:
|
||||
return f"#include <ATen/ops/{func.root_name}.h>"
|
||||
return None
|
||||
|
||||
return f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
|
||||
|
||||
|
||||
def get_fallback_op_name(func: NativeFunction) -> str:
|
||||
@ -465,7 +474,7 @@ def gen_c_shim(
|
||||
func: NativeFunction,
|
||||
version_info: dict[str, list[str]],
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
dispatch_key: Optional[DispatchKey],
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
header: bool,
|
||||
extend_aoti_c_shim: bool,
|
||||
@ -473,11 +482,11 @@ def gen_c_shim(
|
||||
backend_index = get_backend_index_for_aoti(
|
||||
func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim
|
||||
)
|
||||
if backend_index is None:
|
||||
if backend_index is None and dispatch_key is not None:
|
||||
return None
|
||||
|
||||
schema = func.func
|
||||
device = dispatch_key.lower()
|
||||
device = "aten" if dispatch_key is None else dispatch_key.lower()
|
||||
backend_call = gen_static_dispatch_backend_call(
|
||||
func,
|
||||
backend_index,
|
||||
@ -503,7 +512,7 @@ def gen_c_shim(
|
||||
class ShimGenerator:
|
||||
inductor_fallback_ops: dict[str, dict[str, list[str]]]
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
|
||||
dispatch_key: DispatchKey
|
||||
dispatch_key: Optional[DispatchKey]
|
||||
backend_indices: dict[DispatchKey, BackendIndex]
|
||||
header: bool # True to generate .h and False to generate .cpp
|
||||
extend_aoti_c_shim: bool
|
||||
@ -530,7 +539,7 @@ def gen_aoti_c_shim(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
inductor_fallback_ops: dict[str, dict[str, list[str]]],
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
dispatch_key: Optional[DispatchKey],
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
header: bool,
|
||||
extend_aoti_c_shim: bool,
|
||||
@ -551,7 +560,19 @@ def gen_aoti_c_shim(
|
||||
)
|
||||
)
|
||||
)
|
||||
device = dispatch_key.lower()
|
||||
device = "aten" if dispatch_key is None else dispatch_key.lower()
|
||||
include_device_functions = (
|
||||
"#include <ATen/Functions.h>"
|
||||
if dispatch_key is None
|
||||
else f"#include <ATen/{str(dispatch_key)}Functions.h>"
|
||||
)
|
||||
aten_warning = (
|
||||
(
|
||||
"\n\n// This file corresponds to the aten_shimified_ops list in torchgen/aoti/fallback_ops.py\n"
|
||||
)
|
||||
if dispatch_key is None
|
||||
else ""
|
||||
)
|
||||
warning = """
|
||||
|
||||
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
|
||||
@ -560,6 +581,7 @@ def gen_aoti_c_shim(
|
||||
if header:
|
||||
return (
|
||||
warning
|
||||
+ aten_warning
|
||||
+ textwrap.dedent("""
|
||||
|
||||
#pragma once
|
||||
@ -582,13 +604,14 @@ def gen_aoti_c_shim(
|
||||
else:
|
||||
return (
|
||||
warning
|
||||
+ aten_warning
|
||||
+ textwrap.dedent(f"""
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/generated/{"extend/" if extend_aoti_c_shim else ""}c_shim_{device}.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/{str(dispatch_key)}Functions.h>
|
||||
{include_device_functions}
|
||||
#include <ATen/CompositeExplicitAutogradFunctions.h>
|
||||
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
|
||||
#include <ATen/CompositeImplicitAutogradFunctions.h>
|
||||
@ -607,7 +630,7 @@ def gen_aoti_c_shim(
|
||||
|
||||
def gen_aoti_c_shim_files(
|
||||
aoti_fm: FileManager,
|
||||
aoti_backends: set[DispatchKey],
|
||||
aoti_backends: set[Optional[DispatchKey]],
|
||||
native_functions: Sequence[NativeFunction],
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
structured_native_functions: Sequence[NativeFunctionsGroup],
|
||||
@ -623,20 +646,27 @@ def gen_aoti_c_shim_files(
|
||||
break
|
||||
|
||||
for dispatch_key in aoti_backends:
|
||||
# Use aten_shimified_ops for the aten backend, inductor_fallback_ops for others
|
||||
fallback_ops_dict = (
|
||||
aten_shimified_ops if dispatch_key is None else inductor_fallback_ops
|
||||
)
|
||||
fallbacks = {}
|
||||
for func in native_functions:
|
||||
op_name = get_fallback_op_name(func)
|
||||
if op_name in inductor_fallback_ops:
|
||||
if op_name in fallback_ops_dict:
|
||||
fallbacks[op_name] = func
|
||||
fallback_native_functions = tuple(
|
||||
value for _, value in sorted(fallbacks.items())
|
||||
)
|
||||
|
||||
# Use "aten" as the device name when dispatch_key is Generic
|
||||
device_name = "aten" if dispatch_key is None else dispatch_key.lower()
|
||||
|
||||
# header files were checked in for ABI-compatiblilty checking
|
||||
header_file_name = f"c_shim_{dispatch_key.lower()}.h"
|
||||
header_file_name = f"c_shim_{device_name}.h"
|
||||
new_header = gen_aoti_c_shim(
|
||||
fallback_native_functions,
|
||||
inductor_fallback_ops,
|
||||
fallback_ops_dict,
|
||||
structured_func_group_dict,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
@ -704,10 +734,14 @@ https://github.com/pytorch/pytorch/pull/154848 as an example.
|
||||
headers.append(header)
|
||||
return "\n".join(sorted(set(headers)))
|
||||
|
||||
extra_headers = extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
|
||||
extra_headers = (
|
||||
extra_cuda_headers
|
||||
if dispatch_key is not None and is_cuda_dispatch_key(dispatch_key)
|
||||
else ""
|
||||
)
|
||||
|
||||
aoti_fm.write(
|
||||
f"c_shim_{dispatch_key.lower()}.cpp",
|
||||
f"c_shim_{device_name}.cpp",
|
||||
lambda: gen_aoti_c_shim(
|
||||
fallback_native_functions,
|
||||
inductor_fallback_ops,
|
||||
|
Reference in New Issue
Block a user