Files
pytorch/test/distributed/checkpoint/_experimental/test_builder.py
Huy Do 73078f305f Add missing super().setUp() (#167163)
In a trunk failure today, we saw the same test running on both trunk and slow shards.  The reason is that this test didn't invoke `super().setUp()`, so all the test features like slow and disabled test didn't apply to them.

I use Claude to find all test classes with a `setUp()` method that didn't called `super().setUp()` and patch all of them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167163
Approved by: https://github.com/malfet
2025-11-06 17:55:23 +00:00

167 lines
6.4 KiB
Python

# Owner(s): ["oncall: distributed checkpointing"]
import os
import shutil
import tempfile
import torch
from torch.distributed.checkpoint._experimental.barriers import BarrierConfig
from torch.distributed.checkpoint._experimental.builder import (
make_async_checkpointer,
make_sync_checkpointer,
)
from torch.distributed.checkpoint._experimental.checkpointer import (
AsyncCheckpointer,
SyncCheckpointer,
)
from torch.distributed.checkpoint._experimental.config import CheckpointerConfig
from torch.distributed.checkpoint._experimental.staging import CheckpointStagerConfig
from torch.distributed.checkpoint._experimental.types import RankInfo
from torch.testing._internal.common_utils import run_tests, TestCase
class TestMakeCheckpointer(TestCase):
def setUp(self) -> None:
super().setUp()
# Create a temporary directory for checkpoints
self.temp_dir = tempfile.mkdtemp()
# Create real objects for testing
self.rank_info = RankInfo(
global_world_size=1,
global_rank=0,
)
# Create a test state dictionary
self.state_dict = {
"model": torch.nn.Linear(10, 5).state_dict(),
"optimizer": {"param_groups": [{"lr": 0.01}]},
"epoch": 5,
"step": 1000,
}
def tearDown(self) -> None:
# Clean up the temporary directory
shutil.rmtree(self.temp_dir)
def test_make_sync_checkpointer(self) -> None:
"""Test creating a synchronous checkpointer using make_sync_checkpointer."""
# Create sync checkpointer using factory function with no barrier
config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None))
checkpointer = make_sync_checkpointer(config=config, rank_info=self.rank_info)
# Verify it's a SyncCheckpointer instance
self.assertIsInstance(checkpointer, SyncCheckpointer)
# Test that it works for sync operations
checkpoint_path = os.path.join(self.temp_dir, "checkpoint_factory_sync")
result = checkpointer.save(checkpoint_path, self.state_dict)
self.assertIsNone(result) # Sync mode returns None
# Verify checkpoint was created
checkpoint_file = os.path.join(
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
)
self.assertTrue(os.path.exists(checkpoint_file))
# Test loading
loaded_state_dict = checkpointer.load(checkpoint_path)
self.assertEqual(loaded_state_dict["epoch"], 5)
def test_make_sync_checkpointer_with_config_first(self) -> None:
"""Test creating a synchronous checkpointer with config as first parameter."""
# Create sync checkpointer with config as first parameter
config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None))
checkpointer = make_sync_checkpointer(config=config, rank_info=self.rank_info)
# Verify it's a SyncCheckpointer instance
self.assertIsInstance(checkpointer, SyncCheckpointer)
# Test that it works for sync operations
checkpoint_path = os.path.join(
self.temp_dir, "checkpoint_factory_sync_config_first"
)
result = checkpointer.save(checkpoint_path, self.state_dict)
self.assertIsNone(result) # Sync mode returns None
# Verify checkpoint was created
checkpoint_file = os.path.join(
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
)
self.assertTrue(os.path.exists(checkpoint_file))
def test_make_sync_checkpointer_with_custom_config(self) -> None:
"""Test creating a synchronous checkpointer with a custom config."""
# Create a custom config with no barrier
config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None))
# Create sync checkpointer with the custom config
checkpointer = make_sync_checkpointer(rank_info=self.rank_info, config=config)
# Verify it's a SyncCheckpointer instance
self.assertIsInstance(checkpointer, SyncCheckpointer)
# Test that it works for sync operations
checkpoint_path = os.path.join(
self.temp_dir, "checkpoint_factory_sync_custom_config"
)
result = checkpointer.save(checkpoint_path, self.state_dict)
self.assertIsNone(result) # Sync mode returns None
# Verify checkpoint was created
checkpoint_file = os.path.join(
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
)
self.assertTrue(os.path.exists(checkpoint_file))
# Test loading
loaded_state_dict = checkpointer.load(checkpoint_path)
self.assertEqual(loaded_state_dict["epoch"], 5)
def test_make_async_checkpointer(self) -> None:
"""Test creating an asynchronous checkpointer using make_async_checkpointer."""
# Create async checkpointer using factory function with default parameters
config: CheckpointerConfig = CheckpointerConfig()
config.staging_config = CheckpointStagerConfig(
use_non_blocking_copy=torch.accelerator.is_available(),
use_pinned_memory=torch.accelerator.is_available(),
)
checkpointer = make_async_checkpointer(config=config, rank_info=self.rank_info)
try:
# Verify it's an AsyncCheckpointer instance
self.assertIsInstance(checkpointer, AsyncCheckpointer)
# Test that it works for async operations
checkpoint_path = os.path.join(self.temp_dir, "checkpoint_factory_async")
stage_future, write_future = checkpointer.save(
checkpoint_path, self.state_dict
)
# Verify futures are returned
self.assertIsNotNone(stage_future)
self.assertIsNotNone(write_future)
# Wait for completion
stage_future.result()
write_future.result()
# Verify checkpoint was created
checkpoint_file = os.path.join(
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
)
self.assertTrue(os.path.exists(checkpoint_file))
# Test loading
loaded_state_dict = checkpointer.load(checkpoint_path)
self.assertEqual(loaded_state_dict["epoch"], 5)
finally:
# Clean up
checkpointer.close()
if __name__ == "__main__":
run_tests()