[Reland take-2] Add JIT graph fuser for oneDNN Graph API (v0.5)

Re-landing #68111/#74596

## Description
v0.5 PR of this [RFC](https://github.com/pytorch/pytorch/issues/49444).

On the basis of #50256, the below improvements are included:

 * The [v0.5 release branch](https://github.com/oneapi-src/oneDNN/releases/tag/graph-v0.5) of the oneDNN Graph API is used
 * The fuser now works with the profiling graph executor. We have inserted type check nodes to guard the profiled tensor properties.

 ### User API:
The optimization pass is disabled by default. Users could enable it by:

```
 torch.jit.enable_onednn_fusion(True)
```
`torch.jit.freeze` should be used after tracing (recommended) or scripting a model.

 ### Performance:
 [pytorch/benchmark](https://github.com/pytorch/benchmark) tool is used to compare the performance:

 * SkyLake 8180 (1 socket of 28 cores):
   ![image](https://user-images.githubusercontent.com/65992142/151162305-05e44425-a24e-4d5e-94e1-743b40b87a8c.png)
* SkyLake 8180 (single thread):
   ![image](https://user-images.githubusercontent.com/65992142/151162528-69f90b79-d08d-46b8-8775-d80a6ccbce8a.png)
   * By mapping hardswish to oneDNN Graph, it’s 8% faster than PyTorch JIT (NNC + OFI)
   ** We expect performance gain after mapping transpose, contiguous & view to oneDNN graph ops

 ### Directory structure of the integration code
 Fuser-related code is placed under:

 ```
 torch/csrc/jit/codegen/onednn/
 ```

 Optimization pass registration is done in:

 ```
 torch/csrc/jit/passes/onednn_graph_fuser.h
 ```

 CMake for the integration code is in:

 ```
 caffe2/CMakeLists.txt
 cmake/public/mkldnn.cmake
 cmake/Modules/FindMKLDNN.cmake
 ```

 ## Limitations
 * In this PR, we only support Pytorch-oneDNN-Graph integration on Linux platform. Support on Windows and MacOS will be enabled as a next step.
 * We have only optimized the inference use-case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76622
Approved by: https://github.com/eellison
This commit is contained in:
sanchitintel
2022-05-05 16:57:03 +00:00
committed by PyTorch MergeBot
parent 94fc92f288
commit 4ee29d6033
38 changed files with 3283 additions and 73 deletions

View File

@ -682,6 +682,8 @@ if(USE_FBGEMM AND ((CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND CMAKE_SIZEOF_VO
set(USE_FBGEMM OFF)
endif()
set(BUILD_ONEDNN_GRAPH OFF)
include(cmake/Dependencies.cmake)
if(USE_CUDA AND (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 10.2) AND (CMAKE_HOST_SYSTEM_NAME MATCHES "Windows"))

View File

@ -43,6 +43,8 @@ namespace c10 {
_(prim, FusionGroup) \
_(prim, CudaFusionGroup) \
_(prim, CudaFusionGuard) \
_(prim, oneDNNFusionGroup) \
_(prim, oneDNNFusionGuard) \
_(prim, FunctionalGraph) \
_(prim, add_optional) \
_(prim, view_copy) \
@ -319,6 +321,7 @@ namespace c10 {
_(attr, cache_id) \
_(attr, new_axis) \
_(attr, warn_id) \
_(attr, output_layouts) \
_(attr, allowzero) \
_(attr, seen_none) \
_(attr, overload_name)

View File

@ -657,6 +657,26 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/passes/frozen_conv_add_relu_fusion.cpp PROPERTIES COMPILE_FLAGS "-DUSE_CUDA=1")
endif()
if(USE_MLCOMPUTE)
include(../mlc/mlc_build.cmake)
endif()
if(BUILD_ONEDNN_GRAPH)
list(APPEND Caffe2_CPU_SRCS
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/graph_fuser.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/graph_rewriter.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/graph_helper.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/register_interface.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/interface.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/kernel.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/defer_size_check.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/layout_propagation.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/prepare_binary.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/guard_shape.cpp
)
endif()
if(USE_ROCM)
list(APPEND Caffe2_HIP_SRCS ${Caffe2_GPU_HIP_JIT_FUSERS_SRCS})
if(USE_NCCL)

View File

@ -12,86 +12,118 @@
# MKLDNN_USE_NATIVE_ARCH : Whether native CPU instructions should be used in MKLDNN. This should be turned off for
# general packaging to avoid incompatible CPU instructions. Default: OFF.
IF (NOT MKLDNN_FOUND)
IF(NOT MKLDNN_FOUND)
SET(MKLDNN_LIBRARIES)
SET(MKLDNN_INCLUDE_DIR)
SET(MKLDNN_LIBRARIES)
SET(MKLDNN_INCLUDE_DIR)
SET(IDEEP_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep")
SET(MKLDNN_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep/mkl-dnn/third_party/oneDNN")
IF(NOT APPLE AND NOT WIN32 AND NOT BUILD_LITE_INTERPRETER)
MESSAGE("-- Will build oneDNN Graph")
SET(LLGA_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep/mkl-dnn")
SET(BUILD_ONEDNN_GRAPH ON)
ENDIF(NOT APPLE AND NOT WIN32 AND NOT BUILD_LITE_INTERPRETER)
SET(IDEEP_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep")
SET(MKLDNN_ROOT "${IDEEP_ROOT}/mkl-dnn/third_party/oneDNN")
FIND_PACKAGE(BLAS)
FIND_PATH(IDEEP_INCLUDE_DIR ideep.hpp PATHS ${IDEEP_ROOT} PATH_SUFFIXES include)
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include)
IF(NOT MKLDNN_INCLUDE_DIR)
EXECUTE_PROCESS(COMMAND git${CMAKE_EXECUTABLE_SUFFIX} submodule update --init --jobs 0 mkl-dnn WORKING_DIRECTORY ${IDEEP_ROOT})
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include)
ENDIF(NOT MKLDNN_INCLUDE_DIR)
IF(BUILD_ONEDNN_GRAPH)
FIND_PATH(LLGA_INCLUDE_DIR oneapi/dnnl/dnnl_graph.hpp PATHS ${LLGA_ROOT} PATH_SUFFIXES include)
ENDIF(BUILD_ONEDNN_GRAPH)
FIND_PACKAGE(BLAS)
FIND_PATH(IDEEP_INCLUDE_DIR ideep.hpp PATHS ${IDEEP_ROOT} PATH_SUFFIXES include)
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include)
IF (NOT MKLDNN_INCLUDE_DIR)
EXECUTE_PROCESS(COMMAND git${CMAKE_EXECUTABLE_SUFFIX} submodule update --init --jobs 0 mkl-dnn WORKING_DIRECTORY ${IDEEP_ROOT})
FIND_PATH(MKLDNN_INCLUDE_DIR mkldnn.hpp mkldnn.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include)
ENDIF(NOT MKLDNN_INCLUDE_DIR)
IF(NOT IDEEP_INCLUDE_DIR OR NOT MKLDNN_INCLUDE_DIR)
MESSAGE(STATUS "MKLDNN source files not found!")
RETURN()
ENDIF(NOT IDEEP_INCLUDE_DIR OR NOT MKLDNN_INCLUDE_DIR)
LIST(APPEND MKLDNN_INCLUDE_DIR ${IDEEP_INCLUDE_DIR})
IF(BUILD_ONEDNN_GRAPH)
LIST(APPEND MKLDNN_INCLUDE_DIR ${LLGA_INCLUDE_DIR})
ENDIF(BUILD_ONEDNN_GRAPH)
IF(MKL_FOUND)
ADD_DEFINITIONS(-DIDEEP_USE_MKL)
# Append to mkldnn dependencies
LIST(APPEND MKLDNN_LIBRARIES ${MKL_LIBRARIES})
LIST(APPEND MKLDNN_INCLUDE_DIR ${MKL_INCLUDE_DIR})
ELSE(MKL_FOUND)
SET(MKLDNN_USE_MKL "NONE" CACHE STRING "" FORCE)
ENDIF(MKL_FOUND)
IF (NOT IDEEP_INCLUDE_DIR OR NOT MKLDNN_INCLUDE_DIR)
MESSAGE(STATUS "MKLDNN source files not found!")
RETURN()
ENDIF(NOT IDEEP_INCLUDE_DIR OR NOT MKLDNN_INCLUDE_DIR)
LIST(APPEND MKLDNN_INCLUDE_DIR ${IDEEP_INCLUDE_DIR})
IF(MKL_FOUND)
ADD_DEFINITIONS(-DIDEEP_USE_MKL)
# Append to mkldnn dependencies
LIST(APPEND MKLDNN_LIBRARIES ${MKL_LIBRARIES})
LIST(APPEND MKLDNN_INCLUDE_DIR ${MKL_INCLUDE_DIR})
ELSE(MKL_FOUND)
SET(MKLDNN_USE_MKL "NONE" CACHE STRING "" FORCE)
ENDIF(MKL_FOUND)
SET(MKL_cmake_included TRUE)
IF(NOT MKLDNN_CPU_RUNTIME)
SET(MKLDNN_CPU_RUNTIME "OMP" CACHE STRING "")
ELSEIF(MKLDNN_CPU_RUNTIME STREQUAL "TBB")
IF(USE_TBB)
MESSAGE(STATUS "MKL-DNN is using TBB")
SET(MKL_cmake_included TRUE)
IF (NOT MKLDNN_CPU_RUNTIME)
SET(MKLDNN_CPU_RUNTIME "OMP" CACHE STRING "")
ELSEIF (MKLDNN_CPU_RUNTIME STREQUAL "TBB")
IF (USE_TBB)
MESSAGE(STATUS "MKL-DNN is using TBB")
SET(TBB_cmake_included TRUE)
SET(Threading_cmake_included TRUE)
SET(TBB_cmake_included TRUE)
SET(Threading_cmake_included TRUE)
SET(DNNL_CPU_THREADING_RUNTIME ${MKLDNN_CPU_RUNTIME})
INCLUDE_DIRECTORIES(${TBB_INCLUDE_DIR})
LIST(APPEND EXTRA_SHARED_LIBS TBB::tbb)
ELSE()
MESSAGE(FATAL_ERROR "MKLDNN_CPU_RUNTIME is set to TBB but TBB is not used")
ENDIF()
ENDIF()
MESSAGE(STATUS "MKLDNN_CPU_RUNTIME = ${MKLDNN_CPU_RUNTIME}")
SET(MKLDNN_CPU_RUNTIME ${MKLDNN_CPU_RUNTIME} CACHE STRING "" FORCE)
SET(DNNL_BUILD_TESTS FALSE CACHE BOOL "" FORCE)
SET(DNNL_BUILD_EXAMPLES FALSE CACHE BOOL "" FORCE)
SET(DNNL_LIBRARY_TYPE STATIC CACHE STRING "" FORCE)
SET(DNNL_ENABLE_PRIMITIVE_CACHE TRUE CACHE BOOL "" FORCE)
IF(MKLDNN_USE_NATIVE_ARCH) # Disable HostOpts in MKLDNN unless MKLDNN_USE_NATIVE_ARCH is set.
SET(DNNL_ARCH_OPT_FLAGS "HostOpts" CACHE STRING "" FORCE)
ELSE()
IF(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
IF(CPU_INTEL)
SET(DNNL_ARCH_OPT_FLAGS "-msse4" CACHE STRING "" FORCE)
SET(DNNL_CPU_THREADING_RUNTIME ${MKLDNN_CPU_RUNTIME})
INCLUDE_DIRECTORIES(${TBB_INCLUDE_DIR})
LIST(APPEND EXTRA_SHARED_LIBS TBB::tbb)
ELSE()
MESSAGE(FATAL_ERROR "MKLDNN_CPU_RUNTIME is set to TBB but TBB is not used")
ENDIF()
ELSE()
SET(DNNL_ARCH_OPT_FLAGS "" CACHE STRING "" FORCE)
ENDIF()
ENDIF()
MESSAGE(STATUS "MKLDNN_CPU_RUNTIME = ${MKLDNN_CPU_RUNTIME}")
ADD_SUBDIRECTORY(${MKLDNN_ROOT})
IF(NOT TARGET dnnl)
MESSAGE("Failed to include MKL-DNN target")
RETURN()
ENDIF(NOT TARGET dnnl)
SET(MKLDNN_CPU_RUNTIME ${MKLDNN_CPU_RUNTIME} CACHE STRING "" FORCE)
SET(DNNL_BUILD_TESTS FALSE CACHE BOOL "" FORCE)
SET(DNNL_BUILD_EXAMPLES FALSE CACHE BOOL "" FORCE)
SET(DNNL_LIBRARY_TYPE STATIC CACHE STRING "" FORCE)
SET(DNNL_ENABLE_PRIMITIVE_CACHE TRUE CACHE BOOL "" FORCE)
IF(BUILD_ONEDNN_GRAPH)
SET(DNNL_GRAPH_LIBRARY_TYPE STATIC CACHE STRING "" FORCE)
ENDIF(BUILD_ONEDNN_GRAPH)
IF(MKLDNN_USE_NATIVE_ARCH) # Disable HostOpts in MKLDNN unless MKLDNN_USE_NATIVE_ARCH is set.
SET(DNNL_ARCH_OPT_FLAGS "HostOpts" CACHE STRING "" FORCE)
ELSE()
IF(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
IF(CPU_INTEL)
SET(DNNL_ARCH_OPT_FLAGS "-msse4" CACHE STRING "" FORCE)
ENDIF()
ELSE()
SET(DNNL_ARCH_OPT_FLAGS "" CACHE STRING "" FORCE)
ENDIF()
ENDIF()
IF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC)
TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-maybe-uninitialized)
TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-strict-overflow)
TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-error=strict-overflow)
ENDIF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC)
LIST(APPEND MKLDNN_LIBRARIES dnnl)
IF(BUILD_ONEDNN_GRAPH)
ADD_SUBDIRECTORY(${LLGA_ROOT})
IF(NOT TARGET dnnl_graph)
MESSAGE("Failed to include LLGA target")
RETURN()
ENDIF(NOT TARGET dnnl_graph)
SET(MKLDNN_FOUND TRUE)
MESSAGE(STATUS "Found MKL-DNN: TRUE")
IF(CMAKE_COMPILER_IS_GNUCC)
TARGET_COMPILE_OPTIONS(dnnl_graph PRIVATE -Wno-maybe-uninitialized)
TARGET_COMPILE_OPTIONS(dnnl_graph PRIVATE -Wno-strict-overflow)
TARGET_COMPILE_OPTIONS(dnnl_graph PRIVATE -Wno-error=strict-overflow)
ENDIF(CMAKE_COMPILER_IS_GNUCC)
ELSE(BUILD_ONEDNN_GRAPH)
ADD_SUBDIRECTORY(${MKLDNN_ROOT})
ENDIF(BUILD_ONEDNN_GRAPH)
IF(NOT TARGET dnnl)
MESSAGE("Failed to include MKL-DNN target")
RETURN()
ENDIF(NOT TARGET dnnl)
IF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC)
TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-maybe-uninitialized)
TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-strict-overflow)
TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-error=strict-overflow)
ENDIF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC)
LIST(APPEND MKLDNN_LIBRARIES ${MKL_OPENMP_LIBRARY})
IF(BUILD_ONEDNN_GRAPH)
LIST(APPEND MKLDNN_LIBRARIES "$<TARGET_FILE:dnnl_graph>")
ENDIF(BUILD_ONEDNN_GRAPH)
LIST(APPEND MKLDNN_LIBRARIES dnnl)
SET(MKLDNN_FOUND TRUE)
MESSAGE(STATUS "Found MKL-DNN: TRUE")
ENDIF(NOT MKLDNN_FOUND)

View File

@ -16,3 +16,15 @@ set_property(
set_property(
TARGET caffe2::mkldnn PROPERTY INTERFACE_LINK_LIBRARIES
${MKLDNN_LIBRARIES})
if(BUILD_ONEDNN_GRAPH)
if(NOT TARGET caffe2::dnnl_graph)
add_library(caffe2::dnnl_graph INTERFACE IMPORTED)
endif()
set_property(
TARGET caffe2::dnnl_graph PROPERTY INTERFACE_INCLUDE_DIRECTORIES
${MKLDNN_INCLUDE_DIR})
set_property(
TARGET caffe2::dnnl_graph PROPERTY INTERFACE_LINK_LIBRARIES
${MKLDNN_LIBRARIES})
endif()

View File

@ -61,6 +61,8 @@ Creating TorchScript Code
ScriptFunction
freeze
optimize_for_inference
enable_onednn_fusion
onednn_fusion_enabled
set_fusion_strategy
strict_fusion
save

519
test/test_jit_llga_fuser.py Normal file
View File

@ -0,0 +1,519 @@
# Owner(s): ["module: mkldnn"]
import torch
import unittest
import itertools
import torch.nn as nn
import torch.nn.functional as F
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import run_tests, TEST_SCIPY, IS_WINDOWS, IS_MACOS
LLGA_FUSION_GROUP = 'prim::oneDNNFusionGroup'
LLGA_NOT_ENABLED = not torch._C.has_mkldnn or IS_WINDOWS or IS_MACOS
def warmup_forward(f, *args, profiling_count=2):
for i in range(profiling_count):
results = f(*args)
return results
class JitLlgaTestCase(JitTestCase):
def setUp(self):
torch.jit.enable_onednn_fusion(True)
def tearDown(self):
torch.jit.enable_onednn_fusion(False)
def checkTrace(self, m, x, *args, **kwargs):
if isinstance(m, torch.nn.Module):
m.eval()
with torch.no_grad(), \
torch._jit_internal._disable_emit_hooks():
traced = torch.jit.trace(m, x)
if isinstance(m, torch.nn.Module):
traced = torch.jit.freeze(traced)
warmup_forward(traced, *x)
fwd_graph = traced.graph_for(*x)
ref_o = m(*x)
jit_o = traced(*x)
self.assertEqual(jit_o, ref_o)
return traced, fwd_graph
def assertFused(self, graph, fused_patterns):
for pat in fused_patterns:
self.assertGraphContainsExactly(graph, pat, 0)
try:
import torchvision
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
except RuntimeError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, 'no torchvision')
def get_eltwise_fn(name):
if hasattr(torch, name):
return getattr(torch, name)
elif hasattr(F, name):
return getattr(F, name)
else:
raise NameError('Eltwise function %s not found' % name)
@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
class TestOp(JitLlgaTestCase):
def test_conv2d(self):
for [spatial, in_channels, out_channels, kernel, padding, stride, dilation, g, bias] in itertools.product(
[7, 8],
[8, 15],
[7, 16],
[3, 4],
[0, 2],
[1, 2],
[1, 2],
[1, 2],
[True, False]):
m = nn.Conv2d(in_channels=in_channels * g,
out_channels=out_channels * g,
kernel_size=kernel,
padding=padding,
stride=stride,
dilation=dilation,
groups=g,
bias=bias)
x = torch.rand(1, in_channels * g, spatial, spatial)
_, graph = self.checkTrace(m, [x])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
def test_bn2d(self):
m = nn.BatchNorm2d(32).eval()
x = torch.rand(1, 32, 28, 28)
_, graph = self.checkTrace(m, [x])
# single-op partition shouldn't be created for softmax
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
def test_eltwise(self):
class M(nn.Module):
def __init__(self, eltwise_fn):
super(M, self).__init__()
self.eltwise = eltwise_fn
def forward(self, x):
return self.eltwise(x)
for eltwise in ['relu', 'gelu']:
eltwise_fn = get_eltwise_fn(eltwise)
m = M(eltwise_fn)
x = torch.rand(1, 32, 28, 28)
_, graph = self.checkTrace(m, [x])
# single-op partition shouldn't be created.
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
def test_max_pool2d(self):
for [spatial, kernel, padding, stride, dilation, ceil_mode] in itertools.product(
[15, 16, 17, 18, 19],
[4, 5],
[0, 1, 2],
[1, 2], # [1, 2, 4], TODO: fix issue in pad calculation
[1], # [1, 2], TODO: backend support for dilation
[True, False]):
m = nn.MaxPool2d(kernel_size=kernel,
stride=stride,
padding=padding,
dilation=dilation,
ceil_mode=ceil_mode)
x = torch.rand(1, 4, spatial, spatial)
_, graph = self.checkTrace(m, [x])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
def test_avg_pool2d(self):
for [spatial, kernel, padding, stride, ceil_mode, count_include_pad] in itertools.product(
[15, 16, 17, 18, 19],
[4, 5],
[0, 1, 2],
[1, 2, 4],
[False], # TODO: oneDNN Graph does not fully support ceil_mode=True
[True, False]):
m = nn.AvgPool2d(kernel_size=kernel,
stride=stride,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad)
x = torch.rand(1, 4, spatial, spatial)
_, graph = self.checkTrace(m, [x])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
def test_variable_kernel_avg_pool2d(self):
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
def forward(self, x):
x = F.avg_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=0, count_include_pad=False)
return x
x = torch.randn(1, 1000, 1, 1)
m = M()
_, graph = self.checkTrace(m, [x])
# kernel_size is not Constant, shouldn't have any LLGA_FUSION_GROUP
# TODO: with shape specialization, should have 1 LLGA_FUSION_GROUP
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
def test_softmax(self):
for dim in [-4, -3, -2, -1, 0, 1, 2, 3]:
m = nn.Softmax(dim=dim)
x = torch.rand(8, 12, 12, 12)
_, graph = self.checkTrace(m, [x])
# single-op partition shouldn't be created for softmax
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
def test_linear(self):
for bias in [True, False]:
x = torch.rand(32, 28)
m = torch.nn.Linear(in_features=28, out_features=64, bias=bias)
_, graph = self.checkTrace(m, [x])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
self.assertFused(graph, ['aten::linear'])
def _gen_binary_inputs(self, gen_permute=True):
for xshape, yshape in [
[[1, 32, 28, 28], [1, 32, 28, 28]],
[[1, 32, 28, 28], [1, 1, 28, 28]],
[[1, 32, 28, 28], [28]],
[[1, 32, 28, 28], [1]],
]:
yield torch.rand(xshape), torch.rand(yshape)
if gen_permute and xshape != yshape:
yield torch.rand(yshape), torch.rand(xshape)
def test_add(self):
def forward_add(x, y):
return torch.add(x, y, alpha=2)
for x, y in self._gen_binary_inputs():
_, graph = self.checkTrace(forward_add, [x, y])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
def test_add_scalar(self):
def add_scalar(x):
return 42 + x + 3.14
x = torch.rand(32, 32)
_, graph = self.checkTrace(add_scalar, [x])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
def test_addmm(self):
def addmm(x, y, z):
# alpha and beta are 1, by default
return torch.addmm(z, x, y)
x = torch.rand(64, 32)
y = torch.rand(32, 32)
z = torch.rand(64, 32)
_, graph = self.checkTrace(addmm, [x, y, z])
# single-op partition should be created for matmul with bias.
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
def test_mul(self):
def forward_mul(x, y):
return torch.mul(x, y) * 3
for x, y in self._gen_binary_inputs():
_, graph = self.checkTrace(forward_mul, [x, y])
# single-op partitions shouldn't be created
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
def test_identity_binary(self):
def forward(x):
return x * 1 + 0.0
x = torch.rand(32)
_, graph = self.checkTrace(forward, [x])
self.assertFused(graph, ['aten::add', 'aten::mul'])
def test_layer_norm(self):
# TODO: support more normalized_shape
m = torch.nn.LayerNorm(10)
x = torch.randn(2, 5, 10, 10)
_, graph = self.checkTrace(m, [x])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
def test_cat(self):
def cat_along_dim(d):
def forward_cat(*inputs):
return torch.cat(inputs, d)
return forward_cat
for xshape in [
[8, 8, 8, 8],
[64, 8, 32],
[2048, 64],
]:
for d in range(len(xshape)):
x = torch.rand(xshape)
_, graph = self.checkTrace(cat_along_dim(d), [x, x, x])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
def test_typecheck(self):
x = torch.rand(32, 28)
m = torch.nn.Linear(in_features=28, out_features=64, bias=True)
traced, graph = self.checkTrace(m, [x])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
self.assertFused(graph, ['aten::linear'])
# change the shape of the input, we should enter fallback graph
x = torch.rand(5, 28)
self.assertEqual(m(x), traced(x))
@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
class TestFusionPattern(JitLlgaTestCase):
def test_conv2d_eltwise(self):
class M(nn.Module):
def __init__(self, eltwise_fn):
super(M, self).__init__()
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=False)
self.eltwise = eltwise_fn
def forward(self, x):
x = self.conv1(x)
x = self.eltwise(x)
x = self.conv2(x)
x = self.eltwise(x)
return x
# for eltwise in ['relu', 'sigmoid', 'sqrt', 'abs', 'square', 'hardtanh']:
for eltwise in ['relu']:
for inplace in [True, False]:
eltwise_fn_name = eltwise + '_' if inplace else eltwise
eltwise_fn = get_eltwise_fn(eltwise_fn_name)
m = M(eltwise_fn)
x = torch.rand(1, 32, 28, 28)
_, graph = self.checkTrace(m, [x])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
# test if relu_ is replace with relu by mutation removal pass
self.assertFused(graph, ['aten::' + eltwise_fn_name])
# test if relu is fused into the fusion group
self.assertFused(graph, ['aten::' + eltwise])
def test_conv2d_bn(self):
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
self.bn1 = nn.BatchNorm2d(32)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
return x
m = M().eval()
x = torch.rand(1, 32, 28, 28)
_, graph = self.checkTrace(m, [x])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm'])
def test_conv2d_bn_relu(self):
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
self.bn1 = nn.BatchNorm2d(32)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
return x
m = M().eval()
x = torch.rand(1, 32, 28, 28)
_, graph = self.checkTrace(m, [x])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm',
'aten::relu'])
def test_bn2d_eltwise(self):
class M(nn.Module):
def __init__(self, eltwise_fn):
super(M, self).__init__()
self.eltwise = eltwise_fn
self.bn = nn.BatchNorm2d(32)
def forward(self, x):
x = self.bn(x)
x = self.eltwise(x)
return x
for eltwise in ['relu']:
eltwise_fn = get_eltwise_fn(eltwise)
m = M(eltwise_fn).eval()
x = torch.rand(1, 32, 28, 28)
_, graph = self.checkTrace(m, [x])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
self.assertFused(graph, ['aten::' + eltwise])
def test_linear_eltwise(self):
class M(nn.Module):
def __init__(self, eltwise_fn, bias):
super(M, self).__init__()
self.linear = nn.Linear(28, 64, bias)
self.eltwise = eltwise_fn
def forward(self, x):
x = self.linear(x)
x = self.eltwise(x)
return x
for [has_bias, eltwise] in itertools.product(
[True, False],
['relu', 'gelu', 'sigmoid', 'hardtanh', 'relu6', 'elu']):
eltwise_fn = get_eltwise_fn(eltwise)
m = M(eltwise_fn, has_bias)
x = torch.rand(32, 28, requires_grad=False)
_, graph = self.checkTrace(m, [x])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
self.assertFused(graph, ['aten::' + eltwise])
def test_conv2d_sum(self):
class M(nn.Module):
def __init__(self, bias=False):
super(M, self).__init__()
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
self.bn2 = nn.BatchNorm2d(32)
self.relu = nn.ReLU()
self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
self.bn3 = nn.BatchNorm2d(32)
def forward(self, x, y):
x = self.conv1(x)
x = self.bn1(x)
y = self.conv2(y)
y = self.bn2(y)
z = self.relu(x + y)
z = self.conv3(z)
z = self.bn3(z)
return z
for bias in [True, False]:
m = M(bias).eval()
x = torch.rand(1, 32, 16, 16, requires_grad=False)
y = torch.rand(1, 32, 16, 16, requires_grad=False)
_, graph = self.checkTrace(m, [x, y])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
def test_wildcard(self):
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
self.eltwise = nn.ReLU()
def forward(self, x):
x = self.conv1(x)
y = self.eltwise(x)
return [x, y]
# The pattern is as the following:
# conv
# | \
# eltwise \
# | \
# ListConstruct
#
# The output of conv is used by a wildcard op: ListConstruct.
# Thus conv-eltwise cannot be selected into the same Partition.
m = M()
x = torch.rand(1, 32, 28, 28)
_, graph = self.checkTrace(m, [x])
# conv can exist in a single-op oneDNN Graph partition but not relu
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
self.assertFused(graph, ['aten::_convolution'])
def test_rewrap_tensor_input_to_pytorch(self):
class M(nn.Module):
def __init__(self, eltwise_fn, data_type):
super(M, self).__init__()
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True, dtype=data_type)
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True, dtype=data_type)
self.eltwise = eltwise_fn
self.adaptive_avg_pool_2d = nn.AdaptiveAvgPool2d((5, 7))
def forward(self, x, y):
x = self.conv1(x)
x = self.eltwise(x)
x = self.conv2(x)
x = self.eltwise(x)
x = torch.add(x, y)
x = self.adaptive_avg_pool_2d(x)
return x
eltwise_fn_name = 'relu'
eltwise_fn = get_eltwise_fn(eltwise_fn_name)
# Add bfloat16 later
for data_type in [torch.float]:
m = M(eltwise_fn, data_type)
m = m.to(memory_format=torch.channels_last)
x = torch.rand(1, 32, 28, 28, dtype=data_type).to(memory_format=torch.channels_last)
y = torch.rand(1, 32, 28, 28, dtype=data_type).to(memory_format=torch.channels_last)
# Simply test if the output is accurate
# The output of the second partition is input to adaptive_avg_pool2d, which is
# unsupported by LLGA, so it must be handled by PyTorch, which should receive
# correct strides info of the channels-last tensor.
graph, _ = self.checkTrace(m, [x, y])
@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
class TestModel(JitLlgaTestCase):
@skipIfNoTorchVision
def _test_vision(self, model_name):
m = getattr(torchvision.models, model_name)().eval()
x = torch.rand(1, 3, 224, 224) / 10
_, graph = self.checkTrace(m, [x])
self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm',
'aten::relu', 'aten::linear',
'aten::avg_pool2d', 'aten::max_pool2d'])
for model_name, enabled in [
['resnet50', True],
['resnext50_32x4d', True],
['resnext101_32x8d', True],
['densenet121', True],
['googlenet', TEST_SCIPY],
['mobilenet_v2', True],
['mnasnet1_0', True],
['squeezenet1_0', True],
['vgg16', True],
['alexnet', True],
['shufflenet_v2_x1_0', True],
['wide_resnet50_2', True],
]:
def wrapper(mname):
@unittest.skipIf(not enabled, 'Disabled')
def test(self):
return self._test_vision(mname)
return test
setattr(TestModel, 'test_vision_%s' % model_name, wrapper(model_name))
if __name__ == '__main__':
run_tests()

