mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][DTensor] move torch.distributed._tensor import to torch.distributed.tensor in test files (#153225)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153225 Approved by: https://github.com/kwen2501, https://github.com/fegin
This commit is contained in:
committed by
PyTorch MergeBot
parent
3976e52264
commit
cbb03e6971
@ -17,7 +17,7 @@ from torch import nn
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import comms
|
||||
from torch._inductor.utils import is_fallback_op, run_and_get_code
|
||||
from torch.distributed._tensor import init_device_mesh
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import (
|
||||
fully_shard,
|
||||
FullyShardedDataParallel as FSDP,
|
||||
@ -731,15 +731,17 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
with self._reinplace_all_gather_with_optional_checks(
|
||||
fwd_fullgraph
|
||||
), torch._inductor.config.patch(
|
||||
post_grad_custom_post_pass=functools.partial(
|
||||
self._check_fsdp_copy_and_resize_ops_count_in_graph,
|
||||
fwd_copy_count=0,
|
||||
fwd_resize_count=0,
|
||||
bwd_copy_count=0,
|
||||
bwd_resize_count=0,
|
||||
post_grad_custom_post_pass=(
|
||||
functools.partial(
|
||||
self._check_fsdp_copy_and_resize_ops_count_in_graph,
|
||||
fwd_copy_count=0,
|
||||
fwd_resize_count=0,
|
||||
bwd_copy_count=0,
|
||||
bwd_resize_count=0,
|
||||
)
|
||||
if fwd_fullgraph
|
||||
else None
|
||||
)
|
||||
if fwd_fullgraph
|
||||
else None
|
||||
):
|
||||
_, triton_codes = run_and_get_code(
|
||||
lambda: self._test_traceable_fsdp(
|
||||
@ -954,18 +956,20 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
with self._reinplace_all_gather_with_optional_checks(
|
||||
fwd_fullgraph
|
||||
), torch._inductor.config.patch(
|
||||
post_grad_custom_post_pass=functools.partial(
|
||||
self._check_fsdp_copy_and_resize_ops_count_in_graph,
|
||||
# NOTE: For the root unsharded params, we don't reshard after forward since for training,
|
||||
# the parameters would be freed and all-gathered immediately. Hence we still have
|
||||
# their resize and copy ops in the graph.
|
||||
fwd_copy_count=4,
|
||||
fwd_resize_count=4,
|
||||
bwd_copy_count=0,
|
||||
bwd_resize_count=4,
|
||||
post_grad_custom_post_pass=(
|
||||
functools.partial(
|
||||
self._check_fsdp_copy_and_resize_ops_count_in_graph,
|
||||
# NOTE: For the root unsharded params, we don't reshard after forward since for training,
|
||||
# the parameters would be freed and all-gathered immediately. Hence we still have
|
||||
# their resize and copy ops in the graph.
|
||||
fwd_copy_count=4,
|
||||
fwd_resize_count=4,
|
||||
bwd_copy_count=0,
|
||||
bwd_resize_count=4,
|
||||
)
|
||||
if fwd_fullgraph
|
||||
else None
|
||||
)
|
||||
if fwd_fullgraph
|
||||
else None
|
||||
):
|
||||
_, triton_codes = run_and_get_code(
|
||||
lambda: self._test_traceable_fsdp(
|
||||
@ -988,9 +992,9 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
file_check = FileCheck().check("def call(args):")
|
||||
for fwd_ag_block_info in [
|
||||
dict(
|
||||
overlapped_compute_op_str="triton_"
|
||||
if all_requires_grad
|
||||
else None,
|
||||
overlapped_compute_op_str=(
|
||||
"triton_" if all_requires_grad else None
|
||||
),
|
||||
),
|
||||
dict(
|
||||
overlapped_compute_op_str="aten.native_dropout.",
|
||||
@ -1029,9 +1033,11 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
# )
|
||||
pass
|
||||
for bwd_rs_block_info in [
|
||||
dict(overlapped_compute_op_str="extern_kernels.mm(")
|
||||
if all_requires_grad
|
||||
else None,
|
||||
(
|
||||
dict(overlapped_compute_op_str="extern_kernels.mm(")
|
||||
if all_requires_grad
|
||||
else None
|
||||
),
|
||||
dict(
|
||||
overlapped_compute_op_str=None
|
||||
), # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None
|
||||
|
@ -4,7 +4,7 @@ import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.amp.grad_scaler import GradScaler, OptState
|
||||
from torch.distributed._tensor import init_device_mesh
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
|
@ -7,9 +7,9 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2
|
||||
from torch.distributed._tensor.experimental import implicit_replication
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.distributed.tensor.experimental import implicit_replication
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest
|
||||
|
@ -9,13 +9,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable import replicate
|
||||
from torch.distributed._tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_init import (
|
||||
@ -31,6 +24,13 @@ from torch.distributed.fsdp._init_utils import (
|
||||
_init_inter_node_process_group,
|
||||
_init_intra_node_process_group,
|
||||
)
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
|
@ -7,8 +7,8 @@ from typing import Callable
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tensor.experimental import implicit_replication
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from torch.distributed.tensor.experimental import implicit_replication
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
FSDPTest,
|
||||
|
@ -12,7 +12,6 @@ import torch.distributed.checkpoint as dcp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._composable import replicate
|
||||
from torch.distributed._tensor import DTensor, init_device_mesh, Replicate, Shard
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
get_model_state_dict,
|
||||
get_optimizer_state_dict,
|
||||
@ -20,7 +19,7 @@ from torch.distributed.checkpoint.state_dict import (
|
||||
set_optimizer_state_dict,
|
||||
StateDictOptions,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
fully_shard,
|
||||
@ -31,6 +30,7 @@ from torch.distributed.fsdp._common_utils import (
|
||||
clean_tensor_name,
|
||||
)
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
||||
from torch.distributed.tensor import DTensor, Replicate, Shard
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
|
@ -8,8 +8,8 @@ import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.distributed._composable.replicate import replicate
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
skip_if_lt_x_gpu,
|
||||
|
@ -414,7 +414,7 @@ class DDP_TP_Test(InductorTestCase):
|
||||
# https://github.com/pytorch/pytorch/issues/127797#issuecomment-2291695474
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
"Expected ProxyTensor, got <class 'torch.distributed._tensor.api.DTensor'>",
|
||||
"Expected ProxyTensor, got <class 'torch.distributed.tensor.DTensor'>",
|
||||
):
|
||||
loss.backward()
|
||||
|
||||
|
@ -6,12 +6,12 @@ from typing import Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable import checkpoint
|
||||
from torch.distributed._tensor import init_device_mesh
|
||||
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
apply_activation_checkpointing,
|
||||
CheckpointWrapper,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
fully_shard,
|
||||
|
@ -13,7 +13,6 @@ import torch.distributed.checkpoint as DCP
|
||||
import torch.distributed.checkpoint.state_dict_saver as saver
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._tensor.device_mesh import init_device_mesh
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
_patch_model_state_dict,
|
||||
_patch_optimizer_state_dict,
|
||||
@ -26,6 +25,7 @@ from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from
|
||||
from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType
|
||||
from torch.distributed.checkpoint.stateful import Stateful
|
||||
from torch.distributed.checkpoint.utils import CheckpointException
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.distributed_c10d import ReduceOp
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.api import ShardingStrategy
|
||||
@ -262,9 +262,11 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
||||
f = saver.async_save(
|
||||
sd,
|
||||
storage_writer=writer,
|
||||
async_checkpointer_type=async_checkpointer_type
|
||||
if async_checkpointer_type
|
||||
else AsyncCheckpointerType.THREAD,
|
||||
async_checkpointer_type=(
|
||||
async_checkpointer_type
|
||||
if async_checkpointer_type
|
||||
else AsyncCheckpointerType.THREAD
|
||||
),
|
||||
)
|
||||
t = time.monotonic()
|
||||
while not f.done():
|
||||
|
@ -7,7 +7,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tensor import init_device_mesh
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
get_model_state_dict,
|
||||
get_state_dict,
|
||||
@ -15,6 +14,7 @@ from torch.distributed.checkpoint.state_dict import (
|
||||
set_state_dict,
|
||||
StateDictOptions,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
|
@ -2,10 +2,10 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed.checkpoint.state_dict import get_state_dict
|
||||
from torch.distributed.device_mesh import _mesh_resources, init_device_mesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
|
@ -6,19 +6,20 @@ import copy
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tensor import DTensor, init_device_mesh
|
||||
from torch.distributed._tensor.experimental import implicit_replication
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
get_model_state_dict,
|
||||
get_optimizer_state_dict,
|
||||
StateDictOptions,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import (
|
||||
fully_shard,
|
||||
FullyShardedDataParallel as FSDP,
|
||||
StateDictType,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import always_wrap_policy
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.distributed.tensor.experimental import implicit_replication
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
|
@ -4,7 +4,7 @@ from typing import Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
from torch.distributed._tensor import (
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
|
@ -1,14 +1,9 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
from torch.distributed._tensor import (
|
||||
distribute_tensor,
|
||||
init_device_mesh,
|
||||
Replicate,
|
||||
Shard,
|
||||
zeros,
|
||||
)
|
||||
from torch.distributed.checkpoint._extension import ZStandard
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import distribute_tensor, Replicate, Shard, zeros
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
|
@ -5,13 +5,13 @@ import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._tensor.device_mesh import init_device_mesh
|
||||
from torch.distributed.checkpoint.format_utils import (
|
||||
BroadcastingTorchSaveReader,
|
||||
dcp_to_torch_save,
|
||||
DynamicMetaLoadPlanner,
|
||||
torch_save_to_dcp,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
@ -3,9 +3,10 @@ import torch
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
from torch.distributed._state_dict_utils import _all_gather_sharded_tensor
|
||||
from torch.distributed._tensor import DTensor, init_device_mesh, Replicate
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
||||
from torch.distributed.tensor import DTensor, Replicate
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
|
@ -5,16 +5,17 @@ import torch
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._tensor import init_device_mesh, Replicate
|
||||
from torch.distributed.checkpoint.default_planner import (
|
||||
DefaultLoadPlanner,
|
||||
DefaultSavePlanner,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
ShardingStrategy,
|
||||
StateDictType,
|
||||
)
|
||||
from torch.distributed.tensor import Replicate
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
|
@ -5,7 +5,7 @@ from unittest.mock import patch
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tensor.device_mesh import init_device_mesh
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
|
@ -11,7 +11,6 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable import replicate
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
from torch.distributed._tensor import DTensor, init_device_mesh
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
apply_activation_checkpointing,
|
||||
)
|
||||
@ -26,6 +25,7 @@ from torch.distributed.checkpoint.state_dict import (
|
||||
set_optimizer_state_dict,
|
||||
StateDictOptions,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import (
|
||||
fully_shard,
|
||||
FullyShardedDataParallel as FSDP,
|
||||
@ -34,6 +34,7 @@ from torch.distributed.fsdp import (
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.distributed.optim import _apply_optimizer_in_backward
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
|
@ -13,12 +13,8 @@ from torch.distributed._state_dict_utils import (
|
||||
_gather_state_dict,
|
||||
_offload_state_dict_to_cpu,
|
||||
)
|
||||
from torch.distributed._tensor import (
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
init_device_mesh,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import distribute_tensor, DTensor, Shard
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
|
@ -4,11 +4,11 @@ from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dcp
|
||||
from torch.distributed._tensor import init_device_mesh
|
||||
from torch.distributed.checkpoint.default_planner import (
|
||||
DefaultLoadPlanner,
|
||||
DefaultSavePlanner,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
|
@ -5,7 +5,6 @@ from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
from torch.distributed._tensor import DTensor, Shard
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.api import (
|
||||
@ -13,6 +12,7 @@ from torch.distributed.fsdp.api import (
|
||||
ShardedStateDictConfig,
|
||||
StateDictType,
|
||||
)
|
||||
from torch.distributed.tensor import DTensor, Shard
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
from torch.testing._internal.common_utils import parametrize, run_tests
|
||||
|
@ -6,19 +6,19 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch.distributed._tensor import (
|
||||
DeviceMesh,
|
||||
distribute_module,
|
||||
DTensor,
|
||||
init_device_mesh,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
CPUOffload,
|
||||
FullyShardedDataParallel as FSDP,
|
||||
ShardingStrategy,
|
||||
)
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_module,
|
||||
DTensor,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
@ -190,9 +190,11 @@ class TestTPFSDPIntegration(FSDPTest):
|
||||
local_grads_as_flattened = (
|
||||
torch.cat(
|
||||
[
|
||||
torch.flatten(param.grad)
|
||||
if param.grad is not None
|
||||
else torch.zeros_like(torch.flatten(param))
|
||||
(
|
||||
torch.flatten(param.grad)
|
||||
if param.grad is not None
|
||||
else torch.zeros_like(torch.flatten(param))
|
||||
)
|
||||
for param in model.parameters()
|
||||
]
|
||||
)
|
||||
|
@ -7,7 +7,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
from torch.distributed._tensor import DTensor, Replicate, Shard
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.api import (
|
||||
@ -16,6 +15,7 @@ from torch.distributed.fsdp.api import (
|
||||
ShardingStrategy,
|
||||
StateDictType,
|
||||
)
|
||||
from torch.distributed.tensor import DTensor, Replicate, Shard
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
from torch.testing._internal.common_utils import parametrize, run_tests
|
||||
|
@ -4,8 +4,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tensor import DeviceMesh, DTensor
|
||||
from torch.distributed._tensor.placement_types import Shard
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Shard
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.testing._internal.common_distributed import requires_nccl
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
@ -4,8 +4,7 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import DeviceMesh
|
||||
from torch.distributed._tensor.api import distribute_tensor, DTensor
|
||||
from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
|
@ -4,15 +4,10 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
from torch.distributed._tensor import (
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
init_device_mesh,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed._tensor.experimental import local_map
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.experimental import local_map
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
|
@ -3,9 +3,9 @@
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
|
||||
from torch.distributed._tensor.experimental import register_sharding
|
||||
from torch.distributed._tensor.placement_types import DTensorSpec
|
||||
from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
||||
from torch.distributed.tensor.experimental import register_sharding
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
|
@ -2,7 +2,7 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor.experimental._tp_transform import (
|
||||
from torch.distributed.tensor.experimental._tp_transform import (
|
||||
tensor_parallel_transformation,
|
||||
)
|
||||
from torch.distributed.tensor.parallel.style import (
|
||||
|
@ -19,9 +19,8 @@ from torch.distributed._functional_collectives import (
|
||||
reduce_scatter_tensor,
|
||||
)
|
||||
from torch.distributed._symmetric_memory import _test_mode
|
||||
from torch.distributed._tensor import DeviceMesh
|
||||
from torch.distributed._tensor.placement_types import Shard
|
||||
from torch.distributed.distributed_c10d import _get_group_size_by_name
|
||||
from torch.distributed.tensor import DeviceMesh, Shard
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
|
@ -3,7 +3,7 @@ from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.parallel.api import parallelize_module
|
||||
from torch.distributed.tensor.parallel.style import (
|
||||
|
@ -8,17 +8,17 @@ from typing import NamedTuple, Optional
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._tensor import (
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
checkpoint_wrapper,
|
||||
CheckpointImpl,
|
||||
)
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
checkpoint_wrapper,
|
||||
CheckpointImpl,
|
||||
)
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
@ -259,7 +259,7 @@ class DistTensorParallelExampleTest(DTensorTestBase):
|
||||
check_comms=True,
|
||||
):
|
||||
optim.step() # Ensure model weights are still the same after update.
|
||||
from torch.distributed._tensor.experimental import implicit_replication
|
||||
from torch.distributed.tensor.experimental import implicit_replication
|
||||
|
||||
with implicit_replication():
|
||||
with CommDebugMode() as comm_mode:
|
||||
|
@ -2,7 +2,8 @@
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
import torch.distributed.tensor._random as random
|
||||
from torch.distributed._tensor import init_device_mesh, Replicate
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import Replicate
|
||||
from torch.distributed.tensor.parallel.api import parallelize_module
|
||||
from torch.distributed.tensor.parallel.style import ColwiseParallel
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
|
@ -5,13 +5,8 @@ from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tensor import (
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
init_device_mesh,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.parallel import parallelize_module
|
||||
from torch.distributed.tensor.parallel.style import (
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tensor import (
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_module,
|
||||
distribute_tensor,
|
||||
|
@ -2,8 +2,8 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import DeviceMesh
|
||||
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor import DeviceMesh
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor._op_schema import OpSchema
|
||||
from torch.distributed.tensor._ops._common_rules import einop_rule, pointwise_rule
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
@ -11,22 +11,19 @@ from numpy.testing import assert_array_equal
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
from torch.distributed._tensor import (
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
init_device_mesh,
|
||||
)
|
||||
from torch.distributed._tensor.experimental import implicit_replication
|
||||
from torch.distributed._tensor.placement_types import (
|
||||
DTensorSpec,
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
TensorMeta,
|
||||
)
|
||||
from torch.distributed.tensor._api import _shard_tensor
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.experimental import implicit_replication
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
@ -743,7 +740,7 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
),
|
||||
]
|
||||
|
||||
from torch.distributed._tensor._utils import (
|
||||
from torch.distributed.tensor._utils import (
|
||||
compute_local_shape_and_global_offset,
|
||||
)
|
||||
|
||||
@ -1009,7 +1006,8 @@ class DTensorLogTest(LoggingTestCase):
|
||||
"""\
|
||||
import logging
|
||||
import torch
|
||||
from torch.distributed._tensor import init_device_mesh, distribute_tensor, Shard
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import distribute_tensor, Shard
|
||||
|
||||
mesh = init_device_mesh("cuda", (1,), mesh_dim_names=("dp",))
|
||||
placements = [Shard(0)]
|
||||
|
@ -13,20 +13,14 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch._C import FileCheck
|
||||
from torch._inductor.utils import run_and_get_triton_code
|
||||
from torch.distributed._tensor import (
|
||||
DeviceMesh,
|
||||
DTensor,
|
||||
init_device_mesh,
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
checkpoint_wrapper,
|
||||
CheckpointImpl,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
|
@ -7,7 +7,7 @@ import warnings
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.testing._internal.common_methods_invocations as common_ops
|
||||
from torch.distributed._tensor import DeviceMesh, DTensor
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor
|
||||
from torch.overrides import resolve_name
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
|
@ -3,7 +3,7 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import (
|
||||
from torch.distributed.tensor import (
|
||||
distribute_module,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate
|
||||
from torch.distributed.tensor import DeviceMesh, distribute_tensor, Replicate
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
|
@ -2,7 +2,7 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard, zeros
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard, zeros
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
@ -94,7 +94,7 @@ class DTensorConstructorTest(DTensorTestBase):
|
||||
def test_ones(self):
|
||||
self._run_init_op(
|
||||
torch.ones,
|
||||
torch.distributed._tensor.ones,
|
||||
torch.distributed.tensor.ones,
|
||||
self.assertEqual,
|
||||
requires_grad=True,
|
||||
)
|
||||
@ -103,7 +103,7 @@ class DTensorConstructorTest(DTensorTestBase):
|
||||
def test_empty(self):
|
||||
self._run_init_op(
|
||||
torch.empty,
|
||||
torch.distributed._tensor.empty,
|
||||
torch.distributed.tensor.empty,
|
||||
lambda x, y: (x.shape == y.shape)
|
||||
and (x.dtype == y.dtype)
|
||||
and (x.layout == y.layout),
|
||||
@ -114,7 +114,7 @@ class DTensorConstructorTest(DTensorTestBase):
|
||||
def test_full(self):
|
||||
self._run_init_op(
|
||||
torch.full,
|
||||
torch.distributed._tensor.full,
|
||||
torch.distributed.tensor.full,
|
||||
self.assertEqual,
|
||||
123.4,
|
||||
requires_grad=True,
|
||||
@ -124,7 +124,7 @@ class DTensorConstructorTest(DTensorTestBase):
|
||||
def test_zeros(self):
|
||||
self._run_init_op(
|
||||
torch.zeros,
|
||||
torch.distributed._tensor.zeros,
|
||||
torch.distributed.tensor.zeros,
|
||||
self.assertEqual,
|
||||
requires_grad=True,
|
||||
)
|
||||
|
@ -7,13 +7,14 @@ from pprint import pformat
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor.placement_types import Replicate, Shard
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_module,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
init_device_mesh,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor._ops.utils import is_tensor_partial, normalize_dim
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
@ -310,7 +311,7 @@ class DistMathOpsTest(DTensorTestBase):
|
||||
f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
|
||||
)
|
||||
|
||||
from torch.distributed._tensor.placement_types import TensorMeta
|
||||
from torch.distributed.tensor._dtensor_spec import TensorMeta
|
||||
|
||||
dtensor_meta = y_dist._spec.tensor_meta
|
||||
assert isinstance(dtensor_meta, TensorMeta)
|
||||
|
@ -3,15 +3,9 @@
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import DeviceMesh, DTensor
|
||||
from torch.distributed._tensor.placement_types import (
|
||||
DTensorSpec,
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
TensorMeta,
|
||||
)
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
|
||||
from torch.distributed.tensor._collective_utils import redistribute_cost
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor._op_schema import OpSchema, OpStrategy, PlacementStrategy
|
||||
from torch.distributed.tensor._ops._einsum_strategy import (
|
||||
EinsumDims,
|
||||
|
@ -4,7 +4,7 @@ from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tensor import (
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_module,
|
||||
distribute_tensor,
|
||||
|
@ -8,8 +8,10 @@ from unittest import skip
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import Tensor
|
||||
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
|
||||
from torch.distributed._tensor.placement_types import (
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Partial,
|
||||
Placement,
|
||||
Replicate,
|
||||
|
@ -6,17 +6,22 @@ import itertools
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
import torch.distributed.tensor._random as random
|
||||
from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh
|
||||
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
|
||||
from torch.distributed._tensor.api import distribute_tensor
|
||||
from torch.distributed._tensor.placement_types import Replicate, Shard
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.distributed_c10d import broadcast_object_list
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor._random import (
|
||||
is_rng_supported_mesh,
|
||||
manual_seed,
|
||||
OffsetBasedRNGTracker,
|
||||
)
|
||||
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_HPU
|
||||
@ -396,8 +401,8 @@ class DistTensorRandomOpTest(DTensorTestBase):
|
||||
size = [4, 4 * self.world_size]
|
||||
|
||||
for fn in [
|
||||
torch.distributed._tensor.rand,
|
||||
torch.distributed._tensor.randn,
|
||||
torch.distributed.tensor.rand,
|
||||
torch.distributed.tensor.randn,
|
||||
]:
|
||||
dtensor = fn(size, device_mesh=device_mesh, placements=[Shard(1)])
|
||||
local_tensor = funcol.all_gather_tensor(
|
||||
|
@ -4,9 +4,15 @@
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
|
||||
from torch.distributed._tensor.placement_types import Partial, Replicate, Shard
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TEST_HPU
|
||||
|
@ -2,8 +2,14 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
|
||||
from torch.distributed._tensor.placement_types import Partial, Replicate, Shard
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
||||
|
@ -10,7 +10,7 @@ import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.distributed._tensor import (
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_module,
|
||||
distribute_tensor,
|
||||
|
@ -8,7 +8,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
|
||||
from torch.distributed.pipelining import PipelineStage
|
||||
@ -20,6 +19,7 @@ from torch.distributed.pipelining.schedules import (
|
||||
ScheduleInterleavedZeroBubble,
|
||||
ScheduleLoopedBFS,
|
||||
)
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import (
|
||||
|
@ -5,7 +5,6 @@ import os
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_get_default_group,
|
||||
@ -17,6 +16,7 @@ from torch.distributed.distributed_c10d import (
|
||||
new_group,
|
||||
ProcessGroup,
|
||||
)
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.distributed.tensor._collective_utils import (
|
||||
mesh_broadcast,
|
||||
mesh_scatter,
|
||||
|
@ -7,8 +7,9 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tensor import DeviceMesh, init_device_mesh, Shard
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.tensor import DeviceMesh, Shard
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
|
@ -7,8 +7,8 @@ from functools import partial, wraps
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._functional_collectives as ft_c
|
||||
import torch.distributed._tensor as dt
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
import torch.distributed.tensor as dt
|
||||
from functorch import make_fx
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing import FileCheck
|
||||
|
@ -996,8 +996,6 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def]
|
||||
placements: Optional[Sequence[Placement]] = None,
|
||||
**kwargs,
|
||||
) -> DTensor:
|
||||
# from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
|
||||
|
||||
# if device_mesh is None, use the one from mesh resources
|
||||
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
|
||||
kwargs["device"] = device_mesh.device_type
|
||||
|
@ -1372,7 +1372,7 @@ def context_parallel(
|
||||
these buffers can be put in this list to avoid extra restore time.
|
||||
|
||||
.. warning::
|
||||
`torch.distributed._tensor.experimental.attention.context_parallel` is a
|
||||
`torch.distributed.tensor.experimental.context_parallel` is a
|
||||
prototype feature in PyTorch. The API is subject to change.
|
||||
"""
|
||||
buffers = [] if buffers is None else buffers
|
||||
|
@ -14,8 +14,13 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch._utils import _get_device_module
|
||||
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
|
||||
from torch.distributed._tensor.placement_types import Placement
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
Placement,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
|
@ -10,13 +10,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed._sharded_tensor import ShardedTensor
|
||||
from torch.distributed._state_dict_utils import _gather_state_dict
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
_PG,
|
||||
_STATE,
|
||||
set_state_dict,
|
||||
StateDictOptions,
|
||||
)
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
|
||||
class VerifyStateDictMixin:
|
||||
|
Reference in New Issue
Block a user