[_shard] make ShardedTensor be a Tensor and nn.Parameter (#79825)

Differential Revision: [D37707371](https://our.internmc.facebook.com/intern/diff/D37707371)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79825
Approved by: https://github.com/kumpera
This commit is contained in:
Wanchao Liang
2022-07-22 00:09:54 -07:00
committed by PyTorch MergeBot
parent 38988a8d14
commit 9c32439a77
16 changed files with 463 additions and 317 deletions

View File

@ -64,11 +64,7 @@ def assert_state_dict_equal(
for key, value_1 in state_dict_1.items():
value_2 = state_dict_2[key]
if isinstance(value_1, torch.Tensor):
self.assertTrue(
torch.equal(value_1, value_2), f"Key {key}'s tensor does not match"
)
elif isinstance(value_1, ShardedTensor):
if isinstance(value_1, ShardedTensor):
for local_shard_1, local_shard_2 in zip(
value_1.local_shards(), value_2.local_shards()
):
@ -76,6 +72,10 @@ def assert_state_dict_equal(
torch.equal(local_shard_1.tensor, local_shard_1.tensor),
f"Key {key}'s shard does not match",
)
elif isinstance(value_1, torch.Tensor):
self.assertTrue(
torch.equal(value_1, value_2), f"Key {key}'s tensor does not match"
)
return True

View File

@ -63,11 +63,7 @@ def assert_state_dict_equal(
for key, value_1 in state_dict_1.items():
value_2 = state_dict_2[key]
if isinstance(value_1, torch.Tensor):
self.assertTrue(
torch.equal(value_1, value_2), f"Key {key}'s tensor does not match"
)
elif isinstance(value_1, ShardedTensor):
if isinstance(value_1, ShardedTensor):
for local_shard_1, local_shard_2 in zip(
value_1.local_shards(), value_2.local_shards()
):
@ -75,6 +71,10 @@ def assert_state_dict_equal(
torch.equal(local_shard_1.tensor, local_shard_1.tensor),
f"Key {key}'s shard does not match",
)
elif isinstance(value_1, torch.Tensor):
self.assertTrue(
torch.equal(value_1, value_2), f"Key {key}'s tensor does not match"
)
return True

View File

@ -13,7 +13,6 @@ from torch.distributed._shard.sharding_spec import (
)
from torch.distributed._shard.sharded_optim import (
ShardedOptimizer,
named_params_with_sharded_tensor
)
from torch.testing._internal.common_distributed import (
requires_nccl,
@ -35,7 +34,7 @@ class MyShardedModel(torch.nn.Module):
torch.manual_seed(0)
self.param = torch.nn.Parameter(torch.rand(5, 10))
if spec is not None:
self.sharded_param = sharded_tensor.rand(spec, 20, 10, requires_grad=True, process_group=group)
self.sharded_param = torch.nn.Parameter(sharded_tensor.rand(spec, 20, 10, requires_grad=True, process_group=group))
else:
self.sharded_param = torch.nn.Parameter(torch.rand(5, 10))
@ -110,7 +109,7 @@ class TestShardedOptimizer(ShardedTensorTestBase):
local_model.sharded_param.detach().clone().requires_grad_()
local_optim = optim.SGD(local_model.parameters(), lr=0.1)
sharded_model_params = dict(named_params_with_sharded_tensor(sharded_model))
sharded_model_params = dict(sharded_model.named_parameters())
sharded_optim = ShardedOptimizer(sharded_model_params, optim.SGD, lr=0.1)
local_optim.zero_grad()
@ -163,7 +162,7 @@ class TestShardedOptimizer(ShardedTensorTestBase):
],
)
sharded_model = MyShardedModel(spec=rowwise_spec).cuda(self.rank)
sharded_model_params = dict(named_params_with_sharded_tensor(sharded_model))
sharded_model_params = dict(sharded_model.named_parameters())
param_keys = list(sharded_model_params.keys())
self.assertEqual(len(param_keys), 2)
self.assertTrue("param" in param_keys)
@ -171,7 +170,7 @@ class TestShardedOptimizer(ShardedTensorTestBase):
sharded_linear = MyShardedLinear(rank=self.rank).cuda(self.rank)
sharded_linear.shard_parameter()
sharded_linear_params = dict(named_params_with_sharded_tensor(sharded_linear))
sharded_linear_params = dict(sharded_linear.named_parameters())
param_keys = list(sharded_linear_params.keys())
self.assertEqual(len(param_keys), 4)
self.assertTrue("linear1.bias" in param_keys)
@ -180,9 +179,5 @@ class TestShardedOptimizer(ShardedTensorTestBase):
self.assertTrue("linear2.weight" in param_keys)
self.assertFalse("bias" in param_keys)
if __name__ == '__main__':
run_tests()

View File

@ -12,7 +12,6 @@ from torch.distributed._shard.api import (
)
from torch.distributed._shard.sharded_optim import (
ShardedOptimizer,
named_params_with_sharded_tensor,
)
from torch.distributed._shard.sharded_tensor import (
empty,
@ -127,7 +126,7 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
previous_sharded_weight = sharded_weight.clone()
previous_sharded_bias = sharded_linear.bias.clone()
sharded_optim = ShardedOptimizer(
dict(named_params_with_sharded_tensor(sharded_linear)),
dict(sharded_linear.named_parameters()),
torch.optim.SGD,
lr=0.1,
)
@ -192,6 +191,7 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
def test_sharded_linear_errors(self):
for spec in generate_chunk_sharding_specs_for_test(0):
fc1 = torch.nn.Linear(10, 10).cuda(self.rank)
shard_parameter(fc1, "weight", spec)
shard_parameter(fc1, "bias", spec)
with self.assertRaisesRegex(TypeError, 'bias needs to be torch.Tensor'):
fc1(torch.rand(10, 10).cuda(self.rank))

View File

@ -7,7 +7,6 @@ import torch
import torch.distributed as dist
from torch.distributed._shard.sharded_optim import (
ShardedOptimizer,
named_params_with_sharded_tensor,
)
from torch.distributed._shard.api import (
shard_parameter,
@ -171,7 +170,7 @@ class TestShardedTensorMegatronLinear(ShardedTensorTestBase):
optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1)
optim.step()
sharded_optim = ShardedOptimizer(
dict(named_params_with_sharded_tensor(sharded_megatron_lm)),
dict(sharded_megatron_lm.named_parameters()),
torch.optim.SGD,
lr=0.1,
)

View File

@ -404,6 +404,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True)
st_metadata = st.metadata()
self.assertEqual(torch.Size([10, 20]), st_metadata.size)
self.assertEqual(torch.Size([10, 20]), st.size())
self.assertEqual(torch.float, st.dtype)
self.assertEqual(torch.strided, st.layout)
self.assertEqual(False, st.requires_grad)
@ -432,7 +433,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
# test read only properties, they're read only as we can't simply change
# the global metadata without changing the underlying shard's properties
with self.assertRaisesRegex(AttributeError, "can't set attribute"):
with self.assertRaisesRegex(RuntimeError, "torch function '__set__'"):
st.requires_grad = True
@with_comms
@ -952,7 +953,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"])
with self.assertRaisesRegex(ValueError, 'Only torch.strided layout is currently supported'):
sharded_tensor.empty(spec, 10, 20, layout=torch.sparse)
sharded_tensor.empty(spec, 10, 20, layout=torch.sparse_coo)
spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"])
with self.assertRaisesRegex(ValueError, 'Only torch.contiguous_format memory_format is currently supported'):
@ -1069,11 +1070,18 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True)
self.assertEqual(st.size(1), 20)
# Test with negative indexed size
st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True)
self.assertEqual(st.size(-1), 20)
# Test with dim/ndim
self.assertEqual(st.dim(), 2)
self.assertEqual(st.ndim, 2)
# Test with invalid input
st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True)
with self.assertRaisesRegex(ValueError, 'must be within the range of tensor dimensions \\[-2, 2\\)'):
with self.assertRaisesRegex(IndexError, 'Dimension out of range'):
st.size(-3)
with self.assertRaisesRegex(ValueError, 'must be within the range of tensor dimensions \\[-2, 2\\)'):
with self.assertRaisesRegex(IndexError, 'Dimension out of range'):
st.size(2)
with self.assertRaises(TypeError):
@ -1545,7 +1553,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase):
# CPU sharded tensor should return the same instance (no copy)
st_cpu = sharded_tensor.zeros(cpu_spec, h, w, process_group=gloo_pg)
new_st_cpu = st_cpu.cpu()
self.assertEqual(st_cpu, new_st_cpu)
self.assertTrue(st_cpu is new_st_cpu)
# GPU sharded tensor to cpu
st = sharded_tensor.zeros(spec, h, w)
@ -1553,7 +1561,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase):
spec_before_move = st.sharding_spec()
new_st = st.cpu(process_group=gloo_pg)
# return a copy of orginal st
self.assertNotEqual(st, new_st)
self.assertFalse(st is new_st)
# check the spec is still ChunkShardingSpec
spec_after_move = new_st.sharding_spec()
self.assertIsInstance(spec_after_move, ChunkShardingSpec)
@ -1586,7 +1594,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase):
st = sharded_tensor.zeros(mixed_spec, h, w, process_group=gloo_pg)
new_st = st.cpu()
# return a copy of orginal st
self.assertNotEqual(st, new_st)
self.assertFalse(st is new_st)
# check the spec is still ChunkShardingSpec
spec_after_move = new_st.sharding_spec()
self.assertIsInstance(spec_after_move, ChunkShardingSpec)
@ -1603,6 +1611,158 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase):
for meta in metas:
self.assertEqual(str(meta.placement.device()), "cpu")
@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_sharded_tensor_to_cuda(self):
cpu_spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cpu",
"rank:1/cpu",
"rank:2/cpu",
"rank:3/cpu",
],
)
spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
],
)
h, w = 10, 20
# CUDA sharded tensor should return a new ShardedTensor, but same
# local shards(no movements)
st_cuda = sharded_tensor.zeros(spec, h, w)
new_st_cuda = st_cuda.cuda()
self.assertTrue(st_cuda is not new_st_cuda)
self.assertTrue(st_cuda.local_tensor() is new_st_cuda.local_tensor())
gloo_pg = dist.new_group(backend="gloo")
# CPU sharded tensor to GPU
st_cpu = sharded_tensor.zeros(cpu_spec, h, w, process_group=gloo_pg)
# test ability to move st to GPU
spec_before_move = st_cpu.sharding_spec()
new_st_gpu = st_cpu.cuda()
# check the spec is still ChunkShardingSpec
spec_after_move = new_st_gpu.sharding_spec()
self.assertIsInstance(spec_after_move, ChunkShardingSpec)
# test specs before and after the move almost the same except placement device
self.assertEqual(spec_before_move.dim, spec_after_move.dim)
self.assertEqual(len(spec_before_move.placements), len(spec_after_move.placements))
for i, remote_device_after in enumerate(spec_after_move.placements):
remote_device_before = spec_before_move.placements[i]
self.assertEqual(remote_device_before.rank(), remote_device_after.rank())
self.assertEqual(str(remote_device_before.device().type), "cpu")
self.assertEqual(str(remote_device_after.device().type), "cuda")
# ensure metdata also get changed to GPU
metas = new_st_gpu.metadata().shards_metadata
for meta in metas:
self.assertEqual(str(meta.placement.device().type), "cuda")
@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_sharded_tensor_to_test(self):
spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
],
)
h, w = 10, 20
# CUDA sharded tensor should return a new ShardedTensor, but same
# local shards(no movements)
st = sharded_tensor.zeros(spec, h, w)
# test same dtype, device return itself
st_self = st.to(dtype=st.dtype, device="cuda")
self.assertTrue(st_self is st)
# test dtype to
st_16 = st.to(torch.float16)
self.assertFalse(st_16 is st)
self.assertEqual(st_16.dtype, torch.float16)
# test device to
st_cpu = st.to(device=torch.device("cpu"))
self.assertFalse(st_cpu is st)
self.assertEqual(st_cpu.local_tensor().device.type, "cpu")
st_cuda = st_cpu.to(device=torch.device("cuda"))
self.assertEqual(st_cuda.local_tensor().device.type, "cuda")
# non-kwarg device to
st_cuda = st_cpu.to(torch.device("cuda"))
self.assertEqual(st_cuda.local_tensor().device.type, "cuda")
st_cpu = st_cuda.to(torch.device("cpu"))
self.assertEqual(st_cpu.local_tensor().device.type, "cpu")
# with string like device conversion
st_cpu = st_cuda.to("cpu")
self.assertEqual(st_cpu.local_tensor().device.type, "cpu")
st_cuda = st_cpu.to("cuda")
self.assertEqual(st_cuda.local_tensor().device.type, "cuda")
# with int like device conversion
st_cpu = st_cuda.to("cpu")
self.assertEqual(st_cpu.local_tensor().device.type, "cpu")
st_cuda = st_cpu.to(self.rank)
self.assertEqual(st_cuda.local_tensor().device.type, "cuda")
# test tensor to
cuda_tensor = torch.randn(3, 4, dtype=torch.float16, device="cuda")
st_cuda = st.to(cuda_tensor)
self.assertFalse(st_cuda is st)
self.assertEqual(st_cuda.dtype, torch.float16)
cuda_tensor = torch.randn(3, 4, dtype=torch.float16, device="cuda:2")
st_cuda = st.to(cuda_tensor)
self.assertEqual(st_cuda.dtype, torch.float16)
# test dtype and device together
st_cpu_16 = st.to("cpu", torch.float16)
self.assertEqual(st_cpu_16.dtype, torch.float16)
self.assertEqual(st_cpu_16.local_tensor().device.type, "cpu")
st_cuda_32 = st_cpu_16.to("cuda", torch.float32)
self.assertEqual(st_cuda_32.dtype, torch.float32)
self.assertEqual(st_cuda_32.local_tensor().device.type, "cuda")
# test pass additional process group
gloo_pg = dist.new_group(backend="gloo")
st_gloo = st.to(device="cpu", process_group=gloo_pg)
self.assertFalse(st_gloo is st)
self.assertEqual(st_gloo.local_tensor().device.type, "cpu")
self.assertEqual(st_gloo._process_group, gloo_pg)
@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_sharded_tensor_device(self):
spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
],
)
h, w = 10, 20
# CUDA sharded tensor should return a new ShardedTensor, but same
# local shards(no movements)
st = sharded_tensor.zeros(spec, h, w)
current_device = torch.device(torch.cuda.current_device())
self.assertEqual(current_device, st.device)
# test after to cpu, device get changed
cpu_device = torch.device("cpu")
st_cpu = st.to(device=cpu_device)
self.assertEqual(st_cpu.device, cpu_device)
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_uneven_shards(self):

