Revert "Add sharding tests to multigpu-test.sh and fix custom operator decorator (#77825)"

This reverts commit 8d4c8df33a58dc5f905dfdcee1cd124a79aaf2d8.

Reverted https://github.com/pytorch/pytorch/pull/77825 on behalf of https://github.com/janeyx99 due to as it will break multigpu test reporting
This commit is contained in:
PyTorch MergeBot
2022-05-20 17:59:03 +00:00
parent 9d44b3d110
commit 0f74b44f1a
12 changed files with 46 additions and 87 deletions

View File

@ -28,27 +28,4 @@ time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
time python test/run_test.py --verbose -i distributed/test_store
time python test/run_test.py --verbose -i distributed/test_pg_wrapper
time python test/run_test.py --verbose -i distributed/rpc/cuda/test_tensorpipe_agent
time python test/run_test.py --verbose -i distributed/_shard/checkpoint/test_checkpoint
time python test/run_test.py --verbose -i distributed/_shard/checkpoint/test_file_system_checkpoint
time python test/run_test.py --verbose -i distributed/_shard/sharding_spec/test_sharding_spec
time python test/run_test.py --verbose -i distributed/_shard/sharding_plan/test_sharding_plan
time python test/run_test.py --verbose -i distributed/_shard/sharding_spec/test_sharding_spec
time python test/run_test.py --verbose -i distributed/_shard/sharding_plan/test_sharding_plan
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/test_megatron_prototype
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/test_sharded_tensor
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/test_sharded_tensor_reshard
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_chunk
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_elementwise_ops
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_embedding
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_embedding_bag
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_binary_cmp
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_init
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_linear
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_math_ops
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_matrix_ops
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_softmax
time python test/run_test.py --verbose -i distributed/_shard/sharding_spec/test_sharding_spec
time python test/run_test.py --verbose -i distributed/_shard/sharded_optim/test_sharded_optim
time python test/run_test.py --verbose -i distributed/_shard/test_partial_tensor
time python test/run_test.py --verbose -i distributed/_shard/test_replicated_tensor
assert_git_not_dirty

View File

@ -101,8 +101,8 @@ class TestShardedTensorMatrixOps(ShardedTensorTestBase):
enumerable_spec, 10, 10, init_rrefs=False, dtype=torch.double
)
with self.assertRaisesRegex(
RuntimeError,
"not supported",
NotImplementedError,
"Only ChunkShardingSpec supported for 'transpose'",
):
st.transpose(1, 0)

View File

