From 654afb6f3ae3ddbd926a753f9af95a6f6e22131c Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Fri, 24 May 2024 08:01:27 +0000 Subject: [PATCH] [AOTI] support freezing for MKLDNN (#124350) ## Description Fixes https://github.com/pytorch/pytorch/issues/114450. This PR builds upon the work from @imzhuhl done in https://github.com/pytorch/pytorch/pull/114451. This PR requires https://github.com/pytorch/pytorch/pull/122472 to land firstly. We leverage the serialization and deserialization API from oneDNN v3.4.1 to save the opaque MKLDNN tensor during the compilation and restore the opaque tensor when loading the compiled .so. ideep version is updated so that we won't break any pipeline even if third_party/ideep is not updated at the same time. ### Test plan: ```sh python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_freezing_non_abi_compatible_cpu python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_conv_freezing_non_abi_compatible_cpu python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_deconv_freezing_non_abi_compatible_cpu python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_linear_freezing_non_abi_compatible_cpu ``` ### TODOs in follow-up PRs 1. We found that using `AOTI_TORCH_CHECK` will cause performance drop on several models (`DistillGPT2`, `MBartForConditionalGeneration`, `T5ForConditionalGeneration`, `T5Small`) compared with JIT Inductor which uses `TORCH_CHECK`. This may need further discussion how to address (`AOTI_TORCH_CHECK` is introduced in https://github.com/pytorch/pytorch/pull/119220). 2. Freezing in non-ABI compatible mode will work with the support in this PR. While for ABI compatible mode, we need to firstly address this issue: `AssertionError: None, i.e. optional output is not supported`. https://github.com/pytorch/pytorch/blob/6c4f43f82675b5fcfe8cf3e5983d0c0f326408aa/torch/_inductor/codegen/cpp_wrapper_cpu.py#L2023-L2024 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124350 Approved by: https://github.com/jgong5, https://github.com/desertfire --- aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp | 42 +++++++ aten/src/ATen/native/mkldnn/MKLDNNCommon.h | 12 ++ .../ATen/native/mkldnn/MKLDNNConversions.cpp | 27 +++++ .../mkldnn/RegisterMkldnnOpContextClass.cpp | 3 + build_variables.bzl | 1 + test/inductor/test_aot_inductor.py | 103 ++++++++++++++++-- torch/_inductor/codecache.py | 17 ++- torch/_inductor/codegen/cpp.py | 2 + torch/_inductor/codegen/cpp_utils.py | 5 + torch/_inductor/codegen/cpp_wrapper_cpu.py | 48 +++++++- torch/_inductor/fx_passes/mkldnn_fusion.py | 19 +++- torch/csrc/inductor/aoti_runtime/model.h | 39 +++++++ torch/csrc/inductor/aoti_runtime/utils.h | 3 + torch/csrc/inductor/aoti_torch/c/shim.h | 17 +++ .../inductor/aoti_torch/mkldnn_tensor.cpp | 49 +++++++++ .../csrc/inductor/aoti_torch/mkldnn_tensor.h | 19 ++++ .../csrc/inductor/aoti_torch/shim_common.cpp | 58 +++++++++- 17 files changed, 447 insertions(+), 17 deletions(-) create mode 100644 torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp create mode 100644 torch/csrc/inductor/aoti_torch/mkldnn_tensor.h diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp index b0bb5c3fcebc..954a5289cdfe 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp +++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #if AT_MKLDNN_ENABLED() @@ -61,6 +62,33 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) { } } +int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor) { + MKLDNNTensorImpl *mklimpl = static_cast(mkldnn_tensor.unsafeGetTensorImpl()); + void* data_ptr = mklimpl->unsafe_opaque_handle()->get_target().get_data_handle(); + return reinterpret_cast(data_ptr); +} + +at::Tensor mkldnn_tensor_from_data_ptr( + void* data_ptr, + at::IntArrayRef dims, + at::ScalarType dtype, + at::Device device, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size) { + std::vector vector_serialized_md{ + opaque_metadata, opaque_metadata + opaque_metadata_size}; + ideep::tensor::desc deserialized_ideep_desc; +#if IDEEP_PREREQ(3, 4, 1, 2) + // groups is needed for grouped conv + deserialized_ideep_desc = ideep::tensor::desc(vector_serialized_md); +#else + TORCH_CHECK(false, "Unexpected IDeep version to do weight deserialization."); +#endif + + auto a = ideep::tensor(deserialized_ideep_desc, data_ptr); + return at::native::new_with_itensor_mkldnn(std::move(a), dtype, device); +} + Tensor new_with_itensor_mkldnn(ideep::tensor&& it, std::optional dtype, std::optional device) { // NOTE: int32_t dims from ideep::tensor but sizes needs int64_t // TODO: support int64_t dims in ideep::tensor to avoid extra conversion @@ -81,6 +109,11 @@ ideep::tensor& itensor_from_mkldnn(const MKLDNNTensor& mkldnn_tensor) { return mklimpl->unsafe_opaque_handle()->get_target(); } +int64_t nbytes_from_mkldnn(const Tensor& mkldnn_tensor) { + ideep::tensor t = itensor_from_mkldnn(mkldnn_tensor); + return t.get_desc().get_size(); +} + ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr) { TORCH_CHECK( tensor.device().is_cpu(), @@ -167,6 +200,15 @@ int set_verbose(int level) { return ideep::utils::set_verbose(level); } +TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("mkldnn::data_ptr"), + TORCH_FN(data_ptr_from_mkldnn)); + m.impl( + TORCH_SELECTIVE_NAME("mkldnn::_nbytes"), + TORCH_FN(nbytes_from_mkldnn)); +} + }} #endif // AT_MKLDNN_ENABLED() diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.h b/aten/src/ATen/native/mkldnn/MKLDNNCommon.h index 4009a144c766..cc5739825d7e 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.h +++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.h @@ -28,12 +28,24 @@ static inline ideep::tensor::data_type get_mkldnn_dtype(const Tensor& t) { return get_mkldnn_dtype(t.scalar_type()); } +TORCH_API int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor); + +TORCH_API at::Tensor mkldnn_tensor_from_data_ptr( + void* data_ptr, + at::IntArrayRef dims, + at::ScalarType dtype, + at::Device device, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size); + // Construct aten MKL-DNN tensor given an ideep tensor TORCH_API Tensor new_with_itensor_mkldnn(ideep::tensor&& it, std::optional dtype, std::optional device); // Retrieve `ideep::tensor` from MKL-DNN tensor TORCH_API ideep::tensor& itensor_from_mkldnn(const Tensor& mkldnn_tensor); +TORCH_API int64_t nbytes_from_mkldnn(const Tensor& mkldnn_tensor); + // Construct an `ideep::tensor` "view" from dense tensor, note the // ideep::tensor will share the underlying buffer TORCH_API ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr=false); diff --git a/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp b/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp index dd0ccb66ff1d..5478b1e91e98 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp +++ b/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp @@ -12,7 +12,9 @@ #else #include #include +#include #include +#include #include #include #include @@ -508,6 +510,25 @@ static std::vector mkldnn_reorder_mkldnn_rnn_layer_weight( return {packed_w1, packed_w2}; } +static Tensor get_mkldnn_serialized_md(const Tensor& self) { + const ideep::tensor packed_w = itensor_from_tensor(self); + auto packed_w_desc = packed_w.get_desc(); + std::vector serialized_wei_desc; + +#if IDEEP_PREREQ(3, 4, 1, 2) + serialized_wei_desc = packed_w_desc.get_blob(); +#else + TORCH_CHECK(false, "Unexpected IDeep version to do weight serialization."); +#endif + Tensor serialized_md = at::from_blob((void*)serialized_wei_desc.data(), {(int64_t)serialized_wei_desc.size()}, at::TensorOptions(at::kByte)); + auto res = at::empty_like(serialized_md); + // serialized_md shares the buffer with serialized_wei_desc, + // which will be released outside of this function thus invalidating the buffer of serialized_md. + // A copy is needed here so that res has its own buffer, which remains valid even after serialized_wei_desc is released. + res.copy_(serialized_md); + return res; +} + TORCH_LIBRARY_IMPL(mkldnn, CPU, m) { m.impl( TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_transpose_weight"), @@ -523,6 +544,12 @@ TORCH_LIBRARY_IMPL(mkldnn, CPU, m) { TORCH_FN(mkldnn_reorder_mkldnn_rnn_layer_weight)); } +TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("mkldnn::_get_mkldnn_serialized_md"), + TORCH_FN(get_mkldnn_serialized_md )); +} + #else Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, std::optional dtype, std::optional masked_grad) { diff --git a/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp b/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp index 6ca39632818a..b8dc4ecd9ce7 100644 --- a/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp +++ b/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp @@ -74,6 +74,9 @@ TORCH_LIBRARY(mkldnn, m) { m.def("_is_mkldnn_bf16_supported", &is_mkldnn_bf16_supported); m.def("_is_mkldnn_fp16_supported", &is_mkldnn_fp16_supported); m.def("_is_mkldnn_acl_supported", &is_mkldnn_acl_supported); + m.def("mkldnn::data_ptr(Tensor mkldnn_tensor) -> int"); + m.def("mkldnn::_get_mkldnn_serialized_md (Tensor mkldnn_tensor) -> Tensor"); + m.def("mkldnn::_nbytes(Tensor mkldnn_tensor) -> int"); } TORCH_LIBRARY(mkldnn_prepacked, m) { diff --git a/build_variables.bzl b/build_variables.bzl index 152324a4d90c..ccd09b8fea93 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -471,6 +471,7 @@ inductor_core_resources = [ "torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp", "torch/csrc/inductor/aoti_torch/shim_common.cpp", "torch/csrc/inductor/aoti_torch/tensor_converter.cpp", + "torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp", "torch/csrc/inductor/inductor_ops.cpp", ] diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 8e20506139c5..193ca06a1d8d 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] import copy +import itertools import os import sys import tempfile @@ -89,6 +90,8 @@ def check_model( options=None, dynamic_shapes=None, disable_constraint_solver=False, + atol=None, + rtol=None, ): with torch.no_grad(), config.patch( { @@ -114,7 +117,7 @@ def check_model( disable_constraint_solver, ) - self.assertTrue(same(actual, expected)) + self.assertEqual(actual, expected, atol=atol, rtol=rtol) def check_model_with_multiple_inputs( @@ -312,6 +315,10 @@ class AOTInductorTestsTemplate: ) self.check_model(Model(self.device), example_inputs) + @unittest.skipIf( + IS_FBCODE, + "Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used", + ) def test_freezing(self): class Model(torch.nn.Module): def __init__(self, device): @@ -331,6 +338,80 @@ class AOTInductorTestsTemplate: with config.patch({"freezing": True}): self.check_model(Model(self.device), example_inputs) + @unittest.skipIf( + IS_FBCODE, + "Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used", + ) + def test_conv_freezing(self): + for dtype, groups in itertools.product([torch.bfloat16, torch.float], [1, 2]): + iC = 2 + oC = 3 + + class Model(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.weight = torch.randn(oC * groups, iC, 3, 3, device=device).to( + dtype + ) + + def forward(self, y): + return torch.nn.functional.conv2d(y, self.weight, groups=groups) + + example_inputs = ( + torch.randn(2, iC * groups, 10, 10, device=self.device).to(dtype), + ) + + with config.patch({"freezing": True}): + self.check_model(Model(self.device), example_inputs) + + @unittest.skipIf( + IS_FBCODE, + "Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used", + ) + def test_deconv_freezing(self): + dtypes = [torch.float] + if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + dtypes.append(torch.bfloat16) + for dtype, groups in itertools.product(dtypes, [2, 1]): + iC = 4 + oC = 2 + + class Model(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.weight = torch.randn(iC, oC * groups, 2, 2, device=device).to( + dtype + ) + + def forward(self, y): + return torch.nn.functional.conv_transpose2d( + y, self.weight, groups=groups + ) + + example_inputs = (torch.randn(1, iC, 3, 3, device=self.device).to(dtype),) + with config.patch({"freezing": True}): + self.check_model(Model(self.device), example_inputs) + + @unittest.skipIf( + IS_FBCODE, + "Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used", + ) + def test_linear_freezing(self): + for dtype in [torch.float32, torch.bfloat16]: + + class LinearModel(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.weight = torch.randn(10, 10, device=device).to(dtype) + + def forward(self, y): + return torch.nn.functional.linear(y, self.weight) + + example_inputs = (torch.randn(10, 10, device=self.device).to(dtype),) + + with config.patch({"freezing": True}): + self.check_model(LinearModel(self.device), example_inputs) + @torch._inductor.config.patch( pre_grad_fusion_options={ "normalization_pass": {}, @@ -1390,7 +1471,9 @@ class AOTInductorTestsTemplate: torch.randn(87, 87, device=self.device), torch.randn(87, 87, device=self.device), ) - self.check_model(Model(), example_inputs) + self.check_model( + Model(), example_inputs, atol=1e-4, rtol=1e-4 + ) # 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py if self.device == "cuda": so_path = torch._export.aot_compile(Model(), example_inputs) @@ -2872,6 +2955,12 @@ def fail_non_abi_compatible_cuda(is_skip=False): # test_failures, xfail by default, set is_skip=True to skip CPU_TEST_FAILURES = { "test_add_complex": fail_stack_allocation(is_skip=True), + # TODO: test_conv_freezing_abi_compatible_cpu fails, + # AssertionError: None, i.e. optional output is not supported + "test_conv_freezing": fail_with_and_without_stack_allocation(is_skip=True), + # TODO: test_deconv_freezing_abi_compatible_cpu fails, + # AssertionError: None, i.e. optional output is not supported + "test_deconv_freezing": fail_with_and_without_stack_allocation(is_skip=True), # FIXME: failed with Segfault while exiting the Python runtime "test_duplicate_constant_folding": fail_with_and_without_stack_allocation( is_skip=True @@ -2885,9 +2974,12 @@ CPU_TEST_FAILURES = { "test_dynamic_scalar": fail_stack_allocation(is_skip=True), # https://github.com/pytorch/pytorch/issues/122980 "test_fft_c2c": fail_stack_allocation(is_skip=True), - # TODO: test_freezing_abi_compatible_cpu somehow fails on CI but not locally, - # NotImplementedError: Cannot access storage of OpaqueTensorImpl + # TODO: test_freezing_abi_compatible_cpu fails, + # AssertionError: None, i.e. optional output is not supported "test_freezing": fail_with_and_without_stack_allocation(is_skip=True), + # TODO: test_linear_freezing_abi_compatible_cpu fails, + # AssertionError: None, i.e. optional output is not supported + "test_linear_freezing": fail_with_and_without_stack_allocation(is_skip=True), # FIXME: failed with Segfault while exiting the Python runtime "test_missing_cubin": fail_with_and_without_stack_allocation(is_skip=True), # minimal arrayref interface only works with CPU; test crashes. @@ -3129,9 +3221,6 @@ copy_tests( "test_duplicate_constant_folding": TestFailure( ("non_abi_compatible_cpu",), is_skip=True ), - # TODO: test_freezing_non_abi_compatible_cpu somehow fails on CI but not locally, - # NotImplementedError: Cannot access storage of OpaqueTensorImpl - "test_freezing": TestFailure(("non_abi_compatible_cpu",), is_skip=True), # no runtime checks for non_abi_compatible mode "test_runtime_checks": TestFailure(("non_abi_compatible_cpu",), is_skip=True), "test_runtime_checks_dtype_failed": TestFailure( diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 810ffa40255a..3c7e9f741d5a 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1522,6 +1522,10 @@ def use_custom_generated_macros() -> str: def use_fb_internal_macros() -> str: if config.is_fbcode(): + # TODO: this is to avoid FC breakage for fbcode. When using newly + # generated model.so on an older verion of PyTorch, need to use + # the v1 version for aoti_torch_create_tensor_from_blob + create_tensor_from_blob_v1 = "-D AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1" openmp_lib = build_paths.openmp_lib() preprocessor_flags = " ".join( ( @@ -1530,7 +1534,7 @@ def use_fb_internal_macros() -> str: "-D C10_DISABLE_TENSORIMPL_EXTENSIBILITY", ) ) - return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags}" + return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags} {create_tensor_from_blob_v1}" else: return "" @@ -2076,7 +2080,9 @@ class AotCodeCompiler: output_o = os.path.splitext(input_path)[0] + ".o" consts_size = sum( - tensor.untyped_storage().nbytes() + torch.ops.mkldnn._nbytes(tensor) + if tensor.is_mkldnn + else tensor.untyped_storage().nbytes() for (name, tensor) in graph.constants.items() if name not in graph.folded_constants ) @@ -2109,6 +2115,13 @@ class AotCodeCompiler: if t.numel() == 0: return b"" + if t.is_mkldnn: + raw_array = ctypes.cast( + torch.ops.mkldnn.data_ptr(t), + ctypes.POINTER(ctypes.c_ubyte * torch.ops.mkldnn._nbytes(t)), + ) + return bytes(raw_array.contents) + t_cpu = t.untyped_storage().cpu() raw_array = ctypes.cast( t_cpu.data_ptr(), diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 52f92bd0becb..9fcd952e37b6 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1971,6 +1971,8 @@ class CppKernel(Kernel): @property def assert_function(self) -> str: if V.graph.aot_mode: + # TODO: Using AOTI_TORCH_CHECK is causing performance drop for some models + # compared with JIT Inductor which uses TORCH_CHECK return "AOTI_TORCH_CHECK" else: return "TORCH_CHECK" diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index fdebe9929158..04fec1d56221 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -64,6 +64,11 @@ DEVICE_TO_ATEN = { "cuda": "at::kCUDA", } +LAYOUT_TO_ATEN = { + torch.strided: "at::kStrided", + torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined] +} + INDEX_TYPE = "long" GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"]) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 18a6c9967f2b..1259418fc09e 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -18,7 +18,14 @@ from ..utils import cache_on_self, sympy_product from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper from .common import IndentedBuffer -from .cpp_utils import cexpr, CppPrinter, DEVICE_TO_ATEN, DTYPE_TO_ATEN, DTYPE_TO_CPP +from .cpp_utils import ( + cexpr, + CppPrinter, + DEVICE_TO_ATEN, + DTYPE_TO_ATEN, + DTYPE_TO_CPP, + LAYOUT_TO_ATEN, +) from .wrapper import EnterSubgraphLine, ExitSubgraphLine, WrapperCodeGen @@ -56,6 +63,7 @@ class CppWrapperCpu(WrapperCodeGen): self.arg_var_id = count() self.used_cached_devices = set() self.used_cached_dtypes = set() + self.used_cached_layouts = set() self.cached_output_id = count() self.scalar_to_tensor_id = count() self.custom_op_wrapper_loaded = False @@ -722,9 +730,14 @@ class CppWrapperCpu(WrapperCodeGen): self.prefix.writeline( f"constants_info_[{idx}].offset = {tensor.storage_offset()};" ) - self.prefix.writeline( - f"constants_info_[{idx}].data_size = {tensor.untyped_storage().nbytes()};" - ) + if tensor.is_mkldnn: + self.prefix.writeline( + f"constants_info_[{idx}].data_size = {torch.ops.mkldnn._nbytes(tensor)};" + ) + else: + self.prefix.writeline( + f"constants_info_[{idx}].data_size = {tensor.untyped_storage().nbytes()};" + ) from_folded = "true" if name in V.graph.folded_constants else "false" self.prefix.writeline( f"constants_info_[{idx}].from_folded = {from_folded};" @@ -737,6 +750,23 @@ class CppWrapperCpu(WrapperCodeGen): self.prefix.writeline( f"constants_info_[{idx}].stride = {{{stride_str}}};" ) + self.prefix.writeline( + f"constants_info_[{idx}].layout = static_cast({self.codegen_layout(tensor.layout)});" + ) + + if tensor.is_mkldnn: + opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md( + tensor + ) + assert ( + opaque_metadata_tensor.dim() == 1 + ), "Expect opaque_metadata_tensor to be 1-D" + + opaque_metadata_list = opaque_metadata_tensor.tolist() + opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list) + self.prefix.writeline( + f"constants_info_[{idx}].opaque_metadata = {opaque_metadata_str};" + ) if name in V.graph.dynamo_flat_name_to_original_fqn: original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get( name, name @@ -877,6 +907,8 @@ class CppWrapperCpu(WrapperCodeGen): cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});") for device in self.used_cached_devices: cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});") + for layout in self.used_cached_layouts: + cached_dtypes_buffer.writeline(f"CACHE_TORCH_LAYOUT({layout});") cached_dtypes_buffer.splice(self.prefix) self.prefix = cached_dtypes_buffer @@ -1493,6 +1525,14 @@ class CppWrapperCpu(WrapperCodeGen): else: return DTYPE_TO_ATEN[dtype] + def codegen_layout(self, layout): + if config.abi_compatible: + layout_str = str(layout).split(".")[-1] + self.used_cached_layouts.add(layout_str) + return f"cached_torch_layout_{layout_str}" + else: + return LAYOUT_TO_ATEN[layout] + @functools.lru_cache(None) def codegen_int_array_var( self, diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 4ca9879d94a8..3edb4a397932 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -18,7 +18,7 @@ from ..pattern_matcher import ( KeywordArg, MULTIPLE, ) -from ..virtualized import ops +from ..virtualized import ops, V from .freezing_patterns import register_freezing_graph_pattern from .post_grad import register_lowering_pattern from .quantization import ( @@ -1146,9 +1146,18 @@ if torch._C._has_mkldnn: if has_free_symbols(batch_size) else batch_size, ) + # MKL packed matrix can't be copied to a different address because the internal implementation + # depends on the alignment of internally-stored metadata. + # In aot mode, we need to firstly save the packed weight, when loading it, + # it will be in a different address which doesn't work. + # Disable MKL prepack linear in AOT mode packed_weight_op = ( mkldnn._reorder_linear_weight - if (is_lp_weight or mkldnn._is_mkldnn_acl_supported()) + if ( + is_lp_weight + or mkldnn._is_mkldnn_acl_supported() + or V.aot_compilation is True + ) else torch.ops.mkl._mkl_reorder_linear_weight ) packed_weight_node = graph.create_node( @@ -1156,7 +1165,11 @@ if torch._C._has_mkldnn: ) packed_linear_inputs: Tuple[Any, ...] = (input, packed_weight_node) - if is_lp_weight or mkldnn._is_mkldnn_acl_supported(): + if ( + is_lp_weight + or mkldnn._is_mkldnn_acl_supported() + or V.aot_compilation is True + ): packed_linear_inputs += (bias, "none", [], "") packed_linear_op = mkldnn._linear_pointwise.default else: diff --git a/torch/csrc/inductor/aoti_runtime/model.h b/torch/csrc/inductor/aoti_runtime/model.h index 7ea53dc24b41..d9f78dcc5b78 100644 --- a/torch/csrc/inductor/aoti_runtime/model.h +++ b/torch/csrc/inductor/aoti_runtime/model.h @@ -222,8 +222,17 @@ class AOTInductorModelBase { auto size = this->constant_shape(i); auto stride = this->constant_stride(i); auto offset = this->constant_offset(i); + auto layout = this->constant_layout(i); + auto opaque_metadata_ptr = this->opaque_metadata(i); + auto opaque_metadata_size = this->opaque_metadata_size(i); AtenTensorHandle tensor_handle; +#ifdef AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1 + // When opaque_metadata_size is not 0, we need to have the + // aoti_torch_create_tensor_from_blob_v2 available + AOTI_RUNTIME_CHECK( + opaque_metadata_size == 0, + "Expect opaque_metadata_size to be 0 when AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1 is defined"); AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob( internal_ptr, ndim, @@ -234,6 +243,21 @@ class AOTInductorModelBase { device_type_, device_idx_, &tensor_handle)); +#else + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_v2( + internal_ptr, + ndim, + size, + stride, + offset, + dtype, + device_type_, + device_idx_, + &tensor_handle, + layout, + opaque_metadata_ptr, + opaque_metadata_size)); +#endif // AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1 constants_map_->emplace(std::move(name), tensor_handle); } if (constants_map_) { @@ -340,6 +364,10 @@ class AOTInductorModelBase { return constants_info_.at(idx).dtype; } + int32_t constant_layout(int64_t idx) const { + return constants_info_.at(idx).layout; + } + size_t constant_offset(int64_t idx) const { return constants_info_.at(idx).offset; } @@ -352,6 +380,14 @@ class AOTInductorModelBase { return constants_info_.at(idx).original_fqn; } + const uint8_t* opaque_metadata(int64_t idx) const { + return constants_info_.at(idx).opaque_metadata.data(); + } + + size_t opaque_metadata_size(int64_t idx) { + return constants_info_.at(idx).opaque_metadata.size(); + } + bool constant_from_folded(int64_t idx) const { return constants_info_.at(idx).from_folded; } @@ -485,6 +521,9 @@ class AOTInductorModelBase { int32_t dtype; int64_t offset; size_t data_size; + int32_t layout; + std::vector opaque_metadata; + int64_t opaque_metadata_size; const char* original_fqn = nullptr; bool from_folded; }; diff --git a/torch/csrc/inductor/aoti_runtime/utils.h b/torch/csrc/inductor/aoti_runtime/utils.h index 8020004b06bc..f7af5ffcfc70 100644 --- a/torch/csrc/inductor/aoti_runtime/utils.h +++ b/torch/csrc/inductor/aoti_runtime/utils.h @@ -174,4 +174,7 @@ inline AtenTensorHandle wrap_with_raii_handle_if_needed( static auto cached_torch_device_type_##device = \ aoti_torch_device_type_##device() +#define CACHE_TORCH_LAYOUT(layout) \ + static auto cached_torch_layout_##layout = aoti_torch_layout_##layout() + } // namespace torch::aot_inductor diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 6fa7df75c056..ba716e213a0f 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -112,6 +112,9 @@ AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex32(); AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex64(); AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex128(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_strided(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout__mkldnn(); + // Functions for converting a single-element tensor to a scalar value AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_float32(AtenTensorHandle tensor, float* ret_value); @@ -270,6 +273,20 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob( AtenTensorHandle* ret // returns new reference ); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AtenTensorHandle* ret, // returns new reference + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size); + AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag( AtenTensorHandle weight, AtenTensorHandle indices, diff --git a/torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp b/torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp new file mode 100644 index 000000000000..7f0811f0d88b --- /dev/null +++ b/torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp @@ -0,0 +1,49 @@ +#include +#include + +#if AT_MKLDNN_ENABLED() +#include +#include +#endif + +namespace torch { +namespace aot_inductor { + +#if AT_MKLDNN_ENABLED() + +void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor) { + return reinterpret_cast( + at::native::data_ptr_from_mkldnn(*mkldnn_tensor)); +} + +at::Tensor mkldnn_tensor_from_data_ptr( + void* data_ptr, + at::IntArrayRef dims, + at::ScalarType dtype, + at::Device device, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size) { + return at::native::mkldnn_tensor_from_data_ptr( + data_ptr, dims, dtype, device, opaque_metadata, opaque_metadata_size); +} + +#else + +void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor) { + TORCH_CHECK(false, "MKL-DNN build is disabled"); +} + +at::Tensor mkldnn_tensor_from_data_ptr( + void* data_ptr, + at::IntArrayRef dims, + at::ScalarType dtype, + at::Device device, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size) { + TORCH_CHECK(false, "MKL-DNN build is disabled"); +} + +#endif + +} // namespace aot_inductor +} // namespace torch diff --git a/torch/csrc/inductor/aoti_torch/mkldnn_tensor.h b/torch/csrc/inductor/aoti_torch/mkldnn_tensor.h new file mode 100644 index 000000000000..08712833d8ae --- /dev/null +++ b/torch/csrc/inductor/aoti_torch/mkldnn_tensor.h @@ -0,0 +1,19 @@ +#pragma once + +#include + +namespace torch { +namespace aot_inductor { + +void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor); + +at::Tensor mkldnn_tensor_from_data_ptr( + void* data_ptr, + at::IntArrayRef dims, + at::ScalarType dtype, + at::Device device, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size); + +} // namespace aot_inductor +} // namespace torch diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 79cea0cb45ec..6f93407aa467 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -1,8 +1,10 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -90,6 +92,14 @@ AOTI_TORCH_DTYPE_IMPL(complex64, ComplexFloat) AOTI_TORCH_DTYPE_IMPL(complex128, ComplexDouble) #undef AOTI_TORCH_DTYPE_IMPL +int32_t aoti_torch_layout_strided() { + return (int32_t)at::kStrided; +} + +int32_t aoti_torch_layout__mkldnn() { + return (int32_t)at::kMkldnn; +} + #define AOTI_TORCH_ITEM_IMPL(dtype, ctype) \ AOTITorchError aoti_torch_item_##dtype( \ AtenTensorHandle tensor, ctype* ret_value) { \ @@ -154,7 +164,11 @@ AOTITorchError aoti_torch_get_data_ptr( void** ret_data_ptr) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ at::Tensor* t = tensor_handle_to_tensor_pointer(tensor); - *ret_data_ptr = t->data_ptr(); + if (t->is_mkldnn()) { + *ret_data_ptr = data_ptr_from_mkldnn(t); + } else { + *ret_data_ptr = t->data_ptr(); + } }); } @@ -325,6 +339,48 @@ AOTITorchError aoti_torch_create_tensor_from_blob( }); } +AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AtenTensorHandle* ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + if (layout == static_cast(at::kMkldnn)) { + c10::IntArrayRef sizes(sizes_ptr, ndim); + c10::IntArrayRef strides(strides_ptr, ndim); + c10::Device device = c10_device(device_type, device_index); + // get a mkldnn tensor wrapped by a torch Tensor(OpaqueTensorImpl), + // which used by later mkldnn op. + *ret_new_tensor = new_tensor_handle(mkldnn_tensor_from_data_ptr( + data, + sizes, + static_cast(dtype), + device, + opaque_metadata, + opaque_metadata_size)); + } else { + aoti_torch_create_tensor_from_blob( + data, + ndim, + sizes_ptr, + strides_ptr, + storage_offset, + dtype, + device_type, + device_index, + ret_new_tensor); + } + }); +} + AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag( AtenTensorHandle weight, AtenTensorHandle indices,