mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Refactor partitioner and clean it up (#126318)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126318 Approved by: https://github.com/anijain2305
This commit is contained in:
@ -25,7 +25,6 @@ from torch._functorch.fx_minifier import minifier
|
||||
from torch._functorch.partitioners import (
|
||||
default_partition,
|
||||
draw_graph,
|
||||
draw_joint_graph,
|
||||
min_cut_rematerialization_partition,
|
||||
)
|
||||
from torch._functorch.python_key import pythonkey_decompose
|
||||
|
@ -4835,70 +4835,6 @@ class TestPartitioning(AOTTestCase):
|
||||
self.assertEqual(get_num_ins_outs(fw_graph), (4, 2))
|
||||
self.assertEqual(get_num_ins_outs(bw_graph), (2, 4))
|
||||
|
||||
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
|
||||
def test_min_cut_partitioner_recomputable_ops(self):
|
||||
def f(x):
|
||||
return x * x * x
|
||||
|
||||
recomputable_ops = []
|
||||
partition_fn = partial(
|
||||
min_cut_rematerialization_partition, recomputable_ops=recomputable_ops
|
||||
)
|
||||
|
||||
fw_graph, bw_graph = get_fw_bw_graph(
|
||||
f, [torch.randn(3, requires_grad=True)], partition_fn
|
||||
)
|
||||
# Expected forward graph:
|
||||
# opcode name target args kwargs
|
||||
# ------------- --------- --------------- -------------------------- --------
|
||||
# placeholder primals_1 primals_1 () {}
|
||||
# call_function mul aten.mul.Tensor (primals_1, primals_1) {}
|
||||
# call_function mul_1 aten.mul.Tensor (mul, primals_1) {}
|
||||
# output output output ([mul_1, primals_1, mul],) {}
|
||||
self.assertEqual(get_num_ins_outs(fw_graph), (1, 3))
|
||||
# Expected backward graph:
|
||||
# opcode name target args kwargs
|
||||
# ------------- ---------- --------------- ----------------------- --------
|
||||
# placeholder primals_1 primals_1 () {}
|
||||
# placeholder mul mul () {}
|
||||
# placeholder tangents_1 tangents_1 () {}
|
||||
# call_function mul_2 aten.mul.Tensor (tangents_1, mul) {}
|
||||
# call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {}
|
||||
# call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {}
|
||||
# call_function add aten.add.Tensor (mul_2, mul_4) {}
|
||||
# call_function add_1 aten.add.Tensor (add, mul_4) {}
|
||||
# output output output ([add_1],) {}
|
||||
self.assertEqual(get_num_ins_outs(bw_graph), (3, 1))
|
||||
|
||||
recomputable_ops = [torch.ops.aten.mul]
|
||||
partition_fn = partial(
|
||||
min_cut_rematerialization_partition, recomputable_ops=recomputable_ops
|
||||
)
|
||||
fw_graph, bw_graph = get_fw_bw_graph(
|
||||
f, [torch.randn(3, requires_grad=True)], partition_fn
|
||||
)
|
||||
# Expected forward graph:
|
||||
# opcode name target args kwargs
|
||||
# ------------- --------- --------------- ---------------------- --------
|
||||
# placeholder primals_1 primals_1 () {}
|
||||
# call_function mul aten.mul.Tensor (primals_1, primals_1) {}
|
||||
# call_function mul_1 aten.mul.Tensor (mul, primals_1) {}
|
||||
# output output output ([mul_1, primals_1],) {}
|
||||
self.assertEqual(get_num_ins_outs(fw_graph), (1, 2))
|
||||
# Expected backward graph:
|
||||
# opcode name target args kwargs
|
||||
# ------------- ---------- --------------- ----------------------- --------
|
||||
# placeholder primals_1 primals_1 () {}
|
||||
# placeholder tangents_1 tangents_1 () {}
|
||||
# call_function mul aten.mul.Tensor (primals_1, primals_1) {} # RECOMPUTED
|
||||
# call_function mul_2 aten.mul.Tensor (tangents_1, mul) {}
|
||||
# call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {}
|
||||
# call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {}
|
||||
# call_function add aten.add.Tensor (mul_2, mul_4) {}
|
||||
# call_function add_1 aten.add.Tensor (add, mul_4) {}
|
||||
# output output output ([add_1],) {}
|
||||
self.assertEqual(get_num_ins_outs(bw_graph), (2, 1))
|
||||
|
||||
def test_contiguous(self):
|
||||
# The test simulates the condition where transpose followed by view
|
||||
# happens in the backward pass.
|
||||
|
@ -1,6 +1,8 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch.utils import _pytree as pytree
|
||||
@ -9,7 +11,7 @@ from torch.utils._pytree import tree_flatten
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
def get_aten_target(node):
|
||||
def get_aten_target(node: fx.Node) -> Callable:
|
||||
if hasattr(node.target, "overloadpacket"):
|
||||
return node.target.overloadpacket
|
||||
return node.target
|
||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user