From 40311e2ec15b991bd9fa7942973fdc015366e140 Mon Sep 17 00:00:00 2001 From: Mu-Chu Lee Date: Wed, 13 Aug 2025 21:12:33 +0000 Subject: [PATCH] [AOTInductor] ABI-Compatibility for RecordFunction. (#159842) Summary: Previous our implementation for RecordFunction injects Aten into codegen, which is breaking the ABI contract for AOTInductor. C10::IValue is aded to call the full record function. The extension of more profiling info will come in later PRs. Test Plan: Included in commit. Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D79622071](https://our.internmc.facebook.com/intern/diff/D79622071) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159842 Approved by: https://github.com/desertfire --- test/inductor/test_aot_inductor.py | 10 +-- torch/_inductor/codegen/cpp.py | 7 +- torch/_inductor/codegen/cpp_template.py | 2 +- .../_inductor/codegen/cpp_template_kernel.py | 6 +- torch/_inductor/codegen/cpp_wrapper_cpu.py | 13 +--- torch/csrc/inductor/aoti_runtime/utils.h | 69 +++++++++++++++++++ torch/csrc/inductor/aoti_torch/c/macros.h | 3 + torch/csrc/inductor/aoti_torch/c/shim.h | 16 +++++ .../csrc/inductor/aoti_torch/shim_common.cpp | 40 +++++++++++ 9 files changed, 146 insertions(+), 20 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 3903d6e5ec48..81a218d5c42e 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -5031,13 +5031,13 @@ class AOTInductorTestsTemplate: _, code = run_and_get_cpp_code( AOTIRunnerUtil.compile, model, example_inputs ) - shim_fn_codes = ( - f'RECORD_FUNCTION("{kernel_calls}", c10::ArrayRef());' - ) + shim_fn_codes = f'RAIIAtenRecordFunctionHandle .*\\("{kernel_calls}"' if enable_kernel_profile: - FileCheck().check(shim_fn_codes).run(code) + FileCheck().check_regex(shim_fn_codes).run(code) else: - FileCheck().check_not(shim_fn_codes).run(code) + FileCheck().check_not("RAIIAtenRecordFunctionHandle").run(code) + + self.check_model(Model(N, K, self.device), example_inputs) def test_aoti_debug_printer_user_defined_triton_kernel(self): if self.device != GPU_TYPE: diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index e71a1d91b0ff..a585cb6951a8 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -5467,7 +5467,7 @@ class KernelGroup: "win32", ] if enable_kernel_profile: - code.writelines(["#include "]) + code.writelines(["#include "]) code.writeline("#include ") # 2. Function definition @@ -5490,7 +5490,10 @@ class KernelGroup: prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" code.writelines( [ - f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef({{}}));' + ( + "torch::aot_inductor::RAIIAtenRecordFunctionHandle " + f'record_{prefix + kernel_name}_("{prefix + kernel_name}", nullptr);' + ) ] ) for old, new in self.args.aliases(): diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index 09ee0b184892..d72f13a3e3fa 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -131,7 +131,7 @@ class CppTemplate(KernelTemplate): "win32", ] if enable_kernel_profile: - res.writelines(["#include "]) + res.writelines(["#include "]) return res def render(self, **kwargs) -> str: diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 184c0fe889af..b0dee69b012b 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -190,7 +190,11 @@ class CppTemplateKernel(CppKernel): if config.cpp.enable_kernel_profile: graph_id = V.graph.graph_id prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" - return f'RECORD_FUNCTION("{prefix}{self.kernel_name}", c10::ArrayRef({{}}));' + handle_str = ( + "torch::aot_inductor::RAIIAtenRecordFunctionHandle " + f'record_{prefix}{self.kernel_name}_("{prefix}{self.kernel_name}", nullptr);' + ) + return handle_str else: return "" diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 794a971adf08..9b1b0ac075ed 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -233,15 +233,6 @@ class CppWrapperCpu(PythonWrapperCodegen): self.header.splice(f"""#include \"{self.model_class_name_suffix}.h\"""") self.header.splice("\n") - enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ - "linux", - "win32", - ] - if config.profiler_mark_wrapper_call or enable_kernel_profile: - # No C shim for profiling APIs, assuming profiling is a debugging feature which - # does not provide any ABI compatibility promise. - self.header.splice("#include ") - def _include_extra_header(self, header: str): # This is needed for cpp to python dtype conversion self.header.splice(f"#include <{header}>") @@ -1251,7 +1242,7 @@ class CppWrapperCpu(PythonWrapperCodegen): shim_fn_codes = textwrap.dedent( f""" {{ - RECORD_FUNCTION("{shim_fn}", c10::ArrayRef()); + RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}", nullptr); {shim_fn_codes} }} """ @@ -1495,7 +1486,7 @@ class CppWrapperCpu(PythonWrapperCodegen): def generate_profiler_mark_wrapper_call(self, stack): self.wrapper_call.writeline( - 'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef());' + 'RAIIAtenRecordFunctionHandle record_inductor_wrapper_call_("inductor_wrapper_call", nullptr);' ) def generate_start_graph(self): diff --git a/torch/csrc/inductor/aoti_runtime/utils.h b/torch/csrc/inductor/aoti_runtime/utils.h index 8d1dd116afe5..b813b3f6f745 100644 --- a/torch/csrc/inductor/aoti_runtime/utils.h +++ b/torch/csrc/inductor/aoti_runtime/utils.h @@ -42,11 +42,80 @@ using DeleterFnPtr = void (*)(void*); inline void noop_deleter(void*) {} +inline void delete_record_function_object(void* ptr) { + AOTI_TORCH_ERROR_CODE_CHECK(aoti_record_function_end( + reinterpret_cast(ptr))); +} + inline void delete_tensor_object(void* ptr) { AOTI_TORCH_ERROR_CODE_CHECK( aoti_torch_delete_tensor_object(reinterpret_cast(ptr))); } +class RAIIAtenRecordFunctionHandle { + public: + RAIIAtenRecordFunctionHandle() : handle_(nullptr, noop_deleter) {} + RAIIAtenRecordFunctionHandle(const RAIIAtenRecordFunctionHandle& other) = + delete; + RAIIAtenRecordFunctionHandle& operator=( + const RAIIAtenRecordFunctionHandle& other) = delete; + + // Initiate an RAII RecordFunction without Inputs + RAIIAtenRecordFunctionHandle(const char* name, IValueMapHandle kwargs) + : handle_(nullptr, delete_record_function_object) { + AtenRecordFunctionHandle tmp_handle = nullptr; + aoti_record_function_start(name, kwargs, nullptr, 0, &tmp_handle); + handle_.reset(tmp_handle); + } + + // Initiate an RAII RecordFunction with Inputs + RAIIAtenRecordFunctionHandle( + const char* name, + IValueMapHandle kwargs, + std::vector inputs) + : handle_(nullptr, delete_record_function_object) { + AtenRecordFunctionHandle tmp_handle = nullptr; + aoti_record_function_start( + name, kwargs, inputs.data(), inputs.size(), &tmp_handle); + handle_.reset(tmp_handle); + } + + // Steal the ownership from another RAIIAtenRecordFunctionHandle using + // std::move + RAIIAtenRecordFunctionHandle(RAIIAtenRecordFunctionHandle&& other) = default; + RAIIAtenRecordFunctionHandle& operator=( + RAIIAtenRecordFunctionHandle&& other) = default; + + // Steal the ownership from raw AtenRecordFunctionHandle + RAIIAtenRecordFunctionHandle(AtenRecordFunctionHandle handle) + : handle_(handle, delete_record_function_object) {} + + ~RAIIAtenRecordFunctionHandle() { + handle_.reset(); + } + + // Return a raw AtenRecordFunctionHandle to be used by aoti_torch functions + // Note: this function does NOT transfer the ownership of the handle + operator AtenRecordFunctionHandle() const { + return handle_.get(); + } + + AtenRecordFunctionHandle release() { + return handle_.release(); + } + + AtenRecordFunctionHandle get() const { + return handle_.get(); + } + + void reset() { + handle_.reset(); + } + + private: + std::unique_ptr handle_; +}; + // RAIIAtenTensorHandle steals the tensor objects created by the libtorch C ABI class RAIIAtenTensorHandle { public: diff --git a/torch/csrc/inductor/aoti_torch/c/macros.h b/torch/csrc/inductor/aoti_torch/c/macros.h index 6f1346cdcf86..e49cd39deac0 100644 --- a/torch/csrc/inductor/aoti_torch/c/macros.h +++ b/torch/csrc/inductor/aoti_torch/c/macros.h @@ -52,6 +52,9 @@ using AtenGeneratorHandle = AtenGeneratorOpaque*; struct AOTIProxyExecutorOpaque; using AOTIProxyExecutorHandle = AOTIProxyExecutorOpaque*; +struct C10IValueOpaque; +using C10IValueHandle = C10IValueOpaque*; + using AOTITorchError = int32_t; #define AOTI_TORCH_SUCCESS 0 #define AOTI_TORCH_FAILURE 1 diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index a5083bb1405f..8bda9bcc28a2 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -400,6 +400,22 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_zero_(AtenTensorHandle self); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_check_inf_and_nan(const char* tensor_name, AtenTensorHandle tensor); +struct AtenRecordFunctionOpaque; +using AtenRecordFunctionHandle = AtenRecordFunctionOpaque*; + +struct IValueMapOpaque; +using IValueMapHandle = IValueMapOpaque*; + +AOTI_TORCH_EXPORT AOTITorchError aoti_record_function_start( + const char* name, + IValueMapHandle kwargs, + const C10IValueHandle* inputs, + const uint64_t n_inputs, + AtenRecordFunctionHandle* guard); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_record_function_end(AtenRecordFunctionHandle guard); + AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_out( AtenTensorHandle out, AtenTensorHandle self, diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 89218e4e5c98..b52fc3f363cb 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -1091,6 +1092,45 @@ AOTITorchError aoti_torch_check_inf_and_nan( }); } +AOTITorchError aoti_record_function_start( + const char* name, + IValueMapHandle kwargs, + const C10IValueHandle* inputs, + const uint64_t n_inputs, + AtenRecordFunctionHandle* guard) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::RecordFunction* newGuard = + new at::RecordFunction(at::RecordScope::FUNCTION); + std::unordered_map recordKwargs; + + if (kwargs != nullptr) { + auto wrappedKwargs = + reinterpret_cast*>( + kwargs); + for (const auto& pair : *wrappedKwargs) { + recordKwargs.emplace( + pair.first, *(reinterpret_cast(pair.second))); + } + } + + std::vector recordInputs(n_inputs); + for (size_t i = 0; i < n_inputs; i++) { + recordInputs.push_back(*reinterpret_cast(inputs[i])); + } + + newGuard->before(name, &recordInputs, &recordKwargs); + *guard = reinterpret_cast(newGuard); + }); +} + +AOTITorchError aoti_record_function_end(AtenRecordFunctionHandle guard) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::RecordFunction* t = reinterpret_cast(guard); + + delete t; + }); +} + AOTITorchError aoti_torch_scatter_out( AtenTensorHandle out, AtenTensorHandle self,