[DCP] Avoid in-place update and deepcopy during dudpe (#149320)

Summary:
Avoid in-place update and deepcopy during dudpe. Deepcopy becomes prohibitively expensive with models having a huge number of FQNs. This was manifestd in the Ads 2K experiment as well. Here are the results from the TextRay model in Mitra:

#### Control job with deepcopy regression:
First save ~24.8s
Global step latency is ~7-8s

Test job with the new fix to avoid deepcopy:
First save is ~21s
global step latency ~2s

Test Plan:
```
buck test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/distributed/checkpoint:test_planner
```
https://www.internalfb.com/intern/testinfra/testrun/3940649945104822

Differential Revision: D71245218

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149320
Approved by: https://github.com/MeetVadakkanchery
This commit is contained in:
Saurabh Mishra
2025-03-18 16:08:37 +00:00
committed by PyTorch MergeBot
parent c41196a4d0
commit 381d0cb239
3 changed files with 51 additions and 36 deletions

View File

@ -1,6 +1,5 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
import copy
import sys import sys
import torch import torch
@ -201,6 +200,26 @@ class TestSavePlan(TestCase):
item_md.chunks[new_item.index.index], old_item.tensor_data.chunk item_md.chunks[new_item.index.index], old_item.tensor_data.chunk
) )
def test_dedup_plans(self):
def create_data(rank):
with with_dist(rank=rank, world_size=4):
tensor = torch.rand(10)
val = [1, 2, 3]
st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
state_dict = {"tensor": tensor, "value": val, "st": st}
return create_default_local_save_plan(state_dict, rank == 0)
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
deduped_plans = dedup_save_plans(all_plans)
# Number of plans should remain unchanged
self.assertEqual(len(all_plans), len(deduped_plans))
# Numer of items in the deduped plans should be less than the original plans
for new_plan, old_plan in zip(deduped_plans, all_plans):
self.assertFalse(_compare_save_plans(new_plan, old_plan))
self.assertTrue(len(new_plan.items) < len(old_plan.items))
def test_global_plan_with_caching(self): def test_global_plan_with_caching(self):
def create_data(rank): def create_data(rank):
with with_dist(rank=rank, world_size=4): with with_dist(rank=rank, world_size=4):
@ -213,7 +232,6 @@ class TestSavePlan(TestCase):
return planner.create_local_plan() return planner.create_local_plan()
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)] all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
expected_all_plans = copy.deepcopy(all_plans)
planner = DefaultSavePlanner(enable_plan_caching=True) planner = DefaultSavePlanner(enable_plan_caching=True)
# First iteration, should create a new plan # First iteration, should create a new plan
first_global_plan, first_metadata = planner.create_global_plan(all_plans) first_global_plan, first_metadata = planner.create_global_plan(all_plans)
@ -224,12 +242,11 @@ class TestSavePlan(TestCase):
# Validate that all_plans are cached # Validate that all_plans are cached
cached_all_plans = SavePlanner._cached_all_plans[planner._cached_plans_key] cached_all_plans = SavePlanner._cached_all_plans[planner._cached_plans_key]
self.assertEqual(cached_all_plans, expected_all_plans) self.assertEqual(cached_all_plans, all_plans)
# Second iteration, should return empty plans # Second iteration, should return empty plans
# Recreate the plans as the previous ones are deduped. # Recreate the plans as the previous ones are deduped.
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)] all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
expected_all_plans = copy.deepcopy(all_plans)
second_global_plan, second_metadata = planner.create_global_plan(all_plans) second_global_plan, second_metadata = planner.create_global_plan(all_plans)
# All the plans should be empty and usable # All the plans should be empty and usable
for plan in second_global_plan: for plan in second_global_plan:
@ -242,7 +259,7 @@ class TestSavePlan(TestCase):
# Validate that all_plans are cached and remain unchanged. # Validate that all_plans are cached and remain unchanged.
cached_all_plans = SavePlanner._cached_all_plans[planner._cached_plans_key] cached_all_plans = SavePlanner._cached_all_plans[planner._cached_plans_key]
self.assertEqual(cached_all_plans, expected_all_plans) self.assertEqual(cached_all_plans, all_plans)
# Third iteration with changed plans # Third iteration with changed plans
def create_data_v2(rank): def create_data_v2(rank):
@ -261,7 +278,6 @@ class TestSavePlan(TestCase):
create_data_v2(2), create_data_v2(2),
create_data_v2(3), create_data_v2(3),
] ]
expected_all_plans = copy.deepcopy(all_plans)
third_global_plan, third_metadata = planner.create_global_plan(all_plans) third_global_plan, third_metadata = planner.create_global_plan(all_plans)
# Only the rank 0 plan should be non-empty. The rest should be empty # Only the rank 0 plan should be non-empty. The rest should be empty
tensor_plan = third_global_plan[0] tensor_plan = third_global_plan[0]
@ -270,7 +286,7 @@ class TestSavePlan(TestCase):
# Validate that all_plans are updated and cached # Validate that all_plans are updated and cached
cached_all_plans = SavePlanner._cached_all_plans[planner._cached_plans_key] cached_all_plans = SavePlanner._cached_all_plans[planner._cached_plans_key]
self.assertEqual(cached_all_plans, expected_all_plans) self.assertEqual(cached_all_plans, all_plans)
for plan in third_global_plan[1:]: for plan in third_global_plan[1:]:
self.assertFalse(plan.usable) self.assertFalse(plan.usable)

