mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Codemod inductor/runtime from Optional to union none (#165605)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165605 Approved by: https://github.com/aorenste ghstack dependencies: #165604
This commit is contained in:
committed by
PyTorch MergeBot
parent
f6daffc54d
commit
ab6014a903
@ -31,7 +31,7 @@ import logging
|
||||
import os
|
||||
import os.path
|
||||
import re
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
@ -115,14 +115,14 @@ class AutotuneCacheArtifact(CacheArtifact):
|
||||
@dataclasses.dataclass
|
||||
class AutotuneCache:
|
||||
configs_hash: str
|
||||
local_cache: Optional[tuple[RemoteCache[JsonDataTy], str]] = None
|
||||
remote_cache: Optional[tuple[RemoteCache[JsonDataTy], str]] = None
|
||||
local_cache: tuple[RemoteCache[JsonDataTy], str] | None = None
|
||||
remote_cache: tuple[RemoteCache[JsonDataTy], str] | None = None
|
||||
|
||||
# Create a AutotuneCache. Returns None if none of the caches can be used.
|
||||
@staticmethod
|
||||
def create(
|
||||
inductor_meta: _InductorMetaTy, filename: str, configs_hash: str
|
||||
) -> Optional[AutotuneCache]:
|
||||
) -> AutotuneCache | None:
|
||||
cache = AutotuneCache(configs_hash)
|
||||
key = AutotuneCache._prepare_key(filename)
|
||||
|
||||
@ -142,7 +142,7 @@ class AutotuneCache:
|
||||
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
||||
|
||||
# Read the best config options from the most local cache and return it.
|
||||
def _read(self) -> Optional[dict[str, JsonDataTy]]:
|
||||
def _read(self) -> dict[str, JsonDataTy] | None:
|
||||
if local_cache := self.local_cache:
|
||||
cache, key = local_cache
|
||||
if best_config := cache.get(key):
|
||||
@ -161,7 +161,7 @@ class AutotuneCache:
|
||||
# which `configs` represents that option.
|
||||
def read_best(
|
||||
self, inductor_meta: _InductorMetaTy, configs: list[Config]
|
||||
) -> Optional[Config]:
|
||||
) -> Config | None:
|
||||
if best := self._read():
|
||||
return _load_cached_autotuning(
|
||||
best, self.configs_hash, configs, inductor_meta
|
||||
@ -272,7 +272,7 @@ class AutotuneCache:
|
||||
config: Config,
|
||||
time_taken_ns: int,
|
||||
found_by_coordesc: bool = False,
|
||||
triton_cache_hash: Optional[str] = None,
|
||||
triton_cache_hash: str | None = None,
|
||||
) -> None:
|
||||
data = {
|
||||
**config.kwargs,
|
||||
@ -414,7 +414,7 @@ class _AutotuneCacheBundlerImpl:
|
||||
|
||||
|
||||
class AutotuneCacheBundler:
|
||||
_bundler: Optional[_AutotuneCacheBundlerImpl] = None
|
||||
_bundler: _AutotuneCacheBundlerImpl | None = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
@ -427,8 +427,8 @@ class AutotuneCacheBundler:
|
||||
cls,
|
||||
inductor_meta: _InductorMetaTy,
|
||||
*,
|
||||
code: Optional[str] = None,
|
||||
code_hash: Optional[str] = None,
|
||||
code: str | None = None,
|
||||
code_hash: str | None = None,
|
||||
) -> None:
|
||||
assert cls._bundler is None
|
||||
|
||||
@ -536,7 +536,7 @@ def _load_cached_autotuning(
|
||||
configs_hash: str,
|
||||
configs: list[Config],
|
||||
inductor_meta: _InductorMetaTy,
|
||||
) -> Optional[Config]:
|
||||
) -> Config | None:
|
||||
if best_config is None:
|
||||
return None
|
||||
if best_config.pop("configs_hash", None) != configs_hash:
|
||||
@ -589,7 +589,7 @@ def _load_cached_autotuning(
|
||||
|
||||
class _LocalAutotuneCacheBackend(RemoteCacheBackend[bytes]):
|
||||
@override
|
||||
def _get(self, key: str) -> Optional[bytes]:
|
||||
def _get(self, key: str) -> bytes | None:
|
||||
try:
|
||||
with open(key, "rb") as fd:
|
||||
return fd.read()
|
||||
@ -611,7 +611,7 @@ class LocalAutotuneCache(RemoteCache[JsonDataTy]):
|
||||
super().__init__(backend, serde)
|
||||
|
||||
@override
|
||||
def _get(self, key: str, sample: Optional[Sample]) -> Optional[JsonDataTy]:
|
||||
def _get(self, key: str, sample: Sample | None) -> JsonDataTy | None:
|
||||
AutotuneCacheBundler.sync()
|
||||
result = super()._get(key, sample)
|
||||
if result is not None:
|
||||
@ -629,7 +629,7 @@ class LocalAutotuneCache(RemoteCache[JsonDataTy]):
|
||||
return result
|
||||
|
||||
@override
|
||||
def _put(self, key: str, value: JsonDataTy, sample: Optional[Sample]) -> None:
|
||||
def _put(self, key: str, value: JsonDataTy, sample: Sample | None) -> None:
|
||||
AutotuneCacheBundler.put(key, value)
|
||||
super()._put(key, value, sample)
|
||||
|
||||
|
@ -4,7 +4,7 @@ import time
|
||||
from functools import cached_property, wraps
|
||||
from itertools import chain
|
||||
from statistics import median
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Union
|
||||
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
|
||||
|
||||
import torch
|
||||
@ -273,7 +273,7 @@ class InductorBenchmarker(TritonBenchmarker): # noqa: docstring_linter
|
||||
benchmark_iters: int = 100,
|
||||
max_benchmark_duration: int = 25,
|
||||
return_mode: str = "min",
|
||||
grad_to_none: Optional[list[torch.Tensor]] = None,
|
||||
grad_to_none: list[torch.Tensor] | None = None,
|
||||
is_vetted_benchmarking: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[float, list[float]]:
|
||||
|
@ -1,5 +1,4 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._environment import is_fbcode
|
||||
@ -9,7 +8,7 @@ def _versioned_config(
|
||||
jk_name: str,
|
||||
this_version: int,
|
||||
oss_default: bool,
|
||||
env_var_override: Optional[str] = None,
|
||||
env_var_override: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
A versioned configuration utility that determines boolean settings based on:
|
||||
|
@ -9,7 +9,7 @@ from abc import ABC, abstractmethod
|
||||
from base64 import b64encode
|
||||
from functools import cache
|
||||
from hashlib import sha256
|
||||
from typing import Any, Optional, Sequence
|
||||
from typing import Any, Sequence
|
||||
from typing_extensions import override, TypedDict
|
||||
|
||||
import torch
|
||||
@ -152,7 +152,7 @@ class _CompileContext(_Context):
|
||||
|
||||
@cache
|
||||
@staticmethod
|
||||
def triton_version_hash() -> Optional[str]:
|
||||
def triton_version_hash() -> str | None:
|
||||
"""Get Triton version key if Triton is available.
|
||||
|
||||
Returns:
|
||||
@ -164,7 +164,7 @@ class _CompileContext(_Context):
|
||||
|
||||
@cache
|
||||
@staticmethod
|
||||
def runtime() -> Optional[str]:
|
||||
def runtime() -> str | None:
|
||||
"""Determine the runtime type based on available backends.
|
||||
|
||||
Returns:
|
||||
@ -174,7 +174,7 @@ class _CompileContext(_Context):
|
||||
|
||||
@cache
|
||||
@staticmethod
|
||||
def runtime_version() -> Optional[str]:
|
||||
def runtime_version() -> str | None:
|
||||
"""Get the version string for the detected runtime.
|
||||
|
||||
Returns:
|
||||
@ -188,7 +188,7 @@ class _CompileContext(_Context):
|
||||
|
||||
@cache
|
||||
@staticmethod
|
||||
def accelerator_properties() -> Optional[str]:
|
||||
def accelerator_properties() -> str | None:
|
||||
"""Get string representation of CUDA device properties.
|
||||
|
||||
Returns:
|
||||
@ -254,7 +254,7 @@ def _isolation_context(
|
||||
("runtime_context", _RuntimeContext),
|
||||
("compile_context", _CompileContext),
|
||||
):
|
||||
selected_context: Optional[dict[str, Any]] = None
|
||||
selected_context: dict[str, Any] | None = None
|
||||
if ischema[context_name] is True: # type: ignore[literal-required]
|
||||
selected_context = {
|
||||
form_of_context: getattr(context_cls, form_of_context)()
|
||||
|
@ -14,7 +14,7 @@ from io import BufferedReader, BufferedWriter
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Generator, Optional
|
||||
from typing import Any, Callable, Generator
|
||||
from typing_extensions import override, TypeAlias
|
||||
|
||||
from filelock import FileLock
|
||||
@ -71,7 +71,7 @@ class _CacheImpl(ABC):
|
||||
self._lock: Lock = Lock()
|
||||
|
||||
@property
|
||||
def lock(self) -> Callable[[Optional[float]], _LockContextManager]:
|
||||
def lock(self) -> Callable[[float | None], _LockContextManager]:
|
||||
"""Get a context manager for acquiring the cache lock.
|
||||
|
||||
Locking of the cache is not done by the implementation itself, but by the
|
||||
@ -87,14 +87,14 @@ class _CacheImpl(ABC):
|
||||
"""
|
||||
|
||||
def _lock_with_timeout(
|
||||
timeout: Optional[float] = None,
|
||||
timeout: float | None = None,
|
||||
) -> _LockContextManager:
|
||||
return locks._acquire_lock_with_timeout(self._lock, timeout)
|
||||
|
||||
return _lock_with_timeout
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: Any) -> Optional[Hit]:
|
||||
def get(self, key: Any) -> Hit | None:
|
||||
"""Retrieve a value from the cache.
|
||||
|
||||
Args:
|
||||
@ -132,7 +132,7 @@ class _InMemoryCacheImpl(_CacheImpl):
|
||||
self._memory: dict[bytes, Any] = {}
|
||||
|
||||
@override
|
||||
def get(self, key: Any) -> Optional[Hit]:
|
||||
def get(self, key: Any) -> Hit | None:
|
||||
"""Retrieve a value from the in-memory cache.
|
||||
|
||||
Args:
|
||||
@ -182,7 +182,7 @@ class _OnDiskCacheImpl(_CacheImpl):
|
||||
_version: int = 0
|
||||
_version_header_length: int = 4
|
||||
|
||||
def __init__(self, sub_dir: Optional[PathLike[str]] = None) -> None:
|
||||
def __init__(self, sub_dir: PathLike[str] | None = None) -> None:
|
||||
"""Initialize the on-disk cache with a specified subdirectory.
|
||||
|
||||
Args:
|
||||
@ -246,7 +246,7 @@ class _OnDiskCacheImpl(_CacheImpl):
|
||||
|
||||
@override
|
||||
@property
|
||||
def lock(self) -> Callable[[Optional[float]], _LockContextManager]:
|
||||
def lock(self) -> Callable[[float | None], _LockContextManager]:
|
||||
"""Get a context manager for acquiring the file lock.
|
||||
|
||||
Uses file locking to ensure thread safety across processes.
|
||||
@ -259,14 +259,14 @@ class _OnDiskCacheImpl(_CacheImpl):
|
||||
"""
|
||||
|
||||
def _lock_with_timeout(
|
||||
timeout: Optional[float] = None,
|
||||
timeout: float | None = None,
|
||||
) -> _LockContextManager:
|
||||
return locks._acquire_flock_with_timeout(self._flock, timeout)
|
||||
|
||||
return _lock_with_timeout
|
||||
|
||||
@override
|
||||
def get(self, key: Any) -> Optional[Hit]:
|
||||
def get(self, key: Any) -> Hit | None:
|
||||
"""Retrieve a value from the on-disk cache.
|
||||
|
||||
Args:
|
||||
@ -281,7 +281,7 @@ class _OnDiskCacheImpl(_CacheImpl):
|
||||
if not fpath.is_file():
|
||||
return None
|
||||
|
||||
pickled_value: Optional[bytes] = None
|
||||
pickled_value: bytes | None = None
|
||||
with open(fpath, "rb") as fp:
|
||||
if self._version_header_matches(fp):
|
||||
pickled_value = fp.read()
|
||||
@ -370,7 +370,7 @@ except ModuleNotFoundError:
|
||||
|
||||
@override
|
||||
@property
|
||||
def lock(self) -> Callable[[Optional[float]], _LockContextManager]:
|
||||
def lock(self) -> Callable[[float | None], _LockContextManager]:
|
||||
"""Get a pseudo lock that does nothing.
|
||||
|
||||
Most remote cache implementations don't have an ability to implement
|
||||
@ -386,14 +386,14 @@ except ModuleNotFoundError:
|
||||
|
||||
@contextmanager
|
||||
def pseudo_lock(
|
||||
timeout: Optional[float] = None,
|
||||
timeout: float | None = None,
|
||||
) -> Generator[None, None, None]:
|
||||
yield
|
||||
|
||||
return pseudo_lock
|
||||
|
||||
@override
|
||||
def get(self, key: Any) -> Optional[Hit]:
|
||||
def get(self, key: Any) -> Hit | None:
|
||||
"""Raise NotImplementedError for remote cache get operations.
|
||||
|
||||
Args:
|
||||
|
@ -11,7 +11,7 @@ The module offers both context manager and manual acquisition patterns:
|
||||
|
||||
from contextlib import contextmanager
|
||||
from threading import Lock
|
||||
from typing import Generator, Optional
|
||||
from typing import Generator
|
||||
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
@ -31,7 +31,7 @@ _DEFAULT_TIMEOUT: float = _BLOCKING_WITH_TIMEOUT
|
||||
@contextmanager
|
||||
def _acquire_lock_with_timeout(
|
||||
lock: Lock,
|
||||
timeout: Optional[float] = None,
|
||||
timeout: float | None = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager that safely acquires a threading.Lock with timeout and automatically releases it.
|
||||
|
||||
@ -65,9 +65,7 @@ def _acquire_lock_with_timeout(
|
||||
lock.release()
|
||||
|
||||
|
||||
def _unsafe_acquire_lock_with_timeout(
|
||||
lock: Lock, timeout: Optional[float] = None
|
||||
) -> None:
|
||||
def _unsafe_acquire_lock_with_timeout(lock: Lock, timeout: float | None = None) -> None:
|
||||
"""Acquire a threading.Lock with timeout without automatic release (unsafe).
|
||||
|
||||
This function acquires a lock with timeout support but does NOT automatically
|
||||
@ -106,7 +104,7 @@ def _unsafe_acquire_lock_with_timeout(
|
||||
@contextmanager
|
||||
def _acquire_flock_with_timeout(
|
||||
flock: FileLock,
|
||||
timeout: Optional[float] = None,
|
||||
timeout: float | None = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager that safely acquires a FileLock with timeout and automatically releases it.
|
||||
|
||||
@ -141,9 +139,7 @@ def _acquire_flock_with_timeout(
|
||||
flock.release()
|
||||
|
||||
|
||||
def _unsafe_acquire_flock_with_timeout(
|
||||
flock: FileLock, timeout: Optional[float]
|
||||
) -> None:
|
||||
def _unsafe_acquire_flock_with_timeout(flock: FileLock, timeout: float | None) -> None:
|
||||
"""Acquire a FileLock with timeout without automatic release (unsafe).
|
||||
|
||||
This function acquires a file lock with timeout support but does NOT automatically
|
||||
|
@ -2,7 +2,7 @@
|
||||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Callable, Optional, TYPE_CHECKING
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
|
||||
from .hints import TRITON_MAX_BLOCK
|
||||
from .runtime_utils import red_text, triton_config_to_hashable
|
||||
@ -257,7 +257,7 @@ class CoordescTuner:
|
||||
self,
|
||||
func: Callable[["triton.Config"], float],
|
||||
baseline_config: "triton.Config",
|
||||
baseline_timing: Optional[float] = None,
|
||||
baseline_timing: float | None = None,
|
||||
) -> "triton.Config":
|
||||
if baseline_timing is None:
|
||||
baseline_timing = self.call_func(func, baseline_config)
|
||||
|
@ -5,7 +5,7 @@ import collections
|
||||
import functools
|
||||
import typing
|
||||
from enum import auto, Enum
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
from torch.utils._triton import has_triton_package
|
||||
|
||||
@ -130,10 +130,10 @@ class DeviceProperties(typing.NamedTuple):
|
||||
index: int # type: ignore[assignment]
|
||||
multi_processor_count: int
|
||||
cc: int
|
||||
major: Optional[int] = None
|
||||
regs_per_multiprocessor: Optional[int] = None
|
||||
max_threads_per_multi_processor: Optional[int] = None
|
||||
warp_size: Optional[int] = None
|
||||
major: int | None = None
|
||||
regs_per_multiprocessor: int | None = None
|
||||
max_threads_per_multi_processor: int | None = None
|
||||
warp_size: int | None = None
|
||||
|
||||
@classmethod
|
||||
@functools.cache
|
||||
@ -174,10 +174,10 @@ class DeviceProperties(typing.NamedTuple):
|
||||
class HalideInputSpec(typing.NamedTuple):
|
||||
ctype: str
|
||||
name: str
|
||||
shape: Optional[list[str]] = None
|
||||
stride: Optional[list[str]] = None
|
||||
offset: Optional[str] = None
|
||||
alias_of: Optional[str] = None
|
||||
shape: list[str] | None = None
|
||||
stride: list[str] | None = None
|
||||
offset: str | None = None
|
||||
alias_of: str | None = None
|
||||
|
||||
def bindings_type(self) -> str:
|
||||
if self.ctype in ("at::Half*", "at::BFloat16*"):
|
||||
@ -201,9 +201,9 @@ class HalideInputSpec(typing.NamedTuple):
|
||||
class HalideMeta(typing.NamedTuple):
|
||||
argtypes: list[HalideInputSpec]
|
||||
target: str
|
||||
scheduler: Optional[str] = None
|
||||
scheduler_flags: Optional[dict[str, Union[int, str]]] = None
|
||||
cuda_device: Optional[int] = None
|
||||
scheduler: str | None = None
|
||||
scheduler_flags: dict[str, Union[int, str]] | None = None
|
||||
cuda_device: int | None = None
|
||||
|
||||
def args(self) -> list[str]:
|
||||
"""Command line args to pass to halide generator"""
|
||||
|
@ -1,6 +1,6 @@
|
||||
import functools
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from .triton_compat import ASTSource, CompiledKernel, knobs as triton_knobs
|
||||
@ -92,9 +92,7 @@ class StaticallyLaunchedCudaKernel:
|
||||
self.has_profile_scratch = needs_scratch_arg("Profile", "profile_scratch_size")
|
||||
|
||||
self.arg_tys = self.arg_ty_from_signature(kernel.src)
|
||||
self.function: Optional[int] = (
|
||||
None # Loaded by load_kernel(on the parent process)
|
||||
)
|
||||
self.function: int | None = None # Loaded by load_kernel(on the parent process)
|
||||
num_ctas = 1
|
||||
if hasattr(kernel, "num_ctas"):
|
||||
num_ctas = kernel.num_ctas
|
||||
|
@ -18,16 +18,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Generic, Literal, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters, set_feature_use
|
||||
@ -119,7 +110,7 @@ def generate_lookup_hash_from_source_code(size_hints_str: str, source_code: str)
|
||||
return fn_hash
|
||||
|
||||
|
||||
def lookup_autotune_config(size_hints, fn) -> Optional[Config]:
|
||||
def lookup_autotune_config(size_hints, fn) -> Config | None:
|
||||
lookup_table = torch._inductor.config.autotune_lookup_table
|
||||
cached_config = None
|
||||
if len(lookup_table) > 0 and "_fused_" in fn.src:
|
||||
@ -157,7 +148,7 @@ def autotune_hints_to_configs(
|
||||
Based on those hints, this function will generate a list of additional autotuning
|
||||
configs to try.
|
||||
"""
|
||||
xyz_options: tuple[tuple[int, Optional[int], Optional[int]], ...]
|
||||
xyz_options: tuple[tuple[int, int | None, int | None], ...]
|
||||
configs: list[Config] = []
|
||||
for hint in hints:
|
||||
if hint == AutotuneHint.ONE_ELEMENT_PER_THREAD:
|
||||
@ -217,8 +208,8 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
|
||||
|
||||
|
||||
def check_autotune_cache(
|
||||
configs: list[Config], filename: Optional[str], inductor_meta: dict[str, Any]
|
||||
) -> tuple[list[Config], Optional[AutotuneCache], dict[str, Any]]:
|
||||
configs: list[Config], filename: str | None, inductor_meta: dict[str, Any]
|
||||
) -> tuple[list[Config], AutotuneCache | None, dict[str, Any]]:
|
||||
"""
|
||||
Given a list of configs, checks autotune cache and return metadata
|
||||
"""
|
||||
@ -285,9 +276,9 @@ class CachingAutotuner(KernelInterface):
|
||||
size_hints=None,
|
||||
inductor_meta=None, # metadata not relevant to triton
|
||||
custom_kernel=False, # whether the kernel is inductor-generated or custom
|
||||
filename: Optional[str] = None,
|
||||
reset_to_zero_arg_names: Optional[list[str]] = None,
|
||||
autotune_cache_info: Optional[dict[str, Any]] = None,
|
||||
filename: str | None = None,
|
||||
reset_to_zero_arg_names: list[str] | None = None,
|
||||
autotune_cache_info: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -367,7 +358,7 @@ class CachingAutotuner(KernelInterface):
|
||||
self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1"
|
||||
|
||||
# Compile-time info included in runtime logginging
|
||||
self.compile_id: Optional[CompileId] = None
|
||||
self.compile_id: CompileId | None = None
|
||||
self.is_backward = False
|
||||
|
||||
# Mode for launch grid calculation
|
||||
@ -419,17 +410,15 @@ class CachingAutotuner(KernelInterface):
|
||||
self.fn = reload_kernel_from_src().fn
|
||||
self.compile_results = [self._precompile_config(best_config)]
|
||||
|
||||
def set_compile_info(
|
||||
self, compile_id: Optional[CompileId], is_backward: bool
|
||||
) -> None:
|
||||
def set_compile_info(self, compile_id: CompileId | None, is_backward: bool) -> None:
|
||||
self.compile_id = compile_id
|
||||
self.is_backward = is_backward
|
||||
|
||||
def precompile(
|
||||
self,
|
||||
warm_cache_only=False,
|
||||
reload_kernel: Optional[Callable[[], CachingAutotuner]] = None,
|
||||
static_triton_bundle_key: Optional[str] = None,
|
||||
reload_kernel: Callable[[], CachingAutotuner] | None = None,
|
||||
static_triton_bundle_key: str | None = None,
|
||||
):
|
||||
if warm_cache_only:
|
||||
self._precompile_worker()
|
||||
@ -492,7 +481,7 @@ class CachingAutotuner(KernelInterface):
|
||||
assert device_prop.regs_per_multiprocessor
|
||||
assert device_prop.max_threads_per_multi_processor
|
||||
assert device_prop.multi_processor_count
|
||||
seen_config_hashes: Optional[OrderedSet[Hashable]] = None
|
||||
seen_config_hashes: OrderedSet[Hashable] | None = None
|
||||
warp_size = device_prop.warp_size or 32
|
||||
for result in self.compile_results:
|
||||
triton_config = result.config
|
||||
@ -638,7 +627,7 @@ class CachingAutotuner(KernelInterface):
|
||||
return old_values
|
||||
|
||||
def restore_after_unpickle(
|
||||
self, old_values: Optional[tuple[Any, Any, Any, Any, Any, Any]]
|
||||
self, old_values: tuple[Any, Any, Any, Any, Any, Any] | None
|
||||
) -> None:
|
||||
if old_values:
|
||||
(
|
||||
@ -1322,7 +1311,7 @@ class CachingAutotuner(KernelInterface):
|
||||
): # type:ignore[override]
|
||||
if hasattr(triton, "set_allocator"):
|
||||
|
||||
def alloc_fn(size: int, align: int, stream: Optional[int]):
|
||||
def alloc_fn(size: int, align: int, stream: int | None):
|
||||
return torch.empty(
|
||||
size, dtype=torch.int8, device=self.device_props.type
|
||||
)
|
||||
@ -1571,7 +1560,7 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
|
||||
inductor_meta: dict[str, Any],
|
||||
triton_meta: dict[str, Any],
|
||||
heuristic_type: HeuristicType,
|
||||
) -> Optional[StaticallyLaunchedCudaKernel]:
|
||||
) -> StaticallyLaunchedCudaKernel | None:
|
||||
if not torch._inductor.config.use_static_cuda_launcher:
|
||||
return None
|
||||
|
||||
@ -1932,12 +1921,12 @@ class TritonCompileResult(CompileResult[CompiledKernel]):
|
||||
|
||||
# in AMD's Triton backend, the global scratch size is never provided
|
||||
# (but for AMD it's safe to pass an extra null arg, so always include it)
|
||||
global_scratch: Optional[int] = getattr(
|
||||
global_scratch: int | None = getattr(
|
||||
kernel_metadata,
|
||||
"global_scratch_size",
|
||||
(0 if torch.version.hip else None),
|
||||
)
|
||||
profile_scratch: Optional[int] = getattr(
|
||||
profile_scratch: int | None = getattr(
|
||||
kernel_metadata, "profile_scratch_size", None
|
||||
)
|
||||
launcher.global_scratch = global_scratch
|
||||
@ -2091,7 +2080,7 @@ def hash_configs(configs: list[Config]):
|
||||
|
||||
|
||||
def cached_autotune(
|
||||
size_hints: Optional[list[int]],
|
||||
size_hints: list[int] | None,
|
||||
configs: list[Config],
|
||||
triton_meta,
|
||||
heuristic_type,
|
||||
|
Reference in New Issue
Block a user