View File

@ -14,7 +14,7 @@ if(NOT BUILD_PYTHON)
endif()
if(USE_TBB)
include_directories(${TBB_INCLUDE_DIR})
include_directories(${TBB_INCLUDE_DIR})
endif()
set(TORCH_SRC_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
@ -423,6 +423,10 @@ target_compile_options(torch_python PRIVATE ${TORCH_PYTHON_COMPILE_OPTIONS})
target_include_directories(torch_python PUBLIC ${TORCH_PYTHON_INCLUDE_DIRECTORIES})
if(BUILD_ONEDNN_GRAPH)
target_compile_definitions(torch_python PRIVATE "-DBUILD_ONEDNN_GRAPH")
target_compile_definitions(torch_cpu PRIVATE "-DBUILD_ONEDNN_GRAPH")
endif()
if(NOT TORCH_PYTHON_LINK_FLAGS STREQUAL "")
set_target_properties(torch_python PROPERTIES LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS})

View File

@ -228,6 +228,8 @@ def _debug_get_fusion_group_inlining() -> _bool: ...
def _debug_set_fusion_group_inlining(enable: _bool): ...
def _jit_texpr_fuser_enabled() -> _bool: ...
def _jit_nvfuser_enabled() -> _bool: ...
def _jit_llga_enabled() -> _bool: ...
def _jit_set_llga_enabled(enable: _bool): ...
def _llvm_enabled() -> _bool: ...
def _jit_override_can_fuse_on_cpu(override: _bool): ...
def _jit_override_can_fuse_on_gpu(override: _bool): ...

