mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Codemod codecache.py from Optional to union none (#165604)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165604 Approved by: https://github.com/aorenste
This commit is contained in:
committed by
PyTorch MergeBot
parent
66b75693ae
commit
f6daffc54d
@ -34,17 +34,7 @@ from pathlib import Path
|
||||
from tempfile import _TemporaryFileWrapper
|
||||
from time import time, time_ns
|
||||
from types import ModuleType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Generic,
|
||||
NoReturn,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, cast, Generic, NoReturn, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import override, Self
|
||||
|
||||
import torch
|
||||
@ -258,7 +248,7 @@ class CacheBase:
|
||||
|
||||
|
||||
class LocalCache(CacheBase):
|
||||
def lookup(self, *keys: str) -> Optional[dict[str, Any]]:
|
||||
def lookup(self, *keys: str) -> dict[str, Any] | None:
|
||||
cache = self.get_local_cache()
|
||||
|
||||
sub_cache = cache
|
||||
@ -288,8 +278,8 @@ class PersistentCache(CacheBase):
|
||||
choices: list[ChoiceCaller],
|
||||
op: str,
|
||||
inputs: str,
|
||||
benchmark: Optional[Callable[[Any], dict[ChoiceCaller, float]]],
|
||||
hint_override: Optional[int] = None,
|
||||
benchmark: Callable[[Any], dict[ChoiceCaller, float]] | None,
|
||||
hint_override: int | None = None,
|
||||
) -> dict[ChoiceCaller, float]:
|
||||
"""
|
||||
Check to see if we have benchmarked the given choice callers. For each
|
||||
@ -424,7 +414,7 @@ def write(
|
||||
extra: str = "",
|
||||
hash_type: str = "code",
|
||||
specified_dir: str = "",
|
||||
key: Optional[str] = None,
|
||||
key: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
if key is None:
|
||||
# use striped content to compute hash so we don't end up with different
|
||||
@ -937,7 +927,7 @@ class FxGraphHashDetails:
|
||||
# - if any of them are set to custom callables, we will need to cache miss
|
||||
# Future work is for someone to find any places where these functions are used
|
||||
# and force them to be of type CustomGraphPass, so we can guarantee serialization.
|
||||
def _get_custom_pass_detail_unsafe(self, custom_pass: Any) -> Optional[Any]:
|
||||
def _get_custom_pass_detail_unsafe(self, custom_pass: Any) -> Any | None:
|
||||
if not custom_pass:
|
||||
return None
|
||||
if isinstance(custom_pass, list):
|
||||
@ -954,7 +944,7 @@ class FxGraphHashDetails:
|
||||
|
||||
def _get_custom_pass_detail(
|
||||
self, custom_pass: Union[CustomGraphPassType, CustomGraphModulePass]
|
||||
) -> Optional[Any]:
|
||||
) -> Any | None:
|
||||
if not custom_pass:
|
||||
return None
|
||||
assert isinstance(custom_pass, (CustomGraphPass, CustomGraphModulePass))
|
||||
@ -962,7 +952,7 @@ class FxGraphHashDetails:
|
||||
|
||||
def _get_custom_partitioner_fn_detail(
|
||||
self, custom_partitioner_fn: CustomPartitionerFnType
|
||||
) -> Optional[Any]:
|
||||
) -> Any | None:
|
||||
if not custom_partitioner_fn:
|
||||
return None
|
||||
assert isinstance(custom_partitioner_fn, CustomPartitionerFn)
|
||||
@ -1032,7 +1022,7 @@ class GuardedCache(Generic[T]):
|
||||
def iterate_over_candidates(
|
||||
cls: type[GuardedCache[T]],
|
||||
local: bool,
|
||||
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
||||
remote_cache: RemoteCache[JsonDataTy] | None,
|
||||
key: str,
|
||||
) -> Generator[tuple[T, bytes], None, None]:
|
||||
if local:
|
||||
@ -1067,10 +1057,10 @@ class GuardedCache(Generic[T]):
|
||||
cls: type[GuardedCache[T]],
|
||||
key: str,
|
||||
local: bool,
|
||||
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
||||
remote_cache: RemoteCache[JsonDataTy] | None,
|
||||
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool],
|
||||
hints: list[int],
|
||||
) -> tuple[Optional[T], Optional[bytes], dict[str, str]]:
|
||||
) -> tuple[T | None, bytes | None, dict[str, str]]:
|
||||
"""
|
||||
Find the first cache entry in iterate_over_candidates that passes `evaluate_guards`.
|
||||
|
||||
@ -1134,7 +1124,7 @@ class GuardedCache(Generic[T]):
|
||||
return [s for s in inputs if isinstance(s, torch.SymInt) and has_hint(s)]
|
||||
|
||||
@classmethod
|
||||
def _get_shape_env(cls: type[GuardedCache[T]]) -> Optional[ShapeEnv]:
|
||||
def _get_shape_env(cls: type[GuardedCache[T]]) -> ShapeEnv | None:
|
||||
"""
|
||||
Helper to get the shape env from the tracing context.
|
||||
"""
|
||||
@ -1205,7 +1195,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||
graph: CompiledFxGraph,
|
||||
cache_info: dict[str, Any],
|
||||
constants: CompiledFxGraphConstants,
|
||||
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
|
||||
) -> tuple[CompiledFxGraph | None, dict[str, Any]]:
|
||||
"""
|
||||
Cache specific post compile steps that need to run if we find a graph in the cache
|
||||
This includes putting bundled triton artifacts in the right place,
|
||||
@ -1300,12 +1290,11 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||
key: str,
|
||||
example_inputs: Sequence[InputType],
|
||||
local: bool,
|
||||
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
||||
remote_cache: RemoteCache[JsonDataTy] | None,
|
||||
constants: CompiledFxGraphConstants,
|
||||
evaluate_guards: Optional[
|
||||
Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
|
||||
] = None,
|
||||
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
|
||||
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
|
||||
| None = None,
|
||||
) -> tuple[CompiledFxGraph | None, dict[str, Any]]:
|
||||
"""
|
||||
Lookup a compiled graph in the cache by key. On a hit, return the
|
||||
deserialized CompiledFxGraph object. On a miss, return None.
|
||||
@ -1373,7 +1362,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||
compiled_graph: OutputCode,
|
||||
example_inputs: Sequence[InputType],
|
||||
local: bool,
|
||||
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
||||
remote_cache: RemoteCache[JsonDataTy] | None,
|
||||
) -> None:
|
||||
"""
|
||||
Store a serialized CompiledFxGraph on disk.
|
||||
@ -1502,7 +1491,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||
fx_kwargs: _CompileFxKwargs,
|
||||
inputs_to_check: Sequence[int],
|
||||
remote: bool,
|
||||
) -> tuple[Optional[tuple[str, list[str]]], dict[str, Any]]:
|
||||
) -> tuple[tuple[str, list[str]] | None, dict[str, Any]]:
|
||||
"""
|
||||
Checks that the inductor input is cacheable, then computes
|
||||
and returns the cache key for the input.
|
||||
@ -1533,7 +1522,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||
return (key, debug_lines), {}
|
||||
|
||||
@staticmethod
|
||||
def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]:
|
||||
def get_remote_cache() -> RemoteCache[JsonDataTy] | None:
|
||||
"""
|
||||
Attempts to load the remote cache, returns None on error.
|
||||
"""
|
||||
@ -1551,13 +1540,12 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||
debug_lines: list[str],
|
||||
example_inputs: Sequence[InputType],
|
||||
local: bool,
|
||||
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
||||
remote_cache: RemoteCache[JsonDataTy] | None,
|
||||
is_backward: bool,
|
||||
constants: CompiledFxGraphConstants,
|
||||
evaluate_guards: Optional[
|
||||
Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
|
||||
] = None,
|
||||
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
|
||||
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
|
||||
| None = None,
|
||||
) -> tuple[CompiledFxGraph | None, dict[str, Any]]:
|
||||
"""
|
||||
Lookup the graph with the given key, and return results and metadata.
|
||||
Doesn't do any logging on its own, because AOTAutograd handles a cache miss
|
||||
@ -1655,11 +1643,11 @@ class CudaKernelParamCache:
|
||||
def set(
|
||||
cls,
|
||||
key: str,
|
||||
params: dict[str, Optional[str]],
|
||||
params: dict[str, str | None],
|
||||
cubin: str,
|
||||
bin_type: str,
|
||||
asm: Optional[str] = None,
|
||||
asm_type: Optional[str] = None,
|
||||
asm: str | None = None,
|
||||
asm_type: str | None = None,
|
||||
) -> None:
|
||||
basename = None
|
||||
if config.aot_inductor.package_cpp_only:
|
||||
@ -1712,7 +1700,7 @@ class CudaKernelParamCache:
|
||||
cls.cache[key] = params
|
||||
|
||||
@classmethod
|
||||
def get(cls, key: str) -> Optional[dict[str, Any]]:
|
||||
def get(cls, key: str) -> dict[str, Any] | None:
|
||||
return cls.cache.get(key, None)
|
||||
|
||||
@classmethod
|
||||
@ -1731,7 +1719,7 @@ class AotCodeCompiler:
|
||||
graph: GraphLowering,
|
||||
wrapper_code: str,
|
||||
kernel_code: str,
|
||||
serialized_extern_kernel_nodes: Optional[str],
|
||||
serialized_extern_kernel_nodes: str | None,
|
||||
*,
|
||||
device_type: str,
|
||||
additional_files: list[str],
|
||||
@ -2564,7 +2552,7 @@ end
|
||||
return output_so
|
||||
|
||||
|
||||
_libgomp: Optional[CDLL] = None
|
||||
_libgomp: CDLL | None = None
|
||||
|
||||
|
||||
def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p, None]:
|
||||
@ -2687,7 +2675,7 @@ def _precompile_header(
|
||||
return header_full_path
|
||||
|
||||
|
||||
def _get_cpp_prefix_header(device: str) -> Optional[str]:
|
||||
def _get_cpp_prefix_header(device: str) -> str | None:
|
||||
if device.startswith("cpu"):
|
||||
return "torch/csrc/inductor/cpp_prefix.h"
|
||||
return None
|
||||
@ -2755,7 +2743,7 @@ class CppCodeCache:
|
||||
device_type: str = "cpu",
|
||||
submit_fn: Any = None,
|
||||
extra_flags: Sequence[str] = (),
|
||||
optimized_code: Optional[str] = None,
|
||||
optimized_code: str | None = None,
|
||||
) -> Any:
|
||||
"""Compile and load a C++ library. Returns a callable that returns the loaded
|
||||
library."""
|
||||
@ -2814,7 +2802,7 @@ class CppCodeCache:
|
||||
from torch.utils._filelock import FileLock
|
||||
|
||||
lock_path = os.path.join(get_lock_dir(), key + ".lock")
|
||||
future: Optional[Future[Any]] = None
|
||||
future: Future[Any] | None = None
|
||||
lib = None
|
||||
|
||||
# if requested, pre-compile any headers
|
||||
@ -3053,7 +3041,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
|
||||
num_outputs: int = -1,
|
||||
submit_fn: Any = None,
|
||||
extra_flags: Sequence[str] = (),
|
||||
kernel_code: Optional[str] = None,
|
||||
kernel_code: str | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Wrap a C++ function in fast Python bindings.
|
||||
@ -3175,7 +3163,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
|
||||
class HalideCodeCache(CppPythonBindingsCodeCache):
|
||||
cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
_standalone_runtime_path: Optional[str] = None
|
||||
_standalone_runtime_path: str | None = None
|
||||
prefix = textwrap.dedent(
|
||||
"""
|
||||
#include "{halideruntime_h}"
|
||||
@ -3606,8 +3594,8 @@ class PyCodeCache:
|
||||
cls,
|
||||
key: str,
|
||||
path: str,
|
||||
linemap: Optional[list[tuple[int, str]]] = None,
|
||||
attrs: Optional[dict[str, Any]] = None,
|
||||
linemap: list[tuple[int, str]] | None = None,
|
||||
attrs: dict[str, Any] | None = None,
|
||||
) -> ModuleType:
|
||||
if linemap is None:
|
||||
linemap = []
|
||||
@ -3655,7 +3643,7 @@ class PyCodeCache:
|
||||
@functools.cache
|
||||
def stack_frames_for_code(
|
||||
cls, path: str, lineno: int
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
) -> list[dict[str, Any]] | None:
|
||||
if path not in cls.linemaps:
|
||||
return None
|
||||
if len(cls.linemaps[path]) == 0:
|
||||
@ -3688,7 +3676,7 @@ def _load_triton_kernel_from_source(
|
||||
return getattr(PyCodeCache.load(source_code), kernel_name)
|
||||
|
||||
|
||||
def _cuda_compiler() -> Optional[str]:
|
||||
def _cuda_compiler() -> str | None:
|
||||
if cuda_env.nvcc_exist(config.cuda.cuda_cxx):
|
||||
return config.cuda.cuda_cxx
|
||||
if config.is_fbcode():
|
||||
@ -3855,7 +3843,7 @@ def cuda_compile_command(
|
||||
src_files: list[str],
|
||||
dst_file: str,
|
||||
dst_file_ext: str,
|
||||
extra_args: Optional[list[str]] = None,
|
||||
extra_args: list[str] | None = None,
|
||||
) -> str:
|
||||
if extra_args is None:
|
||||
extra_args = []
|
||||
@ -3993,7 +3981,7 @@ class CUDACodeCache:
|
||||
class CacheEntry:
|
||||
input_path: str
|
||||
output_path: str
|
||||
error_json: Optional[str] = None
|
||||
error_json: str | None = None
|
||||
|
||||
cache: dict[str, CacheEntry] = {}
|
||||
aot_kernels_o: list[str] = []
|
||||
@ -4008,7 +3996,7 @@ class CUDACodeCache:
|
||||
@lru_cache(maxsize=4)
|
||||
def get_kernel_binary_remote_cache(
|
||||
caching_enabled: bool, caching_available: bool
|
||||
) -> Optional[Any]:
|
||||
) -> Any | None:
|
||||
"""
|
||||
Get or create the class instance of the CUTLASSKernelBinaryRemoteCache.
|
||||
|
||||
@ -4069,7 +4057,7 @@ class CUDACodeCache:
|
||||
|
||||
@classmethod
|
||||
def compile(
|
||||
cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None
|
||||
cls, source_code: str, dst_file_ext: str, extra_args: list[str] | None = None
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
Compiles CUDA source_code into a file with dst_file_ext extension.
|
||||
@ -4279,7 +4267,7 @@ class ROCmCodeCache:
|
||||
|
||||
@classmethod
|
||||
def compile(
|
||||
cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None
|
||||
cls, source_code: str, dst_file_ext: str, extra_args: list[str] | None = None
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
Compiles source_code into a file with dst_file_ext extension,
|
||||
@ -4352,7 +4340,7 @@ class CodeCacheFuture:
|
||||
|
||||
class LambdaFuture(CodeCacheFuture):
|
||||
def __init__(
|
||||
self, result_fn: Callable[..., Any], future: Optional[Future[Any]] = None
|
||||
self, result_fn: Callable[..., Any], future: Future[Any] | None = None
|
||||
) -> None:
|
||||
self.result_fn = result_fn
|
||||
self.future = future
|
||||
@ -4373,7 +4361,7 @@ class StaticAutotunerFuture(CodeCacheFuture):
|
||||
# we need to reload the CachingAutotuner from its source code
|
||||
# We don't store the source code on the CachingAutotuner itself
|
||||
# since it can be very large.
|
||||
self.reload_kernel_from_src: Optional[Callable[[], Any]] = None
|
||||
self.reload_kernel_from_src: Callable[[], Any] | None = None
|
||||
|
||||
def result(self) -> CachingAutotuner:
|
||||
assert self.reload_kernel_from_src is not None
|
||||
|
Reference in New Issue
Block a user