mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964 Approved by: https://github.com/justinchuby, https://github.com/albanD
		
			
				
	
	
		
			164 lines
		
	
	
		
			6.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			164 lines
		
	
	
		
			6.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Owner(s): ["oncall: mobile"]
 | 
						|
 | 
						|
import unittest
 | 
						|
import torch
 | 
						|
from torch.nn import functional as F
 | 
						|
 | 
						|
from torch.testing._internal.common_utils import TestCase, run_tests
 | 
						|
from torch.testing import FileCheck
 | 
						|
import io
 | 
						|
 | 
						|
@unittest.skipUnless(torch.is_vulkan_available(),
 | 
						|
                     "Vulkan backend must be available for these tests.")
 | 
						|
class TestVulkanRewritePass(TestCase):
 | 
						|
    @staticmethod
 | 
						|
    def validate_transformed_module(
 | 
						|
            # To please flake
 | 
						|
            self,
 | 
						|
            pattern_count_map,
 | 
						|
            data_shape,
 | 
						|
            prepack_removal=False,
 | 
						|
            fuse_clamping_ops=False):
 | 
						|
        module_instance = self
 | 
						|
        scripted_model = torch.jit.script(module_instance)
 | 
						|
        scripted_model.eval()
 | 
						|
        input_data = torch.normal(1, 20, size=data_shape)
 | 
						|
        scripted_model(input_data)
 | 
						|
        torch._C._jit_pass_vulkan_insert_prepacked_ops(scripted_model._c)
 | 
						|
        if fuse_clamping_ops or prepack_removal:
 | 
						|
            scripted_model._c = torch._C._freeze_module(scripted_model._c)
 | 
						|
        if fuse_clamping_ops:
 | 
						|
            torch._C._jit_pass_vulkan_fuse_clamp_w_prepacked_conv(scripted_model._c)
 | 
						|
        if prepack_removal:
 | 
						|
            torch._C._jit_pass_vulkan_fold_prepacking_ops(scripted_model._c)
 | 
						|
 | 
						|
        buffer = io.BytesIO()
 | 
						|
        torch.jit.save(scripted_model, buffer)
 | 
						|
        buffer.seek(0)
 | 
						|
        deserialized_scripted_model = torch.jit.load(buffer)
 | 
						|
        for pattern, v in pattern_count_map.items():
 | 
						|
            if (v == 0):
 | 
						|
                FileCheck().check(pattern).run(deserialized_scripted_model.graph)
 | 
						|
            elif (v == -1):
 | 
						|
                FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
 | 
						|
            else:
 | 
						|
                FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
 | 
						|
 | 
						|
    def test_conv(self):
 | 
						|
        # Conv params
 | 
						|
        batch_size = 2
 | 
						|
        input_channels_per_group = 6
 | 
						|
        height = 16
 | 
						|
        width = 16
 | 
						|
        output_channels_per_group = 6
 | 
						|
        groups = 4
 | 
						|
        kernel_h = kernel_w = 3
 | 
						|
        stride_h = stride_w = 1
 | 
						|
        pad_h = pad_w = 1
 | 
						|
        dilation = 1
 | 
						|
        input_channels = input_channels_per_group * groups
 | 
						|
        output_channels = output_channels_per_group * groups
 | 
						|
        strides = (stride_h, stride_w)
 | 
						|
        paddings = (pad_h, pad_w)
 | 
						|
        dilations = (dilation, dilation)
 | 
						|
        conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
 | 
						|
        conv_bias_shape = (output_channels)
 | 
						|
 | 
						|
        class Conv2D(torch.nn.Module):
 | 
						|
            def __init__(self) -> None:
 | 
						|
                super().__init__()
 | 
						|
                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
 | 
						|
                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
 | 
						|
                self.strides = strides
 | 
						|
                self.paddings = paddings
 | 
						|
                self.dilations = dilations
 | 
						|
                self.groups = groups
 | 
						|
 | 
						|
            def forward(self, x):
 | 
						|
                return F.conv2d(x, self.weight, self.bias,
 | 
						|
                                self.strides, self.paddings, self.dilations, self.groups)
 | 
						|
 | 
						|
        data_shape = (batch_size, input_channels, height, width)
 | 
						|
        pattern_count_map = {"Tensor = aten::conv2d": -1,
 | 
						|
                             "vulkan_prepack::conv2d_clamp_prepack": 1,
 | 
						|
                             "vulkan_prepack::conv2d_clamp_run": 1}
 | 
						|
        TestVulkanRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
 | 
						|
 | 
						|
        class Conv2DRelu(torch.nn.Module):
 | 
						|
            def __init__(self) -> None:
 | 
						|
                super().__init__()
 | 
						|
                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
 | 
						|
                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
 | 
						|
                self.strides = strides
 | 
						|
                self.paddings = paddings
 | 
						|
                self.dilations = dilations
 | 
						|
                self.groups = groups
 | 
						|
 | 
						|
            def forward(self, x):
 | 
						|
                o = F.conv2d(x, self.weight, self.bias,
 | 
						|
                             self.strides, self.paddings, self.dilations, self.groups)
 | 
						|
                o = F.relu(o)
 | 
						|
                return o
 | 
						|
 | 
						|
        data_shape = (batch_size, input_channels, height, width)
 | 
						|
        pattern_count_map = {"Tensor = aten::conv2d": -1,
 | 
						|
                             "vulkan_prepack::conv2d_clamp_prepack": 1,
 | 
						|
                             "vulkan_prepack::conv2d_clamp_run": 1}
 | 
						|
        TestVulkanRewritePass.validate_transformed_module(
 | 
						|
            Conv2DRelu(), pattern_count_map, data_shape)
 | 
						|
 | 
						|
        pattern_count_map["aten::relu"] = 1
 | 
						|
        pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
 | 
						|
        TestVulkanRewritePass.validate_transformed_module(
 | 
						|
            Conv2DRelu(),
 | 
						|
            pattern_count_map,
 | 
						|
            data_shape,
 | 
						|
            prepack_removal=True)
 | 
						|
        pattern_count_map["aten::relu"] = -1
 | 
						|
        TestVulkanRewritePass.validate_transformed_module(
 | 
						|
            Conv2DRelu(),
 | 
						|
            pattern_count_map,
 | 
						|
            data_shape,
 | 
						|
            prepack_removal=True,
 | 
						|
            fuse_clamping_ops=True)
 | 
						|
 | 
						|
 | 
						|
        class Conv2DHardtanh(torch.nn.Module):
 | 
						|
            def __init__(self) -> None:
 | 
						|
                super().__init__()
 | 
						|
                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
 | 
						|
                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
 | 
						|
                self.strides = strides
 | 
						|
                self.paddings = paddings
 | 
						|
                self.dilations = dilations
 | 
						|
                self.groups = groups
 | 
						|
 | 
						|
            def forward(self, x):
 | 
						|
                o = F.conv2d(x, self.weight, self.bias,
 | 
						|
                             self.strides, self.paddings, self.dilations, self.groups)
 | 
						|
                o = F.hardtanh(o)
 | 
						|
                return o
 | 
						|
 | 
						|
        data_shape = (batch_size, input_channels, height, width)
 | 
						|
        pattern_count_map = {"Tensor = aten::conv2d": -1,
 | 
						|
                             "vulkan_prepack::conv2d_clamp_prepack": 1,
 | 
						|
                             "vulkan_prepack::conv2d_clamp_run": 1}
 | 
						|
        TestVulkanRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape)
 | 
						|
        pattern_count_map["aten::hardtanh"] = 1
 | 
						|
        pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
 | 
						|
        TestVulkanRewritePass.validate_transformed_module(
 | 
						|
            Conv2DHardtanh(),
 | 
						|
            pattern_count_map,
 | 
						|
            data_shape,
 | 
						|
            prepack_removal=True)
 | 
						|
        pattern_count_map["aten::hardtanh"] = -1
 | 
						|
        TestVulkanRewritePass.validate_transformed_module(
 | 
						|
            Conv2DRelu(),
 | 
						|
            pattern_count_map,
 | 
						|
            data_shape,
 | 
						|
            prepack_removal=True,
 | 
						|
            fuse_clamping_ops=True)
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    run_tests()
 |