PEP585 update - torch/_higher_order_ops torch/_subclasses torch/backends torch/compiler torch/cuda torch/masked torch/mtia torch/nested (#145202)

See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145202
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-19 21:35:06 -08:00
committed by PyTorch MergeBot
parent 54a00af2c6
commit 805c4b597a
39 changed files with 482 additions and 511 deletions

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import functools
import itertools
from typing import Any, Callable, List
from typing import Any, Callable
import torch
import torch._prims_common as utils
@ -335,7 +335,7 @@ def generic_associative_scan(operator, leaves, dim=0):
def trace_associative_scan(
proxy_mode, func_overload, combine_fn: Callable, xs: List[torch.Tensor], dim: int
proxy_mode, func_overload, combine_fn: Callable, xs: list[torch.Tensor], dim: int
):
with disable_proxy_modes_tracing():
sample_xs = [first_slice_copy(x, dim) for x in itertools.chain(xs, xs)]
@ -415,7 +415,7 @@ def associative_scan_functionalize(ctx, combine_fn, xs, dim):
def _fake_associative_scan(combine_fn, xs, dim, reverse=False): # noqa: F811
inp_leaves, spec = pytree.tree_flatten(xs)
result_flat: List[Any] = []
result_flat: list[Any] = []
num_leaves = len(inp_leaves)
op = reversed if reverse else lambda x: x

View File

@ -1,8 +1,9 @@
# mypy: allow-untyped-defs
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Optional, Union
import torch
import torch.utils._pytree as pytree
@ -32,7 +33,7 @@ class ViewInfo(ABC):
self.base_index = base_index
@abstractmethod
def regenerate_view(self, bases_list: List[Tensor]):
def regenerate_view(self, bases_list: list[Tensor]):
pass
@ -48,7 +49,7 @@ class AsStridedViewInfo(ViewInfo):
self.stride = stride
self.storage_offset = storage_offset
def regenerate_view(self, bases_list: List[Tensor]):
def regenerate_view(self, bases_list: list[Tensor]):
return torch.as_strided(
bases_list[self.base_index],
self.size,
@ -69,7 +70,7 @@ class SliceViewInfo(ViewInfo):
self.start = start
self.end = end
def regenerate_view(self, bases_list: List[Tensor]):
def regenerate_view(self, bases_list: list[Tensor]):
return torch.ops.aten.slice.Tensor(
bases_list[self.base_index], self.dim, self.start, self.end
)
@ -80,7 +81,7 @@ class AliasViewInfo(ViewInfo):
def __init__(self, base_index):
super().__init__(base_index)
def regenerate_view(self, bases_list: List[Tensor]):
def regenerate_view(self, bases_list: list[Tensor]):
return torch.ops.aten.alias.default(bases_list[self.base_index])
@ -89,7 +90,7 @@ class NotView(ViewInfo):
def __init__(self, base_index):
super().__init__(base_index)
def regenerate_view(self, bases_list: List[Tensor]):
def regenerate_view(self, bases_list: list[Tensor]):
return bases_list[self.base_index]
@ -137,10 +138,10 @@ def try_use_slice(base, tensor):
def write_view_information_to_args(
mutable_arg_names: List[str],
mutable_arg_types: List[torch.Type],
kwargs: Dict[str, Any],
arg_to_base_index: Dict[str, Any],
mutable_arg_names: list[str],
mutable_arg_types: list[torch.Type],
kwargs: dict[str, Any],
arg_to_base_index: dict[str, Any],
):
"""
This function writes the view information into kwargs. It reads mutable_args from kwargs.
@ -215,10 +216,10 @@ def write_view_information_to_args(
# Returns a dict of arg_name -> ViewInfo | [ViewInfo]
def read_view_information_from_args(
mutable_arg_names: List[str],
mutable_arg_types: List[torch.Type],
kwargs: Dict[str, Any],
all_bases: List[Tensor],
mutable_arg_names: list[str],
mutable_arg_types: list[torch.Type],
kwargs: dict[str, Any],
all_bases: list[Tensor],
):
"""
This reads the view information added by `write_view_information_to_args` from kwargs, pop them,
@ -254,7 +255,7 @@ def read_view_information_from_args(
# This means that the argument is the base tensor
return NotView(base_index)
args_view_info: Dict[str, Any] = {}
args_view_info: dict[str, Any] = {}
for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types):
if isinstance(arg_type, torch.ListType):
length = get_arg(f"_{arg_name}_length")
@ -321,7 +322,7 @@ class AutoFunctionalized(HigherOrderOperator):
/,
_mutable_op: OpOverload,
**kwargs: Any,
) -> Tuple[Any, Tuple[Tensor, ...]]:
) -> tuple[Any, tuple[Tensor, ...]]:
assert can_auto_functionalize(_mutable_op)
assert isinstance(kwargs, dict)
return super().__call__(_mutable_op, **kwargs)
@ -350,7 +351,7 @@ class AutoFunctionalizedV2(HigherOrderOperator):
/,
_mutable_op: OpOverload,
**kwargs: Any,
) -> Tuple[Any, Tuple[Tensor, ...]]:
) -> tuple[Any, tuple[Tensor, ...]]:
assert can_auto_functionalize(_mutable_op)
assert isinstance(kwargs, dict)
return super().__call__(_mutable_op, **kwargs)
@ -411,7 +412,7 @@ def can_auto_functionalize(op: OperatorBase) -> bool:
return True
def get_mutable_args(op: OpOverload) -> Tuple[List[str], List[torch.Type]]:
def get_mutable_args(op: OpOverload) -> tuple[list[str], list[torch.Type]]:
"""
Returns the list of argument names that get mutated according to the
schema and their types.
@ -432,8 +433,8 @@ def get_mutable_args(op: OpOverload) -> Tuple[List[str], List[torch.Type]]:
def do_auto_functionalize(
op: OpOverload,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Any:
"""Functionalizes a call to op(*args, **kwargs) by emitting a call to
`outs = auto_functionalized(op, normalized_kwargs)`
@ -476,7 +477,7 @@ def do_auto_functionalize(
# List of the name of args that get mutated (according to the schema)
mutable_args_names, _ = get_mutable_args(op)
unwrapped_actual_out: Union[Any, Tuple[Any]] = unwrapped_outs[
unwrapped_actual_out: Union[Any, tuple[Any]] = unwrapped_outs[
: -len(mutable_args_names)
]
unwrapped_mutable_out = unwrapped_outs[-len(mutable_args_names) :]
@ -521,8 +522,8 @@ def do_auto_functionalize(
def do_auto_functionalize_v2(
op: OpOverload,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Any:
from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
@ -553,7 +554,7 @@ def do_auto_functionalize_v2(
all_bases_addresses: list[int] = []
# Map arg_name to the index of its base in all_bases.
arg_to_base_index: Dict[str, Any] = {}
arg_to_base_index: dict[str, Any] = {}
def update_dict(tensor, arg_name, index=None):
base = tensor if get_base(tensor) is None else get_base(tensor)
@ -613,7 +614,7 @@ def do_auto_functionalize_v2(
op, **dict(unwrapped_kwargs, _all_bases=all_basis_unwrapped) # type: ignore[arg-type]
)
unwrapped_actual_out: Union[Any, Tuple[Any]] = (
unwrapped_actual_out: Union[Any, tuple[Any]] = (
unwrapped_outs if len(all_bases) == 0 else unwrapped_outs[: -len(all_bases)]
)
@ -661,9 +662,9 @@ def do_auto_functionalize_v2(
@auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd)
def auto_functionalized_dense(
_mutable_op: OpOverload,
_only_clone_these_tensors: Optional[Tuple[str, ...]] = None,
_only_clone_these_tensors: Optional[tuple[str, ...]] = None,
**kwargs: Any,
) -> Tuple[Any, Tuple[Tensor, ...]]:
) -> tuple[Any, tuple[Tensor, ...]]:
new_kwargs = dict(**kwargs)
result = []
@ -698,7 +699,7 @@ def auto_functionalized_fake(
mode,
_mutable_op: OpOverload,
**kwargs: Any,
) -> Tuple[Any, Tuple[Tensor, ...]]:
) -> tuple[Any, tuple[Tensor, ...]]:
with mode:
result = auto_functionalized_dense(
_mutable_op, _only_clone_these_tensors=None, **kwargs
@ -711,7 +712,7 @@ def auto_functionalized_proxy(
mode,
_mutable_op: OpOverload,
**kwargs: Any,
) -> Tuple[Any, Tuple[Tensor, ...]]:
) -> tuple[Any, tuple[Tensor, ...]]:
with disable_proxy_modes_tracing():
out = auto_functionalized(_mutable_op, **kwargs)
@ -738,10 +739,10 @@ def auto_functionalized_func(ctx, _mutable_op, **kwargs):
@auto_functionalized_v2.py_impl(DispatchKey.CompositeExplicitAutograd)
def auto_functionalized_v2_dense(
_mutable_op: OpOverload,
_only_clone_these_bases: Optional[Tuple[int, ...]] = None,
_only_clone_these_bases: Optional[tuple[int, ...]] = None,
**kwargs: Any,
) -> Tuple[Any, Tuple[Tensor, ...]]:
all_bases: List[Tensor] = kwargs.pop("_all_bases", [])
) -> tuple[Any, tuple[Tensor, ...]]:
all_bases: list[Tensor] = kwargs.pop("_all_bases", [])
mutable_args_names, mutable_args_types = get_mutable_args(_mutable_op)
args_view_info = read_view_information_from_args(
mutable_args_names, mutable_args_types, kwargs, all_bases
@ -794,8 +795,8 @@ def auto_functionalized_v2_dense(
def auto_functionalized_v2_fake(
mode,
_mutable_op: OpOverload,
**kwargs: Dict[str, Any],
) -> Tuple[Any, Tuple[Tensor, ...]]:
**kwargs: dict[str, Any],
) -> tuple[Any, tuple[Tensor, ...]]:
with mode:
result = auto_functionalized_v2_dense(
_mutable_op, _only_clone_these_bases=None, **kwargs
@ -807,8 +808,8 @@ def auto_functionalized_v2_fake(
def auto_functionalized_v2_proxy(
mode,
_mutable_op: OpOverload,
**kwargs: Dict[str, Any],
) -> Tuple[Any, Tuple[Tensor, ...]]:
**kwargs: dict[str, Any],
) -> tuple[Any, tuple[Tensor, ...]]:
with disable_proxy_modes_tracing():
out = auto_functionalized_v2(_mutable_op, **kwargs)

View File

@ -3,7 +3,7 @@
import contextlib
import logging
import warnings
from typing import Any, Callable, List, Tuple, Union
from typing import Any, Callable, Union
import torch
import torch._subclasses.functional_tensor
@ -71,7 +71,7 @@ def cond(
pred: Union[bool, int, float, torch.Tensor],
true_fn: Callable,
false_fn: Callable,
operands: Union[Tuple, List] = (),
operands: Union[tuple, list] = (),
) -> Any:
r"""
Conditionally applies `true_fn` or `false_fn`.

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
from enum import Enum
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Optional, Union
from weakref import WeakKeyDictionary
import torch
@ -71,9 +71,9 @@ class WithEffects(HigherOrderOperator):
self,
token,
op: OpType,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> Tuple[Any, ...]:
*args: tuple[Any, ...],
**kwargs: dict[str, Any],
) -> tuple[Any, ...]:
assert isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
assert not has_aliasing(op), "Ops with aliasing is not supported"
assert has_effects(op, args, kwargs)
@ -133,9 +133,9 @@ def new_token_tensor() -> torch.Tensor:
def with_effects_dense(
token: torch.Tensor,
op: torch._ops.OpOverload,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, ...]:
*args: tuple[Any, ...],
**kwargs: dict[str, Any],
) -> tuple[torch.Tensor, ...]:
out = op(*args, **kwargs)
new_token = new_token_tensor()
if isinstance(out, tuple):
@ -148,9 +148,9 @@ def with_effects_fake(
mode,
token: torch.Tensor,
op: torch._ops.OpOverload,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, ...]:
*args: tuple[Any, ...],
**kwargs: dict[str, Any],
) -> tuple[torch.Tensor, ...]:
with mode:
result = with_effects_dense(token, op, *args, **kwargs)
return result
@ -161,9 +161,9 @@ def with_effects_proxy(
mode,
token: torch.Tensor,
op: torch._ops.OpOverload,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, ...]:
*args: tuple[Any, ...],
**kwargs: dict[str, Any],
) -> tuple[torch.Tensor, ...]:
with disable_proxy_modes_tracing():
out = with_effects(token, op, *args, **kwargs)
@ -202,10 +202,10 @@ def _get_schema(op, args) -> torch.FunctionSchema:
def handle_effects(
allow_token_discovery: bool,
tokens: Dict[_EffectType, torch.Tensor],
tokens: dict[_EffectType, torch.Tensor],
op: OpType,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Any:
"""
Args:

View File

@ -1,5 +1,6 @@
import math
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
import torch
import torch.utils._pytree as pytree
@ -48,7 +49,7 @@ def _construct_strides(
return strides
def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch.Tensor:
def _permute_strides(out: torch.Tensor, query_strides: tuple[int, ...]) -> torch.Tensor:
"""
Create a new tensor with the same data and shape as the input,
but with strides permuted based on the input tensor's stride order.
@ -81,12 +82,12 @@ class FlexAttentionHOP(HigherOrderOperator):
key: torch.Tensor,
value: torch.Tensor,
score_mod: Callable,
block_mask: Tuple,
block_mask: tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor]:
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple = (),
mask_mod_other_buffers: tuple = (),
) -> tuple[torch.Tensor, torch.Tensor]:
validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers)
return super().__call__(
query,
@ -119,13 +120,13 @@ class FlexAttentionBackwardHOP(HigherOrderOperator):
grad_logsumexp: torch.Tensor,
fw_graph: Union[Callable, GraphModule],
joint_graph: GraphModule,
block_mask: Tuple,
block_mask: tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple = (),
mask_mod_other_buffers: tuple = (),
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
]:
validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers)
return super().__call__(
@ -154,12 +155,12 @@ def _math_attention_inner(
key: torch.Tensor,
value: torch.Tensor,
score_mod: Callable,
block_mask: Tuple,
block_mask: tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor]:
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple = (),
mask_mod_other_buffers: tuple = (),
) -> tuple[torch.Tensor, torch.Tensor]:
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
working_precision = torch.float64 if query.dtype == torch.float64 else torch.float32
@ -197,12 +198,12 @@ def math_attention(
key: torch.Tensor,
value: torch.Tensor,
score_mod: Callable,
block_mask: Tuple,
block_mask: tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor]:
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple = (),
mask_mod_other_buffers: tuple = (),
) -> tuple[torch.Tensor, torch.Tensor]:
"""Eager implementation
This implementation uses vmap to vectorize the score_mod function over the batch, head, m, and n dimensions.
@ -256,12 +257,12 @@ def sdpa_dense(
key: torch.Tensor,
value: torch.Tensor,
score_mod: Callable,
block_mask: Tuple,
block_mask: tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor]:
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple = (),
mask_mod_other_buffers: tuple = (),
) -> tuple[torch.Tensor, torch.Tensor]:
out, lse = math_attention(
query,
key,
@ -283,12 +284,12 @@ def trace_flex_attention(
key: torch.Tensor,
value: torch.Tensor,
score_mod: Callable,
block_mask: Tuple,
block_mask: tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor]:
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple = (),
mask_mod_other_buffers: tuple = (),
) -> tuple[torch.Tensor, torch.Tensor]:
"""Traces the flex_attention operator with the given score_mod function and other_buffers.
Trace SDPA will call make_fx with "fake" example vals and then trace the score_mod function
@ -353,12 +354,12 @@ def flex_attention_proxy_torch_dispatch_mode(
key: torch.Tensor,
value: torch.Tensor,
score_mod: Callable,
block_mask: Tuple,
block_mask: tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor]:
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple = (),
mask_mod_other_buffers: tuple = (),
) -> tuple[torch.Tensor, torch.Tensor]:
assert mode is not None, "Mode should always be enabled for python fallback key"
return trace_flex_attention(
mode,
@ -381,12 +382,12 @@ def flex_attention_functionalize(
key: torch.Tensor,
value: torch.Tensor,
score_mod: Callable,
block_mask: Tuple,
block_mask: tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor]:
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple = (),
mask_mod_other_buffers: tuple = (),
) -> tuple[torch.Tensor, torch.Tensor]:
"""Defines the functionalization rules for the flex_attention operator.
Write now we are unwrapping each tensor and then redispatching to the next, however we want to
@ -443,7 +444,7 @@ def flex_attention_functionalize(
def flex_attention_fake_impl(
query: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# TODO: Figure out a better way to handle this for NJT than using sum()
if query.is_nested:
out = torch.empty_like(query, memory_format=torch.contiguous_format)
@ -466,12 +467,12 @@ def flex_attention_fake_tensor_mode(
key: torch.Tensor,
value: torch.Tensor,
score_mod: Callable,
block_mask: Tuple,
block_mask: tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor]:
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple = (),
mask_mod_other_buffers: tuple = (),
) -> tuple[torch.Tensor, torch.Tensor]:
with mode:
out, logsumexp = flex_attention_fake_impl(query, value)
return out, logsumexp
@ -480,9 +481,9 @@ def flex_attention_fake_tensor_mode(
# ---------------------------- Autograd Implementation ----------------------------
def create_fw_bw_graph(
score_mod: Callable,
index_values: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor],
other_buffers: Tuple[Tensor, ...],
) -> Tuple[Callable, Callable]:
index_values: tuple[Tensor, Tensor, Tensor, Tensor, Tensor],
other_buffers: tuple[Tensor, ...],
) -> tuple[Callable, Callable]:
# See Note:[HOP create fw_bw graph]
# All of these imports need to be here in order to avoid circular dependencies
@ -554,11 +555,11 @@ def create_fw_bw_graph(
m: Tensor,
n: Tensor,
example_grad: Tensor,
*other_buffers: Tuple[Tensor, ...],
) -> Tuple[Tensor, ...]:
*other_buffers: tuple[Tensor, ...],
) -> tuple[Tensor, ...]:
def fw_with_masks(
*args: Tuple[Tensor, ...]
) -> Tuple[Tuple[Tensor], Tuple[bool]]:
*args: tuple[Tensor, ...]
) -> tuple[tuple[Tensor], tuple[bool]]:
fw_out = score_mod(*args)
out_requires_grad = fw_out.requires_grad
return ((fw_out,), (out_requires_grad,))
@ -585,12 +586,12 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
value: Tensor,
fw_graph: Callable,
joint_graph: Callable,
block_mask: Tuple[Any, ...],
block_mask: tuple[Any, ...],
scale: float,
kernel_options: Dict[str, Any],
mask_mod_other_buffers: Tuple[Any, ...],
*score_mod_other_buffers: Tuple[Any, ...],
) -> Tuple[torch.Tensor, torch.Tensor]:
kernel_options: dict[str, Any],
mask_mod_other_buffers: tuple[Any, ...],
*score_mod_other_buffers: tuple[Any, ...],
) -> tuple[torch.Tensor, torch.Tensor]:
any_buffer_requires_grad = any(
buffer.requires_grad
for buffer in mask_mod_other_buffers
@ -634,7 +635,7 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
return out, logsumexp
@staticmethod
def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Optional[Tensor], ...]: # type: ignore[override]
def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> tuple[Optional[Tensor], ...]: # type: ignore[override]
fw_args = saved_tensors_and_symints(ctx)
(
query,
@ -714,12 +715,12 @@ def flex_attention_autograd(
key: torch.Tensor,
value: torch.Tensor,
score_mod: Callable,
block_mask: Tuple,
block_mask: tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple[Tensor, ...] = (),
mask_mod_other_buffers: Tuple[Tensor, ...] = (),
) -> Tuple[torch.Tensor, torch.Tensor]:
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple[Tensor, ...] = (),
mask_mod_other_buffers: tuple[Tensor, ...] = (),
) -> tuple[torch.Tensor, torch.Tensor]:
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
with TransformGetItemToIndex():
@ -769,13 +770,13 @@ def sdpa_dense_backward(
grad_logsumexp: torch.Tensor,
fw_graph: Callable, # GraphModule type hint?
joint_graph: Callable,
block_mask: Tuple,
block_mask: tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple,
mask_mod_other_buffers: Tuple,
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple,
mask_mod_other_buffers: tuple,
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
]:
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
@ -921,13 +922,13 @@ def trace_flex_attention_backward(
grad_logsumexp: torch.Tensor,
fw_graph: Union[Callable, GraphModule],
joint_graph: GraphModule,
block_mask: Tuple,
block_mask: tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple = (),
mask_mod_other_buffers: tuple = (),
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
]:
"""We already have the forward graph and joint graph from the forward pass, so we create a proxy attach both graphs"""
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
@ -1018,13 +1019,13 @@ def flex_attention_backward_proxy_torch_dispatch_mode(
grad_logsumexp: torch.Tensor,
fw_graph: Union[Callable, GraphModule],
joint_graph: GraphModule,
block_mask: Tuple,
block_mask: tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple = (),
mask_mod_other_buffers: tuple = (),
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
]:
assert mode is not None, "Mode should always be enabled for python fallback key"
return trace_flex_attention_backward(
@ -1058,13 +1059,13 @@ def flex_attention_backward_functionalize(
grad_logsumexp: torch.Tensor,
fw_graph: Union[Callable, GraphModule],
joint_graph: GraphModule,
block_mask: Tuple,
block_mask: tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple = (),
mask_mod_other_buffers: tuple = (),
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
]:
"""Defines the functionalization rules for the flex_attention operator.
@ -1136,13 +1137,13 @@ def flex_attention_backward_fake_tensor_mode(
grad_logsumexp: torch.Tensor,
fw_graph: Union[Callable, GraphModule],
joint_graph: GraphModule,
block_mask: Tuple,
block_mask: tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple = (),
mask_mod_other_buffers: tuple = (),
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
]:
with mode:
grad_query = torch.empty_like(query)

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable
from torch._higher_order_ops.prim_hop_base import FunctionWithNoFreeVars, PrimHOPBase
@ -18,7 +18,7 @@ _foreach_map = ForeachMap()
def foreach_map(
op: Callable, operands: Any, *unused: Tuple[Any], **kwargs: Dict[str, Any]
op: Callable, operands: Any, *unused: tuple[Any], **kwargs: dict[str, Any]
):
from torch._dynamo.polyfills import foreach_map_fn

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
from typing import List, Optional, Tuple, Union
from typing import Optional, Union
import torch
import torch.utils._pytree as pytree
@ -42,8 +42,8 @@ class InvokeSubgraphHOP(HigherOrderOperator):
subgraph: GraphModule,
identifier: Optional[str],
operands: Union[
List[Union[torch.Tensor, int, torch.SymInt]],
Tuple[Union[torch.Tensor, int, torch.SymInt]],
list[Union[torch.Tensor, int, torch.SymInt]],
tuple[Union[torch.Tensor, int, torch.SymInt]],
],
):
assert identifier is None or isinstance(

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import functools
import itertools
from typing import Any, Callable, List, Tuple
from typing import Any, Callable
import torch
import torch._prims_common as utils
@ -45,20 +45,20 @@ def wrap_combine_fn_flat(
return [*carry_flat, *combined_flat]
def _extract_carry_and_out(flat_out: List[Any], num_carry: int):
def _extract_carry_and_out(flat_out: list[Any], num_carry: int):
return flat_out[:num_carry], flat_out[num_carry:]
def scan(
combine_fn: Callable[
[pytree.PyTree, pytree.PyTree], Tuple[pytree.PyTree, pytree.PyTree]
[pytree.PyTree, pytree.PyTree], tuple[pytree.PyTree, pytree.PyTree]
],
init: pytree.PyTree,
xs: pytree.PyTree,
*,
dim: int = 0,
reverse: bool = False,
) -> Tuple[pytree.PyTree, pytree.PyTree]:
) -> tuple[pytree.PyTree, pytree.PyTree]:
r"""
Performs an inclusive scan with a combine function.
@ -331,11 +331,11 @@ def trace_scan(
proxy_mode,
func_overload,
combine_fn: Callable,
init: List[torch.Tensor],
xs: List[torch.Tensor],
init: list[torch.Tensor],
xs: list[torch.Tensor],
dim: int,
reverse: bool,
additional_inputs: List[torch.Tensor],
additional_inputs: list[torch.Tensor],
):
from torch._dynamo.utils import clone_input

View File

@ -5,17 +5,8 @@ import inspect
import logging
import threading
from collections import defaultdict
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
Union,
)
from collections.abc import Sequence
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing_extensions import Never
import sympy
@ -48,9 +39,9 @@ if TYPE_CHECKING:
from torch.fx.proxy import Proxy
from torch.utils._triton import has_triton
TritonMetaParamsType = Dict[str, int]
TritonGridTupleType = Tuple[Union[int, sympy.Expr, SymInt], ...]
TritonGridCallableType = Callable[[TritonMetaParamsType], Tuple[int, ...]]
TritonMetaParamsType = dict[str, int]
TritonGridTupleType = tuple[Union[int, sympy.Expr, SymInt], ...]
TritonGridCallableType = Callable[[TritonMetaParamsType], tuple[int, ...]]
TritonGridType = Union[TritonGridTupleType, TritonGridCallableType]
if has_triton():
@ -76,11 +67,11 @@ log = logging.getLogger("torch._dynamo")
# conisting of list of dims, list of block dims, and element size. E.g., for this
# call in host-side Triton TMA API ``create_2d_tma_descriptor(ptr, 50, 60, 32, 15, 4)``,
# the metadata will look like ``([50, 60], [32, 15], 4)``. All ints can be SymInts.
TMADescriptorMetadata = Dict[
TMADescriptorMetadata = dict[
str, # kernel parameter name
Tuple[
List[Union[int, SymInt]], # dims
List[Union[int, SymInt]], # block_dims
tuple[
list[Union[int, SymInt]], # dims
list[Union[int, SymInt]], # block_dims
Union[int, SymInt], # element_size
],
]
@ -95,9 +86,9 @@ TMADescriptorMetadata = Dict[
# Use a side table.
# We use two dicts so that fetching both the kernel and id are O(1)
class KernelSideTable:
id_to_kernel: Dict[int, "TritonKernelType"] = {}
kernel_to_id: Dict["TritonKernelType", int] = {}
constant_args: Dict[int, Dict[str, Any]] = {}
id_to_kernel: dict[int, "TritonKernelType"] = {}
kernel_to_id: dict["TritonKernelType", int] = {}
constant_args: dict[int, dict[str, Any]] = {}
lock = threading.Lock()
# Returns index on the table
@ -119,14 +110,14 @@ class KernelSideTable:
# Not every constant arg can be added to the graph. Use this side table
# for constant args.
def add_constant_args(self, args: Dict[str, Any]) -> int:
def add_constant_args(self, args: dict[str, Any]) -> int:
with self.lock:
idx = len(self.constant_args)
self.constant_args[idx] = args
return idx
# Returns the constant args
def get_constant_args(self, idx: int) -> Dict[str, Any]:
def get_constant_args(self, idx: int) -> dict[str, Any]:
# No need to lock here as fetching from dict is atomic
assert idx in self.constant_args
return self.constant_args[idx]
@ -163,7 +154,7 @@ class Intermediate:
class Op:
name: str
fn_call_name: Optional[str]
args: List[Union[Param, Intermediate]]
args: list[Union[Param, Intermediate]]
ret: Intermediate = dataclasses.field(repr=False)
def __post_init__(self) -> None:
@ -174,8 +165,8 @@ class Op:
def generate_ttir(
kernel: "TritonKernelType", kwargs: Dict[str, Any]
) -> Tuple["TritonIRModule", List[str]]:
kernel: "TritonKernelType", kwargs: dict[str, Any]
) -> tuple["TritonIRModule", list[str]]:
"""
Uses Triton's internal code generation to create TTIR
"""
@ -218,7 +209,7 @@ def generate_ttir(
# Replace all SymExprs with a regular value for TTIR generation
# Replace all FakeTensor/TensorBox with real tensors
# These replacements are needed for triton's type, key and config functions
ordered_args: Dict[str, Any] = {}
ordered_args: dict[str, Any] = {}
for name in kernel.arg_names:
a = kwargs[name]
if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool, sympy.Expr)):
@ -280,22 +271,22 @@ def generate_ttir(
def ttir_to_functions(
ttir_module: "TritonIRModule",
) -> Dict[str, Dict[Intermediate, List[Op]]]:
) -> dict[str, dict[Intermediate, list[Op]]]:
"""
Walk the `ttir_module` bottom up to mine the `functions` from
the structured MLIR entities representing the Triton kernel
(mlir::Operation, mlir::Block, mlir::Region).
"""
functions: Dict[str, Dict[Intermediate, List[Op]]] = {}
functions: dict[str, dict[Intermediate, list[Op]]] = {}
# block id --> op result (Intermediate) --> one or more ops
op_stack: Dict[int, Dict[Intermediate, List[Op]]] = defaultdict(
op_stack: dict[int, dict[Intermediate, list[Op]]] = defaultdict(
lambda: defaultdict(list)
)
region_id_to_block_ids: Dict[int, List[int]] = defaultdict(list)
block_id_to_block_arg_ids: Dict[int, List[int]] = {}
replacements: Dict[int, Union[Intermediate, Param]] = {}
reindex_map: Dict[int, int] = {}
region_id_to_block_ids: dict[int, list[int]] = defaultdict(list)
block_id_to_block_arg_ids: dict[int, list[int]] = {}
replacements: dict[int, Union[Intermediate, Param]] = {}
reindex_map: dict[int, int] = {}
next_fake_intermediate = 0
def reindex(idx: int) -> int:
@ -309,14 +300,14 @@ def ttir_to_functions(
# this wraps all tt.func ops
return
operand_ids: List[int] = [
operand_ids: list[int] = [
reindex(op.get_operand(i).id()) for i in range(op.get_num_operands())
]
result_ids: List[int] = [
result_ids: list[int] = [
reindex(op.get_result(i).id()) for i in range(op.get_num_results())
]
child_block_ids: List[int] = []
child_block_ids: list[int] = []
for i in [op.get_region(i).id() for i in range(op.get_num_regions())]:
# as the walk is bottom-up, the region_id_to_block_ids[i]
# must be populated by the time we process the enclosing op
@ -460,7 +451,7 @@ def ttir_to_functions(
callee = None
if name == "tt.call":
callee = op.get_flat_symbol_ref_attr("callee")
args: List[Union[Param, Intermediate]] = [
args: list[Union[Param, Intermediate]] = [
Intermediate(operand) for operand in operand_ids
]
block_ops = op_stack[parent_block_id]
@ -480,7 +471,7 @@ def ttir_to_functions(
class MemoizeWithCycleCheck:
fn: Callable[..., Any]
cache: Dict[Tuple[str, int], Any]
cache: dict[tuple[str, int], Any]
def __init__(self, fn: Callable[..., Any]) -> None:
self.fn = fn
@ -488,10 +479,10 @@ class MemoizeWithCycleCheck:
def __call__(
self,
functions: Dict[str, Dict[Intermediate, List[Op]]],
functions: dict[str, dict[Intermediate, list[Op]]],
fn_name: str,
num_args: int,
) -> List[bool]:
) -> list[bool]:
key = (fn_name, num_args)
if key not in self.cache:
self.cache[key] = None
@ -506,8 +497,8 @@ class MemoizeWithCycleCheck:
@MemoizeWithCycleCheck
def analyze_kernel_mutations(
functions: Dict[str, Dict[Intermediate, List[Op]]], fn_name: str, num_args: int
) -> List[bool]:
functions: dict[str, dict[Intermediate, list[Op]]], fn_name: str, num_args: int
) -> list[bool]:
"""
Analyzes the graph to detect all sinks from a predefined list of sinks
by using triton's MemWrite trait list. NOTE: What if triton exposed this?
@ -527,7 +518,7 @@ def analyze_kernel_mutations(
# Ops that we want to bail out on
UNKNOWN_OPS = {"tt.elementwise_inline_asm"}
stack: List[Union[Param, Intermediate]] = []
stack: list[Union[Param, Intermediate]] = []
visited = set()
ops = functions[fn_name]
for op_list in ops.values():
@ -569,8 +560,8 @@ def analyze_kernel_mutations(
def identify_mutated_tensors(
kernel: "TritonKernelType", kwargs: Dict[str, Any]
) -> List[str]:
kernel: "TritonKernelType", kwargs: dict[str, Any]
) -> list[str]:
"""
Given a triton kernel and the arguments for this kernel, this function
1) Retrieves the TTIR converted version of the kernel from Triton's API.
@ -630,9 +621,9 @@ class TritonKernelWrapperMutation(HigherOrderOperator):
self,
kernel_idx: int,
constant_args_idx: int,
grid: List["TritonGridType"],
grid: list["TritonGridType"],
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
) -> Any:
return super().__call__(
kernel_idx=kernel_idx,
@ -655,11 +646,11 @@ class TritonKernelWrapperFunctional(HigherOrderOperator):
self,
kernel_idx: int,
constant_args_idx: int,
grid: List["TritonGridType"],
grid: list["TritonGridType"],
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs: Dict[str, Any],
tensors_to_clone: List[str],
) -> Dict[str, Any]:
kwargs: dict[str, Any],
tensors_to_clone: list[str],
) -> dict[str, Any]:
return super().__call__(
kernel_idx=kernel_idx,
constant_args_idx=constant_args_idx,
@ -678,9 +669,9 @@ def triton_kernel_wrapper_mutation_dense(
*,
kernel_idx: int,
constant_args_idx: int,
grid: List["TritonGridType"],
grid: list["TritonGridType"],
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
) -> None:
from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code
@ -693,7 +684,7 @@ def triton_kernel_wrapper_mutation_dense(
fn_name, code = user_defined_kernel_grid_fn_code(
kernel.fn.__name__, kernel.configs, grid
)
namespace: Dict[str, Any] = {}
namespace: dict[str, Any] = {}
exec(code, namespace)
grid_fn = namespace[fn_name]
@ -746,9 +737,9 @@ def triton_kernel_wrapper_mutation_fake_tensor_mode(
*,
kernel_idx: int,
constant_args_idx: int,
grid: List["TritonGridType"],
grid: list["TritonGridType"],
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
) -> None:
with mode:
return None
@ -759,9 +750,9 @@ def _(
*,
kernel_idx: int,
constant_args_idx: int,
grid: List["TritonGridType"],
grid: list["TritonGridType"],
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
) -> None:
return None
@ -769,8 +760,8 @@ def _(
def trace_triton_kernel_wrapper(
proxy_mode: ProxyTorchDispatchMode,
func_overload: Callable[..., Any],
node_args: Dict[str, Any],
) -> Optional[Dict[str, Any]]:
node_args: dict[str, Any],
) -> Optional[dict[str, Any]]:
with disable_proxy_modes_tracing():
out = func_overload(**node_args)
@ -795,9 +786,9 @@ def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
*,
kernel_idx: int,
constant_args_idx: int,
grid: List["TritonGridType"],
grid: list["TritonGridType"],
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
) -> None:
trace_triton_kernel_wrapper(
mode,
@ -815,8 +806,8 @@ def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
def get_mutated_tensors(
kernel_idx: int, constant_args_idx: int, kwargs: Dict[str, Any]
) -> List[str]:
kernel_idx: int, constant_args_idx: int, kwargs: dict[str, Any]
) -> list[str]:
kernel = kernel_side_table.get_kernel(kernel_idx)
constant_args = kernel_side_table.get_constant_args(constant_args_idx)
return identify_mutated_tensors(kernel, {**kwargs, **constant_args})
@ -827,9 +818,9 @@ def triton_kernel_wrapper_mutation_functionalize(
ctx: "BaseFunctionalizeAPI",
kernel_idx: int,
constant_args_idx: int,
grid: List["TritonGridType"],
grid: list["TritonGridType"],
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
) -> None:
unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
# TODO(oulgen): Preexisting bug, if two kernel inputs are views of each
@ -869,11 +860,11 @@ def triton_kernel_wrapper_functional_dense(
*,
kernel_idx: int,
constant_args_idx: int,
grid: List["TritonGridType"],
grid: list["TritonGridType"],
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs: Dict[str, Any],
tensors_to_clone: List[str],
) -> Dict[str, Any]:
kwargs: dict[str, Any],
tensors_to_clone: list[str],
) -> dict[str, Any]:
# TODO(oulgen): For performance reasons, we want to ensure that these
# `clone_preserve_strides` calls are never executed at runtime
# (inductor should always optimize them away).
@ -898,11 +889,11 @@ def triton_kernel_wrapper_functional_fake_tensor_mode(
*,
kernel_idx: int,
constant_args_idx: int,
grid: List["TritonGridType"],
grid: list["TritonGridType"],
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs: Dict[str, Any],
tensors_to_clone: List[str],
) -> Dict[str, Any]:
kwargs: dict[str, Any],
tensors_to_clone: list[str],
) -> dict[str, Any]:
# TODO(oulgen): For performance reasons, we want to ensure that these
# `clone_preserve_strides` calls are never executed at runtime
# (inductor should always optimize them away).
@ -921,11 +912,11 @@ def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode(
*,
kernel_idx: int,
constant_args_idx: int,
grid: List["TritonGridType"],
grid: list["TritonGridType"],
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs: Dict[str, Any],
tensors_to_clone: List[str],
) -> Dict[str, Any]:
kwargs: dict[str, Any],
tensors_to_clone: list[str],
) -> dict[str, Any]:
ret = trace_triton_kernel_wrapper(
mode,
triton_kernel_wrapper_functional,
@ -947,11 +938,11 @@ def triton_kernel_wrapper_functional_functionalize(
ctx: "BaseFunctionalizeAPI",
kernel_idx: int,
constant_args_idx: int,
grid: List["TritonGridType"],
grid: list["TritonGridType"],
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs: Dict[str, Any],
tensors_to_clone: List[str],
) -> Dict[str, Any]:
kwargs: dict[str, Any],
tensors_to_clone: list[str],
) -> dict[str, Any]:
unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
with ctx.redispatch_to_next():
outputs = triton_kernel_wrapper_functional(
@ -1024,7 +1015,7 @@ class TritonHOPifier:
grid,
meta,
tx,
) -> Union[Tuple[Union[int, sympy.Expr, SymInt], ...], Tuple["Proxy", ...]]:
) -> Union[tuple[Union[int, sympy.Expr, SymInt], ...], tuple["Proxy", ...]]:
raise NotImplementedError("abstract method")
def wrap_user_defined_obj(
@ -1041,8 +1032,8 @@ class TritonHOPifier:
def call_user_defined_fn(
self,
user_fn: Callable[..., Any],
args: List,
kwargs: Dict,
args: list,
kwargs: dict,
tx: Optional["InstructionTranslator"],
variable: Optional[
Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
@ -1051,8 +1042,8 @@ class TritonHOPifier:
raise NotImplementedError("abstract method")
def maybe_unpack_configs(
self, configs: List["TritonConfig"], tx: Optional["InstructionTranslator"]
) -> List["TritonConfig"]:
self, configs: list["TritonConfig"], tx: Optional["InstructionTranslator"]
) -> list["TritonConfig"]:
raise NotImplementedError("abstract method")
def maybe_unpack_heuristic_result(self, result: Any) -> Any:
@ -1064,10 +1055,10 @@ class TritonHOPifier:
early_config_prune: Optional[Callable],
perf_model: Optional[Callable],
top_k: float,
configs: List,
named_args: Dict,
kwargs: Dict,
) -> List["TritonConfig"]:
configs: list,
named_args: dict,
kwargs: dict,
) -> list["TritonConfig"]:
# Reimplement autotuner.prune_configs(...) here
# see: https://github.com/triton-lang/triton/blob/e57b46897191b3b3061c78d0d60e58e94be565b6/python/triton/runtime/autotuner.py # noqa: E501,B950
# We do this to avoid calling prune_configs, which in turn calls early_config_prune and perf_model
@ -1107,14 +1098,14 @@ class TritonHOPifier:
self,
variable,
grids,
combined_args: Dict[str, Any],
combined_args: dict[str, Any],
tx,
) -> Optional["ConstantVariable"]:
raise NotImplementedError("abstract method")
def check_grid( # type: ignore[no-untyped-def]
self, grid
) -> Union[Tuple[Union[int, sympy.Expr, SymInt], ...], Tuple["Proxy", ...]]:
) -> Union[tuple[Union[int, sympy.Expr, SymInt], ...], tuple["Proxy", ...]]:
raise NotImplementedError("abstract method")
def init_variable(
@ -1215,7 +1206,7 @@ class TritonHOPifier:
self,
variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"],
args: Sequence[Any],
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
tx: Optional["InstructionTranslator"],
) -> Optional["ConstantVariable"]:
if "grid" not in kwargs:
@ -1236,7 +1227,7 @@ class TritonHOPifier:
self,
variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"],
args: Sequence[Any],
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
tx: Optional["InstructionTranslator"],
) -> Optional["ConstantVariable"]:
from triton import JITFunction
@ -1546,7 +1537,7 @@ class TracingTritonHOPifier(TritonHOPifier):
grid: "TritonGridCallableType",
meta: "TritonMetaParamsType",
tx: None,
) -> Tuple[Union[int, sympy.Expr, SymInt], ...]:
) -> tuple[Union[int, sympy.Expr, SymInt], ...]:
assert tx is None
assert isinstance(meta, dict)
assert callable(grid)
@ -1567,8 +1558,8 @@ class TracingTritonHOPifier(TritonHOPifier):
def call_user_defined_fn(
self,
user_fn: Callable[..., Any],
args: List,
kwargs: Dict,
args: list,
kwargs: dict,
tx: Optional["InstructionTranslator"],
variable: Optional[
Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
@ -1580,8 +1571,8 @@ class TracingTritonHOPifier(TritonHOPifier):
return user_fn(*args, **kwargs)
def maybe_unpack_configs(
self, configs: List["TritonConfig"], tx: Optional["InstructionTranslator"]
) -> List["TritonConfig"]:
self, configs: list["TritonConfig"], tx: Optional["InstructionTranslator"]
) -> list["TritonConfig"]:
assert isinstance(configs, list)
return configs
@ -1591,7 +1582,7 @@ class TracingTritonHOPifier(TritonHOPifier):
def check_grid(
self,
grid: "TritonGridType",
) -> Tuple[Union[int, sympy.Expr, SymInt], ...]:
) -> tuple[Union[int, sympy.Expr, SymInt], ...]:
if not isinstance(grid, collections.abc.Sequence):
raise RuntimeError(
"wrap_triton can only handle grids that resolve to Sequence[int]."
@ -1602,8 +1593,8 @@ class TracingTritonHOPifier(TritonHOPifier):
def call_HOP(
self,
variable: "TraceableTritonKernelWrapper",
grids: List["TritonGridTupleType"],
combined_args: Dict[str, Any],
grids: list["TritonGridTupleType"],
combined_args: dict[str, Any],
tx: None,
) -> None:
assert tx is None
@ -1652,7 +1643,7 @@ class TraceableTritonKernelWrapper:
def __getitem__(self, *args: Sequence[Any]) -> "TraceableTritonKernelWrapper":
return tracing_triton_hopifier_singleton.call_getitem(self, args) # type: ignore[return-value]
def run(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> Any:
def run(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> Any:
from torch._library.triton import is_wrap_triton_enabled
if is_wrap_triton_enabled():
@ -1661,7 +1652,7 @@ class TraceableTritonKernelWrapper:
assert self.kernel is not None
return self.kernel.run(*args, **kwargs)
def __call__(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> Any:
def __call__(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> Any:
from torch._library.triton import is_wrap_triton_enabled
if is_wrap_triton_enabled():

View File

@ -2,7 +2,7 @@
import functools
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, List, Tuple, Union
from typing import Any, Callable, Union
import torch
import torch.fx.traceback as fx_traceback
@ -462,7 +462,7 @@ def save_tensors_and_symints_for_backward(ctx, args):
assert all(
isinstance(arg, (torch.Tensor, torch.SymInt, int, type(None))) for arg in args
), args
partitioned_args: List[Any] = [[], []]
partitioned_args: list[Any] = [[], []]
pos = []
for i, arg in enumerate(args):
idx = 0 if isinstance(arg, torch.Tensor) else 1
@ -514,7 +514,7 @@ def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
# Reports the difference between meta of two tensors in a string
def diff_tensor_meta(
meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True
) -> List[str]:
) -> list[str]:
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
pair_diffs = []
@ -541,7 +541,7 @@ def diff_tensor_meta(
# to support int arguments. In the eager run case, we re-trace the subgraph in AutogradKey, so inner
# hops may receive int inputs from the shape of outer tensor inputs.
# However, CompositeExplicitAutograd won't receive SymInt inputs because it only accepts real tensor inputs.
def validate_subgraph_args_types(lifted_args: Union[Tuple[Any, ...], List[Any]]):
def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]]):
allowed_types = (torch.Tensor, int, torch.SymInt)
assert all(
isinstance(arg, (torch.Tensor, int, torch.SymInt)) for arg in lifted_args

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Callable, List, Tuple, Union
from typing import Callable, Union
import torch
import torch.utils._pytree as pytree
@ -33,8 +33,8 @@ class WhileLoopOp(HigherOrderOperator):
self,
cond_fn: Callable,
body_fn: Callable,
carried_inputs: Tuple[Union[torch.Tensor, int, float, bool]],
additional_inputs: Tuple[Union[torch.Tensor, torch.SymInt, int], ...],
carried_inputs: tuple[Union[torch.Tensor, int, float, bool]],
additional_inputs: tuple[Union[torch.Tensor, torch.SymInt, int], ...],
/,
):
if not isinstance(carried_inputs, tuple):
@ -126,7 +126,7 @@ def while_loop(cond_fn, body_fn, carried_inputs):
# Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo.
# parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs.
additional_inputs: Tuple = ()
additional_inputs: tuple = ()
# The reason we flatten the output before calling into dynamo is that
# we want to create a consistent input ordering for cond_fn and body_fn.
@ -330,15 +330,15 @@ def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs
def check_meta_consistency(
lhs_list: List[Union[torch.Tensor, torch.SymInt, int]],
rhs_list: List[Union[torch.Tensor, torch.SymInt, int]],
lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
lhs_name: str,
rhs_name: str,
) -> None:
def diff_meta_pairs(
lhs_list: List[Union[torch.Tensor, torch.SymInt, int]],
rhs_list: List[Union[torch.Tensor, torch.SymInt, int]],
) -> List[str]:
lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
) -> list[str]:
def diff_meta(
lhs: Union[torch.Tensor, torch.SymInt, int],
rhs: Union[torch.Tensor, torch.SymInt, int],

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, Type, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING, Union
import torch
from torch import SymInt
@ -80,7 +80,7 @@ class _DeconstructedSymType:
Represents a SymInt, SymFloat, SymBool without the associated ShapeEnv
"""
ty: Type[PySymType]
ty: type[PySymType]
node: _DeconstructedSymNode
@staticmethod
@ -216,7 +216,7 @@ class _CacheKeyState:
# We track the SymNodes so when we get the output we can see if it exactly
# matches one of the inputs so we can uncache it properly.
sym_node_lookup: Dict[int, int] # id(SymNode) -> index
sym_node_lookup: dict[int, int] # id(SymNode) -> index
# There are cases where we're asked to perform an op when we have no
# ShapeEnv on the FakeTensorMode - but for SymNodes we MUST have a
@ -241,7 +241,7 @@ class _CacheKeyState:
"""
return bool(self.sym_node_lookup)
def convert_sym_int(self, result: List[object], arg: SymInt) -> None:
def convert_sym_int(self, result: list[object], arg: SymInt) -> None:
node_id = id(arg.node)
if node_id in self.sym_node_lookup:
result.append(_InputBackref(self.sym_node_lookup[node_id]))

View File

@ -13,25 +13,7 @@ import typing
import weakref
from collections import defaultdict
from dataclasses import dataclass
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
Iterable,
List,
Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import Any, Callable, cast, Literal, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import Self, TypeGuard
from weakref import ReferenceType
@ -69,6 +51,7 @@ from ._fake_tensor_utils import _CacheKeyState, _PySymInputStub, _SymIntOutputSt
if TYPE_CHECKING:
from collections.abc import Generator, Iterable, Mapping, Sequence
from types import TracebackType
from torch._guards import Source
@ -94,7 +77,7 @@ class _Unassigned:
_UNASSIGNED = _Unassigned()
DimList = List
DimList = list
pytree = torch.utils._pytree
T = TypeVar("T")
@ -145,7 +128,7 @@ class MetadataMismatchError(RuntimeError):
reason: str
def ordered_set(*items: T) -> Dict[T, Literal[True]]:
def ordered_set(*items: T) -> dict[T, Literal[True]]:
return dict.fromkeys(items, True)
@ -160,8 +143,8 @@ def unset_fake_temporarily() -> Generator[Optional[TorchDispatchMode], None, Non
def get_plain_tensors(
subclass: Tensor, *, out: List[Union[Tensor, int, SymInt]]
) -> List[Union[Tensor, int, SymInt]]:
subclass: Tensor, *, out: list[Union[Tensor, int, SymInt]]
) -> list[Union[Tensor, int, SymInt]]:
# This function is used in Runtime, do not add redundant asserts
todo = [subclass]
while todo:
@ -248,7 +231,7 @@ def torch_decomp_decompositions(func: OpOverload) -> bool:
) and decomposition_table[func].__name__ in dir(decompositions)
def tree_flatten_only(ty: Type[T], tree: PyTree) -> List[T]:
def tree_flatten_only(ty: type[T], tree: PyTree) -> list[T]:
flat_vals = pytree.tree_leaves(tree)
return [elem for elem in flat_vals if isinstance(elem, ty)]
@ -281,7 +264,7 @@ class FakeTensorConverter:
return self.meta_converter.tensor_memo
meta_converter: MetaConverter
constant_storage_mapping: Dict[StorageWeakRef, List[ReferenceType]]
constant_storage_mapping: dict[StorageWeakRef, list[ReferenceType]]
export: bool
def __init__(self, *, copy_data: bool = False, export: bool = False) -> None:
@ -581,7 +564,7 @@ class SymNumberMemoDescriptor:
return f"_{self._name}_epoch"
def __get__(
self, obj: FakeTensor, objtype: Optional[Type[FakeTensor]] = None
self, obj: FakeTensor, objtype: Optional[type[FakeTensor]] = None
) -> Optional[Union[torch.SymInt, torch.SymFloat]]:
if (r := getattr(obj, self._memo(obj))) is None:
return None
@ -674,13 +657,13 @@ class FakeTensor(Tensor):
# We don't support named tensors; graph break
@property
def names(self) -> List[str]:
def names(self) -> list[str]:
raise UnsupportedFakeTensorException(
"torch.compile doesn't support named tensors"
)
@names.setter
def names(self, _: List[str]) -> None:
def names(self, _: list[str]) -> None:
raise NotImplementedError
@staticmethod
@ -774,7 +757,7 @@ class FakeTensor(Tensor):
def __torch_dispatch__( # type: ignore[override] # TODO
cls,
func: OpOverload,
types: Sequence[Type],
types: Sequence[type],
args: Sequence[object] = (),
kwargs: Mapping[str, object] = immutable_dict(),
) -> object:
@ -847,7 +830,7 @@ class FakeTensor(Tensor):
@staticmethod
def _find_common_device(
func: OpOverload, flat_args: Sequence[object]
) -> Tuple[torch.device, bool]:
) -> tuple[torch.device, bool]:
# Returns: (common_device, has_scalar_only_inputs)
# cpu - zero-dim tensors can be called in cuda kernels,
@ -942,8 +925,8 @@ class TensorMetadata:
"""
dtype: torch.dtype
shape: Tuple[_MetadataIntLike, ...]
stride: Tuple[_MetadataIntLike, ...]
shape: tuple[_MetadataIntLike, ...]
stride: tuple[_MetadataIntLike, ...]
device: torch.device
layout: torch.layout
memory_format: Optional[torch.memory_format]
@ -961,7 +944,7 @@ class TensorMetadata:
def _flatten_into(
self,
result: List[object],
result: list[object],
mode: FakeTensorMode,
state: _CacheKeyState,
) -> None:
@ -1024,10 +1007,10 @@ class _DispatchCacheKey:
Key for the FakeTensor dispatch cache.
"""
key: Tuple[object, ...]
key: tuple[object, ...]
hashvalue: int
def __init__(self, tup: Tuple[object, ...]) -> None:
def __init__(self, tup: tuple[object, ...]) -> None:
self.key = tup
self.hashvalue = hash(tup)
@ -1073,7 +1056,7 @@ class _DispatchCacheEntry:
is_output_tuple flag helps in differentiating the return type
"""
output_infos: Tuple[_DispatchCacheEntryOutputInfo]
output_infos: tuple[_DispatchCacheEntryOutputInfo]
is_output_tuple: bool = False
@ -1096,7 +1079,7 @@ class DispatchCacheInfo:
hits: int
misses: int
bypasses: Dict[str, int]
bypasses: dict[str, int]
size: int
@ -1110,10 +1093,10 @@ class DispatchCacheInfo:
class FakeTensorMode(TorchDispatchMode):
cache: Dict[_DispatchCacheKey, _DispatchCacheEntry] = {}
cache: dict[_DispatchCacheKey, _DispatchCacheEntry] = {}
cache_hits: int = 0
cache_misses: int = 0
cache_bypasses: Dict[str, int] = defaultdict(int)
cache_bypasses: dict[str, int] = defaultdict(int)
# Every time you retrace using the same fake tensor mode, you should
# advance the epoch so we don't reuse unbacked memos
epoch: int = 0
@ -1208,8 +1191,8 @@ class FakeTensorMode(TorchDispatchMode):
# in_kernel_invocation
# If another fake mode was already active when we enter, we also stash it here.
# That way when we exit, we know to re-enable the previous fake mode.
self.enter_stack: List[
Tuple[bool, Optional[TorchDispatchMode], Optional[bool]]
self.enter_stack: list[
tuple[bool, Optional[TorchDispatchMode], Optional[bool]]
] = []
self.shape_env = shape_env
@ -1272,7 +1255,7 @@ class FakeTensorMode(TorchDispatchMode):
def __torch_dispatch__(
self,
func: OpOverload,
types: Sequence[Type],
types: Sequence[type],
args: Sequence[object] = (),
kwargs: Mapping[str, object] = immutable_dict(),
) -> object:
@ -1309,7 +1292,7 @@ class FakeTensorMode(TorchDispatchMode):
def __exit__(
self,
a: Optional[Type[BaseException]],
a: Optional[type[BaseException]],
b: Optional[BaseException],
c: Optional[TracebackType],
) -> None:
@ -1356,7 +1339,7 @@ class FakeTensorMode(TorchDispatchMode):
def _cached_dispatch_impl(
self,
func: OpOverload,
types: Sequence[Type],
types: Sequence[type],
args: Sequence[object],
kwargs: Mapping[str, object],
) -> object:
@ -1474,7 +1457,7 @@ class FakeTensorMode(TorchDispatchMode):
def _prep_args_for_hash(
self,
result: List[object],
result: list[object],
args: Union[Mapping[str, object], Sequence[object], Iterable[object]],
state: _CacheKeyState,
) -> None:
@ -1734,7 +1717,7 @@ class FakeTensorMode(TorchDispatchMode):
key: _DispatchCacheKey,
func: OpOverload,
args: Sequence[object],
) -> Union[Optional[FakeTensor], Tuple[Optional[FakeTensor], ...]]:
) -> Union[Optional[FakeTensor], tuple[Optional[FakeTensor], ...]]:
"""
Create a new FakeTensor from the cache entry.
"""
@ -1758,9 +1741,9 @@ class FakeTensorMode(TorchDispatchMode):
def _crosscheck_cache_output(
self,
output: Union[Optional[FakeTensor], Tuple[Optional[FakeTensor], ...]],
output: Union[Optional[FakeTensor], tuple[Optional[FakeTensor], ...]],
func: OpOverload,
types: Sequence[Type],
types: Sequence[type],
args: Sequence[object],
kwargs: Mapping[str, object],
) -> None:
@ -1796,7 +1779,7 @@ class FakeTensorMode(TorchDispatchMode):
def dispatch(
self,
func: OpOverload,
types: Sequence[Type],
types: Sequence[type],
args: Sequence[object] = (),
kwargs: Mapping[str, object] = immutable_dict(),
) -> object:
@ -1982,7 +1965,7 @@ class FakeTensorMode(TorchDispatchMode):
def _dispatch_impl(
self,
func: OpOverload,
types: Sequence[Type],
types: Sequence[type],
args: Sequence[object],
kwargs: Mapping[str, object],
) -> Optional[FakeTensor]:
@ -2436,13 +2419,13 @@ class FakeTensorMode(TorchDispatchMode):
converter: FakeTensorConverter,
flat_args: Sequence[object],
args_spec: TreeSpec,
) -> Tuple[List[object], List[FakeTensor]]:
) -> tuple[list[object], list[FakeTensor]]:
"""
Checks if the list of tensors are fake tensors.
If not, try to convert them to fake tensors.
Returns the original args, kwargs, and a flattened list of (args, kwargs) that are fake tensors.
"""
flat_arg_fake_tensors: List[FakeTensor] = []
flat_arg_fake_tensors: list[FakeTensor] = []
def validate(x: T) -> Union[T, FakeTensor]:
if not isinstance(x, Tensor):
@ -2663,7 +2646,7 @@ def run_fallback_kernel(
r = func(*args, **kwargs)
storages: Set[_StoragePointer] = set()
storages: set[_StoragePointer] = set()
for e in flat_args:
if isinstance(e, Tensor):
@ -2703,7 +2686,7 @@ class FakeCopyMode(TorchFunctionMode):
def __torch_function__(
self,
func: OpOverload,
types: Sequence[Type],
types: Sequence[type],
args: Sequence[object] = (),
kwargs: Optional[Mapping[str, object]] = None,
) -> FakeTensor:
@ -2718,7 +2701,7 @@ class FakeCopyMode(TorchFunctionMode):
elif func == Tensor.__deepcopy__:
assert len(args) == 2 and len(kwargs) == 0
tensor = cast(Tensor, args[0])
memo = cast(Dict[int, FakeTensor], args[1])
memo = cast(dict[int, FakeTensor], args[1])
if id(tensor) in memo:
return memo[id(tensor)]

View File

@ -2,7 +2,7 @@
import functools
import warnings
from typing import Any, Callable, List, Union
from typing import Any, Callable, Union
import torch
import torch.utils._pytree as pytree
@ -102,8 +102,8 @@ def is_sdpa_error(func, idx, e):
def try_convert_fake_to_real(
ten_list: List[Union[FakeTensor, Any]]
) -> List[Union[FakeTensor, torch.Tensor, Any]]:
ten_list: list[Union[FakeTensor, Any]]
) -> list[Union[FakeTensor, torch.Tensor, Any]]:
"""
Attempt to convert fake tensors to a corresponding real tensor with the correct underlying storage by looking up
the FakeTensorMode meta to real storage mapping. On failure to find the storage mapping, the FakeTensor will

