mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Reland take-2] Add JIT graph fuser for oneDNN Graph API (v0.5)
Re-landing #68111/#74596 ## Description v0.5 PR of this [RFC](https://github.com/pytorch/pytorch/issues/49444). On the basis of #50256, the below improvements are included: * The [v0.5 release branch](https://github.com/oneapi-src/oneDNN/releases/tag/graph-v0.5) of the oneDNN Graph API is used * The fuser now works with the profiling graph executor. We have inserted type check nodes to guard the profiled tensor properties. ### User API: The optimization pass is disabled by default. Users could enable it by: ``` torch.jit.enable_onednn_fusion(True) ``` `torch.jit.freeze` should be used after tracing (recommended) or scripting a model. ### Performance: [pytorch/benchmark](https://github.com/pytorch/benchmark) tool is used to compare the performance: * SkyLake 8180 (1 socket of 28 cores):  * SkyLake 8180 (single thread):  * By mapping hardswish to oneDNN Graph, it’s 8% faster than PyTorch JIT (NNC + OFI) ** We expect performance gain after mapping transpose, contiguous & view to oneDNN graph ops ### Directory structure of the integration code Fuser-related code is placed under: ``` torch/csrc/jit/codegen/onednn/ ``` Optimization pass registration is done in: ``` torch/csrc/jit/passes/onednn_graph_fuser.h ``` CMake for the integration code is in: ``` caffe2/CMakeLists.txt cmake/public/mkldnn.cmake cmake/Modules/FindMKLDNN.cmake ``` ## Limitations * In this PR, we only support Pytorch-oneDNN-Graph integration on Linux platform. Support on Windows and MacOS will be enabled as a next step. * We have only optimized the inference use-case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76622 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
94fc92f288
commit
4ee29d6033
@ -682,6 +682,8 @@ if(USE_FBGEMM AND ((CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND CMAKE_SIZEOF_VO
|
||||
set(USE_FBGEMM OFF)
|
||||
endif()
|
||||
|
||||
set(BUILD_ONEDNN_GRAPH OFF)
|
||||
|
||||
include(cmake/Dependencies.cmake)
|
||||
|
||||
if(USE_CUDA AND (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 10.2) AND (CMAKE_HOST_SYSTEM_NAME MATCHES "Windows"))
|
||||
|
@ -43,6 +43,8 @@ namespace c10 {
|
||||
_(prim, FusionGroup) \
|
||||
_(prim, CudaFusionGroup) \
|
||||
_(prim, CudaFusionGuard) \
|
||||
_(prim, oneDNNFusionGroup) \
|
||||
_(prim, oneDNNFusionGuard) \
|
||||
_(prim, FunctionalGraph) \
|
||||
_(prim, add_optional) \
|
||||
_(prim, view_copy) \
|
||||
@ -319,6 +321,7 @@ namespace c10 {
|
||||
_(attr, cache_id) \
|
||||
_(attr, new_axis) \
|
||||
_(attr, warn_id) \
|
||||
_(attr, output_layouts) \
|
||||
_(attr, allowzero) \
|
||||
_(attr, seen_none) \
|
||||
_(attr, overload_name)
|
||||
|
@ -657,6 +657,26 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/passes/frozen_conv_add_relu_fusion.cpp PROPERTIES COMPILE_FLAGS "-DUSE_CUDA=1")
|
||||
endif()
|
||||
|
||||
if(USE_MLCOMPUTE)
|
||||
include(../mlc/mlc_build.cmake)
|
||||
endif()
|
||||
|
||||
if(BUILD_ONEDNN_GRAPH)
|
||||
list(APPEND Caffe2_CPU_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/graph_fuser.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/graph_rewriter.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/graph_helper.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/register_interface.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/interface.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/kernel.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/defer_size_check.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/layout_propagation.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/prepare_binary.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/guard_shape.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
if(USE_ROCM)
|
||||
list(APPEND Caffe2_HIP_SRCS ${Caffe2_GPU_HIP_JIT_FUSERS_SRCS})
|
||||
if(USE_NCCL)
|
||||
|
@ -12,86 +12,118 @@
|
||||
# MKLDNN_USE_NATIVE_ARCH : Whether native CPU instructions should be used in MKLDNN. This should be turned off for
|
||||
# general packaging to avoid incompatible CPU instructions. Default: OFF.
|
||||
|
||||
IF (NOT MKLDNN_FOUND)
|
||||
IF(NOT MKLDNN_FOUND)
|
||||
SET(MKLDNN_LIBRARIES)
|
||||
SET(MKLDNN_INCLUDE_DIR)
|
||||
|
||||
SET(MKLDNN_LIBRARIES)
|
||||
SET(MKLDNN_INCLUDE_DIR)
|
||||
SET(IDEEP_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep")
|
||||
SET(MKLDNN_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep/mkl-dnn/third_party/oneDNN")
|
||||
IF(NOT APPLE AND NOT WIN32 AND NOT BUILD_LITE_INTERPRETER)
|
||||
MESSAGE("-- Will build oneDNN Graph")
|
||||
SET(LLGA_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep/mkl-dnn")
|
||||
SET(BUILD_ONEDNN_GRAPH ON)
|
||||
ENDIF(NOT APPLE AND NOT WIN32 AND NOT BUILD_LITE_INTERPRETER)
|
||||
|
||||
SET(IDEEP_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep")
|
||||
SET(MKLDNN_ROOT "${IDEEP_ROOT}/mkl-dnn/third_party/oneDNN")
|
||||
FIND_PACKAGE(BLAS)
|
||||
FIND_PATH(IDEEP_INCLUDE_DIR ideep.hpp PATHS ${IDEEP_ROOT} PATH_SUFFIXES include)
|
||||
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include)
|
||||
IF(NOT MKLDNN_INCLUDE_DIR)
|
||||
EXECUTE_PROCESS(COMMAND git${CMAKE_EXECUTABLE_SUFFIX} submodule update --init --jobs 0 mkl-dnn WORKING_DIRECTORY ${IDEEP_ROOT})
|
||||
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include)
|
||||
ENDIF(NOT MKLDNN_INCLUDE_DIR)
|
||||
IF(BUILD_ONEDNN_GRAPH)
|
||||
FIND_PATH(LLGA_INCLUDE_DIR oneapi/dnnl/dnnl_graph.hpp PATHS ${LLGA_ROOT} PATH_SUFFIXES include)
|
||||
ENDIF(BUILD_ONEDNN_GRAPH)
|
||||
|
||||
FIND_PACKAGE(BLAS)
|
||||
FIND_PATH(IDEEP_INCLUDE_DIR ideep.hpp PATHS ${IDEEP_ROOT} PATH_SUFFIXES include)
|
||||
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include)
|
||||
IF (NOT MKLDNN_INCLUDE_DIR)
|
||||
EXECUTE_PROCESS(COMMAND git${CMAKE_EXECUTABLE_SUFFIX} submodule update --init --jobs 0 mkl-dnn WORKING_DIRECTORY ${IDEEP_ROOT})
|
||||
FIND_PATH(MKLDNN_INCLUDE_DIR mkldnn.hpp mkldnn.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include)
|
||||
ENDIF(NOT MKLDNN_INCLUDE_DIR)
|
||||
IF(NOT IDEEP_INCLUDE_DIR OR NOT MKLDNN_INCLUDE_DIR)
|
||||
MESSAGE(STATUS "MKLDNN source files not found!")
|
||||
RETURN()
|
||||
ENDIF(NOT IDEEP_INCLUDE_DIR OR NOT MKLDNN_INCLUDE_DIR)
|
||||
LIST(APPEND MKLDNN_INCLUDE_DIR ${IDEEP_INCLUDE_DIR})
|
||||
IF(BUILD_ONEDNN_GRAPH)
|
||||
LIST(APPEND MKLDNN_INCLUDE_DIR ${LLGA_INCLUDE_DIR})
|
||||
ENDIF(BUILD_ONEDNN_GRAPH)
|
||||
IF(MKL_FOUND)
|
||||
ADD_DEFINITIONS(-DIDEEP_USE_MKL)
|
||||
# Append to mkldnn dependencies
|
||||
LIST(APPEND MKLDNN_LIBRARIES ${MKL_LIBRARIES})
|
||||
LIST(APPEND MKLDNN_INCLUDE_DIR ${MKL_INCLUDE_DIR})
|
||||
ELSE(MKL_FOUND)
|
||||
SET(MKLDNN_USE_MKL "NONE" CACHE STRING "" FORCE)
|
||||
ENDIF(MKL_FOUND)
|
||||
|
||||
IF (NOT IDEEP_INCLUDE_DIR OR NOT MKLDNN_INCLUDE_DIR)
|
||||
MESSAGE(STATUS "MKLDNN source files not found!")
|
||||
RETURN()
|
||||
ENDIF(NOT IDEEP_INCLUDE_DIR OR NOT MKLDNN_INCLUDE_DIR)
|
||||
LIST(APPEND MKLDNN_INCLUDE_DIR ${IDEEP_INCLUDE_DIR})
|
||||
IF(MKL_FOUND)
|
||||
ADD_DEFINITIONS(-DIDEEP_USE_MKL)
|
||||
# Append to mkldnn dependencies
|
||||
LIST(APPEND MKLDNN_LIBRARIES ${MKL_LIBRARIES})
|
||||
LIST(APPEND MKLDNN_INCLUDE_DIR ${MKL_INCLUDE_DIR})
|
||||
ELSE(MKL_FOUND)
|
||||
SET(MKLDNN_USE_MKL "NONE" CACHE STRING "" FORCE)
|
||||
ENDIF(MKL_FOUND)
|
||||
SET(MKL_cmake_included TRUE)
|
||||
IF(NOT MKLDNN_CPU_RUNTIME)
|
||||
SET(MKLDNN_CPU_RUNTIME "OMP" CACHE STRING "")
|
||||
ELSEIF(MKLDNN_CPU_RUNTIME STREQUAL "TBB")
|
||||
IF(USE_TBB)
|
||||
MESSAGE(STATUS "MKL-DNN is using TBB")
|
||||
|
||||
SET(MKL_cmake_included TRUE)
|
||||
IF (NOT MKLDNN_CPU_RUNTIME)
|
||||
SET(MKLDNN_CPU_RUNTIME "OMP" CACHE STRING "")
|
||||
ELSEIF (MKLDNN_CPU_RUNTIME STREQUAL "TBB")
|
||||
IF (USE_TBB)
|
||||
MESSAGE(STATUS "MKL-DNN is using TBB")
|
||||
SET(TBB_cmake_included TRUE)
|
||||
SET(Threading_cmake_included TRUE)
|
||||
|
||||
SET(TBB_cmake_included TRUE)
|
||||
SET(Threading_cmake_included TRUE)
|
||||
|
||||
SET(DNNL_CPU_THREADING_RUNTIME ${MKLDNN_CPU_RUNTIME})
|
||||
INCLUDE_DIRECTORIES(${TBB_INCLUDE_DIR})
|
||||
LIST(APPEND EXTRA_SHARED_LIBS TBB::tbb)
|
||||
ELSE()
|
||||
MESSAGE(FATAL_ERROR "MKLDNN_CPU_RUNTIME is set to TBB but TBB is not used")
|
||||
ENDIF()
|
||||
ENDIF()
|
||||
MESSAGE(STATUS "MKLDNN_CPU_RUNTIME = ${MKLDNN_CPU_RUNTIME}")
|
||||
|
||||
SET(MKLDNN_CPU_RUNTIME ${MKLDNN_CPU_RUNTIME} CACHE STRING "" FORCE)
|
||||
SET(DNNL_BUILD_TESTS FALSE CACHE BOOL "" FORCE)
|
||||
SET(DNNL_BUILD_EXAMPLES FALSE CACHE BOOL "" FORCE)
|
||||
SET(DNNL_LIBRARY_TYPE STATIC CACHE STRING "" FORCE)
|
||||
SET(DNNL_ENABLE_PRIMITIVE_CACHE TRUE CACHE BOOL "" FORCE)
|
||||
IF(MKLDNN_USE_NATIVE_ARCH) # Disable HostOpts in MKLDNN unless MKLDNN_USE_NATIVE_ARCH is set.
|
||||
SET(DNNL_ARCH_OPT_FLAGS "HostOpts" CACHE STRING "" FORCE)
|
||||
ELSE()
|
||||
IF(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
IF(CPU_INTEL)
|
||||
SET(DNNL_ARCH_OPT_FLAGS "-msse4" CACHE STRING "" FORCE)
|
||||
SET(DNNL_CPU_THREADING_RUNTIME ${MKLDNN_CPU_RUNTIME})
|
||||
INCLUDE_DIRECTORIES(${TBB_INCLUDE_DIR})
|
||||
LIST(APPEND EXTRA_SHARED_LIBS TBB::tbb)
|
||||
ELSE()
|
||||
MESSAGE(FATAL_ERROR "MKLDNN_CPU_RUNTIME is set to TBB but TBB is not used")
|
||||
ENDIF()
|
||||
ELSE()
|
||||
SET(DNNL_ARCH_OPT_FLAGS "" CACHE STRING "" FORCE)
|
||||
ENDIF()
|
||||
ENDIF()
|
||||
MESSAGE(STATUS "MKLDNN_CPU_RUNTIME = ${MKLDNN_CPU_RUNTIME}")
|
||||
|
||||
ADD_SUBDIRECTORY(${MKLDNN_ROOT})
|
||||
IF(NOT TARGET dnnl)
|
||||
MESSAGE("Failed to include MKL-DNN target")
|
||||
RETURN()
|
||||
ENDIF(NOT TARGET dnnl)
|
||||
SET(MKLDNN_CPU_RUNTIME ${MKLDNN_CPU_RUNTIME} CACHE STRING "" FORCE)
|
||||
SET(DNNL_BUILD_TESTS FALSE CACHE BOOL "" FORCE)
|
||||
SET(DNNL_BUILD_EXAMPLES FALSE CACHE BOOL "" FORCE)
|
||||
SET(DNNL_LIBRARY_TYPE STATIC CACHE STRING "" FORCE)
|
||||
SET(DNNL_ENABLE_PRIMITIVE_CACHE TRUE CACHE BOOL "" FORCE)
|
||||
IF(BUILD_ONEDNN_GRAPH)
|
||||
SET(DNNL_GRAPH_LIBRARY_TYPE STATIC CACHE STRING "" FORCE)
|
||||
ENDIF(BUILD_ONEDNN_GRAPH)
|
||||
IF(MKLDNN_USE_NATIVE_ARCH) # Disable HostOpts in MKLDNN unless MKLDNN_USE_NATIVE_ARCH is set.
|
||||
SET(DNNL_ARCH_OPT_FLAGS "HostOpts" CACHE STRING "" FORCE)
|
||||
ELSE()
|
||||
IF(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
IF(CPU_INTEL)
|
||||
SET(DNNL_ARCH_OPT_FLAGS "-msse4" CACHE STRING "" FORCE)
|
||||
ENDIF()
|
||||
ELSE()
|
||||
SET(DNNL_ARCH_OPT_FLAGS "" CACHE STRING "" FORCE)
|
||||
ENDIF()
|
||||
ENDIF()
|
||||
|
||||
IF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC)
|
||||
TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-maybe-uninitialized)
|
||||
TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-strict-overflow)
|
||||
TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-error=strict-overflow)
|
||||
ENDIF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC)
|
||||
LIST(APPEND MKLDNN_LIBRARIES dnnl)
|
||||
IF(BUILD_ONEDNN_GRAPH)
|
||||
ADD_SUBDIRECTORY(${LLGA_ROOT})
|
||||
IF(NOT TARGET dnnl_graph)
|
||||
MESSAGE("Failed to include LLGA target")
|
||||
RETURN()
|
||||
ENDIF(NOT TARGET dnnl_graph)
|
||||
|
||||
SET(MKLDNN_FOUND TRUE)
|
||||
MESSAGE(STATUS "Found MKL-DNN: TRUE")
|
||||
IF(CMAKE_COMPILER_IS_GNUCC)
|
||||
TARGET_COMPILE_OPTIONS(dnnl_graph PRIVATE -Wno-maybe-uninitialized)
|
||||
TARGET_COMPILE_OPTIONS(dnnl_graph PRIVATE -Wno-strict-overflow)
|
||||
TARGET_COMPILE_OPTIONS(dnnl_graph PRIVATE -Wno-error=strict-overflow)
|
||||
ENDIF(CMAKE_COMPILER_IS_GNUCC)
|
||||
ELSE(BUILD_ONEDNN_GRAPH)
|
||||
ADD_SUBDIRECTORY(${MKLDNN_ROOT})
|
||||
ENDIF(BUILD_ONEDNN_GRAPH)
|
||||
|
||||
IF(NOT TARGET dnnl)
|
||||
MESSAGE("Failed to include MKL-DNN target")
|
||||
RETURN()
|
||||
ENDIF(NOT TARGET dnnl)
|
||||
|
||||
IF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC)
|
||||
TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-maybe-uninitialized)
|
||||
TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-strict-overflow)
|
||||
TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-error=strict-overflow)
|
||||
ENDIF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC)
|
||||
LIST(APPEND MKLDNN_LIBRARIES ${MKL_OPENMP_LIBRARY})
|
||||
IF(BUILD_ONEDNN_GRAPH)
|
||||
LIST(APPEND MKLDNN_LIBRARIES "$<TARGET_FILE:dnnl_graph>")
|
||||
ENDIF(BUILD_ONEDNN_GRAPH)
|
||||
LIST(APPEND MKLDNN_LIBRARIES dnnl)
|
||||
|
||||
SET(MKLDNN_FOUND TRUE)
|
||||
MESSAGE(STATUS "Found MKL-DNN: TRUE")
|
||||
|
||||
ENDIF(NOT MKLDNN_FOUND)
|
||||
|
@ -16,3 +16,15 @@ set_property(
|
||||
set_property(
|
||||
TARGET caffe2::mkldnn PROPERTY INTERFACE_LINK_LIBRARIES
|
||||
${MKLDNN_LIBRARIES})
|
||||
if(BUILD_ONEDNN_GRAPH)
|
||||
if(NOT TARGET caffe2::dnnl_graph)
|
||||
add_library(caffe2::dnnl_graph INTERFACE IMPORTED)
|
||||
endif()
|
||||
|
||||
set_property(
|
||||
TARGET caffe2::dnnl_graph PROPERTY INTERFACE_INCLUDE_DIRECTORIES
|
||||
${MKLDNN_INCLUDE_DIR})
|
||||
set_property(
|
||||
TARGET caffe2::dnnl_graph PROPERTY INTERFACE_LINK_LIBRARIES
|
||||
${MKLDNN_LIBRARIES})
|
||||
endif()
|
||||
|
@ -61,6 +61,8 @@ Creating TorchScript Code
|
||||
ScriptFunction
|
||||
freeze
|
||||
optimize_for_inference
|
||||
enable_onednn_fusion
|
||||
onednn_fusion_enabled
|
||||
set_fusion_strategy
|
||||
strict_fusion
|
||||
save
|
||||
|
519
test/test_jit_llga_fuser.py
Normal file
519
test/test_jit_llga_fuser.py
Normal file
@ -0,0 +1,519 @@
|
||||
# Owner(s): ["module: mkldnn"]
|
||||
import torch
|
||||
import unittest
|
||||
import itertools
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_SCIPY, IS_WINDOWS, IS_MACOS
|
||||
|
||||
LLGA_FUSION_GROUP = 'prim::oneDNNFusionGroup'
|
||||
LLGA_NOT_ENABLED = not torch._C.has_mkldnn or IS_WINDOWS or IS_MACOS
|
||||
|
||||
|
||||
def warmup_forward(f, *args, profiling_count=2):
|
||||
for i in range(profiling_count):
|
||||
results = f(*args)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class JitLlgaTestCase(JitTestCase):
|
||||
def setUp(self):
|
||||
torch.jit.enable_onednn_fusion(True)
|
||||
|
||||
def tearDown(self):
|
||||
torch.jit.enable_onednn_fusion(False)
|
||||
|
||||
def checkTrace(self, m, x, *args, **kwargs):
|
||||
if isinstance(m, torch.nn.Module):
|
||||
m.eval()
|
||||
with torch.no_grad(), \
|
||||
torch._jit_internal._disable_emit_hooks():
|
||||
traced = torch.jit.trace(m, x)
|
||||
if isinstance(m, torch.nn.Module):
|
||||
traced = torch.jit.freeze(traced)
|
||||
warmup_forward(traced, *x)
|
||||
fwd_graph = traced.graph_for(*x)
|
||||
|
||||
ref_o = m(*x)
|
||||
jit_o = traced(*x)
|
||||
self.assertEqual(jit_o, ref_o)
|
||||
return traced, fwd_graph
|
||||
|
||||
def assertFused(self, graph, fused_patterns):
|
||||
for pat in fused_patterns:
|
||||
self.assertGraphContainsExactly(graph, pat, 0)
|
||||
|
||||
|
||||
try:
|
||||
import torchvision
|
||||
HAS_TORCHVISION = True
|
||||
except ImportError:
|
||||
HAS_TORCHVISION = False
|
||||
except RuntimeError:
|
||||
HAS_TORCHVISION = False
|
||||
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, 'no torchvision')
|
||||
|
||||
def get_eltwise_fn(name):
|
||||
if hasattr(torch, name):
|
||||
return getattr(torch, name)
|
||||
elif hasattr(F, name):
|
||||
return getattr(F, name)
|
||||
else:
|
||||
raise NameError('Eltwise function %s not found' % name)
|
||||
|
||||
|
||||
@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
|
||||
class TestOp(JitLlgaTestCase):
|
||||
def test_conv2d(self):
|
||||
for [spatial, in_channels, out_channels, kernel, padding, stride, dilation, g, bias] in itertools.product(
|
||||
[7, 8],
|
||||
[8, 15],
|
||||
[7, 16],
|
||||
[3, 4],
|
||||
[0, 2],
|
||||
[1, 2],
|
||||
[1, 2],
|
||||
[1, 2],
|
||||
[True, False]):
|
||||
|
||||
m = nn.Conv2d(in_channels=in_channels * g,
|
||||
out_channels=out_channels * g,
|
||||
kernel_size=kernel,
|
||||
padding=padding,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
groups=g,
|
||||
bias=bias)
|
||||
|
||||
x = torch.rand(1, in_channels * g, spatial, spatial)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_bn2d(self):
|
||||
m = nn.BatchNorm2d(32).eval()
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
# single-op partition shouldn't be created for softmax
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
|
||||
|
||||
def test_eltwise(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn):
|
||||
super(M, self).__init__()
|
||||
self.eltwise = eltwise_fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.eltwise(x)
|
||||
|
||||
for eltwise in ['relu', 'gelu']:
|
||||
eltwise_fn = get_eltwise_fn(eltwise)
|
||||
m = M(eltwise_fn)
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
# single-op partition shouldn't be created.
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
|
||||
|
||||
def test_max_pool2d(self):
|
||||
for [spatial, kernel, padding, stride, dilation, ceil_mode] in itertools.product(
|
||||
[15, 16, 17, 18, 19],
|
||||
[4, 5],
|
||||
[0, 1, 2],
|
||||
[1, 2], # [1, 2, 4], TODO: fix issue in pad calculation
|
||||
[1], # [1, 2], TODO: backend support for dilation
|
||||
[True, False]):
|
||||
|
||||
m = nn.MaxPool2d(kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
ceil_mode=ceil_mode)
|
||||
|
||||
x = torch.rand(1, 4, spatial, spatial)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_avg_pool2d(self):
|
||||
for [spatial, kernel, padding, stride, ceil_mode, count_include_pad] in itertools.product(
|
||||
[15, 16, 17, 18, 19],
|
||||
[4, 5],
|
||||
[0, 1, 2],
|
||||
[1, 2, 4],
|
||||
[False], # TODO: oneDNN Graph does not fully support ceil_mode=True
|
||||
[True, False]):
|
||||
|
||||
m = nn.AvgPool2d(kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
ceil_mode=ceil_mode,
|
||||
count_include_pad=count_include_pad)
|
||||
|
||||
x = torch.rand(1, 4, spatial, spatial)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_variable_kernel_avg_pool2d(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
x = F.avg_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=0, count_include_pad=False)
|
||||
return x
|
||||
|
||||
x = torch.randn(1, 1000, 1, 1)
|
||||
m = M()
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
# kernel_size is not Constant, shouldn't have any LLGA_FUSION_GROUP
|
||||
# TODO: with shape specialization, should have 1 LLGA_FUSION_GROUP
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
|
||||
|
||||
def test_softmax(self):
|
||||
for dim in [-4, -3, -2, -1, 0, 1, 2, 3]:
|
||||
m = nn.Softmax(dim=dim)
|
||||
x = torch.rand(8, 12, 12, 12)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
# single-op partition shouldn't be created for softmax
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
|
||||
|
||||
def test_linear(self):
|
||||
for bias in [True, False]:
|
||||
x = torch.rand(32, 28)
|
||||
m = torch.nn.Linear(in_features=28, out_features=64, bias=bias)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::linear'])
|
||||
|
||||
def _gen_binary_inputs(self, gen_permute=True):
|
||||
for xshape, yshape in [
|
||||
[[1, 32, 28, 28], [1, 32, 28, 28]],
|
||||
[[1, 32, 28, 28], [1, 1, 28, 28]],
|
||||
[[1, 32, 28, 28], [28]],
|
||||
[[1, 32, 28, 28], [1]],
|
||||
|
||||
]:
|
||||
yield torch.rand(xshape), torch.rand(yshape)
|
||||
if gen_permute and xshape != yshape:
|
||||
yield torch.rand(yshape), torch.rand(xshape)
|
||||
|
||||
def test_add(self):
|
||||
def forward_add(x, y):
|
||||
return torch.add(x, y, alpha=2)
|
||||
|
||||
for x, y in self._gen_binary_inputs():
|
||||
_, graph = self.checkTrace(forward_add, [x, y])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_add_scalar(self):
|
||||
def add_scalar(x):
|
||||
return 42 + x + 3.14
|
||||
|
||||
x = torch.rand(32, 32)
|
||||
_, graph = self.checkTrace(add_scalar, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_addmm(self):
|
||||
def addmm(x, y, z):
|
||||
# alpha and beta are 1, by default
|
||||
return torch.addmm(z, x, y)
|
||||
|
||||
x = torch.rand(64, 32)
|
||||
y = torch.rand(32, 32)
|
||||
z = torch.rand(64, 32)
|
||||
_, graph = self.checkTrace(addmm, [x, y, z])
|
||||
# single-op partition should be created for matmul with bias.
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_mul(self):
|
||||
def forward_mul(x, y):
|
||||
return torch.mul(x, y) * 3
|
||||
|
||||
for x, y in self._gen_binary_inputs():
|
||||
_, graph = self.checkTrace(forward_mul, [x, y])
|
||||
# single-op partitions shouldn't be created
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_identity_binary(self):
|
||||
def forward(x):
|
||||
return x * 1 + 0.0
|
||||
|
||||
x = torch.rand(32)
|
||||
_, graph = self.checkTrace(forward, [x])
|
||||
self.assertFused(graph, ['aten::add', 'aten::mul'])
|
||||
|
||||
def test_layer_norm(self):
|
||||
# TODO: support more normalized_shape
|
||||
m = torch.nn.LayerNorm(10)
|
||||
x = torch.randn(2, 5, 10, 10)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_cat(self):
|
||||
def cat_along_dim(d):
|
||||
def forward_cat(*inputs):
|
||||
return torch.cat(inputs, d)
|
||||
return forward_cat
|
||||
|
||||
for xshape in [
|
||||
[8, 8, 8, 8],
|
||||
[64, 8, 32],
|
||||
[2048, 64],
|
||||
]:
|
||||
for d in range(len(xshape)):
|
||||
x = torch.rand(xshape)
|
||||
_, graph = self.checkTrace(cat_along_dim(d), [x, x, x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_typecheck(self):
|
||||
x = torch.rand(32, 28)
|
||||
m = torch.nn.Linear(in_features=28, out_features=64, bias=True)
|
||||
traced, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::linear'])
|
||||
# change the shape of the input, we should enter fallback graph
|
||||
x = torch.rand(5, 28)
|
||||
self.assertEqual(m(x), traced(x))
|
||||
|
||||
|
||||
@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
|
||||
class TestFusionPattern(JitLlgaTestCase):
|
||||
def test_conv2d_eltwise(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=False)
|
||||
self.eltwise = eltwise_fn
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.eltwise(x)
|
||||
x = self.conv2(x)
|
||||
x = self.eltwise(x)
|
||||
return x
|
||||
|
||||
# for eltwise in ['relu', 'sigmoid', 'sqrt', 'abs', 'square', 'hardtanh']:
|
||||
for eltwise in ['relu']:
|
||||
for inplace in [True, False]:
|
||||
eltwise_fn_name = eltwise + '_' if inplace else eltwise
|
||||
eltwise_fn = get_eltwise_fn(eltwise_fn_name)
|
||||
|
||||
m = M(eltwise_fn)
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
|
||||
# test if relu_ is replace with relu by mutation removal pass
|
||||
self.assertFused(graph, ['aten::' + eltwise_fn_name])
|
||||
# test if relu is fused into the fusion group
|
||||
self.assertFused(graph, ['aten::' + eltwise])
|
||||
|
||||
def test_conv2d_bn(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.bn1 = nn.BatchNorm2d(32)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
return x
|
||||
|
||||
m = M().eval()
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm'])
|
||||
|
||||
|
||||
def test_conv2d_bn_relu(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.bn1 = nn.BatchNorm2d(32)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = F.relu(x)
|
||||
return x
|
||||
|
||||
m = M().eval()
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm',
|
||||
'aten::relu'])
|
||||
|
||||
def test_bn2d_eltwise(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn):
|
||||
super(M, self).__init__()
|
||||
self.eltwise = eltwise_fn
|
||||
self.bn = nn.BatchNorm2d(32)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.bn(x)
|
||||
x = self.eltwise(x)
|
||||
return x
|
||||
|
||||
for eltwise in ['relu']:
|
||||
eltwise_fn = get_eltwise_fn(eltwise)
|
||||
m = M(eltwise_fn).eval()
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::' + eltwise])
|
||||
|
||||
def test_linear_eltwise(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn, bias):
|
||||
super(M, self).__init__()
|
||||
self.linear = nn.Linear(28, 64, bias)
|
||||
self.eltwise = eltwise_fn
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = self.eltwise(x)
|
||||
return x
|
||||
|
||||
for [has_bias, eltwise] in itertools.product(
|
||||
[True, False],
|
||||
['relu', 'gelu', 'sigmoid', 'hardtanh', 'relu6', 'elu']):
|
||||
|
||||
eltwise_fn = get_eltwise_fn(eltwise)
|
||||
m = M(eltwise_fn, has_bias)
|
||||
x = torch.rand(32, 28, requires_grad=False)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::' + eltwise])
|
||||
|
||||
def test_conv2d_sum(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, bias=False):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
|
||||
self.bn1 = nn.BatchNorm2d(32)
|
||||
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
|
||||
self.bn2 = nn.BatchNorm2d(32)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
|
||||
self.bn3 = nn.BatchNorm2d(32)
|
||||
|
||||
def forward(self, x, y):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
y = self.conv2(y)
|
||||
y = self.bn2(y)
|
||||
z = self.relu(x + y)
|
||||
z = self.conv3(z)
|
||||
z = self.bn3(z)
|
||||
return z
|
||||
|
||||
for bias in [True, False]:
|
||||
m = M(bias).eval()
|
||||
x = torch.rand(1, 32, 16, 16, requires_grad=False)
|
||||
y = torch.rand(1, 32, 16, 16, requires_grad=False)
|
||||
_, graph = self.checkTrace(m, [x, y])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
|
||||
|
||||
def test_wildcard(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.eltwise = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
y = self.eltwise(x)
|
||||
return [x, y]
|
||||
|
||||
# The pattern is as the following:
|
||||
# conv
|
||||
# | \
|
||||
# eltwise \
|
||||
# | \
|
||||
# ListConstruct
|
||||
#
|
||||
# The output of conv is used by a wildcard op: ListConstruct.
|
||||
# Thus conv-eltwise cannot be selected into the same Partition.
|
||||
m = M()
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
# conv can exist in a single-op oneDNN Graph partition but not relu
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::_convolution'])
|
||||
|
||||
def test_rewrap_tensor_input_to_pytorch(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn, data_type):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True, dtype=data_type)
|
||||
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True, dtype=data_type)
|
||||
self.eltwise = eltwise_fn
|
||||
self.adaptive_avg_pool_2d = nn.AdaptiveAvgPool2d((5, 7))
|
||||
|
||||
def forward(self, x, y):
|
||||
x = self.conv1(x)
|
||||
x = self.eltwise(x)
|
||||
x = self.conv2(x)
|
||||
x = self.eltwise(x)
|
||||
x = torch.add(x, y)
|
||||
x = self.adaptive_avg_pool_2d(x)
|
||||
return x
|
||||
|
||||
eltwise_fn_name = 'relu'
|
||||
eltwise_fn = get_eltwise_fn(eltwise_fn_name)
|
||||
# Add bfloat16 later
|
||||
for data_type in [torch.float]:
|
||||
m = M(eltwise_fn, data_type)
|
||||
m = m.to(memory_format=torch.channels_last)
|
||||
x = torch.rand(1, 32, 28, 28, dtype=data_type).to(memory_format=torch.channels_last)
|
||||
y = torch.rand(1, 32, 28, 28, dtype=data_type).to(memory_format=torch.channels_last)
|
||||
# Simply test if the output is accurate
|
||||
# The output of the second partition is input to adaptive_avg_pool2d, which is
|
||||
# unsupported by LLGA, so it must be handled by PyTorch, which should receive
|
||||
# correct strides info of the channels-last tensor.
|
||||
graph, _ = self.checkTrace(m, [x, y])
|
||||
|
||||
|
||||
@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
|
||||
class TestModel(JitLlgaTestCase):
|
||||
@skipIfNoTorchVision
|
||||
def _test_vision(self, model_name):
|
||||
m = getattr(torchvision.models, model_name)().eval()
|
||||
x = torch.rand(1, 3, 224, 224) / 10
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm',
|
||||
'aten::relu', 'aten::linear',
|
||||
'aten::avg_pool2d', 'aten::max_pool2d'])
|
||||
|
||||
|
||||
for model_name, enabled in [
|
||||
['resnet50', True],
|
||||
['resnext50_32x4d', True],
|
||||
['resnext101_32x8d', True],
|
||||
['densenet121', True],
|
||||
['googlenet', TEST_SCIPY],
|
||||
['mobilenet_v2', True],
|
||||
['mnasnet1_0', True],
|
||||
['squeezenet1_0', True],
|
||||
['vgg16', True],
|
||||
['alexnet', True],
|
||||
['shufflenet_v2_x1_0', True],
|
||||
['wide_resnet50_2', True],
|
||||
]:
|
||||
def wrapper(mname):
|
||||
@unittest.skipIf(not enabled, 'Disabled')
|
||||
def test(self):
|
||||
return self._test_vision(mname)
|
||||
return test
|
||||
|
||||
setattr(TestModel, 'test_vision_%s' % model_name, wrapper(model_name))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
@ -14,7 +14,7 @@ if(NOT BUILD_PYTHON)
|
||||
endif()
|
||||
|
||||
if(USE_TBB)
|
||||
include_directories(${TBB_INCLUDE_DIR})
|
||||
include_directories(${TBB_INCLUDE_DIR})
|
||||
endif()
|
||||
|
||||
set(TORCH_SRC_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
@ -423,6 +423,10 @@ target_compile_options(torch_python PRIVATE ${TORCH_PYTHON_COMPILE_OPTIONS})
|
||||
|
||||
target_include_directories(torch_python PUBLIC ${TORCH_PYTHON_INCLUDE_DIRECTORIES})
|
||||
|
||||
if(BUILD_ONEDNN_GRAPH)
|
||||
target_compile_definitions(torch_python PRIVATE "-DBUILD_ONEDNN_GRAPH")
|
||||
target_compile_definitions(torch_cpu PRIVATE "-DBUILD_ONEDNN_GRAPH")
|
||||
endif()
|
||||
|
||||
if(NOT TORCH_PYTHON_LINK_FLAGS STREQUAL "")
|
||||
set_target_properties(torch_python PROPERTIES LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS})
|
||||
|
@ -228,6 +228,8 @@ def _debug_get_fusion_group_inlining() -> _bool: ...
|
||||
def _debug_set_fusion_group_inlining(enable: _bool): ...
|
||||
def _jit_texpr_fuser_enabled() -> _bool: ...
|
||||
def _jit_nvfuser_enabled() -> _bool: ...
|
||||
def _jit_llga_enabled() -> _bool: ...
|
||||
def _jit_set_llga_enabled(enable: _bool): ...
|
||||
def _llvm_enabled() -> _bool: ...
|
||||
def _jit_override_can_fuse_on_cpu(override: _bool): ...
|
||||
def _jit_override_can_fuse_on_gpu(override: _bool): ...
|
||||
|
132
torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp
Normal file
132
torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp
Normal file
@ -0,0 +1,132 @@
|
||||
#include <ATen/Config.h>
|
||||
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
dnnl::graph::engine& Engine::getEngine() {
|
||||
static dnnl::graph::engine cpu_engine(
|
||||
dnnl::graph::engine::kind::cpu, /* device_id = */ 0);
|
||||
return cpu_engine;
|
||||
}
|
||||
|
||||
dnnl::graph::stream& Stream::getStream() {
|
||||
static dnnl::graph::stream cpu_stream{Engine::getEngine(), nullptr};
|
||||
return cpu_stream;
|
||||
}
|
||||
|
||||
LlgaTensorImpl::LlgaTensorImpl(
|
||||
at::Storage&& storage,
|
||||
const caffe2::TypeMeta& data_type,
|
||||
const LlgaTensorDesc& desc)
|
||||
: at::TensorImpl(
|
||||
std::move(storage),
|
||||
c10::DispatchKeySet(c10::DispatchKey::MkldnnCPU),
|
||||
data_type),
|
||||
desc_(desc) {
|
||||
set_sizes_and_strides(desc.sizes(), desc.strides());
|
||||
refresh_numel();
|
||||
}
|
||||
|
||||
at::Tensor LlgaTensorImpl::llga_to_aten_tensor(LlgaTensorImpl* llgaImpl) {
|
||||
auto aten_tensor = at::detail::make_tensor<TensorImpl>(
|
||||
std::move(llgaImpl->storage_),
|
||||
c10::DispatchKeySet(c10::DispatchKey::CPU),
|
||||
llgaImpl->data_type_);
|
||||
auto impl = aten_tensor.unsafeGetTensorImpl();
|
||||
impl->set_storage_offset(llgaImpl->storage_offset_);
|
||||
impl->set_sizes_and_strides(llgaImpl->sizes(), llgaImpl->strides());
|
||||
return aten_tensor;
|
||||
}
|
||||
|
||||
at::Tensor empty_llga(
|
||||
const LlgaTensorDesc& desc,
|
||||
const c10::TensorOptions& options) {
|
||||
auto nbytes = desc.storage_size();
|
||||
|
||||
auto allocator = at::GetCPUAllocator();
|
||||
auto storage_impl = c10::make_intrusive<c10::StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
nbytes,
|
||||
allocator->allocate(nbytes),
|
||||
allocator,
|
||||
/*resizable=*/false);
|
||||
|
||||
return at::detail::make_tensor<LlgaTensorImpl>(
|
||||
std::move(storage_impl), options.dtype(), desc);
|
||||
}
|
||||
|
||||
const LlgaTensorDesc& get_llga_desc(const at::Tensor& tensor) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
tensor.is_mkldnn(), "get_llga_desc expects Mkldnn tensor input");
|
||||
return static_cast<LlgaTensorImpl*>(tensor.unsafeGetTensorImpl())->desc();
|
||||
}
|
||||
|
||||
dnnl::graph::tensor llga_from_aten_tensor(const at::Tensor& tensor) {
|
||||
return {
|
||||
get_llga_desc(tensor).logical_tensor(),
|
||||
torch::jit::fuser::onednn::Engine::getEngine(),
|
||||
tensor.data_ptr()};
|
||||
}
|
||||
|
||||
using data_type = dnnl::graph::logical_tensor::data_type;
|
||||
|
||||
data_type getLlgaDataType(at::ScalarType dt) {
|
||||
switch (dt) {
|
||||
case at::ScalarType::Float:
|
||||
return data_type::f32;
|
||||
case at::ScalarType::BFloat16:
|
||||
return data_type::bf16;
|
||||
case at::kInt:
|
||||
return data_type::s32;
|
||||
case at::ScalarType::QInt8:
|
||||
return data_type::s8;
|
||||
case at::ScalarType::QUInt8:
|
||||
return data_type::u8;
|
||||
default:
|
||||
TORCH_CHECK(false, "Not support data type ", dt);
|
||||
}
|
||||
}
|
||||
|
||||
LlgaTensorDesc LlgaTensorDesc::supplementTensorInfo(const at::Tensor& t) const {
|
||||
if (t.is_mkldnn()) {
|
||||
// if input tensor is of mkldnn, it's originated from an upstream
|
||||
// LLGA partition which carries opaque layout info
|
||||
return get_llga_desc(t).tid(tid_);
|
||||
} else {
|
||||
// if input tensor is not an mkldnn tensor, use default layout
|
||||
auto sizes = t.sizes().vec();
|
||||
auto strides = t.strides().vec();
|
||||
auto dtype = getLlgaDataType(t.scalar_type());
|
||||
return {tid_, sizes, strides, dtype, property_type_};
|
||||
}
|
||||
}
|
||||
|
||||
at::ScalarType LlgaTensorDesc::aten_scalar_type() const {
|
||||
switch (dtype_) {
|
||||
case data_type::f32:
|
||||
return at::ScalarType::Float;
|
||||
case data_type::bf16:
|
||||
return at::ScalarType::BFloat16;
|
||||
case data_type::s32:
|
||||
return at::kInt;
|
||||
case data_type::s8:
|
||||
return at::ScalarType::QInt8;
|
||||
case data_type::u8:
|
||||
return at::ScalarType::QUInt8;
|
||||
default:
|
||||
TORCH_CHECK(false, "Invalid data type ", static_cast<size_t>(dtype_));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
#endif // AT_MKLDNN_ENABLED()
|
273
torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h
Normal file
273
torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h
Normal file
@ -0,0 +1,273 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Config.h>
|
||||
|
||||
#include <oneapi/dnnl/dnnl_graph.hpp>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
// Engine represents a device and its context. From the device kind, the engine
|
||||
// knows how to generate code for the target device and what kind of device
|
||||
// object to be expected. The device id ensures that there is a unique engine
|
||||
// being created for each device. The device handle passed from PyTorch allows
|
||||
// oneDNN Graph implementation to work on the device specified by PyTorch, which
|
||||
// is currently CPU, so we only have one engine.
|
||||
// Ref: https://spec.oneapi.io/onednn-graph/latest/programming_model.html#engine
|
||||
struct Engine {
|
||||
// CPU engine singleton
|
||||
static dnnl::graph::engine& getEngine();
|
||||
Engine(const Engine&) = delete;
|
||||
void operator=(const Engine&) = delete;
|
||||
};
|
||||
|
||||
// Stream is the logical abstraction for execution units. It is created on top
|
||||
// of oneDNN Graph engine. A compiled oneDNN Graph partition is submitted to a
|
||||
// stream for execution.
|
||||
struct Stream {
|
||||
// CPU stream singleton
|
||||
static dnnl::graph::stream& getStream();
|
||||
Stream(const Stream&) = delete;
|
||||
void operator=(const Stream&) = delete;
|
||||
};
|
||||
|
||||
struct LlgaTensorDesc {
|
||||
using desc = dnnl::graph::logical_tensor;
|
||||
|
||||
LlgaTensorDesc(
|
||||
size_t tid,
|
||||
std::vector<int64_t> sizes,
|
||||
std::vector<int64_t> strides,
|
||||
desc::data_type dtype,
|
||||
desc::property_type property_type)
|
||||
: tid_(tid),
|
||||
sizes_(sizes),
|
||||
strides_(strides),
|
||||
dtype_(dtype),
|
||||
property_type_(property_type),
|
||||
layout_type_(desc::layout_type::strided),
|
||||
layout_id_(-1) {}
|
||||
|
||||
LlgaTensorDesc(const desc& t)
|
||||
: tid_(t.get_id()),
|
||||
sizes_(t.get_dims()),
|
||||
strides_({-1}),
|
||||
dtype_(t.get_data_type()),
|
||||
property_type_(t.get_property_type()),
|
||||
layout_type_(t.get_layout_type()),
|
||||
layout_id_(-1) {
|
||||
if (is_opaque()) {
|
||||
layout_id_ = t.get_layout_id();
|
||||
}
|
||||
if (is_strided()) {
|
||||
strides_ = t.get_strides();
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: llga need set input/output type constraints while it seems that we
|
||||
// cannot get the dtype during compile time, hard-coded to fp32 for now to be
|
||||
// able to add_op
|
||||
LlgaTensorDesc(const torch::jit::Value* v)
|
||||
: LlgaTensorDesc(
|
||||
v->unique(),
|
||||
{},
|
||||
{},
|
||||
desc::data_type::f32,
|
||||
get_property_type(v)) {
|
||||
if (v->type()->isSubtypeOf(TensorType::get())) {
|
||||
auto tt = v->type()->cast<TensorType>();
|
||||
|
||||
auto sizes = tt->sizes();
|
||||
if (sizes.sizes()) {
|
||||
for (auto d : *sizes.sizes()) {
|
||||
sizes_.push_back(d.value_or(DNNL_GRAPH_UNKNOWN_DIM));
|
||||
}
|
||||
}
|
||||
|
||||
auto strides = tt->strides();
|
||||
if (strides.sizes()) {
|
||||
for (auto d : *strides.sizes()) {
|
||||
strides_.push_back(d.value_or(DNNL_GRAPH_UNKNOWN_DIM));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LlgaTensorDesc supplementTensorInfo(const at::Tensor& t) const;
|
||||
|
||||
at::ScalarType aten_scalar_type() const;
|
||||
|
||||
const std::vector<int64_t>& sizes() const {
|
||||
return sizes_;
|
||||
}
|
||||
|
||||
const std::vector<int64_t>& strides() const {
|
||||
TORCH_CHECK(!is_opaque(), "Cannot get strides on opaque layout");
|
||||
return strides_;
|
||||
}
|
||||
|
||||
size_t tid() const {
|
||||
return tid_;
|
||||
}
|
||||
|
||||
LlgaTensorDesc tid(uint64_t new_id) const {
|
||||
auto ret = *this;
|
||||
ret.tid_ = new_id;
|
||||
return ret;
|
||||
}
|
||||
|
||||
desc::data_type dtype() const {
|
||||
return dtype_;
|
||||
}
|
||||
|
||||
LlgaTensorDesc dtype(desc::data_type new_dtype) const {
|
||||
return LlgaTensorDesc(tid_, sizes_, strides_, new_dtype, property_type_);
|
||||
}
|
||||
|
||||
desc::layout_type layout_type() const {
|
||||
return layout_type_;
|
||||
}
|
||||
|
||||
LlgaTensorDesc layout_type(desc::layout_type new_layout_type) {
|
||||
auto ret = *this;
|
||||
ret.layout_type_ = new_layout_type;
|
||||
return ret;
|
||||
}
|
||||
|
||||
desc::property_type get_property_type(const torch::jit::Value* v) {
|
||||
switch (v->node()->kind()) {
|
||||
case prim::Constant:
|
||||
return desc::property_type::constant;
|
||||
default:
|
||||
return desc::property_type::variable;
|
||||
}
|
||||
}
|
||||
|
||||
LlgaTensorDesc any() {
|
||||
return layout_type(desc::layout_type::any);
|
||||
}
|
||||
|
||||
size_t storage_size() const {
|
||||
return logical_tensor().get_mem_size();
|
||||
}
|
||||
|
||||
desc logical_tensor() const {
|
||||
if (is_dimensionality_unknown()) {
|
||||
return desc(
|
||||
tid_, dtype_, DNNL_GRAPH_UNKNOWN_NDIMS, layout_type_, property_type_);
|
||||
} else if (is_opaque()) {
|
||||
return desc(tid_, dtype_, sizes_, layout_id_, property_type_);
|
||||
} else if (is_any()) {
|
||||
return desc(tid_, dtype_, sizes_, layout_type_, property_type_);
|
||||
} else {
|
||||
return desc(tid_, dtype_, sizes_, strides_, property_type_);
|
||||
}
|
||||
}
|
||||
|
||||
bool is_strided() const {
|
||||
return layout_type_ == desc::layout_type::strided;
|
||||
}
|
||||
|
||||
bool is_any() const {
|
||||
return layout_type_ == desc::layout_type::any;
|
||||
}
|
||||
|
||||
bool is_opaque() const {
|
||||
return layout_type_ == desc::layout_type::opaque;
|
||||
}
|
||||
|
||||
bool operator==(const LlgaTensorDesc& desc) const {
|
||||
return tid_ == desc.tid_ && sizes_ == desc.sizes_ &&
|
||||
dtype_ == desc.dtype_ && layout_type_ == desc.layout_type_ &&
|
||||
((is_opaque() && layout_id_ == desc.layout_id_) ||
|
||||
strides_ == desc.strides_);
|
||||
}
|
||||
|
||||
bool operator!=(const LlgaTensorDesc& desc) const {
|
||||
return (tid_ != desc.tid_) || (sizes_ != desc.sizes_) ||
|
||||
(dtype_ != desc.dtype_) || (layout_type_ != desc.layout_type_) ||
|
||||
!((is_opaque() && (layout_id_ == desc.layout_id_)) ||
|
||||
(strides_ == desc.strides_));
|
||||
}
|
||||
|
||||
static size_t hash(const LlgaTensorDesc& desc) {
|
||||
return c10::get_hash(
|
||||
desc.tid_,
|
||||
desc.sizes_,
|
||||
desc.dtype_,
|
||||
desc.layout_type_,
|
||||
desc.layout_id_);
|
||||
}
|
||||
|
||||
void set_compute_inplace() {
|
||||
compute_inplace_ = true;
|
||||
}
|
||||
|
||||
void set_input_tensor_index(size_t index) {
|
||||
input_tensor_index_ = index;
|
||||
}
|
||||
|
||||
bool reuses_input_tensor() {
|
||||
return compute_inplace_;
|
||||
}
|
||||
|
||||
size_t get_input_tensor_index() {
|
||||
return input_tensor_index_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_dimensionality_unknown() const {
|
||||
return sizes_.size() == 0;
|
||||
}
|
||||
|
||||
size_t tid_;
|
||||
std::vector<int64_t> sizes_;
|
||||
std::vector<int64_t> strides_;
|
||||
desc::data_type dtype_;
|
||||
desc::property_type property_type_;
|
||||
desc::layout_type layout_type_;
|
||||
size_t layout_id_;
|
||||
// If this is an output tensor, and querying the compiled partition would
|
||||
// determine that this tensor would reuse its input tensor, then
|
||||
// compute_inplace would be true, and input_tensor_index would be the index of
|
||||
// the corresponding input tensor in inputSpecs_ of the LlgaKernel object.
|
||||
bool compute_inplace_ = false;
|
||||
size_t input_tensor_index_;
|
||||
};
|
||||
|
||||
// Initially, oneDNN Graph also used to have blocked layout for tensors between
|
||||
// partitions, and the LlgaTensorImpl wrapper helped us bypass guard checks.
|
||||
// oneDNN Graph has switched over to using strided tensors between partitions,
|
||||
// but this wrapper still helps us bypass guard checks because the strides of
|
||||
// tensors between partitions would be different from the ones the guard is
|
||||
// otherwise expecting.
|
||||
struct TORCH_API LlgaTensorImpl : public c10::TensorImpl {
|
||||
LlgaTensorImpl(
|
||||
at::Storage&& storage,
|
||||
const caffe2::TypeMeta& data_type,
|
||||
const LlgaTensorDesc& desc);
|
||||
|
||||
const LlgaTensorDesc& desc() const {
|
||||
return desc_;
|
||||
}
|
||||
|
||||
static at::Tensor llga_to_aten_tensor(LlgaTensorImpl* llgaImpl);
|
||||
|
||||
private:
|
||||
LlgaTensorDesc desc_;
|
||||
};
|
||||
|
||||
at::Tensor empty_llga(
|
||||
const LlgaTensorDesc& desc,
|
||||
const c10::TensorOptions& options);
|
||||
|
||||
dnnl::graph::tensor llga_from_aten_tensor(const at::Tensor& tensor);
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
108
torch/csrc/jit/codegen/onednn/README.md
Normal file
108
torch/csrc/jit/codegen/onednn/README.md
Normal file
@ -0,0 +1,108 @@
|
||||
# Pytorch - oneDNN Graph API Bridge
|
||||
This integration will add the infrastructure of a new PyTorch JIT graph fuser based on [oneDNN Graph API](https://spec.oneapi.io/onednn-graph/latest/programming_model.html), which provides a flexible API for aggressive fusion. The current preview4 version supports fusion for FP32 inference. Currently, the speedup is achieved for static shapes,
|
||||
although we'd soon add dynamic-shape support. When oneDNN Graph is enabled, weights are cached, as they're constant during inference.
|
||||
|
||||
## Graph Optimization
|
||||
We have registered optimization passes in the custom pre-passes set of PyTorch:
|
||||
|
||||
1. Alias and mutation reduction
|
||||
|
||||
The operators of oneDNN graph are pure functional while PyTorch has operators in in-place forms or create views for buffer sharing.
|
||||
Due to the semantic gaps between the backend operators and the PyTorch operators, we have a pass to reduce mutation with best effort at the beginning.
|
||||
|
||||
2. Graph passing
|
||||
|
||||
With a PyTorch TorchScript graph, the integration maps PyTorch operators on the graph to the corresponding oneDNN Graph operators to form a backend graph.
|
||||
|
||||
3. Partitioning
|
||||
|
||||
The backend selects regions to be fused in the graph and returns a list of partitions. Each partition corresponds to a set of fused operators.
|
||||
|
||||
4. Graph rewriting
|
||||
|
||||
The original PyTorch JIT graph will be re-written based on the partitions returned from the backend. The operators in one partition will be grouped together to form a JIT operator, referred to as a oneDNN Graph fusion group.
|
||||
|
||||
5. Layout propagation
|
||||
|
||||
This pass is to eliminate unnecessary layout conversions at partition boundaries. We set different formats to the output of a partition so that the backend could perform layout conversion internally. When `ANY` is set, the layout at boundaries will be fully decided by the backend. Otherwise, the backend should follow the layout set by PyTorch. Currently, we set `ANY` layout for a tensor that's an output of a oneDNN Graph partition, and an input to another.
|
||||
|
||||
## Graph Executor
|
||||
During runtime execution of a (re-written) PyTorch JIT graph, oneDNN graph partitions will be dispatched to the oneDNN graph JIT variadic Operator.
|
||||
Inside the oneDNN graph JIT Op, input PyTorch tensors of each partition will be mapped to oneDNN graph tensors. The partition will then be [compiled](https://spec.oneapi.io/onednn-graph/latest/programming_model.html#partition) and [executed](https://spec.oneapi.io/onednn-graph/latest/programming_model.html#compiled-partition). The output oneDNN graph tensor will be mapped back to PyTorch tensors to be fed to the next operator on the PyTorch JIT graph.
|
||||
|
||||
|
||||
## Tests
|
||||
|
||||
```bash
|
||||
pytest test/test_jit_llga_fuser.py
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
A simple cascaded Conv-Relu example is provided in test. Please consider enabling log outputs to familiarize yourself with the whole pipeline:
|
||||
|
||||
**Mutation Removal -> Prepare Binary -> Defer Size Check -> Graph Fuser -> Layout Propagation -> Type Guard -> Kernel Execution**
|
||||
|
||||
oneDNN Graph was formerly known as LLGA (Low Level Graph API),
|
||||
and thus LLGA in the codebase corresponds to oneDNN Graph.
|
||||
|
||||
```bash
|
||||
DNNL_VERBOSE=1 PYTORCH_JIT_LOG_LEVEL=">>graph_helper:>>graph_fuser:>>kernel:>>interface" python -u test/test_jit_llga_fuser.py -k test_conv2d_eltwise
|
||||
```
|
||||
|
||||
## Codebase structure
|
||||
|
||||
Most of the source code is placed in
|
||||
|
||||
```bash
|
||||
torch/csrc/jit/codegen/onednn/*
|
||||
```
|
||||
|
||||
Tensor related code is located at
|
||||
|
||||
```bash
|
||||
torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h
|
||||
torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp
|
||||
```
|
||||
|
||||
CMake files where bridge code is included:
|
||||
|
||||
```bash
|
||||
caffe2/CMakeLists.txt
|
||||
```
|
||||
|
||||
CMake files where oneDNN Graph submodule are included:
|
||||
|
||||
```bash
|
||||
third_party/ideep/mkl-dnn
|
||||
cmake/public/mkldnn.cmake
|
||||
cmake/Modules/FindMKLDNN.cmake
|
||||
cmake/Dependencies.cmake
|
||||
```
|
||||
|
||||
To map another op to oneDNN Graph, you should add an entry for it in in createOperator in torch/csrc/jit/codegen/onednn/graph_helper.cpp.
|
||||
If it has an inplace variant, you should add it in the lambda being passed to RemoveTensorMutation in
|
||||
torch/csrc/jit/codegen/onednn/interface.cpp. You might also want to add it to canFuseNode in torch/csrc/jit/codegen/onednn/register_interface.cpp.
|
||||
|
||||
## How to use
|
||||
|
||||
|
||||
```python
|
||||
# enable oneDNN graph fusion globally
|
||||
torch.jit.enable_onednn_fusion(True)
|
||||
|
||||
# define the model
|
||||
def MyModel(torch.nn.Module):
|
||||
...
|
||||
|
||||
# construct the model
|
||||
model = MyModel(…)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
model = torch.jit.trace(model, torch.rand(args.batch_size, 3, 224, 224))
|
||||
|
||||
# run the model
|
||||
with torch.no_grad():
|
||||
# oneDNN graph fusion will be trigerred during runtime
|
||||
output = model(images)
|
||||
```
|
87
torch/csrc/jit/codegen/onednn/defer_size_check.cpp
Normal file
87
torch/csrc/jit/codegen/onednn/defer_size_check.cpp
Normal file
@ -0,0 +1,87 @@
|
||||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
#include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
class SizeCheckMover {
|
||||
private:
|
||||
Block* block_;
|
||||
std::shared_ptr<Graph> graph_;
|
||||
|
||||
public:
|
||||
SizeCheckMover(Block* block, std::shared_ptr<Graph> graph)
|
||||
: block_(block), graph_(std::move(graph)) {}
|
||||
|
||||
bool analyzeNode(Node* node, AliasDb& aliasDb) {
|
||||
//
|
||||
// %b = addmm(%a)
|
||||
// %sz = aten::size(%b)
|
||||
// %c = relu(%b)
|
||||
// =>
|
||||
// %b = addmm(%a)
|
||||
// %c = relu(%b)
|
||||
// %sz = aten::size(%c)
|
||||
// ^-- move size check after relu as it preserves input shape
|
||||
//
|
||||
if (!node->matches("aten::size(Tensor self) -> int[]"))
|
||||
return false;
|
||||
|
||||
auto* input = node->input(0);
|
||||
auto& uses = input->uses();
|
||||
bool onlyUsedByShapePreserveOp =
|
||||
uses.size() > 1 && std::all_of(uses.begin(), uses.end(), [&](auto& u) {
|
||||
if (u.user == node) {
|
||||
return true;
|
||||
}
|
||||
// match with shape-preserving unary ops in
|
||||
// tensorexpr_elementwise_set that's defined in
|
||||
// torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp
|
||||
OperatorMap<std::string> schemaMap = get_tensorexpr_elementwise_set();
|
||||
c10::optional<std::string> mapping =
|
||||
schemaMap.find(u.user->getOperator());
|
||||
return mapping == "unary";
|
||||
});
|
||||
|
||||
if (!onlyUsedByShapePreserveOp)
|
||||
return false;
|
||||
|
||||
for (const auto& use : uses) {
|
||||
if (use.user == node)
|
||||
continue;
|
||||
auto shapePreserveOp = use.user;
|
||||
if (aliasDb.moveAfterTopologicallyValid(node, shapePreserveOp)) {
|
||||
node->replaceInputWith(input, shapePreserveOp->output(0));
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void run() {
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
AliasDb aliasDb(graph_);
|
||||
for (Node* node : block_->nodes()) {
|
||||
changed |= analyzeNode(node, aliasDb);
|
||||
}
|
||||
}
|
||||
|
||||
for (Node* node : block_->nodes())
|
||||
for (Block* subBlock : node->blocks())
|
||||
SizeCheckMover(subBlock, graph_).run();
|
||||
}
|
||||
};
|
||||
|
||||
void DeferSizeCheck(std::shared_ptr<Graph>& graph) {
|
||||
SizeCheckMover(graph->block(), graph).run();
|
||||
}
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
15
torch/csrc/jit/codegen/onednn/defer_size_check.h
Normal file
15
torch/csrc/jit/codegen/onednn/defer_size_check.h
Normal file
@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
void DeferSizeCheck(std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
31
torch/csrc/jit/codegen/onednn/graph_fuser.cpp
Normal file
31
torch/csrc/jit/codegen/onednn/graph_fuser.cpp
Normal file
@ -0,0 +1,31 @@
|
||||
#include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
|
||||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
void CreateLlgaSubgraphs(std::shared_ptr<Graph>& graph) {
|
||||
AliasDb db(graph);
|
||||
GraphRewriter graphRewriter(graph->block(), graph, db);
|
||||
// We maintain alias db correctness in-place while building up the LLGA
|
||||
// subgraphs, however it is difficult to preserve correctness when
|
||||
// un-inlining autodiff subgraphs. We first recursively construct all
|
||||
// subgraphs and then recursively cleanup & unmerge the small subgraphs
|
||||
graphRewriter.buildupSubgraphs();
|
||||
graphRewriter.cleanupSubgraphs();
|
||||
// Run CSE globally onceto eliminate duplicates that may have occurred
|
||||
// while inlining subgraphs.
|
||||
EliminateCommonSubexpression(graph);
|
||||
EliminateDeadCode(graph);
|
||||
}
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
53
torch/csrc/jit/codegen/onednn/graph_fuser.h
Normal file
53
torch/csrc/jit/codegen/onednn/graph_fuser.h
Normal file
@ -0,0 +1,53 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
struct WorkBlock : public std::pair<Node*, Node*> {
|
||||
using pair::pair;
|
||||
|
||||
Node* begin() {
|
||||
return this->first;
|
||||
}
|
||||
Node* end() {
|
||||
return this->second;
|
||||
}
|
||||
};
|
||||
|
||||
class GraphRewriter {
|
||||
public:
|
||||
GraphRewriter(Block* block, std::shared_ptr<Graph> graph, AliasDb& aliasDb)
|
||||
: block_(block),
|
||||
graph_(std::move(graph)),
|
||||
aliasDb_(aliasDb),
|
||||
llgaHelper_(graph_) {}
|
||||
|
||||
void cleanupSubgraphs();
|
||||
void buildupSubgraphs();
|
||||
|
||||
private:
|
||||
Block* block_;
|
||||
std::shared_ptr<Graph> graph_;
|
||||
AliasDb& aliasDb_;
|
||||
LlgaGraphHelper llgaHelper_;
|
||||
std::vector<WorkBlock> buildWorkBlocks();
|
||||
std::pair<graph_node_list::iterator, bool> scanNode(
|
||||
Node* consumer,
|
||||
graph_node_list::iterator workblock_begin);
|
||||
c10::optional<Node*> tryMerge(Node* consumer, Node* producer);
|
||||
};
|
||||
|
||||
// This pass creates the subgraphs for oneDNN Graph Fusion Nodes.
|
||||
// Its code-structure has been vastly inspired from
|
||||
// torch/csrc/jit/passes/create_autodiff_subgraphs.cpp
|
||||
void CreateLlgaSubgraphs(std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
562
torch/csrc/jit/codegen/onednn/graph_helper.cpp
Normal file
562
torch/csrc/jit/codegen/onednn/graph_helper.cpp
Normal file
@ -0,0 +1,562 @@
|
||||
#include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
|
||||
|
||||
#include <ATen/core/functional.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
using opkind = dnnl::graph::op::kind;
|
||||
|
||||
void fixConvOptionalBias(Node* node) {
|
||||
if (node->namedInput("bias")->mustNotBeNone() == false) {
|
||||
// Replace non-existent optional bias with const None
|
||||
auto g = node->owningGraph();
|
||||
auto n = g->createNone();
|
||||
auto v = n->insertBefore(node)->output();
|
||||
node->replaceInput(2, v);
|
||||
}
|
||||
}
|
||||
|
||||
c10::optional<size_t> getDimensions(Value* v) {
|
||||
if (v->type()->isSubtypeOf(TensorType::get())) {
|
||||
return v->type()->cast<TensorType>()->sizes().size();
|
||||
} else {
|
||||
return c10::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
// PyTorch ops that can't otherwise be mapped to oneDNN Graph ops are mapped as
|
||||
// Wildcards instead. They make the integration code with PyTorch simpler by
|
||||
// passing every op to the oneDNN Graph library in the add_op call -
|
||||
// no need to check beforehand whether the op is supported by oneDNN Graph or
|
||||
// not oneDNN Graph ops separated by wildcards don't end up in the same
|
||||
// partition.
|
||||
Operator makeWildcardOp(Node* node) {
|
||||
auto o = Operator(node, opkind::Wildcard);
|
||||
// wildcard op contains only topology info
|
||||
for (size_t i = 0; i < node->inputs().size(); i++) {
|
||||
o.setInput(i);
|
||||
}
|
||||
for (size_t i = 0; i < node->outputs().size(); i++) {
|
||||
o.setOutput(i);
|
||||
}
|
||||
return o;
|
||||
}
|
||||
|
||||
// If we don't meet a certain condition to map a PyTorch op to a oneDNN Graph
|
||||
// op, then we create a wildcard op corresponding to that PyTorch op instead.
|
||||
#define REQUIRE(cond) \
|
||||
if (!(cond)) { \
|
||||
GRAPH_DEBUG("Unsupported condition " #cond "\n"); \
|
||||
return makeWildcardOp(node); \
|
||||
}
|
||||
|
||||
Operator makeEltwiseOp(Node* node, opkind kind) {
|
||||
return Operator(node, kind).setInput(0).setOutput(0);
|
||||
}
|
||||
|
||||
Operator makeBinaryOp(Node* node, opkind kind) {
|
||||
REQUIRE(
|
||||
node->input(0)->type()->isSubtypeOf(TensorType::get()) &&
|
||||
node->input(1)->type()->isSubtypeOf(TensorType::get()))
|
||||
return Operator(node, kind).setInput(0, 1).setOutput(0);
|
||||
}
|
||||
|
||||
// Map a PyTorch op to its corresponding oneDNN Graph op.
|
||||
// If mapping isn't possible, then create a wildcard op instead.
|
||||
// The mapping is done as per oneDNN Graph op schema defined in
|
||||
// third_party/ideep/mkl-dnn/src/interface/op_def.hpp.
|
||||
Operator createOperator(Node* node) {
|
||||
switch (node->kind()) {
|
||||
case aten::conv2d: {
|
||||
fixConvOptionalBias(node);
|
||||
return Operator(node, opkind::Convolution)
|
||||
.setInput(0, 1, 2)
|
||||
.setOutput(0)
|
||||
.setAttr("strides", Operator::Ints, 3)
|
||||
.setAttr("pads_begin", Operator::Ints, 4)
|
||||
.setAttr("pads_end", Operator::Ints, 4)
|
||||
.setAttr("dilations", Operator::Ints, 5)
|
||||
.setAttr("groups", Operator::Int, 6)
|
||||
.setAttr("filter_format", std::string("OIX"))
|
||||
.setAttr("data_format", std::string("NCX"));
|
||||
}
|
||||
|
||||
case aten::_convolution: {
|
||||
bool transposed = toIValue(node->namedInput("transposed"))->toBool();
|
||||
REQUIRE(!transposed);
|
||||
|
||||
return Operator(node, opkind::Convolution)
|
||||
.setInput(0, 1, 2)
|
||||
.setOutput(0)
|
||||
.setAttr("strides", Operator::Ints, 3)
|
||||
.setAttr("pads_begin", Operator::Ints, 4)
|
||||
.setAttr("pads_end", Operator::Ints, 4)
|
||||
.setAttr("dilations", Operator::Ints, 5)
|
||||
.setAttr("groups", Operator::Int, 8)
|
||||
.setAttr("filter_format", std::string("OIX"))
|
||||
.setAttr("data_format", std::string("NCX"));
|
||||
}
|
||||
|
||||
case aten::batch_norm: {
|
||||
auto training = toIValue(node->namedInput("training"));
|
||||
REQUIRE(
|
||||
training.has_value()); // cannot get training status in script mode
|
||||
REQUIRE(!training->toBool()); // TODO: support bn training
|
||||
return Operator(node, opkind::BatchNormInference)
|
||||
.setInput(0, 1, 2, 3, 4)
|
||||
.setOutput(0)
|
||||
.setAttr("epsilon", Operator::Float, 7)
|
||||
.setAttr("data_format", std::string("NCX"));
|
||||
}
|
||||
|
||||
case aten::layer_norm: {
|
||||
auto normalized_shape = toIValue(node->namedInput("normalized_shape"));
|
||||
REQUIRE(normalized_shape->toIntList().size() == 1);
|
||||
return Operator(node, opkind::LayerNorm)
|
||||
.setInput(0, 2, 3)
|
||||
.setOutput(0)
|
||||
.setAttr("epsilon", Operator::Float, 4)
|
||||
.setAttr("keep_stats", false);
|
||||
}
|
||||
|
||||
case aten::addmm: {
|
||||
auto alpha = toIValue(node->namedInput("alpha"));
|
||||
auto beta = toIValue(node->namedInput("beta"));
|
||||
REQUIRE(
|
||||
alpha.has_value() && beta.has_value() && (alpha->toDouble() == 1.0) &&
|
||||
(beta->toDouble() == 1.0));
|
||||
return Operator(node, opkind::MatMul).setInput(1, 2, 0).setOutput(0);
|
||||
}
|
||||
|
||||
case aten::add:
|
||||
return makeBinaryOp(node, opkind::Add);
|
||||
|
||||
case aten::mul:
|
||||
return makeBinaryOp(node, opkind::Multiply);
|
||||
|
||||
case aten::tanh:
|
||||
return makeEltwiseOp(node, opkind::Tanh);
|
||||
|
||||
case aten::relu:
|
||||
return makeEltwiseOp(node, opkind::ReLU);
|
||||
|
||||
case aten::elu:
|
||||
return makeEltwiseOp(node, opkind::Elu)
|
||||
.setAttr("alpha", Operator::Float, 1);
|
||||
|
||||
case aten::sigmoid:
|
||||
return makeEltwiseOp(node, opkind::Sigmoid);
|
||||
case aten::gelu:
|
||||
return makeEltwiseOp(node, opkind::GELU);
|
||||
|
||||
case aten::sqrt:
|
||||
return makeEltwiseOp(node, opkind::Sqrt);
|
||||
|
||||
case aten::abs:
|
||||
return makeEltwiseOp(node, opkind::Abs);
|
||||
|
||||
case aten::square:
|
||||
return makeEltwiseOp(node, opkind::Square);
|
||||
|
||||
case aten::hardtanh:
|
||||
return makeEltwiseOp(node, opkind::HardTanh)
|
||||
.setAttr("min", Operator::Float, 1)
|
||||
.setAttr("max", Operator::Float, 2);
|
||||
|
||||
case aten::relu6:
|
||||
return makeEltwiseOp(node, opkind::HardTanh)
|
||||
.setAttr("min", 0.f)
|
||||
.setAttr("max", 6.f);
|
||||
|
||||
case aten::softmax: {
|
||||
auto axis = toIValue(node->namedInput("dim"))->toInt();
|
||||
return Operator(node, opkind::SoftMax)
|
||||
.setInput(0)
|
||||
.setOutput(0)
|
||||
.setAttr("axis", axis);
|
||||
}
|
||||
|
||||
case aten::cat: {
|
||||
auto o = Operator(node, opkind::Concat);
|
||||
REQUIRE(
|
||||
node->namedInput("tensors")->node()->kind() == prim::ListConstruct);
|
||||
REQUIRE(node->namedInput("tensors")->uses().size() == 1);
|
||||
REQUIRE(node->namedInput("dim")->node()->kind() == prim::Constant);
|
||||
// aten::cat needs a special handling since it takes a Tensor[] as input.
|
||||
// We set the inputs of ListConstruct as the inputs of cat.
|
||||
//
|
||||
// Pytorch IR: LLGA sees:
|
||||
// %a %b %c %dim %a %b %c
|
||||
// \ | / | \ | /
|
||||
// prim::ListConstruct prim::Constant llga::Concat[axis=%dim]
|
||||
// \ /
|
||||
// aten::cat
|
||||
auto listConstruct = node->input(0)->node();
|
||||
for (auto input : listConstruct->inputs())
|
||||
o.setInputValue(input);
|
||||
return o.setOutput(0).setAttr("axis", Operator::Int, 1);
|
||||
}
|
||||
|
||||
case aten::max_pool2d: {
|
||||
REQUIRE(
|
||||
node->namedInput("kernel_size")->node()->kind() == prim::Constant);
|
||||
|
||||
auto rounding_type =
|
||||
toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor";
|
||||
return Operator(node, opkind::MaxPool)
|
||||
.setInput(0)
|
||||
.setOutput(0)
|
||||
.setAttr("kernel", Operator::Ints, 1)
|
||||
.setAttr("strides", Operator::Ints, 2)
|
||||
.setAttr("pads_begin", Operator::Ints, 3)
|
||||
.setAttr("pads_end", Operator::Ints, 3)
|
||||
.setAttr("dilations", Operator::Ints, 4)
|
||||
.setAttr("rounding_type", std::string(rounding_type))
|
||||
.setAttr("data_format", std::string("NCX"));
|
||||
}
|
||||
|
||||
case aten::avg_pool2d: {
|
||||
// TODO: do we need add checks for all Constants?
|
||||
REQUIRE(
|
||||
node->namedInput("kernel_size")->node()->kind() == prim::Constant);
|
||||
auto rounding_type =
|
||||
toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor";
|
||||
auto divisor_override = toIValue(node->namedInput("divisor_override"));
|
||||
REQUIRE(divisor_override->isNone());
|
||||
return Operator(node, opkind::AvgPool)
|
||||
.setInput(0)
|
||||
.setOutput(0)
|
||||
.setAttr("kernel", Operator::Ints, 1)
|
||||
.setAttr("strides", Operator::Ints, 2)
|
||||
.setAttr("pads_begin", Operator::Ints, 3)
|
||||
.setAttr("pads_end", Operator::Ints, 3)
|
||||
.setAttr("exclude_pad", !Operator::Bool(node, 5))
|
||||
.setAttr("rounding_type", std::string(rounding_type))
|
||||
.setAttr("data_format", std::string("NCX"));
|
||||
}
|
||||
|
||||
case aten::matmul: {
|
||||
auto dim0 = getDimensions(node->namedInput("self")).value_or(-1);
|
||||
auto dim1 = getDimensions(node->namedInput("other")).value_or(-1);
|
||||
// TODO: support all shape combinations
|
||||
REQUIRE(
|
||||
(dim0 == 2 && dim1 == 2) || (dim0 == 4 && dim1 == 4) ||
|
||||
(dim0 == 3 && dim1 == 2));
|
||||
} // fall through
|
||||
case aten::mm: {
|
||||
return Operator(node, opkind::MatMul).setInput(0, 1).setOutput(0);
|
||||
}
|
||||
|
||||
case aten::linear: {
|
||||
return Operator(node, opkind::MatMul)
|
||||
.setInput(0, 1, 2)
|
||||
.setOutput(0)
|
||||
.setAttr("transpose_b", true);
|
||||
}
|
||||
|
||||
default:
|
||||
return makeWildcardOp(node);
|
||||
}
|
||||
}
|
||||
|
||||
dnnl::graph::op createLlgaOp(Node* node) {
|
||||
return createOperator(node).llgaOp();
|
||||
}
|
||||
|
||||
bool isSupported(Node* node) {
|
||||
return createOperator(node).kind() != opkind::Wildcard;
|
||||
};
|
||||
|
||||
DeviceType inferDeviceFromValue(Value* v) {
|
||||
auto tt = v->type()->cast<TensorType>();
|
||||
if (!tt) {
|
||||
return at::kCPU;
|
||||
}
|
||||
auto device = tt->device();
|
||||
if (!device) {
|
||||
return at::kCPU;
|
||||
}
|
||||
return device->type();
|
||||
}
|
||||
|
||||
DeviceType inferDevice(const std::shared_ptr<Graph>& graph) {
|
||||
auto dt = inferDeviceFromValue(graph->inputs()[0]);
|
||||
TORCH_CHECK(
|
||||
std::all_of(
|
||||
graph->inputs().begin(),
|
||||
graph->inputs().end(),
|
||||
[dt](Value* v) { return inferDeviceFromValue(v) == dt; }),
|
||||
"All inputs must have the same deive type");
|
||||
return dt;
|
||||
}
|
||||
|
||||
dnnl::graph::engine::kind getLlgaEngineKind(DeviceType type) {
|
||||
switch (type) {
|
||||
case DeviceType::CPU:
|
||||
return dnnl::graph::engine::kind::cpu;
|
||||
default:
|
||||
TORCH_CHECK(false, "Not support device type ", type);
|
||||
}
|
||||
}
|
||||
|
||||
void mayAddListConstructIntoConcatPartition(
|
||||
Node* n,
|
||||
OpPartitionMap& opToOwningPartition) {
|
||||
// Since prim::ListConstruct is not visible to the LLGA,
|
||||
// it will not be in any partition returned from partfuseritioning results.
|
||||
// We need rewrite opToOwningPartition to make the prim::ListConstruct to be
|
||||
// 'virtually' in the same partition with the aten::cat, so that
|
||||
// prim::ListConstruct can be fused into the fusion group by graph fuser.
|
||||
// We emphasize on 'virtually' because get_num_ops() for cat's partition
|
||||
// would still return 1.
|
||||
if (n->kind() == aten::cat && opToOwningPartition.has(n)) {
|
||||
auto listConstrcut = n->namedInput("tensors")->node();
|
||||
auto partitionId = opToOwningPartition.get(n);
|
||||
opToOwningPartition.add(listConstrcut, partitionId);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that input tensors are compatible with oneDNN Graph.
|
||||
// Scalars would be converted to 1-D tensors later anyway,
|
||||
// but they shouldn't be complex-double
|
||||
// If this check fails, convert op to wildcard
|
||||
bool checkInputCompatibility(Node* node) {
|
||||
auto allInputs = node->inputs();
|
||||
for (auto input : allInputs) {
|
||||
c10::IValue inputIValue = toIValue(input);
|
||||
if (inputIValue.isTensor()) {
|
||||
const at::Tensor& tensor = inputIValue.toTensor();
|
||||
if (tensor.device() != at::kCPU) {
|
||||
return false;
|
||||
}
|
||||
auto dtype = tensor.scalar_type();
|
||||
if ((dtype != at::ScalarType::Float) && (dtype != at::ScalarType::Long)) {
|
||||
return false;
|
||||
}
|
||||
} else if (inputIValue.isScalar()) {
|
||||
if (inputIValue.isComplexDouble()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
LlgaGraphHelper::LlgaGraphHelper(
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
dnnl::graph::partition::policy policy) {
|
||||
auto deviceType = inferDevice(graph);
|
||||
auto engineKind = getLlgaEngineKind(deviceType);
|
||||
dnnl::graph::graph g{engineKind};
|
||||
|
||||
GRAPH_DEBUG("Constructing LLGA graph");
|
||||
// TODO: select nodes in top-level block for now
|
||||
for (auto* node : graph->block()->nodes()) {
|
||||
auto op = createLlgaOp(node);
|
||||
auto kindOfNode = node->kind();
|
||||
if (checkInputCompatibility(node)) {
|
||||
g.add_op(op);
|
||||
GRAPH_DEBUG(" Added node ", kindOfNode.toQualString());
|
||||
} else {
|
||||
GRAPH_DEBUG("The backend failed to add node ", kindOfNode.toQualString());
|
||||
g.add_op(makeWildcardOp(node).llgaOp());
|
||||
}
|
||||
|
||||
for (Value* input : node->inputs()) {
|
||||
tensorIdToValue_.emplace(input->unique(), input);
|
||||
}
|
||||
}
|
||||
|
||||
GRAPH_DEBUG("Get Partitions");
|
||||
std::vector<dnnl::graph::partition> partitions = g.get_partitions(policy);
|
||||
// excluded unsupported Wildcard partitions
|
||||
for (auto& partition : partitions) {
|
||||
if (partition.is_supported()) {
|
||||
partitions_.push_back(partition);
|
||||
}
|
||||
}
|
||||
|
||||
GRAPH_DEBUG(" Got #partitions: ", partitions_.size());
|
||||
for (size_t partId = 0; partId < partitions_.size(); partId++) {
|
||||
for (auto opId : partitions_[partId].get_ops()) {
|
||||
opToOwningPartition_.add(opId, partId);
|
||||
}
|
||||
}
|
||||
|
||||
// Scanning the graph again for post processing
|
||||
for (auto* node : graph->block()->nodes()) {
|
||||
mayAddListConstructIntoConcatPartition(node, opToOwningPartition_);
|
||||
}
|
||||
}
|
||||
|
||||
bool LlgaGraphHelper::isLlgaSubgraph(const Node* node) {
|
||||
return node->hasAttribute(attr::Subgraph) &&
|
||||
node->kind() == prim::oneDNNFusionGroup;
|
||||
}
|
||||
|
||||
bool LlgaGraphHelper::shouldMerge(Node* toMerge, Node* subgraph) {
|
||||
TORCH_CHECK(
|
||||
isLlgaSubgraph(subgraph),
|
||||
"The consumer node does not contain a subgraph");
|
||||
if (!shouldConsiderForMerge(toMerge)) {
|
||||
return false;
|
||||
}
|
||||
return opToOwningPartition_.get(toMerge) ==
|
||||
opToOwningPartition_.get(subgraph);
|
||||
}
|
||||
|
||||
// Except for conv & GEMMs, which should always be handled by oneDNN Graph,
|
||||
// only use single-op partitions for ops unsupported by NNC, or ops
|
||||
// that oneDNN executes faster. prim::ListConstruct is an exception, since
|
||||
// we simply want to fuse it with cat.
|
||||
bool isBetterSuitedForLLGA(NodeKind kindOfOp) {
|
||||
return (
|
||||
(kindOfOp == aten::layer_norm) || (kindOfOp == aten::avg_pool2d) ||
|
||||
(kindOfOp == aten::matmul) || (kindOfOp == aten::max_pool2d) ||
|
||||
(kindOfOp == aten::conv2d) || (kindOfOp == aten::_convolution) ||
|
||||
(kindOfOp == aten::mm) || (kindOfOp == aten::linear) ||
|
||||
(kindOfOp == aten::cat) || (kindOfOp == prim::ListConstruct));
|
||||
}
|
||||
|
||||
bool LlgaGraphHelper::checkForSingleOpPartition(Node* node) {
|
||||
if (opToOwningPartition_.has(node)) {
|
||||
auto partitionId = opToOwningPartition_.get(node);
|
||||
if (partitions_[partitionId].get_ops_num() == 1) {
|
||||
auto kindOfNode = node->kind();
|
||||
return isBetterSuitedForLLGA(kindOfNode);
|
||||
} else {
|
||||
// multi-op partition
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
// this op isn't present in any partition
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool LlgaGraphHelper::shouldConsiderForMerge(Node* node) {
|
||||
// if we're already in the process of merging
|
||||
if (isLlgaSubgraph(node)) {
|
||||
return true;
|
||||
}
|
||||
return checkForSingleOpPartition(node);
|
||||
}
|
||||
|
||||
Node* LlgaGraphHelper::createSingletonSubgraph(Node* n, AliasDb& aliasDb) {
|
||||
auto partitionId = opToOwningPartition_.get(n);
|
||||
GRAPH_DEBUG(
|
||||
"Creating FusionGroup_", partitionId, " for ", n->kind().toQualString());
|
||||
auto group = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
|
||||
n, prim::oneDNNFusionGroup, aliasDb);
|
||||
opToOwningPartition_.add(group, partitionId);
|
||||
LlgaNodeWrapper(group).initOutputLayouts();
|
||||
return group;
|
||||
}
|
||||
|
||||
void LlgaGraphHelper::mergeNodeIntoSubgraph(
|
||||
Node* toMerge,
|
||||
Node* subgraphNode,
|
||||
AliasDb& aliasDb) {
|
||||
if (isLlgaSubgraph(toMerge)) {
|
||||
GRAPH_DEBUG(
|
||||
"Merging ",
|
||||
toMerge->kind().toQualString(),
|
||||
"_",
|
||||
opToOwningPartition_.get(toMerge),
|
||||
" into ",
|
||||
subgraphNode->kind().toQualString(),
|
||||
"_",
|
||||
opToOwningPartition_.get(subgraphNode));
|
||||
} else {
|
||||
GRAPH_DEBUG(
|
||||
"Merging ",
|
||||
toMerge->kind().toQualString(),
|
||||
" into ",
|
||||
subgraphNode->kind().toQualString(),
|
||||
"_",
|
||||
opToOwningPartition_.get(subgraphNode));
|
||||
}
|
||||
|
||||
SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
|
||||
toMerge, subgraphNode, aliasDb);
|
||||
}
|
||||
|
||||
void LlgaGraphHelper::unmergeIfAnyNodeIsMissing(Node* subgraphNode) {
|
||||
TORCH_CHECK(isLlgaSubgraph(subgraphNode), "Cannot unmerge a non-LLGA node");
|
||||
|
||||
auto partitionId = opToOwningPartition_.get(subgraphNode);
|
||||
auto expectOpNum = partitions_[partitionId].get_ops_num();
|
||||
auto actualOpNum = countSupportedOps(subgraphNode->g(attr::Subgraph));
|
||||
|
||||
if (expectOpNum != actualOpNum) {
|
||||
GRAPH_DEBUG(
|
||||
"Unmerging FusionGroup_",
|
||||
partitionId,
|
||||
". Expected ",
|
||||
expectOpNum,
|
||||
" ops, but got ",
|
||||
actualOpNum,
|
||||
" ops.");
|
||||
SubgraphUtils::unmergeSubgraph(subgraphNode);
|
||||
}
|
||||
}
|
||||
|
||||
size_t LlgaGraphHelper::countSupportedOps(
|
||||
const std::shared_ptr<Graph>& graph) const {
|
||||
// TODO: count nodes in top-level block for now
|
||||
size_t cnt = 0;
|
||||
for (auto* node : graph->block()->nodes()) {
|
||||
auto nodeKind = node->kind();
|
||||
if ((nodeKind != prim::Constant) && (nodeKind != prim::ListConstruct)) {
|
||||
cnt++;
|
||||
}
|
||||
}
|
||||
return cnt;
|
||||
}
|
||||
|
||||
std::vector<dnnl::graph::partition> LlgaGraphHelper::getPartitions() const {
|
||||
return partitions_;
|
||||
}
|
||||
|
||||
std::map<size_t, Value*> LlgaGraphHelper::getTensorIdToValue() const {
|
||||
return tensorIdToValue_;
|
||||
}
|
||||
|
||||
LlgaNodeWrapper::LlgaNodeWrapper(const Node* node)
|
||||
: n(const_cast<Node*>(node)) { // NOLINT
|
||||
TORCH_CHECK(
|
||||
LlgaGraphHelper::isLlgaSubgraph(n), "Cannot wrap a non-LLGA fusion node");
|
||||
}
|
||||
|
||||
void LlgaNodeWrapper::setOpaqueLayout(size_t offset) {
|
||||
TORCH_CHECK(offset < n->outputs().size(), "Invalid output offset ", offset);
|
||||
auto& layouts =
|
||||
const_cast<std::vector<int64_t>&>(n->is(attr::output_layouts)); // NOLINT
|
||||
layouts.at(offset) = 1;
|
||||
}
|
||||
|
||||
bool LlgaNodeWrapper::useOpaqueLayout(size_t offset) const {
|
||||
TORCH_CHECK(offset < n->outputs().size(), "Invalid output offset ", offset);
|
||||
return n->is(attr::output_layouts)[offset] == 1;
|
||||
}
|
||||
|
||||
void LlgaNodeWrapper::initOutputLayouts() {
|
||||
if (n->hasAttribute(attr::output_layouts)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Init all output layouts as undef
|
||||
std::vector<int64_t> layouts(n->outputs().size(), 0);
|
||||
n->is_(attr::output_layouts, layouts);
|
||||
}
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
95
torch/csrc/jit/codegen/onednn/graph_helper.h
Normal file
95
torch/csrc/jit/codegen/onednn/graph_helper.h
Normal file
@ -0,0 +1,95 @@
|
||||
#pragma once
|
||||
|
||||
#include <oneapi/dnnl/dnnl_graph.hpp>
|
||||
#include <torch/csrc/jit/codegen/onednn/operator.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
struct OpPartitionMap {
|
||||
void add(uint64_t opId, uint64_t partitionId) {
|
||||
opmap_[opId] = partitionId;
|
||||
}
|
||||
void add(Node* n, uint64_t partitionId) {
|
||||
add(Operator::getId(n), partitionId);
|
||||
}
|
||||
bool has(uint64_t opId) {
|
||||
return opmap_.count(opId) > 0;
|
||||
}
|
||||
bool has(Node* n) {
|
||||
return has(Operator::getId(n));
|
||||
}
|
||||
uint64_t get(uint64_t opId) {
|
||||
return opmap_[opId];
|
||||
}
|
||||
uint64_t get(Node* n) {
|
||||
auto opId = Operator::getId(n);
|
||||
TORCH_CHECK(
|
||||
has(opId),
|
||||
"Node ",
|
||||
n->kind().toQualString(),
|
||||
" does not belong to any LLGA partition");
|
||||
return get(opId);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<uint64_t, uint64_t> opmap_;
|
||||
};
|
||||
|
||||
class LlgaGraphHelper {
|
||||
public:
|
||||
LlgaGraphHelper(
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
dnnl::graph::partition::policy policy =
|
||||
dnnl::graph::partition::policy::fusion);
|
||||
|
||||
bool shouldMerge(Node* toMerge, Node* subgraph);
|
||||
|
||||
bool shouldConsiderForMerge(Node* node);
|
||||
|
||||
bool checkForSingleOpPartition(Node* node);
|
||||
|
||||
Node* createSingletonSubgraph(Node* n, AliasDb& db);
|
||||
|
||||
void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode, AliasDb& db);
|
||||
|
||||
void unmergeIfAnyNodeIsMissing(Node* subgraphNode);
|
||||
|
||||
static bool isLlgaSubgraph(const Node* node);
|
||||
|
||||
std::vector<dnnl::graph::partition> getPartitions() const;
|
||||
|
||||
std::map<size_t, Value*> getTensorIdToValue() const;
|
||||
|
||||
private:
|
||||
size_t countSupportedOps(const std::shared_ptr<Graph>& graph) const;
|
||||
|
||||
OpPartitionMap opToOwningPartition_;
|
||||
std::vector<dnnl::graph::partition> partitions_;
|
||||
std::map<size_t, Value*>
|
||||
tensorIdToValue_; // map from tensorId to torch::jit::Value
|
||||
};
|
||||
|
||||
class LlgaNodeWrapper {
|
||||
public:
|
||||
LlgaNodeWrapper(const Node* node);
|
||||
|
||||
void setOpaqueLayout(size_t offset);
|
||||
|
||||
bool useOpaqueLayout(size_t offset) const;
|
||||
|
||||
friend class LlgaGraphHelper;
|
||||
|
||||
private:
|
||||
void initOutputLayouts();
|
||||
|
||||
Node* n;
|
||||
};
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
144
torch/csrc/jit/codegen/onednn/graph_rewriter.cpp
Normal file
144
torch/csrc/jit/codegen/onednn/graph_rewriter.cpp
Normal file
@ -0,0 +1,144 @@
|
||||
#include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
|
||||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
void GraphRewriter::cleanupSubgraphs() {
|
||||
auto curNode = *block_->nodes().rbegin();
|
||||
while (curNode != *block_->nodes().rend()) {
|
||||
// Save the previous node, since we might delete `curNode` in next block
|
||||
auto prevNode = curNode->prev();
|
||||
if (llgaHelper_.isLlgaSubgraph(curNode)) {
|
||||
// Unmerge subgraph if we don't get every nodes of a partition
|
||||
// into the subgraph due to failed alias check
|
||||
llgaHelper_.unmergeIfAnyNodeIsMissing(curNode);
|
||||
}
|
||||
curNode = prevNode;
|
||||
}
|
||||
for (Node* n : block_->nodes()) {
|
||||
for (Block* b : n->blocks()) {
|
||||
GraphRewriter(b, graph_, aliasDb_).cleanupSubgraphs();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphRewriter::buildupSubgraphs() {
|
||||
// We need to run the rewriter multiple times in order to get all merge
|
||||
// opportunities. This is because moveBeforeTopologicalValid may reorder
|
||||
// nodes to be AFTER the current iteration point. In order to properly
|
||||
// consider those nodes for merging, we need run the pass until no changes
|
||||
// have been made.
|
||||
//
|
||||
// Example:
|
||||
// c = f(a, b)
|
||||
// d = f(c)
|
||||
// e = f(d) <- iter is here, moving upward
|
||||
// After c.moveBeforeTopologicallyValid(e), we have:
|
||||
// c = f(a, b)
|
||||
// e = f(d) <- iter still here
|
||||
// d = f(c) <- this was node moved on the other side.
|
||||
// see [workblocks]
|
||||
auto workblocks = buildWorkBlocks();
|
||||
for (auto& workblock : workblocks) {
|
||||
bool any_changed = true;
|
||||
while (any_changed) {
|
||||
any_changed = false;
|
||||
auto workblock_end = workblock.end()->reverseIterator();
|
||||
auto workblock_begin = workblock.begin()->reverseIterator();
|
||||
for (auto it = workblock_end; it != workblock_begin;) {
|
||||
bool changed = false;
|
||||
std::tie(it, changed) = scanNode(*it, workblock_begin);
|
||||
any_changed |= changed;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Construct Subgraphs Recursively
|
||||
for (Node* n : block_->nodes()) {
|
||||
for (auto subBlock : n->blocks()) {
|
||||
GraphRewriter(subBlock, graph_, aliasDb_).buildupSubgraphs();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<WorkBlock> GraphRewriter::buildWorkBlocks() {
|
||||
// [workblocks]
|
||||
// the IR has many nodes which can never be reordered around, such as a
|
||||
// prim::Bailout. if a node N is surrounded by two nodes which cannot be
|
||||
// reordered, A and B, then a fusion group that is created from N
|
||||
// can only contain nodes from (A, B) The nodes from A to B represent one
|
||||
// work block for the subgraph rewriter to work on. By creating these up
|
||||
// front, we avoid retraversing the whole graph block any time scanNode
|
||||
// returns
|
||||
Node* end_bound_node = block_->return_node();
|
||||
Node* curr = end_bound_node->prev();
|
||||
std::vector<WorkBlock> worklist;
|
||||
while (curr != block_->param_node()) {
|
||||
// cannot reorder around side effectful nodes
|
||||
if (curr->hasSideEffects()) {
|
||||
worklist.emplace_back(curr, end_bound_node);
|
||||
end_bound_node = curr;
|
||||
}
|
||||
curr = curr->prev();
|
||||
}
|
||||
worklist.emplace_back(curr, end_bound_node);
|
||||
return worklist;
|
||||
}
|
||||
|
||||
std::pair<graph_node_list::iterator, bool> GraphRewriter::scanNode(
|
||||
Node* consumer,
|
||||
graph_node_list::iterator workblock_begin) {
|
||||
GRAPH_DEBUG("Scanning ", consumer->kind().toQualString());
|
||||
if (llgaHelper_.shouldConsiderForMerge(consumer)) {
|
||||
if (!llgaHelper_.isLlgaSubgraph(consumer)) {
|
||||
consumer = llgaHelper_.createSingletonSubgraph(consumer, aliasDb_);
|
||||
}
|
||||
// Iterate through the workblock to merge nodes of the
|
||||
// same partition determined by LLGA graph helper.
|
||||
// Nodes like B and C do not share a common input but belong to a
|
||||
// same partition, and thus we cannot only scan the input nodes
|
||||
// to find merging opportunities. Instead, we have to scan through
|
||||
// the whole workblock, which might lead to O^2 accesses in worst case
|
||||
// A
|
||||
// + - - / - \ - - +
|
||||
// | B C |
|
||||
// | | | |
|
||||
// | D E |
|
||||
// + - - \ - / - - +
|
||||
// F
|
||||
auto prev = ++consumer->reverseIterator();
|
||||
for (auto it = prev; it != workblock_begin; it++) {
|
||||
if (auto group = tryMerge(consumer, *it)) {
|
||||
// we successfully merged, so the new group's `inputs` may have
|
||||
// changed. So rescan the new group for more merging opportunities.
|
||||
return std::make_pair(group.value()->reverseIterator(), true);
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(++consumer->reverseIterator(), false);
|
||||
}
|
||||
|
||||
// Try to merge `producer` into `consumer`. If successful, this destroys
|
||||
// `producer` and returns the `consumer` group.
|
||||
c10::optional<Node*> GraphRewriter::tryMerge(Node* consumer, Node* producer) {
|
||||
AT_ASSERT(llgaHelper_.isLlgaSubgraph(consumer));
|
||||
bool canMerge = llgaHelper_.shouldMerge(producer, consumer) &&
|
||||
aliasDb_.moveBeforeTopologicallyValid(producer, consumer);
|
||||
if (!canMerge) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
llgaHelper_.mergeNodeIntoSubgraph(producer, consumer, aliasDb_);
|
||||
return consumer;
|
||||
}
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
45
torch/csrc/jit/codegen/onednn/guard_shape.cpp
Normal file
45
torch/csrc/jit/codegen/onednn/guard_shape.cpp
Normal file
@ -0,0 +1,45 @@
|
||||
#include <torch/csrc/jit/codegen/onednn/guard_shape.h>
|
||||
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
||||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
//! [ Note -- prepareFusionGroupAndGuardOutputs implementation ]
|
||||
//! shamelessly copying code from NNC (tensorexpr_fuser) with very little
|
||||
//! modification, original code at:
|
||||
//! `torch/csrc/jit/passes/tensorexpr_fuser.cpp:prepareFusionGroupAndGuardOutputs`
|
||||
//!
|
||||
//! We have the assumption that LLGA does not have operators
|
||||
//! depending on the content of the tensor.
|
||||
void prepareFusionGroupAndGuardOutputs(Block* block) {
|
||||
std::vector<Node*> fusion_groups;
|
||||
for (Node* n : block->nodes()) {
|
||||
for (Block* b : n->blocks()) {
|
||||
prepareFusionGroupAndGuardOutputs(b);
|
||||
}
|
||||
if (n->kind() == prim::oneDNNFusionGroup) {
|
||||
fusion_groups.push_back(n);
|
||||
}
|
||||
}
|
||||
for (Node* fusion_group : fusion_groups) {
|
||||
// TODO: add further optimization pass to removeOutputsUsedOnlyInSize,
|
||||
// refer to
|
||||
// `torch/csrc/jit/passes/tensorexpr_fuser.cpp:removeOutputsUsedOnlyInSize`
|
||||
// removeOutputsUsedOnlyInSize(fusion_group);
|
||||
insertTypeGuard(
|
||||
fusion_group,
|
||||
[](const TensorTypePtr& t) { return t; },
|
||||
prim::oneDNNFusionGuard);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
15
torch/csrc/jit/codegen/onednn/guard_shape.h
Normal file
15
torch/csrc/jit/codegen/onednn/guard_shape.h
Normal file
@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
void prepareFusionGroupAndGuardOutputs(Block* block);
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
172
torch/csrc/jit/codegen/onednn/interface.cpp
Normal file
172
torch/csrc/jit/codegen/onednn/interface.cpp
Normal file
@ -0,0 +1,172 @@
|
||||
#include <oneapi/dnnl/dnnl_graph.hpp>
|
||||
#include <torch/csrc/jit/codegen/onednn/defer_size_check.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/guard_shape.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/interface.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/kernel.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/layout_propagation.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/prepare_binary.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/decompose_ops.h>
|
||||
#include <torch/csrc/jit/passes/pass_manager.h>
|
||||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||
#include <torch/csrc/jit/runtime/operator_options.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
void fuseGraph(std::shared_ptr<Graph>& g) {
|
||||
// Follow the process of the tensorexpr_fuser in profiling mode:
|
||||
// Remove prim::profile nodes and embed the profile info directly in the
|
||||
// IR in value types to avoid breaking the fusion patterns.
|
||||
// Will add shape guard after LLGA optimization passes and
|
||||
// wipe the tensor type information from the IR, so that it's not
|
||||
// accidentally used by any other pass.
|
||||
|
||||
// We rely on the shape specialization and shape guard to ensure the validity
|
||||
// of the cached compilation in the kernel, thus only support profiling mode.
|
||||
// TODO: add check on oneDNNFusionGroup to ensure allShapesAreKnown on nodes
|
||||
// to fuse: torch/csrc/jit/passes/tensorexpr_fuser.cpp: allShapesAreKnown
|
||||
if (getProfilingMode()) {
|
||||
GRAPH_DUMP(
|
||||
"Before RemoveProfileNodesAndSpecializeTypes. Beginning of LLGA "
|
||||
"optimization pass",
|
||||
g);
|
||||
RemoveProfileNodesAndSpecializeTypes(g);
|
||||
GRAPH_DUMP(
|
||||
"After RemoveProfileNodesAndSpecializeTypes. Before mutation removal",
|
||||
g);
|
||||
|
||||
RemoveTensorMutation(g, [](Node* nodeToFunctionalize) {
|
||||
static std::unordered_set<Symbol> supportedOps = {
|
||||
aten::add_,
|
||||
aten::mul_,
|
||||
aten::tanh_,
|
||||
aten::elu_,
|
||||
aten::relu_,
|
||||
aten::relu6_,
|
||||
aten::gelu_,
|
||||
aten::sqrt_,
|
||||
aten::sigmoid_,
|
||||
aten::hardtanh_,
|
||||
aten::abs_,
|
||||
aten::square_,
|
||||
};
|
||||
return supportedOps.count(nodeToFunctionalize->kind()) != 0;
|
||||
});
|
||||
RemoveListMutation(g);
|
||||
GRAPH_DUMP("After mutation removal. Before PrepareBinaryForLLGA", g);
|
||||
PrepareBinaryForLLGA(g);
|
||||
GRAPH_DUMP("After PrepareBinaryForLLGA. Before DeferSizeCheck", g);
|
||||
DeferSizeCheck(g);
|
||||
GRAPH_DUMP("After DeferSizeCheck. Before CreateLlgaSubgraphs", g);
|
||||
CreateLlgaSubgraphs(g);
|
||||
GRAPH_DUMP("After CreateLlgaSubgraphs. Before PropagateLayout", g);
|
||||
PropagateLayout(g);
|
||||
GRAPH_DUMP(
|
||||
"After PropagateLayout. Before prepareFusionGroupAndGuardOutputs", g);
|
||||
|
||||
// Add shape guard for profiling mode and wipe the tensor type information
|
||||
// from the IR
|
||||
prepareFusionGroupAndGuardOutputs(g->block());
|
||||
GRAPH_DUMP(
|
||||
"After prepareFusionGroupAndGuardOutputs. Before "
|
||||
"RemoveTensorTypeSpecializations",
|
||||
g);
|
||||
RemoveTensorTypeSpecializations(g);
|
||||
GRAPH_DUMP(
|
||||
"After RemoveTensorTypeSpecializations. End of LLGA optimization pass",
|
||||
g);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
|
||||
Operation createLlgaKernel(const Node* node) {
|
||||
auto kernel = std::make_shared<fuser::onednn::LlgaKernel>(node);
|
||||
return [kernel](Stack* stack) {
|
||||
RECORD_FUNCTION(kernel->debugName(), std::vector<c10::IValue>());
|
||||
kernel->run(*stack);
|
||||
return 0;
|
||||
};
|
||||
}
|
||||
|
||||
RegisterOperators oneDNNFusionGroupOp({
|
||||
torch::jit::Operator(
|
||||
prim::oneDNNFusionGroup,
|
||||
createLlgaKernel,
|
||||
AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
|
||||
});
|
||||
|
||||
// Currently, we convert some scalar inputs, such as the second argument of
|
||||
// binary ops to a 1D tensor. Other scalar inputs are prim::Constant nodes.
|
||||
// But if we have any scalar inputs to guard in the future, some logic here
|
||||
// would have to be changed.
|
||||
Operation createLlgaGuardKernel(const Node* node) {
|
||||
return [node](Stack* stack) {
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("Guarding node: ", node->kind().toQualString());
|
||||
#endif
|
||||
std::vector<TypePtr> types = node->tys(attr::types);
|
||||
const auto num_inputs = types.size();
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("num_inputs to guard: ", num_inputs);
|
||||
#endif
|
||||
for (size_t i = 0; i < num_inputs; i++) {
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("checking input ", i);
|
||||
#endif
|
||||
auto& input = peek(stack, i, num_inputs);
|
||||
const c10::TensorTypePtr& guard_tensor_type =
|
||||
types[i]->cast<TensorType>();
|
||||
|
||||
if (!input.isTensor()) {
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("input ", i, " is not a tensor, return false");
|
||||
#endif
|
||||
push(stack, IValue(false));
|
||||
return;
|
||||
}
|
||||
const at::Tensor& tensor = input.toTensor();
|
||||
|
||||
// If input tensor is of mkldnn, it's originated from an upstream
|
||||
// LLGA partition that has passed the check on input shapes.
|
||||
// It is valid to continue here as long as the output shapes from
|
||||
// oneDNN graph partitions are determined by the input shapes.
|
||||
if (tensor.is_mkldnn()) {
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("input ", i, " is_mkldnn, continue");
|
||||
#endif
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!guard_tensor_type->matchTensor(tensor)) {
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("input ", i, " check failed, return false");
|
||||
#endif
|
||||
push(stack, IValue(false));
|
||||
return;
|
||||
}
|
||||
}
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("all check done, return true");
|
||||
#endif
|
||||
push(stack, IValue(true));
|
||||
return;
|
||||
};
|
||||
}
|
||||
|
||||
RegisterOperators oneDNNGuardOp({
|
||||
torch::jit::Operator(
|
||||
prim::oneDNNFusionGuard,
|
||||
createLlgaGuardKernel,
|
||||
AliasAnalysisKind::FROM_SCHEMA),
|
||||
});
|
||||
} // namespace jit
|
||||
} // namespace torch
|
62
torch/csrc/jit/codegen/onednn/interface.h
Normal file
62
torch/csrc/jit/codegen/onednn/interface.h
Normal file
@ -0,0 +1,62 @@
|
||||
#pragma once
|
||||
#include <ATen/Config.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/passes/pass_manager.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
static std::atomic<bool> onednn_enabled{false};
|
||||
|
||||
std::atomic<bool>& getLlgaEnabled() {
|
||||
return onednn_enabled;
|
||||
}
|
||||
|
||||
C10_EXPORT void fuseGraph(std::shared_ptr<Graph>& g);
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
|
||||
struct C10_EXPORT RegisterLlgaFuseGraph
|
||||
: public PassManager<RegisterLlgaFuseGraph> {
|
||||
static bool setEnabled(bool enabled) {
|
||||
TORCH_CHECK(
|
||||
AT_MKLDNN_ENABLED(),
|
||||
"Running oneDNN Graph fuser is only supported with MKLDNN builds.");
|
||||
bool oldState = fuser::onednn::getLlgaEnabled();
|
||||
fuser::onednn::getLlgaEnabled() = enabled;
|
||||
if (enabled) {
|
||||
registerPass(fuser::onednn::fuseGraph);
|
||||
} else {
|
||||
clearPass();
|
||||
}
|
||||
return oldState;
|
||||
}
|
||||
|
||||
static bool isEnabled() {
|
||||
return fuser::onednn::getLlgaEnabled();
|
||||
}
|
||||
|
||||
// override PassManager::registerPass to register pre-pass
|
||||
static bool registerPass(GraphPass p) {
|
||||
if (!isRegistered()) {
|
||||
passID(registerPrePass(std::move(p)), true);
|
||||
isRegistered(true);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// override PassManager::clearPass to clear pre-pass
|
||||
static void clearPass() {
|
||||
if (isRegistered()) {
|
||||
clearPrePass(passID());
|
||||
isRegistered(true);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
292
torch/csrc/jit/codegen/onednn/kernel.cpp
Normal file
292
torch/csrc/jit/codegen/onednn/kernel.cpp
Normal file
@ -0,0 +1,292 @@
|
||||
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/kernel.h>
|
||||
|
||||
#include <ATen/core/functional.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
using namespace dnnl::graph;
|
||||
using data_type = dnnl::graph::logical_tensor::data_type;
|
||||
|
||||
LlgaKernel::LlgaKernel(const Node* fusionNode)
|
||||
: fusionNode_(fusionNode),
|
||||
graph_(fusionNode->g(attr::Subgraph)),
|
||||
nGraphInputs_(graph_->inputs().size()),
|
||||
nOutputs_(graph_->outputs().size()),
|
||||
debugName_(genDebugName()) {
|
||||
// TODO: This is a workaround to recreate the partitions here.
|
||||
// The ideal way is to use the partition serialization API (not available from
|
||||
// LLGA now) to carry a serialized string representation from graph rewrite
|
||||
// and deserialize it here.
|
||||
auto llgaGraphHelper = LlgaGraphHelper(graph_);
|
||||
auto partitions = llgaGraphHelper.getPartitions();
|
||||
tensorIdToValue_ = llgaGraphHelper.getTensorIdToValue();
|
||||
TORCH_CHECK(
|
||||
partitions.size() == 1,
|
||||
"LLGA subgraph should contain only one partition");
|
||||
partition_ = partitions[0];
|
||||
nPartitionInputs_ = partition_.get_in_ports().size();
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("Initialized ", debugName(), "\n", graph_->toString());
|
||||
#endif
|
||||
}
|
||||
|
||||
bool LlgaKernel::useOpaqueLayout(size_t offset) const {
|
||||
return LlgaNodeWrapper(fusionNode_).useOpaqueLayout(offset);
|
||||
}
|
||||
|
||||
void LlgaKernel::initializeConstantInputs() {
|
||||
for (auto& lt : partition_.get_in_ports()) {
|
||||
auto inputId = lt.get_id();
|
||||
if (initializedInputIds_.find(inputId) == initializedInputIds_.end()) {
|
||||
TORCH_CHECK(
|
||||
tensorIdToValue_.count(inputId) > 0,
|
||||
"inputs with inputId ",
|
||||
inputId,
|
||||
" is missing");
|
||||
auto* value = tensorIdToValue_[inputId];
|
||||
|
||||
TORCH_CHECK(
|
||||
value->node()->kind() == prim::Constant &&
|
||||
value->type()->cast<TensorType>(),
|
||||
"inputs with inputId ",
|
||||
inputId,
|
||||
" should be a Constant tensor");
|
||||
constantValues_.emplace_back(value);
|
||||
|
||||
auto const_tensor = toIValue(value)->toTensor();
|
||||
constantInputs_.emplace_back(const_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::map<size_t, int64_t> LlgaKernel::initializeTensorIdToOccurence() const {
|
||||
std::map<size_t, int64_t> tensorIdToOccurence;
|
||||
for (auto& lt : partition_.get_in_ports()) {
|
||||
auto inputId = lt.get_id();
|
||||
std::map<size_t, int64_t>::iterator it(tensorIdToOccurence.find(inputId));
|
||||
if (it != tensorIdToOccurence.end()) {
|
||||
it->second++;
|
||||
} else {
|
||||
tensorIdToOccurence[inputId] = 1;
|
||||
}
|
||||
}
|
||||
return tensorIdToOccurence;
|
||||
}
|
||||
|
||||
ArgSpecs LlgaKernel::initializeInputSpecs(const TensorArgs& inputs) {
|
||||
ArgSpecs inputSpecs;
|
||||
inputSpecs.reserve(nPartitionInputs_);
|
||||
GRAPH_DEBUG("Initializing graph input logical tensors");
|
||||
std::map<size_t, int64_t> tensorIdToOccurence =
|
||||
initializeTensorIdToOccurence();
|
||||
for (size_t i = 0; i < nGraphInputs_; i++) {
|
||||
auto spec = ArgSpec(graph_->inputs()[i]).supplementTensorInfo(inputs[i]);
|
||||
initializedInputIds_.insert(spec.tid());
|
||||
int64_t occurence = tensorIdToOccurence[spec.tid()];
|
||||
inputSpecs.insert(inputSpecs.end(), occurence, spec);
|
||||
runArgsIdx_.insert(runArgsIdx_.end(), occurence, i);
|
||||
}
|
||||
GRAPH_DEBUG("Initializing constant input tensors");
|
||||
initializeConstantInputs();
|
||||
|
||||
TORCH_CHECK(
|
||||
inputSpecs.size() + constantValues_.size() == nPartitionInputs_,
|
||||
"Partition inputs are missing");
|
||||
GRAPH_DEBUG(
|
||||
"Concatenating constant input logical tensors to graph input "
|
||||
"logical tensors");
|
||||
for (Value* constant_value : constantValues_) {
|
||||
ArgSpec constantInputSpec(constant_value);
|
||||
inputSpecs.emplace_back(constantInputSpec);
|
||||
constantLogicalTensors_.emplace_back(constantInputSpec.logical_tensor());
|
||||
}
|
||||
return inputSpecs;
|
||||
}
|
||||
|
||||
ArgSpecs LlgaKernel::initializeOutputSpecs() const {
|
||||
ArgSpecs outputSpecs;
|
||||
outputSpecs.reserve(nOutputs_);
|
||||
for (size_t i = 0; i < nOutputs_; i++) {
|
||||
auto spec = ArgSpec(graph_->outputs()[i]);
|
||||
if (useOpaqueLayout(i)) {
|
||||
spec = spec.any();
|
||||
}
|
||||
outputSpecs.emplace_back(spec);
|
||||
}
|
||||
return outputSpecs;
|
||||
}
|
||||
|
||||
std::tuple<RunArgs, RunArgs> LlgaKernel::prepareRunArgs(
|
||||
const TensorArgs& inputs,
|
||||
TensorArgs& outputs) const {
|
||||
RunArgs runInputs, runOutputs;
|
||||
auto numInputs = runArgsIdx_.size();
|
||||
for (size_t i = 0; i < numInputs; i++) {
|
||||
auto spec = inputSpecs_[i];
|
||||
auto input = inputs[runArgsIdx_[i]];
|
||||
runInputs.push_back(
|
||||
{spec.logical_tensor(), Engine::getEngine(), input.data_ptr()});
|
||||
}
|
||||
auto numConstantInputs = constantInputs_.size();
|
||||
for (size_t i = 0; i < numConstantInputs; i++) {
|
||||
// constantInputSpecs are placed after graphInputSpecs
|
||||
auto constantInputSpecIdx = nGraphInputs_ + i;
|
||||
auto constantInputSpec = inputSpecs_[constantInputSpecIdx];
|
||||
runInputs.push_back(
|
||||
{constantLogicalTensors_[i],
|
||||
Engine::getEngine(),
|
||||
constantInputs_[i].data_ptr()});
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < nOutputs_; i++) {
|
||||
auto spec = outputSpecs_[i];
|
||||
auto opt = c10::TensorOptions(spec.aten_scalar_type()).device(device_);
|
||||
|
||||
if (spec.reuses_input_tensor()) {
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("oneDNN Graph would perform inplace computation");
|
||||
#endif
|
||||
auto inputTensor = inputs[spec.get_input_tensor_index()];
|
||||
auto dataType = spec.dtype();
|
||||
if (C10_UNLIKELY(!useOpaqueLayout(i) && inputTensor.is_mkldnn())) {
|
||||
// If the input tensor was between two partitions, it would've been
|
||||
// wrapped with LlgaTensorImpl. But if it's being reused as the output
|
||||
// tensor which is not between two partitions, then we'd have to re-wrap
|
||||
// it with TensorImpl, as it'd be fed into a PyTorch op.
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("Rewrap tensor");
|
||||
#endif
|
||||
auto llgaImpl =
|
||||
static_cast<LlgaTensorImpl*>(inputTensor.unsafeGetTensorImpl());
|
||||
switch (dataType) {
|
||||
case data_type::f32:
|
||||
case data_type::bf16:
|
||||
inputTensor = LlgaTensorImpl::llga_to_aten_tensor(llgaImpl);
|
||||
break;
|
||||
case data_type::s32:
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false, "Invalid data type ", static_cast<size_t>(dataType));
|
||||
}
|
||||
}
|
||||
outputs.push_back(inputTensor);
|
||||
runOutputs.push_back(
|
||||
{spec.logical_tensor(), Engine::getEngine(), inputTensor.data_ptr()});
|
||||
} else if (useOpaqueLayout(i)) {
|
||||
// Wrap tensors between partitions with LlgaTensorImpl wrapper, so that we
|
||||
// can bypass guard-check, as strides would be different than those
|
||||
// expected.
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("Between two oneDNN Graph partitions");
|
||||
#endif
|
||||
auto tensor = empty_llga(spec, opt);
|
||||
outputs.push_back(tensor);
|
||||
runOutputs.push_back(llga_from_aten_tensor(tensor));
|
||||
} else {
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("Neither opaque to PyTorch nor inplace-computation");
|
||||
#endif
|
||||
auto tensor = at::empty_strided(spec.sizes(), spec.strides(), opt);
|
||||
outputs.push_back(tensor);
|
||||
runOutputs.push_back(
|
||||
{spec.logical_tensor(), Engine::getEngine(), tensor.data_ptr()});
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(runInputs, runOutputs);
|
||||
}
|
||||
|
||||
compiled_partition LlgaKernel::compile(const partition& partition) {
|
||||
auto inputs = fmap(inputSpecs_, toLogicalTensor);
|
||||
auto outputs = fmap(outputSpecs_, toLogicalTensor);
|
||||
auto compilation = partition.compile(inputs, outputs, Engine::getEngine());
|
||||
|
||||
// Since layouts of opaque outputs would be known after compilation,
|
||||
// we need to query them out from compilation and update outputSpecs
|
||||
for (size_t i = 0; i < nOutputs_; i++) {
|
||||
auto tid = outputSpecs_[i].tid();
|
||||
outputSpecs_[i] = compilation.query_logical_tensor(tid);
|
||||
}
|
||||
|
||||
// Build static mapping from output id to input offset
|
||||
// in accordance with available inplace options
|
||||
for (auto&& option : compilation.get_inplace_ports()) {
|
||||
size_t inputId = option.first;
|
||||
size_t outputId = option.second;
|
||||
auto inputSpecIter =
|
||||
std::find_if(inputSpecs_.begin(), inputSpecs_.end(), [&](auto& spec) {
|
||||
return spec.tid() == inputId;
|
||||
});
|
||||
TORCH_CHECK(inputSpecIter != inputSpecs_.end(), "In-place input not found");
|
||||
auto inputOffset = inputSpecIter - inputSpecs_.begin();
|
||||
auto outputSpecIter =
|
||||
std::find_if(outputSpecs_.begin(), outputSpecs_.end(), [&](auto& spec) {
|
||||
return spec.tid() == outputId;
|
||||
});
|
||||
auto outputOffset = outputSpecIter - outputSpecs_.begin();
|
||||
outputSpecs_[outputOffset].set_compute_inplace();
|
||||
outputSpecs_[outputOffset].set_input_tensor_index(inputOffset);
|
||||
}
|
||||
|
||||
return compilation;
|
||||
}
|
||||
|
||||
void LlgaKernel::run(Stack& stack) {
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("In ", debugName(), "\n");
|
||||
#endif
|
||||
|
||||
// Grab input values from stack
|
||||
auto stackInputs = last(stack, nGraphInputs_);
|
||||
auto inputs = fmap(stackInputs, [&](const IValue& v) {
|
||||
TORCH_CHECK(
|
||||
v.isTensor(), "Stack values for LLGA partition must be Tensor type");
|
||||
return v.toTensor();
|
||||
});
|
||||
|
||||
// Even in case of concurrent threads, the kernel would be initialized once.
|
||||
// TODO: Try not using an atomic lock
|
||||
std::call_once(
|
||||
initialized_flag,
|
||||
[&](const TensorArgs& inputs) {
|
||||
GRAPH_DEBUG("Initializing input logical tensors");
|
||||
inputSpecs_ = initializeInputSpecs(inputs);
|
||||
GRAPH_DEBUG("Initializing output logical tensors");
|
||||
outputSpecs_ = initializeOutputSpecs();
|
||||
GRAPH_DEBUG("Compiling partition");
|
||||
compilation_ = compile(partition_);
|
||||
is_initialized_ = true;
|
||||
},
|
||||
inputs);
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("Preparing runtime tensors");
|
||||
#endif
|
||||
TensorArgs outputs;
|
||||
RunArgs runInputs, runOutputs;
|
||||
std::tie(runInputs, runOutputs) = prepareRunArgs(inputs, outputs);
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("Executing partition");
|
||||
#endif
|
||||
compilation_.execute(Stream::getStream(), runInputs, runOutputs);
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("Partition executed");
|
||||
#endif
|
||||
|
||||
// Update the stack.
|
||||
drop(stack, nGraphInputs_);
|
||||
for (auto& o : outputs)
|
||||
push_one(stack, std::move(o));
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("Stack updated");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
93
torch/csrc/jit/codegen/onednn/kernel.h
Normal file
93
torch/csrc/jit/codegen/onednn/kernel.h
Normal file
@ -0,0 +1,93 @@
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include <oneapi/dnnl/dnnl_graph.hpp>
|
||||
#include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/runtime/interpreter.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
using ArgSpec = LlgaTensorDesc;
|
||||
using ArgSpecs = std::vector<ArgSpec>;
|
||||
using RunArg = dnnl::graph::tensor;
|
||||
using RunArgs = std::vector<RunArg>;
|
||||
using TensorArgs = std::vector<at::Tensor>;
|
||||
|
||||
class LlgaKernel {
|
||||
public:
|
||||
explicit LlgaKernel(const Node* fusionNode);
|
||||
|
||||
void run(Stack& stack);
|
||||
|
||||
void initialize(const TensorArgs& inputs);
|
||||
|
||||
const std::string& debugName() const {
|
||||
return debugName_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool useOpaqueLayout(size_t offset) const;
|
||||
|
||||
// PyTorch copy constants inside the subgraph instead of referencing them.
|
||||
// Constants inputs to the partition are no longer in the graph->inputs().
|
||||
// Need use the tid retrieved from the partition to find the missing
|
||||
// constant inputs.
|
||||
void initializeConstantInputs();
|
||||
|
||||
ArgSpecs initializeInputSpecs(const TensorArgs& inputs);
|
||||
|
||||
ArgSpecs initializeOutputSpecs() const;
|
||||
|
||||
dnnl::graph::compiled_partition compile(
|
||||
const dnnl::graph::partition& partition);
|
||||
|
||||
std::map<size_t, int64_t> initializeTensorIdToOccurence() const;
|
||||
|
||||
std::tuple<RunArgs, RunArgs> prepareRunArgs(
|
||||
const TensorArgs& inputs,
|
||||
TensorArgs& outputs) const;
|
||||
|
||||
static std::string genDebugName() {
|
||||
static size_t debugId = 0;
|
||||
return "LlgaPartition_" + std::to_string(debugId++);
|
||||
}
|
||||
|
||||
static dnnl::graph::logical_tensor toLogicalTensor(const ArgSpec& s) {
|
||||
return s.logical_tensor();
|
||||
}
|
||||
|
||||
at::Device device_ = at::kCPU;
|
||||
const Node* fusionNode_;
|
||||
std::shared_ptr<Graph> graph_;
|
||||
int64_t nGraphInputs_ = 0; // number of inputs to graph_ on the IR
|
||||
int64_t nOutputs_ = 0;
|
||||
std::map<size_t, Value*> tensorIdToValue_;
|
||||
std::vector<int64_t> runArgsIdx_;
|
||||
dnnl::graph::partition partition_;
|
||||
// nPartitionInputs_ is the actual number of inputs to partition_ of graph_
|
||||
// needed by the backend.
|
||||
// nPartitionInputs_ = nGraphInputs_ + constantInputs_.size() since Constant
|
||||
// inputs are copied to the inside of the subgraph
|
||||
int64_t nPartitionInputs_;
|
||||
dnnl::graph::compiled_partition compilation_;
|
||||
std::set<size_t> initializedInputIds_;
|
||||
std::vector<Value*> constantValues_;
|
||||
TensorArgs constantInputs_;
|
||||
ArgSpecs inputSpecs_;
|
||||
ArgSpecs outputSpecs_;
|
||||
std::vector<dnnl::graph::logical_tensor> constantLogicalTensors_;
|
||||
std::string debugName_;
|
||||
std::once_flag initialized_flag;
|
||||
bool is_initialized_ = false;
|
||||
};
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
44
torch/csrc/jit/codegen/onednn/layout_propagation.cpp
Normal file
44
torch/csrc/jit/codegen/onednn/layout_propagation.cpp
Normal file
@ -0,0 +1,44 @@
|
||||
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/layout_propagation.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
void LayoutPropagation(Node* n) {
|
||||
if (!LlgaGraphHelper::isLlgaSubgraph(n))
|
||||
return;
|
||||
|
||||
for (auto input : n->inputs()) {
|
||||
auto prev = input->node();
|
||||
auto offset = input->offset();
|
||||
if (LlgaGraphHelper::isLlgaSubgraph(prev)) {
|
||||
bool useOpaqueLayout = true;
|
||||
for (auto& use : input->uses()) {
|
||||
if (!LlgaGraphHelper::isLlgaSubgraph(use.user)) {
|
||||
useOpaqueLayout = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (useOpaqueLayout) {
|
||||
LlgaNodeWrapper(prev).setOpaqueLayout(offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LayoutPropagation(at::ArrayRef<Block*> blocks) {
|
||||
for (Block* block : blocks)
|
||||
for (Node* node : block->nodes())
|
||||
LayoutPropagation(node);
|
||||
}
|
||||
|
||||
void PropagateLayout(const std::shared_ptr<Graph>& graph) {
|
||||
LayoutPropagation(graph->block());
|
||||
}
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
15
torch/csrc/jit/codegen/onednn/layout_propagation.h
Normal file
15
torch/csrc/jit/codegen/onednn/layout_propagation.h
Normal file
@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
void PropagateLayout(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
101
torch/csrc/jit/codegen/onednn/operator.h
Normal file
101
torch/csrc/jit/codegen/onednn/operator.h
Normal file
@ -0,0 +1,101 @@
|
||||
#pragma once
|
||||
|
||||
#include <oneapi/dnnl/dnnl_graph.hpp>
|
||||
#include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
class Operator {
|
||||
public:
|
||||
Operator(const Node* node, dnnl::graph::op::kind kind)
|
||||
: n(node), o(getId(node), kind, node->kind().toQualString()), k(kind) {}
|
||||
|
||||
Operator& setInputValue(Value* v) {
|
||||
if (v->mustNotBeNone())
|
||||
o.add_input(createLogicalTensor(v));
|
||||
return *this;
|
||||
}
|
||||
|
||||
Operator& setInput(size_t offset) {
|
||||
return setInputValue(n->input(offset));
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
Operator& setInput(size_t offset, Ts... other) {
|
||||
setInput(offset);
|
||||
return setInput(other...);
|
||||
}
|
||||
|
||||
Operator& setOutputValue(Value* v) {
|
||||
if (v->mustNotBeNone())
|
||||
o.add_output(createLogicalTensor(v));
|
||||
return *this;
|
||||
}
|
||||
|
||||
Operator& setOutput(size_t offset) {
|
||||
return setOutputValue(n->output(offset));
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
Operator& setOutput(size_t offset, Ts... other) {
|
||||
setOutput(offset);
|
||||
return setOutput(other...);
|
||||
}
|
||||
|
||||
template <typename Attr>
|
||||
Operator& setAttr(std::string name, Attr&& attr) {
|
||||
o.set_attr(name, std::forward<Attr>(attr));
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
Operator& setAttr(std::string name, const F& fn, size_t offset) {
|
||||
return setAttr(name, fn(n, offset));
|
||||
}
|
||||
|
||||
static std::vector<int64_t> Ints(const Node* node, size_t offset) {
|
||||
return toIValue(node->input(offset))->toIntVector();
|
||||
}
|
||||
|
||||
static int64_t Int(const Node* node, size_t offset) {
|
||||
return toIValue(node->input(offset))->toInt();
|
||||
}
|
||||
|
||||
static float Float(const Node* node, size_t offset) {
|
||||
return static_cast<float>(toIValue(node->input(offset))->toDouble());
|
||||
}
|
||||
|
||||
static bool Bool(const Node* node, size_t offset) {
|
||||
return toIValue(node->input(offset))->toBool();
|
||||
}
|
||||
|
||||
static uint64_t getId(const Node* node) {
|
||||
return reinterpret_cast<uint64_t>(node); // cast node address as op id
|
||||
}
|
||||
|
||||
dnnl::graph::op::kind kind() const {
|
||||
return k;
|
||||
}
|
||||
|
||||
dnnl::graph::op llgaOp() const {
|
||||
return o;
|
||||
}
|
||||
|
||||
private:
|
||||
dnnl::graph::logical_tensor createLogicalTensor(Value* value) const {
|
||||
return LlgaTensorDesc(value).logical_tensor();
|
||||
}
|
||||
|
||||
const Node* n;
|
||||
dnnl::graph::op o;
|
||||
dnnl::graph::op::kind k;
|
||||
};
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
106
torch/csrc/jit/codegen/onednn/prepare_binary.cpp
Normal file
106
torch/csrc/jit/codegen/onednn/prepare_binary.cpp
Normal file
@ -0,0 +1,106 @@
|
||||
#include <torch/csrc/jit/codegen/onednn/prepare_binary.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
bool compareConstValue(Value* v, double d) {
|
||||
auto ival = toIValue(v);
|
||||
return ival.has_value() &&
|
||||
((ival->isInt() && static_cast<int>(ival->toInt()) == d) ||
|
||||
(ival->isDouble() && ival->toDouble() == d));
|
||||
}
|
||||
|
||||
void mayConvertScalarInputToTensor(Node* node) {
|
||||
// We do not handle binary ops with two scalar inputs,
|
||||
// and we assume scalar is always at the second place.
|
||||
if (node->input(0)->type()->isSubtypeOf(TensorType::get()) &&
|
||||
(node->input(1)->type()->isSubtypeOf(FloatType::get()) ||
|
||||
node->input(1)->type()->isSubtypeOf(IntType::get()))) {
|
||||
auto scalar = node->input(1);
|
||||
WithInsertPoint guard(node);
|
||||
auto g = node->owningGraph();
|
||||
// 42 : Scalar --> tensor(42.0) : Float([])
|
||||
auto t = g->insert(
|
||||
aten::as_tensor, {scalar}, {{"dtype", at::ScalarType::Float}});
|
||||
// add dim & stride info to IR
|
||||
c10::optional<size_t> t_dim = 1;
|
||||
auto target_type = TensorTypePtr(
|
||||
TensorType::create(at::ScalarType::Float, at::kCPU, t_dim, false));
|
||||
target_type = target_type->withSizes({1});
|
||||
t->setType(target_type);
|
||||
|
||||
// tensor(42.0) : Float([]) --> tensor([42.0]) : Float([1])
|
||||
auto unsqueezed = g->insert(aten::unsqueeze, {t, 0});
|
||||
unsqueezed->setType(target_type);
|
||||
node->replaceInput(1, unsqueezed);
|
||||
}
|
||||
}
|
||||
|
||||
static void ConvertScalarToTensor(Block* block) {
|
||||
for (auto node : block->nodes()) {
|
||||
for (auto sub : node->blocks()) {
|
||||
ConvertScalarToTensor(sub);
|
||||
}
|
||||
|
||||
if (node->kind() == aten::add || node->kind() == aten::mul) {
|
||||
mayConvertScalarInputToTensor(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void mayDecomposeAdd(Node* node) {
|
||||
if (toIValue(node->namedInput("alpha")).has_value()) {
|
||||
auto alphaEqualsOne = compareConstValue(node->namedInput("alpha"), 1.0);
|
||||
if (!alphaEqualsOne) {
|
||||
WithInsertPoint guard(node);
|
||||
auto g = node->owningGraph();
|
||||
auto mul = g->insert(
|
||||
aten::mul, {node->namedInput("other"), node->namedInput("alpha")});
|
||||
node->replaceInput(1, mul);
|
||||
auto one = g->insertConstant(1.0);
|
||||
node->replaceInput(2, one);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void DecomposeFusedAdd(Block* block) {
|
||||
for (auto node : block->nodes()) {
|
||||
for (auto sub : node->blocks()) {
|
||||
DecomposeFusedAdd(sub);
|
||||
}
|
||||
|
||||
if (node->kind() == aten::add) {
|
||||
mayDecomposeAdd(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void EliminateIdentityMulAdd(Block* block) {
|
||||
for (auto node : block->nodes()) {
|
||||
for (auto sub : node->blocks()) {
|
||||
EliminateIdentityMulAdd(sub);
|
||||
}
|
||||
|
||||
if ((node->kind() == aten::add && compareConstValue(node->input(1), 0.0)) ||
|
||||
(node->kind() == aten::mul && compareConstValue(node->input(1), 1.0))) {
|
||||
node->output()->replaceAllUsesWith(node->namedInput("self"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PrepareBinaryForLLGA(const std::shared_ptr<Graph>& graph) {
|
||||
DecomposeFusedAdd(graph->block());
|
||||
EliminateIdentityMulAdd(graph->block());
|
||||
EliminateDeadCode(graph);
|
||||
// ConvertScalarToTensor must be placed after EliminateIdentityMulAdd
|
||||
ConvertScalarToTensor(graph->block());
|
||||
}
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
26
torch/csrc/jit/codegen/onednn/prepare_binary.h
Normal file
26
torch/csrc/jit/codegen/onednn/prepare_binary.h
Normal file
@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
// Prepare binary ops for LLGA
|
||||
//
|
||||
// The pass does the following:
|
||||
//
|
||||
// - Convert scalar input of aten::add and aten::mul into Float tensor with
|
||||
// dimension [1]
|
||||
//
|
||||
// - Decompose fused add into aten::mul + aten::add when alpha != 1.0
|
||||
//
|
||||
// - Eliminate identity add/mul, i.e., tensor + 0, tensor * 1
|
||||
//
|
||||
void PrepareBinaryForLLGA(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
54
torch/csrc/jit/codegen/onednn/register_interface.cpp
Normal file
54
torch/csrc/jit/codegen/onednn/register_interface.cpp
Normal file
@ -0,0 +1,54 @@
|
||||
#include <torch/csrc/jit/runtime/profiling_record.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
bool canFuseNode(const Node* node) {
|
||||
switch (node->kind()) {
|
||||
case aten::conv2d:
|
||||
case aten::_convolution:
|
||||
case aten::batch_norm:
|
||||
case aten::layer_norm:
|
||||
case aten::add:
|
||||
case aten::mul:
|
||||
case aten::tanh:
|
||||
case aten::relu:
|
||||
case aten::elu:
|
||||
case aten::sigmoid:
|
||||
case aten::gelu:
|
||||
case aten::sqrt:
|
||||
case aten::abs:
|
||||
case aten::square:
|
||||
case aten::hardtanh:
|
||||
case aten::relu6:
|
||||
case aten::softmax:
|
||||
case aten::max_pool2d:
|
||||
case aten::avg_pool2d:
|
||||
case aten::matmul:
|
||||
case aten::mm:
|
||||
case aten::linear:
|
||||
case aten::addmm:
|
||||
return true;
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
class RegisterInterface {
|
||||
public:
|
||||
RegisterInterface() {
|
||||
RegisterProfilingNode(canFuseNode);
|
||||
}
|
||||
};
|
||||
|
||||
static RegisterInterface register_interface_;
|
||||
} // namespace
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -623,6 +623,7 @@ void AliasDb::analyzeImpl(Node* node) {
|
||||
return analyzeLoop(node);
|
||||
case prim::FusionGroup:
|
||||
case prim::CudaFusionGroup:
|
||||
case prim::oneDNNFusionGroup:
|
||||
case prim::FunctionalGraph:
|
||||
case prim::DifferentiableGraph:
|
||||
case prim::FallbackGraph:
|
||||
|
@ -508,6 +508,7 @@ void Node::lint() const {
|
||||
break;
|
||||
case prim::FusionGroup:
|
||||
case prim::CudaFusionGroup:
|
||||
case prim::oneDNNFusionGroup:
|
||||
checkSameDevice(this);
|
||||
// TODO: Typecheck the parameters
|
||||
g(attr::Subgraph)->lint();
|
||||
|
@ -21,7 +21,8 @@ bool canRunWithAutograd(Node* node) {
|
||||
}
|
||||
return kind != prim::FusionGroup && kind != prim::CudaFusionGroup &&
|
||||
kind != prim::TypeCheck && kind != prim::TensorExprGroup &&
|
||||
kind != prim::CudaFusionGuard && (kind.is_aten() || kind.is_prim());
|
||||
kind != prim::CudaFusionGuard && kind != prim::oneDNNFusionGroup &&
|
||||
kind != prim::oneDNNFusionGuard && (kind.is_aten() || kind.is_prim());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
64
torch/csrc/jit/passes/onednn_graph_fuser.h
Normal file
64
torch/csrc/jit/passes/onednn_graph_fuser.h
Normal file
@ -0,0 +1,64 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/passes/pass_manager.h>
|
||||
|
||||
#include <ATen/Config.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
static std::atomic<bool> onednn_enabled{true};
|
||||
|
||||
static std::atomic<bool>& getLlgaEnabled() {
|
||||
return onednn_enabled;
|
||||
}
|
||||
|
||||
TORCH_API void fuseGraph(std::shared_ptr<Graph>& g);
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
|
||||
struct C10_EXPORT RegisterLlgaFuseGraph
|
||||
: public PassManager<RegisterLlgaFuseGraph> {
|
||||
static bool setEnabled(bool enabled) {
|
||||
TORCH_CHECK(
|
||||
AT_MKLDNN_ENABLED(),
|
||||
"Running oneDNN Graph fuser is only supported with MKLDNN builds.");
|
||||
bool oldState = fuser::onednn::getLlgaEnabled();
|
||||
fuser::onednn::getLlgaEnabled() = enabled;
|
||||
if (enabled) {
|
||||
registerPass(fuser::onednn::fuseGraph);
|
||||
} else {
|
||||
clearPass();
|
||||
}
|
||||
return oldState;
|
||||
}
|
||||
|
||||
static bool isEnabled() {
|
||||
return fuser::onednn::getLlgaEnabled();
|
||||
}
|
||||
|
||||
// override PassManager::registerPass to register pre-pass
|
||||
static bool registerPass(GraphPass p) {
|
||||
if (!isRegistered()) {
|
||||
passID(registerPrePass(std::move(p)), true);
|
||||
isRegistered(true);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// override PassManager::clearPass to clear pre-pass
|
||||
static void clearPass() {
|
||||
if (isRegistered()) {
|
||||
clearPrePass(passID());
|
||||
isRegistered(true);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -8,6 +8,9 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.h>
|
||||
#include <torch/csrc/jit/codegen/fuser/interface.h>
|
||||
#include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
|
||||
#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
|
||||
#include <torch/csrc/jit/codegen/onednn/interface.h>
|
||||
#endif
|
||||
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||
#include <torch/csrc/jit/frontend/tracer.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
@ -630,6 +633,10 @@ void initJITBindings(PyObject* module) {
|
||||
auto stack = toTraceableStack(args);
|
||||
checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
|
||||
})
|
||||
#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
|
||||
.def("_jit_set_llga_enabled", &RegisterLlgaFuseGraph::setEnabled)
|
||||
.def("_jit_llga_enabled", &RegisterLlgaFuseGraph::isEnabled)
|
||||
#endif
|
||||
.def(
|
||||
"_jit_set_nvfuser_skip_node_kind",
|
||||
// Args:
|
||||
|
@ -246,6 +246,8 @@ bool printerHasSpecialCaseFor(Symbol sym) {
|
||||
prim::StaticSubgraph, // optimization pass adds it
|
||||
prim::ConstantMKLDNNTensor, // optimization pass adds it
|
||||
prim::BroadcastMKLDNNTensors, // optimization pass adds it
|
||||
prim::oneDNNFusionGroup, // optimization pass adds it
|
||||
prim::oneDNNFusionGuard, // optimization pass adds it
|
||||
prim::StaticRuntimeCopyOuts, // used in SR only
|
||||
prim::Load, // used in interpreter only
|
||||
prim::MMTreeReduce, // used as an optimization
|
||||
@ -282,6 +284,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
|
||||
prim::Loop,
|
||||
prim::FusionGroup,
|
||||
prim::CudaFusionGroup,
|
||||
prim::oneDNNFusionGroup,
|
||||
prim::DifferentiableGraph,
|
||||
prim::TensorExprGroup,
|
||||
prim::TensorExprDynamicGroup,
|
||||
|
@ -229,7 +229,19 @@ def _hide_source_ranges() -> Iterator[None]:
|
||||
finally:
|
||||
torch._C.Graph.set_global_print_source_ranges(old_enable_source_ranges) # type: ignore[attr-defined]
|
||||
|
||||
# dont expose Any, TODO: define `__all__`
|
||||
def enable_onednn_fusion(enabled: bool):
|
||||
"""
|
||||
Enables or disables onednn JIT fusion based on the parameter `enabled`.
|
||||
"""
|
||||
|
||||
torch._C._jit_set_llga_enabled(enabled)
|
||||
|
||||
def onednn_fusion_enabled():
|
||||
"""
|
||||
Returns whether onednn JIT fusion is enabled
|
||||
"""
|
||||
return torch._C._jit_llga_enabled()
|
||||
|
||||
del Any
|
||||
|
||||
if not torch._C._jit_init():
|
||||
|
Reference in New Issue
Block a user