View File

@ -0,0 +1,132 @@
#include <ATen/Config.h>
#if AT_MKLDNN_ENABLED()
#include <c10/core/CPUAllocator.h>
#include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
dnnl::graph::engine& Engine::getEngine() {
static dnnl::graph::engine cpu_engine(
dnnl::graph::engine::kind::cpu, /* device_id = */ 0);
return cpu_engine;
}
dnnl::graph::stream& Stream::getStream() {
static dnnl::graph::stream cpu_stream{Engine::getEngine(), nullptr};
return cpu_stream;
}
LlgaTensorImpl::LlgaTensorImpl(
at::Storage&& storage,
const caffe2::TypeMeta& data_type,
const LlgaTensorDesc& desc)
: at::TensorImpl(
std::move(storage),
c10::DispatchKeySet(c10::DispatchKey::MkldnnCPU),
data_type),
desc_(desc) {
set_sizes_and_strides(desc.sizes(), desc.strides());
refresh_numel();
}
at::Tensor LlgaTensorImpl::llga_to_aten_tensor(LlgaTensorImpl* llgaImpl) {
auto aten_tensor = at::detail::make_tensor<TensorImpl>(
std::move(llgaImpl->storage_),
c10::DispatchKeySet(c10::DispatchKey::CPU),
llgaImpl->data_type_);
auto impl = aten_tensor.unsafeGetTensorImpl();
impl->set_storage_offset(llgaImpl->storage_offset_);
impl->set_sizes_and_strides(llgaImpl->sizes(), llgaImpl->strides());
return aten_tensor;
}
at::Tensor empty_llga(
const LlgaTensorDesc& desc,
const c10::TensorOptions& options) {
auto nbytes = desc.storage_size();
auto allocator = at::GetCPUAllocator();
auto storage_impl = c10::make_intrusive<c10::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
nbytes,
allocator->allocate(nbytes),
allocator,
/*resizable=*/false);
return at::detail::make_tensor<LlgaTensorImpl>(
std::move(storage_impl), options.dtype(), desc);
}
const LlgaTensorDesc& get_llga_desc(const at::Tensor& tensor) {
TORCH_INTERNAL_ASSERT(
tensor.is_mkldnn(), "get_llga_desc expects Mkldnn tensor input");
return static_cast<LlgaTensorImpl*>(tensor.unsafeGetTensorImpl())->desc();
}
dnnl::graph::tensor llga_from_aten_tensor(const at::Tensor& tensor) {
return {
get_llga_desc(tensor).logical_tensor(),
torch::jit::fuser::onednn::Engine::getEngine(),
tensor.data_ptr()};
}
using data_type = dnnl::graph::logical_tensor::data_type;
data_type getLlgaDataType(at::ScalarType dt) {
switch (dt) {
case at::ScalarType::Float:
return data_type::f32;
case at::ScalarType::BFloat16:
return data_type::bf16;
case at::kInt:
return data_type::s32;
case at::ScalarType::QInt8:
return data_type::s8;
case at::ScalarType::QUInt8:
return data_type::u8;
default:
TORCH_CHECK(false, "Not support data type ", dt);
}
}
LlgaTensorDesc LlgaTensorDesc::supplementTensorInfo(const at::Tensor& t) const {
if (t.is_mkldnn()) {
// if input tensor is of mkldnn, it's originated from an upstream
// LLGA partition which carries opaque layout info
return get_llga_desc(t).tid(tid_);
} else {
// if input tensor is not an mkldnn tensor, use default layout
auto sizes = t.sizes().vec();
auto strides = t.strides().vec();
auto dtype = getLlgaDataType(t.scalar_type());
return {tid_, sizes, strides, dtype, property_type_};
}
}
at::ScalarType LlgaTensorDesc::aten_scalar_type() const {
switch (dtype_) {
case data_type::f32:
return at::ScalarType::Float;
case data_type::bf16:
return at::ScalarType::BFloat16;
case data_type::s32:
return at::kInt;
case data_type::s8:
return at::ScalarType::QInt8;
case data_type::u8:
return at::ScalarType::QUInt8;
default:
TORCH_CHECK(false, "Invalid data type ", static_cast<size_t>(dtype_));
}
}
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch
#endif // AT_MKLDNN_ENABLED()

View File

