mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
[kernels] refactor function kernel calling (#41577)
* refactor function kernel callling * nit * don't pass the mapping * use _kernels_available * rm import
This commit is contained in:
@ -14,12 +14,16 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from types import ModuleType
|
||||
from typing import Optional, Union
|
||||
|
||||
from ..modeling_flash_attention_utils import lazy_import_flash_attention
|
||||
from ..utils import logging
|
||||
from .flash_attention import flash_attention_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
try:
|
||||
from kernels import (
|
||||
Device,
|
||||
@ -158,6 +162,13 @@ except ImportError:
|
||||
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
|
||||
|
||||
|
||||
_HUB_KERNEL_MAPPING: dict[str, str] = {
|
||||
"causal-conv1d": "kernels-community/causal-conv1d",
|
||||
}
|
||||
|
||||
_KERNEL_MODULE_MAPPING: dict[str, Optional[ModuleType]] = {}
|
||||
|
||||
|
||||
def is_kernel(attn_implementation: Optional[str]) -> bool:
|
||||
"""Check whether `attn_implementation` matches a kernel pattern from the hub."""
|
||||
return (
|
||||
@ -220,9 +231,53 @@ def load_and_register_attn_kernel(attn_implementation: str, attention_wrapper: O
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
|
||||
|
||||
|
||||
def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]] = _KERNEL_MODULE_MAPPING):
|
||||
if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType):
|
||||
return mapping[kernel_name]
|
||||
if kernel_name not in _HUB_KERNEL_MAPPING:
|
||||
logger.warning(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING")
|
||||
mapping[kernel_name] = None
|
||||
return None
|
||||
if _kernels_available:
|
||||
from kernels import get_kernel
|
||||
|
||||
try:
|
||||
kernel = get_kernel(_HUB_KERNEL_MAPPING[kernel_name])
|
||||
mapping[kernel_name] = kernel
|
||||
except FileNotFoundError:
|
||||
mapping[kernel_name] = None
|
||||
|
||||
else:
|
||||
# Try to import is_{kernel_name}_available from ..utils
|
||||
import importlib
|
||||
|
||||
new_kernel_name = kernel_name.replace("-", "_")
|
||||
func_name = f"is_{new_kernel_name}_available"
|
||||
|
||||
try:
|
||||
utils_mod = importlib.import_module("..utils.import_utils", __package__)
|
||||
is_kernel_available = getattr(utils_mod, func_name, None)
|
||||
except Exception:
|
||||
is_kernel_available = None
|
||||
|
||||
if callable(is_kernel_available) and is_kernel_available():
|
||||
# Try to import the module "{kernel_name}" from parent package level
|
||||
try:
|
||||
module = importlib.import_module(f"{kernel_name}")
|
||||
mapping[kernel_name] = module
|
||||
return module
|
||||
except Exception:
|
||||
mapping[kernel_name] = None
|
||||
else:
|
||||
mapping[kernel_name] = None
|
||||
|
||||
return mapping[kernel_name]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LayerRepository",
|
||||
"use_kernel_forward_from_hub",
|
||||
"register_kernel_mapping",
|
||||
"replace_kernel_forward_from_hub",
|
||||
"lazy_load_kernel",
|
||||
]
|
||||
|
@ -30,12 +30,11 @@ from torch.nn import CrossEntropyLoss
|
||||
from ...activations import ACT2FN
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations.hub_kernels import lazy_load_kernel
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from ...utils.import_utils import (
|
||||
is_causal_conv1d_available,
|
||||
is_kernels_available,
|
||||
is_mamba_ssm_available,
|
||||
is_mambapy_available,
|
||||
)
|
||||
@ -162,33 +161,6 @@ class FalconMambaCache:
|
||||
self.ssm_states[layer_idx].zero_()
|
||||
|
||||
|
||||
def _lazy_load_causal_conv1d():
|
||||
global _causal_conv1d_cache
|
||||
if _causal_conv1d_cache is not None:
|
||||
return _causal_conv1d_cache
|
||||
|
||||
if is_kernels_available():
|
||||
from kernels import get_kernel
|
||||
|
||||
try:
|
||||
_causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
|
||||
except FileNotFoundError:
|
||||
# no kernel binary match, fallback to slow path
|
||||
_causal_conv1d_cache = (None, None)
|
||||
else:
|
||||
_causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
|
||||
elif is_causal_conv1d_available():
|
||||
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
|
||||
_causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
|
||||
else:
|
||||
_causal_conv1d_cache = (None, None)
|
||||
return _causal_conv1d_cache
|
||||
|
||||
|
||||
_causal_conv1d_cache = None
|
||||
|
||||
|
||||
def rms_forward(hidden_states, variance_epsilon=1e-6):
|
||||
"""
|
||||
Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will
|
||||
@ -268,7 +240,12 @@ class FalconMambaMixer(nn.Module):
|
||||
self.rms_eps = config.mixer_rms_eps
|
||||
|
||||
def warn_slow_implementation(self):
|
||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
||||
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||
causal_conv1d_update, causal_conv1d_fn = (
|
||||
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||
if causal_conv1d is not None
|
||||
else (None, None)
|
||||
)
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
@ -323,7 +300,12 @@ class FalconMambaMixer(nn.Module):
|
||||
)
|
||||
|
||||
else:
|
||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
||||
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||
causal_conv1d_update, causal_conv1d_fn = (
|
||||
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||
if causal_conv1d is not None
|
||||
else (None, None)
|
||||
)
|
||||
hidden_states, gate = projected_states.chunk(2, dim=1)
|
||||
|
||||
if attention_mask is not None:
|
||||
@ -518,7 +500,12 @@ class FalconMambaMixer(nn.Module):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
||||
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||
causal_conv1d_update, causal_conv1d_fn = (
|
||||
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||
if causal_conv1d is not None
|
||||
else (None, None)
|
||||
)
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
|
@ -19,6 +19,7 @@ from typing import Optional
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...integrations.hub_kernels import lazy_load_kernel
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils.import_utils import (
|
||||
is_mamba_ssm_available,
|
||||
@ -35,7 +36,6 @@ from ..mamba.modeling_mamba import (
|
||||
MambaOutput,
|
||||
MambaPreTrainedModel,
|
||||
MambaRMSNorm,
|
||||
_lazy_load_causal_conv1d,
|
||||
)
|
||||
|
||||
|
||||
@ -54,8 +54,6 @@ if is_mamba_ssm_available():
|
||||
else:
|
||||
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
||||
|
||||
_causal_conv1d_cache = None
|
||||
|
||||
|
||||
class FalconMambaConfig(MambaConfig):
|
||||
"""
|
||||
@ -258,7 +256,12 @@ def rms_forward(hidden_states, variance_epsilon=1e-6):
|
||||
|
||||
class FalconMambaMixer(MambaMixer):
|
||||
def warn_slow_implementation(self):
|
||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
||||
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||
causal_conv1d_update, causal_conv1d_fn = (
|
||||
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||
if causal_conv1d is not None
|
||||
else (None, None)
|
||||
)
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
@ -324,7 +327,12 @@ class FalconMambaMixer(MambaMixer):
|
||||
)
|
||||
|
||||
else:
|
||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
||||
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||
causal_conv1d_update, causal_conv1d_fn = (
|
||||
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||
if causal_conv1d is not None
|
||||
else (None, None)
|
||||
)
|
||||
hidden_states, gate = projected_states.chunk(2, dim=1)
|
||||
|
||||
if attention_mask is not None:
|
||||
@ -518,7 +526,12 @@ class FalconMambaMixer(MambaMixer):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
||||
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||
causal_conv1d_update, causal_conv1d_fn = (
|
||||
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||
if causal_conv1d is not None
|
||||
else (None, None)
|
||||
)
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
|
@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss
|
||||
from ...activations import ACT2FN
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations.hub_kernels import lazy_load_kernel
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
@ -33,8 +34,6 @@ from ...utils import (
|
||||
logging,
|
||||
)
|
||||
from ...utils.import_utils import (
|
||||
is_causal_conv1d_available,
|
||||
is_kernels_available,
|
||||
is_mamba_ssm_available,
|
||||
is_mambapy_available,
|
||||
)
|
||||
@ -54,32 +53,6 @@ if is_mamba_ssm_available():
|
||||
else:
|
||||
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
||||
|
||||
_causal_conv1d_cache = None
|
||||
|
||||
|
||||
def _lazy_load_causal_conv1d():
|
||||
global _causal_conv1d_cache
|
||||
if _causal_conv1d_cache is not None:
|
||||
return _causal_conv1d_cache
|
||||
|
||||
if is_kernels_available():
|
||||
from kernels import get_kernel
|
||||
|
||||
try:
|
||||
_causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
|
||||
except FileNotFoundError:
|
||||
# no kernel binary match, fallback to slow path
|
||||
_causal_conv1d_cache = (None, None)
|
||||
else:
|
||||
_causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
|
||||
elif is_causal_conv1d_available():
|
||||
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
|
||||
_causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
|
||||
else:
|
||||
_causal_conv1d_cache = (None, None)
|
||||
return _causal_conv1d_cache
|
||||
|
||||
|
||||
class MambaCache:
|
||||
"""
|
||||
@ -236,7 +209,12 @@ class MambaMixer(nn.Module):
|
||||
self.warn_slow_implementation()
|
||||
|
||||
def warn_slow_implementation(self):
|
||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
||||
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||
causal_conv1d_update, causal_conv1d_fn = (
|
||||
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||
if causal_conv1d is not None
|
||||
else (None, None)
|
||||
)
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
@ -287,7 +265,12 @@ class MambaMixer(nn.Module):
|
||||
)
|
||||
|
||||
else:
|
||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
||||
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||
causal_conv1d_update, causal_conv1d_fn = (
|
||||
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||
if causal_conv1d is not None
|
||||
else (None, None)
|
||||
)
|
||||
hidden_states, gate = projected_states.chunk(2, dim=1)
|
||||
|
||||
if attention_mask is not None:
|
||||
@ -451,7 +434,12 @@ class MambaMixer(nn.Module):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
||||
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||
causal_conv1d_update, causal_conv1d_fn = (
|
||||
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||
if causal_conv1d is not None
|
||||
else (None, None)
|
||||
)
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
|
Reference in New Issue
Block a user