Codemod codecache.py from Optional to union none (#165604)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165604
Approved by: https://github.com/aorenste
This commit is contained in:
Oguz Ulgen
2025-10-15 19:29:50 -07:00
committed by PyTorch MergeBot
parent 66b75693ae
commit f6daffc54d

View File

@ -34,17 +34,7 @@ from pathlib import Path
from tempfile import _TemporaryFileWrapper
from time import time, time_ns
from types import ModuleType
from typing import (
Any,
Callable,
cast,
Generic,
NoReturn,
Optional,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import Any, Callable, cast, Generic, NoReturn, TYPE_CHECKING, TypeVar, Union
from typing_extensions import override, Self
import torch
@ -258,7 +248,7 @@ class CacheBase:
class LocalCache(CacheBase):
def lookup(self, *keys: str) -> Optional[dict[str, Any]]:
def lookup(self, *keys: str) -> dict[str, Any] | None:
cache = self.get_local_cache()
sub_cache = cache
@ -288,8 +278,8 @@ class PersistentCache(CacheBase):
choices: list[ChoiceCaller],
op: str,
inputs: str,
benchmark: Optional[Callable[[Any], dict[ChoiceCaller, float]]],
hint_override: Optional[int] = None,
benchmark: Callable[[Any], dict[ChoiceCaller, float]] | None,
hint_override: int | None = None,
) -> dict[ChoiceCaller, float]:
"""
Check to see if we have benchmarked the given choice callers. For each
@ -424,7 +414,7 @@ def write(
extra: str = "",
hash_type: str = "code",
specified_dir: str = "",
key: Optional[str] = None,
key: str | None = None,
) -> tuple[str, str]:
if key is None:
# use striped content to compute hash so we don't end up with different
@ -937,7 +927,7 @@ class FxGraphHashDetails:
# - if any of them are set to custom callables, we will need to cache miss
# Future work is for someone to find any places where these functions are used
# and force them to be of type CustomGraphPass, so we can guarantee serialization.
def _get_custom_pass_detail_unsafe(self, custom_pass: Any) -> Optional[Any]:
def _get_custom_pass_detail_unsafe(self, custom_pass: Any) -> Any | None:
if not custom_pass:
return None
if isinstance(custom_pass, list):
@ -954,7 +944,7 @@ class FxGraphHashDetails:
def _get_custom_pass_detail(
self, custom_pass: Union[CustomGraphPassType, CustomGraphModulePass]
) -> Optional[Any]:
) -> Any | None:
if not custom_pass:
return None
assert isinstance(custom_pass, (CustomGraphPass, CustomGraphModulePass))
@ -962,7 +952,7 @@ class FxGraphHashDetails:
def _get_custom_partitioner_fn_detail(
self, custom_partitioner_fn: CustomPartitionerFnType
) -> Optional[Any]:
) -> Any | None:
if not custom_partitioner_fn:
return None
assert isinstance(custom_partitioner_fn, CustomPartitionerFn)
@ -1032,7 +1022,7 @@ class GuardedCache(Generic[T]):
def iterate_over_candidates(
cls: type[GuardedCache[T]],
local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]],
remote_cache: RemoteCache[JsonDataTy] | None,
key: str,
) -> Generator[tuple[T, bytes], None, None]:
if local:
@ -1067,10 +1057,10 @@ class GuardedCache(Generic[T]):
cls: type[GuardedCache[T]],
key: str,
local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]],
remote_cache: RemoteCache[JsonDataTy] | None,
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool],
hints: list[int],
) -> tuple[Optional[T], Optional[bytes], dict[str, str]]:
) -> tuple[T | None, bytes | None, dict[str, str]]:
"""
Find the first cache entry in iterate_over_candidates that passes `evaluate_guards`.
@ -1134,7 +1124,7 @@ class GuardedCache(Generic[T]):
return [s for s in inputs if isinstance(s, torch.SymInt) and has_hint(s)]
@classmethod
def _get_shape_env(cls: type[GuardedCache[T]]) -> Optional[ShapeEnv]:
def _get_shape_env(cls: type[GuardedCache[T]]) -> ShapeEnv | None:
"""
Helper to get the shape env from the tracing context.
"""
@ -1205,7 +1195,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
graph: CompiledFxGraph,
cache_info: dict[str, Any],
constants: CompiledFxGraphConstants,
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
) -> tuple[CompiledFxGraph | None, dict[str, Any]]:
"""
Cache specific post compile steps that need to run if we find a graph in the cache
This includes putting bundled triton artifacts in the right place,
@ -1300,12 +1290,11 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
key: str,
example_inputs: Sequence[InputType],
local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]],
remote_cache: RemoteCache[JsonDataTy] | None,
constants: CompiledFxGraphConstants,
evaluate_guards: Optional[
Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
] = None,
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
| None = None,
) -> tuple[CompiledFxGraph | None, dict[str, Any]]:
"""
Lookup a compiled graph in the cache by key. On a hit, return the
deserialized CompiledFxGraph object. On a miss, return None.
@ -1373,7 +1362,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
compiled_graph: OutputCode,
example_inputs: Sequence[InputType],
local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]],
remote_cache: RemoteCache[JsonDataTy] | None,
) -> None:
"""
Store a serialized CompiledFxGraph on disk.
@ -1502,7 +1491,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
fx_kwargs: _CompileFxKwargs,
inputs_to_check: Sequence[int],
remote: bool,
) -> tuple[Optional[tuple[str, list[str]]], dict[str, Any]]:
) -> tuple[tuple[str, list[str]] | None, dict[str, Any]]:
"""
Checks that the inductor input is cacheable, then computes
and returns the cache key for the input.
@ -1533,7 +1522,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
return (key, debug_lines), {}
@staticmethod
def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]:
def get_remote_cache() -> RemoteCache[JsonDataTy] | None:
"""
Attempts to load the remote cache, returns None on error.
"""
@ -1551,13 +1540,12 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
debug_lines: list[str],
example_inputs: Sequence[InputType],
local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]],
remote_cache: RemoteCache[JsonDataTy] | None,
is_backward: bool,
constants: CompiledFxGraphConstants,
evaluate_guards: Optional[
Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
] = None,
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
| None = None,
) -> tuple[CompiledFxGraph | None, dict[str, Any]]:
"""
Lookup the graph with the given key, and return results and metadata.
Doesn't do any logging on its own, because AOTAutograd handles a cache miss
@ -1655,11 +1643,11 @@ class CudaKernelParamCache:
def set(
cls,
key: str,
params: dict[str, Optional[str]],
params: dict[str, str | None],
cubin: str,
bin_type: str,
asm: Optional[str] = None,
asm_type: Optional[str] = None,
asm: str | None = None,
asm_type: str | None = None,
) -> None:
basename = None
if config.aot_inductor.package_cpp_only:
@ -1712,7 +1700,7 @@ class CudaKernelParamCache:
cls.cache[key] = params
@classmethod
def get(cls, key: str) -> Optional[dict[str, Any]]:
def get(cls, key: str) -> dict[str, Any] | None:
return cls.cache.get(key, None)
@classmethod
@ -1731,7 +1719,7 @@ class AotCodeCompiler:
graph: GraphLowering,
wrapper_code: str,
kernel_code: str,
serialized_extern_kernel_nodes: Optional[str],
serialized_extern_kernel_nodes: str | None,
*,
device_type: str,
additional_files: list[str],
@ -2564,7 +2552,7 @@ end
return output_so
_libgomp: Optional[CDLL] = None
_libgomp: CDLL | None = None
def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p, None]:
@ -2687,7 +2675,7 @@ def _precompile_header(
return header_full_path
def _get_cpp_prefix_header(device: str) -> Optional[str]:
def _get_cpp_prefix_header(device: str) -> str | None:
if device.startswith("cpu"):
return "torch/csrc/inductor/cpp_prefix.h"
return None
@ -2755,7 +2743,7 @@ class CppCodeCache:
device_type: str = "cpu",
submit_fn: Any = None,
extra_flags: Sequence[str] = (),
optimized_code: Optional[str] = None,
optimized_code: str | None = None,
) -> Any:
"""Compile and load a C++ library. Returns a callable that returns the loaded
library."""
@ -2814,7 +2802,7 @@ class CppCodeCache:
from torch.utils._filelock import FileLock
lock_path = os.path.join(get_lock_dir(), key + ".lock")
future: Optional[Future[Any]] = None
future: Future[Any] | None = None
lib = None
# if requested, pre-compile any headers
@ -3053,7 +3041,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
num_outputs: int = -1,
submit_fn: Any = None,
extra_flags: Sequence[str] = (),
kernel_code: Optional[str] = None,
kernel_code: str | None = None,
) -> Any:
"""
Wrap a C++ function in fast Python bindings.
@ -3175,7 +3163,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
class HalideCodeCache(CppPythonBindingsCodeCache):
cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
cache_clear = staticmethod(cache.clear)
_standalone_runtime_path: Optional[str] = None
_standalone_runtime_path: str | None = None
prefix = textwrap.dedent(
"""
#include "{halideruntime_h}"
@ -3606,8 +3594,8 @@ class PyCodeCache:
cls,
key: str,
path: str,
linemap: Optional[list[tuple[int, str]]] = None,
attrs: Optional[dict[str, Any]] = None,
linemap: list[tuple[int, str]] | None = None,
attrs: dict[str, Any] | None = None,
) -> ModuleType:
if linemap is None:
linemap = []
@ -3655,7 +3643,7 @@ class PyCodeCache:
@functools.cache
def stack_frames_for_code(
cls, path: str, lineno: int
) -> Optional[list[dict[str, Any]]]:
) -> list[dict[str, Any]] | None:
if path not in cls.linemaps:
return None
if len(cls.linemaps[path]) == 0:
@ -3688,7 +3676,7 @@ def _load_triton_kernel_from_source(
return getattr(PyCodeCache.load(source_code), kernel_name)
def _cuda_compiler() -> Optional[str]:
def _cuda_compiler() -> str | None:
if cuda_env.nvcc_exist(config.cuda.cuda_cxx):
return config.cuda.cuda_cxx
if config.is_fbcode():
@ -3855,7 +3843,7 @@ def cuda_compile_command(
src_files: list[str],
dst_file: str,
dst_file_ext: str,
extra_args: Optional[list[str]] = None,
extra_args: list[str] | None = None,
) -> str:
if extra_args is None:
extra_args = []
@ -3993,7 +3981,7 @@ class CUDACodeCache:
class CacheEntry:
input_path: str
output_path: str
error_json: Optional[str] = None
error_json: str | None = None
cache: dict[str, CacheEntry] = {}
aot_kernels_o: list[str] = []
@ -4008,7 +3996,7 @@ class CUDACodeCache:
@lru_cache(maxsize=4)
def get_kernel_binary_remote_cache(
caching_enabled: bool, caching_available: bool
) -> Optional[Any]:
) -> Any | None:
"""
Get or create the class instance of the CUTLASSKernelBinaryRemoteCache.
@ -4069,7 +4057,7 @@ class CUDACodeCache:
@classmethod
def compile(
cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None
cls, source_code: str, dst_file_ext: str, extra_args: list[str] | None = None
) -> tuple[str, str, str]:
"""
Compiles CUDA source_code into a file with dst_file_ext extension.
@ -4279,7 +4267,7 @@ class ROCmCodeCache:
@classmethod
def compile(
cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None
cls, source_code: str, dst_file_ext: str, extra_args: list[str] | None = None
) -> tuple[str, str, str]:
"""
Compiles source_code into a file with dst_file_ext extension,
@ -4352,7 +4340,7 @@ class CodeCacheFuture:
class LambdaFuture(CodeCacheFuture):
def __init__(
self, result_fn: Callable[..., Any], future: Optional[Future[Any]] = None
self, result_fn: Callable[..., Any], future: Future[Any] | None = None
) -> None:
self.result_fn = result_fn
self.future = future
@ -4373,7 +4361,7 @@ class StaticAutotunerFuture(CodeCacheFuture):
# we need to reload the CachingAutotuner from its source code
# We don't store the source code on the CachingAutotuner itself
# since it can be very large.
self.reload_kernel_from_src: Optional[Callable[[], Any]] = None
self.reload_kernel_from_src: Callable[[], Any] | None = None
def result(self) -> CachingAutotuner:
assert self.reload_kernel_from_src is not None