[Metal] Add the Python binding for optimize_for_mobile (#46456)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46456

Add the python binding in CMake. The general workflow is

- Build pytorch -  `USE_PYTORCH_METAL=ON python setup.py install --cmake`
- Run optimize_for_mobile

```
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile

scripted_model = torch.jit.load('./mobilenetv2.pt')
optimized_model = optimize_for_mobile(scripted_model, backend='metal')
torch.jit.export_opnames(optimized_model)
torch.jit.save(optimized_model, './mobilenetv2_metal.bc')
```
The exported ops are

```
['aten::adaptive_avg_pool2d', 'aten::add.Tensor', 'aten::addmm', 'aten::reshape', 'aten::size.int', 'metal::copy_to_host', 'metal_prepack::conv2d_run']
```
ghstack-source-id: 114559878

Test Plan:
- Sandcastle CI
- Circle CI

Reviewed By: kimishpatel

Differential Revision: D24356768

fbshipit-source-id: fb5c4c4b6316347b67edb4132da044a81470ddfd
This commit is contained in:
Tao Xu
2020-10-17 10:22:55 -07:00
committed by Facebook GitHub Bot
parent e8ff0f6c5c
commit 495070b388
6 changed files with 212 additions and 6 deletions

159
test/test_metal.py Normal file
View File

@ -0,0 +1,159 @@
import torch
from torch.nn import functional as F
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing import FileCheck
import io
class TestMetalRewritePass(TestCase):
@staticmethod
def validate_transformed_module(
# To please flake
self,
pattern_count_map,
data_shape,
prepack_removal=False,
fuse_clamping_ops=False):
module_instance = self
scripted_model = torch.jit.script(module_instance)
scripted_model.eval()
input_data = torch.normal(1, 20, size=data_shape)
ref_result = scripted_model(input_data)
torch._C._jit_pass_metal_insert_prepacked_ops(scripted_model._c)
if fuse_clamping_ops or prepack_removal:
scripted_model._c = torch._C._freeze_module(scripted_model._c)
if fuse_clamping_ops:
torch._C._jit_pass_metal_fuse_clamp_w_prepacked_conv(scripted_model._c)
if prepack_removal:
torch._C._jit_pass_metal_fold_prepacking_ops(scripted_model._c)
buffer = io.BytesIO()
torch.jit.save(scripted_model, buffer)
buffer.seek(0)
deserialized_scripted_model = torch.jit.load(buffer)
for pattern, v in pattern_count_map.items():
if (v == 0):
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
elif (v == -1):
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
else:
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
def test_conv(self):
# Conv params
batch_size = 2
input_channels_per_group = 6
height = 16
width = 16
output_channels_per_group = 6
groups = 4
kernel_h = kernel_w = 3
stride_h = stride_w = 1
pad_h = pad_w = 1
dilation = 1
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
kernels = (kernel_h, kernel_w)
strides = (stride_h, stride_w)
paddings = (pad_h, pad_w)
dilations = (dilation, dilation)
conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
conv_bias_shape = (output_channels)
class Conv2D(torch.nn.Module):
def __init__(self):
super(Conv2D, self).__init__()
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
self.strides = strides
self.paddings = paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
return F.conv2d(x, self.weight, self.bias,
self.strides, self.paddings, self.dilations, self.groups)
data_shape = (batch_size, input_channels, height, width)
pattern_count_map = {"Tensor = aten::conv2d": -1,
"metal_prepack::conv2d_prepack": 1,
"metal_prepack::conv2d_run": 1}
TestMetalRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
class Conv2DRelu(torch.nn.Module):
def __init__(self):
super(Conv2DRelu, self).__init__()
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
self.strides = strides
self.paddings = paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
o = F.conv2d(x, self.weight, self.bias,
self.strides, self.paddings, self.dilations, self.groups)
o = F.relu(o)
return o
data_shape = (batch_size, input_channels, height, width)
pattern_count_map = {"Tensor = aten::conv2d": -1,
"metal_prepack::conv2d_prepack": 1,
"metal_prepack::conv2d_run": 1}
TestMetalRewritePass.validate_transformed_module(
Conv2DRelu(), pattern_count_map, data_shape)
pattern_count_map["aten::relu"] = 1
pattern_count_map["metal_prepack::conv2d_prepack"] = -1
TestMetalRewritePass.validate_transformed_module(
Conv2DRelu(),
pattern_count_map,
data_shape,
prepack_removal=True)
pattern_count_map["aten::relu"] = -1
TestMetalRewritePass.validate_transformed_module(
Conv2DRelu(),
pattern_count_map,
data_shape,
prepack_removal=True,
fuse_clamping_ops=True)
class Conv2DHardtanh(torch.nn.Module):
def __init__(self):
super(Conv2DHardtanh, self).__init__()
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
self.strides = strides
self.paddings = paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
o = F.conv2d(x, self.weight, self.bias,
self.strides, self.paddings, self.dilations, self.groups)
o = F.hardtanh(o)
return o
data_shape = (batch_size, input_channels, height, width)
pattern_count_map = {"Tensor = aten::conv2d": -1,
"metal_prepack::conv2d_prepack": 1,
"metal_prepack::conv2d_run": 1}
TestMetalRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape)
pattern_count_map["aten::hardtanh"] = 1
pattern_count_map["metal_prepack::conv2d_prepack"] = -1
TestMetalRewritePass.validate_transformed_module(
Conv2DHardtanh(),
pattern_count_map,
data_shape,
prepack_removal=True)
pattern_count_map["aten::hardtanh"] = -1
TestMetalRewritePass.validate_transformed_module(
Conv2DRelu(),
pattern_count_map,
data_shape,
prepack_removal=True,
fuse_clamping_ops=True)
if __name__ == "__main__":
run_tests()