Set parallelism_config in constructor due to Trainer reset of State (#3713)

This commit is contained in:
Wing Lian
2025-08-06 07:47:49 -04:00
committed by GitHub
parent 6891c57072
commit 24c8157bba

View File

@ -903,6 +903,7 @@ class AcceleratorState:
fsdp_plugin=None,
torch_tp_plugin=None,
megatron_lm_plugin=None,
parallelism_config=None,
_from_accelerator: bool = False,
**kwargs,
):
@ -917,6 +918,7 @@ class AcceleratorState:
self.deepspeed_plugins = None
self.use_ipex = None
self.torch_tp_plugin = torch_tp_plugin
self.parallelism_config = parallelism_config
mixed_precision = (
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
if mixed_precision is None
@ -995,13 +997,13 @@ class AcceleratorState:
raise ValueError(
"Using `cp_size>1` requires FSDP2, but the provided `fsdp_plugin` is using FSDP1. "
)
if (
os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" or fsdp_plugin is not None
) or (self.parallelism_config is not None and self.parallelism_config.cp_enabled):
self.distributed_type = DistributedType.FSDP
if self._mixed_precision != "no":
fsdp_plugin.set_mixed_precision(self._mixed_precision)
self.fsdp_plugin = fsdp_plugin
if (os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" or fsdp_plugin is not None) or (
self.parallelism_config is not None and self.parallelism_config.cp_enabled
):
self.distributed_type = DistributedType.FSDP
if self._mixed_precision != "no" and fsdp_plugin is not None:
fsdp_plugin.set_mixed_precision(self._mixed_precision)
self.fsdp_plugin = fsdp_plugin
if os.environ.get(
"ACCELERATE_USE_MEGATRON_LM", "false"
).lower() == "true" and self.distributed_type not in [