mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "Fix DCE eliminating in-place operations by improving Node.is_impure() (#162267)"
This reverts commit b9a7d0e13b4a34be83c778734dbad437c7c5117b.
Reverted https://github.com/pytorch/pytorch/pull/162267 on behalf of https://github.com/malfet due to Not sure how it happened, but looks like it broke everything, see c2388201fc/1
([comment](https://github.com/pytorch/pytorch/pull/162267#issuecomment-3275164109))
This commit is contained in:
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["module: fx"]
|
||||
import copy
|
||||
import inspect
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
@ -39,39 +38,12 @@ class TestDCE(TestCase):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@torch.compiler.disable
|
||||
def _trace_with_dynamo(self, m: torch.nn.Module) -> torch.fx.GraphModule:
|
||||
"""Dynamo will keep in-place operations, whereas torch.fx.Tracer will remove them."""
|
||||
graph_module: torch.fx.GraphModule | None = None
|
||||
|
||||
def _backend(gm: torch.fx.GraphModule, _):
|
||||
nonlocal graph_module
|
||||
graph_module = gm
|
||||
return gm
|
||||
|
||||
inputs = [
|
||||
torch.tensor([1.5])
|
||||
for _ in range(len(inspect.signature(m.forward).parameters))
|
||||
]
|
||||
torch.compile(
|
||||
m,
|
||||
backend=_backend,
|
||||
fullgraph=True,
|
||||
)(*inputs)
|
||||
assert graph_module is not None
|
||||
|
||||
# TorchDynamo returns a graph with flattened output; unflatten here for the test
|
||||
graph_module.graph.output_node().args = graph_module.graph.output_node().args[0]
|
||||
graph_module.recompile()
|
||||
return graph_module
|
||||
|
||||
def _run_dce_and_test(
|
||||
self,
|
||||
m: torch.nn.Module,
|
||||
expect_dce_changes: bool,
|
||||
modules_to_be_leafs: Optional[set[type]] = None,
|
||||
custom: bool = False,
|
||||
use_dynamo_for_tracing: bool = False,
|
||||
):
|
||||
class TestTracer(torch.fx.Tracer):
|
||||
def is_leaf_module(self, m, qualname):
|
||||
@ -79,12 +51,7 @@ class TestDCE(TestCase):
|
||||
return True
|
||||
return super().trace(m, qualname)
|
||||
|
||||
if use_dynamo_for_tracing:
|
||||
traced = self._trace_with_dynamo(m)
|
||||
else:
|
||||
traced: torch.fx.GraphModule = torch.fx.GraphModule(
|
||||
m, TestTracer().trace(m)
|
||||
)
|
||||
traced: torch.fx.GraphModule = torch.fx.GraphModule(m, TestTracer().trace(m))
|
||||
print(str(traced.graph))
|
||||
|
||||
# Verify there are nodes without users (if expected).
|
||||
@ -113,7 +80,7 @@ class TestDCE(TestCase):
|
||||
|
||||
traced.recompile()
|
||||
# Make sure we run and get the same results before/after DCE.
|
||||
inputs = [torch.tensor([1.5]) for _ in range(new_num_phs)]
|
||||
inputs = [torch.tensor([1.5])] * new_num_phs
|
||||
inputs_copy = copy.deepcopy(inputs)
|
||||
self.assertTrue(torch.equal(m(*inputs), traced(*inputs_copy)))
|
||||
|
||||
@ -215,57 +182,6 @@ class TestDCE(TestCase):
|
||||
TestModule(), expect_dce_changes=False, modules_to_be_leafs={ReLUImpure}
|
||||
)
|
||||
|
||||
def test_keep_inplace_with_side_effects(self):
|
||||
"""
|
||||
Test that DCE doesn't remove an inplace operation.
|
||||
"""
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x.add_(2)
|
||||
y = 2 * x
|
||||
x.add_(y)
|
||||
return y
|
||||
|
||||
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
|
||||
|
||||
def test_keep_inplace_python_operator_with_side_effects(self):
|
||||
"""
|
||||
Test that DCE doesn't remove an inplace operation.
|
||||
"""
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
x += y
|
||||
x //= y
|
||||
x %= y
|
||||
x *= y
|
||||
x -= y
|
||||
x /= y
|
||||
x @= y
|
||||
|
||||
x = x.reshape_as(y)
|
||||
concat_a = [x]
|
||||
concat_b = [y]
|
||||
concat_a += concat_b
|
||||
|
||||
a = x.to(dtype=torch.long)
|
||||
b = y.to(dtype=torch.long)
|
||||
|
||||
a //= b
|
||||
a <<= b
|
||||
a %= b
|
||||
a |= b
|
||||
a **= b
|
||||
a >>= b
|
||||
a ^= b
|
||||
|
||||
return x + y + concat_a[0] + a + b
|
||||
|
||||
self._run_dce_and_test(
|
||||
TestModule(), expect_dce_changes=False, use_dynamo_for_tracing=True
|
||||
)
|
||||
|
||||
def test_keep_torch_assert(self):
|
||||
"""
|
||||
Test that DCE doesn't remove torch._assert since it has side effects.
|
||||
|
@ -84,23 +84,6 @@ _side_effectful_need_to_be_preserved_pre_dispatch: list[Callable[..., Any]] = [
|
||||
torch.amp._exit_autocast,
|
||||
]
|
||||
|
||||
_side_effect_inplace: set[Callable[..., Any]] = {
|
||||
operator.iadd,
|
||||
operator.iand,
|
||||
operator.iconcat,
|
||||
operator.ifloordiv,
|
||||
operator.ilshift,
|
||||
operator.imod,
|
||||
operator.imul,
|
||||
operator.imatmul,
|
||||
operator.ior,
|
||||
operator.ipow,
|
||||
operator.irshift,
|
||||
operator.isub,
|
||||
operator.itruediv,
|
||||
operator.ixor,
|
||||
}
|
||||
|
||||
# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs,
|
||||
# or add logic to correctly mark all inplace ops as side effectful.
|
||||
_side_effectful_functions: set[Callable[..., Any]] = {
|
||||
@ -116,7 +99,6 @@ _side_effectful_functions: set[Callable[..., Any]] = {
|
||||
_ops.profiler._record_function_exit,
|
||||
_ops.inductor.accumulate_grad_.default,
|
||||
operator.setitem,
|
||||
*_side_effect_inplace,
|
||||
*_side_effectful_need_to_be_preserved_pre_dispatch,
|
||||
}
|
||||
|
||||
@ -831,16 +813,6 @@ class Node(_NodeBase):
|
||||
)
|
||||
return getattr(target_mod, "_is_impure", False)
|
||||
|
||||
if self.op == "call_method":
|
||||
target_name = (
|
||||
self.target
|
||||
if isinstance(self.target, str)
|
||||
else torch.typename(self.target)
|
||||
)
|
||||
# Check for functions with names ending in an underscore (e.g., 'add_') that are inplace in torch
|
||||
if target_name.endswith("_"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
|
Reference in New Issue
Block a user