mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix https://github.com/pytorch/pytorch/issues/134095 This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective). This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135725 Approved by: https://github.com/fegin