Revert "[inductor] Preserve metadata across replace_by_example and register_replacement patterns (#138089)"

This reverts commit fb44658415e50b5be6a187ff3f14243c0fdf3daf.

Reverted https://github.com/pytorch/pytorch/pull/138089 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but the new test_original_aten_preserved_pad_mm test runs OOM in trunk fb44658415 ([comment](https://github.com/pytorch/pytorch/pull/138089#issuecomment-2424297269))
This commit is contained in:
PyTorch MergeBot
2024-10-19 23:55:01 +00:00
parent fcedf93d1e
commit 47e80abc7a
4 changed files with 3 additions and 84 deletions

View File

@ -139,15 +139,6 @@ class Multiple:
MULTIPLE = Multiple()
def _transfer_meta(new_meta: Dict[str, Any], old_meta: Dict[str, Any]) -> None:
# transfer metadata after pattern matching occurs.
# skip "val" and "tensor_meta" because this info is too specific; it's unlikely
# to remain accurate after pattern matching has occurred.
new_meta.update(
(k, v) for k, v in old_meta.items() if k in torch.fx.proxy._COPY_META_FIELDS
)
class Match:
"""
Represents a successfully matched pattern.
@ -166,7 +157,7 @@ class Match:
nodes: List[torch.fx.Node]
targets: Dict[_TargetExpr, torch.fx.node.Target]
ctx: MatchContext
replacement_graph: Optional[torch.fx.GraphModule]
replacement_graph: Optional[torch.fx.Graph]
def __init__(
self,
@ -262,10 +253,6 @@ class Match:
replacement = trace_fn(
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type]
)
if len(self.nodes) == 1:
for n in replacement.graph.nodes:
_transfer_meta(new_meta=n.meta, old_meta=self.nodes[0].meta)
ReplacementPatternEntry.replace_with_graph(
self,
self.ctx.graph,
@ -1062,7 +1049,6 @@ class ReplacementPatternEntry(PatternEntry):
target = node.target
args, kwargs = self.fetch_args_kwargs_from_env(node)
result = graph.call_function(target, args, kwargs) # type: ignore[arg-type]
_transfer_meta(new_meta=result.meta, old_meta=node.meta)
if "val" in node.meta and "val" not in result.meta:
result.meta["val"] = node.meta["val"]
if isinstance(node.meta["val"], torch.Tensor):
@ -1344,13 +1330,7 @@ def register_replacement(
if is_match(specific_pattern_match) and extra_check(specific_pattern_match):
# trace the pattern using the shapes from the user program
match.replacement_graph = trace_fn(replace_fn, args)
if len(match.nodes) == 1:
for n in match.replacement_graph.graph.nodes:
_transfer_meta(
new_meta=n.meta,
old_meta=match.nodes[0].meta,
)
match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment]
return True
return False