mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: See https://github.com/pytorch/pytorch/issues/130772 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137298 Approved by: https://github.com/oulgen, https://github.com/eellison
273 lines
9.2 KiB
Python
273 lines
9.2 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import contextlib
|
|
import operator
|
|
from collections import defaultdict
|
|
|
|
import torch
|
|
import torch._inductor.pattern_matcher as pattern_matcher
|
|
import torch.fx as fx
|
|
from torch._dynamo.utils import counters
|
|
from torch._inductor import config
|
|
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
|
|
from torch._inductor.lowering import lowerings as L
|
|
from torch._inductor.pattern_matcher import Arg, CallFunction, PatternMatcherPass
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch.testing._internal.common_utils import IS_LINUX
|
|
from torch.testing._internal.inductor_utils import HAS_CPU
|
|
|
|
|
|
@config.patch({"freezing": True})
|
|
class TestCustomPassBase(TestCase):
|
|
def _clone_inputs(self, inputs):
|
|
def clone(x):
|
|
if not isinstance(x, torch.Tensor):
|
|
return x
|
|
return x.clone()
|
|
|
|
return tuple(clone(x) for x in inputs)
|
|
|
|
def _test_common(
|
|
self,
|
|
mod,
|
|
inputs,
|
|
matcher_count,
|
|
matcher_nodes,
|
|
atol=1e-5,
|
|
rtol=1.3e-6,
|
|
):
|
|
counters.clear()
|
|
maybe_autocast = contextlib.nullcontext()
|
|
with torch.no_grad(), maybe_autocast:
|
|
clone_inputs = self._clone_inputs(inputs)
|
|
expected = mod(*inputs)
|
|
actual = torch.compile(mod)(*clone_inputs)
|
|
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
|
|
self.assertEqual(
|
|
counters["inductor"]["pattern_matcher_count"], matcher_count
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["pattern_matcher_nodes"],
|
|
matcher_nodes,
|
|
)
|
|
|
|
|
|
aten = torch.ops.aten
|
|
mkldnn = torch.ops.mkldnn
|
|
|
|
|
|
def change_cos_pass(graph):
|
|
for node in graph.nodes:
|
|
if node.op == "call_function" and node.target == aten.cos.default:
|
|
node.target = aten.sin.default
|
|
|
|
|
|
class TestPostGradCustomPrePostPass(TestCustomPassBase):
|
|
# mkldnn fusion's pattern_matcher
|
|
# (torch/_inductor/fx_passes/mkldnn_fusion.py),
|
|
# and apply it to custom post_grad_passes.
|
|
def _register_mkldnn_conv_relu_fusion(self, custom_pass_dict):
|
|
# pattern
|
|
def _mkldnn_conv_relu_pattern():
|
|
return CallFunction(
|
|
aten.relu,
|
|
CallFunction(
|
|
mkldnn._convolution_pointwise.default,
|
|
Arg(),
|
|
Arg(),
|
|
Arg(),
|
|
Arg(),
|
|
Arg(),
|
|
Arg(),
|
|
Arg(),
|
|
Arg(),
|
|
Arg(),
|
|
Arg(),
|
|
_users=1,
|
|
),
|
|
)
|
|
|
|
# utils of pattern matcher registration
|
|
def _register_fusion_lowering(pattern, custom_pass_dict):
|
|
def dummy_check(m):
|
|
return True
|
|
|
|
def register_custom_lowering_pattern(
|
|
pattern, extra_check, custom_pass_dict
|
|
):
|
|
return pattern_matcher.register_lowering_pattern(
|
|
pattern, extra_check, pass_dict=custom_pass_dict
|
|
)
|
|
|
|
@register_custom_lowering_pattern(pattern, dummy_check, custom_pass_dict)
|
|
def fn(match, *args, **kwargs):
|
|
computation_args = list(args)[:-3] + ["relu", [], ""]
|
|
return L[mkldnn._convolution_pointwise.default](*computation_args)
|
|
|
|
return fn
|
|
|
|
_register_fusion_lowering(_mkldnn_conv_relu_pattern(), custom_pass_dict)
|
|
|
|
# custom post grad pass
|
|
class _CustomPass(PatternMatcherPass, CustomGraphPass):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def __call__(self, g: torch.fx.graph.Graph):
|
|
self.apply(g)
|
|
|
|
def uuid(self) -> bytes:
|
|
return get_hash_for_files((__file__,))
|
|
|
|
# case model
|
|
class _ConvReLU(torch.nn.Module):
|
|
def __init__(self, ic, oc):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(ic, oc, kernel_size=3, stride=1, padding=1)
|
|
|
|
def forward(self, x):
|
|
x1 = self.conv(x)
|
|
return x1.relu()
|
|
|
|
def test_custom_joint_pass_pre(self):
|
|
with config.patch(joint_custom_pre_pass=change_cos_pass):
|
|
|
|
def g(x):
|
|
return x.sin().sin().sin()
|
|
|
|
def f(x):
|
|
return x.cos().cos().cos()
|
|
|
|
x = torch.randn(8, dtype=torch.float32)
|
|
torch.testing.assert_close(torch.compile(f)(x), g(x))
|
|
|
|
def test_custom_joint_pass_post(self):
|
|
with config.patch(joint_custom_post_pass=change_cos_pass):
|
|
|
|
def g(x):
|
|
return x.sin().sin().sin()
|
|
|
|
def f(x):
|
|
return x.cos().cos().cos()
|
|
|
|
x = torch.randn(8, dtype=torch.float32)
|
|
torch.testing.assert_close(torch.compile(f)(x), g(x))
|
|
|
|
def test_custom_pre_pass(self):
|
|
with config.patch(
|
|
# leave custom pass only in post_grad_passes()
|
|
pattern_matcher=False,
|
|
post_grad_custom_pre_pass=self._CustomPass(),
|
|
# define pattern match as custom post grad opt pass
|
|
post_grad_custom_post_pass=None,
|
|
):
|
|
# init mkldnn fusion on custom_matcher
|
|
self._register_mkldnn_conv_relu_fusion(config.post_grad_custom_pre_pass)
|
|
|
|
mod = self._ConvReLU(16, 16).eval()
|
|
x = torch.randn((1, 16, 56, 56), dtype=torch.float32)
|
|
|
|
match_count = 1
|
|
match_nodes = 2
|
|
other_match_count = 1 # conv prepack weight
|
|
other_match_nodes = 1 # conv prepack weight
|
|
self._test_common(
|
|
mod,
|
|
(x,),
|
|
match_count + other_match_count,
|
|
match_nodes + other_match_nodes,
|
|
)
|
|
|
|
def test_custom_post_pass(self):
|
|
with config.patch(
|
|
# leave custom pass only in post_grad_passes()
|
|
pattern_matcher=False,
|
|
# define pattern match as custom post grad opt pass
|
|
post_grad_custom_pre_pass=None,
|
|
post_grad_custom_post_pass=self._CustomPass(),
|
|
):
|
|
# init mkldnn fusion on custom_matcher
|
|
self._register_mkldnn_conv_relu_fusion(config.post_grad_custom_post_pass)
|
|
|
|
mod = self._ConvReLU(16, 16).eval()
|
|
x = torch.randn((1, 16, 56, 56), dtype=torch.float32)
|
|
|
|
match_count = 1
|
|
match_nodes = 2
|
|
other_match_count = 1 # conv prepack weight
|
|
other_match_nodes = 1 # conv prepack weight
|
|
self._test_common(
|
|
mod,
|
|
(x,),
|
|
match_count + other_match_count,
|
|
match_nodes + other_match_nodes,
|
|
)
|
|
|
|
def test_custom_pre_grad_pass(self):
|
|
saved_graph = [None]
|
|
|
|
def merge_mm_shared_rhs(graph: fx.Graph):
|
|
"""
|
|
Bad POC of merging mm with a shared RHS.
|
|
i.e. [mm(x, W), mm(x2, W)] => mm(cat(x, x2), W).split()
|
|
|
|
Isn't actually safe for a couple reasons. For example, it doesn't handle the
|
|
case where the LHS inputs depend on each other
|
|
"""
|
|
saved_graph[0] = graph
|
|
matmuls = [n for n in graph.nodes if n.target == torch.mm]
|
|
rhs_vals = defaultdict(set)
|
|
for m in matmuls:
|
|
rhs_vals[m.args[1]].add(m)
|
|
|
|
order = {}
|
|
for idx, n in enumerate(graph.nodes):
|
|
order[n] = idx
|
|
|
|
for rhs, matmuls in rhs_vals.items():
|
|
if len(matmuls) == 1:
|
|
continue
|
|
matmuls = sorted(matmuls, key=lambda x: order[x])
|
|
with graph.inserting_before(matmuls[0]):
|
|
lhs_vals = [m.args[0] for m in matmuls]
|
|
new_cat = graph.create_node(
|
|
"call_function", torch.cat, args=(lhs_vals, 0)
|
|
)
|
|
new_mm = graph.create_node(
|
|
"call_function", torch.mm, args=(new_cat, rhs)
|
|
)
|
|
split_vals = graph.create_node(
|
|
"call_function",
|
|
torch.split,
|
|
args=(
|
|
new_mm,
|
|
[l.meta["example_value"].shape[0] for l in lhs_vals],
|
|
),
|
|
)
|
|
for idx, m in enumerate(matmuls):
|
|
m.target = operator.getitem
|
|
m.args = (split_vals, idx)
|
|
|
|
@config.patch(pre_grad_custom_pass=merge_mm_shared_rhs)
|
|
def inner_test():
|
|
@torch.compile
|
|
def f(W, nested_seqs):
|
|
outs = [torch.mm(s, W) for s in nested_seqs]
|
|
return outs
|
|
|
|
W = torch.randn(16, 16, dtype=torch.bfloat16)
|
|
nested_seqs = [
|
|
torch.randn(l, 16, dtype=torch.bfloat16) for l in [4, 8, 5, 3]
|
|
]
|
|
|
|
f(W, nested_seqs)
|
|
assert saved_graph[0] is not None
|
|
matmuls = [n for n in saved_graph[0].nodes if n.target == torch.mm]
|
|
assert len(matmuls) == 1
|
|
|
|
inner_test()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available():
|
|
run_tests()
|