mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit b9a7d0e13b4a34be83c778734dbad437c7c5117b.
Reverted https://github.com/pytorch/pytorch/pull/162267 on behalf of https://github.com/malfet due to Not sure how it happened, but looks like it broke everything, see c2388201fc/1
([comment](https://github.com/pytorch/pytorch/pull/162267#issuecomment-3275164109))
388 lines
14 KiB
Python
388 lines
14 KiB
Python
# Owner(s): ["module: fx"]
|
|
import copy
|
|
import unittest
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.fx
|
|
from torch.testing._internal.common_utils import (
|
|
IS_MACOS,
|
|
raise_on_run_directly,
|
|
TestCase,
|
|
)
|
|
|
|
|
|
class TestDCE(TestCase):
|
|
def _custom_is_impure_node(self, node: torch.fx.Node) -> bool:
|
|
if node.is_impure():
|
|
return True
|
|
# a custom function that defines add operators as impure.
|
|
if node.target == torch.ops.aten.add:
|
|
return True
|
|
return False
|
|
|
|
def _has_nodes_without_users(self, m: torch.fx.GraphModule, custom: bool = False):
|
|
for node in m.graph.nodes:
|
|
if (not custom and node.is_impure()) or (
|
|
custom and self._custom_is_impure_node(node)
|
|
):
|
|
continue
|
|
if len(node.users) == 0:
|
|
return True
|
|
return False
|
|
|
|
def _get_num_placeholders(self, m: torch.fx.GraphModule) -> int:
|
|
count = 0
|
|
for node in m.graph.nodes:
|
|
if node.op == "placeholder":
|
|
count += 1
|
|
return count
|
|
|
|
def _run_dce_and_test(
|
|
self,
|
|
m: torch.nn.Module,
|
|
expect_dce_changes: bool,
|
|
modules_to_be_leafs: Optional[set[type]] = None,
|
|
custom: bool = False,
|
|
):
|
|
class TestTracer(torch.fx.Tracer):
|
|
def is_leaf_module(self, m, qualname):
|
|
if modules_to_be_leafs and type(m) in modules_to_be_leafs:
|
|
return True
|
|
return super().trace(m, qualname)
|
|
|
|
traced: torch.fx.GraphModule = torch.fx.GraphModule(m, TestTracer().trace(m))
|
|
print(str(traced.graph))
|
|
|
|
# Verify there are nodes without users (if expected).
|
|
has_nodes_without_users = self._has_nodes_without_users(traced, custom=custom)
|
|
if expect_dce_changes:
|
|
self.assertTrue(has_nodes_without_users)
|
|
else:
|
|
self.assertFalse(has_nodes_without_users)
|
|
|
|
# Get the original number of placeholders to verify it doesn't change
|
|
# during DCE.
|
|
orig_num_phs = self._get_num_placeholders(traced)
|
|
if custom:
|
|
changed = traced.graph.eliminate_dead_code(
|
|
is_impure_node=self._custom_is_impure_node
|
|
)
|
|
else:
|
|
changed = traced.graph.eliminate_dead_code()
|
|
|
|
self.assertTrue(changed if expect_dce_changes else not changed)
|
|
|
|
# Verify there are no nodes without users after DCE is run.
|
|
self.assertFalse(self._has_nodes_without_users(traced, custom=custom))
|
|
new_num_phs = self._get_num_placeholders(traced)
|
|
self.assertEqual(orig_num_phs, new_num_phs)
|
|
|
|
traced.recompile()
|
|
# Make sure we run and get the same results before/after DCE.
|
|
inputs = [torch.tensor([1.5])] * new_num_phs
|
|
inputs_copy = copy.deepcopy(inputs)
|
|
self.assertTrue(torch.equal(m(*inputs), traced(*inputs_copy)))
|
|
|
|
def test_simple(self):
|
|
"""
|
|
Tests that a single node in the graph is DCE'd correctly.
|
|
"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
|
|
|
|
def forward(self, x):
|
|
a = x + 1 # noqa: F841
|
|
return x + self.attr_1
|
|
|
|
self._run_dce_and_test(TestModule(), expect_dce_changes=True)
|
|
|
|
def test_dead_chain(self):
|
|
"""
|
|
Tests that a chain of two nodes in the graph are DCE'd correctly.
|
|
"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
|
|
|
|
def forward(self, x):
|
|
a = x + 1
|
|
b = a * 7 # noqa: F841
|
|
return x + self.attr_1
|
|
|
|
self._run_dce_and_test(TestModule(), expect_dce_changes=True)
|
|
|
|
def test_dead_getattr(self):
|
|
"""
|
|
Tests that a getatrr in the graph is DCE'd correctly.
|
|
"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
|
|
|
|
def forward(self, x):
|
|
a = x + 1
|
|
b = a * self.attr_1 # noqa: F841
|
|
return x + 11
|
|
|
|
self._run_dce_and_test(TestModule(), expect_dce_changes=True)
|
|
|
|
def test_dead_placeholder(self):
|
|
"""
|
|
Tests that a placeholder in the graph is not DCE'd, as that would change
|
|
the function signature.
|
|
"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + 7
|
|
|
|
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
|
|
|
|
def test_dead_placeholder_with_user(self):
|
|
"""
|
|
Tests that a placeholder in the graph is not DCE'd, as that would change
|
|
the function signature. Also verifies that a dead node that uses the
|
|
placeholder is DCE'd.
|
|
|
|
"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
a = y + 2 # noqa: F841
|
|
return x + 7
|
|
|
|
self._run_dce_and_test(TestModule(), expect_dce_changes=True)
|
|
|
|
def test_keep_module_with_side_effects(self):
|
|
"""
|
|
Test that DCE doesn't remove a module if it's specified as having side effects.
|
|
"""
|
|
|
|
class ReLUImpure(torch.nn.ReLU):
|
|
_is_impure = True
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.relu = ReLUImpure()
|
|
|
|
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
|
r = self.relu(a) # noqa: F841
|
|
return a * 2
|
|
|
|
self._run_dce_and_test(
|
|
TestModule(), expect_dce_changes=False, modules_to_be_leafs={ReLUImpure}
|
|
)
|
|
|
|
def test_keep_torch_assert(self):
|
|
"""
|
|
Test that DCE doesn't remove torch._assert since it has side effects.
|
|
"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
|
torch._assert(torch.equal(a, a), "a must equal a")
|
|
return a * 2
|
|
|
|
# Note: Don't need to specify torch._assert as having side effects
|
|
# because it's known to.
|
|
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
|
|
|
|
def test_keep_setitem(self):
|
|
"""
|
|
Fix issue: https://github.com/pytorch/pytorch/issues/145697
|
|
Test that DCE doesn't remove operator.setitem since it has side effects.
|
|
"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
|
a[0, 0, 0, 0] *= 2.0
|
|
return a * 2
|
|
|
|
def dce_backend(gm, inputs, **kwargs):
|
|
import torch._inductor.constant_folding
|
|
|
|
torch._inductor.constant_folding.constant_fold(gm)
|
|
return gm
|
|
|
|
x = torch.randn(1, 3, 224, 224)
|
|
dce_x = x.detach().clone()
|
|
model = TestModule().eval()
|
|
dce_mod = torch.compile(copy.deepcopy(model), backend=dce_backend)
|
|
|
|
with torch.inference_mode():
|
|
eager_out = model(x)
|
|
out = dce_mod(dce_x)
|
|
self.assertEqual(eager_out, out, atol=1e-5, rtol=1e-5)
|
|
|
|
def test_impure_nodes_args(self):
|
|
"""
|
|
Test that DCE doesn't remove call_function nodes with side effects.
|
|
"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
|
torch._ops.ops.aten.add_.Tensor(a, 1)
|
|
return a * 2
|
|
|
|
# %add_ node should not be removed because it has side effects.
|
|
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
|
|
|
|
def test_impure_random(self):
|
|
"""
|
|
Test that DCE doesn't remove call_function for torch.rand and other random functions.
|
|
Tests both FX tracing and AOT compilation (issue #151524).
|
|
"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
|
x = torch.rand([10]) # noqa: F841
|
|
return a * 2
|
|
|
|
# Test FX tracing + DCE
|
|
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
|
|
|
|
# Test comprehensive random functions in AOT compilation
|
|
class ComprehensiveRandomModule(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# Test various random functions that should be preserved
|
|
a = torch.rand(1) # noqa: F841
|
|
b = torch.randn(1) # noqa: F841
|
|
c = torch.randint(0, 10, (1,)) # noqa: F841
|
|
d = torch.randperm(5) # noqa: F841
|
|
e = torch.normal(0, 1, (1,)) # noqa: F841
|
|
f = torch.poisson(torch.tensor([1.0])) # noqa: F841
|
|
g = torch.rand(1) # Used
|
|
|
|
# Test that random operations with explicit generators are also preserved
|
|
gen = torch.Generator().manual_seed(123)
|
|
h = torch.rand(1, generator=gen) # noqa: F841
|
|
i = torch.randn(1, generator=gen) # noqa: F841
|
|
j = torch.rand(1, generator=gen) # Used
|
|
return x + g + j
|
|
|
|
def aot_backend(gm, example_inputs):
|
|
def count_random_ops():
|
|
return len(
|
|
[
|
|
n
|
|
for n in gm.graph.nodes
|
|
if n.op == "call_function"
|
|
and any(
|
|
fn in str(n.target)
|
|
for fn in [
|
|
"rand",
|
|
"randn",
|
|
"randint",
|
|
"randperm",
|
|
"normal",
|
|
"poisson",
|
|
]
|
|
)
|
|
]
|
|
)
|
|
|
|
rand_count = count_random_ops()
|
|
gm.graph.eliminate_dead_code()
|
|
self.assertEqual(
|
|
count_random_ops(), rand_count, "Random ops should be preserved"
|
|
)
|
|
return gm.forward
|
|
|
|
model = ComprehensiveRandomModule()
|
|
torch.manual_seed(42)
|
|
eager_result = model(torch.tensor([1.0]))
|
|
torch.manual_seed(42)
|
|
compiled_result = torch.compile(model, backend=aot_backend)(torch.tensor([1.0]))
|
|
self.assertEqual(eager_result, compiled_result)
|
|
|
|
def test_impure_kwargs(self):
|
|
"""
|
|
Test that DCE doesn't remove call_function nodes with side effects on kwargs.
|
|
"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
|
b = a + 1
|
|
torch._ops.ops.aten.add.out(b, b, out=a, alpha=2)
|
|
return a
|
|
|
|
# %add_out node should not be removed because it has side effects.
|
|
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
|
|
|
|
def test_impure_custom(self):
|
|
"""
|
|
Test that DCE doesn't remove nodes marked as impure by a custom function.
|
|
"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
|
b = a + 1
|
|
c = torch._ops.ops.aten.add(b, b) # noqa: F841
|
|
return a
|
|
|
|
# %add_out node should not be removed because it has side effects.
|
|
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True)
|
|
|
|
@unittest.skipIf(IS_MACOS, "Not working on macos")
|
|
def test_keep_collectives(self):
|
|
"""
|
|
Test that DCE doesn't remote collective ops even the results are not used.
|
|
"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def forward(
|
|
self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
|
|
) -> torch.Tensor:
|
|
d = torch.ops.aten.mul.Tensor(a, b)
|
|
e = torch.ops.aten.mul.Tensor(a, c)
|
|
future = torch.ops._c10d_functional.all_reduce.default(e, "sum", "0")
|
|
torch.ops._c10d_functional.wait_tensor.default(future)
|
|
return d
|
|
|
|
torch.distributed.init_process_group(
|
|
backend="fake",
|
|
world_size=2,
|
|
rank=0,
|
|
)
|
|
# collective nodes should not be removed because they have side effects.
|
|
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False)
|
|
torch.distributed.destroy_process_group()
|
|
|
|
@unittest.skipIf(IS_MACOS, "Not working on macos")
|
|
def test_keep_collectives_no_overload(self):
|
|
"""
|
|
Test that DCE doesn't remote collective ops (no overload version) even the results are not used.
|
|
"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def forward(
|
|
self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
|
|
) -> torch.Tensor:
|
|
d = torch.ops.aten.mul(a, b)
|
|
e = torch.ops.aten.mul(a, c)
|
|
future = torch.ops._c10d_functional.all_reduce(e, "sum", "0")
|
|
torch.ops._c10d_functional.wait_tensor(future)
|
|
return d
|
|
|
|
torch.distributed.init_process_group(
|
|
backend="fake",
|
|
world_size=2,
|
|
rank=0,
|
|
)
|
|
# collective nodes should not be removed because they have side effects.
|
|
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False)
|
|
torch.distributed.destroy_process_group()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_fx.py")
|