Fix FP8 tests, enable FP8 to be used without direct Accelerator() configuring (#3677)

* single-gpu tests passing

* install deepspeed in fp8 container

* revert mixed_precision check
This commit is contained in:
Peter St. John
2025-07-15 07:20:57 -06:00
committed by GitHub
parent 6e104f31de
commit 847ae58c74
6 changed files with 128 additions and 55 deletions

View File

@ -7,7 +7,7 @@ RUN pip install transformers evaluate datasets
RUN git clone https://github.com/huggingface/accelerate.git RUN git clone https://github.com/huggingface/accelerate.git
RUN cd accelerate && \ RUN cd accelerate && \
pip install -e . && \ pip install -e .[deepspeed] && \
cd benchmarks/fp8 cd benchmarks/fp8
RUN /bin/bash RUN /bin/bash

View File

@ -11,8 +11,8 @@ fp8_config:
fp8_format: E4M3 fp8_format: E4M3
interval: 1 interval: 1
margin: 0 margin: 0
override_linear_precision: (false, false, false) override_linear_precision: [false, false, false]
# Generally this should always be set to `false` to have the most realistic fp8 eval performance # Generally this should always be set to `false` to have the most realistic fp8 eval performance
use_autocast_during_eval: false use_autocast_during_eval: false
# If using MS-AMP, we ignore all of the prior and set a opt_level # If using MS-AMP, we ignore all of the prior and set a opt_level
#opt_level: O1 #opt_level: O1

View File

@ -33,6 +33,8 @@ import torch
import torch.utils.hooks as hooks import torch.utils.hooks as hooks
from huggingface_hub import split_torch_state_dict_into_shards from huggingface_hub import split_torch_state_dict_into_shards
from accelerate.utils.dataclasses import FP8BackendType
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
from .logging import get_logger from .logging import get_logger
@ -301,6 +303,7 @@ class Accelerator:
self.project_configuration = ProjectConfiguration(project_dir=project_dir) self.project_configuration = ProjectConfiguration(project_dir=project_dir)
if project_dir is not None and self.project_dir is None: if project_dir is not None and self.project_dir is None:
self.project_configuration.set_directories(project_dir) self.project_configuration.set_directories(project_dir)
if mixed_precision is not None: if mixed_precision is not None:
mixed_precision = str(mixed_precision) mixed_precision = str(mixed_precision)
if mixed_precision not in PrecisionType: if mixed_precision not in PrecisionType:
@ -458,27 +461,34 @@ class Accelerator:
# Check for automatic FP8 recipe creation # Check for automatic FP8 recipe creation
if self.fp8_enabled and not self.has_fp8_handler: if self.fp8_enabled and not self.has_fp8_handler:
# Prioritize AO -> TE -> MSAMP if self.fp8_backend == FP8BackendType.AO:
if is_torchao_available():
logger.info("Found `torchao` installed, using it for FP8 training.")
self.ao_recipe_handler = AORecipeKwargs() self.ao_recipe_handler = AORecipeKwargs()
elif is_transformer_engine_available(): elif self.fp8_backend == FP8BackendType.TE:
logger.info("Found `transformer-engine` installed, using it for FP8 training.")
self.te_recipe_handler = TERecipeKwargs() self.te_recipe_handler = TERecipeKwargs()
elif is_msamp_available(): elif self.fp8_backend == FP8BackendType.MSAMP:
logger.info("Found `msamp` installed, using it for FP8 training.")
self.msamp_recipe_handler = MSAMPRecipeKwargs() self.msamp_recipe_handler = MSAMPRecipeKwargs()
else: elif self.fp8_backend == FP8BackendType.NO:
raise ImportError( # Prioritize AO -> TE -> MSAMP
"Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. " if is_torchao_available():
"Valid backends are: `torchao`, `transformer-engine`, and `msamp`." logger.info("Found `torchao` installed, using it for FP8 training.")
) self.ao_recipe_handler = AORecipeKwargs()
elif is_transformer_engine_available():
logger.info("Found `transformer-engine` installed, using it for FP8 training.")
self.te_recipe_handler = TERecipeKwargs()
elif is_msamp_available():
logger.info("Found `msamp` installed, using it for FP8 training.")
self.msamp_recipe_handler = MSAMPRecipeKwargs()
else:
raise ImportError(
"Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. "
"Valid backends are: `torchao`, `transformer-engine`, and `msamp`."
)
self.has_fp8_handler = True self.has_fp8_handler = True
self.delayed_fp8_autocast = False self.delayed_fp8_autocast = False
if self.has_fp8_handler: if self.has_fp8_handler:
# We already check if FP8 is available during `self.state` # We already check if FP8 is available during `self.state`
if mixed_precision != "fp8" and ( if not self.fp8_enabled and (
self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED) self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED)
): ):
raise ValueError("Passing in an FP8 configuration requires setting `mixed_precision='fp8'`.") raise ValueError("Passing in an FP8 configuration requires setting `mixed_precision='fp8'`.")
@ -488,7 +498,11 @@ class Accelerator:
) )
# TODO: S1ro - this is probably gonna be a problem with other fp8 backends too # TODO: S1ro - this is probably gonna be a problem with other fp8 backends too
if self.fp8_backend == "AO" and self.state.fsdp_plugin.cpu_ram_efficient_loading: if (
self.fp8_backend == FP8BackendType.AO
and self.state.distributed_type == DistributedType.FSDP
and self.state.fsdp_plugin.cpu_ram_efficient_loading
):
raise ValueError( raise ValueError(
"torchao with FSDP2 and cpu_ram_efficient_loading is not supported, setting `cpu_ram_efficient_loading` to False will fix the issue and work as intended." "torchao with FSDP2 and cpu_ram_efficient_loading is not supported, setting `cpu_ram_efficient_loading` to False will fix the issue and work as intended."
) )
@ -572,7 +586,7 @@ class Accelerator:
elif self.fp8_enabled: elif self.fp8_enabled:
# We always enable `native_amp` for FP8 # We always enable `native_amp` for FP8
self.native_amp = True self.native_amp = True
if self.fp8_backend == "MSAMP": if self.fp8_backend == FP8BackendType.MSAMP:
if self.distributed_type == DistributedType.FSDP: if self.distributed_type == DistributedType.FSDP:
raise NotImplementedError( raise NotImplementedError(
"`accelerate` + `MS-AMP` + `FSDP` is not supported at this time. " "`accelerate` + `MS-AMP` + `FSDP` is not supported at this time. "
@ -1419,9 +1433,9 @@ class Accelerator:
"You are using lower version of PyTorch(< 2.7.0) with ipex acceleration on Intel CPU or XPU, Intel has upstreamed most of the optimizations into stock PyTorch from 2.7.0, we enourage you to install the latest stock PyTorch and enjoy the out-of-experience on Intel CPU/XPU." "You are using lower version of PyTorch(< 2.7.0) with ipex acceleration on Intel CPU or XPU, Intel has upstreamed most of the optimizations into stock PyTorch from 2.7.0, we enourage you to install the latest stock PyTorch and enjoy the out-of-experience on Intel CPU/XPU."
) )
args = self._prepare_ipex(*args) args = self._prepare_ipex(*args)
if self.fp8_backend == "TE": if self.fp8_backend == FP8BackendType.TE:
args = self._prepare_te(*args) args = self._prepare_te(*args)
elif self.fp8_backend == "AO": elif self.fp8_backend == FP8BackendType.AO:
args = self._prepare_ao(*args) args = self._prepare_ao(*args)
if self.distributed_type == DistributedType.DEEPSPEED: if self.distributed_type == DistributedType.DEEPSPEED:
result = self._prepare_deepspeed(*args) result = self._prepare_deepspeed(*args)
@ -1430,7 +1444,7 @@ class Accelerator:
elif self.is_fsdp2: elif self.is_fsdp2:
result = self._prepare_fsdp2(*args) result = self._prepare_fsdp2(*args)
else: else:
if self.fp8_backend == "MSAMP": if self.fp8_backend == FP8BackendType.MSAMP:
args, device_placement = self._prepare_msamp(*args, device_placement=device_placement) args, device_placement = self._prepare_msamp(*args, device_placement=device_placement)
result = tuple( result = tuple(
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
@ -1570,7 +1584,7 @@ class Accelerator:
model._original_forward = model.forward model._original_forward = model.forward
autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler) autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)
# NOTE: MS-AMP adds `__func__` already to `model.forward`, so we should always use `model.forward` # NOTE: MS-AMP adds `__func__` already to `model.forward`, so we should always use `model.forward`
if self.fp8_backend == "MSAMP" or not hasattr(model.forward, "__func__"): if self.fp8_backend == FP8BackendType.MSAMP or not hasattr(model.forward, "__func__"):
model_forward_func = model.forward model_forward_func = model.forward
model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func)) model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func))
else: else:
@ -1580,7 +1594,7 @@ class Accelerator:
model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)
# We prepare TE after, allowing for bf16 autocast to happen first # We prepare TE after, allowing for bf16 autocast to happen first
if self.fp8_backend == "TE" and not self.delayed_fp8_autocast: if self.fp8_backend == FP8BackendType.TE and not self.delayed_fp8_autocast:
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler) model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr( if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
@ -1806,7 +1820,7 @@ class Accelerator:
elif self.distributed_type == DistributedType.XLA and self.state.fork_launched: elif self.distributed_type == DistributedType.XLA and self.state.fork_launched:
model = xmp.MpModelWrapper(model).to(self.device) model = xmp.MpModelWrapper(model).to(self.device)
# Now we can apply the FP8 autocast # Now we can apply the FP8 autocast
if self.fp8_backend == "TE" and self.delayed_fp8_autocast: if self.fp8_backend == FP8BackendType.TE and self.delayed_fp8_autocast:
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler) model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
# torch.compile should be called last and only if the model isn't already compiled # torch.compile should be called last and only if the model isn't already compiled
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model): if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
@ -1884,7 +1898,7 @@ class Accelerator:
import deepspeed import deepspeed
ds_initialize = deepspeed.initialize ds_initialize = deepspeed.initialize
if self.fp8_backend == "MSAMP": if self.fp8_backend == FP8BackendType.MSAMP:
# MS-AMP requires DeepSpeed patches # MS-AMP requires DeepSpeed patches
from msamp import deepspeed as msamp_deepspeed from msamp import deepspeed as msamp_deepspeed
@ -2022,7 +2036,7 @@ class Accelerator:
if model is not None: if model is not None:
# If we are using FP8, we need to apply the autowrap now # If we are using FP8, we need to apply the autowrap now
if self.fp8_backend == "TE": if self.fp8_backend == FP8BackendType.TE:
model = apply_fp8_autowrap(model, self.fp8_recipe_handler) model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
# if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules # if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules
deepspeed_plugin.set_moe_leaf_modules(model) deepspeed_plugin.set_moe_leaf_modules(model)
@ -2479,7 +2493,7 @@ class Accelerator:
device_placement = self.device_placement device_placement = self.device_placement
# NOTE: Special case with MS-AMP we do *not* pass in the scaler explicitly to the `AcceleratedOptimizer`, # NOTE: Special case with MS-AMP we do *not* pass in the scaler explicitly to the `AcceleratedOptimizer`,
# Their optimizer handles it for us. # Their optimizer handles it for us.
scaler = None if self.fp8_backend == "MSAMP" else self.scaler scaler = None if self.fp8_backend == FP8BackendType.MSAMP else self.scaler
optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=scaler) optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=scaler)
self._optimizers.append(optimizer) self._optimizers.append(optimizer)
return optimizer return optimizer
@ -3668,7 +3682,7 @@ class Accelerator:
# we need this bit as `WeightWithDynamic...` returns 0 when `data_ptr()` is called, # we need this bit as `WeightWithDynamic...` returns 0 when `data_ptr()` is called,
# the underlying pointer is actually hidden in `_tensor` attribute # the underlying pointer is actually hidden in `_tensor` attribute
if self.fp8_backend == "AO": if self.fp8_backend == FP8BackendType.AO:
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
accessor_mapping[WeightWithDynamicFloat8CastTensor] = "_tensor" accessor_mapping[WeightWithDynamicFloat8CastTensor] = "_tensor"
@ -3977,17 +3991,18 @@ class Accelerator:
) )
@property @property
def fp8_backend(self): def fp8_backend(self) -> FP8BackendType:
"Returns the configured backend for training in FP8" "Returns the configured backend for training in FP8"
if self.has_fp8_handler: if self.has_fp8_handler:
if self.fp8_recipe_handler is not None: if self.fp8_recipe_handler is not None:
return self.fp8_recipe_handler.backend return FP8BackendType(self.fp8_recipe_handler.backend)
elif self.ao_recipe_handler is not None: elif self.ao_recipe_handler is not None:
return "AO" return FP8BackendType.AO
elif self.te_recipe_handler is not None: elif self.te_recipe_handler is not None:
return "TE" return FP8BackendType.TE
elif self.msamp_recipe_handler is not None: elif self.msamp_recipe_handler is not None:
return "MSAMP" return FP8BackendType.MSAMP
elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp: elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp:
return "MSAMP" return FP8BackendType.MSAMP
return None
return FP8BackendType(parse_choice_from_env("ACCELERATE_FP8_BACKEND", "NO"))

