mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[3/N] Import Callable from collections.abc in torch/distributed (#164104)
This is the result of applying the ruff `UP035` check. `Callable` is imported from `collections.abc` instead of `typing`. This PR is the follow-up of #164054. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164104 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
cee4e36f9a
commit
da003d7b95
@ -1,9 +1,10 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Callable, Generic, Optional, Protocol
|
||||
from typing_extensions import Concatenate, ParamSpec, TypeVar
|
||||
from typing import Concatenate, Generic, Optional, Protocol
|
||||
from typing_extensions import ParamSpec, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -40,6 +40,8 @@ from .contract import _get_registry, contract
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from torch.distributed.tensor import Shard
|
||||
|
||||
|
||||
|
@ -36,8 +36,8 @@ Functions for manipulating IntTuples
|
||||
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from typing import Optional, Union
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
from typing import Optional, TypeAlias, Union
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from .typing import Integer
|
||||
|
||||
|
@ -36,8 +36,8 @@ of lexicographic instead of co-lexicographic as implemented in the original layo
|
||||
"""
|
||||
|
||||
from itertools import chain
|
||||
from typing import Optional, Union
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
from typing import Optional, TypeAlias, Union
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from .int_tuple import (
|
||||
crd2idx,
|
||||
|
@ -8,7 +8,7 @@ import warnings
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce
|
||||
from typing import Callable, cast, Optional, TYPE_CHECKING
|
||||
from typing import cast, Optional, TYPE_CHECKING
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
@ -41,7 +41,7 @@ from .utils import (
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from torch.distributed._shard.metadata import ShardMetadata
|
||||
|
||||
|
@ -2,8 +2,9 @@
|
||||
import functools
|
||||
import operator
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
|
||||
|
@ -3,8 +3,8 @@ import copy
|
||||
import io
|
||||
import math
|
||||
import weakref
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
from collections.abc import Callable, Mapping, MutableMapping
|
||||
from typing import Any, cast, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda._pin_memory_utils as pin_memory_utils
|
||||
|
@ -4,12 +4,12 @@ import math
|
||||
import os
|
||||
import socket
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
@ -1,7 +1,8 @@
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from enum import auto, Enum
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, NamedTuple, Optional, TypeVar, Union
|
||||
from typing import Any, NamedTuple, Optional, TypeVar, Union
|
||||
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
|
||||
|
||||
import torch
|
||||
|
@ -2,11 +2,12 @@ import math
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from enum import auto, Enum
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
|
@ -2,9 +2,9 @@
|
||||
import operator
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, no_type_check, TYPE_CHECKING
|
||||
from typing import Any, no_type_check, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -1,7 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
import weakref
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.autograd.graph import register_multi_grad_hook
|
||||
|
@ -2,7 +2,7 @@
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
@ -16,6 +16,10 @@ from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils.flop_counter import flop_registry
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
# This value is hard-coded here:
|
||||
|
@ -1,10 +1,10 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Callable, Iterator
|
||||
from enum import auto, Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import weakref
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Callable, cast
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -1,7 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Callable, no_type_check
|
||||
from typing import Any, no_type_check
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -9,7 +9,8 @@
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Callable, TypeVar
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
|
@ -6,7 +6,8 @@ of checkpointer instances by automatically handling component initialization
|
||||
and configuration with reasonable defaults.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
@ -1,11 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
from torch.multiprocessing.spawn import ProcessExitedException
|
||||
|
@ -7,8 +7,7 @@ saving and loading.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing_extensions import TypeAlias
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
|
||||
# Type alias for state dictionaries used in checkpointing
|
||||
|
@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import pickle
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Callable, cast, Optional, TypeVar, Union
|
||||
from typing import cast, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup, Work
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from collections.abc import Collection, Mapping, MutableMapping
|
||||
from typing import Callable, cast, Optional, TypeVar, Union
|
||||
from collections.abc import Callable, Collection, Mapping, MutableMapping
|
||||
from typing import cast, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
||||
|
@ -11,13 +11,13 @@ import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Iterable, Iterator, Sequence
|
||||
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from io import UnsupportedOperation
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, cast, Final, IO, Optional, Union
|
||||
from typing import Any, cast, Final, IO, Optional, Union
|
||||
|
||||
# introduced as collections.abc.Buffer in Python 3.12
|
||||
from typing_extensions import Buffer
|
||||
|
@ -2,7 +2,8 @@
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Callable, TypeVar
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
from uuid import uuid4
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import io
|
||||
from typing import Any, Callable, cast
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -3,10 +3,10 @@ import contextlib
|
||||
import functools
|
||||
import gc
|
||||
import warnings
|
||||
from collections.abc import Generator, Iterable
|
||||
from collections.abc import Callable, Generator, Iterable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, cast, no_type_check, Optional, Union
|
||||
from typing import Any, cast, no_type_check, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -5,11 +5,11 @@ import io
|
||||
import itertools
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from pstats import Stats
|
||||
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
||||
from typing import Any, cast, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -13,11 +13,11 @@ import importlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, cast, Generic, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import Any, cast, Generic, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -14,8 +14,9 @@ import sys
|
||||
import time
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
@ -15,10 +15,11 @@ import time
|
||||
import traceback
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch.distributed.elastic.rendezvous as rdzv
|
||||
import torch.distributed.elastic.utils.store as store_util
|
||||
|
@ -6,7 +6,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
from torch.distributed.elastic.utils.logging import get_logger
|
||||
|
||||
|
@ -63,7 +63,8 @@ was launched a :class:`api.SubprocessContext` is returned. Both are specific
|
||||
implementations of the parent :class:`api.PContext` class.
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Optional, Union
|
||||
|
||||
from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401
|
||||
_validate_full_rank,
|
||||
|
@ -19,12 +19,13 @@ import tempfile
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntFlag
|
||||
from multiprocessing import synchronize
|
||||
from types import FrameType
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record
|
||||
|
@ -54,11 +54,12 @@ import os
|
||||
import signal
|
||||
import socket
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from string import Template
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from typing import Any, Optional, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from torch.distributed.elastic.utils.logging import get_logger
|
||||
|
@ -7,8 +7,9 @@
|
||||
|
||||
import socket
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, ClassVar, Optional
|
||||
from typing import Any, ClassVar, Optional
|
||||
|
||||
from torch.distributed import Store
|
||||
from torch.distributed.elastic.utils.distributed import get_free_port
|
||||
|
@ -14,10 +14,11 @@ import threading
|
||||
import time
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import Store
|
||||
|
@ -11,9 +11,10 @@ import re
|
||||
import socket
|
||||
import time
|
||||
import weakref
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from threading import Event, Thread
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
|
||||
__all__ = ["parse_rendezvous_endpoint"]
|
||||
|
@ -13,7 +13,8 @@ import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Optional, TypeVar
|
||||
from collections.abc import Callable
|
||||
from typing import Optional, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from torch.distributed.elastic.timer.api import TimerClient, TimerRequest
|
||||
|
@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Callable, TypeVar
|
||||
from collections.abc import Callable, Iterator
|
||||
from typing import TypeVar
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
|
@ -7,10 +7,10 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Callable, Iterable
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -7,11 +7,11 @@ import logging
|
||||
import traceback
|
||||
import warnings
|
||||
import weakref
|
||||
from collections.abc import Generator, Iterable
|
||||
from collections.abc import Callable, Generator, Iterable
|
||||
from enum import auto, Enum
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, cast, no_type_check, Optional, TYPE_CHECKING
|
||||
from typing import Any, cast, no_type_check, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -4,10 +4,10 @@ import functools
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Generator, Iterator, Sequence
|
||||
from collections.abc import Callable, Generator, Iterator, Sequence
|
||||
from enum import auto, Enum
|
||||
from itertools import accumulate, chain
|
||||
from typing import Any, Callable, cast, NamedTuple, no_type_check, Optional, Union
|
||||
from typing import Any, cast, NamedTuple, no_type_check, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -1,7 +1,7 @@
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, cast, NamedTuple, Optional, Union
|
||||
from typing import Any, cast, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -1,10 +1,10 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import inspect
|
||||
import itertools
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, cast, Optional
|
||||
from typing import Any, cast, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -1,7 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import Any, Callable, cast, NamedTuple, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -2,8 +2,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -4,16 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
NoReturn,
|
||||
Optional,
|
||||
overload,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, cast, NoReturn, Optional, overload, TYPE_CHECKING, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
@ -36,7 +27,7 @@ from ._fsdp_state import _get_module_fsdp_state, FSDPState
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
from torch.distributed.tensor import DeviceMesh, Shard
|
||||
|
||||
|
@ -3,8 +3,8 @@ import collections
|
||||
import itertools
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Generator, Iterable, Iterator
|
||||
from typing import Any, Callable, no_type_check, Optional, TYPE_CHECKING, Union
|
||||
from collections.abc import Callable, Generator, Iterable, Iterator
|
||||
from typing import Any, no_type_check, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -1,8 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, no_type_check, Optional
|
||||
from typing import Any, no_type_check, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -3,8 +3,8 @@ import contextlib
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Any, Callable, cast, no_type_check
|
||||
from collections.abc import Callable, Generator, Iterator
|
||||
from typing import Any, cast, no_type_check
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -1,8 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, NamedTuple, Optional
|
||||
from typing import Any, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -3,8 +3,9 @@ import collections
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp._common_utils import (
|
||||
|
@ -6,10 +6,10 @@ import functools
|
||||
import math
|
||||
import traceback
|
||||
import warnings
|
||||
from collections.abc import Generator, Iterable, Iterator
|
||||
from collections.abc import Callable, Generator, Iterable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -7,8 +7,8 @@
|
||||
import contextlib
|
||||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
from collections.abc import Callable, Generator, Iterable, Sequence
|
||||
from typing import Any, cast, Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -9,8 +9,9 @@
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
|
||||
|
@ -4,8 +4,8 @@ import collections
|
||||
import io
|
||||
import sys
|
||||
import types
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from collections.abc import Callable, Iterator, Mapping
|
||||
from typing import Any, Optional, TypeVar, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
|
@ -1,8 +1,8 @@
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Collection, Mapping
|
||||
from collections.abc import Callable, Collection, Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Optional, overload, Union
|
||||
from typing import Any, Optional, overload, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -11,8 +11,9 @@ import enum
|
||||
import inspect
|
||||
import io
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import enum
|
||||
from typing import Any, Callable, overload
|
||||
from collections.abc import Callable
|
||||
from typing import Any, overload
|
||||
|
||||
import torch
|
||||
from torch.distributed.algorithms.join import Joinable, JoinHook
|
||||
|
@ -4,10 +4,11 @@ import copy
|
||||
import logging
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from inspect import Parameter, Signature, signature
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
|
@ -8,9 +8,10 @@ import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, NamedTuple, Optional, Union
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -3,7 +3,8 @@
|
||||
import logging
|
||||
import operator
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -9,9 +9,9 @@ except ImportError as e:
|
||||
import numbers
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Callable, Iterator
|
||||
from datetime import timedelta
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
from torch.distributed import FileStore, Store, TCPStore
|
||||
|
||||
|
@ -373,8 +373,9 @@ import os
|
||||
import sys
|
||||
import uuid
|
||||
from argparse import ArgumentParser, REMAINDER
|
||||
from collections.abc import Callable
|
||||
from importlib import metadata
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed.argparse_util import check_env, env
|
||||
|
@ -3,8 +3,8 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import inspect
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, cast, Optional
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, cast, Optional
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
@ -1,8 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from collections.abc import Iterable, Sequence
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, cast, Optional, Union
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -3,8 +3,8 @@
|
||||
import functools
|
||||
import itertools
|
||||
import operator
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Callable, cast, Optional, TypeVar, Union
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import cast, Optional, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
|
@ -1,9 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
from typing import Callable, cast, Optional, Union
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverload
|
||||
|
@ -5,7 +5,7 @@ torchrun --standalone --nnodes=1 --nproc-per-node=4 comm_mode_features_example.p
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Callable, Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -26,6 +26,10 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
def get_device_type() -> str:
|
||||
device_type = "cpu"
|
||||
if torch.accelerator.device_count() >= 4:
|
||||
|
@ -3,10 +3,10 @@ import itertools
|
||||
import logging
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Generator
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, Optional, Protocol
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -1,8 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Optional, Union
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
|
@ -1,8 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Callable, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverload
|
||||
|
@ -2,8 +2,8 @@
|
||||
import dataclasses
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Container
|
||||
from typing import Any, Callable, Optional, overload, TypeVar
|
||||
from collections.abc import Callable, Container
|
||||
from typing import Any, Optional, overload, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
Reference in New Issue
Block a user