Revert "[AOTI] support freezing for MKLDNN (#124350)"

This reverts commit 654afb6f3ae3ddbd926a753f9af95a6f6e22131c.

Reverted https://github.com/pytorch/pytorch/pull/124350 on behalf of https://github.com/clee2000 due to Seems to have broken inductor/test_aot_inductor.py::AOTInductorTestNonABICompatibleCpu::test_freezing_non_abi_compatible_cpu 654afb6f3a https://github.com/pytorch/pytorch/actions/runs/9224838183/job/25382780192 ([comment](https://github.com/pytorch/pytorch/pull/124350#issuecomment-2129889809))
This commit is contained in:
PyTorch MergeBot
2024-05-24 16:03:07 +00:00
parent 2ac739cc80
commit 5ae9daa4a2
17 changed files with 17 additions and 447 deletions

View File

@ -1,7 +1,6 @@
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/OpaqueTensorImpl.h>
#include <c10/core/Allocator.h>
#include <torch/library.h>
#if AT_MKLDNN_ENABLED()
@ -62,33 +61,6 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
}
}
int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor) {
MKLDNNTensorImpl *mklimpl = static_cast<MKLDNNTensorImpl *>(mkldnn_tensor.unsafeGetTensorImpl());
void* data_ptr = mklimpl->unsafe_opaque_handle()->get_target().get_data_handle();
return reinterpret_cast<int64_t>(data_ptr);
}
at::Tensor mkldnn_tensor_from_data_ptr(
void* data_ptr,
at::IntArrayRef dims,
at::ScalarType dtype,
at::Device device,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size) {
std::vector<uint8_t> vector_serialized_md{
opaque_metadata, opaque_metadata + opaque_metadata_size};
ideep::tensor::desc deserialized_ideep_desc;
#if IDEEP_PREREQ(3, 4, 1, 2)
// groups is needed for grouped conv
deserialized_ideep_desc = ideep::tensor::desc(vector_serialized_md);
#else
TORCH_CHECK(false, "Unexpected IDeep version to do weight deserialization.");
#endif
auto a = ideep::tensor(deserialized_ideep_desc, data_ptr);
return at::native::new_with_itensor_mkldnn(std::move(a), dtype, device);
}
Tensor new_with_itensor_mkldnn(ideep::tensor&& it, std::optional<ScalarType> dtype, std::optional<Device> device) {
// NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
// TODO: support int64_t dims in ideep::tensor to avoid extra conversion
@ -109,11 +81,6 @@ ideep::tensor& itensor_from_mkldnn(const MKLDNNTensor& mkldnn_tensor) {
return mklimpl->unsafe_opaque_handle()->get_target();
}
int64_t nbytes_from_mkldnn(const Tensor& mkldnn_tensor) {
ideep::tensor t = itensor_from_mkldnn(mkldnn_tensor);
return t.get_desc().get_size();
}
ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr) {
TORCH_CHECK(
tensor.device().is_cpu(),
@ -200,15 +167,6 @@ int set_verbose(int level) {
return ideep::utils::set_verbose(level);
}
TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::data_ptr"),
TORCH_FN(data_ptr_from_mkldnn));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_nbytes"),
TORCH_FN(nbytes_from_mkldnn));
}
}}
#endif // AT_MKLDNN_ENABLED()

View File

@ -28,24 +28,12 @@ static inline ideep::tensor::data_type get_mkldnn_dtype(const Tensor& t) {
return get_mkldnn_dtype(t.scalar_type());
}
TORCH_API int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor);
TORCH_API at::Tensor mkldnn_tensor_from_data_ptr(
void* data_ptr,
at::IntArrayRef dims,
at::ScalarType dtype,
at::Device device,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size);
// Construct aten MKL-DNN tensor given an ideep tensor
TORCH_API Tensor new_with_itensor_mkldnn(ideep::tensor&& it, std::optional<ScalarType> dtype, std::optional<Device> device);
// Retrieve `ideep::tensor` from MKL-DNN tensor
TORCH_API ideep::tensor& itensor_from_mkldnn(const Tensor& mkldnn_tensor);
TORCH_API int64_t nbytes_from_mkldnn(const Tensor& mkldnn_tensor);
// Construct an `ideep::tensor` "view" from dense tensor, note the
// ideep::tensor will share the underlying buffer
TORCH_API ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr=false);

View File

