PEP585 update - torch/distributed (#145164)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-20 14:50:01 -08:00
committed by PyTorch MergeBot
parent c6986ca2e1
commit 00ffeca1b1
79 changed files with 805 additions and 860 deletions

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
import torch
from torch.fx.node import map_aggregate
@ -92,7 +92,7 @@ class TensorChunkSpec:
@staticmethod
def from_dict(
chunk_dims: Dict[str, int],
chunk_dims: dict[str, int],
):
"""
A helper for creating a dictionary of `TensorChunkSpec` from a
@ -243,11 +243,11 @@ def _shard_dict_of_args(
def split_args_kwargs_into_chunks(
args: tuple[Any, ...],
kwargs: Optional[Dict[str, Any]],
kwargs: Optional[dict[str, Any]],
chunks: int,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
) -> tuple[List[Tuple], List[Dict]]:
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
) -> tuple[list[tuple], list[dict]]:
"""
Given a sequence of args and kwargs, split them into a number of chunks
according to their respective chunking specs.
@ -347,7 +347,7 @@ def split_args_kwargs_into_chunks(
def merge_chunks(
chunks: List[Any],
chunks: list[Any],
chunk_spec,
):
"""