feat: allow mixed precision policy as dtype (#3751)

* feat: allow mixed precision as dtype

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* feat: allow mixed precision as dtype

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* feat: allow mixed precision as dtype

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* test: extend test for MP as str dtype

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* Fix: style

---------

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
Co-authored-by: S1ro1 <matej.sirovatka@gmail.com>
This commit is contained in:
Mehant Kammakomati
2025-09-09 02:59:20 +05:30
committed by GitHub
parent 8830e58a91
commit a0bc36e8ed
2 changed files with 13 additions and 3 deletions

View File

@ -1553,10 +1553,12 @@ class FullyShardedDataParallelPlugin:
backward_prefetch (`Union[str, torch.distributed.fsdp.BackwardPrefetch]`, defaults to `'NO_PREFETCH'`):
Backward prefetch strategy to use. Should be either a `str` or an instance of
`torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`.
mixed_precision_policy (`Optional[Union[dict, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
mixed_precision_policy (`Optional[Union[dict, str, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
A config to enable mixed precision training with FullyShardedDataParallel. If passing in a `dict`, it
should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`, can be an instance of
`torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2.
`torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2. If passing in a `str`, it
should be one of the following values: fp8, fp16, bf16, fp32, and used to set `param_dtype`,
`reduce_dtype`, and `buffer_dtype`.
auto_wrap_policy (`Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to `NO_WRAP`):
A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one
of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. See
@ -1635,6 +1637,7 @@ class FullyShardedDataParallelPlugin:
mixed_precision_policy: Optional[
Union[
dict,
str,
"torch.distributed.fsdp.MixedPrecision",
"torch.distributed.fsdp.MixedPrecisionPolicy",
]
@ -1926,7 +1929,11 @@ class FullyShardedDataParallelPlugin:
)
os.environ[env_var] = str(self.cpu_ram_efficient_loading)
if isinstance(self.mixed_precision_policy, dict):
if isinstance(self.mixed_precision_policy, str):
# override is True since self.mixed_precision_policy is not None
# has to be overwritten with the correct mixed precision object
self.set_mixed_precision(self.mixed_precision_policy, override=True)
elif isinstance(self.mixed_precision_policy, dict):
self.set_mixed_precision(self.mixed_precision_policy)
if self.mixed_precision_policy is not None:
self.validate_mixed_precision_policy()

View File

@ -316,6 +316,9 @@ class FSDPPluginIntegration(AccelerateTestCase):
AcceleratorState._reset_state(True)
env = self.fsdp_envs[fsdp_version].copy()
with patch_environment(**env):
plugin = FullyShardedDataParallelPlugin(mixed_precision_policy=mp_dtype)
assert plugin.mixed_precision_policy == mp_policy
with patch_environment(**env):
plugin = FullyShardedDataParallelPlugin(
mixed_precision_policy={"param_dtype": dtype, "reduce_dtype": dtype, **{extra_arg: dtype}}