diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index cd9909e5aaf6..965e0654380c 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -1,7 +1,7 @@ import collections import logging from collections import defaultdict -from typing import Any, Callable, Optional +from typing import Any, Callable import torch import torch.distributed as dist @@ -34,7 +34,7 @@ def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: return (group_name, reduce_op, dtype) -def bucket_key(node: torch.fx.Node) -> Optional[object]: +def bucket_key(node: torch.fx.Node) -> object | None: if is_all_gather_into_tensor(node): return _ag_group_key(node) elif is_reduce_scatter_tensor(node): @@ -58,8 +58,8 @@ def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float: def bucket_all_gather( gm: torch.fx.GraphModule, - bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None, - mode: Optional[str] = None, + bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None, + mode: str | None = None, ) -> None: if bucket_cap_mb_by_bucket_idx is None: from torch._inductor.fx_passes.bucketing import ( @@ -75,8 +75,8 @@ def bucket_all_gather( def bucket_reduce_scatter( gm: torch.fx.GraphModule, - bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None, - mode: Optional[str] = None, + bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None, + mode: str | None = None, ) -> None: if bucket_cap_mb_by_bucket_idx is None: from torch._inductor.fx_passes.bucketing import ( @@ -156,7 +156,7 @@ def greedy_bucket_collective_by_mb( bucket_cap_mb_by_bucket_idx: Callable[[int], float], filter_node: Callable[[torch.fx.Node], bool], node_group_key: Callable[[torch.fx.Node], Any], - filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None, + filter_wait_node: Callable[[torch.fx.Node], bool] | None = None, ) -> list[list[torch.fx.Node]]: """ Bucketing adjacent collectives with equal node_group_key. @@ -234,7 +234,7 @@ def greedy_bucket_collective_by_mb( def bucket_all_gather_by_mb( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Callable[[int], float], - filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None, + filter_wait_node: Callable[[torch.fx.Node], bool] | None = None, ) -> list[list[torch.fx.Node]]: """ Identifies all all_gather nodes and groups them into buckets, @@ -247,7 +247,7 @@ def bucket_all_gather_by_mb( to specify different sizes of the buckets at the start, as first all_gather is usually exposed. Interface of bucket_cap_mb_by_bucket_idx is `bucket_cap_mb_by_bucket_idx_default` function that is default value for `bucket_cap_mb_by_bucket_idx`. - filter_wait_node (Optional[Callable[[torch.fx.Node], bool]]): If specified, + filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified, only all_gather nodes with wait_node that satisfy `filter_wait_node` will be bucketed. Returns: @@ -266,7 +266,7 @@ def bucket_all_gather_by_mb( def bucket_reduce_scatter_by_mb( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Callable[[int], float], - filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None, + filter_wait_node: Callable[[torch.fx.Node], bool] | None = None, ) -> list[list[torch.fx.Node]]: """ Identifies all reduce_scatter nodes and groups them into buckets, @@ -277,7 +277,7 @@ def bucket_reduce_scatter_by_mb( bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow to specify different sizes of the buckets. - filter_wait_node (Optional[Callable[[torch.fx.Node], bool]]): If specified, + filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified, only reduce_scatter nodes with wait_node that satisfy `filter_wait_node` will be bucketed. Returns: @@ -577,8 +577,8 @@ def process_collective_bucket( bucket_nodes: list[torch.fx.Node], fn_to_trace: Callable[..., list[torch.Tensor]], trace_args_fn: Callable[[list[torch.fx.Node]], tuple[Any, ...]], - insert_before: Optional[torch.fx.Node] = None, - wait_insertion_point: Optional[torch.fx.Node] = None, + insert_before: torch.fx.Node | None = None, + wait_insertion_point: torch.fx.Node | None = None, ) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]: """ Process a single bucket of collective operation nodes with flexible insertion control. @@ -666,9 +666,9 @@ def process_collective_bucket( def merge_reduce_scatter_bucket( g: torch.fx.Graph, rs_nodes: list[torch.fx.Node], - mode: Optional[str] = None, - insert_before: Optional[torch.fx.Node] = None, - wait_insertion_point: Optional[torch.fx.Node] = None, + mode: str | None = None, + insert_before: torch.fx.Node | None = None, + wait_insertion_point: torch.fx.Node | None = None, ) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]: # Validate bucket consistency rs0 = rs_nodes[0] @@ -716,9 +716,9 @@ def merge_reduce_scatter_bucket( def merge_all_gather_bucket( g: torch.fx.Graph, ag_nodes: list[torch.fx.Node], - mode: Optional[str] = None, - insert_before: Optional[torch.fx.Node] = None, - wait_insertion_point: Optional[torch.fx.Node] = None, + mode: str | None = None, + insert_before: torch.fx.Node | None = None, + wait_insertion_point: torch.fx.Node | None = None, ) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]: from torch.distributed.distributed_c10d import _resolve_process_group @@ -764,7 +764,7 @@ def merge_all_gather_bucket( def merge_reduce_scatter( gm: torch.fx.GraphModule, rs_buckets: list[list[torch.fx.Node]], - mode: Optional[str] = None, + mode: str | None = None, ) -> None: """ Merges specified buckets of reduce_scatter to joint reduce_scatter. @@ -788,7 +788,7 @@ def merge_reduce_scatter( def merge_all_gather( gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]], - mode: Optional[str] = None, + mode: str | None = None, ) -> None: """ Merges specified buckets of all_gather to joint all_gather. diff --git a/torch/_inductor/fx_passes/ddp_fusion.py b/torch/_inductor/fx_passes/ddp_fusion.py index 8f5cc7bc5d2b..d8b26ddf7a9b 100644 --- a/torch/_inductor/fx_passes/ddp_fusion.py +++ b/torch/_inductor/fx_passes/ddp_fusion.py @@ -7,7 +7,7 @@ import operator from collections.abc import Generator from dataclasses import dataclass from functools import partial -from typing import Any, Callable, cast, Optional, Union +from typing import Any, Callable, cast, Union import torch import torch.fx as fx @@ -40,8 +40,8 @@ def move_block_before(block: list[fx.Node], target_node: fx.Node) -> None: def call_function( graph: fx.Graph, target: Union[str, Callable[..., Any]], - args: Optional[tuple[fx.node.Argument, ...]] = None, - kwargs: Optional[dict[str, fx.node.Argument]] = None, + args: tuple[fx.node.Argument, ...] | None = None, + kwargs: dict[str, fx.node.Argument] | None = None, ) -> fx.Node: # We accept target as a str to avoid typing error as the type of # a node.target is Union[str, Callable[..., Any]]. @@ -70,7 +70,7 @@ class CommBlock: outputs: OrderedSet[fx.Node] -def get_comm_block(comm_node: fx.Node) -> Optional[CommBlock]: +def get_comm_block(comm_node: fx.Node) -> CommBlock | None: """ Given a collective node (e.g., allreduce), find out all the nodes belong to this communication. @@ -150,7 +150,7 @@ def get_comm_block(comm_node: fx.Node) -> Optional[CommBlock]: def get_all_comm_blocks( graph: fx.Graph, comm_ops: tuple[torch._ops.OpOverload, ...], - comm_filter: Optional[Callable[..., bool]] = None, + comm_filter: Callable[..., bool] | None = None, ) -> list[CommBlock]: if comm_filter is None: diff --git a/torch/_inductor/fx_passes/fsdp.py b/torch/_inductor/fx_passes/fsdp.py index e7e574ae4934..73787bd928a5 100644 --- a/torch/_inductor/fx_passes/fsdp.py +++ b/torch/_inductor/fx_passes/fsdp.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Optional +from typing import Callable import torch from torch._inductor.fx_passes.bucketing import ( @@ -55,15 +55,15 @@ def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool: def bucket_fsdp_all_gather( gm: torch.fx.GraphModule, - bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None, - mode: Optional[str] = None, + bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None, + mode: str | None = None, ) -> None: """ Bucketing pass for SimpleFSDP all_gather ops. Attributes: gm (torch.fx.GraphModule): Graph module of the graph. - bucket_cap_mb_by_bucket_idx (Optional[Callable[[int], float]]): callback function that + bucket_cap_mb_by_bucket_idx (Callable[[int], float] | None): callback function that takes in bucket id and returns size of a bucket in megabytes. """ if bucket_cap_mb_by_bucket_idx is None: @@ -85,15 +85,15 @@ def bucket_fsdp_all_gather( def bucket_fsdp_reduce_scatter( gm: torch.fx.GraphModule, - bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None, - mode: Optional[str] = None, + bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None, + mode: str | None = None, ) -> None: """ Bucketing pass for SimpleFSDP reduce_scatter ops. Attributes: gm (torch.fx.GraphModule): Graph module of the graph. - bucket_cap_mb_by_bucket_idx (Optional[Callable[[int], float]]): callback function that + bucket_cap_mb_by_bucket_idx (Callable[[int], float] | None): callback function that takes in bucket idx and returns size of a bucket in megabytes. By default torch._inductor.fx_passes.bucketing.bucket_cap_mb_by_bucket_idx_default is used. diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 743d9a1b85a0..a8e2a4816ec0 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -4,7 +4,7 @@ import logging import operator from collections import OrderedDict from collections.abc import Iterable, Iterator -from typing import Any, Optional +from typing import Any import torch from torch._dynamo.utils import counters, is_node_meta_valid @@ -185,9 +185,7 @@ class PostGradBatchLinearFusion(BatchFusion): and isinstance(input_shapes[1], int) ) - def match( - self, node: torch.fx.Node - ) -> Optional[tuple[str, int, int, int, bool, str]]: + def match(self, node: torch.fx.Node) -> tuple[str, int, int, int, bool, str] | None: if CallFunctionVarArgs(aten.mm).match(node): input_m, weight_m = node.args bias_m = None @@ -325,7 +323,7 @@ class GroupLinearFusion(GroupFusion): ) ) - def match(self, node: torch.fx.Node) -> Optional[tuple[str, bool]]: + def match(self, node: torch.fx.Node) -> tuple[str, bool] | None: if CallFunctionVarArgs(aten.mm.default).match( node ) and self._mm_node_can_be_fused(node): @@ -493,7 +491,7 @@ class BatchLinearLHSFusion(BatchFusion): We have a separate pass to eliminate contiguous transpose in a generic way. """ - def match(self, node: torch.fx.Node) -> Optional[tuple[str, bool, Any]]: + def match(self, node: torch.fx.Node) -> tuple[str, bool, Any] | None: if CallFunctionVarArgs(torch.nn.functional.linear).match( node ) and is_linear_node_can_be_fused(node): diff --git a/torch/_inductor/fx_passes/memory_estimator.py b/torch/_inductor/fx_passes/memory_estimator.py index 3c941c9dc08f..f4bb1cc72cbf 100644 --- a/torch/_inductor/fx_passes/memory_estimator.py +++ b/torch/_inductor/fx_passes/memory_estimator.py @@ -2,7 +2,7 @@ import itertools import logging from collections import defaultdict from dataclasses import dataclass -from typing import Callable, Optional, Union +from typing import Callable, Union import torch import torch.fx as fx @@ -154,7 +154,7 @@ def device_filter(device: torch.device) -> bool: def build_memory_profile( graph: fx.Graph, is_releasable: Callable[[fx.Node], bool], - size_of: Optional[Callable[[Union[int, torch.SymInt]], int]] = None, + size_of: Callable[[Union[int, torch.SymInt]], int] | None = None, ) -> list[int]: """ Function to estimate the memory profile of an input FX graph. @@ -216,7 +216,7 @@ def build_memory_profile( def get_fwd_bwd_interactions( fwd_graph: fx.Graph, bwd_graph: fx.Graph, - size_of: Optional[Callable[[Union[int, torch.SymInt]], int]] = None, + size_of: Callable[[Union[int, torch.SymInt]], int] | None = None, ) -> tuple[int, OrderedSet[str]]: """ Analyze the interactions between the forward (fwd) and backward (bwd) graphs @@ -325,8 +325,8 @@ class MemoryTracker: def __init__( self, graph: fx.Graph, - is_releasable: Optional[Callable[[fx.Node], bool]] = None, - device_filter: Optional[Callable[[torch.device], bool]] = None, + is_releasable: Callable[[fx.Node], bool] | None = None, + device_filter: Callable[[torch.device], bool] | None = None, ): """ Initialize memory tracker for alternative scheduling of the given graph. diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index 4a4b3456f4a3..713143ec02fe 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -4,7 +4,7 @@ import operator from collections import defaultdict from dataclasses import dataclass, field from math import prod -from typing import Any, cast, Optional +from typing import Any, cast import torch from torch.utils._ordered_set import OrderedSet @@ -374,8 +374,8 @@ class _Matmul: arg_ancestor_nodes: OrderedSet[torch.fx.Node] = field(init=False) A_node: torch.fx.Node B_node: torch.fx.Node - pre_mm_reshape: Optional[torch.fx.Node] - post_mm_reshape: Optional[torch.fx.Node] + pre_mm_reshape: torch.fx.Node | None + post_mm_reshape: torch.fx.Node | None def __post_init__(self): assert len(self.nodes) in (1, 3) @@ -450,12 +450,12 @@ class _Matmul: class _ScaledMatmul(_Matmul): A_scale_node: torch.fx.Node B_scale_node: torch.fx.Node - bias_node: Optional[torch.fx.Node] - result_scale_node: Optional[torch.fx.Node] - out_dtype: Optional[torch.dtype] + bias_node: torch.fx.Node | None + result_scale_node: torch.fx.Node | None + out_dtype: torch.dtype | None use_fast_accum: bool - pre_mm_reshape: Optional[torch.fx.Node] - post_mm_reshape: Optional[torch.fx.Node] + pre_mm_reshape: torch.fx.Node | None + post_mm_reshape: torch.fx.Node | None def __post_init__(self): super().__post_init__() @@ -763,7 +763,7 @@ def _scatter_dim_after_reshape( return 0 if leading_dims_collapsed else 1 -def _find_producer_matmul(node: torch.fx.Node) -> Optional[_Matmul]: +def _find_producer_matmul(node: torch.fx.Node) -> _Matmul | None: """ Returns producer matmul node if found, otherwise returns None. """ diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index ad9b835372ec..9f02b2549eda 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -6,7 +6,7 @@ import sys from collections import Counter, defaultdict from collections.abc import Iterable from dataclasses import dataclass -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Union import torch import torch.fx as fx @@ -42,7 +42,7 @@ def get_group_name(n: fx.Node) -> str: return kwargs["group_name"] -def get_custom_estimation(n: fx.Node) -> Optional[float]: +def get_custom_estimation(n: fx.Node) -> float | None: runtime_estimation = torch._inductor.config.test_configs.estimate_aten_runtime if runtime_estimation == "default": return None @@ -51,7 +51,7 @@ def get_custom_estimation(n: fx.Node) -> Optional[float]: return runtime_estimation(n) -def estimate_collective_time(n: fx.Node, override_size: Optional[int] = None) -> float: +def estimate_collective_time(n: fx.Node, override_size: int | None = None) -> float: """Estimate the runtime of a collective operation, optionally with an overridden size.""" if (est := get_custom_estimation(n)) is not None: return est @@ -82,7 +82,7 @@ def is_compute_node(n: fx.Node) -> bool: ) -def get_hint(x: Union[int, torch.SymInt]) -> Optional[int]: +def get_hint(x: Union[int, torch.SymInt]) -> int | None: if isinstance(x, int): return x assert isinstance(x, torch.SymInt) @@ -100,7 +100,7 @@ def get_collective_do_bench() -> Callable[[Callable[[], Any]], float]: ) -def benchmark_node_with_cache_key(n: fx.Node) -> tuple[float, Optional[str]]: +def benchmark_node_with_cache_key(n: fx.Node) -> tuple[float, str | None]: assert is_compute_node(n) from torch._dynamo.testing import rand_strided @@ -115,7 +115,7 @@ def benchmark_node_with_cache_key(n: fx.Node) -> tuple[float, Optional[str]]: key = f"{str(n.target)}: " - def to_real(t: torch.Tensor) -> Optional[torch.Tensor]: + def to_real(t: torch.Tensor) -> torch.Tensor | None: shape = [get_hint(dim) for dim in t.shape] stride = [get_hint(s) for s in t.stride()] @@ -177,7 +177,7 @@ class CollectiveInfo: size_bytes: int estimated_time_ms: float exposed_time_ms: float # How much of this collective is still exposed - hiding_node: Optional[fx.Node] = None # Node that hides this collective + hiding_node: fx.Node | None = None # Node that hides this collective @property def is_exposed(self) -> bool: @@ -189,8 +189,8 @@ class CollBucket: """Track information about a bucket of collectives.""" collectives: list[fx.Node] # Original collective starts - bucketed_start: Optional[fx.Node] = None # After bucketing - bucketed_wait: Optional[fx.Node] = None # After bucketing + bucketed_start: fx.Node | None = None # After bucketing + bucketed_wait: fx.Node | None = None # After bucketing total_bytes: int = 0 @@ -342,7 +342,7 @@ class OverlapScheduler: log.info( "Overlap scheduling: Aligning runtime estimations across all distributed ranks" ) - runtime_estimations_keys: list[Optional[str]] = [] + runtime_estimations_keys: list[str | None] = [] runtime_estimations: list[float] = [] for n in self.compute_nodes: val, key = benchmark_node_with_cache_key(n) @@ -670,8 +670,8 @@ class OverlapScheduler: available_compute_time -= overlap_amount def _find_schedulable_path( - self, target: fx.Node, curr_compute_node: Optional[fx.Node] - ) -> Optional[OrderedSet[fx.Node]]: + self, target: fx.Node, curr_compute_node: fx.Node | None + ) -> OrderedSet[fx.Node] | None: """Find path to target by collecting unscheduled dependencies.""" # TODO - following path faster than doing set difference here @@ -725,7 +725,7 @@ class OverlapScheduler: return self.collective_info[oldest_start].wait_node def _wait_is_hidden( - self, wait_node: fx.Node, compute_node: Optional[fx.Node] = None + self, wait_node: fx.Node, compute_node: fx.Node | None = None ) -> bool: assert is_wait_tensor(wait_node) info = self.collective_info[self.wait_to_start[wait_node]] @@ -821,7 +821,7 @@ class OverlapScheduler: used_compute_nodes: OrderedSet[fx.Node] = OrderedSet() - def could_be_hidden(start: fx.Node) -> Optional[fx.Node]: + def could_be_hidden(start: fx.Node) -> fx.Node | None: for compute_node in self.compute_nodes: if limit_coll_per_compute and compute_node in used_compute_nodes: continue diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index f58678e7651e..8d1b31eb4067 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -3,7 +3,7 @@ import itertools import operator import typing from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Union import torch import torch._inductor.runtime.runtime_utils @@ -83,12 +83,10 @@ def check_dtype(a: Tensor, b: Tensor) -> bool: return a.is_floating_point() and b.is_floating_point() -def should_pad_common( - mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None -) -> bool: +def should_pad_common(mat1: Tensor, mat2: Tensor, input: Tensor | None = None) -> bool: # It's fine we have symbolic shapes or strides as long as they # have hints. Later, we will make sure we only pad non-symbolic dimensions. - def valid_shape_and_stride(t: Optional[Tensor]) -> bool: + def valid_shape_and_stride(t: Tensor | None) -> bool: if t is None: return True @@ -153,7 +151,7 @@ def should_pad_addmm(match: Match) -> bool: def pad_addmm( - input: Optional[Tensor], + input: Tensor | None, mat1: Tensor, mat2: Tensor, m_padded_length: int, @@ -195,7 +193,7 @@ def pad_addmm( def addmm_replace( - input: Optional[Tensor], + input: Tensor | None, mat1: Tensor, mat2: Tensor, beta: float = 1.0, @@ -275,7 +273,7 @@ def should_pad_bench_key( mat1: Tensor, mat2: Tensor, op: torch._ops.OpOverloadPacket, - input: Optional[Tensor] = None, + input: Tensor | None = None, is_base_time_key: bool = False, ) -> str: def tensor_key(t: Tensor) -> tuple[torch.Size, tuple[int, ...], torch.dtype]: @@ -285,7 +283,7 @@ def should_pad_bench_key( None if mat1.dtype != torch.float32 else torch.backends.cuda.matmul.allow_tf32 ) - def fmt_pad(name: str) -> Optional[str]: + def fmt_pad(name: str) -> str | None: if is_base_time_key: return None return f"exclude_pad:{should_exclude_padding_time(match, name)}" @@ -412,7 +410,7 @@ def _should_pad_bench( mat1: Tensor, mat2: Tensor, op: torch._ops.OpOverloadPacket, - input: Optional[Tensor] = None, + input: Tensor | None = None, ) -> bool: do_bench = get_do_bench() @@ -681,10 +679,10 @@ def run_autoheuristic( ori_time: float, ori_time_key: str, key: str, -) -> Optional[bool]: +) -> bool | None: def feedback_fn( choice: str, - ) -> Optional[float]: + ) -> float | None: if choice == orig_choice: return do_bench(orig_bench_fn) elif choice == pad_choice: diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index db9f6f8563e6..938e15deedb2 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -5,7 +5,7 @@ import itertools import logging import operator from collections import Counter, defaultdict -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, TypeVar, Union from typing_extensions import ParamSpec import torch @@ -1726,7 +1726,7 @@ class ConstructorMoverPass: return False - def get_node_device(self, node: fx.Node) -> Optional[torch.device]: + def get_node_device(self, node: fx.Node) -> torch.device | None: """ Get the device of a node. """ diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index 597013f6233c..238c6556b5c2 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -5,7 +5,6 @@ import itertools import logging import types from collections.abc import Sequence -from typing import Optional import torch import torch.nn as nn @@ -191,8 +190,8 @@ def _get_pass_name_func(p): def _run_pre_dispatch_passes( gm: torch.fx.GraphModule, example_inputs: Sequence[object] = (), - add_passes: Optional[str] = None, - remove_passes: Optional[str] = None, + add_passes: str | None = None, + remove_passes: str | None = None, ) -> None: # order matters default_pass_list = [ @@ -278,8 +277,8 @@ def _run_pre_dispatch_passes( def pre_grad_passes( gm: torch.fx.GraphModule, example_inputs: Sequence[object] = (), - add_passes: Optional[str] = None, - remove_passes: Optional[str] = None, + add_passes: str | None = None, + remove_passes: str | None = None, ) -> torch.fx.GraphModule: """ Apply passes on the input FX graph using Torch IR. @@ -763,7 +762,7 @@ def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: # ----> # Y2 = (W * X^T + bias.unsqueeze(-1))^T def linear_transpose( - input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None ) -> torch.Tensor: if bias is None: return torch.matmul(weight, input.transpose(-1, -2)) @@ -860,7 +859,7 @@ def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: # ----> # Y2 = X1.transpose(-1, -2) * W1^T + bias1 def transpose_linear( - input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None ) -> torch.Tensor: if bias is None: return torch.matmul(input.transpose(-1, -2), weight.t()) diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index 242bb98d4584..ee9fe6aff780 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -679,7 +679,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None: from torch._higher_order_ops.auto_functionalize import get_mutable_args tensors_to_clone, _ = get_mutable_args(_mutable_op) - # Don't try to reinplace Optional[Tensor] args that are None. + # Don't try to reinplace Tensor | None args that are None. tensors_to_clone = [ t for t in tensors_to_clone if node.kwargs[t] is not None ] diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 9b0f5956cce6..015e33274434 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -5,7 +5,7 @@ import operator import os from collections import defaultdict from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Union from typing_extensions import TypeAlias import torch @@ -38,10 +38,10 @@ log = logging.getLogger(__name__) _Arguments: TypeAlias = tuple[torch.fx.node.Argument, ...] _TransformParam: TypeAlias = tuple[ - Optional[_Arguments], - Optional[_Arguments], - Optional[_Arguments], - Optional[_Arguments], + _Arguments | None, + _Arguments | None, + _Arguments | None, + _Arguments | None, ] _Range: TypeAlias = tuple[int, int] @@ -167,7 +167,7 @@ def _get_dim(node: Any): def normalize_split_base( match: Match, _get_split_args: Callable[ - [torch.fx.Node], tuple[Optional[torch.fx.Node], Optional[Any], Optional[int]] + [torch.fx.Node], tuple[torch.fx.Node | None, Any | None, int | None] ], ): """ @@ -802,7 +802,7 @@ class SplitCatSimplifier: split_sections, next_users, user_inputs_list: list[list[Union[torch.fx.Node, _Range]]], - ) -> Optional[list[_Range]]: + ) -> list[_Range] | None: ranges = OrderedSet[Any]() for user_inputs in user_inputs_list: ranges.update(u for u in user_inputs if isinstance(u, tuple)) @@ -848,7 +848,7 @@ class SplitCatSimplifier: split_node: torch.fx.Node, next_users: list[torch.fx.Node], user_inputs_list: list[list[Union[torch.fx.Node, _Range]]], - ) -> Optional[list[list[_TransformParam]]]: + ) -> list[list[_TransformParam]] | None: """ Figure out what transforms are needed for each input to each cat node. @@ -1178,7 +1178,7 @@ class UnbindCatRemover(SplitCatSimplifier): split_sections: list[int], next_users: list[torch.fx.Node], user_inputs_list: list[list[Union[torch.fx.Node, _Range]]], - ) -> Optional[list[_Range]]: + ) -> list[_Range] | None: simplified_split_ranges = super().get_simplified_split_ranges( split_sections, next_users, user_inputs_list ) @@ -1191,7 +1191,7 @@ class UnbindCatRemover(SplitCatSimplifier): split_node: torch.fx.Node, next_users: list[torch.fx.Node], user_inputs_list: list[list[Union[torch.fx.Node, _Range]]], - ) -> Optional[list[list[_TransformParam]]]: + ) -> list[list[_TransformParam]] | None: """ Figure out what transforms are needed for each input to each cat node.