@ -0,0 +1,273 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/Config.h>
#include <oneapi/dnnl/dnnl_graph.hpp>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
// Engine represents a device and its context. From the device kind, the engine
// knows how to generate code for the target device and what kind of device
// object to be expected. The device id ensures that there is a unique engine
// being created for each device. The device handle passed from PyTorch allows
// oneDNN Graph implementation to work on the device specified by PyTorch, which
// is currently CPU, so we only have one engine.
// Ref: https://spec.oneapi.io/onednn-graph/latest/programming_model.html#engine
struct Engine {
// CPU engine singleton
static dnnl::graph::engine& getEngine();
Engine(const Engine&) = delete;
void operator=(const Engine&) = delete;
};
// Stream is the logical abstraction for execution units. It is created on top
// of oneDNN Graph engine. A compiled oneDNN Graph partition is submitted to a
// stream for execution.
struct Stream {
// CPU stream singleton
static dnnl::graph::stream& getStream();
Stream(const Stream&) = delete;
void operator=(const Stream&) = delete;
};
struct LlgaTensorDesc {
using desc = dnnl::graph::logical_tensor;
LlgaTensorDesc(
size_t tid,
std::vector<int64_t> sizes,
std::vector<int64_t> strides,
desc::data_type dtype,
desc::property_type property_type)
: tid_(tid),
sizes_(sizes),
strides_(strides),
dtype_(dtype),
property_type_(property_type),
layout_type_(desc::layout_type::strided),
layout_id_(-1) {}
LlgaTensorDesc(const desc& t)
: tid_(t.get_id()),
sizes_(t.get_dims()),
strides_({-1}),
dtype_(t.get_data_type()),
property_type_(t.get_property_type()),
layout_type_(t.get_layout_type()),
layout_id_(-1) {
if (is_opaque()) {
layout_id_ = t.get_layout_id();
}
if (is_strided()) {
strides_ = t.get_strides();
}
}
// TODO: llga need set input/output type constraints while it seems that we
// cannot get the dtype during compile time, hard-coded to fp32 for now to be
// able to add_op
LlgaTensorDesc(const torch::jit::Value* v)
: LlgaTensorDesc(
v->unique(),
{},
{},
desc::data_type::f32,
get_property_type(v)) {
if (v->type()->isSubtypeOf(TensorType::get())) {
auto tt = v->type()->cast<TensorType>();
auto sizes = tt->sizes();
if (sizes.sizes()) {
for (auto d : *sizes.sizes()) {
sizes_.push_back(d.value_or(DNNL_GRAPH_UNKNOWN_DIM));
}
}
auto strides = tt->strides();
if (strides.sizes()) {
for (auto d : *strides.sizes()) {
strides_.push_back(d.value_or(DNNL_GRAPH_UNKNOWN_DIM));
}
}
}
}
LlgaTensorDesc supplementTensorInfo(const at::Tensor& t) const;
at::ScalarType aten_scalar_type() const;
const std::vector<int64_t>& sizes() const {
return sizes_;
}
const std::vector<int64_t>& strides() const {
TORCH_CHECK(!is_opaque(), "Cannot get strides on opaque layout");
return strides_;
}
size_t tid() const {
return tid_;
}
LlgaTensorDesc tid(uint64_t new_id) const {
auto ret = *this;
ret.tid_ = new_id;
return ret;
}
desc::data_type dtype() const {
return dtype_;
}
LlgaTensorDesc dtype(desc::data_type new_dtype) const {
return LlgaTensorDesc(tid_, sizes_, strides_, new_dtype, property_type_);
}
desc::layout_type layout_type() const {
return layout_type_;
}
LlgaTensorDesc layout_type(desc::layout_type new_layout_type) {
auto ret = *this;
ret.layout_type_ = new_layout_type;
return ret;
}
desc::property_type get_property_type(const torch::jit::Value* v) {
switch (v->node()->kind()) {
case prim::Constant:
return desc::property_type::constant;
default:
return desc::property_type::variable;
}
}
LlgaTensorDesc any() {
return layout_type(desc::layout_type::any);
}
size_t storage_size() const {
return logical_tensor().get_mem_size();
}
desc logical_tensor() const {
if (is_dimensionality_unknown()) {
return desc(
tid_, dtype_, DNNL_GRAPH_UNKNOWN_NDIMS, layout_type_, property_type_);
} else if (is_opaque()) {
return desc(tid_, dtype_, sizes_, layout_id_, property_type_);
} else if (is_any()) {
return desc(tid_, dtype_, sizes_, layout_type_, property_type_);
} else {
return desc(tid_, dtype_, sizes_, strides_, property_type_);
}
}
bool is_strided() const {
return layout_type_ == desc::layout_type::strided;
}
bool is_any() const {
return layout_type_ == desc::layout_type::any;
}
bool is_opaque() const {
return layout_type_ == desc::layout_type::opaque;
}
bool operator==(const LlgaTensorDesc& desc) const {
return tid_ == desc.tid_ && sizes_ == desc.sizes_ &&
dtype_ == desc.dtype_ && layout_type_ == desc.layout_type_ &&
((is_opaque() && layout_id_ == desc.layout_id_) ||
strides_ == desc.strides_);
}
bool operator!=(const LlgaTensorDesc& desc) const {
return (tid_ != desc.tid_) || (sizes_ != desc.sizes_) ||
(dtype_ != desc.dtype_) || (layout_type_ != desc.layout_type_) ||
!((is_opaque() && (layout_id_ == desc.layout_id_)) ||
(strides_ == desc.strides_));
}
static size_t hash(const LlgaTensorDesc& desc) {
return c10::get_hash(
desc.tid_,
desc.sizes_,
desc.dtype_,
desc.layout_type_,
desc.layout_id_);
}
void set_compute_inplace() {
compute_inplace_ = true;
}
void set_input_tensor_index(size_t index) {
input_tensor_index_ = index;
}
bool reuses_input_tensor() {
return compute_inplace_;
}
size_t get_input_tensor_index() {
return input_tensor_index_;
}
private:
bool is_dimensionality_unknown() const {
return sizes_.size() == 0;
}
size_t tid_;
std::vector<int64_t> sizes_;
std::vector<int64_t> strides_;
desc::data_type dtype_;
desc::property_type property_type_;
desc::layout_type layout_type_;
size_t layout_id_;
// If this is an output tensor, and querying the compiled partition would
// determine that this tensor would reuse its input tensor, then
// compute_inplace would be true, and input_tensor_index would be the index of
// the corresponding input tensor in inputSpecs_ of the LlgaKernel object.
bool compute_inplace_ = false;
size_t input_tensor_index_;
};
// Initially, oneDNN Graph also used to have blocked layout for tensors between
// partitions, and the LlgaTensorImpl wrapper helped us bypass guard checks.
// oneDNN Graph has switched over to using strided tensors between partitions,
// but this wrapper still helps us bypass guard checks because the strides of
// tensors between partitions would be different from the ones the guard is
// otherwise expecting.
struct TORCH_API LlgaTensorImpl : public c10::TensorImpl {
LlgaTensorImpl(
at::Storage&& storage,
const caffe2::TypeMeta& data_type,
const LlgaTensorDesc& desc);
const LlgaTensorDesc& desc() const {
return desc_;
}
static at::Tensor llga_to_aten_tensor(LlgaTensorImpl* llgaImpl);
private:
LlgaTensorDesc desc_;
};
at::Tensor empty_llga(
const LlgaTensorDesc& desc,
const c10::TensorOptions& options);
dnnl::graph::tensor llga_from_aten_tensor(const at::Tensor& tensor);
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,108 @@
# Pytorch - oneDNN Graph API Bridge
This integration will add the infrastructure of a new PyTorch JIT graph fuser based on [oneDNN Graph API](https://spec.oneapi.io/onednn-graph/latest/programming_model.html), which provides a flexible API for aggressive fusion. The current preview4 version supports fusion for FP32 inference. Currently, the speedup is achieved for static shapes,
although we'd soon add dynamic-shape support. When oneDNN Graph is enabled, weights are cached, as they're constant during inference.
## Graph Optimization
We have registered optimization passes in the custom pre-passes set of PyTorch:
1. Alias and mutation reduction
The operators of oneDNN graph are pure functional while PyTorch has operators in in-place forms or create views for buffer sharing.
Due to the semantic gaps between the backend operators and the PyTorch operators, we have a pass to reduce mutation with best effort at the beginning.
2. Graph passing
With a PyTorch TorchScript graph, the integration maps PyTorch operators on the graph to the corresponding oneDNN Graph operators to form a backend graph.
3. Partitioning
The backend selects regions to be fused in the graph and returns a list of partitions. Each partition corresponds to a set of fused operators.
4. Graph rewriting
The original PyTorch JIT graph will be re-written based on the partitions returned from the backend. The operators in one partition will be grouped together to form a JIT operator, referred to as a oneDNN Graph fusion group.
5. Layout propagation
This pass is to eliminate unnecessary layout conversions at partition boundaries. We set different formats to the output of a partition so that the backend could perform layout conversion internally. When `ANY` is set, the layout at boundaries will be fully decided by the backend. Otherwise, the backend should follow the layout set by PyTorch. Currently, we set `ANY` layout for a tensor that's an output of a oneDNN Graph partition, and an input to another.
## Graph Executor
During runtime execution of a (re-written) PyTorch JIT graph, oneDNN graph partitions will be dispatched to the oneDNN graph JIT variadic Operator.
Inside the oneDNN graph JIT Op, input PyTorch tensors of each partition will be mapped to oneDNN graph tensors. The partition will then be [compiled](https://spec.oneapi.io/onednn-graph/latest/programming_model.html#partition) and [executed](https://spec.oneapi.io/onednn-graph/latest/programming_model.html#compiled-partition). The output oneDNN graph tensor will be mapped back to PyTorch tensors to be fed to the next operator on the PyTorch JIT graph.
## Tests
```bash
pytest test/test_jit_llga_fuser.py
```
## Quick Start
A simple cascaded Conv-Relu example is provided in test. Please consider enabling log outputs to familiarize yourself with the whole pipeline:
**Mutation Removal -> Prepare Binary -> Defer Size Check -> Graph Fuser -> Layout Propagation -> Type Guard -> Kernel Execution**
oneDNN Graph was formerly known as LLGA (Low Level Graph API),
and thus LLGA in the codebase corresponds to oneDNN Graph.
```bash
DNNL_VERBOSE=1 PYTORCH_JIT_LOG_LEVEL=">>graph_helper:>>graph_fuser:>>kernel:>>interface" python -u test/test_jit_llga_fuser.py -k test_conv2d_eltwise
```
## Codebase structure
Most of the source code is placed in
```bash
torch/csrc/jit/codegen/onednn/*
```
Tensor related code is located at
```bash
torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h
torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp
```
CMake files where bridge code is included:
```bash
caffe2/CMakeLists.txt
```
CMake files where oneDNN Graph submodule are included:
```bash
third_party/ideep/mkl-dnn
cmake/public/mkldnn.cmake
cmake/Modules/FindMKLDNN.cmake
cmake/Dependencies.cmake
```
To map another op to oneDNN Graph, you should add an entry for it in in createOperator in torch/csrc/jit/codegen/onednn/graph_helper.cpp.
If it has an inplace variant, you should add it in the lambda being passed to RemoveTensorMutation in
torch/csrc/jit/codegen/onednn/interface.cpp. You might also want to add it to canFuseNode in torch/csrc/jit/codegen/onednn/register_interface.cpp.
## How to use
```python
# enable oneDNN graph fusion globally
torch.jit.enable_onednn_fusion(True)
# define the model
def MyModel(torch.nn.Module):
...
# construct the model
model = MyModel()
with torch.no_grad():
model.eval()
model = torch.jit.trace(model, torch.rand(args.batch_size, 3, 224, 224))
# run the model
with torch.no_grad():
# oneDNN graph fusion will be trigerred during runtime
output = model(images)
```

View File

@ -0,0 +1,87 @@
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
class SizeCheckMover {
private:
Block* block_;
std::shared_ptr<Graph> graph_;
public:
SizeCheckMover(Block* block, std::shared_ptr<Graph> graph)
: block_(block), graph_(std::move(graph)) {}
bool analyzeNode(Node* node, AliasDb& aliasDb) {
//
// %b = addmm(%a)
// %sz = aten::size(%b)
// %c = relu(%b)
// =>
// %b = addmm(%a)
// %c = relu(%b)
// %sz = aten::size(%c)
// ^-- move size check after relu as it preserves input shape
//
if (!node->matches("aten::size(Tensor self) -> int[]"))
return false;
auto* input = node->input(0);
auto& uses = input->uses();
bool onlyUsedByShapePreserveOp =
uses.size() > 1 && std::all_of(uses.begin(), uses.end(), [&](auto& u) {
if (u.user == node) {
return true;
}
// match with shape-preserving unary ops in
// tensorexpr_elementwise_set that's defined in
// torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp
OperatorMap<std::string> schemaMap = get_tensorexpr_elementwise_set();
c10::optional<std::string> mapping =
schemaMap.find(u.user->getOperator());
return mapping == "unary";
});
if (!onlyUsedByShapePreserveOp)
return false;
for (const auto& use : uses) {
if (use.user == node)
continue;
auto shapePreserveOp = use.user;
if (aliasDb.moveAfterTopologicallyValid(node, shapePreserveOp)) {
node->replaceInputWith(input, shapePreserveOp->output(0));
return true;
}
}
return false;
}
void run() {
bool changed = true;
while (changed) {
changed = false;
AliasDb aliasDb(graph_);
for (Node* node : block_->nodes()) {
changed |= analyzeNode(node, aliasDb);
}
}
for (Node* node : block_->nodes())
for (Block* subBlock : node->blocks())
SizeCheckMover(subBlock, graph_).run();
}
};
void DeferSizeCheck(std::shared_ptr<Graph>& graph) {
SizeCheckMover(graph->block(), graph).run();
}
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,15 @@
#pragma once
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
void DeferSizeCheck(std::shared_ptr<Graph>& graph);
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,31 @@
#include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
void CreateLlgaSubgraphs(std::shared_ptr<Graph>& graph) {
AliasDb db(graph);
GraphRewriter graphRewriter(graph->block(), graph, db);
// We maintain alias db correctness in-place while building up the LLGA
// subgraphs, however it is difficult to preserve correctness when
// un-inlining autodiff subgraphs. We first recursively construct all
// subgraphs and then recursively cleanup & unmerge the small subgraphs
graphRewriter.buildupSubgraphs();
graphRewriter.cleanupSubgraphs();
// Run CSE globally onceto eliminate duplicates that may have occurred
// while inlining subgraphs.
EliminateCommonSubexpression(graph);
EliminateDeadCode(graph);
}
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,53 @@
#pragma once
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
struct WorkBlock : public std::pair<Node*, Node*> {
using pair::pair;
Node* begin() {
return this->first;
}
Node* end() {
return this->second;
}
};
class GraphRewriter {
public:
GraphRewriter(Block* block, std::shared_ptr<Graph> graph, AliasDb& aliasDb)
: block_(block),
graph_(std::move(graph)),
aliasDb_(aliasDb),
llgaHelper_(graph_) {}
void cleanupSubgraphs();
void buildupSubgraphs();
private:
Block* block_;
std::shared_ptr<Graph> graph_;
AliasDb& aliasDb_;
LlgaGraphHelper llgaHelper_;
std::vector<WorkBlock> buildWorkBlocks();
std::pair<graph_node_list::iterator, bool> scanNode(
Node* consumer,
graph_node_list::iterator workblock_begin);
c10::optional<Node*> tryMerge(Node* consumer, Node* producer);
};
// This pass creates the subgraphs for oneDNN Graph Fusion Nodes.
// Its code-structure has been vastly inspired from
// torch/csrc/jit/passes/create_autodiff_subgraphs.cpp
void CreateLlgaSubgraphs(std::shared_ptr<Graph>& graph);
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,562 @@
#include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h>
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
#include <ATen/core/functional.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
using opkind = dnnl::graph::op::kind;
void fixConvOptionalBias(Node* node) {
if (node->namedInput("bias")->mustNotBeNone() == false) {
// Replace non-existent optional bias with const None
auto g = node->owningGraph();
auto n = g->createNone();
auto v = n->insertBefore(node)->output();
node->replaceInput(2, v);
}
}
c10::optional<size_t> getDimensions(Value* v) {
if (v->type()->isSubtypeOf(TensorType::get())) {
return v->type()->cast<TensorType>()->sizes().size();
} else {
return c10::nullopt;
}
}
// PyTorch ops that can't otherwise be mapped to oneDNN Graph ops are mapped as
// Wildcards instead. They make the integration code with PyTorch simpler by
// passing every op to the oneDNN Graph library in the add_op call -
// no need to check beforehand whether the op is supported by oneDNN Graph or
// not oneDNN Graph ops separated by wildcards don't end up in the same
// partition.
Operator makeWildcardOp(Node* node) {
auto o = Operator(node, opkind::Wildcard);
// wildcard op contains only topology info
for (size_t i = 0; i < node->inputs().size(); i++) {
o.setInput(i);
}
for (size_t i = 0; i < node->outputs().size(); i++) {
o.setOutput(i);
}
return o;
}
// If we don't meet a certain condition to map a PyTorch op to a oneDNN Graph
// op, then we create a wildcard op corresponding to that PyTorch op instead.
#define REQUIRE(cond) \
if (!(cond)) { \
GRAPH_DEBUG("Unsupported condition " #cond "\n"); \
return makeWildcardOp(node); \
}
Operator makeEltwiseOp(Node* node, opkind kind) {
return Operator(node, kind).setInput(0).setOutput(0);
}
Operator makeBinaryOp(Node* node, opkind kind) {
REQUIRE(
node->input(0)->type()->isSubtypeOf(TensorType::get()) &&
node->input(1)->type()->isSubtypeOf(TensorType::get()))
return Operator(node, kind).setInput(0, 1).setOutput(0);
}
// Map a PyTorch op to its corresponding oneDNN Graph op.
// If mapping isn't possible, then create a wildcard op instead.
// The mapping is done as per oneDNN Graph op schema defined in
// third_party/ideep/mkl-dnn/src/interface/op_def.hpp.
Operator createOperator(Node* node) {
switch (node->kind()) {
case aten::conv2d: {
fixConvOptionalBias(node);
return Operator(node, opkind::Convolution)
.setInput(0, 1, 2)
.setOutput(0)
.setAttr("strides", Operator::Ints, 3)
.setAttr("pads_begin", Operator::Ints, 4)
.setAttr("pads_end", Operator::Ints, 4)
.setAttr("dilations", Operator::Ints, 5)
.setAttr("groups", Operator::Int, 6)
.setAttr("filter_format", std::string("OIX"))
.setAttr("data_format", std::string("NCX"));
}
case aten::_convolution: {
bool transposed = toIValue(node->namedInput("transposed"))->toBool();
REQUIRE(!transposed);
return Operator(node, opkind::Convolution)
.setInput(0, 1, 2)
.setOutput(0)
.setAttr("strides", Operator::Ints, 3)
.setAttr("pads_begin", Operator::Ints, 4)
.setAttr("pads_end", Operator::Ints, 4)
.setAttr("dilations", Operator::Ints, 5)
.setAttr("groups", Operator::Int, 8)
.setAttr("filter_format", std::string("OIX"))
.setAttr("data_format", std::string("NCX"));
}
case aten::batch_norm: {
auto training = toIValue(node->namedInput("training"));
REQUIRE(
training.has_value()); // cannot get training status in script mode
REQUIRE(!training->toBool()); // TODO: support bn training
return Operator(node, opkind::BatchNormInference)
.setInput(0, 1, 2, 3, 4)
.setOutput(0)
.setAttr("epsilon", Operator::Float, 7)
.setAttr("data_format", std::string("NCX"));
}
case aten::layer_norm: {
auto normalized_shape = toIValue(node->namedInput("normalized_shape"));
REQUIRE(normalized_shape->toIntList().size() == 1);
return Operator(node, opkind::LayerNorm)
.setInput(0, 2, 3)
.setOutput(0)
.setAttr("epsilon", Operator::Float, 4)
.setAttr("keep_stats", false);
}
case aten::addmm: {
auto alpha = toIValue(node->namedInput("alpha"));
auto beta = toIValue(node->namedInput("beta"));
REQUIRE(
alpha.has_value() && beta.has_value() && (alpha->toDouble() == 1.0) &&
(beta->toDouble() == 1.0));
return Operator(node, opkind::MatMul).setInput(1, 2, 0).setOutput(0);
}
case aten::add:
return makeBinaryOp(node, opkind::Add);
case aten::mul:
return makeBinaryOp(node, opkind::Multiply);
case aten::tanh:
return makeEltwiseOp(node, opkind::Tanh);
case aten::relu:
return makeEltwiseOp(node, opkind::ReLU);
case aten::elu:
return makeEltwiseOp(node, opkind::Elu)
.setAttr("alpha", Operator::Float, 1);
case aten::sigmoid:
return makeEltwiseOp(node, opkind::Sigmoid);
case aten::gelu:
return makeEltwiseOp(node, opkind::GELU);
case aten::sqrt:
return makeEltwiseOp(node, opkind::Sqrt);
case aten::abs:
return makeEltwiseOp(node, opkind::Abs);
case aten::square:
return makeEltwiseOp(node, opkind::Square);
case aten::hardtanh:
return makeEltwiseOp(node, opkind::HardTanh)
.setAttr("min", Operator::Float, 1)
.setAttr("max", Operator::Float, 2);
case aten::relu6:
return makeEltwiseOp(node, opkind::HardTanh)
.setAttr("min", 0.f)
.setAttr("max", 6.f);
case aten::softmax: {
auto axis = toIValue(node->namedInput("dim"))->toInt();
return Operator(node, opkind::SoftMax)
.setInput(0)
.setOutput(0)
.setAttr("axis", axis);
}
case aten::cat: {
auto o = Operator(node, opkind::Concat);
REQUIRE(
node->namedInput("tensors")->node()->kind() == prim::ListConstruct);
REQUIRE(node->namedInput("tensors")->uses().size() == 1);
REQUIRE(node->namedInput("dim")->node()->kind() == prim::Constant);
// aten::cat needs a special handling since it takes a Tensor[] as input.
// We set the inputs of ListConstruct as the inputs of cat.
//
// Pytorch IR: LLGA sees:
// %a %b %c %dim %a %b %c
// \ | / | \ | /
// prim::ListConstruct prim::Constant llga::Concat[axis=%dim]
// \ /
// aten::cat
auto listConstruct = node->input(0)->node();
for (auto input : listConstruct->inputs())
o.setInputValue(input);
return o.setOutput(0).setAttr("axis", Operator::Int, 1);
}
case aten::max_pool2d: {
REQUIRE(
node->namedInput("kernel_size")->node()->kind() == prim::Constant);
auto rounding_type =
toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor";
return Operator(node, opkind::MaxPool)
.setInput(0)
.setOutput(0)
.setAttr("kernel", Operator::Ints, 1)
.setAttr("strides", Operator::Ints, 2)
.setAttr("pads_begin", Operator::Ints, 3)
.setAttr("pads_end", Operator::Ints, 3)
.setAttr("dilations", Operator::Ints, 4)
.setAttr("rounding_type", std::string(rounding_type))
.setAttr("data_format", std::string("NCX"));
}
case aten::avg_pool2d: {
// TODO: do we need add checks for all Constants?
REQUIRE(
node->namedInput("kernel_size")->node()->kind() == prim::Constant);
auto rounding_type =
toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor";
auto divisor_override = toIValue(node->namedInput("divisor_override"));
REQUIRE(divisor_override->isNone());
return Operator(node, opkind::AvgPool)
.setInput(0)
.setOutput(0)
.setAttr("kernel", Operator::Ints, 1)
.setAttr("strides", Operator::Ints, 2)
.setAttr("pads_begin", Operator::Ints, 3)
.setAttr("pads_end", Operator::Ints, 3)
.setAttr("exclude_pad", !Operator::Bool(node, 5))
.setAttr("rounding_type", std::string(rounding_type))
.setAttr("data_format", std::string("NCX"));
}
case aten::matmul: {
auto dim0 = getDimensions(node->namedInput("self")).value_or(-1);
auto dim1 = getDimensions(node->namedInput("other")).value_or(-1);
// TODO: support all shape combinations
REQUIRE(
(dim0 == 2 && dim1 == 2) || (dim0 == 4 && dim1 == 4) ||
(dim0 == 3 && dim1 == 2));
} // fall through
case aten::mm: {
return Operator(node, opkind::MatMul).setInput(0, 1).setOutput(0);
}
case aten::linear: {
return Operator(node, opkind::MatMul)
.setInput(0, 1, 2)
.setOutput(0)
.setAttr("transpose_b", true);
}
default:
return makeWildcardOp(node);
}
}
dnnl::graph::op createLlgaOp(Node* node) {
return createOperator(node).llgaOp();
}
bool isSupported(Node* node) {
return createOperator(node).kind() != opkind::Wildcard;
};
DeviceType inferDeviceFromValue(Value* v) {
auto tt = v->type()->cast<TensorType>();
if (!tt) {
return at::kCPU;
}
auto device = tt->device();
if (!device) {
return at::kCPU;
}
return device->type();
}
DeviceType inferDevice(const std::shared_ptr<Graph>& graph) {
auto dt = inferDeviceFromValue(graph->inputs()[0]);
TORCH_CHECK(
std::all_of(
graph->inputs().begin(),
graph->inputs().end(),
[dt](Value* v) { return inferDeviceFromValue(v) == dt; }),
"All inputs must have the same deive type");
return dt;
}
dnnl::graph::engine::kind getLlgaEngineKind(DeviceType type) {
switch (type) {
case DeviceType::CPU:
return dnnl::graph::engine::kind::cpu;
default:
TORCH_CHECK(false, "Not support device type ", type);
}
}
void mayAddListConstructIntoConcatPartition(
Node* n,
OpPartitionMap& opToOwningPartition) {
// Since prim::ListConstruct is not visible to the LLGA,
// it will not be in any partition returned from partfuseritioning results.
// We need rewrite opToOwningPartition to make the prim::ListConstruct to be
// 'virtually' in the same partition with the aten::cat, so that
// prim::ListConstruct can be fused into the fusion group by graph fuser.
// We emphasize on 'virtually' because get_num_ops() for cat's partition
// would still return 1.
if (n->kind() == aten::cat && opToOwningPartition.has(n)) {
auto listConstrcut = n->namedInput("tensors")->node();
auto partitionId = opToOwningPartition.get(n);
opToOwningPartition.add(listConstrcut, partitionId);
}
}
// Verify that input tensors are compatible with oneDNN Graph.
// Scalars would be converted to 1-D tensors later anyway,
// but they shouldn't be complex-double
// If this check fails, convert op to wildcard
bool checkInputCompatibility(Node* node) {
auto allInputs = node->inputs();
for (auto input : allInputs) {
c10::IValue inputIValue = toIValue(input);
if (inputIValue.isTensor()) {
const at::Tensor& tensor = inputIValue.toTensor();
if (tensor.device() != at::kCPU) {
return false;
}
auto dtype = tensor.scalar_type();
if ((dtype != at::ScalarType::Float) && (dtype != at::ScalarType::Long)) {
return false;
}
} else if (inputIValue.isScalar()) {
if (inputIValue.isComplexDouble()) {
return false;
}
}
}
return true;
}
LlgaGraphHelper::LlgaGraphHelper(
const std::shared_ptr<Graph>& graph,
dnnl::graph::partition::policy policy) {
auto deviceType = inferDevice(graph);
auto engineKind = getLlgaEngineKind(deviceType);
dnnl::graph::graph g{engineKind};
GRAPH_DEBUG("Constructing LLGA graph");
// TODO: select nodes in top-level block for now
for (auto* node : graph->block()->nodes()) {
auto op = createLlgaOp(node);
auto kindOfNode = node->kind();
if (checkInputCompatibility(node)) {
g.add_op(op);
GRAPH_DEBUG(" Added node ", kindOfNode.toQualString());
} else {
GRAPH_DEBUG("The backend failed to add node ", kindOfNode.toQualString());
g.add_op(makeWildcardOp(node).llgaOp());
}
for (Value* input : node->inputs()) {
tensorIdToValue_.emplace(input->unique(), input);
}
}
GRAPH_DEBUG("Get Partitions");
std::vector<dnnl::graph::partition> partitions = g.get_partitions(policy);
// excluded unsupported Wildcard partitions
for (auto& partition : partitions) {
if (partition.is_supported()) {
partitions_.push_back(partition);
}
}
GRAPH_DEBUG(" Got #partitions: ", partitions_.size());
for (size_t partId = 0; partId < partitions_.size(); partId++) {
for (auto opId : partitions_[partId].get_ops()) {
opToOwningPartition_.add(opId, partId);
}
}
// Scanning the graph again for post processing
for (auto* node : graph->block()->nodes()) {
mayAddListConstructIntoConcatPartition(node, opToOwningPartition_);
}
}
bool LlgaGraphHelper::isLlgaSubgraph(const Node* node) {
return node->hasAttribute(attr::Subgraph) &&
node->kind() == prim::oneDNNFusionGroup;
}
bool LlgaGraphHelper::shouldMerge(Node* toMerge, Node* subgraph) {
TORCH_CHECK(
isLlgaSubgraph(subgraph),
"The consumer node does not contain a subgraph");
if (!shouldConsiderForMerge(toMerge)) {
return false;
}
return opToOwningPartition_.get(toMerge) ==
opToOwningPartition_.get(subgraph);
}
// Except for conv & GEMMs, which should always be handled by oneDNN Graph,
// only use single-op partitions for ops unsupported by NNC, or ops
// that oneDNN executes faster. prim::ListConstruct is an exception, since
// we simply want to fuse it with cat.
bool isBetterSuitedForLLGA(NodeKind kindOfOp) {
return (
(kindOfOp == aten::layer_norm) || (kindOfOp == aten::avg_pool2d) ||
(kindOfOp == aten::matmul) || (kindOfOp == aten::max_pool2d) ||
(kindOfOp == aten::conv2d) || (kindOfOp == aten::_convolution) ||
(kindOfOp == aten::mm) || (kindOfOp == aten::linear) ||
(kindOfOp == aten::cat) || (kindOfOp == prim::ListConstruct));
}
bool LlgaGraphHelper::checkForSingleOpPartition(Node* node) {
if (opToOwningPartition_.has(node)) {
auto partitionId = opToOwningPartition_.get(node);
if (partitions_[partitionId].get_ops_num() == 1) {
auto kindOfNode = node->kind();
return isBetterSuitedForLLGA(kindOfNode);
} else {
// multi-op partition
return true;
}
} else {
// this op isn't present in any partition
return false;
}
}
bool LlgaGraphHelper::shouldConsiderForMerge(Node* node) {
// if we're already in the process of merging
if (isLlgaSubgraph(node)) {
return true;
}
return checkForSingleOpPartition(node);
}
Node* LlgaGraphHelper::createSingletonSubgraph(Node* n, AliasDb& aliasDb) {
auto partitionId = opToOwningPartition_.get(n);
GRAPH_DEBUG(
"Creating FusionGroup_", partitionId, " for ", n->kind().toQualString());
auto group = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
n, prim::oneDNNFusionGroup, aliasDb);
opToOwningPartition_.add(group, partitionId);
LlgaNodeWrapper(group).initOutputLayouts();
return group;
}
void LlgaGraphHelper::mergeNodeIntoSubgraph(
Node* toMerge,
Node* subgraphNode,
AliasDb& aliasDb) {
if (isLlgaSubgraph(toMerge)) {
GRAPH_DEBUG(
"Merging ",
toMerge->kind().toQualString(),
"_",
opToOwningPartition_.get(toMerge),
" into ",
subgraphNode->kind().toQualString(),
"_",
opToOwningPartition_.get(subgraphNode));
} else {
GRAPH_DEBUG(
"Merging ",
toMerge->kind().toQualString(),
" into ",
subgraphNode->kind().toQualString(),
"_",
opToOwningPartition_.get(subgraphNode));
}
SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
toMerge, subgraphNode, aliasDb);
}
void LlgaGraphHelper::unmergeIfAnyNodeIsMissing(Node* subgraphNode) {
TORCH_CHECK(isLlgaSubgraph(subgraphNode), "Cannot unmerge a non-LLGA node");
auto partitionId = opToOwningPartition_.get(subgraphNode);
auto expectOpNum = partitions_[partitionId].get_ops_num();
auto actualOpNum = countSupportedOps(subgraphNode->g(attr::Subgraph));
if (expectOpNum != actualOpNum) {
GRAPH_DEBUG(
"Unmerging FusionGroup_",
partitionId,
". Expected ",
expectOpNum,
" ops, but got ",
actualOpNum,
" ops.");
SubgraphUtils::unmergeSubgraph(subgraphNode);
}
}
size_t LlgaGraphHelper::countSupportedOps(
const std::shared_ptr<Graph>& graph) const {
// TODO: count nodes in top-level block for now
size_t cnt = 0;
for (auto* node : graph->block()->nodes()) {
auto nodeKind = node->kind();
if ((nodeKind != prim::Constant) && (nodeKind != prim::ListConstruct)) {
cnt++;
}
}
return cnt;
}
std::vector<dnnl::graph::partition> LlgaGraphHelper::getPartitions() const {
return partitions_;
}
std::map<size_t, Value*> LlgaGraphHelper::getTensorIdToValue() const {
return tensorIdToValue_;
}
LlgaNodeWrapper::LlgaNodeWrapper(const Node* node)
: n(const_cast<Node*>(node)) { // NOLINT
TORCH_CHECK(
LlgaGraphHelper::isLlgaSubgraph(n), "Cannot wrap a non-LLGA fusion node");
}
void LlgaNodeWrapper::setOpaqueLayout(size_t offset) {
TORCH_CHECK(offset < n->outputs().size(), "Invalid output offset ", offset);
auto& layouts =
const_cast<std::vector<int64_t>&>(n->is(attr::output_layouts)); // NOLINT
layouts.at(offset) = 1;
}
bool LlgaNodeWrapper::useOpaqueLayout(size_t offset) const {
TORCH_CHECK(offset < n->outputs().size(), "Invalid output offset ", offset);
return n->is(attr::output_layouts)[offset] == 1;
}
void LlgaNodeWrapper::initOutputLayouts() {
if (n->hasAttribute(attr::output_layouts)) {
return;
}
// Init all output layouts as undef
std::vector<int64_t> layouts(n->outputs().size(), 0);
n->is_(attr::output_layouts, layouts);
}
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,95 @@
#pragma once
#include <oneapi/dnnl/dnnl_graph.hpp>
#include <torch/csrc/jit/codegen/onednn/operator.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
struct OpPartitionMap {
void add(uint64_t opId, uint64_t partitionId) {
opmap_[opId] = partitionId;
}
void add(Node* n, uint64_t partitionId) {
add(Operator::getId(n), partitionId);
}
bool has(uint64_t opId) {
return opmap_.count(opId) > 0;
}
bool has(Node* n) {
return has(Operator::getId(n));
}
uint64_t get(uint64_t opId) {
return opmap_[opId];
}
uint64_t get(Node* n) {
auto opId = Operator::getId(n);
TORCH_CHECK(
has(opId),
"Node ",
n->kind().toQualString(),
" does not belong to any LLGA partition");
return get(opId);
}
private:
std::unordered_map<uint64_t, uint64_t> opmap_;
};
class LlgaGraphHelper {
public:
LlgaGraphHelper(
const std::shared_ptr<Graph>& graph,
dnnl::graph::partition::policy policy =
dnnl::graph::partition::policy::fusion);
bool shouldMerge(Node* toMerge, Node* subgraph);
bool shouldConsiderForMerge(Node* node);
bool checkForSingleOpPartition(Node* node);
Node* createSingletonSubgraph(Node* n, AliasDb& db);
void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode, AliasDb& db);
void unmergeIfAnyNodeIsMissing(Node* subgraphNode);
static bool isLlgaSubgraph(const Node* node);
std::vector<dnnl::graph::partition> getPartitions() const;
std::map<size_t, Value*> getTensorIdToValue() const;
private:
size_t countSupportedOps(const std::shared_ptr<Graph>& graph) const;
OpPartitionMap opToOwningPartition_;
std::vector<dnnl::graph::partition> partitions_;
std::map<size_t, Value*>
tensorIdToValue_; // map from tensorId to torch::jit::Value
};
class LlgaNodeWrapper {
public:
LlgaNodeWrapper(const Node* node);
void setOpaqueLayout(size_t offset);
bool useOpaqueLayout(size_t offset) const;
friend class LlgaGraphHelper;
private:
void initOutputLayouts();
Node* n;
};
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,144 @@
#include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
void GraphRewriter::cleanupSubgraphs() {
auto curNode = *block_->nodes().rbegin();
while (curNode != *block_->nodes().rend()) {
// Save the previous node, since we might delete `curNode` in next block
auto prevNode = curNode->prev();
if (llgaHelper_.isLlgaSubgraph(curNode)) {
// Unmerge subgraph if we don't get every nodes of a partition
// into the subgraph due to failed alias check
llgaHelper_.unmergeIfAnyNodeIsMissing(curNode);
}
curNode = prevNode;
}
for (Node* n : block_->nodes()) {
for (Block* b : n->blocks()) {
GraphRewriter(b, graph_, aliasDb_).cleanupSubgraphs();
}
}
}
void GraphRewriter::buildupSubgraphs() {
// We need to run the rewriter multiple times in order to get all merge
// opportunities. This is because moveBeforeTopologicalValid may reorder
// nodes to be AFTER the current iteration point. In order to properly
// consider those nodes for merging, we need run the pass until no changes
// have been made.
//
// Example:
// c = f(a, b)
// d = f(c)
// e = f(d) <- iter is here, moving upward
// After c.moveBeforeTopologicallyValid(e), we have:
// c = f(a, b)
// e = f(d) <- iter still here
// d = f(c) <- this was node moved on the other side.
// see [workblocks]
auto workblocks = buildWorkBlocks();
for (auto& workblock : workblocks) {
bool any_changed = true;
while (any_changed) {
any_changed = false;
auto workblock_end = workblock.end()->reverseIterator();
auto workblock_begin = workblock.begin()->reverseIterator();
for (auto it = workblock_end; it != workblock_begin;) {
bool changed = false;
std::tie(it, changed) = scanNode(*it, workblock_begin);
any_changed |= changed;
}
}
}
// Construct Subgraphs Recursively
for (Node* n : block_->nodes()) {
for (auto subBlock : n->blocks()) {
GraphRewriter(subBlock, graph_, aliasDb_).buildupSubgraphs();
}
}
}
std::vector<WorkBlock> GraphRewriter::buildWorkBlocks() {
// [workblocks]
// the IR has many nodes which can never be reordered around, such as a
// prim::Bailout. if a node N is surrounded by two nodes which cannot be
// reordered, A and B, then a fusion group that is created from N
// can only contain nodes from (A, B) The nodes from A to B represent one
// work block for the subgraph rewriter to work on. By creating these up
// front, we avoid retraversing the whole graph block any time scanNode
// returns
Node* end_bound_node = block_->return_node();
Node* curr = end_bound_node->prev();
std::vector<WorkBlock> worklist;
while (curr != block_->param_node()) {
// cannot reorder around side effectful nodes
if (curr->hasSideEffects()) {
worklist.emplace_back(curr, end_bound_node);
end_bound_node = curr;
}
curr = curr->prev();
}
worklist.emplace_back(curr, end_bound_node);
return worklist;
}
std::pair<graph_node_list::iterator, bool> GraphRewriter::scanNode(
Node* consumer,
graph_node_list::iterator workblock_begin) {
GRAPH_DEBUG("Scanning ", consumer->kind().toQualString());
if (llgaHelper_.shouldConsiderForMerge(consumer)) {
if (!llgaHelper_.isLlgaSubgraph(consumer)) {
consumer = llgaHelper_.createSingletonSubgraph(consumer, aliasDb_);
}
// Iterate through the workblock to merge nodes of the
// same partition determined by LLGA graph helper.
// Nodes like B and C do not share a common input but belong to a
// same partition, and thus we cannot only scan the input nodes
// to find merging opportunities. Instead, we have to scan through
// the whole workblock, which might lead to O^2 accesses in worst case
// A
// + - - / - \ - - +
// | B C |
// | | | |
// | D E |
// + - - \ - / - - +
// F
auto prev = ++consumer->reverseIterator();
for (auto it = prev; it != workblock_begin; it++) {
if (auto group = tryMerge(consumer, *it)) {
// we successfully merged, so the new group's `inputs` may have
// changed. So rescan the new group for more merging opportunities.
return std::make_pair(group.value()->reverseIterator(), true);
}
}
}
return std::make_pair(++consumer->reverseIterator(), false);
}
// Try to merge `producer` into `consumer`. If successful, this destroys
// `producer` and returns the `consumer` group.
c10::optional<Node*> GraphRewriter::tryMerge(Node* consumer, Node* producer) {
AT_ASSERT(llgaHelper_.isLlgaSubgraph(consumer));
bool canMerge = llgaHelper_.shouldMerge(producer, consumer) &&
aliasDb_.moveBeforeTopologicallyValid(producer, consumer);
if (!canMerge) {
return c10::nullopt;
}
llgaHelper_.mergeNodeIntoSubgraph(producer, consumer, aliasDb_);
return consumer;
}
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,45 @@
#include <torch/csrc/jit/codegen/onednn/guard_shape.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
//! [ Note -- prepareFusionGroupAndGuardOutputs implementation ]
//! shamelessly copying code from NNC (tensorexpr_fuser) with very little
//! modification, original code at:
//! `torch/csrc/jit/passes/tensorexpr_fuser.cpp:prepareFusionGroupAndGuardOutputs`
//!
//! We have the assumption that LLGA does not have operators
//! depending on the content of the tensor.
void prepareFusionGroupAndGuardOutputs(Block* block) {
std::vector<Node*> fusion_groups;
for (Node* n : block->nodes()) {
for (Block* b : n->blocks()) {
prepareFusionGroupAndGuardOutputs(b);
}
if (n->kind() == prim::oneDNNFusionGroup) {
fusion_groups.push_back(n);
}
}
for (Node* fusion_group : fusion_groups) {
// TODO: add further optimization pass to removeOutputsUsedOnlyInSize,
// refer to
// `torch/csrc/jit/passes/tensorexpr_fuser.cpp:removeOutputsUsedOnlyInSize`
// removeOutputsUsedOnlyInSize(fusion_group);
insertTypeGuard(
fusion_group,
[](const TensorTypePtr& t) { return t; },
prim::oneDNNFusionGuard);
}
}
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,15 @@
#pragma once
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
void prepareFusionGroupAndGuardOutputs(Block* block);
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,172 @@
#include <oneapi/dnnl/dnnl_graph.hpp>
#include <torch/csrc/jit/codegen/onednn/defer_size_check.h>
#include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
#include <torch/csrc/jit/codegen/onednn/guard_shape.h>
#include <torch/csrc/jit/codegen/onednn/interface.h>
#include <torch/csrc/jit/codegen/onednn/kernel.h>
#include <torch/csrc/jit/codegen/onednn/layout_propagation.h>
#include <torch/csrc/jit/codegen/onednn/prepare_binary.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/decompose_ops.h>
#include <torch/csrc/jit/passes/pass_manager.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/operator_options.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
void fuseGraph(std::shared_ptr<Graph>& g) {
// Follow the process of the tensorexpr_fuser in profiling mode:
// Remove prim::profile nodes and embed the profile info directly in the
// IR in value types to avoid breaking the fusion patterns.
// Will add shape guard after LLGA optimization passes and
// wipe the tensor type information from the IR, so that it's not
// accidentally used by any other pass.
// We rely on the shape specialization and shape guard to ensure the validity
// of the cached compilation in the kernel, thus only support profiling mode.
// TODO: add check on oneDNNFusionGroup to ensure allShapesAreKnown on nodes
// to fuse: torch/csrc/jit/passes/tensorexpr_fuser.cpp: allShapesAreKnown
if (getProfilingMode()) {
GRAPH_DUMP(
"Before RemoveProfileNodesAndSpecializeTypes. Beginning of LLGA "
"optimization pass",
g);
RemoveProfileNodesAndSpecializeTypes(g);
GRAPH_DUMP(
"After RemoveProfileNodesAndSpecializeTypes. Before mutation removal",
g);
RemoveTensorMutation(g, [](Node* nodeToFunctionalize) {
static std::unordered_set<Symbol> supportedOps = {
aten::add_,
aten::mul_,
aten::tanh_,
aten::elu_,
aten::relu_,
aten::relu6_,
aten::gelu_,
aten::sqrt_,
aten::sigmoid_,
aten::hardtanh_,
aten::abs_,
aten::square_,
};
return supportedOps.count(nodeToFunctionalize->kind()) != 0;
});
RemoveListMutation(g);
GRAPH_DUMP("After mutation removal. Before PrepareBinaryForLLGA", g);
PrepareBinaryForLLGA(g);
GRAPH_DUMP("After PrepareBinaryForLLGA. Before DeferSizeCheck", g);
DeferSizeCheck(g);
GRAPH_DUMP("After DeferSizeCheck. Before CreateLlgaSubgraphs", g);
CreateLlgaSubgraphs(g);
GRAPH_DUMP("After CreateLlgaSubgraphs. Before PropagateLayout", g);
PropagateLayout(g);
GRAPH_DUMP(
"After PropagateLayout. Before prepareFusionGroupAndGuardOutputs", g);
// Add shape guard for profiling mode and wipe the tensor type information
// from the IR
prepareFusionGroupAndGuardOutputs(g->block());
GRAPH_DUMP(
"After prepareFusionGroupAndGuardOutputs. Before "
"RemoveTensorTypeSpecializations",
g);
RemoveTensorTypeSpecializations(g);
GRAPH_DUMP(
"After RemoveTensorTypeSpecializations. End of LLGA optimization pass",
g);
}
}
} // namespace onednn
} // namespace fuser
Operation createLlgaKernel(const Node* node) {
auto kernel = std::make_shared<fuser::onednn::LlgaKernel>(node);
return [kernel](Stack* stack) {
RECORD_FUNCTION(kernel->debugName(), std::vector<c10::IValue>());
kernel->run(*stack);
return 0;
};
}
RegisterOperators oneDNNFusionGroupOp({
torch::jit::Operator(
prim::oneDNNFusionGroup,
createLlgaKernel,
AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
});
// Currently, we convert some scalar inputs, such as the second argument of
// binary ops to a 1D tensor. Other scalar inputs are prim::Constant nodes.
// But if we have any scalar inputs to guard in the future, some logic here
// would have to be changed.
Operation createLlgaGuardKernel(const Node* node) {
return [node](Stack* stack) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("Guarding node: ", node->kind().toQualString());
#endif
std::vector<TypePtr> types = node->tys(attr::types);
const auto num_inputs = types.size();
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("num_inputs to guard: ", num_inputs);
#endif
for (size_t i = 0; i < num_inputs; i++) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("checking input ", i);
#endif
auto& input = peek(stack, i, num_inputs);
const c10::TensorTypePtr& guard_tensor_type =
types[i]->cast<TensorType>();
if (!input.isTensor()) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("input ", i, " is not a tensor, return false");
#endif
push(stack, IValue(false));
return;
}
const at::Tensor& tensor = input.toTensor();
// If input tensor is of mkldnn, it's originated from an upstream
// LLGA partition that has passed the check on input shapes.
// It is valid to continue here as long as the output shapes from
// oneDNN graph partitions are determined by the input shapes.
if (tensor.is_mkldnn()) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("input ", i, " is_mkldnn, continue");
#endif
continue;
}
if (!guard_tensor_type->matchTensor(tensor)) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("input ", i, " check failed, return false");
#endif
push(stack, IValue(false));
return;
}
}
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("all check done, return true");
#endif
push(stack, IValue(true));
return;
};
}
RegisterOperators oneDNNGuardOp({
torch::jit::Operator(
prim::oneDNNFusionGuard,
createLlgaGuardKernel,
AliasAnalysisKind::FROM_SCHEMA),
});
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,62 @@
#pragma once
#include <ATen/Config.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/pass_manager.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
static std::atomic<bool> onednn_enabled{false};
std::atomic<bool>& getLlgaEnabled() {
return onednn_enabled;
}
C10_EXPORT void fuseGraph(std::shared_ptr<Graph>& g);
} // namespace onednn
} // namespace fuser
struct C10_EXPORT RegisterLlgaFuseGraph
: public PassManager<RegisterLlgaFuseGraph> {
static bool setEnabled(bool enabled) {
TORCH_CHECK(
AT_MKLDNN_ENABLED(),
"Running oneDNN Graph fuser is only supported with MKLDNN builds.");
bool oldState = fuser::onednn::getLlgaEnabled();
fuser::onednn::getLlgaEnabled() = enabled;
if (enabled) {
registerPass(fuser::onednn::fuseGraph);
} else {
clearPass();
}
return oldState;
}
static bool isEnabled() {
return fuser::onednn::getLlgaEnabled();
}
// override PassManager::registerPass to register pre-pass
static bool registerPass(GraphPass p) {
if (!isRegistered()) {
passID(registerPrePass(std::move(p)), true);
isRegistered(true);
return false;
}
return true;
}
// override PassManager::clearPass to clear pre-pass
static void clearPass() {
if (isRegistered()) {
clearPrePass(passID());
isRegistered(true);
}
}
};
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,292 @@
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
#include <torch/csrc/jit/codegen/onednn/kernel.h>
#include <ATen/core/functional.h>
#include <torch/csrc/jit/jit_log.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
using namespace dnnl::graph;
using data_type = dnnl::graph::logical_tensor::data_type;
LlgaKernel::LlgaKernel(const Node* fusionNode)
: fusionNode_(fusionNode),
graph_(fusionNode->g(attr::Subgraph)),
nGraphInputs_(graph_->inputs().size()),
nOutputs_(graph_->outputs().size()),
debugName_(genDebugName()) {
// TODO: This is a workaround to recreate the partitions here.
// The ideal way is to use the partition serialization API (not available from
// LLGA now) to carry a serialized string representation from graph rewrite
// and deserialize it here.
auto llgaGraphHelper = LlgaGraphHelper(graph_);
auto partitions = llgaGraphHelper.getPartitions();
tensorIdToValue_ = llgaGraphHelper.getTensorIdToValue();
TORCH_CHECK(
partitions.size() == 1,
"LLGA subgraph should contain only one partition");
partition_ = partitions[0];
nPartitionInputs_ = partition_.get_in_ports().size();
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("Initialized ", debugName(), "\n", graph_->toString());
#endif
}
bool LlgaKernel::useOpaqueLayout(size_t offset) const {
return LlgaNodeWrapper(fusionNode_).useOpaqueLayout(offset);
}
void LlgaKernel::initializeConstantInputs() {
for (auto& lt : partition_.get_in_ports()) {
auto inputId = lt.get_id();
if (initializedInputIds_.find(inputId) == initializedInputIds_.end()) {
TORCH_CHECK(
tensorIdToValue_.count(inputId) > 0,
"inputs with inputId ",
inputId,
" is missing");
auto* value = tensorIdToValue_[inputId];
TORCH_CHECK(
value->node()->kind() == prim::Constant &&
value->type()->cast<TensorType>(),
"inputs with inputId ",
inputId,
" should be a Constant tensor");
constantValues_.emplace_back(value);
auto const_tensor = toIValue(value)->toTensor();
constantInputs_.emplace_back(const_tensor);
}
}
}
std::map<size_t, int64_t> LlgaKernel::initializeTensorIdToOccurence() const {
std::map<size_t, int64_t> tensorIdToOccurence;
for (auto& lt : partition_.get_in_ports()) {
auto inputId = lt.get_id();
std::map<size_t, int64_t>::iterator it(tensorIdToOccurence.find(inputId));
if (it != tensorIdToOccurence.end()) {
it->second++;
} else {
tensorIdToOccurence[inputId] = 1;
}
}
return tensorIdToOccurence;
}
ArgSpecs LlgaKernel::initializeInputSpecs(const TensorArgs& inputs) {
ArgSpecs inputSpecs;
inputSpecs.reserve(nPartitionInputs_);
GRAPH_DEBUG("Initializing graph input logical tensors");
std::map<size_t, int64_t> tensorIdToOccurence =
initializeTensorIdToOccurence();
for (size_t i = 0; i < nGraphInputs_; i++) {
auto spec = ArgSpec(graph_->inputs()[i]).supplementTensorInfo(inputs[i]);
initializedInputIds_.insert(spec.tid());
int64_t occurence = tensorIdToOccurence[spec.tid()];
inputSpecs.insert(inputSpecs.end(), occurence, spec);
runArgsIdx_.insert(runArgsIdx_.end(), occurence, i);
}
GRAPH_DEBUG("Initializing constant input tensors");
initializeConstantInputs();
TORCH_CHECK(
inputSpecs.size() + constantValues_.size() == nPartitionInputs_,
"Partition inputs are missing");
GRAPH_DEBUG(
"Concatenating constant input logical tensors to graph input "
"logical tensors");
for (Value* constant_value : constantValues_) {
ArgSpec constantInputSpec(constant_value);
inputSpecs.emplace_back(constantInputSpec);
constantLogicalTensors_.emplace_back(constantInputSpec.logical_tensor());
}
return inputSpecs;
}
ArgSpecs LlgaKernel::initializeOutputSpecs() const {
ArgSpecs outputSpecs;
outputSpecs.reserve(nOutputs_);
for (size_t i = 0; i < nOutputs_; i++) {
auto spec = ArgSpec(graph_->outputs()[i]);
if (useOpaqueLayout(i)) {
spec = spec.any();
}
outputSpecs.emplace_back(spec);
}
return outputSpecs;
}
std::tuple<RunArgs, RunArgs> LlgaKernel::prepareRunArgs(
const TensorArgs& inputs,
TensorArgs& outputs) const {
RunArgs runInputs, runOutputs;
auto numInputs = runArgsIdx_.size();
for (size_t i = 0; i < numInputs; i++) {
auto spec = inputSpecs_[i];
auto input = inputs[runArgsIdx_[i]];
runInputs.push_back(
{spec.logical_tensor(), Engine::getEngine(), input.data_ptr()});
}
auto numConstantInputs = constantInputs_.size();
for (size_t i = 0; i < numConstantInputs; i++) {
// constantInputSpecs are placed after graphInputSpecs
auto constantInputSpecIdx = nGraphInputs_ + i;
auto constantInputSpec = inputSpecs_[constantInputSpecIdx];
runInputs.push_back(
{constantLogicalTensors_[i],
Engine::getEngine(),
constantInputs_[i].data_ptr()});
}
for (size_t i = 0; i < nOutputs_; i++) {
auto spec = outputSpecs_[i];
auto opt = c10::TensorOptions(spec.aten_scalar_type()).device(device_);
if (spec.reuses_input_tensor()) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("oneDNN Graph would perform inplace computation");
#endif
auto inputTensor = inputs[spec.get_input_tensor_index()];
auto dataType = spec.dtype();
if (C10_UNLIKELY(!useOpaqueLayout(i) && inputTensor.is_mkldnn())) {
// If the input tensor was between two partitions, it would've been
// wrapped with LlgaTensorImpl. But if it's being reused as the output
// tensor which is not between two partitions, then we'd have to re-wrap
// it with TensorImpl, as it'd be fed into a PyTorch op.
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("Rewrap tensor");
#endif
auto llgaImpl =
static_cast<LlgaTensorImpl*>(inputTensor.unsafeGetTensorImpl());
switch (dataType) {
case data_type::f32:
case data_type::bf16:
inputTensor = LlgaTensorImpl::llga_to_aten_tensor(llgaImpl);
break;
case data_type::s32:
default:
TORCH_CHECK(
false, "Invalid data type ", static_cast<size_t>(dataType));
}
}
outputs.push_back(inputTensor);
runOutputs.push_back(
{spec.logical_tensor(), Engine::getEngine(), inputTensor.data_ptr()});
} else if (useOpaqueLayout(i)) {
// Wrap tensors between partitions with LlgaTensorImpl wrapper, so that we
// can bypass guard-check, as strides would be different than those
// expected.
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("Between two oneDNN Graph partitions");
#endif
auto tensor = empty_llga(spec, opt);
outputs.push_back(tensor);
runOutputs.push_back(llga_from_aten_tensor(tensor));
} else {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("Neither opaque to PyTorch nor inplace-computation");
#endif
auto tensor = at::empty_strided(spec.sizes(), spec.strides(), opt);
outputs.push_back(tensor);
runOutputs.push_back(
{spec.logical_tensor(), Engine::getEngine(), tensor.data_ptr()});
}
}
return std::make_tuple(runInputs, runOutputs);
}
compiled_partition LlgaKernel::compile(const partition& partition) {
auto inputs = fmap(inputSpecs_, toLogicalTensor);
auto outputs = fmap(outputSpecs_, toLogicalTensor);
auto compilation = partition.compile(inputs, outputs, Engine::getEngine());
// Since layouts of opaque outputs would be known after compilation,
// we need to query them out from compilation and update outputSpecs
for (size_t i = 0; i < nOutputs_; i++) {
auto tid = outputSpecs_[i].tid();
outputSpecs_[i] = compilation.query_logical_tensor(tid);
}
// Build static mapping from output id to input offset
// in accordance with available inplace options
for (auto&& option : compilation.get_inplace_ports()) {
size_t inputId = option.first;
size_t outputId = option.second;
auto inputSpecIter =
std::find_if(inputSpecs_.begin(), inputSpecs_.end(), [&](auto& spec) {
return spec.tid() == inputId;
});
TORCH_CHECK(inputSpecIter != inputSpecs_.end(), "In-place input not found");
auto inputOffset = inputSpecIter - inputSpecs_.begin();
auto outputSpecIter =
std::find_if(outputSpecs_.begin(), outputSpecs_.end(), [&](auto& spec) {
return spec.tid() == outputId;
});
auto outputOffset = outputSpecIter - outputSpecs_.begin();
outputSpecs_[outputOffset].set_compute_inplace();
outputSpecs_[outputOffset].set_input_tensor_index(inputOffset);
}
return compilation;
}
void LlgaKernel::run(Stack& stack) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("In ", debugName(), "\n");
#endif
// Grab input values from stack
auto stackInputs = last(stack, nGraphInputs_);
auto inputs = fmap(stackInputs, [&](const IValue& v) {
TORCH_CHECK(
v.isTensor(), "Stack values for LLGA partition must be Tensor type");
return v.toTensor();
});
// Even in case of concurrent threads, the kernel would be initialized once.
// TODO: Try not using an atomic lock
std::call_once(
initialized_flag,
[&](const TensorArgs& inputs) {
GRAPH_DEBUG("Initializing input logical tensors");
inputSpecs_ = initializeInputSpecs(inputs);
GRAPH_DEBUG("Initializing output logical tensors");
outputSpecs_ = initializeOutputSpecs();
GRAPH_DEBUG("Compiling partition");
compilation_ = compile(partition_);
is_initialized_ = true;
},
inputs);
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("Preparing runtime tensors");
#endif
TensorArgs outputs;
RunArgs runInputs, runOutputs;
std::tie(runInputs, runOutputs) = prepareRunArgs(inputs, outputs);
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("Executing partition");
#endif
compilation_.execute(Stream::getStream(), runInputs, runOutputs);
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("Partition executed");
#endif
// Update the stack.
drop(stack, nGraphInputs_);
for (auto& o : outputs)
push_one(stack, std::move(o));
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("Stack updated");
#endif
}
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,93 @@
#pragma once
#include <unordered_map>
#include <oneapi/dnnl/dnnl_graph.hpp>
#include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h>
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/interpreter.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
using ArgSpec = LlgaTensorDesc;
using ArgSpecs = std::vector<ArgSpec>;
using RunArg = dnnl::graph::tensor;
using RunArgs = std::vector<RunArg>;
using TensorArgs = std::vector<at::Tensor>;
class LlgaKernel {
public:
explicit LlgaKernel(const Node* fusionNode);
void run(Stack& stack);
void initialize(const TensorArgs& inputs);
const std::string& debugName() const {
return debugName_;
}
private:
bool useOpaqueLayout(size_t offset) const;
// PyTorch copy constants inside the subgraph instead of referencing them.
// Constants inputs to the partition are no longer in the graph->inputs().
// Need use the tid retrieved from the partition to find the missing
// constant inputs.
void initializeConstantInputs();
ArgSpecs initializeInputSpecs(const TensorArgs& inputs);
ArgSpecs initializeOutputSpecs() const;
dnnl::graph::compiled_partition compile(
const dnnl::graph::partition& partition);
std::map<size_t, int64_t> initializeTensorIdToOccurence() const;
std::tuple<RunArgs, RunArgs> prepareRunArgs(
const TensorArgs& inputs,
TensorArgs& outputs) const;
static std::string genDebugName() {
static size_t debugId = 0;
return "LlgaPartition_" + std::to_string(debugId++);
}
static dnnl::graph::logical_tensor toLogicalTensor(const ArgSpec& s) {
return s.logical_tensor();
}
at::Device device_ = at::kCPU;
const Node* fusionNode_;
std::shared_ptr<Graph> graph_;
int64_t nGraphInputs_ = 0; // number of inputs to graph_ on the IR
int64_t nOutputs_ = 0;
std::map<size_t, Value*> tensorIdToValue_;
std::vector<int64_t> runArgsIdx_;
dnnl::graph::partition partition_;
// nPartitionInputs_ is the actual number of inputs to partition_ of graph_
// needed by the backend.
// nPartitionInputs_ = nGraphInputs_ + constantInputs_.size() since Constant
// inputs are copied to the inside of the subgraph
int64_t nPartitionInputs_;
dnnl::graph::compiled_partition compilation_;
std::set<size_t> initializedInputIds_;
std::vector<Value*> constantValues_;
TensorArgs constantInputs_;
ArgSpecs inputSpecs_;
ArgSpecs outputSpecs_;
std::vector<dnnl::graph::logical_tensor> constantLogicalTensors_;
std::string debugName_;
std::once_flag initialized_flag;
bool is_initialized_ = false;
};
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,44 @@
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
#include <torch/csrc/jit/codegen/onednn/layout_propagation.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
void LayoutPropagation(Node* n) {
if (!LlgaGraphHelper::isLlgaSubgraph(n))
return;
for (auto input : n->inputs()) {
auto prev = input->node();
auto offset = input->offset();
if (LlgaGraphHelper::isLlgaSubgraph(prev)) {
bool useOpaqueLayout = true;
for (auto& use : input->uses()) {
if (!LlgaGraphHelper::isLlgaSubgraph(use.user)) {
useOpaqueLayout = false;
break;
}
}
if (useOpaqueLayout) {
LlgaNodeWrapper(prev).setOpaqueLayout(offset);
}
}
}
}
void LayoutPropagation(at::ArrayRef<Block*> blocks) {
for (Block* block : blocks)
for (Node* node : block->nodes())
LayoutPropagation(node);
}
void PropagateLayout(const std::shared_ptr<Graph>& graph) {
LayoutPropagation(graph->block());
}
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,15 @@
#pragma once
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
void PropagateLayout(const std::shared_ptr<Graph>& graph);
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,101 @@
#pragma once
#include <oneapi/dnnl/dnnl_graph.hpp>
#include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
class Operator {
public:
Operator(const Node* node, dnnl::graph::op::kind kind)
: n(node), o(getId(node), kind, node->kind().toQualString()), k(kind) {}
Operator& setInputValue(Value* v) {
if (v->mustNotBeNone())
o.add_input(createLogicalTensor(v));
return *this;
}
Operator& setInput(size_t offset) {
return setInputValue(n->input(offset));
}
template <typename... Ts>
Operator& setInput(size_t offset, Ts... other) {
setInput(offset);
return setInput(other...);
}
Operator& setOutputValue(Value* v) {
if (v->mustNotBeNone())
o.add_output(createLogicalTensor(v));
return *this;
}
Operator& setOutput(size_t offset) {
return setOutputValue(n->output(offset));
}
template <typename... Ts>
Operator& setOutput(size_t offset, Ts... other) {
setOutput(offset);
return setOutput(other...);
}
template <typename Attr>
Operator& setAttr(std::string name, Attr&& attr) {
o.set_attr(name, std::forward<Attr>(attr));
return *this;
}
template <typename F>
Operator& setAttr(std::string name, const F& fn, size_t offset) {
return setAttr(name, fn(n, offset));
}
static std::vector<int64_t> Ints(const Node* node, size_t offset) {
return toIValue(node->input(offset))->toIntVector();
}
static int64_t Int(const Node* node, size_t offset) {
return toIValue(node->input(offset))->toInt();
}
static float Float(const Node* node, size_t offset) {
return static_cast<float>(toIValue(node->input(offset))->toDouble());
}
static bool Bool(const Node* node, size_t offset) {
return toIValue(node->input(offset))->toBool();
}
static uint64_t getId(const Node* node) {
return reinterpret_cast<uint64_t>(node); // cast node address as op id
}
dnnl::graph::op::kind kind() const {
return k;
}
dnnl::graph::op llgaOp() const {
return o;
}
private:
dnnl::graph::logical_tensor createLogicalTensor(Value* value) const {
return LlgaTensorDesc(value).logical_tensor();
}
const Node* n;
dnnl::graph::op o;
dnnl::graph::op::kind k;
};
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,106 @@
#include <torch/csrc/jit/codegen/onednn/prepare_binary.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
bool compareConstValue(Value* v, double d) {
auto ival = toIValue(v);
return ival.has_value() &&
((ival->isInt() && static_cast<int>(ival->toInt()) == d) ||
(ival->isDouble() && ival->toDouble() == d));
}
void mayConvertScalarInputToTensor(Node* node) {
// We do not handle binary ops with two scalar inputs,
// and we assume scalar is always at the second place.
if (node->input(0)->type()->isSubtypeOf(TensorType::get()) &&
(node->input(1)->type()->isSubtypeOf(FloatType::get()) ||
node->input(1)->type()->isSubtypeOf(IntType::get()))) {
auto scalar = node->input(1);
WithInsertPoint guard(node);
auto g = node->owningGraph();
// 42 : Scalar --> tensor(42.0) : Float([])
auto t = g->insert(
aten::as_tensor, {scalar}, {{"dtype", at::ScalarType::Float}});
// add dim & stride info to IR
c10::optional<size_t> t_dim = 1;
auto target_type = TensorTypePtr(
TensorType::create(at::ScalarType::Float, at::kCPU, t_dim, false));
target_type = target_type->withSizes({1});
t->setType(target_type);
// tensor(42.0) : Float([]) --> tensor([42.0]) : Float([1])
auto unsqueezed = g->insert(aten::unsqueeze, {t, 0});
unsqueezed->setType(target_type);
node->replaceInput(1, unsqueezed);
}
}
static void ConvertScalarToTensor(Block* block) {
for (auto node : block->nodes()) {
for (auto sub : node->blocks()) {
ConvertScalarToTensor(sub);
}
if (node->kind() == aten::add || node->kind() == aten::mul) {
mayConvertScalarInputToTensor(node);
}
}
}
void mayDecomposeAdd(Node* node) {
if (toIValue(node->namedInput("alpha")).has_value()) {
auto alphaEqualsOne = compareConstValue(node->namedInput("alpha"), 1.0);
if (!alphaEqualsOne) {
WithInsertPoint guard(node);
auto g = node->owningGraph();
auto mul = g->insert(
aten::mul, {node->namedInput("other"), node->namedInput("alpha")});
node->replaceInput(1, mul);
auto one = g->insertConstant(1.0);
node->replaceInput(2, one);
}
}
}
static void DecomposeFusedAdd(Block* block) {
for (auto node : block->nodes()) {
for (auto sub : node->blocks()) {
DecomposeFusedAdd(sub);
}
if (node->kind() == aten::add) {
mayDecomposeAdd(node);
}
}
}
static void EliminateIdentityMulAdd(Block* block) {
for (auto node : block->nodes()) {
for (auto sub : node->blocks()) {
EliminateIdentityMulAdd(sub);
}
if ((node->kind() == aten::add && compareConstValue(node->input(1), 0.0)) ||
(node->kind() == aten::mul && compareConstValue(node->input(1), 1.0))) {
node->output()->replaceAllUsesWith(node->namedInput("self"));
}
}
}
void PrepareBinaryForLLGA(const std::shared_ptr<Graph>& graph) {
DecomposeFusedAdd(graph->block());
EliminateIdentityMulAdd(graph->block());
EliminateDeadCode(graph);
// ConvertScalarToTensor must be placed after EliminateIdentityMulAdd
ConvertScalarToTensor(graph->block());
}
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,26 @@
#pragma once
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
// Prepare binary ops for LLGA
//
// The pass does the following:
//
// - Convert scalar input of aten::add and aten::mul into Float tensor with
// dimension [1]
//
// - Decompose fused add into aten::mul + aten::add when alpha != 1.0
//
// - Eliminate identity add/mul, i.e., tensor + 0, tensor * 1
//
void PrepareBinaryForLLGA(const std::shared_ptr<Graph>& graph);
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,54 @@
#include <torch/csrc/jit/runtime/profiling_record.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
bool canFuseNode(const Node* node) {
switch (node->kind()) {
case aten::conv2d:
case aten::_convolution:
case aten::batch_norm:
case aten::layer_norm:
case aten::add:
case aten::mul:
case aten::tanh:
case aten::relu:
case aten::elu:
case aten::sigmoid:
case aten::gelu:
case aten::sqrt:
case aten::abs:
case aten::square:
case aten::hardtanh:
case aten::relu6:
case aten::softmax:
case aten::max_pool2d:
case aten::avg_pool2d:
case aten::matmul:
case aten::mm:
case aten::linear:
case aten::addmm:
return true;
default:
return false;
}
}
namespace {
class RegisterInterface {
public:
RegisterInterface() {
RegisterProfilingNode(canFuseNode);
}
};
static RegisterInterface register_interface_;
} // namespace
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -623,6 +623,7 @@ void AliasDb::analyzeImpl(Node* node) {
return analyzeLoop(node);
case prim::FusionGroup:
case prim::CudaFusionGroup:
case prim::oneDNNFusionGroup:
case prim::FunctionalGraph:
case prim::DifferentiableGraph:
case prim::FallbackGraph:

View File

@ -508,6 +508,7 @@ void Node::lint() const {
break;
case prim::FusionGroup:
case prim::CudaFusionGroup:
case prim::oneDNNFusionGroup:
checkSameDevice(this);
// TODO: Typecheck the parameters
g(attr::Subgraph)->lint();

View File

@ -21,7 +21,8 @@ bool canRunWithAutograd(Node* node) {
}
return kind != prim::FusionGroup && kind != prim::CudaFusionGroup &&
kind != prim::TypeCheck && kind != prim::TensorExprGroup &&
kind != prim::CudaFusionGuard && (kind.is_aten() || kind.is_prim());
kind != prim::CudaFusionGuard && kind != prim::oneDNNFusionGroup &&
kind != prim::oneDNNFusionGuard && (kind.is_aten() || kind.is_prim());
}
namespace {

View File

@ -0,0 +1,64 @@
#pragma once
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/pass_manager.h>
#include <ATen/Config.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
static std::atomic<bool> onednn_enabled{true};
static std::atomic<bool>& getLlgaEnabled() {
return onednn_enabled;
}
TORCH_API void fuseGraph(std::shared_ptr<Graph>& g);
} // namespace onednn
} // namespace fuser
struct C10_EXPORT RegisterLlgaFuseGraph
: public PassManager<RegisterLlgaFuseGraph> {
static bool setEnabled(bool enabled) {
TORCH_CHECK(
AT_MKLDNN_ENABLED(),
"Running oneDNN Graph fuser is only supported with MKLDNN builds.");
bool oldState = fuser::onednn::getLlgaEnabled();
fuser::onednn::getLlgaEnabled() = enabled;
if (enabled) {
registerPass(fuser::onednn::fuseGraph);
} else {
clearPass();
}
return oldState;
}
static bool isEnabled() {
return fuser::onednn::getLlgaEnabled();
}
// override PassManager::registerPass to register pre-pass
static bool registerPass(GraphPass p) {
if (!isRegistered()) {
passID(registerPrePass(std::move(p)), true);
isRegistered(true);
return false;
}
return true;
}
// override PassManager::clearPass to clear pre-pass
static void clearPass() {
if (isRegistered()) {
clearPrePass(passID());
isRegistered(true);
}
}
};
} // namespace jit
} // namespace torch

