Codemod inductor/fx_passes from Optional to union none (#165606)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165606
Approved by: https://github.com/aorenste
ghstack dependencies: #165604, #165605
This commit is contained in:
Oguz Ulgen
2025-10-15 19:29:51 -07:00
committed by PyTorch MergeBot
parent ab6014a903
commit 5d0b22008d
12 changed files with 94 additions and 99 deletions

View File

@ -1,7 +1,7 @@
import collections
import logging
from collections import defaultdict
from typing import Any, Callable, Optional
from typing import Any, Callable
import torch
import torch.distributed as dist
@ -34,7 +34,7 @@ def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
return (group_name, reduce_op, dtype)
def bucket_key(node: torch.fx.Node) -> Optional[object]:
def bucket_key(node: torch.fx.Node) -> object | None:
if is_all_gather_into_tensor(node):
return _ag_group_key(node)
elif is_reduce_scatter_tensor(node):
@ -58,8 +58,8 @@ def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
def bucket_all_gather(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
mode: Optional[str] = None,
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
mode: str | None = None,
) -> None:
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
@ -75,8 +75,8 @@ def bucket_all_gather(
def bucket_reduce_scatter(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
mode: Optional[str] = None,
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
mode: str | None = None,
) -> None:
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
@ -156,7 +156,7 @@ def greedy_bucket_collective_by_mb(
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
filter_node: Callable[[torch.fx.Node], bool],
node_group_key: Callable[[torch.fx.Node], Any],
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
) -> list[list[torch.fx.Node]]:
"""
Bucketing adjacent collectives with equal node_group_key.
@ -234,7 +234,7 @@ def greedy_bucket_collective_by_mb(
def bucket_all_gather_by_mb(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
) -> list[list[torch.fx.Node]]:
"""
Identifies all all_gather nodes and groups them into buckets,
@ -247,7 +247,7 @@ def bucket_all_gather_by_mb(
to specify different sizes of the buckets at the start,
as first all_gather is usually exposed. Interface of bucket_cap_mb_by_bucket_idx
is `bucket_cap_mb_by_bucket_idx_default` function that is default value for `bucket_cap_mb_by_bucket_idx`.
filter_wait_node (Optional[Callable[[torch.fx.Node], bool]]): If specified,
filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified,
only all_gather nodes with wait_node that satisfy `filter_wait_node` will be bucketed.
Returns:
@ -266,7 +266,7 @@ def bucket_all_gather_by_mb(
def bucket_reduce_scatter_by_mb(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
) -> list[list[torch.fx.Node]]:
"""
Identifies all reduce_scatter nodes and groups them into buckets,
@ -277,7 +277,7 @@ def bucket_reduce_scatter_by_mb(
bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow
to specify different sizes of the buckets.
filter_wait_node (Optional[Callable[[torch.fx.Node], bool]]): If specified,
filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified,
only reduce_scatter nodes with wait_node that satisfy `filter_wait_node` will be bucketed.
Returns:
@ -577,8 +577,8 @@ def process_collective_bucket(
bucket_nodes: list[torch.fx.Node],
fn_to_trace: Callable[..., list[torch.Tensor]],
trace_args_fn: Callable[[list[torch.fx.Node]], tuple[Any, ...]],
insert_before: Optional[torch.fx.Node] = None,
wait_insertion_point: Optional[torch.fx.Node] = None,
insert_before: torch.fx.Node | None = None,
wait_insertion_point: torch.fx.Node | None = None,
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
"""
Process a single bucket of collective operation nodes with flexible insertion control.
@ -666,9 +666,9 @@ def process_collective_bucket(
def merge_reduce_scatter_bucket(
g: torch.fx.Graph,
rs_nodes: list[torch.fx.Node],
mode: Optional[str] = None,
insert_before: Optional[torch.fx.Node] = None,
wait_insertion_point: Optional[torch.fx.Node] = None,
mode: str | None = None,
insert_before: torch.fx.Node | None = None,
wait_insertion_point: torch.fx.Node | None = None,
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
# Validate bucket consistency
rs0 = rs_nodes[0]
@ -716,9 +716,9 @@ def merge_reduce_scatter_bucket(
def merge_all_gather_bucket(
g: torch.fx.Graph,
ag_nodes: list[torch.fx.Node],
mode: Optional[str] = None,
insert_before: Optional[torch.fx.Node] = None,
wait_insertion_point: Optional[torch.fx.Node] = None,
mode: str | None = None,
insert_before: torch.fx.Node | None = None,
wait_insertion_point: torch.fx.Node | None = None,
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
from torch.distributed.distributed_c10d import _resolve_process_group
@ -764,7 +764,7 @@ def merge_all_gather_bucket(
def merge_reduce_scatter(
gm: torch.fx.GraphModule,
rs_buckets: list[list[torch.fx.Node]],
mode: Optional[str] = None,
mode: str | None = None,
) -> None:
"""
Merges specified buckets of reduce_scatter to joint reduce_scatter.
@ -788,7 +788,7 @@ def merge_reduce_scatter(
def merge_all_gather(
gm: torch.fx.GraphModule,
ag_buckets: list[list[torch.fx.Node]],
mode: Optional[str] = None,
mode: str | None = None,
) -> None:
"""
Merges specified buckets of all_gather to joint all_gather.

View File

@ -7,7 +7,7 @@ import operator
from collections.abc import Generator
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, cast, Optional, Union
from typing import Any, Callable, cast, Union
import torch
import torch.fx as fx
@ -40,8 +40,8 @@ def move_block_before(block: list[fx.Node], target_node: fx.Node) -> None:
def call_function(
graph: fx.Graph,
target: Union[str, Callable[..., Any]],
args: Optional[tuple[fx.node.Argument, ...]] = None,
kwargs: Optional[dict[str, fx.node.Argument]] = None,
args: tuple[fx.node.Argument, ...] | None = None,
kwargs: dict[str, fx.node.Argument] | None = None,
) -> fx.Node:
# We accept target as a str to avoid typing error as the type of
# a node.target is Union[str, Callable[..., Any]].
@ -70,7 +70,7 @@ class CommBlock:
outputs: OrderedSet[fx.Node]
def get_comm_block(comm_node: fx.Node) -> Optional[CommBlock]:
def get_comm_block(comm_node: fx.Node) -> CommBlock | None:
"""
Given a collective node (e.g., allreduce), find out all the nodes belong to
this communication.
@ -150,7 +150,7 @@ def get_comm_block(comm_node: fx.Node) -> Optional[CommBlock]:
def get_all_comm_blocks(
graph: fx.Graph,
comm_ops: tuple[torch._ops.OpOverload, ...],
comm_filter: Optional[Callable[..., bool]] = None,
comm_filter: Callable[..., bool] | None = None,
) -> list[CommBlock]:
if comm_filter is None:

View File

@ -1,5 +1,5 @@
import logging
from typing import Callable, Optional
from typing import Callable
import torch
from torch._inductor.fx_passes.bucketing import (
@ -55,15 +55,15 @@ def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool:
def bucket_fsdp_all_gather(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
mode: Optional[str] = None,
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
mode: str | None = None,
) -> None:
"""
Bucketing pass for SimpleFSDP all_gather ops.
Attributes:
gm (torch.fx.GraphModule): Graph module of the graph.
bucket_cap_mb_by_bucket_idx (Optional[Callable[[int], float]]): callback function that
bucket_cap_mb_by_bucket_idx (Callable[[int], float] | None): callback function that
takes in bucket id and returns size of a bucket in megabytes.
"""
if bucket_cap_mb_by_bucket_idx is None:
@ -85,15 +85,15 @@ def bucket_fsdp_all_gather(
def bucket_fsdp_reduce_scatter(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
mode: Optional[str] = None,
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
mode: str | None = None,
) -> None:
"""
Bucketing pass for SimpleFSDP reduce_scatter ops.
Attributes:
gm (torch.fx.GraphModule): Graph module of the graph.
bucket_cap_mb_by_bucket_idx (Optional[Callable[[int], float]]): callback function that
bucket_cap_mb_by_bucket_idx (Callable[[int], float] | None): callback function that
takes in bucket idx and returns size of a bucket in megabytes. By default
torch._inductor.fx_passes.bucketing.bucket_cap_mb_by_bucket_idx_default is used.

View File

@ -4,7 +4,7 @@ import logging
import operator
from collections import OrderedDict
from collections.abc import Iterable, Iterator
from typing import Any, Optional
from typing import Any
import torch
from torch._dynamo.utils import counters, is_node_meta_valid
@ -185,9 +185,7 @@ class PostGradBatchLinearFusion(BatchFusion):
and isinstance(input_shapes[1], int)
)
def match(
self, node: torch.fx.Node
) -> Optional[tuple[str, int, int, int, bool, str]]:
def match(self, node: torch.fx.Node) -> tuple[str, int, int, int, bool, str] | None:
if CallFunctionVarArgs(aten.mm).match(node):
input_m, weight_m = node.args
bias_m = None
@ -325,7 +323,7 @@ class GroupLinearFusion(GroupFusion):
)
)
def match(self, node: torch.fx.Node) -> Optional[tuple[str, bool]]:
def match(self, node: torch.fx.Node) -> tuple[str, bool] | None:
if CallFunctionVarArgs(aten.mm.default).match(
node
) and self._mm_node_can_be_fused(node):
@ -493,7 +491,7 @@ class BatchLinearLHSFusion(BatchFusion):
We have a separate pass to eliminate contiguous transpose in a generic way.
"""
def match(self, node: torch.fx.Node) -> Optional[tuple[str, bool, Any]]:
def match(self, node: torch.fx.Node) -> tuple[str, bool, Any] | None:
if CallFunctionVarArgs(torch.nn.functional.linear).match(
node
) and is_linear_node_can_be_fused(node):

View File

@ -2,7 +2,7 @@ import itertools
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable, Optional, Union
from typing import Callable, Union
import torch
import torch.fx as fx
@ -154,7 +154,7 @@ def device_filter(device: torch.device) -> bool:
def build_memory_profile(
graph: fx.Graph,
is_releasable: Callable[[fx.Node], bool],
size_of: Optional[Callable[[Union[int, torch.SymInt]], int]] = None,
size_of: Callable[[Union[int, torch.SymInt]], int] | None = None,
) -> list[int]:
"""
Function to estimate the memory profile of an input FX graph.
@ -216,7 +216,7 @@ def build_memory_profile(
def get_fwd_bwd_interactions(
fwd_graph: fx.Graph,
bwd_graph: fx.Graph,
size_of: Optional[Callable[[Union[int, torch.SymInt]], int]] = None,
size_of: Callable[[Union[int, torch.SymInt]], int] | None = None,
) -> tuple[int, OrderedSet[str]]:
"""
Analyze the interactions between the forward (fwd) and backward (bwd) graphs
@ -325,8 +325,8 @@ class MemoryTracker:
def __init__(
self,
graph: fx.Graph,
is_releasable: Optional[Callable[[fx.Node], bool]] = None,
device_filter: Optional[Callable[[torch.device], bool]] = None,
is_releasable: Callable[[fx.Node], bool] | None = None,
device_filter: Callable[[torch.device], bool] | None = None,
):
"""
Initialize memory tracker for alternative scheduling of the given graph.

View File

@ -4,7 +4,7 @@ import operator
from collections import defaultdict
from dataclasses import dataclass, field
from math import prod
from typing import Any, cast, Optional
from typing import Any, cast
import torch
from torch.utils._ordered_set import OrderedSet
@ -374,8 +374,8 @@ class _Matmul:
arg_ancestor_nodes: OrderedSet[torch.fx.Node] = field(init=False)
A_node: torch.fx.Node
B_node: torch.fx.Node
pre_mm_reshape: Optional[torch.fx.Node]
post_mm_reshape: Optional[torch.fx.Node]
pre_mm_reshape: torch.fx.Node | None
post_mm_reshape: torch.fx.Node | None
def __post_init__(self):
assert len(self.nodes) in (1, 3)
@ -450,12 +450,12 @@ class _Matmul:
class _ScaledMatmul(_Matmul):
A_scale_node: torch.fx.Node
B_scale_node: torch.fx.Node
bias_node: Optional[torch.fx.Node]
result_scale_node: Optional[torch.fx.Node]
out_dtype: Optional[torch.dtype]
bias_node: torch.fx.Node | None
result_scale_node: torch.fx.Node | None
out_dtype: torch.dtype | None
use_fast_accum: bool
pre_mm_reshape: Optional[torch.fx.Node]
post_mm_reshape: Optional[torch.fx.Node]
pre_mm_reshape: torch.fx.Node | None
post_mm_reshape: torch.fx.Node | None
def __post_init__(self):
super().__post_init__()
@ -763,7 +763,7 @@ def _scatter_dim_after_reshape(
return 0 if leading_dims_collapsed else 1
def _find_producer_matmul(node: torch.fx.Node) -> Optional[_Matmul]:
def _find_producer_matmul(node: torch.fx.Node) -> _Matmul | None:
"""
Returns producer matmul node if found, otherwise returns None.
"""

View File

@ -6,7 +6,7 @@ import sys
from collections import Counter, defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Union
import torch
import torch.fx as fx
@ -42,7 +42,7 @@ def get_group_name(n: fx.Node) -> str:
return kwargs["group_name"]
def get_custom_estimation(n: fx.Node) -> Optional[float]:
def get_custom_estimation(n: fx.Node) -> float | None:
runtime_estimation = torch._inductor.config.test_configs.estimate_aten_runtime
if runtime_estimation == "default":
return None
@ -51,7 +51,7 @@ def get_custom_estimation(n: fx.Node) -> Optional[float]:
return runtime_estimation(n)
def estimate_collective_time(n: fx.Node, override_size: Optional[int] = None) -> float:
def estimate_collective_time(n: fx.Node, override_size: int | None = None) -> float:
"""Estimate the runtime of a collective operation, optionally with an overridden size."""
if (est := get_custom_estimation(n)) is not None:
return est
@ -82,7 +82,7 @@ def is_compute_node(n: fx.Node) -> bool:
)
def get_hint(x: Union[int, torch.SymInt]) -> Optional[int]:
def get_hint(x: Union[int, torch.SymInt]) -> int | None:
if isinstance(x, int):
return x
assert isinstance(x, torch.SymInt)
@ -100,7 +100,7 @@ def get_collective_do_bench() -> Callable[[Callable[[], Any]], float]:
)
def benchmark_node_with_cache_key(n: fx.Node) -> tuple[float, Optional[str]]:
def benchmark_node_with_cache_key(n: fx.Node) -> tuple[float, str | None]:
assert is_compute_node(n)
from torch._dynamo.testing import rand_strided
@ -115,7 +115,7 @@ def benchmark_node_with_cache_key(n: fx.Node) -> tuple[float, Optional[str]]:
key = f"{str(n.target)}: "
def to_real(t: torch.Tensor) -> Optional[torch.Tensor]:
def to_real(t: torch.Tensor) -> torch.Tensor | None:
shape = [get_hint(dim) for dim in t.shape]
stride = [get_hint(s) for s in t.stride()]
@ -177,7 +177,7 @@ class CollectiveInfo:
size_bytes: int
estimated_time_ms: float
exposed_time_ms: float # How much of this collective is still exposed
hiding_node: Optional[fx.Node] = None # Node that hides this collective
hiding_node: fx.Node | None = None # Node that hides this collective
@property
def is_exposed(self) -> bool:
@ -189,8 +189,8 @@ class CollBucket:
"""Track information about a bucket of collectives."""
collectives: list[fx.Node] # Original collective starts
bucketed_start: Optional[fx.Node] = None # After bucketing
bucketed_wait: Optional[fx.Node] = None # After bucketing
bucketed_start: fx.Node | None = None # After bucketing
bucketed_wait: fx.Node | None = None # After bucketing
total_bytes: int = 0
@ -342,7 +342,7 @@ class OverlapScheduler:
log.info(
"Overlap scheduling: Aligning runtime estimations across all distributed ranks"
)
runtime_estimations_keys: list[Optional[str]] = []
runtime_estimations_keys: list[str | None] = []
runtime_estimations: list[float] = []
for n in self.compute_nodes:
val, key = benchmark_node_with_cache_key(n)
@ -670,8 +670,8 @@ class OverlapScheduler:
available_compute_time -= overlap_amount
def _find_schedulable_path(
self, target: fx.Node, curr_compute_node: Optional[fx.Node]
) -> Optional[OrderedSet[fx.Node]]:
self, target: fx.Node, curr_compute_node: fx.Node | None
) -> OrderedSet[fx.Node] | None:
"""Find path to target by collecting unscheduled dependencies."""
# TODO - following path faster than doing set difference here
@ -725,7 +725,7 @@ class OverlapScheduler:
return self.collective_info[oldest_start].wait_node
def _wait_is_hidden(
self, wait_node: fx.Node, compute_node: Optional[fx.Node] = None
self, wait_node: fx.Node, compute_node: fx.Node | None = None
) -> bool:
assert is_wait_tensor(wait_node)
info = self.collective_info[self.wait_to_start[wait_node]]
@ -821,7 +821,7 @@ class OverlapScheduler:
used_compute_nodes: OrderedSet[fx.Node] = OrderedSet()
def could_be_hidden(start: fx.Node) -> Optional[fx.Node]:
def could_be_hidden(start: fx.Node) -> fx.Node | None:
for compute_node in self.compute_nodes:
if limit_coll_per_compute and compute_node in used_compute_nodes:
continue

View File

@ -3,7 +3,7 @@ import itertools
import operator
import typing
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Union
import torch
import torch._inductor.runtime.runtime_utils
@ -83,12 +83,10 @@ def check_dtype(a: Tensor, b: Tensor) -> bool:
return a.is_floating_point() and b.is_floating_point()
def should_pad_common(
mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None
) -> bool:
def should_pad_common(mat1: Tensor, mat2: Tensor, input: Tensor | None = None) -> bool:
# It's fine we have symbolic shapes or strides as long as they
# have hints. Later, we will make sure we only pad non-symbolic dimensions.
def valid_shape_and_stride(t: Optional[Tensor]) -> bool:
def valid_shape_and_stride(t: Tensor | None) -> bool:
if t is None:
return True
@ -153,7 +151,7 @@ def should_pad_addmm(match: Match) -> bool:
def pad_addmm(
input: Optional[Tensor],
input: Tensor | None,
mat1: Tensor,
mat2: Tensor,
m_padded_length: int,
@ -195,7 +193,7 @@ def pad_addmm(
def addmm_replace(
input: Optional[Tensor],
input: Tensor | None,
mat1: Tensor,
mat2: Tensor,
beta: float = 1.0,
@ -275,7 +273,7 @@ def should_pad_bench_key(
mat1: Tensor,
mat2: Tensor,
op: torch._ops.OpOverloadPacket,
input: Optional[Tensor] = None,
input: Tensor | None = None,
is_base_time_key: bool = False,
) -> str:
def tensor_key(t: Tensor) -> tuple[torch.Size, tuple[int, ...], torch.dtype]:
@ -285,7 +283,7 @@ def should_pad_bench_key(
None if mat1.dtype != torch.float32 else torch.backends.cuda.matmul.allow_tf32
)
def fmt_pad(name: str) -> Optional[str]:
def fmt_pad(name: str) -> str | None:
if is_base_time_key:
return None
return f"exclude_pad:{should_exclude_padding_time(match, name)}"
@ -412,7 +410,7 @@ def _should_pad_bench(
mat1: Tensor,
mat2: Tensor,
op: torch._ops.OpOverloadPacket,
input: Optional[Tensor] = None,
input: Tensor | None = None,
) -> bool:
do_bench = get_do_bench()
@ -681,10 +679,10 @@ def run_autoheuristic(
ori_time: float,
ori_time_key: str,
key: str,
) -> Optional[bool]:
) -> bool | None:
def feedback_fn(
choice: str,
) -> Optional[float]:
) -> float | None:
if choice == orig_choice:
return do_bench(orig_bench_fn)
elif choice == pad_choice:

View File

@ -5,7 +5,7 @@ import itertools
import logging
import operator
from collections import Counter, defaultdict
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, Callable, TypeVar, Union
from typing_extensions import ParamSpec
import torch
@ -1726,7 +1726,7 @@ class ConstructorMoverPass:
return False
def get_node_device(self, node: fx.Node) -> Optional[torch.device]:
def get_node_device(self, node: fx.Node) -> torch.device | None:
"""
Get the device of a node.
"""

View File

@ -5,7 +5,6 @@ import itertools
import logging
import types
from collections.abc import Sequence
from typing import Optional
import torch
import torch.nn as nn
@ -191,8 +190,8 @@ def _get_pass_name_func(p):
def _run_pre_dispatch_passes(
gm: torch.fx.GraphModule,
example_inputs: Sequence[object] = (),
add_passes: Optional[str] = None,
remove_passes: Optional[str] = None,
add_passes: str | None = None,
remove_passes: str | None = None,
) -> None:
# order matters
default_pass_list = [
@ -278,8 +277,8 @@ def _run_pre_dispatch_passes(
def pre_grad_passes(
gm: torch.fx.GraphModule,
example_inputs: Sequence[object] = (),
add_passes: Optional[str] = None,
remove_passes: Optional[str] = None,
add_passes: str | None = None,
remove_passes: str | None = None,
) -> torch.fx.GraphModule:
"""
Apply passes on the input FX graph using Torch IR.
@ -763,7 +762,7 @@ def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
# ---->
# Y2 = (W * X^T + bias.unsqueeze(-1))^T
def linear_transpose(
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
if bias is None:
return torch.matmul(weight, input.transpose(-1, -2))
@ -860,7 +859,7 @@ def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
# ---->
# Y2 = X1.transpose(-1, -2) * W1^T + bias1
def transpose_linear(
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
if bias is None:
return torch.matmul(input.transpose(-1, -2), weight.t())

View File

@ -679,7 +679,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
from torch._higher_order_ops.auto_functionalize import get_mutable_args
tensors_to_clone, _ = get_mutable_args(_mutable_op)
# Don't try to reinplace Optional[Tensor] args that are None.
# Don't try to reinplace Tensor | None args that are None.
tensors_to_clone = [
t for t in tensors_to_clone if node.kwargs[t] is not None
]

View File

@ -5,7 +5,7 @@ import operator
import os
from collections import defaultdict
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Union
from typing_extensions import TypeAlias
import torch
@ -38,10 +38,10 @@ log = logging.getLogger(__name__)
_Arguments: TypeAlias = tuple[torch.fx.node.Argument, ...]
_TransformParam: TypeAlias = tuple[
Optional[_Arguments],
Optional[_Arguments],
Optional[_Arguments],
Optional[_Arguments],
_Arguments | None,
_Arguments | None,
_Arguments | None,
_Arguments | None,
]
_Range: TypeAlias = tuple[int, int]
@ -167,7 +167,7 @@ def _get_dim(node: Any):
def normalize_split_base(
match: Match,
_get_split_args: Callable[
[torch.fx.Node], tuple[Optional[torch.fx.Node], Optional[Any], Optional[int]]
[torch.fx.Node], tuple[torch.fx.Node | None, Any | None, int | None]
],
):
"""
@ -802,7 +802,7 @@ class SplitCatSimplifier:
split_sections,
next_users,
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
) -> Optional[list[_Range]]:
) -> list[_Range] | None:
ranges = OrderedSet[Any]()
for user_inputs in user_inputs_list:
ranges.update(u for u in user_inputs if isinstance(u, tuple))
@ -848,7 +848,7 @@ class SplitCatSimplifier:
split_node: torch.fx.Node,
next_users: list[torch.fx.Node],
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
) -> Optional[list[list[_TransformParam]]]:
) -> list[list[_TransformParam]] | None:
"""
Figure out what transforms are needed for each input to each cat node.
@ -1178,7 +1178,7 @@ class UnbindCatRemover(SplitCatSimplifier):
split_sections: list[int],
next_users: list[torch.fx.Node],
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
) -> Optional[list[_Range]]:
) -> list[_Range] | None:
simplified_split_ranges = super().get_simplified_split_ranges(
split_sections, next_users, user_inputs_list
)
@ -1191,7 +1191,7 @@ class UnbindCatRemover(SplitCatSimplifier):
split_node: torch.fx.Node,
next_users: list[torch.fx.Node],
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
) -> Optional[list[list[_TransformParam]]]:
) -> list[list[_TransformParam]] | None:
"""
Figure out what transforms are needed for each input to each cat node.