Files
pytorch/test/inductor/test_efficient_conv_bn_eval.py
Yuanyuan Chen 8de85896e0 Enable ruff rule E721 (#165162)
`E721` checks for object type comparisons using == and other comparison operators. This is useful because it is recommended to use `is` for type comparisons.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165162
Approved by: https://github.com/Skylion007
2025-10-13 01:48:55 +00:00

225 lines
7.1 KiB
Python

# Owner(s): ["module: inductor"]
import copy
import importlib
import itertools
import os
import sys
import torch
from torch import nn
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch._dynamo.utils import counters
from torch._inductor import config as inductor_config
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_cuda import tf32_on_and_off
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
importlib.import_module("functorch")
importlib.import_module("filelock")
from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library
copy_tests,
)
class ConvOp(nn.Module):
expected_optimization_count = 1
def __init__(
self,
conv_class,
bn_class,
use_bias,
in_channels,
out_channels,
device,
**kwargs,
):
super().__init__()
self.conv = conv_class(in_channels, out_channels, bias=use_bias, **kwargs).to(
device
)
self.bn = bn_class(out_channels).to(device)
def forward(self, x):
x = self.conv(x)
return self.bn(x)
class MultiUserConvOp(nn.Module):
expected_optimization_count = 3
def __init__(
self,
conv_class,
bn_class,
use_bias,
in_channels,
out_channels,
device,
**kwargs,
):
super().__init__()
self.conv1 = conv_class(in_channels, out_channels, bias=use_bias, **kwargs).to(
device
)
self.bn1 = bn_class(out_channels).to(device)
self.conv2 = conv_class(out_channels, out_channels, bias=use_bias, **kwargs).to(
device
)
self.bn2 = bn_class(out_channels).to(device)
self.conv3 = conv_class(out_channels, out_channels, bias=use_bias, **kwargs).to(
device
)
self.bn3 = bn_class(out_channels).to(device)
def forward(self, x):
# this conv-bn pair can use efficient_conv_bn_eval
x = self.bn1(self.conv1(input=x))
# this conv-bn pair cannot use efficient_conv_bn_eval feature
# just for the second forward of the `self.conv2`
x = self.bn2(input=self.conv2(self.conv2(x)))
# this conv-bn pair can use efficient_conv_bn_eval feature
# just for the first forward of the `self.bn3`
# test for multiple users of one computation node
x = self.bn3(input=self.conv3(input=x))
x = self.bn3(x) + x
return x
class EfficientConvBNEvalTemplate(TestCase):
@tf32_on_and_off(0.003)
@inductor_config.patch({"efficient_conv_bn_eval_fx_passes": True})
def test_basic(self):
def test_conv_bn_eval(
test_class, use_bias, module, sync_bn, decompose_nn_module
):
from functorch import make_fx
from torch._dispatch.python import enable_python_dispatcher
kwargs = {"kernel_size": 3, "stride": 2} if module[0] != nn.Linear else {}
mod_eager = test_class(
module[0],
module[1],
use_bias,
3,
32,
self.device,
**kwargs,
).eval()
# Copy module to test backward
mod_optimized = copy.deepcopy(mod_eager)
if sync_bn:
mod_eager = nn.SyncBatchNorm.convert_sync_batchnorm(mod_eager).eval()
mod_optimized = nn.SyncBatchNorm.convert_sync_batchnorm(
mod_optimized
).eval()
torch._dynamo.reset()
inps = [4, 3]
# Conv shape goes from big to small, and ConvTranspose shape goes from small to big
spatial_d = (
4 if issubclass(module[0], nn.modules.conv._ConvTransposeNd) else 96
)
if module[0] is nn.Conv1d or module[0] is nn.ConvTranspose1d:
inps += [spatial_d] * 1
if module[0] is nn.Conv2d or module[0] is nn.ConvTranspose2d:
inps += [spatial_d] * 2
if module[0] is nn.Conv3d or module[0] is nn.ConvTranspose3d:
inps += [spatial_d] * 3
inp = torch.rand(inps).to(self.device)
if decompose_nn_module:
with enable_python_dispatcher():
mod_optimized = make_fx(mod_optimized, pre_dispatch=True)(inp)
mod_optimized = torch.compile(mod_optimized)
original_value = counters["inductor"]["efficient_conv_bn_eval"]
optim_eager = torch.optim.SGD(mod_eager.parameters(), lr=1e-3)
optim_optimized = torch.optim.SGD(mod_optimized.parameters(), lr=1e-3)
optim_eager.zero_grad()
optim_optimized.zero_grad()
# test forward
out_eager = mod_eager(inp)
out_optimized = mod_optimized(inp)
self.assertEqual(out_optimized, out_eager)
out_eager.mean().backward()
out_optimized.mean().backward()
optim_eager.step()
optim_optimized.step()
# test forward (by testing forward again after one training iteration)
inp_bw = torch.rand_like(inp)
out_eager_bw = mod_eager(inp_bw)
out_optimized_bw = mod_optimized(inp_bw)
self.assertEqual(out_eager_bw, out_optimized_bw)
current_value = counters["inductor"]["efficient_conv_bn_eval"]
self.assertEqual(
current_value - original_value, test_class.expected_optimization_count
)
conv_bias = [True, False]
modules = [
(nn.Linear, nn.BatchNorm1d),
(nn.Conv1d, nn.BatchNorm1d),
(nn.Conv2d, nn.BatchNorm2d),
(nn.Conv3d, nn.BatchNorm3d),
(nn.ConvTranspose1d, nn.BatchNorm1d),
(nn.ConvTranspose2d, nn.BatchNorm2d),
(nn.ConvTranspose3d, nn.BatchNorm3d),
]
test_classes = [ConvOp, MultiUserConvOp]
sync_bns = [False, True]
decompose_nn_modules = [False, True]
for (
test_class,
use_bias,
module,
sync_bn,
decompose_nn_module,
) in itertools.product(
test_classes,
conv_bias,
modules,
sync_bns,
decompose_nn_modules,
):
test_conv_bn_eval(
test_class, use_bias, module, sync_bn, decompose_nn_module
)
if HAS_CPU and not torch.backends.mps.is_available():
class EfficientConvBNEvalCpuTests(TestCase):
device = "cpu"
copy_tests(EfficientConvBNEvalTemplate, EfficientConvBNEvalCpuTests, "cpu")
if HAS_GPU:
class EfficientConvBNEvalGpuTests(TestCase):
device = GPU_TYPE
copy_tests(EfficientConvBNEvalTemplate, EfficientConvBNEvalGpuTests, GPU_TYPE)
del EfficientConvBNEvalTemplate
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_CPU or HAS_GPU:
run_tests(needs="filelock")