View File

@ -8,6 +8,9 @@
#include <torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.h>
#include <torch/csrc/jit/codegen/fuser/interface.h>
#include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
#include <torch/csrc/jit/codegen/onednn/interface.h>
#endif
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/irparser.h>
@ -630,6 +633,10 @@ void initJITBindings(PyObject* module) {
auto stack = toTraceableStack(args);
checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
})
#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
.def("_jit_set_llga_enabled", &RegisterLlgaFuseGraph::setEnabled)
.def("_jit_llga_enabled", &RegisterLlgaFuseGraph::isEnabled)
#endif
.def(
"_jit_set_nvfuser_skip_node_kind",
// Args:

View File

@ -246,6 +246,8 @@ bool printerHasSpecialCaseFor(Symbol sym) {
prim::StaticSubgraph, // optimization pass adds it
prim::ConstantMKLDNNTensor, // optimization pass adds it
prim::BroadcastMKLDNNTensors, // optimization pass adds it
prim::oneDNNFusionGroup, // optimization pass adds it
prim::oneDNNFusionGuard, // optimization pass adds it
prim::StaticRuntimeCopyOuts, // used in SR only
prim::Load, // used in interpreter only
prim::MMTreeReduce, // used as an optimization
@ -282,6 +284,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
prim::Loop,
prim::FusionGroup,
prim::CudaFusionGroup,
prim::oneDNNFusionGroup,
prim::DifferentiableGraph,
prim::TensorExprGroup,
prim::TensorExprDynamicGroup,

View File

@ -229,7 +229,19 @@ def _hide_source_ranges() -> Iterator[None]:
finally:
torch._C.Graph.set_global_print_source_ranges(old_enable_source_ranges) # type: ignore[attr-defined]
# dont expose Any, TODO: define `__all__`
def enable_onednn_fusion(enabled: bool):
"""
Enables or disables onednn JIT fusion based on the parameter `enabled`.
"""
torch._C._jit_set_llga_enabled(enabled)
def onednn_fusion_enabled():
"""
Returns whether onednn JIT fusion is enabled
"""
return torch._C._jit_llga_enabled()
del Any
if not torch._C._jit_init():