mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25598 att Test Plan: CI Imported from OSS Differential Revision: D17192467 fbshipit-source-id: 9ee93b02cc293bb71ed114534d92eedda3ddee88
162 lines
6.4 KiB
Python
162 lines
6.4 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import torch
|
|
from torch.nn import Conv2d, BatchNorm2d, ReLU
|
|
from torch.nn._intrinsic.qat import ConvBn2d, ConvBnReLU2d
|
|
from torch.quantization.QConfig import default_qat_qconfig
|
|
from torch.utils.mkldnn import disable_mkldnn_conv
|
|
from common_utils import TestCase, run_tests
|
|
from hypothesis import given
|
|
from hypothesis import strategies as st
|
|
from hypothesis_utils import no_deadline
|
|
from functools import reduce
|
|
|
|
|
|
class IntrinsicQATModuleTest(TestCase):
|
|
# NOTE: Tests in this class are decorated with no_deadline
|
|
# to prevent spurious failures due to cuda runtime initialization.
|
|
|
|
@no_deadline
|
|
@given(batch_size=st.integers(2, 4),
|
|
input_channels_per_group=st.sampled_from([2, 3, 4]),
|
|
height=st.integers(5, 10),
|
|
width=st.integers(5, 10),
|
|
output_channels_per_group=st.sampled_from([2, 3]),
|
|
groups=st.integers(1, 3),
|
|
kernel_h=st.integers(1, 3),
|
|
kernel_w=st.integers(1, 3),
|
|
stride_h=st.integers(1, 2),
|
|
stride_w=st.integers(1, 2),
|
|
pad_h=st.integers(0, 2),
|
|
pad_w=st.integers(0, 2),
|
|
dilation=st.integers(1, 1),
|
|
padding_mode=st.sampled_from(['zeros', 'circular']),
|
|
use_relu=st.booleans(),
|
|
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
|
|
momentum=st.sampled_from([0.1, 0.2, 0.3]),
|
|
freeze_bn=st.booleans())
|
|
def test_conv_bn_relu(
|
|
self,
|
|
batch_size,
|
|
input_channels_per_group,
|
|
height,
|
|
width,
|
|
output_channels_per_group,
|
|
groups,
|
|
kernel_h,
|
|
kernel_w,
|
|
stride_h,
|
|
stride_w,
|
|
pad_h,
|
|
pad_w,
|
|
dilation,
|
|
padding_mode,
|
|
use_relu,
|
|
eps,
|
|
momentum,
|
|
freeze_bn
|
|
):
|
|
with disable_mkldnn_conv():
|
|
input_channels = input_channels_per_group * groups
|
|
output_channels = output_channels_per_group * groups
|
|
dilation_h = dilation_w = dilation
|
|
|
|
conv_op = Conv2d(
|
|
input_channels,
|
|
output_channels,
|
|
(kernel_h, kernel_w),
|
|
(stride_h, stride_w),
|
|
(pad_h, pad_w),
|
|
(dilation_h, dilation_w),
|
|
groups,
|
|
False, # No bias
|
|
padding_mode
|
|
).to(dtype=torch.double)
|
|
bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.double)
|
|
relu_op = ReLU()
|
|
|
|
cls = ConvBnReLU2d if use_relu else ConvBn2d
|
|
qat_op = cls(
|
|
input_channels,
|
|
output_channels,
|
|
(kernel_h, kernel_w),
|
|
(stride_h, stride_w),
|
|
(pad_h, pad_w),
|
|
(dilation_h, dilation_w),
|
|
groups,
|
|
padding_mode,
|
|
eps,
|
|
momentum,
|
|
freeze_bn,
|
|
default_qat_qconfig
|
|
).to(dtype=torch.double).disable_fake_quant()
|
|
|
|
# align inputs and internal parameters
|
|
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
|
|
conv_op.weight = torch.nn.Parameter(qat_op.weight.detach())
|
|
bn_op.running_mean = qat_op.running_mean.clone()
|
|
bn_op.running_var = qat_op.running_var.clone()
|
|
bn_op.weight = torch.nn.Parameter(qat_op.gamma.detach())
|
|
bn_op.bias = torch.nn.Parameter(qat_op.beta.detach())
|
|
|
|
def compose(functions):
|
|
# functions are reversed for natural reading order
|
|
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
|
|
|
|
if not use_relu:
|
|
def relu_op(x):
|
|
return x
|
|
|
|
if freeze_bn:
|
|
def ref_op(x):
|
|
x = conv_op(x)
|
|
x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \
|
|
(bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \
|
|
.reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
|
|
x = relu_op(x)
|
|
return x
|
|
else:
|
|
ref_op = compose([conv_op, bn_op, relu_op])
|
|
|
|
input_clone = input.clone().detach().requires_grad_()
|
|
for i in range(2):
|
|
result_ref = ref_op(input)
|
|
result_actual = qat_op(input_clone)
|
|
self.assertEqual(result_ref, result_actual)
|
|
|
|
# backward
|
|
dout = torch.randn(result_ref.size(), dtype=torch.double)
|
|
loss = (result_ref - dout).sum()
|
|
loss.backward()
|
|
input_grad_ref = input.grad.cpu()
|
|
weight_grad_ref = conv_op.weight.grad.cpu()
|
|
gamma_grad_ref = bn_op.weight.grad.cpu()
|
|
beta_grad_ref = bn_op.bias.grad.cpu()
|
|
running_mean_ref = bn_op.running_mean
|
|
running_var_ref = bn_op.running_var
|
|
num_batches_tracked_ref = bn_op.num_batches_tracked
|
|
loss = (result_actual - dout).sum()
|
|
loss.backward()
|
|
input_grad_actual = input_clone.grad.cpu()
|
|
weight_grad_actual = qat_op.weight.grad.cpu()
|
|
gamma_grad_actual = qat_op.gamma.grad.cpu()
|
|
beta_grad_actual = qat_op.beta.grad.cpu()
|
|
running_mean_actual = qat_op.running_mean
|
|
running_var_actual = qat_op.running_var
|
|
num_batches_tracked_actual = qat_op.num_batches_tracked
|
|
precision = 1e-10
|
|
self.assertEqual(input_grad_ref, input_grad_actual, prec=precision)
|
|
self.assertEqual(weight_grad_ref, weight_grad_actual, prec=precision)
|
|
self.assertEqual(gamma_grad_ref, gamma_grad_actual, prec=precision)
|
|
self.assertEqual(beta_grad_ref, beta_grad_actual, prec=precision)
|
|
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, prec=precision)
|
|
self.assertEqual(running_mean_ref, running_mean_actual, prec=precision)
|
|
self.assertEqual(running_var_ref, running_var_actual, prec=precision)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|