[FSDP2][ez] Removed error check for swap tensors flag (#124513)

Since `DTensor` uses `swap_tensors` path automatically now, we can remove this check for the global flag.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124513
Approved by: https://github.com/weifengpy
ghstack dependencies: #124319, #120256
This commit is contained in:
Andrew Gu
2024-04-19 12:12:23 -07:00
committed by PyTorch MergeBot
parent 1c2cb36811
commit f9fce110af

View File

@ -289,16 +289,10 @@ class FSDP:
module_info = fsdp_param._module_info
new_param = getattr(module_info.module, module_info.param_name)
if new_param is not fsdp_param.sharded_param:
if torch.__future__.get_swap_module_params_on_conversion():
raise AssertionError(
"Expects swap_tensors to preserve object but got "
f"{new_param} instead of {fsdp_param.sharded_param}"
)
else:
raise AssertionError(
"Please set torch.__future__.set_swap_module_params_on_conversion(True) "
"to use _apply methods with FSDP"
)
raise AssertionError(
"Expects swap_tensors to preserve object but got "
f"{new_param} instead of {fsdp_param.sharded_param}"
)
local_tensor = new_param._local_tensor
padded_sharded_size = fsdp_param.padded_sharded_param_size
if local_tensor.size() != padded_sharded_size: