[torchgen] Rename executorch's RuntimeContext to KernelRuntimeContext (#104892)

Rename the context type to match changes in executorch.

Differential Revision: [D46977359](https://our.internmc.facebook.com/intern/diff/D46977359/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104892
Approved by: https://github.com/larryliu0820
This commit is contained in:
Dave Bort
2023-07-13 18:30:21 -07:00
committed by PyTorch MergeBot
parent 99ab2ad677
commit d06e1df1aa
7 changed files with 28 additions and 25 deletions

View File

@ -7,16 +7,15 @@ namespace executor {
* Bucket type abstraction that contains many elements of runtime state that
* a kernel author may want available, but would otherwise be unable to access.
*
* Forwarded along to all operators when running in lean mode.
* NOTE: Will not be forwarded to operators if running in ATen mode
* as those operators do not expect to receive a RuntimeContext and would not
* use it.
* Forwarded along to all operators when running in lean mode. NOTE: Will not be
* forwarded to operators if running in ATen mode as those operators do not
* expect to receive a KernelRuntimeContext and would not use it.
*
* This includes things like setting an error state, a scratch allocator for
* operators that need more then constant space, and a TensorResizer for dynamic
* shape tensors allowing programs to be more flexible with Tensor shape.
*/
class RuntimeContext {};
class KernelRuntimeContext {};
} // namespace executor
} // namespace torch

View File

@ -1,17 +1,18 @@
#pragma once
#include <cstring>
#include <c10/util/ArrayRef.h>
#include "Evalue.h"
#include "RuntimeContext.h"
#include <functional>
#include <map>
#include "Evalue.h"
#include "kernel_runtime_context.h"
#include <c10/util/ArrayRef.h>
namespace torch {
namespace executor {
using KernelFunction = std::function<void(RuntimeContext&, EValue**)>;
using KernelFunction = std::function<void(KernelRuntimeContext&, EValue**)>;
template<typename T>
using ArrayRef = at::ArrayRef<T>;

View File

@ -1,4 +1,6 @@
#include "kernel_runtime_context.h"
#include "operator_registry.h"
#include <gtest/gtest.h>
namespace torch {
@ -18,7 +20,7 @@ TEST(OperatorRegistrationTest, Add) {
for (size_t i = 0; i < 4; i++) {
kernel_values[i] = &values[i];
}
RuntimeContext context{};
KernelRuntimeContext context{};
op(context, kernel_values);
at::Tensor expected = at::ones({2, 3});
expected = at::fill(expected, 2);
@ -40,7 +42,7 @@ TEST(OperatorRegistrationTest, CustomAdd3) {
for (size_t i = 0; i < 4; i++) {
kernel_values[i] = &values[i];
}
RuntimeContext context{};
KernelRuntimeContext context{};
op(context, kernel_values);
at::Tensor expected = at::ones({2, 3});
expected = at::fill(expected, 3);

View File

@ -366,7 +366,7 @@ class TestGenFunctionsDeclarations(unittest.TestCase):
namespace custom_1 {
// custom_1::op_1() -> bool
TORCH_API inline bool op_1(torch::executor::RuntimeContext & context) {
TORCH_API inline bool op_1(torch::executor::KernelRuntimeContext & context) {
return ::at::native::kernel_1(context);
}
@ -380,7 +380,7 @@ TORCH_API inline bool op_1(torch::executor::RuntimeContext & context) {
namespace custom_2 {
// custom_2::op_2() -> bool
TORCH_API inline bool op_2(torch::executor::RuntimeContext & context) {
TORCH_API inline bool op_2(torch::executor::KernelRuntimeContext & context) {
return ::at::native::kernel_2(context);
}
@ -403,7 +403,7 @@ TORCH_API inline bool op_2(torch::executor::RuntimeContext & context) {
namespace custom_1 {
// custom_1::op_1() -> bool
TORCH_API inline bool op_1(torch::executor::RuntimeContext & context) {
TORCH_API inline bool op_1(torch::executor::KernelRuntimeContext & context) {
return at::op_1();
}
@ -463,7 +463,7 @@ class TestComputeCodegenUnboxedKernels(unittest.TestCase):
Kernel(
"custom_1::op_1",
"v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3",
[](torch::executor::RuntimeContext & context, EValue** stack) {
[](torch::executor::KernelRuntimeContext & context, EValue** stack) {
"""
+ """
@ -548,7 +548,7 @@ Kernel(
"""
Kernel(
"custom_1::op_1",
[](torch::executor::RuntimeContext & context, EValue** stack) {
[](torch::executor::KernelRuntimeContext & context, EValue** stack) {
"""
+ """
@ -582,7 +582,7 @@ Kernel(
"""
Kernel(
"custom_1::op_1",
[](torch::executor::RuntimeContext & context, EValue** stack) {
[](torch::executor::KernelRuntimeContext & context, EValue** stack) {
"""
+ """

View File

@ -16,7 +16,7 @@ class ExecutorchCppSignatureTest(unittest.TestCase):
self.sig = ExecutorchCppSignature.from_native_function(DEFAULT_NATIVE_FUNCTION)
def test_runtime_signature_contains_runtime_context(self) -> None:
# test if `RuntimeContext` argument exists in `RuntimeSignature`
# test if `KernelRuntimeContext` argument exists in `RuntimeSignature`
with parametrize(
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
):
@ -25,7 +25,7 @@ class ExecutorchCppSignatureTest(unittest.TestCase):
self.assertTrue(any(a.name == "context" for a in args))
def test_runtime_signature_does_not_contain_runtime_context(self) -> None:
# test if `RuntimeContext` argument is missing in `RuntimeSignature`
# test if `KernelRuntimeContext` argument is missing in `RuntimeSignature`
with parametrize(
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
):
@ -42,7 +42,7 @@ class ExecutorchCppSignatureTest(unittest.TestCase):
decl,
(
"torch::executor::Tensor & foo_outf("
"torch::executor::RuntimeContext & context, "
"torch::executor::KernelRuntimeContext & context, "
"const torch::executor::Tensor & input, "
"torch::executor::Tensor & out)"
),

View File

@ -12,9 +12,10 @@ from .types import contextArg
@dataclass(frozen=True)
class ExecutorchCppSignature:
"""
This signature is merely a CppSignature with Executorch types (optionally contains
RuntimeContext as well). The inline definition of CppSignature is generated in Functions.h
and it's used by unboxing functions.
This signature is merely a CppSignature with Executorch types (optionally
contains KernelRuntimeContext as well). The inline definition of
CppSignature is generated in Functions.h and it's used by unboxing
functions.
"""
# The schema this signature is derived from

View File

@ -25,7 +25,7 @@ scalarT = BaseCppType("torch::executor", "Scalar")
memoryFormatT = BaseCppType("torch::executor", "MemoryFormat")
intArrayRefT = BaseCppType("torch::executor", "IntArrayRef")
optionalT = BaseCppType("torch::executor", "optional")
contextT = BaseCppType("torch::executor", "RuntimeContext")
contextT = BaseCppType("torch::executor", "KernelRuntimeContext")
contextExpr = Expr(
expr="context",