[executorch] Add RuntimeContext to generated C++ API Signature (#94570)

Summary:
Pass runtime context all the way to kernel level.

RegisterCodegenUnboxedKernels.cpp:

```
static Operator operators_to_register[] = {
    Operator(
        "aten::add.out",
        [](torch::executor::RuntimeContext & context, EValue** stack) {

            EValue& self = *stack[0];
    	EValue& other = *stack[1];
    	EValue& alpha = *stack[2];
    	EValue& out = *stack[3];
    	const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
    	const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
    	const torch::executor::Scalar & alpha_base = alpha.to<torch::executor::Scalar>();
    	torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();

            EXECUTORCH_SCOPE_PROF("native_call_add.out");
            torch::executor::aten::add_outf(context, self_base, other_base, alpha_base, out_base);

        }
    ),
}
```

Functions.h
```

// aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
TORCH_API inline at::Tensor & add_outf(torch::executor::RuntimeContext & context, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) {
    return torch::executor::native::add_out(self, other, alpha, out);
}

```

Test Plan: TBD

Differential Revision: D41325633

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94570
Approved by: https://github.com/cccclai
This commit is contained in:
Mengwei Liu
2023-02-16 02:43:14 +00:00
committed by PyTorch MergeBot
parent e5c2a35d83
commit 41865bd8ed
8 changed files with 139 additions and 19 deletions

View File

@ -0,0 +1,22 @@
#pragma once
namespace torch {
namespace executor {
/**
* Bucket type abstraction that contains many elements of runtime state that
* a kernel author may want available, but would otherwise be unable to access.
*
* Forwarded along to all operators when running in lean mode.
* NOTE: Will not be forwarded to operators if running in ATen mode
* as those operators do not expect to receive a RuntimeContext and would not
* use it.
*
* This includes things like setting an error state, a scratch allocator for
* operators that need more then constant space, and a TensorResizer for dynamic
* shape tensors allowing programs to be more flexible with Tensor shape.
*/
class RuntimeContext {};
} // namespace executor
} // namespace torch

View File

@ -4,13 +4,14 @@
#include <c10/util/ArrayRef.h>
#include "Evalue.h"
#include "RuntimeContext.h"
#include <functional>
#include <map>
namespace torch {
namespace executor {
using OpFunction = std::function<void(EValue**)>;
using OpFunction = std::function<void(RuntimeContext&, EValue**)>;
template<typename T>
using ArrayRef = at::ArrayRef<T>;

View File

@ -18,7 +18,8 @@ TEST(OperatorRegistrationTest, Add) {
for (size_t i = 0; i < 4; i++) {
kernel_values[i] = &values[i];
}
op(kernel_values);
RuntimeContext context{};
op(context, kernel_values);
at::Tensor expected = at::ones({2, 3});
expected = at::fill(expected, 2);
ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor()));
@ -39,7 +40,8 @@ TEST(OperatorRegistrationTest, CustomAdd3) {
for (size_t i = 0; i < 4; i++) {
kernel_values[i] = &values[i];
}
op(kernel_values);
RuntimeContext context{};
op(context, kernel_values);
at::Tensor expected = at::ones({2, 3});
expected = at::fill(expected, 3);
ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor()));

View File

@ -181,8 +181,8 @@ class TestGenFunctionsDeclarations(unittest.TestCase):
namespace custom_1 {
// custom_1::op_1() -> bool
TORCH_API inline bool op_1() {
return ::at::native::kernel_1();
TORCH_API inline bool op_1(torch::executor::RuntimeContext & context) {
return ::at::native::kernel_1(context);
}
} // namespace custom_1
@ -195,8 +195,8 @@ TORCH_API inline bool op_1() {
namespace custom_2 {
// custom_2::op_2() -> bool
TORCH_API inline bool op_2() {
return ::at::native::kernel_2();
TORCH_API inline bool op_2(torch::executor::RuntimeContext & context) {
return ::at::native::kernel_2(context);
}
} // namespace custom_2

View File

@ -0,0 +1,58 @@
import unittest
from torchgen.executorch.api.types import ExecutorchCppSignature
from torchgen.local import parametrize
from torchgen.model import Location, NativeFunction
DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
{"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"},
loc=Location(__file__, 1),
valid_tags=set(),
)
class ExecutorchCppSignatureTest(unittest.TestCase):
def setUp(self) -> None:
self.sig = ExecutorchCppSignature.from_native_function(DEFAULT_NATIVE_FUNCTION)
def test_runtime_signature_contains_runtime_context(self) -> None:
# test if `RuntimeContext` argument exists in `RuntimeSignature`
with parametrize(
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
):
args = self.sig.arguments(include_context=True)
self.assertEquals(len(args), 3)
self.assertTrue(any(a.name == "context" for a in args))
def test_runtime_signature_does_not_contain_runtime_context(self) -> None:
# test if `RuntimeContext` argument is missing in `RuntimeSignature`
with parametrize(
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
):
args = self.sig.arguments(include_context=False)
self.assertEquals(len(args), 2)
self.assertFalse(any(a.name == "context" for a in args))
def test_runtime_signature_declaration_correct(self) -> None:
with parametrize(
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
):
decl = self.sig.decl(include_context=True)
self.assertEquals(
decl,
(
"torch::executor::Tensor & foo_outf("
"torch::executor::RuntimeContext & context, "
"const torch::executor::Tensor & input, "
"torch::executor::Tensor & out)"
),
)
no_context_decl = self.sig.decl(include_context=False)
self.assertEquals(
no_context_decl,
(
"torch::executor::Tensor & foo_outf("
"const torch::executor::Tensor & input, "
"torch::executor::Tensor & out)"
),
)

View File

@ -6,12 +6,15 @@ import torchgen.api.cpp as aten_cpp
from torchgen.api.types import Binding, CType
from torchgen.model import FunctionSchema, NativeFunction
from .types import contextArg
@dataclass(frozen=True)
class ExecutorchCppSignature:
"""
This signature is merely a CppSignature with Executorch types. The inline definition
of CppSignature is generated in Functions.h and it's used by unboxing functions.
This signature is merely a CppSignature with Executorch types (optionally contains
RuntimeContext as well). The inline definition of CppSignature is generated in Functions.h
and it's used by unboxing functions.
"""
# The schema this signature is derived from
@ -25,8 +28,8 @@ class ExecutorchCppSignature:
# and need to avoid naming collisions.
prefix: str = ""
def arguments(self) -> List[Binding]:
return et_cpp.arguments(
def arguments(self, *, include_context: bool = True) -> List[Binding]:
return ([contextArg] if include_context else []) + et_cpp.arguments(
self.func.arguments,
faithful=True, # always faithful, out argument at the end
method=False, # method not supported
@ -39,8 +42,10 @@ class ExecutorchCppSignature:
faithful_name_for_out_overloads=True,
)
def decl(self, name: Optional[str] = None) -> str:
args_str = ", ".join(a.decl() for a in self.arguments())
def decl(self, name: Optional[str] = None, *, include_context: bool = True) -> str:
args_str = ", ".join(
a.decl() for a in self.arguments(include_context=include_context)
)
if name is None:
name = self.name()
return f"{self.returns_type().cpp_type()} {name}({args_str})"

View File

@ -1,7 +1,18 @@
from dataclasses import dataclass
from typing import Dict
from torchgen.api.types import BaseCppType, boolT, CType, doubleT, longT
from torchgen.api.types import (
BaseCppType,
BaseCType,
Binding,
boolT,
CType,
doubleT,
Expr,
longT,
MutRefCType,
NamedCType,
)
from torchgen.model import BaseTy
halfT = BaseCppType("torch::executor", "Half")
@ -14,6 +25,19 @@ scalarT = BaseCppType("torch::executor", "Scalar")
memoryFormatT = BaseCppType("torch::executor", "MemoryFormat")
intArrayRefT = BaseCppType("torch::executor", "IntArrayRef")
optionalT = BaseCppType("torch::executor", "optional")
contextT = BaseCppType("torch::executor", "RuntimeContext")
contextExpr = Expr(
expr="context",
type=NamedCType(name="context", type=MutRefCType(BaseCType(contextT))),
)
contextArg = Binding(
name="context",
nctype=contextExpr.type,
argument=None, # type: ignore[arg-type]
default=None,
)
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
BaseTy.int: longT,

View File

@ -17,7 +17,7 @@ from torchgen.executorch.api.custom_ops import (
ComputeNativeFunctionStub,
gen_custom_ops_registration,
)
from torchgen.executorch.api.types import ExecutorchCppSignature
from torchgen.executorch.api.types import contextArg, ExecutorchCppSignature
from torchgen.executorch.api.unboxing import Unboxing
from torchgen.gen import (
get_custom_build_selector,
@ -149,14 +149,16 @@ class ComputeCodegenUnboxedKernels:
).most_faithful_signature()
argument_type_gen = aten_cpp.argumenttype_type
return_type_gen = aten_cpp.returns_type
arguments = sig.arguments()
else:
sig = ExecutorchCppSignature.from_native_function(f)
argument_type_gen = et_cpp.argumenttype_type
return_type_gen = et_cpp.returns_type
arguments = sig.arguments(include_context=False)
# parse arguments into C++ code
binding_list, code_list = Unboxing(
argument_type_gen=argument_type_gen
).convert_arguments(sig.arguments())
).convert_arguments(arguments)
# for each C++ argument, generate the conversion code
code_connector = "\n\t"
@ -185,11 +187,12 @@ class ComputeCodegenUnboxedKernels:
return f"""
Operator(
"{f.namespace}::{f.func.name}",
[](EValue** stack) {{
[]({contextArg.defn()}, EValue** stack) {{
{"(void)context;" if self.use_aten_lib else ""}
{code_connector.join(code_list)}
EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}");
{ret_prefix}torch::executor::{f.namespace}::{sig.name()}({args_str});
{ret_prefix}torch::executor::{f.namespace}::{sig.name()}({"" if self.use_aten_lib else "context, "}{args_str});
{return_assignment}
}}
@ -229,7 +232,12 @@ def compute_native_function_declaration(
if metadata is None:
return []
prefix = "static" if backend_index.external else "TORCH_API"
return [f"{prefix} {sig.decl(name=metadata.kernel)};"]
# for kernels in lean mode, we declare two versions, one with context and one without.
# In the end we will cleanup the unused one.
return [
f"{prefix} {sig.decl(name=metadata.kernel)};",
f"{prefix} {sig.decl(name=metadata.kernel, include_context=False)};",
]
def gen_functions_declarations(