View File

@ -8,7 +8,6 @@ import torch.nn as nn
import torch.distributed as dist
from torch.distributed._shard.sharded_optim import (
ShardedOptimizer,
named_params_with_sharded_tensor,
)
from torch.testing._internal.common_distributed import (
requires_nccl,
@ -19,7 +18,10 @@ from torch.distributed._shard.sharding_plan import ShardingPlan, ShardingPlanner
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)
from torch.testing._internal.distributed._shard.sharded_tensor import (
TEST_GPU_NUM,
ShardedTensorTestBase,
@ -169,7 +171,7 @@ class TestShardingPlan(ShardedTensorTestBase):
optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1)
optim.step()
sharded_optim = ShardedOptimizer(
dict(named_params_with_sharded_tensor(megatron_lm)),
dict(megatron_lm.named_parameters()),
torch.optim.SGD,
lr=0.1,
)
@ -360,3 +362,6 @@ class TestShardingPlan(ShardedTensorTestBase):
if self.rank >= 2:
shard_module(megatron_lm, sharding_plan, process_group=pg)
if __name__ == "__main__":
run_tests()

View File

@ -119,15 +119,7 @@ def shard_parameter(
st = _shard_tensor(tensor, sharding_spec, src_rank, process_group)
# Replace param with ShardedTensor.
# Need to delete the attribute first since param_name might be
# torch.nn.Parameter and can't be replaced with ShardedTensor which is
# not torch.nn.Parameter.
delattr(module, param_name)
# Now we can set the attribute appropriately.
setattr(module, param_name, st)
module.register_parameter(param_name, nn.Parameter(st))
def _replicate_tensor(tensor: torch.Tensor, process_group=None) -> ReplicatedTensor:
"""

View File

@ -20,8 +20,7 @@ class ShardedOptimizer(optim.Optimizer):
Args:
named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict
of parameters, where key is the parameter key, value is either
Tensor or ShardedTensor parameter. This usually used in
conjunction with :meth:`named_params_with_sharded_tensor`
Tensor or ShardedTensor parameter.
optimizer_class (torch.optim.Optimizer): the Optimizer to use
locally, i.e. torch.optim.SGD, torch.optim.Adagrad, etc.
*optimizer_args: the arguments to initialize the optimizer.

View File

@ -3,6 +3,7 @@ import torch.distributed._shard.sharded_tensor._ops.elementwise_ops
import torch.distributed._shard.sharded_tensor._ops.math_ops
import torch.distributed._shard.sharded_tensor._ops.matrix_ops
import torch.distributed._shard.sharded_tensor._ops.tensor_ops
import torch.distributed._shard.sharded_tensor._ops.misc_ops
from .binary_cmp import equal, allclose
from .init import kaiming_uniform_, normal_, uniform_, constant_

View File

@ -0,0 +1,12 @@
import torch
from torch.distributed._shard.sharded_tensor import (
_sharded_op_impl,
)
# This is used by `_apply()` within module.py to set new
# parameters after apply a certain method, we should follow
# the future behavior of overwriting the existing tensor
# instead of doing in-place change using `.data = `.
@_sharded_op_impl(torch._has_compatible_shallow_copy_type)
def tensor_has_compatible_shallow_copy_type(types, args=(), kwargs=None, pg=None):
return False

View File

@ -31,7 +31,6 @@ def tensor_deepcopy(types, args=(), kwargs=None, pg=None):
# Tensor properties access
_register_default_op(torch.Tensor.requires_grad.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.shape.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.dtype.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.layout.__get__, _sharded_op_impl) # type: ignore[attr-defined]
@ -44,6 +43,27 @@ _register_default_op(torch.Tensor.contiguous, _sharded_op_impl)
# __reduce_ex__ to dispatch to get_state/set_state
_register_default_op(torch.Tensor.__reduce_ex__, _sharded_op_impl)
# autograd related properties
_register_default_op(torch.Tensor.requires_grad.__get__, _sharded_op_impl) # type: ignore[attr-defined]
# TODO: set grad with a ShardedTensor that consists of all local grads
_register_default_op(torch.Tensor.grad.__get__, _sharded_op_impl) # type: ignore[union-attr]
_register_default_op(torch.Tensor.grad_fn.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.is_leaf.__get__, _sharded_op_impl) # type: ignore[attr-defined]
# device property is ambiguous as from a global prospective,
# ShardedTensor.device consists of multiple devices (might even across hosts)
# We choose to return the current device of the local tensor to represent
# the device property on each rank
@_sharded_op_impl(torch.Tensor.device.__get__)
def tensor_device(types, args=(), kwargs=None, pg=None):
self_st = args[0]
# Validate types
if not isinstance(self_st, ShardedTensor):
raise TypeError("input needs to be a ShardedTensor")
return self_st.local_shards()[0].tensor.device
def sharded_type_as_check(*args, **kwargs):
"""
Perform extra checks for the sharded_type_as op such as the input needs to
@ -167,6 +187,9 @@ def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None):
for local_shard in self_st.local_shards():
local_shard.tensor.requires_grad_(requires_grad)
# update the wrapper class property
with torch._C.DisableTorchFunction():
self_st.requires_grad_(requires_grad)
# update the metadata in the meanwhile
self_st._metadata.tensor_properties.requires_grad = requires_grad
return self_st

