mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fixes a ton of false negatives throughout the codebase. RUFF also properly validates NOQA comments now and most of the changes are fixing typos there or removing filewide flake8 suppressions that were also silencing ruff issues. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153249 Approved by: https://github.com/cyyever, https://github.com/albanD, https://github.com/seemethere
124 lines
4.1 KiB
Python
124 lines
4.1 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
import copy
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo as torchdynamo
|
|
from torch.ao.quantization.pt2e.graph_utils import (
|
|
find_sequential_partitions,
|
|
get_equivalent_types,
|
|
update_equivalent_types_dict,
|
|
)
|
|
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase
|
|
|
|
|
|
class TestGraphUtils(TestCase):
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows")
|
|
def test_conv_bn_conv_relu(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 3, 3)
|
|
self.bn1 = torch.nn.BatchNorm2d(3)
|
|
self.conv2 = torch.nn.Conv2d(3, 3, 3)
|
|
self.relu2 = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
bn_out = self.bn1(self.conv1(x))
|
|
relu_out = torch.nn.functional.relu(bn_out)
|
|
return self.relu2(self.conv2(relu_out))
|
|
|
|
m = M().eval()
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
|
|
# program capture
|
|
m, guards = torchdynamo.export( # noqa: F841
|
|
m,
|
|
*copy.deepcopy(example_inputs),
|
|
aten_graph=True,
|
|
)
|
|
fused_partitions = find_sequential_partitions(
|
|
m, [torch.nn.Conv2d, torch.nn.BatchNorm2d]
|
|
)
|
|
self.assertEqual(len(fused_partitions), 1)
|
|
fused_partitions = find_sequential_partitions(
|
|
m, [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU]
|
|
)
|
|
self.assertEqual(len(fused_partitions), 1)
|
|
|
|
def x():
|
|
find_sequential_partitions(
|
|
m,
|
|
[
|
|
torch.nn.Conv2d,
|
|
torch.nn.BatchNorm2d,
|
|
torch.nn.ReLU,
|
|
torch.nn.functional.conv2d,
|
|
],
|
|
)
|
|
|
|
self.assertRaises(ValueError, x)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows")
|
|
def test_conv_bn_relu(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.bn1 = torch.nn.BatchNorm2d(3)
|
|
self.conv2 = torch.nn.Conv2d(3, 3, 3)
|
|
self.relu2 = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
bn_out = self.bn1(x)
|
|
return self.relu2(self.conv2(bn_out))
|
|
|
|
m = M().eval()
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
|
|
# program capture
|
|
m, guards = torchdynamo.export( # noqa: F841
|
|
m,
|
|
*copy.deepcopy(example_inputs),
|
|
aten_graph=True,
|
|
)
|
|
fused_partitions = find_sequential_partitions(
|
|
m, [torch.nn.Conv2d, torch.nn.BatchNorm2d]
|
|
)
|
|
self.assertEqual(len(fused_partitions), 0)
|
|
fused_partitions = find_sequential_partitions(
|
|
m, [torch.nn.BatchNorm2d, torch.nn.Conv2d]
|
|
)
|
|
self.assertEqual(len(fused_partitions), 1)
|
|
fused_partitions = find_sequential_partitions(
|
|
m, [torch.nn.BatchNorm2d, torch.nn.ReLU]
|
|
)
|
|
self.assertEqual(len(fused_partitions), 0)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows")
|
|
def test_customized_equivalet_types_dict(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.relu6(self.conv(x))
|
|
|
|
m = M().eval()
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
|
|
# program capture
|
|
m, guards = torchdynamo.export( # noqa: F841
|
|
m,
|
|
*copy.deepcopy(example_inputs),
|
|
aten_graph=True,
|
|
)
|
|
customized_equivalent_types = get_equivalent_types()
|
|
customized_equivalent_types.append({torch.nn.ReLU6, torch.nn.functional.relu6})
|
|
update_equivalent_types_dict(customized_equivalent_types)
|
|
fused_partitions = find_sequential_partitions(
|
|
m,
|
|
[torch.nn.Conv2d, torch.nn.ReLU6],
|
|
)
|
|
self.assertEqual(len(fused_partitions), 1)
|