op should NOT be static in aoti_torch_call_dispatcher (#149208)

aoti_torch_call_dispatcher is meant to call different ops, so the op must not be static. Otherwise, every call to this API will call the first op that was ever called, which is not the intended behavior of any human being.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149208
Approved by: https://github.com/albanD, https://github.com/zou3519, https://github.com/malfet
This commit is contained in:
Jane Xu
2025-03-14 10:09:47 -07:00
committed by PyTorch MergeBot
parent 578160c875
commit 740ce0fa5f
2 changed files with 40 additions and 2 deletions

View File

@ -1124,6 +1124,45 @@ class TestCppExtensionJIT(common.TestCase):
self.assertEqual(pch_exist, True)
self.assertEqual(signature_exist, True)
def test_aoti_torch_call_dispatcher(self):
source = """
#include <torch/csrc/inductor/aoti_runtime/utils.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/library.h>
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;
at::Tensor my_abs(at::Tensor x) {
StableIValue stack[1];
RAIIATH raii(torch::aot_inductor::new_tensor_handle(std::move(x)));
stack[0] = from(raii.release());
aoti_torch_call_dispatcher("aten::abs", "", stack);
RAIIATH res(to<AtenTensorHandle>(stack[0]));
return *reinterpret_cast<at::Tensor*>(res.release());
}
at::Tensor my_floor(at::Tensor x) {
StableIValue stack[1];
RAIIATH raii(torch::aot_inductor::new_tensor_handle(std::move(x)));
stack[0] = from(raii.release());
aoti_torch_call_dispatcher("aten::floor", "", stack);
RAIIATH res(to<AtenTensorHandle>(stack[0]));
return *reinterpret_cast<at::Tensor*>(res.release());
}
"""
module = torch.utils.cpp_extension.load_inline(
name="inline_extension_using_shim_dispatcher",
cpp_sources=[source],
functions=["my_abs", "my_floor"],
)
t = torch.rand(2, 3) - 1.0
floor_t = module.my_floor(t)
abs_t = module.my_abs(t)
self.assertEqual(abs_t, torch.abs(t))
self.assertEqual(floor_t, torch.floor(t))
if __name__ == "__main__":
common.run_tests()

View File

@ -1471,9 +1471,8 @@ AOTITorchError aoti_torch_call_dispatcher(
const char* overloadName,
StableIValue* stack) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
static auto op =
const auto op =
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();