View File

@ -19,18 +19,25 @@ def dedup_save_plans(
""" """
Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across
a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry. a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry.
Please note that this function does not modify the original SavePlans, but rather returns
""" """
# Map to query the plan indices that a write item is duplicated in
write_item_to_plan_indices: dict[MetadataIndex, set[int]] = defaultdict(set) write_item_to_plan_indices: dict[MetadataIndex, set[int]] = defaultdict(set)
# Map to query the write item from its index
write_item_idx_to_write_item: dict[MetadataIndex, WriteItem] = {} write_item_idx_to_write_item: dict[MetadataIndex, WriteItem] = {}
# Set of write item indices that are present in each plan
# After deduplication, this will be the set of write item indices that are present in the final plans
plan_to_item_indices: list[set[MetadataIndex]] = [
{item.index for item in plan.items} for plan in all_plans
]
for plan_idx, plan in enumerate(all_plans): for plan_idx, plan in enumerate(all_plans):
for write_item in plan.items: for write_item in plan.items:
# map each write item to its plan # map each write item to its plan
write_item_to_plan_indices[write_item.index].add(plan_idx) write_item_to_plan_indices[write_item.index].add(plan_idx)
write_item_idx_to_write_item[write_item.index] = write_item write_item_idx_to_write_item[write_item.index] = write_item
# put item in the plan with the smallest size and remove it from the other plan_indices
to_remove: list[set] = [set() for _ in range(len(all_plans))]
plan_to_size = [0] * len(all_plans) plan_to_size = [0] * len(all_plans)
for write_item_idx, plan_indices in write_item_to_plan_indices.items(): for write_item_idx, plan_indices in write_item_to_plan_indices.items():
if save_to_lowest_rank: if save_to_lowest_rank:
@ -41,20 +48,17 @@ def dedup_save_plans(
) )
write_item = write_item_idx_to_write_item[write_item_idx] write_item = write_item_idx_to_write_item[write_item_idx]
# essentially ignores the storage size of anything that is not a tensor, since # Ignore the storage size of anything that is not a tensor, since
# we don't know how much storage they represent # we don't know how much storage they represent
plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1 plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1
for plan_idx in plan_indices - {select_plan_idx}:
plan_indices.remove(select_plan_idx) plan_to_item_indices[plan_idx].discard(write_item_idx)
for plan_idx in plan_indices: # Sanity check
to_remove[plan_idx].add(write_item_idx) assert len(all_plans) == len(plan_to_item_indices)
# Create new plans with the updated write items post deduplication
for plan_idx, remove_set in enumerate(to_remove): return [
new_items = [ dataclasses.replace(
write_item plan, items=[item for item in plan.items if item.index in item_indexes]
for write_item in all_plans[plan_idx].items )
if write_item.index not in remove_set for plan, item_indexes in zip(all_plans, plan_to_item_indices)
] ]
all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
return all_plans

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
import copy
import dataclasses import dataclasses
import io import io
import logging import logging
@ -129,9 +128,9 @@ class DefaultSavePlanner(SavePlanner):
def _create_global_plan( def _create_global_plan(
self, all_plans: list[SavePlan] self, all_plans: list[SavePlan]
) -> tuple[list[SavePlan], Metadata]: ) -> tuple[list[SavePlan], Metadata]:
all_plans = dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank) deduped_plans = dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank)
global_plan, metadata = create_default_global_save_plan(all_plans) global_plan, metadata = create_default_global_save_plan(deduped_plans)
if self.flatten_state_dict: if self.flatten_state_dict:
# | does not work for Python 3.8 or older version. # | does not work for Python 3.8 or older version.
@ -157,10 +156,8 @@ class DefaultSavePlanner(SavePlanner):
global_plan_delta: list[SavePlan] = [] global_plan_delta: list[SavePlan] = []
if self._cached_plans_key not in SavePlanner._cached_all_plans: if self._cached_plans_key not in SavePlanner._cached_all_plans:
# Make a deepcopy of all_plans to avoid caching the modified plans post de-dupe # Cache the all_plans
SavePlanner._cached_all_plans[self._cached_plans_key] = copy.deepcopy( SavePlanner._cached_all_plans[self._cached_plans_key] = all_plans
all_plans
)
global_plan, metadata = self._create_global_plan(all_plans) global_plan, metadata = self._create_global_plan(all_plans)
SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan
# If plans are not cached, global_plan delta will be the same as global plan. # If plans are not cached, global_plan delta will be the same as global plan.
@ -171,10 +168,8 @@ class DefaultSavePlanner(SavePlanner):
merged_plans = _merge_delta_local_plans( merged_plans = _merge_delta_local_plans(
SavePlanner._cached_all_plans[self._cached_plans_key], all_plans SavePlanner._cached_all_plans[self._cached_plans_key], all_plans
) )
# Make a deepcopy of merged_plans to avoid caching the modified plans post de-dupe # Cache the merged_plans
SavePlanner._cached_all_plans[self._cached_plans_key] = copy.deepcopy( SavePlanner._cached_all_plans[self._cached_plans_key] = merged_plans
merged_plans
)
global_plan, metadata = self._create_global_plan(merged_plans) global_plan, metadata = self._create_global_plan(merged_plans)