View File

@ -7,7 +7,6 @@ from typing import (
Optional,
Sequence,
Tuple,
Union,
cast,
)
import copy
@ -40,7 +39,6 @@ from .utils import (
build_metadata_from_local_shards,
build_global_metadata
)
from torch.overrides import handle_torch_function
from torch.distributed.remote_device import _remote_device
from torch.utils._pytree import tree_map
@ -67,9 +65,9 @@ def _register_remote_shards(sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]]
else:
sharded_tensor._register_remote_shards(rrefs, rpc_rank)
class ShardedTensor(object):
class ShardedTensor(torch.Tensor):
"""
ShardedTensor is an abstraction to represent Tensors that are sharded
ShardedTensor is an torch.Tensor subclass to represent Tensors that are sharded
across multiple devices and multiple processes.
ShardedTensor is initialized in an SPMD like fashion where each rank
@ -116,10 +114,47 @@ class ShardedTensor(object):
"""
def __new__(cls, *args, **kwargs):
# Use __new__ for logging purposes.
_sharding_spec: shard_spec.ShardingSpec
_metadata: ShardedTensorMetadata
def __new__(cls,
sharding_spec: shard_spec.ShardingSpec,
*size,
**kwargs):
# Use __new__ to construct a wrapper tensor, for recording tensor
# properties and logging purposes.
torch._C._log_api_usage_once("torch.distributed._shard.sharded_tensor")
return super(ShardedTensor, cls).__new__(cls)
# check sharding spec and build sharded tensor metadata
if not isinstance(sharding_spec, shard_spec.ShardingSpec):
raise ValueError(f'Expecting ShardingSpec but got: {type(sharding_spec)}')
sizes = _flatten_tensor_size(size)
dtype = kwargs['dtype']
layout = kwargs['layout']
pin_memory = kwargs['pin_memory']
requires_grad = kwargs['requires_grad']
if dtype is None:
dtype = torch.get_default_dtype()
tensor_properties = TensorProperties(dtype, layout, requires_grad, pin_memory=pin_memory)
sharded_tensor_metadata = sharding_spec.build_metadata(
sizes, tensor_properties=tensor_properties)
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
sizes,
dtype=dtype,
layout=layout,
pin_memory=pin_memory,
requires_grad=requires_grad
)
# set sharding spec
r._sharding_spec = sharding_spec
# set metadata
r._metadata = sharded_tensor_metadata
return r
def __init__(
self,
@ -137,42 +172,25 @@ class ShardedTensor(object):
# _process_group, _local_shards, etc.
self._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
tensor_properties = TensorProperties(dtype, layout, requires_grad, memory_format, pin_memory)
if tensor_properties is None:
raise ValueError('tensor_properties must not be None.')
if tensor_properties.dtype is None:
tensor_properties.dtype = torch.get_default_dtype()
if tensor_properties.layout != torch.strided:
if layout != torch.strided:
raise ValueError('Only torch.strided layout is currently supported')
if tensor_properties.memory_format != torch.contiguous_format:
if memory_format != torch.contiguous_format:
raise ValueError('Only torch.contiguous_format memory_format is currently supported')
dims = _flatten_tensor_size(size)
if not isinstance(sharding_spec, shard_spec.ShardingSpec):
raise ValueError(f'Expecting ShardingSpec but got: {type(sharding_spec)}')
self._sharding_spec = sharding_spec
sharded_tensor_metadata = sharding_spec.build_metadata(
dims, tensor_properties=tensor_properties)
self._metadata.tensor_properties.memory_format = memory_format
current_rank = dist.get_rank(self._process_group)
for shard_metadata in sharded_tensor_metadata.shards_metadata:
for shard_metadata in self._metadata.shards_metadata:
rank, device = _parse_and_validate_remote_device(self._process_group, shard_metadata.placement)
if rank == current_rank:
local_tensor = _create_tensor_from_params(
shard_metadata.shard_sizes,
local_device=device,
tensor_properties=sharded_tensor_metadata.tensor_properties
tensor_properties=self._metadata.tensor_properties
)
self._local_shards.append(Shard(local_tensor, shard_metadata))
self._metadata = sharded_tensor_metadata
# do post initialization (i.e. register sharded_tensor_id, initialize_rpc)
self._post_init()
@ -266,7 +284,7 @@ class ShardedTensor(object):
return torch.device(torch.cuda.current_device())
return torch.device("cpu")
def gather(
def gather( # type: ignore[override]
self,
dst: int = 0,
out: Optional[torch.Tensor] = None,
@ -407,6 +425,148 @@ class ShardedTensor(object):
)
return st_cpu
def cuda(
self,
device=None,
non_blocking=False,
memory_format=torch.preserve_format,
process_group=None
) -> ShardedTensor:
"""
Returns a copy of this object in CUDA memory, if the original ShardedTensor
is on CPU, we will move the local shard to the current GPU device of each
process in a SPMD fashion.
If this ShardedTensor is already on CUDA memory and local shards on each rank are
already on current device, we still returns a new ShardedTensor object with new
metadata, but no underlying data movements are performed.
.. note:: When moving a ShardedTensor from CPU to GPU, the ShardedTensor might
need to be managed by a different type of ProcessGroup(i.e. ProcessGroupNCCL),
it is the user's responsiblity to explicitly pass in a new process_group that
is compatible with GPU.
"""
if memory_format != torch.preserve_format and \
memory_format != torch.contiguous_format:
raise RuntimeError("Only `torch.contiguous_format` or "
"`torch.preserve_format` is supported!")
if device is not None:
device = torch.device(device) if isinstance(device, str) else device
assert isinstance(device, torch.device) and device.index == torch.cuda.current_device(), \
'''Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!'''
current_device = torch.device(torch.cuda.current_device())
# returns a copy of ShardedTensor on CUDA current device
list_shards: List[Shard] = []
# move all local shards to current device, and change metadata
# if local shards already on the current device, there's no
# real data movement, only the metadata are copied.
for shard in self._local_shards:
cuda_tensor = shard.tensor.cuda(
device=current_device,
non_blocking=non_blocking,
memory_format=memory_format
) # type: ignore[call-arg]
metadata = copy.deepcopy(shard.metadata)
metadata.placement._device = current_device # type: ignore[union-attr]
list_shards.append(
Shard(cuda_tensor, metadata)
)
st_meta = copy.deepcopy(self.metadata())
for meta in st_meta.shards_metadata:
if meta.placement.device().type != "cuda": # type: ignore[union-attr]
meta.placement._device = current_device # type: ignore[union-attr]
pg = self._process_group if process_group is None else process_group
# we need to use `init_from_local_shards` to communicate between ranks
# and update the sharding spec/shards metadata.
st_cuda = ShardedTensor._init_from_local_shards_and_global_metadata(
list_shards,
sharded_tensor_metadata=st_meta,
process_group=pg,
init_rrefs=self._init_rrefs
)
return st_cuda
def to(self, *args, **kwargs) -> ShardedTensor:
current_device = self._local_shards[0].tensor.device
current_dtype = self.dtype
device_to = current_device
dtype_to = current_dtype
if len(args) == 1:
if isinstance(args[0], torch.dtype):
dtype_to = args[0]
elif isinstance(args[0], torch.device):
device_to = args[0]
elif isinstance(args[0], (str, int)):
device_to = torch.device(args[0])
elif isinstance(args[0], torch.Tensor):
dtype_to = args[0].dtype
device_to = args[0].device
else:
raise RuntimeError(f"ShardedTensor.to() have wrong arguments: {args}")
elif len(args) == 2:
device_to, dtype_to = args
else:
dtype_to = kwargs.get("dtype", current_dtype)
device_to = kwargs.get("device", current_device)
device_to = torch.device(device_to) if isinstance(device_to, (str, int)) else device_to
if device_to.type == "cuda":
# if device_to set to cuda, set to current device even
# if user specify the device index.
current_idx = torch.cuda.current_device()
if device_to.index != current_idx:
import warnings
warnings.warn("ShardedTensor.to only move tensor to its current device"
"If you want to put to different device, use `reshard` instead.")
device_to = torch.device(current_idx)
copy_tensor = kwargs.get("copy", False)
non_blocking = kwargs.get("non_blocking", False)
memory_format = kwargs.get("memory_format", torch.preserve_format)
process_group = kwargs.get("process_group", None)
if not copy_tensor and dtype_to == current_dtype and device_to == current_device:
# already have correct dtype and device, return itself
return self
# returns a copy of ShardedTensor on CUDA current device
list_shards: List[Shard] = []
for shard in self._local_shards:
new_tensor = shard.tensor.to( # type: ignore[call-overload]
device=device_to,
dtype=dtype_to,
non_blocking=non_blocking,
copy=copy_tensor,
memory_format=memory_format
)
metadata = copy.deepcopy(shard.metadata)
if metadata.placement is not None:
metadata.placement._device = device_to
list_shards.append(Shard(new_tensor, metadata))
# update metadata
st_meta = copy.deepcopy(self.metadata())
st_meta.tensor_properties.dtype = dtype_to
for meta in st_meta.shards_metadata:
meta.placement._device = device_to # type: ignore[union-attr]
pg = self._process_group if process_group is None else process_group
# we need to use `init_from_local_shards` to communicate between ranks
# and update the sharding spec/shards metadata.
st_to = ShardedTensor._init_from_local_shards_and_global_metadata(
list_shards,
sharded_tensor_metadata=st_meta,
process_group=pg,
init_rrefs=self._init_rrefs
)
return st_to
@classmethod
def _init_from_local_shards(
cls,
@ -446,18 +606,24 @@ class ShardedTensor(object):
gathered_metadatas = [local_sharded_tensor_metadata]
global_sharded_tensor_metadata = build_global_metadata(gathered_metadatas)
tensor_properties = global_sharded_tensor_metadata.tensor_properties
# STEP 3: Validation done, create the actual ShardedTensor and populate fields
# prepare initialization
sharded_tensor = cls.__new__(cls)
sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
# add to metadata and local_shards
sharded_tensor._metadata = global_sharded_tensor_metadata
sharded_tensor._local_shards = local_shards
sharded_tensor._sharding_spec = shard_spec._infer_sharding_spec_from_shards_metadata(
spec = shard_spec._infer_sharding_spec_from_shards_metadata(
global_sharded_tensor_metadata.shards_metadata
)
sharded_tensor = cls.__new__(cls,
spec,
global_sharded_tensor_metadata.size,
dtype=tensor_properties.dtype,
layout=tensor_properties.layout,
pin_memory=tensor_properties.pin_memory,
requires_grad=tensor_properties.requires_grad)
sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
# attach local_shards to the ShardedTensor created
sharded_tensor._local_shards = local_shards
# run post initialization, i.e. map registration, rpc initialization
sharded_tensor._post_init()
@ -598,10 +764,19 @@ class ShardedTensor(object):
if tensor_properties.layout != torch.strided:
raise ValueError('Only torch.strided layout is currently supported')
sharded_tensor = cls.__new__(cls)
sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
if sharding_spec is None:
spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata)
else:
spec = sharding_spec
sharded_tensor._metadata = sharded_tensor_metadata
sharded_tensor = cls.__new__(cls,
spec,
sharded_tensor_metadata.size,
dtype=tensor_properties.dtype,
layout=tensor_properties.layout,
pin_memory=tensor_properties.pin_memory,
requires_grad=tensor_properties.requires_grad)
sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
local_shard_metadatas = []
@ -655,11 +830,6 @@ class ShardedTensor(object):
# done validation, add local_shards
sharded_tensor._local_shards = local_shards
if sharding_spec is None:
sharded_tensor._sharding_spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata)
else:
sharded_tensor._sharding_spec = sharding_spec
# run post initialization, i.e. map registration, rpc initialization
sharded_tensor._post_init()
return sharded_tensor
@ -825,6 +995,11 @@ class ShardedTensor(object):
f"torch function '{func.__name__}', with args: {args} and "
f"kwargs: {kwargs} not supported for ShardedTensor!")
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise RuntimeError(f"A {cls.__name__} object is being used from c++ while calling {func.__module__}.{func.__name__} "
"but the there is no custom __torch_dispatch__ implementation for it.")
def metadata(self) -> ShardedTensorMetadata:
"""
Returns a :class:`ShardedTensorMetadata` object corresponding to the
@ -840,186 +1015,12 @@ class ShardedTensor(object):
"""
return self._local_shards
def size(self, dim: int = None) -> Union[torch.Size, int]:
"""
Returns a :Union:`[torch.Size, int]` which represents the size of the tensor.
The dimension can be specified.
Args:
dim (int, optional): the dimension over which the size represents.
If specified, it returns the size of the given dimension.
If not, it returns a subclass of tuple.
Default: ``None``
Returns:
A :Union:`[torch.Size, int]` represents the size of the tensor.
"""
size = self._metadata.size
if dim is None:
return size
if dim < -len(size) or dim >= len(size):
raise ValueError(
"Argument ``dim`` must be within the range of tensor "
f"dimensions [-{len(size)}, {len(size)})"
)
return size[dim]
def is_pinned(self) -> bool:
def is_pinned(self) -> bool: # type: ignore[override]
"""
Returns True if the sharded tensor (each local shard) resides in pinned memory.
"""
return self._metadata.tensor_properties.pin_memory
def is_contiguous(self) -> bool:
"""
Returns True if the sharded tensor (each local shard) is contiguous in memory
in the order specified by memory format.
"""
return self._metadata.tensor_properties.memory_format == torch.contiguous_format
def dim(self) -> int:
"""
Returns a `int` which represents the dimension of the tensor.
Returns:
A `int` represents the dimension of the tensor.
"""
return len(self._metadata.size)
# TODO: This op needs further definition of what exactly its behavior will be.
def contiguous(self) -> ShardedTensor:
"""
Returns a new sharded tensor with the local tensor is made to contiguous.
"""
if self.is_contiguous():
return self
local_shards = []
for shard in self.local_shards():
local_shards.append(
Shard(shard.tensor.contiguous(), shard.metadata)
)
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards,
self._metadata,
process_group=self._process_group,
init_rrefs=self._init_rrefs,
)
def masked_fill(self, mask, value) -> ShardedTensor:
"""
Returns a new sharded tensor with each shard has been filled elements
with value where mask is True. The shape of mask must be broadcastable
with the shape of the underlying tensor.
Args:
mask (BoolTensor): the boolean mask.
value (float): the value to fill in with.
Returns:
A :class:`ShardedTensor` object whose shards have been applied masked_fill.
"""
return handle_torch_function(
torch.Tensor.masked_fill, (self, mask, value), self, mask, value
)
def type_as(self, tensor) -> ShardedTensor:
"""
Returns a new sharded tensor with each shard has been
cast to the type of the given tensor.
Args:
tensor (Tensor): the tensor which has the desired type.
Returns:
A :class:`ShardedTensor` object whose shards have been applied type_as.
"""
return handle_torch_function(torch.Tensor.type_as, (self, tensor), self, tensor)
def view(self, *shape) -> ShardedTensor:
"""
Returns a new sharded tensor with the same data as the
self tensor but of a different shape for its local tensor.
For now, we only support to pass through the view op to the local
tensor.
Args:
shape (torch.Size or int...) the desired size.
Returns:
A :class:`ShardedTensor` object whose shards have been applied
with view to its local tensor.
"""
return handle_torch_function(torch.Tensor.view, (self, *shape), self, *shape)
def transpose(self, dim0, dim1) -> ShardedTensor:
"""
Returns a new sharded tensor with the given dimensions transposed.
During the transpose, we keep the original shading dim, e.g., if the
tensor is sharded by dim 0 and if we call transpose(1, 0). The returned
tensor will be sharded by dim 1.
Args:
dim0 (int): the first dimension to be transposed.
dim1 (int): the second dimension to be transposed.
Returns:
A :class:`ShardedTensor` object whose dims have been transposed
specified in the input.
"""
return handle_torch_function(torch.Tensor.transpose, (self, dim0, dim1), self, dim0, dim1)
def bmm(self, st2, *, out=None) -> ShardedTensor:
"""
Performs a batch matrix-matrix product of matrices stored in self and st2.
Warning: For now we only supports the case when both tensors are sharded
by dim 0 so that no communication is needed.
Args:
st2 (ShardedTensor) the second batch of sharded matrices to be multiplied.
Returns:
A :class:`ShardedTensor` object which is the result of the batch multiplication.
"""
return handle_torch_function(torch.Tensor.bmm, (self, st2, out), self, st2, out=out)
def chunk(self, chunks, dim=0) -> List[ShardedTensor]:
"""
Attempts to split a tensor into the specified number of chunks.
Each chunk is a view of the input tensor.
Warnings: Chunk by the sharding dim is not supported.
Args:
chunks (int) number of chunks to return
dim (int) dimension along which to split the tensor
Returns:
A List of :class:`ShardedTensor` object chunked on dims.
"""
return handle_torch_function(torch.Tensor.chunk, (self, chunks, dim), self, chunks, dim=dim)
@property
def shape(self):
return self._metadata.size
@property
def requires_grad(self):
return self._metadata.tensor_properties.requires_grad
def requires_grad_(self, requires_grad=True):
return handle_torch_function(torch.Tensor.requires_grad_, (self, requires_grad), self, requires_grad)
@property
def dtype(self):
return self._metadata.tensor_properties.dtype
@property
def layout(self):
return self._metadata.tensor_properties.layout
def _register_remote_shards(self, remote_shards: List[rpc.RRef[Shard]], rpc_rank: int):
self._remote_shards[rpc_rank] = remote_shards
@ -1043,45 +1044,6 @@ class ShardedTensor(object):
def __repr__(self):
return f'ShardedTensor({self._metadata})'
def __add__(self, other):
return handle_torch_function(torch.Tensor.__add__, (self, other), self, other)
def __radd__(self, other):
return handle_torch_function(torch.Tensor.__radd__, (self, other), self, other)
def __sub__(self, other):
return handle_torch_function(torch.Tensor.__sub__, (self, other), self, other)
def __rsub__(self, other):
return handle_torch_function(torch.Tensor.__rsub__, (self, other), self, other)
def __mul__(self, other):
return handle_torch_function(torch.Tensor.__mul__, (self, other), self, other)
def __rmul__(self, other):
return handle_torch_function(torch.Tensor.__rmul__, (self, other), self, other)
def __truediv__(self, other):
return handle_torch_function(torch.Tensor.__div__, (self, other), self, other)
def __rtruediv__(self, other):
return handle_torch_function(torch.Tensor.__rdiv__, (self, other), self, other)
def tanh(self):
return handle_torch_function(torch.Tensor.tanh, (self,), self)
def __getitem__(self, key):
return handle_torch_function(torch.Tensor.__getitem__, (self, key), self, key)
def __deepcopy__(self, memo):
return handle_torch_function(torch.Tensor.__deepcopy__, (self, memo), self, memo)
def clone(self, *, memory_format=torch.preserve_format):
return handle_torch_function(torch.Tensor.clone, (self,), self, memory_format=memory_format)
def detach(self):
return handle_torch_function(torch.Tensor.detach, (self,), self)
@dataclass
class ProcessGroupState:
"""

View File

@ -1,7 +1,5 @@
# coding=utf-8
from typing import cast
import torch
import torch.distributed as dist
from ._common import (
@ -158,7 +156,7 @@ def _validate_embedding_param(args, kwargs):
raise TypeError("input need to be torch.Tensor")
if not isinstance(weight, ShardedTensor):
raise TypeError("weight needs to be ShardedTensor")
weight_size = cast(torch.Size, weight.size())
weight_size = weight.size()
if len(weight_size) != 2:
raise ValueError("Weight needs to have exactly 2 dims")
if int(torch.min(input).item()) < 0:

View File

@ -204,7 +204,7 @@ def _validate_embedding_bag_param(args, kwargs):
raise TypeError("weight needs to be ShardedTensor")
if len(input.size()) > 2:
raise ValueError("Input more than 2 dims not supported")
weight_size = cast(torch.Size, weight.size())
weight_size = weight.size()
if len(weight_size) != 2:
raise ValueError("Weight needs to have exactly 2 dims")
if int(torch.min(input).item()) < 0:

View File

@ -1,4 +1,4 @@
from typing import List, cast
from typing import List
import torch
import torch.distributed as dist
@ -105,14 +105,14 @@ def sharded_linear(types, args, kwargs, pg):
world_size = dist.get_world_size(pg)
rank = dist.get_rank(pg)
if sharding_dim == 1 and isinstance(input, torch.Tensor):
return _handle_row_wise_sharding_tensor(
input, world_size, weight, rank, local_shard_t, bias, pg
)
elif sharding_dim == 1 and isinstance(input, ShardedTensor):
if sharding_dim == 1 and isinstance(input, ShardedTensor):
return _handle_row_wise_sharding_sharded_tensor(
input, world_size, weight, local_shard_t, bias, pg
)
elif sharding_dim == 1 and isinstance(input, torch.Tensor):
return _handle_row_wise_sharding_tensor(
input, world_size, weight, rank, local_shard_t, bias, pg
)
elif sharding_dim == 0:
return _handle_col_wise_sharding(
input, world_size, weight, rank, local_shard_t, bias, pg
@ -125,7 +125,7 @@ def sharded_linear(types, args, kwargs, pg):
def _validate_linear_op_param(args, kwargs):
"""
Validate input params of sharded embedding op.
Validate input params of sharded linear op.
Args:
input: input of the linear layer.
@ -141,13 +141,13 @@ def _validate_linear_op_param(args, kwargs):
# Validate types
if not isinstance(input, torch.Tensor) and not isinstance(input, ShardedTensor):
raise TypeError("input needs to be either torch.Tensor or ShardedTensor")
if not isinstance(bias, torch.Tensor):
if type(bias) != torch.Tensor and type(bias) != torch.nn.Parameter:
raise TypeError("bias needs to be torch.Tensor")
if not isinstance(weight, ShardedTensor):
raise TypeError("weight needs to be ShardedTensor")
if len(input.size()) < 1: # type: ignore[arg-type]
raise ValueError("Input needs to have at least 1 dim")
weight_size = cast(torch.Size, weight.size())
weight_size = weight.size()
if len(weight_size) != 2:
raise ValueError("Weight needs to have exactly 2 dims")
if len(bias.size()) != 1: