[kernels] refactor function kernel calling (#41577)

* refactor function kernel callling

* nit

* don't pass the mapping

* use _kernels_available

* rm import
This commit is contained in:
Mohamed Mekkouri
2025-10-16 15:43:02 +02:00
committed by GitHub
parent 9176af574a
commit 1fb3fc4db0
4 changed files with 112 additions and 69 deletions

View File

@ -14,12 +14,16 @@
import re import re
from collections.abc import Callable from collections.abc import Callable
from functools import partial from functools import partial
from types import ModuleType
from typing import Optional, Union from typing import Optional, Union
from ..modeling_flash_attention_utils import lazy_import_flash_attention from ..modeling_flash_attention_utils import lazy_import_flash_attention
from ..utils import logging
from .flash_attention import flash_attention_forward from .flash_attention import flash_attention_forward
logger = logging.get_logger(__name__)
try: try:
from kernels import ( from kernels import (
Device, Device,
@ -158,6 +162,13 @@ except ImportError:
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
_HUB_KERNEL_MAPPING: dict[str, str] = {
"causal-conv1d": "kernels-community/causal-conv1d",
}
_KERNEL_MODULE_MAPPING: dict[str, Optional[ModuleType]] = {}
def is_kernel(attn_implementation: Optional[str]) -> bool: def is_kernel(attn_implementation: Optional[str]) -> bool:
"""Check whether `attn_implementation` matches a kernel pattern from the hub.""" """Check whether `attn_implementation` matches a kernel pattern from the hub."""
return ( return (
@ -220,9 +231,53 @@ def load_and_register_attn_kernel(attn_implementation: str, attention_wrapper: O
ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]) ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]] = _KERNEL_MODULE_MAPPING):
if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType):
return mapping[kernel_name]
if kernel_name not in _HUB_KERNEL_MAPPING:
logger.warning(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING")
mapping[kernel_name] = None
return None
if _kernels_available:
from kernels import get_kernel
try:
kernel = get_kernel(_HUB_KERNEL_MAPPING[kernel_name])
mapping[kernel_name] = kernel
except FileNotFoundError:
mapping[kernel_name] = None
else:
# Try to import is_{kernel_name}_available from ..utils
import importlib
new_kernel_name = kernel_name.replace("-", "_")
func_name = f"is_{new_kernel_name}_available"
try:
utils_mod = importlib.import_module("..utils.import_utils", __package__)
is_kernel_available = getattr(utils_mod, func_name, None)
except Exception:
is_kernel_available = None
if callable(is_kernel_available) and is_kernel_available():
# Try to import the module "{kernel_name}" from parent package level
try:
module = importlib.import_module(f"{kernel_name}")
mapping[kernel_name] = module
return module
except Exception:
mapping[kernel_name] = None
else:
mapping[kernel_name] = None
return mapping[kernel_name]
__all__ = [ __all__ = [
"LayerRepository", "LayerRepository",
"use_kernel_forward_from_hub", "use_kernel_forward_from_hub",
"register_kernel_mapping", "register_kernel_mapping",
"replace_kernel_forward_from_hub", "replace_kernel_forward_from_hub",
"lazy_load_kernel",
] ]

View File

@ -30,12 +30,11 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...configuration_utils import PreTrainedConfig from ...configuration_utils import PreTrainedConfig
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations.hub_kernels import lazy_load_kernel
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging from ...utils import ModelOutput, auto_docstring, logging
from ...utils.import_utils import ( from ...utils.import_utils import (
is_causal_conv1d_available,
is_kernels_available,
is_mamba_ssm_available, is_mamba_ssm_available,
is_mambapy_available, is_mambapy_available,
) )
@ -162,33 +161,6 @@ class FalconMambaCache:
self.ssm_states[layer_idx].zero_() self.ssm_states[layer_idx].zero_()
def _lazy_load_causal_conv1d():
global _causal_conv1d_cache
if _causal_conv1d_cache is not None:
return _causal_conv1d_cache
if is_kernels_available():
from kernels import get_kernel
try:
_causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
except FileNotFoundError:
# no kernel binary match, fallback to slow path
_causal_conv1d_cache = (None, None)
else:
_causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
elif is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
_causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
else:
_causal_conv1d_cache = (None, None)
return _causal_conv1d_cache
_causal_conv1d_cache = None
def rms_forward(hidden_states, variance_epsilon=1e-6): def rms_forward(hidden_states, variance_epsilon=1e-6):
""" """
Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will
@ -268,7 +240,12 @@ class FalconMambaMixer(nn.Module):
self.rms_eps = config.mixer_rms_eps self.rms_eps = config.mixer_rms_eps
def warn_slow_implementation(self): def warn_slow_implementation(self):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all( is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
) )
@ -323,7 +300,12 @@ class FalconMambaMixer(nn.Module):
) )
else: else:
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
hidden_states, gate = projected_states.chunk(2, dim=1) hidden_states, gate = projected_states.chunk(2, dim=1)
if attention_mask is not None: if attention_mask is not None:
@ -518,7 +500,12 @@ class FalconMambaMixer(nn.Module):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
): ):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all( is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
) )

