[BE][DTensor] move torch.distributed._tensor import to torch.distributed.tensor in test files (#153225)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153225
Approved by: https://github.com/kwen2501, https://github.com/fegin
This commit is contained in:
Xilun Wu
2025-05-09 01:50:00 -07:00
committed by PyTorch MergeBot
parent 3976e52264
commit cbb03e6971
59 changed files with 182 additions and 179 deletions

View File

@ -17,7 +17,7 @@ from torch import nn
from torch._dynamo.utils import counters
from torch._inductor import comms
from torch._inductor.utils import is_fallback_op, run_and_get_code
from torch.distributed._tensor import init_device_mesh
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import (
fully_shard,
FullyShardedDataParallel as FSDP,
@ -731,15 +731,17 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
with self._reinplace_all_gather_with_optional_checks(
fwd_fullgraph
), torch._inductor.config.patch(
post_grad_custom_post_pass=functools.partial(
self._check_fsdp_copy_and_resize_ops_count_in_graph,
fwd_copy_count=0,
fwd_resize_count=0,
bwd_copy_count=0,
bwd_resize_count=0,
post_grad_custom_post_pass=(
functools.partial(
self._check_fsdp_copy_and_resize_ops_count_in_graph,
fwd_copy_count=0,
fwd_resize_count=0,
bwd_copy_count=0,
bwd_resize_count=0,
)
if fwd_fullgraph
else None
)
if fwd_fullgraph
else None
):
_, triton_codes = run_and_get_code(
lambda: self._test_traceable_fsdp(
@ -954,18 +956,20 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
with self._reinplace_all_gather_with_optional_checks(
fwd_fullgraph
), torch._inductor.config.patch(
post_grad_custom_post_pass=functools.partial(
self._check_fsdp_copy_and_resize_ops_count_in_graph,
# NOTE: For the root unsharded params, we don't reshard after forward since for training,
# the parameters would be freed and all-gathered immediately. Hence we still have
# their resize and copy ops in the graph.
fwd_copy_count=4,
fwd_resize_count=4,
bwd_copy_count=0,
bwd_resize_count=4,
post_grad_custom_post_pass=(
functools.partial(
self._check_fsdp_copy_and_resize_ops_count_in_graph,
# NOTE: For the root unsharded params, we don't reshard after forward since for training,
# the parameters would be freed and all-gathered immediately. Hence we still have
# their resize and copy ops in the graph.
fwd_copy_count=4,
fwd_resize_count=4,
bwd_copy_count=0,
bwd_resize_count=4,
)
if fwd_fullgraph
else None
)
if fwd_fullgraph
else None
):
_, triton_codes = run_and_get_code(
lambda: self._test_traceable_fsdp(
@ -988,9 +992,9 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
file_check = FileCheck().check("def call(args):")
for fwd_ag_block_info in [
dict(
overlapped_compute_op_str="triton_"
if all_requires_grad
else None,
overlapped_compute_op_str=(
"triton_" if all_requires_grad else None
),
),
dict(
overlapped_compute_op_str="aten.native_dropout.",
@ -1029,9 +1033,11 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
# )
pass
for bwd_rs_block_info in [
dict(overlapped_compute_op_str="extern_kernels.mm(")
if all_requires_grad
else None,
(
dict(overlapped_compute_op_str="extern_kernels.mm(")
if all_requires_grad
else None
),
dict(
overlapped_compute_op_str=None
), # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None

View File

@ -4,7 +4,7 @@ import copy
import torch
import torch.nn as nn
from torch.amp.grad_scaler import GradScaler, OptState
from torch.distributed._tensor import init_device_mesh
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,

View File

@ -7,9 +7,9 @@ import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2
from torch.distributed._tensor.experimental import implicit_replication
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.experimental import implicit_replication
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest

View File

@ -9,13 +9,6 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable import replicate
from torch.distributed._tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard
from torch.distributed.fsdp._fully_shard._fsdp_init import (
@ -31,6 +24,13 @@ from torch.distributed.fsdp._init_utils import (
_init_inter_node_process_group,
_init_intra_node_process_group,
)
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,

View File

@ -7,8 +7,8 @@ from typing import Callable
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._tensor.experimental import implicit_replication
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor.experimental import implicit_replication
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
FSDPTest,

View File

@ -12,7 +12,6 @@ import torch.distributed.checkpoint as dcp
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._composable import replicate
from torch.distributed._tensor import DTensor, init_device_mesh, Replicate, Shard
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
@ -20,7 +19,7 @@ from torch.distributed.checkpoint.state_dict import (
set_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import (
CPUOffloadPolicy,
fully_shard,
@ -31,6 +30,7 @@ from torch.distributed.fsdp._common_utils import (
clean_tensor_name,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import (
ColwiseParallel,

View File

@ -8,8 +8,8 @@ import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch.distributed._composable.replicate import replicate
from torch.distributed._tensor import DTensor
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import DTensor
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,

View File

@ -414,7 +414,7 @@ class DDP_TP_Test(InductorTestCase):
# https://github.com/pytorch/pytorch/issues/127797#issuecomment-2291695474
with self.assertRaisesRegex(
AssertionError,
"Expected ProxyTensor, got <class 'torch.distributed._tensor.api.DTensor'>",
"Expected ProxyTensor, got <class 'torch.distributed.tensor.DTensor'>",
):
loss.backward()

View File

@ -6,12 +6,12 @@ from typing import Union
import torch
import torch.nn as nn
from torch.distributed._composable import checkpoint
from torch.distributed._tensor import init_device_mesh
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
CheckpointWrapper,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import (
CPUOffloadPolicy,
fully_shard,

View File

@ -13,7 +13,6 @@ import torch.distributed.checkpoint as DCP
import torch.distributed.checkpoint.state_dict_saver as saver
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.checkpoint.state_dict import (
_patch_model_state_dict,
_patch_optimizer_state_dict,
@ -26,6 +25,7 @@ from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from
from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.utils import CheckpointException
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.distributed_c10d import ReduceOp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy
@ -262,9 +262,11 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
f = saver.async_save(
sd,
storage_writer=writer,
async_checkpointer_type=async_checkpointer_type
if async_checkpointer_type
else AsyncCheckpointerType.THREAD,
async_checkpointer_type=(
async_checkpointer_type
if async_checkpointer_type
else AsyncCheckpointerType.THREAD
),
)
t = time.monotonic()
while not f.done():

View File

@ -7,7 +7,6 @@ import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch.nn as nn
from torch.distributed._tensor import init_device_mesh
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_state_dict,
@ -15,6 +14,7 @@ from torch.distributed.checkpoint.state_dict import (
set_state_dict,
StateDictOptions,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN

View File

@ -2,10 +2,10 @@
import torch
import torch.nn as nn
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import get_state_dict
from torch.distributed.device_mesh import _mesh_resources, init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor import DTensor
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,

View File

@ -6,19 +6,20 @@ import copy
import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed._tensor import DTensor, init_device_mesh
from torch.distributed._tensor.experimental import implicit_replication
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import (
fully_shard,
FullyShardedDataParallel as FSDP,
StateDictType,
)
from torch.distributed.fsdp.wrap import always_wrap_policy
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.experimental import implicit_replication
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,

View File

@ -4,7 +4,7 @@ from typing import Union
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
from torch.distributed._tensor import (
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,

View File

@ -1,14 +1,9 @@
# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed.checkpoint as dist_cp
from torch.distributed._tensor import (
distribute_tensor,
init_device_mesh,
Replicate,
Shard,
zeros,
)
from torch.distributed.checkpoint._extension import ZStandard
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Replicate, Shard, zeros
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,

View File

@ -5,13 +5,13 @@ import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.checkpoint.format_utils import (
BroadcastingTorchSaveReader,
dcp_to_torch_save,
DynamicMetaLoadPlanner,
torch_save_to_dcp,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests

View File

@ -3,9 +3,10 @@ import torch
import torch.distributed.checkpoint as dist_cp
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._state_dict_utils import _all_gather_sharded_tensor
from torch.distributed._tensor import DTensor, init_device_mesh, Replicate
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.tensor import DTensor, Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,

View File

@ -5,16 +5,17 @@ import torch
import torch.distributed.checkpoint as dist_cp
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._tensor import init_device_mesh, Replicate
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
DefaultSavePlanner,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
ShardingStrategy,
StateDictType,
)
from torch.distributed.tensor import Replicate
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,

View File

@ -5,7 +5,7 @@ from unittest.mock import patch
import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (

View File

@ -11,7 +11,6 @@ import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable import replicate
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor, init_device_mesh
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)
@ -26,6 +25,7 @@ from torch.distributed.checkpoint.state_dict import (
set_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import (
fully_shard,
FullyShardedDataParallel as FSDP,
@ -34,6 +34,7 @@ from torch.distributed.fsdp import (
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.optim import _apply_optimizer_in_backward
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,

View File

@ -13,12 +13,8 @@ from torch.distributed._state_dict_utils import (
_gather_state_dict,
_offload_state_dict_to_cpu,
)
from torch.distributed._tensor import (
distribute_tensor,
DTensor,
init_device_mesh,
Shard,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, DTensor, Shard
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,

View File

@ -4,11 +4,11 @@ from copy import deepcopy
import torch
import torch.distributed.checkpoint as dcp
from torch.distributed._tensor import init_device_mesh
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
DefaultSavePlanner,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,

View File

@ -5,7 +5,6 @@ from copy import deepcopy
import torch
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor, Shard
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import (
@ -13,6 +12,7 @@ from torch.distributed.fsdp.api import (
ShardedStateDictConfig,
StateDictType,
)
from torch.distributed.tensor import DTensor, Shard
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import parametrize, run_tests

View File

@ -6,19 +6,19 @@ from typing import Optional
import torch
from torch import distributed as dist
from torch.distributed._tensor import (
DeviceMesh,
distribute_module,
DTensor,
init_device_mesh,
Replicate,
Shard,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
FullyShardedDataParallel as FSDP,
ShardingStrategy,
)
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
DTensor,
Replicate,
Shard,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import (
ColwiseParallel,
@ -190,9 +190,11 @@ class TestTPFSDPIntegration(FSDPTest):
local_grads_as_flattened = (
torch.cat(
[
torch.flatten(param.grad)
if param.grad is not None
else torch.zeros_like(torch.flatten(param))
(
torch.flatten(param.grad)
if param.grad is not None
else torch.zeros_like(torch.flatten(param))
)
for param in model.parameters()
]
)

View File

@ -7,7 +7,6 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor, Replicate, Shard
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import (
@ -16,6 +15,7 @@ from torch.distributed.fsdp.api import (
ShardingStrategy,
StateDictType,
)
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import parametrize, run_tests

View File

@ -4,8 +4,7 @@ import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
import torch.nn as nn
from torch.distributed._tensor import DeviceMesh, DTensor
from torch.distributed._tensor.placement_types import Shard
from torch.distributed.tensor import DeviceMesh, DTensor, Shard
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_distributed import requires_nccl
from torch.testing._internal.common_utils import run_tests, TestCase

View File

@ -4,8 +4,7 @@
from typing import Any
import torch
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor.api import distribute_tensor, DTensor
from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import (
ColwiseParallel,

View File

@ -4,15 +4,10 @@ from functools import partial
import torch
import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import (
distribute_tensor,
DTensor,
init_device_mesh,
Replicate,
Shard,
)
from torch.distributed._tensor.experimental import local_map
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.experimental import local_map
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,

View File

@ -3,9 +3,9 @@
import itertools
import torch
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
from torch.distributed._tensor.experimental import register_sharding
from torch.distributed._tensor.placement_types import DTensorSpec
from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor.experimental import register_sharding
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,

View File

@ -2,7 +2,7 @@
from collections import defaultdict
import torch
from torch.distributed._tensor.experimental._tp_transform import (
from torch.distributed.tensor.experimental._tp_transform import (
tensor_parallel_transformation,
)
from torch.distributed.tensor.parallel.style import (

View File

@ -19,9 +19,8 @@ from torch.distributed._functional_collectives import (
reduce_scatter_tensor,
)
from torch.distributed._symmetric_memory import _test_mode
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor.placement_types import Shard
from torch.distributed.distributed_c10d import _get_group_size_by_name
from torch.distributed.tensor import DeviceMesh, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,

View File

@ -3,7 +3,7 @@ from collections import OrderedDict
from copy import deepcopy
import torch
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel.api import parallelize_module
from torch.distributed.tensor.parallel.style import (

View File

@ -8,17 +8,17 @@ from typing import NamedTuple, Optional
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed._tensor import (
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import (
ColwiseParallel,
@ -259,7 +259,7 @@ class DistTensorParallelExampleTest(DTensorTestBase):
check_comms=True,
):
optim.step() # Ensure model weights are still the same after update.
from torch.distributed._tensor.experimental import implicit_replication
from torch.distributed.tensor.experimental import implicit_replication
with implicit_replication():
with CommDebugMode() as comm_mode:

View File

@ -2,7 +2,8 @@
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._random as random
from torch.distributed._tensor import init_device_mesh, Replicate
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel.api import parallelize_module
from torch.distributed.tensor.parallel.style import ColwiseParallel
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu

View File

@ -5,13 +5,8 @@ from copy import deepcopy
import torch
import torch.nn as nn
from torch.distributed._tensor import (
distribute_tensor,
DTensor,
init_device_mesh,
Replicate,
Shard,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import parallelize_module
from torch.distributed.tensor.parallel.style import (

View File

@ -3,7 +3,7 @@
import torch
import torch.nn as nn
from torch.distributed._tensor import (
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
distribute_tensor,

View File

@ -2,8 +2,8 @@
# Owner(s): ["oncall: distributed"]
import torch
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
from torch.distributed.tensor import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import OpSchema
from torch.distributed.tensor._ops._common_rules import einop_rule, pointwise_rule
from torch.testing._internal.common_utils import run_tests

View File

@ -11,22 +11,19 @@ from numpy.testing import assert_array_equal
import torch
import torch.nn.functional as F
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed._tensor import (
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
init_device_mesh,
)
from torch.distributed._tensor.experimental import implicit_replication
from torch.distributed._tensor.placement_types import (
DTensorSpec,
Partial,
Replicate,
Shard,
TensorMeta,
)
from torch.distributed.tensor._api import _shard_tensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.experimental import implicit_replication
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
@ -743,7 +740,7 @@ class DTensorMeshTest(DTensorTestBase):
),
]
from torch.distributed._tensor._utils import (
from torch.distributed.tensor._utils import (
compute_local_shape_and_global_offset,
)
@ -1009,7 +1006,8 @@ class DTensorLogTest(LoggingTestCase):
"""\
import logging
import torch
from torch.distributed._tensor import init_device_mesh, distribute_tensor, Shard
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Shard
mesh = init_device_mesh("cuda", (1,), mesh_dim_names=("dp",))
placements = [Shard(0)]

View File

@ -13,20 +13,14 @@ import torch.distributed as dist
import torch.nn as nn
from torch._C import FileCheck
from torch._inductor.utils import run_and_get_triton_code
from torch.distributed._tensor import (
DeviceMesh,
DTensor,
init_device_mesh,
Partial,
Replicate,
Shard,
)
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,

View File

@ -7,7 +7,7 @@ import warnings
import torch
import torch.distributed as dist
import torch.testing._internal.common_methods_invocations as common_ops
from torch.distributed._tensor import DeviceMesh, DTensor
from torch.distributed.tensor import DeviceMesh, DTensor
from torch.overrides import resolve_name
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,

View File

@ -3,7 +3,7 @@
import sys
import torch
from torch.distributed._tensor import (
from torch.distributed.tensor import (
distribute_module,
distribute_tensor,
DTensor,

View File

@ -4,7 +4,7 @@
import torch
import torch.distributed as dist
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate
from torch.distributed.tensor import DeviceMesh, distribute_tensor, Replicate
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,

View File

@ -2,7 +2,7 @@
# Owner(s): ["oncall: distributed"]
import torch
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard, zeros
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard, zeros
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
@ -94,7 +94,7 @@ class DTensorConstructorTest(DTensorTestBase):
def test_ones(self):
self._run_init_op(
torch.ones,
torch.distributed._tensor.ones,
torch.distributed.tensor.ones,
self.assertEqual,
requires_grad=True,
)
@ -103,7 +103,7 @@ class DTensorConstructorTest(DTensorTestBase):
def test_empty(self):
self._run_init_op(
torch.empty,
torch.distributed._tensor.empty,
torch.distributed.tensor.empty,
lambda x, y: (x.shape == y.shape)
and (x.dtype == y.dtype)
and (x.layout == y.layout),
@ -114,7 +114,7 @@ class DTensorConstructorTest(DTensorTestBase):
def test_full(self):
self._run_init_op(
torch.full,
torch.distributed._tensor.full,
torch.distributed.tensor.full,
self.assertEqual,
123.4,
requires_grad=True,
@ -124,7 +124,7 @@ class DTensorConstructorTest(DTensorTestBase):
def test_zeros(self):
self._run_init_op(
torch.zeros,
torch.distributed._tensor.zeros,
torch.distributed.tensor.zeros,
self.assertEqual,
requires_grad=True,
)

View File

@ -7,13 +7,14 @@ from pprint import pformat
from typing import NamedTuple
import torch
from torch.distributed._tensor.placement_types import Replicate, Shard
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
distribute_tensor,
DTensor,
init_device_mesh,
Replicate,
Shard,
)
from torch.distributed.tensor._ops.utils import is_tensor_partial, normalize_dim
from torch.distributed.tensor.debug import CommDebugMode
@ -310,7 +311,7 @@ class DistMathOpsTest(DTensorTestBase):
f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
)
from torch.distributed._tensor.placement_types import TensorMeta
from torch.distributed.tensor._dtensor_spec import TensorMeta
dtensor_meta = y_dist._spec.tensor_meta
assert isinstance(dtensor_meta, TensorMeta)

View File

@ -3,15 +3,9 @@
from itertools import chain
import torch
from torch.distributed._tensor import DeviceMesh, DTensor
from torch.distributed._tensor.placement_types import (
DTensorSpec,
Partial,
Replicate,
Shard,
TensorMeta,
)
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.distributed.tensor._collective_utils import redistribute_cost
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import OpSchema, OpStrategy, PlacementStrategy
from torch.distributed.tensor._ops._einsum_strategy import (
EinsumDims,

View File

@ -4,7 +4,7 @@ from copy import deepcopy
import torch
import torch.nn as nn
from torch.distributed._tensor import (
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
distribute_tensor,

View File

@ -8,8 +8,10 @@ from unittest import skip
import torch
import torch.utils._pytree as pytree
from torch import Tensor
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed._tensor.placement_types import (
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Partial,
Placement,
Replicate,

View File

@ -6,17 +6,22 @@ import itertools
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._random as random
from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
from torch.distributed._tensor.api import distribute_tensor
from torch.distributed._tensor.placement_types import Replicate, Shard
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.distributed_c10d import broadcast_object_list
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.distributed.tensor._random import (
is_rng_supported_mesh,
manual_seed,
OffsetBasedRNGTracker,
)
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
from torch.testing._internal.common_utils import run_tests, TEST_HPU
@ -396,8 +401,8 @@ class DistTensorRandomOpTest(DTensorTestBase):
size = [4, 4 * self.world_size]
for fn in [
torch.distributed._tensor.rand,
torch.distributed._tensor.randn,
torch.distributed.tensor.rand,
torch.distributed.tensor.randn,
]:
dtensor = fn(size, device_mesh=device_mesh, placements=[Shard(1)])
local_tensor = funcol.all_gather_tensor(

View File

@ -4,9 +4,15 @@
import itertools
import torch
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed._tensor.placement_types import Partial, Replicate, Shard
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TEST_HPU

View File

@ -2,8 +2,14 @@
# Owner(s): ["oncall: distributed"]
import torch
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed._tensor.placement_types import Partial, Replicate, Shard
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, skipIfRocm

View File

@ -10,7 +10,7 @@ import numpy as np
import torch
from torch import nn
from torch.distributed._tensor import (
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
distribute_tensor,

View File

@ -8,7 +8,6 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed.pipelining import PipelineStage
@ -20,6 +19,7 @@ from torch.distributed.pipelining.schedules import (
ScheduleInterleavedZeroBubble,
ScheduleLoopedBFS,
)
from torch.distributed.tensor import DTensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (

View File

@ -5,7 +5,6 @@ import os
import torch
import torch.distributed._functional_collectives as funcol
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
from torch.distributed.distributed_c10d import (
_get_default_group,
@ -17,6 +16,7 @@ from torch.distributed.distributed_c10d import (
new_group,
ProcessGroup,
)
from torch.distributed.tensor import DTensor
from torch.distributed.tensor._collective_utils import (
mesh_broadcast,
mesh_scatter,

View File

@ -7,8 +7,9 @@ import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
import torch.nn as nn
from torch.distributed._tensor import DeviceMesh, init_device_mesh, Shard
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor import DeviceMesh, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,

View File

@ -7,8 +7,8 @@ from functools import partial, wraps
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as ft_c
import torch.distributed._tensor as dt
import torch.distributed.distributed_c10d as c10d
import torch.distributed.tensor as dt
from functorch import make_fx
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck

View File

@ -996,8 +996,6 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def]
placements: Optional[Sequence[Placement]] = None,
**kwargs,
) -> DTensor:
# from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
# if device_mesh is None, use the one from mesh resources
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
kwargs["device"] = device_mesh.device_type

View File

@ -1372,7 +1372,7 @@ def context_parallel(
these buffers can be put in this list to avoid extra restore time.
.. warning::
`torch.distributed._tensor.experimental.attention.context_parallel` is a
`torch.distributed.tensor.experimental.context_parallel` is a
prototype feature in PyTorch. The API is subject to change.
"""
buffers = [] if buffers is None else buffers

View File

@ -14,8 +14,13 @@ import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch._utils import _get_device_module
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
from torch.distributed._tensor.placement_types import Placement
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
Placement,
Replicate,
Shard,
)
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,

View File

@ -10,13 +10,13 @@ import torch
import torch.nn as nn
from torch.distributed._sharded_tensor import ShardedTensor
from torch.distributed._state_dict_utils import _gather_state_dict
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import (
_PG,
_STATE,
set_state_dict,
StateDictOptions,
)
from torch.distributed.tensor import DTensor
class VerifyStateDictMixin: