[AOTI] support freezing for MKLDNN (#124350)

## Description
Fixes https://github.com/pytorch/pytorch/issues/114450. This PR builds upon the work from @imzhuhl done in https://github.com/pytorch/pytorch/pull/114451.

This PR requires https://github.com/pytorch/pytorch/pull/122472 to land firstly.

We leverage the serialization and deserialization API from oneDNN v3.4.1 to save the opaque MKLDNN tensor during the compilation and restore the opaque tensor when loading the compiled .so.
ideep version is updated so that we won't break any pipeline even if third_party/ideep is not updated at the same time.

### Test plan:
```sh
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_freezing_non_abi_compatible_cpu
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_conv_freezing_non_abi_compatible_cpu
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_deconv_freezing_non_abi_compatible_cpu
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_linear_freezing_non_abi_compatible_cpu
```

### TODOs in follow-up PRs
1. We found that using `AOTI_TORCH_CHECK` will cause performance drop on several models (`DistillGPT2`, `MBartForConditionalGeneration`, `T5ForConditionalGeneration`, `T5Small`) compared with JIT Inductor which uses `TORCH_CHECK`. This may need further discussion how to address (`AOTI_TORCH_CHECK` is introduced in
 https://github.com/pytorch/pytorch/pull/119220).
2. Freezing in non-ABI compatible mode will work with the support in this PR. While for ABI compatible mode, we need to firstly address this issue: `AssertionError: None, i.e. optional output is not supported`.
6c4f43f826/torch/_inductor/codegen/cpp_wrapper_cpu.py (L2023-L2024)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124350
Approved by: https://github.com/jgong5, https://github.com/desertfire
This commit is contained in:
Wu, Chunyuan
2024-05-24 08:01:27 +00:00
committed by PyTorch MergeBot
parent 43baabe9b9
commit 654afb6f3a
17 changed files with 447 additions and 17 deletions

View File

@ -1,6 +1,7 @@
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/OpaqueTensorImpl.h>
#include <c10/core/Allocator.h>
#include <torch/library.h>
#if AT_MKLDNN_ENABLED()
@ -61,6 +62,33 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
}
}
int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor) {
MKLDNNTensorImpl *mklimpl = static_cast<MKLDNNTensorImpl *>(mkldnn_tensor.unsafeGetTensorImpl());
void* data_ptr = mklimpl->unsafe_opaque_handle()->get_target().get_data_handle();
return reinterpret_cast<int64_t>(data_ptr);
}
at::Tensor mkldnn_tensor_from_data_ptr(
void* data_ptr,
at::IntArrayRef dims,
at::ScalarType dtype,
at::Device device,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size) {
std::vector<uint8_t> vector_serialized_md{
opaque_metadata, opaque_metadata + opaque_metadata_size};
ideep::tensor::desc deserialized_ideep_desc;
#if IDEEP_PREREQ(3, 4, 1, 2)
// groups is needed for grouped conv
deserialized_ideep_desc = ideep::tensor::desc(vector_serialized_md);
#else
TORCH_CHECK(false, "Unexpected IDeep version to do weight deserialization.");
#endif
auto a = ideep::tensor(deserialized_ideep_desc, data_ptr);
return at::native::new_with_itensor_mkldnn(std::move(a), dtype, device);
}
Tensor new_with_itensor_mkldnn(ideep::tensor&& it, std::optional<ScalarType> dtype, std::optional<Device> device) {
// NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
// TODO: support int64_t dims in ideep::tensor to avoid extra conversion
@ -81,6 +109,11 @@ ideep::tensor& itensor_from_mkldnn(const MKLDNNTensor& mkldnn_tensor) {
return mklimpl->unsafe_opaque_handle()->get_target();
}
int64_t nbytes_from_mkldnn(const Tensor& mkldnn_tensor) {
ideep::tensor t = itensor_from_mkldnn(mkldnn_tensor);
return t.get_desc().get_size();
}
ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr) {
TORCH_CHECK(
tensor.device().is_cpu(),
@ -167,6 +200,15 @@ int set_verbose(int level) {
return ideep::utils::set_verbose(level);
}
TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::data_ptr"),
TORCH_FN(data_ptr_from_mkldnn));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_nbytes"),
TORCH_FN(nbytes_from_mkldnn));
}
}}
#endif // AT_MKLDNN_ENABLED()

