mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert D30279364: [codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: revert-hammer
Differential Revision:
D30279364 (b004307252
)
Original commit changeset: c1ed77dfe43a
fbshipit-source-id: eab50857675c51e0088391af06ec0ecb14e2347e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ed0b8a3e83
commit
1022443168
@ -1,21 +1,19 @@
|
||||
import io
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing import FileCheck
|
||||
import io
|
||||
|
||||
class TestMetalRewritePass(TestCase):
|
||||
@staticmethod
|
||||
def validate_transformed_module(
|
||||
# To please flake
|
||||
self,
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=False,
|
||||
fuse_clamping_ops=False,
|
||||
):
|
||||
# To please flake
|
||||
self,
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=False,
|
||||
fuse_clamping_ops=False):
|
||||
module_instance = self
|
||||
scripted_model = torch.jit.script(module_instance)
|
||||
scripted_model.eval()
|
||||
@ -34,14 +32,12 @@ class TestMetalRewritePass(TestCase):
|
||||
buffer.seek(0)
|
||||
deserialized_scripted_model = torch.jit.load(buffer)
|
||||
for pattern, v in pattern_count_map.items():
|
||||
if v == 0:
|
||||
if (v == 0):
|
||||
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
|
||||
elif v == -1:
|
||||
elif (v == -1):
|
||||
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
|
||||
else:
|
||||
FileCheck().check_count(pattern, v, exactly=True).run(
|
||||
deserialized_scripted_model.graph
|
||||
)
|
||||
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
|
||||
|
||||
def test_conv(self):
|
||||
# Conv params
|
||||
@ -61,150 +57,103 @@ class TestMetalRewritePass(TestCase):
|
||||
strides = (stride_h, stride_w)
|
||||
paddings = (pad_h, pad_w)
|
||||
dilations = (dilation, dilation)
|
||||
conv_weight_shape = (
|
||||
output_channels,
|
||||
input_channels_per_group,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
)
|
||||
conv_bias_shape = output_channels
|
||||
conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
|
||||
conv_bias_shape = (output_channels)
|
||||
|
||||
class Conv2D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Conv2D, self).__init__()
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.rand(conv_weight_shape), requires_grad=False
|
||||
)
|
||||
self.bias = torch.nn.Parameter(
|
||||
torch.rand(conv_bias_shape), requires_grad=False
|
||||
)
|
||||
self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
|
||||
self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
|
||||
self.strides = strides
|
||||
self.paddings = paddings
|
||||
self.dilations = dilations
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
return F.conv2d(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.strides,
|
||||
self.paddings,
|
||||
self.dilations,
|
||||
self.groups,
|
||||
)
|
||||
return F.conv2d(x, self.weight, self.bias,
|
||||
self.strides, self.paddings, self.dilations, self.groups)
|
||||
|
||||
data_shape = (batch_size, input_channels, height, width)
|
||||
pattern_count_map = {
|
||||
"Tensor = aten::conv2d": -1,
|
||||
"metal_prepack::conv2d_prepack": 1,
|
||||
"metal_prepack::conv2d_run": 1,
|
||||
}
|
||||
TestMetalRewritePass.validate_transformed_module(
|
||||
Conv2D(), pattern_count_map, data_shape
|
||||
)
|
||||
pattern_count_map = {"Tensor = aten::conv2d": -1,
|
||||
"metal_prepack::conv2d_prepack": 1,
|
||||
"metal_prepack::conv2d_run": 1}
|
||||
TestMetalRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
|
||||
|
||||
class Conv2DRelu(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Conv2DRelu, self).__init__()
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.rand(conv_weight_shape), requires_grad=False
|
||||
)
|
||||
self.bias = torch.nn.Parameter(
|
||||
torch.rand(conv_bias_shape), requires_grad=False
|
||||
)
|
||||
self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
|
||||
self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
|
||||
self.strides = strides
|
||||
self.paddings = paddings
|
||||
self.dilations = dilations
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
o = F.conv2d(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.strides,
|
||||
self.paddings,
|
||||
self.dilations,
|
||||
self.groups,
|
||||
)
|
||||
o = F.conv2d(x, self.weight, self.bias,
|
||||
self.strides, self.paddings, self.dilations, self.groups)
|
||||
o = F.relu(o)
|
||||
return o
|
||||
|
||||
data_shape = (batch_size, input_channels, height, width)
|
||||
pattern_count_map = {
|
||||
"Tensor = aten::conv2d": -1,
|
||||
"metal_prepack::conv2d_prepack": 1,
|
||||
"metal_prepack::conv2d_run": 1,
|
||||
}
|
||||
pattern_count_map = {"Tensor = aten::conv2d": -1,
|
||||
"metal_prepack::conv2d_prepack": 1,
|
||||
"metal_prepack::conv2d_run": 1}
|
||||
TestMetalRewritePass.validate_transformed_module(
|
||||
Conv2DRelu(), pattern_count_map, data_shape
|
||||
)
|
||||
Conv2DRelu(), pattern_count_map, data_shape)
|
||||
|
||||
pattern_count_map["aten::relu"] = 1
|
||||
pattern_count_map["metal_prepack::conv2d_prepack"] = -1
|
||||
TestMetalRewritePass.validate_transformed_module(
|
||||
Conv2DRelu(), pattern_count_map, data_shape, prepack_removal=True
|
||||
)
|
||||
Conv2DRelu(),
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=True)
|
||||
pattern_count_map["aten::relu"] = -1
|
||||
TestMetalRewritePass.validate_transformed_module(
|
||||
Conv2DRelu(),
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=True,
|
||||
fuse_clamping_ops=True,
|
||||
)
|
||||
fuse_clamping_ops=True)
|
||||
|
||||
|
||||
class Conv2DHardtanh(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Conv2DHardtanh, self).__init__()
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.rand(conv_weight_shape), requires_grad=False
|
||||
)
|
||||
self.bias = torch.nn.Parameter(
|
||||
torch.rand(conv_bias_shape), requires_grad=False
|
||||
)
|
||||
self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
|
||||
self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
|
||||
self.strides = strides
|
||||
self.paddings = paddings
|
||||
self.dilations = dilations
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
o = F.conv2d(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.strides,
|
||||
self.paddings,
|
||||
self.dilations,
|
||||
self.groups,
|
||||
)
|
||||
o = F.conv2d(x, self.weight, self.bias,
|
||||
self.strides, self.paddings, self.dilations, self.groups)
|
||||
o = F.hardtanh(o)
|
||||
return o
|
||||
|
||||
data_shape = (batch_size, input_channels, height, width)
|
||||
pattern_count_map = {
|
||||
"Tensor = aten::conv2d": -1,
|
||||
"metal_prepack::conv2d_prepack": 1,
|
||||
"metal_prepack::conv2d_run": 1,
|
||||
}
|
||||
TestMetalRewritePass.validate_transformed_module(
|
||||
Conv2DHardtanh(), pattern_count_map, data_shape
|
||||
)
|
||||
pattern_count_map = {"Tensor = aten::conv2d": -1,
|
||||
"metal_prepack::conv2d_prepack": 1,
|
||||
"metal_prepack::conv2d_run": 1}
|
||||
TestMetalRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape)
|
||||
pattern_count_map["aten::hardtanh"] = 1
|
||||
pattern_count_map["metal_prepack::conv2d_prepack"] = -1
|
||||
TestMetalRewritePass.validate_transformed_module(
|
||||
Conv2DHardtanh(), pattern_count_map, data_shape, prepack_removal=True
|
||||
)
|
||||
Conv2DHardtanh(),
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=True)
|
||||
pattern_count_map["aten::hardtanh"] = -1
|
||||
TestMetalRewritePass.validate_transformed_module(
|
||||
Conv2DRelu(),
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=True,
|
||||
fuse_clamping_ops=True,
|
||||
)
|
||||
|
||||
fuse_clamping_ops=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user