View File

@ -19,6 +19,7 @@ from typing import Optional
import torch import torch
from torch import nn from torch import nn
from ...integrations.hub_kernels import lazy_load_kernel
from ...utils import auto_docstring, logging from ...utils import auto_docstring, logging
from ...utils.import_utils import ( from ...utils.import_utils import (
is_mamba_ssm_available, is_mamba_ssm_available,
@ -35,7 +36,6 @@ from ..mamba.modeling_mamba import (
MambaOutput, MambaOutput,
MambaPreTrainedModel, MambaPreTrainedModel,
MambaRMSNorm, MambaRMSNorm,
_lazy_load_causal_conv1d,
) )
@ -54,8 +54,6 @@ if is_mamba_ssm_available():
else: else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
_causal_conv1d_cache = None
class FalconMambaConfig(MambaConfig): class FalconMambaConfig(MambaConfig):
""" """
@ -258,7 +256,12 @@ def rms_forward(hidden_states, variance_epsilon=1e-6):
class FalconMambaMixer(MambaMixer): class FalconMambaMixer(MambaMixer):
def warn_slow_implementation(self): def warn_slow_implementation(self):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all( is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
) )
@ -324,7 +327,12 @@ class FalconMambaMixer(MambaMixer):
) )
else: else:
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
hidden_states, gate = projected_states.chunk(2, dim=1) hidden_states, gate = projected_states.chunk(2, dim=1)
if attention_mask is not None: if attention_mask is not None:
@ -518,7 +526,12 @@ class FalconMambaMixer(MambaMixer):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
): ):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all( is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
) )

View File

@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...configuration_utils import PreTrainedConfig from ...configuration_utils import PreTrainedConfig
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations.hub_kernels import lazy_load_kernel
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
@ -33,8 +34,6 @@ from ...utils import (
logging, logging,
) )
from ...utils.import_utils import ( from ...utils.import_utils import (
is_causal_conv1d_available,
is_kernels_available,
is_mamba_ssm_available, is_mamba_ssm_available,
is_mambapy_available, is_mambapy_available,
) )
@ -54,32 +53,6 @@ if is_mamba_ssm_available():
else: else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
_causal_conv1d_cache = None
def _lazy_load_causal_conv1d():
global _causal_conv1d_cache
if _causal_conv1d_cache is not None:
return _causal_conv1d_cache
if is_kernels_available():
from kernels import get_kernel
try:
_causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
except FileNotFoundError:
# no kernel binary match, fallback to slow path
_causal_conv1d_cache = (None, None)
else:
_causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
elif is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
_causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
else:
_causal_conv1d_cache = (None, None)
return _causal_conv1d_cache
class MambaCache: class MambaCache:
""" """
@ -236,7 +209,12 @@ class MambaMixer(nn.Module):
self.warn_slow_implementation() self.warn_slow_implementation()
def warn_slow_implementation(self): def warn_slow_implementation(self):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all( is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
) )
@ -287,7 +265,12 @@ class MambaMixer(nn.Module):
) )
else: else:
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
hidden_states, gate = projected_states.chunk(2, dim=1) hidden_states, gate = projected_states.chunk(2, dim=1)
if attention_mask is not None: if attention_mask is not None:
@ -451,7 +434,12 @@ class MambaMixer(nn.Module):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
): ):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all( is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
) )