mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Split this directory into two PRs to keep them from being too large. Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165062 Approved by: https://github.com/oulgen, https://github.com/mlazos
600 lines
21 KiB
Python
600 lines
21 KiB
Python
# mypy: allow-untyped-defs
|
|
import contextlib
|
|
from typing import Any, Union
|
|
from typing_extensions import deprecated
|
|
|
|
import torch
|
|
|
|
|
|
__all__ = [
|
|
"is_built",
|
|
"cuFFTPlanCacheAttrContextProp",
|
|
"cuFFTPlanCache",
|
|
"cuFFTPlanCacheManager",
|
|
"cuBLASModule",
|
|
"preferred_linalg_library",
|
|
"preferred_blas_library",
|
|
"preferred_rocm_fa_library",
|
|
"cufft_plan_cache",
|
|
"matmul",
|
|
"SDPAParams",
|
|
"enable_cudnn_sdp",
|
|
"cudnn_sdp_enabled",
|
|
"enable_flash_sdp",
|
|
"flash_sdp_enabled",
|
|
"enable_mem_efficient_sdp",
|
|
"mem_efficient_sdp_enabled",
|
|
"math_sdp_enabled",
|
|
"enable_math_sdp",
|
|
"allow_fp16_bf16_reduction_math_sdp",
|
|
"fp16_bf16_reduction_math_sdp_allowed",
|
|
"is_flash_attention_available",
|
|
"can_use_flash_attention",
|
|
"can_use_efficient_attention",
|
|
"can_use_cudnn_attention",
|
|
"sdp_kernel",
|
|
]
|
|
|
|
|
|
def is_built():
|
|
r"""
|
|
Return whether PyTorch is built with CUDA support.
|
|
|
|
Note that this doesn't necessarily mean CUDA is available; just that if this PyTorch
|
|
binary were run on a machine with working CUDA drivers and devices, we would be able to use it.
|
|
"""
|
|
return torch._C._has_cuda
|
|
|
|
|
|
class cuFFTPlanCacheAttrContextProp:
|
|
# Like regular ContextProp, but uses the `.device_index` attribute from the
|
|
# calling object as the first argument to the getter and setter.
|
|
def __init__(self, getter, setter):
|
|
self.getter = getter
|
|
self.setter = setter
|
|
|
|
def __get__(self, obj, objtype):
|
|
return self.getter(obj.device_index)
|
|
|
|
def __set__(self, obj, val):
|
|
if isinstance(self.setter, str):
|
|
raise RuntimeError(self.setter)
|
|
self.setter(obj.device_index, val)
|
|
|
|
|
|
class cuFFTPlanCache:
|
|
r"""
|
|
Represent a specific plan cache for a specific `device_index`.
|
|
|
|
The attributes `size` and `max_size`, and method `clear`, can fetch and/ or
|
|
change properties of the C++ cuFFT plan cache.
|
|
"""
|
|
|
|
def __init__(self, device_index):
|
|
self.device_index = device_index
|
|
|
|
size = cuFFTPlanCacheAttrContextProp(
|
|
torch._cufft_get_plan_cache_size,
|
|
".size is a read-only property showing the number of plans currently in the "
|
|
"cache. To change the cache capacity, set cufft_plan_cache.max_size.",
|
|
)
|
|
|
|
max_size = cuFFTPlanCacheAttrContextProp(
|
|
torch._cufft_get_plan_cache_max_size, torch._cufft_set_plan_cache_max_size
|
|
)
|
|
|
|
def clear(self):
|
|
return torch._cufft_clear_plan_cache(self.device_index)
|
|
|
|
|
|
class cuFFTPlanCacheManager:
|
|
r"""
|
|
Represent all cuFFT plan caches, return the cuFFTPlanCache for a given device when indexed.
|
|
|
|
Finally, this object, when used directly as a `cuFFTPlanCache` object (e.g.,
|
|
setting the `.max_size`) attribute, the current device's cuFFT plan cache is
|
|
used.
|
|
"""
|
|
|
|
__initialized = False
|
|
|
|
def __init__(self):
|
|
self.caches = []
|
|
self.__initialized = True
|
|
|
|
def __getitem__(self, device):
|
|
index = torch.cuda._utils._get_device_index(device)
|
|
if index < 0 or index >= torch.cuda.device_count():
|
|
raise RuntimeError(
|
|
f"cufft_plan_cache: expected 0 <= device index < {torch.cuda.device_count()}, but got "
|
|
f"device with index {index}"
|
|
)
|
|
if len(self.caches) == 0:
|
|
self.caches.extend(
|
|
cuFFTPlanCache(index) for index in range(torch.cuda.device_count())
|
|
)
|
|
return self.caches[index]
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self[torch.cuda.current_device()], name)
|
|
|
|
def __setattr__(self, name, value):
|
|
if self.__initialized:
|
|
return setattr(self[torch.cuda.current_device()], name, value)
|
|
else:
|
|
return super().__setattr__(name, value)
|
|
|
|
|
|
class cuBLASModule:
|
|
@staticmethod
|
|
def _parse_reduction_setting(value: Any, attr_name: str) -> tuple[bool, bool]:
|
|
def _ensure_bool(obj: Any, which: str) -> bool:
|
|
if isinstance(obj, bool):
|
|
return obj
|
|
raise TypeError(
|
|
f"{attr_name} expects a bool for {which}, but got {type(obj)!r}"
|
|
)
|
|
|
|
if isinstance(value, bool):
|
|
return value, True
|
|
if isinstance(value, (list, tuple)):
|
|
if not value:
|
|
raise TypeError(f"{attr_name} expects at least one boolean argument")
|
|
if len(value) > 2:
|
|
raise TypeError(f"{attr_name} expects at most two boolean arguments")
|
|
allow_reduced_precision = _ensure_bool(value[0], "allow_reduced_precision")
|
|
if len(value) == 1:
|
|
return allow_reduced_precision, True
|
|
allow_splitk = _ensure_bool(value[1], "allow_splitk")
|
|
return allow_reduced_precision, allow_splitk
|
|
raise TypeError(
|
|
f"{attr_name} expects a bool or a tuple/list of bools, but got {type(value)!r}"
|
|
)
|
|
|
|
def __getattr__(self, name):
|
|
if name == "allow_tf32":
|
|
return torch._C._get_cublas_allow_tf32()
|
|
elif name == "allow_fp16_reduced_precision_reduction":
|
|
# pyrefly: ignore # not-iterable
|
|
allow_reduced_precision, _ = (
|
|
torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
|
|
)
|
|
return allow_reduced_precision
|
|
elif name == "allow_fp16_reduced_precision_reduction_split_k":
|
|
# pyrefly: ignore # not-iterable
|
|
_, allow_splitk = (
|
|
torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
|
|
)
|
|
return allow_splitk
|
|
elif name == "allow_bf16_reduced_precision_reduction":
|
|
# pyrefly: ignore # not-iterable
|
|
allow_reduced_precision, _ = (
|
|
torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
|
|
)
|
|
return allow_reduced_precision
|
|
elif name == "allow_bf16_reduced_precision_reduction_split_k":
|
|
# pyrefly: ignore # not-iterable
|
|
_, allow_splitk = (
|
|
torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
|
|
)
|
|
return allow_splitk
|
|
elif name == "allow_fp16_accumulation":
|
|
return torch._C._get_cublas_allow_fp16_accumulation()
|
|
elif name == "fp32_precision":
|
|
return torch._C._get_fp32_precision_getter("cuda", "matmul")
|
|
raise AttributeError("Unknown attribute " + name)
|
|
|
|
def __setattr__(self, name, value):
|
|
if name == "allow_tf32":
|
|
return torch._C._set_cublas_allow_tf32(value)
|
|
elif name == "allow_fp16_reduced_precision_reduction":
|
|
allow_reduced_precision, allow_splitk = self._parse_reduction_setting(
|
|
value, "allow_fp16_reduced_precision_reduction"
|
|
)
|
|
return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(
|
|
allow_reduced_precision,
|
|
# pyrefly: ignore # bad-argument-count
|
|
allow_splitk,
|
|
)
|
|
elif name == "allow_bf16_reduced_precision_reduction":
|
|
allow_reduced_precision, allow_splitk = self._parse_reduction_setting(
|
|
value, "allow_bf16_reduced_precision_reduction"
|
|
)
|
|
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(
|
|
allow_reduced_precision,
|
|
# pyrefly: ignore # bad-argument-count
|
|
allow_splitk,
|
|
)
|
|
elif name == "allow_fp16_accumulation":
|
|
return torch._C._set_cublas_allow_fp16_accumulation(value)
|
|
elif name == "fp32_precision":
|
|
return torch._C._set_fp32_precision_setter("cuda", "matmul", value)
|
|
raise AttributeError("Unknown attribute " + name)
|
|
|
|
|
|
_LinalgBackends = {
|
|
"default": torch._C._LinalgBackend.Default,
|
|
"cusolver": torch._C._LinalgBackend.Cusolver,
|
|
"magma": torch._C._LinalgBackend.Magma,
|
|
}
|
|
_LinalgBackends_str = ", ".join(_LinalgBackends.keys())
|
|
|
|
|
|
def preferred_linalg_library(
|
|
backend: Union[None, str, torch._C._LinalgBackend] = None,
|
|
) -> torch._C._LinalgBackend:
|
|
r"""
|
|
Override the heuristic PyTorch uses to choose between cuSOLVER and MAGMA for CUDA linear algebra operations.
|
|
|
|
.. warning:: This flag is experimental and subject to change.
|
|
|
|
When PyTorch runs a CUDA linear algebra operation it often uses the cuSOLVER or MAGMA libraries,
|
|
and if both are available it decides which to use with a heuristic.
|
|
This flag (a :class:`str`) allows overriding those heuristics.
|
|
|
|
* If `"cusolver"` is set then cuSOLVER will be used wherever possible.
|
|
* If `"magma"` is set then MAGMA will be used wherever possible.
|
|
* If `"default"` (the default) is set then heuristics will be used to pick between
|
|
cuSOLVER and MAGMA if both are available.
|
|
* When no input is given, this function returns the currently preferred library.
|
|
* User may use the environment variable TORCH_LINALG_PREFER_CUSOLVER=1 to set the preferred library to cuSOLVER
|
|
globally.
|
|
This flag only sets the initial value of the preferred library and the preferred library
|
|
may still be overridden by this function call later in your script.
|
|
|
|
Note: When a library is preferred other libraries may still be used if the preferred library
|
|
doesn't implement the operation(s) called.
|
|
This flag may achieve better performance if PyTorch's heuristic library selection is incorrect
|
|
for your application's inputs.
|
|
|
|
Currently supported linalg operators:
|
|
|
|
* :func:`torch.linalg.inv`
|
|
* :func:`torch.linalg.inv_ex`
|
|
* :func:`torch.linalg.cholesky`
|
|
* :func:`torch.linalg.cholesky_ex`
|
|
* :func:`torch.cholesky_solve`
|
|
* :func:`torch.cholesky_inverse`
|
|
* :func:`torch.linalg.lu_factor`
|
|
* :func:`torch.linalg.lu`
|
|
* :func:`torch.linalg.lu_solve`
|
|
* :func:`torch.linalg.qr`
|
|
* :func:`torch.linalg.eigh`
|
|
* :func:`torch.linalg.eighvals`
|
|
* :func:`torch.linalg.svd`
|
|
* :func:`torch.linalg.svdvals`
|
|
"""
|
|
if backend is None:
|
|
pass
|
|
elif isinstance(backend, str):
|
|
if backend not in _LinalgBackends:
|
|
raise RuntimeError(
|
|
f"Unknown input value. Choose from: {_LinalgBackends_str}."
|
|
)
|
|
torch._C._set_linalg_preferred_backend(_LinalgBackends[backend])
|
|
elif isinstance(backend, torch._C._LinalgBackend):
|
|
torch._C._set_linalg_preferred_backend(backend)
|
|
else:
|
|
raise RuntimeError("Unknown input value type.")
|
|
|
|
return torch._C._get_linalg_preferred_backend()
|
|
|
|
|
|
_BlasBackends = {
|
|
"default": torch._C._BlasBackend.Default,
|
|
"cublas": torch._C._BlasBackend.Cublas,
|
|
"hipblas": torch._C._BlasBackend.Cublas, # alias
|
|
"cublaslt": torch._C._BlasBackend.Cublaslt,
|
|
"hipblaslt": torch._C._BlasBackend.Cublaslt, # alias
|
|
"ck": torch._C._BlasBackend.Ck,
|
|
}
|
|
_BlasBackends_str = ", ".join(_BlasBackends.keys())
|
|
|
|
|
|
def preferred_blas_library(
|
|
backend: Union[None, str, torch._C._BlasBackend] = None,
|
|
) -> torch._C._BlasBackend:
|
|
r"""
|
|
Override the library PyTorch uses for BLAS operations. Choose between cuBLAS, cuBLASLt, and CK [ROCm-only].
|
|
|
|
.. warning:: This flag is experimental and subject to change.
|
|
|
|
When PyTorch runs a CUDA BLAS operation it defaults to cuBLAS even if both cuBLAS and cuBLASLt are available.
|
|
For PyTorch built for ROCm, hipBLAS, hipBLASLt, and CK may offer different performance.
|
|
This flag (a :class:`str`) allows overriding which BLAS library to use.
|
|
|
|
* If `"cublas"` is set then cuBLAS will be used wherever possible.
|
|
* If `"cublaslt"` is set then cuBLASLt will be used wherever possible.
|
|
* If `"ck"` is set then CK will be used wherever possible.
|
|
* If `"default"` (the default) is set then heuristics will be used to pick between the other options.
|
|
* When no input is given, this function returns the currently preferred library.
|
|
* User may use the environment variable TORCH_BLAS_PREFER_CUBLASLT=1 to set the preferred library to cuBLASLt
|
|
globally.
|
|
This flag only sets the initial value of the preferred library and the preferred library
|
|
may still be overridden by this function call later in your script.
|
|
|
|
Note: When a library is preferred other libraries may still be used if the preferred library
|
|
doesn't implement the operation(s) called.
|
|
This flag may achieve better performance if PyTorch's library selection is incorrect
|
|
for your application's inputs.
|
|
|
|
"""
|
|
if backend is None:
|
|
pass
|
|
elif isinstance(backend, str):
|
|
if backend not in _BlasBackends:
|
|
raise RuntimeError(
|
|
f"Unknown input value. Choose from: {_BlasBackends_str}."
|
|
)
|
|
torch._C._set_blas_preferred_backend(_BlasBackends[backend])
|
|
elif isinstance(backend, torch._C._BlasBackend):
|
|
torch._C._set_blas_preferred_backend(backend)
|
|
else:
|
|
raise RuntimeError("Unknown input value type.")
|
|
|
|
return torch._C._get_blas_preferred_backend()
|
|
|
|
|
|
_ROCmFABackends = {
|
|
"default": torch._C._ROCmFABackend.Default,
|
|
"aotriton": torch._C._ROCmFABackend.AOTriton,
|
|
"ck": torch._C._ROCmFABackend.Ck,
|
|
}
|
|
_ROCmFABackends_str = ", ".join(_ROCmFABackends.keys())
|
|
|
|
|
|
from torch._C import _SDPAParams as SDPAParams, _SDPBackend as SDPBackend
|
|
|
|
|
|
def preferred_rocm_fa_library(
|
|
backend: Union[None, str, torch._C._ROCmFABackend] = None,
|
|
) -> torch._C._ROCmFABackend:
|
|
r"""
|
|
[ROCm-only]
|
|
Override the backend PyTorch uses in ROCm environments for Flash Attention. Choose between AOTriton and CK
|
|
|
|
.. warning:: This flag is experimental and subject to change.
|
|
|
|
When Flash Attention is enabled and desired, PyTorch defaults to using AOTriton as the backend.
|
|
This flag (a :class:`str`) allows users to override this backend to use composable_kernel
|
|
|
|
* If `"default"` is set then the default backend will be used wherever possible. Currently AOTriton.
|
|
* If `"aotriton"` is set then AOTriton will be used wherever possible.
|
|
* If `"ck"` is set then CK will be used wherever possible.
|
|
* When no input is given, this function returns the currently preferred library.
|
|
* User may use the environment variable TORCH_ROCM_FA_PREFER_CK=1 to set the preferred library to CK
|
|
globally.
|
|
|
|
Note: When a library is preferred other libraries may still be used if the preferred library
|
|
doesn't implement the operation(s) called.
|
|
This flag may achieve better performance if PyTorch's library selection is incorrect
|
|
for your application's inputs.
|
|
"""
|
|
if backend is None:
|
|
pass
|
|
elif isinstance(backend, str):
|
|
if backend not in _ROCmFABackends:
|
|
raise RuntimeError(
|
|
f"Unknown input value. Choose from: {_ROCmFABackends_str}."
|
|
)
|
|
torch._C._set_rocm_fa_preferred_backend(_ROCmFABackends[backend])
|
|
elif isinstance(backend, torch._C._ROCmFABackend):
|
|
torch._C._set_rocm_fa_preferred_backend(backend)
|
|
else:
|
|
raise ValueError(f"Unknown input value. Choose from: {_ROCmFABackends_str}.")
|
|
|
|
return torch._C._get_rocm_fa_preferred_backend()
|
|
|
|
|
|
# Set the __module__ attribute
|
|
SDPAParams.__module__ = "torch.backends.cuda"
|
|
SDPAParams.__name__ = "SDPAParams"
|
|
|
|
|
|
def flash_sdp_enabled():
|
|
r"""
|
|
.. warning:: This flag is beta and subject to change.
|
|
|
|
Returns whether flash scaled dot product attention is enabled or not.
|
|
"""
|
|
return torch._C._get_flash_sdp_enabled()
|
|
|
|
|
|
def enable_flash_sdp(enabled: bool):
|
|
r"""
|
|
.. warning:: This flag is beta and subject to change.
|
|
|
|
Enables or disables flash scaled dot product attention.
|
|
"""
|
|
torch._C._set_sdp_use_flash(enabled)
|
|
|
|
|
|
def mem_efficient_sdp_enabled():
|
|
r"""
|
|
.. warning:: This flag is beta and subject to change.
|
|
|
|
Returns whether memory efficient scaled dot product attention is enabled or not.
|
|
"""
|
|
return torch._C._get_mem_efficient_sdp_enabled()
|
|
|
|
|
|
def enable_mem_efficient_sdp(enabled: bool):
|
|
r"""
|
|
.. warning:: This flag is beta and subject to change.
|
|
|
|
Enables or disables memory efficient scaled dot product attention.
|
|
"""
|
|
torch._C._set_sdp_use_mem_efficient(enabled)
|
|
|
|
|
|
def math_sdp_enabled():
|
|
r"""
|
|
.. warning:: This flag is beta and subject to change.
|
|
|
|
Returns whether math scaled dot product attention is enabled or not.
|
|
"""
|
|
return torch._C._get_math_sdp_enabled()
|
|
|
|
|
|
def enable_math_sdp(enabled: bool):
|
|
r"""
|
|
.. warning:: This flag is beta and subject to change.
|
|
|
|
Enables or disables math scaled dot product attention.
|
|
"""
|
|
torch._C._set_sdp_use_math(enabled)
|
|
|
|
|
|
def allow_fp16_bf16_reduction_math_sdp(enabled: bool):
|
|
r"""
|
|
.. warning:: This flag is beta and subject to change.
|
|
|
|
Enables or disables fp16/bf16 reduction in math scaled dot product attention.
|
|
"""
|
|
torch._C._set_math_sdp_allow_fp16_bf16_reduction(enabled)
|
|
|
|
|
|
def fp16_bf16_reduction_math_sdp_allowed():
|
|
r"""
|
|
.. warning:: This flag is beta and subject to change.
|
|
|
|
Returns whether fp16/bf16 reduction in math scaled dot product attention is enabled or not.
|
|
"""
|
|
return torch._C._get_math_sdp_allow_fp16_bf16_reduction()
|
|
|
|
|
|
def is_flash_attention_available() -> bool:
|
|
r"""Check if PyTorch was built with FlashAttention for scaled_dot_product_attention.
|
|
|
|
Returns:
|
|
True if FlashAttention is built and available; otherwise, False.
|
|
|
|
Note:
|
|
This function is dependent on a CUDA-enabled build of PyTorch. It will return False
|
|
in non-CUDA environments.
|
|
"""
|
|
return torch._C._is_flash_attention_available()
|
|
|
|
|
|
def can_use_flash_attention(params: SDPAParams, debug: bool = False) -> bool:
|
|
r"""Check if FlashAttention can be utilized in scaled_dot_product_attention.
|
|
|
|
Args:
|
|
params: An instance of SDPAParams containing the tensors for query,
|
|
key, value, an optional attention mask, dropout rate, and
|
|
a flag indicating if the attention is causal.
|
|
debug: Whether to logging.warn debug information as to why FlashAttention could not be run.
|
|
Defaults to False.
|
|
|
|
Returns:
|
|
True if FlashAttention can be used with the given parameters; otherwise, False.
|
|
|
|
Note:
|
|
This function is dependent on a CUDA-enabled build of PyTorch. It will return False
|
|
in non-CUDA environments.
|
|
"""
|
|
return torch._C._can_use_flash_attention(params, debug)
|
|
|
|
|
|
def can_use_efficient_attention(params: SDPAParams, debug: bool = False) -> bool:
|
|
r"""Check if efficient_attention can be utilized in scaled_dot_product_attention.
|
|
|
|
Args:
|
|
params: An instance of SDPAParams containing the tensors for query,
|
|
key, value, an optional attention mask, dropout rate, and
|
|
a flag indicating if the attention is causal.
|
|
debug: Whether to logging.warn with information as to why efficient_attention could not be run.
|
|
Defaults to False.
|
|
|
|
Returns:
|
|
True if efficient_attention can be used with the given parameters; otherwise, False.
|
|
|
|
Note:
|
|
This function is dependent on a CUDA-enabled build of PyTorch. It will return False
|
|
in non-CUDA environments.
|
|
"""
|
|
return torch._C._can_use_mem_efficient_attention(params, debug)
|
|
|
|
|
|
def can_use_cudnn_attention(params: SDPAParams, debug: bool = False) -> bool:
|
|
r"""Check if cudnn_attention can be utilized in scaled_dot_product_attention.
|
|
|
|
Args:
|
|
params: An instance of SDPAParams containing the tensors for query,
|
|
key, value, an optional attention mask, dropout rate, and
|
|
a flag indicating if the attention is causal.
|
|
debug: Whether to logging.warn with information as to why cuDNN attention could not be run.
|
|
Defaults to False.
|
|
|
|
Returns:
|
|
True if cuDNN can be used with the given parameters; otherwise, False.
|
|
|
|
Note:
|
|
This function is dependent on a CUDA-enabled build of PyTorch. It will return False
|
|
in non-CUDA environments.
|
|
"""
|
|
return torch._C._can_use_cudnn_attention(params, debug)
|
|
|
|
|
|
def cudnn_sdp_enabled():
|
|
r"""
|
|
.. warning:: This flag is beta and subject to change.
|
|
|
|
Returns whether cuDNN scaled dot product attention is enabled or not.
|
|
"""
|
|
return torch._C._get_cudnn_sdp_enabled()
|
|
|
|
|
|
def enable_cudnn_sdp(enabled: bool):
|
|
r"""
|
|
.. warning:: This flag is beta and subject to change.
|
|
|
|
Enables or disables cuDNN scaled dot product attention.
|
|
"""
|
|
torch._C._set_sdp_use_cudnn(enabled)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
@deprecated(
|
|
(
|
|
"`torch.backends.cuda.sdp_kernel()` is deprecated. "
|
|
"In the future, this context manager will be removed. "
|
|
"Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, "
|
|
"with updated signature."
|
|
),
|
|
category=FutureWarning,
|
|
)
|
|
def sdp_kernel(
|
|
enable_flash: bool = True,
|
|
enable_math: bool = True,
|
|
enable_mem_efficient: bool = True,
|
|
enable_cudnn: bool = True,
|
|
):
|
|
r"""
|
|
.. warning:: This flag is beta and subject to change.
|
|
|
|
This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention.
|
|
Upon exiting the context manager, the previous state of the flags will be restored.
|
|
"""
|
|
from torch.nn.attention import sdpa_kernel
|
|
|
|
backend_list = []
|
|
if enable_flash:
|
|
backend_list.append(SDPBackend.FLASH_ATTENTION)
|
|
if enable_mem_efficient:
|
|
backend_list.append(SDPBackend.EFFICIENT_ATTENTION)
|
|
if enable_math:
|
|
backend_list.append(SDPBackend.MATH)
|
|
if enable_cudnn:
|
|
backend_list.append(SDPBackend.CUDNN_ATTENTION)
|
|
|
|
with sdpa_kernel(backend_list) as context:
|
|
try:
|
|
yield context
|
|
finally:
|
|
pass
|
|
|
|
|
|
cufft_plan_cache = cuFFTPlanCacheManager()
|
|
matmul = cuBLASModule()
|