View File

@ -3,7 +3,7 @@ import contextlib
import warnings
import weakref
from abc import ABC, abstractmethod
from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, ContextManager, Optional, Union
import torch
import torch.utils._pytree as pytree
@ -308,10 +308,10 @@ class FunctionalTensorMode(TorchDispatchMode):
self._dispatch_key = torch._C.DispatchKey.PreDispatch if pre_dispatch else None # type: ignore[attr-defined]
# Map of effect type (ex. _EffectType.ORDERED) to a token. The tokens help keep
# track of the ordering between side effectful operations.
self._tokens: Dict[Any, torch.Tensor] = {}
self._tokens: dict[Any, torch.Tensor] = {}
# Filled after forward tracing.
self._tokens_forward_output: Dict[Any, torch.Tensor] = {}
self._tokens_forward_output: dict[Any, torch.Tensor] = {}
# Functionalization runs twice in AOTAutograd, once in
# `run_functionalized_fw_and_collect_metadata` to collect metadata to
@ -648,12 +648,12 @@ def dispatch_functionalize(func, mode: FunctionalTensorMode = FunctionalTensorMo
class BaseFunctionalizeAPI(ABC):
@abstractmethod
def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
def wrap_tensors(self, args: tuple[Any]) -> tuple[Any]:
pass
@abstractmethod
def unwrap_tensors(
self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
self, args: Union[torch.Tensor, tuple[torch.Tensor, ...]]
) -> Any:
pass
@ -690,14 +690,14 @@ class PythonFunctionalizeAPI(BaseFunctionalizeAPI):
self.mode = mode if mode else FunctionalTensorMode()
self.pre_dispatch = pre_dispatch
def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
def wrap_tensors(self, args: tuple[Any]) -> tuple[Any]:
with self.mode:
return torch.utils._pytree.tree_map_only(
torch.Tensor, FunctionalTensor.to_functional, args
)
def unwrap_tensors(
self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor]]
self, args: Union[torch.Tensor, tuple[torch.Tensor, ...], list[torch.Tensor]]
) -> Any:
return torch.utils._pytree.tree_map_only(
FunctionalTensor, FunctionalTensor.from_functional, args
@ -733,14 +733,14 @@ class PythonFunctionalizeAPI(BaseFunctionalizeAPI):
class CppFunctionalizeAPI(BaseFunctionalizeAPI):
def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
def wrap_tensors(self, args: tuple[Any]) -> tuple[Any]:
from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional
return _wrap_all_tensors_to_functional(args, level=0)
def unwrap_tensors(
self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
self, args: Union[torch.Tensor, tuple[torch.Tensor, ...]]
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
from torch._functorch.eager_transforms import (
_unwrap_all_tensors_from_functional,
)
@ -772,14 +772,14 @@ class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI):
def __init__(self, interpreter):
self.interpreter = interpreter
def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
def wrap_tensors(self, args: tuple[Any]) -> tuple[Any]:
from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional
return _wrap_all_tensors_to_functional(args, level=self.interpreter.level())
def unwrap_tensors(
self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
self, args: Union[torch.Tensor, tuple[torch.Tensor, ...]]
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
from torch._functorch.eager_transforms import (
_unwrap_all_tensors_from_functional,
)

View File

@ -13,15 +13,10 @@ from typing import (
Callable,
ClassVar,
ContextManager,
Dict,
Generic,
List,
NewType,
Optional,
Protocol,
Set,
Tuple,
Type,
TYPE_CHECKING,
TypeVar,
Union,
@ -66,7 +61,7 @@ def _is_fake_tensor(t: object) -> TypeIs[FakeTensor]:
return isinstance(t, FakeTensor)
DimList = List
DimList = list
_TensorLikeT = TypeVar("_TensorLikeT", "MetaTensorDesc", torch.Tensor)
_T = TypeVar("_T")
_TensorT = TypeVar("_TensorT", bound=torch.Tensor)
@ -178,7 +173,7 @@ def is_sparse_any(t: object) -> TypeGuard[torch.Tensor]:
return is_sparse_coo(t) or is_sparse_compressed(t)
def _checked_cast(ty: Type[_T], obj: object) -> _T:
def _checked_cast(ty: type[_T], obj: object) -> _T:
assert isinstance(obj, ty), f"expected {ty} but got {type(obj)}"
return obj
@ -224,8 +219,8 @@ class MetaTensorDescriber:
# Storage -> int
self.lookup_storage = WeakIdKeyDictionary()
self.copy_data = copy_data
self.traced_tensors: Set[int] = set()
self.traced_storages: Set[int] = set()
self.traced_tensors: set[int] = set()
self.traced_storages: set[int] = set()
def get_tensor_id(self, t: torch.Tensor) -> MetaTensorId:
if t not in self.lookup_tensor:
@ -480,7 +475,7 @@ class MetaStorageDesc:
# serializable in JSON, you want to do something special here anyway
data: Optional[torch.UntypedStorage]
def as_json(self, describer_id: _DescriberId) -> Dict[str, object]:
def as_json(self, describer_id: _DescriberId) -> dict[str, object]:
return {
"id": self.id,
"describer_id": describer_id,
@ -579,8 +574,8 @@ class MetaTensorDesc(Generic[_TensorT]):
# throw an error, but we don't currently have any subclasses that do this
# except C++ nested tensor but we're going to have nested int to make this
# defined on NJT
size: Tuple[int, ...]
dynamo_dynamic_indices: List[int]
size: tuple[int, ...]
dynamo_dynamic_indices: list[int]
layout: torch.layout = torch.strided
is_inference: bool = False
@ -603,7 +598,7 @@ class MetaTensorDesc(Generic[_TensorT]):
is_conj: bool = False
is_neg: bool = False
is_parameter: bool = False
stride: Optional[Tuple[int, ...]] = None
stride: Optional[tuple[int, ...]] = None
storage_offset: int = 0
# NB: We have a choice whether or not to store the id or a direct pointer
# to the data structure. For ease of use, we store the data structure,
@ -621,13 +616,13 @@ class MetaTensorDesc(Generic[_TensorT]):
unwrapped: Optional[MetaTensorDesc] = None # is_functorch_wrapped
bdim: Optional[int] = None # is_functorch_wrapped
base: Optional[MetaTensorDesc] = None # is_view
attrs: Optional[Dict[str, MetaTensorDesc]] = None # is_traceable_wrapper_subclass
attrs: Optional[dict[str, MetaTensorDesc]] = None # is_traceable_wrapper_subclass
creation_meta: Optional[CreationMeta] = None
grad: Optional[MetaTensorDesc] = None
# Everything below is NOT serializable, need some more work
_UNSERIALIZABLE: ClassVar[Set[str]] = {
_UNSERIALIZABLE: ClassVar[set[str]] = {
"ctx",
"type",
"fake_mode",
@ -642,14 +637,14 @@ class MetaTensorDesc(Generic[_TensorT]):
}
ctx: Optional[object] = None # is_traceable_wrapper_subclass
type: Optional[Type] = None # is_traceable_wrapper_subclass
type: Optional[type] = None # is_traceable_wrapper_subclass
fake_mode: Optional[FakeTensorMode] = None
view_func: Optional[ViewFunc] = None
# level looks serializable, but actually it is meaningless without
# the functorch_stack below
level: Optional[int] = None # is_functorch_wrapped
current_level: Optional[int] = None
functorch_stack: Optional[List[CInterpreter]] = None
functorch_stack: Optional[list[CInterpreter]] = None
autograd_meta_from: Optional[torch.Tensor] = None
# This is only populated on copy_data, and typically is not used at all,
@ -669,7 +664,7 @@ class MetaTensorDesc(Generic[_TensorT]):
# NB: This will reference numeric IDs, and it is assumed that you've
# already serialized everything this recursively references
def as_json(self, describer_id: _DescriberId) -> Dict[str, object]:
def as_json(self, describer_id: _DescriberId) -> dict[str, object]:
def json(k: str, v: object) -> object:
# Some best-effort debugging serialization for unserializable
# fields (feel free to add other special cases as appropriate)
@ -706,7 +701,7 @@ class MetaTensorDesc(Generic[_TensorT]):
return r
@property
def shape(self) -> Tuple[int, ...]:
def shape(self) -> tuple[int, ...]:
return self.size
@ -893,7 +888,7 @@ class MetaConverter(Generic[_TensorT]):
symbolic_context: Optional[
torch.fx.experimental.symbolic_shapes.SymbolicContext
] = symbolic_context,
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
) -> tuple[tuple[int, ...], tuple[int, ...], int]:
assert t.stride is not None
if shape_env is not None:
fake_mode = t.fake_mode
@ -948,8 +943,8 @@ class MetaConverter(Generic[_TensorT]):
# symbolic context.
def empty_create_subclass(
t: MetaTensorDesc,
outer_size: Tuple[int, ...],
outer_stride: Tuple[int, ...],
outer_size: tuple[int, ...],
outer_stride: tuple[int, ...],
symbolic_context: Optional[
torch.fx.experimental.symbolic_shapes.SymbolicContext
] = symbolic_context,
@ -981,8 +976,8 @@ class MetaConverter(Generic[_TensorT]):
def _empty_create_subclass(
t: MetaTensorDesc,
outer_size: Optional[Tuple[int, ...]],
outer_stride: Optional[Tuple[int, ...]],
outer_size: Optional[tuple[int, ...]],
outer_stride: Optional[tuple[int, ...]],
symbolic_context: Optional[
torch.fx.experimental.symbolic_shapes.SymbolicContext
],
@ -1028,7 +1023,7 @@ class MetaConverter(Generic[_TensorT]):
inner_tensors[attr] = new_empty_tensor
assert t.type is not None
return t.type.__tensor_unflatten__(
return t.type.__tensor_unflatten__( # type: ignore[attr-defined]
inner_tensors, t.ctx, outer_size, outer_stride
)
@ -1081,7 +1076,7 @@ class MetaConverter(Generic[_TensorT]):
t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.ndim
if t.is_traceable_wrapper_subclass:
assert t.attrs is not None
inner_contexts: Dict[
inner_contexts: dict[
str, torch.fx.experimental.symbolic_shapes.SymbolicContext
] = {}
for attr, inner in t.attrs.items():

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import hashlib
import json
from typing import Dict, Tuple
import coremltools as ct # type: ignore[import]
from coremltools.converters.mil.input_types import TensorType # type: ignore[import]
@ -83,7 +82,7 @@ def _convert_to_mil_type(shape, dtype, name: str):
return ml_type
def preprocess(script_module: torch._C.ScriptObject, compile_spec: Dict[str, Tuple]):
def preprocess(script_module: torch._C.ScriptObject, compile_spec: dict[str, tuple]):
spec = compile_spec["forward"]
(
input_specs,

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import List, Optional
from typing import Optional
import torch
from torch.backends._nnapi.serializer import _NnapiSerializer
@ -21,16 +21,16 @@ class NnapiModule(torch.nn.Module):
# _nnapi.Compilation is defined
comp: Optional[torch.classes._nnapi.Compilation] # type: ignore[name-defined]
weights: List[torch.Tensor]
out_templates: List[torch.Tensor]
weights: list[torch.Tensor]
out_templates: list[torch.Tensor]
def __init__(
self,
shape_compute_module: torch.nn.Module,
ser_model: torch.Tensor,
weights: List[torch.Tensor],
inp_mem_fmts: List[int],
out_mem_fmts: List[int],
weights: list[torch.Tensor],
inp_mem_fmts: list[int],
out_mem_fmts: list[int],
compilation_preference: int,
relax_f32_to_f16: bool,
):
@ -46,7 +46,7 @@ class NnapiModule(torch.nn.Module):
self.relax_f32_to_f16 = relax_f32_to_f16
@torch.jit.export
def init(self, args: List[torch.Tensor]):
def init(self, args: list[torch.Tensor]):
assert self.comp is None
self.out_templates = self.shape_compute_module.prepare(self.ser_model, args) # type: ignore[operator]
self.weights = [w.contiguous() for w in self.weights]
@ -60,7 +60,7 @@ class NnapiModule(torch.nn.Module):
self.comp = comp
def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]:
def forward(self, args: list[torch.Tensor]) -> list[torch.Tensor]:
if self.comp is None:
self.init(args)
comp = self.comp

View File

@ -6,7 +6,7 @@ import logging
import operator
import struct
import sys
from typing import List, NamedTuple, Optional, Tuple
from typing import NamedTuple, Optional
import torch
@ -210,7 +210,7 @@ class Operand(NamedTuple):
# This is always the PyTorch shape, which is NCHW for feature maps.
# The actual NNAPI operand might have a transposed shape.
# we use 0 for load time dynamic shapes & -1 for runtime dynamic shapes
shape: Tuple[int, ...]
shape: tuple[int, ...]
# Specifies how the shape of the operand that we define in NNAPI
# relates to the shape we track above.
@ -943,8 +943,8 @@ class _NnapiSerializer:
assert node.outputsSize() == 1
output = node.outputsAt(0)
ctype = output.type()
const_vals: Optional[List] = []
tensors: Optional[List] = []
const_vals: Optional[list] = []
tensors: Optional[list] = []
for inp in node.inputs():
if const_vals is not None and inp in self.constants:
_, val = self.get_constant_value(inp)

View File

@ -39,7 +39,7 @@ class _QEngineProp:
class _SupportedQEnginesProp:
def __get__(self, obj, objtype) -> List[str]:
def __get__(self, obj, objtype) -> list[str]:
qengines = torch._C._supported_qengines()
return [_get_qengine_str(qe) for qe in qengines]
@ -63,4 +63,4 @@ class QuantizedEngine(types.ModuleType):
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = QuantizedEngine(sys.modules[__name__], __name__)
engine: str
supported_engines: List[str]
supported_engines: list[str]

View File

@ -132,7 +132,6 @@ import subprocess
import sys
from argparse import ArgumentParser, RawTextHelpFormatter, REMAINDER
from os.path import expanduser
from typing import Dict, List
from torch.distributed.elastic.multiprocessing import (
DefaultLogsSpecs as _DefaultLogsSpecs,
@ -181,8 +180,8 @@ class _CPUinfo:
# physical cores := core column in lscpu output
# logical cores := cPU column in lscpu output
self.node_nums = int(max(line[3] for line in self.cpuinfo)) + 1
self.node_physical_cores: List[List[int]] = [] # node_id is index
self.node_logical_cores: List[List[int]] = [] # node_id is index
self.node_physical_cores: list[list[int]] = [] # node_id is index
self.node_logical_cores: list[list[int]] = [] # node_id is index
self.physical_core_node_map = {} # physical core to numa node id
self.logical_core_node_map = {} # logical core to numa node id
@ -594,7 +593,7 @@ won't take effect even if it is set explicitly."
)
entrypoint = ""
launch_args = {}
launch_envs: Dict[int, Dict] = {}
launch_envs: dict[int, dict] = {}
launch_tee = {}
# check whether is launched from torchrun with --nproc-per-node <num workers>
local_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
@ -623,7 +622,7 @@ won't take effect even if it is set explicitly."
* args.ncores_per_instance
]
core_ranges: List[Dict] = []
core_ranges: list[dict] = []
if local_size > 1:
total_num_cores = len(core_list)
cores_per_rank = total_num_cores // local_size

View File

@ -432,7 +432,7 @@ def is_exporting() -> bool:
return _is_exporting_flag
def save_cache_artifacts() -> Optional[Tuple[bytes, "CacheInfo"]]:
def save_cache_artifacts() -> Optional[tuple[bytes, "CacheInfo"]]:
"""
Serializes all the cache artifacts that were created during the compilation

View File

@ -3,7 +3,7 @@ import logging
import os
import pickle
from enum import Enum
from typing import List, Optional, Tuple, Union
from typing import Optional, Union
from torch._inductor.remote_cache import JsonDataTy, RemoteCacheJsonSerde
from torch._inductor.runtime.runtime_utils import cache_dir
@ -41,10 +41,10 @@ class CacheInfo:
instrumentation
"""
inductor_artifacts: List[str] = dataclasses.field(default_factory=list)
autotune_artifacts: List[str] = dataclasses.field(default_factory=list)
aot_autograd_artifacts: List[str] = dataclasses.field(default_factory=list)
pgo_artifacts: List[str] = dataclasses.field(default_factory=list)
inductor_artifacts: list[str] = dataclasses.field(default_factory=list)
autotune_artifacts: list[str] = dataclasses.field(default_factory=list)
aot_autograd_artifacts: list[str] = dataclasses.field(default_factory=list)
pgo_artifacts: list[str] = dataclasses.field(default_factory=list)
def add(self, artifact: CacheArtifact) -> None:
if artifact.type == CacheArtifactType.INDUCTOR:
@ -77,7 +77,7 @@ class CacheArtifactManager:
"""
# Protected by the compile_lock
_cache_artifacts: List[CacheArtifact] = []
_cache_artifacts: list[CacheArtifact] = []
@classmethod
def clear(cls) -> None:
@ -102,7 +102,7 @@ class CacheArtifactManager:
cls._cache_artifacts.append(CacheArtifact(artifact_type, key, content))
@classmethod
def serialize(cls) -> Optional[Tuple[bytes, CacheInfo]]:
def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
"""
Converts the "mega" list into portable format
"""

View File

@ -45,8 +45,8 @@ except ImportError:
_initialized = False
_tls = threading.local()
_initialization_lock = threading.Lock()
_queued_calls: List[
Tuple[Callable[[], None], List[str]]
_queued_calls: list[
tuple[Callable[[], None], list[str]]
] = [] # don't invoke these until initialization occurs
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
_device_t = Union[_device, str, int, None]
@ -101,7 +101,7 @@ else:
has_half: bool = True
has_magma: bool = torch._C._has_magma
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
default_generators: tuple[torch._C.Generator] = () # type: ignore[assignment]
def _is_compiled() -> bool:
@ -492,7 +492,7 @@ def get_device_name(device: Optional[_device_t] = None) -> str:
return get_device_properties(device).name
def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]:
def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]:
r"""Get the cuda capability of a device.
Args:
@ -642,7 +642,7 @@ def set_stream(stream: Stream):
)
def _parse_visible_devices() -> Union[List[int], List[str]]:
def _parse_visible_devices() -> Union[list[int], list[str]]:
r"""Parse CUDA_VISIBLE_DEVICES environment variable."""
var = os.getenv("CUDA_VISIBLE_DEVICES")
@ -683,12 +683,12 @@ def _parse_visible_devices() -> Union[List[int], List[str]]:
idx += 1
return int(s[:idx]) if idx > 0 else -1
def parse_list_with_prefix(lst: str, prefix: str) -> List[str]:
rcs: List[str] = []
def parse_list_with_prefix(lst: str, prefix: str) -> list[str]:
rcs: list[str] = []
for elem in lst.split(","):
# Repeated id results in empty set
if elem in rcs:
return cast(List[str], [])
return cast(list[str], [])
# Anything other but prefix is ignored
if not elem.startswith(prefix):
break
@ -701,12 +701,12 @@ def _parse_visible_devices() -> Union[List[int], List[str]]:
return parse_list_with_prefix(var, "MIG-")
# CUDA_VISIBLE_DEVICES uses something like strtoul
# which makes `1gpu2,2ampere` is equivalent to `1,2`
rc: List[int] = []
rc: list[int] = []
for elem in var.split(","):
x = _strtoul(elem.strip())
# Repeated ordinal results in empty set
if x in rc:
return cast(List[int], [])
return cast(list[int], [])
# Negative value aborts the sequence
if x < 0:
break
@ -744,7 +744,7 @@ def _raw_device_count_nvml() -> int:
return dev_count.value
def _raw_device_uuid_amdsmi() -> Optional[List[str]]:
def _raw_device_uuid_amdsmi() -> Optional[list[str]]:
from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer
if not _HAS_PYNVML: # If amdsmi is not available
@ -760,7 +760,7 @@ def _raw_device_uuid_amdsmi() -> Optional[List[str]]:
except amdsmi.AmdSmiException:
warnings.warn("Can't get amdsmi device count")
return None
uuids: List[str] = []
uuids: list[str] = []
for idx in range(dev_count):
try:
handler = amdsmi.amdsmi_get_processor_handles()[idx]
@ -780,7 +780,7 @@ def _raw_device_uuid_amdsmi() -> Optional[List[str]]:
return uuids
def _raw_device_uuid_nvml() -> Optional[List[str]]:
def _raw_device_uuid_nvml() -> Optional[list[str]]:
r"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed."""
from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer
@ -794,7 +794,7 @@ def _raw_device_uuid_nvml() -> Optional[List[str]]:
if rc != 0:
warnings.warn("Can't get nvml device count")
return None
uuids: List[str] = []
uuids: list[str] = []
for idx in range(dev_count.value):
dev_id = c_void_p()
rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
@ -812,10 +812,10 @@ def _raw_device_uuid_nvml() -> Optional[List[str]]:
return uuids
def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List[int]:
def _transform_uuid_to_ordinals(candidates: list[str], uuids: list[str]) -> list[int]:
r"""Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials IDs."""
def uuid_to_ordinal(candidate: str, uuids: List[str]) -> int:
def uuid_to_ordinal(candidate: str, uuids: list[str]) -> int:
best_match = -1
for idx, uuid in enumerate(uuids):
if not uuid.startswith(candidate):
@ -826,7 +826,7 @@ def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List
best_match = idx
return best_match
rc: List[int] = []
rc: list[int] = []
for candidate in candidates:
if torch.version.hip:
candidate = candidate.replace(
@ -838,7 +838,7 @@ def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List
break
# Duplicates result in empty set
if idx in rc:
return cast(List[int], [])
return cast(list[int], [])
rc.append(idx)
return rc
@ -853,7 +853,7 @@ def _device_count_amdsmi() -> int:
if uuids is None:
return -1
# Create string version of visible devices to avoid mypy warnings
visible_device_str = cast(List[str], visible_devices)
visible_device_str = cast(list[str], visible_devices)
visible_devices = _transform_uuid_to_ordinals(visible_device_str, uuids)
else:
raw_cnt = _raw_device_count_amdsmi()
@ -887,7 +887,7 @@ def _device_count_nvml() -> int:
if uuids is None:
return -1
visible_devices = _transform_uuid_to_ordinals(
cast(List[str], visible_devices), uuids
cast(list[str], visible_devices), uuids
)
else:
raw_cnt = _raw_device_count_nvml()
@ -913,9 +913,9 @@ def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int:
if uuids is None:
raise RuntimeError("Can't get device UUIDs")
visible_devices = _transform_uuid_to_ordinals(
cast(List[str], visible_devices), uuids
cast(list[str], visible_devices), uuids
)
visible_devices = cast(List[int], visible_devices)
visible_devices = cast(list[int], visible_devices)
if idx < 0 or idx >= len(visible_devices):
raise RuntimeError(
f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})"
@ -944,7 +944,7 @@ def device_count() -> int:
return r
def get_arch_list() -> List[str]:
def get_arch_list() -> list[str]:
r"""Return list CUDA architectures this library was compiled for."""
if not is_available():
return []
@ -1145,10 +1145,10 @@ def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int:
if uuids is None:
raise RuntimeError("Can't get device UUIDs")
visible_devices_str = cast(
List[str], visible_devices
list[str], visible_devices
) # Create str variable for mypy
visible_devices = _transform_uuid_to_ordinals(visible_devices_str, uuids)
idx_map = dict(enumerate(cast(List[int], visible_devices)))
idx_map = dict(enumerate(cast(list[int], visible_devices)))
if idx not in idx_map:
raise RuntimeError(
f"device {idx} is not visible (HIP_VISIBLE_DEVICES={visible_devices})"

View File

@ -20,8 +20,9 @@ import re
import sys
import textwrap
import traceback
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar
from typing import Any, Optional, TypeVar
import torch
import torch.cuda._gpu_trace as gpu_trace
@ -74,7 +75,7 @@ class Access:
seq_num: SeqNum
stream: StreamId
operator: str
aliases: List[str]
aliases: list[str]
is_output: bool
stack_trace: traceback.StackSummary
@ -141,7 +142,7 @@ class UnsynchronizedAccessError(SynchronizationError):
class CUDASanitizerErrors(Exception):
"""Wrapper class for errors reported by CUDA Sanitizer."""
def __init__(self, errors: List[SynchronizationError]):
def __init__(self, errors: list[SynchronizationError]):
self.errors = errors
def __str__(self):
@ -161,13 +162,13 @@ class TensorInfo:
"""
allocation_stack_trace: Optional[traceback.StackSummary]
reads: List[Access] = field(default_factory=list)
reads: list[Access] = field(default_factory=list)
write: Optional[Access] = None
class _TensorsAccessed:
def __init__(self) -> None:
self.accesses: Dict[DataPtr, TensorInfo] = {}
self.accesses: dict[DataPtr, TensorInfo] = {}
def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
if data_ptr not in self.accesses:
@ -209,7 +210,7 @@ class _TensorsAccessed:
def get_write(self, data_ptr: DataPtr) -> Optional[Access]:
return self.accesses[data_ptr].write
def get_reads(self, data_ptr: DataPtr) -> List[Access]:
def get_reads(self, data_ptr: DataPtr) -> list[Access]:
return self.accesses[data_ptr].reads
def add_read(self, data_ptr: DataPtr, access: Access) -> None:
@ -222,9 +223,9 @@ class _TensorsAccessed:
class StreamSynchronizations:
def __init__(self) -> None:
self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
self.host_sync_state: Dict[StreamId, SeqNum] = {}
self.current_sync_states: dict[StreamId, dict[StreamId, SeqNum]] = {}
self.recorded_sync_states: dict[EventId, dict[StreamId, SeqNum]] = {}
self.host_sync_state: dict[StreamId, SeqNum] = {}
self.create_stream(DEFAULT_STREAM_ID)
def _ensure_stream_exists(self, stream: StreamId) -> None:
@ -288,7 +289,7 @@ class StreamSynchronizations:
self.recorded_sync_states[event] = self.current_sync_states[stream].copy()
def _state_wait_for_other(
self, state: Dict[StreamId, SeqNum], other: Dict[StreamId, SeqNum]
self, state: dict[StreamId, SeqNum], other: dict[StreamId, SeqNum]
) -> None:
for stream, seq_num in other.items():
state[stream] = max(state.get(stream, -1), seq_num)
@ -349,12 +350,12 @@ class EventHandler:
def _handle_kernel_launch(
self,
stream: StreamId,
read_only: Set[DataPtr],
read_write: Set[DataPtr],
outputs: Set[DataPtr],
read_only: set[DataPtr],
read_write: set[DataPtr],
outputs: set[DataPtr],
operator: str,
tensor_aliases: Dict[int, List[str]],
) -> List[SynchronizationError]:
tensor_aliases: dict[int, list[str]],
) -> list[SynchronizationError]:
def check_conflict(
data_ptr: DataPtr, current_access: Access, previous_access: Optional[Access]
) -> None:
@ -372,7 +373,7 @@ class EventHandler:
)
)
error_list: List[SynchronizationError] = []
error_list: list[SynchronizationError] = []
self.seq_num += 1
self.syncs.update_seq_num(stream, self.seq_num)
stack_trace = traceback.StackSummary.extract(
@ -462,15 +463,15 @@ class EventHandler:
self.syncs.all_streams_wait_for_event(event)
def zip_by_key(a: Dict[TK, TVa], b: Dict[TK, TVb]) -> Iterator[Tuple[TK, TVa, TVb]]:
def zip_by_key(a: dict[TK, TVa], b: dict[TK, TVb]) -> Iterator[tuple[TK, TVa, TVb]]:
for arg, value in a.items():
if arg in b:
yield arg, value, b[arg]
def zip_arguments(
schema: torch.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Iterator[Tuple[torch.Argument, Any]]:
schema: torch.FunctionSchema, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> Iterator[tuple[torch.Argument, Any]]:
schema_args = schema.arguments[: len(args)]
schema_kwargs = {arg.name: arg for arg in schema.arguments[len(args) :]}
@ -482,10 +483,10 @@ def zip_arguments(
class ArgumentHandler:
def __init__(self) -> None:
self.dataptrs_read: Set[DataPtr] = set()
self.dataptrs_written: Set[DataPtr] = set()
self.tensor_aliases: Dict[DataPtr, List[str]] = {}
self.outputs: Set[DataPtr] = set()
self.dataptrs_read: set[DataPtr] = set()
self.dataptrs_written: set[DataPtr] = set()
self.tensor_aliases: dict[DataPtr, list[str]] = {}
self.outputs: set[DataPtr] = set()
def _handle_argument(
self,
@ -511,8 +512,8 @@ class ArgumentHandler:
def parse_inputs(
self,
schema: torch.FunctionSchema,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
*,
is_factory: bool,
) -> None:

View File

@ -1,12 +1,12 @@
import os
import sys
from typing import Callable, List, Optional
from typing import Callable, Optional
import torch
from torch.types import Storage
__all__: List[str] = []
__all__: list[str] = []
def _dummy_fn(name: str) -> Callable:

View File

@ -1,12 +1,12 @@
# mypy: allow-untyped-defs
import re
from typing import Callable, List
from typing import Callable
import torch
from torch import Tensor
__all__: List[str] = []
__all__: list[str] = []
class _CodeParser:

View File

@ -8,7 +8,7 @@ import pickle
import sys
import warnings
from inspect import signature
from typing import Any, Dict, Literal, Optional, Tuple, Union
from typing import Any, Literal, Optional, Union
from typing_extensions import deprecated
import torch
@ -218,7 +218,7 @@ def empty_cache() -> None:
torch._C._cuda_emptyCache()
def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]:
def memory_stats(device: Union[Device, int] = None) -> dict[str, Any]:
r"""Return a dictionary of CUDA memory allocator statistics for a given device.
The return value of this function is a dictionary of statistics, each of
@ -323,7 +323,7 @@ def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]:
return collections.OrderedDict(result)
def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> Dict[str, Any]:
def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> dict[str, Any]:
r"""Return the result of :func:`~torch.cuda.memory_stats` as a nested dictionary."""
if not is_initialized():
return {}
@ -719,7 +719,7 @@ def list_gpu_processes(device: Union[Device, int] = None) -> str:
return "\n".join(lines)
def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]:
def mem_get_info(device: Union[Device, int] = None) -> tuple[int, int]:
r"""Return the global free and total GPU memory for a given device using cudaMemGetInfo.
Args:
@ -1035,7 +1035,7 @@ class MemPool(_MemPool):
super().__init__(allocator, True)
@property
def id(self) -> Tuple[int, int]:
def id(self) -> tuple[int, int]:
r"""Returns the ID of this pool as a tuple of two ints."""
return super().id

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
import collections
import warnings
from typing import Optional, Sequence, Union
from collections.abc import Sequence
from typing import Optional, Union
import torch.cuda

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
from typing import Iterable, List, Union
from collections.abc import Iterable
from typing import Union
import torch
from torch import Tensor
@ -42,7 +43,7 @@ def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
return default_generator.get_state()
def get_rng_state_all() -> List[Tensor]:
def get_rng_state_all() -> list[Tensor]:
r"""Return a list of ByteTensor representing the random number states of all devices."""
results = [get_rng_state(i) for i in range(device_count())]
return results

View File

@ -118,7 +118,7 @@ import multiprocessing as mp
import os
import shutil
import warnings
from typing import Optional, Tuple
from typing import Optional
import torch
@ -228,12 +228,12 @@ def get_filename() -> str:
return torch._C._cuda_tunableop_get_filename() # type: ignore[attr-defined]
def get_results() -> Tuple[str, str, str, float]:
def get_results() -> tuple[str, str, str, float]:
r"""Return all TunableOp results."""
return torch._C._cuda_tunableop_get_results() # type: ignore[attr-defined]
def get_validators() -> Tuple[str, str]:
def get_validators() -> tuple[str, str]:
r"""Return the TunableOp validators."""
return torch._C._cuda_tunableop_get_validators() # type: ignore[attr-defined]

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import warnings
from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, TypeVar, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import ParamSpec
import torch
@ -14,14 +14,14 @@ from torch.masked.maskedtensor.creation import as_masked_tensor
if TYPE_CHECKING:
from torch.types import _dtype as DType
DimOrDims = Optional[Union[int, Tuple[int], List[int]]]
DimOrDims = Optional[Union[int, tuple[int], list[int]]]
else:
# The JIT doesn't understand Union, nor torch.dtype here
DType = int
DimOrDims = Optional[Tuple[int]]
DimOrDims = Optional[tuple[int]]
__all__: List[str] = []
__all__: list[str] = []
_T = TypeVar("_T")
_P = ParamSpec("_P")
@ -291,7 +291,7 @@ defined as ``prod(x[:i])``.""",
example_dim = 1
example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]])
example_mask = torch.tensor([[True, False, True], [False, False, False]])
example_args: Tuple[Any, ...]
example_args: tuple[Any, ...]
if func.__name__ in {"norm", "normalize"}:
example_args = (2.0, example_dim)
example_input = example_input.to(dtype=torch.float32)
@ -303,8 +303,8 @@ defined as ``prod(x[:i])``.""",
else:
example_args = (example_dim,)
operation_args: Tuple[str, ...]
operation_kwargs: Tuple[str, ...]
operation_args: tuple[str, ...]
operation_kwargs: tuple[str, ...]
operation_args, operation_kwargs = args_and_kwargs[func.__name__]
arg_declarations = [
"\n ".join(
@ -461,9 +461,9 @@ def _reduction_identity(op_name: str, input: Tensor, *args):
raise NotImplementedError(f"identity of {op_name} on {dtype} input")
def _canonical_dim(dim: DimOrDims, ndim: int) -> Tuple[int, ...]:
def _canonical_dim(dim: DimOrDims, ndim: int) -> tuple[int, ...]:
"""Return dim argument as a tuple of sorted dim values."""
dims: List[int] = []
dims: list[int] = []
if dim == ():
# Currently, `dim=()` in reductions operations means "reduce
# over all dimensions" while in future, it will read "no
@ -618,7 +618,7 @@ def _sparse_coo_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor
def _sparse_coo_scatter_reduction_helper(
op,
mask_input: Tensor,
dims: Tuple[int, ...],
dims: tuple[int, ...],
keepdim: bool,
dtype: Optional[DType] = None,
) -> Tensor:
@ -738,7 +738,7 @@ def _sparse_coo_scatter_reduction_helper(
def _sparse_csr_segment_reduction_helper(
op,
mask_input: Tensor,
dims: Tuple[int, ...],
dims: tuple[int, ...],
keepdim: bool,
dtype: Optional[DType] = None,
) -> Tensor:

View File

@ -2,7 +2,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from functools import partial
from typing import Any, Callable, Dict, TYPE_CHECKING
from typing import Any, Callable, TYPE_CHECKING
import torch
@ -228,7 +228,7 @@ def _function_to_sparse_csr(func, *args, **kwargs):
return _MaskedToSparseCsr.apply(args[0])
_MASKEDTENSOR_DISPATCH_TABLE: Dict["OpOverload", Callable[..., Any]] = {}
_MASKEDTENSOR_DISPATCH_TABLE: dict["OpOverload", Callable[..., Any]] = {}
def register_dispatch_func(aten_ops):

View File

@ -22,8 +22,8 @@ Event = torch.Event
Stream = torch.Stream
_initialized = False
_queued_calls: List[
Tuple[Callable[[], None], List[str]]
_queued_calls: list[
tuple[Callable[[], None], list[str]]
] = [] # don't invoke these until initialization occurs
_tls = threading.local()
_initialization_lock = threading.Lock()
@ -170,13 +170,13 @@ def record_memory_history(
torch._C._mtia_recordMemoryHistory(enabled, stacks, max_entries)
def snapshot() -> Dict[str, Any]:
def snapshot() -> dict[str, Any]:
r"""Return a dictionary of MTIA memory allocator history"""
return torch._C._mtia_memorySnapshot()
def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]:
def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]:
r"""Return capability of a given device as a tuple of (major version, minor version).
Args:

View File

@ -2,7 +2,7 @@
r"""This package adds support for device memory management implemented in MTIA."""
from typing import Any, Dict, Optional
from typing import Any, Optional
import torch
@ -10,7 +10,7 @@ from . import _device_t, is_initialized
from ._utils import _get_device_index
def memory_stats(device: Optional[_device_t] = None) -> Dict[str, Any]:
def memory_stats(device: Optional[_device_t] = None) -> dict[str, Any]:
r"""Return a dictionary of MTIA memory allocator statistics for a given device.
Args:

View File

@ -25,7 +25,7 @@ torch.serialization.add_safe_globals([_NestedTensor, _rebuild_njt])
def as_nested_tensor(
ts: Union[Tensor, List[Tensor], Tuple[Tensor, ...]],
ts: Union[Tensor, list[Tensor], tuple[Tensor, ...]],
dtype: Optional[DType] = None,
device: Optional[Device] = None,
layout=None,

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
from typing import * # noqa: F403
from typing import Tuple
import torch
from torch._C import DispatchKey, DispatchKeySet
@ -69,8 +68,8 @@ class NestedTensor(torch.Tensor):
# We also use nested ints to represent the strides of this tensor.
# For example, a jagged tensor with shape [B, x, D] can be strided in two
# ways: [xD, D, 1] and [x, 1, sum(x)], where xD represents x multiplied by D
_size: Tuple[int, ...]
_strides: Tuple[int, ...]
_size: tuple[int, ...]
_strides: tuple[int, ...]
# Indicates that the nth dimension is ragged
_ragged_idx: int
_metadata_cache: Dict[str, Any]
@ -417,7 +416,7 @@ def jagged_from_list(
offsets: Optional[torch.Tensor],
dtype=None,
device=None,
) -> Tuple[NestedTensor, torch.Tensor]:
) -> tuple[NestedTensor, torch.Tensor]:
"""Constructs a NestedTensor backed by jagged layout from a list of tensors"""
if len(tensors) == 0:
@ -500,7 +499,7 @@ def jagged_from_list(
def jagged_from_tensor_and_lengths(
tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor
) -> Tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]:
"""Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths"""
batch_size = tensor.shape[0]
if is_expandable_to(starts.shape, (batch_size,)) and is_expandable_to(

View File

@ -3,7 +3,7 @@ import functools
import math
import operator
from typing import * # noqa: F403
from typing import List, Optional
from typing import Optional
import torch
import torch.nn.functional as F
@ -13,7 +13,7 @@ from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention
from .nested_tensor import NestedTensor
__all__: List[Any] = []
__all__: list[Any] = []
JAGGED_OPS_TABLE: Dict[Any, Any] = {}
@ -973,7 +973,7 @@ def unbind_int(func, *args, **kwargs):
lengths = inp.lengths()
ragged_idx = inp._ragged_idx
def _torch_check(_lengths: List[int], _offsets: Optional[List[int]] = None):
def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None):
# This torch._check and torch._check_is_size are needed for torch.compile
# symbolic shapes processing.
# offsets and lengths are symbolic variables during compilation,

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import logging
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.nn
@ -302,7 +302,7 @@ def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable
return SDPBackend.ERROR
def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> tuple[torch.Tensor, int, int]:
# This function is used to calculate two pieces of metadata that are needed
# for use with flash-attention and efficient_attention kernels. They are the
# cumulative sequence_length over a batch of sequences and the maximum
@ -634,7 +634,7 @@ def _autocast(
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
[Autocasting SDPA for NJT]