mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enables load balancing duplicates in DCP (#116469)
Enables the deduplication of saved entries by load balancing duplicates across ranks. Tested with existing and modified tests. Additionally tested with the following code snippet, which saves a 20GB DDP model in **~3 seconds on 8 ranks**. Before this PR, the same operation has been measured at ~19 seconds. ``` def run(local_rank, world_size, param_size, num_params, work_dir): os.environ["RANK"] = str(local_rank) os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) dist.init_process_group(backend="nccl", rank=local_rank, world_size=world_size) model = Model(param_size=param_size, num_params=num_params) model = DistributedDataParallel(model, gradient_as_bucket_view=True) _patch_model_state_dict(model) sz = sum(t.nelement() * t.element_size() for t in model.parameters()) rank_0_print(f"Model size: {sz / 1_000_000_000.0} GB") rank_0_print("Saving the model with DCP...") checkpointer = _FileSystemCheckpointer( f"{args.work_dir}/dcp", sync_files=False, single_file_per_rank=False, thread_count=1 ) begin_ts = time.monotonic() checkpointer.save(state_dict={"model": model}) end_ts = time.monotonic() rank_0_print(f"Took {end_ts - begin_ts} seconds with DCP") ``` Differential Revision: [D52435926](https://our.internmc.facebook.com/intern/diff/D52435926/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/116469 Approved by: https://github.com/fegin, https://github.com/wz337
This commit is contained in:
committed by
PyTorch MergeBot
parent
b95c45fbf7
commit
ff8e33556e
@ -58,7 +58,7 @@ The following types define the planner interface used during checkpoint:
|
||||
.. autoclass:: torch.distributed.checkpoint.SavePlan
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.distributed.checkpoint.WriteItem
|
||||
.. autoclass:: torch.distributed.checkpoint.planner.WriteItem
|
||||
:members:
|
||||
|
||||
We provide a filesystem based storage layer:
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import time
|
||||
from enum import auto, Enum
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -23,6 +24,7 @@ from torch.distributed.tensor.parallel import (
|
||||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
@ -77,6 +79,7 @@ class ModelType(Enum):
|
||||
FSDP = auto()
|
||||
HSDP = auto()
|
||||
FSDP_TP = auto()
|
||||
DDP = auto()
|
||||
NONE = auto() # no parallelization
|
||||
|
||||
|
||||
@ -128,6 +131,9 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
||||
}
|
||||
model = parallelize_module(dummy_model, tp_mesh, parallelize_plan)
|
||||
model = FSDP(model, device_mesh=dp_mesh, use_orig_params=True)
|
||||
elif model_type == ModelType.DDP:
|
||||
model = DistributedDataParallel(dummy_model)
|
||||
model.get_input = partial(TestDummyModel.get_input, model)
|
||||
else:
|
||||
model = dummy_model
|
||||
|
||||
@ -152,10 +158,10 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
@parametrize("compile", [True, False])
|
||||
# TODO: Previously PariwiseParallel does not shard properly, passing ModelType.FSDP_TP test where it
|
||||
# TODO: Previously PairwiseParallel does not shard properly, passing ModelType.FSDP_TP test where it
|
||||
# should have failed. Disabling the failed test temporarily to unblock the deprecation of PairwiseParallel.
|
||||
# @parametrize("model_type", [ModelType.FSDP, ModelType.HSDP, ModelType.FSDP_TP])
|
||||
@parametrize("model_type", [ModelType.FSDP, ModelType.HSDP])
|
||||
@parametrize("model_type", [ModelType.FSDP, ModelType.HSDP, ModelType.DDP])
|
||||
def test_e2e(self, compile, model_type):
|
||||
self._run_e2e_test(compile, model_type)
|
||||
|
||||
|
@ -433,6 +433,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
"sharded": sharded_tensor.rand(save_spec, tensor_size),
|
||||
"replicated": torch.rand(tensor_size, device="cpu"),
|
||||
}
|
||||
dist.broadcast(save_dict["replicated"], src=0)
|
||||
|
||||
fs_writer = FileSystemWriter(path=path, thread_count=thread_count)
|
||||
save_state_dict(state_dict=save_dict, storage_writer=fs_writer)
|
||||
@ -454,6 +455,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
torch.allclose(save_dict_sharded, load_dict["sharded"]),
|
||||
f"save-spec {save_spec} load-spec {load_spec}",
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(save_dict["replicated"], load_dict_replicated),
|
||||
f"save-spec {save_spec} load-spec {load_spec}",
|
||||
|
@ -13,7 +13,7 @@ from torch.distributed._shard.sharded_tensor import (
|
||||
from torch.distributed._shard.sharded_tensor.metadata import (
|
||||
TensorProperties as TensorProperties_Shard,
|
||||
)
|
||||
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
|
||||
from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
|
||||
|
||||
from torch.distributed.checkpoint.default_planner import (
|
||||
_create_default_local_metadata,
|
||||
@ -143,7 +143,7 @@ class TestSavePlan(TestCase):
|
||||
return create_default_local_save_plan(state_dict, rank == 0)
|
||||
|
||||
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
|
||||
all_plans = dedup_tensors(all_plans)
|
||||
all_plans = dedup_save_plans(all_plans)
|
||||
final_plans, metadata = create_default_global_save_plan(all_plans=all_plans)
|
||||
|
||||
# The default global plan updates all indexes to include hints
|
||||
|
49
torch/distributed/checkpoint/_dedup_save_plans.py
Normal file
49
torch/distributed/checkpoint/_dedup_save_plans.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import dataclasses
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from torch.distributed.checkpoint.metadata import MetadataIndex
|
||||
from torch.distributed.checkpoint.planner import SavePlan, WriteItem
|
||||
|
||||
__all__ = ["dedup_save_plans"]
|
||||
|
||||
|
||||
def dedup_save_plans(all_plans: List[SavePlan]) -> List[SavePlan]:
|
||||
"""
|
||||
Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across
|
||||
a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry.
|
||||
"""
|
||||
|
||||
write_item_to_plan_indices: Dict[MetadataIndex, Set[int]] = defaultdict(set)
|
||||
write_item_idx_to_write_item: Dict[MetadataIndex, WriteItem] = {}
|
||||
for plan_idx, plan in enumerate(all_plans):
|
||||
for write_item in plan.items:
|
||||
# map each write item to its plan
|
||||
write_item_to_plan_indices[write_item.index].add(plan_idx)
|
||||
write_item_idx_to_write_item[write_item.index] = write_item
|
||||
|
||||
# put item in the plan with the smallest size and remove it from the other plan_indices
|
||||
to_remove: List[Set] = [set() for _ in range(len(all_plans))]
|
||||
plan_to_size = [0] * len(all_plans)
|
||||
for write_item_idx, plan_indices in write_item_to_plan_indices.items():
|
||||
select_plan_idx = min(plan_indices, key=lambda plan_idx: plan_to_size[plan_idx])
|
||||
|
||||
write_item = write_item_idx_to_write_item[write_item_idx]
|
||||
# essentially ignores the storage size of anything that is not a tensor, since
|
||||
# we don't know how much storage they represent
|
||||
plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1
|
||||
|
||||
plan_indices.remove(select_plan_idx)
|
||||
for plan_idx in plan_indices:
|
||||
to_remove[plan_idx].add(write_item_idx)
|
||||
|
||||
for plan_idx, remove_set in enumerate(to_remove):
|
||||
new_items = [
|
||||
write_item
|
||||
for write_item in all_plans[plan_idx].items
|
||||
if write_item.index not in remove_set
|
||||
]
|
||||
all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
|
||||
|
||||
return all_plans
|
@ -6,12 +6,12 @@ import logging
|
||||
import operator
|
||||
from collections import ChainMap
|
||||
from functools import reduce
|
||||
from typing import Any, cast, Dict, List, Tuple, Union
|
||||
from typing import Any, cast, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard._utils import narrow_tensor_by_index
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
|
||||
from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
|
||||
from torch.distributed.checkpoint._nested_dict import (
|
||||
FLATTEN_MAPPING,
|
||||
flatten_state_dict,
|
||||
@ -65,13 +65,19 @@ class DefaultSavePlanner(SavePlanner):
|
||||
self,
|
||||
flatten_state_dict: bool = True,
|
||||
flatten_sharded_tensors: bool = True,
|
||||
dedup_replicated_tensors: bool = True,
|
||||
dedup_replicated_tensors: Optional[bool] = None,
|
||||
) -> None:
|
||||
self.flatten_state_dict = flatten_state_dict
|
||||
self.flatten_sharded_tensors = flatten_sharded_tensors
|
||||
self.dedup_replicated_tensors = dedup_replicated_tensors
|
||||
self.mappings = {}
|
||||
|
||||
if dedup_replicated_tensors is not None:
|
||||
logger.warning(
|
||||
"DefaultSavePlanner's `dedup_replicated_tensors` argument is being "
|
||||
"deprecated, and no longer has any effect. Please remove this argument "
|
||||
"from your call."
|
||||
)
|
||||
|
||||
def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None:
|
||||
if self.flatten_state_dict:
|
||||
state_dict, self.mappings = flatten_state_dict(state_dict)
|
||||
@ -91,8 +97,7 @@ class DefaultSavePlanner(SavePlanner):
|
||||
def create_global_plan(
|
||||
self, all_plans: List[SavePlan]
|
||||
) -> Tuple[List[SavePlan], Metadata]:
|
||||
if self.dedup_replicated_tensors:
|
||||
all_plans = dedup_tensors(all_plans)
|
||||
all_plans = dedup_save_plans(all_plans)
|
||||
|
||||
global_plan, metadata = create_default_global_save_plan(all_plans)
|
||||
|
||||
|
@ -2,6 +2,7 @@ import abc
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from functools import reduce
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -48,12 +49,28 @@ class TensorWriteData:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WriteItem:
|
||||
"""Dataclass which holds information about what needs to be written to storage."""
|
||||
|
||||
index: MetadataIndex
|
||||
type: WriteItemType
|
||||
|
||||
# Value present if it's a tensor write
|
||||
tensor_data: Optional[TensorWriteData] = None
|
||||
|
||||
def tensor_storage_size(self) -> Optional[int]:
|
||||
"""
|
||||
Calculates the storage size of the underlying tensor, or None if this is not a tensor write.
|
||||
|
||||
Returns:
|
||||
Optional[int] storage size, in bytes of underlying tensor if any.
|
||||
"""
|
||||
if self.tensor_data is None:
|
||||
return None
|
||||
|
||||
numels = reduce(lambda x, y: x * y, self.tensor_data.size, 1)
|
||||
dtype_size = torch._utils._element_size(self.tensor_data.properties.dtype)
|
||||
return numels * dtype_size
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReadItem:
|
||||
|
Reference in New Issue
Block a user