mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
feat: allow mixed precision policy as dtype (#3751)
* feat: allow mixed precision as dtype Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * feat: allow mixed precision as dtype Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * feat: allow mixed precision as dtype Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * test: extend test for MP as str dtype Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * Fix: style --------- Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> Co-authored-by: S1ro1 <matej.sirovatka@gmail.com>
This commit is contained in:
committed by
GitHub
parent
8830e58a91
commit
a0bc36e8ed
@ -1553,10 +1553,12 @@ class FullyShardedDataParallelPlugin:
|
||||
backward_prefetch (`Union[str, torch.distributed.fsdp.BackwardPrefetch]`, defaults to `'NO_PREFETCH'`):
|
||||
Backward prefetch strategy to use. Should be either a `str` or an instance of
|
||||
`torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`.
|
||||
mixed_precision_policy (`Optional[Union[dict, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
|
||||
mixed_precision_policy (`Optional[Union[dict, str, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
|
||||
A config to enable mixed precision training with FullyShardedDataParallel. If passing in a `dict`, it
|
||||
should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`, can be an instance of
|
||||
`torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2.
|
||||
`torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2. If passing in a `str`, it
|
||||
should be one of the following values: fp8, fp16, bf16, fp32, and used to set `param_dtype`,
|
||||
`reduce_dtype`, and `buffer_dtype`.
|
||||
auto_wrap_policy (`Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to `NO_WRAP`):
|
||||
A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one
|
||||
of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. See
|
||||
@ -1635,6 +1637,7 @@ class FullyShardedDataParallelPlugin:
|
||||
mixed_precision_policy: Optional[
|
||||
Union[
|
||||
dict,
|
||||
str,
|
||||
"torch.distributed.fsdp.MixedPrecision",
|
||||
"torch.distributed.fsdp.MixedPrecisionPolicy",
|
||||
]
|
||||
@ -1926,7 +1929,11 @@ class FullyShardedDataParallelPlugin:
|
||||
)
|
||||
os.environ[env_var] = str(self.cpu_ram_efficient_loading)
|
||||
|
||||
if isinstance(self.mixed_precision_policy, dict):
|
||||
if isinstance(self.mixed_precision_policy, str):
|
||||
# override is True since self.mixed_precision_policy is not None
|
||||
# has to be overwritten with the correct mixed precision object
|
||||
self.set_mixed_precision(self.mixed_precision_policy, override=True)
|
||||
elif isinstance(self.mixed_precision_policy, dict):
|
||||
self.set_mixed_precision(self.mixed_precision_policy)
|
||||
if self.mixed_precision_policy is not None:
|
||||
self.validate_mixed_precision_policy()
|
||||
|
@ -316,6 +316,9 @@ class FSDPPluginIntegration(AccelerateTestCase):
|
||||
AcceleratorState._reset_state(True)
|
||||
|
||||
env = self.fsdp_envs[fsdp_version].copy()
|
||||
with patch_environment(**env):
|
||||
plugin = FullyShardedDataParallelPlugin(mixed_precision_policy=mp_dtype)
|
||||
assert plugin.mixed_precision_policy == mp_policy
|
||||
with patch_environment(**env):
|
||||
plugin = FullyShardedDataParallelPlugin(
|
||||
mixed_precision_policy={"param_dtype": dtype, "reduce_dtype": dtype, **{extra_arg: dtype}}
|
||||
|
Reference in New Issue
Block a user