mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dcp] add new checkpoint staging to preserve storage sharing and support mutable state_dicts (#155192)
Summary: This implements staging in way that doesnt mess up checkpointing semantics. We want to be close to torch.save/load semantics and when async checkpointing is used it messes up shared storages, doesnt handle custom objects or tensors well. EG: users passes a state_dict with a cuda tensor in datatype. this is deepcloned causing the staging tensor to be created on GPU. This can cause ooms is hard to debug. This diffs hooks into deepcopy of storages to move them to cpu using the cached storages created for async checkpoint staging. This allows reusing storages created for staging to avoid recreating them on each checkpoint while also being flexible enough to handle any changes - clean up old storages or create new ones as needed. Lifetime of staging storages is tied to the original storage object. when the original storage object is gc-ed, we delete the corresponding staging storage from cache possibly causing it to gc-ed is there are no other references. I am using data_ptr of the storage to keep track of this. Please share thoughts on this. The alternative is to use fqn's instead of storage_id and verify the underlying storage object has same shape/size,etc to make the caching logic work. Current implementation is much simpler and cleaner. The API: ``` # construct a stager once per job in checkpointing. stager = StateDictStager(pin_memory=pin_memory, share_memory=share_memory) # do this on every checkpoint: with staging_context(stager): cpu_state_dict = copy.deepcopy(state_dict) ``` Also, adds support for pinned-memory. One problem this implementation does not address is that we lose the original device. The only alternatives here are - pickle synchronously like torch.save but with special handling for storages. It is valuable to keep state_dict throughout the checkpointing process. so users can manipulate and debug as needed. so we need to unpickle in the background process. I think this is flexible, not performant and not very different to current solution but needs more code. One idea if we really want to address is this to stick the original device in a some variable on storage and then use it recover on load side. I think we do not need this for now and can be explicit about losing device type for async checkpointing. Update: Note: Due to reservations on hooking into deepcopy to customize it, the PR is now updated to use deepcopy like logic to clone the state_dict. There are some caveats to this solution: 1. Duplicated deepcopy code to hook into for tensors. There is a risk of this code getting outdated with python version changes. This is needed to handle several different types like NamedTuples, frozen dataclasses, nested dataclasses. deepcopy logic is relying on reduce_ex to get a function with which these can be constructed. 2. Since we are bypassing deepcopy and adding custom logic to clone a tensor, we are missing some of the functionality that exists in deepcopy for torch.Tensor like _clear_non_serializable_cached_data(), or other logic. Would like thoughts on which logic or if everything should be copied? 3. If any object implemented deepcopy , we will not be able to handle any tensors in the attrs with this logic because they likely just call copy.deepcopy on the attrs instead of this deepcopy logic. We are taking care of subclasses of torch.Tensor to workaround this. The new API: ``` # construct a stager once per job in checkpointing. stager = StateDictStager(pin_memory=pin_memory, share_memory=share_memory) # do this on every checkpoint: cpu_state_dict = copy.stage(state_dict) ``` Test Plan: unit tests Differential Revision: D75993324 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155192 Approved by: https://github.com/mikaylagawarecki, https://github.com/pradeepfn
This commit is contained in:
committed by
PyTorch MergeBot
parent
d4ad280429
commit
19ffdf4ea0
821
test/distributed/checkpoint/test_state_dict_stager.py
Normal file
821
test/distributed/checkpoint/test_state_dict_stager.py
Normal file
@ -0,0 +1,821 @@
|
|||||||
|
# Owner(s): ["oncall: distributed"]
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed._tensor import DTensor
|
||||||
|
from torch.distributed._tensor.placement_types import Shard
|
||||||
|
from torch.distributed.checkpoint._state_dict_stager import StateDictStager
|
||||||
|
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
|
||||||
|
from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
|
||||||
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||||
|
DTensorTestBase,
|
||||||
|
with_comms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_cpu_state_dict(state_dict):
|
||||||
|
cpu_state_dict = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
cpu_state_dict[key] = value.cpu()
|
||||||
|
return cpu_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def compare_state_dicts(cuda_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
|
||||||
|
"""
|
||||||
|
Compare if two state dictionaries (one on CUDA, one on CPU) are otherwise the same.
|
||||||
|
|
||||||
|
This function checks if the tensors in both state dictionaries have the same values,
|
||||||
|
shapes, dtypes, etc., ignoring the device difference. It also checks if tensors that
|
||||||
|
share storage in one state dict also share storage in the other.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cuda_state_dict: The state dictionary with tensors on CUDA
|
||||||
|
cpu_state_dict: The state dictionary with tensors on CPU
|
||||||
|
rtol: Relative tolerance for comparing tensor values
|
||||||
|
atol: Absolute tolerance for comparing tensor values
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the state dictionaries are equivalent, False otherwise
|
||||||
|
str: Error message if the state dictionaries are not equivalent, empty string otherwise
|
||||||
|
"""
|
||||||
|
# Track storage data pointers to check storage sharing
|
||||||
|
cuda_storage_ptrs = {}
|
||||||
|
cpu_storage_ptrs = {}
|
||||||
|
|
||||||
|
def compare_objects(cuda_obj, cpu_obj, path=""):
|
||||||
|
# If objects are tensors, compare them
|
||||||
|
if isinstance(cuda_obj, torch.Tensor) and isinstance(cpu_obj, torch.Tensor):
|
||||||
|
# Check if devices are as expected
|
||||||
|
if cuda_obj.device.type != "cuda":
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Expected CUDA tensor, got {cuda_obj.device.type} tensor at {path}",
|
||||||
|
)
|
||||||
|
if cpu_obj.device.type != "cpu":
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Expected CPU tensor, got {cpu_obj.device.type} tensor at {path}",
|
||||||
|
)
|
||||||
|
if cuda_obj.storage_offset() != cpu_obj.storage_offset():
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Storage offset mismatch at {path}: {cuda_obj.storage_offset()} vs {cpu_obj.storage_offset()}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not torch.equal(cuda_obj.cpu(), cpu_obj):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Tensors are not same at {path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track storage sharing
|
||||||
|
cuda_storage_ptr = cuda_obj.storage().data_ptr()
|
||||||
|
cpu_storage_ptr = cpu_obj.storage().data_ptr()
|
||||||
|
|
||||||
|
if cuda_storage_ptr in cuda_storage_ptrs:
|
||||||
|
# This CUDA tensor shares storage with another tensor
|
||||||
|
# Check if the corresponding CPU tensors also share storage
|
||||||
|
if cpu_storage_ptr != cuda_storage_ptrs[cuda_storage_ptr]:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Storage sharing mismatch: CUDA tensors share storage but CPU tensors don't at {path}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# First time seeing this storage
|
||||||
|
cuda_storage_ptrs[cuda_storage_ptr] = cpu_storage_ptr
|
||||||
|
cpu_storage_ptrs[cpu_storage_ptr] = cuda_storage_ptr
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
# If objects are dictionaries, compare them recursively
|
||||||
|
elif isinstance(cuda_obj, dict) and isinstance(cpu_obj, dict):
|
||||||
|
if cuda_obj.keys() != cpu_obj.keys():
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Dictionary keys mismatch at {path}: {cuda_obj.keys()} vs {cpu_obj.keys()}",
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in cuda_obj:
|
||||||
|
result, error = compare_objects(
|
||||||
|
cuda_obj[key], cpu_obj[key], f"{path}.{key}" if path else key
|
||||||
|
)
|
||||||
|
if not result:
|
||||||
|
return False, error
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
# If objects are lists, tuples, or sets, compare them recursively
|
||||||
|
elif isinstance(cuda_obj, (list, tuple, set)) and isinstance(
|
||||||
|
cpu_obj, (list, tuple, set)
|
||||||
|
):
|
||||||
|
if len(cuda_obj) != len(cpu_obj):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Collection length mismatch at {path}: {len(cuda_obj)} vs {len(cpu_obj)}",
|
||||||
|
)
|
||||||
|
if type(cuda_obj) != type(cpu_obj):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Collection type mismatch at {path}: {type(cuda_obj)} vs {type(cpu_obj)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, (cuda_item, cpu_item) in enumerate(zip(cuda_obj, cpu_obj)):
|
||||||
|
result, error = compare_objects(cuda_item, cpu_item, f"{path}[{i}]")
|
||||||
|
if not result:
|
||||||
|
return False, error
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
# If objects are custom classes, compare their attributes
|
||||||
|
elif hasattr(cuda_obj, "__dict__") and hasattr(cpu_obj, "__dict__"):
|
||||||
|
if type(cuda_obj) != type(cpu_obj):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Object type mismatch at {path}: {type(cuda_obj)} vs {type(cpu_obj)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
result, error = compare_objects(
|
||||||
|
cuda_obj.__dict__, cpu_obj.__dict__, f"{path}.__dict__"
|
||||||
|
)
|
||||||
|
if not result:
|
||||||
|
return False, error
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
# For other types, use direct equality comparison
|
||||||
|
else:
|
||||||
|
if type(cuda_obj) != type(cpu_obj):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Type mismatch at {path}: {type(cuda_obj)} vs {type(cpu_obj)}",
|
||||||
|
)
|
||||||
|
if cuda_obj != cpu_obj:
|
||||||
|
return False, f"Value mismatch at {path}: {cuda_obj} vs {cpu_obj}"
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
# Start the recursive comparison
|
||||||
|
result, error = compare_objects(cuda_state_dict, cpu_state_dict)
|
||||||
|
return result, error
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class TestStruct:
|
||||||
|
tensor1: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class NestedTensorStruct:
|
||||||
|
tensor: torch.Tensor
|
||||||
|
value: int = 42
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ComplexDataClass:
|
||||||
|
tensor: torch.Tensor
|
||||||
|
name: str
|
||||||
|
values: list[float]
|
||||||
|
nested: NestedTensorStruct
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class FrozenDataClass:
|
||||||
|
tensor: torch.Tensor
|
||||||
|
value: int = 100
|
||||||
|
|
||||||
|
|
||||||
|
class TestStateDictStager(TestCase):
|
||||||
|
@requires_cuda
|
||||||
|
def test_views(self):
|
||||||
|
test_configs = [
|
||||||
|
(False, False), # pin_memory=False, share_memory=False,
|
||||||
|
(True, False), # pin_memory=True, share_memory=False
|
||||||
|
(False, True), # pin_memory=False, share_memory=True
|
||||||
|
(True, True), # pin_memory=True, share_memory=True
|
||||||
|
]
|
||||||
|
for pin_memory, share_memory in test_configs:
|
||||||
|
with self.subTest(pin_memory=pin_memory, share_memory=share_memory):
|
||||||
|
tensor1 = torch.randn(4, 4).cuda()
|
||||||
|
tensor2 = tensor1.view(16)
|
||||||
|
tensor3 = torch.randn(4, 4).cuda()
|
||||||
|
state_dict = {
|
||||||
|
"tensor1": tensor1,
|
||||||
|
"tensor2": tensor2,
|
||||||
|
"recursive": {
|
||||||
|
"tensor3": tensor3,
|
||||||
|
"type": TestStruct(tensor1=tensor3.narrow(0, 0, 2)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert (
|
||||||
|
state_dict["tensor1"].storage().data_ptr()
|
||||||
|
== state_dict["tensor2"].storage().data_ptr()
|
||||||
|
)
|
||||||
|
|
||||||
|
stager = StateDictStager(
|
||||||
|
pin_memory=pin_memory, share_memory=share_memory
|
||||||
|
)
|
||||||
|
|
||||||
|
cpu_state_dict = stager.stage(state_dict)
|
||||||
|
|
||||||
|
# Calculate stats
|
||||||
|
num_storages = len(stager._cached_storage_mapping)
|
||||||
|
num_bytes = sum(
|
||||||
|
storage.nbytes()
|
||||||
|
for storage in stager._cached_storage_mapping.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate tensor count and bytes
|
||||||
|
expected_storage_cnt = 2
|
||||||
|
assert (
|
||||||
|
num_storages == expected_storage_cnt
|
||||||
|
), f"Expected {expected_storage_cnt} storages, got {num_storages}"
|
||||||
|
|
||||||
|
# Calculate expected bytes
|
||||||
|
# Note: Only unique storages are counted in the byte count
|
||||||
|
expected_bytes = (
|
||||||
|
tensor1.numel() * tensor1.element_size()
|
||||||
|
+ tensor3.numel() # tensor1 and tensor2 share storage
|
||||||
|
* tensor3.element_size() # tensor3 and its narrow view share storage
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
num_bytes == expected_bytes
|
||||||
|
), f"Expected {expected_bytes} bytes, got {num_bytes}"
|
||||||
|
# Verify that the CPU state dict is equivalent to the original CUDA state dict
|
||||||
|
result, error = compare_state_dicts(state_dict, cpu_state_dict)
|
||||||
|
assert result, f"State dicts are not equivalent: {error}"
|
||||||
|
|
||||||
|
# Additional checks for storage sharing
|
||||||
|
assert cpu_state_dict["tensor1"].device == torch.device("cpu")
|
||||||
|
assert cpu_state_dict["tensor2"].device == torch.device("cpu")
|
||||||
|
assert (
|
||||||
|
cpu_state_dict["tensor1"].storage().data_ptr()
|
||||||
|
== cpu_state_dict["tensor2"].storage().data_ptr()
|
||||||
|
)
|
||||||
|
|
||||||
|
recursive = cpu_state_dict["recursive"]
|
||||||
|
assert recursive["tensor3"].device == torch.device("cpu")
|
||||||
|
assert recursive["type"].tensor1.device == torch.device("cpu")
|
||||||
|
assert (
|
||||||
|
recursive["tensor3"].storage().data_ptr()
|
||||||
|
== recursive["type"].tensor1.storage().data_ptr()
|
||||||
|
)
|
||||||
|
|
||||||
|
@requires_cuda
|
||||||
|
def test_caching(self):
|
||||||
|
"""
|
||||||
|
Test that the StateDictStager correctly caches and reuses storages.
|
||||||
|
"""
|
||||||
|
test_configs = [
|
||||||
|
(False, False), # pin_memory=False, share_memory=False,
|
||||||
|
(True, False), # pin_memory=True, share_memory=False
|
||||||
|
(False, True), # pin_memory=False, share_memory=True
|
||||||
|
(True, True), # pin_memory=True, share_memory=True
|
||||||
|
]
|
||||||
|
for pin_memory, share_memory in test_configs:
|
||||||
|
with self.subTest(pin_memory=pin_memory, share_memory=share_memory):
|
||||||
|
# Create test tensors and state dict
|
||||||
|
tensor1 = torch.randn(4, 4).cuda()
|
||||||
|
tensor2 = tensor1.view(16)
|
||||||
|
tensor3 = torch.randn(4, 4).cuda()
|
||||||
|
state_dict = {
|
||||||
|
"tensor1": tensor1,
|
||||||
|
"tensor2": tensor2,
|
||||||
|
"recursive": {
|
||||||
|
"tensor3": tensor3,
|
||||||
|
"type": TestStruct(tensor1=tensor3.narrow(0, 0, 2)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a StateDictStager instance
|
||||||
|
stager = StateDictStager(
|
||||||
|
pin_memory=pin_memory, share_memory=share_memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# First call to stage with staging context
|
||||||
|
cpu_state_dict1 = stager.stage(state_dict)
|
||||||
|
|
||||||
|
# Get the number of cached storages after first stage
|
||||||
|
num_storages1 = len(stager._cached_storage_mapping)
|
||||||
|
|
||||||
|
# Verify the first result is correct
|
||||||
|
result, error = compare_state_dicts(state_dict, cpu_state_dict1)
|
||||||
|
assert (
|
||||||
|
result
|
||||||
|
), f"First state dict is not equivalent to original: {error}"
|
||||||
|
|
||||||
|
# Modify the original tensors
|
||||||
|
tensor1.fill_(0)
|
||||||
|
tensor3.fill_(0)
|
||||||
|
|
||||||
|
# Second call to stage with staging context
|
||||||
|
cpu_state_dict2 = stager.stage(state_dict)
|
||||||
|
|
||||||
|
# Get the number of cached storages after second stage
|
||||||
|
num_storages2 = len(stager._cached_storage_mapping)
|
||||||
|
|
||||||
|
# Verify that the second CPU state dict is equivalent to the modified original state dict
|
||||||
|
result, error = compare_state_dicts(state_dict, cpu_state_dict2)
|
||||||
|
assert (
|
||||||
|
result
|
||||||
|
), f"Second state dict is not equivalent to modified original: {error}"
|
||||||
|
|
||||||
|
# Verify that the number of cached storages hasn't changed
|
||||||
|
assert (
|
||||||
|
num_storages1 == num_storages2
|
||||||
|
), f"Storage count changed: {num_storages1} vs {num_storages2}"
|
||||||
|
|
||||||
|
# Verify that the tensors in the second state dict have the same storage pointers as the first
|
||||||
|
assert (
|
||||||
|
cpu_state_dict1["tensor1"].storage().data_ptr()
|
||||||
|
== cpu_state_dict2["tensor1"].storage().data_ptr()
|
||||||
|
), "Storage pointers should match for tensor1"
|
||||||
|
assert (
|
||||||
|
cpu_state_dict1["tensor2"].storage().data_ptr()
|
||||||
|
== cpu_state_dict2["tensor2"].storage().data_ptr()
|
||||||
|
), "Storage pointers should match for tensor2"
|
||||||
|
assert (
|
||||||
|
cpu_state_dict1["recursive"]["tensor3"].storage().data_ptr()
|
||||||
|
== cpu_state_dict2["recursive"]["tensor3"].storage().data_ptr()
|
||||||
|
), "Storage pointers should match for tensor3"
|
||||||
|
|
||||||
|
# Modify the original tensors again with different values
|
||||||
|
tensor1.fill_(42.0)
|
||||||
|
|
||||||
|
# Third call to stage with staging context
|
||||||
|
cpu_state_dict3 = stager.stage(state_dict)
|
||||||
|
|
||||||
|
# Verify that the third CPU state dict reflects the updated values
|
||||||
|
assert torch.all(
|
||||||
|
cpu_state_dict3["tensor1"] == 42.0
|
||||||
|
), "Updated values should be reflected in the cached state dict"
|
||||||
|
assert torch.all(
|
||||||
|
cpu_state_dict3["tensor2"] == 42.0
|
||||||
|
), "Updated values should be reflected in the cached state dict"
|
||||||
|
|
||||||
|
@requires_cuda
|
||||||
|
def test_tensor_attrs(self):
|
||||||
|
"""
|
||||||
|
Test that tensor attributes are preserved during stage with StateDictStager.
|
||||||
|
"""
|
||||||
|
tensor1 = torch.randn(4, 4).cuda()
|
||||||
|
tensor2 = tensor1.view(16)
|
||||||
|
tensor3 = torch.randn(4, 4).cuda()
|
||||||
|
|
||||||
|
# Add custom attributes to tensors
|
||||||
|
tensor1.a = 42
|
||||||
|
tensor1.b = 43
|
||||||
|
tensor3.c = 44
|
||||||
|
|
||||||
|
state_dict = {
|
||||||
|
"tensor1": tensor1,
|
||||||
|
"tensor2": tensor2,
|
||||||
|
"recursive": {
|
||||||
|
"tensor3": tensor3,
|
||||||
|
"type": TestStruct(tensor1=tensor3.narrow(0, 0, 2)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
stager = StateDictStager(pin_memory=True, share_memory=True)
|
||||||
|
cpu_state_dict = stager.stage(state_dict)
|
||||||
|
|
||||||
|
# Verify that tensor attributes are preserved
|
||||||
|
assert hasattr(
|
||||||
|
cpu_state_dict["tensor1"], "a"
|
||||||
|
), "Tensor attribute 'a' was not preserved"
|
||||||
|
assert (
|
||||||
|
cpu_state_dict["tensor1"].a == 42
|
||||||
|
), "Tensor attribute 'a' has incorrect value"
|
||||||
|
assert hasattr(
|
||||||
|
cpu_state_dict["tensor1"], "b"
|
||||||
|
), "Tensor attribute 'b' was not preserved"
|
||||||
|
assert (
|
||||||
|
cpu_state_dict["tensor1"].b == 43
|
||||||
|
), "Tensor attribute 'b' has incorrect value"
|
||||||
|
assert hasattr(
|
||||||
|
cpu_state_dict["recursive"]["tensor3"], "c"
|
||||||
|
), "Tensor attribute 'c' was not preserved"
|
||||||
|
assert (
|
||||||
|
cpu_state_dict["recursive"]["tensor3"].c == 44
|
||||||
|
), "Tensor attribute 'c' has incorrect value"
|
||||||
|
|
||||||
|
@requires_cuda
|
||||||
|
def test_different_dtypes(self):
|
||||||
|
"""
|
||||||
|
Test that StateDictStager works correctly with tensors of different data types.
|
||||||
|
"""
|
||||||
|
# Create tensors with different dtypes
|
||||||
|
tensors = {
|
||||||
|
"float32": torch.randn(4, 4, dtype=torch.float32).cuda(),
|
||||||
|
"float64": torch.randn(4, 4, dtype=torch.float64).cuda(),
|
||||||
|
"int32": torch.randint(-100, 100, (4, 4), dtype=torch.int32).cuda(),
|
||||||
|
"int64": torch.randint(-100, 100, (4, 4), dtype=torch.int64).cuda(),
|
||||||
|
"bool": torch.randint(0, 2, (4, 4), dtype=torch.bool).cuda(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a state dict with these tensors
|
||||||
|
state_dict = tensors.copy()
|
||||||
|
|
||||||
|
stager = StateDictStager()
|
||||||
|
cpu_state_dict = stager.stage(state_dict)
|
||||||
|
|
||||||
|
# Verify that all tensors have been correctly copied to CPU with the right dtypes
|
||||||
|
for dtype_name, original_tensor in tensors.items():
|
||||||
|
cpu_tensor = cpu_state_dict[dtype_name]
|
||||||
|
self.assertEqual(
|
||||||
|
cpu_tensor.device.type, "cpu", f"Tensor {dtype_name} should be on CPU"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
cpu_tensor.dtype,
|
||||||
|
original_tensor.dtype,
|
||||||
|
f"Tensor {dtype_name} has incorrect dtype",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(cpu_tensor, original_tensor.cpu()),
|
||||||
|
f"Tensor {dtype_name} has incorrect values",
|
||||||
|
)
|
||||||
|
|
||||||
|
@requires_cuda
|
||||||
|
def test_empty_tensors(self):
|
||||||
|
"""
|
||||||
|
Test that StateDictStager works correctly with empty tensors.
|
||||||
|
"""
|
||||||
|
test_configs = [
|
||||||
|
(False, False), # pin_memory=False, share_memory=False,
|
||||||
|
(True, False), # pin_memory=True, share_memory=False
|
||||||
|
(False, True), # pin_memory=False, share_memory=True
|
||||||
|
(True, True), # pin_memory=True, share_memory=True
|
||||||
|
]
|
||||||
|
for pin_memory, share_memory in test_configs:
|
||||||
|
with self.subTest(pin_memory=pin_memory, share_memory=share_memory):
|
||||||
|
# Create empty tensors with different shapes
|
||||||
|
tensors = {
|
||||||
|
"empty_0d": torch.tensor([], dtype=torch.float32).cuda(),
|
||||||
|
"empty_1d": torch.tensor([], dtype=torch.float32).reshape(0).cuda(),
|
||||||
|
"empty_2d": torch.tensor([], dtype=torch.float32)
|
||||||
|
.reshape(0, 0)
|
||||||
|
.cuda(),
|
||||||
|
"empty_3d": torch.tensor([], dtype=torch.float32)
|
||||||
|
.reshape(0, 0, 0)
|
||||||
|
.cuda(),
|
||||||
|
"zero_dim": torch.tensor(0.0).cuda(), # scalar tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a state dict with these tensors
|
||||||
|
state_dict = tensors.copy()
|
||||||
|
|
||||||
|
cpu_state_dict = StateDictStager(pin_memory, share_memory).stage(
|
||||||
|
state_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that all tensors have been correctly copied to CPU
|
||||||
|
for tensor_name, original_tensor in tensors.items():
|
||||||
|
cpu_tensor = cpu_state_dict[tensor_name]
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
cpu_tensor.device.type,
|
||||||
|
"cpu",
|
||||||
|
f"Tensor {tensor_name} should be on CPU",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
cpu_tensor.shape,
|
||||||
|
original_tensor.shape,
|
||||||
|
f"Tensor {tensor_name} has incorrect shape",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
cpu_tensor.dtype,
|
||||||
|
original_tensor.dtype,
|
||||||
|
f"Tensor {tensor_name} has incorrect dtype",
|
||||||
|
)
|
||||||
|
|
||||||
|
@requires_cuda
|
||||||
|
def test_complex_storage_sharing(self):
|
||||||
|
"""
|
||||||
|
Test that StateDictStager correctly handles complex storage sharing scenarios.
|
||||||
|
"""
|
||||||
|
# Create a base tensor
|
||||||
|
base_tensor = torch.randn(10, 10).cuda()
|
||||||
|
|
||||||
|
# Create various views and slices that share storage
|
||||||
|
view1 = base_tensor.view(100)
|
||||||
|
view2 = base_tensor.view(10, 10)
|
||||||
|
slice1 = base_tensor[2:8, 2:8]
|
||||||
|
slice2 = base_tensor[:, :5]
|
||||||
|
slice3 = view1[10:60]
|
||||||
|
|
||||||
|
# Create a state dict with these tensors
|
||||||
|
state_dict = {
|
||||||
|
"base": base_tensor,
|
||||||
|
"view1": view1,
|
||||||
|
"view2": view2,
|
||||||
|
"slice1": slice1,
|
||||||
|
"slice2": slice2,
|
||||||
|
"slice3": slice3,
|
||||||
|
}
|
||||||
|
cpu_state_dict = StateDictStager().stage(state_dict)
|
||||||
|
|
||||||
|
# Verify that all tensors have been correctly copied to CPU
|
||||||
|
result, error = compare_state_dicts(state_dict, cpu_state_dict)
|
||||||
|
self.assertTrue(result, f"State dicts are not equivalent: {error}")
|
||||||
|
|
||||||
|
# Verify storage sharing is preserved
|
||||||
|
# All these tensors should share the same storage
|
||||||
|
storage_ptr = cpu_state_dict["base"].storage().data_ptr()
|
||||||
|
self.assertEqual(
|
||||||
|
cpu_state_dict["view1"].storage().data_ptr(),
|
||||||
|
storage_ptr,
|
||||||
|
"view1 should share storage with base",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
cpu_state_dict["view2"].storage().data_ptr(),
|
||||||
|
storage_ptr,
|
||||||
|
"view2 should share storage with base",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
cpu_state_dict["slice1"].storage().data_ptr(),
|
||||||
|
storage_ptr,
|
||||||
|
"slice1 should share storage with base",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
cpu_state_dict["slice2"].storage().data_ptr(),
|
||||||
|
storage_ptr,
|
||||||
|
"slice2 should share storage with base",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
cpu_state_dict["slice3"].storage().data_ptr(),
|
||||||
|
storage_ptr,
|
||||||
|
"slice3 should share storage with base",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that modifying the base tensor affects all views and slices
|
||||||
|
cpu_state_dict["base"].fill_(42.0)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.all(cpu_state_dict["view1"] == 42.0),
|
||||||
|
"view1 should reflect changes to base",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.all(cpu_state_dict["view2"] == 42.0),
|
||||||
|
"view2 should reflect changes to base",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.all(cpu_state_dict["slice1"] == 42.0),
|
||||||
|
"slice1 should reflect changes to base",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.all(cpu_state_dict["slice2"] == 42.0),
|
||||||
|
"slice2 should reflect changes to base",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.all(cpu_state_dict["slice3"] == 42.0),
|
||||||
|
"slice3 should reflect changes to base",
|
||||||
|
)
|
||||||
|
|
||||||
|
@requires_cuda
|
||||||
|
def test_dataclasses(self):
|
||||||
|
# Create tensors
|
||||||
|
tensor1 = torch.randn(4, 4).cuda()
|
||||||
|
tensor2 = torch.randn(8, 8).cuda()
|
||||||
|
tensor3 = torch.randn(2, 6).cuda()
|
||||||
|
tensor4 = torch.randn(3, 5).cuda()
|
||||||
|
|
||||||
|
# Create dataclass instances
|
||||||
|
nested = NestedTensorStruct(tensor=tensor3)
|
||||||
|
complex_dc = ComplexDataClass(
|
||||||
|
tensor=tensor1, name="test", values=[1.0, 2.0, 3.0], nested=nested
|
||||||
|
)
|
||||||
|
frozen_dc = FrozenDataClass(tensor=tensor4)
|
||||||
|
|
||||||
|
# Create a state dict with these dataclasses
|
||||||
|
state_dict = {
|
||||||
|
"regular_tensor": tensor2,
|
||||||
|
"complex_dataclass": complex_dc,
|
||||||
|
"frozen_dataclass": frozen_dc,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Stage the state dict
|
||||||
|
stager = StateDictStager(pin_memory=False, share_memory=False)
|
||||||
|
cpu_state_dict = stager.stage(state_dict)
|
||||||
|
|
||||||
|
# Verify regular tensor
|
||||||
|
self.assertEqual(cpu_state_dict["regular_tensor"].device.type, "cpu")
|
||||||
|
self.assertTrue(torch.allclose(cpu_state_dict["regular_tensor"], tensor2.cpu()))
|
||||||
|
|
||||||
|
# Verify complex dataclass
|
||||||
|
complex_cpu = cpu_state_dict["complex_dataclass"]
|
||||||
|
self.assertEqual(complex_cpu.name, "test")
|
||||||
|
self.assertEqual(complex_cpu.values, [1.0, 2.0, 3.0])
|
||||||
|
self.assertEqual(complex_cpu.tensor.device.type, "cpu")
|
||||||
|
self.assertTrue(torch.allclose(complex_cpu.tensor, tensor1.cpu()))
|
||||||
|
|
||||||
|
# Verify nested dataclass inside complex dataclass
|
||||||
|
nested_cpu = complex_cpu.nested
|
||||||
|
self.assertEqual(nested_cpu.value, 42)
|
||||||
|
self.assertEqual(nested_cpu.tensor.device.type, "cpu")
|
||||||
|
self.assertTrue(torch.allclose(nested_cpu.tensor, tensor3.cpu()))
|
||||||
|
|
||||||
|
# Verify frozen dataclass
|
||||||
|
frozen_cpu = cpu_state_dict["frozen_dataclass"]
|
||||||
|
self.assertEqual(frozen_cpu.value, 100)
|
||||||
|
self.assertEqual(frozen_cpu.tensor.device.type, "cpu")
|
||||||
|
self.assertTrue(torch.allclose(frozen_cpu.tensor, tensor4.cpu()))
|
||||||
|
|
||||||
|
# Verify that modifying the original tensors doesn't affect the staged ones
|
||||||
|
tensor1.fill_(99.0)
|
||||||
|
tensor3.fill_(88.0)
|
||||||
|
tensor4.fill_(77.0)
|
||||||
|
|
||||||
|
self.assertFalse(torch.allclose(complex_cpu.tensor, tensor1.cpu()))
|
||||||
|
self.assertFalse(torch.allclose(nested_cpu.tensor, tensor3.cpu()))
|
||||||
|
self.assertFalse(torch.allclose(frozen_cpu.tensor, tensor4.cpu()))
|
||||||
|
|
||||||
|
def test_cpu_storage_independence(self):
|
||||||
|
"""
|
||||||
|
Test ensures CPU tensors passed to StateDictStager are actually cloned
|
||||||
|
"""
|
||||||
|
# Create test tensors
|
||||||
|
tensor1 = torch.randn(4, 4)
|
||||||
|
tensor2 = torch.randn(8, 8)
|
||||||
|
|
||||||
|
# Create a state dict with these tensors
|
||||||
|
state_dict = {
|
||||||
|
"tensor1": tensor1,
|
||||||
|
"tensor2": tensor2,
|
||||||
|
}
|
||||||
|
|
||||||
|
cpu_state_dict = StateDictStager().stage(state_dict)
|
||||||
|
cpu_tensor1 = cpu_state_dict["tensor1"]
|
||||||
|
cpu_tensor2 = cpu_state_dict["tensor2"]
|
||||||
|
|
||||||
|
# Verify that the CPU tensors have different storage pointers than the original tensors
|
||||||
|
self.assertNotEqual(
|
||||||
|
tensor1.storage().data_ptr(),
|
||||||
|
cpu_tensor1.storage().data_ptr(),
|
||||||
|
"CPU tensor should have a different storage pointer than the original tensor",
|
||||||
|
)
|
||||||
|
self.assertNotEqual(
|
||||||
|
tensor2.storage().data_ptr(),
|
||||||
|
cpu_tensor2.storage().data_ptr(),
|
||||||
|
"CPU tensor should have a different storage pointer than the original tensor",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(tensor1, cpu_tensor1),
|
||||||
|
"CPU tensor should have the same values as the original tensor",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(tensor2, cpu_tensor2),
|
||||||
|
"CPU tensor should have the same values as the original tensor",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Modify the original CPU tensors and validate staged tensors are not modified
|
||||||
|
cloned_orginial1 = tensor1.clone()
|
||||||
|
cloned_orginia2 = tensor2.clone()
|
||||||
|
tensor1.fill_(99.0)
|
||||||
|
tensor2.fill_(88.0)
|
||||||
|
|
||||||
|
self.assertFalse(torch.allclose(cloned_orginial1, tensor1))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(cloned_orginial1, cpu_tensor1),
|
||||||
|
"CPU tensor should have the same values as the original tensor",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(cloned_orginia2, cpu_tensor2),
|
||||||
|
"CPU tensor should have the same values as the original tensor",
|
||||||
|
)
|
||||||
|
|
||||||
|
@requires_cuda
|
||||||
|
def test_tensor_pinned_and_shared(self):
|
||||||
|
"""
|
||||||
|
Test that verifies tensors are actually pinned and shared using tensor.is_pinned() and tensor.is_shared() methods.
|
||||||
|
"""
|
||||||
|
# Create test tensors
|
||||||
|
tensor1 = torch.randn(4, 4).cuda()
|
||||||
|
tensor2 = torch.randn(8, 8).cuda()
|
||||||
|
|
||||||
|
# Create a state dict with these tensors
|
||||||
|
state_dict = {
|
||||||
|
"tensor1": tensor1,
|
||||||
|
"tensor2": tensor2,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test all combinations of pin_memory and share_memory
|
||||||
|
test_configs = [
|
||||||
|
(False, False), # pin_memory=False, share_memory=False
|
||||||
|
(True, False), # pin_memory=True, share_memory=False
|
||||||
|
(False, True), # pin_memory=False, share_memory=True
|
||||||
|
(True, True), # pin_memory=True, share_memory=True
|
||||||
|
]
|
||||||
|
|
||||||
|
for pin_memory, share_memory in test_configs:
|
||||||
|
with self.subTest(pin_memory=pin_memory, share_memory=share_memory):
|
||||||
|
# Create stager with specific configuration
|
||||||
|
stager = StateDictStager(
|
||||||
|
pin_memory=pin_memory, share_memory=share_memory
|
||||||
|
)
|
||||||
|
cpu_state_dict = stager.stage(state_dict)
|
||||||
|
|
||||||
|
# Get the staged tensors
|
||||||
|
cpu_tensor1 = cpu_state_dict["tensor1"]
|
||||||
|
cpu_tensor2 = cpu_state_dict["tensor2"]
|
||||||
|
|
||||||
|
# Verify tensor device
|
||||||
|
self.assertEqual(
|
||||||
|
cpu_tensor1.device.type, "cpu", "Staged tensor should be on CPU"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
cpu_tensor2.device.type, "cpu", "Staged tensor should be on CPU"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify tensor values
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(cpu_tensor1, tensor1.cpu()),
|
||||||
|
"CPU tensor should have the same values as the original tensor",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(cpu_tensor2, tensor2.cpu()),
|
||||||
|
"CPU tensor should have the same values as the original tensor",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify pinned memory status
|
||||||
|
self.assertEqual(
|
||||||
|
cpu_tensor1.is_pinned(),
|
||||||
|
pin_memory,
|
||||||
|
f"Tensor pinned status should be {pin_memory}",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
cpu_tensor2.is_pinned(),
|
||||||
|
pin_memory,
|
||||||
|
f"Tensor pinned status should be {pin_memory}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify shared memory status
|
||||||
|
self.assertEqual(
|
||||||
|
cpu_tensor1.is_shared(),
|
||||||
|
share_memory,
|
||||||
|
f"Tensor shared status should be {share_memory}",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
cpu_tensor2.is_shared(),
|
||||||
|
share_memory,
|
||||||
|
f"Tensor shared status should be {share_memory}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify storage sharing is consistent with tensor sharing
|
||||||
|
if share_memory:
|
||||||
|
# When share_memory is True, the storage should also be shared
|
||||||
|
self.assertTrue(
|
||||||
|
cpu_tensor1.storage().is_shared(),
|
||||||
|
"When share_memory=True, tensor storage should be shared",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
cpu_tensor2.storage().is_shared(),
|
||||||
|
"When share_memory=True, tensor storage should be shared",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# When share_memory is False, the storage should not be shared
|
||||||
|
self.assertFalse(
|
||||||
|
cpu_tensor1.storage().is_shared(),
|
||||||
|
"When share_memory=False, tensor storage should not be shared",
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
cpu_tensor2.storage().is_shared(),
|
||||||
|
"When share_memory=False, tensor storage should not be shared",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDTensorStateDictStager(DTensorTestBase):
|
||||||
|
@with_comms
|
||||||
|
@requires_nccl()
|
||||||
|
@skip_if_lt_x_gpu(2)
|
||||||
|
def test_dtensor(self):
|
||||||
|
"""
|
||||||
|
Test that StateDictStager works correctly with DTensors.
|
||||||
|
"""
|
||||||
|
# Create a DTensor
|
||||||
|
device_mesh = dist.DeviceMesh("cuda", list(range(dist.get_world_size())))
|
||||||
|
tensor = torch.randn(3, 3, device="cuda")
|
||||||
|
dtensor = DTensor.from_local(tensor, device_mesh, [Shard(0)])
|
||||||
|
|
||||||
|
dtensor = dtensor + 1
|
||||||
|
dtensor = dtensor * 2
|
||||||
|
|
||||||
|
state_dict = {
|
||||||
|
"dtensor": dtensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
stager = StateDictStager(pin_memory=True, share_memory=True)
|
||||||
|
cpu_state_dict = stager.stage(state_dict)
|
||||||
|
|
||||||
|
# Verify the original DTensor has the expected values
|
||||||
|
self.assertTrue(torch.allclose(dtensor.to_local(), (tensor + 1) * 2))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
cpu_state_dict["dtensor"].to_local(), dtensor.to_local().cpu()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(cpu_state_dict["dtensor"]._spec, dtensor._spec)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
24
torch/cuda/_pin_memory_utils.py
Normal file
24
torch/cuda/_pin_memory_utils.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def pin_memory(data_ptr: int, size: int) -> None:
|
||||||
|
cudart = torch.cuda.cudart()
|
||||||
|
succ = int(
|
||||||
|
cudart.cudaHostRegister(
|
||||||
|
data_ptr,
|
||||||
|
size,
|
||||||
|
1, # lines up with 'cudaHostRegisterPortable'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if succ != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Registering memory failed with cudaError: {succ}."
|
||||||
|
" It's possible that this is an asynchronous error raised from a previous cuda operation."
|
||||||
|
" Consider launching with CUDA_LAUNCH_BLOCKING=1 to debug."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def unpin_memory(data_ptr: int) -> None:
|
||||||
|
succ = int(torch.cuda.cudart().cudaHostUnregister(data_ptr))
|
||||||
|
assert succ == 0, f"Unpinning shared memory failed with error-code: {succ}"
|
@ -7,6 +7,7 @@ from collections.abc import Mapping, MutableMapping
|
|||||||
from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union
|
from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.cuda._pin_memory_utils as pin_memory_utils
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||||
@ -421,24 +422,9 @@ def _create_cpu_state_dict(
|
|||||||
t = torch.empty(*tuple(obj.size()), dtype=obj.dtype)
|
t = torch.empty(*tuple(obj.size()), dtype=obj.dtype)
|
||||||
t = t.share_memory_()
|
t = t.share_memory_()
|
||||||
if pin_memory:
|
if pin_memory:
|
||||||
|
pin_memory_utils.pin_memory(t.data_ptr(), t.numel() * t.element_size())
|
||||||
|
weakref.finalize(t, pin_memory_utils.unpin_memory, t)
|
||||||
|
|
||||||
def unpin_memory(t):
|
|
||||||
succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr()))
|
|
||||||
assert succ == 0, (
|
|
||||||
f"Unpinning shared memory failed with error-code: {succ}"
|
|
||||||
)
|
|
||||||
|
|
||||||
weakref.finalize(t, unpin_memory, t)
|
|
||||||
succ = int(
|
|
||||||
torch.cuda.cudart().cudaHostRegister(
|
|
||||||
t.data_ptr(),
|
|
||||||
t.numel() * t.element_size(),
|
|
||||||
1, # lines up with 'cudaHostRegisterPortable'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert succ == 0, (
|
|
||||||
f"Pinning shared memory failed with error-code: {succ}"
|
|
||||||
)
|
|
||||||
return t
|
return t
|
||||||
elif pin_memory:
|
elif pin_memory:
|
||||||
return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory()
|
return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory()
|
||||||
|
354
torch/distributed/checkpoint/_state_dict_stager.py
Normal file
354
torch/distributed/checkpoint/_state_dict_stager.py
Normal file
@ -0,0 +1,354 @@
|
|||||||
|
# mypy: allow-untyped-defs
|
||||||
|
import logging
|
||||||
|
import types
|
||||||
|
import weakref
|
||||||
|
from copyreg import dispatch_table
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.cuda._pin_memory_utils as pin_memory_utils
|
||||||
|
from torch.storage import UntypedStorage
|
||||||
|
from torch.utils.weak import WeakIdKeyDictionary
|
||||||
|
|
||||||
|
|
||||||
|
logger = getLogger()
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
class StateDictStager:
|
||||||
|
"""
|
||||||
|
A class for optimizing storage objects during staging for async checkpointing.
|
||||||
|
|
||||||
|
StateDictStager stages the state_dict to CPU DRAM while applying optimizations
|
||||||
|
like memory sharing and pinning to improve performance. It caches storage objects
|
||||||
|
to avoid redundant copies and can be configured to automatically share memory
|
||||||
|
(for multi-process usage) and pin memory (for faster CPU-GPU transfers).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
pin_memory (bool): Whether to pin CPU memory for faster CPU-GPU transfers
|
||||||
|
share_memory (bool): Whether to share memory across processes
|
||||||
|
_cached_storage_mapping (WeakIdKeyDictionary): Maps storage objects to optimized CPU storages using weak references
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pin_memory: bool = False, share_memory: bool = False):
|
||||||
|
if pin_memory and not torch.cuda.is_available():
|
||||||
|
logger.warning(
|
||||||
|
"Ignoring pin_memory flag for checkpoint staging as pinning memory"
|
||||||
|
"requires CUDA, but CUDA is not available. "
|
||||||
|
)
|
||||||
|
self.pin_memory = False
|
||||||
|
else:
|
||||||
|
self.pin_memory = pin_memory
|
||||||
|
self.share_memory = share_memory
|
||||||
|
# Mapping from original storage objects to CPU storages using weak references
|
||||||
|
self._cached_storage_mapping = WeakIdKeyDictionary()
|
||||||
|
|
||||||
|
def _deepcopy_atomic(x, _):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _deepcopy_list(x, memo):
|
||||||
|
y: list = []
|
||||||
|
memo[id(x)] = y
|
||||||
|
append = y.append
|
||||||
|
for a in x:
|
||||||
|
append(self.deepcopy_with_tensor_offload(a, memo))
|
||||||
|
return y
|
||||||
|
|
||||||
|
def _deepcopy_tuple(x, memo):
|
||||||
|
y = [self.deepcopy_with_tensor_offload(a, memo) for a in x]
|
||||||
|
# We're not going to put the tuple in the memo, but it's still important we
|
||||||
|
# check for it, in case the tuple contains recursive mutable structures.
|
||||||
|
try:
|
||||||
|
return memo[id(x)]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Check if any elements changed during deepcopy
|
||||||
|
for k, j in zip(x, y):
|
||||||
|
if k is not j:
|
||||||
|
# At least one element changed, create new tuple
|
||||||
|
return tuple(y)
|
||||||
|
|
||||||
|
# No elements changed, return original tuple
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _deepcopy_dict(x, memo):
|
||||||
|
y: dict = {}
|
||||||
|
memo[id(x)] = y
|
||||||
|
for key, value in x.items():
|
||||||
|
y[self.deepcopy_with_tensor_offload(key, memo)] = (
|
||||||
|
self.deepcopy_with_tensor_offload(value, memo)
|
||||||
|
)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def _deepcopy_method(x, memo): # Copy instance methods
|
||||||
|
return type(x)(
|
||||||
|
x.__func__, self.deepcopy_with_tensor_offload(x.__self__, memo)
|
||||||
|
)
|
||||||
|
|
||||||
|
d: dict[Any, Any] = {}
|
||||||
|
self._deepcopy_dispatch = d
|
||||||
|
d[type(None)] = _deepcopy_atomic
|
||||||
|
d[int] = _deepcopy_atomic
|
||||||
|
d[float] = _deepcopy_atomic
|
||||||
|
d[bool] = _deepcopy_atomic
|
||||||
|
d[complex] = _deepcopy_atomic
|
||||||
|
d[bytes] = _deepcopy_atomic
|
||||||
|
d[str] = _deepcopy_atomic
|
||||||
|
d[types.CodeType] = _deepcopy_atomic
|
||||||
|
d[type] = _deepcopy_atomic
|
||||||
|
d[range] = _deepcopy_atomic
|
||||||
|
d[types.BuiltinFunctionType] = _deepcopy_atomic
|
||||||
|
d[types.FunctionType] = _deepcopy_atomic
|
||||||
|
d[weakref.ref] = _deepcopy_atomic
|
||||||
|
d[property] = _deepcopy_atomic
|
||||||
|
d[types.MethodType] = _deepcopy_method
|
||||||
|
d[dict] = _deepcopy_dict
|
||||||
|
d[tuple] = _deepcopy_tuple
|
||||||
|
d[list] = _deepcopy_list
|
||||||
|
|
||||||
|
def _stage_untyped_storage(
|
||||||
|
self, storage: UntypedStorage, non_blocking: bool = False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Called from the hooked storage_deepcopy function in torch.Tensor.__deepcopy__.
|
||||||
|
|
||||||
|
This method handles the storage optimization logic for the StagingStateDict class.
|
||||||
|
It checks if the storage has already been cached, and if so, reuses it.
|
||||||
|
Otherwise, it creates a new CPU storage and applies memory optimizations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage: The storage to optimize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The optimized storage
|
||||||
|
"""
|
||||||
|
# Check if we've already cached this storage
|
||||||
|
if storage in self._cached_storage_mapping:
|
||||||
|
cached_storage = self._cached_storage_mapping[storage]
|
||||||
|
assert cached_storage.size() == storage.size(), (
|
||||||
|
"For async checkpointing, We cache storages in DRAM and reuse them."
|
||||||
|
"Cached storage size does not match original storage size."
|
||||||
|
"This should never happen as we track the original storage weakref "
|
||||||
|
"and clean up the cache storage. Please report this to PyTorch Distributed Checkpointing."
|
||||||
|
)
|
||||||
|
# Reuse cached storage but update with new data
|
||||||
|
cached_storage.copy_(storage, non_blocking=non_blocking)
|
||||||
|
return cached_storage
|
||||||
|
|
||||||
|
# Create new CPU storage
|
||||||
|
if self.share_memory:
|
||||||
|
new_storage = type(storage)._new_shared(storage.size(), device="cpu")
|
||||||
|
else:
|
||||||
|
new_storage = type(storage)(storage.size(), device="cpu")
|
||||||
|
|
||||||
|
if self.pin_memory and new_storage.nbytes() > 0:
|
||||||
|
pin_memory_utils.pin_memory(new_storage.data_ptr(), new_storage.nbytes())
|
||||||
|
# Set up a weak reference to unpin when cpu storage is garbage collected
|
||||||
|
f = weakref.finalize(
|
||||||
|
new_storage, pin_memory_utils.unpin_memory, new_storage.data_ptr()
|
||||||
|
)
|
||||||
|
# This makes sure that the finalizer is not called after
|
||||||
|
# cuda context is destroyed.
|
||||||
|
f.atexit = False
|
||||||
|
|
||||||
|
new_storage.copy_(storage, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
# Cache the storage - WeakIdKeyDictionary will automatically clean up when storage is garbage collected
|
||||||
|
self._cached_storage_mapping[storage] = new_storage
|
||||||
|
return new_storage
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def stage(
|
||||||
|
self,
|
||||||
|
state_dict: dict[str, Any],
|
||||||
|
non_blocking: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return self.deepcopy_with_tensor_offload(state_dict, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
def _offload_tensor(self, x, memo, non_blocking=False):
|
||||||
|
"""
|
||||||
|
Deep copy a PyTorch tensor with optimized storage handling.
|
||||||
|
|
||||||
|
This method creates a CPU copy of a tensor while applying memory optimizations
|
||||||
|
like sharing and pinning based on the StateDictStager configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The tensor to copy
|
||||||
|
memo: Memo dictionary for tracking already copied objects
|
||||||
|
non_blocking: Whether to perform non-blocking copies where possible
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A CPU copy of the tensor with optimized storage
|
||||||
|
"""
|
||||||
|
# Create a new empty tensor on CPU
|
||||||
|
y = x.new_empty([], device="cpu")
|
||||||
|
|
||||||
|
# Store in memo dict early to handle recursive references
|
||||||
|
d = id(x)
|
||||||
|
memo[d] = y
|
||||||
|
|
||||||
|
if type(x) is torch.Tensor or x.data_ptr() != 0:
|
||||||
|
# Try to get the untyped storage and optimize it
|
||||||
|
untyped_storage = x.untyped_storage()
|
||||||
|
copied_storage = self._stage_untyped_storage(
|
||||||
|
untyped_storage, non_blocking=non_blocking
|
||||||
|
)
|
||||||
|
# Set the tensor data using the optimized storage
|
||||||
|
y.set_(copied_storage, x.storage_offset(), x.size(), x.stride())
|
||||||
|
|
||||||
|
# Copy any attributes the tensor might have
|
||||||
|
if hasattr(x, "__dict__"):
|
||||||
|
for attr_name, attr_value in x.__dict__.items():
|
||||||
|
setattr(
|
||||||
|
y,
|
||||||
|
attr_name,
|
||||||
|
self.deepcopy_with_tensor_offload(
|
||||||
|
attr_value, memo, non_blocking=non_blocking
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(x, "__slots__"):
|
||||||
|
for slot in x.__slots__:
|
||||||
|
if hasattr(x, slot):
|
||||||
|
setattr(
|
||||||
|
y,
|
||||||
|
slot,
|
||||||
|
self.deepcopy_with_tensor_offload(
|
||||||
|
getattr(x, slot), memo, non_blocking=non_blocking
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def deepcopy_with_tensor_offload(self, x, memo=None, _nil=[], non_blocking=False): # noqa: B006
|
||||||
|
"""Deep copy operation on arbitrary Python objects with special handling for PyTorch tensors.
|
||||||
|
|
||||||
|
This implementation extends the standard deepcopy functionality to handle PyTorch tensors
|
||||||
|
and their storages in a way that optimizes memory usage and performance, similar to the
|
||||||
|
stage method. It applies memory sharing and pinning optimizations based on the StateDictStager
|
||||||
|
configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The object to deep copy
|
||||||
|
memo: Memo dictionary for tracking already copied objects
|
||||||
|
_nil: Sentinel value for memo dictionary
|
||||||
|
non_blocking: Whether to perform non-blocking copies where possible
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A deep copy of the input object with optimized tensor storage handling
|
||||||
|
"""
|
||||||
|
if memo is None:
|
||||||
|
memo = {}
|
||||||
|
|
||||||
|
d = id(x)
|
||||||
|
y = memo.get(d, _nil)
|
||||||
|
if y is not _nil:
|
||||||
|
return y
|
||||||
|
|
||||||
|
cls = type(x)
|
||||||
|
|
||||||
|
# tensors and subclasses of tensors are handled separately
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
y = self._offload_tensor(x, memo, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
# Use the dispatch table for standard types
|
||||||
|
copier = self._deepcopy_dispatch.get(cls)
|
||||||
|
if copier is not None:
|
||||||
|
y = copier(x, memo)
|
||||||
|
else:
|
||||||
|
if issubclass(cls, type):
|
||||||
|
y = self._deepcopy_dispatch[type](x, memo)
|
||||||
|
else:
|
||||||
|
copier = getattr(x, "__deepcopy__", None)
|
||||||
|
if copier is not None:
|
||||||
|
y = copier(memo)
|
||||||
|
else:
|
||||||
|
reductor = dispatch_table.get(cls)
|
||||||
|
if reductor:
|
||||||
|
rv = reductor(x)
|
||||||
|
else:
|
||||||
|
reductor = getattr(x, "__reduce_ex__", None)
|
||||||
|
if reductor is not None:
|
||||||
|
rv = reductor(4)
|
||||||
|
else:
|
||||||
|
reductor = getattr(x, "__reduce__", None)
|
||||||
|
if reductor:
|
||||||
|
rv = reductor()
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"un(deep)copyable object of type {cls}"
|
||||||
|
)
|
||||||
|
if isinstance(rv, str):
|
||||||
|
y = x
|
||||||
|
else:
|
||||||
|
y = self._reconstruct(x, memo, *rv)
|
||||||
|
|
||||||
|
# If is its own copy, don't memoize.
|
||||||
|
if y is not x:
|
||||||
|
memo[d] = y
|
||||||
|
self._keep_alive(x, memo) # Make sure x lives at least as long as d
|
||||||
|
return y
|
||||||
|
|
||||||
|
def _keep_alive(self, x, memo):
|
||||||
|
"""Keeps a reference to the object x in the memo.
|
||||||
|
|
||||||
|
Because we remember objects by their id, we have
|
||||||
|
to assure that possibly temporary objects are kept
|
||||||
|
alive by referencing them.
|
||||||
|
We store a reference at the id of the memo, which should
|
||||||
|
normally not be used unless someone tries to deepcopy
|
||||||
|
the memo itself...
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
memo[id(memo)].append(x)
|
||||||
|
except KeyError:
|
||||||
|
# aha, this is the first one :-)
|
||||||
|
memo[id(memo)] = [x]
|
||||||
|
|
||||||
|
def _reconstruct(
|
||||||
|
self, x, memo, func, args, state=None, listiter=None, dictiter=None
|
||||||
|
):
|
||||||
|
deep = memo is not None
|
||||||
|
if deep and args:
|
||||||
|
args = (self.deepcopy_with_tensor_offload(arg, memo) for arg in args)
|
||||||
|
y = func(*args)
|
||||||
|
if deep:
|
||||||
|
memo[id(x)] = y
|
||||||
|
|
||||||
|
if state is not None:
|
||||||
|
if deep:
|
||||||
|
state = self.deepcopy_with_tensor_offload(state, memo)
|
||||||
|
if hasattr(y, "__setstate__"):
|
||||||
|
y.__setstate__(state)
|
||||||
|
else:
|
||||||
|
if isinstance(state, tuple) and len(state) == 2:
|
||||||
|
state, slotstate = state
|
||||||
|
else:
|
||||||
|
slotstate = None
|
||||||
|
if state is not None:
|
||||||
|
y.__dict__.update(state)
|
||||||
|
if slotstate is not None:
|
||||||
|
for key, value in slotstate.items():
|
||||||
|
setattr(y, key, value)
|
||||||
|
|
||||||
|
if listiter is not None:
|
||||||
|
if deep:
|
||||||
|
for item in listiter:
|
||||||
|
item = self.deepcopy_with_tensor_offload(item, memo)
|
||||||
|
y.append(item)
|
||||||
|
else:
|
||||||
|
for item in listiter:
|
||||||
|
y.append(item)
|
||||||
|
if dictiter is not None:
|
||||||
|
if deep:
|
||||||
|
for key, value in dictiter:
|
||||||
|
key = self.deepcopy_with_tensor_offload(key, memo)
|
||||||
|
value = self.deepcopy_with_tensor_offload(value, memo)
|
||||||
|
y[key] = value
|
||||||
|
else:
|
||||||
|
for key, value in dictiter:
|
||||||
|
y[key] = value
|
||||||
|
return y
|
Reference in New Issue
Block a user