mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[AOTInductor] ABI-Compatibility for RecordFunction. (#159842)
Summary: Previous our implementation for RecordFunction injects Aten into codegen, which is breaking the ABI contract for AOTInductor. C10::IValue is aded to call the full record function. The extension of more profiling info will come in later PRs. Test Plan: Included in commit. Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D79622071](https://our.internmc.facebook.com/intern/diff/D79622071) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159842 Approved by: https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
8ca8b6053c
commit
40311e2ec1
@ -5031,13 +5031,13 @@ class AOTInductorTestsTemplate:
|
||||
_, code = run_and_get_cpp_code(
|
||||
AOTIRunnerUtil.compile, model, example_inputs
|
||||
)
|
||||
shim_fn_codes = (
|
||||
f'RECORD_FUNCTION("{kernel_calls}", c10::ArrayRef<c10::IValue>());'
|
||||
)
|
||||
shim_fn_codes = f'RAIIAtenRecordFunctionHandle .*\\("{kernel_calls}"'
|
||||
if enable_kernel_profile:
|
||||
FileCheck().check(shim_fn_codes).run(code)
|
||||
FileCheck().check_regex(shim_fn_codes).run(code)
|
||||
else:
|
||||
FileCheck().check_not(shim_fn_codes).run(code)
|
||||
FileCheck().check_not("RAIIAtenRecordFunctionHandle").run(code)
|
||||
|
||||
self.check_model(Model(N, K, self.device), example_inputs)
|
||||
|
||||
def test_aoti_debug_printer_user_defined_triton_kernel(self):
|
||||
if self.device != GPU_TYPE:
|
||||
|
@ -5467,7 +5467,7 @@ class KernelGroup:
|
||||
"win32",
|
||||
]
|
||||
if enable_kernel_profile:
|
||||
code.writelines(["#include <ATen/record_function.h>"])
|
||||
code.writelines(["#include <torch/csrc/inductor/aoti_runtime/utils.h>"])
|
||||
code.writeline("#include <torch/csrc/inductor/cpp_prefix.h>")
|
||||
|
||||
# 2. Function definition
|
||||
@ -5490,7 +5490,10 @@ class KernelGroup:
|
||||
prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else ""
|
||||
code.writelines(
|
||||
[
|
||||
f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef<c10::IValue>({{}}));'
|
||||
(
|
||||
"torch::aot_inductor::RAIIAtenRecordFunctionHandle "
|
||||
f'record_{prefix + kernel_name}_("{prefix + kernel_name}", nullptr);'
|
||||
)
|
||||
]
|
||||
)
|
||||
for old, new in self.args.aliases():
|
||||
|
@ -131,7 +131,7 @@ class CppTemplate(KernelTemplate):
|
||||
"win32",
|
||||
]
|
||||
if enable_kernel_profile:
|
||||
res.writelines(["#include <ATen/record_function.h>"])
|
||||
res.writelines(["#include <torch/csrc/inductor/aoti_runtime/utils.h>"])
|
||||
return res
|
||||
|
||||
def render(self, **kwargs) -> str:
|
||||
|
@ -190,7 +190,11 @@ class CppTemplateKernel(CppKernel):
|
||||
if config.cpp.enable_kernel_profile:
|
||||
graph_id = V.graph.graph_id
|
||||
prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else ""
|
||||
return f'RECORD_FUNCTION("{prefix}{self.kernel_name}", c10::ArrayRef<c10::IValue>({{}}));'
|
||||
handle_str = (
|
||||
"torch::aot_inductor::RAIIAtenRecordFunctionHandle "
|
||||
f'record_{prefix}{self.kernel_name}_("{prefix}{self.kernel_name}", nullptr);'
|
||||
)
|
||||
return handle_str
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
@ -233,15 +233,6 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
self.header.splice(f"""#include \"{self.model_class_name_suffix}.h\"""")
|
||||
self.header.splice("\n")
|
||||
|
||||
enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
|
||||
"linux",
|
||||
"win32",
|
||||
]
|
||||
if config.profiler_mark_wrapper_call or enable_kernel_profile:
|
||||
# No C shim for profiling APIs, assuming profiling is a debugging feature which
|
||||
# does not provide any ABI compatibility promise.
|
||||
self.header.splice("#include <ATen/record_function.h>")
|
||||
|
||||
def _include_extra_header(self, header: str):
|
||||
# This is needed for cpp to python dtype conversion
|
||||
self.header.splice(f"#include <{header}>")
|
||||
@ -1251,7 +1242,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
shim_fn_codes = textwrap.dedent(
|
||||
f"""
|
||||
{{
|
||||
RECORD_FUNCTION("{shim_fn}", c10::ArrayRef<c10::IValue>());
|
||||
RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}", nullptr);
|
||||
{shim_fn_codes}
|
||||
}}
|
||||
"""
|
||||
@ -1495,7 +1486,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
|
||||
def generate_profiler_mark_wrapper_call(self, stack):
|
||||
self.wrapper_call.writeline(
|
||||
'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef<c10::IValue>());'
|
||||
'RAIIAtenRecordFunctionHandle record_inductor_wrapper_call_("inductor_wrapper_call", nullptr);'
|
||||
)
|
||||
|
||||
def generate_start_graph(self):
|
||||
|
@ -42,11 +42,80 @@ using DeleterFnPtr = void (*)(void*);
|
||||
|
||||
inline void noop_deleter(void*) {}
|
||||
|
||||
inline void delete_record_function_object(void* ptr) {
|
||||
AOTI_TORCH_ERROR_CODE_CHECK(aoti_record_function_end(
|
||||
reinterpret_cast<AtenRecordFunctionHandle>(ptr)));
|
||||
}
|
||||
|
||||
inline void delete_tensor_object(void* ptr) {
|
||||
AOTI_TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_delete_tensor_object(reinterpret_cast<AtenTensorHandle>(ptr)));
|
||||
}
|
||||
|
||||
class RAIIAtenRecordFunctionHandle {
|
||||
public:
|
||||
RAIIAtenRecordFunctionHandle() : handle_(nullptr, noop_deleter) {}
|
||||
RAIIAtenRecordFunctionHandle(const RAIIAtenRecordFunctionHandle& other) =
|
||||
delete;
|
||||
RAIIAtenRecordFunctionHandle& operator=(
|
||||
const RAIIAtenRecordFunctionHandle& other) = delete;
|
||||
|
||||
// Initiate an RAII RecordFunction without Inputs
|
||||
RAIIAtenRecordFunctionHandle(const char* name, IValueMapHandle kwargs)
|
||||
: handle_(nullptr, delete_record_function_object) {
|
||||
AtenRecordFunctionHandle tmp_handle = nullptr;
|
||||
aoti_record_function_start(name, kwargs, nullptr, 0, &tmp_handle);
|
||||
handle_.reset(tmp_handle);
|
||||
}
|
||||
|
||||
// Initiate an RAII RecordFunction with Inputs
|
||||
RAIIAtenRecordFunctionHandle(
|
||||
const char* name,
|
||||
IValueMapHandle kwargs,
|
||||
std::vector<C10IValueHandle> inputs)
|
||||
: handle_(nullptr, delete_record_function_object) {
|
||||
AtenRecordFunctionHandle tmp_handle = nullptr;
|
||||
aoti_record_function_start(
|
||||
name, kwargs, inputs.data(), inputs.size(), &tmp_handle);
|
||||
handle_.reset(tmp_handle);
|
||||
}
|
||||
|
||||
// Steal the ownership from another RAIIAtenRecordFunctionHandle using
|
||||
// std::move
|
||||
RAIIAtenRecordFunctionHandle(RAIIAtenRecordFunctionHandle&& other) = default;
|
||||
RAIIAtenRecordFunctionHandle& operator=(
|
||||
RAIIAtenRecordFunctionHandle&& other) = default;
|
||||
|
||||
// Steal the ownership from raw AtenRecordFunctionHandle
|
||||
RAIIAtenRecordFunctionHandle(AtenRecordFunctionHandle handle)
|
||||
: handle_(handle, delete_record_function_object) {}
|
||||
|
||||
~RAIIAtenRecordFunctionHandle() {
|
||||
handle_.reset();
|
||||
}
|
||||
|
||||
// Return a raw AtenRecordFunctionHandle to be used by aoti_torch functions
|
||||
// Note: this function does NOT transfer the ownership of the handle
|
||||
operator AtenRecordFunctionHandle() const {
|
||||
return handle_.get();
|
||||
}
|
||||
|
||||
AtenRecordFunctionHandle release() {
|
||||
return handle_.release();
|
||||
}
|
||||
|
||||
AtenRecordFunctionHandle get() const {
|
||||
return handle_.get();
|
||||
}
|
||||
|
||||
void reset() {
|
||||
handle_.reset();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<AtenRecordFunctionOpaque, DeleterFnPtr> handle_;
|
||||
};
|
||||
|
||||
// RAIIAtenTensorHandle steals the tensor objects created by the libtorch C ABI
|
||||
class RAIIAtenTensorHandle {
|
||||
public:
|
||||
|
@ -52,6 +52,9 @@ using AtenGeneratorHandle = AtenGeneratorOpaque*;
|
||||
struct AOTIProxyExecutorOpaque;
|
||||
using AOTIProxyExecutorHandle = AOTIProxyExecutorOpaque*;
|
||||
|
||||
struct C10IValueOpaque;
|
||||
using C10IValueHandle = C10IValueOpaque*;
|
||||
|
||||
using AOTITorchError = int32_t;
|
||||
#define AOTI_TORCH_SUCCESS 0
|
||||
#define AOTI_TORCH_FAILURE 1
|
||||
|
@ -400,6 +400,22 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_zero_(AtenTensorHandle self);
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_check_inf_and_nan(const char* tensor_name, AtenTensorHandle tensor);
|
||||
|
||||
struct AtenRecordFunctionOpaque;
|
||||
using AtenRecordFunctionHandle = AtenRecordFunctionOpaque*;
|
||||
|
||||
struct IValueMapOpaque;
|
||||
using IValueMapHandle = IValueMapOpaque*;
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_record_function_start(
|
||||
const char* name,
|
||||
IValueMapHandle kwargs,
|
||||
const C10IValueHandle* inputs,
|
||||
const uint64_t n_inputs,
|
||||
AtenRecordFunctionHandle* guard);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_record_function_end(AtenRecordFunctionHandle guard);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_out(
|
||||
AtenTensorHandle out,
|
||||
AtenTensorHandle self,
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <ATen/native/quantized/cpu/qlinear.h>
|
||||
#include <ATen/record_function.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/DispatchKey.h>
|
||||
#include <c10/core/GradMode.h>
|
||||
@ -1091,6 +1092,45 @@ AOTITorchError aoti_torch_check_inf_and_nan(
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_record_function_start(
|
||||
const char* name,
|
||||
IValueMapHandle kwargs,
|
||||
const C10IValueHandle* inputs,
|
||||
const uint64_t n_inputs,
|
||||
AtenRecordFunctionHandle* guard) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
at::RecordFunction* newGuard =
|
||||
new at::RecordFunction(at::RecordScope::FUNCTION);
|
||||
std::unordered_map<std::string, c10::IValue> recordKwargs;
|
||||
|
||||
if (kwargs != nullptr) {
|
||||
auto wrappedKwargs =
|
||||
reinterpret_cast<std::unordered_map<std::string, C10IValueHandle>*>(
|
||||
kwargs);
|
||||
for (const auto& pair : *wrappedKwargs) {
|
||||
recordKwargs.emplace(
|
||||
pair.first, *(reinterpret_cast<c10::IValue*>(pair.second)));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> recordInputs(n_inputs);
|
||||
for (size_t i = 0; i < n_inputs; i++) {
|
||||
recordInputs.push_back(*reinterpret_cast<c10::IValue*>(inputs[i]));
|
||||
}
|
||||
|
||||
newGuard->before(name, &recordInputs, &recordKwargs);
|
||||
*guard = reinterpret_cast<AtenRecordFunctionHandle>(newGuard);
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_record_function_end(AtenRecordFunctionHandle guard) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
at::RecordFunction* t = reinterpret_cast<at::RecordFunction*>(guard);
|
||||
|
||||
delete t;
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_scatter_out(
|
||||
AtenTensorHandle out,
|
||||
AtenTensorHandle self,
|
||||
|
Reference in New Issue
Block a user