Upgrade submodule oneDNN to v3.7.1 (#148293)

This PR is to upgrade submodule oneDNN to v3.7.1.

## Improvements

- Improved performance of convolution and matmul primitives on Intel Xeon processors with Intel AMX instruction set support (formerly Sapphire Rapids and Granite Rapids).
- Improved performance of int8 and fp32 forward convolution primitive on processors with Intel AVX2 instruction set support.
- Improved performance of fp8 matmul primitives with bf16 and fp16 bias data type on Intel Xeon processors with Intel AMX instruction set support (formerly Sapphire Rapids and Granite Rapids).
- Introduced initial optimizations for Intel GPUs based on Xe3 architecture.
- Added bfloat16 support for SDPA, implemented fp16 and bf16 gemm kernel in SDPA.
- Fixed f16 matmul accuracy, the issue of SDPA cannot dispatched to ukernel, bf16/fp16/fp32 conv performance, INT8 Kernel trigger page fault, deconvolution precision issue on complex128 and fp64 and gemm correctness issue in float16 issues.
- Improved bf16 matmul performance with fp32 destination with Arm Compute Library (ACL).
- Improved bf16 to fp32 reorder performance.
- Improved bf16 reorder performance.
- Improved bf16 convolution with ACL.

Fixes https://github.com/pytorch/pytorch/issues/136348.

## Validation results on CPU

1. NLP models accuracy/inference/training
![image](https://github.com/user-attachments/assets/859279b8-1631-4268-b226-7de9ac5870d8)

![image](https://github.com/user-attachments/assets/30ec7151-41ca-482a-9d2d-0c4850e75bab)

2. Torchbench cpu userbenchmark inference & training

![image](https://github.com/user-attachments/assets/71c9807c-caf9-4385-9990-d2ab637031cd)

3. Inductor quantization

![image](https://github.com/user-attachments/assets/3d2a3bd3-82fa-4566-8050-7ea5d6b61675)

4. Dynamo benchmarks
![image](https://github.com/user-attachments/assets/554ecce3-c85c-4a0e-88f1-2e73983c5dcd)
![image](https://github.com/user-attachments/assets/148c88f8-4367-4428-bb54-ce8a4deefd1b)
![image](https://github.com/user-attachments/assets/f2e744f4-d710-4699-acf4-1f130ecfadf1)
![image](https://github.com/user-attachments/assets/97128b80-4d0e-495a-aeda-dde3e70c96fd)
![image](https://github.com/user-attachments/assets/a9afce37-684c-45c0-b938-6dd7e0383805)
![image](https://github.com/user-attachments/assets/b8714236-9681-4fbe-8d98-be93deedab88)
![image](https://github.com/user-attachments/assets/4423061f-d133-45ba-98bd-d2f739e50431)
![image](https://github.com/user-attachments/assets/7955da10-3d23-493e-99fa-658f7f40035b)

## Validation results on XPU
Accuracy is same as baseline. Performance is shown below.
![image](https://github.com/user-attachments/assets/7645304d-5b1d-43f9-b840-9f846ed380a0)

## Validation results on ARM
![image](https://github.com/user-attachments/assets/080f7c02-0238-436f-ad20-5a9e3f6aafbb)
![image](https://github.com/user-attachments/assets/443742aa-ca61-41de-ae80-5d4c65cd0c87)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148293
Approved by: https://github.com/mingfeima, https://github.com/atalman
This commit is contained in:
Jiang, Yanbing
2025-03-04 13:56:45 +00:00
committed by PyTorch MergeBot
parent f339e41a38
commit f2f25a5444
3 changed files with 38 additions and 6 deletions

View File

@ -40,7 +40,16 @@ class BinaryFoldingTemplate(TestCase):
@skipCUDAIf(TEST_CUDNN, "CUDNN has accuracy issues for this test")
def test_conv_binary_folding(self):
@torch.no_grad()
def test_conv_fusion(use_bias, module, op, scalar, add_tensor, expect_success):
def test_conv_fusion(
use_bias,
module,
op,
scalar,
add_tensor,
expect_success,
rtol=None,
atol=None,
):
class ConvOp(nn.Module):
__constants__ = ["use_scalar"]
@ -82,7 +91,7 @@ class BinaryFoldingTemplate(TestCase):
inp = torch.rand(inps).to(self.device)
out_eager = mod_eager(inp)
out_optimized = out_optimized(inp)
self.assertEqual(out_optimized, out_eager)
self.assertEqual(out_optimized, out_eager, rtol=rtol, atol=atol)
if expect_success:
self.assertEqual(counters["inductor"]["binary_folding"], 1)
else:
@ -137,6 +146,12 @@ class BinaryFoldingTemplate(TestCase):
False,
add_tensor=torch.tensor([2]).to(torch.float64).to(self.device),
expect_success=False,
# This test is for float32 conv fusion with different dtype, like float64,
# which will not be fused. The tolerance of float64 is too tight
# for float32 conv post fusion with float64 tensor. Will relax the tolerance
# for this case.
rtol=1.3e-6,
atol=1e-5,
)
@inductor_config.patch({"freezing": True})

View File

@ -5,17 +5,20 @@ _DNNL_RUNTIME_OMP = {
"#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_OMP",
"#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_OMP",
"#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE",
"#cmakedefine DNNL_GPU_VENDOR DNNL_VENDOR_${DNNL_GPU_VENDOR}": "/* undef DNNL_GPU_VENDOR */",
"#cmakedefine DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE": "/* undef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE */",
"#cmakedefine DNNL_WITH_SYCL": "/* #undef DNNL_WITH_SYCL */",
"#cmakedefine DNNL_WITH_LEVEL_ZERO": "/* #undef DNNL_WITH_LEVEL_ZERO */",
"#cmakedefine DNNL_SYCL_CUDA": "/* #undef DNNL_SYCL_CUDA */",
"#cmakedefine DNNL_SYCL_HIP": "/* #undef DNNL_SYCL_HIP */",
"#cmakedefine DNNL_SYCL_GENERIC": "/* #undef DNNL_SYCL_GENERIC */",
"#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER",
"#cmakedefine DNNL_EXPERIMENTAL_UKERNEL": "/* undef DNNL_EXPERIMENTAL_UKERNEL */",
"#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL",
"#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#undef DNNL_EXPERIMENTAL_SPARSE",
"#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH",
"#cmakedefine DNNL_EXPERIMENTAL_PROFILING": "#undef DNNL_EXPERIMENTAL_PROFILING",
"#cmakedefine DNNL_DISABLE_GPU_REF_KERNELS": "#undef DNNL_DISABLE_GPU_REF_KERNELS",
"#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1",
"#cmakedefine01 BUILD_INFERENCE": "#define BUILD_INFERENCE 0",
"#cmakedefine01 BUILD_PRIMITIVE_ALL": "#define BUILD_PRIMITIVE_ALL 1",
@ -36,6 +39,7 @@ _DNNL_RUNTIME_OMP = {
"#cmakedefine01 BUILD_REORDER": "#define BUILD_REORDER 0",
"#cmakedefine01 BUILD_RESAMPLING": "#define BUILD_RESAMPLING 0",
"#cmakedefine01 BUILD_RNN": "#define BUILD_RNN 0",
"#cmakedefine01 BUILD_SDPA": "#define BUILD_SDPA 0",
"#cmakedefine01 BUILD_SHUFFLE": "#define BUILD_SHUFFLE 0",
"#cmakedefine01 BUILD_SOFTMAX": "#define BUILD_SOFTMAX 0",
"#cmakedefine01 BUILD_SUM": "#define BUILD_SUM 0",
@ -52,6 +56,7 @@ _DNNL_RUNTIME_OMP = {
"#cmakedefine01 BUILD_XEHPC": "#define BUILD_XEHPC 0",
"#cmakedefine01 BUILD_XEHP": "#define BUILD_XEHP 0",
"#cmakedefine01 BUILD_XE2": "#define BUILD_XE2 0",
"#cmakedefine01 BUILD_XE3": "#define BUILD_XE3 0",
"#cmakedefine01 BUILD_GEMM_KERNELS_ALL": "#define BUILD_GEMM_KERNELS_ALL 0",
"#cmakedefine01 BUILD_GEMM_KERNELS_NONE": "#define BUILD_GEMM_KERNELS_NONE 0",
"#cmakedefine01 BUILD_GEMM_SSE41": "#define BUILD_GEMM_SSE41 0",
@ -65,9 +70,8 @@ template_rule(
out = "include/oneapi/dnnl/dnnl_version.h",
substitutions = {
"@DNNL_VERSION_MAJOR@": "3",
"@DNNL_VERSION_MINOR@": "5",
"@DNNL_VERSION_PATCH@": "3",
"@DNNL_VERSION_HASH@": "66f0cb9eb66affd2da3bf5f8d897376f04aae6af",
"@DNNL_VERSION_MINOR@": "7",
"@DNNL_VERSION_PATCH@": "1",
},
)
@ -78,14 +82,23 @@ template_rule(
substitutions = _DNNL_RUNTIME_OMP,
)
template_rule(
name = "include_dnnl_version_hash",
src = "include/oneapi/dnnl/dnnl_version_hash.h.in",
out = "include/oneapi/dnnl/dnnl_version_hash.h",
substitutions = {"@DNNL_VERSION_HASH@": "8d263e693366ef8db40acc569cc7d8edf644556d",}
)
cc_library(
name = "mkl-dnn",
srcs = glob([
"src/common/*.cpp",
"src/cpu/**/*.cpp",
"src/cpu/**/**/*.cpp",
], exclude=[
"src/cpu/aarch64/**/*.cpp",
"src/cpu/rv64/**/*.cpp",
"src/cpu/sycl/**/*.cpp",
]),
hdrs = glob([
"include/oneapi/dnnl/*.h",
@ -94,16 +107,20 @@ cc_library(
"include/*.hpp",
"src/cpu/**/*.hpp",
"src/cpu/**/*.h",
"src/cpu/**/**/*.h",
"src/common/*.hpp",
"src/common/**/**/*.h",
"src/common/ittnotify/jitprofiling.h",
], exclude=[
"src/cpu/aarch64/**/*.hpp",
"src/cpu/aarch64/**/*.h",
"src/cpu/rv64/**/*.hpp",
"src/cpu/rv64/**/*.h",
"src/cpu/sycl/**/*.hpp",
]) + [
"include/oneapi/dnnl/dnnl_config.h",
"include/oneapi/dnnl/dnnl_version.h",
"include/oneapi/dnnl/dnnl_version_hash.h",
],
copts = [
"-DDNNL_DLL",