mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This should be the final PR before we can enable RUFF UP006. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392 Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
151 lines
5.7 KiB
Python
151 lines
5.7 KiB
Python
# Owner(s): ["module: functorch"]
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from torch._functorch._activation_checkpointing.ac_logging_utils import (
|
|
create_activation_checkpointing_logging_structure_payload,
|
|
create_joint_graph_edges,
|
|
create_joint_graph_node_information,
|
|
create_structured_trace_for_min_cut_info,
|
|
)
|
|
from torch.fx import Graph, Node
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
class TestAcLogging(TestCase):
|
|
def setUp(self) -> None:
|
|
self.graph: MagicMock = MagicMock(spec=Graph)
|
|
self.node1: MagicMock = MagicMock(spec=Node)
|
|
self.node2: MagicMock = MagicMock(spec=Node)
|
|
|
|
self.node1.name = "node1"
|
|
self.node1.target = "target1"
|
|
self.node1.meta = {
|
|
"tensor_meta": MagicMock(shape=(2, 2)),
|
|
"stack_trace": "trace1",
|
|
}
|
|
self.node1.all_input_nodes = []
|
|
|
|
self.node2.name = "node2"
|
|
self.node2.target = "target2"
|
|
self.node2.meta = {"tensor_meta": None, "stack_trace": "trace2"}
|
|
self.node2.all_input_nodes = [self.node1]
|
|
|
|
self.graph.nodes = [self.node1, self.node2]
|
|
|
|
self.all_recomputable_banned_nodes: list[Node] = [self.node1]
|
|
self.saved_node_idxs: list[int] = [0]
|
|
self.recomputable_node_idxs: list[int] = []
|
|
self.expected_runtime: int = 100
|
|
self.memories_banned_nodes: list[int] = [50]
|
|
self.runtimes_banned_nodes: list[int] = [10]
|
|
self.min_cut_saved_values: list[Node] = [self.node1]
|
|
|
|
def test_create_joint_graph_node_information(self) -> None:
|
|
recomputable_node_info: dict[str, int] = {"node1": 0}
|
|
expected_output: dict[str, dict] = {
|
|
"node1": {
|
|
"index": 0,
|
|
"name": "node1",
|
|
"is_recomputable_candidate": True,
|
|
"target": "target1",
|
|
"shape": "(2, 2)",
|
|
"input_arguments": [],
|
|
"stack_trace": "trace1",
|
|
"recomputable_candidate_info": {"recomputable_node_idx": 0},
|
|
},
|
|
"node2": {
|
|
"index": 1,
|
|
"name": "node2",
|
|
"is_recomputable_candidate": False,
|
|
"target": "target2",
|
|
"shape": "[]",
|
|
"input_arguments": ["node1"],
|
|
"stack_trace": "trace2",
|
|
},
|
|
}
|
|
result = create_joint_graph_node_information(self.graph, recomputable_node_info)
|
|
self.assertEqual(result, expected_output)
|
|
|
|
def test_create_joint_graph_edges(self) -> None:
|
|
expected_edges: list[tuple[str, str]] = [("node1", "node2")]
|
|
result = create_joint_graph_edges(self.graph)
|
|
self.assertEqual(result, expected_edges)
|
|
|
|
def test_create_activation_checkpointing_logging_structure_payload(self) -> None:
|
|
input_joint_graph_node_information: dict[str, dict] = {
|
|
"node1": {
|
|
"index": 0,
|
|
"name": "node1",
|
|
"is_recomputable_candidate": True,
|
|
"target": "target1",
|
|
"shape": "(2, 2)",
|
|
"input_arguments": [],
|
|
"stack_trace": "trace1",
|
|
"recomputable_candidate_info": {"recomputable_node_idx": 0},
|
|
}
|
|
}
|
|
joint_graph_edges: list[tuple[str, str]] = [("node1", "node2")]
|
|
expected_payload: dict[str, any] = {
|
|
"Joint Graph Size": 2,
|
|
"Joint Graph Edges": {"Total": 1, "Edges": joint_graph_edges},
|
|
"Joint Graph Node Information": input_joint_graph_node_information,
|
|
"Recomputable Banned Nodes Order": ["node1"],
|
|
"Expected Runtime": self.expected_runtime,
|
|
"Knapsack Saved Nodes": self.saved_node_idxs,
|
|
"Knapsack Recomputed Nodes": self.recomputable_node_idxs,
|
|
"Knapsack Input Memories": self.memories_banned_nodes,
|
|
"Knapsack Input Runtimes": self.runtimes_banned_nodes,
|
|
"Min Cut Solution Saved Values": ["node1"],
|
|
}
|
|
result = create_activation_checkpointing_logging_structure_payload(
|
|
self.graph,
|
|
input_joint_graph_node_information,
|
|
joint_graph_edges,
|
|
self.all_recomputable_banned_nodes,
|
|
self.expected_runtime,
|
|
self.saved_node_idxs,
|
|
self.recomputable_node_idxs,
|
|
self.memories_banned_nodes,
|
|
self.runtimes_banned_nodes,
|
|
self.min_cut_saved_values,
|
|
)
|
|
self.assertEqual(result, expected_payload)
|
|
|
|
@patch(
|
|
"torch._functorch._activation_checkpointing.ac_logging_utils.trace_structured"
|
|
)
|
|
@patch("json.dumps", return_value="mocked_payload")
|
|
def test_create_structured_trace_for_min_cut_info(
|
|
self, mock_json_dumps: MagicMock, mock_trace_structured: MagicMock
|
|
) -> None:
|
|
create_structured_trace_for_min_cut_info(
|
|
self.graph,
|
|
self.all_recomputable_banned_nodes,
|
|
self.saved_node_idxs,
|
|
self.recomputable_node_idxs,
|
|
self.expected_runtime,
|
|
self.memories_banned_nodes,
|
|
self.runtimes_banned_nodes,
|
|
self.min_cut_saved_values,
|
|
)
|
|
|
|
self.assertEqual(mock_trace_structured.call_count, 1)
|
|
|
|
metadata_fn_result = mock_trace_structured.call_args[1]["metadata_fn"]()
|
|
payload_fn_result = mock_trace_structured.call_args[1]["payload_fn"]()
|
|
|
|
self.assertEqual(
|
|
metadata_fn_result,
|
|
{
|
|
"name": "min_cut_information",
|
|
"encoding": "json",
|
|
},
|
|
)
|
|
self.assertEqual(payload_fn_result, "mocked_payload")
|
|
|
|
mock_json_dumps.assert_called_once()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|