mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
make functorch CSE respect mutations as barriers (like fsdp.set_) (#132243)
Fixes https://github.com/pytorch/pytorch/issues/132200 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132243 Approved by: https://github.com/albanD, https://github.com/zou3519, https://github.com/yf225
This commit is contained in:
committed by
PyTorch MergeBot
parent
ee0ae11b34
commit
4db368a475
@ -1614,6 +1614,12 @@ def is_mutation_op(node: torch.fx.Node) -> bool:
|
||||
return node.kwargs.get("out") is not None
|
||||
|
||||
|
||||
def same_mutation_regions(a: torch.fx.Node, b: torch.fx.Node) -> bool:
|
||||
assert "mutation_region_id" in a.meta
|
||||
assert "mutation_region_id" in b.meta
|
||||
return a.meta["mutation_region_id"] == b.meta["mutation_region_id"]
|
||||
|
||||
|
||||
def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int:
|
||||
n = node
|
||||
while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n):
|
||||
|
Reference in New Issue
Block a user