mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
[kernels] refactor function kernel calling (#41577)
* refactor function kernel callling * nit * don't pass the mapping * use _kernels_available * rm import
This commit is contained in:
@ -14,12 +14,16 @@
|
|||||||
import re
|
import re
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from types import ModuleType
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from ..modeling_flash_attention_utils import lazy_import_flash_attention
|
from ..modeling_flash_attention_utils import lazy_import_flash_attention
|
||||||
|
from ..utils import logging
|
||||||
from .flash_attention import flash_attention_forward
|
from .flash_attention import flash_attention_forward
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from kernels import (
|
from kernels import (
|
||||||
Device,
|
Device,
|
||||||
@ -158,6 +162,13 @@ except ImportError:
|
|||||||
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
|
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
|
||||||
|
|
||||||
|
|
||||||
|
_HUB_KERNEL_MAPPING: dict[str, str] = {
|
||||||
|
"causal-conv1d": "kernels-community/causal-conv1d",
|
||||||
|
}
|
||||||
|
|
||||||
|
_KERNEL_MODULE_MAPPING: dict[str, Optional[ModuleType]] = {}
|
||||||
|
|
||||||
|
|
||||||
def is_kernel(attn_implementation: Optional[str]) -> bool:
|
def is_kernel(attn_implementation: Optional[str]) -> bool:
|
||||||
"""Check whether `attn_implementation` matches a kernel pattern from the hub."""
|
"""Check whether `attn_implementation` matches a kernel pattern from the hub."""
|
||||||
return (
|
return (
|
||||||
@ -220,9 +231,53 @@ def load_and_register_attn_kernel(attn_implementation: str, attention_wrapper: O
|
|||||||
ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
|
ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
|
||||||
|
|
||||||
|
|
||||||
|
def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]] = _KERNEL_MODULE_MAPPING):
|
||||||
|
if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType):
|
||||||
|
return mapping[kernel_name]
|
||||||
|
if kernel_name not in _HUB_KERNEL_MAPPING:
|
||||||
|
logger.warning(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING")
|
||||||
|
mapping[kernel_name] = None
|
||||||
|
return None
|
||||||
|
if _kernels_available:
|
||||||
|
from kernels import get_kernel
|
||||||
|
|
||||||
|
try:
|
||||||
|
kernel = get_kernel(_HUB_KERNEL_MAPPING[kernel_name])
|
||||||
|
mapping[kernel_name] = kernel
|
||||||
|
except FileNotFoundError:
|
||||||
|
mapping[kernel_name] = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Try to import is_{kernel_name}_available from ..utils
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
new_kernel_name = kernel_name.replace("-", "_")
|
||||||
|
func_name = f"is_{new_kernel_name}_available"
|
||||||
|
|
||||||
|
try:
|
||||||
|
utils_mod = importlib.import_module("..utils.import_utils", __package__)
|
||||||
|
is_kernel_available = getattr(utils_mod, func_name, None)
|
||||||
|
except Exception:
|
||||||
|
is_kernel_available = None
|
||||||
|
|
||||||
|
if callable(is_kernel_available) and is_kernel_available():
|
||||||
|
# Try to import the module "{kernel_name}" from parent package level
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(f"{kernel_name}")
|
||||||
|
mapping[kernel_name] = module
|
||||||
|
return module
|
||||||
|
except Exception:
|
||||||
|
mapping[kernel_name] = None
|
||||||
|
else:
|
||||||
|
mapping[kernel_name] = None
|
||||||
|
|
||||||
|
return mapping[kernel_name]
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LayerRepository",
|
"LayerRepository",
|
||||||
"use_kernel_forward_from_hub",
|
"use_kernel_forward_from_hub",
|
||||||
"register_kernel_mapping",
|
"register_kernel_mapping",
|
||||||
"replace_kernel_forward_from_hub",
|
"replace_kernel_forward_from_hub",
|
||||||
|
"lazy_load_kernel",
|
||||||
]
|
]
|
||||||
|
@ -30,12 +30,11 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...configuration_utils import PreTrainedConfig
|
from ...configuration_utils import PreTrainedConfig
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations.hub_kernels import lazy_load_kernel
|
||||||
from ...modeling_layers import GradientCheckpointingLayer
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import ModelOutput, auto_docstring, logging
|
from ...utils import ModelOutput, auto_docstring, logging
|
||||||
from ...utils.import_utils import (
|
from ...utils.import_utils import (
|
||||||
is_causal_conv1d_available,
|
|
||||||
is_kernels_available,
|
|
||||||
is_mamba_ssm_available,
|
is_mamba_ssm_available,
|
||||||
is_mambapy_available,
|
is_mambapy_available,
|
||||||
)
|
)
|
||||||
@ -162,33 +161,6 @@ class FalconMambaCache:
|
|||||||
self.ssm_states[layer_idx].zero_()
|
self.ssm_states[layer_idx].zero_()
|
||||||
|
|
||||||
|
|
||||||
def _lazy_load_causal_conv1d():
|
|
||||||
global _causal_conv1d_cache
|
|
||||||
if _causal_conv1d_cache is not None:
|
|
||||||
return _causal_conv1d_cache
|
|
||||||
|
|
||||||
if is_kernels_available():
|
|
||||||
from kernels import get_kernel
|
|
||||||
|
|
||||||
try:
|
|
||||||
_causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
|
|
||||||
except FileNotFoundError:
|
|
||||||
# no kernel binary match, fallback to slow path
|
|
||||||
_causal_conv1d_cache = (None, None)
|
|
||||||
else:
|
|
||||||
_causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
|
|
||||||
elif is_causal_conv1d_available():
|
|
||||||
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
|
||||||
|
|
||||||
_causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
|
|
||||||
else:
|
|
||||||
_causal_conv1d_cache = (None, None)
|
|
||||||
return _causal_conv1d_cache
|
|
||||||
|
|
||||||
|
|
||||||
_causal_conv1d_cache = None
|
|
||||||
|
|
||||||
|
|
||||||
def rms_forward(hidden_states, variance_epsilon=1e-6):
|
def rms_forward(hidden_states, variance_epsilon=1e-6):
|
||||||
"""
|
"""
|
||||||
Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will
|
Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will
|
||||||
@ -268,7 +240,12 @@ class FalconMambaMixer(nn.Module):
|
|||||||
self.rms_eps = config.mixer_rms_eps
|
self.rms_eps = config.mixer_rms_eps
|
||||||
|
|
||||||
def warn_slow_implementation(self):
|
def warn_slow_implementation(self):
|
||||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||||
|
causal_conv1d_update, causal_conv1d_fn = (
|
||||||
|
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||||
|
if causal_conv1d is not None
|
||||||
|
else (None, None)
|
||||||
|
)
|
||||||
is_fast_path_available = all(
|
is_fast_path_available = all(
|
||||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||||
)
|
)
|
||||||
@ -323,7 +300,12 @@ class FalconMambaMixer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||||
|
causal_conv1d_update, causal_conv1d_fn = (
|
||||||
|
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||||
|
if causal_conv1d is not None
|
||||||
|
else (None, None)
|
||||||
|
)
|
||||||
hidden_states, gate = projected_states.chunk(2, dim=1)
|
hidden_states, gate = projected_states.chunk(2, dim=1)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
@ -518,7 +500,12 @@ class FalconMambaMixer(nn.Module):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.LongTensor] = None,
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||||
|
causal_conv1d_update, causal_conv1d_fn = (
|
||||||
|
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||||
|
if causal_conv1d is not None
|
||||||
|
else (None, None)
|
||||||
|
)
|
||||||
is_fast_path_available = all(
|
is_fast_path_available = all(
|
||||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||||
)
|
)
|
||||||
|
@ -19,6 +19,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from ...integrations.hub_kernels import lazy_load_kernel
|
||||||
from ...utils import auto_docstring, logging
|
from ...utils import auto_docstring, logging
|
||||||
from ...utils.import_utils import (
|
from ...utils.import_utils import (
|
||||||
is_mamba_ssm_available,
|
is_mamba_ssm_available,
|
||||||
@ -35,7 +36,6 @@ from ..mamba.modeling_mamba import (
|
|||||||
MambaOutput,
|
MambaOutput,
|
||||||
MambaPreTrainedModel,
|
MambaPreTrainedModel,
|
||||||
MambaRMSNorm,
|
MambaRMSNorm,
|
||||||
_lazy_load_causal_conv1d,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -54,8 +54,6 @@ if is_mamba_ssm_available():
|
|||||||
else:
|
else:
|
||||||
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
||||||
|
|
||||||
_causal_conv1d_cache = None
|
|
||||||
|
|
||||||
|
|
||||||
class FalconMambaConfig(MambaConfig):
|
class FalconMambaConfig(MambaConfig):
|
||||||
"""
|
"""
|
||||||
@ -258,7 +256,12 @@ def rms_forward(hidden_states, variance_epsilon=1e-6):
|
|||||||
|
|
||||||
class FalconMambaMixer(MambaMixer):
|
class FalconMambaMixer(MambaMixer):
|
||||||
def warn_slow_implementation(self):
|
def warn_slow_implementation(self):
|
||||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||||
|
causal_conv1d_update, causal_conv1d_fn = (
|
||||||
|
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||||
|
if causal_conv1d is not None
|
||||||
|
else (None, None)
|
||||||
|
)
|
||||||
is_fast_path_available = all(
|
is_fast_path_available = all(
|
||||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||||
)
|
)
|
||||||
@ -324,7 +327,12 @@ class FalconMambaMixer(MambaMixer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||||
|
causal_conv1d_update, causal_conv1d_fn = (
|
||||||
|
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||||
|
if causal_conv1d is not None
|
||||||
|
else (None, None)
|
||||||
|
)
|
||||||
hidden_states, gate = projected_states.chunk(2, dim=1)
|
hidden_states, gate = projected_states.chunk(2, dim=1)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
@ -518,7 +526,12 @@ class FalconMambaMixer(MambaMixer):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.LongTensor] = None,
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||||
|
causal_conv1d_update, causal_conv1d_fn = (
|
||||||
|
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||||
|
if causal_conv1d is not None
|
||||||
|
else (None, None)
|
||||||
|
)
|
||||||
is_fast_path_available = all(
|
is_fast_path_available = all(
|
||||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||||
)
|
)
|
||||||
|
@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...configuration_utils import PreTrainedConfig
|
from ...configuration_utils import PreTrainedConfig
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations.hub_kernels import lazy_load_kernel
|
||||||
from ...modeling_layers import GradientCheckpointingLayer
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@ -33,8 +34,6 @@ from ...utils import (
|
|||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from ...utils.import_utils import (
|
from ...utils.import_utils import (
|
||||||
is_causal_conv1d_available,
|
|
||||||
is_kernels_available,
|
|
||||||
is_mamba_ssm_available,
|
is_mamba_ssm_available,
|
||||||
is_mambapy_available,
|
is_mambapy_available,
|
||||||
)
|
)
|
||||||
@ -54,32 +53,6 @@ if is_mamba_ssm_available():
|
|||||||
else:
|
else:
|
||||||
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
||||||
|
|
||||||
_causal_conv1d_cache = None
|
|
||||||
|
|
||||||
|
|
||||||
def _lazy_load_causal_conv1d():
|
|
||||||
global _causal_conv1d_cache
|
|
||||||
if _causal_conv1d_cache is not None:
|
|
||||||
return _causal_conv1d_cache
|
|
||||||
|
|
||||||
if is_kernels_available():
|
|
||||||
from kernels import get_kernel
|
|
||||||
|
|
||||||
try:
|
|
||||||
_causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
|
|
||||||
except FileNotFoundError:
|
|
||||||
# no kernel binary match, fallback to slow path
|
|
||||||
_causal_conv1d_cache = (None, None)
|
|
||||||
else:
|
|
||||||
_causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
|
|
||||||
elif is_causal_conv1d_available():
|
|
||||||
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
|
||||||
|
|
||||||
_causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
|
|
||||||
else:
|
|
||||||
_causal_conv1d_cache = (None, None)
|
|
||||||
return _causal_conv1d_cache
|
|
||||||
|
|
||||||
|
|
||||||
class MambaCache:
|
class MambaCache:
|
||||||
"""
|
"""
|
||||||
@ -236,7 +209,12 @@ class MambaMixer(nn.Module):
|
|||||||
self.warn_slow_implementation()
|
self.warn_slow_implementation()
|
||||||
|
|
||||||
def warn_slow_implementation(self):
|
def warn_slow_implementation(self):
|
||||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||||
|
causal_conv1d_update, causal_conv1d_fn = (
|
||||||
|
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||||
|
if causal_conv1d is not None
|
||||||
|
else (None, None)
|
||||||
|
)
|
||||||
is_fast_path_available = all(
|
is_fast_path_available = all(
|
||||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||||
)
|
)
|
||||||
@ -287,7 +265,12 @@ class MambaMixer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||||
|
causal_conv1d_update, causal_conv1d_fn = (
|
||||||
|
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||||
|
if causal_conv1d is not None
|
||||||
|
else (None, None)
|
||||||
|
)
|
||||||
hidden_states, gate = projected_states.chunk(2, dim=1)
|
hidden_states, gate = projected_states.chunk(2, dim=1)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
@ -451,7 +434,12 @@ class MambaMixer(nn.Module):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.LongTensor] = None,
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
|
causal_conv1d = lazy_load_kernel("causal-conv1d")
|
||||||
|
causal_conv1d_update, causal_conv1d_fn = (
|
||||||
|
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
|
||||||
|
if causal_conv1d is not None
|
||||||
|
else (None, None)
|
||||||
|
)
|
||||||
is_fast_path_available = all(
|
is_fast_path_available = all(
|
||||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user