mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
op should NOT be static in aoti_torch_call_dispatcher (#149208)
aoti_torch_call_dispatcher is meant to call different ops, so the op must not be static. Otherwise, every call to this API will call the first op that was ever called, which is not the intended behavior of any human being. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149208 Approved by: https://github.com/albanD, https://github.com/zou3519, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
578160c875
commit
740ce0fa5f
@ -1124,6 +1124,45 @@ class TestCppExtensionJIT(common.TestCase):
|
||||
self.assertEqual(pch_exist, True)
|
||||
self.assertEqual(signature_exist, True)
|
||||
|
||||
def test_aoti_torch_call_dispatcher(self):
|
||||
source = """
|
||||
#include <torch/csrc/inductor/aoti_runtime/utils.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
|
||||
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;
|
||||
|
||||
at::Tensor my_abs(at::Tensor x) {
|
||||
StableIValue stack[1];
|
||||
RAIIATH raii(torch::aot_inductor::new_tensor_handle(std::move(x)));
|
||||
stack[0] = from(raii.release());
|
||||
aoti_torch_call_dispatcher("aten::abs", "", stack);
|
||||
RAIIATH res(to<AtenTensorHandle>(stack[0]));
|
||||
return *reinterpret_cast<at::Tensor*>(res.release());
|
||||
}
|
||||
|
||||
at::Tensor my_floor(at::Tensor x) {
|
||||
StableIValue stack[1];
|
||||
RAIIATH raii(torch::aot_inductor::new_tensor_handle(std::move(x)));
|
||||
stack[0] = from(raii.release());
|
||||
aoti_torch_call_dispatcher("aten::floor", "", stack);
|
||||
RAIIATH res(to<AtenTensorHandle>(stack[0]));
|
||||
return *reinterpret_cast<at::Tensor*>(res.release());
|
||||
}
|
||||
"""
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
name="inline_extension_using_shim_dispatcher",
|
||||
cpp_sources=[source],
|
||||
functions=["my_abs", "my_floor"],
|
||||
)
|
||||
|
||||
t = torch.rand(2, 3) - 1.0
|
||||
floor_t = module.my_floor(t)
|
||||
abs_t = module.my_abs(t)
|
||||
self.assertEqual(abs_t, torch.abs(t))
|
||||
self.assertEqual(floor_t, torch.floor(t))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.run_tests()
|
||||
|
@ -1471,9 +1471,8 @@ AOTITorchError aoti_torch_call_dispatcher(
|
||||
const char* overloadName,
|
||||
StableIValue* stack) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
static auto op =
|
||||
const auto op =
|
||||
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
|
||||
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
Reference in New Issue
Block a user