mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add BFloat16 dtype support for oneDNN Graph JIT fuser (#85591)
## BFloat16 dtype support for faster inference with TorchScript using oneDNN Graph Intel Xeon Cooper Lake platform & beyond support the `AVX512_BF16` ISA, which is essentially native BFloat16 support. oneDNN Graph delivers high inference performance with BFloat16 on such machines. While oneDNN Graph can still be used with BFloat16 on older machines that lack `avx512_bf16` ISA but support `avx512bw`, `avx512vl` & `avx512dq` ISAs, the BF16 performance on these older machines will be significantly poorer (probably even poorer than Float32), as they lack native BF16 support. Currently, [AMP support for eager mode & JIT mode is divergent in PyTorch](https://github.com/pytorch/pytorch/issues/75956). So, for using oneDNN Graph with BFloat16, eager-mode AMP should be leveraged by turning off AMP for JIT mode, using `torch._C._jit_set_autocast_mode(False)` in python code, so as to avoid conflicts. Please use the following environment variable to view JIT logs - `PYTORCH_JIT_LOG_LEVEL=">>graph_helper:>>graph_fuser:>>kernel:>>interface"` ## Changes being made in this PR 1. This PR does NOT change the `oneDNN` commit or the `ideep` files. While the `ideep` commit is being updated, only files pertaining to oneDNN Graph are being updated. oneDNN Graph is being upgraded to version 0.5.2 (alpha patch release 2). To put things into perspective, `ideep` is a git submodule of PyTorch. `oneDNN Graph` is a git submodule of `ideep` (`ideep/mkl-dnn`), and oneDNN is a git submodule of oneDNN Graph (`ideep/mkl-dnn/third_party/oneDNN`). 2. Unit-tests are being updated. We now use the [existing dtypes decorator](https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_device_type.py#L123-L131). 3. Suggestions made by @eellison in the [FP32 PR](https://github.com/pytorch/pytorch/pull/68111#pullrequestreview-896719477) are being incorporated/addressed - | Action-item | Status | | :--- | ---: | |checkInputCompatibility follow up | Fixed | |the mayConvertScalarInputToTensor logic we can consider | Added type promotion code | |fix up fixConvOptionalBias| The current approach seems correct | |Use opinfo tests| using dtypes decorator. Will use `OpInfo` in a subsequent PR, if that'd be possible. Should we create a list of ops from opDB that are supported by oneDNN Graph, and add it to `common_methods_invocations.py`? | |inferDevice torch_check call | not necessary now, perhaps, as only CPU is supported, for now? We'd add it by the beta release of oneDNN Graph, though, so that by then, users might be able to use other fusers with oneDNN Graph (NNC/TensorExpr are already compatible with the oneDNN Graph fuser). We can still add it, if you'd insist. | |not checking shapes of input mkldnn tensor to llga guard | Those checks should not be present because oneDNN Graph may use blocked or channels-last layout, so those strides would be different. They're only skipped if an LLGA subgraph's output is input to another LLGA subgraph, which enables LLGA to choose an optimal layout between them. | |fix test failures with respect to unsupported inputs | We'll address them with the upcoming release of oneDNN Graph beta version| 4. More PyTorch ops are being been mapped to oneDNN Graph ## Example of using oneDNN Graph with BFloat16 ```python # Assuming we have a model of the name 'model' example_input = torch.rand(1, 3, 224, 224) # enable oneDNN Graph torch.jit.enable_onednn_fusion(True) # Disable AMP for JIT torch._C._jit_set_autocast_mode(False) with torch.no_grad(), torch.cpu.amp.autocast(): model = torch.jit.trace(model, (example_input)) model = torch.jit.freeze(model) # 2 warm-ups (2 for tracing/scripting with an example, 3 without an example) model(example_input) model(example_input) # speedup would be observed in subsequent runs. model(example_input) ``` ## TorchBench based Benchmarks **URL:** https://github.com/sanchitintel/benchmark/tree/onednn_graph_benchmark (instructions present at URL). **Batch-size(s):** TorchBench-default for each model **Baseline :** PyTorch JIT OFI FP32 **Machine:** Intel(R) Xeon(R) Platinum 8371HC (Cooper Lake) **Sockets used**: 1 **Number of cores on one socket**: 26 Intel OpenMP & tcmalloc were preloaded #### Benchmark results with single thread | name | latency of PyTorch JIT OFI FP32 (s) | Latency of oneDNN Graph BF16 (s) | % change | | :--- | ---: | ---: | ---: | | test_eval[alexnet-cpu-jit] | 1.063851 | 0.509820 | -52.1% | | test_eval[mnasnet1_0-cpu-jit] | 0.218435 | 0.107100 | -51.0% | | test_eval[mobilenet_v2-cpu-jit] | 0.114467 | 0.058359 | -49.0% | | test_eval[mobilenet_v3_large-cpu-jit] | 0.233873 | 0.117614 | -49.7% | | test_eval[resnet18-cpu-jit] | 0.160584 | 0.075854 | -52.8% | | test_eval[resnet50-cpu-jit] | 1.652846 | 0.713373 | -56.8% | | test_eval[resnext50_32x4d-cpu-jit] | 0.471174 | 0.209431 | -55.6% | |test_eval[shufflenet_v2_x1_0-cpu-jit] | 0.310306 | 0.167090 | -46.2% | | test_eval[squeezenet1_1-cpu-jit] | 0.161247 | 0.045684 | -71.7% | | test_eval[timm_efficientnet-cpu-jit] | 1.643772 | 0.800099 | -51.3% | | test_eval[timm_regnet-cpu-jit] | 5.732272 | 2.333417 | -59.3% | | test_eval[timm_resnest-cpu-jit] | 1.366464 | 0.715252 | -47.7% | | test_eval[timm_vision_transformer-cpu-jit] | 0.508521 | 0.271598 | -46.6% | | test_eval[timm_vovnet-cpu-jit] | 2.756692 | 1.125033 | -59.2% | | test_eval[vgg16-cpu-jit] | 0.711533 | 0.312344 | -56.1% | #### Benchmark results with 26 threads: | name | latency of PyTorch JIT OFI FP32 (s) | Latency of oneDNN Graph BF16 (s) | % change | | :--- | ---: | ---: | ---: | | test_eval[alexnet-cpu-jit] | 0.062871 | 0.034198 | -45.6% | | test_eval[mnasnet1_0-cpu-jit] | 0.022490 | 0.008172 | -63.7% | | test_eval[mobilenet_v2-cpu-jit] | 0.012730 | 0.005866 | -53.9% | | test_eval[mobilenet_v3_large-cpu-jit] | 0.025948 | 0.010346 | -60.1% | | test_eval[resnet18-cpu-jit] | 0.011194 | 0.005726 | -48.9% | | test_eval[resnet50-cpu-jit] | 0.124662 | 0.045599 | -63.4% | | test_eval[resnext50_32x4d-cpu-jit] | 0.034737 | 0.015214 | -56.2% | |test_eval[shufflenet_v2_x1_0-cpu-jit] | 0.028820 | 0.012517 | -56.6% | | test_eval[squeezenet1_1-cpu-jit] | 0.012557 | 0.003876 | -69.1% | | test_eval[timm_efficientnet-cpu-jit] | 0.203177 | 0.051879 | -74.5% | | test_eval[timm_regnet-cpu-jit] | 0.452050 | 0.151113 | -66.6% | | test_eval[timm_resnest-cpu-jit] | 0.117072 | 0.052848 | -54.9% | | test_eval[timm_vision_transformer-cpu-jit] | 0.046048 | 0.023275 | -49.5% | | test_eval[timm_vovnet-cpu-jit] | 0.213187 | 0.077482 | -63.7% | | test_eval[vgg16-cpu-jit] | 0.044726 | 0.021998 | -50.8% | Pull Request resolved: https://github.com/pytorch/pytorch/pull/85591 Approved by: https://github.com/jgong5, https://github.com/frank-wei, https://github.com/chunyuan-w
This commit is contained in:
committed by
PyTorch MergeBot
parent
14dd5db2f5
commit
974ad8fa6c
@ -637,6 +637,7 @@ if(USE_CUDA)
|
||||
else()
|
||||
set(DELAY_LOAD_FLAGS "")
|
||||
endif()
|
||||
|
||||
target_link_libraries(caffe2_nvrtc ${CUDA_NVRTC} ${CUDA_CUDA_LIB} ${CUDA_NVRTC_LIB} ${DELAY_LOAD_FLAGS})
|
||||
target_include_directories(caffe2_nvrtc PRIVATE ${CUDA_INCLUDE_DIRS})
|
||||
install(TARGETS caffe2_nvrtc DESTINATION "${TORCH_INSTALL_LIB_DIR}")
|
||||
@ -664,6 +665,7 @@ if(BUILD_ONEDNN_GRAPH)
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/graph_rewriter.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/graph_helper.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/register_interface.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/decompose_silu.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/interface.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/kernel.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/defer_size_check.cpp
|
||||
|
@ -1,51 +1,114 @@
|
||||
# Owner(s): ["module: mkldnn"]
|
||||
import sys
|
||||
import torch
|
||||
import unittest
|
||||
import itertools
|
||||
|
||||
import torch.nn as nn
|
||||
from functools import wraps
|
||||
from concurrent import futures
|
||||
import torch.nn.functional as F
|
||||
import torch.fx.experimental.optimization as optimization
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_SCIPY, IS_WINDOWS, IS_MACOS
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
dtypes
|
||||
)
|
||||
|
||||
# We use this wrapper to run UTs of TorchVision models because of a memory-leak
|
||||
# issue with JIT tracing that causes traced model objects to persist in the
|
||||
# memory. Ref: https://github.com/pytorch/pytorch/issues/35600
|
||||
# Memory requirement for running these UTs was thus increasing cumulatively, and
|
||||
# invoked the Linux kernel OOM killer on linux.2xlarge PyTorch CI runners, which
|
||||
# only have 16 GB RAM. Cumulatively, these UTs had been using more than 14 GB
|
||||
# memory (as per psutils). So now we run each TorchVision model UTs in separate processes.
|
||||
def separate_process(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with futures.ProcessPoolExecutor() as executor:
|
||||
future = executor.submit(func, *args, **kwargs)
|
||||
futures.wait([future])
|
||||
return wrapper
|
||||
|
||||
def is_avx512_supported():
|
||||
if sys.platform != 'linux':
|
||||
return False
|
||||
with open("/proc/cpuinfo", encoding="ascii") as f:
|
||||
lines = f.read()
|
||||
return "avx512" in lines
|
||||
|
||||
IS_AVX512_UNSUPPORTED = not is_avx512_supported()
|
||||
|
||||
LLGA_FUSION_GROUP = 'prim::oneDNNFusionGroup'
|
||||
LLGA_NOT_ENABLED = not torch._C.has_mkldnn or IS_WINDOWS or IS_MACOS
|
||||
|
||||
|
||||
def warmup_forward(f, *args, profiling_count=2):
|
||||
def warmup_forward(f, *args, profiling_count=3):
|
||||
for i in range(profiling_count):
|
||||
results = f(*args)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class JitLlgaTestCase(JitTestCase):
|
||||
|
||||
def setUp(self):
|
||||
# PyTorch has divergent op support for AMP in JIT & eager modes
|
||||
# so we disable AMP for JIT & leverage eager-mode AMP.
|
||||
# Ref: https://github.com/pytorch/pytorch/issues/75956
|
||||
self.original_autocast_mode = torch._C._jit_set_autocast_mode(False)
|
||||
torch.jit.enable_onednn_fusion(True)
|
||||
|
||||
def tearDown(self):
|
||||
torch.jit.enable_onednn_fusion(False)
|
||||
torch._C._jit_set_autocast_mode(self.original_autocast_mode)
|
||||
|
||||
def checkTrace(self, m, x, *args, **kwargs):
|
||||
def checkTrace(self, m, x, dtype=torch.float32, *args, **kwargs):
|
||||
if isinstance(m, torch.nn.Module):
|
||||
m.eval()
|
||||
with torch.no_grad(), \
|
||||
torch._jit_internal._disable_emit_hooks():
|
||||
traced = torch.jit.trace(m, x)
|
||||
if isinstance(m, torch.nn.Module):
|
||||
traced = torch.jit.freeze(traced)
|
||||
warmup_forward(traced, *x)
|
||||
fwd_graph = traced.graph_for(*x)
|
||||
with torch.no_grad(), torch._jit_internal._disable_emit_hooks():
|
||||
if dtype == torch.bfloat16:
|
||||
# We rely upon eager-mode AMP support for BF16
|
||||
with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16):
|
||||
traced = torch.jit.trace(m, x)
|
||||
if isinstance(m, torch.nn.Module):
|
||||
traced = torch.jit.freeze(traced)
|
||||
warmup_forward(traced, *x)
|
||||
ref_o = m(*x)
|
||||
fwd_graph = traced.graph_for(*x)
|
||||
else:
|
||||
traced = torch.jit.trace(m, x)
|
||||
if isinstance(m, torch.nn.Module):
|
||||
traced = torch.jit.freeze(traced)
|
||||
warmup_forward(traced, *x)
|
||||
ref_o = m(*x)
|
||||
fwd_graph = traced.graph_for(*x)
|
||||
|
||||
ref_o = m(*x)
|
||||
jit_o = traced(*x)
|
||||
self.assertEqual(jit_o, ref_o)
|
||||
return traced, fwd_graph
|
||||
return traced, fwd_graph
|
||||
|
||||
|
||||
def assertFused(self, graph, fused_patterns):
|
||||
for pat in fused_patterns:
|
||||
self.assertGraphContainsExactly(graph, pat, 0)
|
||||
|
||||
def findFusionGroups(self, graph):
|
||||
result = []
|
||||
for n in graph.nodes():
|
||||
if n.kind() == LLGA_FUSION_GROUP:
|
||||
result.append(n.g('Subgraph'))
|
||||
continue
|
||||
for block in n.blocks():
|
||||
result += self.findFusionGroups(block)
|
||||
return result
|
||||
|
||||
def checkPatterns(self, graph, patterns):
|
||||
fusion_groups = self.findFusionGroups(graph)
|
||||
assert len(fusion_groups) == len(patterns), "length of subgraphs not equal to length of given patterns"
|
||||
|
||||
for i in range(len(fusion_groups)):
|
||||
for pattern in patterns[i]:
|
||||
self.assertGraphContains(fusion_groups[i], pattern)
|
||||
|
||||
try:
|
||||
import torchvision
|
||||
@ -61,13 +124,18 @@ def get_eltwise_fn(name):
|
||||
return getattr(torch, name)
|
||||
elif hasattr(F, name):
|
||||
return getattr(F, name)
|
||||
elif name == 'hardswish_':
|
||||
return torch.nn.Hardswish(inplace=True)
|
||||
else:
|
||||
raise NameError('Eltwise function %s not found' % name)
|
||||
|
||||
|
||||
@unittest.skipIf(IS_AVX512_UNSUPPORTED, "This test fails for BF16 on machines without AVX512.")
|
||||
@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
|
||||
class TestOp(JitLlgaTestCase):
|
||||
def test_conv2d(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_conv2d(self, dtype):
|
||||
for [spatial, in_channels, out_channels, kernel, padding, stride, dilation, g, bias] in itertools.product(
|
||||
[7, 8],
|
||||
[8, 15],
|
||||
@ -89,17 +157,21 @@ class TestOp(JitLlgaTestCase):
|
||||
bias=bias)
|
||||
|
||||
x = torch.rand(1, in_channels * g, spatial, spatial)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_bn2d(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_bn2d(self, dtype):
|
||||
m = nn.BatchNorm2d(32).eval()
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
# single-op partition shouldn't be created for softmax
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
|
||||
|
||||
def test_eltwise(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_eltwise(self, dtype):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn):
|
||||
super(M, self).__init__()
|
||||
@ -112,11 +184,13 @@ class TestOp(JitLlgaTestCase):
|
||||
eltwise_fn = get_eltwise_fn(eltwise)
|
||||
m = M(eltwise_fn)
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
# single-op partition shouldn't be created.
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
|
||||
|
||||
def test_max_pool2d(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_max_pool2d(self, dtype):
|
||||
for [spatial, kernel, padding, stride, dilation, ceil_mode] in itertools.product(
|
||||
[15, 16, 17, 18, 19],
|
||||
[4, 5],
|
||||
@ -132,10 +206,12 @@ class TestOp(JitLlgaTestCase):
|
||||
ceil_mode=ceil_mode)
|
||||
|
||||
x = torch.rand(1, 4, spatial, spatial)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_avg_pool2d(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_avg_pool2d(self, dtype):
|
||||
for [spatial, kernel, padding, stride, ceil_mode, count_include_pad] in itertools.product(
|
||||
[15, 16, 17, 18, 19],
|
||||
[4, 5],
|
||||
@ -151,10 +227,12 @@ class TestOp(JitLlgaTestCase):
|
||||
count_include_pad=count_include_pad)
|
||||
|
||||
x = torch.rand(1, 4, spatial, spatial)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_variable_kernel_avg_pool2d(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_variable_kernel_avg_pool2d(self, dtype):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
@ -165,27 +243,32 @@ class TestOp(JitLlgaTestCase):
|
||||
|
||||
x = torch.randn(1, 1000, 1, 1)
|
||||
m = M()
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
# kernel_size is not Constant, shouldn't have any LLGA_FUSION_GROUP
|
||||
# TODO: with shape specialization, should have 1 LLGA_FUSION_GROUP
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
|
||||
|
||||
def test_softmax(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_softmax(self, dtype):
|
||||
for dim in [-4, -3, -2, -1, 0, 1, 2, 3]:
|
||||
m = nn.Softmax(dim=dim)
|
||||
x = torch.rand(8, 12, 12, 12)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
# single-op partition shouldn't be created for softmax
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
|
||||
|
||||
def test_linear(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_linear(self, dtype):
|
||||
for bias in [True, False]:
|
||||
x = torch.rand(32, 28)
|
||||
m = torch.nn.Linear(in_features=28, out_features=64, bias=bias)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::linear'])
|
||||
|
||||
|
||||
def _gen_binary_inputs(self, gen_permute=True):
|
||||
for xshape, yshape in [
|
||||
[[1, 32, 28, 28], [1, 32, 28, 28]],
|
||||
@ -198,23 +281,32 @@ class TestOp(JitLlgaTestCase):
|
||||
if gen_permute and xshape != yshape:
|
||||
yield torch.rand(yshape), torch.rand(xshape)
|
||||
|
||||
def test_add(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_add(self, dtype):
|
||||
def forward_add(x, y):
|
||||
return torch.add(x, y, alpha=2)
|
||||
|
||||
for x, y in self._gen_binary_inputs():
|
||||
_, graph = self.checkTrace(forward_add, [x, y])
|
||||
_, graph = self.checkTrace(forward_add, [x, y], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_add_scalar(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_add_scalar(self, dtype):
|
||||
def add_scalar(x):
|
||||
return 42 + x + 3.14
|
||||
|
||||
x = torch.rand(32, 32)
|
||||
_, graph = self.checkTrace(add_scalar, [x])
|
||||
_, graph = self.checkTrace(add_scalar, [x], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_addmm(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_addmm(self, dtype):
|
||||
# Just a sidenote - comparison of eager-mode & oneDNN Graph JIT outputs of
|
||||
# addmm (which entails matmul-bias-add fusion) might require higher tolerance
|
||||
# bounds for BF16. This is subject to change in the near future.
|
||||
def addmm(x, y, z):
|
||||
# alpha and beta are 1, by default
|
||||
return torch.addmm(z, x, y)
|
||||
@ -222,35 +314,43 @@ class TestOp(JitLlgaTestCase):
|
||||
x = torch.rand(64, 32)
|
||||
y = torch.rand(32, 32)
|
||||
z = torch.rand(64, 32)
|
||||
_, graph = self.checkTrace(addmm, [x, y, z])
|
||||
_, graph = self.checkTrace(addmm, [x, y, z], dtype)
|
||||
# single-op partition should be created for matmul with bias.
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_mul(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_mul(self, dtype):
|
||||
def forward_mul(x, y):
|
||||
return torch.mul(x, y) * 3
|
||||
|
||||
for x, y in self._gen_binary_inputs():
|
||||
_, graph = self.checkTrace(forward_mul, [x, y])
|
||||
_, graph = self.checkTrace(forward_mul, [x, y], dtype)
|
||||
# single-op partitions shouldn't be created
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_identity_binary(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_identity_binary(self, dtype):
|
||||
def forward(x):
|
||||
return x * 1 + 0.0
|
||||
|
||||
x = torch.rand(32)
|
||||
_, graph = self.checkTrace(forward, [x])
|
||||
_, graph = self.checkTrace(forward, [x], dtype)
|
||||
self.assertFused(graph, ['aten::add', 'aten::mul'])
|
||||
|
||||
def test_layer_norm(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_layer_norm(self, dtype):
|
||||
# TODO: support more normalized_shape
|
||||
m = torch.nn.LayerNorm(10)
|
||||
x = torch.randn(2, 5, 10, 10)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_cat(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_cat(self, dtype):
|
||||
def cat_along_dim(d):
|
||||
def forward_cat(*inputs):
|
||||
return torch.cat(inputs, d)
|
||||
@ -263,23 +363,28 @@ class TestOp(JitLlgaTestCase):
|
||||
]:
|
||||
for d in range(len(xshape)):
|
||||
x = torch.rand(xshape)
|
||||
_, graph = self.checkTrace(cat_along_dim(d), [x, x, x])
|
||||
_, graph = self.checkTrace(cat_along_dim(d), [x, x, x], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_typecheck(self):
|
||||
x = torch.rand(32, 28)
|
||||
m = torch.nn.Linear(in_features=28, out_features=64, bias=True)
|
||||
traced, graph = self.checkTrace(m, [x])
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_typecheck(self, dtype):
|
||||
x = torch.rand(32, 28, dtype=dtype)
|
||||
m = torch.nn.Linear(in_features=28, out_features=64, bias=True, dtype=dtype)
|
||||
traced, graph = self.checkTrace(m, [x], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::linear'])
|
||||
# change the shape of the input, we should enter fallback graph
|
||||
x = torch.rand(5, 28)
|
||||
x = torch.rand(5, 28, dtype=dtype)
|
||||
self.assertEqual(m(x), traced(x))
|
||||
|
||||
|
||||
@unittest.skipIf(IS_AVX512_UNSUPPORTED, "This test fails for BF16 on machines without AVX512.")
|
||||
@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
|
||||
class TestFusionPattern(JitLlgaTestCase):
|
||||
def test_conv2d_eltwise(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_conv2d_eltwise(self, dtype):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn):
|
||||
super(M, self).__init__()
|
||||
@ -294,22 +399,128 @@ class TestFusionPattern(JitLlgaTestCase):
|
||||
x = self.eltwise(x)
|
||||
return x
|
||||
|
||||
# for eltwise in ['relu', 'sigmoid', 'sqrt', 'abs', 'square', 'hardtanh']:
|
||||
for eltwise in ['relu']:
|
||||
for eltwise in ['relu', 'leaky_relu', 'sigmoid', 'square',
|
||||
'abs', 'exp', 'hardswish', 'tanh', 'hardtanh']:
|
||||
for inplace in [True, False]:
|
||||
eltwise_fn_name = eltwise + '_' if inplace else eltwise
|
||||
eltwise_fn = get_eltwise_fn(eltwise_fn_name)
|
||||
|
||||
m = M(eltwise_fn)
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
_, graph = self.checkTrace(m, [x], dtype=dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
|
||||
# test if relu_ is replace with relu by mutation removal pass
|
||||
self.assertFused(graph, ['aten::' + eltwise_fn_name])
|
||||
# test if relu is fused into the fusion group
|
||||
self.assertFused(graph, ['aten::' + eltwise])
|
||||
|
||||
def test_conv2d_bn(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_conv2d_silu(self, dtype):
|
||||
class M(nn.Module):
|
||||
def __init__(self, inplace):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.eltwise = nn.SiLU(inplace=inplace)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.eltwise(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
for inplace in [False, True]:
|
||||
for memory_format in [torch.contiguous_format, torch.channels_last]:
|
||||
m = M(inplace)
|
||||
x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format)
|
||||
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
|
||||
# oneDNN graph does not have silu OP. The bridge will convert silu to sigmoid - mul
|
||||
# Inplace op will become outplace op on the JIT graph
|
||||
patterns = [
|
||||
["aten::_convolution", 'aten::sigmoid', 'aten::mul'],
|
||||
["aten::_convolution"]
|
||||
]
|
||||
silu_op = 'aten::silu_' if inplace else 'aten::silu'
|
||||
self.assertFused(graph, ['aten::_convolution', silu_op])
|
||||
self.checkPatterns(graph, patterns)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_ensure_tensor_is_rewrapped(self, dtype):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.conv4 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.eltwise = eltwise_fn
|
||||
self.adaptive_avg_pool_2d = nn.AdaptiveAvgPool2d((5, 7))
|
||||
|
||||
def forward(self, x, y):
|
||||
x = self.conv1(x)
|
||||
x = self.eltwise(x)
|
||||
x = self.conv2(x)
|
||||
x = self.eltwise(x)
|
||||
y = self.conv3(y)
|
||||
y = self.eltwise(y)
|
||||
y = self.conv4(y)
|
||||
y = self.eltwise(y)
|
||||
|
||||
x = torch.add(x, y)
|
||||
x = self.adaptive_avg_pool_2d(x)
|
||||
return x
|
||||
|
||||
eltwise_fn_name = 'relu'
|
||||
eltwise_fn = get_eltwise_fn(eltwise_fn_name)
|
||||
m = M(eltwise_fn)
|
||||
m = m.to(memory_format=torch.channels_last)
|
||||
x = torch.rand(1, 32, 28, 28).to(memory_format=torch.channels_last)
|
||||
y = torch.rand(1, 32, 28, 28).to(memory_format=torch.channels_last)
|
||||
# Simply test if the output is accurate
|
||||
# The output of the second partition is input to adaptive_avg_pool2d, which is
|
||||
# unsupported by LLGA. In resnext101 32x16d, we encountered an accuracy issue.
|
||||
_, graph = self.checkTrace(m, [x, y], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_conv2d_clamp(self, dtype):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.conv4 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.conv5 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = torch.clamp(x, min=float('-inf'))
|
||||
x = self.conv2(x)
|
||||
x = torch.clamp(x, min=-5)
|
||||
x = self.conv3(x)
|
||||
x = torch.clamp(x, min=0, max=float('inf'))
|
||||
x = self.conv4(x)
|
||||
x = torch.clamp(x, min=1, max=5)
|
||||
x = self.conv5(x)
|
||||
x = torch.clamp(x, max=2)
|
||||
return x
|
||||
|
||||
for inplace in [False, True]:
|
||||
for memory_format in [torch.contiguous_format, torch.channels_last]:
|
||||
x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format)
|
||||
m = M()
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 5)
|
||||
self.assertFused(graph, ['aten::_convolution', "aten::clamp"])
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_conv2d_bn(self, dtype):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
@ -322,13 +533,16 @@ class TestFusionPattern(JitLlgaTestCase):
|
||||
return x
|
||||
|
||||
m = M().eval()
|
||||
if dtype == torch.bfloat16:
|
||||
m = optimization.fuse(m)
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm'])
|
||||
|
||||
|
||||
def test_conv2d_bn_relu(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_conv2d_bn_relu(self, dtype):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
@ -342,13 +556,17 @@ class TestFusionPattern(JitLlgaTestCase):
|
||||
return x
|
||||
|
||||
m = M().eval()
|
||||
if dtype == torch.bfloat16:
|
||||
m = optimization.fuse(m)
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm',
|
||||
'aten::relu'])
|
||||
|
||||
def test_bn2d_eltwise(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_bn2d_eltwise(self, dtype):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn):
|
||||
super(M, self).__init__()
|
||||
@ -364,11 +582,13 @@ class TestFusionPattern(JitLlgaTestCase):
|
||||
eltwise_fn = get_eltwise_fn(eltwise)
|
||||
m = M(eltwise_fn).eval()
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::' + eltwise])
|
||||
|
||||
def test_linear_eltwise(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_linear_eltwise(self, dtype):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn, bias):
|
||||
super(M, self).__init__()
|
||||
@ -387,11 +607,13 @@ class TestFusionPattern(JitLlgaTestCase):
|
||||
eltwise_fn = get_eltwise_fn(eltwise)
|
||||
m = M(eltwise_fn, has_bias)
|
||||
x = torch.rand(32, 28, requires_grad=False)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::' + eltwise])
|
||||
|
||||
def test_conv2d_sum(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_conv2d_sum(self, dtype):
|
||||
class M(nn.Module):
|
||||
def __init__(self, bias=False):
|
||||
super(M, self).__init__()
|
||||
@ -415,12 +637,16 @@ class TestFusionPattern(JitLlgaTestCase):
|
||||
|
||||
for bias in [True, False]:
|
||||
m = M(bias).eval()
|
||||
if dtype == torch.bfloat16:
|
||||
m = optimization.fuse(m)
|
||||
x = torch.rand(1, 32, 16, 16, requires_grad=False)
|
||||
y = torch.rand(1, 32, 16, 16, requires_grad=False)
|
||||
_, graph = self.checkTrace(m, [x, y])
|
||||
_, graph = self.checkTrace(m, [x, y], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
|
||||
|
||||
def test_wildcard(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_wildcard(self, dtype):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
@ -443,17 +669,43 @@ class TestFusionPattern(JitLlgaTestCase):
|
||||
# Thus conv-eltwise cannot be selected into the same Partition.
|
||||
m = M()
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
# conv can exist in a single-op oneDNN Graph partition but not relu
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::_convolution'])
|
||||
|
||||
def test_rewrap_tensor_input_to_pytorch(self):
|
||||
@onlyCPU
|
||||
@dtypes(torch.int32)
|
||||
def test_wildcard_unsupported_dtype(self, dtype):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn, data_type):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True, dtype=data_type)
|
||||
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True, dtype=data_type)
|
||||
|
||||
def forward(self, x):
|
||||
y = x // 2
|
||||
return y
|
||||
|
||||
# In shufflenet_v2_x1_0, channels_per_groups is computed as:
|
||||
# channels_per_group = num_channels // groups
|
||||
# JIT IR converts groups to Long dtype, which is unsupported
|
||||
# by oneDNN Graph, viz. Long(requires_grad=0, device=cpu) = prim::Constant[value={2}]()
|
||||
# This test just ensures that the bridge code can handle
|
||||
# unsupported dtypes for inputs to ops unsupported
|
||||
# by oneDNN Graph. In this particular UT, aten::floor_divide
|
||||
# would be added as a wildcard in graph-construction stage.
|
||||
m = M()
|
||||
x = torch.tensor([32], dtype=dtype)
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.bfloat16)
|
||||
def test_rewrap_tensor_input_to_pytorch(self, dtype):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.eltwise = eltwise_fn
|
||||
self.adaptive_avg_pool_2d = nn.AdaptiveAvgPool2d((5, 7))
|
||||
|
||||
@ -468,18 +720,15 @@ class TestFusionPattern(JitLlgaTestCase):
|
||||
|
||||
eltwise_fn_name = 'relu'
|
||||
eltwise_fn = get_eltwise_fn(eltwise_fn_name)
|
||||
# Add bfloat16 later
|
||||
for data_type in [torch.float]:
|
||||
m = M(eltwise_fn, data_type)
|
||||
m = m.to(memory_format=torch.channels_last)
|
||||
x = torch.rand(1, 32, 28, 28, dtype=data_type).to(memory_format=torch.channels_last)
|
||||
y = torch.rand(1, 32, 28, 28, dtype=data_type).to(memory_format=torch.channels_last)
|
||||
# Simply test if the output is accurate
|
||||
# The output of the second partition is input to adaptive_avg_pool2d, which is
|
||||
# unsupported by LLGA, so it must be handled by PyTorch, which should receive
|
||||
# correct strides info of the channels-last tensor.
|
||||
graph, _ = self.checkTrace(m, [x, y])
|
||||
|
||||
m = M(eltwise_fn)
|
||||
m = m.to(memory_format=torch.channels_last)
|
||||
x = torch.rand(1, 32, 28, 28).to(memory_format=torch.channels_last)
|
||||
y = torch.rand(1, 32, 28, 28).to(memory_format=torch.channels_last)
|
||||
# Simply test if the output is accurate
|
||||
# The output of the second partition is input to adaptive_avg_pool2d, which is
|
||||
# unsupported by LLGA, so it must be handled by PyTorch, which should receive
|
||||
# correct strides info of the channels-last tensor.
|
||||
graph, _ = self.checkTrace(m, [x, y], dtype)
|
||||
|
||||
@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
|
||||
class TestEnableDisableLlgaFuser(JitTestCase):
|
||||
@ -525,25 +774,40 @@ class TestEnableDisableLlgaFuser(JitTestCase):
|
||||
self.assertGraphContainsExactly(t_jit_3.graph_for(x, y), LLGA_FUSION_GROUP, 0)
|
||||
|
||||
|
||||
@unittest.skipIf(IS_AVX512_UNSUPPORTED, "This test fails for BF16 on machines without AVX512.")
|
||||
@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
|
||||
class TestModel(JitLlgaTestCase):
|
||||
@skipIfNoTorchVision
|
||||
def _test_vision(self, model_name):
|
||||
def _test_vision(self, model_name, dtype):
|
||||
m = getattr(torchvision.models, model_name)().eval()
|
||||
if dtype == torch.bfloat16:
|
||||
m = optimization.fuse(m)
|
||||
x = torch.rand(1, 3, 224, 224) / 10
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
_, graph = self.checkTrace(m, [x], dtype)
|
||||
self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm',
|
||||
'aten::relu', 'aten::linear',
|
||||
'aten::avg_pool2d', 'aten::max_pool2d'])
|
||||
|
||||
|
||||
for model_name, enabled in [
|
||||
['resnet50', True],
|
||||
['resnext50_32x4d', True],
|
||||
['resnext101_32x8d', True],
|
||||
['densenet121', True],
|
||||
['densenet161', True],
|
||||
['densenet169', True],
|
||||
['densenet201', True],
|
||||
['efficientnet_b0', True],
|
||||
['efficientnet_b1', True],
|
||||
['efficientnet_b2', True],
|
||||
['efficientnet_b3', True],
|
||||
['efficientnet_b4', True],
|
||||
['efficientnet_b5', True],
|
||||
['efficientnet_b6', True],
|
||||
['efficientnet_b7', True],
|
||||
['regnet_y_400mf', True],
|
||||
['googlenet', TEST_SCIPY],
|
||||
['mobilenet_v2', True],
|
||||
['mobilenet_v3_large', True],
|
||||
['mnasnet1_0', True],
|
||||
['squeezenet1_0', True],
|
||||
['vgg16', True],
|
||||
@ -551,13 +815,19 @@ for model_name, enabled in [
|
||||
['shufflenet_v2_x1_0', True],
|
||||
['wide_resnet50_2', True],
|
||||
]:
|
||||
def wrapper(mname):
|
||||
def _wrapper(mname, dtype):
|
||||
@unittest.skipIf(not enabled, 'Disabled')
|
||||
def test(self):
|
||||
return self._test_vision(mname)
|
||||
@separate_process
|
||||
def test(self, dtype=dtype):
|
||||
return self._test_vision(mname, dtype)
|
||||
return test
|
||||
|
||||
setattr(TestModel, 'test_vision_%s' % model_name, wrapper(model_name))
|
||||
for dtype in [torch.bfloat16, torch.float32]:
|
||||
setattr(TestModel, 'test_vision_%s_%s' % (model_name, str(dtype).split("torch.")[1]), _wrapper(model_name, dtype))
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestFusionPattern, globals())
|
||||
instantiate_device_type_tests(TestOp, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
2
third_party/ideep
vendored
2
third_party/ideep
vendored
Submodule third_party/ideep updated: 77d662b313...76d2b0dd18
@ -16,7 +16,7 @@ dnnl::graph::engine& Engine::getEngine() {
|
||||
}
|
||||
|
||||
dnnl::graph::stream& Stream::getStream() {
|
||||
static dnnl::graph::stream cpu_stream{Engine::getEngine(), nullptr};
|
||||
static dnnl::graph::stream cpu_stream{Engine::getEngine()};
|
||||
return cpu_stream;
|
||||
}
|
||||
|
||||
@ -76,7 +76,7 @@ dnnl::graph::tensor llga_from_aten_tensor(const at::Tensor& tensor) {
|
||||
|
||||
using data_type = dnnl::graph::logical_tensor::data_type;
|
||||
|
||||
data_type getLlgaDataType(at::ScalarType dt) {
|
||||
data_type LlgaTensorDesc::getLlgaDataType(at::ScalarType dt) const {
|
||||
switch (dt) {
|
||||
case at::ScalarType::Float:
|
||||
return data_type::f32;
|
||||
@ -89,7 +89,12 @@ data_type getLlgaDataType(at::ScalarType dt) {
|
||||
case at::ScalarType::QUInt8:
|
||||
return data_type::u8;
|
||||
default:
|
||||
TORCH_CHECK(false, "Not support data type ", dt);
|
||||
// If a dtype is unsupported, oneDNN Graph will make that op a wildcard in
|
||||
// the graph construction stage. Then when we would execute oneDNN Graph
|
||||
// kernels pertaining to oneDNN Graph partitions, such an op would not be
|
||||
// inside a oneDNN Graph partition, so we would not encounter inputs with
|
||||
// unsupported dtypes at the time of executing compiled partitions.
|
||||
return data_type::undef;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -68,9 +68,6 @@ struct LlgaTensorDesc {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: llga need set input/output type constraints while it seems that we
|
||||
// cannot get the dtype during compile time, hard-coded to fp32 for now to be
|
||||
// able to add_op
|
||||
LlgaTensorDesc(const torch::jit::Value* v)
|
||||
: LlgaTensorDesc(
|
||||
v->unique(),
|
||||
@ -81,6 +78,10 @@ struct LlgaTensorDesc {
|
||||
if (v->type()->isSubtypeOf(TensorType::get())) {
|
||||
auto tt = v->type()->cast<TensorType>();
|
||||
|
||||
if (tt->scalarType()) {
|
||||
dtype_ = getLlgaDataType(tt->scalarType().value());
|
||||
}
|
||||
|
||||
auto sizes = tt->sizes();
|
||||
if (sizes.sizes()) {
|
||||
for (auto d : *sizes.sizes()) {
|
||||
@ -99,6 +100,8 @@ struct LlgaTensorDesc {
|
||||
|
||||
LlgaTensorDesc supplementTensorInfo(const at::Tensor& t) const;
|
||||
|
||||
desc::data_type getLlgaDataType(at::ScalarType dt) const;
|
||||
|
||||
at::ScalarType aten_scalar_type() const;
|
||||
|
||||
const std::vector<int64_t>& sizes() const {
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Pytorch - oneDNN Graph API Bridge
|
||||
This integration will add the infrastructure of a new PyTorch JIT graph fuser based on [oneDNN Graph API](https://spec.oneapi.io/onednn-graph/latest/programming_model.html), which provides a flexible API for aggressive fusion. The current preview4 version supports fusion for FP32 inference. Currently, the speedup is achieved for static shapes,
|
||||
although we'd soon add dynamic-shape support. When oneDNN Graph is enabled, weights are cached, as they're constant during inference.
|
||||
This is a PyTorch JIT graph fuser based on [oneDNN Graph API](https://spec.oneapi.io/onednn-graph/latest/programming_model.html), which provides a flexible API for aggressive fusion. Float & BFloat16 inference is supported. However, BFloat16 only performs well on Intel Xeon Cooper Lake platform & beyond, as they have native BFloat16 support. Also, currently, PyTorch has divergent AMP support in JIT & eager modes, so one should disable JIT AMP support & leverage eager mode AMP support to use BFloat16. Please refer to the BFloat16 example below.
|
||||
|
||||
Currently, speedup is achieved only for static shapes, although we'd soon add dynamic-shape support. When oneDNN Graph is enabled, weights are cached, as they're constant during inference.
|
||||
|
||||
## Graph Optimization
|
||||
We have registered optimization passes in the custom pre-passes set of PyTorch:
|
||||
@ -84,7 +85,7 @@ To map another op to oneDNN Graph, you should add an entry for it in in createOp
|
||||
If it has an inplace variant, you should add it in the lambda being passed to RemoveTensorMutation in
|
||||
torch/csrc/jit/codegen/onednn/interface.cpp. You might also want to add it to canFuseNode in torch/csrc/jit/codegen/onednn/register_interface.cpp.
|
||||
|
||||
## How to use
|
||||
## Example with Float
|
||||
|
||||
|
||||
```python
|
||||
@ -106,3 +107,25 @@ with torch.no_grad():
|
||||
# oneDNN graph fusion will be trigerred during runtime
|
||||
output = model(images)
|
||||
```
|
||||
|
||||
## Example with BFloat16
|
||||
|
||||
```python
|
||||
# Assuming we have a model of the name 'model'
|
||||
|
||||
example_input = torch.rand(1, 3, 224, 224)
|
||||
|
||||
# enable oneDNN Graph
|
||||
torch.jit.enable_onednn_fusion(True)
|
||||
# Disable AMP for JIT
|
||||
torch._C._jit_set_autocast_mode(False)
|
||||
with torch.no_grad(), torch.cpu.amp.autocast():
|
||||
model = torch.jit.trace(model, (example_input))
|
||||
model = torch.jit.freeze(model)
|
||||
# 2 warm-ups (2 for tracing/scripting with an example, 3 without an example)
|
||||
model(example_input)
|
||||
model(example_input)
|
||||
|
||||
# speedup would be observed in subsequent runs.
|
||||
model(example_input)
|
||||
```
|
||||
|
65
torch/csrc/jit/codegen/onednn/decompose_silu.cpp
Normal file
65
torch/csrc/jit/codegen/onednn/decompose_silu.cpp
Normal file
@ -0,0 +1,65 @@
|
||||
#include <torch/csrc/jit/codegen/onednn/decompose_silu.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/operator.h>
|
||||
|
||||
#include <ATen/code_template.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
bool shouldDecomposeSilu(Node* node) {
|
||||
if (node->kind() != aten::silu) {
|
||||
return false;
|
||||
}
|
||||
auto inputToSilu = node->input(0)->node();
|
||||
if (inputToSilu->kind() == aten::_convolution) {
|
||||
// TODO: remove transpose check once the bridge supported ConvTranspose
|
||||
bool transposed = Operator::Bool(inputToSilu, 6);
|
||||
return !transposed;
|
||||
}
|
||||
if (inputToSilu->kind() == aten::linear) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void DecomposeSilu(Node* node) {
|
||||
if (shouldDecomposeSilu(node)) {
|
||||
auto dtype = node->input(0)->type()->expect<TensorType>();
|
||||
|
||||
WithInsertPoint guard(node);
|
||||
auto g = node->owningGraph();
|
||||
auto sigmoid = g->insert(aten::sigmoid, {node->input(0)});
|
||||
sigmoid->setType(dtype);
|
||||
|
||||
auto mul = g->insert(aten::mul, {sigmoid, node->input(0)});
|
||||
mul->setType(dtype);
|
||||
|
||||
node->output()->replaceAllUsesWith(mul);
|
||||
}
|
||||
}
|
||||
|
||||
static void DecomposeSilu(Block* block) {
|
||||
for (auto node : block->nodes()) {
|
||||
for (auto sub : node->blocks()) {
|
||||
DecomposeSilu(sub);
|
||||
}
|
||||
|
||||
if (node->kind() == aten::silu) {
|
||||
DecomposeSilu(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph) {
|
||||
DecomposeSilu(graph->block());
|
||||
EliminateDeadCode(graph);
|
||||
}
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
15
torch/csrc/jit/codegen/onednn/decompose_silu.h
Normal file
15
torch/csrc/jit/codegen/onednn/decompose_silu.h
Normal file
@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -40,7 +40,7 @@ Operator makeWildcardOp(Node* node) {
|
||||
auto o = Operator(node, opkind::Wildcard);
|
||||
// wildcard op contains only topology info
|
||||
for (size_t i = 0; i < node->inputs().size(); i++) {
|
||||
o.setInput(i);
|
||||
o.setInput(static_cast<size_t>(NULL), i);
|
||||
}
|
||||
for (size_t i = 0; i < node->outputs().size(); i++) {
|
||||
o.setOutput(i);
|
||||
@ -56,223 +56,257 @@ Operator makeWildcardOp(Node* node) {
|
||||
return makeWildcardOp(node); \
|
||||
}
|
||||
|
||||
Operator makeEltwiseOp(Node* node, opkind kind) {
|
||||
return Operator(node, kind).setInput(0).setOutput(0);
|
||||
Operator LlgaGraphHelper::makeEltwiseOp(Node* node, opkind kind) {
|
||||
return Operator(node, kind).setInput(0).setOutput(dnnl_graph_, 0);
|
||||
}
|
||||
|
||||
Operator makeBinaryOp(Node* node, opkind kind) {
|
||||
Operator LlgaGraphHelper::makeBinaryOp(Node* node, opkind kind) {
|
||||
REQUIRE(
|
||||
node->input(0)->type()->isSubtypeOf(TensorType::get()) &&
|
||||
node->input(1)->type()->isSubtypeOf(TensorType::get()))
|
||||
return Operator(node, kind).setInput(0, 1).setOutput(0);
|
||||
return Operator(node, kind).setInput(0, 1).setOutput(dnnl_graph_, 0);
|
||||
}
|
||||
|
||||
// Map a PyTorch op to its corresponding oneDNN Graph op.
|
||||
// If mapping isn't possible, then create a wildcard op instead.
|
||||
// The mapping is done as per oneDNN Graph op schema defined in
|
||||
// third_party/ideep/mkl-dnn/src/interface/op_def.hpp.
|
||||
Operator createOperator(Node* node) {
|
||||
switch (node->kind()) {
|
||||
case aten::conv2d: {
|
||||
fixConvOptionalBias(node);
|
||||
return Operator(node, opkind::Convolution)
|
||||
.setInput(0, 1, 2)
|
||||
.setOutput(0)
|
||||
.setAttr("strides", Operator::Ints, 3)
|
||||
.setAttr("pads_begin", Operator::Ints, 4)
|
||||
.setAttr("pads_end", Operator::Ints, 4)
|
||||
.setAttr("dilations", Operator::Ints, 5)
|
||||
.setAttr("groups", Operator::Int, 6)
|
||||
.setAttr("filter_format", std::string("OIX"))
|
||||
.setAttr("data_format", std::string("NCX"));
|
||||
}
|
||||
|
||||
case aten::_convolution: {
|
||||
bool transposed = toIValue(node->namedInput("transposed"))->toBool();
|
||||
REQUIRE(!transposed);
|
||||
|
||||
return Operator(node, opkind::Convolution)
|
||||
.setInput(0, 1, 2)
|
||||
.setOutput(0)
|
||||
.setAttr("strides", Operator::Ints, 3)
|
||||
.setAttr("pads_begin", Operator::Ints, 4)
|
||||
.setAttr("pads_end", Operator::Ints, 4)
|
||||
.setAttr("dilations", Operator::Ints, 5)
|
||||
.setAttr("groups", Operator::Int, 8)
|
||||
.setAttr("filter_format", std::string("OIX"))
|
||||
.setAttr("data_format", std::string("NCX"));
|
||||
}
|
||||
|
||||
case aten::batch_norm: {
|
||||
auto training = toIValue(node->namedInput("training"));
|
||||
REQUIRE(
|
||||
training.has_value()); // cannot get training status in script mode
|
||||
REQUIRE(!training->toBool()); // TODO: support bn training
|
||||
Operator LlgaGraphHelper::createOperator(Node* node) {
|
||||
auto nodeKind = node->kind();
|
||||
// we're using an if-else clause instead of a switch staement
|
||||
// because we would soon be adding custom ops with function schemas.
|
||||
// We would have to use Symbol::fromQualString at that time anyway,
|
||||
// but we are okay with this choice, since this code is not in the hot-path.
|
||||
if (nodeKind == Symbol::fromQualString("aten::conv2d")) {
|
||||
fixConvOptionalBias(node);
|
||||
return Operator(node, opkind::Convolution)
|
||||
.setInput(0, 1, 2)
|
||||
.setOutput(dnnl_graph_, 0)
|
||||
.setAttr("strides", Operator::Ints, 3)
|
||||
.setAttr("pads_begin", Operator::Ints, 4)
|
||||
.setAttr("pads_end", Operator::Ints, 4)
|
||||
.setAttr("dilations", Operator::Ints, 5)
|
||||
.setAttr("groups", Operator::Int, 6)
|
||||
.setAttr("filter_format", std::string("OIX"))
|
||||
.setAttr("data_format", std::string("NCX"));
|
||||
} else if (
|
||||
(nodeKind == Symbol::fromQualString("aten::_convolution")) ||
|
||||
(nodeKind == Symbol::fromQualString("aten::convolution"))) {
|
||||
bool transposed = toIValue(node->namedInput("transposed"))->toBool();
|
||||
REQUIRE(!transposed);
|
||||
return Operator(node, opkind::Convolution)
|
||||
.setInput(0, 1, 2)
|
||||
.setOutput(dnnl_graph_, 0)
|
||||
.setAttr("strides", Operator::Ints, 3)
|
||||
.setAttr("pads_begin", Operator::Ints, 4)
|
||||
.setAttr("pads_end", Operator::Ints, 4)
|
||||
.setAttr("dilations", Operator::Ints, 5)
|
||||
.setAttr("groups", Operator::Int, 8)
|
||||
.setAttr("filter_format", std::string("OIX"))
|
||||
.setAttr("data_format", std::string("NCX"));
|
||||
} else if (nodeKind == Symbol::fromQualString("aten::batch_norm")) {
|
||||
auto training = toIValue(node->namedInput("training"));
|
||||
REQUIRE(training.has_value()); // cannot get training status in script mode
|
||||
if (!training->toBool()) {
|
||||
return Operator(node, opkind::BatchNormInference)
|
||||
.setInput(0, 1, 2, 3, 4)
|
||||
.setOutput(0)
|
||||
.setOutput(dnnl_graph_, 0)
|
||||
.setAttr("epsilon", Operator::Float, 7)
|
||||
.setAttr("data_format", std::string("NCX"));
|
||||
}
|
||||
|
||||
case aten::layer_norm: {
|
||||
auto normalized_shape = toIValue(node->namedInput("normalized_shape"));
|
||||
REQUIRE(normalized_shape->toIntList().size() == 1);
|
||||
return Operator(node, opkind::LayerNorm)
|
||||
.setInput(0, 2, 3)
|
||||
.setOutput(0)
|
||||
.setAttr("epsilon", Operator::Float, 4)
|
||||
.setAttr("keep_stats", false);
|
||||
} else if (nodeKind == Symbol::fromQualString("aten::layer_norm")) {
|
||||
auto normalized_shape = toIValue(node->namedInput("normalized_shape"));
|
||||
REQUIRE(normalized_shape->toIntList().size() == 1);
|
||||
return Operator(node, opkind::LayerNorm)
|
||||
.setInput(0, 2, 3)
|
||||
.setOutput(dnnl_graph_, 0)
|
||||
.setAttr("epsilon", Operator::Float, 4)
|
||||
.setAttr("keep_stats", false);
|
||||
} else if (nodeKind == Symbol::fromQualString("aten::addmm")) {
|
||||
auto alpha = toIValue(node->namedInput("alpha"));
|
||||
auto beta = toIValue(node->namedInput("beta"));
|
||||
if (alpha.has_value() && beta.has_value()) {
|
||||
if ((alpha->toDouble() == 1.0) && (beta->toDouble() == 1.0)) {
|
||||
return Operator(node, opkind::MatMul)
|
||||
.setInput(1, 2, 0)
|
||||
.setOutput(dnnl_graph_, 0);
|
||||
} else if ((alpha->toDouble() == 1.0) && (beta->toDouble() == 0.0)) {
|
||||
return Operator(node, opkind::MatMul)
|
||||
.setInput(1, 2)
|
||||
.setOutput(dnnl_graph_, 0);
|
||||
}
|
||||
}
|
||||
|
||||
case aten::addmm: {
|
||||
auto alpha = toIValue(node->namedInput("alpha"));
|
||||
auto beta = toIValue(node->namedInput("beta"));
|
||||
REQUIRE(
|
||||
alpha.has_value() && beta.has_value() && (alpha->toDouble() == 1.0) &&
|
||||
(beta->toDouble() == 1.0));
|
||||
return Operator(node, opkind::MatMul).setInput(1, 2, 0).setOutput(0);
|
||||
}
|
||||
|
||||
case aten::add:
|
||||
return makeBinaryOp(node, opkind::Add);
|
||||
|
||||
case aten::mul:
|
||||
return makeBinaryOp(node, opkind::Multiply);
|
||||
|
||||
case aten::tanh:
|
||||
return makeEltwiseOp(node, opkind::Tanh);
|
||||
|
||||
case aten::relu:
|
||||
return makeEltwiseOp(node, opkind::ReLU);
|
||||
|
||||
case aten::elu:
|
||||
return makeEltwiseOp(node, opkind::Elu)
|
||||
.setAttr("alpha", Operator::Float, 1);
|
||||
|
||||
case aten::sigmoid:
|
||||
return makeEltwiseOp(node, opkind::Sigmoid);
|
||||
case aten::gelu:
|
||||
return makeEltwiseOp(node, opkind::GELU);
|
||||
|
||||
case aten::sqrt:
|
||||
return makeEltwiseOp(node, opkind::Sqrt);
|
||||
|
||||
case aten::abs:
|
||||
return makeEltwiseOp(node, opkind::Abs);
|
||||
|
||||
case aten::square:
|
||||
return makeEltwiseOp(node, opkind::Square);
|
||||
|
||||
case aten::hardtanh:
|
||||
return makeEltwiseOp(node, opkind::HardTanh)
|
||||
.setAttr("min", Operator::Float, 1)
|
||||
.setAttr("max", Operator::Float, 2);
|
||||
|
||||
case aten::relu6:
|
||||
return makeEltwiseOp(node, opkind::HardTanh)
|
||||
.setAttr("min", 0.f)
|
||||
.setAttr("max", 6.f);
|
||||
|
||||
case aten::softmax: {
|
||||
auto axis = toIValue(node->namedInput("dim"))->toInt();
|
||||
return Operator(node, opkind::SoftMax)
|
||||
.setInput(0)
|
||||
.setOutput(0)
|
||||
.setAttr("axis", axis);
|
||||
}
|
||||
|
||||
case aten::cat: {
|
||||
auto o = Operator(node, opkind::Concat);
|
||||
REQUIRE(
|
||||
node->namedInput("tensors")->node()->kind() == prim::ListConstruct);
|
||||
REQUIRE(node->namedInput("tensors")->uses().size() == 1);
|
||||
REQUIRE(node->namedInput("dim")->node()->kind() == prim::Constant);
|
||||
// aten::cat needs a special handling since it takes a Tensor[] as input.
|
||||
// We set the inputs of ListConstruct as the inputs of cat.
|
||||
//
|
||||
// Pytorch IR: LLGA sees:
|
||||
// %a %b %c %dim %a %b %c
|
||||
// \ | / | \ | /
|
||||
// prim::ListConstruct prim::Constant llga::Concat[axis=%dim]
|
||||
// \ /
|
||||
// aten::cat
|
||||
auto listConstruct = node->input(0)->node();
|
||||
for (auto input : listConstruct->inputs())
|
||||
o.setInputValue(input);
|
||||
return o.setOutput(0).setAttr("axis", Operator::Int, 1);
|
||||
}
|
||||
|
||||
case aten::max_pool2d: {
|
||||
REQUIRE(
|
||||
node->namedInput("kernel_size")->node()->kind() == prim::Constant);
|
||||
|
||||
auto rounding_type =
|
||||
toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor";
|
||||
return Operator(node, opkind::MaxPool)
|
||||
.setInput(0)
|
||||
.setOutput(0)
|
||||
.setAttr("kernel", Operator::Ints, 1)
|
||||
.setAttr("strides", Operator::Ints, 2)
|
||||
.setAttr("pads_begin", Operator::Ints, 3)
|
||||
.setAttr("pads_end", Operator::Ints, 3)
|
||||
.setAttr("dilations", Operator::Ints, 4)
|
||||
.setAttr("rounding_type", std::string(rounding_type))
|
||||
.setAttr("data_format", std::string("NCX"));
|
||||
}
|
||||
|
||||
case aten::avg_pool2d: {
|
||||
// TODO: do we need add checks for all Constants?
|
||||
REQUIRE(
|
||||
node->namedInput("kernel_size")->node()->kind() == prim::Constant);
|
||||
auto rounding_type =
|
||||
toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor";
|
||||
auto divisor_override = toIValue(node->namedInput("divisor_override"));
|
||||
REQUIRE(divisor_override->isNone());
|
||||
return Operator(node, opkind::AvgPool)
|
||||
.setInput(0)
|
||||
.setOutput(0)
|
||||
.setAttr("kernel", Operator::Ints, 1)
|
||||
.setAttr("strides", Operator::Ints, 2)
|
||||
.setAttr("pads_begin", Operator::Ints, 3)
|
||||
.setAttr("pads_end", Operator::Ints, 3)
|
||||
.setAttr("exclude_pad", !Operator::Bool(node, 5))
|
||||
.setAttr("rounding_type", std::string(rounding_type))
|
||||
.setAttr("data_format", std::string("NCX"));
|
||||
}
|
||||
|
||||
case aten::matmul: {
|
||||
auto dim0 = getDimensions(node->namedInput("self")).value_or(-1);
|
||||
auto dim1 = getDimensions(node->namedInput("other")).value_or(-1);
|
||||
// TODO: support all shape combinations
|
||||
REQUIRE(
|
||||
(dim0 == 2 && dim1 == 2) || (dim0 == 4 && dim1 == 4) ||
|
||||
(dim0 == 3 && dim1 == 2));
|
||||
} // fall through
|
||||
case aten::mm: {
|
||||
return Operator(node, opkind::MatMul).setInput(0, 1).setOutput(0);
|
||||
}
|
||||
|
||||
case aten::linear: {
|
||||
return Operator(node, opkind::MatMul)
|
||||
.setInput(0, 1, 2)
|
||||
.setOutput(0)
|
||||
.setAttr("transpose_b", true);
|
||||
}
|
||||
|
||||
default:
|
||||
return makeWildcardOp(node);
|
||||
} else if (nodeKind == Symbol::fromQualString("aten::add"))
|
||||
return makeBinaryOp(node, opkind::Add);
|
||||
else if (nodeKind == Symbol::fromQualString("aten::mul"))
|
||||
return makeBinaryOp(node, opkind::Multiply);
|
||||
else if (nodeKind == Symbol::fromQualString("aten::div"))
|
||||
return makeBinaryOp(node, opkind::Divide);
|
||||
else if (nodeKind == Symbol::fromQualString("aten::tanh"))
|
||||
return makeEltwiseOp(node, opkind::Tanh);
|
||||
else if (nodeKind == Symbol::fromQualString("aten::relu"))
|
||||
return makeEltwiseOp(node, opkind::ReLU);
|
||||
else if (nodeKind == Symbol::fromQualString("aten::elu"))
|
||||
return makeEltwiseOp(node, opkind::Elu)
|
||||
.setAttr("alpha", Operator::Float, 1);
|
||||
else if (nodeKind == Symbol::fromQualString("aten::sigmoid"))
|
||||
return makeEltwiseOp(node, opkind::Sigmoid);
|
||||
else if (nodeKind == Symbol::fromQualString("aten::gelu"))
|
||||
return makeEltwiseOp(node, opkind::GELU);
|
||||
else if (nodeKind == Symbol::fromQualString("aten::round"))
|
||||
return makeEltwiseOp(node, opkind::Round);
|
||||
else if (nodeKind == Symbol::fromQualString("aten::exp"))
|
||||
return makeEltwiseOp(node, opkind::Exp);
|
||||
else if (nodeKind == Symbol::fromQualString("aten::sqrt"))
|
||||
return makeEltwiseOp(node, opkind::Sqrt);
|
||||
else if (nodeKind == Symbol::fromQualString("aten::abs"))
|
||||
return makeEltwiseOp(node, opkind::Abs);
|
||||
else if (nodeKind == Symbol::fromQualString("aten::square"))
|
||||
return makeEltwiseOp(node, opkind::Square);
|
||||
else if (nodeKind == Symbol::fromQualString("aten::clamp")) {
|
||||
// PyTorch API already checks that both min & max are not None.
|
||||
// But we can check it nevertheless.
|
||||
auto clamp_min = toIValue(node->input(1));
|
||||
auto clamp_max = toIValue(node->input(2));
|
||||
REQUIRE(!(clamp_max->isNone() && clamp_min->isNone()));
|
||||
auto clamp_min_value = (clamp_min->isNone())
|
||||
? -std::numeric_limits<float>::infinity()
|
||||
: Operator::ScalarToFloat(node, 1);
|
||||
auto clamp_max_value = (clamp_max->isNone())
|
||||
? std::numeric_limits<float>::infinity()
|
||||
: Operator::ScalarToFloat(node, 2);
|
||||
return makeEltwiseOp(node, opkind::Clamp)
|
||||
.setAttr("min", clamp_min_value)
|
||||
.setAttr("max", clamp_max_value);
|
||||
} else if (nodeKind == Symbol::fromQualString("aten::hardtanh")) {
|
||||
return makeEltwiseOp(node, opkind::Clamp)
|
||||
.setAttr("min", Operator::ScalarToFloat, 1)
|
||||
.setAttr("max", Operator::ScalarToFloat, 2);
|
||||
} else if (nodeKind == Symbol::fromQualString("aten::hardswish"))
|
||||
return makeEltwiseOp(node, opkind::HardSwish);
|
||||
else if (nodeKind == Symbol::fromQualString("aten::log"))
|
||||
return makeEltwiseOp(node, opkind::Log);
|
||||
else if (nodeKind == Symbol::fromQualString("aten::leaky_relu")) {
|
||||
return makeEltwiseOp(node, opkind::LeakyReLU)
|
||||
.setAttr("alpha", Operator::Float, 1);
|
||||
} else if (nodeKind == Symbol::fromQualString("aten::relu6")) {
|
||||
return makeEltwiseOp(node, opkind::Clamp)
|
||||
.setAttr("min", 0.f)
|
||||
.setAttr("max", 6.f);
|
||||
} else if (
|
||||
(nodeKind == Symbol::fromQualString("aten::softmax")) ||
|
||||
(nodeKind == Symbol::fromQualString("aten::_softmax"))) {
|
||||
auto axis = toIValue(node->namedInput("dim"))->toInt();
|
||||
return Operator(node, opkind::SoftMax)
|
||||
.setInput(0)
|
||||
.setOutput(dnnl_graph_, 0)
|
||||
.setAttr("axis", axis);
|
||||
} else if (nodeKind == Symbol::fromQualString("aten::_log_softmax")) {
|
||||
auto axis = toIValue(node->namedInput("dim"))->toInt();
|
||||
return Operator(node, opkind::LogSoftmax)
|
||||
.setInput(0)
|
||||
.setOutput(dnnl_graph_, 0)
|
||||
.setAttr("axis", axis);
|
||||
} else if (nodeKind == Symbol::fromQualString("aten::cat")) {
|
||||
auto o = Operator(node, opkind::Concat);
|
||||
REQUIRE(node->namedInput("tensors")->node()->kind() == prim::ListConstruct);
|
||||
REQUIRE(node->namedInput("tensors")->uses().size() == 1);
|
||||
REQUIRE(node->namedInput("dim")->node()->kind() == prim::Constant);
|
||||
// aten::cat needs a special handling since it takes a Tensor[] as input.
|
||||
// We set the inputs of ListConstruct as the inputs of cat.
|
||||
//
|
||||
// Pytorch IR: LLGA sees:
|
||||
// %a %b %c %dim %a %b %c
|
||||
// \ | / | \ | /
|
||||
// prim::ListConstruct prim::Constant llga::Concat[axis=%dim]
|
||||
// \ /
|
||||
// aten::cat
|
||||
auto listConstruct = node->input(0)->node();
|
||||
for (auto input : listConstruct->inputs())
|
||||
o.setInputValue(input);
|
||||
return o.setOutput(dnnl_graph_, 0).setAttr("axis", Operator::Int, 1);
|
||||
} else if (
|
||||
(nodeKind == Symbol::fromQualString("aten::max_pool2d")) ||
|
||||
(nodeKind == Symbol::fromQualString("aten::max_pool2d_with_indices"))) {
|
||||
// Currently, LLGA lacks support to create indices mask.
|
||||
// Once it's supported, max_pool2d_with_indices should be mapped differently
|
||||
REQUIRE(node->namedInput("kernel_size")->node()->kind() == prim::Constant);
|
||||
auto rounding_type =
|
||||
toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor";
|
||||
return Operator(node, opkind::MaxPool)
|
||||
.setInput(0)
|
||||
.setOutput(dnnl_graph_, 0)
|
||||
.setAttr("kernel", Operator::Ints, 1)
|
||||
.setAttr("strides", Operator::Ints, 2)
|
||||
.setAttr("pads_begin", Operator::Ints, 3)
|
||||
.setAttr("pads_end", Operator::Ints, 3)
|
||||
.setAttr("dilations", Operator::Ints, 4)
|
||||
.setAttr("rounding_type", std::string(rounding_type))
|
||||
.setAttr("data_format", std::string("NCX"));
|
||||
} else if (nodeKind == Symbol::fromQualString("aten::avg_pool2d")) {
|
||||
// TODO: do we need add checks for all Constants?
|
||||
REQUIRE(node->namedInput("kernel_size")->node()->kind() == prim::Constant);
|
||||
auto rounding_type =
|
||||
toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor";
|
||||
auto divisor_override = toIValue(node->namedInput("divisor_override"));
|
||||
REQUIRE(divisor_override->isNone());
|
||||
return Operator(node, opkind::AvgPool)
|
||||
.setInput(0)
|
||||
.setOutput(dnnl_graph_, 0)
|
||||
.setAttr("kernel", Operator::Ints, 1)
|
||||
.setAttr("strides", Operator::Ints, 2)
|
||||
.setAttr("pads_begin", Operator::Ints, 3)
|
||||
.setAttr("pads_end", Operator::Ints, 3)
|
||||
.setAttr("exclude_pad", !Operator::Bool(node, 5))
|
||||
.setAttr("rounding_type", std::string(rounding_type))
|
||||
.setAttr("data_format", std::string("NCX"));
|
||||
} else if (nodeKind == Symbol::fromQualString("aten::matmul")) {
|
||||
auto dim0 = getDimensions(node->namedInput("self")).value_or(-1);
|
||||
auto dim1 = getDimensions(node->namedInput("other")).value_or(-1);
|
||||
// TODO: support all shape combinations
|
||||
REQUIRE(
|
||||
(dim0 == 2 && dim1 == 2) || (dim0 == 4 && dim1 == 4) ||
|
||||
(dim0 == 3 && dim1 == 2));
|
||||
return Operator(node, opkind::MatMul)
|
||||
.setInput(0, 1)
|
||||
.setOutput(dnnl_graph_, 0);
|
||||
} // fall through
|
||||
else if (nodeKind == Symbol::fromQualString("aten::mm")) {
|
||||
return Operator(node, opkind::MatMul)
|
||||
.setInput(0, 1)
|
||||
.setOutput(dnnl_graph_, 0);
|
||||
} else if (nodeKind == Symbol::fromQualString("aten::bmm")) {
|
||||
return Operator(node, opkind::MatMul)
|
||||
.setInput(0, 1)
|
||||
.setOutput(dnnl_graph_, 0);
|
||||
} else if (nodeKind == Symbol::fromQualString("aten::linear")) {
|
||||
return Operator(node, opkind::MatMul)
|
||||
.setInput(0, 1, 2)
|
||||
.setOutput(dnnl_graph_, 0)
|
||||
.setAttr("transpose_b", true);
|
||||
} else if (nodeKind == Symbol::fromQualString("aten::permute")) {
|
||||
REQUIRE(aliasDb_->hasInputWriters(node) == false);
|
||||
return Operator(node, opkind::StaticTranspose)
|
||||
.setInput(0)
|
||||
.setOutput(dnnl_graph_, 0)
|
||||
.setAttr("order", toIValue(node->namedInput("dims"))->toIntVector());
|
||||
} else if (nodeKind == Symbol::fromQualString("aten::contiguous")) {
|
||||
// Contiguous should only be mapped to oneDNN Graph if the destination
|
||||
// memory-layout is different than the source memory-format
|
||||
// Strides would be different, but shape would be same
|
||||
auto typeOfInput = node->input(0)->type()->expect<TensorType>();
|
||||
auto typeOfOutput = node->output(0)->type()->expect<TensorType>();
|
||||
auto inputStrides = typeOfInput->strides().concrete_sizes();
|
||||
auto outputStrides = typeOfOutput->strides().concrete_sizes();
|
||||
REQUIRE(inputStrides != outputStrides);
|
||||
return Operator(node, opkind::Reorder)
|
||||
.setInput(0)
|
||||
.setOutput(dnnl_graph_, 0);
|
||||
}
|
||||
GRAPH_DEBUG("Making ", nodeKind.toQualString(), " a wildcard");
|
||||
return makeWildcardOp(node);
|
||||
}
|
||||
|
||||
dnnl::graph::op createLlgaOp(Node* node) {
|
||||
return createOperator(node).llgaOp();
|
||||
}
|
||||
|
||||
bool isSupported(Node* node) {
|
||||
return createOperator(node).kind() != opkind::Wildcard;
|
||||
};
|
||||
|
||||
DeviceType inferDeviceFromValue(Value* v) {
|
||||
auto tt = v->type()->cast<TensorType>();
|
||||
if (!tt) {
|
||||
@ -336,13 +370,26 @@ bool checkInputCompatibility(Node* node) {
|
||||
return false;
|
||||
}
|
||||
auto dtype = tensor.scalar_type();
|
||||
if ((dtype != at::ScalarType::Float) && (dtype != at::ScalarType::Long)) {
|
||||
if ((dtype != at::ScalarType::BFloat16) &&
|
||||
(dtype != at::ScalarType::Float) && (dtype != at::ScalarType::Long)) {
|
||||
// We've allowed Long dtype here although oneDNN Graph does not support
|
||||
// Long dtype because oneDNN Graph will end up not handling the op that
|
||||
// has an input with Long dtype, so it'd be handled by PyTorch.
|
||||
return false;
|
||||
}
|
||||
} else if (inputIValue.isScalar()) {
|
||||
if (inputIValue.isComplexDouble()) {
|
||||
return false;
|
||||
}
|
||||
} else if (input->type()->isSubtypeOf(TensorType::get())) {
|
||||
auto input_typeptr = input->type()->cast<TensorType>();
|
||||
if (input_typeptr->scalarType().has_value()) {
|
||||
at::ScalarType dtype = input_typeptr->scalarType().value();
|
||||
if ((dtype != at::ScalarType::Float) &&
|
||||
(dtype != at::ScalarType::BFloat16)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
@ -353,19 +400,21 @@ LlgaGraphHelper::LlgaGraphHelper(
|
||||
dnnl::graph::partition::policy policy) {
|
||||
auto deviceType = inferDevice(graph);
|
||||
auto engineKind = getLlgaEngineKind(deviceType);
|
||||
dnnl::graph::graph g{engineKind};
|
||||
|
||||
dnnl_graph_ =
|
||||
std::unique_ptr<dnnl::graph::graph>(new dnnl::graph::graph(engineKind));
|
||||
aliasDb_ = std::make_unique<torch::jit::AliasDb>(graph);
|
||||
GRAPH_DEBUG("Constructing LLGA graph");
|
||||
// TODO: select nodes in top-level block for now
|
||||
for (auto* node : graph->block()->nodes()) {
|
||||
auto op = createLlgaOp(node);
|
||||
auto kindOfNode = node->kind();
|
||||
GRAPH_DEBUG("Trying to add ", kindOfNode.toQualString());
|
||||
if (checkInputCompatibility(node)) {
|
||||
g.add_op(op);
|
||||
auto op = createOperator(node);
|
||||
dnnl_graph_->add_op(op.llgaOp());
|
||||
GRAPH_DEBUG(" Added node ", kindOfNode.toQualString());
|
||||
} else {
|
||||
GRAPH_DEBUG("The backend failed to add node ", kindOfNode.toQualString());
|
||||
g.add_op(makeWildcardOp(node).llgaOp());
|
||||
GRAPH_DEBUG("Incompatible inputs for ", kindOfNode.toQualString());
|
||||
dnnl_graph_->add_op(makeWildcardOp(node).llgaOp());
|
||||
}
|
||||
|
||||
for (Value* input : node->inputs()) {
|
||||
@ -374,7 +423,8 @@ LlgaGraphHelper::LlgaGraphHelper(
|
||||
}
|
||||
|
||||
GRAPH_DEBUG("Get Partitions");
|
||||
std::vector<dnnl::graph::partition> partitions = g.get_partitions(policy);
|
||||
std::vector<dnnl::graph::partition> partitions =
|
||||
dnnl_graph_->get_partitions(policy);
|
||||
// excluded unsupported Wildcard partitions
|
||||
for (auto& partition : partitions) {
|
||||
if (partition.is_supported()) {
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <oneapi/dnnl/dnnl_graph.hpp>
|
||||
#include <torch/csrc/jit/codegen/onednn/operator.h>
|
||||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
@ -60,13 +61,20 @@ class LlgaGraphHelper {
|
||||
|
||||
static bool isLlgaSubgraph(const Node* node);
|
||||
|
||||
Operator makeEltwiseOp(Node* node, dnnl::graph::op::kind kind);
|
||||
|
||||
Operator makeBinaryOp(Node* node, dnnl::graph::op::kind kind);
|
||||
|
||||
std::vector<dnnl::graph::partition> getPartitions() const;
|
||||
|
||||
std::map<size_t, Value*> getTensorIdToValue() const;
|
||||
|
||||
Operator createOperator(Node* node);
|
||||
|
||||
private:
|
||||
size_t countSupportedOps(const std::shared_ptr<Graph>& graph) const;
|
||||
|
||||
std::unique_ptr<dnnl::graph::graph> dnnl_graph_ = nullptr;
|
||||
std::unique_ptr<torch::jit::AliasDb> aliasDb_ = nullptr;
|
||||
OpPartitionMap opToOwningPartition_;
|
||||
std::vector<dnnl::graph::partition> partitions_;
|
||||
std::map<size_t, Value*>
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <oneapi/dnnl/dnnl_graph.hpp>
|
||||
#include <torch/csrc/jit/codegen/onednn/decompose_silu.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/defer_size_check.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/guard_shape.h>
|
||||
@ -56,11 +57,19 @@ void fuseGraph(std::shared_ptr<Graph>& g) {
|
||||
aten::hardtanh_,
|
||||
aten::abs_,
|
||||
aten::square_,
|
||||
};
|
||||
aten::pow_,
|
||||
aten::leaky_relu_,
|
||||
aten::round_,
|
||||
aten::exp_,
|
||||
aten::abs_,
|
||||
aten::hardswish_,
|
||||
aten::silu_};
|
||||
return supportedOps.count(nodeToFunctionalize->kind()) != 0;
|
||||
});
|
||||
RemoveListMutation(g);
|
||||
GRAPH_DUMP("After mutation removal. Before PrepareBinaryForLLGA", g);
|
||||
GRAPH_DUMP("After mutation removal. Before DecomposeSiluForLlga", g);
|
||||
DecomposeSiluForLLGA(g);
|
||||
GRAPH_DUMP("After DecomposeSiluForLlga. Before PrepareBinaryForLLGA", g);
|
||||
PrepareBinaryForLLGA(g);
|
||||
GRAPH_DUMP("After PrepareBinaryForLLGA. Before DeferSizeCheck", g);
|
||||
DeferSizeCheck(g);
|
||||
|
@ -149,35 +149,42 @@ std::tuple<RunArgs, RunArgs> LlgaKernel::prepareRunArgs(
|
||||
|
||||
if (spec.reuses_input_tensor()) {
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("oneDNN Graph would perform inplace computation");
|
||||
GRAPH_DEBUG("inplace computation - input tensor would be reused");
|
||||
#endif
|
||||
auto inputTensor = inputs[spec.get_input_tensor_index()];
|
||||
auto dataType = spec.dtype();
|
||||
if (C10_UNLIKELY(!useOpaqueLayout(i) && inputTensor.is_mkldnn())) {
|
||||
// If the input tensor was between two partitions, it would've been
|
||||
// wrapped with LlgaTensorImpl. But if it's being reused as the output
|
||||
// tensor which is not between two partitions, then we'd have to re-wrap
|
||||
// it with TensorImpl, as it'd be fed into a PyTorch op.
|
||||
if (inputTensor.is_mkldnn()) {
|
||||
auto dataType = spec.dtype();
|
||||
if (C10_UNLIKELY(!useOpaqueLayout(i))) {
|
||||
// If the input tensor was between two partitions, it would've been
|
||||
// wrapped with LlgaTensorImpl. But if it's being reused as the output
|
||||
// tensor, which is not between two partitions, then we'd have to
|
||||
// re-wrap it with a sub-class of TensorImpl, as it'd be fed into a
|
||||
// PyTorch op.
|
||||
#ifdef GRAPH_DEBUG_ENABLED
|
||||
GRAPH_DEBUG("Rewrap tensor");
|
||||
GRAPH_DEBUG("rewrap tensors");
|
||||
#endif
|
||||
auto llgaImpl =
|
||||
static_cast<LlgaTensorImpl*>(inputTensor.unsafeGetTensorImpl());
|
||||
switch (dataType) {
|
||||
case data_type::f32:
|
||||
case data_type::bf16:
|
||||
inputTensor = LlgaTensorImpl::llga_to_aten_tensor(llgaImpl);
|
||||
break;
|
||||
case data_type::s32:
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false, "Invalid data type ", static_cast<size_t>(dataType));
|
||||
auto llgaImpl =
|
||||
static_cast<LlgaTensorImpl*>(inputTensor.unsafeGetTensorImpl());
|
||||
switch (dataType) {
|
||||
case data_type::f32:
|
||||
case data_type::bf16:
|
||||
inputTensor = LlgaTensorImpl::llga_to_aten_tensor(llgaImpl);
|
||||
break;
|
||||
case data_type::s32:
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false, "Invalid data type ", static_cast<size_t>(dataType));
|
||||
}
|
||||
}
|
||||
outputs.push_back(inputTensor);
|
||||
runOutputs.push_back(
|
||||
{spec.logical_tensor(),
|
||||
Engine::getEngine(),
|
||||
inputTensor.data_ptr()});
|
||||
return std::make_tuple(runInputs, runOutputs);
|
||||
}
|
||||
outputs.push_back(inputTensor);
|
||||
runOutputs.push_back(
|
||||
{spec.logical_tensor(), Engine::getEngine(), inputTensor.data_ptr()});
|
||||
} else if (useOpaqueLayout(i)) {
|
||||
}
|
||||
if (useOpaqueLayout(i)) {
|
||||
// Wrap tensors between partitions with LlgaTensorImpl wrapper, so that we
|
||||
// can bypass guard-check, as strides would be different than those
|
||||
// expected.
|
||||
|
@ -14,9 +14,25 @@ class Operator {
|
||||
Operator(const Node* node, dnnl::graph::op::kind kind)
|
||||
: n(node), o(getId(node), kind, node->kind().toQualString()), k(kind) {}
|
||||
|
||||
// Returns output index if the Value is a graph output.
|
||||
// Otherwise returns -1
|
||||
int32_t graphOutputIdx(Value* v) {
|
||||
int32_t i = 0;
|
||||
for (const Value* output : v->owningGraph()->outputs()) {
|
||||
if (v == output) {
|
||||
return i;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
Operator& setInputValue(Value* v) {
|
||||
if (v->mustNotBeNone())
|
||||
o.add_input(createLogicalTensor(v));
|
||||
if (v->mustNotBeNone()) {
|
||||
if (v->type()->kind() == c10::TensorType::Kind) {
|
||||
o.add_input(createLogicalTensor(v));
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
@ -31,19 +47,50 @@ class Operator {
|
||||
}
|
||||
|
||||
Operator& setOutputValue(Value* v) {
|
||||
if (v->mustNotBeNone())
|
||||
if (v->mustNotBeNone()) {
|
||||
o.add_output(createLogicalTensor(v));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// setOutputValue & setOutput require a pointer to the LLGA graph, as output
|
||||
// logical tensors that are graph outputs should be connected to an End LLGA
|
||||
// op. A value of NULL can be provided for the graph pointer in order to
|
||||
// maintain the legacy functionality of this function.
|
||||
Operator& setOutputValue(Value* v, std::unique_ptr<dnnl::graph::graph>& g) {
|
||||
if (v->mustNotBeNone()) {
|
||||
auto output_tensor = createLogicalTensor(v);
|
||||
o.add_output(output_tensor);
|
||||
if (g) {
|
||||
int32_t outputIndex = graphOutputIdx(v);
|
||||
if (outputIndex != -1) {
|
||||
dnnl::graph::op newEndNode(
|
||||
LONG_MAX - outputIndex,
|
||||
dnnl::graph::op::kind::End,
|
||||
"EndNodeForGraphOutput");
|
||||
newEndNode.add_input(output_tensor);
|
||||
g->add_op(newEndNode);
|
||||
}
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
Operator& setOutput(std::unique_ptr<dnnl::graph::graph>& g, size_t offset) {
|
||||
return setOutputValue(n->output(offset), g);
|
||||
}
|
||||
|
||||
Operator& setOutput(size_t offset) {
|
||||
return setOutputValue(n->output(offset));
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
Operator& setOutput(size_t offset, Ts... other) {
|
||||
setOutput(offset);
|
||||
return setOutput(other...);
|
||||
Operator& setOutput(
|
||||
std::unique_ptr<dnnl::graph::graph>& g,
|
||||
size_t offset,
|
||||
Ts... other) {
|
||||
setOutput(g, offset);
|
||||
return setOutput(g, other...);
|
||||
}
|
||||
|
||||
template <typename Attr>
|
||||
@ -57,6 +104,10 @@ class Operator {
|
||||
return setAttr(name, fn(n, offset));
|
||||
}
|
||||
|
||||
static float ScalarToFloat(const Node* node, size_t offset) {
|
||||
return toIValue(node->input(offset))->toScalar().to<float>();
|
||||
}
|
||||
|
||||
static std::vector<int64_t> Ints(const Node* node, size_t offset) {
|
||||
return toIValue(node->input(offset))->toIntVector();
|
||||
}
|
||||
|
@ -1,3 +1,4 @@
|
||||
#include <aten/src/ATen/core/jit_type.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/prepare_binary.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||
@ -14,29 +15,98 @@ bool compareConstValue(Value* v, double d) {
|
||||
(ival->isDouble() && ival->toDouble() == d));
|
||||
}
|
||||
|
||||
void mayConvertScalarInputToTensor(Node* node) {
|
||||
void handleBinaryOpInputs(Node* node) {
|
||||
// We do not handle binary ops with two scalar inputs,
|
||||
// and we assume scalar is always at the second place.
|
||||
if (node->input(0)->type()->isSubtypeOf(TensorType::get()) &&
|
||||
(node->input(1)->type()->isSubtypeOf(FloatType::get()) ||
|
||||
node->input(1)->type()->isSubtypeOf(IntType::get()))) {
|
||||
auto scalar = node->input(1);
|
||||
WithInsertPoint guard(node);
|
||||
auto g = node->owningGraph();
|
||||
// 42 : Scalar --> tensor(42.0) : Float([])
|
||||
auto t = g->insert(
|
||||
aten::as_tensor, {scalar}, {{"dtype", at::ScalarType::Float}});
|
||||
// add dim & stride info to IR
|
||||
c10::optional<size_t> t_dim = 1;
|
||||
auto target_type = TensorTypePtr(
|
||||
TensorType::create(at::ScalarType::Float, at::kCPU, t_dim, false));
|
||||
target_type = target_type->withSizes({1});
|
||||
t->setType(target_type);
|
||||
if (node->input(0)->type()->isSubtypeOf(TensorType::get())) {
|
||||
auto dtypeOfFirstInput =
|
||||
node->input(0)->type()->cast<TensorType>()->scalarType().value();
|
||||
if (node->input(1)->type()->isSubtypeOf(FloatType::get()) ||
|
||||
node->input(1)->type()->isSubtypeOf(IntType::get())) {
|
||||
// If a scalar is added to be a tensor, we would assume that the
|
||||
// scalar is of the same dtype as the tensor, as oneDNN graph
|
||||
// currently requires inputs of binary ops to have the same dtype.
|
||||
// We create a 1D tensor from the scalar input & "promote" its
|
||||
// dtype to that of the first input. Doing so helps us satisfy PyTorch's
|
||||
// type promotion rules.
|
||||
// Although we convert the scalar to a tensor, we still need to promote
|
||||
// types, as if the second input were still a scalar.
|
||||
// The following sample code-snippet illustrates that converting a scalar
|
||||
// input to a 1-D tensor may result in a different output dtype than would
|
||||
// otherwise have been the case.
|
||||
// clang-format off
|
||||
// >>> (1. + torch.rand([2]).half()).dtype
|
||||
// torch.float16
|
||||
// >>> (torch.tensor(1.).unsqueeze(0) + (torch.rand([2]).half())).dtype
|
||||
// torch.float32
|
||||
// clang-format on
|
||||
auto promotedDtype = dtypeOfFirstInput;
|
||||
auto scalar = node->input(1);
|
||||
WithInsertPoint guard(node);
|
||||
auto g = node->owningGraph();
|
||||
// 42 : Scalar --> tensor(42.0) : Float([])
|
||||
auto t = g->insert(aten::as_tensor, {scalar}, {{"dtype", promotedDtype}});
|
||||
// add dim & stride info to IR
|
||||
c10::optional<size_t> t_dim = 1;
|
||||
auto target_type = TensorTypePtr(
|
||||
TensorType::create(promotedDtype, at::kCPU, t_dim, false));
|
||||
target_type = target_type->withSizes({1});
|
||||
t->setType(target_type);
|
||||
|
||||
// tensor(42.0) : Float([]) --> tensor([42.0]) : Float([1])
|
||||
auto unsqueezed = g->insert(aten::unsqueeze, {t, 0});
|
||||
unsqueezed->setType(target_type);
|
||||
node->replaceInput(1, unsqueezed);
|
||||
// tensor(42.0) : Float([]) --> tensor([42.0]) : Float([1])
|
||||
auto unsqueezed = g->insert(aten::unsqueeze, {t, 0});
|
||||
unsqueezed->setType(target_type);
|
||||
node->replaceInput(1, unsqueezed);
|
||||
|
||||
// dtype might have changed, so needs to be updated in IR as well
|
||||
node->output()->setType(
|
||||
node->output()->type()->expect<TensorType>()->withScalarType(
|
||||
promotedDtype));
|
||||
} else if (node->input(1)->type()->isSubtypeOf(TensorType::get())) {
|
||||
// Here, both inputs are tensors, and we just wanna make sure that they
|
||||
// are the same dtype, as oneDNN Graph requires both inputs to have the
|
||||
// same dtype. We'll follow PyTorch's type-promotion rules here.
|
||||
auto second_input_typeptr = node->input(1)->type()->expect<TensorType>();
|
||||
c10::optional<at::ScalarType> second_input_type =
|
||||
second_input_typeptr->scalarType();
|
||||
if (second_input_type != c10::nullopt) {
|
||||
// dtype of the second tensor might not be available in the IR
|
||||
auto dtypeOfSecondInput = second_input_type.value();
|
||||
if (dtypeOfFirstInput != dtypeOfSecondInput) {
|
||||
// Type promotion is required
|
||||
auto promotedDtype =
|
||||
c10::promoteTypes(dtypeOfFirstInput, dtypeOfSecondInput);
|
||||
WithInsertPoint guard(node);
|
||||
auto g = node->owningGraph();
|
||||
if (promotedDtype == dtypeOfFirstInput) {
|
||||
auto to_node_output = g->insert(
|
||||
aten::to, {node->input(1)}, {{"dtype", promotedDtype}});
|
||||
to_node_output->setType(
|
||||
node->input(1)->type()->expect<TensorType>()->withScalarType(
|
||||
promotedDtype));
|
||||
node->replaceInput(1, to_node_output);
|
||||
} else {
|
||||
auto to_node_output = g->insert(
|
||||
aten::to, {node->input(0)}, {{"dtype", promotedDtype}});
|
||||
to_node_output->setType(
|
||||
node->input(0)->type()->expect<TensorType>()->withScalarType(
|
||||
promotedDtype));
|
||||
node->replaceInput(0, to_node_output);
|
||||
}
|
||||
// dtype might have changed, so needs to be updated in IR as well
|
||||
node->output()->setType(
|
||||
node->output()->type()->expect<TensorType>()->withScalarType(
|
||||
promotedDtype));
|
||||
} else {
|
||||
// both dtypes are same
|
||||
// IR info of dtypes is missing sometimes in JIT IR,
|
||||
// and we shouldn't treat those tensors as FP32 tensors by default.
|
||||
node->output()->setType(
|
||||
node->output()->type()->expect<TensorType>()->withScalarType(
|
||||
dtypeOfFirstInput));
|
||||
}
|
||||
} // end inner if block
|
||||
} // end outer if block
|
||||
}
|
||||
}
|
||||
|
||||
@ -46,13 +116,18 @@ static void ConvertScalarToTensor(Block* block) {
|
||||
ConvertScalarToTensor(sub);
|
||||
}
|
||||
|
||||
if (node->kind() == aten::add || node->kind() == aten::mul) {
|
||||
mayConvertScalarInputToTensor(node);
|
||||
if (node->kind() == aten::add || node->kind() == aten::mul ||
|
||||
node->kind() == aten::div) {
|
||||
handleBinaryOpInputs(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void mayDecomposeAdd(Node* node) {
|
||||
if (node->inputs().size() < 3) {
|
||||
return; // corner-case in BERT-mrpc that's not in line with
|
||||
// native_functions.yaml
|
||||
}
|
||||
if (toIValue(node->namedInput("alpha")).has_value()) {
|
||||
auto alphaEqualsOne = compareConstValue(node->namedInput("alpha"), 1.0);
|
||||
if (!alphaEqualsOne) {
|
||||
@ -60,6 +135,10 @@ void mayDecomposeAdd(Node* node) {
|
||||
auto g = node->owningGraph();
|
||||
auto mul = g->insert(
|
||||
aten::mul, {node->namedInput("other"), node->namedInput("alpha")});
|
||||
if (node->namedInput("other")->type()->isSubtypeOf(TensorType::get())) {
|
||||
auto mulTensorTypePtr = node->namedInput("other")->type();
|
||||
mul->setType(mulTensorTypePtr);
|
||||
}
|
||||
node->replaceInput(1, mul);
|
||||
auto one = g->insertConstant(1.0);
|
||||
node->replaceInput(2, one);
|
||||
|
Reference in New Issue
Block a user