View File

@ -28,12 +28,24 @@ static inline ideep::tensor::data_type get_mkldnn_dtype(const Tensor& t) {
return get_mkldnn_dtype(t.scalar_type());
}
TORCH_API int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor);
TORCH_API at::Tensor mkldnn_tensor_from_data_ptr(
void* data_ptr,
at::IntArrayRef dims,
at::ScalarType dtype,
at::Device device,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size);
// Construct aten MKL-DNN tensor given an ideep tensor
TORCH_API Tensor new_with_itensor_mkldnn(ideep::tensor&& it, std::optional<ScalarType> dtype, std::optional<Device> device);
// Retrieve `ideep::tensor` from MKL-DNN tensor
TORCH_API ideep::tensor& itensor_from_mkldnn(const Tensor& mkldnn_tensor);
TORCH_API int64_t nbytes_from_mkldnn(const Tensor& mkldnn_tensor);
// Construct an `ideep::tensor` "view" from dense tensor, note the
// ideep::tensor will share the underlying buffer
TORCH_API ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr=false);

View File

@ -12,7 +12,9 @@
#else
#include <ATen/ops/_to_dense_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/from_blob.h>
#include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
#include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
#include <ATen/ops/to_mkldnn_native.h>
@ -508,6 +510,25 @@ static std::vector<Tensor> mkldnn_reorder_mkldnn_rnn_layer_weight(
return {packed_w1, packed_w2};
}
static Tensor get_mkldnn_serialized_md(const Tensor& self) {
const ideep::tensor packed_w = itensor_from_tensor(self);
auto packed_w_desc = packed_w.get_desc();
std::vector<uint8_t> serialized_wei_desc;
#if IDEEP_PREREQ(3, 4, 1, 2)
serialized_wei_desc = packed_w_desc.get_blob();
#else
TORCH_CHECK(false, "Unexpected IDeep version to do weight serialization.");
#endif
Tensor serialized_md = at::from_blob((void*)serialized_wei_desc.data(), {(int64_t)serialized_wei_desc.size()}, at::TensorOptions(at::kByte));
auto res = at::empty_like(serialized_md);
// serialized_md shares the buffer with serialized_wei_desc,
// which will be released outside of this function thus invalidating the buffer of serialized_md.
// A copy is needed here so that res has its own buffer, which remains valid even after serialized_wei_desc is released.
res.copy_(serialized_md);
return res;
}
TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_transpose_weight"),
@ -523,6 +544,12 @@ TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
TORCH_FN(mkldnn_reorder_mkldnn_rnn_layer_weight));
}
TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_get_mkldnn_serialized_md"),
TORCH_FN(get_mkldnn_serialized_md ));
}
#else
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, std::optional<ScalarType> dtype, std::optional<bool> masked_grad) {

View File

@ -74,6 +74,9 @@ TORCH_LIBRARY(mkldnn, m) {
m.def("_is_mkldnn_bf16_supported", &is_mkldnn_bf16_supported);
m.def("_is_mkldnn_fp16_supported", &is_mkldnn_fp16_supported);
m.def("_is_mkldnn_acl_supported", &is_mkldnn_acl_supported);
m.def("mkldnn::data_ptr(Tensor mkldnn_tensor) -> int");
m.def("mkldnn::_get_mkldnn_serialized_md (Tensor mkldnn_tensor) -> Tensor");
m.def("mkldnn::_nbytes(Tensor mkldnn_tensor) -> int");
}
TORCH_LIBRARY(mkldnn_prepacked, m) {

View File

@ -471,6 +471,7 @@ inductor_core_resources = [
"torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp",
"torch/csrc/inductor/aoti_torch/shim_common.cpp",
"torch/csrc/inductor/aoti_torch/tensor_converter.cpp",
"torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp",
"torch/csrc/inductor/inductor_ops.cpp",
]

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: inductor"]
import copy
import itertools
import os
import sys
import tempfile
@ -89,6 +90,8 @@ def check_model(
options=None,
dynamic_shapes=None,
disable_constraint_solver=False,
atol=None,
rtol=None,
):
with torch.no_grad(), config.patch(
{
@ -114,7 +117,7 @@ def check_model(
disable_constraint_solver,
)
self.assertTrue(same(actual, expected))
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
def check_model_with_multiple_inputs(
@ -312,6 +315,10 @@ class AOTInductorTestsTemplate:
)
self.check_model(Model(self.device), example_inputs)
@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_freezing(self):
class Model(torch.nn.Module):
def __init__(self, device):
@ -331,6 +338,80 @@ class AOTInductorTestsTemplate:
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)
@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_conv_freezing(self):
for dtype, groups in itertools.product([torch.bfloat16, torch.float], [1, 2]):
iC = 2
oC = 3
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(oC * groups, iC, 3, 3, device=device).to(
dtype
)
def forward(self, y):
return torch.nn.functional.conv2d(y, self.weight, groups=groups)
example_inputs = (
torch.randn(2, iC * groups, 10, 10, device=self.device).to(dtype),
)
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)
@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_deconv_freezing(self):
dtypes = [torch.float]
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
dtypes.append(torch.bfloat16)
for dtype, groups in itertools.product(dtypes, [2, 1]):
iC = 4
oC = 2
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(iC, oC * groups, 2, 2, device=device).to(
dtype
)
def forward(self, y):
return torch.nn.functional.conv_transpose2d(
y, self.weight, groups=groups
)
example_inputs = (torch.randn(1, iC, 3, 3, device=self.device).to(dtype),)
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)
@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_linear_freezing(self):
for dtype in [torch.float32, torch.bfloat16]:
class LinearModel(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(10, 10, device=device).to(dtype)
def forward(self, y):
return torch.nn.functional.linear(y, self.weight)
example_inputs = (torch.randn(10, 10, device=self.device).to(dtype),)
with config.patch({"freezing": True}):
self.check_model(LinearModel(self.device), example_inputs)
@torch._inductor.config.patch(
pre_grad_fusion_options={
"normalization_pass": {},
@ -1390,7 +1471,9 @@ class AOTInductorTestsTemplate:
torch.randn(87, 87, device=self.device),
torch.randn(87, 87, device=self.device),
)
self.check_model(Model(), example_inputs)
self.check_model(
Model(), example_inputs, atol=1e-4, rtol=1e-4
) # 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py
if self.device == "cuda":
so_path = torch._export.aot_compile(Model(), example_inputs)
@ -2872,6 +2955,12 @@ def fail_non_abi_compatible_cuda(is_skip=False):
# test_failures, xfail by default, set is_skip=True to skip
CPU_TEST_FAILURES = {
"test_add_complex": fail_stack_allocation(is_skip=True),
# TODO: test_conv_freezing_abi_compatible_cpu fails,
# AssertionError: None, i.e. optional output is not supported
"test_conv_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# TODO: test_deconv_freezing_abi_compatible_cpu fails,
# AssertionError: None, i.e. optional output is not supported
"test_deconv_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# FIXME: failed with Segfault while exiting the Python runtime
"test_duplicate_constant_folding": fail_with_and_without_stack_allocation(
is_skip=True
@ -2885,9 +2974,12 @@ CPU_TEST_FAILURES = {
"test_dynamic_scalar": fail_stack_allocation(is_skip=True),
# https://github.com/pytorch/pytorch/issues/122980
"test_fft_c2c": fail_stack_allocation(is_skip=True),
# TODO: test_freezing_abi_compatible_cpu somehow fails on CI but not locally,
# NotImplementedError: Cannot access storage of OpaqueTensorImpl
# TODO: test_freezing_abi_compatible_cpu fails,
# AssertionError: None, i.e. optional output is not supported
"test_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# TODO: test_linear_freezing_abi_compatible_cpu fails,
# AssertionError: None, i.e. optional output is not supported
"test_linear_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# FIXME: failed with Segfault while exiting the Python runtime
"test_missing_cubin": fail_with_and_without_stack_allocation(is_skip=True),
# minimal arrayref interface only works with CPU; test crashes.
@ -3129,9 +3221,6 @@ copy_tests(
"test_duplicate_constant_folding": TestFailure(
("non_abi_compatible_cpu",), is_skip=True
),
# TODO: test_freezing_non_abi_compatible_cpu somehow fails on CI but not locally,
# NotImplementedError: Cannot access storage of OpaqueTensorImpl
"test_freezing": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
# no runtime checks for non_abi_compatible mode
"test_runtime_checks": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
"test_runtime_checks_dtype_failed": TestFailure(

View File

@ -1522,6 +1522,10 @@ def use_custom_generated_macros() -> str:
def use_fb_internal_macros() -> str:
if config.is_fbcode():
# TODO: this is to avoid FC breakage for fbcode. When using newly
# generated model.so on an older verion of PyTorch, need to use
# the v1 version for aoti_torch_create_tensor_from_blob
create_tensor_from_blob_v1 = "-D AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1"
openmp_lib = build_paths.openmp_lib()
preprocessor_flags = " ".join(
(
@ -1530,7 +1534,7 @@ def use_fb_internal_macros() -> str:
"-D C10_DISABLE_TENSORIMPL_EXTENSIBILITY",
)
)
return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags}"
return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags} {create_tensor_from_blob_v1}"
else:
return ""
@ -2076,7 +2080,9 @@ class AotCodeCompiler:
output_o = os.path.splitext(input_path)[0] + ".o"
consts_size = sum(
tensor.untyped_storage().nbytes()
torch.ops.mkldnn._nbytes(tensor)
if tensor.is_mkldnn
else tensor.untyped_storage().nbytes()
for (name, tensor) in graph.constants.items()
if name not in graph.folded_constants
)
@ -2109,6 +2115,13 @@ class AotCodeCompiler:
if t.numel() == 0:
return b""
if t.is_mkldnn:
raw_array = ctypes.cast(
torch.ops.mkldnn.data_ptr(t),
ctypes.POINTER(ctypes.c_ubyte * torch.ops.mkldnn._nbytes(t)),
)
return bytes(raw_array.contents)
t_cpu = t.untyped_storage().cpu()
raw_array = ctypes.cast(
t_cpu.data_ptr(),

View File

@ -1971,6 +1971,8 @@ class CppKernel(Kernel):
@property
def assert_function(self) -> str:
if V.graph.aot_mode:
# TODO: Using AOTI_TORCH_CHECK is causing performance drop for some models
# compared with JIT Inductor which uses TORCH_CHECK
return "AOTI_TORCH_CHECK"
else:
return "TORCH_CHECK"

View File

@ -64,6 +64,11 @@ DEVICE_TO_ATEN = {
"cuda": "at::kCUDA",
}
LAYOUT_TO_ATEN = {
torch.strided: "at::kStrided",
torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined]
}
INDEX_TYPE = "long"
GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"])

View File

@ -18,7 +18,14 @@ from ..utils import cache_on_self, sympy_product
from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .common import IndentedBuffer
from .cpp_utils import cexpr, CppPrinter, DEVICE_TO_ATEN, DTYPE_TO_ATEN, DTYPE_TO_CPP
from .cpp_utils import (
cexpr,
CppPrinter,
DEVICE_TO_ATEN,
DTYPE_TO_ATEN,
DTYPE_TO_CPP,
LAYOUT_TO_ATEN,
)
from .wrapper import EnterSubgraphLine, ExitSubgraphLine, WrapperCodeGen
@ -56,6 +63,7 @@ class CppWrapperCpu(WrapperCodeGen):
self.arg_var_id = count()
self.used_cached_devices = set()
self.used_cached_dtypes = set()
self.used_cached_layouts = set()
self.cached_output_id = count()
self.scalar_to_tensor_id = count()
self.custom_op_wrapper_loaded = False
@ -722,6 +730,11 @@ class CppWrapperCpu(WrapperCodeGen):
self.prefix.writeline(
f"constants_info_[{idx}].offset = {tensor.storage_offset()};"
)
if tensor.is_mkldnn:
self.prefix.writeline(
f"constants_info_[{idx}].data_size = {torch.ops.mkldnn._nbytes(tensor)};"
)
else:
self.prefix.writeline(
f"constants_info_[{idx}].data_size = {tensor.untyped_storage().nbytes()};"
)
@ -737,6 +750,23 @@ class CppWrapperCpu(WrapperCodeGen):
self.prefix.writeline(
f"constants_info_[{idx}].stride = {{{stride_str}}};"
)
self.prefix.writeline(
f"constants_info_[{idx}].layout = static_cast<int32_t>({self.codegen_layout(tensor.layout)});"
)
if tensor.is_mkldnn:
opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md(
tensor
)
assert (
opaque_metadata_tensor.dim() == 1
), "Expect opaque_metadata_tensor to be 1-D"
opaque_metadata_list = opaque_metadata_tensor.tolist()
opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list)
self.prefix.writeline(
f"constants_info_[{idx}].opaque_metadata = {opaque_metadata_str};"
)
if name in V.graph.dynamo_flat_name_to_original_fqn:
original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get(
name, name
@ -877,6 +907,8 @@ class CppWrapperCpu(WrapperCodeGen):
cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});")
for device in self.used_cached_devices:
cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});")
for layout in self.used_cached_layouts:
cached_dtypes_buffer.writeline(f"CACHE_TORCH_LAYOUT({layout});")
cached_dtypes_buffer.splice(self.prefix)
self.prefix = cached_dtypes_buffer
@ -1493,6 +1525,14 @@ class CppWrapperCpu(WrapperCodeGen):
else:
return DTYPE_TO_ATEN[dtype]
def codegen_layout(self, layout):
if config.abi_compatible:
layout_str = str(layout).split(".")[-1]
self.used_cached_layouts.add(layout_str)
return f"cached_torch_layout_{layout_str}"
else:
return LAYOUT_TO_ATEN[layout]
@functools.lru_cache(None)
def codegen_int_array_var(
self,

View File

@ -18,7 +18,7 @@ from ..pattern_matcher import (
KeywordArg,
MULTIPLE,
)
from ..virtualized import ops
from ..virtualized import ops, V
from .freezing_patterns import register_freezing_graph_pattern
from .post_grad import register_lowering_pattern
from .quantization import (
@ -1146,9 +1146,18 @@ if torch._C._has_mkldnn:
if has_free_symbols(batch_size)
else batch_size,
)
# MKL packed matrix can't be copied to a different address because the internal implementation
# depends on the alignment of internally-stored metadata.
# In aot mode, we need to firstly save the packed weight, when loading it,
# it will be in a different address which doesn't work.
# Disable MKL prepack linear in AOT mode
packed_weight_op = (
mkldnn._reorder_linear_weight
if (is_lp_weight or mkldnn._is_mkldnn_acl_supported())
if (
is_lp_weight
or mkldnn._is_mkldnn_acl_supported()
or V.aot_compilation is True
)
else torch.ops.mkl._mkl_reorder_linear_weight
)
packed_weight_node = graph.create_node(
@ -1156,7 +1165,11 @@ if torch._C._has_mkldnn:
)
packed_linear_inputs: Tuple[Any, ...] = (input, packed_weight_node)
if is_lp_weight or mkldnn._is_mkldnn_acl_supported():
if (
is_lp_weight
or mkldnn._is_mkldnn_acl_supported()
or V.aot_compilation is True
):
packed_linear_inputs += (bias, "none", [], "")
packed_linear_op = mkldnn._linear_pointwise.default
else:

View File

@ -222,8 +222,17 @@ class AOTInductorModelBase {
auto size = this->constant_shape(i);
auto stride = this->constant_stride(i);
auto offset = this->constant_offset(i);
auto layout = this->constant_layout(i);
auto opaque_metadata_ptr = this->opaque_metadata(i);
auto opaque_metadata_size = this->opaque_metadata_size(i);
AtenTensorHandle tensor_handle;
#ifdef AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1
// When opaque_metadata_size is not 0, we need to have the
// aoti_torch_create_tensor_from_blob_v2 available
AOTI_RUNTIME_CHECK(
opaque_metadata_size == 0,
"Expect opaque_metadata_size to be 0 when AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1 is defined");
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob(
internal_ptr,
ndim,
@ -234,6 +243,21 @@ class AOTInductorModelBase {
device_type_,
device_idx_,
&tensor_handle));
#else
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_v2(
internal_ptr,
ndim,
size,
stride,
offset,
dtype,
device_type_,
device_idx_,
&tensor_handle,
layout,
opaque_metadata_ptr,
opaque_metadata_size));
#endif // AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1
constants_map_->emplace(std::move(name), tensor_handle);
}
if (constants_map_) {
@ -340,6 +364,10 @@ class AOTInductorModelBase {
return constants_info_.at(idx).dtype;
}
int32_t constant_layout(int64_t idx) const {
return constants_info_.at(idx).layout;
}
size_t constant_offset(int64_t idx) const {
return constants_info_.at(idx).offset;
}
@ -352,6 +380,14 @@ class AOTInductorModelBase {
return constants_info_.at(idx).original_fqn;
}
const uint8_t* opaque_metadata(int64_t idx) const {
return constants_info_.at(idx).opaque_metadata.data();
}
size_t opaque_metadata_size(int64_t idx) {
return constants_info_.at(idx).opaque_metadata.size();
}
bool constant_from_folded(int64_t idx) const {
return constants_info_.at(idx).from_folded;
}
@ -485,6 +521,9 @@ class AOTInductorModelBase {
int32_t dtype;
int64_t offset;
size_t data_size;
int32_t layout;
std::vector<uint8_t> opaque_metadata;
int64_t opaque_metadata_size;
const char* original_fqn = nullptr;
bool from_folded;
};

View File

@ -174,4 +174,7 @@ inline AtenTensorHandle wrap_with_raii_handle_if_needed(
static auto cached_torch_device_type_##device = \
aoti_torch_device_type_##device()
#define CACHE_TORCH_LAYOUT(layout) \
static auto cached_torch_layout_##layout = aoti_torch_layout_##layout()
} // namespace torch::aot_inductor

View File

@ -112,6 +112,9 @@ AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex32();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex64();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex128();
AOTI_TORCH_EXPORT int32_t aoti_torch_layout_strided();
AOTI_TORCH_EXPORT int32_t aoti_torch_layout__mkldnn();
// Functions for converting a single-element tensor to a scalar value
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_item_float32(AtenTensorHandle tensor, float* ret_value);
@ -270,6 +273,20 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
AtenTensorHandle* ret // returns new reference
);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2(
void* data,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t storage_offset,
int32_t dtype,
int32_t device_type,
int32_t device_index,
AtenTensorHandle* ret, // returns new reference
int32_t layout,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag(
AtenTensorHandle weight,
AtenTensorHandle indices,

View File

@ -0,0 +1,49 @@
#include <ATen/Config.h>
#include <torch/csrc/inductor/aoti_torch/mkldnn_tensor.h>
#if AT_MKLDNN_ENABLED()
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ideep.hpp>
#endif
namespace torch {
namespace aot_inductor {
#if AT_MKLDNN_ENABLED()
void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor) {
return reinterpret_cast<void*>(
at::native::data_ptr_from_mkldnn(*mkldnn_tensor));
}
at::Tensor mkldnn_tensor_from_data_ptr(
void* data_ptr,
at::IntArrayRef dims,
at::ScalarType dtype,
at::Device device,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size) {
return at::native::mkldnn_tensor_from_data_ptr(
data_ptr, dims, dtype, device, opaque_metadata, opaque_metadata_size);
}
#else
void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor) {
TORCH_CHECK(false, "MKL-DNN build is disabled");
}
at::Tensor mkldnn_tensor_from_data_ptr(
void* data_ptr,
at::IntArrayRef dims,
at::ScalarType dtype,
at::Device device,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size) {
TORCH_CHECK(false, "MKL-DNN build is disabled");
}
#endif
} // namespace aot_inductor
} // namespace torch

View File

@ -0,0 +1,19 @@
#pragma once
#include <ATen/Tensor.h>
namespace torch {
namespace aot_inductor {
void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor);
at::Tensor mkldnn_tensor_from_data_ptr(
void* data_ptr,
at::IntArrayRef dims,
at::ScalarType dtype,
at::Device device,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size);
} // namespace aot_inductor
} // namespace torch

View File

@ -1,8 +1,10 @@
#include <c10/core/DeviceType.h>
#include <c10/core/GradMode.h>
#include <c10/core/Layout.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/mkldnn_tensor.h>
#include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
@ -90,6 +92,14 @@ AOTI_TORCH_DTYPE_IMPL(complex64, ComplexFloat)
AOTI_TORCH_DTYPE_IMPL(complex128, ComplexDouble)
#undef AOTI_TORCH_DTYPE_IMPL
int32_t aoti_torch_layout_strided() {
return (int32_t)at::kStrided;
}
int32_t aoti_torch_layout__mkldnn() {
return (int32_t)at::kMkldnn;
}
#define AOTI_TORCH_ITEM_IMPL(dtype, ctype) \
AOTITorchError aoti_torch_item_##dtype( \
AtenTensorHandle tensor, ctype* ret_value) { \
@ -154,7 +164,11 @@ AOTITorchError aoti_torch_get_data_ptr(
void** ret_data_ptr) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
if (t->is_mkldnn()) {
*ret_data_ptr = data_ptr_from_mkldnn(t);
} else {
*ret_data_ptr = t->data_ptr();
}
});
}
@ -325,6 +339,48 @@ AOTITorchError aoti_torch_create_tensor_from_blob(
});
}
AOTITorchError aoti_torch_create_tensor_from_blob_v2(
void* data,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t storage_offset,
int32_t dtype,
int32_t device_type,
int32_t device_index,
AtenTensorHandle* ret_new_tensor,
int32_t layout,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
if (layout == static_cast<int32_t>(at::kMkldnn)) {
c10::IntArrayRef sizes(sizes_ptr, ndim);
c10::IntArrayRef strides(strides_ptr, ndim);
c10::Device device = c10_device(device_type, device_index);
// get a mkldnn tensor wrapped by a torch Tensor(OpaqueTensorImpl),
// which used by later mkldnn op.
*ret_new_tensor = new_tensor_handle(mkldnn_tensor_from_data_ptr(
data,
sizes,
static_cast<c10::ScalarType>(dtype),
device,
opaque_metadata,
opaque_metadata_size));
} else {
aoti_torch_create_tensor_from_blob(
data,
ndim,
sizes_ptr,
strides_ptr,
storage_offset,
dtype,
device_type,
device_index,
ret_new_tensor);
}
});
}
AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag(
AtenTensorHandle weight,
AtenTensorHandle indices,