mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
resolves https://github.com/pytorch/torchtitan/issues/1136 torchtitan use cached state dict for ft. reset_sharded_param should be idempotent if model.parameters() are padded already ``` # pad DTensor._local_tensor fully_shard(model) sd = fsdp_model.state_dict() # reset_sharded_param should be a no-op in lazy_init loss = fsdp_model(inp).sum() ``` this PR make `reset_sharded_param` idempotent by checking storage data ptr and return early unit test ``` pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_cached_state_dict ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/163130 Approved by: https://github.com/tianyu-l