[3/N] Import Callable from collections.abc in torch/distributed (#164104)

This is the result of applying the ruff `UP035` check.
`Callable` is imported from `collections.abc` instead of `typing`.
This PR is the follow-up of #164054.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164104
Approved by: https://github.com/Skylion007
This commit is contained in:
Yuanyuan Chen
2025-09-30 00:28:50 +00:00
committed by PyTorch MergeBot
parent cee4e36f9a
commit da003d7b95
74 changed files with 154 additions and 121 deletions

View File

@ -1,9 +1,10 @@
# mypy: allow-untyped-defs
import uuid
from collections import OrderedDict
from collections.abc import Callable
from functools import wraps
from typing import Callable, Generic, Optional, Protocol
from typing_extensions import Concatenate, ParamSpec, TypeVar
from typing import Concatenate, Generic, Optional, Protocol
from typing_extensions import ParamSpec, TypeVar
import torch
import torch.nn as nn

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import logging
from typing import Callable, Optional, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING, Union
import torch
import torch.distributed as dist
@ -40,6 +40,8 @@ from .contract import _get_registry, contract
if TYPE_CHECKING:
from collections.abc import Callable
from torch.distributed.tensor import Shard

View File

@ -36,8 +36,8 @@ Functions for manipulating IntTuples
from functools import reduce
from itertools import chain
from typing import Optional, Union
from typing_extensions import TypeAlias, TypeIs
from typing import Optional, TypeAlias, Union
from typing_extensions import TypeIs
from .typing import Integer

View File

@ -36,8 +36,8 @@ of lexicographic instead of co-lexicographic as implemented in the original layo
"""
from itertools import chain
from typing import Optional, Union
from typing_extensions import TypeAlias, TypeIs
from typing import Optional, TypeAlias, Union
from typing_extensions import TypeIs
from .int_tuple import (
crd2idx,

View File

@ -8,7 +8,7 @@ import warnings
import weakref
from dataclasses import dataclass
from functools import reduce
from typing import Callable, cast, Optional, TYPE_CHECKING
from typing import cast, Optional, TYPE_CHECKING
from typing_extensions import deprecated
import torch
@ -41,7 +41,7 @@ from .utils import (
if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from torch.distributed._shard.metadata import ShardMetadata

View File

@ -2,8 +2,9 @@
import functools
import operator
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import Callable, TYPE_CHECKING
from typing import TYPE_CHECKING
import torch
import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta

View File

@ -3,8 +3,8 @@ import copy
import io
import math
import weakref
from collections.abc import Mapping, MutableMapping
from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union
from collections.abc import Callable, Mapping, MutableMapping
from typing import Any, cast, NamedTuple, Optional, TYPE_CHECKING, Union
import torch
import torch.cuda._pin_memory_utils as pin_memory_utils

View File

@ -4,12 +4,12 @@ import math
import os
import socket
import uuid
from collections.abc import Generator
from collections.abc import Callable, Generator
from contextlib import contextmanager
from datetime import timedelta
from enum import Enum
from functools import partial
from typing import Any, Callable, Literal
from typing import Any, Literal
import torch
import torch.distributed._functional_collectives as funcol

View File

@ -1,7 +1,8 @@
from collections.abc import Callable
from copy import deepcopy
from enum import auto, Enum
from functools import partial, wraps
from typing import Any, Callable, NamedTuple, Optional, TypeVar, Union
from typing import Any, NamedTuple, Optional, TypeVar, Union
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
import torch

View File

@ -2,11 +2,12 @@ import math
import os
import re
import warnings
from collections.abc import Callable
from contextlib import nullcontext
from copy import deepcopy
from enum import auto, Enum
from functools import partial, wraps
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union
from typing_extensions import Self
import torch

View File

@ -2,9 +2,9 @@
import operator
import pickle
from collections import defaultdict
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from itertools import chain
from typing import Any, Callable, no_type_check, TYPE_CHECKING
from typing import Any, no_type_check, TYPE_CHECKING
import torch
import torch.nn as nn

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
import warnings
import weakref
from typing import Callable, Optional
from collections.abc import Callable
from typing import Optional
import torch
from torch.autograd.graph import register_multi_grad_hook

View File

@ -2,7 +2,7 @@
import math
import os
from collections import defaultdict
from typing import Any, Callable
from typing import Any, TYPE_CHECKING
from typing_extensions import Self
import torch
@ -16,6 +16,10 @@ from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.flop_counter import flop_registry
if TYPE_CHECKING:
from collections.abc import Callable
aten = torch.ops.aten
# This value is hard-coded here:

View File

@ -1,10 +1,10 @@
# mypy: allow-untyped-defs
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterator
from collections.abc import Callable, Iterator
from enum import auto, Enum
from functools import partial
from typing import Any, Callable, Optional
from typing import Any, Optional
import torch
import torch.nn as nn

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import weakref
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any, Optional
import torch
import torch.distributed as dist

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
from typing import Any, Callable, cast
from collections.abc import Callable
from typing import Any, cast
import torch
import torch.distributed as dist

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
from collections.abc import Callable
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, no_type_check
from typing import Any, no_type_check
import torch
import torch.distributed as dist

View File

@ -9,7 +9,8 @@
import functools
import logging
from typing import Any, Callable, TypeVar
from collections.abc import Callable
from typing import Any, TypeVar
from typing_extensions import ParamSpec
import torch

View File

@ -6,7 +6,8 @@ of checkpointer instances by automatically handling component initialization
and configuration with reasonable defaults.
"""
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any, Optional
import torch.distributed as dist

View File

@ -1,11 +1,12 @@
import logging
import os
import traceback
from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum
from multiprocessing.connection import Connection
from typing import Any, Callable, Optional, Union
from typing import Any, Optional, Union
import torch.multiprocessing as mp
from torch.multiprocessing.spawn import ProcessExitedException

View File

@ -7,8 +7,7 @@ saving and loading.
"""
from dataclasses import dataclass
from typing import Any
from typing_extensions import TypeAlias
from typing import Any, TypeAlias
# Type alias for state dictionaries used in checkpointing

View File

@ -1,11 +1,11 @@
import logging
import pickle
import time
from collections.abc import Generator
from collections.abc import Callable, Generator
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Callable, cast, Optional, TypeVar, Union
from typing import cast, Optional, TypeVar, Union
import torch
from torch.distributed import ProcessGroup, Work

View File

@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from collections.abc import Collection, Mapping, MutableMapping
from typing import Callable, cast, Optional, TypeVar, Union
from collections.abc import Callable, Collection, Mapping, MutableMapping
from typing import cast, Optional, TypeVar, Union
import torch
from torch.distributed._shard.sharded_tensor.api import ShardedTensor

View File

@ -11,13 +11,13 @@ import threading
import uuid
import warnings
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterable, Iterator, Sequence
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from io import UnsupportedOperation
from pathlib import Path
from typing import Any, Callable, cast, Final, IO, Optional, Union
from typing import Any, cast, Final, IO, Optional, Union
# introduced as collections.abc.Buffer in Python 3.12
from typing_extensions import Buffer

View File

@ -2,7 +2,8 @@
import functools
import logging
import time
from typing import Any, Callable, TypeVar
from collections.abc import Callable
from typing import Any, TypeVar
from typing_extensions import ParamSpec
from uuid import uuid4

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import io
from typing import Any, Callable, cast
from collections.abc import Callable
from typing import Any, cast
import torch
import torch.distributed as dist

View File

@ -3,10 +3,10 @@ import contextlib
import functools
import gc
import warnings
from collections.abc import Generator, Iterable
from collections.abc import Callable, Generator, Iterable
from dataclasses import asdict, dataclass, field
from itertools import chain
from typing import Any, Callable, cast, no_type_check, Optional, Union
from typing import Any, cast, no_type_check, Optional, Union
import torch
import torch.distributed as dist

View File

@ -5,11 +5,11 @@ import io
import itertools
import os
import warnings
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from contextlib import contextmanager
from functools import wraps
from pstats import Stats
from typing import Any, Callable, cast, Optional, TypeVar, Union
from typing import Any, cast, Optional, TypeVar, Union
import torch
import torch.distributed as dist

View File

@ -13,11 +13,11 @@ import importlib
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, cast, Generic, Optional, TYPE_CHECKING, TypeVar, Union
from typing import Any, cast, Generic, Optional, TYPE_CHECKING, TypeVar, Union
if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Callable, Iterable
import torch
import torch.distributed as dist

View File

@ -14,8 +14,9 @@ import sys
import time
import warnings
from collections import namedtuple
from collections.abc import Callable
from datetime import timedelta
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union
from typing_extensions import deprecated
import torch

View File

@ -15,10 +15,11 @@ import time
import traceback
import warnings
from collections import defaultdict
from collections.abc import Callable
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Optional, Union
from typing import Any, Optional, Union
import torch.distributed.elastic.rendezvous as rdzv
import torch.distributed.elastic.utils.store as store_util

View File

@ -6,7 +6,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable
from collections.abc import Callable
from torch.distributed.elastic.utils.logging import get_logger

View File

@ -63,7 +63,8 @@ was launched a :class:`api.SubprocessContext` is returned. Both are specific
implementations of the parent :class:`api.PContext` class.
"""
from typing import Callable, Optional, Union
from collections.abc import Callable
from typing import Optional, Union
from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401
_validate_full_rank,

View File

@ -19,12 +19,13 @@ import tempfile
import threading
import time
from abc import ABC, abstractmethod
from collections.abc import Callable
from contextlib import nullcontext
from dataclasses import dataclass, field
from enum import IntFlag
from multiprocessing import synchronize
from types import FrameType
from typing import Any, Callable, Optional, Union
from typing import Any, Optional, Union
import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record

View File

@ -54,11 +54,12 @@ import os
import signal
import socket
import time
from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import datetime
from functools import wraps
from string import Template
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, Optional, TypeVar, Union
from typing_extensions import ParamSpec
from torch.distributed.elastic.utils.logging import get_logger

View File

@ -7,8 +7,9 @@
import socket
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Optional
from typing import Any, ClassVar, Optional
from torch.distributed import Store
from torch.distributed.elastic.utils.distributed import get_free_port

View File

@ -14,10 +14,11 @@ import threading
import time
import weakref
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any, Callable, Optional
from typing import Any, Optional
import torch.distributed as dist
from torch.distributed import Store

View File

@ -11,9 +11,10 @@ import re
import socket
import time
import weakref
from collections.abc import Callable
from datetime import timedelta
from threading import Event, Thread
from typing import Any, Callable, Optional, Union
from typing import Any, Optional, Union
__all__ = ["parse_rendezvous_endpoint"]

View File

@ -13,7 +13,8 @@ import signal
import sys
import threading
import time
from typing import Callable, Optional, TypeVar
from collections.abc import Callable
from typing import Optional, TypeVar
from typing_extensions import ParamSpec
from torch.distributed.elastic.timer.api import TimerClient, TimerRequest

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3
from collections.abc import Iterator
from typing import Callable, TypeVar
from collections.abc import Callable, Iterator
from typing import TypeVar
from typing_extensions import Self

View File

@ -7,10 +7,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from contextlib import contextmanager
from datetime import timedelta
from typing import Callable, Optional
from typing import Optional
import torch

View File

@ -7,11 +7,11 @@ import logging
import traceback
import warnings
import weakref
from collections.abc import Generator, Iterable
from collections.abc import Callable, Generator, Iterable
from enum import auto, Enum
from functools import partial
from itertools import chain
from typing import Any, Callable, cast, no_type_check, Optional, TYPE_CHECKING
from typing import Any, cast, no_type_check, Optional, TYPE_CHECKING
import torch
import torch.distributed as dist

View File

@ -4,10 +4,10 @@ import functools
import logging
import os
import warnings
from collections.abc import Generator, Iterator, Sequence
from collections.abc import Callable, Generator, Iterator, Sequence
from enum import auto, Enum
from itertools import accumulate, chain
from typing import Any, Callable, cast, NamedTuple, no_type_check, Optional, Union
from typing import Any, cast, NamedTuple, no_type_check, Optional, Union
import torch
import torch.distributed as dist

View File

@ -1,7 +1,7 @@
import math
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from itertools import chain
from typing import Any, Callable, cast, NamedTuple, Optional, Union
from typing import Any, cast, NamedTuple, Optional, Union
import torch
import torch.distributed as dist

View File

@ -1,10 +1,10 @@
# mypy: allow-untyped-defs
import inspect
import itertools
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from enum import auto, Enum
from typing import Any, Callable, cast, Optional
from typing import Any, cast, Optional
import torch
import torch.nn as nn

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
import contextlib
import logging
from typing import Any, Callable, cast, NamedTuple, Optional
from collections.abc import Callable
from typing import Any, cast, NamedTuple, Optional
import torch
import torch.distributed as dist

View File

@ -2,8 +2,8 @@
# mypy: allow-untyped-defs
import functools
import logging
from collections.abc import Sequence
from typing import Any, Callable, Optional, TYPE_CHECKING
from collections.abc import Callable, Sequence
from typing import Any, Optional, TYPE_CHECKING
import torch
import torch.nn as nn

View File

@ -4,16 +4,7 @@
from __future__ import annotations
import functools
from typing import (
Any,
Callable,
cast,
NoReturn,
Optional,
overload,
TYPE_CHECKING,
Union,
)
from typing import Any, cast, NoReturn, Optional, overload, TYPE_CHECKING, Union
from typing_extensions import deprecated
import torch
@ -36,7 +27,7 @@ from ._fsdp_state import _get_module_fsdp_state, FSDPState
if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from torch.distributed.tensor import DeviceMesh, Shard

View File

@ -3,8 +3,8 @@ import collections
import itertools
import os
import warnings
from collections.abc import Generator, Iterable, Iterator
from typing import Any, Callable, no_type_check, Optional, TYPE_CHECKING, Union
from collections.abc import Callable, Generator, Iterable, Iterator
from typing import Any, no_type_check, Optional, TYPE_CHECKING, Union
import torch
import torch.distributed as dist

View File

@ -1,8 +1,9 @@
# mypy: allow-untyped-defs
import functools
import logging
from collections.abc import Callable
from enum import auto, Enum
from typing import Any, Callable, no_type_check, Optional
from typing import Any, no_type_check, Optional
import torch
import torch.distributed as dist

View File

@ -3,8 +3,8 @@ import contextlib
import logging
import math
import warnings
from collections.abc import Generator, Iterator
from typing import Any, Callable, cast, no_type_check
from collections.abc import Callable, Generator, Iterator
from typing import Any, cast, no_type_check
import torch
import torch.distributed as dist

View File

@ -1,8 +1,9 @@
# mypy: allow-untyped-defs
import functools
from collections.abc import Callable
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Callable, NamedTuple, Optional
from typing import Any, NamedTuple, Optional
import torch
import torch.nn as nn

View File

@ -3,8 +3,9 @@ import collections
import functools
import inspect
import warnings
from collections.abc import Callable
from functools import partial
from typing import Any, Callable, Union
from typing import Any, Union
import torch.nn as nn
from torch.distributed.fsdp._common_utils import (

View File

@ -6,10 +6,10 @@ import functools
import math
import traceback
import warnings
from collections.abc import Generator, Iterable, Iterator
from collections.abc import Callable, Generator, Iterable, Iterator
from contextlib import contextmanager
from enum import auto, Enum
from typing import Any, Callable, Optional, Union
from typing import Any, Optional, Union
import torch
import torch.distributed as dist

View File

@ -7,8 +7,8 @@
import contextlib
import copy
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterable, Sequence
from typing import Any, Callable, cast, Optional, Union
from collections.abc import Callable, Generator, Iterable, Sequence
from typing import Any, cast, Optional, Union
import torch.nn as nn

View File

@ -9,8 +9,9 @@
import os
import sys
import uuid
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, Callable, Optional, Union
from typing import Any, Optional, Union
import torch
import torch.distributed.elastic.rendezvous.registry as rdzv_registry

View File

@ -4,8 +4,8 @@ import collections
import io
import sys
import types
from collections.abc import Iterator, Mapping
from typing import Any, Callable, Optional, TypeVar, Union
from collections.abc import Callable, Iterator, Mapping
from typing import Any, Optional, TypeVar, Union
from typing_extensions import Self
import torch

View File

@ -1,8 +1,8 @@
import logging
import warnings
from collections.abc import Collection, Mapping
from collections.abc import Callable, Collection, Mapping
from copy import deepcopy
from typing import Any, Callable, Optional, overload, Union
from typing import Any, Optional, overload, Union
import torch
import torch.nn as nn

View File

@ -11,8 +11,9 @@ import enum
import inspect
import io
import logging
from collections.abc import Callable
from itertools import chain
from typing import Any, Callable, Optional, Union
from typing import Any, Optional, Union
import torch
import torch.distributed as dist

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import enum
from typing import Any, Callable, overload
from collections.abc import Callable
from typing import Any, overload
import torch
from torch.distributed.algorithms.join import Joinable, JoinHook

View File

@ -4,10 +4,11 @@ import copy
import logging
import operator
from collections import defaultdict
from collections.abc import Callable
from enum import Enum
from inspect import Parameter, Signature, signature
from types import MethodType
from typing import Any, Callable, Optional, Union
from typing import Any, Optional, Union
import torch
import torch.fx as fx

View File

@ -8,9 +8,10 @@ import logging
import re
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from collections.abc import Callable
from enum import Enum
from functools import lru_cache
from typing import Any, Callable, NamedTuple, Optional, Union
from typing import Any, NamedTuple, Optional, Union
import torch
import torch.distributed as dist

View File

@ -3,7 +3,8 @@
import logging
import operator
from abc import ABC, abstractmethod
from typing import Any, Callable, cast, Optional, Union
from collections.abc import Callable
from typing import Any, cast, Optional, Union
import torch
import torch.distributed as dist

View File

@ -9,9 +9,9 @@ except ImportError as e:
import numbers
import os
import sys
from collections.abc import Iterator
from collections.abc import Callable, Iterator
from datetime import timedelta
from typing import Callable, Optional
from typing import Optional
from torch.distributed import FileStore, Store, TCPStore

View File

@ -373,8 +373,9 @@ import os
import sys
import uuid
from argparse import ArgumentParser, REMAINDER
from collections.abc import Callable
from importlib import metadata
from typing import Callable, Optional, Union
from typing import Optional, Union
import torch
from torch.distributed.argparse_util import check_env, env

View File

@ -3,8 +3,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import inspect
import warnings
from collections.abc import Sequence
from typing import Any, Callable, cast, Optional
from collections.abc import Callable, Sequence
from typing import Any, cast, Optional
from typing_extensions import deprecated
import torch

View File

@ -1,8 +1,8 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
from dataclasses import dataclass
from typing import Callable, cast, Optional, Union
from typing import cast, Optional, Union
import torch
from torch import Tensor

View File

@ -3,8 +3,8 @@
import functools
import itertools
import operator
from collections.abc import Iterable, Sequence
from typing import Callable, cast, Optional, TypeVar, Union
from collections.abc import Callable, Iterable, Sequence
from typing import cast, Optional, TypeVar, Union
from typing_extensions import ParamSpec
import torch

View File

@ -1,9 +1,9 @@
# mypy: allow-untyped-defs
import threading
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import lru_cache
from itertools import chain
from typing import Callable, cast, Optional, Union
from typing import cast, Optional, Union
import torch
from torch._ops import OpOverload

View File

@ -5,7 +5,7 @@ torchrun --standalone --nnodes=1 --nproc-per-node=4 comm_mode_features_example.p
import argparse
import os
from typing import Callable, Union
from typing import TYPE_CHECKING, Union
import torch
import torch.nn as nn
@ -26,6 +26,10 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
from torch.utils.checkpoint import checkpoint
if TYPE_CHECKING:
from collections.abc import Callable
def get_device_type() -> str:
device_type = "cpu"
if torch.accelerator.device_count() >= 4:

View File

@ -3,10 +3,10 @@ import itertools
import logging
import types
from abc import ABC, abstractmethod
from collections.abc import Generator
from collections.abc import Callable, Generator
from dataclasses import dataclass
from enum import auto, Enum
from typing import Any, Callable, Optional, Protocol
from typing import Any, Optional, Protocol
import torch
import torch.distributed as dist

View File

@ -1,8 +1,8 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import functools
from collections.abc import Sequence
from typing import Callable, Optional, Union
from collections.abc import Callable, Sequence
from typing import Optional, Union
import torch
from torch.distributed._functional_collectives import AsyncCollectiveTensor

View File

@ -1,8 +1,8 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import partial
from typing import Callable, Union
from typing import Union
import torch
from torch._ops import OpOverload

View File

@ -2,8 +2,8 @@
import dataclasses
import traceback
from collections import OrderedDict
from collections.abc import Container
from typing import Any, Callable, Optional, overload, TypeVar
from collections.abc import Callable, Container
from typing import Any, Optional, overload, TypeVar
import torch
import torch.distributed as dist