mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: The two implementations are functionally equivalent. They both calculate the memory budget at the knee point in the Pareto frontier using the same algorithm. 1. np.linspace -> basic list comprehension 2. runtime and memory values -> lists instead of numpy arrays 3. np.ptp -> max - min 4. np.norm -> diff with min value / range 5. np.sqrt -> **0.5 5. np.argmin -> .index(min(_)) Test Plan: # Unit Testing ``` buck test mode/opt //caffe2/test/functorch:test_ac_knapsack; pingme "tests done" Buck UI: https://www.internalfb.com/buck2/f4e41eb8-e775-4f04-b4e7-8e567599deb8 Test UI: https://www.internalfb.com/intern/testinfra/testrun/10133099236155875 Network: Up: 24KiB Down: 1.9GiB (reSessionID-7cd11487-f3e7-43ab-982a-805510771c8d) Executing actions. Remaining 0/259826 98:15:40.5s exec time total Command: test. Finished 3 local, 5 remote, 103467 cache (99% hit) 98:15:14.8s exec time cached (99%) Time elapsed: 1:09.9s Tests finished: Pass 15. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` # End to End Testing ### Baseline Run with DP Let's confirm everything we are running on works. - Optimization Algo: DP - Memory Budget: 0.05 - AIX Link: apf_local-basilwong-2025-03-22_20:39:10 - TLParse rank 0: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpDJaWp5/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 - TLParse rank 1: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpDJaWp5/rank_1/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 ### Dynamic Memory Budget (Before Change) - Revision: 2c95489b7f79 - Optimization Algo: Dynamic Memory Budget - Memory Budget: 0.05 - AIX Link: https://www.internalfb.com/mlhub/pipeline/4088035428184866 - TLParse: - https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpykEy8U/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 - https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpykEy8U/rank_1/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 ### Dynamic Memory Budget (After Change) - Revision: 14353eef3c9e - Optimization Algo: Dynamic Memory Budget - Memory Budget: 0.05 - AIX Link: https://www.internalfb.com/mlhub/pipeline/1613558749306737 - TLParse Links: - https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpZKNWFw/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 - https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpZKNWFw/rank_1/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 As a sanity check lets take the AC information for the following compile id: 7_0_0 from the rank 0 of each TLParse. {F1976883124} * Baseline: P1779400819 * Saved node values show we are storing much more compared to dynamic memory: ``` "Knapsack Saved Nodes": [ 16, 17, 19, 20, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60 ] ``` * Before Change: P1779401775 * Saved nodes are similar to after change but not exactly. ``` "Knapsack Saved Nodes": [ 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 49, 50 ] ``` * After Change: P1779402106 * Here we se the largest nodes that are saved are around the same, but there is a small discrepancy for the smallest nodes. ``` "Knapsack Saved Nodes": [ 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 50, 51, 57, 58, 59, 60, 61, 62 ], ``` The discrepancy can be explained by looking at the estimated memory values. This is the non-deterministic part(below are the top 5 memory values for considered candidates): ``` 0.05774741703905514, 0.007333005338292718, 0.007333005338292718, 0.007333005338292718, 0.007333005338292718, ``` vs ``` 0.049254204820440746, 0.006254502199421049, 0.006254502199421049, 0.006254502199421049, 0.006254502199421049, ``` Based on that the dynamic memory implementations performed similarly in an E2E test and that memory is non-deterministic we should be good to go to land. Differential Revision: D71692245 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150825 Approved by: https://github.com/seemethere, https://github.com/jansel
333 lines
13 KiB
Python
333 lines
13 KiB
Python
# Owner(s): ["module: functorch"]
|
|
from torch._functorch._activation_checkpointing.graph_info_provider import (
|
|
GraphInfoProvider,
|
|
)
|
|
from torch._functorch._activation_checkpointing.knapsack_evaluator import (
|
|
KnapsackEvaluator,
|
|
)
|
|
from torch.fx.graph import Graph
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
class TestGraphInfoProvider(TestCase):
|
|
"""
|
|
Test class for GraphInfoProvider.
|
|
The test class sets up a small graph example and tests the methods validating the graph building logic.
|
|
"""
|
|
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self.graph_nodes_in_order = [
|
|
"node1",
|
|
"node2",
|
|
"node3",
|
|
"node4",
|
|
"node5",
|
|
"output",
|
|
]
|
|
self.graph_edges = [
|
|
("node1", "node2"),
|
|
("node2", "node3"),
|
|
("node3", "node4"),
|
|
("node4", "node5"),
|
|
("node5", "output"),
|
|
("node1", "output"),
|
|
]
|
|
self.all_recomputable_banned_nodes = ["node1", "node2", "node5"]
|
|
self.recorded_knapsack_input_memories = [1.0, 1.0, 1.0]
|
|
self.recorded_knapsack_input_runtimes = [1.0, 1.0, 1.0]
|
|
self.graph_info_provider = GraphInfoProvider(
|
|
graph_nodes_in_order=self.graph_nodes_in_order,
|
|
graph_edges=self.graph_edges,
|
|
all_recomputable_banned_nodes=self.all_recomputable_banned_nodes,
|
|
recorded_knapsack_input_memories=self.recorded_knapsack_input_memories,
|
|
recorded_knapsack_input_runtimes=self.recorded_knapsack_input_runtimes,
|
|
)
|
|
|
|
def test_inialize_from_graph(self):
|
|
joint_graph = Graph()
|
|
node1 = joint_graph.placeholder("node1")
|
|
node2 = joint_graph.call_function(lambda x: x, (node1,))
|
|
node2.name = "node2"
|
|
node3 = joint_graph.call_function(lambda x: x, (node2,))
|
|
node3.name = "node3"
|
|
node4 = joint_graph.call_function(lambda x: x, (node3,))
|
|
node4.name = "node4"
|
|
node5 = joint_graph.call_function(lambda x: x, (node4,))
|
|
node5.name = "node5"
|
|
output = joint_graph.call_function(lambda x, y: (x, y), (node5, node1))
|
|
output.name = "output"
|
|
all_recomputable_banned_nodes = [node1, node2, node5]
|
|
recorded_knapsack_input_memories = [1.0, 1.0, 1.0]
|
|
recorded_knapsack_input_runtimes = [1.0, 1.0, 1.0]
|
|
graph_info_provider = GraphInfoProvider.inialize_from_graph(
|
|
joint_graph=joint_graph,
|
|
all_recomputable_banned_nodes=all_recomputable_banned_nodes,
|
|
recorded_knapsack_input_memories=recorded_knapsack_input_memories,
|
|
recorded_knapsack_input_runtimes=recorded_knapsack_input_runtimes,
|
|
)
|
|
self.assertEqual(
|
|
graph_info_provider.graph_nodes_in_order,
|
|
["node1", "node2", "node3", "node4", "node5", "output"],
|
|
)
|
|
self.assertEqual(
|
|
sorted(graph_info_provider.graph_edges),
|
|
sorted(
|
|
[
|
|
("node1", "node2"),
|
|
("node2", "node3"),
|
|
("node3", "node4"),
|
|
("node4", "node5"),
|
|
("node5", "output"),
|
|
("node1", "output"),
|
|
]
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
graph_info_provider.all_recomputable_banned_nodes,
|
|
["node1", "node2", "node5"],
|
|
)
|
|
|
|
def test_get_non_ac_peak_memory(self):
|
|
self.assertEqual(
|
|
self.graph_info_provider.get_non_ac_peak_memory(),
|
|
sum(self.recorded_knapsack_input_memories),
|
|
)
|
|
|
|
def test_get_theoretical_max_runtime(self):
|
|
self.assertEqual(
|
|
self.graph_info_provider.get_theoretical_max_runtime(),
|
|
sum(self.recorded_knapsack_input_runtimes),
|
|
)
|
|
|
|
def test_get_knapsack_memory_input(self):
|
|
self.assertEqual(
|
|
self.graph_info_provider.get_knapsack_memory_input(),
|
|
self.recorded_knapsack_input_memories,
|
|
)
|
|
|
|
def test_get_knapsack_runtime_input(self):
|
|
self.assertEqual(
|
|
self.graph_info_provider.get_knapsack_runtime_input(),
|
|
self.recorded_knapsack_input_runtimes,
|
|
)
|
|
|
|
def test_recomputable_node_only_graph(self):
|
|
recomputable_node_only_graph = (
|
|
self.graph_info_provider.recomputable_node_only_graph
|
|
)
|
|
expected_nodes = self.all_recomputable_banned_nodes
|
|
expected_edges = [("node1", "node2")]
|
|
self.assertEqual(list(recomputable_node_only_graph.nodes), expected_nodes)
|
|
self.assertEqual(
|
|
sorted(recomputable_node_only_graph.edges), sorted(expected_edges)
|
|
)
|
|
|
|
def test_recomputable_node_only_graph_with_larger_graph_context(self):
|
|
recomputable_node_only_graph_with_larger_graph_context = (
|
|
self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context
|
|
)
|
|
expected_nodes = self.all_recomputable_banned_nodes
|
|
# node1 does not have an indirect path to node5 because of node2
|
|
# node2 has an indirect path to node5
|
|
expected_edges = [("node1", "node2"), ("node2", "node5")]
|
|
self.assertEqual(
|
|
sorted(recomputable_node_only_graph_with_larger_graph_context.nodes),
|
|
sorted(expected_nodes),
|
|
)
|
|
self.assertEqual(
|
|
sorted(recomputable_node_only_graph_with_larger_graph_context.edges),
|
|
sorted(expected_edges),
|
|
)
|
|
|
|
def test_full_joint_nx_graph(self):
|
|
graph_info_provider = GraphInfoProvider(
|
|
graph_nodes_in_order=self.graph_nodes_in_order,
|
|
graph_edges=self.graph_edges,
|
|
all_recomputable_banned_nodes=self.all_recomputable_banned_nodes,
|
|
recorded_knapsack_input_memories=self.recorded_knapsack_input_memories,
|
|
recorded_knapsack_input_runtimes=self.recorded_knapsack_input_runtimes,
|
|
)
|
|
full_joint_nx_graph = graph_info_provider.full_joint_nx_graph
|
|
expected_nodes = [
|
|
node for node in self.graph_nodes_in_order if node != "output"
|
|
]
|
|
expected_edges = [
|
|
(u, v) for u, v in self.graph_edges if u != "output" and v != "output"
|
|
]
|
|
self.assertEqual(list(full_joint_nx_graph.nodes), expected_nodes)
|
|
self.assertEqual(sorted(full_joint_nx_graph.edges), sorted(expected_edges))
|
|
|
|
def test_simplified_fx_joint_graph(self):
|
|
graph_info_provider = GraphInfoProvider(
|
|
graph_nodes_in_order=self.graph_nodes_in_order,
|
|
graph_edges=self.graph_edges,
|
|
all_recomputable_banned_nodes=self.all_recomputable_banned_nodes,
|
|
recorded_knapsack_input_memories=self.recorded_knapsack_input_memories,
|
|
recorded_knapsack_input_runtimes=self.recorded_knapsack_input_runtimes,
|
|
)
|
|
simplified_fx_joint_graph = graph_info_provider.simplified_fx_joint_graph
|
|
expected_nodes = self.graph_nodes_in_order
|
|
expected_edges = self.graph_edges
|
|
self.assertEqual(
|
|
[node.name for node in simplified_fx_joint_graph.nodes], expected_nodes
|
|
)
|
|
self.assertEqual(
|
|
sorted(
|
|
[
|
|
(node.name, user.name)
|
|
for node in simplified_fx_joint_graph.nodes
|
|
for user in node.users
|
|
]
|
|
),
|
|
sorted(expected_edges),
|
|
)
|
|
|
|
|
|
class TestKnapsackEvaluator(TestCase):
|
|
"""
|
|
Test class for KnapsackEvaluator.
|
|
The test class sets up a small graph example and tests the methods validating the knapsack evaluation logic.
|
|
"""
|
|
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self.graph_nodes_in_order = [
|
|
"node1",
|
|
"node2",
|
|
"node3",
|
|
"node4",
|
|
"node5",
|
|
"output",
|
|
]
|
|
self.graph_edges = [
|
|
("node1", "node2"),
|
|
("node2", "node3"),
|
|
("node3", "node4"),
|
|
("node4", "node5"),
|
|
("node5", "output"),
|
|
("node1", "output"),
|
|
]
|
|
self.all_recomputable_banned_nodes = ["node1", "node2", "node5"]
|
|
self.recorded_knapsack_input_memories = [0.1, 0.2, 0.2]
|
|
self.recorded_knapsack_input_runtimes = [100.0, 50.0, 51.0]
|
|
self.graph_info_provider = GraphInfoProvider(
|
|
graph_nodes_in_order=self.graph_nodes_in_order,
|
|
graph_edges=self.graph_edges,
|
|
all_recomputable_banned_nodes=self.all_recomputable_banned_nodes,
|
|
recorded_knapsack_input_memories=self.recorded_knapsack_input_memories,
|
|
recorded_knapsack_input_runtimes=self.recorded_knapsack_input_runtimes,
|
|
)
|
|
self.knapsack_evaluator = KnapsackEvaluator(
|
|
graph_info_provider=self.graph_info_provider
|
|
)
|
|
self.knapsack_algo = lambda memory_values, runtime_values, memory_budget: {
|
|
0.1: (101.0, [0], [1, 2]),
|
|
0.2: (101.0, [0], [1, 2]),
|
|
0.3: (50.0, [0, 2], [1]),
|
|
0.4: (50.0, [0, 2], [1]),
|
|
0.5: (0.0, [0, 1, 2], []),
|
|
}.get(memory_budget, (0.0, [0, 1, 2], []))
|
|
|
|
def test_evaluate_knapsack_output_not_accounting_for_backward_pass(self):
|
|
saved_nodes_idxs = [0]
|
|
recomputable_node_idxs = [1, 2]
|
|
result = self.knapsack_evaluator.evaluate_knapsack_output(
|
|
saved_nodes_idxs=saved_nodes_idxs,
|
|
recomputable_node_idxs=recomputable_node_idxs,
|
|
)
|
|
self.assertEqual(result["peak_memory"], 0.1)
|
|
self.assertEqual(result["recomputation_runtime"], 101.0)
|
|
|
|
def test_evaluate_knapsack_output_accounting_for_backward_pass(self):
|
|
saved_nodes_idxs = [0]
|
|
recomputable_node_idxs = [1, 2]
|
|
result = self.knapsack_evaluator.evaluate_knapsack_output(
|
|
saved_nodes_idxs=saved_nodes_idxs,
|
|
recomputable_node_idxs=recomputable_node_idxs,
|
|
account_for_backward_pass=True,
|
|
)
|
|
self.assertEqual(result["peak_memory"], 0.5)
|
|
self.assertEqual(result["recomputation_runtime"], 101.0)
|
|
|
|
def test_evaluate_knapsack_output_with_wrong_sized_values(self):
|
|
saved_nodes_idxs = [0]
|
|
recomputable_node_idxs = [1]
|
|
with self.assertRaises(AssertionError):
|
|
self.knapsack_evaluator.evaluate_knapsack_output(
|
|
saved_nodes_idxs=saved_nodes_idxs,
|
|
recomputable_node_idxs=recomputable_node_idxs,
|
|
)
|
|
|
|
def test_evaluate_distribution_of_results_for_knapsack_algo(self):
|
|
memory_budget_values = [0.1, 0.2, 0.3]
|
|
results = (
|
|
self.knapsack_evaluator.evaluate_distribution_of_results_for_knapsack_algo(
|
|
knapsack_algo=self.knapsack_algo,
|
|
memory_budget_values=memory_budget_values,
|
|
)
|
|
)
|
|
self.assertEqual(len(results), len(memory_budget_values))
|
|
self.assertEqual(results[0]["memory_budget"], 0.1)
|
|
self.assertEqual(results[0]["peak_memory"], 0.1)
|
|
self.assertEqual(results[0]["recomputation_runtime"], 101)
|
|
self.assertEqual(results[1]["non_ac_peak_memory"], 0.5)
|
|
self.assertEqual(results[1]["theoretical_max_runtime"], 201)
|
|
self.assertEqual(results[2]["percentage_of_theoretical_peak_memory"], 0.3 / 0.5)
|
|
self.assertEqual(
|
|
results[2]["percentage_of_theoretical_peak_runtime"], 50.0 / 201
|
|
)
|
|
|
|
def test_get_knee_point_memory_budget(self):
|
|
"""
|
|
Checks if the method correctly estimates the knee point in the memory budget
|
|
where the trade-off between memory usage and recomputation runtime is optimal.
|
|
|
|
If memory budget and runtime are considered as equal cost, then the knee point
|
|
is where the distance from 0 is smallest.
|
|
"""
|
|
max_mem_budget_to_expected_knee_point = {
|
|
0.1: 0.1,
|
|
0.2: 0.1,
|
|
0.3: 0.3,
|
|
0.4: 0.4, # 0.3 and 0.4 provide the same algo output so this is arbitrary
|
|
0.5: 0.4,
|
|
}
|
|
for (
|
|
max_mem_budget,
|
|
expected_knee_point,
|
|
) in max_mem_budget_to_expected_knee_point.items():
|
|
knee_point_memory_budget = (
|
|
self.knapsack_evaluator.get_knee_point_memory_budget(
|
|
knapsack_algo=self.knapsack_algo,
|
|
max_mem_budget=max_mem_budget,
|
|
min_mem_budget=0.1,
|
|
iterations=5,
|
|
)
|
|
)
|
|
self.assertEqual(knee_point_memory_budget, expected_knee_point)
|
|
|
|
def test_get_backward_memory_from_topologically_sorted_graph(self):
|
|
result = self.knapsack_evaluator._get_backward_memory_from_topologically_sorted_graph(
|
|
node_graph=self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context,
|
|
node_memories=self.graph_info_provider.all_node_memories,
|
|
saved_nodes_set={"node1"},
|
|
peak_memory_after_forward_pass=0.1,
|
|
)
|
|
expected_result = [
|
|
(0.1, "Initial Peak/Current Memory"),
|
|
(0.3, "Recomputing Node: node5"),
|
|
(0.5, "Recomputing Predecessor of node5: node2"),
|
|
(0.3, "Dropping Node: node5"),
|
|
(0.1, "Dropping Node(already saved): node2"),
|
|
(0.0, "Dropping Node(already saved): node1"),
|
|
]
|
|
print(result, expected_result)
|
|
for result_item, expected_result_item in zip(result, expected_result):
|
|
self.assertAlmostEqual(result_item[0], expected_result_item[0])
|
|
self.assertEqual(result_item[1], expected_result_item[1])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|