Codemod inductor/runtime from Optional to union none (#165605)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165605
Approved by: https://github.com/aorenste
ghstack dependencies: #165604
This commit is contained in:
Oguz Ulgen
2025-10-15 19:29:51 -07:00
committed by PyTorch MergeBot
parent f6daffc54d
commit ab6014a903
10 changed files with 76 additions and 94 deletions

View File

@ -31,7 +31,7 @@ import logging
import os
import os.path
import re
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, TYPE_CHECKING
from typing_extensions import override
import torch
@ -115,14 +115,14 @@ class AutotuneCacheArtifact(CacheArtifact):
@dataclasses.dataclass
class AutotuneCache:
configs_hash: str
local_cache: Optional[tuple[RemoteCache[JsonDataTy], str]] = None
remote_cache: Optional[tuple[RemoteCache[JsonDataTy], str]] = None
local_cache: tuple[RemoteCache[JsonDataTy], str] | None = None
remote_cache: tuple[RemoteCache[JsonDataTy], str] | None = None
# Create a AutotuneCache. Returns None if none of the caches can be used.
@staticmethod
def create(
inductor_meta: _InductorMetaTy, filename: str, configs_hash: str
) -> Optional[AutotuneCache]:
) -> AutotuneCache | None:
cache = AutotuneCache(configs_hash)
key = AutotuneCache._prepare_key(filename)
@ -142,7 +142,7 @@ class AutotuneCache:
return hashlib.sha256(key.encode("utf-8")).hexdigest()
# Read the best config options from the most local cache and return it.
def _read(self) -> Optional[dict[str, JsonDataTy]]:
def _read(self) -> dict[str, JsonDataTy] | None:
if local_cache := self.local_cache:
cache, key = local_cache
if best_config := cache.get(key):
@ -161,7 +161,7 @@ class AutotuneCache:
# which `configs` represents that option.
def read_best(
self, inductor_meta: _InductorMetaTy, configs: list[Config]
) -> Optional[Config]:
) -> Config | None:
if best := self._read():
return _load_cached_autotuning(
best, self.configs_hash, configs, inductor_meta
@ -272,7 +272,7 @@ class AutotuneCache:
config: Config,
time_taken_ns: int,
found_by_coordesc: bool = False,
triton_cache_hash: Optional[str] = None,
triton_cache_hash: str | None = None,
) -> None:
data = {
**config.kwargs,
@ -414,7 +414,7 @@ class _AutotuneCacheBundlerImpl:
class AutotuneCacheBundler:
_bundler: Optional[_AutotuneCacheBundlerImpl] = None
_bundler: _AutotuneCacheBundlerImpl | None = None
def __init__(self) -> None:
pass
@ -427,8 +427,8 @@ class AutotuneCacheBundler:
cls,
inductor_meta: _InductorMetaTy,
*,
code: Optional[str] = None,
code_hash: Optional[str] = None,
code: str | None = None,
code_hash: str | None = None,
) -> None:
assert cls._bundler is None
@ -536,7 +536,7 @@ def _load_cached_autotuning(
configs_hash: str,
configs: list[Config],
inductor_meta: _InductorMetaTy,
) -> Optional[Config]:
) -> Config | None:
if best_config is None:
return None
if best_config.pop("configs_hash", None) != configs_hash:
@ -589,7 +589,7 @@ def _load_cached_autotuning(
class _LocalAutotuneCacheBackend(RemoteCacheBackend[bytes]):
@override
def _get(self, key: str) -> Optional[bytes]:
def _get(self, key: str) -> bytes | None:
try:
with open(key, "rb") as fd:
return fd.read()
@ -611,7 +611,7 @@ class LocalAutotuneCache(RemoteCache[JsonDataTy]):
super().__init__(backend, serde)
@override
def _get(self, key: str, sample: Optional[Sample]) -> Optional[JsonDataTy]:
def _get(self, key: str, sample: Sample | None) -> JsonDataTy | None:
AutotuneCacheBundler.sync()
result = super()._get(key, sample)
if result is not None:
@ -629,7 +629,7 @@ class LocalAutotuneCache(RemoteCache[JsonDataTy]):
return result
@override
def _put(self, key: str, value: JsonDataTy, sample: Optional[Sample]) -> None:
def _put(self, key: str, value: JsonDataTy, sample: Sample | None) -> None:
AutotuneCacheBundler.put(key, value)
super()._put(key, value, sample)

View File

@ -4,7 +4,7 @@ import time
from functools import cached_property, wraps
from itertools import chain
from statistics import median
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Union
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
import torch
@ -273,7 +273,7 @@ class InductorBenchmarker(TritonBenchmarker): # noqa: docstring_linter
benchmark_iters: int = 100,
max_benchmark_duration: int = 25,
return_mode: str = "min",
grad_to_none: Optional[list[torch.Tensor]] = None,
grad_to_none: list[torch.Tensor] | None = None,
is_vetted_benchmarking: bool = False,
**kwargs: Any,
) -> Union[float, list[float]]:

