mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Reland take-2] Add JIT graph fuser for oneDNN Graph API (v0.5)
Re-landing #68111/#74596 ## Description v0.5 PR of this [RFC](https://github.com/pytorch/pytorch/issues/49444). On the basis of #50256, the below improvements are included: * The [v0.5 release branch](https://github.com/oneapi-src/oneDNN/releases/tag/graph-v0.5) of the oneDNN Graph API is used * The fuser now works with the profiling graph executor. We have inserted type check nodes to guard the profiled tensor properties. ### User API: The optimization pass is disabled by default. Users could enable it by: ``` torch.jit.enable_onednn_fusion(True) ``` `torch.jit.freeze` should be used after tracing (recommended) or scripting a model. ### Performance: [pytorch/benchmark](https://github.com/pytorch/benchmark) tool is used to compare the performance: * SkyLake 8180 (1 socket of 28 cores):  * SkyLake 8180 (single thread):  * By mapping hardswish to oneDNN Graph, it’s 8% faster than PyTorch JIT (NNC + OFI) ** We expect performance gain after mapping transpose, contiguous & view to oneDNN graph ops ### Directory structure of the integration code Fuser-related code is placed under: ``` torch/csrc/jit/codegen/onednn/ ``` Optimization pass registration is done in: ``` torch/csrc/jit/passes/onednn_graph_fuser.h ``` CMake for the integration code is in: ``` caffe2/CMakeLists.txt cmake/public/mkldnn.cmake cmake/Modules/FindMKLDNN.cmake ``` ## Limitations * In this PR, we only support Pytorch-oneDNN-Graph integration on Linux platform. Support on Windows and MacOS will be enabled as a next step. * We have only optimized the inference use-case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76622 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
94fc92f288
commit
4ee29d6033
519
test/test_jit_llga_fuser.py
Normal file
519
test/test_jit_llga_fuser.py
Normal file
@ -0,0 +1,519 @@
|
||||
# Owner(s): ["module: mkldnn"]
|
||||
import torch
|
||||
import unittest
|
||||
import itertools
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_SCIPY, IS_WINDOWS, IS_MACOS
|
||||
|
||||
LLGA_FUSION_GROUP = 'prim::oneDNNFusionGroup'
|
||||
LLGA_NOT_ENABLED = not torch._C.has_mkldnn or IS_WINDOWS or IS_MACOS
|
||||
|
||||
|
||||
def warmup_forward(f, *args, profiling_count=2):
|
||||
for i in range(profiling_count):
|
||||
results = f(*args)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class JitLlgaTestCase(JitTestCase):
|
||||
def setUp(self):
|
||||
torch.jit.enable_onednn_fusion(True)
|
||||
|
||||
def tearDown(self):
|
||||
torch.jit.enable_onednn_fusion(False)
|
||||
|
||||
def checkTrace(self, m, x, *args, **kwargs):
|
||||
if isinstance(m, torch.nn.Module):
|
||||
m.eval()
|
||||
with torch.no_grad(), \
|
||||
torch._jit_internal._disable_emit_hooks():
|
||||
traced = torch.jit.trace(m, x)
|
||||
if isinstance(m, torch.nn.Module):
|
||||
traced = torch.jit.freeze(traced)
|
||||
warmup_forward(traced, *x)
|
||||
fwd_graph = traced.graph_for(*x)
|
||||
|
||||
ref_o = m(*x)
|
||||
jit_o = traced(*x)
|
||||
self.assertEqual(jit_o, ref_o)
|
||||
return traced, fwd_graph
|
||||
|
||||
def assertFused(self, graph, fused_patterns):
|
||||
for pat in fused_patterns:
|
||||
self.assertGraphContainsExactly(graph, pat, 0)
|
||||
|
||||
|
||||
try:
|
||||
import torchvision
|
||||
HAS_TORCHVISION = True
|
||||
except ImportError:
|
||||
HAS_TORCHVISION = False
|
||||
except RuntimeError:
|
||||
HAS_TORCHVISION = False
|
||||
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, 'no torchvision')
|
||||
|
||||
def get_eltwise_fn(name):
|
||||
if hasattr(torch, name):
|
||||
return getattr(torch, name)
|
||||
elif hasattr(F, name):
|
||||
return getattr(F, name)
|
||||
else:
|
||||
raise NameError('Eltwise function %s not found' % name)
|
||||
|
||||
|
||||
@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
|
||||
class TestOp(JitLlgaTestCase):
|
||||
def test_conv2d(self):
|
||||
for [spatial, in_channels, out_channels, kernel, padding, stride, dilation, g, bias] in itertools.product(
|
||||
[7, 8],
|
||||
[8, 15],
|
||||
[7, 16],
|
||||
[3, 4],
|
||||
[0, 2],
|
||||
[1, 2],
|
||||
[1, 2],
|
||||
[1, 2],
|
||||
[True, False]):
|
||||
|
||||
m = nn.Conv2d(in_channels=in_channels * g,
|
||||
out_channels=out_channels * g,
|
||||
kernel_size=kernel,
|
||||
padding=padding,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
groups=g,
|
||||
bias=bias)
|
||||
|
||||
x = torch.rand(1, in_channels * g, spatial, spatial)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_bn2d(self):
|
||||
m = nn.BatchNorm2d(32).eval()
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
# single-op partition shouldn't be created for softmax
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
|
||||
|
||||
def test_eltwise(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn):
|
||||
super(M, self).__init__()
|
||||
self.eltwise = eltwise_fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.eltwise(x)
|
||||
|
||||
for eltwise in ['relu', 'gelu']:
|
||||
eltwise_fn = get_eltwise_fn(eltwise)
|
||||
m = M(eltwise_fn)
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
# single-op partition shouldn't be created.
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
|
||||
|
||||
def test_max_pool2d(self):
|
||||
for [spatial, kernel, padding, stride, dilation, ceil_mode] in itertools.product(
|
||||
[15, 16, 17, 18, 19],
|
||||
[4, 5],
|
||||
[0, 1, 2],
|
||||
[1, 2], # [1, 2, 4], TODO: fix issue in pad calculation
|
||||
[1], # [1, 2], TODO: backend support for dilation
|
||||
[True, False]):
|
||||
|
||||
m = nn.MaxPool2d(kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
ceil_mode=ceil_mode)
|
||||
|
||||
x = torch.rand(1, 4, spatial, spatial)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_avg_pool2d(self):
|
||||
for [spatial, kernel, padding, stride, ceil_mode, count_include_pad] in itertools.product(
|
||||
[15, 16, 17, 18, 19],
|
||||
[4, 5],
|
||||
[0, 1, 2],
|
||||
[1, 2, 4],
|
||||
[False], # TODO: oneDNN Graph does not fully support ceil_mode=True
|
||||
[True, False]):
|
||||
|
||||
m = nn.AvgPool2d(kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
ceil_mode=ceil_mode,
|
||||
count_include_pad=count_include_pad)
|
||||
|
||||
x = torch.rand(1, 4, spatial, spatial)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_variable_kernel_avg_pool2d(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
x = F.avg_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=0, count_include_pad=False)
|
||||
return x
|
||||
|
||||
x = torch.randn(1, 1000, 1, 1)
|
||||
m = M()
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
# kernel_size is not Constant, shouldn't have any LLGA_FUSION_GROUP
|
||||
# TODO: with shape specialization, should have 1 LLGA_FUSION_GROUP
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
|
||||
|
||||
def test_softmax(self):
|
||||
for dim in [-4, -3, -2, -1, 0, 1, 2, 3]:
|
||||
m = nn.Softmax(dim=dim)
|
||||
x = torch.rand(8, 12, 12, 12)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
# single-op partition shouldn't be created for softmax
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
|
||||
|
||||
def test_linear(self):
|
||||
for bias in [True, False]:
|
||||
x = torch.rand(32, 28)
|
||||
m = torch.nn.Linear(in_features=28, out_features=64, bias=bias)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::linear'])
|
||||
|
||||
def _gen_binary_inputs(self, gen_permute=True):
|
||||
for xshape, yshape in [
|
||||
[[1, 32, 28, 28], [1, 32, 28, 28]],
|
||||
[[1, 32, 28, 28], [1, 1, 28, 28]],
|
||||
[[1, 32, 28, 28], [28]],
|
||||
[[1, 32, 28, 28], [1]],
|
||||
|
||||
]:
|
||||
yield torch.rand(xshape), torch.rand(yshape)
|
||||
if gen_permute and xshape != yshape:
|
||||
yield torch.rand(yshape), torch.rand(xshape)
|
||||
|
||||
def test_add(self):
|
||||
def forward_add(x, y):
|
||||
return torch.add(x, y, alpha=2)
|
||||
|
||||
for x, y in self._gen_binary_inputs():
|
||||
_, graph = self.checkTrace(forward_add, [x, y])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_add_scalar(self):
|
||||
def add_scalar(x):
|
||||
return 42 + x + 3.14
|
||||
|
||||
x = torch.rand(32, 32)
|
||||
_, graph = self.checkTrace(add_scalar, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_addmm(self):
|
||||
def addmm(x, y, z):
|
||||
# alpha and beta are 1, by default
|
||||
return torch.addmm(z, x, y)
|
||||
|
||||
x = torch.rand(64, 32)
|
||||
y = torch.rand(32, 32)
|
||||
z = torch.rand(64, 32)
|
||||
_, graph = self.checkTrace(addmm, [x, y, z])
|
||||
# single-op partition should be created for matmul with bias.
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_mul(self):
|
||||
def forward_mul(x, y):
|
||||
return torch.mul(x, y) * 3
|
||||
|
||||
for x, y in self._gen_binary_inputs():
|
||||
_, graph = self.checkTrace(forward_mul, [x, y])
|
||||
# single-op partitions shouldn't be created
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_identity_binary(self):
|
||||
def forward(x):
|
||||
return x * 1 + 0.0
|
||||
|
||||
x = torch.rand(32)
|
||||
_, graph = self.checkTrace(forward, [x])
|
||||
self.assertFused(graph, ['aten::add', 'aten::mul'])
|
||||
|
||||
def test_layer_norm(self):
|
||||
# TODO: support more normalized_shape
|
||||
m = torch.nn.LayerNorm(10)
|
||||
x = torch.randn(2, 5, 10, 10)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_cat(self):
|
||||
def cat_along_dim(d):
|
||||
def forward_cat(*inputs):
|
||||
return torch.cat(inputs, d)
|
||||
return forward_cat
|
||||
|
||||
for xshape in [
|
||||
[8, 8, 8, 8],
|
||||
[64, 8, 32],
|
||||
[2048, 64],
|
||||
]:
|
||||
for d in range(len(xshape)):
|
||||
x = torch.rand(xshape)
|
||||
_, graph = self.checkTrace(cat_along_dim(d), [x, x, x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
|
||||
def test_typecheck(self):
|
||||
x = torch.rand(32, 28)
|
||||
m = torch.nn.Linear(in_features=28, out_features=64, bias=True)
|
||||
traced, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::linear'])
|
||||
# change the shape of the input, we should enter fallback graph
|
||||
x = torch.rand(5, 28)
|
||||
self.assertEqual(m(x), traced(x))
|
||||
|
||||
|
||||
@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
|
||||
class TestFusionPattern(JitLlgaTestCase):
|
||||
def test_conv2d_eltwise(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=False)
|
||||
self.eltwise = eltwise_fn
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.eltwise(x)
|
||||
x = self.conv2(x)
|
||||
x = self.eltwise(x)
|
||||
return x
|
||||
|
||||
# for eltwise in ['relu', 'sigmoid', 'sqrt', 'abs', 'square', 'hardtanh']:
|
||||
for eltwise in ['relu']:
|
||||
for inplace in [True, False]:
|
||||
eltwise_fn_name = eltwise + '_' if inplace else eltwise
|
||||
eltwise_fn = get_eltwise_fn(eltwise_fn_name)
|
||||
|
||||
m = M(eltwise_fn)
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
|
||||
# test if relu_ is replace with relu by mutation removal pass
|
||||
self.assertFused(graph, ['aten::' + eltwise_fn_name])
|
||||
# test if relu is fused into the fusion group
|
||||
self.assertFused(graph, ['aten::' + eltwise])
|
||||
|
||||
def test_conv2d_bn(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.bn1 = nn.BatchNorm2d(32)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
return x
|
||||
|
||||
m = M().eval()
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm'])
|
||||
|
||||
|
||||
def test_conv2d_bn_relu(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.bn1 = nn.BatchNorm2d(32)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = F.relu(x)
|
||||
return x
|
||||
|
||||
m = M().eval()
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm',
|
||||
'aten::relu'])
|
||||
|
||||
def test_bn2d_eltwise(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn):
|
||||
super(M, self).__init__()
|
||||
self.eltwise = eltwise_fn
|
||||
self.bn = nn.BatchNorm2d(32)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.bn(x)
|
||||
x = self.eltwise(x)
|
||||
return x
|
||||
|
||||
for eltwise in ['relu']:
|
||||
eltwise_fn = get_eltwise_fn(eltwise)
|
||||
m = M(eltwise_fn).eval()
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::' + eltwise])
|
||||
|
||||
def test_linear_eltwise(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn, bias):
|
||||
super(M, self).__init__()
|
||||
self.linear = nn.Linear(28, 64, bias)
|
||||
self.eltwise = eltwise_fn
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = self.eltwise(x)
|
||||
return x
|
||||
|
||||
for [has_bias, eltwise] in itertools.product(
|
||||
[True, False],
|
||||
['relu', 'gelu', 'sigmoid', 'hardtanh', 'relu6', 'elu']):
|
||||
|
||||
eltwise_fn = get_eltwise_fn(eltwise)
|
||||
m = M(eltwise_fn, has_bias)
|
||||
x = torch.rand(32, 28, requires_grad=False)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::' + eltwise])
|
||||
|
||||
def test_conv2d_sum(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, bias=False):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
|
||||
self.bn1 = nn.BatchNorm2d(32)
|
||||
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
|
||||
self.bn2 = nn.BatchNorm2d(32)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
|
||||
self.bn3 = nn.BatchNorm2d(32)
|
||||
|
||||
def forward(self, x, y):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
y = self.conv2(y)
|
||||
y = self.bn2(y)
|
||||
z = self.relu(x + y)
|
||||
z = self.conv3(z)
|
||||
z = self.bn3(z)
|
||||
return z
|
||||
|
||||
for bias in [True, False]:
|
||||
m = M(bias).eval()
|
||||
x = torch.rand(1, 32, 16, 16, requires_grad=False)
|
||||
y = torch.rand(1, 32, 16, 16, requires_grad=False)
|
||||
_, graph = self.checkTrace(m, [x, y])
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
|
||||
|
||||
def test_wildcard(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
||||
self.eltwise = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
y = self.eltwise(x)
|
||||
return [x, y]
|
||||
|
||||
# The pattern is as the following:
|
||||
# conv
|
||||
# | \
|
||||
# eltwise \
|
||||
# | \
|
||||
# ListConstruct
|
||||
#
|
||||
# The output of conv is used by a wildcard op: ListConstruct.
|
||||
# Thus conv-eltwise cannot be selected into the same Partition.
|
||||
m = M()
|
||||
x = torch.rand(1, 32, 28, 28)
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
# conv can exist in a single-op oneDNN Graph partition but not relu
|
||||
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
|
||||
self.assertFused(graph, ['aten::_convolution'])
|
||||
|
||||
def test_rewrap_tensor_input_to_pytorch(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn, data_type):
|
||||
super(M, self).__init__()
|
||||
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True, dtype=data_type)
|
||||
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True, dtype=data_type)
|
||||
self.eltwise = eltwise_fn
|
||||
self.adaptive_avg_pool_2d = nn.AdaptiveAvgPool2d((5, 7))
|
||||
|
||||
def forward(self, x, y):
|
||||
x = self.conv1(x)
|
||||
x = self.eltwise(x)
|
||||
x = self.conv2(x)
|
||||
x = self.eltwise(x)
|
||||
x = torch.add(x, y)
|
||||
x = self.adaptive_avg_pool_2d(x)
|
||||
return x
|
||||
|
||||
eltwise_fn_name = 'relu'
|
||||
eltwise_fn = get_eltwise_fn(eltwise_fn_name)
|
||||
# Add bfloat16 later
|
||||
for data_type in [torch.float]:
|
||||
m = M(eltwise_fn, data_type)
|
||||
m = m.to(memory_format=torch.channels_last)
|
||||
x = torch.rand(1, 32, 28, 28, dtype=data_type).to(memory_format=torch.channels_last)
|
||||
y = torch.rand(1, 32, 28, 28, dtype=data_type).to(memory_format=torch.channels_last)
|
||||
# Simply test if the output is accurate
|
||||
# The output of the second partition is input to adaptive_avg_pool2d, which is
|
||||
# unsupported by LLGA, so it must be handled by PyTorch, which should receive
|
||||
# correct strides info of the channels-last tensor.
|
||||
graph, _ = self.checkTrace(m, [x, y])
|
||||
|
||||
|
||||
@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
|
||||
class TestModel(JitLlgaTestCase):
|
||||
@skipIfNoTorchVision
|
||||
def _test_vision(self, model_name):
|
||||
m = getattr(torchvision.models, model_name)().eval()
|
||||
x = torch.rand(1, 3, 224, 224) / 10
|
||||
_, graph = self.checkTrace(m, [x])
|
||||
self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm',
|
||||
'aten::relu', 'aten::linear',
|
||||
'aten::avg_pool2d', 'aten::max_pool2d'])
|
||||
|
||||
|
||||
for model_name, enabled in [
|
||||
['resnet50', True],
|
||||
['resnext50_32x4d', True],
|
||||
['resnext101_32x8d', True],
|
||||
['densenet121', True],
|
||||
['googlenet', TEST_SCIPY],
|
||||
['mobilenet_v2', True],
|
||||
['mnasnet1_0', True],
|
||||
['squeezenet1_0', True],
|
||||
['vgg16', True],
|
||||
['alexnet', True],
|
||||
['shufflenet_v2_x1_0', True],
|
||||
['wide_resnet50_2', True],
|
||||
]:
|
||||
def wrapper(mname):
|
||||
@unittest.skipIf(not enabled, 'Disabled')
|
||||
def test(self):
|
||||
return self._test_vision(mname)
|
||||
return test
|
||||
|
||||
setattr(TestModel, 'test_vision_%s' % model_name, wrapper(model_name))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
Reference in New Issue
Block a user