make functorch CSE respect mutations as barriers (like fsdp.set_) (#132243)

Fixes https://github.com/pytorch/pytorch/issues/132200

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132243
Approved by: https://github.com/albanD, https://github.com/zou3519, https://github.com/yf225
This commit is contained in:
Brian Hirsh
2024-08-02 11:59:38 -07:00
committed by PyTorch MergeBot
parent ee0ae11b34
commit 4db368a475
3 changed files with 55 additions and 3 deletions

View File

@ -1614,6 +1614,12 @@ def is_mutation_op(node: torch.fx.Node) -> bool:
return node.kwargs.get("out") is not None
def same_mutation_regions(a: torch.fx.Node, b: torch.fx.Node) -> bool:
assert "mutation_region_id" in a.meta
assert "mutation_region_id" in b.meta
return a.meta["mutation_region_id"] == b.meta["mutation_region_id"]
def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int:
n = node
while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n):