Refactor partitioner and clean it up (#126318)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126318
Approved by: https://github.com/anijain2305
This commit is contained in:
chilli
2024-05-16 14:44:55 -07:00
committed by PyTorch MergeBot
parent 5756b53dd8
commit f9a7033194
4 changed files with 668 additions and 605 deletions

View File

@ -25,7 +25,6 @@ from torch._functorch.fx_minifier import minifier
from torch._functorch.partitioners import (
default_partition,
draw_graph,
draw_joint_graph,
min_cut_rematerialization_partition,
)
from torch._functorch.python_key import pythonkey_decompose

View File

@ -4835,70 +4835,6 @@ class TestPartitioning(AOTTestCase):
self.assertEqual(get_num_ins_outs(fw_graph), (4, 2))
self.assertEqual(get_num_ins_outs(bw_graph), (2, 4))
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
def test_min_cut_partitioner_recomputable_ops(self):
def f(x):
return x * x * x
recomputable_ops = []
partition_fn = partial(
min_cut_rematerialization_partition, recomputable_ops=recomputable_ops
)
fw_graph, bw_graph = get_fw_bw_graph(
f, [torch.randn(3, requires_grad=True)], partition_fn
)
# Expected forward graph:
# opcode name target args kwargs
# ------------- --------- --------------- -------------------------- --------
# placeholder primals_1 primals_1 () {}
# call_function mul aten.mul.Tensor (primals_1, primals_1) {}
# call_function mul_1 aten.mul.Tensor (mul, primals_1) {}
# output output output ([mul_1, primals_1, mul],) {}
self.assertEqual(get_num_ins_outs(fw_graph), (1, 3))
# Expected backward graph:
# opcode name target args kwargs
# ------------- ---------- --------------- ----------------------- --------
# placeholder primals_1 primals_1 () {}
# placeholder mul mul () {}
# placeholder tangents_1 tangents_1 () {}
# call_function mul_2 aten.mul.Tensor (tangents_1, mul) {}
# call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {}
# call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {}
# call_function add aten.add.Tensor (mul_2, mul_4) {}
# call_function add_1 aten.add.Tensor (add, mul_4) {}
# output output output ([add_1],) {}
self.assertEqual(get_num_ins_outs(bw_graph), (3, 1))
recomputable_ops = [torch.ops.aten.mul]
partition_fn = partial(
min_cut_rematerialization_partition, recomputable_ops=recomputable_ops
)
fw_graph, bw_graph = get_fw_bw_graph(
f, [torch.randn(3, requires_grad=True)], partition_fn
)
# Expected forward graph:
# opcode name target args kwargs
# ------------- --------- --------------- ---------------------- --------
# placeholder primals_1 primals_1 () {}
# call_function mul aten.mul.Tensor (primals_1, primals_1) {}
# call_function mul_1 aten.mul.Tensor (mul, primals_1) {}
# output output output ([mul_1, primals_1],) {}
self.assertEqual(get_num_ins_outs(fw_graph), (1, 2))
# Expected backward graph:
# opcode name target args kwargs
# ------------- ---------- --------------- ----------------------- --------
# placeholder primals_1 primals_1 () {}
# placeholder tangents_1 tangents_1 () {}
# call_function mul aten.mul.Tensor (primals_1, primals_1) {} # RECOMPUTED
# call_function mul_2 aten.mul.Tensor (tangents_1, mul) {}
# call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {}
# call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {}
# call_function add aten.add.Tensor (mul_2, mul_4) {}
# call_function add_1 aten.add.Tensor (add, mul_4) {}
# output output output ([add_1],) {}
self.assertEqual(get_num_ins_outs(bw_graph), (2, 1))
def test_contiguous(self):
# The test simulates the condition where transpose followed by view
# happens in the backward pass.

View File

@ -1,6 +1,8 @@
# mypy: ignore-errors
from typing import Callable
import torch
import torch.fx as fx
from torch.utils import _pytree as pytree
@ -9,7 +11,7 @@ from torch.utils._pytree import tree_flatten
aten = torch.ops.aten
def get_aten_target(node):
def get_aten_target(node: fx.Node) -> Callable:
if hasattr(node.target, "overloadpacket"):
return node.target.overloadpacket
return node.target

File diff suppressed because it is too large Load Diff