mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[DCP] Avoid in-place update and deepcopy during dudpe (#149320)
Summary: Avoid in-place update and deepcopy during dudpe. Deepcopy becomes prohibitively expensive with models having a huge number of FQNs. This was manifestd in the Ads 2K experiment as well. Here are the results from the TextRay model in Mitra: #### Control job with deepcopy regression: First save ~24.8s Global step latency is ~7-8s Test job with the new fix to avoid deepcopy: First save is ~21s global step latency ~2s Test Plan: ``` buck test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/distributed/checkpoint:test_planner ``` https://www.internalfb.com/intern/testinfra/testrun/3940649945104822 Differential Revision: D71245218 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149320 Approved by: https://github.com/MeetVadakkanchery
This commit is contained in:
committed by
PyTorch MergeBot
parent
c41196a4d0
commit
381d0cb239
@ -1,6 +1,5 @@
|
|||||||
# Owner(s): ["oncall: distributed"]
|
# Owner(s): ["oncall: distributed"]
|
||||||
|
|
||||||
import copy
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -201,6 +200,26 @@ class TestSavePlan(TestCase):
|
|||||||
item_md.chunks[new_item.index.index], old_item.tensor_data.chunk
|
item_md.chunks[new_item.index.index], old_item.tensor_data.chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_dedup_plans(self):
|
||||||
|
def create_data(rank):
|
||||||
|
with with_dist(rank=rank, world_size=4):
|
||||||
|
tensor = torch.rand(10)
|
||||||
|
val = [1, 2, 3]
|
||||||
|
st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
|
||||||
|
state_dict = {"tensor": tensor, "value": val, "st": st}
|
||||||
|
return create_default_local_save_plan(state_dict, rank == 0)
|
||||||
|
|
||||||
|
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
|
||||||
|
deduped_plans = dedup_save_plans(all_plans)
|
||||||
|
|
||||||
|
# Number of plans should remain unchanged
|
||||||
|
self.assertEqual(len(all_plans), len(deduped_plans))
|
||||||
|
|
||||||
|
# Numer of items in the deduped plans should be less than the original plans
|
||||||
|
for new_plan, old_plan in zip(deduped_plans, all_plans):
|
||||||
|
self.assertFalse(_compare_save_plans(new_plan, old_plan))
|
||||||
|
self.assertTrue(len(new_plan.items) < len(old_plan.items))
|
||||||
|
|
||||||
def test_global_plan_with_caching(self):
|
def test_global_plan_with_caching(self):
|
||||||
def create_data(rank):
|
def create_data(rank):
|
||||||
with with_dist(rank=rank, world_size=4):
|
with with_dist(rank=rank, world_size=4):
|
||||||
@ -213,7 +232,6 @@ class TestSavePlan(TestCase):
|
|||||||
return planner.create_local_plan()
|
return planner.create_local_plan()
|
||||||
|
|
||||||
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
|
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
|
||||||
expected_all_plans = copy.deepcopy(all_plans)
|
|
||||||
planner = DefaultSavePlanner(enable_plan_caching=True)
|
planner = DefaultSavePlanner(enable_plan_caching=True)
|
||||||
# First iteration, should create a new plan
|
# First iteration, should create a new plan
|
||||||
first_global_plan, first_metadata = planner.create_global_plan(all_plans)
|
first_global_plan, first_metadata = planner.create_global_plan(all_plans)
|
||||||
@ -224,12 +242,11 @@ class TestSavePlan(TestCase):
|
|||||||
|
|
||||||
# Validate that all_plans are cached
|
# Validate that all_plans are cached
|
||||||
cached_all_plans = SavePlanner._cached_all_plans[planner._cached_plans_key]
|
cached_all_plans = SavePlanner._cached_all_plans[planner._cached_plans_key]
|
||||||
self.assertEqual(cached_all_plans, expected_all_plans)
|
self.assertEqual(cached_all_plans, all_plans)
|
||||||
|
|
||||||
# Second iteration, should return empty plans
|
# Second iteration, should return empty plans
|
||||||
# Recreate the plans as the previous ones are deduped.
|
# Recreate the plans as the previous ones are deduped.
|
||||||
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
|
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
|
||||||
expected_all_plans = copy.deepcopy(all_plans)
|
|
||||||
second_global_plan, second_metadata = planner.create_global_plan(all_plans)
|
second_global_plan, second_metadata = planner.create_global_plan(all_plans)
|
||||||
# All the plans should be empty and usable
|
# All the plans should be empty and usable
|
||||||
for plan in second_global_plan:
|
for plan in second_global_plan:
|
||||||
@ -242,7 +259,7 @@ class TestSavePlan(TestCase):
|
|||||||
|
|
||||||
# Validate that all_plans are cached and remain unchanged.
|
# Validate that all_plans are cached and remain unchanged.
|
||||||
cached_all_plans = SavePlanner._cached_all_plans[planner._cached_plans_key]
|
cached_all_plans = SavePlanner._cached_all_plans[planner._cached_plans_key]
|
||||||
self.assertEqual(cached_all_plans, expected_all_plans)
|
self.assertEqual(cached_all_plans, all_plans)
|
||||||
|
|
||||||
# Third iteration with changed plans
|
# Third iteration with changed plans
|
||||||
def create_data_v2(rank):
|
def create_data_v2(rank):
|
||||||
@ -261,7 +278,6 @@ class TestSavePlan(TestCase):
|
|||||||
create_data_v2(2),
|
create_data_v2(2),
|
||||||
create_data_v2(3),
|
create_data_v2(3),
|
||||||
]
|
]
|
||||||
expected_all_plans = copy.deepcopy(all_plans)
|
|
||||||
third_global_plan, third_metadata = planner.create_global_plan(all_plans)
|
third_global_plan, third_metadata = planner.create_global_plan(all_plans)
|
||||||
# Only the rank 0 plan should be non-empty. The rest should be empty
|
# Only the rank 0 plan should be non-empty. The rest should be empty
|
||||||
tensor_plan = third_global_plan[0]
|
tensor_plan = third_global_plan[0]
|
||||||
@ -270,7 +286,7 @@ class TestSavePlan(TestCase):
|
|||||||
|
|
||||||
# Validate that all_plans are updated and cached
|
# Validate that all_plans are updated and cached
|
||||||
cached_all_plans = SavePlanner._cached_all_plans[planner._cached_plans_key]
|
cached_all_plans = SavePlanner._cached_all_plans[planner._cached_plans_key]
|
||||||
self.assertEqual(cached_all_plans, expected_all_plans)
|
self.assertEqual(cached_all_plans, all_plans)
|
||||||
|
|
||||||
for plan in third_global_plan[1:]:
|
for plan in third_global_plan[1:]:
|
||||||
self.assertFalse(plan.usable)
|
self.assertFalse(plan.usable)
|
||||||
|
@ -19,18 +19,25 @@ def dedup_save_plans(
|
|||||||
"""
|
"""
|
||||||
Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across
|
Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across
|
||||||
a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry.
|
a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry.
|
||||||
|
|
||||||
|
Please note that this function does not modify the original SavePlans, but rather returns
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Map to query the plan indices that a write item is duplicated in
|
||||||
write_item_to_plan_indices: dict[MetadataIndex, set[int]] = defaultdict(set)
|
write_item_to_plan_indices: dict[MetadataIndex, set[int]] = defaultdict(set)
|
||||||
|
# Map to query the write item from its index
|
||||||
write_item_idx_to_write_item: dict[MetadataIndex, WriteItem] = {}
|
write_item_idx_to_write_item: dict[MetadataIndex, WriteItem] = {}
|
||||||
|
# Set of write item indices that are present in each plan
|
||||||
|
# After deduplication, this will be the set of write item indices that are present in the final plans
|
||||||
|
plan_to_item_indices: list[set[MetadataIndex]] = [
|
||||||
|
{item.index for item in plan.items} for plan in all_plans
|
||||||
|
]
|
||||||
|
|
||||||
for plan_idx, plan in enumerate(all_plans):
|
for plan_idx, plan in enumerate(all_plans):
|
||||||
for write_item in plan.items:
|
for write_item in plan.items:
|
||||||
# map each write item to its plan
|
# map each write item to its plan
|
||||||
write_item_to_plan_indices[write_item.index].add(plan_idx)
|
write_item_to_plan_indices[write_item.index].add(plan_idx)
|
||||||
write_item_idx_to_write_item[write_item.index] = write_item
|
write_item_idx_to_write_item[write_item.index] = write_item
|
||||||
|
|
||||||
# put item in the plan with the smallest size and remove it from the other plan_indices
|
|
||||||
to_remove: list[set] = [set() for _ in range(len(all_plans))]
|
|
||||||
plan_to_size = [0] * len(all_plans)
|
plan_to_size = [0] * len(all_plans)
|
||||||
for write_item_idx, plan_indices in write_item_to_plan_indices.items():
|
for write_item_idx, plan_indices in write_item_to_plan_indices.items():
|
||||||
if save_to_lowest_rank:
|
if save_to_lowest_rank:
|
||||||
@ -41,20 +48,17 @@ def dedup_save_plans(
|
|||||||
)
|
)
|
||||||
|
|
||||||
write_item = write_item_idx_to_write_item[write_item_idx]
|
write_item = write_item_idx_to_write_item[write_item_idx]
|
||||||
# essentially ignores the storage size of anything that is not a tensor, since
|
# Ignore the storage size of anything that is not a tensor, since
|
||||||
# we don't know how much storage they represent
|
# we don't know how much storage they represent
|
||||||
plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1
|
plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1
|
||||||
|
for plan_idx in plan_indices - {select_plan_idx}:
|
||||||
plan_indices.remove(select_plan_idx)
|
plan_to_item_indices[plan_idx].discard(write_item_idx)
|
||||||
for plan_idx in plan_indices:
|
# Sanity check
|
||||||
to_remove[plan_idx].add(write_item_idx)
|
assert len(all_plans) == len(plan_to_item_indices)
|
||||||
|
# Create new plans with the updated write items post deduplication
|
||||||
for plan_idx, remove_set in enumerate(to_remove):
|
return [
|
||||||
new_items = [
|
dataclasses.replace(
|
||||||
write_item
|
plan, items=[item for item in plan.items if item.index in item_indexes]
|
||||||
for write_item in all_plans[plan_idx].items
|
)
|
||||||
if write_item.index not in remove_set
|
for plan, item_indexes in zip(all_plans, plan_to_item_indices)
|
||||||
]
|
]
|
||||||
all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
|
|
||||||
|
|
||||||
return all_plans
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||||
|
|
||||||
import copy
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
@ -129,9 +128,9 @@ class DefaultSavePlanner(SavePlanner):
|
|||||||
def _create_global_plan(
|
def _create_global_plan(
|
||||||
self, all_plans: list[SavePlan]
|
self, all_plans: list[SavePlan]
|
||||||
) -> tuple[list[SavePlan], Metadata]:
|
) -> tuple[list[SavePlan], Metadata]:
|
||||||
all_plans = dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank)
|
deduped_plans = dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank)
|
||||||
|
|
||||||
global_plan, metadata = create_default_global_save_plan(all_plans)
|
global_plan, metadata = create_default_global_save_plan(deduped_plans)
|
||||||
|
|
||||||
if self.flatten_state_dict:
|
if self.flatten_state_dict:
|
||||||
# | does not work for Python 3.8 or older version.
|
# | does not work for Python 3.8 or older version.
|
||||||
@ -157,10 +156,8 @@ class DefaultSavePlanner(SavePlanner):
|
|||||||
global_plan_delta: list[SavePlan] = []
|
global_plan_delta: list[SavePlan] = []
|
||||||
|
|
||||||
if self._cached_plans_key not in SavePlanner._cached_all_plans:
|
if self._cached_plans_key not in SavePlanner._cached_all_plans:
|
||||||
# Make a deepcopy of all_plans to avoid caching the modified plans post de-dupe
|
# Cache the all_plans
|
||||||
SavePlanner._cached_all_plans[self._cached_plans_key] = copy.deepcopy(
|
SavePlanner._cached_all_plans[self._cached_plans_key] = all_plans
|
||||||
all_plans
|
|
||||||
)
|
|
||||||
global_plan, metadata = self._create_global_plan(all_plans)
|
global_plan, metadata = self._create_global_plan(all_plans)
|
||||||
SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan
|
SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan
|
||||||
# If plans are not cached, global_plan delta will be the same as global plan.
|
# If plans are not cached, global_plan delta will be the same as global plan.
|
||||||
@ -171,10 +168,8 @@ class DefaultSavePlanner(SavePlanner):
|
|||||||
merged_plans = _merge_delta_local_plans(
|
merged_plans = _merge_delta_local_plans(
|
||||||
SavePlanner._cached_all_plans[self._cached_plans_key], all_plans
|
SavePlanner._cached_all_plans[self._cached_plans_key], all_plans
|
||||||
)
|
)
|
||||||
# Make a deepcopy of merged_plans to avoid caching the modified plans post de-dupe
|
# Cache the merged_plans
|
||||||
SavePlanner._cached_all_plans[self._cached_plans_key] = copy.deepcopy(
|
SavePlanner._cached_all_plans[self._cached_plans_key] = merged_plans
|
||||||
merged_plans
|
|
||||||
)
|
|
||||||
|
|
||||||
global_plan, metadata = self._create_global_plan(merged_plans)
|
global_plan, metadata = self._create_global_plan(merged_plans)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user