Files
pytorch/torch/csrc/inductor/aoti_runtime/utils.h
Wu, Chunyuan 654afb6f3a [AOTI] support freezing for MKLDNN (#124350)
## Description
Fixes https://github.com/pytorch/pytorch/issues/114450. This PR builds upon the work from @imzhuhl done in https://github.com/pytorch/pytorch/pull/114451.

This PR requires https://github.com/pytorch/pytorch/pull/122472 to land firstly.

We leverage the serialization and deserialization API from oneDNN v3.4.1 to save the opaque MKLDNN tensor during the compilation and restore the opaque tensor when loading the compiled .so.
ideep version is updated so that we won't break any pipeline even if third_party/ideep is not updated at the same time.

### Test plan:
```sh
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_freezing_non_abi_compatible_cpu
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_conv_freezing_non_abi_compatible_cpu
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_deconv_freezing_non_abi_compatible_cpu
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_linear_freezing_non_abi_compatible_cpu
```

### TODOs in follow-up PRs
1. We found that using `AOTI_TORCH_CHECK` will cause performance drop on several models (`DistillGPT2`, `MBartForConditionalGeneration`, `T5ForConditionalGeneration`, `T5Small`) compared with JIT Inductor which uses `TORCH_CHECK`. This may need further discussion how to address (`AOTI_TORCH_CHECK` is introduced in
 https://github.com/pytorch/pytorch/pull/119220).
2. Freezing in non-ABI compatible mode will work with the support in this PR. While for ABI compatible mode, we need to firstly address this issue: `AssertionError: None, i.e. optional output is not supported`.
6c4f43f826/torch/_inductor/codegen/cpp_wrapper_cpu.py (L2023-L2024)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124350
Approved by: https://github.com/jgong5, https://github.com/desertfire
2024-05-24 13:34:04 +00:00

181 lines
4.8 KiB
C++

#pragma once
#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
// WARNING: Be careful when adding new includes here. This header will be used
// in model.so, and should not refer to any aten/c10 headers except the stable
// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
// applies to other files under torch/csrc/inductor/aoti_runtime/.
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#if defined(__GNUC__) || defined(__clang__)
#define AOTI_NOINLINE __attribute__((noinline))
#elif _MSC_VER
#define AOTI_NOINLINE __declspec(noinline)
#else
#define AOTI_NOINLINE
#endif
AOTI_NOINLINE static void throw_exception(
const char* call,
const char* file,
int64_t line) {
std::stringstream ss;
ss << call << " API call failed at " << file << ", line " << line;
throw std::runtime_error(ss.str());
}
#define AOTI_TORCH_ERROR_CODE_CHECK(call) \
if ((call) != AOTI_TORCH_SUCCESS) { \
throw_exception(#call, __FILE__, __LINE__); \
}
using AOTIRuntimeError = int32_t;
#define AOTI_RUNTIME_SUCCESS 0
#define AOTI_RUNTIME_FAILURE 1
#define AOTI_RUNTIME_ERROR_CODE_CHECK(call) \
if ((call) != AOTI_RUNTIME_SUCCESS) { \
throw_exception(#call, __FILE__, __LINE__); \
}
namespace torch::aot_inductor {
using DeleterFnPtr = void (*)(void*);
inline void noop_deleter(void*) {}
inline void delete_tensor_object(void* ptr) {
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_delete_tensor_object(reinterpret_cast<AtenTensorHandle>(ptr)));
}
// RAIIAtenTensorHandle steals the tensor objects created by the libtorch C ABI
class RAIIAtenTensorHandle {
public:
RAIIAtenTensorHandle() : handle_(nullptr, noop_deleter) {}
RAIIAtenTensorHandle(const RAIIAtenTensorHandle& other) = delete;
RAIIAtenTensorHandle& operator=(const RAIIAtenTensorHandle& other) = delete;
// Steal the ownership from another RAIIAtenTensorHandle using std::move
RAIIAtenTensorHandle(RAIIAtenTensorHandle&& other) = default;
RAIIAtenTensorHandle& operator=(RAIIAtenTensorHandle&& other) = default;
// Steal the ownership from raw AtenTensorHandle
RAIIAtenTensorHandle(AtenTensorHandle handle)
: handle_(handle, delete_tensor_object) {}
~RAIIAtenTensorHandle() {
handle_.reset();
}
// Return a raw AtenTensorHandle to be used by aoti_torch functions
// Note: this function does NOT transfer the ownership of the handle
operator AtenTensorHandle() const {
return handle_.get();
}
AtenTensorHandle release() {
return handle_.release();
}
AtenTensorHandle get() const {
return handle_.get();
}
void reset() {
handle_.reset();
}
int64_t size(int64_t d) {
int64_t size;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(handle_.get(), d, &size));
return size;
}
int64_t stride(int64_t d) {
int64_t stride;
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_get_stride(handle_.get(), d, &stride));
return stride;
}
int64_t storage_offset() {
int64_t storage_offset;
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_get_storage_offset(handle_.get(), &storage_offset));
return storage_offset;
}
private:
std::unique_ptr<AtenTensorOpaque, DeleterFnPtr> handle_;
};
// Steal the ownership from raw AtenTensorHandle to RAIIAtenTensorHandle
inline std::vector<RAIIAtenTensorHandle> steal_from_raw_handles_to_raii_handles(
AtenTensorHandle* handles,
size_t size) {
std::vector<RAIIAtenTensorHandle> result;
result.reserve(size);
for (size_t i = 0; i < size; i++) {
result.emplace_back(handles[i]);
handles[i] = nullptr;
}
return result;
}
class ConstantHandle {
public:
ConstantHandle() = default;
explicit ConstantHandle(AtenTensorHandle handle) : handle_(handle) {
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle_, &data_));
}
operator AtenTensorHandle() const {
return handle_;
}
AtenTensorHandle tensor() const {
return handle_;
}
void* data_ptr() const {
return data_;
}
private:
AtenTensorHandle handle_;
void* data_ = nullptr;
};
inline void* get_data_ptr_wrapper(const ConstantHandle& constant) {
return constant.data_ptr();
}
inline const ConstantHandle& unwrap_raii_handle_if_needed(
const ConstantHandle& handle) {
return handle;
}
// Shouldn't be called.
inline AtenTensorHandle wrap_with_raii_handle_if_needed(
const ConstantHandle& handle) = delete;
#define CACHE_TORCH_DTYPE(typename) \
static auto cached_torch_dtype_##typename = aoti_torch_dtype_##typename()
#define CACHE_TORCH_DEVICE(device) \
static auto cached_torch_device_type_##device = \
aoti_torch_device_type_##device()
#define CACHE_TORCH_LAYOUT(layout) \
static auto cached_torch_layout_##layout = aoti_torch_layout_##layout()
} // namespace torch::aot_inductor