mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement efficient_conv_bn_eval_decomp_graph_transform to handle conv and bn fusion after decomp (#123680)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123680 Approved by: https://github.com/ezyang, https://github.com/youkaichao
This commit is contained in:
committed by
PyTorch MergeBot
parent
ca6a0e1348
commit
9ed9b22ec0
@ -103,7 +103,12 @@ class MultiUserConvOp(nn.Module):
|
||||
class EfficientConvBNEvalTemplate(TestCase):
|
||||
@inductor_config.patch({"efficient_conv_bn_eval_fx_passes": True})
|
||||
def test_basic(self):
|
||||
def test_conv_bn_eval(test_class, use_bias, module, sync_bn):
|
||||
def test_conv_bn_eval(
|
||||
test_class, use_bias, module, sync_bn, decompose_nn_module
|
||||
):
|
||||
from functorch import make_fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
|
||||
kwargs = {"kernel_size": 3, "stride": 2} if module[0] != nn.Linear else {}
|
||||
mod_eager = test_class(
|
||||
module[0],
|
||||
@ -122,7 +127,6 @@ class EfficientConvBNEvalTemplate(TestCase):
|
||||
mod_optimized
|
||||
).eval()
|
||||
torch._dynamo.reset()
|
||||
mod_optimized = torch.compile(mod_optimized)
|
||||
|
||||
inps = [4, 3]
|
||||
# Conv shape goes from big to small, and ConvTranspose shape goes from small to big
|
||||
@ -137,6 +141,11 @@ class EfficientConvBNEvalTemplate(TestCase):
|
||||
inps += [spatial_d] * 3
|
||||
inp = torch.rand(inps).to(self.device)
|
||||
|
||||
if decompose_nn_module:
|
||||
with enable_python_dispatcher():
|
||||
mod_optimized = make_fx(mod_optimized, pre_dispatch=True)(inp)
|
||||
mod_optimized = torch.compile(mod_optimized)
|
||||
|
||||
original_value = counters["inductor"]["efficient_conv_bn_eval"]
|
||||
|
||||
optim_eager = torch.optim.SGD(mod_eager.parameters(), lr=1e-3)
|
||||
@ -179,10 +188,23 @@ class EfficientConvBNEvalTemplate(TestCase):
|
||||
]
|
||||
test_classes = [ConvOp, MultiUserConvOp]
|
||||
sync_bns = [False, True]
|
||||
for test_class, use_bias, module, sync_bn in itertools.product(
|
||||
test_classes, conv_bias, modules, sync_bns
|
||||
decompose_nn_modules = [False, True]
|
||||
for (
|
||||
test_class,
|
||||
use_bias,
|
||||
module,
|
||||
sync_bn,
|
||||
decompose_nn_module,
|
||||
) in itertools.product(
|
||||
test_classes,
|
||||
conv_bias,
|
||||
modules,
|
||||
sync_bns,
|
||||
decompose_nn_modules,
|
||||
):
|
||||
test_conv_bn_eval(test_class, use_bias, module, sync_bn)
|
||||
test_conv_bn_eval(
|
||||
test_class, use_bias, module, sync_bn, decompose_nn_module
|
||||
)
|
||||
|
||||
|
||||
if HAS_CPU and not torch.backends.mps.is_available():
|
||||
|
@ -5,7 +5,12 @@ from torch._dynamo.utils import counters
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch.func import functional_call
|
||||
|
||||
from ..pattern_matcher import CallModuleVarArgs, Match, register_graph_pattern
|
||||
from ..pattern_matcher import (
|
||||
CallFunctionVarArgs,
|
||||
CallModuleVarArgs,
|
||||
Match,
|
||||
register_graph_pattern,
|
||||
)
|
||||
|
||||
from .pre_grad import efficient_conv_bn_eval_pass
|
||||
|
||||
@ -15,7 +20,7 @@ def efficient_conv_bn_eval(
|
||||
):
|
||||
"""
|
||||
Implementation based on https://arxiv.org/abs/2305.11624
|
||||
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
|
||||
"Efficient ConvBN Blocks for Transfer Learning and Beyond"
|
||||
It leverages the associative law between convolution and affine transform,
|
||||
i.e., normalize (weight conv feature) = (normalize weight) conv feature.
|
||||
It works for Eval mode of ConvBN blocks during validation, and can be used
|
||||
@ -70,6 +75,160 @@ def efficient_conv_bn_eval(
|
||||
return output
|
||||
|
||||
|
||||
def efficient_conv_bn_eval_decomposed(
|
||||
bn_weight,
|
||||
bn_bias,
|
||||
bn_running_mean,
|
||||
bn_running_var,
|
||||
bn_eps,
|
||||
conv: torch._ops.OpOverload,
|
||||
conv_weight,
|
||||
conv_bias,
|
||||
x,
|
||||
conv_remainging_args,
|
||||
):
|
||||
"""
|
||||
Implementation based on https://arxiv.org/abs/2305.11624
|
||||
"Efficient ConvBN Blocks for Transfer Learning and Beyond"
|
||||
It leverages the associative law between convolution and affine transform,
|
||||
i.e., normalize (weight conv feature) = (normalize weight) conv feature.
|
||||
It works for Eval mode of ConvBN blocks during validation, and can be used
|
||||
for **training** as well, but only if one sets `bn.training=False`. It
|
||||
reduces memory footprint and computation cost, at the cost of slightly
|
||||
reduced numerical stability.
|
||||
Args:
|
||||
"""
|
||||
assert bn_running_var is not None
|
||||
|
||||
# These lines of code are designed to deal with various cases
|
||||
# like bn without affine transform, and conv without bias
|
||||
weight_on_the_fly = conv_weight
|
||||
if conv_bias is not None:
|
||||
bias_on_the_fly = conv_bias
|
||||
else:
|
||||
bias_on_the_fly = torch.zeros_like(bn_running_var)
|
||||
|
||||
if bn_weight is not None:
|
||||
bn_weight = bn_weight
|
||||
else:
|
||||
bn_weight = torch.ones_like(bn_running_var)
|
||||
|
||||
if bn_bias is not None:
|
||||
bn_bias = bn_bias
|
||||
else:
|
||||
bn_bias = torch.zeros_like(bn_running_var)
|
||||
|
||||
# shape of [C_out, 1, 1, 1] in Conv2d
|
||||
target_shape = [-1] + [1] * (conv_weight.ndim - 1)
|
||||
if "conv_transpose" in conv.__str__():
|
||||
# for transposed conv, the C_out dimension should at index 1.
|
||||
target_shape[:2] = [target_shape[1], target_shape[0]]
|
||||
weight_coeff = torch.rsqrt(bn_running_var + bn_eps).reshape(target_shape)
|
||||
# shape of [C_out, 1, 1, 1] in Conv2d
|
||||
coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
|
||||
|
||||
# shape of [C_out, C_in, k, k] in Conv2d
|
||||
weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
|
||||
# shape of [C_out] in Conv2d
|
||||
bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * (
|
||||
bias_on_the_fly - bn_running_mean
|
||||
)
|
||||
|
||||
input = x
|
||||
return conv(*((input, weight_on_the_fly, bias_on_the_fly) + conv_remainging_args))
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunctionVarArgs(
|
||||
[
|
||||
torch.ops.aten.batch_norm.default,
|
||||
]
|
||||
),
|
||||
pass_dict=efficient_conv_bn_eval_pass,
|
||||
extra_check=lambda match: not inductor_config.freezing
|
||||
and inductor_config.efficient_conv_bn_eval_fx_passes,
|
||||
)
|
||||
def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwargs):
|
||||
bn_node = match.nodes[0]
|
||||
graph = match.graph
|
||||
assert len(bn_node.args) == 9
|
||||
|
||||
# We can only use efficient conv-bn for eval mode with track_running_stats
|
||||
# bn_node.args is `training`
|
||||
if bn_node.args[-4]:
|
||||
return
|
||||
|
||||
# Check if the input is Conv
|
||||
input_node = bn_node.args[0]
|
||||
|
||||
if input_node.op != "call_function": # type: ignore[union-attr]
|
||||
return
|
||||
|
||||
input_fn = input_node.target # type: ignore[arg-type, union-attr]
|
||||
supported_convs = [
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.conv1d.default,
|
||||
torch.ops.aten.conv2d.default,
|
||||
torch.ops.aten.conv3d.default,
|
||||
torch.ops.aten.conv_transpose1d.default,
|
||||
torch.ops.aten.conv_transpose2d.input,
|
||||
torch.ops.aten.conv_transpose3d.input,
|
||||
]
|
||||
|
||||
if not any(input_fn is cls for cls in supported_convs):
|
||||
return
|
||||
|
||||
conv_node = input_node
|
||||
# Output of conv is used by other nodes, cannot optimize
|
||||
if len(conv_node.users) > 1: # type: ignore[union-attr]
|
||||
return
|
||||
|
||||
counters["inductor"]["efficient_conv_bn_eval"] += 1
|
||||
|
||||
with graph.inserting_before(bn_node):
|
||||
# prepare args for the fused function
|
||||
bn_weight = bn_node.args[1]
|
||||
bn_bias = bn_node.args[2]
|
||||
bn_running_mean = bn_node.args[3]
|
||||
bn_running_var = bn_node.args[4]
|
||||
bn_eps = bn_node.args[7]
|
||||
assert len(conv_node.args) >= 2 # type: ignore[union-attr]
|
||||
conv_input = conv_node.args[0] # type: ignore[union-attr]
|
||||
conv_weight = conv_node.args[1] # type: ignore[union-attr]
|
||||
conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr]
|
||||
conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr]
|
||||
args = (
|
||||
bn_weight,
|
||||
bn_bias,
|
||||
bn_running_mean,
|
||||
bn_running_var,
|
||||
bn_eps,
|
||||
conv_node.target, # type: ignore[union-attr]
|
||||
conv_weight,
|
||||
conv_bias,
|
||||
conv_input,
|
||||
conv_remainging_args,
|
||||
)
|
||||
|
||||
# create a new node
|
||||
new_node = graph.create_node(
|
||||
op="call_function",
|
||||
target=efficient_conv_bn_eval_decomposed,
|
||||
args=args,
|
||||
name="efficient_conv_bn_eval",
|
||||
)
|
||||
|
||||
# this node replaces the original conv + bn, and therefore
|
||||
# should replace the uses of bn_node
|
||||
bn_node.replace_all_uses_with(new_node)
|
||||
# take care of the deletion order:
|
||||
# delete bn_node first, and then conv_node
|
||||
graph.erase_node(bn_node)
|
||||
graph.erase_node(conv_node)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
CallModuleVarArgs(
|
||||
[
|
||||
|
Reference in New Issue
Block a user