mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
Do not import transformer_engine
on import (#3056)
* Do not import `transformer_engine` on import * fix message * add test * Update test_imports.py * resolve comment 1/2 * resolve comment 1.5/2 * lint * more lint * Update tests/test_imports.py Co-authored-by: Zach Mueller <muellerzr@gmail.com> * fmt --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
@ -20,8 +20,7 @@ from .imports import is_fp8_available
|
||||
from .operations import GatheredParameters
|
||||
|
||||
|
||||
if is_fp8_available():
|
||||
import transformer_engine.pytorch as te
|
||||
# Do not import `transformer_engine` at package level to avoid potential issues
|
||||
|
||||
|
||||
def convert_model(model, to_transformer_engine=True, _convert_linear=True, _convert_ln=True):
|
||||
@ -30,6 +29,8 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _conv
|
||||
"""
|
||||
if not is_fp8_available():
|
||||
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
|
||||
import transformer_engine.pytorch as te
|
||||
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear:
|
||||
has_bias = module.bias is not None
|
||||
@ -87,6 +88,8 @@ def has_transformer_engine_layers(model):
|
||||
"""
|
||||
if not is_fp8_available():
|
||||
raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.")
|
||||
import transformer_engine.pytorch as te
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, (te.LayerNorm, te.Linear, te.TransformerLayer)):
|
||||
return True
|
||||
@ -98,6 +101,8 @@ def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False):
|
||||
Wrapper for a model's forward method to apply FP8 autocast. Is context aware, meaning that by default it will
|
||||
disable FP8 autocast during eval mode, which is generally better for more accurate metrics.
|
||||
"""
|
||||
if not is_fp8_available():
|
||||
raise ImportError("Using `contextual_fp8_autocast` requires transformer_engine to be installed.")
|
||||
from transformer_engine.pytorch import fp8_autocast
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@ -115,7 +120,8 @@ def apply_fp8_autowrap(model, fp8_recipe_handler):
|
||||
"""
|
||||
Applies FP8 context manager to the model's forward method
|
||||
"""
|
||||
# Import here to keep base imports fast
|
||||
if not is_fp8_available():
|
||||
raise ImportError("Using `apply_fp8_autowrap` requires transformer_engine to be installed.")
|
||||
import transformer_engine.common.recipe as te_recipe
|
||||
|
||||
kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}
|
||||
|
@ -12,7 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from accelerate.test_utils import require_transformer_engine
|
||||
from accelerate.test_utils.testing import TempDirTestCase, require_import_timer
|
||||
from accelerate.utils import is_import_timer_available
|
||||
|
||||
@ -31,7 +33,7 @@ def convert_list_to_string(data):
|
||||
|
||||
|
||||
def run_import_time(command: str):
|
||||
output = subprocess.run(["python3", "-X", "importtime", "-c", command], capture_output=True, text=True)
|
||||
output = subprocess.run([sys.executable, "-X", "importtime", "-c", command], capture_output=True, text=True)
|
||||
return output.stderr
|
||||
|
||||
|
||||
@ -81,3 +83,18 @@ class ImportSpeedTester(TempDirTestCase):
|
||||
paths_above_threshold = get_paths_above_threshold(sorted_data, 0.05, max_depth=7)
|
||||
err_msg += f"\n{convert_list_to_string(paths_above_threshold)}"
|
||||
self.assertLess(pct_more, 20, err_msg)
|
||||
|
||||
|
||||
@require_transformer_engine
|
||||
class LazyImportTester(TempDirTestCase):
|
||||
"""
|
||||
Test suite which checks if specific packages are lazy-loaded.
|
||||
|
||||
Eager-import will trigger circular import in some case,
|
||||
e.g. in huggingface/accelerate#3056.
|
||||
"""
|
||||
|
||||
def test_te_import(self):
|
||||
output = run_import_time("import accelerate, accelerate.utils.transformer_engine")
|
||||
|
||||
self.assertFalse(" transformer_engine" in output, "`transformer_engine` should not be imported on import")
|
||||
|
Reference in New Issue
Block a user