[kernels] refactor function kernel calling (#41577)

* refactor function kernel callling

* nit

* don't pass the mapping

* use _kernels_available

* rm import
This commit is contained in:
Mohamed Mekkouri
2025-10-16 15:43:02 +02:00
committed by GitHub
parent 9176af574a
commit 1fb3fc4db0
4 changed files with 112 additions and 69 deletions

View File

@ -14,12 +14,16 @@
import re
from collections.abc import Callable
from functools import partial
from types import ModuleType
from typing import Optional, Union
from ..modeling_flash_attention_utils import lazy_import_flash_attention
from ..utils import logging
from .flash_attention import flash_attention_forward
logger = logging.get_logger(__name__)
try:
from kernels import (
Device,
@ -158,6 +162,13 @@ except ImportError:
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
_HUB_KERNEL_MAPPING: dict[str, str] = {
"causal-conv1d": "kernels-community/causal-conv1d",
}
_KERNEL_MODULE_MAPPING: dict[str, Optional[ModuleType]] = {}
def is_kernel(attn_implementation: Optional[str]) -> bool:
"""Check whether `attn_implementation` matches a kernel pattern from the hub."""
return (
@ -220,9 +231,53 @@ def load_and_register_attn_kernel(attn_implementation: str, attention_wrapper: O
ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]] = _KERNEL_MODULE_MAPPING):
if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType):
return mapping[kernel_name]
if kernel_name not in _HUB_KERNEL_MAPPING:
logger.warning(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING")
mapping[kernel_name] = None
return None
if _kernels_available:
from kernels import get_kernel
try:
kernel = get_kernel(_HUB_KERNEL_MAPPING[kernel_name])
mapping[kernel_name] = kernel
except FileNotFoundError:
mapping[kernel_name] = None
else:
# Try to import is_{kernel_name}_available from ..utils
import importlib
new_kernel_name = kernel_name.replace("-", "_")
func_name = f"is_{new_kernel_name}_available"
try:
utils_mod = importlib.import_module("..utils.import_utils", __package__)
is_kernel_available = getattr(utils_mod, func_name, None)
except Exception:
is_kernel_available = None
if callable(is_kernel_available) and is_kernel_available():
# Try to import the module "{kernel_name}" from parent package level
try:
module = importlib.import_module(f"{kernel_name}")
mapping[kernel_name] = module
return module
except Exception:
mapping[kernel_name] = None
else:
mapping[kernel_name] = None
return mapping[kernel_name]
__all__ = [
"LayerRepository",
"use_kernel_forward_from_hub",
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
"lazy_load_kernel",
]

View File

@ -30,12 +30,11 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...configuration_utils import PreTrainedConfig
from ...generation import GenerationMixin
from ...integrations.hub_kernels import lazy_load_kernel
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging
from ...utils.import_utils import (
is_causal_conv1d_available,
is_kernels_available,
is_mamba_ssm_available,
is_mambapy_available,
)
@ -162,33 +161,6 @@ class FalconMambaCache:
self.ssm_states[layer_idx].zero_()
def _lazy_load_causal_conv1d():
global _causal_conv1d_cache
if _causal_conv1d_cache is not None:
return _causal_conv1d_cache
if is_kernels_available():
from kernels import get_kernel
try:
_causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
except FileNotFoundError:
# no kernel binary match, fallback to slow path
_causal_conv1d_cache = (None, None)
else:
_causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
elif is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
_causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
else:
_causal_conv1d_cache = (None, None)
return _causal_conv1d_cache
_causal_conv1d_cache = None
def rms_forward(hidden_states, variance_epsilon=1e-6):
"""
Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will
@ -268,7 +240,12 @@ class FalconMambaMixer(nn.Module):
self.rms_eps = config.mixer_rms_eps
def warn_slow_implementation(self):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
@ -323,7 +300,12 @@ class FalconMambaMixer(nn.Module):
)
else:
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
hidden_states, gate = projected_states.chunk(2, dim=1)
if attention_mask is not None:
@ -518,7 +500,12 @@ class FalconMambaMixer(nn.Module):
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)

View File

@ -19,6 +19,7 @@ from typing import Optional
import torch
from torch import nn
from ...integrations.hub_kernels import lazy_load_kernel
from ...utils import auto_docstring, logging
from ...utils.import_utils import (
is_mamba_ssm_available,
@ -35,7 +36,6 @@ from ..mamba.modeling_mamba import (
MambaOutput,
MambaPreTrainedModel,
MambaRMSNorm,
_lazy_load_causal_conv1d,
)
@ -54,8 +54,6 @@ if is_mamba_ssm_available():
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
_causal_conv1d_cache = None
class FalconMambaConfig(MambaConfig):
"""
@ -258,7 +256,12 @@ def rms_forward(hidden_states, variance_epsilon=1e-6):
class FalconMambaMixer(MambaMixer):
def warn_slow_implementation(self):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
@ -324,7 +327,12 @@ class FalconMambaMixer(MambaMixer):
)
else:
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
hidden_states, gate = projected_states.chunk(2, dim=1)
if attention_mask is not None:
@ -518,7 +526,12 @@ class FalconMambaMixer(MambaMixer):
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)

View File

@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...configuration_utils import PreTrainedConfig
from ...generation import GenerationMixin
from ...integrations.hub_kernels import lazy_load_kernel
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel
from ...utils import (
@ -33,8 +34,6 @@ from ...utils import (
logging,
)
from ...utils.import_utils import (
is_causal_conv1d_available,
is_kernels_available,
is_mamba_ssm_available,
is_mambapy_available,
)
@ -54,32 +53,6 @@ if is_mamba_ssm_available():
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
_causal_conv1d_cache = None
def _lazy_load_causal_conv1d():
global _causal_conv1d_cache
if _causal_conv1d_cache is not None:
return _causal_conv1d_cache
if is_kernels_available():
from kernels import get_kernel
try:
_causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
except FileNotFoundError:
# no kernel binary match, fallback to slow path
_causal_conv1d_cache = (None, None)
else:
_causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
elif is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
_causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
else:
_causal_conv1d_cache = (None, None)
return _causal_conv1d_cache
class MambaCache:
"""
@ -236,7 +209,12 @@ class MambaMixer(nn.Module):
self.warn_slow_implementation()
def warn_slow_implementation(self):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
@ -287,7 +265,12 @@ class MambaMixer(nn.Module):
)
else:
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
hidden_states, gate = projected_states.chunk(2, dim=1)
if attention_mask is not None:
@ -451,7 +434,12 @@ class MambaMixer(nn.Module):
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update, causal_conv1d_fn = (
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
if causal_conv1d is not None
else (None, None)
)
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)