[AOTInductor] ABI-Compatibility for RecordFunction. (#159842)

Summary:
Previous our implementation for RecordFunction injects Aten into
codegen, which is breaking the ABI contract for AOTInductor.

C10::IValue is aded to call the full record function. The extension of
more profiling info will come in later PRs.

Test Plan:
Included in commit.

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D79622071](https://our.internmc.facebook.com/intern/diff/D79622071)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159842
Approved by: https://github.com/desertfire
This commit is contained in:
Mu-Chu Lee
2025-08-13 21:12:33 +00:00
committed by PyTorch MergeBot
parent 8ca8b6053c
commit 40311e2ec1
9 changed files with 146 additions and 20 deletions

View File

@ -5031,13 +5031,13 @@ class AOTInductorTestsTemplate:
_, code = run_and_get_cpp_code(
AOTIRunnerUtil.compile, model, example_inputs
)
shim_fn_codes = (
f'RECORD_FUNCTION("{kernel_calls}", c10::ArrayRef<c10::IValue>());'
)
shim_fn_codes = f'RAIIAtenRecordFunctionHandle .*\\("{kernel_calls}"'
if enable_kernel_profile:
FileCheck().check(shim_fn_codes).run(code)
FileCheck().check_regex(shim_fn_codes).run(code)
else:
FileCheck().check_not(shim_fn_codes).run(code)
FileCheck().check_not("RAIIAtenRecordFunctionHandle").run(code)
self.check_model(Model(N, K, self.device), example_inputs)
def test_aoti_debug_printer_user_defined_triton_kernel(self):
if self.device != GPU_TYPE:

View File

@ -5467,7 +5467,7 @@ class KernelGroup:
"win32",
]
if enable_kernel_profile:
code.writelines(["#include <ATen/record_function.h>"])
code.writelines(["#include <torch/csrc/inductor/aoti_runtime/utils.h>"])
code.writeline("#include <torch/csrc/inductor/cpp_prefix.h>")
# 2. Function definition
@ -5490,7 +5490,10 @@ class KernelGroup:
prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else ""
code.writelines(
[
f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef<c10::IValue>({{}}));'
(
"torch::aot_inductor::RAIIAtenRecordFunctionHandle "
f'record_{prefix + kernel_name}_("{prefix + kernel_name}", nullptr);'
)
]
)
for old, new in self.args.aliases():

View File

@ -131,7 +131,7 @@ class CppTemplate(KernelTemplate):
"win32",
]
if enable_kernel_profile:
res.writelines(["#include <ATen/record_function.h>"])
res.writelines(["#include <torch/csrc/inductor/aoti_runtime/utils.h>"])
return res
def render(self, **kwargs) -> str:

View File

@ -190,7 +190,11 @@ class CppTemplateKernel(CppKernel):
if config.cpp.enable_kernel_profile:
graph_id = V.graph.graph_id
prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else ""
return f'RECORD_FUNCTION("{prefix}{self.kernel_name}", c10::ArrayRef<c10::IValue>({{}}));'
handle_str = (
"torch::aot_inductor::RAIIAtenRecordFunctionHandle "
f'record_{prefix}{self.kernel_name}_("{prefix}{self.kernel_name}", nullptr);'
)
return handle_str
else:
return ""

View File

