mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Add sharding tests to multigpu-test.sh and fix custom operator decorator (#77825)"
This reverts commit 8d4c8df33a58dc5f905dfdcee1cd124a79aaf2d8. Reverted https://github.com/pytorch/pytorch/pull/77825 on behalf of https://github.com/janeyx99 due to as it will break multigpu test reporting
This commit is contained in:
@ -28,27 +28,4 @@ time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
|
||||
time python test/run_test.py --verbose -i distributed/test_store
|
||||
time python test/run_test.py --verbose -i distributed/test_pg_wrapper
|
||||
time python test/run_test.py --verbose -i distributed/rpc/cuda/test_tensorpipe_agent
|
||||
time python test/run_test.py --verbose -i distributed/_shard/checkpoint/test_checkpoint
|
||||
time python test/run_test.py --verbose -i distributed/_shard/checkpoint/test_file_system_checkpoint
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharding_spec/test_sharding_spec
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharding_plan/test_sharding_plan
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharding_spec/test_sharding_spec
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharding_plan/test_sharding_plan
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/test_megatron_prototype
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/test_sharded_tensor
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/test_sharded_tensor_reshard
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_chunk
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_elementwise_ops
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_embedding
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_embedding_bag
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_binary_cmp
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_init
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_linear
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_math_ops
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_matrix_ops
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_softmax
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharding_spec/test_sharding_spec
|
||||
time python test/run_test.py --verbose -i distributed/_shard/sharded_optim/test_sharded_optim
|
||||
time python test/run_test.py --verbose -i distributed/_shard/test_partial_tensor
|
||||
time python test/run_test.py --verbose -i distributed/_shard/test_replicated_tensor
|
||||
assert_git_not_dirty
|
||||
|
@ -101,8 +101,8 @@ class TestShardedTensorMatrixOps(ShardedTensorTestBase):
|
||||
enumerable_spec, 10, 10, init_rrefs=False, dtype=torch.double
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"not supported",
|
||||
NotImplementedError,
|
||||
"Only ChunkShardingSpec supported for 'transpose'",
|
||||
):
|
||||
st.transpose(1, 0)
|
||||
|
||||
|
@ -19,7 +19,7 @@ from torch.distributed._shard.api import (
|
||||
_reshard_output,
|
||||
)
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
custom_sharded_op_impl,
|
||||
sharded_op_impl,
|
||||
pre_load_state_dict_hook,
|
||||
state_dict_hook,
|
||||
ShardedTensor,
|
||||
@ -174,7 +174,7 @@ class TestShardParameter(ShardedTensorTestBase):
|
||||
with self.assertRaisesRegex(ValueError, 'does not match with src_rank'):
|
||||
shard_parameter(fc, 'weight', spec, src_rank=self.rank)
|
||||
|
||||
with self.assertRaisesRegex(AttributeError, 'has no attribute'):
|
||||
with self.assertRaisesRegex(AttributeError, 'Linear have no attribute'):
|
||||
shard_parameter(fc, 'foo', spec)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'Expected Linear.bias to be a Tensor, but found str'):
|
||||
@ -2463,7 +2463,7 @@ class TestShardedTensorCustomOps(ShardedTensorTestBase):
|
||||
@requires_nccl()
|
||||
def test_custom_op(self):
|
||||
|
||||
@custom_sharded_op_impl(torch.asin)
|
||||
@sharded_op_impl(torch.asin)
|
||||
def my_sharded_asin(types, args, kwargs, process_group):
|
||||
return torch.asin(args[0].local_shards()[0].tensor)
|
||||
|
||||
@ -2491,7 +2491,7 @@ class TestShardedTensorCustomOps(ShardedTensorTestBase):
|
||||
from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
|
||||
|
||||
@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.linear)
|
||||
def my_sharded_linear(types, args, kwargs, process_group):
|
||||
def my_sharded_linear(types, args, kwargs):
|
||||
return t
|
||||
|
||||
spec = ChunkShardingSpec(
|
||||
@ -2515,12 +2515,12 @@ class TestShardedTensorCustomOps(ShardedTensorTestBase):
|
||||
def test_custom_op_errors(self):
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'expects signature'):
|
||||
@custom_sharded_op_impl(torch.nn.functional.linear)
|
||||
@sharded_op_impl(torch.nn.functional.linear)
|
||||
def my_op1(types, args, kwargs, process_group, random_param):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'expects signature'):
|
||||
@custom_sharded_op_impl(torch.nn.functional.linear)
|
||||
@sharded_op_impl(torch.nn.functional.linear)
|
||||
def my_op2(types):
|
||||
pass
|
||||
|
||||
|
@ -201,8 +201,6 @@ WINDOWS_BLOCKLIST = [
|
||||
"distributed/pipeline/sync/test_worker",
|
||||
"distributed/elastic/agent/server/test/api_test",
|
||||
"distributed/elastic/multiprocessing/api_test",
|
||||
"distributed/_shard/checkpoint/test_checkpoint"
|
||||
"distributed/_shard/checkpoint/test_file_system_checkpoint"
|
||||
"distributed/_shard/sharding_spec/test_sharding_spec",
|
||||
"distributed/_shard/sharding_plan/test_sharding_plan",
|
||||
"distributed/_shard/sharded_tensor/test_megatron_prototype",
|
||||
@ -218,6 +216,8 @@ WINDOWS_BLOCKLIST = [
|
||||
"distributed/_shard/sharded_tensor/ops/test_math_ops",
|
||||
"distributed/_shard/sharded_tensor/ops/test_matrix_ops",
|
||||
"distributed/_shard/sharded_tensor/ops/test_softmax",
|
||||
"distributed/_shard/sharded_tensor/ops/test_tensor_ops",
|
||||
"distributed/_shard/sharding_spec/test_sharding_spec",
|
||||
"distributed/_shard/sharded_optim/test_sharded_optim",
|
||||
"distributed/_shard/test_partial_tensor",
|
||||
"distributed/_shard/test_replicated_tensor",
|
||||
@ -228,8 +228,6 @@ ROCM_BLOCKLIST = [
|
||||
"distributed/rpc/test_faulty_agent",
|
||||
"distributed/rpc/test_tensorpipe_agent",
|
||||
"distributed/rpc/cuda/test_tensorpipe_agent",
|
||||
"distributed/_shard/checkpoint/test_checkpoint"
|
||||
"distributed/_shard/checkpoint/test_file_system_checkpoint"
|
||||
"distributed/_shard/sharding_spec/test_sharding_spec",
|
||||
"distributed/_shard/sharding_plan/test_sharding_plan",
|
||||
"distributed/_shard/sharded_tensor/test_megatron_prototype",
|
||||
@ -245,6 +243,8 @@ ROCM_BLOCKLIST = [
|
||||
"distributed/_shard/sharded_tensor/ops/test_math_ops",
|
||||
"distributed/_shard/sharded_tensor/ops/test_matrix_ops",
|
||||
"distributed/_shard/sharded_tensor/ops/test_softmax",
|
||||
"distributed/_shard/sharded_tensor/ops/test_tensor_ops",
|
||||
"distributed/_shard/sharding_spec/test_sharding_spec",
|
||||
"distributed/_shard/sharded_optim/test_sharded_optim",
|
||||
"distributed/_shard/test_partial_tensor",
|
||||
"distributed/_shard/test_replicated_tensor",
|
||||
|
@ -9,7 +9,6 @@ import torch.distributed._shard.sharding_spec as shard_spec
|
||||
from torch.distributed._shard.partial_tensor import _PartialTensor
|
||||
|
||||
from .api import (
|
||||
_CUSTOM_SHARDED_OPS,
|
||||
_SHARDED_OPS,
|
||||
Shard,
|
||||
ShardedTensor,
|
||||
@ -412,7 +411,7 @@ def pre_load_state_dict_hook(module, state_dict, prefix, local_metadata, strict,
|
||||
if isinstance(state_dict[key], ShardedTensor):
|
||||
setattr(submodule, attr_name, state_dict[key])
|
||||
|
||||
def custom_sharded_op_impl(func):
|
||||
def sharded_op_impl(func):
|
||||
"""
|
||||
Provides a way for users to write their own custom sharded operator. This
|
||||
can be used to override existing ShardedTensor operators or write a new
|
||||
@ -421,7 +420,7 @@ def custom_sharded_op_impl(func):
|
||||
parameters, the function provided will be invoked for that operator.
|
||||
|
||||
Example::
|
||||
>>> @custom_sharded_op_impl(torch.nn.functional.linear)
|
||||
>>> @sharded_op_impl(torch.nn.functional.linear)
|
||||
>>> def my_custom_sharded_linear(types, args, kwargs, process_group):
|
||||
>>> ....
|
||||
>>>
|
||||
@ -442,16 +441,6 @@ def custom_sharded_op_impl(func):
|
||||
func(Callable): Torch function for which we want to provide a sharded
|
||||
implementation (ex: torch.nn.functional.linear)
|
||||
"""
|
||||
return functools.partial(
|
||||
_decorator_func,
|
||||
op=func,
|
||||
op_table=_CUSTOM_SHARDED_OPS
|
||||
)
|
||||
|
||||
def _sharded_op_impl(func):
|
||||
"""
|
||||
Decorator to register a default sharded op.
|
||||
"""
|
||||
return functools.partial(
|
||||
_decorator_func,
|
||||
op=func,
|
||||
|
@ -1,6 +1,6 @@
|
||||
import functools
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
_sharded_op_impl,
|
||||
sharded_op_impl,
|
||||
Shard,
|
||||
ShardedTensor,
|
||||
)
|
||||
@ -13,7 +13,7 @@ def _sharded_op_common(op, early_stop_func, extra_check):
|
||||
|
||||
Example::
|
||||
>>> op = torch.transpose
|
||||
>>> @_sharded_op_impl(op)
|
||||
>>> @sharded_op_impl(op)
|
||||
>>> @_sharded_op_common(op, early_stop_func, extra_check)
|
||||
>>> def sharded_tensor_op(types, args, kwargs, process_group):
|
||||
>>> ....
|
||||
@ -82,7 +82,7 @@ def _register_sharded_op_on_local_shards(
|
||||
func (Callable): registered implementation for sharded op for
|
||||
``__torch_function__`` dispatch.
|
||||
"""
|
||||
@_sharded_op_impl(op)
|
||||
@sharded_op_impl(op)
|
||||
@_sharded_op_common(op, early_stop_func, extra_check)
|
||||
def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
|
||||
st = args[0]
|
||||
|
@ -3,7 +3,7 @@ import torch.distributed as dist
|
||||
import torch.distributed.distributed_c10d as distributed_c10d
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
ShardedTensor,
|
||||
_sharded_op_impl
|
||||
sharded_op_impl
|
||||
)
|
||||
|
||||
def _communicate_result(result, pg):
|
||||
@ -59,10 +59,10 @@ def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None):
|
||||
|
||||
return _communicate_result(True, st1._process_group)
|
||||
|
||||
@_sharded_op_impl(torch.equal)
|
||||
@sharded_op_impl(torch.equal)
|
||||
def equal(types, args, kwargs, process_group):
|
||||
return binary_cmp(torch.equal, types, args, kwargs, process_group)
|
||||
|
||||
@_sharded_op_impl(torch.allclose)
|
||||
@sharded_op_impl(torch.allclose)
|
||||
def allclose(types, args, kwargs, process_group):
|
||||
return binary_cmp(torch.allclose, types, args, kwargs, process_group)
|
||||
|
@ -1,13 +1,13 @@
|
||||
import torch
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
_sharded_op_impl,
|
||||
sharded_op_impl,
|
||||
ShardedTensor,
|
||||
)
|
||||
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
|
||||
|
||||
|
||||
def register_chunk_op(op):
|
||||
@_sharded_op_impl(op)
|
||||
@sharded_op_impl(op)
|
||||
def sharded_chunk(types, args=(), kwargs=None, pg=None):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for the chunk op.
|
||||
|
@ -1,14 +1,14 @@
|
||||
import torch
|
||||
import torch.distributed._shard.sharded_tensor as sharded_tensor
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
_sharded_op_impl,
|
||||
sharded_op_impl,
|
||||
)
|
||||
|
||||
def validate_param(param, param_name):
|
||||
if param is None:
|
||||
raise ValueError(f"param: {param_name} shouldn't be None!")
|
||||
|
||||
@_sharded_op_impl(torch.nn.init.uniform_)
|
||||
@sharded_op_impl(torch.nn.init.uniform_)
|
||||
def uniform_(types, args=(), kwargs=None, pg=None):
|
||||
r"""
|
||||
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
|
||||
@ -30,7 +30,7 @@ def uniform_(types, args=(), kwargs=None, pg=None):
|
||||
torch.nn.init.uniform_(shard.tensor, a=a, b=b)
|
||||
return sharded_tensor
|
||||
|
||||
@_sharded_op_impl(torch.nn.init.normal_)
|
||||
@sharded_op_impl(torch.nn.init.normal_)
|
||||
def normal_(types, args=(), kwargs=None, pg=None):
|
||||
r"""
|
||||
Fills the Tensors in sharded_tensor.local_shards with values drawn from the normal
|
||||
@ -52,7 +52,7 @@ def normal_(types, args=(), kwargs=None, pg=None):
|
||||
torch.nn.init.normal_(shard.tensor, mean=mean, std=std)
|
||||
return sharded_tensor
|
||||
|
||||
@_sharded_op_impl(torch.nn.init.kaiming_uniform_)
|
||||
@sharded_op_impl(torch.nn.init.kaiming_uniform_)
|
||||
def kaiming_uniform_(types, args=(), kwargs=None, pg=None):
|
||||
r"""
|
||||
Fills the Tensors in sharded_tensor.local_shards with values according to the method
|
||||
@ -88,7 +88,7 @@ def kaiming_uniform_(types, args=(), kwargs=None, pg=None):
|
||||
torch.nn.init.kaiming_uniform_(shard.tensor, a=a, mode=mode, nonlinearity=nonlinearity)
|
||||
return sharded_tensor
|
||||
|
||||
@_sharded_op_impl(torch.nn.init.constant_)
|
||||
@sharded_op_impl(torch.nn.init.constant_)
|
||||
def constant_(types, args=(), kwargs=None, pg=None):
|
||||
r"""
|
||||
Fills the input ShardedTensor with the value \text{val}val.
|
||||
@ -116,7 +116,7 @@ tensor_like_creation_op_map = {
|
||||
|
||||
# tensor ops that behave the same as the default tensor
|
||||
def register_tensor_creation_op(op):
|
||||
@_sharded_op_impl(op)
|
||||
@sharded_op_impl(op)
|
||||
def tensor_creation_op(types, args=(), kwargs=None, pg=None):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for tensor creation ops that
|
||||
|
@ -2,7 +2,7 @@ import torch
|
||||
from torch import Tensor
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
ShardedTensor,
|
||||
_sharded_op_impl
|
||||
sharded_op_impl
|
||||
)
|
||||
from torch.distributed._shard.replicated_tensor import ReplicatedTensor
|
||||
from torch.distributed._shard._utils import narrow_tensor
|
||||
@ -74,7 +74,7 @@ def binary_math_op_impl(op, types, args=(), kwargs=None, pg=None):
|
||||
f"kwargs: {kwargs} not supported yet for ShardedTensor!")
|
||||
|
||||
def register_math_op(op):
|
||||
@_sharded_op_impl(op)
|
||||
@sharded_op_impl(op)
|
||||
def binary_math_op(types, args=(), kwargs=None, pg=None):
|
||||
return binary_math_op_impl(op, types, args, kwargs, pg)
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import copy
|
||||
import torch
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
_sharded_op_impl,
|
||||
sharded_op_impl,
|
||||
Shard,
|
||||
ShardedTensor,
|
||||
)
|
||||
@ -10,7 +10,7 @@ from ._common import (
|
||||
)
|
||||
from torch.distributed._shard.common_op_utils import _register_default_op
|
||||
|
||||
@_sharded_op_impl(torch.Tensor.__deepcopy__)
|
||||
@sharded_op_impl(torch.Tensor.__deepcopy__)
|
||||
def tensor_deepcopy(types, args=(), kwargs=None, pg=None):
|
||||
# NOTE: we directly implement deepcopy magic method
|
||||
# instead of using the default tensor.__deepcopy__
|
||||
@ -31,18 +31,18 @@ def tensor_deepcopy(types, args=(), kwargs=None, pg=None):
|
||||
|
||||
|
||||
# Tensor properties access
|
||||
_register_default_op(torch.Tensor.requires_grad.__get__, _sharded_op_impl) # type: ignore[attr-defined]
|
||||
_register_default_op(torch.Tensor.shape.__get__, _sharded_op_impl) # type: ignore[attr-defined]
|
||||
_register_default_op(torch.Tensor.dtype.__get__, _sharded_op_impl) # type: ignore[attr-defined]
|
||||
_register_default_op(torch.Tensor.layout.__get__, _sharded_op_impl) # type: ignore[attr-defined]
|
||||
_register_default_op(torch.Tensor.size, _sharded_op_impl)
|
||||
_register_default_op(torch.Tensor.dim, _sharded_op_impl)
|
||||
_register_default_op(torch.Tensor.ndim.__get__, _sharded_op_impl) # type: ignore[attr-defined]
|
||||
_register_default_op(torch.Tensor.is_contiguous, _sharded_op_impl)
|
||||
_register_default_op(torch.Tensor.contiguous, _sharded_op_impl)
|
||||
_register_default_op(torch.Tensor.requires_grad.__get__, sharded_op_impl) # type: ignore[attr-defined]
|
||||
_register_default_op(torch.Tensor.shape.__get__, sharded_op_impl) # type: ignore[attr-defined]
|
||||
_register_default_op(torch.Tensor.dtype.__get__, sharded_op_impl) # type: ignore[attr-defined]
|
||||
_register_default_op(torch.Tensor.layout.__get__, sharded_op_impl) # type: ignore[attr-defined]
|
||||
_register_default_op(torch.Tensor.size, sharded_op_impl)
|
||||
_register_default_op(torch.Tensor.dim, sharded_op_impl)
|
||||
_register_default_op(torch.Tensor.ndim.__get__, sharded_op_impl) # type: ignore[attr-defined]
|
||||
_register_default_op(torch.Tensor.is_contiguous, sharded_op_impl)
|
||||
_register_default_op(torch.Tensor.contiguous, sharded_op_impl)
|
||||
|
||||
# __reduce_ex__ to dispatch to get_state/set_state
|
||||
_register_default_op(torch.Tensor.__reduce_ex__, _sharded_op_impl)
|
||||
_register_default_op(torch.Tensor.__reduce_ex__, sharded_op_impl)
|
||||
|
||||
def sharded_type_as_check(*args, **kwargs):
|
||||
"""
|
||||
@ -153,7 +153,7 @@ _register_sharded_op_on_local_shards(
|
||||
customized_func=sharded_detach,
|
||||
)
|
||||
|
||||
@_sharded_op_impl(torch.Tensor.requires_grad_)
|
||||
@sharded_op_impl(torch.Tensor.requires_grad_)
|
||||
def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None):
|
||||
self_st = args[0]
|
||||
requires_grad = args[1]
|
||||
|
@ -11,8 +11,8 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
import copy
|
||||
from functools import reduce
|
||||
import weakref
|
||||
import math
|
||||
|
||||
import threading
|
||||
import torch
|
||||
@ -49,12 +49,9 @@ _sharded_tensor_lock = threading.Lock()
|
||||
_sharded_tensor_current_id = 0
|
||||
_sharded_tensor_map: Dict[int, 'weakref.ReferenceType[ShardedTensor]'] = {}
|
||||
|
||||
# Default sharded ops
|
||||
# Custom sharded ops
|
||||
_SHARDED_OPS: Dict[Callable, Callable] = {}
|
||||
|
||||
# Customized user ops
|
||||
_CUSTOM_SHARDED_OPS: Dict[Callable, Callable] = {}
|
||||
|
||||
def _register_remote_shards(sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int):
|
||||
with _sharded_tensor_lock:
|
||||
if sharded_tensor_id not in _sharded_tensor_map:
|
||||
@ -287,7 +284,7 @@ class ShardedTensor(object):
|
||||
Default: ``None``
|
||||
"""
|
||||
def shard_size(shard_md):
|
||||
return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
|
||||
return math.prod(shard_md.shard_sizes) # type: ignore[attr-defined]
|
||||
|
||||
rank = dist.get_rank(self._process_group)
|
||||
full_size = self.metadata().size
|
||||
@ -785,10 +782,6 @@ class ShardedTensor(object):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
def dispatch(st: ShardedTensor, func: Callable):
|
||||
# Dispatch to custom user provided op first if it exists.
|
||||
if func in _CUSTOM_SHARDED_OPS:
|
||||
return _CUSTOM_SHARDED_OPS[func](types, args, kwargs, st._process_group)
|
||||
|
||||
# Dispatch to custom sharding spec op if it has one.
|
||||
if _has_custom_op(st._sharding_spec, func):
|
||||
return _dispatch_custom_op(
|
||||
|
Reference in New Issue
Block a user