Files
pytorch/test/inductor/test_custom_post_grad_passes.py

273 lines
9.2 KiB
Python

# Owner(s): ["module: inductor"]
import contextlib
import operator
from collections import defaultdict
import torch
import torch._inductor.pattern_matcher as pattern_matcher
import torch.fx as fx
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
from torch._inductor.lowering import lowerings as L
from torch._inductor.pattern_matcher import Arg, CallFunction, PatternMatcherPass
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CPU
@config.patch({"freezing": True})
class TestCustomPassBase(TestCase):
def _clone_inputs(self, inputs):
def clone(x):
if not isinstance(x, torch.Tensor):
return x
return x.clone()
return tuple(clone(x) for x in inputs)
def _test_common(
self,
mod,
inputs,
matcher_count,
matcher_nodes,
atol=1e-5,
rtol=1.3e-6,
):
counters.clear()
maybe_autocast = contextlib.nullcontext()
with torch.no_grad(), maybe_autocast:
clone_inputs = self._clone_inputs(inputs)
expected = mod(*inputs)
actual = torch.compile(mod)(*clone_inputs)
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
self.assertEqual(
counters["inductor"]["pattern_matcher_count"], matcher_count
)
self.assertEqual(
counters["inductor"]["pattern_matcher_nodes"],
matcher_nodes,
)
aten = torch.ops.aten
mkldnn = torch.ops.mkldnn
def change_cos_pass(graph):
for node in graph.nodes:
if node.op == "call_function" and node.target == aten.cos.default:
node.target = aten.sin.default
class TestPostGradCustomPrePostPass(TestCustomPassBase):
# mkldnn fusion's pattern_matcher
# (torch/_inductor/fx_passes/mkldnn_fusion.py),
# and apply it to custom post_grad_passes.
def _register_mkldnn_conv_relu_fusion(self, custom_pass_dict):
# pattern
def _mkldnn_conv_relu_pattern():
return CallFunction(
aten.relu,
CallFunction(
mkldnn._convolution_pointwise.default,
Arg(),
Arg(),
Arg(),
Arg(),
Arg(),
Arg(),
Arg(),
Arg(),
Arg(),
Arg(),
_users=1,
),
)
# utils of pattern matcher registration
def _register_fusion_lowering(pattern, custom_pass_dict):
def dummy_check(m):
return True
def register_custom_lowering_pattern(
pattern, extra_check, custom_pass_dict
):
return pattern_matcher.register_lowering_pattern(
pattern, extra_check, pass_dict=custom_pass_dict
)
@register_custom_lowering_pattern(pattern, dummy_check, custom_pass_dict)
def fn(match, *args, **kwargs):
computation_args = list(args)[:-3] + ["relu", [], ""]
return L[mkldnn._convolution_pointwise.default](*computation_args)
return fn
_register_fusion_lowering(_mkldnn_conv_relu_pattern(), custom_pass_dict)
# custom post grad pass
class _CustomPass(PatternMatcherPass, CustomGraphPass):
def __init__(self) -> None:
super().__init__()
def __call__(self, g: torch.fx.graph.Graph):
self.apply(g)
def uuid(self) -> bytes:
return get_hash_for_files((__file__,))
# case model
class _ConvReLU(torch.nn.Module):
def __init__(self, ic, oc):
super().__init__()
self.conv = torch.nn.Conv2d(ic, oc, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x1 = self.conv(x)
return x1.relu()
def test_custom_joint_pass_pre(self):
with config.patch(joint_custom_pre_pass=change_cos_pass):
def g(x):
return x.sin().sin().sin()
def f(x):
return x.cos().cos().cos()
x = torch.randn(8, dtype=torch.float32)
torch.testing.assert_close(torch.compile(f)(x), g(x))
def test_custom_joint_pass_post(self):
with config.patch(joint_custom_post_pass=change_cos_pass):
def g(x):
return x.sin().sin().sin()
def f(x):
return x.cos().cos().cos()
x = torch.randn(8, dtype=torch.float32)
torch.testing.assert_close(torch.compile(f)(x), g(x))
def test_custom_pre_pass(self):
with config.patch(
# leave custom pass only in post_grad_passes()
pattern_matcher=False,
post_grad_custom_pre_pass=self._CustomPass(),
# define pattern match as custom post grad opt pass
post_grad_custom_post_pass=None,
):
# init mkldnn fusion on custom_matcher
self._register_mkldnn_conv_relu_fusion(config.post_grad_custom_pre_pass)
mod = self._ConvReLU(16, 16).eval()
x = torch.randn((1, 16, 56, 56), dtype=torch.float32)
match_count = 1
match_nodes = 2
other_match_count = 1 # conv prepack weight
other_match_nodes = 1 # conv prepack weight
self._test_common(
mod,
(x,),
match_count + other_match_count,
match_nodes + other_match_nodes,
)
def test_custom_post_pass(self):
with config.patch(
# leave custom pass only in post_grad_passes()
pattern_matcher=False,
# define pattern match as custom post grad opt pass
post_grad_custom_pre_pass=None,
post_grad_custom_post_pass=self._CustomPass(),
):
# init mkldnn fusion on custom_matcher
self._register_mkldnn_conv_relu_fusion(config.post_grad_custom_post_pass)
mod = self._ConvReLU(16, 16).eval()
x = torch.randn((1, 16, 56, 56), dtype=torch.float32)
match_count = 1
match_nodes = 2
other_match_count = 1 # conv prepack weight
other_match_nodes = 1 # conv prepack weight
self._test_common(
mod,
(x,),
match_count + other_match_count,
match_nodes + other_match_nodes,
)
def test_custom_pre_grad_pass(self):
saved_graph = [None]
def merge_mm_shared_rhs(graph: fx.Graph):
"""
Bad POC of merging mm with a shared RHS.
i.e. [mm(x, W), mm(x2, W)] => mm(cat(x, x2), W).split()
Isn't actually safe for a couple reasons. For example, it doesn't handle the
case where the LHS inputs depend on each other
"""
saved_graph[0] = graph
matmuls = [n for n in graph.nodes if n.target == torch.mm]
rhs_vals = defaultdict(set)
for m in matmuls:
rhs_vals[m.args[1]].add(m)
order = {}
for idx, n in enumerate(graph.nodes):
order[n] = idx
for rhs, matmuls in rhs_vals.items():
if len(matmuls) == 1:
continue
matmuls = sorted(matmuls, key=lambda x: order[x])
with graph.inserting_before(matmuls[0]):
lhs_vals = [m.args[0] for m in matmuls]
new_cat = graph.create_node(
"call_function", torch.cat, args=(lhs_vals, 0)
)
new_mm = graph.create_node(
"call_function", torch.mm, args=(new_cat, rhs)
)
split_vals = graph.create_node(
"call_function",
torch.split,
args=(
new_mm,
[l.meta["example_value"].shape[0] for l in lhs_vals],
),
)
for idx, m in enumerate(matmuls):
m.target = operator.getitem
m.args = (split_vals, idx)
@config.patch(pre_grad_custom_pass=merge_mm_shared_rhs)
def inner_test():
@torch.compile
def f(W, nested_seqs):
outs = [torch.mm(s, W) for s in nested_seqs]
return outs
W = torch.randn(16, 16, dtype=torch.bfloat16)
nested_seqs = [
torch.randn(l, 16, dtype=torch.bfloat16) for l in [4, 8, 5, 3]
]
f(W, nested_seqs)
assert saved_graph[0] is not None
matmuls = [n for n in saved_graph[0].nodes if n.target == torch.mm]
assert len(matmuls) == 1
inner_test()
if __name__ == "__main__":
if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available():
run_tests()