mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Codemod inductor/fx_passes from Optional to union none (#165606)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165606 Approved by: https://github.com/aorenste ghstack dependencies: #165604, #165605
This commit is contained in:
committed by
PyTorch MergeBot
parent
ab6014a903
commit
5d0b22008d
@ -1,7 +1,7 @@
|
||||
import collections
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -34,7 +34,7 @@ def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
|
||||
return (group_name, reduce_op, dtype)
|
||||
|
||||
|
||||
def bucket_key(node: torch.fx.Node) -> Optional[object]:
|
||||
def bucket_key(node: torch.fx.Node) -> object | None:
|
||||
if is_all_gather_into_tensor(node):
|
||||
return _ag_group_key(node)
|
||||
elif is_reduce_scatter_tensor(node):
|
||||
@ -58,8 +58,8 @@ def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
|
||||
|
||||
def bucket_all_gather(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
|
||||
mode: Optional[str] = None,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
||||
mode: str | None = None,
|
||||
) -> None:
|
||||
if bucket_cap_mb_by_bucket_idx is None:
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
@ -75,8 +75,8 @@ def bucket_all_gather(
|
||||
|
||||
def bucket_reduce_scatter(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
|
||||
mode: Optional[str] = None,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
||||
mode: str | None = None,
|
||||
) -> None:
|
||||
if bucket_cap_mb_by_bucket_idx is None:
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
@ -156,7 +156,7 @@ def greedy_bucket_collective_by_mb(
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
||||
filter_node: Callable[[torch.fx.Node], bool],
|
||||
node_group_key: Callable[[torch.fx.Node], Any],
|
||||
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
|
||||
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
"""
|
||||
Bucketing adjacent collectives with equal node_group_key.
|
||||
@ -234,7 +234,7 @@ def greedy_bucket_collective_by_mb(
|
||||
def bucket_all_gather_by_mb(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
||||
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
|
||||
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
"""
|
||||
Identifies all all_gather nodes and groups them into buckets,
|
||||
@ -247,7 +247,7 @@ def bucket_all_gather_by_mb(
|
||||
to specify different sizes of the buckets at the start,
|
||||
as first all_gather is usually exposed. Interface of bucket_cap_mb_by_bucket_idx
|
||||
is `bucket_cap_mb_by_bucket_idx_default` function that is default value for `bucket_cap_mb_by_bucket_idx`.
|
||||
filter_wait_node (Optional[Callable[[torch.fx.Node], bool]]): If specified,
|
||||
filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified,
|
||||
only all_gather nodes with wait_node that satisfy `filter_wait_node` will be bucketed.
|
||||
|
||||
Returns:
|
||||
@ -266,7 +266,7 @@ def bucket_all_gather_by_mb(
|
||||
def bucket_reduce_scatter_by_mb(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
||||
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
|
||||
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
"""
|
||||
Identifies all reduce_scatter nodes and groups them into buckets,
|
||||
@ -277,7 +277,7 @@ def bucket_reduce_scatter_by_mb(
|
||||
bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
|
||||
in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow
|
||||
to specify different sizes of the buckets.
|
||||
filter_wait_node (Optional[Callable[[torch.fx.Node], bool]]): If specified,
|
||||
filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified,
|
||||
only reduce_scatter nodes with wait_node that satisfy `filter_wait_node` will be bucketed.
|
||||
|
||||
Returns:
|
||||
@ -577,8 +577,8 @@ def process_collective_bucket(
|
||||
bucket_nodes: list[torch.fx.Node],
|
||||
fn_to_trace: Callable[..., list[torch.Tensor]],
|
||||
trace_args_fn: Callable[[list[torch.fx.Node]], tuple[Any, ...]],
|
||||
insert_before: Optional[torch.fx.Node] = None,
|
||||
wait_insertion_point: Optional[torch.fx.Node] = None,
|
||||
insert_before: torch.fx.Node | None = None,
|
||||
wait_insertion_point: torch.fx.Node | None = None,
|
||||
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
||||
"""
|
||||
Process a single bucket of collective operation nodes with flexible insertion control.
|
||||
@ -666,9 +666,9 @@ def process_collective_bucket(
|
||||
def merge_reduce_scatter_bucket(
|
||||
g: torch.fx.Graph,
|
||||
rs_nodes: list[torch.fx.Node],
|
||||
mode: Optional[str] = None,
|
||||
insert_before: Optional[torch.fx.Node] = None,
|
||||
wait_insertion_point: Optional[torch.fx.Node] = None,
|
||||
mode: str | None = None,
|
||||
insert_before: torch.fx.Node | None = None,
|
||||
wait_insertion_point: torch.fx.Node | None = None,
|
||||
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
||||
# Validate bucket consistency
|
||||
rs0 = rs_nodes[0]
|
||||
@ -716,9 +716,9 @@ def merge_reduce_scatter_bucket(
|
||||
def merge_all_gather_bucket(
|
||||
g: torch.fx.Graph,
|
||||
ag_nodes: list[torch.fx.Node],
|
||||
mode: Optional[str] = None,
|
||||
insert_before: Optional[torch.fx.Node] = None,
|
||||
wait_insertion_point: Optional[torch.fx.Node] = None,
|
||||
mode: str | None = None,
|
||||
insert_before: torch.fx.Node | None = None,
|
||||
wait_insertion_point: torch.fx.Node | None = None,
|
||||
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
||||
from torch.distributed.distributed_c10d import _resolve_process_group
|
||||
|
||||
@ -764,7 +764,7 @@ def merge_all_gather_bucket(
|
||||
def merge_reduce_scatter(
|
||||
gm: torch.fx.GraphModule,
|
||||
rs_buckets: list[list[torch.fx.Node]],
|
||||
mode: Optional[str] = None,
|
||||
mode: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Merges specified buckets of reduce_scatter to joint reduce_scatter.
|
||||
@ -788,7 +788,7 @@ def merge_reduce_scatter(
|
||||
def merge_all_gather(
|
||||
gm: torch.fx.GraphModule,
|
||||
ag_buckets: list[list[torch.fx.Node]],
|
||||
mode: Optional[str] = None,
|
||||
mode: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Merges specified buckets of all_gather to joint all_gather.
|
||||
|
@ -7,7 +7,7 @@ import operator
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
from typing import Any, Callable, cast, Union
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
@ -40,8 +40,8 @@ def move_block_before(block: list[fx.Node], target_node: fx.Node) -> None:
|
||||
def call_function(
|
||||
graph: fx.Graph,
|
||||
target: Union[str, Callable[..., Any]],
|
||||
args: Optional[tuple[fx.node.Argument, ...]] = None,
|
||||
kwargs: Optional[dict[str, fx.node.Argument]] = None,
|
||||
args: tuple[fx.node.Argument, ...] | None = None,
|
||||
kwargs: dict[str, fx.node.Argument] | None = None,
|
||||
) -> fx.Node:
|
||||
# We accept target as a str to avoid typing error as the type of
|
||||
# a node.target is Union[str, Callable[..., Any]].
|
||||
@ -70,7 +70,7 @@ class CommBlock:
|
||||
outputs: OrderedSet[fx.Node]
|
||||
|
||||
|
||||
def get_comm_block(comm_node: fx.Node) -> Optional[CommBlock]:
|
||||
def get_comm_block(comm_node: fx.Node) -> CommBlock | None:
|
||||
"""
|
||||
Given a collective node (e.g., allreduce), find out all the nodes belong to
|
||||
this communication.
|
||||
@ -150,7 +150,7 @@ def get_comm_block(comm_node: fx.Node) -> Optional[CommBlock]:
|
||||
def get_all_comm_blocks(
|
||||
graph: fx.Graph,
|
||||
comm_ops: tuple[torch._ops.OpOverload, ...],
|
||||
comm_filter: Optional[Callable[..., bool]] = None,
|
||||
comm_filter: Callable[..., bool] | None = None,
|
||||
) -> list[CommBlock]:
|
||||
if comm_filter is None:
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
@ -55,15 +55,15 @@ def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool:
|
||||
|
||||
def bucket_fsdp_all_gather(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
|
||||
mode: Optional[str] = None,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
||||
mode: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Bucketing pass for SimpleFSDP all_gather ops.
|
||||
|
||||
Attributes:
|
||||
gm (torch.fx.GraphModule): Graph module of the graph.
|
||||
bucket_cap_mb_by_bucket_idx (Optional[Callable[[int], float]]): callback function that
|
||||
bucket_cap_mb_by_bucket_idx (Callable[[int], float] | None): callback function that
|
||||
takes in bucket id and returns size of a bucket in megabytes.
|
||||
"""
|
||||
if bucket_cap_mb_by_bucket_idx is None:
|
||||
@ -85,15 +85,15 @@ def bucket_fsdp_all_gather(
|
||||
|
||||
def bucket_fsdp_reduce_scatter(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
|
||||
mode: Optional[str] = None,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
||||
mode: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Bucketing pass for SimpleFSDP reduce_scatter ops.
|
||||
|
||||
Attributes:
|
||||
gm (torch.fx.GraphModule): Graph module of the graph.
|
||||
bucket_cap_mb_by_bucket_idx (Optional[Callable[[int], float]]): callback function that
|
||||
bucket_cap_mb_by_bucket_idx (Callable[[int], float] | None): callback function that
|
||||
takes in bucket idx and returns size of a bucket in megabytes. By default
|
||||
torch._inductor.fx_passes.bucketing.bucket_cap_mb_by_bucket_idx_default is used.
|
||||
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
import operator
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters, is_node_meta_valid
|
||||
@ -185,9 +185,7 @@ class PostGradBatchLinearFusion(BatchFusion):
|
||||
and isinstance(input_shapes[1], int)
|
||||
)
|
||||
|
||||
def match(
|
||||
self, node: torch.fx.Node
|
||||
) -> Optional[tuple[str, int, int, int, bool, str]]:
|
||||
def match(self, node: torch.fx.Node) -> tuple[str, int, int, int, bool, str] | None:
|
||||
if CallFunctionVarArgs(aten.mm).match(node):
|
||||
input_m, weight_m = node.args
|
||||
bias_m = None
|
||||
@ -325,7 +323,7 @@ class GroupLinearFusion(GroupFusion):
|
||||
)
|
||||
)
|
||||
|
||||
def match(self, node: torch.fx.Node) -> Optional[tuple[str, bool]]:
|
||||
def match(self, node: torch.fx.Node) -> tuple[str, bool] | None:
|
||||
if CallFunctionVarArgs(aten.mm.default).match(
|
||||
node
|
||||
) and self._mm_node_can_be_fused(node):
|
||||
@ -493,7 +491,7 @@ class BatchLinearLHSFusion(BatchFusion):
|
||||
We have a separate pass to eliminate contiguous transpose in a generic way.
|
||||
"""
|
||||
|
||||
def match(self, node: torch.fx.Node) -> Optional[tuple[str, bool, Any]]:
|
||||
def match(self, node: torch.fx.Node) -> tuple[str, bool, Any] | None:
|
||||
if CallFunctionVarArgs(torch.nn.functional.linear).match(
|
||||
node
|
||||
) and is_linear_node_can_be_fused(node):
|
||||
|
@ -2,7 +2,7 @@ import itertools
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
@ -154,7 +154,7 @@ def device_filter(device: torch.device) -> bool:
|
||||
def build_memory_profile(
|
||||
graph: fx.Graph,
|
||||
is_releasable: Callable[[fx.Node], bool],
|
||||
size_of: Optional[Callable[[Union[int, torch.SymInt]], int]] = None,
|
||||
size_of: Callable[[Union[int, torch.SymInt]], int] | None = None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Function to estimate the memory profile of an input FX graph.
|
||||
@ -216,7 +216,7 @@ def build_memory_profile(
|
||||
def get_fwd_bwd_interactions(
|
||||
fwd_graph: fx.Graph,
|
||||
bwd_graph: fx.Graph,
|
||||
size_of: Optional[Callable[[Union[int, torch.SymInt]], int]] = None,
|
||||
size_of: Callable[[Union[int, torch.SymInt]], int] | None = None,
|
||||
) -> tuple[int, OrderedSet[str]]:
|
||||
"""
|
||||
Analyze the interactions between the forward (fwd) and backward (bwd) graphs
|
||||
@ -325,8 +325,8 @@ class MemoryTracker:
|
||||
def __init__(
|
||||
self,
|
||||
graph: fx.Graph,
|
||||
is_releasable: Optional[Callable[[fx.Node], bool]] = None,
|
||||
device_filter: Optional[Callable[[torch.device], bool]] = None,
|
||||
is_releasable: Callable[[fx.Node], bool] | None = None,
|
||||
device_filter: Callable[[torch.device], bool] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize memory tracker for alternative scheduling of the given graph.
|
||||
|
@ -4,7 +4,7 @@ import operator
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from math import prod
|
||||
from typing import Any, cast, Optional
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -374,8 +374,8 @@ class _Matmul:
|
||||
arg_ancestor_nodes: OrderedSet[torch.fx.Node] = field(init=False)
|
||||
A_node: torch.fx.Node
|
||||
B_node: torch.fx.Node
|
||||
pre_mm_reshape: Optional[torch.fx.Node]
|
||||
post_mm_reshape: Optional[torch.fx.Node]
|
||||
pre_mm_reshape: torch.fx.Node | None
|
||||
post_mm_reshape: torch.fx.Node | None
|
||||
|
||||
def __post_init__(self):
|
||||
assert len(self.nodes) in (1, 3)
|
||||
@ -450,12 +450,12 @@ class _Matmul:
|
||||
class _ScaledMatmul(_Matmul):
|
||||
A_scale_node: torch.fx.Node
|
||||
B_scale_node: torch.fx.Node
|
||||
bias_node: Optional[torch.fx.Node]
|
||||
result_scale_node: Optional[torch.fx.Node]
|
||||
out_dtype: Optional[torch.dtype]
|
||||
bias_node: torch.fx.Node | None
|
||||
result_scale_node: torch.fx.Node | None
|
||||
out_dtype: torch.dtype | None
|
||||
use_fast_accum: bool
|
||||
pre_mm_reshape: Optional[torch.fx.Node]
|
||||
post_mm_reshape: Optional[torch.fx.Node]
|
||||
pre_mm_reshape: torch.fx.Node | None
|
||||
post_mm_reshape: torch.fx.Node | None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
@ -763,7 +763,7 @@ def _scatter_dim_after_reshape(
|
||||
return 0 if leading_dims_collapsed else 1
|
||||
|
||||
|
||||
def _find_producer_matmul(node: torch.fx.Node) -> Optional[_Matmul]:
|
||||
def _find_producer_matmul(node: torch.fx.Node) -> _Matmul | None:
|
||||
"""
|
||||
Returns producer matmul node if found, otherwise returns None.
|
||||
"""
|
||||
|
@ -6,7 +6,7 @@ import sys
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
@ -42,7 +42,7 @@ def get_group_name(n: fx.Node) -> str:
|
||||
return kwargs["group_name"]
|
||||
|
||||
|
||||
def get_custom_estimation(n: fx.Node) -> Optional[float]:
|
||||
def get_custom_estimation(n: fx.Node) -> float | None:
|
||||
runtime_estimation = torch._inductor.config.test_configs.estimate_aten_runtime
|
||||
if runtime_estimation == "default":
|
||||
return None
|
||||
@ -51,7 +51,7 @@ def get_custom_estimation(n: fx.Node) -> Optional[float]:
|
||||
return runtime_estimation(n)
|
||||
|
||||
|
||||
def estimate_collective_time(n: fx.Node, override_size: Optional[int] = None) -> float:
|
||||
def estimate_collective_time(n: fx.Node, override_size: int | None = None) -> float:
|
||||
"""Estimate the runtime of a collective operation, optionally with an overridden size."""
|
||||
if (est := get_custom_estimation(n)) is not None:
|
||||
return est
|
||||
@ -82,7 +82,7 @@ def is_compute_node(n: fx.Node) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def get_hint(x: Union[int, torch.SymInt]) -> Optional[int]:
|
||||
def get_hint(x: Union[int, torch.SymInt]) -> int | None:
|
||||
if isinstance(x, int):
|
||||
return x
|
||||
assert isinstance(x, torch.SymInt)
|
||||
@ -100,7 +100,7 @@ def get_collective_do_bench() -> Callable[[Callable[[], Any]], float]:
|
||||
)
|
||||
|
||||
|
||||
def benchmark_node_with_cache_key(n: fx.Node) -> tuple[float, Optional[str]]:
|
||||
def benchmark_node_with_cache_key(n: fx.Node) -> tuple[float, str | None]:
|
||||
assert is_compute_node(n)
|
||||
|
||||
from torch._dynamo.testing import rand_strided
|
||||
@ -115,7 +115,7 @@ def benchmark_node_with_cache_key(n: fx.Node) -> tuple[float, Optional[str]]:
|
||||
|
||||
key = f"{str(n.target)}: "
|
||||
|
||||
def to_real(t: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
def to_real(t: torch.Tensor) -> torch.Tensor | None:
|
||||
shape = [get_hint(dim) for dim in t.shape]
|
||||
stride = [get_hint(s) for s in t.stride()]
|
||||
|
||||
@ -177,7 +177,7 @@ class CollectiveInfo:
|
||||
size_bytes: int
|
||||
estimated_time_ms: float
|
||||
exposed_time_ms: float # How much of this collective is still exposed
|
||||
hiding_node: Optional[fx.Node] = None # Node that hides this collective
|
||||
hiding_node: fx.Node | None = None # Node that hides this collective
|
||||
|
||||
@property
|
||||
def is_exposed(self) -> bool:
|
||||
@ -189,8 +189,8 @@ class CollBucket:
|
||||
"""Track information about a bucket of collectives."""
|
||||
|
||||
collectives: list[fx.Node] # Original collective starts
|
||||
bucketed_start: Optional[fx.Node] = None # After bucketing
|
||||
bucketed_wait: Optional[fx.Node] = None # After bucketing
|
||||
bucketed_start: fx.Node | None = None # After bucketing
|
||||
bucketed_wait: fx.Node | None = None # After bucketing
|
||||
total_bytes: int = 0
|
||||
|
||||
|
||||
@ -342,7 +342,7 @@ class OverlapScheduler:
|
||||
log.info(
|
||||
"Overlap scheduling: Aligning runtime estimations across all distributed ranks"
|
||||
)
|
||||
runtime_estimations_keys: list[Optional[str]] = []
|
||||
runtime_estimations_keys: list[str | None] = []
|
||||
runtime_estimations: list[float] = []
|
||||
for n in self.compute_nodes:
|
||||
val, key = benchmark_node_with_cache_key(n)
|
||||
@ -670,8 +670,8 @@ class OverlapScheduler:
|
||||
available_compute_time -= overlap_amount
|
||||
|
||||
def _find_schedulable_path(
|
||||
self, target: fx.Node, curr_compute_node: Optional[fx.Node]
|
||||
) -> Optional[OrderedSet[fx.Node]]:
|
||||
self, target: fx.Node, curr_compute_node: fx.Node | None
|
||||
) -> OrderedSet[fx.Node] | None:
|
||||
"""Find path to target by collecting unscheduled dependencies."""
|
||||
|
||||
# TODO - following path faster than doing set difference here
|
||||
@ -725,7 +725,7 @@ class OverlapScheduler:
|
||||
return self.collective_info[oldest_start].wait_node
|
||||
|
||||
def _wait_is_hidden(
|
||||
self, wait_node: fx.Node, compute_node: Optional[fx.Node] = None
|
||||
self, wait_node: fx.Node, compute_node: fx.Node | None = None
|
||||
) -> bool:
|
||||
assert is_wait_tensor(wait_node)
|
||||
info = self.collective_info[self.wait_to_start[wait_node]]
|
||||
@ -821,7 +821,7 @@ class OverlapScheduler:
|
||||
|
||||
used_compute_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||
|
||||
def could_be_hidden(start: fx.Node) -> Optional[fx.Node]:
|
||||
def could_be_hidden(start: fx.Node) -> fx.Node | None:
|
||||
for compute_node in self.compute_nodes:
|
||||
if limit_coll_per_compute and compute_node in used_compute_nodes:
|
||||
continue
|
||||
|
@ -3,7 +3,7 @@ import itertools
|
||||
import operator
|
||||
import typing
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor.runtime.runtime_utils
|
||||
@ -83,12 +83,10 @@ def check_dtype(a: Tensor, b: Tensor) -> bool:
|
||||
return a.is_floating_point() and b.is_floating_point()
|
||||
|
||||
|
||||
def should_pad_common(
|
||||
mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None
|
||||
) -> bool:
|
||||
def should_pad_common(mat1: Tensor, mat2: Tensor, input: Tensor | None = None) -> bool:
|
||||
# It's fine we have symbolic shapes or strides as long as they
|
||||
# have hints. Later, we will make sure we only pad non-symbolic dimensions.
|
||||
def valid_shape_and_stride(t: Optional[Tensor]) -> bool:
|
||||
def valid_shape_and_stride(t: Tensor | None) -> bool:
|
||||
if t is None:
|
||||
return True
|
||||
|
||||
@ -153,7 +151,7 @@ def should_pad_addmm(match: Match) -> bool:
|
||||
|
||||
|
||||
def pad_addmm(
|
||||
input: Optional[Tensor],
|
||||
input: Tensor | None,
|
||||
mat1: Tensor,
|
||||
mat2: Tensor,
|
||||
m_padded_length: int,
|
||||
@ -195,7 +193,7 @@ def pad_addmm(
|
||||
|
||||
|
||||
def addmm_replace(
|
||||
input: Optional[Tensor],
|
||||
input: Tensor | None,
|
||||
mat1: Tensor,
|
||||
mat2: Tensor,
|
||||
beta: float = 1.0,
|
||||
@ -275,7 +273,7 @@ def should_pad_bench_key(
|
||||
mat1: Tensor,
|
||||
mat2: Tensor,
|
||||
op: torch._ops.OpOverloadPacket,
|
||||
input: Optional[Tensor] = None,
|
||||
input: Tensor | None = None,
|
||||
is_base_time_key: bool = False,
|
||||
) -> str:
|
||||
def tensor_key(t: Tensor) -> tuple[torch.Size, tuple[int, ...], torch.dtype]:
|
||||
@ -285,7 +283,7 @@ def should_pad_bench_key(
|
||||
None if mat1.dtype != torch.float32 else torch.backends.cuda.matmul.allow_tf32
|
||||
)
|
||||
|
||||
def fmt_pad(name: str) -> Optional[str]:
|
||||
def fmt_pad(name: str) -> str | None:
|
||||
if is_base_time_key:
|
||||
return None
|
||||
return f"exclude_pad:{should_exclude_padding_time(match, name)}"
|
||||
@ -412,7 +410,7 @@ def _should_pad_bench(
|
||||
mat1: Tensor,
|
||||
mat2: Tensor,
|
||||
op: torch._ops.OpOverloadPacket,
|
||||
input: Optional[Tensor] = None,
|
||||
input: Tensor | None = None,
|
||||
) -> bool:
|
||||
do_bench = get_do_bench()
|
||||
|
||||
@ -681,10 +679,10 @@ def run_autoheuristic(
|
||||
ori_time: float,
|
||||
ori_time_key: str,
|
||||
key: str,
|
||||
) -> Optional[bool]:
|
||||
) -> bool | None:
|
||||
def feedback_fn(
|
||||
choice: str,
|
||||
) -> Optional[float]:
|
||||
) -> float | None:
|
||||
if choice == orig_choice:
|
||||
return do_bench(orig_bench_fn)
|
||||
elif choice == pad_choice:
|
||||
|
@ -5,7 +5,7 @@ import itertools
|
||||
import logging
|
||||
import operator
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
@ -1726,7 +1726,7 @@ class ConstructorMoverPass:
|
||||
|
||||
return False
|
||||
|
||||
def get_node_device(self, node: fx.Node) -> Optional[torch.device]:
|
||||
def get_node_device(self, node: fx.Node) -> torch.device | None:
|
||||
"""
|
||||
Get the device of a node.
|
||||
"""
|
||||
|
@ -5,7 +5,6 @@ import itertools
|
||||
import logging
|
||||
import types
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -191,8 +190,8 @@ def _get_pass_name_func(p):
|
||||
def _run_pre_dispatch_passes(
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: Sequence[object] = (),
|
||||
add_passes: Optional[str] = None,
|
||||
remove_passes: Optional[str] = None,
|
||||
add_passes: str | None = None,
|
||||
remove_passes: str | None = None,
|
||||
) -> None:
|
||||
# order matters
|
||||
default_pass_list = [
|
||||
@ -278,8 +277,8 @@ def _run_pre_dispatch_passes(
|
||||
def pre_grad_passes(
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: Sequence[object] = (),
|
||||
add_passes: Optional[str] = None,
|
||||
remove_passes: Optional[str] = None,
|
||||
add_passes: str | None = None,
|
||||
remove_passes: str | None = None,
|
||||
) -> torch.fx.GraphModule:
|
||||
"""
|
||||
Apply passes on the input FX graph using Torch IR.
|
||||
@ -763,7 +762,7 @@ def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
# ---->
|
||||
# Y2 = (W * X^T + bias.unsqueeze(-1))^T
|
||||
def linear_transpose(
|
||||
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
|
||||
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
if bias is None:
|
||||
return torch.matmul(weight, input.transpose(-1, -2))
|
||||
@ -860,7 +859,7 @@ def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
# ---->
|
||||
# Y2 = X1.transpose(-1, -2) * W1^T + bias1
|
||||
def transpose_linear(
|
||||
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
|
||||
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
if bias is None:
|
||||
return torch.matmul(input.transpose(-1, -2), weight.t())
|
||||
|
@ -679,7 +679,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
||||
from torch._higher_order_ops.auto_functionalize import get_mutable_args
|
||||
|
||||
tensors_to_clone, _ = get_mutable_args(_mutable_op)
|
||||
# Don't try to reinplace Optional[Tensor] args that are None.
|
||||
# Don't try to reinplace Tensor | None args that are None.
|
||||
tensors_to_clone = [
|
||||
t for t in tensors_to_clone if node.kwargs[t] is not None
|
||||
]
|
||||
|
@ -5,7 +5,7 @@ import operator
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import torch
|
||||
@ -38,10 +38,10 @@ log = logging.getLogger(__name__)
|
||||
|
||||
_Arguments: TypeAlias = tuple[torch.fx.node.Argument, ...]
|
||||
_TransformParam: TypeAlias = tuple[
|
||||
Optional[_Arguments],
|
||||
Optional[_Arguments],
|
||||
Optional[_Arguments],
|
||||
Optional[_Arguments],
|
||||
_Arguments | None,
|
||||
_Arguments | None,
|
||||
_Arguments | None,
|
||||
_Arguments | None,
|
||||
]
|
||||
_Range: TypeAlias = tuple[int, int]
|
||||
|
||||
@ -167,7 +167,7 @@ def _get_dim(node: Any):
|
||||
def normalize_split_base(
|
||||
match: Match,
|
||||
_get_split_args: Callable[
|
||||
[torch.fx.Node], tuple[Optional[torch.fx.Node], Optional[Any], Optional[int]]
|
||||
[torch.fx.Node], tuple[torch.fx.Node | None, Any | None, int | None]
|
||||
],
|
||||
):
|
||||
"""
|
||||
@ -802,7 +802,7 @@ class SplitCatSimplifier:
|
||||
split_sections,
|
||||
next_users,
|
||||
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
|
||||
) -> Optional[list[_Range]]:
|
||||
) -> list[_Range] | None:
|
||||
ranges = OrderedSet[Any]()
|
||||
for user_inputs in user_inputs_list:
|
||||
ranges.update(u for u in user_inputs if isinstance(u, tuple))
|
||||
@ -848,7 +848,7 @@ class SplitCatSimplifier:
|
||||
split_node: torch.fx.Node,
|
||||
next_users: list[torch.fx.Node],
|
||||
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
|
||||
) -> Optional[list[list[_TransformParam]]]:
|
||||
) -> list[list[_TransformParam]] | None:
|
||||
"""
|
||||
Figure out what transforms are needed for each input to each cat node.
|
||||
|
||||
@ -1178,7 +1178,7 @@ class UnbindCatRemover(SplitCatSimplifier):
|
||||
split_sections: list[int],
|
||||
next_users: list[torch.fx.Node],
|
||||
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
|
||||
) -> Optional[list[_Range]]:
|
||||
) -> list[_Range] | None:
|
||||
simplified_split_ranges = super().get_simplified_split_ranges(
|
||||
split_sections, next_users, user_inputs_list
|
||||
)
|
||||
@ -1191,7 +1191,7 @@ class UnbindCatRemover(SplitCatSimplifier):
|
||||
split_node: torch.fx.Node,
|
||||
next_users: list[torch.fx.Node],
|
||||
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
|
||||
) -> Optional[list[list[_TransformParam]]]:
|
||||
) -> list[list[_TransformParam]] | None:
|
||||
"""
|
||||
Figure out what transforms are needed for each input to each cat node.
|
||||
|
||||
|
Reference in New Issue
Block a user