Fix FP8 tests, enable FP8 to be used without direct Accelerator() configuring (#3677)

* single-gpu tests passing

* install deepspeed in fp8 container

* revert mixed_precision check
This commit is contained in:
Peter St. John
2025-07-15 07:20:57 -06:00
committed by GitHub
parent 6e104f31de
commit 847ae58c74
6 changed files with 128 additions and 55 deletions

View File

@ -7,7 +7,7 @@ RUN pip install transformers evaluate datasets
RUN git clone https://github.com/huggingface/accelerate.git
RUN cd accelerate && \
pip install -e . && \
pip install -e .[deepspeed] && \
cd benchmarks/fp8
RUN /bin/bash

View File

@ -11,8 +11,8 @@ fp8_config:
fp8_format: E4M3
interval: 1
margin: 0
override_linear_precision: (false, false, false)
override_linear_precision: [false, false, false]
# Generally this should always be set to `false` to have the most realistic fp8 eval performance
use_autocast_during_eval: false
# If using MS-AMP, we ignore all of the prior and set a opt_level
#opt_level: O1
#opt_level: O1

View File

@ -33,6 +33,8 @@ import torch
import torch.utils.hooks as hooks
from huggingface_hub import split_torch_state_dict_into_shards
from accelerate.utils.dataclasses import FP8BackendType
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
from .logging import get_logger
@ -301,6 +303,7 @@ class Accelerator:
self.project_configuration = ProjectConfiguration(project_dir=project_dir)
if project_dir is not None and self.project_dir is None:
self.project_configuration.set_directories(project_dir)
if mixed_precision is not None:
mixed_precision = str(mixed_precision)
if mixed_precision not in PrecisionType:
@ -458,27 +461,34 @@ class Accelerator:
# Check for automatic FP8 recipe creation
if self.fp8_enabled and not self.has_fp8_handler:
# Prioritize AO -> TE -> MSAMP
if is_torchao_available():
logger.info("Found `torchao` installed, using it for FP8 training.")
if self.fp8_backend == FP8BackendType.AO:
self.ao_recipe_handler = AORecipeKwargs()
elif is_transformer_engine_available():
logger.info("Found `transformer-engine` installed, using it for FP8 training.")
elif self.fp8_backend == FP8BackendType.TE:
self.te_recipe_handler = TERecipeKwargs()
elif is_msamp_available():
logger.info("Found `msamp` installed, using it for FP8 training.")
elif self.fp8_backend == FP8BackendType.MSAMP:
self.msamp_recipe_handler = MSAMPRecipeKwargs()
else:
raise ImportError(
"Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. "
"Valid backends are: `torchao`, `transformer-engine`, and `msamp`."
)
elif self.fp8_backend == FP8BackendType.NO:
# Prioritize AO -> TE -> MSAMP
if is_torchao_available():
logger.info("Found `torchao` installed, using it for FP8 training.")
self.ao_recipe_handler = AORecipeKwargs()
elif is_transformer_engine_available():
logger.info("Found `transformer-engine` installed, using it for FP8 training.")
self.te_recipe_handler = TERecipeKwargs()
elif is_msamp_available():
logger.info("Found `msamp` installed, using it for FP8 training.")
self.msamp_recipe_handler = MSAMPRecipeKwargs()
else:
raise ImportError(
"Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. "
"Valid backends are: `torchao`, `transformer-engine`, and `msamp`."
)
self.has_fp8_handler = True
self.delayed_fp8_autocast = False
if self.has_fp8_handler:
# We already check if FP8 is available during `self.state`
if mixed_precision != "fp8" and (
if not self.fp8_enabled and (
self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED)
):
raise ValueError("Passing in an FP8 configuration requires setting `mixed_precision='fp8'`.")
@ -488,7 +498,11 @@ class Accelerator:
)
# TODO: S1ro - this is probably gonna be a problem with other fp8 backends too
if self.fp8_backend == "AO" and self.state.fsdp_plugin.cpu_ram_efficient_loading:
if (
self.fp8_backend == FP8BackendType.AO
and self.state.distributed_type == DistributedType.FSDP
and self.state.fsdp_plugin.cpu_ram_efficient_loading
):
raise ValueError(
"torchao with FSDP2 and cpu_ram_efficient_loading is not supported, setting `cpu_ram_efficient_loading` to False will fix the issue and work as intended."
)
@ -572,7 +586,7 @@ class Accelerator:
elif self.fp8_enabled:
# We always enable `native_amp` for FP8
self.native_amp = True
if self.fp8_backend == "MSAMP":
if self.fp8_backend == FP8BackendType.MSAMP:
if self.distributed_type == DistributedType.FSDP:
raise NotImplementedError(
"`accelerate` + `MS-AMP` + `FSDP` is not supported at this time. "
@ -1419,9 +1433,9 @@ class Accelerator:
"You are using lower version of PyTorch(< 2.7.0) with ipex acceleration on Intel CPU or XPU, Intel has upstreamed most of the optimizations into stock PyTorch from 2.7.0, we enourage you to install the latest stock PyTorch and enjoy the out-of-experience on Intel CPU/XPU."
)
args = self._prepare_ipex(*args)
if self.fp8_backend == "TE":
if self.fp8_backend == FP8BackendType.TE:
args = self._prepare_te(*args)
elif self.fp8_backend == "AO":
elif self.fp8_backend == FP8BackendType.AO:
args = self._prepare_ao(*args)
if self.distributed_type == DistributedType.DEEPSPEED:
result = self._prepare_deepspeed(*args)
@ -1430,7 +1444,7 @@ class Accelerator:
elif self.is_fsdp2:
result = self._prepare_fsdp2(*args)
else:
if self.fp8_backend == "MSAMP":
if self.fp8_backend == FP8BackendType.MSAMP:
args, device_placement = self._prepare_msamp(*args, device_placement=device_placement)
result = tuple(
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
@ -1570,7 +1584,7 @@ class Accelerator:
model._original_forward = model.forward
autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)
# NOTE: MS-AMP adds `__func__` already to `model.forward`, so we should always use `model.forward`
if self.fp8_backend == "MSAMP" or not hasattr(model.forward, "__func__"):
if self.fp8_backend == FP8BackendType.MSAMP or not hasattr(model.forward, "__func__"):
model_forward_func = model.forward
model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func))
else:
@ -1580,7 +1594,7 @@ class Accelerator:
model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)
# We prepare TE after, allowing for bf16 autocast to happen first
if self.fp8_backend == "TE" and not self.delayed_fp8_autocast:
if self.fp8_backend == FP8BackendType.TE and not self.delayed_fp8_autocast:
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
@ -1806,7 +1820,7 @@ class Accelerator:
elif self.distributed_type == DistributedType.XLA and self.state.fork_launched:
model = xmp.MpModelWrapper(model).to(self.device)
# Now we can apply the FP8 autocast
if self.fp8_backend == "TE" and self.delayed_fp8_autocast:
if self.fp8_backend == FP8BackendType.TE and self.delayed_fp8_autocast:
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
# torch.compile should be called last and only if the model isn't already compiled
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
@ -1884,7 +1898,7 @@ class Accelerator:
import deepspeed
ds_initialize = deepspeed.initialize
if self.fp8_backend == "MSAMP":
if self.fp8_backend == FP8BackendType.MSAMP:
# MS-AMP requires DeepSpeed patches
from msamp import deepspeed as msamp_deepspeed
@ -2022,7 +2036,7 @@ class Accelerator:
if model is not None:
# If we are using FP8, we need to apply the autowrap now
if self.fp8_backend == "TE":
if self.fp8_backend == FP8BackendType.TE:
model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
# if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules
deepspeed_plugin.set_moe_leaf_modules(model)
@ -2479,7 +2493,7 @@ class Accelerator:
device_placement = self.device_placement
# NOTE: Special case with MS-AMP we do *not* pass in the scaler explicitly to the `AcceleratedOptimizer`,
# Their optimizer handles it for us.
scaler = None if self.fp8_backend == "MSAMP" else self.scaler
scaler = None if self.fp8_backend == FP8BackendType.MSAMP else self.scaler
optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=scaler)
self._optimizers.append(optimizer)
return optimizer
@ -3668,7 +3682,7 @@ class Accelerator:
# we need this bit as `WeightWithDynamic...` returns 0 when `data_ptr()` is called,
# the underlying pointer is actually hidden in `_tensor` attribute
if self.fp8_backend == "AO":
if self.fp8_backend == FP8BackendType.AO:
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
accessor_mapping[WeightWithDynamicFloat8CastTensor] = "_tensor"
@ -3977,17 +3991,18 @@ class Accelerator:
)
@property
def fp8_backend(self):
def fp8_backend(self) -> FP8BackendType:
"Returns the configured backend for training in FP8"
if self.has_fp8_handler:
if self.fp8_recipe_handler is not None:
return self.fp8_recipe_handler.backend
return FP8BackendType(self.fp8_recipe_handler.backend)
elif self.ao_recipe_handler is not None:
return "AO"
return FP8BackendType.AO
elif self.te_recipe_handler is not None:
return "TE"
return FP8BackendType.TE
elif self.msamp_recipe_handler is not None:
return "MSAMP"
return FP8BackendType.MSAMP
elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp:
return "MSAMP"
return None
return FP8BackendType.MSAMP
return FP8BackendType(parse_choice_from_env("ACCELERATE_FP8_BACKEND", "NO"))

