Files
pytorch/test/quantization/pt2e/test_graph_utils.py
Aaron Gokaslan 3555ebb63d [BE]: Update ruff to 0.11.8 (#153249)
Fixes a ton of false negatives throughout the codebase. RUFF also properly validates NOQA comments now and most of the changes are fixing typos there or removing filewide flake8 suppressions that were also silencing ruff issues.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153249
Approved by: https://github.com/cyyever, https://github.com/albanD, https://github.com/seemethere
2025-05-12 18:30:52 +00:00

124 lines
4.1 KiB
Python

# Owner(s): ["oncall: quantization"]
import copy
import unittest
import torch
import torch._dynamo as torchdynamo
from torch.ao.quantization.pt2e.graph_utils import (
find_sequential_partitions,
get_equivalent_types,
update_equivalent_types_dict,
)
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase
class TestGraphUtils(TestCase):
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows")
def test_conv_bn_conv_relu(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.bn1 = torch.nn.BatchNorm2d(3)
self.conv2 = torch.nn.Conv2d(3, 3, 3)
self.relu2 = torch.nn.ReLU()
def forward(self, x):
bn_out = self.bn1(self.conv1(x))
relu_out = torch.nn.functional.relu(bn_out)
return self.relu2(self.conv2(relu_out))
m = M().eval()
example_inputs = (torch.randn(1, 3, 5, 5),)
# program capture
m, guards = torchdynamo.export( # noqa: F841
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
)
fused_partitions = find_sequential_partitions(
m, [torch.nn.Conv2d, torch.nn.BatchNorm2d]
)
self.assertEqual(len(fused_partitions), 1)
fused_partitions = find_sequential_partitions(
m, [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU]
)
self.assertEqual(len(fused_partitions), 1)
def x():
find_sequential_partitions(
m,
[
torch.nn.Conv2d,
torch.nn.BatchNorm2d,
torch.nn.ReLU,
torch.nn.functional.conv2d,
],
)
self.assertRaises(ValueError, x)
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows")
def test_conv_bn_relu(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.bn1 = torch.nn.BatchNorm2d(3)
self.conv2 = torch.nn.Conv2d(3, 3, 3)
self.relu2 = torch.nn.ReLU()
def forward(self, x):
bn_out = self.bn1(x)
return self.relu2(self.conv2(bn_out))
m = M().eval()
example_inputs = (torch.randn(1, 3, 5, 5),)
# program capture
m, guards = torchdynamo.export( # noqa: F841
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
)
fused_partitions = find_sequential_partitions(
m, [torch.nn.Conv2d, torch.nn.BatchNorm2d]
)
self.assertEqual(len(fused_partitions), 0)
fused_partitions = find_sequential_partitions(
m, [torch.nn.BatchNorm2d, torch.nn.Conv2d]
)
self.assertEqual(len(fused_partitions), 1)
fused_partitions = find_sequential_partitions(
m, [torch.nn.BatchNorm2d, torch.nn.ReLU]
)
self.assertEqual(len(fused_partitions), 0)
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows")
def test_customized_equivalet_types_dict(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
return torch.nn.functional.relu6(self.conv(x))
m = M().eval()
example_inputs = (torch.randn(1, 3, 5, 5),)
# program capture
m, guards = torchdynamo.export( # noqa: F841
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
)
customized_equivalent_types = get_equivalent_types()
customized_equivalent_types.append({torch.nn.ReLU6, torch.nn.functional.relu6})
update_equivalent_types_dict(customized_equivalent_types)
fused_partitions = find_sequential_partitions(
m,
[torch.nn.Conv2d, torch.nn.ReLU6],
)
self.assertEqual(len(fused_partitions), 1)