@ -12,9 +12,7 @@
#else
#include <ATen/ops/_to_dense_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/from_blob.h>
#include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
#include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
#include <ATen/ops/to_mkldnn_native.h>
@ -510,25 +508,6 @@ static std::vector<Tensor> mkldnn_reorder_mkldnn_rnn_layer_weight(
return {packed_w1, packed_w2};
}
static Tensor get_mkldnn_serialized_md(const Tensor& self) {
const ideep::tensor packed_w = itensor_from_tensor(self);
auto packed_w_desc = packed_w.get_desc();
std::vector<uint8_t> serialized_wei_desc;
#if IDEEP_PREREQ(3, 4, 1, 2)
serialized_wei_desc = packed_w_desc.get_blob();
#else
TORCH_CHECK(false, "Unexpected IDeep version to do weight serialization.");
#endif
Tensor serialized_md = at::from_blob((void*)serialized_wei_desc.data(), {(int64_t)serialized_wei_desc.size()}, at::TensorOptions(at::kByte));
auto res = at::empty_like(serialized_md);
// serialized_md shares the buffer with serialized_wei_desc,
// which will be released outside of this function thus invalidating the buffer of serialized_md.
// A copy is needed here so that res has its own buffer, which remains valid even after serialized_wei_desc is released.
res.copy_(serialized_md);
return res;
}
TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_transpose_weight"),
@ -544,12 +523,6 @@ TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
TORCH_FN(mkldnn_reorder_mkldnn_rnn_layer_weight));
}
TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_get_mkldnn_serialized_md"),
TORCH_FN(get_mkldnn_serialized_md ));
}
#else
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, std::optional<ScalarType> dtype, std::optional<bool> masked_grad) {

View File

@ -74,9 +74,6 @@ TORCH_LIBRARY(mkldnn, m) {
m.def("_is_mkldnn_bf16_supported", &is_mkldnn_bf16_supported);
m.def("_is_mkldnn_fp16_supported", &is_mkldnn_fp16_supported);
m.def("_is_mkldnn_acl_supported", &is_mkldnn_acl_supported);
m.def("mkldnn::data_ptr(Tensor mkldnn_tensor) -> int");
m.def("mkldnn::_get_mkldnn_serialized_md (Tensor mkldnn_tensor) -> Tensor");
m.def("mkldnn::_nbytes(Tensor mkldnn_tensor) -> int");
}
TORCH_LIBRARY(mkldnn_prepacked, m) {

View File

@ -471,7 +471,6 @@ inductor_core_resources = [
"torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp",
"torch/csrc/inductor/aoti_torch/shim_common.cpp",
"torch/csrc/inductor/aoti_torch/tensor_converter.cpp",
"torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp",
"torch/csrc/inductor/inductor_ops.cpp",
]

View File

@ -1,6 +1,5 @@
# Owner(s): ["module: inductor"]
import copy
import itertools
import os
import sys
import tempfile
@ -90,8 +89,6 @@ def check_model(
options=None,
dynamic_shapes=None,
disable_constraint_solver=False,
atol=None,
rtol=None,
):
with torch.no_grad(), config.patch(
{
@ -117,7 +114,7 @@ def check_model(
disable_constraint_solver,
)
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
self.assertTrue(same(actual, expected))
def check_model_with_multiple_inputs(
@ -315,10 +312,6 @@ class AOTInductorTestsTemplate:
)
self.check_model(Model(self.device), example_inputs)
@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_freezing(self):
class Model(torch.nn.Module):
def __init__(self, device):
@ -338,80 +331,6 @@ class AOTInductorTestsTemplate:
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)
@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_conv_freezing(self):
for dtype, groups in itertools.product([torch.bfloat16, torch.float], [1, 2]):
iC = 2
oC = 3
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(oC * groups, iC, 3, 3, device=device).to(
dtype
)
def forward(self, y):
return torch.nn.functional.conv2d(y, self.weight, groups=groups)
example_inputs = (
torch.randn(2, iC * groups, 10, 10, device=self.device).to(dtype),
)
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)
@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_deconv_freezing(self):
dtypes = [torch.float]
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
dtypes.append(torch.bfloat16)
for dtype, groups in itertools.product(dtypes, [2, 1]):
iC = 4
oC = 2
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(iC, oC * groups, 2, 2, device=device).to(
dtype
)
def forward(self, y):
return torch.nn.functional.conv_transpose2d(
y, self.weight, groups=groups
)
example_inputs = (torch.randn(1, iC, 3, 3, device=self.device).to(dtype),)
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)
@unittest.skipIf(
IS_FBCODE,
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
)
def test_linear_freezing(self):
for dtype in [torch.float32, torch.bfloat16]:
class LinearModel(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(10, 10, device=device).to(dtype)
def forward(self, y):
return torch.nn.functional.linear(y, self.weight)
example_inputs = (torch.randn(10, 10, device=self.device).to(dtype),)
with config.patch({"freezing": True}):
self.check_model(LinearModel(self.device), example_inputs)
@torch._inductor.config.patch(
pre_grad_fusion_options={
"normalization_pass": {},
@ -1471,9 +1390,7 @@ class AOTInductorTestsTemplate:
torch.randn(87, 87, device=self.device),
torch.randn(87, 87, device=self.device),
)
self.check_model(
Model(), example_inputs, atol=1e-4, rtol=1e-4
) # 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py
self.check_model(Model(), example_inputs)
if self.device == "cuda":
so_path = torch._export.aot_compile(Model(), example_inputs)
@ -2955,12 +2872,6 @@ def fail_non_abi_compatible_cuda(is_skip=False):
# test_failures, xfail by default, set is_skip=True to skip
CPU_TEST_FAILURES = {
"test_add_complex": fail_stack_allocation(is_skip=True),
# TODO: test_conv_freezing_abi_compatible_cpu fails,
# AssertionError: None, i.e. optional output is not supported
"test_conv_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# TODO: test_deconv_freezing_abi_compatible_cpu fails,
# AssertionError: None, i.e. optional output is not supported
"test_deconv_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# FIXME: failed with Segfault while exiting the Python runtime
"test_duplicate_constant_folding": fail_with_and_without_stack_allocation(
is_skip=True
@ -2974,12 +2885,9 @@ CPU_TEST_FAILURES = {
"test_dynamic_scalar": fail_stack_allocation(is_skip=True),
# https://github.com/pytorch/pytorch/issues/122980
"test_fft_c2c": fail_stack_allocation(is_skip=True),
# TODO: test_freezing_abi_compatible_cpu fails,
# AssertionError: None, i.e. optional output is not supported
# TODO: test_freezing_abi_compatible_cpu somehow fails on CI but not locally,
# NotImplementedError: Cannot access storage of OpaqueTensorImpl
"test_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# TODO: test_linear_freezing_abi_compatible_cpu fails,
# AssertionError: None, i.e. optional output is not supported
"test_linear_freezing": fail_with_and_without_stack_allocation(is_skip=True),
# FIXME: failed with Segfault while exiting the Python runtime
"test_missing_cubin": fail_with_and_without_stack_allocation(is_skip=True),
# minimal arrayref interface only works with CPU; test crashes.
@ -3221,6 +3129,9 @@ copy_tests(
"test_duplicate_constant_folding": TestFailure(
("non_abi_compatible_cpu",), is_skip=True
),
# TODO: test_freezing_non_abi_compatible_cpu somehow fails on CI but not locally,
# NotImplementedError: Cannot access storage of OpaqueTensorImpl
"test_freezing": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
# no runtime checks for non_abi_compatible mode
"test_runtime_checks": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
"test_runtime_checks_dtype_failed": TestFailure(

View File

@ -1522,10 +1522,6 @@ def use_custom_generated_macros() -> str:
def use_fb_internal_macros() -> str:
if config.is_fbcode():
# TODO: this is to avoid FC breakage for fbcode. When using newly
# generated model.so on an older verion of PyTorch, need to use
# the v1 version for aoti_torch_create_tensor_from_blob
create_tensor_from_blob_v1 = "-D AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1"
openmp_lib = build_paths.openmp_lib()
preprocessor_flags = " ".join(
(
@ -1534,7 +1530,7 @@ def use_fb_internal_macros() -> str:
"-D C10_DISABLE_TENSORIMPL_EXTENSIBILITY",
)
)
return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags} {create_tensor_from_blob_v1}"
return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags}"
else:
return ""
@ -2080,9 +2076,7 @@ class AotCodeCompiler:
output_o = os.path.splitext(input_path)[0] + ".o"
consts_size = sum(
torch.ops.mkldnn._nbytes(tensor)
if tensor.is_mkldnn
else tensor.untyped_storage().nbytes()
tensor.untyped_storage().nbytes()
for (name, tensor) in graph.constants.items()
if name not in graph.folded_constants
)
@ -2115,13 +2109,6 @@ class AotCodeCompiler:
if t.numel() == 0:
return b""
if t.is_mkldnn:
raw_array = ctypes.cast(
torch.ops.mkldnn.data_ptr(t),
ctypes.POINTER(ctypes.c_ubyte * torch.ops.mkldnn._nbytes(t)),
)
return bytes(raw_array.contents)
t_cpu = t.untyped_storage().cpu()
raw_array = ctypes.cast(
t_cpu.data_ptr(),

View File

@ -1971,8 +1971,6 @@ class CppKernel(Kernel):
@property
def assert_function(self) -> str:
if V.graph.aot_mode:
# TODO: Using AOTI_TORCH_CHECK is causing performance drop for some models
# compared with JIT Inductor which uses TORCH_CHECK
return "AOTI_TORCH_CHECK"
else:
return "TORCH_CHECK"

View File

@ -64,11 +64,6 @@ DEVICE_TO_ATEN = {
"cuda": "at::kCUDA",
}
LAYOUT_TO_ATEN = {
torch.strided: "at::kStrided",
torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined]
}
INDEX_TYPE = "long"
GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"])

View File

@ -18,14 +18,7 @@ from ..utils import cache_on_self, sympy_product
from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .common import IndentedBuffer
from .cpp_utils import (
cexpr,
CppPrinter,
DEVICE_TO_ATEN,
DTYPE_TO_ATEN,
DTYPE_TO_CPP,
LAYOUT_TO_ATEN,
)
from .cpp_utils import cexpr, CppPrinter, DEVICE_TO_ATEN, DTYPE_TO_ATEN, DTYPE_TO_CPP
from .wrapper import EnterSubgraphLine, ExitSubgraphLine, WrapperCodeGen
@ -63,7 +56,6 @@ class CppWrapperCpu(WrapperCodeGen):
self.arg_var_id = count()
self.used_cached_devices = set()
self.used_cached_dtypes = set()
self.used_cached_layouts = set()
self.cached_output_id = count()
self.scalar_to_tensor_id = count()
self.custom_op_wrapper_loaded = False
@ -730,14 +722,9 @@ class CppWrapperCpu(WrapperCodeGen):
self.prefix.writeline(
f"constants_info_[{idx}].offset = {tensor.storage_offset()};"
)
if tensor.is_mkldnn:
self.prefix.writeline(
f"constants_info_[{idx}].data_size = {torch.ops.mkldnn._nbytes(tensor)};"
)
else:
self.prefix.writeline(
f"constants_info_[{idx}].data_size = {tensor.untyped_storage().nbytes()};"
)
self.prefix.writeline(
f"constants_info_[{idx}].data_size = {tensor.untyped_storage().nbytes()};"
)
from_folded = "true" if name in V.graph.folded_constants else "false"
self.prefix.writeline(
f"constants_info_[{idx}].from_folded = {from_folded};"
@ -750,23 +737,6 @@ class CppWrapperCpu(WrapperCodeGen):
self.prefix.writeline(
f"constants_info_[{idx}].stride = {{{stride_str}}};"
)
self.prefix.writeline(
f"constants_info_[{idx}].layout = static_cast<int32_t>({self.codegen_layout(tensor.layout)});"
)
if tensor.is_mkldnn:
opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md(
tensor
)
assert (
opaque_metadata_tensor.dim() == 1
), "Expect opaque_metadata_tensor to be 1-D"
opaque_metadata_list = opaque_metadata_tensor.tolist()
opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list)
self.prefix.writeline(
f"constants_info_[{idx}].opaque_metadata = {opaque_metadata_str};"
)
if name in V.graph.dynamo_flat_name_to_original_fqn:
original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get(
name, name
@ -907,8 +877,6 @@ class CppWrapperCpu(WrapperCodeGen):
cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});")
for device in self.used_cached_devices:
cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});")
for layout in self.used_cached_layouts:
cached_dtypes_buffer.writeline(f"CACHE_TORCH_LAYOUT({layout});")
cached_dtypes_buffer.splice(self.prefix)
self.prefix = cached_dtypes_buffer
@ -1525,14 +1493,6 @@ class CppWrapperCpu(WrapperCodeGen):
else:
return DTYPE_TO_ATEN[dtype]
def codegen_layout(self, layout):
if config.abi_compatible:
layout_str = str(layout).split(".")[-1]
self.used_cached_layouts.add(layout_str)
return f"cached_torch_layout_{layout_str}"
else:
return LAYOUT_TO_ATEN[layout]
@functools.lru_cache(None)
def codegen_int_array_var(
self,

View File

@ -18,7 +18,7 @@ from ..pattern_matcher import (
KeywordArg,
MULTIPLE,
)
from ..virtualized import ops, V
from ..virtualized import ops
from .freezing_patterns import register_freezing_graph_pattern
from .post_grad import register_lowering_pattern
from .quantization import (
@ -1146,18 +1146,9 @@ if torch._C._has_mkldnn:
if has_free_symbols(batch_size)
else batch_size,
)
# MKL packed matrix can't be copied to a different address because the internal implementation
# depends on the alignment of internally-stored metadata.
# In aot mode, we need to firstly save the packed weight, when loading it,
# it will be in a different address which doesn't work.
# Disable MKL prepack linear in AOT mode
packed_weight_op = (
mkldnn._reorder_linear_weight
if (
is_lp_weight
or mkldnn._is_mkldnn_acl_supported()
or V.aot_compilation is True
)
if (is_lp_weight or mkldnn._is_mkldnn_acl_supported())
else torch.ops.mkl._mkl_reorder_linear_weight
)
packed_weight_node = graph.create_node(
@ -1165,11 +1156,7 @@ if torch._C._has_mkldnn:
)
packed_linear_inputs: Tuple[Any, ...] = (input, packed_weight_node)
if (
is_lp_weight
or mkldnn._is_mkldnn_acl_supported()
or V.aot_compilation is True
):
if is_lp_weight or mkldnn._is_mkldnn_acl_supported():
packed_linear_inputs += (bias, "none", [], "")
packed_linear_op = mkldnn._linear_pointwise.default
else:

View File

@ -222,17 +222,8 @@ class AOTInductorModelBase {
auto size = this->constant_shape(i);
auto stride = this->constant_stride(i);
auto offset = this->constant_offset(i);
auto layout = this->constant_layout(i);
auto opaque_metadata_ptr = this->opaque_metadata(i);
auto opaque_metadata_size = this->opaque_metadata_size(i);
AtenTensorHandle tensor_handle;
#ifdef AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1
// When opaque_metadata_size is not 0, we need to have the
// aoti_torch_create_tensor_from_blob_v2 available
AOTI_RUNTIME_CHECK(
opaque_metadata_size == 0,
"Expect opaque_metadata_size to be 0 when AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1 is defined");
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob(
internal_ptr,
ndim,
@ -243,21 +234,6 @@ class AOTInductorModelBase {
device_type_,
device_idx_,
&tensor_handle));
#else
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_v2(
internal_ptr,
ndim,
size,
stride,
offset,
dtype,
device_type_,
device_idx_,
&tensor_handle,
layout,
opaque_metadata_ptr,
opaque_metadata_size));
#endif // AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1
constants_map_->emplace(std::move(name), tensor_handle);
}
if (constants_map_) {
@ -364,10 +340,6 @@ class AOTInductorModelBase {
return constants_info_.at(idx).dtype;
}
int32_t constant_layout(int64_t idx) const {
return constants_info_.at(idx).layout;
}
size_t constant_offset(int64_t idx) const {
return constants_info_.at(idx).offset;
}
@ -380,14 +352,6 @@ class AOTInductorModelBase {
return constants_info_.at(idx).original_fqn;
}
const uint8_t* opaque_metadata(int64_t idx) const {
return constants_info_.at(idx).opaque_metadata.data();
}
size_t opaque_metadata_size(int64_t idx) {
return constants_info_.at(idx).opaque_metadata.size();
}
bool constant_from_folded(int64_t idx) const {
return constants_info_.at(idx).from_folded;
}
@ -521,9 +485,6 @@ class AOTInductorModelBase {
int32_t dtype;
int64_t offset;
size_t data_size;
int32_t layout;
std::vector<uint8_t> opaque_metadata;
int64_t opaque_metadata_size;
const char* original_fqn = nullptr;
bool from_folded;
};

View File

@ -174,7 +174,4 @@ inline AtenTensorHandle wrap_with_raii_handle_if_needed(
static auto cached_torch_device_type_##device = \
aoti_torch_device_type_##device()
#define CACHE_TORCH_LAYOUT(layout) \
static auto cached_torch_layout_##layout = aoti_torch_layout_##layout()
} // namespace torch::aot_inductor

View File

@ -112,9 +112,6 @@ AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex32();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex64();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex128();
AOTI_TORCH_EXPORT int32_t aoti_torch_layout_strided();
AOTI_TORCH_EXPORT int32_t aoti_torch_layout__mkldnn();
// Functions for converting a single-element tensor to a scalar value
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_item_float32(AtenTensorHandle tensor, float* ret_value);
@ -273,20 +270,6 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
AtenTensorHandle* ret // returns new reference
);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2(
void* data,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t storage_offset,
int32_t dtype,
int32_t device_type,
int32_t device_index,
AtenTensorHandle* ret, // returns new reference
int32_t layout,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag(
AtenTensorHandle weight,
AtenTensorHandle indices,

View File

@ -1,49 +0,0 @@
#include <ATen/Config.h>
#include <torch/csrc/inductor/aoti_torch/mkldnn_tensor.h>
#if AT_MKLDNN_ENABLED()
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ideep.hpp>
#endif
namespace torch {
namespace aot_inductor {
#if AT_MKLDNN_ENABLED()
void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor) {
return reinterpret_cast<void*>(
at::native::data_ptr_from_mkldnn(*mkldnn_tensor));
}
at::Tensor mkldnn_tensor_from_data_ptr(
void* data_ptr,
at::IntArrayRef dims,
at::ScalarType dtype,
at::Device device,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size) {
return at::native::mkldnn_tensor_from_data_ptr(
data_ptr, dims, dtype, device, opaque_metadata, opaque_metadata_size);
}
#else
void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor) {
TORCH_CHECK(false, "MKL-DNN build is disabled");
}
at::Tensor mkldnn_tensor_from_data_ptr(
void* data_ptr,
at::IntArrayRef dims,
at::ScalarType dtype,
at::Device device,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size) {
TORCH_CHECK(false, "MKL-DNN build is disabled");
}
#endif
} // namespace aot_inductor
} // namespace torch

View File

@ -1,19 +0,0 @@
#pragma once
#include <ATen/Tensor.h>
namespace torch {
namespace aot_inductor {
void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor);
at::Tensor mkldnn_tensor_from_data_ptr(
void* data_ptr,
at::IntArrayRef dims,
at::ScalarType dtype,
at::Device device,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size);
} // namespace aot_inductor
} // namespace torch

View File

@ -1,10 +1,8 @@
#include <c10/core/DeviceType.h>
#include <c10/core/GradMode.h>
#include <c10/core/Layout.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/mkldnn_tensor.h>
#include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
@ -92,14 +90,6 @@ AOTI_TORCH_DTYPE_IMPL(complex64, ComplexFloat)
AOTI_TORCH_DTYPE_IMPL(complex128, ComplexDouble)
#undef AOTI_TORCH_DTYPE_IMPL
int32_t aoti_torch_layout_strided() {
return (int32_t)at::kStrided;
}
int32_t aoti_torch_layout__mkldnn() {
return (int32_t)at::kMkldnn;
}
#define AOTI_TORCH_ITEM_IMPL(dtype, ctype) \
AOTITorchError aoti_torch_item_##dtype( \
AtenTensorHandle tensor, ctype* ret_value) { \
@ -164,11 +154,7 @@ AOTITorchError aoti_torch_get_data_ptr(
void** ret_data_ptr) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
if (t->is_mkldnn()) {
*ret_data_ptr = data_ptr_from_mkldnn(t);
} else {
*ret_data_ptr = t->data_ptr();
}
*ret_data_ptr = t->data_ptr();
});
}
@ -339,48 +325,6 @@ AOTITorchError aoti_torch_create_tensor_from_blob(
});
}
AOTITorchError aoti_torch_create_tensor_from_blob_v2(
void* data,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t storage_offset,
int32_t dtype,
int32_t device_type,
int32_t device_index,
AtenTensorHandle* ret_new_tensor,
int32_t layout,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
if (layout == static_cast<int32_t>(at::kMkldnn)) {
c10::IntArrayRef sizes(sizes_ptr, ndim);
c10::IntArrayRef strides(strides_ptr, ndim);
c10::Device device = c10_device(device_type, device_index);
// get a mkldnn tensor wrapped by a torch Tensor(OpaqueTensorImpl),
// which used by later mkldnn op.
*ret_new_tensor = new_tensor_handle(mkldnn_tensor_from_data_ptr(
data,
sizes,
static_cast<c10::ScalarType>(dtype),
device,
opaque_metadata,
opaque_metadata_size));
} else {
aoti_torch_create_tensor_from_blob(
data,
ndim,
sizes_ptr,
strides_ptr,
storage_offset,
dtype,
device_type,
device_index,
ret_new_tensor);
}
});
}
AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag(
AtenTensorHandle weight,
AtenTensorHandle indices,