Do not import transformer_engine on import (#3056)

* Do not import `transformer_engine` on import

* fix message

* add test

* Update test_imports.py

* resolve comment 1/2

* resolve comment 1.5/2

* lint

* more lint

* Update tests/test_imports.py

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* fmt

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
Yichen Yan
2024-08-28 21:06:13 +08:00
committed by GitHub
parent 939ce400cb
commit 3fcc9461c4
2 changed files with 27 additions and 4 deletions

View File

@ -20,8 +20,7 @@ from .imports import is_fp8_available
from .operations import GatheredParameters
if is_fp8_available():
import transformer_engine.pytorch as te
# Do not import `transformer_engine` at package level to avoid potential issues
def convert_model(model, to_transformer_engine=True, _convert_linear=True, _convert_ln=True):
@ -30,6 +29,8 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _conv
"""
if not is_fp8_available():
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
import transformer_engine.pytorch as te
for name, module in model.named_children():
if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear:
has_bias = module.bias is not None
@ -87,6 +88,8 @@ def has_transformer_engine_layers(model):
"""
if not is_fp8_available():
raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.")
import transformer_engine.pytorch as te
for m in model.modules():
if isinstance(m, (te.LayerNorm, te.Linear, te.TransformerLayer)):
return True
@ -98,6 +101,8 @@ def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False):
Wrapper for a model's forward method to apply FP8 autocast. Is context aware, meaning that by default it will
disable FP8 autocast during eval mode, which is generally better for more accurate metrics.
"""
if not is_fp8_available():
raise ImportError("Using `contextual_fp8_autocast` requires transformer_engine to be installed.")
from transformer_engine.pytorch import fp8_autocast
def forward(self, *args, **kwargs):
@ -115,7 +120,8 @@ def apply_fp8_autowrap(model, fp8_recipe_handler):
"""
Applies FP8 context manager to the model's forward method
"""
# Import here to keep base imports fast
if not is_fp8_available():
raise ImportError("Using `apply_fp8_autowrap` requires transformer_engine to be installed.")
import transformer_engine.common.recipe as te_recipe
kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}

View File

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import sys
from accelerate.test_utils import require_transformer_engine
from accelerate.test_utils.testing import TempDirTestCase, require_import_timer
from accelerate.utils import is_import_timer_available
@ -31,7 +33,7 @@ def convert_list_to_string(data):
def run_import_time(command: str):
output = subprocess.run(["python3", "-X", "importtime", "-c", command], capture_output=True, text=True)
output = subprocess.run([sys.executable, "-X", "importtime", "-c", command], capture_output=True, text=True)
return output.stderr
@ -81,3 +83,18 @@ class ImportSpeedTester(TempDirTestCase):
paths_above_threshold = get_paths_above_threshold(sorted_data, 0.05, max_depth=7)
err_msg += f"\n{convert_list_to_string(paths_above_threshold)}"
self.assertLess(pct_more, 20, err_msg)
@require_transformer_engine
class LazyImportTester(TempDirTestCase):
"""
Test suite which checks if specific packages are lazy-loaded.
Eager-import will trigger circular import in some case,
e.g. in huggingface/accelerate#3056.
"""
def test_te_import(self):
output = run_import_time("import accelerate, accelerate.utils.transformer_engine")
self.assertFalse(" transformer_engine" in output, "`transformer_engine` should not be imported on import")