View File

@ -616,8 +616,10 @@ class FP8BackendType(str, enum.Enum):
""" """
# Subclassing str as well as Enum allows the `FP8BackendType` to be JSON-serializable out of the box. # Subclassing str as well as Enum allows the `FP8BackendType` to be JSON-serializable out of the box.
NO = "NO"
TE = "TE" TE = "TE"
MSAMP = "MSAMP" MSAMP = "MSAMP"
AO = "AO"
class ComputeEnvironment(str, enum.Enum): class ComputeEnvironment(str, enum.Enum):

View File

@ -89,9 +89,9 @@ def setup_fp8_env(args: argparse.Namespace, current_env: dict[str, str]):
value = getattr(args, arg) value = getattr(args, arg)
if value is not None: if value is not None:
if arg == "fp8_override_linear_precision": if arg == "fp8_override_linear_precision":
current_env[prefix + "FP8_OVERRIDE_FPROP"] = value[0] current_env[prefix + "FP8_OVERRIDE_FPROP"] = str(value[0])
current_env[prefix + "FP8_OVERRIDE_DGRAD"] = value[1] current_env[prefix + "FP8_OVERRIDE_DGRAD"] = str(value[1])
current_env[prefix + "FP8_OVERRIDE_WGRAD"] = value[2] current_env[prefix + "FP8_OVERRIDE_WGRAD"] = str(value[2])
else: else:
current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg)) current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
return current_env return current_env

