mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Test Plan: manual inspection & sandcastle Reviewed By: zertosh Differential Revision: D30279364 fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
211 lines
7.2 KiB
Python
211 lines
7.2 KiB
Python
import io
|
|
|
|
import torch
|
|
from torch.nn import functional as F
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
|
|
|
|
class TestMetalRewritePass(TestCase):
|
|
@staticmethod
|
|
def validate_transformed_module(
|
|
# To please flake
|
|
self,
|
|
pattern_count_map,
|
|
data_shape,
|
|
prepack_removal=False,
|
|
fuse_clamping_ops=False,
|
|
):
|
|
module_instance = self
|
|
scripted_model = torch.jit.script(module_instance)
|
|
scripted_model.eval()
|
|
input_data = torch.normal(1, 20, size=data_shape)
|
|
ref_result = scripted_model(input_data)
|
|
torch._C._jit_pass_metal_insert_prepacked_ops(scripted_model._c)
|
|
if fuse_clamping_ops or prepack_removal:
|
|
scripted_model._c = torch._C._freeze_module(scripted_model._c)
|
|
if fuse_clamping_ops:
|
|
torch._C._jit_pass_metal_fuse_clamp_w_prepacked_conv(scripted_model._c)
|
|
if prepack_removal:
|
|
torch._C._jit_pass_metal_fold_prepacking_ops(scripted_model._c)
|
|
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(scripted_model, buffer)
|
|
buffer.seek(0)
|
|
deserialized_scripted_model = torch.jit.load(buffer)
|
|
for pattern, v in pattern_count_map.items():
|
|
if v == 0:
|
|
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
|
|
elif v == -1:
|
|
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
|
|
else:
|
|
FileCheck().check_count(pattern, v, exactly=True).run(
|
|
deserialized_scripted_model.graph
|
|
)
|
|
|
|
def test_conv(self):
|
|
# Conv params
|
|
batch_size = 2
|
|
input_channels_per_group = 6
|
|
height = 16
|
|
width = 16
|
|
output_channels_per_group = 6
|
|
groups = 4
|
|
kernel_h = kernel_w = 3
|
|
stride_h = stride_w = 1
|
|
pad_h = pad_w = 1
|
|
dilation = 1
|
|
input_channels = input_channels_per_group * groups
|
|
output_channels = output_channels_per_group * groups
|
|
kernels = (kernel_h, kernel_w)
|
|
strides = (stride_h, stride_w)
|
|
paddings = (pad_h, pad_w)
|
|
dilations = (dilation, dilation)
|
|
conv_weight_shape = (
|
|
output_channels,
|
|
input_channels_per_group,
|
|
kernel_h,
|
|
kernel_w,
|
|
)
|
|
conv_bias_shape = output_channels
|
|
|
|
class Conv2D(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Conv2D, self).__init__()
|
|
self.weight = torch.nn.Parameter(
|
|
torch.rand(conv_weight_shape), requires_grad=False
|
|
)
|
|
self.bias = torch.nn.Parameter(
|
|
torch.rand(conv_bias_shape), requires_grad=False
|
|
)
|
|
self.strides = strides
|
|
self.paddings = paddings
|
|
self.dilations = dilations
|
|
self.groups = groups
|
|
|
|
def forward(self, x):
|
|
return F.conv2d(
|
|
x,
|
|
self.weight,
|
|
self.bias,
|
|
self.strides,
|
|
self.paddings,
|
|
self.dilations,
|
|
self.groups,
|
|
)
|
|
|
|
data_shape = (batch_size, input_channels, height, width)
|
|
pattern_count_map = {
|
|
"Tensor = aten::conv2d": -1,
|
|
"metal_prepack::conv2d_prepack": 1,
|
|
"metal_prepack::conv2d_run": 1,
|
|
}
|
|
TestMetalRewritePass.validate_transformed_module(
|
|
Conv2D(), pattern_count_map, data_shape
|
|
)
|
|
|
|
class Conv2DRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Conv2DRelu, self).__init__()
|
|
self.weight = torch.nn.Parameter(
|
|
torch.rand(conv_weight_shape), requires_grad=False
|
|
)
|
|
self.bias = torch.nn.Parameter(
|
|
torch.rand(conv_bias_shape), requires_grad=False
|
|
)
|
|
self.strides = strides
|
|
self.paddings = paddings
|
|
self.dilations = dilations
|
|
self.groups = groups
|
|
|
|
def forward(self, x):
|
|
o = F.conv2d(
|
|
x,
|
|
self.weight,
|
|
self.bias,
|
|
self.strides,
|
|
self.paddings,
|
|
self.dilations,
|
|
self.groups,
|
|
)
|
|
o = F.relu(o)
|
|
return o
|
|
|
|
data_shape = (batch_size, input_channels, height, width)
|
|
pattern_count_map = {
|
|
"Tensor = aten::conv2d": -1,
|
|
"metal_prepack::conv2d_prepack": 1,
|
|
"metal_prepack::conv2d_run": 1,
|
|
}
|
|
TestMetalRewritePass.validate_transformed_module(
|
|
Conv2DRelu(), pattern_count_map, data_shape
|
|
)
|
|
|
|
pattern_count_map["aten::relu"] = 1
|
|
pattern_count_map["metal_prepack::conv2d_prepack"] = -1
|
|
TestMetalRewritePass.validate_transformed_module(
|
|
Conv2DRelu(), pattern_count_map, data_shape, prepack_removal=True
|
|
)
|
|
pattern_count_map["aten::relu"] = -1
|
|
TestMetalRewritePass.validate_transformed_module(
|
|
Conv2DRelu(),
|
|
pattern_count_map,
|
|
data_shape,
|
|
prepack_removal=True,
|
|
fuse_clamping_ops=True,
|
|
)
|
|
|
|
class Conv2DHardtanh(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Conv2DHardtanh, self).__init__()
|
|
self.weight = torch.nn.Parameter(
|
|
torch.rand(conv_weight_shape), requires_grad=False
|
|
)
|
|
self.bias = torch.nn.Parameter(
|
|
torch.rand(conv_bias_shape), requires_grad=False
|
|
)
|
|
self.strides = strides
|
|
self.paddings = paddings
|
|
self.dilations = dilations
|
|
self.groups = groups
|
|
|
|
def forward(self, x):
|
|
o = F.conv2d(
|
|
x,
|
|
self.weight,
|
|
self.bias,
|
|
self.strides,
|
|
self.paddings,
|
|
self.dilations,
|
|
self.groups,
|
|
)
|
|
o = F.hardtanh(o)
|
|
return o
|
|
|
|
data_shape = (batch_size, input_channels, height, width)
|
|
pattern_count_map = {
|
|
"Tensor = aten::conv2d": -1,
|
|
"metal_prepack::conv2d_prepack": 1,
|
|
"metal_prepack::conv2d_run": 1,
|
|
}
|
|
TestMetalRewritePass.validate_transformed_module(
|
|
Conv2DHardtanh(), pattern_count_map, data_shape
|
|
)
|
|
pattern_count_map["aten::hardtanh"] = 1
|
|
pattern_count_map["metal_prepack::conv2d_prepack"] = -1
|
|
TestMetalRewritePass.validate_transformed_module(
|
|
Conv2DHardtanh(), pattern_count_map, data_shape, prepack_removal=True
|
|
)
|
|
pattern_count_map["aten::hardtanh"] = -1
|
|
TestMetalRewritePass.validate_transformed_module(
|
|
Conv2DRelu(),
|
|
pattern_count_map,
|
|
data_shape,
|
|
prepack_removal=True,
|
|
fuse_clamping_ops=True,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|