Revert "[inductor] Preserve metadata across replace_by_example and register_replacement patterns (#138089)"

This reverts commit fb44658415e50b5be6a187ff3f14243c0fdf3daf.

Reverted https://github.com/pytorch/pytorch/pull/138089 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but the new test_original_aten_preserved_pad_mm test runs OOM in trunk fb44658415 ([comment](https://github.com/pytorch/pytorch/pull/138089#issuecomment-2424297269))
This commit is contained in:
PyTorch MergeBot
2024-10-19 23:55:01 +00:00
parent fcedf93d1e
commit 47e80abc7a
4 changed files with 3 additions and 84 deletions

View File

@ -279,7 +279,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
for out_code in [code, code2]:
FileCheck().check("def call").check_count(
"empty_strided_cuda", 1, exactly=True
).check("triton_tem_fused_addmm_relu_0.run").check_count(
).check("triton_tem_fused_relu_0.run").check_count(
"del", 3, exactly=True
).check(
"return"

View File

@ -4,7 +4,6 @@ import unittest
import torch
import torch._inductor.config as inductor_config
from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import counters
from torch._inductor.fx_passes.pad_mm import (
get_alignment_size,
get_pad_cache,
@ -490,36 +489,6 @@ class PadMMTest(TestCase):
assert torch.allclose(res2, mm_expected_result), "MM results are not identical"
@fresh_inductor_cache()
@inductor_config.patch(
{
"triton.unique_kernel_names": "original_aten",
"max_autotune_gemm_backends": "TRITON",
"shape_padding": True,
}
)
def test_original_aten_preserved_pad_mm(self):
def fn(x, y):
return x @ y
args = [
torch.randn(2**14, 2**16 - 1, device="cuda", dtype=torch.float16),
torch.randn(2**16 - 1, 2**14, device="cuda", dtype=torch.float16),
]
counters.clear()
with unittest.mock.patch(
"torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
):
opt_fn = torch.compile(fn, mode="max-autotune")
ret, code = run_and_get_code(opt_fn, *args)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
# The mm kernel should use a template (because we set max_autotune_gemm_backends = TRITON).
# Its name should contain `mm` because `mm` was the original aten op where the mm came from.
FileCheck().check("def triton_tem_fused_mm").run(code[0])
if __name__ == "__main__":
if HAS_CUDA:

View File

@ -1234,36 +1234,6 @@ class TestPatternMatcher(TestCase):
# of search_fn).
self.assertTrue(pattern.pattern_eq(search_fn_pattern))
@inductor_config.patch(
{
"triton.unique_kernel_names": "original_aten",
"fx_graph_remote_cache": False,
"max_autotune_gemm_backends": "TRITON",
}
)
def test_original_aten_preserved_split_addmm(self):
# addmm -> elementwise should be decomposed into mm -> add -> elementwise
def fn(x, y, z):
return torch.addmm(z, x, y).sin()
args = [
torch.randn(16, 24, device=GPU_TYPE),
torch.randn(24, 32, device=GPU_TYPE),
torch.randn(16, 32, device=GPU_TYPE),
]
counters.clear()
opt_fn = torch.compile(fn, mode="max-autotune")
ret, code = run_and_get_code(opt_fn, *args)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
# The mm kernel should use a template (because we set max_autotune_gemm_backends = TRITON).
# Its name should contain `addmm` because `addmm` was the original aten op where the mm came from.
FileCheck().check_not("extern_kernels.addmm(").check(
"def triton_tem_fused_addmm"
).run(code[0])
@inductor_config.patch(fx_graph_remote_cache=False)
def test_match_equivalent_function_invocations1(self):
counter = 0

View File

@ -139,15 +139,6 @@ class Multiple:
MULTIPLE = Multiple()
def _transfer_meta(new_meta: Dict[str, Any], old_meta: Dict[str, Any]) -> None:
# transfer metadata after pattern matching occurs.
# skip "val" and "tensor_meta" because this info is too specific; it's unlikely
# to remain accurate after pattern matching has occurred.
new_meta.update(
(k, v) for k, v in old_meta.items() if k in torch.fx.proxy._COPY_META_FIELDS
)
class Match:
"""
Represents a successfully matched pattern.
@ -166,7 +157,7 @@ class Match:
nodes: List[torch.fx.Node]
targets: Dict[_TargetExpr, torch.fx.node.Target]
ctx: MatchContext
replacement_graph: Optional[torch.fx.GraphModule]
replacement_graph: Optional[torch.fx.Graph]
def __init__(
self,
@ -262,10 +253,6 @@ class Match:
replacement = trace_fn(
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type]
)
if len(self.nodes) == 1:
for n in replacement.graph.nodes:
_transfer_meta(new_meta=n.meta, old_meta=self.nodes[0].meta)
ReplacementPatternEntry.replace_with_graph(
self,
self.ctx.graph,
@ -1062,7 +1049,6 @@ class ReplacementPatternEntry(PatternEntry):
target = node.target
args, kwargs = self.fetch_args_kwargs_from_env(node)
result = graph.call_function(target, args, kwargs) # type: ignore[arg-type]
_transfer_meta(new_meta=result.meta, old_meta=node.meta)
if "val" in node.meta and "val" not in result.meta:
result.meta["val"] = node.meta["val"]
if isinstance(node.meta["val"], torch.Tensor):
@ -1344,13 +1330,7 @@ def register_replacement(
if is_match(specific_pattern_match) and extra_check(specific_pattern_match):
# trace the pattern using the shapes from the user program
match.replacement_graph = trace_fn(replace_fn, args)
if len(match.nodes) == 1:
for n in match.replacement_graph.graph.nodes:
_transfer_meta(
new_meta=n.meta,
old_meta=match.nodes[0].meta,
)
match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment]
return True
return False