mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 17:24:59 +08:00
Refactor partitioner and clean it up (#126318)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126318 Approved by: https://github.com/anijain2305
This commit is contained in:
@ -25,7 +25,6 @@ from torch._functorch.fx_minifier import minifier
|
|||||||
from torch._functorch.partitioners import (
|
from torch._functorch.partitioners import (
|
||||||
default_partition,
|
default_partition,
|
||||||
draw_graph,
|
draw_graph,
|
||||||
draw_joint_graph,
|
|
||||||
min_cut_rematerialization_partition,
|
min_cut_rematerialization_partition,
|
||||||
)
|
)
|
||||||
from torch._functorch.python_key import pythonkey_decompose
|
from torch._functorch.python_key import pythonkey_decompose
|
||||||
|
|||||||
@ -4835,70 +4835,6 @@ class TestPartitioning(AOTTestCase):
|
|||||||
self.assertEqual(get_num_ins_outs(fw_graph), (4, 2))
|
self.assertEqual(get_num_ins_outs(fw_graph), (4, 2))
|
||||||
self.assertEqual(get_num_ins_outs(bw_graph), (2, 4))
|
self.assertEqual(get_num_ins_outs(bw_graph), (2, 4))
|
||||||
|
|
||||||
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
|
|
||||||
def test_min_cut_partitioner_recomputable_ops(self):
|
|
||||||
def f(x):
|
|
||||||
return x * x * x
|
|
||||||
|
|
||||||
recomputable_ops = []
|
|
||||||
partition_fn = partial(
|
|
||||||
min_cut_rematerialization_partition, recomputable_ops=recomputable_ops
|
|
||||||
)
|
|
||||||
|
|
||||||
fw_graph, bw_graph = get_fw_bw_graph(
|
|
||||||
f, [torch.randn(3, requires_grad=True)], partition_fn
|
|
||||||
)
|
|
||||||
# Expected forward graph:
|
|
||||||
# opcode name target args kwargs
|
|
||||||
# ------------- --------- --------------- -------------------------- --------
|
|
||||||
# placeholder primals_1 primals_1 () {}
|
|
||||||
# call_function mul aten.mul.Tensor (primals_1, primals_1) {}
|
|
||||||
# call_function mul_1 aten.mul.Tensor (mul, primals_1) {}
|
|
||||||
# output output output ([mul_1, primals_1, mul],) {}
|
|
||||||
self.assertEqual(get_num_ins_outs(fw_graph), (1, 3))
|
|
||||||
# Expected backward graph:
|
|
||||||
# opcode name target args kwargs
|
|
||||||
# ------------- ---------- --------------- ----------------------- --------
|
|
||||||
# placeholder primals_1 primals_1 () {}
|
|
||||||
# placeholder mul mul () {}
|
|
||||||
# placeholder tangents_1 tangents_1 () {}
|
|
||||||
# call_function mul_2 aten.mul.Tensor (tangents_1, mul) {}
|
|
||||||
# call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {}
|
|
||||||
# call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {}
|
|
||||||
# call_function add aten.add.Tensor (mul_2, mul_4) {}
|
|
||||||
# call_function add_1 aten.add.Tensor (add, mul_4) {}
|
|
||||||
# output output output ([add_1],) {}
|
|
||||||
self.assertEqual(get_num_ins_outs(bw_graph), (3, 1))
|
|
||||||
|
|
||||||
recomputable_ops = [torch.ops.aten.mul]
|
|
||||||
partition_fn = partial(
|
|
||||||
min_cut_rematerialization_partition, recomputable_ops=recomputable_ops
|
|
||||||
)
|
|
||||||
fw_graph, bw_graph = get_fw_bw_graph(
|
|
||||||
f, [torch.randn(3, requires_grad=True)], partition_fn
|
|
||||||
)
|
|
||||||
# Expected forward graph:
|
|
||||||
# opcode name target args kwargs
|
|
||||||
# ------------- --------- --------------- ---------------------- --------
|
|
||||||
# placeholder primals_1 primals_1 () {}
|
|
||||||
# call_function mul aten.mul.Tensor (primals_1, primals_1) {}
|
|
||||||
# call_function mul_1 aten.mul.Tensor (mul, primals_1) {}
|
|
||||||
# output output output ([mul_1, primals_1],) {}
|
|
||||||
self.assertEqual(get_num_ins_outs(fw_graph), (1, 2))
|
|
||||||
# Expected backward graph:
|
|
||||||
# opcode name target args kwargs
|
|
||||||
# ------------- ---------- --------------- ----------------------- --------
|
|
||||||
# placeholder primals_1 primals_1 () {}
|
|
||||||
# placeholder tangents_1 tangents_1 () {}
|
|
||||||
# call_function mul aten.mul.Tensor (primals_1, primals_1) {} # RECOMPUTED
|
|
||||||
# call_function mul_2 aten.mul.Tensor (tangents_1, mul) {}
|
|
||||||
# call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {}
|
|
||||||
# call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {}
|
|
||||||
# call_function add aten.add.Tensor (mul_2, mul_4) {}
|
|
||||||
# call_function add_1 aten.add.Tensor (add, mul_4) {}
|
|
||||||
# output output output ([add_1],) {}
|
|
||||||
self.assertEqual(get_num_ins_outs(bw_graph), (2, 1))
|
|
||||||
|
|
||||||
def test_contiguous(self):
|
def test_contiguous(self):
|
||||||
# The test simulates the condition where transpose followed by view
|
# The test simulates the condition where transpose followed by view
|
||||||
# happens in the backward pass.
|
# happens in the backward pass.
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
# mypy: ignore-errors
|
# mypy: ignore-errors
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
from torch.utils import _pytree as pytree
|
from torch.utils import _pytree as pytree
|
||||||
@ -9,7 +11,7 @@ from torch.utils._pytree import tree_flatten
|
|||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
||||||
|
|
||||||
def get_aten_target(node):
|
def get_aten_target(node: fx.Node) -> Callable:
|
||||||
if hasattr(node.target, "overloadpacket"):
|
if hasattr(node.target, "overloadpacket"):
|
||||||
return node.target.overloadpacket
|
return node.target.overloadpacket
|
||||||
return node.target
|
return node.target
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user