From 381d0cb239fa35f23b34ffb51f94784a3f9798b4 Mon Sep 17 00:00:00 2001 From: Saurabh Mishra Date: Tue, 18 Mar 2025 16:08:37 +0000 Subject: [PATCH] [DCP] Avoid in-place update and deepcopy during dudpe (#149320) Summary: Avoid in-place update and deepcopy during dudpe. Deepcopy becomes prohibitively expensive with models having a huge number of FQNs. This was manifestd in the Ads 2K experiment as well. Here are the results from the TextRay model in Mitra: #### Control job with deepcopy regression: First save ~24.8s Global step latency is ~7-8s Test job with the new fix to avoid deepcopy: First save is ~21s global step latency ~2s Test Plan: ``` buck test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/distributed/checkpoint:test_planner ``` https://www.internalfb.com/intern/testinfra/testrun/3940649945104822 Differential Revision: D71245218 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149320 Approved by: https://github.com/MeetVadakkanchery --- test/distributed/checkpoint/test_planner.py | 30 ++++++++++---- .../checkpoint/_dedup_save_plans.py | 40 ++++++++++--------- .../distributed/checkpoint/default_planner.py | 17 +++----- 3 files changed, 51 insertions(+), 36 deletions(-) diff --git a/test/distributed/checkpoint/test_planner.py b/test/distributed/checkpoint/test_planner.py index a8d673ca8e4c..0d2183df05b8 100644 --- a/test/distributed/checkpoint/test_planner.py +++ b/test/distributed/checkpoint/test_planner.py @@ -1,6 +1,5 @@ # Owner(s): ["oncall: distributed"] -import copy import sys import torch @@ -201,6 +200,26 @@ class TestSavePlan(TestCase): item_md.chunks[new_item.index.index], old_item.tensor_data.chunk ) + def test_dedup_plans(self): + def create_data(rank): + with with_dist(rank=rank, world_size=4): + tensor = torch.rand(10) + val = [1, 2, 3] + st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1) + state_dict = {"tensor": tensor, "value": val, "st": st} + return create_default_local_save_plan(state_dict, rank == 0) + + all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)] + deduped_plans = dedup_save_plans(all_plans) + + # Number of plans should remain unchanged + self.assertEqual(len(all_plans), len(deduped_plans)) + + # Numer of items in the deduped plans should be less than the original plans + for new_plan, old_plan in zip(deduped_plans, all_plans): + self.assertFalse(_compare_save_plans(new_plan, old_plan)) + self.assertTrue(len(new_plan.items) < len(old_plan.items)) + def test_global_plan_with_caching(self): def create_data(rank): with with_dist(rank=rank, world_size=4): @@ -213,7 +232,6 @@ class TestSavePlan(TestCase): return planner.create_local_plan() all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)] - expected_all_plans = copy.deepcopy(all_plans) planner = DefaultSavePlanner(enable_plan_caching=True) # First iteration, should create a new plan first_global_plan, first_metadata = planner.create_global_plan(all_plans) @@ -224,12 +242,11 @@ class TestSavePlan(TestCase): # Validate that all_plans are cached cached_all_plans = SavePlanner._cached_all_plans[planner._cached_plans_key] - self.assertEqual(cached_all_plans, expected_all_plans) + self.assertEqual(cached_all_plans, all_plans) # Second iteration, should return empty plans # Recreate the plans as the previous ones are deduped. all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)] - expected_all_plans = copy.deepcopy(all_plans) second_global_plan, second_metadata = planner.create_global_plan(all_plans) # All the plans should be empty and usable for plan in second_global_plan: @@ -242,7 +259,7 @@ class TestSavePlan(TestCase): # Validate that all_plans are cached and remain unchanged. cached_all_plans = SavePlanner._cached_all_plans[planner._cached_plans_key] - self.assertEqual(cached_all_plans, expected_all_plans) + self.assertEqual(cached_all_plans, all_plans) # Third iteration with changed plans def create_data_v2(rank): @@ -261,7 +278,6 @@ class TestSavePlan(TestCase): create_data_v2(2), create_data_v2(3), ] - expected_all_plans = copy.deepcopy(all_plans) third_global_plan, third_metadata = planner.create_global_plan(all_plans) # Only the rank 0 plan should be non-empty. The rest should be empty tensor_plan = third_global_plan[0] @@ -270,7 +286,7 @@ class TestSavePlan(TestCase): # Validate that all_plans are updated and cached cached_all_plans = SavePlanner._cached_all_plans[planner._cached_plans_key] - self.assertEqual(cached_all_plans, expected_all_plans) + self.assertEqual(cached_all_plans, all_plans) for plan in third_global_plan[1:]: self.assertFalse(plan.usable) diff --git a/torch/distributed/checkpoint/_dedup_save_plans.py b/torch/distributed/checkpoint/_dedup_save_plans.py index e416030a816a..3e2cf954c409 100644 --- a/torch/distributed/checkpoint/_dedup_save_plans.py +++ b/torch/distributed/checkpoint/_dedup_save_plans.py @@ -19,18 +19,25 @@ def dedup_save_plans( """ Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry. + + Please note that this function does not modify the original SavePlans, but rather returns """ + # Map to query the plan indices that a write item is duplicated in write_item_to_plan_indices: dict[MetadataIndex, set[int]] = defaultdict(set) + # Map to query the write item from its index write_item_idx_to_write_item: dict[MetadataIndex, WriteItem] = {} + # Set of write item indices that are present in each plan + # After deduplication, this will be the set of write item indices that are present in the final plans + plan_to_item_indices: list[set[MetadataIndex]] = [ + {item.index for item in plan.items} for plan in all_plans + ] + for plan_idx, plan in enumerate(all_plans): for write_item in plan.items: # map each write item to its plan write_item_to_plan_indices[write_item.index].add(plan_idx) write_item_idx_to_write_item[write_item.index] = write_item - - # put item in the plan with the smallest size and remove it from the other plan_indices - to_remove: list[set] = [set() for _ in range(len(all_plans))] plan_to_size = [0] * len(all_plans) for write_item_idx, plan_indices in write_item_to_plan_indices.items(): if save_to_lowest_rank: @@ -41,20 +48,17 @@ def dedup_save_plans( ) write_item = write_item_idx_to_write_item[write_item_idx] - # essentially ignores the storage size of anything that is not a tensor, since + # Ignore the storage size of anything that is not a tensor, since # we don't know how much storage they represent plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1 - - plan_indices.remove(select_plan_idx) - for plan_idx in plan_indices: - to_remove[plan_idx].add(write_item_idx) - - for plan_idx, remove_set in enumerate(to_remove): - new_items = [ - write_item - for write_item in all_plans[plan_idx].items - if write_item.index not in remove_set - ] - all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) - - return all_plans + for plan_idx in plan_indices - {select_plan_idx}: + plan_to_item_indices[plan_idx].discard(write_item_idx) + # Sanity check + assert len(all_plans) == len(plan_to_item_indices) + # Create new plans with the updated write items post deduplication + return [ + dataclasses.replace( + plan, items=[item for item in plan.items if item.index in item_indexes] + ) + for plan, item_indexes in zip(all_plans, plan_to_item_indices) + ] diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py index ffeb5a01ec72..f1ba5e18dbb6 100644 --- a/torch/distributed/checkpoint/default_planner.py +++ b/torch/distributed/checkpoint/default_planner.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates -import copy import dataclasses import io import logging @@ -129,9 +128,9 @@ class DefaultSavePlanner(SavePlanner): def _create_global_plan( self, all_plans: list[SavePlan] ) -> tuple[list[SavePlan], Metadata]: - all_plans = dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank) + deduped_plans = dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank) - global_plan, metadata = create_default_global_save_plan(all_plans) + global_plan, metadata = create_default_global_save_plan(deduped_plans) if self.flatten_state_dict: # | does not work for Python 3.8 or older version. @@ -157,10 +156,8 @@ class DefaultSavePlanner(SavePlanner): global_plan_delta: list[SavePlan] = [] if self._cached_plans_key not in SavePlanner._cached_all_plans: - # Make a deepcopy of all_plans to avoid caching the modified plans post de-dupe - SavePlanner._cached_all_plans[self._cached_plans_key] = copy.deepcopy( - all_plans - ) + # Cache the all_plans + SavePlanner._cached_all_plans[self._cached_plans_key] = all_plans global_plan, metadata = self._create_global_plan(all_plans) SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan # If plans are not cached, global_plan delta will be the same as global plan. @@ -171,10 +168,8 @@ class DefaultSavePlanner(SavePlanner): merged_plans = _merge_delta_local_plans( SavePlanner._cached_all_plans[self._cached_plans_key], all_plans ) - # Make a deepcopy of merged_plans to avoid caching the modified plans post de-dupe - SavePlanner._cached_all_plans[self._cached_plans_key] = copy.deepcopy( - merged_plans - ) + # Cache the merged_plans + SavePlanner._cached_all_plans[self._cached_plans_key] = merged_plans global_plan, metadata = self._create_global_plan(merged_plans)