Enables load balancing duplicates in DCP (#116469)

Enables the deduplication of saved entries by load balancing duplicates across ranks.

Tested with existing and modified tests. Additionally tested with the following code snippet, which saves a 20GB DDP model in **~3 seconds on 8 ranks**.  Before this PR, the same operation has been measured at ~19 seconds.

```
def run(local_rank, world_size, param_size, num_params, work_dir):

    os.environ["RANK"] = str(local_rank)
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    dist.init_process_group(backend="nccl", rank=local_rank, world_size=world_size)

    model = Model(param_size=param_size, num_params=num_params)
    model = DistributedDataParallel(model, gradient_as_bucket_view=True)
    _patch_model_state_dict(model)

    sz = sum(t.nelement() * t.element_size() for t in model.parameters())
    rank_0_print(f"Model size: {sz / 1_000_000_000.0} GB")
    rank_0_print("Saving the model with DCP...")

    checkpointer = _FileSystemCheckpointer(
        f"{args.work_dir}/dcp",
        sync_files=False,
        single_file_per_rank=False,
        thread_count=1
    )

    begin_ts = time.monotonic()
    checkpointer.save(state_dict={"model": model})
    end_ts = time.monotonic()
    rank_0_print(f"Took {end_ts - begin_ts} seconds with DCP")
```

Differential Revision: [D52435926](https://our.internmc.facebook.com/intern/diff/D52435926/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116469
Approved by: https://github.com/fegin, https://github.com/wz337
This commit is contained in:
Lucas Pasqualin
2024-01-26 11:25:49 -08:00
committed by PyTorch MergeBot
parent b95c45fbf7
commit ff8e33556e
7 changed files with 90 additions and 11 deletions

View File

@ -58,7 +58,7 @@ The following types define the planner interface used during checkpoint:
.. autoclass:: torch.distributed.checkpoint.SavePlan
:members:
.. autoclass:: torch.distributed.checkpoint.WriteItem
.. autoclass:: torch.distributed.checkpoint.planner.WriteItem
:members:
We provide a filesystem based storage layer:

View File

@ -2,6 +2,7 @@
import time
from enum import auto, Enum
from functools import partial
import torch
import torch.distributed as dist
@ -23,6 +24,7 @@ from torch.distributed.tensor.parallel import (
parallelize_module,
RowwiseParallel,
)
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -77,6 +79,7 @@ class ModelType(Enum):
FSDP = auto()
HSDP = auto()
FSDP_TP = auto()
DDP = auto()
NONE = auto() # no parallelization
@ -128,6 +131,9 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
}
model = parallelize_module(dummy_model, tp_mesh, parallelize_plan)
model = FSDP(model, device_mesh=dp_mesh, use_orig_params=True)
elif model_type == ModelType.DDP:
model = DistributedDataParallel(dummy_model)
model.get_input = partial(TestDummyModel.get_input, model)
else:
model = dummy_model
@ -152,10 +158,10 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
@skip_if_lt_x_gpu(4)
@with_temp_dir
@parametrize("compile", [True, False])
# TODO: Previously PariwiseParallel does not shard properly, passing ModelType.FSDP_TP test where it
# TODO: Previously PairwiseParallel does not shard properly, passing ModelType.FSDP_TP test where it
# should have failed. Disabling the failed test temporarily to unblock the deprecation of PairwiseParallel.
# @parametrize("model_type", [ModelType.FSDP, ModelType.HSDP, ModelType.FSDP_TP])
@parametrize("model_type", [ModelType.FSDP, ModelType.HSDP])
@parametrize("model_type", [ModelType.FSDP, ModelType.HSDP, ModelType.DDP])
def test_e2e(self, compile, model_type):
self._run_e2e_test(compile, model_type)

View File

@ -433,6 +433,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
"sharded": sharded_tensor.rand(save_spec, tensor_size),
"replicated": torch.rand(tensor_size, device="cpu"),
}
dist.broadcast(save_dict["replicated"], src=0)
fs_writer = FileSystemWriter(path=path, thread_count=thread_count)
save_state_dict(state_dict=save_dict, storage_writer=fs_writer)
@ -454,6 +455,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
torch.allclose(save_dict_sharded, load_dict["sharded"]),
f"save-spec {save_spec} load-spec {load_spec}",
)
self.assertTrue(
torch.allclose(save_dict["replicated"], load_dict_replicated),
f"save-spec {save_spec} load-spec {load_spec}",

View File

@ -13,7 +13,7 @@ from torch.distributed._shard.sharded_tensor import (
from torch.distributed._shard.sharded_tensor.metadata import (
TensorProperties as TensorProperties_Shard,
)
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
from torch.distributed.checkpoint.default_planner import (
_create_default_local_metadata,
@ -143,7 +143,7 @@ class TestSavePlan(TestCase):
return create_default_local_save_plan(state_dict, rank == 0)
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
all_plans = dedup_tensors(all_plans)
all_plans = dedup_save_plans(all_plans)
final_plans, metadata = create_default_global_save_plan(all_plans=all_plans)
# The default global plan updates all indexes to include hints

View File

@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import dataclasses
from collections import defaultdict
from typing import Dict, List, Set
from torch.distributed.checkpoint.metadata import MetadataIndex
from torch.distributed.checkpoint.planner import SavePlan, WriteItem
__all__ = ["dedup_save_plans"]
def dedup_save_plans(all_plans: List[SavePlan]) -> List[SavePlan]:
"""
Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across
a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry.
"""
write_item_to_plan_indices: Dict[MetadataIndex, Set[int]] = defaultdict(set)
write_item_idx_to_write_item: Dict[MetadataIndex, WriteItem] = {}
for plan_idx, plan in enumerate(all_plans):
for write_item in plan.items:
# map each write item to its plan
write_item_to_plan_indices[write_item.index].add(plan_idx)
write_item_idx_to_write_item[write_item.index] = write_item
# put item in the plan with the smallest size and remove it from the other plan_indices
to_remove: List[Set] = [set() for _ in range(len(all_plans))]
plan_to_size = [0] * len(all_plans)
for write_item_idx, plan_indices in write_item_to_plan_indices.items():
select_plan_idx = min(plan_indices, key=lambda plan_idx: plan_to_size[plan_idx])
write_item = write_item_idx_to_write_item[write_item_idx]
# essentially ignores the storage size of anything that is not a tensor, since
# we don't know how much storage they represent
plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1
plan_indices.remove(select_plan_idx)
for plan_idx in plan_indices:
to_remove[plan_idx].add(write_item_idx)
for plan_idx, remove_set in enumerate(to_remove):
new_items = [
write_item
for write_item in all_plans[plan_idx].items
if write_item.index not in remove_set
]
all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
return all_plans

View File

@ -6,12 +6,12 @@ import logging
import operator
from collections import ChainMap
from functools import reduce
from typing import Any, cast, Dict, List, Tuple, Union
from typing import Any, cast, Dict, List, Optional, Tuple, Union
import torch
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
from torch.distributed.checkpoint._nested_dict import (
FLATTEN_MAPPING,
flatten_state_dict,
@ -65,13 +65,19 @@ class DefaultSavePlanner(SavePlanner):
self,
flatten_state_dict: bool = True,
flatten_sharded_tensors: bool = True,
dedup_replicated_tensors: bool = True,
dedup_replicated_tensors: Optional[bool] = None,
) -> None:
self.flatten_state_dict = flatten_state_dict
self.flatten_sharded_tensors = flatten_sharded_tensors
self.dedup_replicated_tensors = dedup_replicated_tensors
self.mappings = {}
if dedup_replicated_tensors is not None:
logger.warning(
"DefaultSavePlanner's `dedup_replicated_tensors` argument is being "
"deprecated, and no longer has any effect. Please remove this argument "
"from your call."
)
def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None:
if self.flatten_state_dict:
state_dict, self.mappings = flatten_state_dict(state_dict)
@ -91,8 +97,7 @@ class DefaultSavePlanner(SavePlanner):
def create_global_plan(
self, all_plans: List[SavePlan]
) -> Tuple[List[SavePlan], Metadata]:
if self.dedup_replicated_tensors:
all_plans = dedup_tensors(all_plans)
all_plans = dedup_save_plans(all_plans)
global_plan, metadata = create_default_global_save_plan(all_plans)

View File

@ -2,6 +2,7 @@ import abc
import io
from dataclasses import dataclass
from enum import auto, Enum
from functools import reduce
from typing import Any, List, Optional, Tuple, Union
import torch
@ -48,12 +49,28 @@ class TensorWriteData:
@dataclass(frozen=True)
class WriteItem:
"""Dataclass which holds information about what needs to be written to storage."""
index: MetadataIndex
type: WriteItemType
# Value present if it's a tensor write
tensor_data: Optional[TensorWriteData] = None
def tensor_storage_size(self) -> Optional[int]:
"""
Calculates the storage size of the underlying tensor, or None if this is not a tensor write.
Returns:
Optional[int] storage size, in bytes of underlying tensor if any.
"""
if self.tensor_data is None:
return None
numels = reduce(lambda x, y: x * y, self.tensor_data.size, 1)
dtype_size = torch._utils._element_size(self.tensor_data.properties.dtype)
return numels * dtype_size
@dataclass(frozen=True)
class ReadItem: