mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Patch the _is_conv_node function (#153749)"
This reverts commit c985cec5b2545d46af682d486b18866eee5dffd5. Reverted https://github.com/pytorch/pytorch/pull/153749 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/153749#issuecomment-2905504697))
This commit is contained in:
@ -3184,11 +3184,11 @@ class TestHelperModules:
|
||||
return x
|
||||
|
||||
class ConvWithBNRelu(torch.nn.Module):
|
||||
def __init__(self, relu, dim=2, bn=True, bias=True, padding=0):
|
||||
def __init__(self, relu, dim=2, bn=True, bias=True):
|
||||
super().__init__()
|
||||
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
|
||||
bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
|
||||
self.conv = convs[dim](3, 3, 3, bias=bias, padding=padding)
|
||||
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d}
|
||||
bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d}
|
||||
self.conv = convs[dim](3, 3, 3, bias=bias)
|
||||
|
||||
if bn:
|
||||
self.bn = bns[dim](3)
|
||||
|
Reference in New Issue
Block a user