[BE][DTensor] move torch.distributed._tensor import to torch.distributed.tensor in test files (#153225)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153225
Approved by: https://github.com/kwen2501, https://github.com/fegin
This commit is contained in:
Xilun Wu
2025-05-09 01:50:00 -07:00
committed by PyTorch MergeBot
parent 3976e52264
commit cbb03e6971
59 changed files with 182 additions and 179 deletions

View File

@ -17,7 +17,7 @@ from torch import nn
from torch._dynamo.utils import counters from torch._dynamo.utils import counters
from torch._inductor import comms from torch._inductor import comms
from torch._inductor.utils import is_fallback_op, run_and_get_code from torch._inductor.utils import is_fallback_op, run_and_get_code
from torch.distributed._tensor import init_device_mesh from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import ( from torch.distributed.fsdp import (
fully_shard, fully_shard,
FullyShardedDataParallel as FSDP, FullyShardedDataParallel as FSDP,
@ -731,15 +731,17 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
with self._reinplace_all_gather_with_optional_checks( with self._reinplace_all_gather_with_optional_checks(
fwd_fullgraph fwd_fullgraph
), torch._inductor.config.patch( ), torch._inductor.config.patch(
post_grad_custom_post_pass=functools.partial( post_grad_custom_post_pass=(
self._check_fsdp_copy_and_resize_ops_count_in_graph, functools.partial(
fwd_copy_count=0, self._check_fsdp_copy_and_resize_ops_count_in_graph,
fwd_resize_count=0, fwd_copy_count=0,
bwd_copy_count=0, fwd_resize_count=0,
bwd_resize_count=0, bwd_copy_count=0,
bwd_resize_count=0,
)
if fwd_fullgraph
else None
) )
if fwd_fullgraph
else None
): ):
_, triton_codes = run_and_get_code( _, triton_codes = run_and_get_code(
lambda: self._test_traceable_fsdp( lambda: self._test_traceable_fsdp(
@ -954,18 +956,20 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
with self._reinplace_all_gather_with_optional_checks( with self._reinplace_all_gather_with_optional_checks(
fwd_fullgraph fwd_fullgraph
), torch._inductor.config.patch( ), torch._inductor.config.patch(
post_grad_custom_post_pass=functools.partial( post_grad_custom_post_pass=(
self._check_fsdp_copy_and_resize_ops_count_in_graph, functools.partial(
# NOTE: For the root unsharded params, we don't reshard after forward since for training, self._check_fsdp_copy_and_resize_ops_count_in_graph,
# the parameters would be freed and all-gathered immediately. Hence we still have # NOTE: For the root unsharded params, we don't reshard after forward since for training,
# their resize and copy ops in the graph. # the parameters would be freed and all-gathered immediately. Hence we still have
fwd_copy_count=4, # their resize and copy ops in the graph.
fwd_resize_count=4, fwd_copy_count=4,
bwd_copy_count=0, fwd_resize_count=4,
bwd_resize_count=4, bwd_copy_count=0,
bwd_resize_count=4,
)
if fwd_fullgraph
else None
) )
if fwd_fullgraph
else None
): ):
_, triton_codes = run_and_get_code( _, triton_codes = run_and_get_code(
lambda: self._test_traceable_fsdp( lambda: self._test_traceable_fsdp(
@ -988,9 +992,9 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
file_check = FileCheck().check("def call(args):") file_check = FileCheck().check("def call(args):")
for fwd_ag_block_info in [ for fwd_ag_block_info in [
dict( dict(
overlapped_compute_op_str="triton_" overlapped_compute_op_str=(
if all_requires_grad "triton_" if all_requires_grad else None
else None, ),
), ),
dict( dict(
overlapped_compute_op_str="aten.native_dropout.", overlapped_compute_op_str="aten.native_dropout.",
@ -1029,9 +1033,11 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
# ) # )
pass pass
for bwd_rs_block_info in [ for bwd_rs_block_info in [
dict(overlapped_compute_op_str="extern_kernels.mm(") (
if all_requires_grad dict(overlapped_compute_op_str="extern_kernels.mm(")
else None, if all_requires_grad
else None
),
dict( dict(
overlapped_compute_op_str=None overlapped_compute_op_str=None
), # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None ), # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None

View File

@ -4,7 +4,7 @@ import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.amp.grad_scaler import GradScaler, OptState from torch.amp.grad_scaler import GradScaler, OptState
from torch.distributed._tensor import init_device_mesh from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,

View File

@ -7,9 +7,9 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed._composable.fsdp import fully_shard from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2 from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2
from torch.distributed._tensor.experimental import implicit_replication
from torch.distributed.device_mesh import init_device_mesh from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor from torch.distributed.tensor import DTensor
from torch.distributed.tensor.experimental import implicit_replication
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_fsdp import FSDPTest

View File

@ -9,13 +9,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed._composable import replicate from torch.distributed._composable import replicate
from torch.distributed._tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.distributed.device_mesh import init_device_mesh from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard from torch.distributed.fsdp import fully_shard
from torch.distributed.fsdp._fully_shard._fsdp_init import ( from torch.distributed.fsdp._fully_shard._fsdp_init import (
@ -31,6 +24,13 @@ from torch.distributed.fsdp._init_utils import (
_init_inter_node_process_group, _init_inter_node_process_group,
_init_intra_node_process_group, _init_intra_node_process_group,
) )
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,
parallelize_module, parallelize_module,

View File

@ -7,8 +7,8 @@ from typing import Callable
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor.experimental import implicit_replication
from torch.distributed.fsdp import fully_shard from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor.experimental import implicit_replication
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import ( from torch.testing._internal.common_fsdp import (
FSDPTest, FSDPTest,

View File

@ -12,7 +12,6 @@ import torch.distributed.checkpoint as dcp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributed._composable import replicate from torch.distributed._composable import replicate
from torch.distributed._tensor import DTensor, init_device_mesh, Replicate, Shard
from torch.distributed.checkpoint.state_dict import ( from torch.distributed.checkpoint.state_dict import (
get_model_state_dict, get_model_state_dict,
get_optimizer_state_dict, get_optimizer_state_dict,
@ -20,7 +19,7 @@ from torch.distributed.checkpoint.state_dict import (
set_optimizer_state_dict, set_optimizer_state_dict,
StateDictOptions, StateDictOptions,
) )
from torch.distributed.device_mesh import DeviceMesh from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import ( from torch.distributed.fsdp import (
CPUOffloadPolicy, CPUOffloadPolicy,
fully_shard, fully_shard,
@ -31,6 +30,7 @@ from torch.distributed.fsdp._common_utils import (
clean_tensor_name, clean_tensor_name,
) )
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,

View File

@ -8,8 +8,8 @@ import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.distributed._composable.replicate import replicate from torch.distributed._composable.replicate import replicate
from torch.distributed._tensor import DTensor
from torch.distributed.fsdp import fully_shard from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import DTensor
from torch.testing._internal.common_distributed import ( from torch.testing._internal.common_distributed import (
MultiProcessTestCase, MultiProcessTestCase,
skip_if_lt_x_gpu, skip_if_lt_x_gpu,

View File

@ -414,7 +414,7 @@ class DDP_TP_Test(InductorTestCase):
# https://github.com/pytorch/pytorch/issues/127797#issuecomment-2291695474 # https://github.com/pytorch/pytorch/issues/127797#issuecomment-2291695474
with self.assertRaisesRegex( with self.assertRaisesRegex(
AssertionError, AssertionError,
"Expected ProxyTensor, got <class 'torch.distributed._tensor.api.DTensor'>", "Expected ProxyTensor, got <class 'torch.distributed.tensor.DTensor'>",
): ):
loss.backward() loss.backward()

View File

@ -6,12 +6,12 @@ from typing import Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed._composable import checkpoint from torch.distributed._composable import checkpoint
from torch.distributed._tensor import init_device_mesh
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing, apply_activation_checkpointing,
CheckpointWrapper, CheckpointWrapper,
) )
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import ( from torch.distributed.fsdp import (
CPUOffloadPolicy, CPUOffloadPolicy,
fully_shard, fully_shard,

View File

@ -13,7 +13,6 @@ import torch.distributed.checkpoint as DCP
import torch.distributed.checkpoint.state_dict_saver as saver import torch.distributed.checkpoint.state_dict_saver as saver
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.checkpoint.state_dict import ( from torch.distributed.checkpoint.state_dict import (
_patch_model_state_dict, _patch_model_state_dict,
_patch_optimizer_state_dict, _patch_optimizer_state_dict,
@ -26,6 +25,7 @@ from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from
from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType
from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.utils import CheckpointException from torch.distributed.checkpoint.utils import CheckpointException
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.distributed_c10d import ReduceOp from torch.distributed.distributed_c10d import ReduceOp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy from torch.distributed.fsdp.api import ShardingStrategy
@ -262,9 +262,11 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
f = saver.async_save( f = saver.async_save(
sd, sd,
storage_writer=writer, storage_writer=writer,
async_checkpointer_type=async_checkpointer_type async_checkpointer_type=(
if async_checkpointer_type async_checkpointer_type
else AsyncCheckpointerType.THREAD, if async_checkpointer_type
else AsyncCheckpointerType.THREAD
),
) )
t = time.monotonic() t = time.monotonic()
while not f.done(): while not f.done():

View File

@ -7,7 +7,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp import torch.distributed.checkpoint as dist_cp
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor import init_device_mesh
from torch.distributed.checkpoint.state_dict import ( from torch.distributed.checkpoint.state_dict import (
get_model_state_dict, get_model_state_dict,
get_state_dict, get_state_dict,
@ -15,6 +14,7 @@ from torch.distributed.checkpoint.state_dict import (
set_state_dict, set_state_dict,
StateDictOptions, StateDictOptions,
) )
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN

View File

@ -2,10 +2,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import get_state_dict from torch.distributed.checkpoint.state_dict import get_state_dict
from torch.distributed.device_mesh import _mesh_resources, init_device_mesh from torch.distributed.device_mesh import _mesh_resources, init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor import DTensor
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase, DTensorTestBase,

View File

@ -6,19 +6,20 @@ import copy
import torch import torch
import torch.distributed.checkpoint as dcp import torch.distributed.checkpoint as dcp
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor import DTensor, init_device_mesh
from torch.distributed._tensor.experimental import implicit_replication
from torch.distributed.checkpoint.state_dict import ( from torch.distributed.checkpoint.state_dict import (
get_model_state_dict, get_model_state_dict,
get_optimizer_state_dict, get_optimizer_state_dict,
StateDictOptions, StateDictOptions,
) )
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import ( from torch.distributed.fsdp import (
fully_shard, fully_shard,
FullyShardedDataParallel as FSDP, FullyShardedDataParallel as FSDP,
StateDictType, StateDictType,
) )
from torch.distributed.fsdp.wrap import always_wrap_policy from torch.distributed.fsdp.wrap import always_wrap_policy
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.experimental import implicit_replication
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,
parallelize_module, parallelize_module,

View File

@ -4,7 +4,7 @@ from typing import Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp import torch.distributed.checkpoint as dist_cp
from torch.distributed._tensor import ( from torch.distributed.tensor import (
DeviceMesh, DeviceMesh,
distribute_tensor, distribute_tensor,
DTensor, DTensor,

View File

@ -1,14 +1,9 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
import torch import torch
import torch.distributed.checkpoint as dist_cp import torch.distributed.checkpoint as dist_cp
from torch.distributed._tensor import (
distribute_tensor,
init_device_mesh,
Replicate,
Shard,
zeros,
)
from torch.distributed.checkpoint._extension import ZStandard from torch.distributed.checkpoint._extension import ZStandard
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Replicate, Shard, zeros
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
parametrize, parametrize,

View File

@ -5,13 +5,13 @@ import torch.distributed as dist
import torch.distributed.checkpoint as dcp import torch.distributed.checkpoint as dcp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.checkpoint.format_utils import ( from torch.distributed.checkpoint.format_utils import (
BroadcastingTorchSaveReader, BroadcastingTorchSaveReader,
dcp_to_torch_save, dcp_to_torch_save,
DynamicMetaLoadPlanner, DynamicMetaLoadPlanner,
torch_save_to_dcp, torch_save_to_dcp,
) )
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests

View File

@ -3,9 +3,10 @@ import torch
import torch.distributed.checkpoint as dist_cp import torch.distributed.checkpoint as dist_cp
from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._state_dict_utils import _all_gather_sharded_tensor from torch.distributed._state_dict_utils import _all_gather_sharded_tensor
from torch.distributed._tensor import DTensor, init_device_mesh, Replicate from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.tensor import DTensor, Replicate
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,
parallelize_module, parallelize_module,

View File

@ -5,16 +5,17 @@ import torch
import torch.distributed.checkpoint as dist_cp import torch.distributed.checkpoint as dist_cp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributed._tensor import init_device_mesh, Replicate
from torch.distributed.checkpoint.default_planner import ( from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner, DefaultLoadPlanner,
DefaultSavePlanner, DefaultSavePlanner,
) )
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import ( from torch.distributed.fsdp.fully_sharded_data_parallel import (
ShardingStrategy, ShardingStrategy,
StateDictType, StateDictType,
) )
from torch.distributed.tensor import Replicate
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,

View File

@ -5,7 +5,7 @@ from unittest.mock import patch
import torch import torch
import torch.distributed.checkpoint as dcp import torch.distributed.checkpoint as dcp
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor.device_mesh import init_device_mesh from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed._tensor.common_dtensor import (

View File

@ -11,7 +11,6 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed._composable import replicate from torch.distributed._composable import replicate
from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor, init_device_mesh
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing, apply_activation_checkpointing,
) )
@ -26,6 +25,7 @@ from torch.distributed.checkpoint.state_dict import (
set_optimizer_state_dict, set_optimizer_state_dict,
StateDictOptions, StateDictOptions,
) )
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import ( from torch.distributed.fsdp import (
fully_shard, fully_shard,
FullyShardedDataParallel as FSDP, FullyShardedDataParallel as FSDP,
@ -34,6 +34,7 @@ from torch.distributed.fsdp import (
) )
from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.optim import _apply_optimizer_in_backward from torch.distributed.optim import _apply_optimizer_in_backward
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,
parallelize_module, parallelize_module,

View File

@ -13,12 +13,8 @@ from torch.distributed._state_dict_utils import (
_gather_state_dict, _gather_state_dict,
_offload_state_dict_to_cpu, _offload_state_dict_to_cpu,
) )
from torch.distributed._tensor import ( from torch.distributed.device_mesh import init_device_mesh
distribute_tensor, from torch.distributed.tensor import distribute_tensor, DTensor, Shard
DTensor,
init_device_mesh,
Shard,
)
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase, DTensorTestBase,

View File

@ -4,11 +4,11 @@ from copy import deepcopy
import torch import torch
import torch.distributed.checkpoint as dcp import torch.distributed.checkpoint as dcp
from torch.distributed._tensor import init_device_mesh
from torch.distributed.checkpoint.default_planner import ( from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner, DefaultLoadPlanner,
DefaultSavePlanner, DefaultSavePlanner,
) )
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,
parallelize_module, parallelize_module,

View File

@ -5,7 +5,6 @@ from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor, Shard
from torch.distributed.device_mesh import init_device_mesh from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ( from torch.distributed.fsdp.api import (
@ -13,6 +12,7 @@ from torch.distributed.fsdp.api import (
ShardedStateDictConfig, ShardedStateDictConfig,
StateDictType, StateDictType,
) )
from torch.distributed.tensor import DTensor, Shard
from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_fsdp import get_devtype from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import parametrize, run_tests from torch.testing._internal.common_utils import parametrize, run_tests

View File

@ -6,19 +6,19 @@ from typing import Optional
import torch import torch
from torch import distributed as dist from torch import distributed as dist
from torch.distributed._tensor import ( from torch.distributed.device_mesh import init_device_mesh
DeviceMesh,
distribute_module,
DTensor,
init_device_mesh,
Replicate,
Shard,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import ( from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload, CPUOffload,
FullyShardedDataParallel as FSDP, FullyShardedDataParallel as FSDP,
ShardingStrategy, ShardingStrategy,
) )
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
DTensor,
Replicate,
Shard,
)
from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,
@ -190,9 +190,11 @@ class TestTPFSDPIntegration(FSDPTest):
local_grads_as_flattened = ( local_grads_as_flattened = (
torch.cat( torch.cat(
[ [
torch.flatten(param.grad) (
if param.grad is not None torch.flatten(param.grad)
else torch.zeros_like(torch.flatten(param)) if param.grad is not None
else torch.zeros_like(torch.flatten(param))
)
for param in model.parameters() for param in model.parameters()
] ]
) )

View File

@ -7,7 +7,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor, Replicate, Shard
from torch.distributed.device_mesh import init_device_mesh from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ( from torch.distributed.fsdp.api import (
@ -16,6 +15,7 @@ from torch.distributed.fsdp.api import (
ShardingStrategy, ShardingStrategy,
StateDictType, StateDictType,
) )
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_fsdp import get_devtype from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import parametrize, run_tests from torch.testing._internal.common_utils import parametrize, run_tests

View File

@ -4,8 +4,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed.tensor import DeviceMesh, DTensor, Shard
from torch.distributed._tensor.placement_types import Shard
from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_distributed import requires_nccl from torch.testing._internal.common_distributed import requires_nccl
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import run_tests, TestCase

View File

@ -4,8 +4,7 @@
from typing import Any from typing import Any
import torch import torch
from torch.distributed._tensor import DeviceMesh from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed._tensor.api import distribute_tensor, DTensor
from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,

View File

@ -4,15 +4,10 @@ from functools import partial
import torch import torch
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import ( from torch.distributed.device_mesh import init_device_mesh
distribute_tensor, from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard
DTensor,
init_device_mesh,
Replicate,
Shard,
)
from torch.distributed._tensor.experimental import local_map
from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.experimental import local_map
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase, DTensorTestBase,

View File

@ -3,9 +3,9 @@
import itertools import itertools
import torch import torch
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard
from torch.distributed._tensor.experimental import register_sharding from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed._tensor.placement_types import DTensorSpec from torch.distributed.tensor.experimental import register_sharding
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase, DTensorTestBase,

View File

@ -2,7 +2,7 @@
from collections import defaultdict from collections import defaultdict
import torch import torch
from torch.distributed._tensor.experimental._tp_transform import ( from torch.distributed.tensor.experimental._tp_transform import (
tensor_parallel_transformation, tensor_parallel_transformation,
) )
from torch.distributed.tensor.parallel.style import ( from torch.distributed.tensor.parallel.style import (

View File

@ -19,9 +19,8 @@ from torch.distributed._functional_collectives import (
reduce_scatter_tensor, reduce_scatter_tensor,
) )
from torch.distributed._symmetric_memory import _test_mode from torch.distributed._symmetric_memory import _test_mode
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor.placement_types import Shard
from torch.distributed.distributed_c10d import _get_group_size_by_name from torch.distributed.distributed_c10d import _get_group_size_by_name
from torch.distributed.tensor import DeviceMesh, Shard
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,
parallelize_module, parallelize_module,

View File

@ -3,7 +3,7 @@ from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
import torch import torch
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel.api import parallelize_module from torch.distributed.tensor.parallel.api import parallelize_module
from torch.distributed.tensor.parallel.style import ( from torch.distributed.tensor.parallel.style import (

View File

@ -8,17 +8,17 @@ from typing import NamedTuple, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributed._tensor import ( from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.tensor import (
DeviceMesh, DeviceMesh,
distribute_tensor, distribute_tensor,
DTensor, DTensor,
Replicate, Replicate,
Shard, Shard,
) )
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,
@ -259,7 +259,7 @@ class DistTensorParallelExampleTest(DTensorTestBase):
check_comms=True, check_comms=True,
): ):
optim.step() # Ensure model weights are still the same after update. optim.step() # Ensure model weights are still the same after update.
from torch.distributed._tensor.experimental import implicit_replication from torch.distributed.tensor.experimental import implicit_replication
with implicit_replication(): with implicit_replication():
with CommDebugMode() as comm_mode: with CommDebugMode() as comm_mode:

View File

@ -2,7 +2,8 @@
import torch import torch
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._random as random import torch.distributed.tensor._random as random
from torch.distributed._tensor import init_device_mesh, Replicate from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel.api import parallelize_module from torch.distributed.tensor.parallel.api import parallelize_module
from torch.distributed.tensor.parallel.style import ColwiseParallel from torch.distributed.tensor.parallel.style import ColwiseParallel
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_distributed import skip_if_lt_x_gpu

View File

@ -5,13 +5,8 @@ from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor import ( from torch.distributed.device_mesh import init_device_mesh
distribute_tensor, from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard
DTensor,
init_device_mesh,
Replicate,
Shard,
)
from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import parallelize_module from torch.distributed.tensor.parallel import parallelize_module
from torch.distributed.tensor.parallel.style import ( from torch.distributed.tensor.parallel.style import (

View File

@ -3,7 +3,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor import ( from torch.distributed.tensor import (
DeviceMesh, DeviceMesh,
distribute_module, distribute_module,
distribute_tensor, distribute_tensor,

View File

@ -2,8 +2,8 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
import torch import torch
from torch.distributed._tensor import DeviceMesh from torch.distributed.tensor import DeviceMesh
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import OpSchema from torch.distributed.tensor._op_schema import OpSchema
from torch.distributed.tensor._ops._common_rules import einop_rule, pointwise_rule from torch.distributed.tensor._ops._common_rules import einop_rule, pointwise_rule
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests

View File

@ -11,22 +11,19 @@ from numpy.testing import assert_array_equal
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed._tensor import ( from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import (
DeviceMesh, DeviceMesh,
distribute_tensor, distribute_tensor,
DTensor, DTensor,
init_device_mesh,
)
from torch.distributed._tensor.experimental import implicit_replication
from torch.distributed._tensor.placement_types import (
DTensorSpec,
Partial, Partial,
Replicate, Replicate,
Shard, Shard,
TensorMeta,
) )
from torch.distributed.tensor._api import _shard_tensor from torch.distributed.tensor._api import _shard_tensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.experimental import implicit_replication
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,
parallelize_module, parallelize_module,
@ -743,7 +740,7 @@ class DTensorMeshTest(DTensorTestBase):
), ),
] ]
from torch.distributed._tensor._utils import ( from torch.distributed.tensor._utils import (
compute_local_shape_and_global_offset, compute_local_shape_and_global_offset,
) )
@ -1009,7 +1006,8 @@ class DTensorLogTest(LoggingTestCase):
"""\ """\
import logging import logging
import torch import torch
from torch.distributed._tensor import init_device_mesh, distribute_tensor, Shard from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Shard
mesh = init_device_mesh("cuda", (1,), mesh_dim_names=("dp",)) mesh = init_device_mesh("cuda", (1,), mesh_dim_names=("dp",))
placements = [Shard(0)] placements = [Shard(0)]

View File

@ -13,20 +13,14 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch._C import FileCheck from torch._C import FileCheck
from torch._inductor.utils import run_and_get_triton_code from torch._inductor.utils import run_and_get_triton_code
from torch.distributed._tensor import (
DeviceMesh,
DTensor,
init_device_mesh,
Partial,
Replicate,
Shard,
)
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper, checkpoint_wrapper,
CheckpointImpl, CheckpointImpl,
) )
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,
parallelize_module, parallelize_module,

View File

@ -7,7 +7,7 @@ import warnings
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.testing._internal.common_methods_invocations as common_ops import torch.testing._internal.common_methods_invocations as common_ops
from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed.tensor import DeviceMesh, DTensor
from torch.overrides import resolve_name from torch.overrides import resolve_name
from torch.testing._internal.common_device_type import ( from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, instantiate_device_type_tests,

View File

@ -3,7 +3,7 @@
import sys import sys
import torch import torch
from torch.distributed._tensor import ( from torch.distributed.tensor import (
distribute_module, distribute_module,
distribute_tensor, distribute_tensor,
DTensor, DTensor,

View File

@ -4,7 +4,7 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate from torch.distributed.tensor import DeviceMesh, distribute_tensor, Replicate
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase, DTensorTestBase,

View File

@ -2,7 +2,7 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
import torch import torch
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard, zeros from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard, zeros
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase, DTensorTestBase,
@ -94,7 +94,7 @@ class DTensorConstructorTest(DTensorTestBase):
def test_ones(self): def test_ones(self):
self._run_init_op( self._run_init_op(
torch.ones, torch.ones,
torch.distributed._tensor.ones, torch.distributed.tensor.ones,
self.assertEqual, self.assertEqual,
requires_grad=True, requires_grad=True,
) )
@ -103,7 +103,7 @@ class DTensorConstructorTest(DTensorTestBase):
def test_empty(self): def test_empty(self):
self._run_init_op( self._run_init_op(
torch.empty, torch.empty,
torch.distributed._tensor.empty, torch.distributed.tensor.empty,
lambda x, y: (x.shape == y.shape) lambda x, y: (x.shape == y.shape)
and (x.dtype == y.dtype) and (x.dtype == y.dtype)
and (x.layout == y.layout), and (x.layout == y.layout),
@ -114,7 +114,7 @@ class DTensorConstructorTest(DTensorTestBase):
def test_full(self): def test_full(self):
self._run_init_op( self._run_init_op(
torch.full, torch.full,
torch.distributed._tensor.full, torch.distributed.tensor.full,
self.assertEqual, self.assertEqual,
123.4, 123.4,
requires_grad=True, requires_grad=True,
@ -124,7 +124,7 @@ class DTensorConstructorTest(DTensorTestBase):
def test_zeros(self): def test_zeros(self):
self._run_init_op( self._run_init_op(
torch.zeros, torch.zeros,
torch.distributed._tensor.zeros, torch.distributed.tensor.zeros,
self.assertEqual, self.assertEqual,
requires_grad=True, requires_grad=True,
) )

View File

@ -7,13 +7,14 @@ from pprint import pformat
from typing import NamedTuple from typing import NamedTuple
import torch import torch
from torch.distributed._tensor.placement_types import Replicate, Shard from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import ( from torch.distributed.tensor import (
DeviceMesh, DeviceMesh,
distribute_module, distribute_module,
distribute_tensor, distribute_tensor,
DTensor, DTensor,
init_device_mesh, Replicate,
Shard,
) )
from torch.distributed.tensor._ops.utils import is_tensor_partial, normalize_dim from torch.distributed.tensor._ops.utils import is_tensor_partial, normalize_dim
from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.debug import CommDebugMode
@ -310,7 +311,7 @@ class DistMathOpsTest(DTensorTestBase):
f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
) )
from torch.distributed._tensor.placement_types import TensorMeta from torch.distributed.tensor._dtensor_spec import TensorMeta
dtensor_meta = y_dist._spec.tensor_meta dtensor_meta = y_dist._spec.tensor_meta
assert isinstance(dtensor_meta, TensorMeta) assert isinstance(dtensor_meta, TensorMeta)

View File

@ -3,15 +3,9 @@
from itertools import chain from itertools import chain
import torch import torch
from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.distributed._tensor.placement_types import (
DTensorSpec,
Partial,
Replicate,
Shard,
TensorMeta,
)
from torch.distributed.tensor._collective_utils import redistribute_cost from torch.distributed.tensor._collective_utils import redistribute_cost
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import OpSchema, OpStrategy, PlacementStrategy from torch.distributed.tensor._op_schema import OpSchema, OpStrategy, PlacementStrategy
from torch.distributed.tensor._ops._einsum_strategy import ( from torch.distributed.tensor._ops._einsum_strategy import (
EinsumDims, EinsumDims,

View File

@ -4,7 +4,7 @@ from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor import ( from torch.distributed.tensor import (
DeviceMesh, DeviceMesh,
distribute_module, distribute_module,
distribute_tensor, distribute_tensor,

View File

@ -8,8 +8,10 @@ from unittest import skip
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch import Tensor from torch import Tensor
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor from torch.distributed.tensor import (
from torch.distributed._tensor.placement_types import ( DeviceMesh,
distribute_tensor,
DTensor,
Partial, Partial,
Placement, Placement,
Replicate, Replicate,

View File

@ -6,17 +6,22 @@ import itertools
import torch import torch
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._random as random import torch.distributed.tensor._random as random
from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh from torch.distributed.device_mesh import init_device_mesh
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
from torch.distributed._tensor.api import distribute_tensor
from torch.distributed._tensor.placement_types import Replicate, Shard
from torch.distributed.distributed_c10d import broadcast_object_list from torch.distributed.distributed_c10d import broadcast_object_list
from torch.distributed.fsdp import fully_shard from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.distributed.tensor._random import ( from torch.distributed.tensor._random import (
is_rng_supported_mesh, is_rng_supported_mesh,
manual_seed, manual_seed,
OffsetBasedRNGTracker, OffsetBasedRNGTracker,
) )
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
from torch.testing._internal.common_utils import run_tests, TEST_HPU from torch.testing._internal.common_utils import run_tests, TEST_HPU
@ -396,8 +401,8 @@ class DistTensorRandomOpTest(DTensorTestBase):
size = [4, 4 * self.world_size] size = [4, 4 * self.world_size]
for fn in [ for fn in [
torch.distributed._tensor.rand, torch.distributed.tensor.rand,
torch.distributed._tensor.randn, torch.distributed.tensor.randn,
]: ]:
dtensor = fn(size, device_mesh=device_mesh, placements=[Shard(1)]) dtensor = fn(size, device_mesh=device_mesh, placements=[Shard(1)])
local_tensor = funcol.all_gather_tensor( local_tensor = funcol.all_gather_tensor(

View File

@ -4,9 +4,15 @@
import itertools import itertools
import torch import torch
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed._tensor.placement_types import Partial, Replicate, Shard
from torch.distributed.device_mesh import init_device_mesh from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor._collective_utils import shard_dim_alltoall from torch.distributed.tensor._collective_utils import shard_dim_alltoall
from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TEST_HPU from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TEST_HPU

View File

@ -2,8 +2,14 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
import torch import torch
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor from torch.distributed.tensor import (
from torch.distributed._tensor.placement_types import Partial, Replicate, Shard DeviceMesh,
distribute_tensor,
DTensor,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, skipIfRocm from torch.testing._internal.common_utils import run_tests, skipIfRocm

View File

@ -10,7 +10,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.distributed._tensor import ( from torch.distributed.tensor import (
DeviceMesh, DeviceMesh,
distribute_module, distribute_module,
distribute_tensor, distribute_tensor,

View File

@ -8,7 +8,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import init_device_mesh from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed.pipelining import PipelineStage from torch.distributed.pipelining import PipelineStage
@ -20,6 +19,7 @@ from torch.distributed.pipelining.schedules import (
ScheduleInterleavedZeroBubble, ScheduleInterleavedZeroBubble,
ScheduleLoopedBFS, ScheduleLoopedBFS,
) )
from torch.distributed.tensor import DTensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import ( from torch.testing._internal.common_distributed import (

View File

@ -5,7 +5,6 @@ import os
import torch import torch
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
from torch.distributed.distributed_c10d import ( from torch.distributed.distributed_c10d import (
_get_default_group, _get_default_group,
@ -17,6 +16,7 @@ from torch.distributed.distributed_c10d import (
new_group, new_group,
ProcessGroup, ProcessGroup,
) )
from torch.distributed.tensor import DTensor
from torch.distributed.tensor._collective_utils import ( from torch.distributed.tensor._collective_utils import (
mesh_broadcast, mesh_broadcast,
mesh_scatter, mesh_scatter,

View File

@ -7,8 +7,9 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor import DeviceMesh, init_device_mesh, Shard from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor import DeviceMesh, Shard
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,
parallelize_module, parallelize_module,

View File

@ -7,8 +7,8 @@ from functools import partial, wraps
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.distributed._functional_collectives as ft_c import torch.distributed._functional_collectives as ft_c
import torch.distributed._tensor as dt
import torch.distributed.distributed_c10d as c10d import torch.distributed.distributed_c10d as c10d
import torch.distributed.tensor as dt
from functorch import make_fx from functorch import make_fx
from torch._inductor.utils import run_and_get_code from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck from torch.testing import FileCheck

View File

@ -996,8 +996,6 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def]
placements: Optional[Sequence[Placement]] = None, placements: Optional[Sequence[Placement]] = None,
**kwargs, **kwargs,
) -> DTensor: ) -> DTensor:
# from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
# if device_mesh is None, use the one from mesh resources # if device_mesh is None, use the one from mesh resources
device_mesh = device_mesh or _mesh_resources.get_current_mesh() device_mesh = device_mesh or _mesh_resources.get_current_mesh()
kwargs["device"] = device_mesh.device_type kwargs["device"] = device_mesh.device_type

View File

@ -1372,7 +1372,7 @@ def context_parallel(
these buffers can be put in this list to avoid extra restore time. these buffers can be put in this list to avoid extra restore time.
.. warning:: .. warning::
`torch.distributed._tensor.experimental.attention.context_parallel` is a `torch.distributed.tensor.experimental.context_parallel` is a
prototype feature in PyTorch. The API is subject to change. prototype feature in PyTorch. The API is subject to change.
""" """
buffers = [] if buffers is None else buffers buffers = [] if buffers is None else buffers

View File

@ -14,8 +14,13 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch._utils import _get_device_module from torch._utils import _get_device_module
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard from torch.distributed.tensor import (
from torch.distributed._tensor.placement_types import Placement DeviceMesh,
distribute_tensor,
Placement,
Replicate,
Shard,
)
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,
parallelize_module, parallelize_module,

View File

@ -10,13 +10,13 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed._sharded_tensor import ShardedTensor from torch.distributed._sharded_tensor import ShardedTensor
from torch.distributed._state_dict_utils import _gather_state_dict from torch.distributed._state_dict_utils import _gather_state_dict
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import ( from torch.distributed.checkpoint.state_dict import (
_PG, _PG,
_STATE, _STATE,
set_state_dict, set_state_dict,
StateDictOptions, StateDictOptions,
) )
from torch.distributed.tensor import DTensor
class VerifyStateDictMixin: class VerifyStateDictMixin: