Files
pytorch/test/functorch/test_ac_logging.py
Aaron Orenstein db4ce78d46 PEP585: More UP006 fixes (#146392)
This should be the final PR before we can enable RUFF UP006.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392
Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
2025-02-20 06:18:13 +00:00

151 lines
5.7 KiB
Python

# Owner(s): ["module: functorch"]
from unittest.mock import MagicMock, patch
from torch._functorch._activation_checkpointing.ac_logging_utils import (
create_activation_checkpointing_logging_structure_payload,
create_joint_graph_edges,
create_joint_graph_node_information,
create_structured_trace_for_min_cut_info,
)
from torch.fx import Graph, Node
from torch.testing._internal.common_utils import run_tests, TestCase
class TestAcLogging(TestCase):
def setUp(self) -> None:
self.graph: MagicMock = MagicMock(spec=Graph)
self.node1: MagicMock = MagicMock(spec=Node)
self.node2: MagicMock = MagicMock(spec=Node)
self.node1.name = "node1"
self.node1.target = "target1"
self.node1.meta = {
"tensor_meta": MagicMock(shape=(2, 2)),
"stack_trace": "trace1",
}
self.node1.all_input_nodes = []
self.node2.name = "node2"
self.node2.target = "target2"
self.node2.meta = {"tensor_meta": None, "stack_trace": "trace2"}
self.node2.all_input_nodes = [self.node1]
self.graph.nodes = [self.node1, self.node2]
self.all_recomputable_banned_nodes: list[Node] = [self.node1]
self.saved_node_idxs: list[int] = [0]
self.recomputable_node_idxs: list[int] = []
self.expected_runtime: int = 100
self.memories_banned_nodes: list[int] = [50]
self.runtimes_banned_nodes: list[int] = [10]
self.min_cut_saved_values: list[Node] = [self.node1]
def test_create_joint_graph_node_information(self) -> None:
recomputable_node_info: dict[str, int] = {"node1": 0}
expected_output: dict[str, dict] = {
"node1": {
"index": 0,
"name": "node1",
"is_recomputable_candidate": True,
"target": "target1",
"shape": "(2, 2)",
"input_arguments": [],
"stack_trace": "trace1",
"recomputable_candidate_info": {"recomputable_node_idx": 0},
},
"node2": {
"index": 1,
"name": "node2",
"is_recomputable_candidate": False,
"target": "target2",
"shape": "[]",
"input_arguments": ["node1"],
"stack_trace": "trace2",
},
}
result = create_joint_graph_node_information(self.graph, recomputable_node_info)
self.assertEqual(result, expected_output)
def test_create_joint_graph_edges(self) -> None:
expected_edges: list[tuple[str, str]] = [("node1", "node2")]
result = create_joint_graph_edges(self.graph)
self.assertEqual(result, expected_edges)
def test_create_activation_checkpointing_logging_structure_payload(self) -> None:
input_joint_graph_node_information: dict[str, dict] = {
"node1": {
"index": 0,
"name": "node1",
"is_recomputable_candidate": True,
"target": "target1",
"shape": "(2, 2)",
"input_arguments": [],
"stack_trace": "trace1",
"recomputable_candidate_info": {"recomputable_node_idx": 0},
}
}
joint_graph_edges: list[tuple[str, str]] = [("node1", "node2")]
expected_payload: dict[str, any] = {
"Joint Graph Size": 2,
"Joint Graph Edges": {"Total": 1, "Edges": joint_graph_edges},
"Joint Graph Node Information": input_joint_graph_node_information,
"Recomputable Banned Nodes Order": ["node1"],
"Expected Runtime": self.expected_runtime,
"Knapsack Saved Nodes": self.saved_node_idxs,
"Knapsack Recomputed Nodes": self.recomputable_node_idxs,
"Knapsack Input Memories": self.memories_banned_nodes,
"Knapsack Input Runtimes": self.runtimes_banned_nodes,
"Min Cut Solution Saved Values": ["node1"],
}
result = create_activation_checkpointing_logging_structure_payload(
self.graph,
input_joint_graph_node_information,
joint_graph_edges,
self.all_recomputable_banned_nodes,
self.expected_runtime,
self.saved_node_idxs,
self.recomputable_node_idxs,
self.memories_banned_nodes,
self.runtimes_banned_nodes,
self.min_cut_saved_values,
)
self.assertEqual(result, expected_payload)
@patch(
"torch._functorch._activation_checkpointing.ac_logging_utils.trace_structured"
)
@patch("json.dumps", return_value="mocked_payload")
def test_create_structured_trace_for_min_cut_info(
self, mock_json_dumps: MagicMock, mock_trace_structured: MagicMock
) -> None:
create_structured_trace_for_min_cut_info(
self.graph,
self.all_recomputable_banned_nodes,
self.saved_node_idxs,
self.recomputable_node_idxs,
self.expected_runtime,
self.memories_banned_nodes,
self.runtimes_banned_nodes,
self.min_cut_saved_values,
)
self.assertEqual(mock_trace_structured.call_count, 1)
metadata_fn_result = mock_trace_structured.call_args[1]["metadata_fn"]()
payload_fn_result = mock_trace_structured.call_args[1]["payload_fn"]()
self.assertEqual(
metadata_fn_result,
{
"name": "min_cut_information",
"encoding": "json",
},
)
self.assertEqual(payload_fn_result, "mocked_payload")
mock_json_dumps.assert_called_once()
if __name__ == "__main__":
run_tests()