mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[TorchTidy] Add pattern to detect if bias is enabled in conv2d followed by batchnorm2d (#81941)"
This reverts commit 615f2fda4f40be098da16be075d192f36820353f. Reverted https://github.com/pytorch/pytorch/pull/81941 on behalf of https://github.com/ZainRizvi due to New test failed on ROCm builds
This commit is contained in:
@ -96,13 +96,6 @@ class Pattern:
|
||||
prev_events, _ = self.siblings_of(event)
|
||||
return prev_events[-1] if prev_events else None
|
||||
|
||||
def go_up_until(self, event: _ProfilerEvent, predicate):
|
||||
if not event:
|
||||
return None
|
||||
while event.parent and not predicate(event):
|
||||
event = event.parent
|
||||
return event
|
||||
|
||||
|
||||
# Patterns
|
||||
|
||||
@ -434,46 +427,6 @@ class GradNotSetToNonePattern(Pattern):
|
||||
return False
|
||||
|
||||
|
||||
class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern):
|
||||
'''
|
||||
This pattern identifies if we are enabling bias in Conv2d which is followed by BatchNorm2d.
|
||||
Bias doesn't do anything when followed by batchnorm.
|
||||
|
||||
Pattern:
|
||||
nn.Module: Conv2d | nn.Module: BatchNorm2d
|
||||
...
|
||||
aten::_convolution
|
||||
... | aten::add_
|
||||
# This pattern only works when using CUDA
|
||||
|
||||
Algorithm:
|
||||
String match
|
||||
'''
|
||||
|
||||
def __init__(self, prof: profile, should_benchmark: bool = False):
|
||||
super().__init__(prof, should_benchmark)
|
||||
self.name = "Enabling Bias in Conv2d Followed By BatchNorm Pattern"
|
||||
self.description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d."
|
||||
|
||||
def match(self, event: _ProfilerEvent):
|
||||
if event.name() != "aten::_convolution":
|
||||
return False
|
||||
if not event.children:
|
||||
return False
|
||||
event = event.children[-1]
|
||||
if event.name() != "aten::add_":
|
||||
return False
|
||||
# This means bias=True
|
||||
event = self.go_up_until(
|
||||
event, lambda e: e.name().startswith("nn.Module: Conv2d"))
|
||||
if not event:
|
||||
return False
|
||||
event = self.next_of(event)
|
||||
if not event:
|
||||
return False
|
||||
return event.name().startswith("nn.Module: BatchNorm2d")
|
||||
|
||||
|
||||
def source_code_location(event: _ProfilerEvent):
|
||||
while event:
|
||||
if event_type(event) == _EventType.PyCall or event_type(
|
||||
@ -523,8 +476,7 @@ def report_all_anti_patterns(prof, should_benchmark: bool = False):
|
||||
FP32MatMulPattern(prof, should_benchmark),
|
||||
OptimizerSingleTensorPattern(prof, should_benchmark),
|
||||
SynchronizedDataLoaderPattern(prof, should_benchmark),
|
||||
GradNotSetToNonePattern(prof, should_benchmark),
|
||||
Conv2dBiasFollowedByBatchNorm2dPattern(prof, should_benchmark)
|
||||
GradNotSetToNonePattern(prof, should_benchmark)
|
||||
]
|
||||
reported = set()
|
||||
summaries = []
|
||||
|
Reference in New Issue
Block a user