Compare commits

...

2 Commits

Author SHA1 Message Date
569fc2f868 Add check in Graph.eliminate_dead_code 2025-02-26 20:54:22 -08:00
89577b3e07 Remove unused rand call 2025-02-26 20:54:08 -08:00
3 changed files with 13 additions and 3 deletions

View File

@ -6497,6 +6497,9 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
torch.cuda.manual_seed_all(54321)
expected = f(torch.randn((2, 12, 16, 32, 32))).sum()
# https://github.com/pytorch/pytorch/issues/147171
torch._inductor.config.fallback_random = True
for backend in ["eager", "aot_eager"]:
torch.manual_seed(54321)
torch.cuda.manual_seed_all(54321)

View File

@ -1854,10 +1854,14 @@ class Graph:
# DCE below will not behave as expected.
self.lint()
impure_random = True
if torch._guards.TracingContext.try_get():
impure_random = torch._inductor.config.fallback_random
def has_side_effect(node):
if is_impure_node is not None:
return is_impure_node(node)
return node.is_impure()
return node.is_impure(impure_random)
# Reverse iterate so that when we remove a node, any nodes used as an
# input to that node have an updated user count that no longer reflects

View File

@ -758,11 +758,14 @@ class Node(_NodeBase):
return [n for n in to_process if n not in skipped]
@compatibility(is_backward_compatible=False)
def is_impure(self) -> bool:
def is_impure(self, impure_random: bool = True) -> bool:
"""
Returns whether this op is impure, i.e. if its op is a placeholder or
output, or if a call_function or call_module which is impure.
Args:
impure_random (bool): Whether to treat rand op as impure.
Returns:
bool: If the op is impure or not.
@ -778,7 +781,7 @@ class Node(_NodeBase):
if getattr(self.target, "_nondeterministic_seeded", False):
# impure since it mutates RNG state
return True
return impure_random
return self.target in _side_effectful_functions