# Owner(s): ["module: functorch"] from torch._functorch._activation_checkpointing.graph_info_provider import ( GraphInfoProvider, ) from torch._functorch._activation_checkpointing.knapsack_evaluator import ( KnapsackEvaluator, ) from torch.fx.graph import Graph from torch.testing._internal.common_utils import run_tests, TestCase class TestGraphInfoProvider(TestCase): """ Test class for GraphInfoProvider. The test class sets up a small graph example and tests the methods validating the graph building logic. """ def setUp(self) -> None: super().setUp() self.graph_nodes_in_order = [ "node1", "node2", "node3", "node4", "node5", "output", ] self.graph_edges = [ ("node1", "node2"), ("node2", "node3"), ("node3", "node4"), ("node4", "node5"), ("node5", "output"), ("node1", "output"), ] self.all_recomputable_banned_nodes = ["node1", "node2", "node5"] self.recorded_knapsack_input_memories = [1.0, 1.0, 1.0] self.recorded_knapsack_input_runtimes = [1.0, 1.0, 1.0] self.graph_info_provider = GraphInfoProvider( graph_nodes_in_order=self.graph_nodes_in_order, graph_edges=self.graph_edges, all_recomputable_banned_nodes=self.all_recomputable_banned_nodes, recorded_knapsack_input_memories=self.recorded_knapsack_input_memories, recorded_knapsack_input_runtimes=self.recorded_knapsack_input_runtimes, ) def test_inialize_from_graph(self): joint_graph = Graph() node1 = joint_graph.placeholder("node1") node2 = joint_graph.call_function(lambda x: x, (node1,)) node2.name = "node2" node3 = joint_graph.call_function(lambda x: x, (node2,)) node3.name = "node3" node4 = joint_graph.call_function(lambda x: x, (node3,)) node4.name = "node4" node5 = joint_graph.call_function(lambda x: x, (node4,)) node5.name = "node5" output = joint_graph.call_function(lambda x, y: (x, y), (node5, node1)) output.name = "output" all_recomputable_banned_nodes = [node1, node2, node5] recorded_knapsack_input_memories = [1.0, 1.0, 1.0] recorded_knapsack_input_runtimes = [1.0, 1.0, 1.0] graph_info_provider = GraphInfoProvider.inialize_from_graph( joint_graph=joint_graph, all_recomputable_banned_nodes=all_recomputable_banned_nodes, recorded_knapsack_input_memories=recorded_knapsack_input_memories, recorded_knapsack_input_runtimes=recorded_knapsack_input_runtimes, ) self.assertEqual( graph_info_provider.graph_nodes_in_order, ["node1", "node2", "node3", "node4", "node5", "output"], ) self.assertEqual( sorted(graph_info_provider.graph_edges), sorted( [ ("node1", "node2"), ("node2", "node3"), ("node3", "node4"), ("node4", "node5"), ("node5", "output"), ("node1", "output"), ] ), ) self.assertEqual( graph_info_provider.all_recomputable_banned_nodes, ["node1", "node2", "node5"], ) def test_get_non_ac_peak_memory(self): self.assertEqual( self.graph_info_provider.get_non_ac_peak_memory(), sum(self.recorded_knapsack_input_memories), ) def test_get_theoretical_max_runtime(self): self.assertEqual( self.graph_info_provider.get_theoretical_max_runtime(), sum(self.recorded_knapsack_input_runtimes), ) def test_get_knapsack_memory_input(self): self.assertEqual( self.graph_info_provider.get_knapsack_memory_input(), self.recorded_knapsack_input_memories, ) def test_get_knapsack_runtime_input(self): self.assertEqual( self.graph_info_provider.get_knapsack_runtime_input(), self.recorded_knapsack_input_runtimes, ) def test_recomputable_node_only_graph(self): recomputable_node_only_graph = ( self.graph_info_provider.recomputable_node_only_graph ) expected_nodes = self.all_recomputable_banned_nodes expected_edges = [("node1", "node2")] self.assertEqual(list(recomputable_node_only_graph.nodes), expected_nodes) self.assertEqual( sorted(recomputable_node_only_graph.edges), sorted(expected_edges) ) def test_recomputable_node_only_graph_with_larger_graph_context(self): recomputable_node_only_graph_with_larger_graph_context = self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context # noqa: B950 expected_nodes = self.all_recomputable_banned_nodes # node1 does not have an indirect path to node5 because of node2 # node2 has an indirect path to node5 expected_edges = [("node1", "node2"), ("node2", "node5")] self.assertEqual( sorted(recomputable_node_only_graph_with_larger_graph_context.nodes), sorted(expected_nodes), ) self.assertEqual( sorted(recomputable_node_only_graph_with_larger_graph_context.edges), sorted(expected_edges), ) def test_full_joint_nx_graph(self): graph_info_provider = GraphInfoProvider( graph_nodes_in_order=self.graph_nodes_in_order, graph_edges=self.graph_edges, all_recomputable_banned_nodes=self.all_recomputable_banned_nodes, recorded_knapsack_input_memories=self.recorded_knapsack_input_memories, recorded_knapsack_input_runtimes=self.recorded_knapsack_input_runtimes, ) full_joint_nx_graph = graph_info_provider.full_joint_nx_graph expected_nodes = [ node for node in self.graph_nodes_in_order if node != "output" ] expected_edges = [ (u, v) for u, v in self.graph_edges if u != "output" and v != "output" ] self.assertEqual(list(full_joint_nx_graph.nodes), expected_nodes) self.assertEqual(sorted(full_joint_nx_graph.edges), sorted(expected_edges)) def test_simplified_fx_joint_graph(self): graph_info_provider = GraphInfoProvider( graph_nodes_in_order=self.graph_nodes_in_order, graph_edges=self.graph_edges, all_recomputable_banned_nodes=self.all_recomputable_banned_nodes, recorded_knapsack_input_memories=self.recorded_knapsack_input_memories, recorded_knapsack_input_runtimes=self.recorded_knapsack_input_runtimes, ) simplified_fx_joint_graph = graph_info_provider.simplified_fx_joint_graph expected_nodes = self.graph_nodes_in_order expected_edges = self.graph_edges self.assertEqual( [node.name for node in simplified_fx_joint_graph.nodes], expected_nodes ) self.assertEqual( sorted( [ (node.name, user.name) for node in simplified_fx_joint_graph.nodes for user in node.users ] ), sorted(expected_edges), ) class TestKnapsackEvaluator(TestCase): """ Test class for KnapsackEvaluator. The test class sets up a small graph example and tests the methods validating the knapsack evaluation logic. """ def setUp(self) -> None: super().setUp() self.graph_nodes_in_order = [ "node1", "node2", "node3", "node4", "node5", "output", ] self.graph_edges = [ ("node1", "node2"), ("node2", "node3"), ("node3", "node4"), ("node4", "node5"), ("node5", "output"), ("node1", "output"), ] self.all_recomputable_banned_nodes = ["node1", "node2", "node5"] self.recorded_knapsack_input_memories = [0.1, 0.2, 0.2] self.recorded_knapsack_input_runtimes = [100.0, 50.0, 51.0] self.graph_info_provider = GraphInfoProvider( graph_nodes_in_order=self.graph_nodes_in_order, graph_edges=self.graph_edges, all_recomputable_banned_nodes=self.all_recomputable_banned_nodes, recorded_knapsack_input_memories=self.recorded_knapsack_input_memories, recorded_knapsack_input_runtimes=self.recorded_knapsack_input_runtimes, ) self.knapsack_evaluator = KnapsackEvaluator( graph_info_provider=self.graph_info_provider ) self.knapsack_algo = lambda memory_values, runtime_values, memory_budget: { 0.1: (101.0, [0], [1, 2]), 0.2: (101.0, [0], [1, 2]), 0.3: (50.0, [0, 2], [1]), 0.4: (50.0, [0, 2], [1]), 0.5: (0.0, [0, 1, 2], []), }.get(memory_budget, (0.0, [0, 1, 2], [])) def test_evaluate_knapsack_output_not_accounting_for_backward_pass(self): saved_nodes_idxs = [0] recomputable_node_idxs = [1, 2] result = self.knapsack_evaluator.evaluate_knapsack_output( saved_nodes_idxs=saved_nodes_idxs, recomputable_node_idxs=recomputable_node_idxs, ) self.assertEqual(result["peak_memory"], 0.1) self.assertEqual(result["recomputation_runtime"], 101.0) def test_evaluate_knapsack_output_accounting_for_backward_pass(self): saved_nodes_idxs = [0] recomputable_node_idxs = [1, 2] result = self.knapsack_evaluator.evaluate_knapsack_output( saved_nodes_idxs=saved_nodes_idxs, recomputable_node_idxs=recomputable_node_idxs, account_for_backward_pass=True, ) self.assertEqual(result["peak_memory"], 0.5) self.assertEqual(result["recomputation_runtime"], 101.0) def test_evaluate_knapsack_output_with_wrong_sized_values(self): saved_nodes_idxs = [0] recomputable_node_idxs = [1] with self.assertRaises(AssertionError): self.knapsack_evaluator.evaluate_knapsack_output( saved_nodes_idxs=saved_nodes_idxs, recomputable_node_idxs=recomputable_node_idxs, ) def test_evaluate_distribution_of_results_for_knapsack_algo(self): memory_budget_values = [0.1, 0.2, 0.3] results = ( self.knapsack_evaluator.evaluate_distribution_of_results_for_knapsack_algo( knapsack_algo=self.knapsack_algo, memory_budget_values=memory_budget_values, ) ) self.assertEqual(len(results), len(memory_budget_values)) self.assertEqual(results[0]["memory_budget"], 0.1) self.assertEqual(results[0]["peak_memory"], 0.1) self.assertEqual(results[0]["recomputation_runtime"], 101) self.assertEqual(results[1]["non_ac_peak_memory"], 0.5) self.assertEqual(results[1]["theoretical_max_runtime"], 201) self.assertEqual(results[2]["percentage_of_theoretical_peak_memory"], 0.3 / 0.5) self.assertEqual( results[2]["percentage_of_theoretical_peak_runtime"], 50.0 / 201 ) def test_get_knee_point_memory_budget(self): """ Checks if the method correctly estimates the knee point in the memory budget where the trade-off between memory usage and recomputation runtime is optimal. If memory budget and runtime are considered as equal cost, then the knee point is where the distance from 0 is smallest. """ max_mem_budget_to_expected_knee_point = { 0.1: 0.1, 0.2: 0.1, 0.3: 0.3, 0.4: 0.4, # 0.3 and 0.4 provide the same algo output so this is arbitrary 0.5: 0.4, } for ( max_mem_budget, expected_knee_point, ) in max_mem_budget_to_expected_knee_point.items(): knee_point_memory_budget = ( self.knapsack_evaluator.get_knee_point_memory_budget( knapsack_algo=self.knapsack_algo, max_mem_budget=max_mem_budget, min_mem_budget=0.1, iterations=5, ) ) self.assertEqual(knee_point_memory_budget, expected_knee_point) def test_get_backward_memory_from_topologically_sorted_graph(self): result = self.knapsack_evaluator._get_backward_memory_from_topologically_sorted_graph( node_graph=self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context, node_memories=self.graph_info_provider.all_node_memories, saved_nodes_set={"node1"}, peak_memory_after_forward_pass=0.1, ) expected_result = [ (0.1, "Initial Peak/Current Memory"), (0.3, "Recomputing Node: node5"), (0.5, "Recomputing Predecessor of node5: node2"), (0.3, "Dropping Node: node5"), (0.1, "Dropping Node(already saved): node2"), (0.0, "Dropping Node(already saved): node1"), ] print(result, expected_result) for result_item, expected_result_item in zip(result, expected_result): self.assertAlmostEqual(result_item[0], expected_result_item[0]) self.assertEqual(result_item[1], expected_result_item[1]) if __name__ == "__main__": run_tests()