mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
Fix FP8 tests, enable FP8 to be used without direct Accelerator()
configuring (#3677)
* single-gpu tests passing * install deepspeed in fp8 container * revert mixed_precision check
This commit is contained in:
@ -7,7 +7,7 @@ RUN pip install transformers evaluate datasets
|
|||||||
RUN git clone https://github.com/huggingface/accelerate.git
|
RUN git clone https://github.com/huggingface/accelerate.git
|
||||||
|
|
||||||
RUN cd accelerate && \
|
RUN cd accelerate && \
|
||||||
pip install -e . && \
|
pip install -e .[deepspeed] && \
|
||||||
cd benchmarks/fp8
|
cd benchmarks/fp8
|
||||||
|
|
||||||
RUN /bin/bash
|
RUN /bin/bash
|
||||||
|
@ -11,8 +11,8 @@ fp8_config:
|
|||||||
fp8_format: E4M3
|
fp8_format: E4M3
|
||||||
interval: 1
|
interval: 1
|
||||||
margin: 0
|
margin: 0
|
||||||
override_linear_precision: (false, false, false)
|
override_linear_precision: [false, false, false]
|
||||||
# Generally this should always be set to `false` to have the most realistic fp8 eval performance
|
# Generally this should always be set to `false` to have the most realistic fp8 eval performance
|
||||||
use_autocast_during_eval: false
|
use_autocast_during_eval: false
|
||||||
# If using MS-AMP, we ignore all of the prior and set a opt_level
|
# If using MS-AMP, we ignore all of the prior and set a opt_level
|
||||||
#opt_level: O1
|
#opt_level: O1
|
||||||
|
@ -33,6 +33,8 @@ import torch
|
|||||||
import torch.utils.hooks as hooks
|
import torch.utils.hooks as hooks
|
||||||
from huggingface_hub import split_torch_state_dict_into_shards
|
from huggingface_hub import split_torch_state_dict_into_shards
|
||||||
|
|
||||||
|
from accelerate.utils.dataclasses import FP8BackendType
|
||||||
|
|
||||||
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
|
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
|
||||||
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
|
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
@ -301,6 +303,7 @@ class Accelerator:
|
|||||||
self.project_configuration = ProjectConfiguration(project_dir=project_dir)
|
self.project_configuration = ProjectConfiguration(project_dir=project_dir)
|
||||||
if project_dir is not None and self.project_dir is None:
|
if project_dir is not None and self.project_dir is None:
|
||||||
self.project_configuration.set_directories(project_dir)
|
self.project_configuration.set_directories(project_dir)
|
||||||
|
|
||||||
if mixed_precision is not None:
|
if mixed_precision is not None:
|
||||||
mixed_precision = str(mixed_precision)
|
mixed_precision = str(mixed_precision)
|
||||||
if mixed_precision not in PrecisionType:
|
if mixed_precision not in PrecisionType:
|
||||||
@ -458,27 +461,34 @@ class Accelerator:
|
|||||||
|
|
||||||
# Check for automatic FP8 recipe creation
|
# Check for automatic FP8 recipe creation
|
||||||
if self.fp8_enabled and not self.has_fp8_handler:
|
if self.fp8_enabled and not self.has_fp8_handler:
|
||||||
# Prioritize AO -> TE -> MSAMP
|
if self.fp8_backend == FP8BackendType.AO:
|
||||||
if is_torchao_available():
|
|
||||||
logger.info("Found `torchao` installed, using it for FP8 training.")
|
|
||||||
self.ao_recipe_handler = AORecipeKwargs()
|
self.ao_recipe_handler = AORecipeKwargs()
|
||||||
elif is_transformer_engine_available():
|
elif self.fp8_backend == FP8BackendType.TE:
|
||||||
logger.info("Found `transformer-engine` installed, using it for FP8 training.")
|
|
||||||
self.te_recipe_handler = TERecipeKwargs()
|
self.te_recipe_handler = TERecipeKwargs()
|
||||||
elif is_msamp_available():
|
elif self.fp8_backend == FP8BackendType.MSAMP:
|
||||||
logger.info("Found `msamp` installed, using it for FP8 training.")
|
|
||||||
self.msamp_recipe_handler = MSAMPRecipeKwargs()
|
self.msamp_recipe_handler = MSAMPRecipeKwargs()
|
||||||
else:
|
elif self.fp8_backend == FP8BackendType.NO:
|
||||||
raise ImportError(
|
# Prioritize AO -> TE -> MSAMP
|
||||||
"Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. "
|
if is_torchao_available():
|
||||||
"Valid backends are: `torchao`, `transformer-engine`, and `msamp`."
|
logger.info("Found `torchao` installed, using it for FP8 training.")
|
||||||
)
|
self.ao_recipe_handler = AORecipeKwargs()
|
||||||
|
elif is_transformer_engine_available():
|
||||||
|
logger.info("Found `transformer-engine` installed, using it for FP8 training.")
|
||||||
|
self.te_recipe_handler = TERecipeKwargs()
|
||||||
|
elif is_msamp_available():
|
||||||
|
logger.info("Found `msamp` installed, using it for FP8 training.")
|
||||||
|
self.msamp_recipe_handler = MSAMPRecipeKwargs()
|
||||||
|
else:
|
||||||
|
raise ImportError(
|
||||||
|
"Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. "
|
||||||
|
"Valid backends are: `torchao`, `transformer-engine`, and `msamp`."
|
||||||
|
)
|
||||||
self.has_fp8_handler = True
|
self.has_fp8_handler = True
|
||||||
|
|
||||||
self.delayed_fp8_autocast = False
|
self.delayed_fp8_autocast = False
|
||||||
if self.has_fp8_handler:
|
if self.has_fp8_handler:
|
||||||
# We already check if FP8 is available during `self.state`
|
# We already check if FP8 is available during `self.state`
|
||||||
if mixed_precision != "fp8" and (
|
if not self.fp8_enabled and (
|
||||||
self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED)
|
self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED)
|
||||||
):
|
):
|
||||||
raise ValueError("Passing in an FP8 configuration requires setting `mixed_precision='fp8'`.")
|
raise ValueError("Passing in an FP8 configuration requires setting `mixed_precision='fp8'`.")
|
||||||
@ -488,7 +498,11 @@ class Accelerator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO: S1ro - this is probably gonna be a problem with other fp8 backends too
|
# TODO: S1ro - this is probably gonna be a problem with other fp8 backends too
|
||||||
if self.fp8_backend == "AO" and self.state.fsdp_plugin.cpu_ram_efficient_loading:
|
if (
|
||||||
|
self.fp8_backend == FP8BackendType.AO
|
||||||
|
and self.state.distributed_type == DistributedType.FSDP
|
||||||
|
and self.state.fsdp_plugin.cpu_ram_efficient_loading
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"torchao with FSDP2 and cpu_ram_efficient_loading is not supported, setting `cpu_ram_efficient_loading` to False will fix the issue and work as intended."
|
"torchao with FSDP2 and cpu_ram_efficient_loading is not supported, setting `cpu_ram_efficient_loading` to False will fix the issue and work as intended."
|
||||||
)
|
)
|
||||||
@ -572,7 +586,7 @@ class Accelerator:
|
|||||||
elif self.fp8_enabled:
|
elif self.fp8_enabled:
|
||||||
# We always enable `native_amp` for FP8
|
# We always enable `native_amp` for FP8
|
||||||
self.native_amp = True
|
self.native_amp = True
|
||||||
if self.fp8_backend == "MSAMP":
|
if self.fp8_backend == FP8BackendType.MSAMP:
|
||||||
if self.distributed_type == DistributedType.FSDP:
|
if self.distributed_type == DistributedType.FSDP:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"`accelerate` + `MS-AMP` + `FSDP` is not supported at this time. "
|
"`accelerate` + `MS-AMP` + `FSDP` is not supported at this time. "
|
||||||
@ -1419,9 +1433,9 @@ class Accelerator:
|
|||||||
"You are using lower version of PyTorch(< 2.7.0) with ipex acceleration on Intel CPU or XPU, Intel has upstreamed most of the optimizations into stock PyTorch from 2.7.0, we enourage you to install the latest stock PyTorch and enjoy the out-of-experience on Intel CPU/XPU."
|
"You are using lower version of PyTorch(< 2.7.0) with ipex acceleration on Intel CPU or XPU, Intel has upstreamed most of the optimizations into stock PyTorch from 2.7.0, we enourage you to install the latest stock PyTorch and enjoy the out-of-experience on Intel CPU/XPU."
|
||||||
)
|
)
|
||||||
args = self._prepare_ipex(*args)
|
args = self._prepare_ipex(*args)
|
||||||
if self.fp8_backend == "TE":
|
if self.fp8_backend == FP8BackendType.TE:
|
||||||
args = self._prepare_te(*args)
|
args = self._prepare_te(*args)
|
||||||
elif self.fp8_backend == "AO":
|
elif self.fp8_backend == FP8BackendType.AO:
|
||||||
args = self._prepare_ao(*args)
|
args = self._prepare_ao(*args)
|
||||||
if self.distributed_type == DistributedType.DEEPSPEED:
|
if self.distributed_type == DistributedType.DEEPSPEED:
|
||||||
result = self._prepare_deepspeed(*args)
|
result = self._prepare_deepspeed(*args)
|
||||||
@ -1430,7 +1444,7 @@ class Accelerator:
|
|||||||
elif self.is_fsdp2:
|
elif self.is_fsdp2:
|
||||||
result = self._prepare_fsdp2(*args)
|
result = self._prepare_fsdp2(*args)
|
||||||
else:
|
else:
|
||||||
if self.fp8_backend == "MSAMP":
|
if self.fp8_backend == FP8BackendType.MSAMP:
|
||||||
args, device_placement = self._prepare_msamp(*args, device_placement=device_placement)
|
args, device_placement = self._prepare_msamp(*args, device_placement=device_placement)
|
||||||
result = tuple(
|
result = tuple(
|
||||||
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
|
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
|
||||||
@ -1570,7 +1584,7 @@ class Accelerator:
|
|||||||
model._original_forward = model.forward
|
model._original_forward = model.forward
|
||||||
autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)
|
autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)
|
||||||
# NOTE: MS-AMP adds `__func__` already to `model.forward`, so we should always use `model.forward`
|
# NOTE: MS-AMP adds `__func__` already to `model.forward`, so we should always use `model.forward`
|
||||||
if self.fp8_backend == "MSAMP" or not hasattr(model.forward, "__func__"):
|
if self.fp8_backend == FP8BackendType.MSAMP or not hasattr(model.forward, "__func__"):
|
||||||
model_forward_func = model.forward
|
model_forward_func = model.forward
|
||||||
model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func))
|
model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func))
|
||||||
else:
|
else:
|
||||||
@ -1580,7 +1594,7 @@ class Accelerator:
|
|||||||
model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)
|
model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)
|
||||||
|
|
||||||
# We prepare TE after, allowing for bf16 autocast to happen first
|
# We prepare TE after, allowing for bf16 autocast to happen first
|
||||||
if self.fp8_backend == "TE" and not self.delayed_fp8_autocast:
|
if self.fp8_backend == FP8BackendType.TE and not self.delayed_fp8_autocast:
|
||||||
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
|
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
|
||||||
|
|
||||||
if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
|
if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
|
||||||
@ -1806,7 +1820,7 @@ class Accelerator:
|
|||||||
elif self.distributed_type == DistributedType.XLA and self.state.fork_launched:
|
elif self.distributed_type == DistributedType.XLA and self.state.fork_launched:
|
||||||
model = xmp.MpModelWrapper(model).to(self.device)
|
model = xmp.MpModelWrapper(model).to(self.device)
|
||||||
# Now we can apply the FP8 autocast
|
# Now we can apply the FP8 autocast
|
||||||
if self.fp8_backend == "TE" and self.delayed_fp8_autocast:
|
if self.fp8_backend == FP8BackendType.TE and self.delayed_fp8_autocast:
|
||||||
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
|
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
|
||||||
# torch.compile should be called last and only if the model isn't already compiled
|
# torch.compile should be called last and only if the model isn't already compiled
|
||||||
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
|
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
|
||||||
@ -1884,7 +1898,7 @@ class Accelerator:
|
|||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
ds_initialize = deepspeed.initialize
|
ds_initialize = deepspeed.initialize
|
||||||
if self.fp8_backend == "MSAMP":
|
if self.fp8_backend == FP8BackendType.MSAMP:
|
||||||
# MS-AMP requires DeepSpeed patches
|
# MS-AMP requires DeepSpeed patches
|
||||||
from msamp import deepspeed as msamp_deepspeed
|
from msamp import deepspeed as msamp_deepspeed
|
||||||
|
|
||||||
@ -2022,7 +2036,7 @@ class Accelerator:
|
|||||||
|
|
||||||
if model is not None:
|
if model is not None:
|
||||||
# If we are using FP8, we need to apply the autowrap now
|
# If we are using FP8, we need to apply the autowrap now
|
||||||
if self.fp8_backend == "TE":
|
if self.fp8_backend == FP8BackendType.TE:
|
||||||
model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
|
model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
|
||||||
# if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules
|
# if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules
|
||||||
deepspeed_plugin.set_moe_leaf_modules(model)
|
deepspeed_plugin.set_moe_leaf_modules(model)
|
||||||
@ -2479,7 +2493,7 @@ class Accelerator:
|
|||||||
device_placement = self.device_placement
|
device_placement = self.device_placement
|
||||||
# NOTE: Special case with MS-AMP we do *not* pass in the scaler explicitly to the `AcceleratedOptimizer`,
|
# NOTE: Special case with MS-AMP we do *not* pass in the scaler explicitly to the `AcceleratedOptimizer`,
|
||||||
# Their optimizer handles it for us.
|
# Their optimizer handles it for us.
|
||||||
scaler = None if self.fp8_backend == "MSAMP" else self.scaler
|
scaler = None if self.fp8_backend == FP8BackendType.MSAMP else self.scaler
|
||||||
optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=scaler)
|
optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=scaler)
|
||||||
self._optimizers.append(optimizer)
|
self._optimizers.append(optimizer)
|
||||||
return optimizer
|
return optimizer
|
||||||
@ -3668,7 +3682,7 @@ class Accelerator:
|
|||||||
|
|
||||||
# we need this bit as `WeightWithDynamic...` returns 0 when `data_ptr()` is called,
|
# we need this bit as `WeightWithDynamic...` returns 0 when `data_ptr()` is called,
|
||||||
# the underlying pointer is actually hidden in `_tensor` attribute
|
# the underlying pointer is actually hidden in `_tensor` attribute
|
||||||
if self.fp8_backend == "AO":
|
if self.fp8_backend == FP8BackendType.AO:
|
||||||
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
|
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
|
||||||
|
|
||||||
accessor_mapping[WeightWithDynamicFloat8CastTensor] = "_tensor"
|
accessor_mapping[WeightWithDynamicFloat8CastTensor] = "_tensor"
|
||||||
@ -3977,17 +3991,18 @@ class Accelerator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fp8_backend(self):
|
def fp8_backend(self) -> FP8BackendType:
|
||||||
"Returns the configured backend for training in FP8"
|
"Returns the configured backend for training in FP8"
|
||||||
if self.has_fp8_handler:
|
if self.has_fp8_handler:
|
||||||
if self.fp8_recipe_handler is not None:
|
if self.fp8_recipe_handler is not None:
|
||||||
return self.fp8_recipe_handler.backend
|
return FP8BackendType(self.fp8_recipe_handler.backend)
|
||||||
elif self.ao_recipe_handler is not None:
|
elif self.ao_recipe_handler is not None:
|
||||||
return "AO"
|
return FP8BackendType.AO
|
||||||
elif self.te_recipe_handler is not None:
|
elif self.te_recipe_handler is not None:
|
||||||
return "TE"
|
return FP8BackendType.TE
|
||||||
elif self.msamp_recipe_handler is not None:
|
elif self.msamp_recipe_handler is not None:
|
||||||
return "MSAMP"
|
return FP8BackendType.MSAMP
|
||||||
elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp:
|
elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp:
|
||||||
return "MSAMP"
|
return FP8BackendType.MSAMP
|
||||||
return None
|
|
||||||
|
return FP8BackendType(parse_choice_from_env("ACCELERATE_FP8_BACKEND", "NO"))
|
||||||
|
@ -616,8 +616,10 @@ class FP8BackendType(str, enum.Enum):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Subclassing str as well as Enum allows the `FP8BackendType` to be JSON-serializable out of the box.
|
# Subclassing str as well as Enum allows the `FP8BackendType` to be JSON-serializable out of the box.
|
||||||
|
NO = "NO"
|
||||||
TE = "TE"
|
TE = "TE"
|
||||||
MSAMP = "MSAMP"
|
MSAMP = "MSAMP"
|
||||||
|
AO = "AO"
|
||||||
|
|
||||||
|
|
||||||
class ComputeEnvironment(str, enum.Enum):
|
class ComputeEnvironment(str, enum.Enum):
|
||||||
|
@ -89,9 +89,9 @@ def setup_fp8_env(args: argparse.Namespace, current_env: dict[str, str]):
|
|||||||
value = getattr(args, arg)
|
value = getattr(args, arg)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
if arg == "fp8_override_linear_precision":
|
if arg == "fp8_override_linear_precision":
|
||||||
current_env[prefix + "FP8_OVERRIDE_FPROP"] = value[0]
|
current_env[prefix + "FP8_OVERRIDE_FPROP"] = str(value[0])
|
||||||
current_env[prefix + "FP8_OVERRIDE_DGRAD"] = value[1]
|
current_env[prefix + "FP8_OVERRIDE_DGRAD"] = str(value[1])
|
||||||
current_env[prefix + "FP8_OVERRIDE_WGRAD"] = value[2]
|
current_env[prefix + "FP8_OVERRIDE_WGRAD"] = str(value[2])
|
||||||
else:
|
else:
|
||||||
current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
|
current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
|
||||||
return current_env
|
return current_env
|
||||||
|
@ -12,9 +12,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
import textwrap
|
||||||
import unittest
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -32,16 +36,18 @@ from accelerate.test_utils import (
|
|||||||
from accelerate.test_utils.testing import require_deepspeed, run_command
|
from accelerate.test_utils.testing import require_deepspeed, run_command
|
||||||
from accelerate.utils import (
|
from accelerate.utils import (
|
||||||
AORecipeKwargs,
|
AORecipeKwargs,
|
||||||
FP8RecipeKwargs,
|
TERecipeKwargs,
|
||||||
has_ao_layers,
|
has_ao_layers,
|
||||||
has_transformer_engine_layers,
|
has_transformer_engine_layers,
|
||||||
is_torchao_available,
|
|
||||||
is_transformer_engine_available,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def can_convert_te_model():
|
def can_convert_te_model(from_config=False):
|
||||||
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [FP8RecipeKwargs(backend="TE")]}
|
if not from_config:
|
||||||
|
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [TERecipeKwargs()]}
|
||||||
|
else:
|
||||||
|
accelerator_kwargs = {}
|
||||||
|
|
||||||
accelerator = Accelerator(**accelerator_kwargs)
|
accelerator = Accelerator(**accelerator_kwargs)
|
||||||
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
|
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
|
||||||
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.Linear(32, 16))
|
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.Linear(32, 16))
|
||||||
@ -58,10 +64,14 @@ def maintain_proper_deepspeed_config(expected_version):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def can_convert_ao_model():
|
def can_convert_ao_model(from_config=False):
|
||||||
from transformers import AutoModelForSequenceClassification
|
from transformers import AutoModelForSequenceClassification
|
||||||
|
|
||||||
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [AORecipeKwargs()]}
|
if not from_config:
|
||||||
|
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [AORecipeKwargs()]}
|
||||||
|
else:
|
||||||
|
accelerator_kwargs = {}
|
||||||
|
|
||||||
accelerator = Accelerator(**accelerator_kwargs)
|
accelerator = Accelerator(**accelerator_kwargs)
|
||||||
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
|
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
|
||||||
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased")
|
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased")
|
||||||
@ -78,13 +88,31 @@ def can_convert_ao_model():
|
|||||||
class TestTransformerEngine(unittest.TestCase):
|
class TestTransformerEngine(unittest.TestCase):
|
||||||
def test_can_prepare_model_single_gpu(self):
|
def test_can_prepare_model_single_gpu(self):
|
||||||
command = get_launch_command(num_processes=1, monitor_interval=0.1)
|
command = get_launch_command(num_processes=1, monitor_interval=0.1)
|
||||||
command += ["-m", "tests.test_fp8"]
|
command += ["-m", "tests.test_fp8", "--test_te"]
|
||||||
run_command(command)
|
run_command(command)
|
||||||
|
|
||||||
|
def test_can_prepare_model_single_gpu_from_config(self):
|
||||||
|
with tempfile.TemporaryDirectory() as dir_name:
|
||||||
|
config_file = Path(dir_name) / "config.yaml"
|
||||||
|
config_file.write_text(
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
distributed_type: "NO"
|
||||||
|
num_processes: 1
|
||||||
|
mixed_precision: fp8
|
||||||
|
fp8_config:
|
||||||
|
backend: TE
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
|
||||||
|
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
|
||||||
|
run_command(command)
|
||||||
|
|
||||||
@require_multi_device
|
@require_multi_device
|
||||||
def test_can_prepare_model_multi_gpu(self):
|
def test_can_prepare_model_multi_gpu(self):
|
||||||
command = get_launch_command(num_processes=2, monitor_interval=0.1)
|
command = get_launch_command(num_processes=2, monitor_interval=0.1)
|
||||||
command += ["-m", "tests.test_fp8"]
|
command += ["-m", "tests.test_fp8", "--test_te"]
|
||||||
run_command(command)
|
run_command(command)
|
||||||
|
|
||||||
@require_deepspeed
|
@require_deepspeed
|
||||||
@ -116,7 +144,7 @@ class TestTransformerEngine(unittest.TestCase):
|
|||||||
command = get_launch_command(
|
command = get_launch_command(
|
||||||
num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config
|
num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config
|
||||||
)
|
)
|
||||||
command += ["-m", "tests.test_fp8"]
|
command += ["-m", "tests.test_fp8", "--test_te"]
|
||||||
run_command(command)
|
run_command(command)
|
||||||
|
|
||||||
|
|
||||||
@ -125,13 +153,31 @@ class TestTransformerEngine(unittest.TestCase):
|
|||||||
class TestTorchAO(unittest.TestCase):
|
class TestTorchAO(unittest.TestCase):
|
||||||
def test_can_prepare_model_single_accelerator(self):
|
def test_can_prepare_model_single_accelerator(self):
|
||||||
command = get_launch_command(num_processes=1, monitor_interval=0.1)
|
command = get_launch_command(num_processes=1, monitor_interval=0.1)
|
||||||
command += ["-m", "tests.test_fp8"]
|
command += ["-m", "tests.test_fp8", "--test_ao"]
|
||||||
run_command(command)
|
run_command(command)
|
||||||
|
|
||||||
|
def test_can_prepare_model_single_gpu_from_config(self):
|
||||||
|
with tempfile.TemporaryDirectory() as dir_name:
|
||||||
|
config_file = Path(dir_name) / "config.yaml"
|
||||||
|
config_file.write_text(
|
||||||
|
textwrap.dedent(
|
||||||
|
"""
|
||||||
|
distributed_type: "NO"
|
||||||
|
num_processes: 1
|
||||||
|
mixed_precision: fp8
|
||||||
|
fp8_config:
|
||||||
|
backend: AO
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
|
||||||
|
command += ["-m", "tests.test_fp8", "--test_ao", "--from_config"]
|
||||||
|
run_command(command)
|
||||||
|
|
||||||
@require_multi_device
|
@require_multi_device
|
||||||
def test_can_prepare_model_multi_accelerator(self):
|
def test_can_prepare_model_multi_accelerator(self):
|
||||||
command = get_launch_command(num_processes=2, monitor_interval=0.1)
|
command = get_launch_command(num_processes=2, monitor_interval=0.1)
|
||||||
command += ["-m", "tests.test_fp8"]
|
command += ["-m", "tests.test_fp8", "--test_ao"]
|
||||||
run_command(command)
|
run_command(command)
|
||||||
|
|
||||||
@require_deepspeed
|
@require_deepspeed
|
||||||
@ -163,16 +209,26 @@ class TestTorchAO(unittest.TestCase):
|
|||||||
command = get_launch_command(
|
command = get_launch_command(
|
||||||
num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config
|
num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config
|
||||||
)
|
)
|
||||||
command += ["-m", "tests.test_fp8"]
|
command += ["-m", "tests.test_fp8", "--test_ao"]
|
||||||
run_command(command)
|
run_command(command)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# TE suite
|
# TE suite
|
||||||
if is_transformer_engine_available():
|
parser = argparse.ArgumentParser()
|
||||||
can_convert_te_model()
|
parser.add_argument("--test_te", action="store_true", default=False)
|
||||||
|
parser.add_argument("--test_ao", action="store_true", default=False)
|
||||||
|
parser.add_argument("--from_config", action="store_true", default=False)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.test_te and not args.test_ao:
|
||||||
|
raise ValueError("Must specify at least one of --test_te or --test_ao")
|
||||||
|
|
||||||
|
if args.test_te:
|
||||||
|
can_convert_te_model(args.from_config)
|
||||||
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
||||||
maintain_proper_deepspeed_config(int(os.environ.get("ZERO_STAGE")))
|
maintain_proper_deepspeed_config(int(os.environ.get("ZERO_STAGE")))
|
||||||
|
|
||||||
# AO suite
|
# AO suite
|
||||||
if is_torchao_available():
|
if args.test_ao:
|
||||||
can_convert_ao_model()
|
can_convert_ao_model(args.from_config)
|
||||||
|
Reference in New Issue
Block a user