mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Reland][pytorch] Patch the _is_conv_node function (#154473)
Summary: Add the conv padding ops in pytorch, the corresponding pr in torch ao is https://github.com/pytorch/ao/pull/2257 Test Plan: ``` buck test 'fbcode//mode/opt' fbcode//caffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_conv_padding_bn_relu (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2E)' ``` Differential Revision: D75494468 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154473 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
d6cb0fe576
commit
9371491529
@ -3183,12 +3183,13 @@ class TestHelperModules:
|
||||
x = self.adaptive_avg_pool2d(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvWithBNRelu(torch.nn.Module):
|
||||
def __init__(self, relu, dim=2, bn=True, bias=True):
|
||||
def __init__(self, relu, dim=2, bn=True, bias=True, padding=0):
|
||||
super().__init__()
|
||||
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d}
|
||||
bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d}
|
||||
self.conv = convs[dim](3, 3, 3, bias=bias)
|
||||
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
|
||||
bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
|
||||
self.conv = convs[dim](3, 3, 3, bias=bias, padding=padding)
|
||||
|
||||
if bn:
|
||||
self.bn = bns[dim](3)
|
||||
|
||||
Reference in New Issue
Block a user