mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[inductor] Preserve metadata across replace_by_example and register_replacement patterns (#138089)"
This reverts commit fb44658415e50b5be6a187ff3f14243c0fdf3daf.
Reverted https://github.com/pytorch/pytorch/pull/138089 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but the new test_original_aten_preserved_pad_mm test runs OOM in trunk fb44658415 ([comment](https://github.com/pytorch/pytorch/pull/138089#issuecomment-2424297269))
This commit is contained in:
@ -139,15 +139,6 @@ class Multiple:
|
||||
MULTIPLE = Multiple()
|
||||
|
||||
|
||||
def _transfer_meta(new_meta: Dict[str, Any], old_meta: Dict[str, Any]) -> None:
|
||||
# transfer metadata after pattern matching occurs.
|
||||
# skip "val" and "tensor_meta" because this info is too specific; it's unlikely
|
||||
# to remain accurate after pattern matching has occurred.
|
||||
new_meta.update(
|
||||
(k, v) for k, v in old_meta.items() if k in torch.fx.proxy._COPY_META_FIELDS
|
||||
)
|
||||
|
||||
|
||||
class Match:
|
||||
"""
|
||||
Represents a successfully matched pattern.
|
||||
@ -166,7 +157,7 @@ class Match:
|
||||
nodes: List[torch.fx.Node]
|
||||
targets: Dict[_TargetExpr, torch.fx.node.Target]
|
||||
ctx: MatchContext
|
||||
replacement_graph: Optional[torch.fx.GraphModule]
|
||||
replacement_graph: Optional[torch.fx.Graph]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -262,10 +253,6 @@ class Match:
|
||||
replacement = trace_fn(
|
||||
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type]
|
||||
)
|
||||
if len(self.nodes) == 1:
|
||||
for n in replacement.graph.nodes:
|
||||
_transfer_meta(new_meta=n.meta, old_meta=self.nodes[0].meta)
|
||||
|
||||
ReplacementPatternEntry.replace_with_graph(
|
||||
self,
|
||||
self.ctx.graph,
|
||||
@ -1062,7 +1049,6 @@ class ReplacementPatternEntry(PatternEntry):
|
||||
target = node.target
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
||||
result = graph.call_function(target, args, kwargs) # type: ignore[arg-type]
|
||||
_transfer_meta(new_meta=result.meta, old_meta=node.meta)
|
||||
if "val" in node.meta and "val" not in result.meta:
|
||||
result.meta["val"] = node.meta["val"]
|
||||
if isinstance(node.meta["val"], torch.Tensor):
|
||||
@ -1344,13 +1330,7 @@ def register_replacement(
|
||||
|
||||
if is_match(specific_pattern_match) and extra_check(specific_pattern_match):
|
||||
# trace the pattern using the shapes from the user program
|
||||
match.replacement_graph = trace_fn(replace_fn, args)
|
||||
if len(match.nodes) == 1:
|
||||
for n in match.replacement_graph.graph.nodes:
|
||||
_transfer_meta(
|
||||
new_meta=n.meta,
|
||||
old_meta=match.nodes[0].meta,
|
||||
)
|
||||
match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment]
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user