@ -233,15 +233,6 @@ class CppWrapperCpu(PythonWrapperCodegen):
self.header.splice(f"""#include \"{self.model_class_name_suffix}.h\"""")
self.header.splice("\n")
enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
"linux",
"win32",
]
if config.profiler_mark_wrapper_call or enable_kernel_profile:
# No C shim for profiling APIs, assuming profiling is a debugging feature which
# does not provide any ABI compatibility promise.
self.header.splice("#include <ATen/record_function.h>")
def _include_extra_header(self, header: str):
# This is needed for cpp to python dtype conversion
self.header.splice(f"#include <{header}>")
@ -1251,7 +1242,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
shim_fn_codes = textwrap.dedent(
f"""
{{
RECORD_FUNCTION("{shim_fn}", c10::ArrayRef<c10::IValue>());
RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}", nullptr);
{shim_fn_codes}
}}
"""
@ -1495,7 +1486,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
def generate_profiler_mark_wrapper_call(self, stack):
self.wrapper_call.writeline(
'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef<c10::IValue>());'
'RAIIAtenRecordFunctionHandle record_inductor_wrapper_call_("inductor_wrapper_call", nullptr);'
)
def generate_start_graph(self):

View File

@ -42,11 +42,80 @@ using DeleterFnPtr = void (*)(void*);
inline void noop_deleter(void*) {}
inline void delete_record_function_object(void* ptr) {
AOTI_TORCH_ERROR_CODE_CHECK(aoti_record_function_end(
reinterpret_cast<AtenRecordFunctionHandle>(ptr)));
}
inline void delete_tensor_object(void* ptr) {
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_delete_tensor_object(reinterpret_cast<AtenTensorHandle>(ptr)));
}
class RAIIAtenRecordFunctionHandle {
public:
RAIIAtenRecordFunctionHandle() : handle_(nullptr, noop_deleter) {}
RAIIAtenRecordFunctionHandle(const RAIIAtenRecordFunctionHandle& other) =
delete;
RAIIAtenRecordFunctionHandle& operator=(
const RAIIAtenRecordFunctionHandle& other) = delete;
// Initiate an RAII RecordFunction without Inputs
RAIIAtenRecordFunctionHandle(const char* name, IValueMapHandle kwargs)
: handle_(nullptr, delete_record_function_object) {
AtenRecordFunctionHandle tmp_handle = nullptr;
aoti_record_function_start(name, kwargs, nullptr, 0, &tmp_handle);
handle_.reset(tmp_handle);
}
// Initiate an RAII RecordFunction with Inputs
RAIIAtenRecordFunctionHandle(
const char* name,
IValueMapHandle kwargs,
std::vector<C10IValueHandle> inputs)
: handle_(nullptr, delete_record_function_object) {
AtenRecordFunctionHandle tmp_handle = nullptr;
aoti_record_function_start(
name, kwargs, inputs.data(), inputs.size(), &tmp_handle);
handle_.reset(tmp_handle);
}
// Steal the ownership from another RAIIAtenRecordFunctionHandle using
// std::move
RAIIAtenRecordFunctionHandle(RAIIAtenRecordFunctionHandle&& other) = default;
RAIIAtenRecordFunctionHandle& operator=(
RAIIAtenRecordFunctionHandle&& other) = default;
// Steal the ownership from raw AtenRecordFunctionHandle
RAIIAtenRecordFunctionHandle(AtenRecordFunctionHandle handle)
: handle_(handle, delete_record_function_object) {}
~RAIIAtenRecordFunctionHandle() {
handle_.reset();
}
// Return a raw AtenRecordFunctionHandle to be used by aoti_torch functions
// Note: this function does NOT transfer the ownership of the handle
operator AtenRecordFunctionHandle() const {
return handle_.get();
}
AtenRecordFunctionHandle release() {
return handle_.release();
}
AtenRecordFunctionHandle get() const {
return handle_.get();
}
void reset() {
handle_.reset();
}
private:
std::unique_ptr<AtenRecordFunctionOpaque, DeleterFnPtr> handle_;
};
// RAIIAtenTensorHandle steals the tensor objects created by the libtorch C ABI
class RAIIAtenTensorHandle {
public:

View File

@ -52,6 +52,9 @@ using AtenGeneratorHandle = AtenGeneratorOpaque*;
struct AOTIProxyExecutorOpaque;
using AOTIProxyExecutorHandle = AOTIProxyExecutorOpaque*;
struct C10IValueOpaque;
using C10IValueHandle = C10IValueOpaque*;
using AOTITorchError = int32_t;
#define AOTI_TORCH_SUCCESS 0
#define AOTI_TORCH_FAILURE 1

View File

@ -400,6 +400,22 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_zero_(AtenTensorHandle self);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_check_inf_and_nan(const char* tensor_name, AtenTensorHandle tensor);
struct AtenRecordFunctionOpaque;
using AtenRecordFunctionHandle = AtenRecordFunctionOpaque*;
struct IValueMapOpaque;
using IValueMapHandle = IValueMapOpaque*;
AOTI_TORCH_EXPORT AOTITorchError aoti_record_function_start(
const char* name,
IValueMapHandle kwargs,
const C10IValueHandle* inputs,
const uint64_t n_inputs,
AtenRecordFunctionHandle* guard);
AOTI_TORCH_EXPORT AOTITorchError
aoti_record_function_end(AtenRecordFunctionHandle guard);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_out(
AtenTensorHandle out,
AtenTensorHandle self,

View File

@ -1,4 +1,5 @@
#include <ATen/native/quantized/cpu/qlinear.h>
#include <ATen/record_function.h>
#include <c10/core/DeviceType.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/GradMode.h>
@ -1091,6 +1092,45 @@ AOTITorchError aoti_torch_check_inf_and_nan(
});
}
AOTITorchError aoti_record_function_start(
const char* name,
IValueMapHandle kwargs,
const C10IValueHandle* inputs,
const uint64_t n_inputs,
AtenRecordFunctionHandle* guard) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::RecordFunction* newGuard =
new at::RecordFunction(at::RecordScope::FUNCTION);
std::unordered_map<std::string, c10::IValue> recordKwargs;
if (kwargs != nullptr) {
auto wrappedKwargs =
reinterpret_cast<std::unordered_map<std::string, C10IValueHandle>*>(
kwargs);
for (const auto& pair : *wrappedKwargs) {
recordKwargs.emplace(
pair.first, *(reinterpret_cast<c10::IValue*>(pair.second)));
}
}
std::vector<c10::IValue> recordInputs(n_inputs);
for (size_t i = 0; i < n_inputs; i++) {
recordInputs.push_back(*reinterpret_cast<c10::IValue*>(inputs[i]));
}
newGuard->before(name, &recordInputs, &recordKwargs);
*guard = reinterpret_cast<AtenRecordFunctionHandle>(newGuard);
});
}
AOTITorchError aoti_record_function_end(AtenRecordFunctionHandle guard) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::RecordFunction* t = reinterpret_cast<at::RecordFunction*>(guard);
delete t;
});
}
AOTITorchError aoti_torch_scatter_out(
AtenTensorHandle out,
AtenTensorHandle self,