Revert "Fix DCE eliminating in-place operations by improving Node.is_impure() (#162267)"

This reverts commit b9a7d0e13b4a34be83c778734dbad437c7c5117b.

Reverted https://github.com/pytorch/pytorch/pull/162267 on behalf of https://github.com/malfet due to Not sure how it happened, but looks like it broke everything, see c2388201fc/1 ([comment](https://github.com/pytorch/pytorch/pull/162267#issuecomment-3275164109))
This commit is contained in:
PyTorch MergeBot
2025-09-10 14:12:22 +00:00
parent c2388201fc
commit fc1b09a52a
2 changed files with 2 additions and 114 deletions

View File

@ -1,6 +1,5 @@
# Owner(s): ["module: fx"]
import copy
import inspect
import unittest
from typing import Optional
@ -39,39 +38,12 @@ class TestDCE(TestCase):
count += 1
return count
@torch.compiler.disable
def _trace_with_dynamo(self, m: torch.nn.Module) -> torch.fx.GraphModule:
"""Dynamo will keep in-place operations, whereas torch.fx.Tracer will remove them."""
graph_module: torch.fx.GraphModule | None = None
def _backend(gm: torch.fx.GraphModule, _):
nonlocal graph_module
graph_module = gm
return gm
inputs = [
torch.tensor([1.5])
for _ in range(len(inspect.signature(m.forward).parameters))
]
torch.compile(
m,
backend=_backend,
fullgraph=True,
)(*inputs)
assert graph_module is not None
# TorchDynamo returns a graph with flattened output; unflatten here for the test
graph_module.graph.output_node().args = graph_module.graph.output_node().args[0]
graph_module.recompile()
return graph_module
def _run_dce_and_test(
self,
m: torch.nn.Module,
expect_dce_changes: bool,
modules_to_be_leafs: Optional[set[type]] = None,
custom: bool = False,
use_dynamo_for_tracing: bool = False,
):
class TestTracer(torch.fx.Tracer):
def is_leaf_module(self, m, qualname):
@ -79,12 +51,7 @@ class TestDCE(TestCase):
return True
return super().trace(m, qualname)
if use_dynamo_for_tracing:
traced = self._trace_with_dynamo(m)
else:
traced: torch.fx.GraphModule = torch.fx.GraphModule(
m, TestTracer().trace(m)
)
traced: torch.fx.GraphModule = torch.fx.GraphModule(m, TestTracer().trace(m))
print(str(traced.graph))
# Verify there are nodes without users (if expected).
@ -113,7 +80,7 @@ class TestDCE(TestCase):
traced.recompile()
# Make sure we run and get the same results before/after DCE.
inputs = [torch.tensor([1.5]) for _ in range(new_num_phs)]
inputs = [torch.tensor([1.5])] * new_num_phs
inputs_copy = copy.deepcopy(inputs)
self.assertTrue(torch.equal(m(*inputs), traced(*inputs_copy)))
@ -215,57 +182,6 @@ class TestDCE(TestCase):
TestModule(), expect_dce_changes=False, modules_to_be_leafs={ReLUImpure}
)
def test_keep_inplace_with_side_effects(self):
"""
Test that DCE doesn't remove an inplace operation.
"""
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x.add_(2)
y = 2 * x
x.add_(y)
return y
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
def test_keep_inplace_python_operator_with_side_effects(self):
"""
Test that DCE doesn't remove an inplace operation.
"""
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x += y
x //= y
x %= y
x *= y
x -= y
x /= y
x @= y
x = x.reshape_as(y)
concat_a = [x]
concat_b = [y]
concat_a += concat_b
a = x.to(dtype=torch.long)
b = y.to(dtype=torch.long)
a //= b
a <<= b
a %= b
a |= b
a **= b
a >>= b
a ^= b
return x + y + concat_a[0] + a + b
self._run_dce_and_test(
TestModule(), expect_dce_changes=False, use_dynamo_for_tracing=True
)
def test_keep_torch_assert(self):
"""
Test that DCE doesn't remove torch._assert since it has side effects.

View File

@ -84,23 +84,6 @@ _side_effectful_need_to_be_preserved_pre_dispatch: list[Callable[..., Any]] = [
torch.amp._exit_autocast,
]
_side_effect_inplace: set[Callable[..., Any]] = {
operator.iadd,
operator.iand,
operator.iconcat,
operator.ifloordiv,
operator.ilshift,
operator.imod,
operator.imul,
operator.imatmul,
operator.ior,
operator.ipow,
operator.irshift,
operator.isub,
operator.itruediv,
operator.ixor,
}
# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs,
# or add logic to correctly mark all inplace ops as side effectful.
_side_effectful_functions: set[Callable[..., Any]] = {
@ -116,7 +99,6 @@ _side_effectful_functions: set[Callable[..., Any]] = {
_ops.profiler._record_function_exit,
_ops.inductor.accumulate_grad_.default,
operator.setitem,
*_side_effect_inplace,
*_side_effectful_need_to_be_preserved_pre_dispatch,
}
@ -831,16 +813,6 @@ class Node(_NodeBase):
)
return getattr(target_mod, "_is_impure", False)
if self.op == "call_method":
target_name = (
self.target
if isinstance(self.target, str)
else torch.typename(self.target)
)
# Check for functions with names ending in an underscore (e.g., 'add_') that are inplace in torch
if target_name.endswith("_"):
return True
return False
@compatibility(is_backward_compatible=False)