mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
PEP585 update - torch/_higher_order_ops torch/_subclasses torch/backends torch/compiler torch/cuda torch/masked torch/mtia torch/nested (#145202)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145202 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
54a00af2c6
commit
805c4b597a
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import itertools
|
||||
from typing import Any, Callable, List
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
import torch._prims_common as utils
|
||||
@ -335,7 +335,7 @@ def generic_associative_scan(operator, leaves, dim=0):
|
||||
|
||||
|
||||
def trace_associative_scan(
|
||||
proxy_mode, func_overload, combine_fn: Callable, xs: List[torch.Tensor], dim: int
|
||||
proxy_mode, func_overload, combine_fn: Callable, xs: list[torch.Tensor], dim: int
|
||||
):
|
||||
with disable_proxy_modes_tracing():
|
||||
sample_xs = [first_slice_copy(x, dim) for x in itertools.chain(xs, xs)]
|
||||
@ -415,7 +415,7 @@ def associative_scan_functionalize(ctx, combine_fn, xs, dim):
|
||||
|
||||
def _fake_associative_scan(combine_fn, xs, dim, reverse=False): # noqa: F811
|
||||
inp_leaves, spec = pytree.tree_flatten(xs)
|
||||
result_flat: List[Any] = []
|
||||
result_flat: list[Any] = []
|
||||
num_leaves = len(inp_leaves)
|
||||
op = reversed if reverse else lambda x: x
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -32,7 +33,7 @@ class ViewInfo(ABC):
|
||||
self.base_index = base_index
|
||||
|
||||
@abstractmethod
|
||||
def regenerate_view(self, bases_list: List[Tensor]):
|
||||
def regenerate_view(self, bases_list: list[Tensor]):
|
||||
pass
|
||||
|
||||
|
||||
@ -48,7 +49,7 @@ class AsStridedViewInfo(ViewInfo):
|
||||
self.stride = stride
|
||||
self.storage_offset = storage_offset
|
||||
|
||||
def regenerate_view(self, bases_list: List[Tensor]):
|
||||
def regenerate_view(self, bases_list: list[Tensor]):
|
||||
return torch.as_strided(
|
||||
bases_list[self.base_index],
|
||||
self.size,
|
||||
@ -69,7 +70,7 @@ class SliceViewInfo(ViewInfo):
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
def regenerate_view(self, bases_list: List[Tensor]):
|
||||
def regenerate_view(self, bases_list: list[Tensor]):
|
||||
return torch.ops.aten.slice.Tensor(
|
||||
bases_list[self.base_index], self.dim, self.start, self.end
|
||||
)
|
||||
@ -80,7 +81,7 @@ class AliasViewInfo(ViewInfo):
|
||||
def __init__(self, base_index):
|
||||
super().__init__(base_index)
|
||||
|
||||
def regenerate_view(self, bases_list: List[Tensor]):
|
||||
def regenerate_view(self, bases_list: list[Tensor]):
|
||||
return torch.ops.aten.alias.default(bases_list[self.base_index])
|
||||
|
||||
|
||||
@ -89,7 +90,7 @@ class NotView(ViewInfo):
|
||||
def __init__(self, base_index):
|
||||
super().__init__(base_index)
|
||||
|
||||
def regenerate_view(self, bases_list: List[Tensor]):
|
||||
def regenerate_view(self, bases_list: list[Tensor]):
|
||||
return bases_list[self.base_index]
|
||||
|
||||
|
||||
@ -137,10 +138,10 @@ def try_use_slice(base, tensor):
|
||||
|
||||
|
||||
def write_view_information_to_args(
|
||||
mutable_arg_names: List[str],
|
||||
mutable_arg_types: List[torch.Type],
|
||||
kwargs: Dict[str, Any],
|
||||
arg_to_base_index: Dict[str, Any],
|
||||
mutable_arg_names: list[str],
|
||||
mutable_arg_types: list[torch.Type],
|
||||
kwargs: dict[str, Any],
|
||||
arg_to_base_index: dict[str, Any],
|
||||
):
|
||||
"""
|
||||
This function writes the view information into kwargs. It reads mutable_args from kwargs.
|
||||
@ -215,10 +216,10 @@ def write_view_information_to_args(
|
||||
|
||||
# Returns a dict of arg_name -> ViewInfo | [ViewInfo]
|
||||
def read_view_information_from_args(
|
||||
mutable_arg_names: List[str],
|
||||
mutable_arg_types: List[torch.Type],
|
||||
kwargs: Dict[str, Any],
|
||||
all_bases: List[Tensor],
|
||||
mutable_arg_names: list[str],
|
||||
mutable_arg_types: list[torch.Type],
|
||||
kwargs: dict[str, Any],
|
||||
all_bases: list[Tensor],
|
||||
):
|
||||
"""
|
||||
This reads the view information added by `write_view_information_to_args` from kwargs, pop them,
|
||||
@ -254,7 +255,7 @@ def read_view_information_from_args(
|
||||
# This means that the argument is the base tensor
|
||||
return NotView(base_index)
|
||||
|
||||
args_view_info: Dict[str, Any] = {}
|
||||
args_view_info: dict[str, Any] = {}
|
||||
for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types):
|
||||
if isinstance(arg_type, torch.ListType):
|
||||
length = get_arg(f"_{arg_name}_length")
|
||||
@ -321,7 +322,7 @@ class AutoFunctionalized(HigherOrderOperator):
|
||||
/,
|
||||
_mutable_op: OpOverload,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
||||
) -> tuple[Any, tuple[Tensor, ...]]:
|
||||
assert can_auto_functionalize(_mutable_op)
|
||||
assert isinstance(kwargs, dict)
|
||||
return super().__call__(_mutable_op, **kwargs)
|
||||
@ -350,7 +351,7 @@ class AutoFunctionalizedV2(HigherOrderOperator):
|
||||
/,
|
||||
_mutable_op: OpOverload,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
||||
) -> tuple[Any, tuple[Tensor, ...]]:
|
||||
assert can_auto_functionalize(_mutable_op)
|
||||
assert isinstance(kwargs, dict)
|
||||
return super().__call__(_mutable_op, **kwargs)
|
||||
@ -411,7 +412,7 @@ def can_auto_functionalize(op: OperatorBase) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def get_mutable_args(op: OpOverload) -> Tuple[List[str], List[torch.Type]]:
|
||||
def get_mutable_args(op: OpOverload) -> tuple[list[str], list[torch.Type]]:
|
||||
"""
|
||||
Returns the list of argument names that get mutated according to the
|
||||
schema and their types.
|
||||
@ -432,8 +433,8 @@ def get_mutable_args(op: OpOverload) -> Tuple[List[str], List[torch.Type]]:
|
||||
|
||||
def do_auto_functionalize(
|
||||
op: OpOverload,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Functionalizes a call to op(*args, **kwargs) by emitting a call to
|
||||
`outs = auto_functionalized(op, normalized_kwargs)`
|
||||
@ -476,7 +477,7 @@ def do_auto_functionalize(
|
||||
# List of the name of args that get mutated (according to the schema)
|
||||
mutable_args_names, _ = get_mutable_args(op)
|
||||
|
||||
unwrapped_actual_out: Union[Any, Tuple[Any]] = unwrapped_outs[
|
||||
unwrapped_actual_out: Union[Any, tuple[Any]] = unwrapped_outs[
|
||||
: -len(mutable_args_names)
|
||||
]
|
||||
unwrapped_mutable_out = unwrapped_outs[-len(mutable_args_names) :]
|
||||
@ -521,8 +522,8 @@ def do_auto_functionalize(
|
||||
|
||||
def do_auto_functionalize_v2(
|
||||
op: OpOverload,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
|
||||
|
||||
@ -553,7 +554,7 @@ def do_auto_functionalize_v2(
|
||||
all_bases_addresses: list[int] = []
|
||||
|
||||
# Map arg_name to the index of its base in all_bases.
|
||||
arg_to_base_index: Dict[str, Any] = {}
|
||||
arg_to_base_index: dict[str, Any] = {}
|
||||
|
||||
def update_dict(tensor, arg_name, index=None):
|
||||
base = tensor if get_base(tensor) is None else get_base(tensor)
|
||||
@ -613,7 +614,7 @@ def do_auto_functionalize_v2(
|
||||
op, **dict(unwrapped_kwargs, _all_bases=all_basis_unwrapped) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
unwrapped_actual_out: Union[Any, Tuple[Any]] = (
|
||||
unwrapped_actual_out: Union[Any, tuple[Any]] = (
|
||||
unwrapped_outs if len(all_bases) == 0 else unwrapped_outs[: -len(all_bases)]
|
||||
)
|
||||
|
||||
@ -661,9 +662,9 @@ def do_auto_functionalize_v2(
|
||||
@auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def auto_functionalized_dense(
|
||||
_mutable_op: OpOverload,
|
||||
_only_clone_these_tensors: Optional[Tuple[str, ...]] = None,
|
||||
_only_clone_these_tensors: Optional[tuple[str, ...]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
||||
) -> tuple[Any, tuple[Tensor, ...]]:
|
||||
new_kwargs = dict(**kwargs)
|
||||
result = []
|
||||
|
||||
@ -698,7 +699,7 @@ def auto_functionalized_fake(
|
||||
mode,
|
||||
_mutable_op: OpOverload,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
||||
) -> tuple[Any, tuple[Tensor, ...]]:
|
||||
with mode:
|
||||
result = auto_functionalized_dense(
|
||||
_mutable_op, _only_clone_these_tensors=None, **kwargs
|
||||
@ -711,7 +712,7 @@ def auto_functionalized_proxy(
|
||||
mode,
|
||||
_mutable_op: OpOverload,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
||||
) -> tuple[Any, tuple[Tensor, ...]]:
|
||||
with disable_proxy_modes_tracing():
|
||||
out = auto_functionalized(_mutable_op, **kwargs)
|
||||
|
||||
@ -738,10 +739,10 @@ def auto_functionalized_func(ctx, _mutable_op, **kwargs):
|
||||
@auto_functionalized_v2.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def auto_functionalized_v2_dense(
|
||||
_mutable_op: OpOverload,
|
||||
_only_clone_these_bases: Optional[Tuple[int, ...]] = None,
|
||||
_only_clone_these_bases: Optional[tuple[int, ...]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
||||
all_bases: List[Tensor] = kwargs.pop("_all_bases", [])
|
||||
) -> tuple[Any, tuple[Tensor, ...]]:
|
||||
all_bases: list[Tensor] = kwargs.pop("_all_bases", [])
|
||||
mutable_args_names, mutable_args_types = get_mutable_args(_mutable_op)
|
||||
args_view_info = read_view_information_from_args(
|
||||
mutable_args_names, mutable_args_types, kwargs, all_bases
|
||||
@ -794,8 +795,8 @@ def auto_functionalized_v2_dense(
|
||||
def auto_functionalized_v2_fake(
|
||||
mode,
|
||||
_mutable_op: OpOverload,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
||||
**kwargs: dict[str, Any],
|
||||
) -> tuple[Any, tuple[Tensor, ...]]:
|
||||
with mode:
|
||||
result = auto_functionalized_v2_dense(
|
||||
_mutable_op, _only_clone_these_bases=None, **kwargs
|
||||
@ -807,8 +808,8 @@ def auto_functionalized_v2_fake(
|
||||
def auto_functionalized_v2_proxy(
|
||||
mode,
|
||||
_mutable_op: OpOverload,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
||||
**kwargs: dict[str, Any],
|
||||
) -> tuple[Any, tuple[Tensor, ...]]:
|
||||
with disable_proxy_modes_tracing():
|
||||
out = auto_functionalized_v2(_mutable_op, **kwargs)
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Callable, List, Tuple, Union
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch._subclasses.functional_tensor
|
||||
@ -71,7 +71,7 @@ def cond(
|
||||
pred: Union[bool, int, float, torch.Tensor],
|
||||
true_fn: Callable,
|
||||
false_fn: Callable,
|
||||
operands: Union[Tuple, List] = (),
|
||||
operands: Union[tuple, list] = (),
|
||||
) -> Any:
|
||||
r"""
|
||||
Conditionally applies `true_fn` or `false_fn`.
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
import torch
|
||||
@ -71,9 +71,9 @@ class WithEffects(HigherOrderOperator):
|
||||
self,
|
||||
token,
|
||||
op: OpType,
|
||||
*args: Tuple[Any, ...],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[Any, ...]:
|
||||
*args: tuple[Any, ...],
|
||||
**kwargs: dict[str, Any],
|
||||
) -> tuple[Any, ...]:
|
||||
assert isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
|
||||
assert not has_aliasing(op), "Ops with aliasing is not supported"
|
||||
assert has_effects(op, args, kwargs)
|
||||
@ -133,9 +133,9 @@ def new_token_tensor() -> torch.Tensor:
|
||||
def with_effects_dense(
|
||||
token: torch.Tensor,
|
||||
op: torch._ops.OpOverload,
|
||||
*args: Tuple[Any, ...],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
*args: tuple[Any, ...],
|
||||
**kwargs: dict[str, Any],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
out = op(*args, **kwargs)
|
||||
new_token = new_token_tensor()
|
||||
if isinstance(out, tuple):
|
||||
@ -148,9 +148,9 @@ def with_effects_fake(
|
||||
mode,
|
||||
token: torch.Tensor,
|
||||
op: torch._ops.OpOverload,
|
||||
*args: Tuple[Any, ...],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
*args: tuple[Any, ...],
|
||||
**kwargs: dict[str, Any],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
with mode:
|
||||
result = with_effects_dense(token, op, *args, **kwargs)
|
||||
return result
|
||||
@ -161,9 +161,9 @@ def with_effects_proxy(
|
||||
mode,
|
||||
token: torch.Tensor,
|
||||
op: torch._ops.OpOverload,
|
||||
*args: Tuple[Any, ...],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
*args: tuple[Any, ...],
|
||||
**kwargs: dict[str, Any],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
with disable_proxy_modes_tracing():
|
||||
out = with_effects(token, op, *args, **kwargs)
|
||||
|
||||
@ -202,10 +202,10 @@ def _get_schema(op, args) -> torch.FunctionSchema:
|
||||
|
||||
def handle_effects(
|
||||
allow_token_discovery: bool,
|
||||
tokens: Dict[_EffectType, torch.Tensor],
|
||||
tokens: dict[_EffectType, torch.Tensor],
|
||||
op: OpType,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
"""
|
||||
Args:
|
||||
|
@ -1,5 +1,6 @@
|
||||
import math
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -48,7 +49,7 @@ def _construct_strides(
|
||||
return strides
|
||||
|
||||
|
||||
def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch.Tensor:
|
||||
def _permute_strides(out: torch.Tensor, query_strides: tuple[int, ...]) -> torch.Tensor:
|
||||
"""
|
||||
Create a new tensor with the same data and shape as the input,
|
||||
but with strides permuted based on the input tensor's stride order.
|
||||
@ -81,12 +82,12 @@ class FlexAttentionHOP(HigherOrderOperator):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
score_mod: Callable,
|
||||
block_mask: Tuple,
|
||||
block_mask: tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
kernel_options: dict[str, Any],
|
||||
score_mod_other_buffers: tuple = (),
|
||||
mask_mod_other_buffers: tuple = (),
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers)
|
||||
return super().__call__(
|
||||
query,
|
||||
@ -119,13 +120,13 @@ class FlexAttentionBackwardHOP(HigherOrderOperator):
|
||||
grad_logsumexp: torch.Tensor,
|
||||
fw_graph: Union[Callable, GraphModule],
|
||||
joint_graph: GraphModule,
|
||||
block_mask: Tuple,
|
||||
block_mask: tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
|
||||
kernel_options: dict[str, Any],
|
||||
score_mod_other_buffers: tuple = (),
|
||||
mask_mod_other_buffers: tuple = (),
|
||||
) -> tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
||||
]:
|
||||
validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers)
|
||||
return super().__call__(
|
||||
@ -154,12 +155,12 @@ def _math_attention_inner(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
score_mod: Callable,
|
||||
block_mask: Tuple,
|
||||
block_mask: tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
kernel_options: dict[str, Any],
|
||||
score_mod_other_buffers: tuple = (),
|
||||
mask_mod_other_buffers: tuple = (),
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
||||
|
||||
working_precision = torch.float64 if query.dtype == torch.float64 else torch.float32
|
||||
@ -197,12 +198,12 @@ def math_attention(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
score_mod: Callable,
|
||||
block_mask: Tuple,
|
||||
block_mask: tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
kernel_options: dict[str, Any],
|
||||
score_mod_other_buffers: tuple = (),
|
||||
mask_mod_other_buffers: tuple = (),
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Eager implementation
|
||||
|
||||
This implementation uses vmap to vectorize the score_mod function over the batch, head, m, and n dimensions.
|
||||
@ -256,12 +257,12 @@ def sdpa_dense(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
score_mod: Callable,
|
||||
block_mask: Tuple,
|
||||
block_mask: tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
kernel_options: dict[str, Any],
|
||||
score_mod_other_buffers: tuple = (),
|
||||
mask_mod_other_buffers: tuple = (),
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
out, lse = math_attention(
|
||||
query,
|
||||
key,
|
||||
@ -283,12 +284,12 @@ def trace_flex_attention(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
score_mod: Callable,
|
||||
block_mask: Tuple,
|
||||
block_mask: tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
kernel_options: dict[str, Any],
|
||||
score_mod_other_buffers: tuple = (),
|
||||
mask_mod_other_buffers: tuple = (),
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Traces the flex_attention operator with the given score_mod function and other_buffers.
|
||||
|
||||
Trace SDPA will call make_fx with "fake" example vals and then trace the score_mod function
|
||||
@ -353,12 +354,12 @@ def flex_attention_proxy_torch_dispatch_mode(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
score_mod: Callable,
|
||||
block_mask: Tuple,
|
||||
block_mask: tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
kernel_options: dict[str, Any],
|
||||
score_mod_other_buffers: tuple = (),
|
||||
mask_mod_other_buffers: tuple = (),
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert mode is not None, "Mode should always be enabled for python fallback key"
|
||||
return trace_flex_attention(
|
||||
mode,
|
||||
@ -381,12 +382,12 @@ def flex_attention_functionalize(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
score_mod: Callable,
|
||||
block_mask: Tuple,
|
||||
block_mask: tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
kernel_options: dict[str, Any],
|
||||
score_mod_other_buffers: tuple = (),
|
||||
mask_mod_other_buffers: tuple = (),
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Defines the functionalization rules for the flex_attention operator.
|
||||
|
||||
Write now we are unwrapping each tensor and then redispatching to the next, however we want to
|
||||
@ -443,7 +444,7 @@ def flex_attention_functionalize(
|
||||
|
||||
def flex_attention_fake_impl(
|
||||
query: torch.Tensor, value: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# TODO: Figure out a better way to handle this for NJT than using sum()
|
||||
if query.is_nested:
|
||||
out = torch.empty_like(query, memory_format=torch.contiguous_format)
|
||||
@ -466,12 +467,12 @@ def flex_attention_fake_tensor_mode(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
score_mod: Callable,
|
||||
block_mask: Tuple,
|
||||
block_mask: tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
kernel_options: dict[str, Any],
|
||||
score_mod_other_buffers: tuple = (),
|
||||
mask_mod_other_buffers: tuple = (),
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
with mode:
|
||||
out, logsumexp = flex_attention_fake_impl(query, value)
|
||||
return out, logsumexp
|
||||
@ -480,9 +481,9 @@ def flex_attention_fake_tensor_mode(
|
||||
# ---------------------------- Autograd Implementation ----------------------------
|
||||
def create_fw_bw_graph(
|
||||
score_mod: Callable,
|
||||
index_values: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor],
|
||||
other_buffers: Tuple[Tensor, ...],
|
||||
) -> Tuple[Callable, Callable]:
|
||||
index_values: tuple[Tensor, Tensor, Tensor, Tensor, Tensor],
|
||||
other_buffers: tuple[Tensor, ...],
|
||||
) -> tuple[Callable, Callable]:
|
||||
# See Note:[HOP create fw_bw graph]
|
||||
|
||||
# All of these imports need to be here in order to avoid circular dependencies
|
||||
@ -554,11 +555,11 @@ def create_fw_bw_graph(
|
||||
m: Tensor,
|
||||
n: Tensor,
|
||||
example_grad: Tensor,
|
||||
*other_buffers: Tuple[Tensor, ...],
|
||||
) -> Tuple[Tensor, ...]:
|
||||
*other_buffers: tuple[Tensor, ...],
|
||||
) -> tuple[Tensor, ...]:
|
||||
def fw_with_masks(
|
||||
*args: Tuple[Tensor, ...]
|
||||
) -> Tuple[Tuple[Tensor], Tuple[bool]]:
|
||||
*args: tuple[Tensor, ...]
|
||||
) -> tuple[tuple[Tensor], tuple[bool]]:
|
||||
fw_out = score_mod(*args)
|
||||
out_requires_grad = fw_out.requires_grad
|
||||
return ((fw_out,), (out_requires_grad,))
|
||||
@ -585,12 +586,12 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
|
||||
value: Tensor,
|
||||
fw_graph: Callable,
|
||||
joint_graph: Callable,
|
||||
block_mask: Tuple[Any, ...],
|
||||
block_mask: tuple[Any, ...],
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
mask_mod_other_buffers: Tuple[Any, ...],
|
||||
*score_mod_other_buffers: Tuple[Any, ...],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
kernel_options: dict[str, Any],
|
||||
mask_mod_other_buffers: tuple[Any, ...],
|
||||
*score_mod_other_buffers: tuple[Any, ...],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
any_buffer_requires_grad = any(
|
||||
buffer.requires_grad
|
||||
for buffer in mask_mod_other_buffers
|
||||
@ -634,7 +635,7 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
|
||||
return out, logsumexp
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Optional[Tensor], ...]: # type: ignore[override]
|
||||
def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> tuple[Optional[Tensor], ...]: # type: ignore[override]
|
||||
fw_args = saved_tensors_and_symints(ctx)
|
||||
(
|
||||
query,
|
||||
@ -714,12 +715,12 @@ def flex_attention_autograd(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
score_mod: Callable,
|
||||
block_mask: Tuple,
|
||||
block_mask: tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple[Tensor, ...] = (),
|
||||
mask_mod_other_buffers: Tuple[Tensor, ...] = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
kernel_options: dict[str, Any],
|
||||
score_mod_other_buffers: tuple[Tensor, ...] = (),
|
||||
mask_mod_other_buffers: tuple[Tensor, ...] = (),
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
||||
|
||||
with TransformGetItemToIndex():
|
||||
@ -769,13 +770,13 @@ def sdpa_dense_backward(
|
||||
grad_logsumexp: torch.Tensor,
|
||||
fw_graph: Callable, # GraphModule type hint?
|
||||
joint_graph: Callable,
|
||||
block_mask: Tuple,
|
||||
block_mask: tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple,
|
||||
mask_mod_other_buffers: Tuple,
|
||||
) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
|
||||
kernel_options: dict[str, Any],
|
||||
score_mod_other_buffers: tuple,
|
||||
mask_mod_other_buffers: tuple,
|
||||
) -> tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
||||
]:
|
||||
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
||||
|
||||
@ -921,13 +922,13 @@ def trace_flex_attention_backward(
|
||||
grad_logsumexp: torch.Tensor,
|
||||
fw_graph: Union[Callable, GraphModule],
|
||||
joint_graph: GraphModule,
|
||||
block_mask: Tuple,
|
||||
block_mask: tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
|
||||
kernel_options: dict[str, Any],
|
||||
score_mod_other_buffers: tuple = (),
|
||||
mask_mod_other_buffers: tuple = (),
|
||||
) -> tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
||||
]:
|
||||
"""We already have the forward graph and joint graph from the forward pass, so we create a proxy attach both graphs"""
|
||||
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
||||
@ -1018,13 +1019,13 @@ def flex_attention_backward_proxy_torch_dispatch_mode(
|
||||
grad_logsumexp: torch.Tensor,
|
||||
fw_graph: Union[Callable, GraphModule],
|
||||
joint_graph: GraphModule,
|
||||
block_mask: Tuple,
|
||||
block_mask: tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
|
||||
kernel_options: dict[str, Any],
|
||||
score_mod_other_buffers: tuple = (),
|
||||
mask_mod_other_buffers: tuple = (),
|
||||
) -> tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
||||
]:
|
||||
assert mode is not None, "Mode should always be enabled for python fallback key"
|
||||
return trace_flex_attention_backward(
|
||||
@ -1058,13 +1059,13 @@ def flex_attention_backward_functionalize(
|
||||
grad_logsumexp: torch.Tensor,
|
||||
fw_graph: Union[Callable, GraphModule],
|
||||
joint_graph: GraphModule,
|
||||
block_mask: Tuple,
|
||||
block_mask: tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
|
||||
kernel_options: dict[str, Any],
|
||||
score_mod_other_buffers: tuple = (),
|
||||
mask_mod_other_buffers: tuple = (),
|
||||
) -> tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
||||
]:
|
||||
"""Defines the functionalization rules for the flex_attention operator.
|
||||
|
||||
@ -1136,13 +1137,13 @@ def flex_attention_backward_fake_tensor_mode(
|
||||
grad_logsumexp: torch.Tensor,
|
||||
fw_graph: Union[Callable, GraphModule],
|
||||
joint_graph: GraphModule,
|
||||
block_mask: Tuple,
|
||||
block_mask: tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
|
||||
kernel_options: dict[str, Any],
|
||||
score_mod_other_buffers: tuple = (),
|
||||
mask_mod_other_buffers: tuple = (),
|
||||
) -> tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
||||
]:
|
||||
with mode:
|
||||
grad_query = torch.empty_like(query)
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
from typing import Any, Callable
|
||||
|
||||
from torch._higher_order_ops.prim_hop_base import FunctionWithNoFreeVars, PrimHOPBase
|
||||
|
||||
@ -18,7 +18,7 @@ _foreach_map = ForeachMap()
|
||||
|
||||
|
||||
def foreach_map(
|
||||
op: Callable, operands: Any, *unused: Tuple[Any], **kwargs: Dict[str, Any]
|
||||
op: Callable, operands: Any, *unused: tuple[Any], **kwargs: dict[str, Any]
|
||||
):
|
||||
from torch._dynamo.polyfills import foreach_map_fn
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -42,8 +42,8 @@ class InvokeSubgraphHOP(HigherOrderOperator):
|
||||
subgraph: GraphModule,
|
||||
identifier: Optional[str],
|
||||
operands: Union[
|
||||
List[Union[torch.Tensor, int, torch.SymInt]],
|
||||
Tuple[Union[torch.Tensor, int, torch.SymInt]],
|
||||
list[Union[torch.Tensor, int, torch.SymInt]],
|
||||
tuple[Union[torch.Tensor, int, torch.SymInt]],
|
||||
],
|
||||
):
|
||||
assert identifier is None or isinstance(
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import itertools
|
||||
from typing import Any, Callable, List, Tuple
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
import torch._prims_common as utils
|
||||
@ -45,20 +45,20 @@ def wrap_combine_fn_flat(
|
||||
return [*carry_flat, *combined_flat]
|
||||
|
||||
|
||||
def _extract_carry_and_out(flat_out: List[Any], num_carry: int):
|
||||
def _extract_carry_and_out(flat_out: list[Any], num_carry: int):
|
||||
return flat_out[:num_carry], flat_out[num_carry:]
|
||||
|
||||
|
||||
def scan(
|
||||
combine_fn: Callable[
|
||||
[pytree.PyTree, pytree.PyTree], Tuple[pytree.PyTree, pytree.PyTree]
|
||||
[pytree.PyTree, pytree.PyTree], tuple[pytree.PyTree, pytree.PyTree]
|
||||
],
|
||||
init: pytree.PyTree,
|
||||
xs: pytree.PyTree,
|
||||
*,
|
||||
dim: int = 0,
|
||||
reverse: bool = False,
|
||||
) -> Tuple[pytree.PyTree, pytree.PyTree]:
|
||||
) -> tuple[pytree.PyTree, pytree.PyTree]:
|
||||
r"""
|
||||
Performs an inclusive scan with a combine function.
|
||||
|
||||
@ -331,11 +331,11 @@ def trace_scan(
|
||||
proxy_mode,
|
||||
func_overload,
|
||||
combine_fn: Callable,
|
||||
init: List[torch.Tensor],
|
||||
xs: List[torch.Tensor],
|
||||
init: list[torch.Tensor],
|
||||
xs: list[torch.Tensor],
|
||||
dim: int,
|
||||
reverse: bool,
|
||||
additional_inputs: List[torch.Tensor],
|
||||
additional_inputs: list[torch.Tensor],
|
||||
):
|
||||
from torch._dynamo.utils import clone_input
|
||||
|
||||
|
@ -5,17 +5,8 @@ import inspect
|
||||
import logging
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import Never
|
||||
|
||||
import sympy
|
||||
@ -48,9 +39,9 @@ if TYPE_CHECKING:
|
||||
from torch.fx.proxy import Proxy
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
TritonMetaParamsType = Dict[str, int]
|
||||
TritonGridTupleType = Tuple[Union[int, sympy.Expr, SymInt], ...]
|
||||
TritonGridCallableType = Callable[[TritonMetaParamsType], Tuple[int, ...]]
|
||||
TritonMetaParamsType = dict[str, int]
|
||||
TritonGridTupleType = tuple[Union[int, sympy.Expr, SymInt], ...]
|
||||
TritonGridCallableType = Callable[[TritonMetaParamsType], tuple[int, ...]]
|
||||
TritonGridType = Union[TritonGridTupleType, TritonGridCallableType]
|
||||
|
||||
if has_triton():
|
||||
@ -76,11 +67,11 @@ log = logging.getLogger("torch._dynamo")
|
||||
# conisting of list of dims, list of block dims, and element size. E.g., for this
|
||||
# call in host-side Triton TMA API ``create_2d_tma_descriptor(ptr, 50, 60, 32, 15, 4)``,
|
||||
# the metadata will look like ``([50, 60], [32, 15], 4)``. All ints can be SymInts.
|
||||
TMADescriptorMetadata = Dict[
|
||||
TMADescriptorMetadata = dict[
|
||||
str, # kernel parameter name
|
||||
Tuple[
|
||||
List[Union[int, SymInt]], # dims
|
||||
List[Union[int, SymInt]], # block_dims
|
||||
tuple[
|
||||
list[Union[int, SymInt]], # dims
|
||||
list[Union[int, SymInt]], # block_dims
|
||||
Union[int, SymInt], # element_size
|
||||
],
|
||||
]
|
||||
@ -95,9 +86,9 @@ TMADescriptorMetadata = Dict[
|
||||
# Use a side table.
|
||||
# We use two dicts so that fetching both the kernel and id are O(1)
|
||||
class KernelSideTable:
|
||||
id_to_kernel: Dict[int, "TritonKernelType"] = {}
|
||||
kernel_to_id: Dict["TritonKernelType", int] = {}
|
||||
constant_args: Dict[int, Dict[str, Any]] = {}
|
||||
id_to_kernel: dict[int, "TritonKernelType"] = {}
|
||||
kernel_to_id: dict["TritonKernelType", int] = {}
|
||||
constant_args: dict[int, dict[str, Any]] = {}
|
||||
lock = threading.Lock()
|
||||
|
||||
# Returns index on the table
|
||||
@ -119,14 +110,14 @@ class KernelSideTable:
|
||||
|
||||
# Not every constant arg can be added to the graph. Use this side table
|
||||
# for constant args.
|
||||
def add_constant_args(self, args: Dict[str, Any]) -> int:
|
||||
def add_constant_args(self, args: dict[str, Any]) -> int:
|
||||
with self.lock:
|
||||
idx = len(self.constant_args)
|
||||
self.constant_args[idx] = args
|
||||
return idx
|
||||
|
||||
# Returns the constant args
|
||||
def get_constant_args(self, idx: int) -> Dict[str, Any]:
|
||||
def get_constant_args(self, idx: int) -> dict[str, Any]:
|
||||
# No need to lock here as fetching from dict is atomic
|
||||
assert idx in self.constant_args
|
||||
return self.constant_args[idx]
|
||||
@ -163,7 +154,7 @@ class Intermediate:
|
||||
class Op:
|
||||
name: str
|
||||
fn_call_name: Optional[str]
|
||||
args: List[Union[Param, Intermediate]]
|
||||
args: list[Union[Param, Intermediate]]
|
||||
ret: Intermediate = dataclasses.field(repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@ -174,8 +165,8 @@ class Op:
|
||||
|
||||
|
||||
def generate_ttir(
|
||||
kernel: "TritonKernelType", kwargs: Dict[str, Any]
|
||||
) -> Tuple["TritonIRModule", List[str]]:
|
||||
kernel: "TritonKernelType", kwargs: dict[str, Any]
|
||||
) -> tuple["TritonIRModule", list[str]]:
|
||||
"""
|
||||
Uses Triton's internal code generation to create TTIR
|
||||
"""
|
||||
@ -218,7 +209,7 @@ def generate_ttir(
|
||||
# Replace all SymExprs with a regular value for TTIR generation
|
||||
# Replace all FakeTensor/TensorBox with real tensors
|
||||
# These replacements are needed for triton's type, key and config functions
|
||||
ordered_args: Dict[str, Any] = {}
|
||||
ordered_args: dict[str, Any] = {}
|
||||
for name in kernel.arg_names:
|
||||
a = kwargs[name]
|
||||
if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool, sympy.Expr)):
|
||||
@ -280,22 +271,22 @@ def generate_ttir(
|
||||
|
||||
def ttir_to_functions(
|
||||
ttir_module: "TritonIRModule",
|
||||
) -> Dict[str, Dict[Intermediate, List[Op]]]:
|
||||
) -> dict[str, dict[Intermediate, list[Op]]]:
|
||||
"""
|
||||
Walk the `ttir_module` bottom up to mine the `functions` from
|
||||
the structured MLIR entities representing the Triton kernel
|
||||
(mlir::Operation, mlir::Block, mlir::Region).
|
||||
"""
|
||||
functions: Dict[str, Dict[Intermediate, List[Op]]] = {}
|
||||
functions: dict[str, dict[Intermediate, list[Op]]] = {}
|
||||
|
||||
# block id --> op result (Intermediate) --> one or more ops
|
||||
op_stack: Dict[int, Dict[Intermediate, List[Op]]] = defaultdict(
|
||||
op_stack: dict[int, dict[Intermediate, list[Op]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
region_id_to_block_ids: Dict[int, List[int]] = defaultdict(list)
|
||||
block_id_to_block_arg_ids: Dict[int, List[int]] = {}
|
||||
replacements: Dict[int, Union[Intermediate, Param]] = {}
|
||||
reindex_map: Dict[int, int] = {}
|
||||
region_id_to_block_ids: dict[int, list[int]] = defaultdict(list)
|
||||
block_id_to_block_arg_ids: dict[int, list[int]] = {}
|
||||
replacements: dict[int, Union[Intermediate, Param]] = {}
|
||||
reindex_map: dict[int, int] = {}
|
||||
next_fake_intermediate = 0
|
||||
|
||||
def reindex(idx: int) -> int:
|
||||
@ -309,14 +300,14 @@ def ttir_to_functions(
|
||||
# this wraps all tt.func ops
|
||||
return
|
||||
|
||||
operand_ids: List[int] = [
|
||||
operand_ids: list[int] = [
|
||||
reindex(op.get_operand(i).id()) for i in range(op.get_num_operands())
|
||||
]
|
||||
result_ids: List[int] = [
|
||||
result_ids: list[int] = [
|
||||
reindex(op.get_result(i).id()) for i in range(op.get_num_results())
|
||||
]
|
||||
|
||||
child_block_ids: List[int] = []
|
||||
child_block_ids: list[int] = []
|
||||
for i in [op.get_region(i).id() for i in range(op.get_num_regions())]:
|
||||
# as the walk is bottom-up, the region_id_to_block_ids[i]
|
||||
# must be populated by the time we process the enclosing op
|
||||
@ -460,7 +451,7 @@ def ttir_to_functions(
|
||||
callee = None
|
||||
if name == "tt.call":
|
||||
callee = op.get_flat_symbol_ref_attr("callee")
|
||||
args: List[Union[Param, Intermediate]] = [
|
||||
args: list[Union[Param, Intermediate]] = [
|
||||
Intermediate(operand) for operand in operand_ids
|
||||
]
|
||||
block_ops = op_stack[parent_block_id]
|
||||
@ -480,7 +471,7 @@ def ttir_to_functions(
|
||||
|
||||
class MemoizeWithCycleCheck:
|
||||
fn: Callable[..., Any]
|
||||
cache: Dict[Tuple[str, int], Any]
|
||||
cache: dict[tuple[str, int], Any]
|
||||
|
||||
def __init__(self, fn: Callable[..., Any]) -> None:
|
||||
self.fn = fn
|
||||
@ -488,10 +479,10 @@ class MemoizeWithCycleCheck:
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
functions: Dict[str, Dict[Intermediate, List[Op]]],
|
||||
functions: dict[str, dict[Intermediate, list[Op]]],
|
||||
fn_name: str,
|
||||
num_args: int,
|
||||
) -> List[bool]:
|
||||
) -> list[bool]:
|
||||
key = (fn_name, num_args)
|
||||
if key not in self.cache:
|
||||
self.cache[key] = None
|
||||
@ -506,8 +497,8 @@ class MemoizeWithCycleCheck:
|
||||
|
||||
@MemoizeWithCycleCheck
|
||||
def analyze_kernel_mutations(
|
||||
functions: Dict[str, Dict[Intermediate, List[Op]]], fn_name: str, num_args: int
|
||||
) -> List[bool]:
|
||||
functions: dict[str, dict[Intermediate, list[Op]]], fn_name: str, num_args: int
|
||||
) -> list[bool]:
|
||||
"""
|
||||
Analyzes the graph to detect all sinks from a predefined list of sinks
|
||||
by using triton's MemWrite trait list. NOTE: What if triton exposed this?
|
||||
@ -527,7 +518,7 @@ def analyze_kernel_mutations(
|
||||
# Ops that we want to bail out on
|
||||
UNKNOWN_OPS = {"tt.elementwise_inline_asm"}
|
||||
|
||||
stack: List[Union[Param, Intermediate]] = []
|
||||
stack: list[Union[Param, Intermediate]] = []
|
||||
visited = set()
|
||||
ops = functions[fn_name]
|
||||
for op_list in ops.values():
|
||||
@ -569,8 +560,8 @@ def analyze_kernel_mutations(
|
||||
|
||||
|
||||
def identify_mutated_tensors(
|
||||
kernel: "TritonKernelType", kwargs: Dict[str, Any]
|
||||
) -> List[str]:
|
||||
kernel: "TritonKernelType", kwargs: dict[str, Any]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Given a triton kernel and the arguments for this kernel, this function
|
||||
1) Retrieves the TTIR converted version of the kernel from Triton's API.
|
||||
@ -630,9 +621,9 @@ class TritonKernelWrapperMutation(HigherOrderOperator):
|
||||
self,
|
||||
kernel_idx: int,
|
||||
constant_args_idx: int,
|
||||
grid: List["TritonGridType"],
|
||||
grid: list["TritonGridType"],
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
return super().__call__(
|
||||
kernel_idx=kernel_idx,
|
||||
@ -655,11 +646,11 @@ class TritonKernelWrapperFunctional(HigherOrderOperator):
|
||||
self,
|
||||
kernel_idx: int,
|
||||
constant_args_idx: int,
|
||||
grid: List["TritonGridType"],
|
||||
grid: list["TritonGridType"],
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs: Dict[str, Any],
|
||||
tensors_to_clone: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
kwargs: dict[str, Any],
|
||||
tensors_to_clone: list[str],
|
||||
) -> dict[str, Any]:
|
||||
return super().__call__(
|
||||
kernel_idx=kernel_idx,
|
||||
constant_args_idx=constant_args_idx,
|
||||
@ -678,9 +669,9 @@ def triton_kernel_wrapper_mutation_dense(
|
||||
*,
|
||||
kernel_idx: int,
|
||||
constant_args_idx: int,
|
||||
grid: List["TritonGridType"],
|
||||
grid: list["TritonGridType"],
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code
|
||||
|
||||
@ -693,7 +684,7 @@ def triton_kernel_wrapper_mutation_dense(
|
||||
fn_name, code = user_defined_kernel_grid_fn_code(
|
||||
kernel.fn.__name__, kernel.configs, grid
|
||||
)
|
||||
namespace: Dict[str, Any] = {}
|
||||
namespace: dict[str, Any] = {}
|
||||
exec(code, namespace)
|
||||
grid_fn = namespace[fn_name]
|
||||
|
||||
@ -746,9 +737,9 @@ def triton_kernel_wrapper_mutation_fake_tensor_mode(
|
||||
*,
|
||||
kernel_idx: int,
|
||||
constant_args_idx: int,
|
||||
grid: List["TritonGridType"],
|
||||
grid: list["TritonGridType"],
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
with mode:
|
||||
return None
|
||||
@ -759,9 +750,9 @@ def _(
|
||||
*,
|
||||
kernel_idx: int,
|
||||
constant_args_idx: int,
|
||||
grid: List["TritonGridType"],
|
||||
grid: list["TritonGridType"],
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
@ -769,8 +760,8 @@ def _(
|
||||
def trace_triton_kernel_wrapper(
|
||||
proxy_mode: ProxyTorchDispatchMode,
|
||||
func_overload: Callable[..., Any],
|
||||
node_args: Dict[str, Any],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
node_args: dict[str, Any],
|
||||
) -> Optional[dict[str, Any]]:
|
||||
with disable_proxy_modes_tracing():
|
||||
out = func_overload(**node_args)
|
||||
|
||||
@ -795,9 +786,9 @@ def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
|
||||
*,
|
||||
kernel_idx: int,
|
||||
constant_args_idx: int,
|
||||
grid: List["TritonGridType"],
|
||||
grid: list["TritonGridType"],
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
trace_triton_kernel_wrapper(
|
||||
mode,
|
||||
@ -815,8 +806,8 @@ def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
|
||||
|
||||
|
||||
def get_mutated_tensors(
|
||||
kernel_idx: int, constant_args_idx: int, kwargs: Dict[str, Any]
|
||||
) -> List[str]:
|
||||
kernel_idx: int, constant_args_idx: int, kwargs: dict[str, Any]
|
||||
) -> list[str]:
|
||||
kernel = kernel_side_table.get_kernel(kernel_idx)
|
||||
constant_args = kernel_side_table.get_constant_args(constant_args_idx)
|
||||
return identify_mutated_tensors(kernel, {**kwargs, **constant_args})
|
||||
@ -827,9 +818,9 @@ def triton_kernel_wrapper_mutation_functionalize(
|
||||
ctx: "BaseFunctionalizeAPI",
|
||||
kernel_idx: int,
|
||||
constant_args_idx: int,
|
||||
grid: List["TritonGridType"],
|
||||
grid: list["TritonGridType"],
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
|
||||
# TODO(oulgen): Preexisting bug, if two kernel inputs are views of each
|
||||
@ -869,11 +860,11 @@ def triton_kernel_wrapper_functional_dense(
|
||||
*,
|
||||
kernel_idx: int,
|
||||
constant_args_idx: int,
|
||||
grid: List["TritonGridType"],
|
||||
grid: list["TritonGridType"],
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs: Dict[str, Any],
|
||||
tensors_to_clone: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
kwargs: dict[str, Any],
|
||||
tensors_to_clone: list[str],
|
||||
) -> dict[str, Any]:
|
||||
# TODO(oulgen): For performance reasons, we want to ensure that these
|
||||
# `clone_preserve_strides` calls are never executed at runtime
|
||||
# (inductor should always optimize them away).
|
||||
@ -898,11 +889,11 @@ def triton_kernel_wrapper_functional_fake_tensor_mode(
|
||||
*,
|
||||
kernel_idx: int,
|
||||
constant_args_idx: int,
|
||||
grid: List["TritonGridType"],
|
||||
grid: list["TritonGridType"],
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs: Dict[str, Any],
|
||||
tensors_to_clone: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
kwargs: dict[str, Any],
|
||||
tensors_to_clone: list[str],
|
||||
) -> dict[str, Any]:
|
||||
# TODO(oulgen): For performance reasons, we want to ensure that these
|
||||
# `clone_preserve_strides` calls are never executed at runtime
|
||||
# (inductor should always optimize them away).
|
||||
@ -921,11 +912,11 @@ def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode(
|
||||
*,
|
||||
kernel_idx: int,
|
||||
constant_args_idx: int,
|
||||
grid: List["TritonGridType"],
|
||||
grid: list["TritonGridType"],
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs: Dict[str, Any],
|
||||
tensors_to_clone: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
kwargs: dict[str, Any],
|
||||
tensors_to_clone: list[str],
|
||||
) -> dict[str, Any]:
|
||||
ret = trace_triton_kernel_wrapper(
|
||||
mode,
|
||||
triton_kernel_wrapper_functional,
|
||||
@ -947,11 +938,11 @@ def triton_kernel_wrapper_functional_functionalize(
|
||||
ctx: "BaseFunctionalizeAPI",
|
||||
kernel_idx: int,
|
||||
constant_args_idx: int,
|
||||
grid: List["TritonGridType"],
|
||||
grid: list["TritonGridType"],
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs: Dict[str, Any],
|
||||
tensors_to_clone: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
kwargs: dict[str, Any],
|
||||
tensors_to_clone: list[str],
|
||||
) -> dict[str, Any]:
|
||||
unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
|
||||
with ctx.redispatch_to_next():
|
||||
outputs = triton_kernel_wrapper_functional(
|
||||
@ -1024,7 +1015,7 @@ class TritonHOPifier:
|
||||
grid,
|
||||
meta,
|
||||
tx,
|
||||
) -> Union[Tuple[Union[int, sympy.Expr, SymInt], ...], Tuple["Proxy", ...]]:
|
||||
) -> Union[tuple[Union[int, sympy.Expr, SymInt], ...], tuple["Proxy", ...]]:
|
||||
raise NotImplementedError("abstract method")
|
||||
|
||||
def wrap_user_defined_obj(
|
||||
@ -1041,8 +1032,8 @@ class TritonHOPifier:
|
||||
def call_user_defined_fn(
|
||||
self,
|
||||
user_fn: Callable[..., Any],
|
||||
args: List,
|
||||
kwargs: Dict,
|
||||
args: list,
|
||||
kwargs: dict,
|
||||
tx: Optional["InstructionTranslator"],
|
||||
variable: Optional[
|
||||
Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
|
||||
@ -1051,8 +1042,8 @@ class TritonHOPifier:
|
||||
raise NotImplementedError("abstract method")
|
||||
|
||||
def maybe_unpack_configs(
|
||||
self, configs: List["TritonConfig"], tx: Optional["InstructionTranslator"]
|
||||
) -> List["TritonConfig"]:
|
||||
self, configs: list["TritonConfig"], tx: Optional["InstructionTranslator"]
|
||||
) -> list["TritonConfig"]:
|
||||
raise NotImplementedError("abstract method")
|
||||
|
||||
def maybe_unpack_heuristic_result(self, result: Any) -> Any:
|
||||
@ -1064,10 +1055,10 @@ class TritonHOPifier:
|
||||
early_config_prune: Optional[Callable],
|
||||
perf_model: Optional[Callable],
|
||||
top_k: float,
|
||||
configs: List,
|
||||
named_args: Dict,
|
||||
kwargs: Dict,
|
||||
) -> List["TritonConfig"]:
|
||||
configs: list,
|
||||
named_args: dict,
|
||||
kwargs: dict,
|
||||
) -> list["TritonConfig"]:
|
||||
# Reimplement autotuner.prune_configs(...) here
|
||||
# see: https://github.com/triton-lang/triton/blob/e57b46897191b3b3061c78d0d60e58e94be565b6/python/triton/runtime/autotuner.py # noqa: E501,B950
|
||||
# We do this to avoid calling prune_configs, which in turn calls early_config_prune and perf_model
|
||||
@ -1107,14 +1098,14 @@ class TritonHOPifier:
|
||||
self,
|
||||
variable,
|
||||
grids,
|
||||
combined_args: Dict[str, Any],
|
||||
combined_args: dict[str, Any],
|
||||
tx,
|
||||
) -> Optional["ConstantVariable"]:
|
||||
raise NotImplementedError("abstract method")
|
||||
|
||||
def check_grid( # type: ignore[no-untyped-def]
|
||||
self, grid
|
||||
) -> Union[Tuple[Union[int, sympy.Expr, SymInt], ...], Tuple["Proxy", ...]]:
|
||||
) -> Union[tuple[Union[int, sympy.Expr, SymInt], ...], tuple["Proxy", ...]]:
|
||||
raise NotImplementedError("abstract method")
|
||||
|
||||
def init_variable(
|
||||
@ -1215,7 +1206,7 @@ class TritonHOPifier:
|
||||
self,
|
||||
variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"],
|
||||
args: Sequence[Any],
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
tx: Optional["InstructionTranslator"],
|
||||
) -> Optional["ConstantVariable"]:
|
||||
if "grid" not in kwargs:
|
||||
@ -1236,7 +1227,7 @@ class TritonHOPifier:
|
||||
self,
|
||||
variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"],
|
||||
args: Sequence[Any],
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
tx: Optional["InstructionTranslator"],
|
||||
) -> Optional["ConstantVariable"]:
|
||||
from triton import JITFunction
|
||||
@ -1546,7 +1537,7 @@ class TracingTritonHOPifier(TritonHOPifier):
|
||||
grid: "TritonGridCallableType",
|
||||
meta: "TritonMetaParamsType",
|
||||
tx: None,
|
||||
) -> Tuple[Union[int, sympy.Expr, SymInt], ...]:
|
||||
) -> tuple[Union[int, sympy.Expr, SymInt], ...]:
|
||||
assert tx is None
|
||||
assert isinstance(meta, dict)
|
||||
assert callable(grid)
|
||||
@ -1567,8 +1558,8 @@ class TracingTritonHOPifier(TritonHOPifier):
|
||||
def call_user_defined_fn(
|
||||
self,
|
||||
user_fn: Callable[..., Any],
|
||||
args: List,
|
||||
kwargs: Dict,
|
||||
args: list,
|
||||
kwargs: dict,
|
||||
tx: Optional["InstructionTranslator"],
|
||||
variable: Optional[
|
||||
Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
|
||||
@ -1580,8 +1571,8 @@ class TracingTritonHOPifier(TritonHOPifier):
|
||||
return user_fn(*args, **kwargs)
|
||||
|
||||
def maybe_unpack_configs(
|
||||
self, configs: List["TritonConfig"], tx: Optional["InstructionTranslator"]
|
||||
) -> List["TritonConfig"]:
|
||||
self, configs: list["TritonConfig"], tx: Optional["InstructionTranslator"]
|
||||
) -> list["TritonConfig"]:
|
||||
assert isinstance(configs, list)
|
||||
return configs
|
||||
|
||||
@ -1591,7 +1582,7 @@ class TracingTritonHOPifier(TritonHOPifier):
|
||||
def check_grid(
|
||||
self,
|
||||
grid: "TritonGridType",
|
||||
) -> Tuple[Union[int, sympy.Expr, SymInt], ...]:
|
||||
) -> tuple[Union[int, sympy.Expr, SymInt], ...]:
|
||||
if not isinstance(grid, collections.abc.Sequence):
|
||||
raise RuntimeError(
|
||||
"wrap_triton can only handle grids that resolve to Sequence[int]."
|
||||
@ -1602,8 +1593,8 @@ class TracingTritonHOPifier(TritonHOPifier):
|
||||
def call_HOP(
|
||||
self,
|
||||
variable: "TraceableTritonKernelWrapper",
|
||||
grids: List["TritonGridTupleType"],
|
||||
combined_args: Dict[str, Any],
|
||||
grids: list["TritonGridTupleType"],
|
||||
combined_args: dict[str, Any],
|
||||
tx: None,
|
||||
) -> None:
|
||||
assert tx is None
|
||||
@ -1652,7 +1643,7 @@ class TraceableTritonKernelWrapper:
|
||||
def __getitem__(self, *args: Sequence[Any]) -> "TraceableTritonKernelWrapper":
|
||||
return tracing_triton_hopifier_singleton.call_getitem(self, args) # type: ignore[return-value]
|
||||
|
||||
def run(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> Any:
|
||||
def run(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> Any:
|
||||
from torch._library.triton import is_wrap_triton_enabled
|
||||
|
||||
if is_wrap_triton_enabled():
|
||||
@ -1661,7 +1652,7 @@ class TraceableTritonKernelWrapper:
|
||||
assert self.kernel is not None
|
||||
return self.kernel.run(*args, **kwargs)
|
||||
|
||||
def __call__(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> Any:
|
||||
def __call__(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> Any:
|
||||
from torch._library.triton import is_wrap_triton_enabled
|
||||
|
||||
if is_wrap_triton_enabled():
|
||||
|
@ -2,7 +2,7 @@
|
||||
import functools
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List, Tuple, Union
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
@ -462,7 +462,7 @@ def save_tensors_and_symints_for_backward(ctx, args):
|
||||
assert all(
|
||||
isinstance(arg, (torch.Tensor, torch.SymInt, int, type(None))) for arg in args
|
||||
), args
|
||||
partitioned_args: List[Any] = [[], []]
|
||||
partitioned_args: list[Any] = [[], []]
|
||||
pos = []
|
||||
for i, arg in enumerate(args):
|
||||
idx = 0 if isinstance(arg, torch.Tensor) else 1
|
||||
@ -514,7 +514,7 @@ def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
||||
# Reports the difference between meta of two tensors in a string
|
||||
def diff_tensor_meta(
|
||||
meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
|
||||
|
||||
pair_diffs = []
|
||||
@ -541,7 +541,7 @@ def diff_tensor_meta(
|
||||
# to support int arguments. In the eager run case, we re-trace the subgraph in AutogradKey, so inner
|
||||
# hops may receive int inputs from the shape of outer tensor inputs.
|
||||
# However, CompositeExplicitAutograd won't receive SymInt inputs because it only accepts real tensor inputs.
|
||||
def validate_subgraph_args_types(lifted_args: Union[Tuple[Any, ...], List[Any]]):
|
||||
def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]]):
|
||||
allowed_types = (torch.Tensor, int, torch.SymInt)
|
||||
assert all(
|
||||
isinstance(arg, (torch.Tensor, int, torch.SymInt)) for arg in lifted_args
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Callable, List, Tuple, Union
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -33,8 +33,8 @@ class WhileLoopOp(HigherOrderOperator):
|
||||
self,
|
||||
cond_fn: Callable,
|
||||
body_fn: Callable,
|
||||
carried_inputs: Tuple[Union[torch.Tensor, int, float, bool]],
|
||||
additional_inputs: Tuple[Union[torch.Tensor, torch.SymInt, int], ...],
|
||||
carried_inputs: tuple[Union[torch.Tensor, int, float, bool]],
|
||||
additional_inputs: tuple[Union[torch.Tensor, torch.SymInt, int], ...],
|
||||
/,
|
||||
):
|
||||
if not isinstance(carried_inputs, tuple):
|
||||
@ -126,7 +126,7 @@ def while_loop(cond_fn, body_fn, carried_inputs):
|
||||
|
||||
# Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo.
|
||||
# parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs.
|
||||
additional_inputs: Tuple = ()
|
||||
additional_inputs: tuple = ()
|
||||
|
||||
# The reason we flatten the output before calling into dynamo is that
|
||||
# we want to create a consistent input ordering for cond_fn and body_fn.
|
||||
@ -330,15 +330,15 @@ def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs
|
||||
|
||||
|
||||
def check_meta_consistency(
|
||||
lhs_list: List[Union[torch.Tensor, torch.SymInt, int]],
|
||||
rhs_list: List[Union[torch.Tensor, torch.SymInt, int]],
|
||||
lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
||||
rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
||||
lhs_name: str,
|
||||
rhs_name: str,
|
||||
) -> None:
|
||||
def diff_meta_pairs(
|
||||
lhs_list: List[Union[torch.Tensor, torch.SymInt, int]],
|
||||
rhs_list: List[Union[torch.Tensor, torch.SymInt, int]],
|
||||
) -> List[str]:
|
||||
lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
||||
rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
||||
) -> list[str]:
|
||||
def diff_meta(
|
||||
lhs: Union[torch.Tensor, torch.SymInt, int],
|
||||
rhs: Union[torch.Tensor, torch.SymInt, int],
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Type, TYPE_CHECKING, Union
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import SymInt
|
||||
@ -80,7 +80,7 @@ class _DeconstructedSymType:
|
||||
Represents a SymInt, SymFloat, SymBool without the associated ShapeEnv
|
||||
"""
|
||||
|
||||
ty: Type[PySymType]
|
||||
ty: type[PySymType]
|
||||
node: _DeconstructedSymNode
|
||||
|
||||
@staticmethod
|
||||
@ -216,7 +216,7 @@ class _CacheKeyState:
|
||||
|
||||
# We track the SymNodes so when we get the output we can see if it exactly
|
||||
# matches one of the inputs so we can uncache it properly.
|
||||
sym_node_lookup: Dict[int, int] # id(SymNode) -> index
|
||||
sym_node_lookup: dict[int, int] # id(SymNode) -> index
|
||||
|
||||
# There are cases where we're asked to perform an op when we have no
|
||||
# ShapeEnv on the FakeTensorMode - but for SymNodes we MUST have a
|
||||
@ -241,7 +241,7 @@ class _CacheKeyState:
|
||||
"""
|
||||
return bool(self.sym_node_lookup)
|
||||
|
||||
def convert_sym_int(self, result: List[object], arg: SymInt) -> None:
|
||||
def convert_sym_int(self, result: list[object], arg: SymInt) -> None:
|
||||
node_id = id(arg.node)
|
||||
if node_id in self.sym_node_lookup:
|
||||
result.append(_InputBackref(self.sym_node_lookup[node_id]))
|
||||
|
@ -13,25 +13,7 @@ import typing
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, cast, Literal, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import Self, TypeGuard
|
||||
from weakref import ReferenceType
|
||||
|
||||
@ -69,6 +51,7 @@ from ._fake_tensor_utils import _CacheKeyState, _PySymInputStub, _SymIntOutputSt
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator, Iterable, Mapping, Sequence
|
||||
from types import TracebackType
|
||||
|
||||
from torch._guards import Source
|
||||
@ -94,7 +77,7 @@ class _Unassigned:
|
||||
|
||||
_UNASSIGNED = _Unassigned()
|
||||
|
||||
DimList = List
|
||||
DimList = list
|
||||
|
||||
pytree = torch.utils._pytree
|
||||
T = TypeVar("T")
|
||||
@ -145,7 +128,7 @@ class MetadataMismatchError(RuntimeError):
|
||||
reason: str
|
||||
|
||||
|
||||
def ordered_set(*items: T) -> Dict[T, Literal[True]]:
|
||||
def ordered_set(*items: T) -> dict[T, Literal[True]]:
|
||||
return dict.fromkeys(items, True)
|
||||
|
||||
|
||||
@ -160,8 +143,8 @@ def unset_fake_temporarily() -> Generator[Optional[TorchDispatchMode], None, Non
|
||||
|
||||
|
||||
def get_plain_tensors(
|
||||
subclass: Tensor, *, out: List[Union[Tensor, int, SymInt]]
|
||||
) -> List[Union[Tensor, int, SymInt]]:
|
||||
subclass: Tensor, *, out: list[Union[Tensor, int, SymInt]]
|
||||
) -> list[Union[Tensor, int, SymInt]]:
|
||||
# This function is used in Runtime, do not add redundant asserts
|
||||
todo = [subclass]
|
||||
while todo:
|
||||
@ -248,7 +231,7 @@ def torch_decomp_decompositions(func: OpOverload) -> bool:
|
||||
) and decomposition_table[func].__name__ in dir(decompositions)
|
||||
|
||||
|
||||
def tree_flatten_only(ty: Type[T], tree: PyTree) -> List[T]:
|
||||
def tree_flatten_only(ty: type[T], tree: PyTree) -> list[T]:
|
||||
flat_vals = pytree.tree_leaves(tree)
|
||||
return [elem for elem in flat_vals if isinstance(elem, ty)]
|
||||
|
||||
@ -281,7 +264,7 @@ class FakeTensorConverter:
|
||||
return self.meta_converter.tensor_memo
|
||||
|
||||
meta_converter: MetaConverter
|
||||
constant_storage_mapping: Dict[StorageWeakRef, List[ReferenceType]]
|
||||
constant_storage_mapping: dict[StorageWeakRef, list[ReferenceType]]
|
||||
export: bool
|
||||
|
||||
def __init__(self, *, copy_data: bool = False, export: bool = False) -> None:
|
||||
@ -581,7 +564,7 @@ class SymNumberMemoDescriptor:
|
||||
return f"_{self._name}_epoch"
|
||||
|
||||
def __get__(
|
||||
self, obj: FakeTensor, objtype: Optional[Type[FakeTensor]] = None
|
||||
self, obj: FakeTensor, objtype: Optional[type[FakeTensor]] = None
|
||||
) -> Optional[Union[torch.SymInt, torch.SymFloat]]:
|
||||
if (r := getattr(obj, self._memo(obj))) is None:
|
||||
return None
|
||||
@ -674,13 +657,13 @@ class FakeTensor(Tensor):
|
||||
|
||||
# We don't support named tensors; graph break
|
||||
@property
|
||||
def names(self) -> List[str]:
|
||||
def names(self) -> list[str]:
|
||||
raise UnsupportedFakeTensorException(
|
||||
"torch.compile doesn't support named tensors"
|
||||
)
|
||||
|
||||
@names.setter
|
||||
def names(self, _: List[str]) -> None:
|
||||
def names(self, _: list[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@ -774,7 +757,7 @@ class FakeTensor(Tensor):
|
||||
def __torch_dispatch__( # type: ignore[override] # TODO
|
||||
cls,
|
||||
func: OpOverload,
|
||||
types: Sequence[Type],
|
||||
types: Sequence[type],
|
||||
args: Sequence[object] = (),
|
||||
kwargs: Mapping[str, object] = immutable_dict(),
|
||||
) -> object:
|
||||
@ -847,7 +830,7 @@ class FakeTensor(Tensor):
|
||||
@staticmethod
|
||||
def _find_common_device(
|
||||
func: OpOverload, flat_args: Sequence[object]
|
||||
) -> Tuple[torch.device, bool]:
|
||||
) -> tuple[torch.device, bool]:
|
||||
# Returns: (common_device, has_scalar_only_inputs)
|
||||
|
||||
# cpu - zero-dim tensors can be called in cuda kernels,
|
||||
@ -942,8 +925,8 @@ class TensorMetadata:
|
||||
"""
|
||||
|
||||
dtype: torch.dtype
|
||||
shape: Tuple[_MetadataIntLike, ...]
|
||||
stride: Tuple[_MetadataIntLike, ...]
|
||||
shape: tuple[_MetadataIntLike, ...]
|
||||
stride: tuple[_MetadataIntLike, ...]
|
||||
device: torch.device
|
||||
layout: torch.layout
|
||||
memory_format: Optional[torch.memory_format]
|
||||
@ -961,7 +944,7 @@ class TensorMetadata:
|
||||
|
||||
def _flatten_into(
|
||||
self,
|
||||
result: List[object],
|
||||
result: list[object],
|
||||
mode: FakeTensorMode,
|
||||
state: _CacheKeyState,
|
||||
) -> None:
|
||||
@ -1024,10 +1007,10 @@ class _DispatchCacheKey:
|
||||
Key for the FakeTensor dispatch cache.
|
||||
"""
|
||||
|
||||
key: Tuple[object, ...]
|
||||
key: tuple[object, ...]
|
||||
hashvalue: int
|
||||
|
||||
def __init__(self, tup: Tuple[object, ...]) -> None:
|
||||
def __init__(self, tup: tuple[object, ...]) -> None:
|
||||
self.key = tup
|
||||
self.hashvalue = hash(tup)
|
||||
|
||||
@ -1073,7 +1056,7 @@ class _DispatchCacheEntry:
|
||||
is_output_tuple flag helps in differentiating the return type
|
||||
"""
|
||||
|
||||
output_infos: Tuple[_DispatchCacheEntryOutputInfo]
|
||||
output_infos: tuple[_DispatchCacheEntryOutputInfo]
|
||||
is_output_tuple: bool = False
|
||||
|
||||
|
||||
@ -1096,7 +1079,7 @@ class DispatchCacheInfo:
|
||||
|
||||
hits: int
|
||||
misses: int
|
||||
bypasses: Dict[str, int]
|
||||
bypasses: dict[str, int]
|
||||
size: int
|
||||
|
||||
|
||||
@ -1110,10 +1093,10 @@ class DispatchCacheInfo:
|
||||
|
||||
|
||||
class FakeTensorMode(TorchDispatchMode):
|
||||
cache: Dict[_DispatchCacheKey, _DispatchCacheEntry] = {}
|
||||
cache: dict[_DispatchCacheKey, _DispatchCacheEntry] = {}
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
cache_bypasses: Dict[str, int] = defaultdict(int)
|
||||
cache_bypasses: dict[str, int] = defaultdict(int)
|
||||
# Every time you retrace using the same fake tensor mode, you should
|
||||
# advance the epoch so we don't reuse unbacked memos
|
||||
epoch: int = 0
|
||||
@ -1208,8 +1191,8 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# in_kernel_invocation
|
||||
# If another fake mode was already active when we enter, we also stash it here.
|
||||
# That way when we exit, we know to re-enable the previous fake mode.
|
||||
self.enter_stack: List[
|
||||
Tuple[bool, Optional[TorchDispatchMode], Optional[bool]]
|
||||
self.enter_stack: list[
|
||||
tuple[bool, Optional[TorchDispatchMode], Optional[bool]]
|
||||
] = []
|
||||
|
||||
self.shape_env = shape_env
|
||||
@ -1272,7 +1255,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(
|
||||
self,
|
||||
func: OpOverload,
|
||||
types: Sequence[Type],
|
||||
types: Sequence[type],
|
||||
args: Sequence[object] = (),
|
||||
kwargs: Mapping[str, object] = immutable_dict(),
|
||||
) -> object:
|
||||
@ -1309,7 +1292,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
a: Optional[Type[BaseException]],
|
||||
a: Optional[type[BaseException]],
|
||||
b: Optional[BaseException],
|
||||
c: Optional[TracebackType],
|
||||
) -> None:
|
||||
@ -1356,7 +1339,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
def _cached_dispatch_impl(
|
||||
self,
|
||||
func: OpOverload,
|
||||
types: Sequence[Type],
|
||||
types: Sequence[type],
|
||||
args: Sequence[object],
|
||||
kwargs: Mapping[str, object],
|
||||
) -> object:
|
||||
@ -1474,7 +1457,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
|
||||
def _prep_args_for_hash(
|
||||
self,
|
||||
result: List[object],
|
||||
result: list[object],
|
||||
args: Union[Mapping[str, object], Sequence[object], Iterable[object]],
|
||||
state: _CacheKeyState,
|
||||
) -> None:
|
||||
@ -1734,7 +1717,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
key: _DispatchCacheKey,
|
||||
func: OpOverload,
|
||||
args: Sequence[object],
|
||||
) -> Union[Optional[FakeTensor], Tuple[Optional[FakeTensor], ...]]:
|
||||
) -> Union[Optional[FakeTensor], tuple[Optional[FakeTensor], ...]]:
|
||||
"""
|
||||
Create a new FakeTensor from the cache entry.
|
||||
"""
|
||||
@ -1758,9 +1741,9 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
|
||||
def _crosscheck_cache_output(
|
||||
self,
|
||||
output: Union[Optional[FakeTensor], Tuple[Optional[FakeTensor], ...]],
|
||||
output: Union[Optional[FakeTensor], tuple[Optional[FakeTensor], ...]],
|
||||
func: OpOverload,
|
||||
types: Sequence[Type],
|
||||
types: Sequence[type],
|
||||
args: Sequence[object],
|
||||
kwargs: Mapping[str, object],
|
||||
) -> None:
|
||||
@ -1796,7 +1779,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
def dispatch(
|
||||
self,
|
||||
func: OpOverload,
|
||||
types: Sequence[Type],
|
||||
types: Sequence[type],
|
||||
args: Sequence[object] = (),
|
||||
kwargs: Mapping[str, object] = immutable_dict(),
|
||||
) -> object:
|
||||
@ -1982,7 +1965,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
def _dispatch_impl(
|
||||
self,
|
||||
func: OpOverload,
|
||||
types: Sequence[Type],
|
||||
types: Sequence[type],
|
||||
args: Sequence[object],
|
||||
kwargs: Mapping[str, object],
|
||||
) -> Optional[FakeTensor]:
|
||||
@ -2436,13 +2419,13 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
converter: FakeTensorConverter,
|
||||
flat_args: Sequence[object],
|
||||
args_spec: TreeSpec,
|
||||
) -> Tuple[List[object], List[FakeTensor]]:
|
||||
) -> tuple[list[object], list[FakeTensor]]:
|
||||
"""
|
||||
Checks if the list of tensors are fake tensors.
|
||||
If not, try to convert them to fake tensors.
|
||||
Returns the original args, kwargs, and a flattened list of (args, kwargs) that are fake tensors.
|
||||
"""
|
||||
flat_arg_fake_tensors: List[FakeTensor] = []
|
||||
flat_arg_fake_tensors: list[FakeTensor] = []
|
||||
|
||||
def validate(x: T) -> Union[T, FakeTensor]:
|
||||
if not isinstance(x, Tensor):
|
||||
@ -2663,7 +2646,7 @@ def run_fallback_kernel(
|
||||
|
||||
r = func(*args, **kwargs)
|
||||
|
||||
storages: Set[_StoragePointer] = set()
|
||||
storages: set[_StoragePointer] = set()
|
||||
|
||||
for e in flat_args:
|
||||
if isinstance(e, Tensor):
|
||||
@ -2703,7 +2686,7 @@ class FakeCopyMode(TorchFunctionMode):
|
||||
def __torch_function__(
|
||||
self,
|
||||
func: OpOverload,
|
||||
types: Sequence[Type],
|
||||
types: Sequence[type],
|
||||
args: Sequence[object] = (),
|
||||
kwargs: Optional[Mapping[str, object]] = None,
|
||||
) -> FakeTensor:
|
||||
@ -2718,7 +2701,7 @@ class FakeCopyMode(TorchFunctionMode):
|
||||
elif func == Tensor.__deepcopy__:
|
||||
assert len(args) == 2 and len(kwargs) == 0
|
||||
tensor = cast(Tensor, args[0])
|
||||
memo = cast(Dict[int, FakeTensor], args[1])
|
||||
memo = cast(dict[int, FakeTensor], args[1])
|
||||
|
||||
if id(tensor) in memo:
|
||||
return memo[id(tensor)]
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
from typing import Any, Callable, List, Union
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -102,8 +102,8 @@ def is_sdpa_error(func, idx, e):
|
||||
|
||||
|
||||
def try_convert_fake_to_real(
|
||||
ten_list: List[Union[FakeTensor, Any]]
|
||||
) -> List[Union[FakeTensor, torch.Tensor, Any]]:
|
||||
ten_list: list[Union[FakeTensor, Any]]
|
||||
) -> list[Union[FakeTensor, torch.Tensor, Any]]:
|
||||
"""
|
||||
Attempt to convert fake tensors to a corresponding real tensor with the correct underlying storage by looking up
|
||||
the FakeTensorMode meta to real storage mapping. On failure to find the storage mapping, the FakeTensor will
|
||||
|
@ -3,7 +3,7 @@ import contextlib
|
||||
import warnings
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, ContextManager, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -308,10 +308,10 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
self._dispatch_key = torch._C.DispatchKey.PreDispatch if pre_dispatch else None # type: ignore[attr-defined]
|
||||
# Map of effect type (ex. _EffectType.ORDERED) to a token. The tokens help keep
|
||||
# track of the ordering between side effectful operations.
|
||||
self._tokens: Dict[Any, torch.Tensor] = {}
|
||||
self._tokens: dict[Any, torch.Tensor] = {}
|
||||
|
||||
# Filled after forward tracing.
|
||||
self._tokens_forward_output: Dict[Any, torch.Tensor] = {}
|
||||
self._tokens_forward_output: dict[Any, torch.Tensor] = {}
|
||||
|
||||
# Functionalization runs twice in AOTAutograd, once in
|
||||
# `run_functionalized_fw_and_collect_metadata` to collect metadata to
|
||||
@ -648,12 +648,12 @@ def dispatch_functionalize(func, mode: FunctionalTensorMode = FunctionalTensorMo
|
||||
|
||||
class BaseFunctionalizeAPI(ABC):
|
||||
@abstractmethod
|
||||
def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
|
||||
def wrap_tensors(self, args: tuple[Any]) -> tuple[Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unwrap_tensors(
|
||||
self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
|
||||
self, args: Union[torch.Tensor, tuple[torch.Tensor, ...]]
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
@ -690,14 +690,14 @@ class PythonFunctionalizeAPI(BaseFunctionalizeAPI):
|
||||
self.mode = mode if mode else FunctionalTensorMode()
|
||||
self.pre_dispatch = pre_dispatch
|
||||
|
||||
def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
|
||||
def wrap_tensors(self, args: tuple[Any]) -> tuple[Any]:
|
||||
with self.mode:
|
||||
return torch.utils._pytree.tree_map_only(
|
||||
torch.Tensor, FunctionalTensor.to_functional, args
|
||||
)
|
||||
|
||||
def unwrap_tensors(
|
||||
self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor]]
|
||||
self, args: Union[torch.Tensor, tuple[torch.Tensor, ...], list[torch.Tensor]]
|
||||
) -> Any:
|
||||
return torch.utils._pytree.tree_map_only(
|
||||
FunctionalTensor, FunctionalTensor.from_functional, args
|
||||
@ -733,14 +733,14 @@ class PythonFunctionalizeAPI(BaseFunctionalizeAPI):
|
||||
|
||||
|
||||
class CppFunctionalizeAPI(BaseFunctionalizeAPI):
|
||||
def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
|
||||
def wrap_tensors(self, args: tuple[Any]) -> tuple[Any]:
|
||||
from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional
|
||||
|
||||
return _wrap_all_tensors_to_functional(args, level=0)
|
||||
|
||||
def unwrap_tensors(
|
||||
self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
||||
self, args: Union[torch.Tensor, tuple[torch.Tensor, ...]]
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
from torch._functorch.eager_transforms import (
|
||||
_unwrap_all_tensors_from_functional,
|
||||
)
|
||||
@ -772,14 +772,14 @@ class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI):
|
||||
def __init__(self, interpreter):
|
||||
self.interpreter = interpreter
|
||||
|
||||
def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
|
||||
def wrap_tensors(self, args: tuple[Any]) -> tuple[Any]:
|
||||
from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional
|
||||
|
||||
return _wrap_all_tensors_to_functional(args, level=self.interpreter.level())
|
||||
|
||||
def unwrap_tensors(
|
||||
self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
||||
self, args: Union[torch.Tensor, tuple[torch.Tensor, ...]]
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
from torch._functorch.eager_transforms import (
|
||||
_unwrap_all_tensors_from_functional,
|
||||
)
|
||||
|
@ -13,15 +13,10 @@ from typing import (
|
||||
Callable,
|
||||
ClassVar,
|
||||
ContextManager,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
NewType,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -66,7 +61,7 @@ def _is_fake_tensor(t: object) -> TypeIs[FakeTensor]:
|
||||
return isinstance(t, FakeTensor)
|
||||
|
||||
|
||||
DimList = List
|
||||
DimList = list
|
||||
_TensorLikeT = TypeVar("_TensorLikeT", "MetaTensorDesc", torch.Tensor)
|
||||
_T = TypeVar("_T")
|
||||
_TensorT = TypeVar("_TensorT", bound=torch.Tensor)
|
||||
@ -178,7 +173,7 @@ def is_sparse_any(t: object) -> TypeGuard[torch.Tensor]:
|
||||
return is_sparse_coo(t) or is_sparse_compressed(t)
|
||||
|
||||
|
||||
def _checked_cast(ty: Type[_T], obj: object) -> _T:
|
||||
def _checked_cast(ty: type[_T], obj: object) -> _T:
|
||||
assert isinstance(obj, ty), f"expected {ty} but got {type(obj)}"
|
||||
return obj
|
||||
|
||||
@ -224,8 +219,8 @@ class MetaTensorDescriber:
|
||||
# Storage -> int
|
||||
self.lookup_storage = WeakIdKeyDictionary()
|
||||
self.copy_data = copy_data
|
||||
self.traced_tensors: Set[int] = set()
|
||||
self.traced_storages: Set[int] = set()
|
||||
self.traced_tensors: set[int] = set()
|
||||
self.traced_storages: set[int] = set()
|
||||
|
||||
def get_tensor_id(self, t: torch.Tensor) -> MetaTensorId:
|
||||
if t not in self.lookup_tensor:
|
||||
@ -480,7 +475,7 @@ class MetaStorageDesc:
|
||||
# serializable in JSON, you want to do something special here anyway
|
||||
data: Optional[torch.UntypedStorage]
|
||||
|
||||
def as_json(self, describer_id: _DescriberId) -> Dict[str, object]:
|
||||
def as_json(self, describer_id: _DescriberId) -> dict[str, object]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"describer_id": describer_id,
|
||||
@ -579,8 +574,8 @@ class MetaTensorDesc(Generic[_TensorT]):
|
||||
# throw an error, but we don't currently have any subclasses that do this
|
||||
# except C++ nested tensor but we're going to have nested int to make this
|
||||
# defined on NJT
|
||||
size: Tuple[int, ...]
|
||||
dynamo_dynamic_indices: List[int]
|
||||
size: tuple[int, ...]
|
||||
dynamo_dynamic_indices: list[int]
|
||||
|
||||
layout: torch.layout = torch.strided
|
||||
is_inference: bool = False
|
||||
@ -603,7 +598,7 @@ class MetaTensorDesc(Generic[_TensorT]):
|
||||
is_conj: bool = False
|
||||
is_neg: bool = False
|
||||
is_parameter: bool = False
|
||||
stride: Optional[Tuple[int, ...]] = None
|
||||
stride: Optional[tuple[int, ...]] = None
|
||||
storage_offset: int = 0
|
||||
# NB: We have a choice whether or not to store the id or a direct pointer
|
||||
# to the data structure. For ease of use, we store the data structure,
|
||||
@ -621,13 +616,13 @@ class MetaTensorDesc(Generic[_TensorT]):
|
||||
unwrapped: Optional[MetaTensorDesc] = None # is_functorch_wrapped
|
||||
bdim: Optional[int] = None # is_functorch_wrapped
|
||||
base: Optional[MetaTensorDesc] = None # is_view
|
||||
attrs: Optional[Dict[str, MetaTensorDesc]] = None # is_traceable_wrapper_subclass
|
||||
attrs: Optional[dict[str, MetaTensorDesc]] = None # is_traceable_wrapper_subclass
|
||||
creation_meta: Optional[CreationMeta] = None
|
||||
grad: Optional[MetaTensorDesc] = None
|
||||
|
||||
# Everything below is NOT serializable, need some more work
|
||||
|
||||
_UNSERIALIZABLE: ClassVar[Set[str]] = {
|
||||
_UNSERIALIZABLE: ClassVar[set[str]] = {
|
||||
"ctx",
|
||||
"type",
|
||||
"fake_mode",
|
||||
@ -642,14 +637,14 @@ class MetaTensorDesc(Generic[_TensorT]):
|
||||
}
|
||||
|
||||
ctx: Optional[object] = None # is_traceable_wrapper_subclass
|
||||
type: Optional[Type] = None # is_traceable_wrapper_subclass
|
||||
type: Optional[type] = None # is_traceable_wrapper_subclass
|
||||
fake_mode: Optional[FakeTensorMode] = None
|
||||
view_func: Optional[ViewFunc] = None
|
||||
# level looks serializable, but actually it is meaningless without
|
||||
# the functorch_stack below
|
||||
level: Optional[int] = None # is_functorch_wrapped
|
||||
current_level: Optional[int] = None
|
||||
functorch_stack: Optional[List[CInterpreter]] = None
|
||||
functorch_stack: Optional[list[CInterpreter]] = None
|
||||
autograd_meta_from: Optional[torch.Tensor] = None
|
||||
|
||||
# This is only populated on copy_data, and typically is not used at all,
|
||||
@ -669,7 +664,7 @@ class MetaTensorDesc(Generic[_TensorT]):
|
||||
|
||||
# NB: This will reference numeric IDs, and it is assumed that you've
|
||||
# already serialized everything this recursively references
|
||||
def as_json(self, describer_id: _DescriberId) -> Dict[str, object]:
|
||||
def as_json(self, describer_id: _DescriberId) -> dict[str, object]:
|
||||
def json(k: str, v: object) -> object:
|
||||
# Some best-effort debugging serialization for unserializable
|
||||
# fields (feel free to add other special cases as appropriate)
|
||||
@ -706,7 +701,7 @@ class MetaTensorDesc(Generic[_TensorT]):
|
||||
return r
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
return self.size
|
||||
|
||||
|
||||
@ -893,7 +888,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
symbolic_context: Optional[
|
||||
torch.fx.experimental.symbolic_shapes.SymbolicContext
|
||||
] = symbolic_context,
|
||||
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], int]:
|
||||
assert t.stride is not None
|
||||
if shape_env is not None:
|
||||
fake_mode = t.fake_mode
|
||||
@ -948,8 +943,8 @@ class MetaConverter(Generic[_TensorT]):
|
||||
# symbolic context.
|
||||
def empty_create_subclass(
|
||||
t: MetaTensorDesc,
|
||||
outer_size: Tuple[int, ...],
|
||||
outer_stride: Tuple[int, ...],
|
||||
outer_size: tuple[int, ...],
|
||||
outer_stride: tuple[int, ...],
|
||||
symbolic_context: Optional[
|
||||
torch.fx.experimental.symbolic_shapes.SymbolicContext
|
||||
] = symbolic_context,
|
||||
@ -981,8 +976,8 @@ class MetaConverter(Generic[_TensorT]):
|
||||
|
||||
def _empty_create_subclass(
|
||||
t: MetaTensorDesc,
|
||||
outer_size: Optional[Tuple[int, ...]],
|
||||
outer_stride: Optional[Tuple[int, ...]],
|
||||
outer_size: Optional[tuple[int, ...]],
|
||||
outer_stride: Optional[tuple[int, ...]],
|
||||
symbolic_context: Optional[
|
||||
torch.fx.experimental.symbolic_shapes.SymbolicContext
|
||||
],
|
||||
@ -1028,7 +1023,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
inner_tensors[attr] = new_empty_tensor
|
||||
|
||||
assert t.type is not None
|
||||
return t.type.__tensor_unflatten__(
|
||||
return t.type.__tensor_unflatten__( # type: ignore[attr-defined]
|
||||
inner_tensors, t.ctx, outer_size, outer_stride
|
||||
)
|
||||
|
||||
@ -1081,7 +1076,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.ndim
|
||||
if t.is_traceable_wrapper_subclass:
|
||||
assert t.attrs is not None
|
||||
inner_contexts: Dict[
|
||||
inner_contexts: dict[
|
||||
str, torch.fx.experimental.symbolic_shapes.SymbolicContext
|
||||
] = {}
|
||||
for attr, inner in t.attrs.items():
|
||||
|
@ -1,7 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import coremltools as ct # type: ignore[import]
|
||||
from coremltools.converters.mil.input_types import TensorType # type: ignore[import]
|
||||
@ -83,7 +82,7 @@ def _convert_to_mil_type(shape, dtype, name: str):
|
||||
return ml_type
|
||||
|
||||
|
||||
def preprocess(script_module: torch._C.ScriptObject, compile_spec: Dict[str, Tuple]):
|
||||
def preprocess(script_module: torch._C.ScriptObject, compile_spec: dict[str, tuple]):
|
||||
spec = compile_spec["forward"]
|
||||
(
|
||||
input_specs,
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.backends._nnapi.serializer import _NnapiSerializer
|
||||
@ -21,16 +21,16 @@ class NnapiModule(torch.nn.Module):
|
||||
|
||||
# _nnapi.Compilation is defined
|
||||
comp: Optional[torch.classes._nnapi.Compilation] # type: ignore[name-defined]
|
||||
weights: List[torch.Tensor]
|
||||
out_templates: List[torch.Tensor]
|
||||
weights: list[torch.Tensor]
|
||||
out_templates: list[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape_compute_module: torch.nn.Module,
|
||||
ser_model: torch.Tensor,
|
||||
weights: List[torch.Tensor],
|
||||
inp_mem_fmts: List[int],
|
||||
out_mem_fmts: List[int],
|
||||
weights: list[torch.Tensor],
|
||||
inp_mem_fmts: list[int],
|
||||
out_mem_fmts: list[int],
|
||||
compilation_preference: int,
|
||||
relax_f32_to_f16: bool,
|
||||
):
|
||||
@ -46,7 +46,7 @@ class NnapiModule(torch.nn.Module):
|
||||
self.relax_f32_to_f16 = relax_f32_to_f16
|
||||
|
||||
@torch.jit.export
|
||||
def init(self, args: List[torch.Tensor]):
|
||||
def init(self, args: list[torch.Tensor]):
|
||||
assert self.comp is None
|
||||
self.out_templates = self.shape_compute_module.prepare(self.ser_model, args) # type: ignore[operator]
|
||||
self.weights = [w.contiguous() for w in self.weights]
|
||||
@ -60,7 +60,7 @@ class NnapiModule(torch.nn.Module):
|
||||
|
||||
self.comp = comp
|
||||
|
||||
def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
def forward(self, args: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||
if self.comp is None:
|
||||
self.init(args)
|
||||
comp = self.comp
|
||||
|
@ -6,7 +6,7 @@ import logging
|
||||
import operator
|
||||
import struct
|
||||
import sys
|
||||
from typing import List, NamedTuple, Optional, Tuple
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -210,7 +210,7 @@ class Operand(NamedTuple):
|
||||
# This is always the PyTorch shape, which is NCHW for feature maps.
|
||||
# The actual NNAPI operand might have a transposed shape.
|
||||
# we use 0 for load time dynamic shapes & -1 for runtime dynamic shapes
|
||||
shape: Tuple[int, ...]
|
||||
shape: tuple[int, ...]
|
||||
|
||||
# Specifies how the shape of the operand that we define in NNAPI
|
||||
# relates to the shape we track above.
|
||||
@ -943,8 +943,8 @@ class _NnapiSerializer:
|
||||
assert node.outputsSize() == 1
|
||||
output = node.outputsAt(0)
|
||||
ctype = output.type()
|
||||
const_vals: Optional[List] = []
|
||||
tensors: Optional[List] = []
|
||||
const_vals: Optional[list] = []
|
||||
tensors: Optional[list] = []
|
||||
for inp in node.inputs():
|
||||
if const_vals is not None and inp in self.constants:
|
||||
_, val = self.get_constant_value(inp)
|
||||
|
@ -39,7 +39,7 @@ class _QEngineProp:
|
||||
|
||||
|
||||
class _SupportedQEnginesProp:
|
||||
def __get__(self, obj, objtype) -> List[str]:
|
||||
def __get__(self, obj, objtype) -> list[str]:
|
||||
qengines = torch._C._supported_qengines()
|
||||
return [_get_qengine_str(qe) for qe in qengines]
|
||||
|
||||
@ -63,4 +63,4 @@ class QuantizedEngine(types.ModuleType):
|
||||
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
|
||||
sys.modules[__name__] = QuantizedEngine(sys.modules[__name__], __name__)
|
||||
engine: str
|
||||
supported_engines: List[str]
|
||||
supported_engines: list[str]
|
||||
|
@ -132,7 +132,6 @@ import subprocess
|
||||
import sys
|
||||
from argparse import ArgumentParser, RawTextHelpFormatter, REMAINDER
|
||||
from os.path import expanduser
|
||||
from typing import Dict, List
|
||||
|
||||
from torch.distributed.elastic.multiprocessing import (
|
||||
DefaultLogsSpecs as _DefaultLogsSpecs,
|
||||
@ -181,8 +180,8 @@ class _CPUinfo:
|
||||
# physical cores := core column in lscpu output
|
||||
# logical cores := cPU column in lscpu output
|
||||
self.node_nums = int(max(line[3] for line in self.cpuinfo)) + 1
|
||||
self.node_physical_cores: List[List[int]] = [] # node_id is index
|
||||
self.node_logical_cores: List[List[int]] = [] # node_id is index
|
||||
self.node_physical_cores: list[list[int]] = [] # node_id is index
|
||||
self.node_logical_cores: list[list[int]] = [] # node_id is index
|
||||
self.physical_core_node_map = {} # physical core to numa node id
|
||||
self.logical_core_node_map = {} # logical core to numa node id
|
||||
|
||||
@ -594,7 +593,7 @@ won't take effect even if it is set explicitly."
|
||||
)
|
||||
entrypoint = ""
|
||||
launch_args = {}
|
||||
launch_envs: Dict[int, Dict] = {}
|
||||
launch_envs: dict[int, dict] = {}
|
||||
launch_tee = {}
|
||||
# check whether is launched from torchrun with --nproc-per-node <num workers>
|
||||
local_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
|
||||
@ -623,7 +622,7 @@ won't take effect even if it is set explicitly."
|
||||
* args.ncores_per_instance
|
||||
]
|
||||
|
||||
core_ranges: List[Dict] = []
|
||||
core_ranges: list[dict] = []
|
||||
if local_size > 1:
|
||||
total_num_cores = len(core_list)
|
||||
cores_per_rank = total_num_cores // local_size
|
||||
|
@ -432,7 +432,7 @@ def is_exporting() -> bool:
|
||||
return _is_exporting_flag
|
||||
|
||||
|
||||
def save_cache_artifacts() -> Optional[Tuple[bytes, "CacheInfo"]]:
|
||||
def save_cache_artifacts() -> Optional[tuple[bytes, "CacheInfo"]]:
|
||||
"""
|
||||
Serializes all the cache artifacts that were created during the compilation
|
||||
|
||||
|
@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
import pickle
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from torch._inductor.remote_cache import JsonDataTy, RemoteCacheJsonSerde
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
@ -41,10 +41,10 @@ class CacheInfo:
|
||||
instrumentation
|
||||
"""
|
||||
|
||||
inductor_artifacts: List[str] = dataclasses.field(default_factory=list)
|
||||
autotune_artifacts: List[str] = dataclasses.field(default_factory=list)
|
||||
aot_autograd_artifacts: List[str] = dataclasses.field(default_factory=list)
|
||||
pgo_artifacts: List[str] = dataclasses.field(default_factory=list)
|
||||
inductor_artifacts: list[str] = dataclasses.field(default_factory=list)
|
||||
autotune_artifacts: list[str] = dataclasses.field(default_factory=list)
|
||||
aot_autograd_artifacts: list[str] = dataclasses.field(default_factory=list)
|
||||
pgo_artifacts: list[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
def add(self, artifact: CacheArtifact) -> None:
|
||||
if artifact.type == CacheArtifactType.INDUCTOR:
|
||||
@ -77,7 +77,7 @@ class CacheArtifactManager:
|
||||
"""
|
||||
|
||||
# Protected by the compile_lock
|
||||
_cache_artifacts: List[CacheArtifact] = []
|
||||
_cache_artifacts: list[CacheArtifact] = []
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
@ -102,7 +102,7 @@ class CacheArtifactManager:
|
||||
cls._cache_artifacts.append(CacheArtifact(artifact_type, key, content))
|
||||
|
||||
@classmethod
|
||||
def serialize(cls) -> Optional[Tuple[bytes, CacheInfo]]:
|
||||
def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
|
||||
"""
|
||||
Converts the "mega" list into portable format
|
||||
"""
|
||||
|
@ -45,8 +45,8 @@ except ImportError:
|
||||
_initialized = False
|
||||
_tls = threading.local()
|
||||
_initialization_lock = threading.Lock()
|
||||
_queued_calls: List[
|
||||
Tuple[Callable[[], None], List[str]]
|
||||
_queued_calls: list[
|
||||
tuple[Callable[[], None], list[str]]
|
||||
] = [] # don't invoke these until initialization occurs
|
||||
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
|
||||
_device_t = Union[_device, str, int, None]
|
||||
@ -101,7 +101,7 @@ else:
|
||||
has_half: bool = True
|
||||
has_magma: bool = torch._C._has_magma
|
||||
|
||||
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
|
||||
default_generators: tuple[torch._C.Generator] = () # type: ignore[assignment]
|
||||
|
||||
|
||||
def _is_compiled() -> bool:
|
||||
@ -492,7 +492,7 @@ def get_device_name(device: Optional[_device_t] = None) -> str:
|
||||
return get_device_properties(device).name
|
||||
|
||||
|
||||
def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]:
|
||||
def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]:
|
||||
r"""Get the cuda capability of a device.
|
||||
|
||||
Args:
|
||||
@ -642,7 +642,7 @@ def set_stream(stream: Stream):
|
||||
)
|
||||
|
||||
|
||||
def _parse_visible_devices() -> Union[List[int], List[str]]:
|
||||
def _parse_visible_devices() -> Union[list[int], list[str]]:
|
||||
r"""Parse CUDA_VISIBLE_DEVICES environment variable."""
|
||||
var = os.getenv("CUDA_VISIBLE_DEVICES")
|
||||
|
||||
@ -683,12 +683,12 @@ def _parse_visible_devices() -> Union[List[int], List[str]]:
|
||||
idx += 1
|
||||
return int(s[:idx]) if idx > 0 else -1
|
||||
|
||||
def parse_list_with_prefix(lst: str, prefix: str) -> List[str]:
|
||||
rcs: List[str] = []
|
||||
def parse_list_with_prefix(lst: str, prefix: str) -> list[str]:
|
||||
rcs: list[str] = []
|
||||
for elem in lst.split(","):
|
||||
# Repeated id results in empty set
|
||||
if elem in rcs:
|
||||
return cast(List[str], [])
|
||||
return cast(list[str], [])
|
||||
# Anything other but prefix is ignored
|
||||
if not elem.startswith(prefix):
|
||||
break
|
||||
@ -701,12 +701,12 @@ def _parse_visible_devices() -> Union[List[int], List[str]]:
|
||||
return parse_list_with_prefix(var, "MIG-")
|
||||
# CUDA_VISIBLE_DEVICES uses something like strtoul
|
||||
# which makes `1gpu2,2ampere` is equivalent to `1,2`
|
||||
rc: List[int] = []
|
||||
rc: list[int] = []
|
||||
for elem in var.split(","):
|
||||
x = _strtoul(elem.strip())
|
||||
# Repeated ordinal results in empty set
|
||||
if x in rc:
|
||||
return cast(List[int], [])
|
||||
return cast(list[int], [])
|
||||
# Negative value aborts the sequence
|
||||
if x < 0:
|
||||
break
|
||||
@ -744,7 +744,7 @@ def _raw_device_count_nvml() -> int:
|
||||
return dev_count.value
|
||||
|
||||
|
||||
def _raw_device_uuid_amdsmi() -> Optional[List[str]]:
|
||||
def _raw_device_uuid_amdsmi() -> Optional[list[str]]:
|
||||
from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer
|
||||
|
||||
if not _HAS_PYNVML: # If amdsmi is not available
|
||||
@ -760,7 +760,7 @@ def _raw_device_uuid_amdsmi() -> Optional[List[str]]:
|
||||
except amdsmi.AmdSmiException:
|
||||
warnings.warn("Can't get amdsmi device count")
|
||||
return None
|
||||
uuids: List[str] = []
|
||||
uuids: list[str] = []
|
||||
for idx in range(dev_count):
|
||||
try:
|
||||
handler = amdsmi.amdsmi_get_processor_handles()[idx]
|
||||
@ -780,7 +780,7 @@ def _raw_device_uuid_amdsmi() -> Optional[List[str]]:
|
||||
return uuids
|
||||
|
||||
|
||||
def _raw_device_uuid_nvml() -> Optional[List[str]]:
|
||||
def _raw_device_uuid_nvml() -> Optional[list[str]]:
|
||||
r"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed."""
|
||||
from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer
|
||||
|
||||
@ -794,7 +794,7 @@ def _raw_device_uuid_nvml() -> Optional[List[str]]:
|
||||
if rc != 0:
|
||||
warnings.warn("Can't get nvml device count")
|
||||
return None
|
||||
uuids: List[str] = []
|
||||
uuids: list[str] = []
|
||||
for idx in range(dev_count.value):
|
||||
dev_id = c_void_p()
|
||||
rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
|
||||
@ -812,10 +812,10 @@ def _raw_device_uuid_nvml() -> Optional[List[str]]:
|
||||
return uuids
|
||||
|
||||
|
||||
def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List[int]:
|
||||
def _transform_uuid_to_ordinals(candidates: list[str], uuids: list[str]) -> list[int]:
|
||||
r"""Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials IDs."""
|
||||
|
||||
def uuid_to_ordinal(candidate: str, uuids: List[str]) -> int:
|
||||
def uuid_to_ordinal(candidate: str, uuids: list[str]) -> int:
|
||||
best_match = -1
|
||||
for idx, uuid in enumerate(uuids):
|
||||
if not uuid.startswith(candidate):
|
||||
@ -826,7 +826,7 @@ def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List
|
||||
best_match = idx
|
||||
return best_match
|
||||
|
||||
rc: List[int] = []
|
||||
rc: list[int] = []
|
||||
for candidate in candidates:
|
||||
if torch.version.hip:
|
||||
candidate = candidate.replace(
|
||||
@ -838,7 +838,7 @@ def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List
|
||||
break
|
||||
# Duplicates result in empty set
|
||||
if idx in rc:
|
||||
return cast(List[int], [])
|
||||
return cast(list[int], [])
|
||||
rc.append(idx)
|
||||
return rc
|
||||
|
||||
@ -853,7 +853,7 @@ def _device_count_amdsmi() -> int:
|
||||
if uuids is None:
|
||||
return -1
|
||||
# Create string version of visible devices to avoid mypy warnings
|
||||
visible_device_str = cast(List[str], visible_devices)
|
||||
visible_device_str = cast(list[str], visible_devices)
|
||||
visible_devices = _transform_uuid_to_ordinals(visible_device_str, uuids)
|
||||
else:
|
||||
raw_cnt = _raw_device_count_amdsmi()
|
||||
@ -887,7 +887,7 @@ def _device_count_nvml() -> int:
|
||||
if uuids is None:
|
||||
return -1
|
||||
visible_devices = _transform_uuid_to_ordinals(
|
||||
cast(List[str], visible_devices), uuids
|
||||
cast(list[str], visible_devices), uuids
|
||||
)
|
||||
else:
|
||||
raw_cnt = _raw_device_count_nvml()
|
||||
@ -913,9 +913,9 @@ def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int:
|
||||
if uuids is None:
|
||||
raise RuntimeError("Can't get device UUIDs")
|
||||
visible_devices = _transform_uuid_to_ordinals(
|
||||
cast(List[str], visible_devices), uuids
|
||||
cast(list[str], visible_devices), uuids
|
||||
)
|
||||
visible_devices = cast(List[int], visible_devices)
|
||||
visible_devices = cast(list[int], visible_devices)
|
||||
if idx < 0 or idx >= len(visible_devices):
|
||||
raise RuntimeError(
|
||||
f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})"
|
||||
@ -944,7 +944,7 @@ def device_count() -> int:
|
||||
return r
|
||||
|
||||
|
||||
def get_arch_list() -> List[str]:
|
||||
def get_arch_list() -> list[str]:
|
||||
r"""Return list CUDA architectures this library was compiled for."""
|
||||
if not is_available():
|
||||
return []
|
||||
@ -1145,10 +1145,10 @@ def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int:
|
||||
if uuids is None:
|
||||
raise RuntimeError("Can't get device UUIDs")
|
||||
visible_devices_str = cast(
|
||||
List[str], visible_devices
|
||||
list[str], visible_devices
|
||||
) # Create str variable for mypy
|
||||
visible_devices = _transform_uuid_to_ordinals(visible_devices_str, uuids)
|
||||
idx_map = dict(enumerate(cast(List[int], visible_devices)))
|
||||
idx_map = dict(enumerate(cast(list[int], visible_devices)))
|
||||
if idx not in idx_map:
|
||||
raise RuntimeError(
|
||||
f"device {idx} is not visible (HIP_VISIBLE_DEVICES={visible_devices})"
|
||||
|
@ -20,8 +20,9 @@ import re
|
||||
import sys
|
||||
import textwrap
|
||||
import traceback
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar
|
||||
from typing import Any, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.cuda._gpu_trace as gpu_trace
|
||||
@ -74,7 +75,7 @@ class Access:
|
||||
seq_num: SeqNum
|
||||
stream: StreamId
|
||||
operator: str
|
||||
aliases: List[str]
|
||||
aliases: list[str]
|
||||
is_output: bool
|
||||
stack_trace: traceback.StackSummary
|
||||
|
||||
@ -141,7 +142,7 @@ class UnsynchronizedAccessError(SynchronizationError):
|
||||
class CUDASanitizerErrors(Exception):
|
||||
"""Wrapper class for errors reported by CUDA Sanitizer."""
|
||||
|
||||
def __init__(self, errors: List[SynchronizationError]):
|
||||
def __init__(self, errors: list[SynchronizationError]):
|
||||
self.errors = errors
|
||||
|
||||
def __str__(self):
|
||||
@ -161,13 +162,13 @@ class TensorInfo:
|
||||
"""
|
||||
|
||||
allocation_stack_trace: Optional[traceback.StackSummary]
|
||||
reads: List[Access] = field(default_factory=list)
|
||||
reads: list[Access] = field(default_factory=list)
|
||||
write: Optional[Access] = None
|
||||
|
||||
|
||||
class _TensorsAccessed:
|
||||
def __init__(self) -> None:
|
||||
self.accesses: Dict[DataPtr, TensorInfo] = {}
|
||||
self.accesses: dict[DataPtr, TensorInfo] = {}
|
||||
|
||||
def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
|
||||
if data_ptr not in self.accesses:
|
||||
@ -209,7 +210,7 @@ class _TensorsAccessed:
|
||||
def get_write(self, data_ptr: DataPtr) -> Optional[Access]:
|
||||
return self.accesses[data_ptr].write
|
||||
|
||||
def get_reads(self, data_ptr: DataPtr) -> List[Access]:
|
||||
def get_reads(self, data_ptr: DataPtr) -> list[Access]:
|
||||
return self.accesses[data_ptr].reads
|
||||
|
||||
def add_read(self, data_ptr: DataPtr, access: Access) -> None:
|
||||
@ -222,9 +223,9 @@ class _TensorsAccessed:
|
||||
|
||||
class StreamSynchronizations:
|
||||
def __init__(self) -> None:
|
||||
self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
|
||||
self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
|
||||
self.host_sync_state: Dict[StreamId, SeqNum] = {}
|
||||
self.current_sync_states: dict[StreamId, dict[StreamId, SeqNum]] = {}
|
||||
self.recorded_sync_states: dict[EventId, dict[StreamId, SeqNum]] = {}
|
||||
self.host_sync_state: dict[StreamId, SeqNum] = {}
|
||||
self.create_stream(DEFAULT_STREAM_ID)
|
||||
|
||||
def _ensure_stream_exists(self, stream: StreamId) -> None:
|
||||
@ -288,7 +289,7 @@ class StreamSynchronizations:
|
||||
self.recorded_sync_states[event] = self.current_sync_states[stream].copy()
|
||||
|
||||
def _state_wait_for_other(
|
||||
self, state: Dict[StreamId, SeqNum], other: Dict[StreamId, SeqNum]
|
||||
self, state: dict[StreamId, SeqNum], other: dict[StreamId, SeqNum]
|
||||
) -> None:
|
||||
for stream, seq_num in other.items():
|
||||
state[stream] = max(state.get(stream, -1), seq_num)
|
||||
@ -349,12 +350,12 @@ class EventHandler:
|
||||
def _handle_kernel_launch(
|
||||
self,
|
||||
stream: StreamId,
|
||||
read_only: Set[DataPtr],
|
||||
read_write: Set[DataPtr],
|
||||
outputs: Set[DataPtr],
|
||||
read_only: set[DataPtr],
|
||||
read_write: set[DataPtr],
|
||||
outputs: set[DataPtr],
|
||||
operator: str,
|
||||
tensor_aliases: Dict[int, List[str]],
|
||||
) -> List[SynchronizationError]:
|
||||
tensor_aliases: dict[int, list[str]],
|
||||
) -> list[SynchronizationError]:
|
||||
def check_conflict(
|
||||
data_ptr: DataPtr, current_access: Access, previous_access: Optional[Access]
|
||||
) -> None:
|
||||
@ -372,7 +373,7 @@ class EventHandler:
|
||||
)
|
||||
)
|
||||
|
||||
error_list: List[SynchronizationError] = []
|
||||
error_list: list[SynchronizationError] = []
|
||||
self.seq_num += 1
|
||||
self.syncs.update_seq_num(stream, self.seq_num)
|
||||
stack_trace = traceback.StackSummary.extract(
|
||||
@ -462,15 +463,15 @@ class EventHandler:
|
||||
self.syncs.all_streams_wait_for_event(event)
|
||||
|
||||
|
||||
def zip_by_key(a: Dict[TK, TVa], b: Dict[TK, TVb]) -> Iterator[Tuple[TK, TVa, TVb]]:
|
||||
def zip_by_key(a: dict[TK, TVa], b: dict[TK, TVb]) -> Iterator[tuple[TK, TVa, TVb]]:
|
||||
for arg, value in a.items():
|
||||
if arg in b:
|
||||
yield arg, value, b[arg]
|
||||
|
||||
|
||||
def zip_arguments(
|
||||
schema: torch.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
||||
) -> Iterator[Tuple[torch.Argument, Any]]:
|
||||
schema: torch.FunctionSchema, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> Iterator[tuple[torch.Argument, Any]]:
|
||||
schema_args = schema.arguments[: len(args)]
|
||||
schema_kwargs = {arg.name: arg for arg in schema.arguments[len(args) :]}
|
||||
|
||||
@ -482,10 +483,10 @@ def zip_arguments(
|
||||
|
||||
class ArgumentHandler:
|
||||
def __init__(self) -> None:
|
||||
self.dataptrs_read: Set[DataPtr] = set()
|
||||
self.dataptrs_written: Set[DataPtr] = set()
|
||||
self.tensor_aliases: Dict[DataPtr, List[str]] = {}
|
||||
self.outputs: Set[DataPtr] = set()
|
||||
self.dataptrs_read: set[DataPtr] = set()
|
||||
self.dataptrs_written: set[DataPtr] = set()
|
||||
self.tensor_aliases: dict[DataPtr, list[str]] = {}
|
||||
self.outputs: set[DataPtr] = set()
|
||||
|
||||
def _handle_argument(
|
||||
self,
|
||||
@ -511,8 +512,8 @@ class ArgumentHandler:
|
||||
def parse_inputs(
|
||||
self,
|
||||
schema: torch.FunctionSchema,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
*,
|
||||
is_factory: bool,
|
||||
) -> None:
|
||||
|
@ -1,12 +1,12 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.types import Storage
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
def _dummy_fn(name: str) -> Callable:
|
||||
|
@ -1,12 +1,12 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import re
|
||||
from typing import Callable, List
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
class _CodeParser:
|
||||
|
@ -8,7 +8,7 @@ import pickle
|
||||
import sys
|
||||
import warnings
|
||||
from inspect import signature
|
||||
from typing import Any, Dict, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
@ -218,7 +218,7 @@ def empty_cache() -> None:
|
||||
torch._C._cuda_emptyCache()
|
||||
|
||||
|
||||
def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]:
|
||||
def memory_stats(device: Union[Device, int] = None) -> dict[str, Any]:
|
||||
r"""Return a dictionary of CUDA memory allocator statistics for a given device.
|
||||
|
||||
The return value of this function is a dictionary of statistics, each of
|
||||
@ -323,7 +323,7 @@ def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]:
|
||||
return collections.OrderedDict(result)
|
||||
|
||||
|
||||
def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> Dict[str, Any]:
|
||||
def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> dict[str, Any]:
|
||||
r"""Return the result of :func:`~torch.cuda.memory_stats` as a nested dictionary."""
|
||||
if not is_initialized():
|
||||
return {}
|
||||
@ -719,7 +719,7 @@ def list_gpu_processes(device: Union[Device, int] = None) -> str:
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]:
|
||||
def mem_get_info(device: Union[Device, int] = None) -> tuple[int, int]:
|
||||
r"""Return the global free and total GPU memory for a given device using cudaMemGetInfo.
|
||||
|
||||
Args:
|
||||
@ -1035,7 +1035,7 @@ class MemPool(_MemPool):
|
||||
super().__init__(allocator, True)
|
||||
|
||||
@property
|
||||
def id(self) -> Tuple[int, int]:
|
||||
def id(self) -> tuple[int, int]:
|
||||
r"""Returns the ID of this pool as a tuple of two ints."""
|
||||
return super().id
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import collections
|
||||
import warnings
|
||||
from typing import Optional, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch.cuda
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Iterable, List, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -42,7 +43,7 @@ def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
|
||||
return default_generator.get_state()
|
||||
|
||||
|
||||
def get_rng_state_all() -> List[Tensor]:
|
||||
def get_rng_state_all() -> list[Tensor]:
|
||||
r"""Return a list of ByteTensor representing the random number states of all devices."""
|
||||
results = [get_rng_state(i) for i in range(device_count())]
|
||||
return results
|
||||
|
@ -118,7 +118,7 @@ import multiprocessing as mp
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -228,12 +228,12 @@ def get_filename() -> str:
|
||||
return torch._C._cuda_tunableop_get_filename() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def get_results() -> Tuple[str, str, str, float]:
|
||||
def get_results() -> tuple[str, str, str, float]:
|
||||
r"""Return all TunableOp results."""
|
||||
return torch._C._cuda_tunableop_get_results() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def get_validators() -> Tuple[str, str]:
|
||||
def get_validators() -> tuple[str, str]:
|
||||
r"""Return the TunableOp validators."""
|
||||
return torch._C._cuda_tunableop_get_validators() # type: ignore[attr-defined]
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
@ -14,14 +14,14 @@ from torch.masked.maskedtensor.creation import as_masked_tensor
|
||||
if TYPE_CHECKING:
|
||||
from torch.types import _dtype as DType
|
||||
|
||||
DimOrDims = Optional[Union[int, Tuple[int], List[int]]]
|
||||
DimOrDims = Optional[Union[int, tuple[int], list[int]]]
|
||||
else:
|
||||
# The JIT doesn't understand Union, nor torch.dtype here
|
||||
DType = int
|
||||
DimOrDims = Optional[Tuple[int]]
|
||||
DimOrDims = Optional[tuple[int]]
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
@ -291,7 +291,7 @@ defined as ``prod(x[:i])``.""",
|
||||
example_dim = 1
|
||||
example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]])
|
||||
example_mask = torch.tensor([[True, False, True], [False, False, False]])
|
||||
example_args: Tuple[Any, ...]
|
||||
example_args: tuple[Any, ...]
|
||||
if func.__name__ in {"norm", "normalize"}:
|
||||
example_args = (2.0, example_dim)
|
||||
example_input = example_input.to(dtype=torch.float32)
|
||||
@ -303,8 +303,8 @@ defined as ``prod(x[:i])``.""",
|
||||
else:
|
||||
example_args = (example_dim,)
|
||||
|
||||
operation_args: Tuple[str, ...]
|
||||
operation_kwargs: Tuple[str, ...]
|
||||
operation_args: tuple[str, ...]
|
||||
operation_kwargs: tuple[str, ...]
|
||||
operation_args, operation_kwargs = args_and_kwargs[func.__name__]
|
||||
arg_declarations = [
|
||||
"\n ".join(
|
||||
@ -461,9 +461,9 @@ def _reduction_identity(op_name: str, input: Tensor, *args):
|
||||
raise NotImplementedError(f"identity of {op_name} on {dtype} input")
|
||||
|
||||
|
||||
def _canonical_dim(dim: DimOrDims, ndim: int) -> Tuple[int, ...]:
|
||||
def _canonical_dim(dim: DimOrDims, ndim: int) -> tuple[int, ...]:
|
||||
"""Return dim argument as a tuple of sorted dim values."""
|
||||
dims: List[int] = []
|
||||
dims: list[int] = []
|
||||
if dim == ():
|
||||
# Currently, `dim=()` in reductions operations means "reduce
|
||||
# over all dimensions" while in future, it will read "no
|
||||
@ -618,7 +618,7 @@ def _sparse_coo_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor
|
||||
def _sparse_coo_scatter_reduction_helper(
|
||||
op,
|
||||
mask_input: Tensor,
|
||||
dims: Tuple[int, ...],
|
||||
dims: tuple[int, ...],
|
||||
keepdim: bool,
|
||||
dtype: Optional[DType] = None,
|
||||
) -> Tensor:
|
||||
@ -738,7 +738,7 @@ def _sparse_coo_scatter_reduction_helper(
|
||||
def _sparse_csr_segment_reduction_helper(
|
||||
op,
|
||||
mask_input: Tensor,
|
||||
dims: Tuple[int, ...],
|
||||
dims: tuple[int, ...],
|
||||
keepdim: bool,
|
||||
dtype: Optional[DType] = None,
|
||||
) -> Tensor:
|
||||
|
@ -2,7 +2,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, TYPE_CHECKING
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@ -228,7 +228,7 @@ def _function_to_sparse_csr(func, *args, **kwargs):
|
||||
return _MaskedToSparseCsr.apply(args[0])
|
||||
|
||||
|
||||
_MASKEDTENSOR_DISPATCH_TABLE: Dict["OpOverload", Callable[..., Any]] = {}
|
||||
_MASKEDTENSOR_DISPATCH_TABLE: dict["OpOverload", Callable[..., Any]] = {}
|
||||
|
||||
|
||||
def register_dispatch_func(aten_ops):
|
||||
|
@ -22,8 +22,8 @@ Event = torch.Event
|
||||
Stream = torch.Stream
|
||||
|
||||
_initialized = False
|
||||
_queued_calls: List[
|
||||
Tuple[Callable[[], None], List[str]]
|
||||
_queued_calls: list[
|
||||
tuple[Callable[[], None], list[str]]
|
||||
] = [] # don't invoke these until initialization occurs
|
||||
_tls = threading.local()
|
||||
_initialization_lock = threading.Lock()
|
||||
@ -170,13 +170,13 @@ def record_memory_history(
|
||||
torch._C._mtia_recordMemoryHistory(enabled, stacks, max_entries)
|
||||
|
||||
|
||||
def snapshot() -> Dict[str, Any]:
|
||||
def snapshot() -> dict[str, Any]:
|
||||
r"""Return a dictionary of MTIA memory allocator history"""
|
||||
|
||||
return torch._C._mtia_memorySnapshot()
|
||||
|
||||
|
||||
def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]:
|
||||
def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]:
|
||||
r"""Return capability of a given device as a tuple of (major version, minor version).
|
||||
|
||||
Args:
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
r"""This package adds support for device memory management implemented in MTIA."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -10,7 +10,7 @@ from . import _device_t, is_initialized
|
||||
from ._utils import _get_device_index
|
||||
|
||||
|
||||
def memory_stats(device: Optional[_device_t] = None) -> Dict[str, Any]:
|
||||
def memory_stats(device: Optional[_device_t] = None) -> dict[str, Any]:
|
||||
r"""Return a dictionary of MTIA memory allocator statistics for a given device.
|
||||
|
||||
Args:
|
||||
|
@ -25,7 +25,7 @@ torch.serialization.add_safe_globals([_NestedTensor, _rebuild_njt])
|
||||
|
||||
|
||||
def as_nested_tensor(
|
||||
ts: Union[Tensor, List[Tensor], Tuple[Tensor, ...]],
|
||||
ts: Union[Tensor, list[Tensor], tuple[Tensor, ...]],
|
||||
dtype: Optional[DType] = None,
|
||||
device: Optional[Device] = None,
|
||||
layout=None,
|
||||
|
@ -1,6 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import * # noqa: F403
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch._C import DispatchKey, DispatchKeySet
|
||||
@ -69,8 +68,8 @@ class NestedTensor(torch.Tensor):
|
||||
# We also use nested ints to represent the strides of this tensor.
|
||||
# For example, a jagged tensor with shape [B, x, D] can be strided in two
|
||||
# ways: [xD, D, 1] and [x, 1, sum(x)], where xD represents x multiplied by D
|
||||
_size: Tuple[int, ...]
|
||||
_strides: Tuple[int, ...]
|
||||
_size: tuple[int, ...]
|
||||
_strides: tuple[int, ...]
|
||||
# Indicates that the nth dimension is ragged
|
||||
_ragged_idx: int
|
||||
_metadata_cache: Dict[str, Any]
|
||||
@ -417,7 +416,7 @@ def jagged_from_list(
|
||||
offsets: Optional[torch.Tensor],
|
||||
dtype=None,
|
||||
device=None,
|
||||
) -> Tuple[NestedTensor, torch.Tensor]:
|
||||
) -> tuple[NestedTensor, torch.Tensor]:
|
||||
"""Constructs a NestedTensor backed by jagged layout from a list of tensors"""
|
||||
|
||||
if len(tensors) == 0:
|
||||
@ -500,7 +499,7 @@ def jagged_from_list(
|
||||
|
||||
def jagged_from_tensor_and_lengths(
|
||||
tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor
|
||||
) -> Tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths"""
|
||||
batch_size = tensor.shape[0]
|
||||
if is_expandable_to(starts.shape, (batch_size,)) and is_expandable_to(
|
||||
|
@ -3,7 +3,7 @@ import functools
|
||||
import math
|
||||
import operator
|
||||
from typing import * # noqa: F403
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -13,7 +13,7 @@ from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention
|
||||
from .nested_tensor import NestedTensor
|
||||
|
||||
|
||||
__all__: List[Any] = []
|
||||
__all__: list[Any] = []
|
||||
|
||||
JAGGED_OPS_TABLE: Dict[Any, Any] = {}
|
||||
|
||||
@ -973,7 +973,7 @@ def unbind_int(func, *args, **kwargs):
|
||||
lengths = inp.lengths()
|
||||
ragged_idx = inp._ragged_idx
|
||||
|
||||
def _torch_check(_lengths: List[int], _offsets: Optional[List[int]] = None):
|
||||
def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None):
|
||||
# This torch._check and torch._check_is_size are needed for torch.compile
|
||||
# symbolic shapes processing.
|
||||
# offsets and lengths are symbolic variables during compilation,
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn
|
||||
@ -302,7 +302,7 @@ def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable
|
||||
return SDPBackend.ERROR
|
||||
|
||||
|
||||
def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
||||
def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> tuple[torch.Tensor, int, int]:
|
||||
# This function is used to calculate two pieces of metadata that are needed
|
||||
# for use with flash-attention and efficient_attention kernels. They are the
|
||||
# cumulative sequence_length over a batch of sequences and the maximum
|
||||
@ -634,7 +634,7 @@ def _autocast(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
[Autocasting SDPA for NJT]
|
||||
|
||||
|
Reference in New Issue
Block a user