Revert "[TorchTidy] Add pattern to detect if bias is enabled in conv2d followed by batchnorm2d (#81941)"

This reverts commit 615f2fda4f40be098da16be075d192f36820353f.

Reverted https://github.com/pytorch/pytorch/pull/81941 on behalf of https://github.com/ZainRizvi due to New test failed on ROCm builds
This commit is contained in:
PyTorch MergeBot
2022-07-28 16:14:18 +00:00
parent 688b971876
commit 2fe73164b6
2 changed files with 2 additions and 68 deletions

View File

@ -96,13 +96,6 @@ class Pattern:
prev_events, _ = self.siblings_of(event)
return prev_events[-1] if prev_events else None
def go_up_until(self, event: _ProfilerEvent, predicate):
if not event:
return None
while event.parent and not predicate(event):
event = event.parent
return event
# Patterns
@ -434,46 +427,6 @@ class GradNotSetToNonePattern(Pattern):
return False
class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern):
'''
This pattern identifies if we are enabling bias in Conv2d which is followed by BatchNorm2d.
Bias doesn't do anything when followed by batchnorm.
Pattern:
nn.Module: Conv2d | nn.Module: BatchNorm2d
...
aten::_convolution
... | aten::add_
# This pattern only works when using CUDA
Algorithm:
String match
'''
def __init__(self, prof: profile, should_benchmark: bool = False):
super().__init__(prof, should_benchmark)
self.name = "Enabling Bias in Conv2d Followed By BatchNorm Pattern"
self.description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d."
def match(self, event: _ProfilerEvent):
if event.name() != "aten::_convolution":
return False
if not event.children:
return False
event = event.children[-1]
if event.name() != "aten::add_":
return False
# This means bias=True
event = self.go_up_until(
event, lambda e: e.name().startswith("nn.Module: Conv2d"))
if not event:
return False
event = self.next_of(event)
if not event:
return False
return event.name().startswith("nn.Module: BatchNorm2d")
def source_code_location(event: _ProfilerEvent):
while event:
if event_type(event) == _EventType.PyCall or event_type(
@ -523,8 +476,7 @@ def report_all_anti_patterns(prof, should_benchmark: bool = False):
FP32MatMulPattern(prof, should_benchmark),
OptimizerSingleTensorPattern(prof, should_benchmark),
SynchronizedDataLoaderPattern(prof, should_benchmark),
GradNotSetToNonePattern(prof, should_benchmark),
Conv2dBiasFollowedByBatchNorm2dPattern(prof, should_benchmark)
GradNotSetToNonePattern(prof, should_benchmark)
]
reported = set()
summaries = []