Files
pytorch/test/edge/test_operator_registration.cpp
Mengwei Liu 41865bd8ed [executorch] Add RuntimeContext to generated C++ API Signature (#94570)
Summary:
Pass runtime context all the way to kernel level.

RegisterCodegenUnboxedKernels.cpp:

```
static Operator operators_to_register[] = {
    Operator(
        "aten::add.out",
        [](torch::executor::RuntimeContext & context, EValue** stack) {

            EValue& self = *stack[0];
    	EValue& other = *stack[1];
    	EValue& alpha = *stack[2];
    	EValue& out = *stack[3];
    	const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
    	const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
    	const torch::executor::Scalar & alpha_base = alpha.to<torch::executor::Scalar>();
    	torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();

            EXECUTORCH_SCOPE_PROF("native_call_add.out");
            torch::executor::aten::add_outf(context, self_base, other_base, alpha_base, out_base);

        }
    ),
}
```

Functions.h
```

// aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
TORCH_API inline at::Tensor & add_outf(torch::executor::RuntimeContext & context, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) {
    return torch::executor::native::add_out(self, other, alpha, out);
}

```

Test Plan: TBD

Differential Revision: D41325633

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94570
Approved by: https://github.com/cccclai
2023-02-16 02:43:18 +00:00

52 lines
1.6 KiB
C++

#include "operator_registry.h"
#include <gtest/gtest.h>
namespace torch {
namespace executor {
// add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
TEST(OperatorRegistrationTest, Add) {
EValue values[4];
values[0] = EValue(at::ones({2, 3}));
values[1] = EValue(at::ones({2, 3}));
values[2] = EValue(int64_t(1));
values[3] = EValue(at::zeros({2, 3}));
ASSERT_TRUE(hasOpsFn("aten::add.out"));
auto op = getOpsFn("aten::add.out");
EValue* kernel_values[4];
for (size_t i = 0; i < 4; i++) {
kernel_values[i] = &values[i];
}
RuntimeContext context{};
op(context, kernel_values);
at::Tensor expected = at::ones({2, 3});
expected = at::fill(expected, 2);
ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor()));
}
// custom::add_3.out(Tensor a, Tensor b, Tensor c, *, Tensor(a!) out) -> Tensor(a!)
TEST(OperatorRegistrationTest, CustomAdd3) {
EValue values[4];
values[0] = EValue(at::ones({2, 3}));
values[1] = EValue(at::ones({2, 3}));
values[2] = EValue(at::ones({2, 3}));
values[3] = EValue(at::zeros({2, 3}));
ASSERT_TRUE(hasOpsFn("custom::add_3.out"));
auto op = getOpsFn("custom::add_3.out");
EValue* kernel_values[4];
for (size_t i = 0; i < 4; i++) {
kernel_values[i] = &values[i];
}
RuntimeContext context{};
op(context, kernel_values);
at::Tensor expected = at::ones({2, 3});
expected = at::fill(expected, 3);
ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor()));
}
} // namespace executor
} // namespace torch