mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
Set parallelism_config in constructor due to Trainer reset of State (#3713)
This commit is contained in:
@ -903,6 +903,7 @@ class AcceleratorState:
|
||||
fsdp_plugin=None,
|
||||
torch_tp_plugin=None,
|
||||
megatron_lm_plugin=None,
|
||||
parallelism_config=None,
|
||||
_from_accelerator: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@ -917,6 +918,7 @@ class AcceleratorState:
|
||||
self.deepspeed_plugins = None
|
||||
self.use_ipex = None
|
||||
self.torch_tp_plugin = torch_tp_plugin
|
||||
self.parallelism_config = parallelism_config
|
||||
mixed_precision = (
|
||||
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
|
||||
if mixed_precision is None
|
||||
@ -995,13 +997,13 @@ class AcceleratorState:
|
||||
raise ValueError(
|
||||
"Using `cp_size>1` requires FSDP2, but the provided `fsdp_plugin` is using FSDP1. "
|
||||
)
|
||||
if (
|
||||
os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" or fsdp_plugin is not None
|
||||
) or (self.parallelism_config is not None and self.parallelism_config.cp_enabled):
|
||||
self.distributed_type = DistributedType.FSDP
|
||||
if self._mixed_precision != "no":
|
||||
fsdp_plugin.set_mixed_precision(self._mixed_precision)
|
||||
self.fsdp_plugin = fsdp_plugin
|
||||
if (os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" or fsdp_plugin is not None) or (
|
||||
self.parallelism_config is not None and self.parallelism_config.cp_enabled
|
||||
):
|
||||
self.distributed_type = DistributedType.FSDP
|
||||
if self._mixed_precision != "no" and fsdp_plugin is not None:
|
||||
fsdp_plugin.set_mixed_precision(self._mixed_precision)
|
||||
self.fsdp_plugin = fsdp_plugin
|
||||
if os.environ.get(
|
||||
"ACCELERATE_USE_MEGATRON_LM", "false"
|
||||
).lower() == "true" and self.distributed_type not in [
|
||||
|
Reference in New Issue
Block a user