mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964 Approved by: https://github.com/justinchuby, https://github.com/albanD
164 lines
6.7 KiB
Python
164 lines
6.7 KiB
Python
# Owner(s): ["oncall: mobile"]
|
|
|
|
import unittest
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
from torch.testing import FileCheck
|
|
import io
|
|
|
|
@unittest.skipUnless(torch.is_vulkan_available(),
|
|
"Vulkan backend must be available for these tests.")
|
|
class TestVulkanRewritePass(TestCase):
|
|
@staticmethod
|
|
def validate_transformed_module(
|
|
# To please flake
|
|
self,
|
|
pattern_count_map,
|
|
data_shape,
|
|
prepack_removal=False,
|
|
fuse_clamping_ops=False):
|
|
module_instance = self
|
|
scripted_model = torch.jit.script(module_instance)
|
|
scripted_model.eval()
|
|
input_data = torch.normal(1, 20, size=data_shape)
|
|
scripted_model(input_data)
|
|
torch._C._jit_pass_vulkan_insert_prepacked_ops(scripted_model._c)
|
|
if fuse_clamping_ops or prepack_removal:
|
|
scripted_model._c = torch._C._freeze_module(scripted_model._c)
|
|
if fuse_clamping_ops:
|
|
torch._C._jit_pass_vulkan_fuse_clamp_w_prepacked_conv(scripted_model._c)
|
|
if prepack_removal:
|
|
torch._C._jit_pass_vulkan_fold_prepacking_ops(scripted_model._c)
|
|
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(scripted_model, buffer)
|
|
buffer.seek(0)
|
|
deserialized_scripted_model = torch.jit.load(buffer)
|
|
for pattern, v in pattern_count_map.items():
|
|
if (v == 0):
|
|
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
|
|
elif (v == -1):
|
|
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
|
|
else:
|
|
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
|
|
|
|
def test_conv(self):
|
|
# Conv params
|
|
batch_size = 2
|
|
input_channels_per_group = 6
|
|
height = 16
|
|
width = 16
|
|
output_channels_per_group = 6
|
|
groups = 4
|
|
kernel_h = kernel_w = 3
|
|
stride_h = stride_w = 1
|
|
pad_h = pad_w = 1
|
|
dilation = 1
|
|
input_channels = input_channels_per_group * groups
|
|
output_channels = output_channels_per_group * groups
|
|
strides = (stride_h, stride_w)
|
|
paddings = (pad_h, pad_w)
|
|
dilations = (dilation, dilation)
|
|
conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
|
|
conv_bias_shape = (output_channels)
|
|
|
|
class Conv2D(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
|
|
self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
|
|
self.strides = strides
|
|
self.paddings = paddings
|
|
self.dilations = dilations
|
|
self.groups = groups
|
|
|
|
def forward(self, x):
|
|
return F.conv2d(x, self.weight, self.bias,
|
|
self.strides, self.paddings, self.dilations, self.groups)
|
|
|
|
data_shape = (batch_size, input_channels, height, width)
|
|
pattern_count_map = {"Tensor = aten::conv2d": -1,
|
|
"vulkan_prepack::conv2d_clamp_prepack": 1,
|
|
"vulkan_prepack::conv2d_clamp_run": 1}
|
|
TestVulkanRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
|
|
|
|
class Conv2DRelu(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
|
|
self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
|
|
self.strides = strides
|
|
self.paddings = paddings
|
|
self.dilations = dilations
|
|
self.groups = groups
|
|
|
|
def forward(self, x):
|
|
o = F.conv2d(x, self.weight, self.bias,
|
|
self.strides, self.paddings, self.dilations, self.groups)
|
|
o = F.relu(o)
|
|
return o
|
|
|
|
data_shape = (batch_size, input_channels, height, width)
|
|
pattern_count_map = {"Tensor = aten::conv2d": -1,
|
|
"vulkan_prepack::conv2d_clamp_prepack": 1,
|
|
"vulkan_prepack::conv2d_clamp_run": 1}
|
|
TestVulkanRewritePass.validate_transformed_module(
|
|
Conv2DRelu(), pattern_count_map, data_shape)
|
|
|
|
pattern_count_map["aten::relu"] = 1
|
|
pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
|
|
TestVulkanRewritePass.validate_transformed_module(
|
|
Conv2DRelu(),
|
|
pattern_count_map,
|
|
data_shape,
|
|
prepack_removal=True)
|
|
pattern_count_map["aten::relu"] = -1
|
|
TestVulkanRewritePass.validate_transformed_module(
|
|
Conv2DRelu(),
|
|
pattern_count_map,
|
|
data_shape,
|
|
prepack_removal=True,
|
|
fuse_clamping_ops=True)
|
|
|
|
|
|
class Conv2DHardtanh(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
|
|
self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
|
|
self.strides = strides
|
|
self.paddings = paddings
|
|
self.dilations = dilations
|
|
self.groups = groups
|
|
|
|
def forward(self, x):
|
|
o = F.conv2d(x, self.weight, self.bias,
|
|
self.strides, self.paddings, self.dilations, self.groups)
|
|
o = F.hardtanh(o)
|
|
return o
|
|
|
|
data_shape = (batch_size, input_channels, height, width)
|
|
pattern_count_map = {"Tensor = aten::conv2d": -1,
|
|
"vulkan_prepack::conv2d_clamp_prepack": 1,
|
|
"vulkan_prepack::conv2d_clamp_run": 1}
|
|
TestVulkanRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape)
|
|
pattern_count_map["aten::hardtanh"] = 1
|
|
pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
|
|
TestVulkanRewritePass.validate_transformed_module(
|
|
Conv2DHardtanh(),
|
|
pattern_count_map,
|
|
data_shape,
|
|
prepack_removal=True)
|
|
pattern_count_map["aten::hardtanh"] = -1
|
|
TestVulkanRewritePass.validate_transformed_module(
|
|
Conv2DRelu(),
|
|
pattern_count_map,
|
|
data_shape,
|
|
prepack_removal=True,
|
|
fuse_clamping_ops=True)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|