View File

@ -1,5 +1,4 @@
import os
from typing import Optional
import torch
from torch._environment import is_fbcode
@ -9,7 +8,7 @@ def _versioned_config(
jk_name: str,
this_version: int,
oss_default: bool,
env_var_override: Optional[str] = None,
env_var_override: str | None = None,
) -> bool:
"""
A versioned configuration utility that determines boolean settings based on:

View File

@ -9,7 +9,7 @@ from abc import ABC, abstractmethod
from base64 import b64encode
from functools import cache
from hashlib import sha256
from typing import Any, Optional, Sequence
from typing import Any, Sequence
from typing_extensions import override, TypedDict
import torch
@ -152,7 +152,7 @@ class _CompileContext(_Context):
@cache
@staticmethod
def triton_version_hash() -> Optional[str]:
def triton_version_hash() -> str | None:
"""Get Triton version key if Triton is available.
Returns:
@ -164,7 +164,7 @@ class _CompileContext(_Context):
@cache
@staticmethod
def runtime() -> Optional[str]:
def runtime() -> str | None:
"""Determine the runtime type based on available backends.
Returns:
@ -174,7 +174,7 @@ class _CompileContext(_Context):
@cache
@staticmethod
def runtime_version() -> Optional[str]:
def runtime_version() -> str | None:
"""Get the version string for the detected runtime.
Returns:
@ -188,7 +188,7 @@ class _CompileContext(_Context):
@cache
@staticmethod
def accelerator_properties() -> Optional[str]:
def accelerator_properties() -> str | None:
"""Get string representation of CUDA device properties.
Returns:
@ -254,7 +254,7 @@ def _isolation_context(
("runtime_context", _RuntimeContext),
("compile_context", _CompileContext),
):
selected_context: Optional[dict[str, Any]] = None
selected_context: dict[str, Any] | None = None
if ischema[context_name] is True: # type: ignore[literal-required]
selected_context = {
form_of_context: getattr(context_cls, form_of_context)()

View File

@ -14,7 +14,7 @@ from io import BufferedReader, BufferedWriter
from os import PathLike
from pathlib import Path
from threading import Lock
from typing import Any, Callable, Generator, Optional
from typing import Any, Callable, Generator
from typing_extensions import override, TypeAlias
from filelock import FileLock
@ -71,7 +71,7 @@ class _CacheImpl(ABC):
self._lock: Lock = Lock()
@property
def lock(self) -> Callable[[Optional[float]], _LockContextManager]:
def lock(self) -> Callable[[float | None], _LockContextManager]:
"""Get a context manager for acquiring the cache lock.
Locking of the cache is not done by the implementation itself, but by the
@ -87,14 +87,14 @@ class _CacheImpl(ABC):
"""
def _lock_with_timeout(
timeout: Optional[float] = None,
timeout: float | None = None,
) -> _LockContextManager:
return locks._acquire_lock_with_timeout(self._lock, timeout)
return _lock_with_timeout
@abstractmethod
def get(self, key: Any) -> Optional[Hit]:
def get(self, key: Any) -> Hit | None:
"""Retrieve a value from the cache.
Args:
@ -132,7 +132,7 @@ class _InMemoryCacheImpl(_CacheImpl):
self._memory: dict[bytes, Any] = {}
@override
def get(self, key: Any) -> Optional[Hit]:
def get(self, key: Any) -> Hit | None:
"""Retrieve a value from the in-memory cache.
Args:
@ -182,7 +182,7 @@ class _OnDiskCacheImpl(_CacheImpl):
_version: int = 0
_version_header_length: int = 4
def __init__(self, sub_dir: Optional[PathLike[str]] = None) -> None:
def __init__(self, sub_dir: PathLike[str] | None = None) -> None:
"""Initialize the on-disk cache with a specified subdirectory.
Args:
@ -246,7 +246,7 @@ class _OnDiskCacheImpl(_CacheImpl):
@override
@property
def lock(self) -> Callable[[Optional[float]], _LockContextManager]:
def lock(self) -> Callable[[float | None], _LockContextManager]:
"""Get a context manager for acquiring the file lock.
Uses file locking to ensure thread safety across processes.
@ -259,14 +259,14 @@ class _OnDiskCacheImpl(_CacheImpl):
"""
def _lock_with_timeout(
timeout: Optional[float] = None,
timeout: float | None = None,
) -> _LockContextManager:
return locks._acquire_flock_with_timeout(self._flock, timeout)
return _lock_with_timeout
@override
def get(self, key: Any) -> Optional[Hit]:
def get(self, key: Any) -> Hit | None:
"""Retrieve a value from the on-disk cache.
Args:
@ -281,7 +281,7 @@ class _OnDiskCacheImpl(_CacheImpl):
if not fpath.is_file():
return None
pickled_value: Optional[bytes] = None
pickled_value: bytes | None = None
with open(fpath, "rb") as fp:
if self._version_header_matches(fp):
pickled_value = fp.read()
@ -370,7 +370,7 @@ except ModuleNotFoundError:
@override
@property
def lock(self) -> Callable[[Optional[float]], _LockContextManager]:
def lock(self) -> Callable[[float | None], _LockContextManager]:
"""Get a pseudo lock that does nothing.
Most remote cache implementations don't have an ability to implement
@ -386,14 +386,14 @@ except ModuleNotFoundError:
@contextmanager
def pseudo_lock(
timeout: Optional[float] = None,
timeout: float | None = None,
) -> Generator[None, None, None]:
yield
return pseudo_lock
@override
def get(self, key: Any) -> Optional[Hit]:
def get(self, key: Any) -> Hit | None:
"""Raise NotImplementedError for remote cache get operations.
Args:

View File

@ -11,7 +11,7 @@ The module offers both context manager and manual acquisition patterns:
from contextlib import contextmanager
from threading import Lock
from typing import Generator, Optional
from typing import Generator
from filelock import FileLock, Timeout
@ -31,7 +31,7 @@ _DEFAULT_TIMEOUT: float = _BLOCKING_WITH_TIMEOUT
@contextmanager
def _acquire_lock_with_timeout(
lock: Lock,
timeout: Optional[float] = None,
timeout: float | None = None,
) -> Generator[None, None, None]:
"""Context manager that safely acquires a threading.Lock with timeout and automatically releases it.
@ -65,9 +65,7 @@ def _acquire_lock_with_timeout(
lock.release()
def _unsafe_acquire_lock_with_timeout(
lock: Lock, timeout: Optional[float] = None
) -> None:
def _unsafe_acquire_lock_with_timeout(lock: Lock, timeout: float | None = None) -> None:
"""Acquire a threading.Lock with timeout without automatic release (unsafe).
This function acquires a lock with timeout support but does NOT automatically
@ -106,7 +104,7 @@ def _unsafe_acquire_lock_with_timeout(
@contextmanager
def _acquire_flock_with_timeout(
flock: FileLock,
timeout: Optional[float] = None,
timeout: float | None = None,
) -> Generator[None, None, None]:
"""Context manager that safely acquires a FileLock with timeout and automatically releases it.
@ -141,9 +139,7 @@ def _acquire_flock_with_timeout(
flock.release()
def _unsafe_acquire_flock_with_timeout(
flock: FileLock, timeout: Optional[float]
) -> None:
def _unsafe_acquire_flock_with_timeout(flock: FileLock, timeout: float | None) -> None:
"""Acquire a FileLock with timeout without automatic release (unsafe).
This function acquires a file lock with timeout support but does NOT automatically

View File

@ -2,7 +2,7 @@
import copy
import itertools
import logging
from typing import Callable, Optional, TYPE_CHECKING
from typing import Callable, TYPE_CHECKING
from .hints import TRITON_MAX_BLOCK
from .runtime_utils import red_text, triton_config_to_hashable
@ -257,7 +257,7 @@ class CoordescTuner:
self,
func: Callable[["triton.Config"], float],
baseline_config: "triton.Config",
baseline_timing: Optional[float] = None,
baseline_timing: float | None = None,
) -> "triton.Config":
if baseline_timing is None:
baseline_timing = self.call_func(func, baseline_config)

View File

@ -5,7 +5,7 @@ import collections
import functools
import typing
from enum import auto, Enum
from typing import Optional, Union
from typing import Union
from torch.utils._triton import has_triton_package
@ -130,10 +130,10 @@ class DeviceProperties(typing.NamedTuple):
index: int # type: ignore[assignment]
multi_processor_count: int
cc: int
major: Optional[int] = None
regs_per_multiprocessor: Optional[int] = None
max_threads_per_multi_processor: Optional[int] = None
warp_size: Optional[int] = None
major: int | None = None
regs_per_multiprocessor: int | None = None
max_threads_per_multi_processor: int | None = None
warp_size: int | None = None
@classmethod
@functools.cache
@ -174,10 +174,10 @@ class DeviceProperties(typing.NamedTuple):
class HalideInputSpec(typing.NamedTuple):
ctype: str
name: str
shape: Optional[list[str]] = None
stride: Optional[list[str]] = None
offset: Optional[str] = None
alias_of: Optional[str] = None
shape: list[str] | None = None
stride: list[str] | None = None
offset: str | None = None
alias_of: str | None = None
def bindings_type(self) -> str:
if self.ctype in ("at::Half*", "at::BFloat16*"):
@ -201,9 +201,9 @@ class HalideInputSpec(typing.NamedTuple):
class HalideMeta(typing.NamedTuple):
argtypes: list[HalideInputSpec]
target: str
scheduler: Optional[str] = None
scheduler_flags: Optional[dict[str, Union[int, str]]] = None
cuda_device: Optional[int] = None
scheduler: str | None = None
scheduler_flags: dict[str, Union[int, str]] | None = None
cuda_device: int | None = None
def args(self) -> list[str]:
"""Command line args to pass to halide generator"""

View File

@ -1,6 +1,6 @@
import functools
import os
from typing import Any, Optional
from typing import Any
from typing_extensions import Unpack
from .triton_compat import ASTSource, CompiledKernel, knobs as triton_knobs
@ -92,9 +92,7 @@ class StaticallyLaunchedCudaKernel:
self.has_profile_scratch = needs_scratch_arg("Profile", "profile_scratch_size")
self.arg_tys = self.arg_ty_from_signature(kernel.src)
self.function: Optional[int] = (
None # Loaded by load_kernel(on the parent process)
)
self.function: int | None = None # Loaded by load_kernel(on the parent process)
num_ctas = 1
if hasattr(kernel, "num_ctas"):
num_ctas = kernel.num_ctas

View File

@ -18,16 +18,7 @@ import sys
import threading
import time
from collections import namedtuple
from typing import (
Any,
Callable,
Generic,
Literal,
Optional,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import Any, Callable, Generic, Literal, TYPE_CHECKING, TypeVar, Union
import torch
from torch._dynamo.utils import counters, set_feature_use
@ -119,7 +110,7 @@ def generate_lookup_hash_from_source_code(size_hints_str: str, source_code: str)
return fn_hash
def lookup_autotune_config(size_hints, fn) -> Optional[Config]:
def lookup_autotune_config(size_hints, fn) -> Config | None:
lookup_table = torch._inductor.config.autotune_lookup_table
cached_config = None
if len(lookup_table) > 0 and "_fused_" in fn.src:
@ -157,7 +148,7 @@ def autotune_hints_to_configs(
Based on those hints, this function will generate a list of additional autotuning
configs to try.
"""
xyz_options: tuple[tuple[int, Optional[int], Optional[int]], ...]
xyz_options: tuple[tuple[int, int | None, int | None], ...]
configs: list[Config] = []
for hint in hints:
if hint == AutotuneHint.ONE_ELEMENT_PER_THREAD:
@ -217,8 +208,8 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
def check_autotune_cache(
configs: list[Config], filename: Optional[str], inductor_meta: dict[str, Any]
) -> tuple[list[Config], Optional[AutotuneCache], dict[str, Any]]:
configs: list[Config], filename: str | None, inductor_meta: dict[str, Any]
) -> tuple[list[Config], AutotuneCache | None, dict[str, Any]]:
"""
Given a list of configs, checks autotune cache and return metadata
"""
@ -285,9 +276,9 @@ class CachingAutotuner(KernelInterface):
size_hints=None,
inductor_meta=None, # metadata not relevant to triton
custom_kernel=False, # whether the kernel is inductor-generated or custom
filename: Optional[str] = None,
reset_to_zero_arg_names: Optional[list[str]] = None,
autotune_cache_info: Optional[dict[str, Any]] = None,
filename: str | None = None,
reset_to_zero_arg_names: list[str] | None = None,
autotune_cache_info: dict[str, Any] | None = None,
):
super().__init__()
@ -367,7 +358,7 @@ class CachingAutotuner(KernelInterface):
self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1"
# Compile-time info included in runtime logginging
self.compile_id: Optional[CompileId] = None
self.compile_id: CompileId | None = None
self.is_backward = False
# Mode for launch grid calculation
@ -419,17 +410,15 @@ class CachingAutotuner(KernelInterface):
self.fn = reload_kernel_from_src().fn
self.compile_results = [self._precompile_config(best_config)]
def set_compile_info(
self, compile_id: Optional[CompileId], is_backward: bool
) -> None:
def set_compile_info(self, compile_id: CompileId | None, is_backward: bool) -> None:
self.compile_id = compile_id
self.is_backward = is_backward
def precompile(
self,
warm_cache_only=False,
reload_kernel: Optional[Callable[[], CachingAutotuner]] = None,
static_triton_bundle_key: Optional[str] = None,
reload_kernel: Callable[[], CachingAutotuner] | None = None,
static_triton_bundle_key: str | None = None,
):
if warm_cache_only:
self._precompile_worker()
@ -492,7 +481,7 @@ class CachingAutotuner(KernelInterface):
assert device_prop.regs_per_multiprocessor
assert device_prop.max_threads_per_multi_processor
assert device_prop.multi_processor_count
seen_config_hashes: Optional[OrderedSet[Hashable]] = None
seen_config_hashes: OrderedSet[Hashable] | None = None
warp_size = device_prop.warp_size or 32
for result in self.compile_results:
triton_config = result.config
@ -638,7 +627,7 @@ class CachingAutotuner(KernelInterface):
return old_values
def restore_after_unpickle(
self, old_values: Optional[tuple[Any, Any, Any, Any, Any, Any]]
self, old_values: tuple[Any, Any, Any, Any, Any, Any] | None
) -> None:
if old_values:
(
@ -1322,7 +1311,7 @@ class CachingAutotuner(KernelInterface):
): # type:ignore[override]
if hasattr(triton, "set_allocator"):
def alloc_fn(size: int, align: int, stream: Optional[int]):
def alloc_fn(size: int, align: int, stream: int | None):
return torch.empty(
size, dtype=torch.int8, device=self.device_props.type
)
@ -1571,7 +1560,7 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
inductor_meta: dict[str, Any],
triton_meta: dict[str, Any],
heuristic_type: HeuristicType,
) -> Optional[StaticallyLaunchedCudaKernel]:
) -> StaticallyLaunchedCudaKernel | None:
if not torch._inductor.config.use_static_cuda_launcher:
return None
@ -1932,12 +1921,12 @@ class TritonCompileResult(CompileResult[CompiledKernel]):
# in AMD's Triton backend, the global scratch size is never provided
# (but for AMD it's safe to pass an extra null arg, so always include it)
global_scratch: Optional[int] = getattr(
global_scratch: int | None = getattr(
kernel_metadata,
"global_scratch_size",
(0 if torch.version.hip else None),
)
profile_scratch: Optional[int] = getattr(
profile_scratch: int | None = getattr(
kernel_metadata, "profile_scratch_size", None
)
launcher.global_scratch = global_scratch
@ -2091,7 +2080,7 @@ def hash_configs(configs: list[Config]):
def cached_autotune(
size_hints: Optional[list[int]],
size_hints: list[int] | None,
configs: list[Config],
triton_meta,
heuristic_type,