mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FSDP2][ez] Removed error check for swap tensors flag (#124513)
Since `DTensor` uses `swap_tensors` path automatically now, we can remove this check for the global flag. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124513 Approved by: https://github.com/weifengpy ghstack dependencies: #124319, #120256
This commit is contained in:
committed by
PyTorch MergeBot
parent
1c2cb36811
commit
f9fce110af
@ -289,16 +289,10 @@ class FSDP:
|
||||
module_info = fsdp_param._module_info
|
||||
new_param = getattr(module_info.module, module_info.param_name)
|
||||
if new_param is not fsdp_param.sharded_param:
|
||||
if torch.__future__.get_swap_module_params_on_conversion():
|
||||
raise AssertionError(
|
||||
"Expects swap_tensors to preserve object but got "
|
||||
f"{new_param} instead of {fsdp_param.sharded_param}"
|
||||
)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Please set torch.__future__.set_swap_module_params_on_conversion(True) "
|
||||
"to use _apply methods with FSDP"
|
||||
)
|
||||
raise AssertionError(
|
||||
"Expects swap_tensors to preserve object but got "
|
||||
f"{new_param} instead of {fsdp_param.sharded_param}"
|
||||
)
|
||||
local_tensor = new_param._local_tensor
|
||||
padded_sharded_size = fsdp_param.padded_sharded_param_size
|
||||
if local_tensor.size() != padded_sharded_size:
|
||||
|
Reference in New Issue
Block a user