View File

@ -616,8 +616,10 @@ class FP8BackendType(str, enum.Enum):
"""
# Subclassing str as well as Enum allows the `FP8BackendType` to be JSON-serializable out of the box.
NO = "NO"
TE = "TE"
MSAMP = "MSAMP"
AO = "AO"
class ComputeEnvironment(str, enum.Enum):

View File

@ -89,9 +89,9 @@ def setup_fp8_env(args: argparse.Namespace, current_env: dict[str, str]):
value = getattr(args, arg)
if value is not None:
if arg == "fp8_override_linear_precision":
current_env[prefix + "FP8_OVERRIDE_FPROP"] = value[0]
current_env[prefix + "FP8_OVERRIDE_DGRAD"] = value[1]
current_env[prefix + "FP8_OVERRIDE_WGRAD"] = value[2]
current_env[prefix + "FP8_OVERRIDE_FPROP"] = str(value[0])
current_env[prefix + "FP8_OVERRIDE_DGRAD"] = str(value[1])
current_env[prefix + "FP8_OVERRIDE_WGRAD"] = str(value[2])
else:
current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
return current_env

View File

@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import tempfile
import textwrap
import unittest
from pathlib import Path
import torch
@ -32,16 +36,18 @@ from accelerate.test_utils import (
from accelerate.test_utils.testing import require_deepspeed, run_command
from accelerate.utils import (
AORecipeKwargs,
FP8RecipeKwargs,
TERecipeKwargs,
has_ao_layers,
has_transformer_engine_layers,
is_torchao_available,
is_transformer_engine_available,
)
def can_convert_te_model():
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [FP8RecipeKwargs(backend="TE")]}
def can_convert_te_model(from_config=False):
if not from_config:
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [TERecipeKwargs()]}
else:
accelerator_kwargs = {}
accelerator = Accelerator(**accelerator_kwargs)
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.Linear(32, 16))
@ -58,10 +64,14 @@ def maintain_proper_deepspeed_config(expected_version):
)
def can_convert_ao_model():
def can_convert_ao_model(from_config=False):
from transformers import AutoModelForSequenceClassification
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [AORecipeKwargs()]}
if not from_config:
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [AORecipeKwargs()]}
else:
accelerator_kwargs = {}
accelerator = Accelerator(**accelerator_kwargs)
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased")
@ -78,13 +88,31 @@ def can_convert_ao_model():
class TestTransformerEngine(unittest.TestCase):
def test_can_prepare_model_single_gpu(self):
command = get_launch_command(num_processes=1, monitor_interval=0.1)
command += ["-m", "tests.test_fp8"]
command += ["-m", "tests.test_fp8", "--test_te"]
run_command(command)
def test_can_prepare_model_single_gpu_from_config(self):
with tempfile.TemporaryDirectory() as dir_name:
config_file = Path(dir_name) / "config.yaml"
config_file.write_text(
textwrap.dedent(
"""
distributed_type: "NO"
num_processes: 1
mixed_precision: fp8
fp8_config:
backend: TE
"""
)
)
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
run_command(command)
@require_multi_device
def test_can_prepare_model_multi_gpu(self):
command = get_launch_command(num_processes=2, monitor_interval=0.1)
command += ["-m", "tests.test_fp8"]
command += ["-m", "tests.test_fp8", "--test_te"]
run_command(command)
@require_deepspeed
@ -116,7 +144,7 @@ class TestTransformerEngine(unittest.TestCase):
command = get_launch_command(
num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config
)
command += ["-m", "tests.test_fp8"]
command += ["-m", "tests.test_fp8", "--test_te"]
run_command(command)
@ -125,13 +153,31 @@ class TestTransformerEngine(unittest.TestCase):
class TestTorchAO(unittest.TestCase):
def test_can_prepare_model_single_accelerator(self):
command = get_launch_command(num_processes=1, monitor_interval=0.1)
command += ["-m", "tests.test_fp8"]
command += ["-m", "tests.test_fp8", "--test_ao"]
run_command(command)
def test_can_prepare_model_single_gpu_from_config(self):
with tempfile.TemporaryDirectory() as dir_name:
config_file = Path(dir_name) / "config.yaml"
config_file.write_text(
textwrap.dedent(
"""
distributed_type: "NO"
num_processes: 1
mixed_precision: fp8
fp8_config:
backend: AO
"""
)
)
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
command += ["-m", "tests.test_fp8", "--test_ao", "--from_config"]
run_command(command)
@require_multi_device
def test_can_prepare_model_multi_accelerator(self):
command = get_launch_command(num_processes=2, monitor_interval=0.1)
command += ["-m", "tests.test_fp8"]
command += ["-m", "tests.test_fp8", "--test_ao"]
run_command(command)
@require_deepspeed
@ -163,16 +209,26 @@ class TestTorchAO(unittest.TestCase):
command = get_launch_command(
num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config
)
command += ["-m", "tests.test_fp8"]
command += ["-m", "tests.test_fp8", "--test_ao"]
run_command(command)
if __name__ == "__main__":
# TE suite
if is_transformer_engine_available():
can_convert_te_model()
parser = argparse.ArgumentParser()
parser.add_argument("--test_te", action="store_true", default=False)
parser.add_argument("--test_ao", action="store_true", default=False)
parser.add_argument("--from_config", action="store_true", default=False)
args = parser.parse_args()
if not args.test_te and not args.test_ao:
raise ValueError("Must specify at least one of --test_te or --test_ao")
if args.test_te:
can_convert_te_model(args.from_config)
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
maintain_proper_deepspeed_config(int(os.environ.get("ZERO_STAGE")))
# AO suite
if is_torchao_available():
can_convert_ao_model()
if args.test_ao:
can_convert_ao_model(args.from_config)