mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[executorch] Add RuntimeContext to generated C++ API Signature (#94570)
Summary: Pass runtime context all the way to kernel level. RegisterCodegenUnboxedKernels.cpp: ``` static Operator operators_to_register[] = { Operator( "aten::add.out", [](torch::executor::RuntimeContext & context, EValue** stack) { EValue& self = *stack[0]; EValue& other = *stack[1]; EValue& alpha = *stack[2]; EValue& out = *stack[3]; const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>(); const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>(); const torch::executor::Scalar & alpha_base = alpha.to<torch::executor::Scalar>(); torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>(); EXECUTORCH_SCOPE_PROF("native_call_add.out"); torch::executor::aten::add_outf(context, self_base, other_base, alpha_base, out_base); } ), } ``` Functions.h ``` // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) TORCH_API inline at::Tensor & add_outf(torch::executor::RuntimeContext & context, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) { return torch::executor::native::add_out(self, other, alpha, out); } ``` Test Plan: TBD Differential Revision: D41325633 Pull Request resolved: https://github.com/pytorch/pytorch/pull/94570 Approved by: https://github.com/cccclai
This commit is contained in:
committed by
PyTorch MergeBot
parent
e5c2a35d83
commit
41865bd8ed
22
test/edge/RuntimeContext.h
Normal file
22
test/edge/RuntimeContext.h
Normal file
@ -0,0 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
namespace torch {
|
||||
namespace executor {
|
||||
|
||||
/**
|
||||
* Bucket type abstraction that contains many elements of runtime state that
|
||||
* a kernel author may want available, but would otherwise be unable to access.
|
||||
*
|
||||
* Forwarded along to all operators when running in lean mode.
|
||||
* NOTE: Will not be forwarded to operators if running in ATen mode
|
||||
* as those operators do not expect to receive a RuntimeContext and would not
|
||||
* use it.
|
||||
*
|
||||
* This includes things like setting an error state, a scratch allocator for
|
||||
* operators that need more then constant space, and a TensorResizer for dynamic
|
||||
* shape tensors allowing programs to be more flexible with Tensor shape.
|
||||
*/
|
||||
class RuntimeContext {};
|
||||
|
||||
} // namespace executor
|
||||
} // namespace torch
|
@ -4,13 +4,14 @@
|
||||
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include "Evalue.h"
|
||||
#include "RuntimeContext.h"
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
||||
namespace torch {
|
||||
namespace executor {
|
||||
|
||||
using OpFunction = std::function<void(EValue**)>;
|
||||
using OpFunction = std::function<void(RuntimeContext&, EValue**)>;
|
||||
|
||||
template<typename T>
|
||||
using ArrayRef = at::ArrayRef<T>;
|
||||
|
@ -18,7 +18,8 @@ TEST(OperatorRegistrationTest, Add) {
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
kernel_values[i] = &values[i];
|
||||
}
|
||||
op(kernel_values);
|
||||
RuntimeContext context{};
|
||||
op(context, kernel_values);
|
||||
at::Tensor expected = at::ones({2, 3});
|
||||
expected = at::fill(expected, 2);
|
||||
ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor()));
|
||||
@ -39,7 +40,8 @@ TEST(OperatorRegistrationTest, CustomAdd3) {
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
kernel_values[i] = &values[i];
|
||||
}
|
||||
op(kernel_values);
|
||||
RuntimeContext context{};
|
||||
op(context, kernel_values);
|
||||
at::Tensor expected = at::ones({2, 3});
|
||||
expected = at::fill(expected, 3);
|
||||
ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor()));
|
||||
|
@ -181,8 +181,8 @@ class TestGenFunctionsDeclarations(unittest.TestCase):
|
||||
namespace custom_1 {
|
||||
|
||||
// custom_1::op_1() -> bool
|
||||
TORCH_API inline bool op_1() {
|
||||
return ::at::native::kernel_1();
|
||||
TORCH_API inline bool op_1(torch::executor::RuntimeContext & context) {
|
||||
return ::at::native::kernel_1(context);
|
||||
}
|
||||
|
||||
} // namespace custom_1
|
||||
@ -195,8 +195,8 @@ TORCH_API inline bool op_1() {
|
||||
namespace custom_2 {
|
||||
|
||||
// custom_2::op_2() -> bool
|
||||
TORCH_API inline bool op_2() {
|
||||
return ::at::native::kernel_2();
|
||||
TORCH_API inline bool op_2(torch::executor::RuntimeContext & context) {
|
||||
return ::at::native::kernel_2(context);
|
||||
}
|
||||
|
||||
} // namespace custom_2
|
||||
|
58
tools/test/test_executorch_signatures.py
Normal file
58
tools/test/test_executorch_signatures.py
Normal file
@ -0,0 +1,58 @@
|
||||
import unittest
|
||||
|
||||
from torchgen.executorch.api.types import ExecutorchCppSignature
|
||||
from torchgen.local import parametrize
|
||||
from torchgen.model import Location, NativeFunction
|
||||
|
||||
DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
|
||||
{"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"},
|
||||
loc=Location(__file__, 1),
|
||||
valid_tags=set(),
|
||||
)
|
||||
|
||||
|
||||
class ExecutorchCppSignatureTest(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.sig = ExecutorchCppSignature.from_native_function(DEFAULT_NATIVE_FUNCTION)
|
||||
|
||||
def test_runtime_signature_contains_runtime_context(self) -> None:
|
||||
# test if `RuntimeContext` argument exists in `RuntimeSignature`
|
||||
with parametrize(
|
||||
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
|
||||
):
|
||||
args = self.sig.arguments(include_context=True)
|
||||
self.assertEquals(len(args), 3)
|
||||
self.assertTrue(any(a.name == "context" for a in args))
|
||||
|
||||
def test_runtime_signature_does_not_contain_runtime_context(self) -> None:
|
||||
# test if `RuntimeContext` argument is missing in `RuntimeSignature`
|
||||
with parametrize(
|
||||
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
|
||||
):
|
||||
args = self.sig.arguments(include_context=False)
|
||||
self.assertEquals(len(args), 2)
|
||||
self.assertFalse(any(a.name == "context" for a in args))
|
||||
|
||||
def test_runtime_signature_declaration_correct(self) -> None:
|
||||
with parametrize(
|
||||
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
|
||||
):
|
||||
decl = self.sig.decl(include_context=True)
|
||||
self.assertEquals(
|
||||
decl,
|
||||
(
|
||||
"torch::executor::Tensor & foo_outf("
|
||||
"torch::executor::RuntimeContext & context, "
|
||||
"const torch::executor::Tensor & input, "
|
||||
"torch::executor::Tensor & out)"
|
||||
),
|
||||
)
|
||||
no_context_decl = self.sig.decl(include_context=False)
|
||||
self.assertEquals(
|
||||
no_context_decl,
|
||||
(
|
||||
"torch::executor::Tensor & foo_outf("
|
||||
"const torch::executor::Tensor & input, "
|
||||
"torch::executor::Tensor & out)"
|
||||
),
|
||||
)
|
@ -6,12 +6,15 @@ import torchgen.api.cpp as aten_cpp
|
||||
from torchgen.api.types import Binding, CType
|
||||
from torchgen.model import FunctionSchema, NativeFunction
|
||||
|
||||
from .types import contextArg
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExecutorchCppSignature:
|
||||
"""
|
||||
This signature is merely a CppSignature with Executorch types. The inline definition
|
||||
of CppSignature is generated in Functions.h and it's used by unboxing functions.
|
||||
This signature is merely a CppSignature with Executorch types (optionally contains
|
||||
RuntimeContext as well). The inline definition of CppSignature is generated in Functions.h
|
||||
and it's used by unboxing functions.
|
||||
"""
|
||||
|
||||
# The schema this signature is derived from
|
||||
@ -25,8 +28,8 @@ class ExecutorchCppSignature:
|
||||
# and need to avoid naming collisions.
|
||||
prefix: str = ""
|
||||
|
||||
def arguments(self) -> List[Binding]:
|
||||
return et_cpp.arguments(
|
||||
def arguments(self, *, include_context: bool = True) -> List[Binding]:
|
||||
return ([contextArg] if include_context else []) + et_cpp.arguments(
|
||||
self.func.arguments,
|
||||
faithful=True, # always faithful, out argument at the end
|
||||
method=False, # method not supported
|
||||
@ -39,8 +42,10 @@ class ExecutorchCppSignature:
|
||||
faithful_name_for_out_overloads=True,
|
||||
)
|
||||
|
||||
def decl(self, name: Optional[str] = None) -> str:
|
||||
args_str = ", ".join(a.decl() for a in self.arguments())
|
||||
def decl(self, name: Optional[str] = None, *, include_context: bool = True) -> str:
|
||||
args_str = ", ".join(
|
||||
a.decl() for a in self.arguments(include_context=include_context)
|
||||
)
|
||||
if name is None:
|
||||
name = self.name()
|
||||
return f"{self.returns_type().cpp_type()} {name}({args_str})"
|
||||
|
@ -1,7 +1,18 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
|
||||
from torchgen.api.types import BaseCppType, boolT, CType, doubleT, longT
|
||||
from torchgen.api.types import (
|
||||
BaseCppType,
|
||||
BaseCType,
|
||||
Binding,
|
||||
boolT,
|
||||
CType,
|
||||
doubleT,
|
||||
Expr,
|
||||
longT,
|
||||
MutRefCType,
|
||||
NamedCType,
|
||||
)
|
||||
from torchgen.model import BaseTy
|
||||
|
||||
halfT = BaseCppType("torch::executor", "Half")
|
||||
@ -14,6 +25,19 @@ scalarT = BaseCppType("torch::executor", "Scalar")
|
||||
memoryFormatT = BaseCppType("torch::executor", "MemoryFormat")
|
||||
intArrayRefT = BaseCppType("torch::executor", "IntArrayRef")
|
||||
optionalT = BaseCppType("torch::executor", "optional")
|
||||
contextT = BaseCppType("torch::executor", "RuntimeContext")
|
||||
|
||||
contextExpr = Expr(
|
||||
expr="context",
|
||||
type=NamedCType(name="context", type=MutRefCType(BaseCType(contextT))),
|
||||
)
|
||||
|
||||
contextArg = Binding(
|
||||
name="context",
|
||||
nctype=contextExpr.type,
|
||||
argument=None, # type: ignore[arg-type]
|
||||
default=None,
|
||||
)
|
||||
|
||||
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
|
||||
BaseTy.int: longT,
|
||||
|
@ -17,7 +17,7 @@ from torchgen.executorch.api.custom_ops import (
|
||||
ComputeNativeFunctionStub,
|
||||
gen_custom_ops_registration,
|
||||
)
|
||||
from torchgen.executorch.api.types import ExecutorchCppSignature
|
||||
from torchgen.executorch.api.types import contextArg, ExecutorchCppSignature
|
||||
from torchgen.executorch.api.unboxing import Unboxing
|
||||
from torchgen.gen import (
|
||||
get_custom_build_selector,
|
||||
@ -149,14 +149,16 @@ class ComputeCodegenUnboxedKernels:
|
||||
).most_faithful_signature()
|
||||
argument_type_gen = aten_cpp.argumenttype_type
|
||||
return_type_gen = aten_cpp.returns_type
|
||||
arguments = sig.arguments()
|
||||
else:
|
||||
sig = ExecutorchCppSignature.from_native_function(f)
|
||||
argument_type_gen = et_cpp.argumenttype_type
|
||||
return_type_gen = et_cpp.returns_type
|
||||
arguments = sig.arguments(include_context=False)
|
||||
# parse arguments into C++ code
|
||||
binding_list, code_list = Unboxing(
|
||||
argument_type_gen=argument_type_gen
|
||||
).convert_arguments(sig.arguments())
|
||||
).convert_arguments(arguments)
|
||||
|
||||
# for each C++ argument, generate the conversion code
|
||||
code_connector = "\n\t"
|
||||
@ -185,11 +187,12 @@ class ComputeCodegenUnboxedKernels:
|
||||
return f"""
|
||||
Operator(
|
||||
"{f.namespace}::{f.func.name}",
|
||||
[](EValue** stack) {{
|
||||
[]({contextArg.defn()}, EValue** stack) {{
|
||||
{"(void)context;" if self.use_aten_lib else ""}
|
||||
{code_connector.join(code_list)}
|
||||
|
||||
EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}");
|
||||
{ret_prefix}torch::executor::{f.namespace}::{sig.name()}({args_str});
|
||||
{ret_prefix}torch::executor::{f.namespace}::{sig.name()}({"" if self.use_aten_lib else "context, "}{args_str});
|
||||
|
||||
{return_assignment}
|
||||
}}
|
||||
@ -229,7 +232,12 @@ def compute_native_function_declaration(
|
||||
if metadata is None:
|
||||
return []
|
||||
prefix = "static" if backend_index.external else "TORCH_API"
|
||||
return [f"{prefix} {sig.decl(name=metadata.kernel)};"]
|
||||
# for kernels in lean mode, we declare two versions, one with context and one without.
|
||||
# In the end we will cleanup the unused one.
|
||||
return [
|
||||
f"{prefix} {sig.decl(name=metadata.kernel)};",
|
||||
f"{prefix} {sig.decl(name=metadata.kernel, include_context=False)};",
|
||||
]
|
||||
|
||||
|
||||
def gen_functions_declarations(
|
||||
|
Reference in New Issue
Block a user