@ -19,7 +19,7 @@ from torch.distributed._shard.api import (
_reshard_output,
)
from torch.distributed._shard.sharded_tensor import (
custom_sharded_op_impl,
sharded_op_impl,
pre_load_state_dict_hook,
state_dict_hook,
ShardedTensor,
@ -174,7 +174,7 @@ class TestShardParameter(ShardedTensorTestBase):
with self.assertRaisesRegex(ValueError, 'does not match with src_rank'):
shard_parameter(fc, 'weight', spec, src_rank=self.rank)
with self.assertRaisesRegex(AttributeError, 'has no attribute'):
with self.assertRaisesRegex(AttributeError, 'Linear have no attribute'):
shard_parameter(fc, 'foo', spec)
with self.assertRaisesRegex(ValueError, 'Expected Linear.bias to be a Tensor, but found str'):
@ -2463,7 +2463,7 @@ class TestShardedTensorCustomOps(ShardedTensorTestBase):
@requires_nccl()
def test_custom_op(self):
@custom_sharded_op_impl(torch.asin)
@sharded_op_impl(torch.asin)
def my_sharded_asin(types, args, kwargs, process_group):
return torch.asin(args[0].local_shards()[0].tensor)
@ -2491,7 +2491,7 @@ class TestShardedTensorCustomOps(ShardedTensorTestBase):
from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.linear)
def my_sharded_linear(types, args, kwargs, process_group):
def my_sharded_linear(types, args, kwargs):
return t
spec = ChunkShardingSpec(
@ -2515,12 +2515,12 @@ class TestShardedTensorCustomOps(ShardedTensorTestBase):
def test_custom_op_errors(self):
with self.assertRaisesRegex(TypeError, 'expects signature'):
@custom_sharded_op_impl(torch.nn.functional.linear)
@sharded_op_impl(torch.nn.functional.linear)
def my_op1(types, args, kwargs, process_group, random_param):
pass
with self.assertRaisesRegex(TypeError, 'expects signature'):
@custom_sharded_op_impl(torch.nn.functional.linear)
@sharded_op_impl(torch.nn.functional.linear)
def my_op2(types):
pass

View File

@ -201,8 +201,6 @@ WINDOWS_BLOCKLIST = [
"distributed/pipeline/sync/test_worker",
"distributed/elastic/agent/server/test/api_test",
"distributed/elastic/multiprocessing/api_test",
"distributed/_shard/checkpoint/test_checkpoint"
"distributed/_shard/checkpoint/test_file_system_checkpoint"
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharding_plan/test_sharding_plan",
"distributed/_shard/sharded_tensor/test_megatron_prototype",
@ -218,6 +216,8 @@ WINDOWS_BLOCKLIST = [
"distributed/_shard/sharded_tensor/ops/test_math_ops",
"distributed/_shard/sharded_tensor/ops/test_matrix_ops",
"distributed/_shard/sharded_tensor/ops/test_softmax",
"distributed/_shard/sharded_tensor/ops/test_tensor_ops",
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_optim/test_sharded_optim",
"distributed/_shard/test_partial_tensor",
"distributed/_shard/test_replicated_tensor",
@ -228,8 +228,6 @@ ROCM_BLOCKLIST = [
"distributed/rpc/test_faulty_agent",
"distributed/rpc/test_tensorpipe_agent",
"distributed/rpc/cuda/test_tensorpipe_agent",
"distributed/_shard/checkpoint/test_checkpoint"
"distributed/_shard/checkpoint/test_file_system_checkpoint"
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharding_plan/test_sharding_plan",
"distributed/_shard/sharded_tensor/test_megatron_prototype",
@ -245,6 +243,8 @@ ROCM_BLOCKLIST = [
"distributed/_shard/sharded_tensor/ops/test_math_ops",
"distributed/_shard/sharded_tensor/ops/test_matrix_ops",
"distributed/_shard/sharded_tensor/ops/test_softmax",
"distributed/_shard/sharded_tensor/ops/test_tensor_ops",
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_optim/test_sharded_optim",
"distributed/_shard/test_partial_tensor",
"distributed/_shard/test_replicated_tensor",

View File

@ -9,7 +9,6 @@ import torch.distributed._shard.sharding_spec as shard_spec
from torch.distributed._shard.partial_tensor import _PartialTensor
from .api import (
_CUSTOM_SHARDED_OPS,
_SHARDED_OPS,
Shard,
ShardedTensor,
@ -412,7 +411,7 @@ def pre_load_state_dict_hook(module, state_dict, prefix, local_metadata, strict,
if isinstance(state_dict[key], ShardedTensor):
setattr(submodule, attr_name, state_dict[key])
def custom_sharded_op_impl(func):
def sharded_op_impl(func):
"""
Provides a way for users to write their own custom sharded operator. This
can be used to override existing ShardedTensor operators or write a new
@ -421,7 +420,7 @@ def custom_sharded_op_impl(func):
parameters, the function provided will be invoked for that operator.
Example::
>>> @custom_sharded_op_impl(torch.nn.functional.linear)
>>> @sharded_op_impl(torch.nn.functional.linear)
>>> def my_custom_sharded_linear(types, args, kwargs, process_group):
>>> ....
>>>
@ -442,16 +441,6 @@ def custom_sharded_op_impl(func):
func(Callable): Torch function for which we want to provide a sharded
implementation (ex: torch.nn.functional.linear)
"""
return functools.partial(
_decorator_func,
op=func,
op_table=_CUSTOM_SHARDED_OPS
)
def _sharded_op_impl(func):
"""
Decorator to register a default sharded op.
"""
return functools.partial(
_decorator_func,
op=func,

View File

@ -1,6 +1,6 @@
import functools
from torch.distributed._shard.sharded_tensor import (
_sharded_op_impl,
sharded_op_impl,
Shard,
ShardedTensor,
)
@ -13,7 +13,7 @@ def _sharded_op_common(op, early_stop_func, extra_check):
Example::
>>> op = torch.transpose
>>> @_sharded_op_impl(op)
>>> @sharded_op_impl(op)
>>> @_sharded_op_common(op, early_stop_func, extra_check)
>>> def sharded_tensor_op(types, args, kwargs, process_group):
>>> ....
@ -82,7 +82,7 @@ def _register_sharded_op_on_local_shards(
func (Callable): registered implementation for sharded op for
``__torch_function__`` dispatch.
"""
@_sharded_op_impl(op)
@sharded_op_impl(op)
@_sharded_op_common(op, early_stop_func, extra_check)
def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
st = args[0]

View File

@ -3,7 +3,7 @@ import torch.distributed as dist
import torch.distributed.distributed_c10d as distributed_c10d
from torch.distributed._shard.sharded_tensor import (
ShardedTensor,
_sharded_op_impl
sharded_op_impl
)
def _communicate_result(result, pg):
@ -59,10 +59,10 @@ def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None):
return _communicate_result(True, st1._process_group)
@_sharded_op_impl(torch.equal)
@sharded_op_impl(torch.equal)
def equal(types, args, kwargs, process_group):
return binary_cmp(torch.equal, types, args, kwargs, process_group)
@_sharded_op_impl(torch.allclose)
@sharded_op_impl(torch.allclose)
def allclose(types, args, kwargs, process_group):
return binary_cmp(torch.allclose, types, args, kwargs, process_group)

View File

@ -1,13 +1,13 @@
import torch
from torch.distributed._shard.sharded_tensor import (
_sharded_op_impl,
sharded_op_impl,
ShardedTensor,
)
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
def register_chunk_op(op):
@_sharded_op_impl(op)
@sharded_op_impl(op)
def sharded_chunk(types, args=(), kwargs=None, pg=None):
"""
Handles ``__torch_function__`` dispatch for the chunk op.

View File

@ -1,14 +1,14 @@
import torch
import torch.distributed._shard.sharded_tensor as sharded_tensor
from torch.distributed._shard.sharded_tensor import (
_sharded_op_impl,
sharded_op_impl,
)
def validate_param(param, param_name):
if param is None:
raise ValueError(f"param: {param_name} shouldn't be None!")
@_sharded_op_impl(torch.nn.init.uniform_)
@sharded_op_impl(torch.nn.init.uniform_)
def uniform_(types, args=(), kwargs=None, pg=None):
r"""
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
@ -30,7 +30,7 @@ def uniform_(types, args=(), kwargs=None, pg=None):
torch.nn.init.uniform_(shard.tensor, a=a, b=b)
return sharded_tensor
@_sharded_op_impl(torch.nn.init.normal_)
@sharded_op_impl(torch.nn.init.normal_)
def normal_(types, args=(), kwargs=None, pg=None):
r"""
Fills the Tensors in sharded_tensor.local_shards with values drawn from the normal
@ -52,7 +52,7 @@ def normal_(types, args=(), kwargs=None, pg=None):
torch.nn.init.normal_(shard.tensor, mean=mean, std=std)
return sharded_tensor
@_sharded_op_impl(torch.nn.init.kaiming_uniform_)
@sharded_op_impl(torch.nn.init.kaiming_uniform_)
def kaiming_uniform_(types, args=(), kwargs=None, pg=None):
r"""
Fills the Tensors in sharded_tensor.local_shards with values according to the method
@ -88,7 +88,7 @@ def kaiming_uniform_(types, args=(), kwargs=None, pg=None):
torch.nn.init.kaiming_uniform_(shard.tensor, a=a, mode=mode, nonlinearity=nonlinearity)
return sharded_tensor
@_sharded_op_impl(torch.nn.init.constant_)
@sharded_op_impl(torch.nn.init.constant_)
def constant_(types, args=(), kwargs=None, pg=None):
r"""
Fills the input ShardedTensor with the value \text{val}val.
@ -116,7 +116,7 @@ tensor_like_creation_op_map = {
# tensor ops that behave the same as the default tensor
def register_tensor_creation_op(op):
@_sharded_op_impl(op)
@sharded_op_impl(op)
def tensor_creation_op(types, args=(), kwargs=None, pg=None):
"""
Handles ``__torch_function__`` dispatch for tensor creation ops that

View File

@ -2,7 +2,7 @@ import torch
from torch import Tensor
from torch.distributed._shard.sharded_tensor import (
ShardedTensor,
_sharded_op_impl
sharded_op_impl
)
from torch.distributed._shard.replicated_tensor import ReplicatedTensor
from torch.distributed._shard._utils import narrow_tensor
@ -74,7 +74,7 @@ def binary_math_op_impl(op, types, args=(), kwargs=None, pg=None):
f"kwargs: {kwargs} not supported yet for ShardedTensor!")
def register_math_op(op):
@_sharded_op_impl(op)
@sharded_op_impl(op)
def binary_math_op(types, args=(), kwargs=None, pg=None):
return binary_math_op_impl(op, types, args, kwargs, pg)

View File

@ -1,7 +1,7 @@
import copy
import torch
from torch.distributed._shard.sharded_tensor import (
_sharded_op_impl,
sharded_op_impl,
Shard,
ShardedTensor,
)
@ -10,7 +10,7 @@ from ._common import (
)
from torch.distributed._shard.common_op_utils import _register_default_op
@_sharded_op_impl(torch.Tensor.__deepcopy__)
@sharded_op_impl(torch.Tensor.__deepcopy__)
def tensor_deepcopy(types, args=(), kwargs=None, pg=None):
# NOTE: we directly implement deepcopy magic method
# instead of using the default tensor.__deepcopy__
@ -31,18 +31,18 @@ def tensor_deepcopy(types, args=(), kwargs=None, pg=None):
# Tensor properties access
_register_default_op(torch.Tensor.requires_grad.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.shape.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.dtype.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.layout.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.size, _sharded_op_impl)
_register_default_op(torch.Tensor.dim, _sharded_op_impl)
_register_default_op(torch.Tensor.ndim.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.is_contiguous, _sharded_op_impl)
_register_default_op(torch.Tensor.contiguous, _sharded_op_impl)
_register_default_op(torch.Tensor.requires_grad.__get__, sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.shape.__get__, sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.dtype.__get__, sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.layout.__get__, sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.size, sharded_op_impl)
_register_default_op(torch.Tensor.dim, sharded_op_impl)
_register_default_op(torch.Tensor.ndim.__get__, sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.is_contiguous, sharded_op_impl)
_register_default_op(torch.Tensor.contiguous, sharded_op_impl)
# __reduce_ex__ to dispatch to get_state/set_state
_register_default_op(torch.Tensor.__reduce_ex__, _sharded_op_impl)
_register_default_op(torch.Tensor.__reduce_ex__, sharded_op_impl)
def sharded_type_as_check(*args, **kwargs):
"""
@ -153,7 +153,7 @@ _register_sharded_op_on_local_shards(
customized_func=sharded_detach,
)
@_sharded_op_impl(torch.Tensor.requires_grad_)
@sharded_op_impl(torch.Tensor.requires_grad_)
def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None):
self_st = args[0]
requires_grad = args[1]

View File

@ -11,8 +11,8 @@ from typing import (
cast,
)
import copy
from functools import reduce
import weakref
import math
import threading
import torch
@ -49,12 +49,9 @@ _sharded_tensor_lock = threading.Lock()
_sharded_tensor_current_id = 0
_sharded_tensor_map: Dict[int, 'weakref.ReferenceType[ShardedTensor]'] = {}
# Default sharded ops
# Custom sharded ops
_SHARDED_OPS: Dict[Callable, Callable] = {}
# Customized user ops
_CUSTOM_SHARDED_OPS: Dict[Callable, Callable] = {}
def _register_remote_shards(sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int):
with _sharded_tensor_lock:
if sharded_tensor_id not in _sharded_tensor_map:
@ -287,7 +284,7 @@ class ShardedTensor(object):
Default: ``None``
"""
def shard_size(shard_md):
return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
return math.prod(shard_md.shard_sizes) # type: ignore[attr-defined]
rank = dist.get_rank(self._process_group)
full_size = self.metadata().size
@ -785,10 +782,6 @@ class ShardedTensor(object):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
def dispatch(st: ShardedTensor, func: Callable):
# Dispatch to custom user provided op first if it exists.
if func in _CUSTOM_SHARDED_OPS:
return _CUSTOM_SHARDED_OPS[func](types, args, kwargs, st._process_group)
# Dispatch to custom sharding spec op if it has one.
if _has_custom_op(st._sharding_spec, func):
return _dispatch_custom_op(