Codemod inductor/fx_passes from Optional to union none (#165606)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165606
Approved by: https://github.com/aorenste
ghstack dependencies: #165604, #165605
This commit is contained in:
Oguz Ulgen
2025-10-15 19:29:51 -07:00
committed by PyTorch MergeBot
parent ab6014a903
commit 5d0b22008d
12 changed files with 94 additions and 99 deletions

View File

@ -1,7 +1,7 @@
import collections
import logging
from collections import defaultdict
from typing import Any, Callable, Optional
from typing import Any, Callable
import torch
import torch.distributed as dist
@ -34,7 +34,7 @@ def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
return (group_name, reduce_op, dtype)
def bucket_key(node: torch.fx.Node) -> Optional[object]:
def bucket_key(node: torch.fx.Node) -> object | None:
if is_all_gather_into_tensor(node):
return _ag_group_key(node)
elif is_reduce_scatter_tensor(node):
@ -58,8 +58,8 @@ def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
def bucket_all_gather(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
mode: Optional[str] = None,
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
mode: str | None = None,
) -> None:
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
@ -75,8 +75,8 @@ def bucket_all_gather(
def bucket_reduce_scatter(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
mode: Optional[str] = None,
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
mode: str | None = None,
) -> None:
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
@ -156,7 +156,7 @@ def greedy_bucket_collective_by_mb(
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
filter_node: Callable[[torch.fx.Node], bool],
node_group_key: Callable[[torch.fx.Node], Any],
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
) -> list[list[torch.fx.Node]]:
"""
Bucketing adjacent collectives with equal node_group_key.
@ -234,7 +234,7 @@ def greedy_bucket_collective_by_mb(
def bucket_all_gather_by_mb(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
) -> list[list[torch.fx.Node]]:
"""
Identifies all all_gather nodes and groups them into buckets,
@ -247,7 +247,7 @@ def bucket_all_gather_by_mb(
to specify different sizes of the buckets at the start,
as first all_gather is usually exposed. Interface of bucket_cap_mb_by_bucket_idx
is `bucket_cap_mb_by_bucket_idx_default` function that is default value for `bucket_cap_mb_by_bucket_idx`.
filter_wait_node (Optional[Callable[[torch.fx.Node], bool]]): If specified,
filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified,
only all_gather nodes with wait_node that satisfy `filter_wait_node` will be bucketed.
Returns:
@ -266,7 +266,7 @@ def bucket_all_gather_by_mb(
def bucket_reduce_scatter_by_mb(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
) -> list[list[torch.fx.Node]]:
"""
Identifies all reduce_scatter nodes and groups them into buckets,
@ -277,7 +277,7 @@ def bucket_reduce_scatter_by_mb(
bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow
to specify different sizes of the buckets.
filter_wait_node (Optional[Callable[[torch.fx.Node], bool]]): If specified,
filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified,
only reduce_scatter nodes with wait_node that satisfy `filter_wait_node` will be bucketed.
Returns:
@ -577,8 +577,8 @@ def process_collective_bucket(
bucket_nodes: list[torch.fx.Node],
fn_to_trace: Callable[..., list[torch.Tensor]],
trace_args_fn: Callable[[list[torch.fx.Node]], tuple[Any, ...]],
insert_before: Optional[torch.fx.Node] = None,
wait_insertion_point: Optional[torch.fx.Node] = None,
insert_before: torch.fx.Node | None = None,
wait_insertion_point: torch.fx.Node | None = None,
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
"""
Process a single bucket of collective operation nodes with flexible insertion control.
@ -666,9 +666,9 @@ def process_collective_bucket(
def merge_reduce_scatter_bucket(
g: torch.fx.Graph,
rs_nodes: list[torch.fx.Node],
mode: Optional[str] = None,
insert_before: Optional[torch.fx.Node] = None,
wait_insertion_point: Optional[torch.fx.Node] = None,
mode: str | None = None,
insert_before: torch.fx.Node | None = None,
wait_insertion_point: torch.fx.Node | None = None,
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
# Validate bucket consistency
rs0 = rs_nodes[0]
@ -716,9 +716,9 @@ def merge_reduce_scatter_bucket(
def merge_all_gather_bucket(
g: torch.fx.Graph,
ag_nodes: list[torch.fx.Node],
mode: Optional[str] = None,
insert_before: Optional[torch.fx.Node] = None,
wait_insertion_point: Optional[torch.fx.Node] = None,
mode: str | None = None,
insert_before: torch.fx.Node | None = None,
wait_insertion_point: torch.fx.Node | None = None,
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
from torch.distributed.distributed_c10d import _resolve_process_group
@ -764,7 +764,7 @@ def merge_all_gather_bucket(
def merge_reduce_scatter(
gm: torch.fx.GraphModule,
rs_buckets: list[list[torch.fx.Node]],
mode: Optional[str] = None,
mode: str | None = None,
) -> None:
"""
Merges specified buckets of reduce_scatter to joint reduce_scatter.
@ -788,7 +788,7 @@ def merge_reduce_scatter(
def merge_all_gather(
gm: torch.fx.GraphModule,
ag_buckets: list[list[torch.fx.Node]],
mode: Optional[str] = None,
mode: str | None = None,
) -> None:
"""
Merges specified buckets of all_gather to joint all_gather.