mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI] support freezing for MKLDNN (#124350)
## Description
Fixes https://github.com/pytorch/pytorch/issues/114450. This PR builds upon the work from @imzhuhl done in https://github.com/pytorch/pytorch/pull/114451.
This PR requires https://github.com/pytorch/pytorch/pull/122472 to land firstly.
We leverage the serialization and deserialization API from oneDNN v3.4.1 to save the opaque MKLDNN tensor during the compilation and restore the opaque tensor when loading the compiled .so.
ideep version is updated so that we won't break any pipeline even if third_party/ideep is not updated at the same time.
### Test plan:
```sh
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_freezing_non_abi_compatible_cpu
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_conv_freezing_non_abi_compatible_cpu
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_deconv_freezing_non_abi_compatible_cpu
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_linear_freezing_non_abi_compatible_cpu
```
### TODOs in follow-up PRs
1. We found that using `AOTI_TORCH_CHECK` will cause performance drop on several models (`DistillGPT2`, `MBartForConditionalGeneration`, `T5ForConditionalGeneration`, `T5Small`) compared with JIT Inductor which uses `TORCH_CHECK`. This may need further discussion how to address (`AOTI_TORCH_CHECK` is introduced in
https://github.com/pytorch/pytorch/pull/119220).
2. Freezing in non-ABI compatible mode will work with the support in this PR. While for ABI compatible mode, we need to firstly address this issue: `AssertionError: None, i.e. optional output is not supported`.
6c4f43f826/torch/_inductor/codegen/cpp_wrapper_cpu.py (L2023-L2024)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124350
Approved by: https://github.com/jgong5, https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
43baabe9b9
commit
654afb6f3a
@ -1,6 +1,7 @@
|
||||
#include <ATen/native/mkldnn/MKLDNNCommon.h>
|
||||
#include <ATen/OpaqueTensorImpl.h>
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
|
||||
@ -61,6 +62,33 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
|
||||
}
|
||||
}
|
||||
|
||||
int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor) {
|
||||
MKLDNNTensorImpl *mklimpl = static_cast<MKLDNNTensorImpl *>(mkldnn_tensor.unsafeGetTensorImpl());
|
||||
void* data_ptr = mklimpl->unsafe_opaque_handle()->get_target().get_data_handle();
|
||||
return reinterpret_cast<int64_t>(data_ptr);
|
||||
}
|
||||
|
||||
at::Tensor mkldnn_tensor_from_data_ptr(
|
||||
void* data_ptr,
|
||||
at::IntArrayRef dims,
|
||||
at::ScalarType dtype,
|
||||
at::Device device,
|
||||
const uint8_t* opaque_metadata,
|
||||
int64_t opaque_metadata_size) {
|
||||
std::vector<uint8_t> vector_serialized_md{
|
||||
opaque_metadata, opaque_metadata + opaque_metadata_size};
|
||||
ideep::tensor::desc deserialized_ideep_desc;
|
||||
#if IDEEP_PREREQ(3, 4, 1, 2)
|
||||
// groups is needed for grouped conv
|
||||
deserialized_ideep_desc = ideep::tensor::desc(vector_serialized_md);
|
||||
#else
|
||||
TORCH_CHECK(false, "Unexpected IDeep version to do weight deserialization.");
|
||||
#endif
|
||||
|
||||
auto a = ideep::tensor(deserialized_ideep_desc, data_ptr);
|
||||
return at::native::new_with_itensor_mkldnn(std::move(a), dtype, device);
|
||||
}
|
||||
|
||||
Tensor new_with_itensor_mkldnn(ideep::tensor&& it, std::optional<ScalarType> dtype, std::optional<Device> device) {
|
||||
// NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
|
||||
// TODO: support int64_t dims in ideep::tensor to avoid extra conversion
|
||||
@ -81,6 +109,11 @@ ideep::tensor& itensor_from_mkldnn(const MKLDNNTensor& mkldnn_tensor) {
|
||||
return mklimpl->unsafe_opaque_handle()->get_target();
|
||||
}
|
||||
|
||||
int64_t nbytes_from_mkldnn(const Tensor& mkldnn_tensor) {
|
||||
ideep::tensor t = itensor_from_mkldnn(mkldnn_tensor);
|
||||
return t.get_desc().get_size();
|
||||
}
|
||||
|
||||
ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr) {
|
||||
TORCH_CHECK(
|
||||
tensor.device().is_cpu(),
|
||||
@ -167,6 +200,15 @@ int set_verbose(int level) {
|
||||
return ideep::utils::set_verbose(level);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("mkldnn::data_ptr"),
|
||||
TORCH_FN(data_ptr_from_mkldnn));
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("mkldnn::_nbytes"),
|
||||
TORCH_FN(nbytes_from_mkldnn));
|
||||
}
|
||||
|
||||
}}
|
||||
|
||||
#endif // AT_MKLDNN_ENABLED()
|
||||
|
@ -28,12 +28,24 @@ static inline ideep::tensor::data_type get_mkldnn_dtype(const Tensor& t) {
|
||||
return get_mkldnn_dtype(t.scalar_type());
|
||||
}
|
||||
|
||||
TORCH_API int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor);
|
||||
|
||||
TORCH_API at::Tensor mkldnn_tensor_from_data_ptr(
|
||||
void* data_ptr,
|
||||
at::IntArrayRef dims,
|
||||
at::ScalarType dtype,
|
||||
at::Device device,
|
||||
const uint8_t* opaque_metadata,
|
||||
int64_t opaque_metadata_size);
|
||||
|
||||
// Construct aten MKL-DNN tensor given an ideep tensor
|
||||
TORCH_API Tensor new_with_itensor_mkldnn(ideep::tensor&& it, std::optional<ScalarType> dtype, std::optional<Device> device);
|
||||
|
||||
// Retrieve `ideep::tensor` from MKL-DNN tensor
|
||||
TORCH_API ideep::tensor& itensor_from_mkldnn(const Tensor& mkldnn_tensor);
|
||||
|
||||
TORCH_API int64_t nbytes_from_mkldnn(const Tensor& mkldnn_tensor);
|
||||
|
||||
// Construct an `ideep::tensor` "view" from dense tensor, note the
|
||||
// ideep::tensor will share the underlying buffer
|
||||
TORCH_API ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr=false);
|
||||
|
@ -12,7 +12,9 @@
|
||||
#else
|
||||
#include <ATen/ops/_to_dense_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_like.h>
|
||||
#include <ATen/ops/empty_native.h>
|
||||
#include <ATen/ops/from_blob.h>
|
||||
#include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
|
||||
#include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
|
||||
#include <ATen/ops/to_mkldnn_native.h>
|
||||
@ -508,6 +510,25 @@ static std::vector<Tensor> mkldnn_reorder_mkldnn_rnn_layer_weight(
|
||||
return {packed_w1, packed_w2};
|
||||
}
|
||||
|
||||
static Tensor get_mkldnn_serialized_md(const Tensor& self) {
|
||||
const ideep::tensor packed_w = itensor_from_tensor(self);
|
||||
auto packed_w_desc = packed_w.get_desc();
|
||||
std::vector<uint8_t> serialized_wei_desc;
|
||||
|
||||
#if IDEEP_PREREQ(3, 4, 1, 2)
|
||||
serialized_wei_desc = packed_w_desc.get_blob();
|
||||
#else
|
||||
TORCH_CHECK(false, "Unexpected IDeep version to do weight serialization.");
|
||||
#endif
|
||||
Tensor serialized_md = at::from_blob((void*)serialized_wei_desc.data(), {(int64_t)serialized_wei_desc.size()}, at::TensorOptions(at::kByte));
|
||||
auto res = at::empty_like(serialized_md);
|
||||
// serialized_md shares the buffer with serialized_wei_desc,
|
||||
// which will be released outside of this function thus invalidating the buffer of serialized_md.
|
||||
// A copy is needed here so that res has its own buffer, which remains valid even after serialized_wei_desc is released.
|
||||
res.copy_(serialized_md);
|
||||
return res;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_transpose_weight"),
|
||||
@ -523,6 +544,12 @@ TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
|
||||
TORCH_FN(mkldnn_reorder_mkldnn_rnn_layer_weight));
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("mkldnn::_get_mkldnn_serialized_md"),
|
||||
TORCH_FN(get_mkldnn_serialized_md ));
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, std::optional<ScalarType> dtype, std::optional<bool> masked_grad) {
|
||||
|
@ -74,6 +74,9 @@ TORCH_LIBRARY(mkldnn, m) {
|
||||
m.def("_is_mkldnn_bf16_supported", &is_mkldnn_bf16_supported);
|
||||
m.def("_is_mkldnn_fp16_supported", &is_mkldnn_fp16_supported);
|
||||
m.def("_is_mkldnn_acl_supported", &is_mkldnn_acl_supported);
|
||||
m.def("mkldnn::data_ptr(Tensor mkldnn_tensor) -> int");
|
||||
m.def("mkldnn::_get_mkldnn_serialized_md (Tensor mkldnn_tensor) -> Tensor");
|
||||
m.def("mkldnn::_nbytes(Tensor mkldnn_tensor) -> int");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(mkldnn_prepacked, m) {
|
||||
|
@ -471,6 +471,7 @@ inductor_core_resources = [
|
||||
"torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp",
|
||||
"torch/csrc/inductor/aoti_torch/shim_common.cpp",
|
||||
"torch/csrc/inductor/aoti_torch/tensor_converter.cpp",
|
||||
"torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp",
|
||||
"torch/csrc/inductor/inductor_ops.cpp",
|
||||
]
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
@ -89,6 +90,8 @@ def check_model(
|
||||
options=None,
|
||||
dynamic_shapes=None,
|
||||
disable_constraint_solver=False,
|
||||
atol=None,
|
||||
rtol=None,
|
||||
):
|
||||
with torch.no_grad(), config.patch(
|
||||
{
|
||||
@ -114,7 +117,7 @@ def check_model(
|
||||
disable_constraint_solver,
|
||||
)
|
||||
|
||||
self.assertTrue(same(actual, expected))
|
||||
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
def check_model_with_multiple_inputs(
|
||||
@ -312,6 +315,10 @@ class AOTInductorTestsTemplate:
|
||||
)
|
||||
self.check_model(Model(self.device), example_inputs)
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_FBCODE,
|
||||
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
|
||||
)
|
||||
def test_freezing(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
@ -331,6 +338,80 @@ class AOTInductorTestsTemplate:
|
||||
with config.patch({"freezing": True}):
|
||||
self.check_model(Model(self.device), example_inputs)
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_FBCODE,
|
||||
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
|
||||
)
|
||||
def test_conv_freezing(self):
|
||||
for dtype, groups in itertools.product([torch.bfloat16, torch.float], [1, 2]):
|
||||
iC = 2
|
||||
oC = 3
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.weight = torch.randn(oC * groups, iC, 3, 3, device=device).to(
|
||||
dtype
|
||||
)
|
||||
|
||||
def forward(self, y):
|
||||
return torch.nn.functional.conv2d(y, self.weight, groups=groups)
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(2, iC * groups, 10, 10, device=self.device).to(dtype),
|
||||
)
|
||||
|
||||
with config.patch({"freezing": True}):
|
||||
self.check_model(Model(self.device), example_inputs)
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_FBCODE,
|
||||
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
|
||||
)
|
||||
def test_deconv_freezing(self):
|
||||
dtypes = [torch.float]
|
||||
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
||||
dtypes.append(torch.bfloat16)
|
||||
for dtype, groups in itertools.product(dtypes, [2, 1]):
|
||||
iC = 4
|
||||
oC = 2
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.weight = torch.randn(iC, oC * groups, 2, 2, device=device).to(
|
||||
dtype
|
||||
)
|
||||
|
||||
def forward(self, y):
|
||||
return torch.nn.functional.conv_transpose2d(
|
||||
y, self.weight, groups=groups
|
||||
)
|
||||
|
||||
example_inputs = (torch.randn(1, iC, 3, 3, device=self.device).to(dtype),)
|
||||
with config.patch({"freezing": True}):
|
||||
self.check_model(Model(self.device), example_inputs)
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_FBCODE,
|
||||
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
|
||||
)
|
||||
def test_linear_freezing(self):
|
||||
for dtype in [torch.float32, torch.bfloat16]:
|
||||
|
||||
class LinearModel(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.weight = torch.randn(10, 10, device=device).to(dtype)
|
||||
|
||||
def forward(self, y):
|
||||
return torch.nn.functional.linear(y, self.weight)
|
||||
|
||||
example_inputs = (torch.randn(10, 10, device=self.device).to(dtype),)
|
||||
|
||||
with config.patch({"freezing": True}):
|
||||
self.check_model(LinearModel(self.device), example_inputs)
|
||||
|
||||
@torch._inductor.config.patch(
|
||||
pre_grad_fusion_options={
|
||||
"normalization_pass": {},
|
||||
@ -1390,7 +1471,9 @@ class AOTInductorTestsTemplate:
|
||||
torch.randn(87, 87, device=self.device),
|
||||
torch.randn(87, 87, device=self.device),
|
||||
)
|
||||
self.check_model(Model(), example_inputs)
|
||||
self.check_model(
|
||||
Model(), example_inputs, atol=1e-4, rtol=1e-4
|
||||
) # 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py
|
||||
|
||||
if self.device == "cuda":
|
||||
so_path = torch._export.aot_compile(Model(), example_inputs)
|
||||
@ -2872,6 +2955,12 @@ def fail_non_abi_compatible_cuda(is_skip=False):
|
||||
# test_failures, xfail by default, set is_skip=True to skip
|
||||
CPU_TEST_FAILURES = {
|
||||
"test_add_complex": fail_stack_allocation(is_skip=True),
|
||||
# TODO: test_conv_freezing_abi_compatible_cpu fails,
|
||||
# AssertionError: None, i.e. optional output is not supported
|
||||
"test_conv_freezing": fail_with_and_without_stack_allocation(is_skip=True),
|
||||
# TODO: test_deconv_freezing_abi_compatible_cpu fails,
|
||||
# AssertionError: None, i.e. optional output is not supported
|
||||
"test_deconv_freezing": fail_with_and_without_stack_allocation(is_skip=True),
|
||||
# FIXME: failed with Segfault while exiting the Python runtime
|
||||
"test_duplicate_constant_folding": fail_with_and_without_stack_allocation(
|
||||
is_skip=True
|
||||
@ -2885,9 +2974,12 @@ CPU_TEST_FAILURES = {
|
||||
"test_dynamic_scalar": fail_stack_allocation(is_skip=True),
|
||||
# https://github.com/pytorch/pytorch/issues/122980
|
||||
"test_fft_c2c": fail_stack_allocation(is_skip=True),
|
||||
# TODO: test_freezing_abi_compatible_cpu somehow fails on CI but not locally,
|
||||
# NotImplementedError: Cannot access storage of OpaqueTensorImpl
|
||||
# TODO: test_freezing_abi_compatible_cpu fails,
|
||||
# AssertionError: None, i.e. optional output is not supported
|
||||
"test_freezing": fail_with_and_without_stack_allocation(is_skip=True),
|
||||
# TODO: test_linear_freezing_abi_compatible_cpu fails,
|
||||
# AssertionError: None, i.e. optional output is not supported
|
||||
"test_linear_freezing": fail_with_and_without_stack_allocation(is_skip=True),
|
||||
# FIXME: failed with Segfault while exiting the Python runtime
|
||||
"test_missing_cubin": fail_with_and_without_stack_allocation(is_skip=True),
|
||||
# minimal arrayref interface only works with CPU; test crashes.
|
||||
@ -3129,9 +3221,6 @@ copy_tests(
|
||||
"test_duplicate_constant_folding": TestFailure(
|
||||
("non_abi_compatible_cpu",), is_skip=True
|
||||
),
|
||||
# TODO: test_freezing_non_abi_compatible_cpu somehow fails on CI but not locally,
|
||||
# NotImplementedError: Cannot access storage of OpaqueTensorImpl
|
||||
"test_freezing": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
|
||||
# no runtime checks for non_abi_compatible mode
|
||||
"test_runtime_checks": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
|
||||
"test_runtime_checks_dtype_failed": TestFailure(
|
||||
|
@ -1522,6 +1522,10 @@ def use_custom_generated_macros() -> str:
|
||||
|
||||
def use_fb_internal_macros() -> str:
|
||||
if config.is_fbcode():
|
||||
# TODO: this is to avoid FC breakage for fbcode. When using newly
|
||||
# generated model.so on an older verion of PyTorch, need to use
|
||||
# the v1 version for aoti_torch_create_tensor_from_blob
|
||||
create_tensor_from_blob_v1 = "-D AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1"
|
||||
openmp_lib = build_paths.openmp_lib()
|
||||
preprocessor_flags = " ".join(
|
||||
(
|
||||
@ -1530,7 +1534,7 @@ def use_fb_internal_macros() -> str:
|
||||
"-D C10_DISABLE_TENSORIMPL_EXTENSIBILITY",
|
||||
)
|
||||
)
|
||||
return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags}"
|
||||
return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags} {create_tensor_from_blob_v1}"
|
||||
else:
|
||||
return ""
|
||||
|
||||
@ -2076,7 +2080,9 @@ class AotCodeCompiler:
|
||||
|
||||
output_o = os.path.splitext(input_path)[0] + ".o"
|
||||
consts_size = sum(
|
||||
tensor.untyped_storage().nbytes()
|
||||
torch.ops.mkldnn._nbytes(tensor)
|
||||
if tensor.is_mkldnn
|
||||
else tensor.untyped_storage().nbytes()
|
||||
for (name, tensor) in graph.constants.items()
|
||||
if name not in graph.folded_constants
|
||||
)
|
||||
@ -2109,6 +2115,13 @@ class AotCodeCompiler:
|
||||
if t.numel() == 0:
|
||||
return b""
|
||||
|
||||
if t.is_mkldnn:
|
||||
raw_array = ctypes.cast(
|
||||
torch.ops.mkldnn.data_ptr(t),
|
||||
ctypes.POINTER(ctypes.c_ubyte * torch.ops.mkldnn._nbytes(t)),
|
||||
)
|
||||
return bytes(raw_array.contents)
|
||||
|
||||
t_cpu = t.untyped_storage().cpu()
|
||||
raw_array = ctypes.cast(
|
||||
t_cpu.data_ptr(),
|
||||
|
@ -1971,6 +1971,8 @@ class CppKernel(Kernel):
|
||||
@property
|
||||
def assert_function(self) -> str:
|
||||
if V.graph.aot_mode:
|
||||
# TODO: Using AOTI_TORCH_CHECK is causing performance drop for some models
|
||||
# compared with JIT Inductor which uses TORCH_CHECK
|
||||
return "AOTI_TORCH_CHECK"
|
||||
else:
|
||||
return "TORCH_CHECK"
|
||||
|
@ -64,6 +64,11 @@ DEVICE_TO_ATEN = {
|
||||
"cuda": "at::kCUDA",
|
||||
}
|
||||
|
||||
LAYOUT_TO_ATEN = {
|
||||
torch.strided: "at::kStrided",
|
||||
torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined]
|
||||
}
|
||||
|
||||
INDEX_TYPE = "long"
|
||||
|
||||
GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"])
|
||||
|
@ -18,7 +18,14 @@ from ..utils import cache_on_self, sympy_product
|
||||
from ..virtualized import V
|
||||
from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
||||
from .common import IndentedBuffer
|
||||
from .cpp_utils import cexpr, CppPrinter, DEVICE_TO_ATEN, DTYPE_TO_ATEN, DTYPE_TO_CPP
|
||||
from .cpp_utils import (
|
||||
cexpr,
|
||||
CppPrinter,
|
||||
DEVICE_TO_ATEN,
|
||||
DTYPE_TO_ATEN,
|
||||
DTYPE_TO_CPP,
|
||||
LAYOUT_TO_ATEN,
|
||||
)
|
||||
from .wrapper import EnterSubgraphLine, ExitSubgraphLine, WrapperCodeGen
|
||||
|
||||
|
||||
@ -56,6 +63,7 @@ class CppWrapperCpu(WrapperCodeGen):
|
||||
self.arg_var_id = count()
|
||||
self.used_cached_devices = set()
|
||||
self.used_cached_dtypes = set()
|
||||
self.used_cached_layouts = set()
|
||||
self.cached_output_id = count()
|
||||
self.scalar_to_tensor_id = count()
|
||||
self.custom_op_wrapper_loaded = False
|
||||
@ -722,6 +730,11 @@ class CppWrapperCpu(WrapperCodeGen):
|
||||
self.prefix.writeline(
|
||||
f"constants_info_[{idx}].offset = {tensor.storage_offset()};"
|
||||
)
|
||||
if tensor.is_mkldnn:
|
||||
self.prefix.writeline(
|
||||
f"constants_info_[{idx}].data_size = {torch.ops.mkldnn._nbytes(tensor)};"
|
||||
)
|
||||
else:
|
||||
self.prefix.writeline(
|
||||
f"constants_info_[{idx}].data_size = {tensor.untyped_storage().nbytes()};"
|
||||
)
|
||||
@ -737,6 +750,23 @@ class CppWrapperCpu(WrapperCodeGen):
|
||||
self.prefix.writeline(
|
||||
f"constants_info_[{idx}].stride = {{{stride_str}}};"
|
||||
)
|
||||
self.prefix.writeline(
|
||||
f"constants_info_[{idx}].layout = static_cast<int32_t>({self.codegen_layout(tensor.layout)});"
|
||||
)
|
||||
|
||||
if tensor.is_mkldnn:
|
||||
opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md(
|
||||
tensor
|
||||
)
|
||||
assert (
|
||||
opaque_metadata_tensor.dim() == 1
|
||||
), "Expect opaque_metadata_tensor to be 1-D"
|
||||
|
||||
opaque_metadata_list = opaque_metadata_tensor.tolist()
|
||||
opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list)
|
||||
self.prefix.writeline(
|
||||
f"constants_info_[{idx}].opaque_metadata = {opaque_metadata_str};"
|
||||
)
|
||||
if name in V.graph.dynamo_flat_name_to_original_fqn:
|
||||
original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get(
|
||||
name, name
|
||||
@ -877,6 +907,8 @@ class CppWrapperCpu(WrapperCodeGen):
|
||||
cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});")
|
||||
for device in self.used_cached_devices:
|
||||
cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});")
|
||||
for layout in self.used_cached_layouts:
|
||||
cached_dtypes_buffer.writeline(f"CACHE_TORCH_LAYOUT({layout});")
|
||||
cached_dtypes_buffer.splice(self.prefix)
|
||||
self.prefix = cached_dtypes_buffer
|
||||
|
||||
@ -1493,6 +1525,14 @@ class CppWrapperCpu(WrapperCodeGen):
|
||||
else:
|
||||
return DTYPE_TO_ATEN[dtype]
|
||||
|
||||
def codegen_layout(self, layout):
|
||||
if config.abi_compatible:
|
||||
layout_str = str(layout).split(".")[-1]
|
||||
self.used_cached_layouts.add(layout_str)
|
||||
return f"cached_torch_layout_{layout_str}"
|
||||
else:
|
||||
return LAYOUT_TO_ATEN[layout]
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def codegen_int_array_var(
|
||||
self,
|
||||
|
@ -18,7 +18,7 @@ from ..pattern_matcher import (
|
||||
KeywordArg,
|
||||
MULTIPLE,
|
||||
)
|
||||
from ..virtualized import ops
|
||||
from ..virtualized import ops, V
|
||||
from .freezing_patterns import register_freezing_graph_pattern
|
||||
from .post_grad import register_lowering_pattern
|
||||
from .quantization import (
|
||||
@ -1146,9 +1146,18 @@ if torch._C._has_mkldnn:
|
||||
if has_free_symbols(batch_size)
|
||||
else batch_size,
|
||||
)
|
||||
# MKL packed matrix can't be copied to a different address because the internal implementation
|
||||
# depends on the alignment of internally-stored metadata.
|
||||
# In aot mode, we need to firstly save the packed weight, when loading it,
|
||||
# it will be in a different address which doesn't work.
|
||||
# Disable MKL prepack linear in AOT mode
|
||||
packed_weight_op = (
|
||||
mkldnn._reorder_linear_weight
|
||||
if (is_lp_weight or mkldnn._is_mkldnn_acl_supported())
|
||||
if (
|
||||
is_lp_weight
|
||||
or mkldnn._is_mkldnn_acl_supported()
|
||||
or V.aot_compilation is True
|
||||
)
|
||||
else torch.ops.mkl._mkl_reorder_linear_weight
|
||||
)
|
||||
packed_weight_node = graph.create_node(
|
||||
@ -1156,7 +1165,11 @@ if torch._C._has_mkldnn:
|
||||
)
|
||||
|
||||
packed_linear_inputs: Tuple[Any, ...] = (input, packed_weight_node)
|
||||
if is_lp_weight or mkldnn._is_mkldnn_acl_supported():
|
||||
if (
|
||||
is_lp_weight
|
||||
or mkldnn._is_mkldnn_acl_supported()
|
||||
or V.aot_compilation is True
|
||||
):
|
||||
packed_linear_inputs += (bias, "none", [], "")
|
||||
packed_linear_op = mkldnn._linear_pointwise.default
|
||||
else:
|
||||
|
@ -222,8 +222,17 @@ class AOTInductorModelBase {
|
||||
auto size = this->constant_shape(i);
|
||||
auto stride = this->constant_stride(i);
|
||||
auto offset = this->constant_offset(i);
|
||||
auto layout = this->constant_layout(i);
|
||||
auto opaque_metadata_ptr = this->opaque_metadata(i);
|
||||
auto opaque_metadata_size = this->opaque_metadata_size(i);
|
||||
|
||||
AtenTensorHandle tensor_handle;
|
||||
#ifdef AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1
|
||||
// When opaque_metadata_size is not 0, we need to have the
|
||||
// aoti_torch_create_tensor_from_blob_v2 available
|
||||
AOTI_RUNTIME_CHECK(
|
||||
opaque_metadata_size == 0,
|
||||
"Expect opaque_metadata_size to be 0 when AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1 is defined");
|
||||
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob(
|
||||
internal_ptr,
|
||||
ndim,
|
||||
@ -234,6 +243,21 @@ class AOTInductorModelBase {
|
||||
device_type_,
|
||||
device_idx_,
|
||||
&tensor_handle));
|
||||
#else
|
||||
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_v2(
|
||||
internal_ptr,
|
||||
ndim,
|
||||
size,
|
||||
stride,
|
||||
offset,
|
||||
dtype,
|
||||
device_type_,
|
||||
device_idx_,
|
||||
&tensor_handle,
|
||||
layout,
|
||||
opaque_metadata_ptr,
|
||||
opaque_metadata_size));
|
||||
#endif // AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1
|
||||
constants_map_->emplace(std::move(name), tensor_handle);
|
||||
}
|
||||
if (constants_map_) {
|
||||
@ -340,6 +364,10 @@ class AOTInductorModelBase {
|
||||
return constants_info_.at(idx).dtype;
|
||||
}
|
||||
|
||||
int32_t constant_layout(int64_t idx) const {
|
||||
return constants_info_.at(idx).layout;
|
||||
}
|
||||
|
||||
size_t constant_offset(int64_t idx) const {
|
||||
return constants_info_.at(idx).offset;
|
||||
}
|
||||
@ -352,6 +380,14 @@ class AOTInductorModelBase {
|
||||
return constants_info_.at(idx).original_fqn;
|
||||
}
|
||||
|
||||
const uint8_t* opaque_metadata(int64_t idx) const {
|
||||
return constants_info_.at(idx).opaque_metadata.data();
|
||||
}
|
||||
|
||||
size_t opaque_metadata_size(int64_t idx) {
|
||||
return constants_info_.at(idx).opaque_metadata.size();
|
||||
}
|
||||
|
||||
bool constant_from_folded(int64_t idx) const {
|
||||
return constants_info_.at(idx).from_folded;
|
||||
}
|
||||
@ -485,6 +521,9 @@ class AOTInductorModelBase {
|
||||
int32_t dtype;
|
||||
int64_t offset;
|
||||
size_t data_size;
|
||||
int32_t layout;
|
||||
std::vector<uint8_t> opaque_metadata;
|
||||
int64_t opaque_metadata_size;
|
||||
const char* original_fqn = nullptr;
|
||||
bool from_folded;
|
||||
};
|
||||
|
@ -174,4 +174,7 @@ inline AtenTensorHandle wrap_with_raii_handle_if_needed(
|
||||
static auto cached_torch_device_type_##device = \
|
||||
aoti_torch_device_type_##device()
|
||||
|
||||
#define CACHE_TORCH_LAYOUT(layout) \
|
||||
static auto cached_torch_layout_##layout = aoti_torch_layout_##layout()
|
||||
|
||||
} // namespace torch::aot_inductor
|
||||
|
@ -112,6 +112,9 @@ AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex32();
|
||||
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex64();
|
||||
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex128();
|
||||
|
||||
AOTI_TORCH_EXPORT int32_t aoti_torch_layout_strided();
|
||||
AOTI_TORCH_EXPORT int32_t aoti_torch_layout__mkldnn();
|
||||
|
||||
// Functions for converting a single-element tensor to a scalar value
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_item_float32(AtenTensorHandle tensor, float* ret_value);
|
||||
@ -270,6 +273,20 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
|
||||
AtenTensorHandle* ret // returns new reference
|
||||
);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2(
|
||||
void* data,
|
||||
int64_t ndim,
|
||||
const int64_t* sizes_ptr,
|
||||
const int64_t* strides_ptr,
|
||||
int64_t storage_offset,
|
||||
int32_t dtype,
|
||||
int32_t device_type,
|
||||
int32_t device_index,
|
||||
AtenTensorHandle* ret, // returns new reference
|
||||
int32_t layout,
|
||||
const uint8_t* opaque_metadata,
|
||||
int64_t opaque_metadata_size);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag(
|
||||
AtenTensorHandle weight,
|
||||
AtenTensorHandle indices,
|
||||
|
49
torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp
Normal file
49
torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp
Normal file
@ -0,0 +1,49 @@
|
||||
#include <ATen/Config.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/mkldnn_tensor.h>
|
||||
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
#include <ATen/native/mkldnn/MKLDNNCommon.h>
|
||||
#include <ideep.hpp>
|
||||
#endif
|
||||
|
||||
namespace torch {
|
||||
namespace aot_inductor {
|
||||
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
|
||||
void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor) {
|
||||
return reinterpret_cast<void*>(
|
||||
at::native::data_ptr_from_mkldnn(*mkldnn_tensor));
|
||||
}
|
||||
|
||||
at::Tensor mkldnn_tensor_from_data_ptr(
|
||||
void* data_ptr,
|
||||
at::IntArrayRef dims,
|
||||
at::ScalarType dtype,
|
||||
at::Device device,
|
||||
const uint8_t* opaque_metadata,
|
||||
int64_t opaque_metadata_size) {
|
||||
return at::native::mkldnn_tensor_from_data_ptr(
|
||||
data_ptr, dims, dtype, device, opaque_metadata, opaque_metadata_size);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor) {
|
||||
TORCH_CHECK(false, "MKL-DNN build is disabled");
|
||||
}
|
||||
|
||||
at::Tensor mkldnn_tensor_from_data_ptr(
|
||||
void* data_ptr,
|
||||
at::IntArrayRef dims,
|
||||
at::ScalarType dtype,
|
||||
at::Device device,
|
||||
const uint8_t* opaque_metadata,
|
||||
int64_t opaque_metadata_size) {
|
||||
TORCH_CHECK(false, "MKL-DNN build is disabled");
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace aot_inductor
|
||||
} // namespace torch
|
19
torch/csrc/inductor/aoti_torch/mkldnn_tensor.h
Normal file
19
torch/csrc/inductor/aoti_torch/mkldnn_tensor.h
Normal file
@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
|
||||
namespace torch {
|
||||
namespace aot_inductor {
|
||||
|
||||
void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor);
|
||||
|
||||
at::Tensor mkldnn_tensor_from_data_ptr(
|
||||
void* data_ptr,
|
||||
at::IntArrayRef dims,
|
||||
at::ScalarType dtype,
|
||||
at::Device device,
|
||||
const uint8_t* opaque_metadata,
|
||||
int64_t opaque_metadata_size);
|
||||
|
||||
} // namespace aot_inductor
|
||||
} // namespace torch
|
@ -1,8 +1,10 @@
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/GradMode.h>
|
||||
#include <c10/core/Layout.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/mkldnn_tensor.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
@ -90,6 +92,14 @@ AOTI_TORCH_DTYPE_IMPL(complex64, ComplexFloat)
|
||||
AOTI_TORCH_DTYPE_IMPL(complex128, ComplexDouble)
|
||||
#undef AOTI_TORCH_DTYPE_IMPL
|
||||
|
||||
int32_t aoti_torch_layout_strided() {
|
||||
return (int32_t)at::kStrided;
|
||||
}
|
||||
|
||||
int32_t aoti_torch_layout__mkldnn() {
|
||||
return (int32_t)at::kMkldnn;
|
||||
}
|
||||
|
||||
#define AOTI_TORCH_ITEM_IMPL(dtype, ctype) \
|
||||
AOTITorchError aoti_torch_item_##dtype( \
|
||||
AtenTensorHandle tensor, ctype* ret_value) { \
|
||||
@ -154,7 +164,11 @@ AOTITorchError aoti_torch_get_data_ptr(
|
||||
void** ret_data_ptr) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
|
||||
if (t->is_mkldnn()) {
|
||||
*ret_data_ptr = data_ptr_from_mkldnn(t);
|
||||
} else {
|
||||
*ret_data_ptr = t->data_ptr();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -325,6 +339,48 @@ AOTITorchError aoti_torch_create_tensor_from_blob(
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_create_tensor_from_blob_v2(
|
||||
void* data,
|
||||
int64_t ndim,
|
||||
const int64_t* sizes_ptr,
|
||||
const int64_t* strides_ptr,
|
||||
int64_t storage_offset,
|
||||
int32_t dtype,
|
||||
int32_t device_type,
|
||||
int32_t device_index,
|
||||
AtenTensorHandle* ret_new_tensor,
|
||||
int32_t layout,
|
||||
const uint8_t* opaque_metadata,
|
||||
int64_t opaque_metadata_size) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
if (layout == static_cast<int32_t>(at::kMkldnn)) {
|
||||
c10::IntArrayRef sizes(sizes_ptr, ndim);
|
||||
c10::IntArrayRef strides(strides_ptr, ndim);
|
||||
c10::Device device = c10_device(device_type, device_index);
|
||||
// get a mkldnn tensor wrapped by a torch Tensor(OpaqueTensorImpl),
|
||||
// which used by later mkldnn op.
|
||||
*ret_new_tensor = new_tensor_handle(mkldnn_tensor_from_data_ptr(
|
||||
data,
|
||||
sizes,
|
||||
static_cast<c10::ScalarType>(dtype),
|
||||
device,
|
||||
opaque_metadata,
|
||||
opaque_metadata_size));
|
||||
} else {
|
||||
aoti_torch_create_tensor_from_blob(
|
||||
data,
|
||||
ndim,
|
||||
sizes_ptr,
|
||||
strides_ptr,
|
||||
storage_offset,
|
||||
dtype,
|
||||
device_type,
|
||||
device_index,
|
||||
ret_new_tensor);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag(
|
||||
AtenTensorHandle weight,
|
||||
AtenTensorHandle indices,
|
||||
|
Reference in New Issue
Block a user