mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check --- step 1: uncomment lines in the `pyrefly.toml` file before: https://gist.github.com/maggiemoss/911b4d0bc88bf8cf3ab91f67184e9d46 after: ``` INFO Checking project configured at `/Users/maggiemoss/python_projects/pytorch/pyrefly.toml` INFO 0 errors (1,152 ignored) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164513 Approved by: https://github.com/oulgen
58 lines
1.3 KiB
Python
58 lines
1.3 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
|
|
|
|
__all__ = [
|
|
"version",
|
|
"is_available",
|
|
"get_max_alg_id",
|
|
]
|
|
|
|
try:
|
|
from torch._C import _cusparselt
|
|
except ImportError:
|
|
_cusparselt = None # type: ignore[assignment]
|
|
|
|
__cusparselt_version: Optional[int] = None
|
|
__MAX_ALG_ID: Optional[int] = None
|
|
|
|
if _cusparselt is not None:
|
|
|
|
def _init() -> bool:
|
|
global __cusparselt_version
|
|
global __MAX_ALG_ID
|
|
if __cusparselt_version is None:
|
|
# pyrefly: ignore # missing-attribute
|
|
__cusparselt_version = _cusparselt.getVersionInt()
|
|
if __cusparselt_version == 400:
|
|
__MAX_ALG_ID = 4
|
|
elif __cusparselt_version == 502:
|
|
__MAX_ALG_ID = 5
|
|
elif __cusparselt_version == 602:
|
|
__MAX_ALG_ID = 37
|
|
return True
|
|
|
|
else:
|
|
|
|
def _init() -> bool:
|
|
return False
|
|
|
|
|
|
def version() -> Optional[int]:
|
|
"""Return the version of cuSPARSELt"""
|
|
if not _init():
|
|
return None
|
|
return __cusparselt_version
|
|
|
|
|
|
def is_available() -> bool:
|
|
r"""Return a bool indicating if cuSPARSELt is currently available."""
|
|
return torch._C._has_cusparselt
|
|
|
|
|
|
def get_max_alg_id() -> Optional[int]:
|
|
if not _init():
|
|
return None
|
|
return __MAX_ALG_ID
|