Compare commits

...

1 Commits

Author SHA1 Message Date
3118a33aa6 Use Python 3.10 typing
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
2025-11-10 17:59:01 +08:00
9 changed files with 79 additions and 84 deletions

View File

@ -1,5 +1,3 @@
from typing import Optional
import torch
from torch.export import ExportedProgram
@ -10,7 +8,7 @@ class LoweredBackendModule(torch.nn.Module):
original_exported_program: ExportedProgram,
backend_id: str,
*,
module_name: Optional[str] = None,
module_name: str | None = None,
) -> None:
super().__init__()
self._backend_id = backend_id
@ -22,7 +20,7 @@ class LoweredBackendModule(torch.nn.Module):
return self._backend_id
@property
def module_name(self) -> Optional[str]:
def module_name(self) -> str | None:
return self._module_name
@property

View File

@ -25,9 +25,9 @@ torch.serialization.add_safe_globals([_NestedTensor, _rebuild_njt])
def as_nested_tensor(
ts: Union[Tensor, list[Tensor], tuple[Tensor, ...]],
dtype: Optional[DType] = None,
device: Optional[Device] = None,
ts: Tensor | list[Tensor] | tuple[Tensor, ...],
dtype: DType | None = None,
device: Device | None = None,
layout=None,
) -> Tensor:
r"""
@ -281,8 +281,8 @@ def nested_tensor(
def narrow(
tensor: Tensor,
dim: int,
start: Union[int, Tensor],
length: Union[int, Tensor],
start: int | Tensor,
length: int | Tensor,
layout=torch.strided,
) -> Tensor:
r"""
@ -358,11 +358,11 @@ def narrow(
def nested_tensor_from_jagged(
values: Tensor,
offsets: Optional[Tensor] = None,
lengths: Optional[Tensor] = None,
jagged_dim: Optional[int] = None,
min_seqlen: Optional[int] = None,
max_seqlen: Optional[int] = None,
offsets: Tensor | None = None,
lengths: Tensor | None = None,
jagged_dim: int | None = None,
min_seqlen: int | None = None,
max_seqlen: int | None = None,
) -> Tensor:
r"""
Constructs a jagged layout nested tensor from the given jagged components. The jagged layout

View File

@ -3,7 +3,6 @@ import functools
import math
import operator
from typing import * # noqa: F403
from typing import Optional
import torch
import torch.nn.functional as F
@ -249,7 +248,7 @@ def register_func(tables, aten_ops, schema_str):
register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE)
def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]:
def lookup_jagged(func, *args, **kwargs) -> Callable | None:
dispatch_func = JAGGED_OPS_TABLE.get(func, None)
if dispatch_func is not None:
return dispatch_func
@ -1138,7 +1137,7 @@ def unbind_int(func, *args, **kwargs):
lengths = inp.lengths()
ragged_idx = inp._ragged_idx
def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None) -> None:
def _torch_check(_lengths: list[int], _offsets: list[int] | None = None) -> None:
# This torch._check are needed for torch.compile
# symbolic shapes processing.
# offsets and lengths are symbolic variables during compilation,

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import logging
from typing import Optional
import torch
import torch.nn
@ -27,7 +26,7 @@ def _validate_sdpa_input(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
attn_mask: torch.Tensor | None = None,
dropout_p=0.0,
is_causal=False,
scale=None,
@ -668,8 +667,8 @@ def _autocast(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
attn_mask: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""
[Autocasting SDPA for NJT]
@ -714,7 +713,7 @@ def jagged_scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
attn_mask: torch.Tensor | None = None,
dropout_p=0.0,
is_causal=False,
scale=None,

View File

@ -7,7 +7,7 @@ from dataclasses import asdict, dataclass
from enum import Enum
from functools import wraps
from logging import getLogger
from typing import Optional, ParamSpec, TypeVar
from typing import ParamSpec, TypeVar
import torch
from torch._utils_internal import signpost_event
@ -53,7 +53,7 @@ def maybe_wrap_command_args_with_numa_binding(
command_args: tuple[str, ...],
*,
gpu_index: int,
numa_options: Optional[NumaOptions],
numa_options: NumaOptions | None,
) -> tuple[str, ...]:
"""
Wraps command arguments with numactl to apply NUMA CPU binding.
@ -115,7 +115,7 @@ def maybe_wrap_with_numa_binding(
func: Callable[_TParams, _TReturn],
*,
gpu_index: int,
numa_options: Optional[NumaOptions],
numa_options: NumaOptions | None,
) -> Callable[_TParams, _TReturn]:
"""
Wraps a function to apply NUMA CPU binding before execution.

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
from collections.abc import Callable, Iterable
from math import sqrt
from typing import Optional, TypeVar
from typing import TypeVar
import torch
from torch import Tensor
@ -133,12 +133,12 @@ Examples::
def exponential(
M: int,
*,
center: Optional[float] = None,
center: float | None = None,
tau: float = 1.0,
sym: bool = True,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
device: torch.device | None = None,
requires_grad: bool = False,
) -> Tensor:
if dtype is None:
@ -220,9 +220,9 @@ def cosine(
M: int,
*,
sym: bool = True,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
device: torch.device | None = None,
requires_grad: bool = False,
) -> Tensor:
if dtype is None:
@ -294,9 +294,9 @@ def gaussian(
*,
std: float = 1.0,
sym: bool = True,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
device: torch.device | None = None,
requires_grad: bool = False,
) -> Tensor:
if dtype is None:
@ -373,9 +373,9 @@ def kaiser(
*,
beta: float = 12.0,
sym: bool = True,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
device: torch.device | None = None,
requires_grad: bool = False,
) -> Tensor:
if dtype is None:
@ -465,9 +465,9 @@ def hamming(
M: int,
*,
sym: bool = True,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
device: torch.device | None = None,
requires_grad: bool = False,
) -> Tensor:
return general_hamming(
@ -519,9 +519,9 @@ def hann(
M: int,
*,
sym: bool = True,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
device: torch.device | None = None,
requires_grad: bool = False,
) -> Tensor:
return general_hamming(
@ -573,9 +573,9 @@ def blackman(
M: int,
*,
sym: bool = True,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
device: torch.device | None = None,
requires_grad: bool = False,
) -> Tensor:
if dtype is None:
@ -634,9 +634,9 @@ def bartlett(
M: int,
*,
sym: bool = True,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
device: torch.device | None = None,
requires_grad: bool = False,
) -> Tensor:
if dtype is None:
@ -710,9 +710,9 @@ def general_cosine(
*,
a: Iterable,
sym: bool = True,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
device: torch.device | None = None,
requires_grad: bool = False,
) -> Tensor:
if dtype is None:
@ -803,9 +803,9 @@ def general_hamming(
*,
alpha: float = 0.54,
sym: bool = True,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
device: torch.device | None = None,
requires_grad: bool = False,
) -> Tensor:
return general_cosine(
@ -867,9 +867,9 @@ def nuttall(
M: int,
*,
sym: bool = True,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
device: torch.device | None = None,
requires_grad: bool = False,
) -> Tensor:
return general_cosine(

View File

@ -19,7 +19,7 @@ from .semi_structured import (
if TYPE_CHECKING:
from torch.types import _dtype as DType
DimOrDims = Optional[Union[int, tuple[int, ...], list[int]]]
DimOrDims = Optional[int | tuple[int, ...] | list[int]]
else:
# The JIT doesn't understand Union, nor torch.dtype here
DType = int
@ -198,7 +198,7 @@ Examples::
)
def sum(input: Tensor, dim: DimOrDims = None, dtype: Optional[DType] = None) -> Tensor:
def sum(input: Tensor, dim: DimOrDims = None, dtype: DType | None = None) -> Tensor:
r"""Return the sum of each row of the given sparse tensor.
Returns the sum of each row of the sparse tensor :attr:`input` in the given
@ -521,7 +521,7 @@ class check_sparse_tensor_invariants:
# context manager support
def __init__(self, enable=True):
self.state = enable
self.saved_state: Optional[bool] = None
self.saved_state: bool | None = None
def __enter__(self):
if self.saved_state is not None:

View File

@ -4,7 +4,6 @@ import math
import os
import weakref
from functools import lru_cache
from typing import Optional
import torch
from torch._dynamo.utils import warn_once
@ -1123,12 +1122,12 @@ def _int_bsr_dense_addmm(
*,
beta=1,
alpha=1,
left_alpha: Optional[torch.Tensor] = None,
right_alpha: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
left_alpha: torch.Tensor | None = None,
right_alpha: torch.Tensor | None = None,
out: torch.Tensor | None = None,
skip_checks: bool = False,
max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None,
meta: Optional[dict] = None,
max_grid: tuple[int | None, int | None, int | None] | None = None,
meta: dict | None = None,
):
if out is None and dense.dtype is torch.int8:
f_name = "_int_bsr_dense_addmm"
@ -1164,12 +1163,12 @@ def bsr_dense_addmm(
*,
beta=1,
alpha=1,
left_alpha: Optional[torch.Tensor] = None,
right_alpha: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
left_alpha: torch.Tensor | None = None,
right_alpha: torch.Tensor | None = None,
out: torch.Tensor | None = None,
skip_checks: bool = False,
max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None,
meta: Optional[dict] = None,
max_grid: tuple[int | None, int | None, int | None] | None = None,
meta: dict | None = None,
):
"""Compute
@ -1667,9 +1666,9 @@ if has_triton():
*,
beta=1.0,
alpha=1.0,
out: Optional[torch.Tensor] = None,
out: torch.Tensor | None = None,
skip_checks: bool = False,
max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None,
max_grid: tuple[int | None, int | None, int | None] | None = None,
):
f_name = "sampled_addmm"
@ -1751,10 +1750,10 @@ if has_triton():
bsr: torch.Tensor,
dense: torch.Tensor,
*,
out: Optional[torch.Tensor] = None,
out: torch.Tensor | None = None,
skip_checks: bool = False,
max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None,
meta: Optional[dict] = None,
max_grid: tuple[int | None, int | None, int | None] | None = None,
meta: dict | None = None,
):
f_name = "bsr_dense_mm"
m, _kl = bsr.shape[-2:]
@ -1967,10 +1966,10 @@ if has_triton():
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
attn_mask: torch.Tensor | None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
scale: float | None = None,
):
f_name = "_scaled_dot_product_attention"
check(not is_causal, f"{f_name}(): is_causal == True is not supported.")

View File

@ -2,7 +2,7 @@
import warnings
from collections import namedtuple
from collections.abc import Callable
from typing import Any, Optional
from typing import Any
import torch
from torch.sparse._semi_structured_conversions import (
@ -63,11 +63,11 @@ class SparseSemiStructuredTensor(torch.Tensor):
BACKEND: str
SPARSE_DISPATCH: dict[Callable, Callable]
packed: Optional[torch.Tensor]
meta: Optional[torch.Tensor]
packed_t: Optional[torch.Tensor]
meta_t: Optional[torch.Tensor]
compressed_swizzled_bitmask: Optional[torch.Tensor]
packed: torch.Tensor | None
meta: torch.Tensor | None
packed_t: torch.Tensor | None
meta_t: torch.Tensor | None
compressed_swizzled_bitmask: torch.Tensor | None
fuse_transpose_cusparselt: bool
alg_id_cusparselt: int
@ -77,11 +77,11 @@ class SparseSemiStructuredTensor(torch.Tensor):
def __new__( # noqa: PYI034
cls,
shape: torch.Size,
packed: Optional[torch.Tensor],
meta: Optional[torch.Tensor],
packed_t: Optional[torch.Tensor],
meta_t: Optional[torch.Tensor],
compressed_swizzled_bitmask: Optional[torch.Tensor],
packed: torch.Tensor | None,
meta: torch.Tensor | None,
packed_t: torch.Tensor | None,
meta_t: torch.Tensor | None,
compressed_swizzled_bitmask: torch.Tensor | None,
fuse_transpose_cusparselt: bool = False,
alg_id_cusparselt: int = 0,
requires_grad: bool = False,
@ -312,7 +312,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
self,
B: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
raise NotImplementedError
@ -514,7 +514,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
)
def _mm(
self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs
self, B: torch.Tensor, *, bias: torch.Tensor | None = None, **kwargs
) -> torch.Tensor:
if isinstance(B, SparseSemiStructuredTensor):
raise ValueError(
@ -643,7 +643,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
)
def _mm(
self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs
self, B: torch.Tensor, *, bias: torch.Tensor | None = None, **kwargs
) -> torch.Tensor:
if isinstance(B, SparseSemiStructuredTensor):
raise ValueError(