Implement efficient_conv_bn_eval_decomp_graph_transform to handle conv and bn fusion after decomp (#123680)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123680
Approved by: https://github.com/ezyang, https://github.com/youkaichao
This commit is contained in:
JackCaoG
2024-04-18 03:06:42 +00:00
committed by PyTorch MergeBot
parent ca6a0e1348
commit 9ed9b22ec0
2 changed files with 188 additions and 7 deletions

View File

@ -103,7 +103,12 @@ class MultiUserConvOp(nn.Module):
class EfficientConvBNEvalTemplate(TestCase):
@inductor_config.patch({"efficient_conv_bn_eval_fx_passes": True})
def test_basic(self):
def test_conv_bn_eval(test_class, use_bias, module, sync_bn):
def test_conv_bn_eval(
test_class, use_bias, module, sync_bn, decompose_nn_module
):
from functorch import make_fx
from torch._dispatch.python import enable_python_dispatcher
kwargs = {"kernel_size": 3, "stride": 2} if module[0] != nn.Linear else {}
mod_eager = test_class(
module[0],
@ -122,7 +127,6 @@ class EfficientConvBNEvalTemplate(TestCase):
mod_optimized
).eval()
torch._dynamo.reset()
mod_optimized = torch.compile(mod_optimized)
inps = [4, 3]
# Conv shape goes from big to small, and ConvTranspose shape goes from small to big
@ -137,6 +141,11 @@ class EfficientConvBNEvalTemplate(TestCase):
inps += [spatial_d] * 3
inp = torch.rand(inps).to(self.device)
if decompose_nn_module:
with enable_python_dispatcher():
mod_optimized = make_fx(mod_optimized, pre_dispatch=True)(inp)
mod_optimized = torch.compile(mod_optimized)
original_value = counters["inductor"]["efficient_conv_bn_eval"]
optim_eager = torch.optim.SGD(mod_eager.parameters(), lr=1e-3)
@ -179,10 +188,23 @@ class EfficientConvBNEvalTemplate(TestCase):
]
test_classes = [ConvOp, MultiUserConvOp]
sync_bns = [False, True]
for test_class, use_bias, module, sync_bn in itertools.product(
test_classes, conv_bias, modules, sync_bns
decompose_nn_modules = [False, True]
for (
test_class,
use_bias,
module,
sync_bn,
decompose_nn_module,
) in itertools.product(
test_classes,
conv_bias,
modules,
sync_bns,
decompose_nn_modules,
):
test_conv_bn_eval(test_class, use_bias, module, sync_bn)
test_conv_bn_eval(
test_class, use_bias, module, sync_bn, decompose_nn_module
)
if HAS_CPU and not torch.backends.mps.is_available():

View File

@ -5,7 +5,12 @@ from torch._dynamo.utils import counters
from torch._inductor import config as inductor_config
from torch.func import functional_call
from ..pattern_matcher import CallModuleVarArgs, Match, register_graph_pattern
from ..pattern_matcher import (
CallFunctionVarArgs,
CallModuleVarArgs,
Match,
register_graph_pattern,
)
from .pre_grad import efficient_conv_bn_eval_pass
@ -15,7 +20,7 @@ def efficient_conv_bn_eval(
):
"""
Implementation based on https://arxiv.org/abs/2305.11624
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
"Efficient ConvBN Blocks for Transfer Learning and Beyond"
It leverages the associative law between convolution and affine transform,
i.e., normalize (weight conv feature) = (normalize weight) conv feature.
It works for Eval mode of ConvBN blocks during validation, and can be used
@ -70,6 +75,160 @@ def efficient_conv_bn_eval(
return output
def efficient_conv_bn_eval_decomposed(
bn_weight,
bn_bias,
bn_running_mean,
bn_running_var,
bn_eps,
conv: torch._ops.OpOverload,
conv_weight,
conv_bias,
x,
conv_remainging_args,
):
"""
Implementation based on https://arxiv.org/abs/2305.11624
"Efficient ConvBN Blocks for Transfer Learning and Beyond"
It leverages the associative law between convolution and affine transform,
i.e., normalize (weight conv feature) = (normalize weight) conv feature.
It works for Eval mode of ConvBN blocks during validation, and can be used
for **training** as well, but only if one sets `bn.training=False`. It
reduces memory footprint and computation cost, at the cost of slightly
reduced numerical stability.
Args:
"""
assert bn_running_var is not None
# These lines of code are designed to deal with various cases
# like bn without affine transform, and conv without bias
weight_on_the_fly = conv_weight
if conv_bias is not None:
bias_on_the_fly = conv_bias
else:
bias_on_the_fly = torch.zeros_like(bn_running_var)
if bn_weight is not None:
bn_weight = bn_weight
else:
bn_weight = torch.ones_like(bn_running_var)
if bn_bias is not None:
bn_bias = bn_bias
else:
bn_bias = torch.zeros_like(bn_running_var)
# shape of [C_out, 1, 1, 1] in Conv2d
target_shape = [-1] + [1] * (conv_weight.ndim - 1)
if "conv_transpose" in conv.__str__():
# for transposed conv, the C_out dimension should at index 1.
target_shape[:2] = [target_shape[1], target_shape[0]]
weight_coeff = torch.rsqrt(bn_running_var + bn_eps).reshape(target_shape)
# shape of [C_out, 1, 1, 1] in Conv2d
coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
# shape of [C_out, C_in, k, k] in Conv2d
weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
# shape of [C_out] in Conv2d
bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * (
bias_on_the_fly - bn_running_mean
)
input = x
return conv(*((input, weight_on_the_fly, bias_on_the_fly) + conv_remainging_args))
@register_graph_pattern(
CallFunctionVarArgs(
[
torch.ops.aten.batch_norm.default,
]
),
pass_dict=efficient_conv_bn_eval_pass,
extra_check=lambda match: not inductor_config.freezing
and inductor_config.efficient_conv_bn_eval_fx_passes,
)
def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwargs):
bn_node = match.nodes[0]
graph = match.graph
assert len(bn_node.args) == 9
# We can only use efficient conv-bn for eval mode with track_running_stats
# bn_node.args is `training`
if bn_node.args[-4]:
return
# Check if the input is Conv
input_node = bn_node.args[0]
if input_node.op != "call_function": # type: ignore[union-attr]
return
input_fn = input_node.target # type: ignore[arg-type, union-attr]
supported_convs = [
torch.ops.aten.linear.default,
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv3d.default,
torch.ops.aten.conv_transpose1d.default,
torch.ops.aten.conv_transpose2d.input,
torch.ops.aten.conv_transpose3d.input,
]
if not any(input_fn is cls for cls in supported_convs):
return
conv_node = input_node
# Output of conv is used by other nodes, cannot optimize
if len(conv_node.users) > 1: # type: ignore[union-attr]
return
counters["inductor"]["efficient_conv_bn_eval"] += 1
with graph.inserting_before(bn_node):
# prepare args for the fused function
bn_weight = bn_node.args[1]
bn_bias = bn_node.args[2]
bn_running_mean = bn_node.args[3]
bn_running_var = bn_node.args[4]
bn_eps = bn_node.args[7]
assert len(conv_node.args) >= 2 # type: ignore[union-attr]
conv_input = conv_node.args[0] # type: ignore[union-attr]
conv_weight = conv_node.args[1] # type: ignore[union-attr]
conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr]
conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr]
args = (
bn_weight,
bn_bias,
bn_running_mean,
bn_running_var,
bn_eps,
conv_node.target, # type: ignore[union-attr]
conv_weight,
conv_bias,
conv_input,
conv_remainging_args,
)
# create a new node
new_node = graph.create_node(
op="call_function",
target=efficient_conv_bn_eval_decomposed,
args=args,
name="efficient_conv_bn_eval",
)
# this node replaces the original conv + bn, and therefore
# should replace the uses of bn_node
bn_node.replace_all_uses_with(new_node)
# take care of the deletion order:
# delete bn_node first, and then conv_node
graph.erase_node(bn_node)
graph.erase_node(conv_node)
return
@register_graph_pattern(
CallModuleVarArgs(
[