View File

@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import json import json
import os import os
import tempfile
import textwrap
import unittest import unittest
from pathlib import Path
import torch import torch
@ -32,16 +36,18 @@ from accelerate.test_utils import (
from accelerate.test_utils.testing import require_deepspeed, run_command from accelerate.test_utils.testing import require_deepspeed, run_command
from accelerate.utils import ( from accelerate.utils import (
AORecipeKwargs, AORecipeKwargs,
FP8RecipeKwargs, TERecipeKwargs,
has_ao_layers, has_ao_layers,
has_transformer_engine_layers, has_transformer_engine_layers,
is_torchao_available,
is_transformer_engine_available,
) )
def can_convert_te_model(): def can_convert_te_model(from_config=False):
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [FP8RecipeKwargs(backend="TE")]} if not from_config:
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [TERecipeKwargs()]}
else:
accelerator_kwargs = {}
accelerator = Accelerator(**accelerator_kwargs) accelerator = Accelerator(**accelerator_kwargs)
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2) dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.Linear(32, 16)) model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.Linear(32, 16))
@ -58,10 +64,14 @@ def maintain_proper_deepspeed_config(expected_version):
) )
def can_convert_ao_model(): def can_convert_ao_model(from_config=False):
from transformers import AutoModelForSequenceClassification from transformers import AutoModelForSequenceClassification
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [AORecipeKwargs()]} if not from_config:
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [AORecipeKwargs()]}
else:
accelerator_kwargs = {}
accelerator = Accelerator(**accelerator_kwargs) accelerator = Accelerator(**accelerator_kwargs)
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2) dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased") model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased")
@ -78,13 +88,31 @@ def can_convert_ao_model():
class TestTransformerEngine(unittest.TestCase): class TestTransformerEngine(unittest.TestCase):
def test_can_prepare_model_single_gpu(self): def test_can_prepare_model_single_gpu(self):
command = get_launch_command(num_processes=1, monitor_interval=0.1) command = get_launch_command(num_processes=1, monitor_interval=0.1)
command += ["-m", "tests.test_fp8"] command += ["-m", "tests.test_fp8", "--test_te"]
run_command(command) run_command(command)
def test_can_prepare_model_single_gpu_from_config(self):
with tempfile.TemporaryDirectory() as dir_name:
config_file = Path(dir_name) / "config.yaml"
config_file.write_text(
textwrap.dedent(
"""
distributed_type: "NO"
num_processes: 1
mixed_precision: fp8
fp8_config:
backend: TE
"""
)
)
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
run_command(command)
@require_multi_device @require_multi_device
def test_can_prepare_model_multi_gpu(self): def test_can_prepare_model_multi_gpu(self):
command = get_launch_command(num_processes=2, monitor_interval=0.1) command = get_launch_command(num_processes=2, monitor_interval=0.1)
command += ["-m", "tests.test_fp8"] command += ["-m", "tests.test_fp8", "--test_te"]
run_command(command) run_command(command)
@require_deepspeed @require_deepspeed
@ -116,7 +144,7 @@ class TestTransformerEngine(unittest.TestCase):
command = get_launch_command( command = get_launch_command(
num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config
) )
command += ["-m", "tests.test_fp8"] command += ["-m", "tests.test_fp8", "--test_te"]
run_command(command) run_command(command)
@ -125,13 +153,31 @@ class TestTransformerEngine(unittest.TestCase):
class TestTorchAO(unittest.TestCase): class TestTorchAO(unittest.TestCase):
def test_can_prepare_model_single_accelerator(self): def test_can_prepare_model_single_accelerator(self):
command = get_launch_command(num_processes=1, monitor_interval=0.1) command = get_launch_command(num_processes=1, monitor_interval=0.1)
command += ["-m", "tests.test_fp8"] command += ["-m", "tests.test_fp8", "--test_ao"]
run_command(command) run_command(command)
def test_can_prepare_model_single_gpu_from_config(self):
with tempfile.TemporaryDirectory() as dir_name:
config_file = Path(dir_name) / "config.yaml"
config_file.write_text(
textwrap.dedent(
"""
distributed_type: "NO"
num_processes: 1
mixed_precision: fp8
fp8_config:
backend: AO
"""
)
)
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
command += ["-m", "tests.test_fp8", "--test_ao", "--from_config"]
run_command(command)
@require_multi_device @require_multi_device
def test_can_prepare_model_multi_accelerator(self): def test_can_prepare_model_multi_accelerator(self):
command = get_launch_command(num_processes=2, monitor_interval=0.1) command = get_launch_command(num_processes=2, monitor_interval=0.1)
command += ["-m", "tests.test_fp8"] command += ["-m", "tests.test_fp8", "--test_ao"]
run_command(command) run_command(command)
@require_deepspeed @require_deepspeed
@ -163,16 +209,26 @@ class TestTorchAO(unittest.TestCase):
command = get_launch_command( command = get_launch_command(
num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config
) )
command += ["-m", "tests.test_fp8"] command += ["-m", "tests.test_fp8", "--test_ao"]
run_command(command) run_command(command)
if __name__ == "__main__": if __name__ == "__main__":
# TE suite # TE suite
if is_transformer_engine_available(): parser = argparse.ArgumentParser()
can_convert_te_model() parser.add_argument("--test_te", action="store_true", default=False)
parser.add_argument("--test_ao", action="store_true", default=False)
parser.add_argument("--from_config", action="store_true", default=False)
args = parser.parse_args()
if not args.test_te and not args.test_ao:
raise ValueError("Must specify at least one of --test_te or --test_ao")
if args.test_te:
can_convert_te_model(args.from_config)
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true": if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
maintain_proper_deepspeed_config(int(os.environ.get("ZERO_STAGE"))) maintain_proper_deepspeed_config(int(os.environ.get("ZERO_STAGE")))
# AO suite # AO suite
if is_torchao_available(): if args.test_ao:
can_convert_ao_model() can_convert_ao_model(args.from_config)