mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 00:54:56 +08:00
[_shard] make ShardedTensor be a Tensor and nn.Parameter (#79825)
Differential Revision: [D37707371](https://our.internmc.facebook.com/intern/diff/D37707371) Pull Request resolved: https://github.com/pytorch/pytorch/pull/79825 Approved by: https://github.com/kumpera
This commit is contained in:
committed by
PyTorch MergeBot
parent
38988a8d14
commit
9c32439a77
@ -64,11 +64,7 @@ def assert_state_dict_equal(
|
||||
|
||||
for key, value_1 in state_dict_1.items():
|
||||
value_2 = state_dict_2[key]
|
||||
if isinstance(value_1, torch.Tensor):
|
||||
self.assertTrue(
|
||||
torch.equal(value_1, value_2), f"Key {key}'s tensor does not match"
|
||||
)
|
||||
elif isinstance(value_1, ShardedTensor):
|
||||
if isinstance(value_1, ShardedTensor):
|
||||
for local_shard_1, local_shard_2 in zip(
|
||||
value_1.local_shards(), value_2.local_shards()
|
||||
):
|
||||
@ -76,6 +72,10 @@ def assert_state_dict_equal(
|
||||
torch.equal(local_shard_1.tensor, local_shard_1.tensor),
|
||||
f"Key {key}'s shard does not match",
|
||||
)
|
||||
elif isinstance(value_1, torch.Tensor):
|
||||
self.assertTrue(
|
||||
torch.equal(value_1, value_2), f"Key {key}'s tensor does not match"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@ -63,11 +63,7 @@ def assert_state_dict_equal(
|
||||
|
||||
for key, value_1 in state_dict_1.items():
|
||||
value_2 = state_dict_2[key]
|
||||
if isinstance(value_1, torch.Tensor):
|
||||
self.assertTrue(
|
||||
torch.equal(value_1, value_2), f"Key {key}'s tensor does not match"
|
||||
)
|
||||
elif isinstance(value_1, ShardedTensor):
|
||||
if isinstance(value_1, ShardedTensor):
|
||||
for local_shard_1, local_shard_2 in zip(
|
||||
value_1.local_shards(), value_2.local_shards()
|
||||
):
|
||||
@ -75,6 +71,10 @@ def assert_state_dict_equal(
|
||||
torch.equal(local_shard_1.tensor, local_shard_1.tensor),
|
||||
f"Key {key}'s shard does not match",
|
||||
)
|
||||
elif isinstance(value_1, torch.Tensor):
|
||||
self.assertTrue(
|
||||
torch.equal(value_1, value_2), f"Key {key}'s tensor does not match"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@ -13,7 +13,6 @@ from torch.distributed._shard.sharding_spec import (
|
||||
)
|
||||
from torch.distributed._shard.sharded_optim import (
|
||||
ShardedOptimizer,
|
||||
named_params_with_sharded_tensor
|
||||
)
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
@ -35,7 +34,7 @@ class MyShardedModel(torch.nn.Module):
|
||||
torch.manual_seed(0)
|
||||
self.param = torch.nn.Parameter(torch.rand(5, 10))
|
||||
if spec is not None:
|
||||
self.sharded_param = sharded_tensor.rand(spec, 20, 10, requires_grad=True, process_group=group)
|
||||
self.sharded_param = torch.nn.Parameter(sharded_tensor.rand(spec, 20, 10, requires_grad=True, process_group=group))
|
||||
else:
|
||||
self.sharded_param = torch.nn.Parameter(torch.rand(5, 10))
|
||||
|
||||
@ -110,7 +109,7 @@ class TestShardedOptimizer(ShardedTensorTestBase):
|
||||
local_model.sharded_param.detach().clone().requires_grad_()
|
||||
|
||||
local_optim = optim.SGD(local_model.parameters(), lr=0.1)
|
||||
sharded_model_params = dict(named_params_with_sharded_tensor(sharded_model))
|
||||
sharded_model_params = dict(sharded_model.named_parameters())
|
||||
sharded_optim = ShardedOptimizer(sharded_model_params, optim.SGD, lr=0.1)
|
||||
|
||||
local_optim.zero_grad()
|
||||
@ -163,7 +162,7 @@ class TestShardedOptimizer(ShardedTensorTestBase):
|
||||
],
|
||||
)
|
||||
sharded_model = MyShardedModel(spec=rowwise_spec).cuda(self.rank)
|
||||
sharded_model_params = dict(named_params_with_sharded_tensor(sharded_model))
|
||||
sharded_model_params = dict(sharded_model.named_parameters())
|
||||
param_keys = list(sharded_model_params.keys())
|
||||
self.assertEqual(len(param_keys), 2)
|
||||
self.assertTrue("param" in param_keys)
|
||||
@ -171,7 +170,7 @@ class TestShardedOptimizer(ShardedTensorTestBase):
|
||||
|
||||
sharded_linear = MyShardedLinear(rank=self.rank).cuda(self.rank)
|
||||
sharded_linear.shard_parameter()
|
||||
sharded_linear_params = dict(named_params_with_sharded_tensor(sharded_linear))
|
||||
sharded_linear_params = dict(sharded_linear.named_parameters())
|
||||
param_keys = list(sharded_linear_params.keys())
|
||||
self.assertEqual(len(param_keys), 4)
|
||||
self.assertTrue("linear1.bias" in param_keys)
|
||||
@ -180,9 +179,5 @@ class TestShardedOptimizer(ShardedTensorTestBase):
|
||||
self.assertTrue("linear2.weight" in param_keys)
|
||||
self.assertFalse("bias" in param_keys)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
||||
@ -12,7 +12,6 @@ from torch.distributed._shard.api import (
|
||||
)
|
||||
from torch.distributed._shard.sharded_optim import (
|
||||
ShardedOptimizer,
|
||||
named_params_with_sharded_tensor,
|
||||
)
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
empty,
|
||||
@ -127,7 +126,7 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
|
||||
previous_sharded_weight = sharded_weight.clone()
|
||||
previous_sharded_bias = sharded_linear.bias.clone()
|
||||
sharded_optim = ShardedOptimizer(
|
||||
dict(named_params_with_sharded_tensor(sharded_linear)),
|
||||
dict(sharded_linear.named_parameters()),
|
||||
torch.optim.SGD,
|
||||
lr=0.1,
|
||||
)
|
||||
@ -192,6 +191,7 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
|
||||
def test_sharded_linear_errors(self):
|
||||
for spec in generate_chunk_sharding_specs_for_test(0):
|
||||
fc1 = torch.nn.Linear(10, 10).cuda(self.rank)
|
||||
shard_parameter(fc1, "weight", spec)
|
||||
shard_parameter(fc1, "bias", spec)
|
||||
with self.assertRaisesRegex(TypeError, 'bias needs to be torch.Tensor'):
|
||||
fc1(torch.rand(10, 10).cuda(self.rank))
|
||||
|
||||
@ -7,7 +7,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._shard.sharded_optim import (
|
||||
ShardedOptimizer,
|
||||
named_params_with_sharded_tensor,
|
||||
)
|
||||
from torch.distributed._shard.api import (
|
||||
shard_parameter,
|
||||
@ -171,7 +170,7 @@ class TestShardedTensorMegatronLinear(ShardedTensorTestBase):
|
||||
optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1)
|
||||
optim.step()
|
||||
sharded_optim = ShardedOptimizer(
|
||||
dict(named_params_with_sharded_tensor(sharded_megatron_lm)),
|
||||
dict(sharded_megatron_lm.named_parameters()),
|
||||
torch.optim.SGD,
|
||||
lr=0.1,
|
||||
)
|
||||
|
||||
@ -404,6 +404,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
|
||||
st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True)
|
||||
st_metadata = st.metadata()
|
||||
self.assertEqual(torch.Size([10, 20]), st_metadata.size)
|
||||
self.assertEqual(torch.Size([10, 20]), st.size())
|
||||
self.assertEqual(torch.float, st.dtype)
|
||||
self.assertEqual(torch.strided, st.layout)
|
||||
self.assertEqual(False, st.requires_grad)
|
||||
@ -432,7 +433,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
|
||||
|
||||
# test read only properties, they're read only as we can't simply change
|
||||
# the global metadata without changing the underlying shard's properties
|
||||
with self.assertRaisesRegex(AttributeError, "can't set attribute"):
|
||||
with self.assertRaisesRegex(RuntimeError, "torch function '__set__'"):
|
||||
st.requires_grad = True
|
||||
|
||||
@with_comms
|
||||
@ -952,7 +953,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
|
||||
|
||||
spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"])
|
||||
with self.assertRaisesRegex(ValueError, 'Only torch.strided layout is currently supported'):
|
||||
sharded_tensor.empty(spec, 10, 20, layout=torch.sparse)
|
||||
sharded_tensor.empty(spec, 10, 20, layout=torch.sparse_coo)
|
||||
|
||||
spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"])
|
||||
with self.assertRaisesRegex(ValueError, 'Only torch.contiguous_format memory_format is currently supported'):
|
||||
@ -1069,11 +1070,18 @@ class TestShardedTensorChunked(ShardedTensorTestBase):
|
||||
st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True)
|
||||
self.assertEqual(st.size(1), 20)
|
||||
|
||||
# Test with negative indexed size
|
||||
st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True)
|
||||
self.assertEqual(st.size(-1), 20)
|
||||
|
||||
# Test with dim/ndim
|
||||
self.assertEqual(st.dim(), 2)
|
||||
self.assertEqual(st.ndim, 2)
|
||||
# Test with invalid input
|
||||
st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True)
|
||||
with self.assertRaisesRegex(ValueError, 'must be within the range of tensor dimensions \\[-2, 2\\)'):
|
||||
with self.assertRaisesRegex(IndexError, 'Dimension out of range'):
|
||||
st.size(-3)
|
||||
with self.assertRaisesRegex(ValueError, 'must be within the range of tensor dimensions \\[-2, 2\\)'):
|
||||
with self.assertRaisesRegex(IndexError, 'Dimension out of range'):
|
||||
st.size(2)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
@ -1545,7 +1553,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase):
|
||||
# CPU sharded tensor should return the same instance (no copy)
|
||||
st_cpu = sharded_tensor.zeros(cpu_spec, h, w, process_group=gloo_pg)
|
||||
new_st_cpu = st_cpu.cpu()
|
||||
self.assertEqual(st_cpu, new_st_cpu)
|
||||
self.assertTrue(st_cpu is new_st_cpu)
|
||||
|
||||
# GPU sharded tensor to cpu
|
||||
st = sharded_tensor.zeros(spec, h, w)
|
||||
@ -1553,7 +1561,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase):
|
||||
spec_before_move = st.sharding_spec()
|
||||
new_st = st.cpu(process_group=gloo_pg)
|
||||
# return a copy of orginal st
|
||||
self.assertNotEqual(st, new_st)
|
||||
self.assertFalse(st is new_st)
|
||||
# check the spec is still ChunkShardingSpec
|
||||
spec_after_move = new_st.sharding_spec()
|
||||
self.assertIsInstance(spec_after_move, ChunkShardingSpec)
|
||||
@ -1586,7 +1594,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase):
|
||||
st = sharded_tensor.zeros(mixed_spec, h, w, process_group=gloo_pg)
|
||||
new_st = st.cpu()
|
||||
# return a copy of orginal st
|
||||
self.assertNotEqual(st, new_st)
|
||||
self.assertFalse(st is new_st)
|
||||
# check the spec is still ChunkShardingSpec
|
||||
spec_after_move = new_st.sharding_spec()
|
||||
self.assertIsInstance(spec_after_move, ChunkShardingSpec)
|
||||
@ -1603,6 +1611,158 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase):
|
||||
for meta in metas:
|
||||
self.assertEqual(str(meta.placement.device()), "cpu")
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@requires_nccl()
|
||||
def test_sharded_tensor_to_cuda(self):
|
||||
cpu_spec = ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
"rank:0/cpu",
|
||||
"rank:1/cpu",
|
||||
"rank:2/cpu",
|
||||
"rank:3/cpu",
|
||||
],
|
||||
)
|
||||
spec = ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
"rank:0/cuda:0",
|
||||
"rank:1/cuda:1",
|
||||
"rank:2/cuda:2",
|
||||
"rank:3/cuda:3",
|
||||
],
|
||||
)
|
||||
h, w = 10, 20
|
||||
# CUDA sharded tensor should return a new ShardedTensor, but same
|
||||
# local shards(no movements)
|
||||
st_cuda = sharded_tensor.zeros(spec, h, w)
|
||||
new_st_cuda = st_cuda.cuda()
|
||||
self.assertTrue(st_cuda is not new_st_cuda)
|
||||
self.assertTrue(st_cuda.local_tensor() is new_st_cuda.local_tensor())
|
||||
|
||||
gloo_pg = dist.new_group(backend="gloo")
|
||||
|
||||
# CPU sharded tensor to GPU
|
||||
st_cpu = sharded_tensor.zeros(cpu_spec, h, w, process_group=gloo_pg)
|
||||
# test ability to move st to GPU
|
||||
spec_before_move = st_cpu.sharding_spec()
|
||||
new_st_gpu = st_cpu.cuda()
|
||||
# check the spec is still ChunkShardingSpec
|
||||
spec_after_move = new_st_gpu.sharding_spec()
|
||||
self.assertIsInstance(spec_after_move, ChunkShardingSpec)
|
||||
# test specs before and after the move almost the same except placement device
|
||||
self.assertEqual(spec_before_move.dim, spec_after_move.dim)
|
||||
self.assertEqual(len(spec_before_move.placements), len(spec_after_move.placements))
|
||||
for i, remote_device_after in enumerate(spec_after_move.placements):
|
||||
remote_device_before = spec_before_move.placements[i]
|
||||
self.assertEqual(remote_device_before.rank(), remote_device_after.rank())
|
||||
self.assertEqual(str(remote_device_before.device().type), "cpu")
|
||||
self.assertEqual(str(remote_device_after.device().type), "cuda")
|
||||
|
||||
# ensure metdata also get changed to GPU
|
||||
metas = new_st_gpu.metadata().shards_metadata
|
||||
for meta in metas:
|
||||
self.assertEqual(str(meta.placement.device().type), "cuda")
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@requires_nccl()
|
||||
def test_sharded_tensor_to_test(self):
|
||||
spec = ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
"rank:0/cuda:0",
|
||||
"rank:1/cuda:1",
|
||||
"rank:2/cuda:2",
|
||||
"rank:3/cuda:3",
|
||||
],
|
||||
)
|
||||
h, w = 10, 20
|
||||
# CUDA sharded tensor should return a new ShardedTensor, but same
|
||||
# local shards(no movements)
|
||||
st = sharded_tensor.zeros(spec, h, w)
|
||||
# test same dtype, device return itself
|
||||
st_self = st.to(dtype=st.dtype, device="cuda")
|
||||
self.assertTrue(st_self is st)
|
||||
|
||||
# test dtype to
|
||||
st_16 = st.to(torch.float16)
|
||||
self.assertFalse(st_16 is st)
|
||||
self.assertEqual(st_16.dtype, torch.float16)
|
||||
# test device to
|
||||
st_cpu = st.to(device=torch.device("cpu"))
|
||||
self.assertFalse(st_cpu is st)
|
||||
self.assertEqual(st_cpu.local_tensor().device.type, "cpu")
|
||||
st_cuda = st_cpu.to(device=torch.device("cuda"))
|
||||
self.assertEqual(st_cuda.local_tensor().device.type, "cuda")
|
||||
# non-kwarg device to
|
||||
st_cuda = st_cpu.to(torch.device("cuda"))
|
||||
self.assertEqual(st_cuda.local_tensor().device.type, "cuda")
|
||||
st_cpu = st_cuda.to(torch.device("cpu"))
|
||||
self.assertEqual(st_cpu.local_tensor().device.type, "cpu")
|
||||
# with string like device conversion
|
||||
st_cpu = st_cuda.to("cpu")
|
||||
self.assertEqual(st_cpu.local_tensor().device.type, "cpu")
|
||||
st_cuda = st_cpu.to("cuda")
|
||||
self.assertEqual(st_cuda.local_tensor().device.type, "cuda")
|
||||
# with int like device conversion
|
||||
st_cpu = st_cuda.to("cpu")
|
||||
self.assertEqual(st_cpu.local_tensor().device.type, "cpu")
|
||||
st_cuda = st_cpu.to(self.rank)
|
||||
self.assertEqual(st_cuda.local_tensor().device.type, "cuda")
|
||||
|
||||
# test tensor to
|
||||
cuda_tensor = torch.randn(3, 4, dtype=torch.float16, device="cuda")
|
||||
st_cuda = st.to(cuda_tensor)
|
||||
self.assertFalse(st_cuda is st)
|
||||
self.assertEqual(st_cuda.dtype, torch.float16)
|
||||
|
||||
cuda_tensor = torch.randn(3, 4, dtype=torch.float16, device="cuda:2")
|
||||
st_cuda = st.to(cuda_tensor)
|
||||
self.assertEqual(st_cuda.dtype, torch.float16)
|
||||
|
||||
# test dtype and device together
|
||||
st_cpu_16 = st.to("cpu", torch.float16)
|
||||
self.assertEqual(st_cpu_16.dtype, torch.float16)
|
||||
self.assertEqual(st_cpu_16.local_tensor().device.type, "cpu")
|
||||
|
||||
st_cuda_32 = st_cpu_16.to("cuda", torch.float32)
|
||||
self.assertEqual(st_cuda_32.dtype, torch.float32)
|
||||
self.assertEqual(st_cuda_32.local_tensor().device.type, "cuda")
|
||||
|
||||
# test pass additional process group
|
||||
gloo_pg = dist.new_group(backend="gloo")
|
||||
st_gloo = st.to(device="cpu", process_group=gloo_pg)
|
||||
self.assertFalse(st_gloo is st)
|
||||
self.assertEqual(st_gloo.local_tensor().device.type, "cpu")
|
||||
self.assertEqual(st_gloo._process_group, gloo_pg)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@requires_nccl()
|
||||
def test_sharded_tensor_device(self):
|
||||
spec = ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
"rank:0/cuda:0",
|
||||
"rank:1/cuda:1",
|
||||
"rank:2/cuda:2",
|
||||
"rank:3/cuda:3",
|
||||
],
|
||||
)
|
||||
h, w = 10, 20
|
||||
# CUDA sharded tensor should return a new ShardedTensor, but same
|
||||
# local shards(no movements)
|
||||
st = sharded_tensor.zeros(spec, h, w)
|
||||
current_device = torch.device(torch.cuda.current_device())
|
||||
self.assertEqual(current_device, st.device)
|
||||
|
||||
# test after to cpu, device get changed
|
||||
cpu_device = torch.device("cpu")
|
||||
st_cpu = st.to(device=cpu_device)
|
||||
self.assertEqual(st_cpu.device, cpu_device)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@requires_nccl()
|
||||
def test_uneven_shards(self):
|
||||
|
||||
@ -8,7 +8,6 @@ import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._shard.sharded_optim import (
|
||||
ShardedOptimizer,
|
||||
named_params_with_sharded_tensor,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
@ -19,7 +18,10 @@ from torch.distributed._shard.sharding_plan import ShardingPlan, ShardingPlanner
|
||||
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
|
||||
from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
TEST_GPU_NUM,
|
||||
ShardedTensorTestBase,
|
||||
@ -169,7 +171,7 @@ class TestShardingPlan(ShardedTensorTestBase):
|
||||
optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1)
|
||||
optim.step()
|
||||
sharded_optim = ShardedOptimizer(
|
||||
dict(named_params_with_sharded_tensor(megatron_lm)),
|
||||
dict(megatron_lm.named_parameters()),
|
||||
torch.optim.SGD,
|
||||
lr=0.1,
|
||||
)
|
||||
@ -360,3 +362,6 @@ class TestShardingPlan(ShardedTensorTestBase):
|
||||
|
||||
if self.rank >= 2:
|
||||
shard_module(megatron_lm, sharding_plan, process_group=pg)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -119,15 +119,7 @@ def shard_parameter(
|
||||
st = _shard_tensor(tensor, sharding_spec, src_rank, process_group)
|
||||
|
||||
# Replace param with ShardedTensor.
|
||||
|
||||
# Need to delete the attribute first since param_name might be
|
||||
# torch.nn.Parameter and can't be replaced with ShardedTensor which is
|
||||
# not torch.nn.Parameter.
|
||||
delattr(module, param_name)
|
||||
|
||||
# Now we can set the attribute appropriately.
|
||||
setattr(module, param_name, st)
|
||||
|
||||
module.register_parameter(param_name, nn.Parameter(st))
|
||||
|
||||
def _replicate_tensor(tensor: torch.Tensor, process_group=None) -> ReplicatedTensor:
|
||||
"""
|
||||
|
||||
@ -20,8 +20,7 @@ class ShardedOptimizer(optim.Optimizer):
|
||||
Args:
|
||||
named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict
|
||||
of parameters, where key is the parameter key, value is either
|
||||
Tensor or ShardedTensor parameter. This usually used in
|
||||
conjunction with :meth:`named_params_with_sharded_tensor`
|
||||
Tensor or ShardedTensor parameter.
|
||||
optimizer_class (torch.optim.Optimizer): the Optimizer to use
|
||||
locally, i.e. torch.optim.SGD, torch.optim.Adagrad, etc.
|
||||
*optimizer_args: the arguments to initialize the optimizer.
|
||||
|
||||
@ -3,6 +3,7 @@ import torch.distributed._shard.sharded_tensor._ops.elementwise_ops
|
||||
import torch.distributed._shard.sharded_tensor._ops.math_ops
|
||||
import torch.distributed._shard.sharded_tensor._ops.matrix_ops
|
||||
import torch.distributed._shard.sharded_tensor._ops.tensor_ops
|
||||
import torch.distributed._shard.sharded_tensor._ops.misc_ops
|
||||
|
||||
from .binary_cmp import equal, allclose
|
||||
from .init import kaiming_uniform_, normal_, uniform_, constant_
|
||||
|
||||
12
torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py
Normal file
12
torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py
Normal file
@ -0,0 +1,12 @@
|
||||
import torch
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
_sharded_op_impl,
|
||||
)
|
||||
|
||||
# This is used by `_apply()` within module.py to set new
|
||||
# parameters after apply a certain method, we should follow
|
||||
# the future behavior of overwriting the existing tensor
|
||||
# instead of doing in-place change using `.data = `.
|
||||
@_sharded_op_impl(torch._has_compatible_shallow_copy_type)
|
||||
def tensor_has_compatible_shallow_copy_type(types, args=(), kwargs=None, pg=None):
|
||||
return False
|
||||
@ -31,7 +31,6 @@ def tensor_deepcopy(types, args=(), kwargs=None, pg=None):
|
||||
|
||||
|
||||
# Tensor properties access
|
||||
_register_default_op(torch.Tensor.requires_grad.__get__, _sharded_op_impl) # type: ignore[attr-defined]
|
||||
_register_default_op(torch.Tensor.shape.__get__, _sharded_op_impl) # type: ignore[attr-defined]
|
||||
_register_default_op(torch.Tensor.dtype.__get__, _sharded_op_impl) # type: ignore[attr-defined]
|
||||
_register_default_op(torch.Tensor.layout.__get__, _sharded_op_impl) # type: ignore[attr-defined]
|
||||
@ -44,6 +43,27 @@ _register_default_op(torch.Tensor.contiguous, _sharded_op_impl)
|
||||
# __reduce_ex__ to dispatch to get_state/set_state
|
||||
_register_default_op(torch.Tensor.__reduce_ex__, _sharded_op_impl)
|
||||
|
||||
# autograd related properties
|
||||
_register_default_op(torch.Tensor.requires_grad.__get__, _sharded_op_impl) # type: ignore[attr-defined]
|
||||
# TODO: set grad with a ShardedTensor that consists of all local grads
|
||||
_register_default_op(torch.Tensor.grad.__get__, _sharded_op_impl) # type: ignore[union-attr]
|
||||
_register_default_op(torch.Tensor.grad_fn.__get__, _sharded_op_impl) # type: ignore[attr-defined]
|
||||
_register_default_op(torch.Tensor.is_leaf.__get__, _sharded_op_impl) # type: ignore[attr-defined]
|
||||
|
||||
# device property is ambiguous as from a global prospective,
|
||||
# ShardedTensor.device consists of multiple devices (might even across hosts)
|
||||
# We choose to return the current device of the local tensor to represent
|
||||
# the device property on each rank
|
||||
@_sharded_op_impl(torch.Tensor.device.__get__)
|
||||
def tensor_device(types, args=(), kwargs=None, pg=None):
|
||||
self_st = args[0]
|
||||
# Validate types
|
||||
if not isinstance(self_st, ShardedTensor):
|
||||
raise TypeError("input needs to be a ShardedTensor")
|
||||
|
||||
return self_st.local_shards()[0].tensor.device
|
||||
|
||||
|
||||
def sharded_type_as_check(*args, **kwargs):
|
||||
"""
|
||||
Perform extra checks for the sharded_type_as op such as the input needs to
|
||||
@ -167,6 +187,9 @@ def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None):
|
||||
for local_shard in self_st.local_shards():
|
||||
local_shard.tensor.requires_grad_(requires_grad)
|
||||
|
||||
# update the wrapper class property
|
||||
with torch._C.DisableTorchFunction():
|
||||
self_st.requires_grad_(requires_grad)
|
||||
# update the metadata in the meanwhile
|
||||
self_st._metadata.tensor_properties.requires_grad = requires_grad
|
||||
return self_st
|
||||
|
||||
@ -7,7 +7,6 @@ from typing import (
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
import copy
|
||||
@ -40,7 +39,6 @@ from .utils import (
|
||||
build_metadata_from_local_shards,
|
||||
build_global_metadata
|
||||
)
|
||||
from torch.overrides import handle_torch_function
|
||||
from torch.distributed.remote_device import _remote_device
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
@ -67,9 +65,9 @@ def _register_remote_shards(sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]]
|
||||
else:
|
||||
sharded_tensor._register_remote_shards(rrefs, rpc_rank)
|
||||
|
||||
class ShardedTensor(object):
|
||||
class ShardedTensor(torch.Tensor):
|
||||
"""
|
||||
ShardedTensor is an abstraction to represent Tensors that are sharded
|
||||
ShardedTensor is an torch.Tensor subclass to represent Tensors that are sharded
|
||||
across multiple devices and multiple processes.
|
||||
|
||||
ShardedTensor is initialized in an SPMD like fashion where each rank
|
||||
@ -116,10 +114,47 @@ class ShardedTensor(object):
|
||||
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# Use __new__ for logging purposes.
|
||||
_sharding_spec: shard_spec.ShardingSpec
|
||||
_metadata: ShardedTensorMetadata
|
||||
|
||||
def __new__(cls,
|
||||
sharding_spec: shard_spec.ShardingSpec,
|
||||
*size,
|
||||
**kwargs):
|
||||
# Use __new__ to construct a wrapper tensor, for recording tensor
|
||||
# properties and logging purposes.
|
||||
torch._C._log_api_usage_once("torch.distributed._shard.sharded_tensor")
|
||||
return super(ShardedTensor, cls).__new__(cls)
|
||||
|
||||
# check sharding spec and build sharded tensor metadata
|
||||
if not isinstance(sharding_spec, shard_spec.ShardingSpec):
|
||||
raise ValueError(f'Expecting ShardingSpec but got: {type(sharding_spec)}')
|
||||
|
||||
sizes = _flatten_tensor_size(size)
|
||||
dtype = kwargs['dtype']
|
||||
layout = kwargs['layout']
|
||||
pin_memory = kwargs['pin_memory']
|
||||
requires_grad = kwargs['requires_grad']
|
||||
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
tensor_properties = TensorProperties(dtype, layout, requires_grad, pin_memory=pin_memory)
|
||||
sharded_tensor_metadata = sharding_spec.build_metadata(
|
||||
sizes, tensor_properties=tensor_properties)
|
||||
|
||||
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
cls,
|
||||
sizes,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
pin_memory=pin_memory,
|
||||
requires_grad=requires_grad
|
||||
)
|
||||
# set sharding spec
|
||||
r._sharding_spec = sharding_spec
|
||||
# set metadata
|
||||
r._metadata = sharded_tensor_metadata
|
||||
return r
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -137,42 +172,25 @@ class ShardedTensor(object):
|
||||
# _process_group, _local_shards, etc.
|
||||
self._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
|
||||
|
||||
tensor_properties = TensorProperties(dtype, layout, requires_grad, memory_format, pin_memory)
|
||||
|
||||
if tensor_properties is None:
|
||||
raise ValueError('tensor_properties must not be None.')
|
||||
|
||||
if tensor_properties.dtype is None:
|
||||
tensor_properties.dtype = torch.get_default_dtype()
|
||||
|
||||
if tensor_properties.layout != torch.strided:
|
||||
if layout != torch.strided:
|
||||
raise ValueError('Only torch.strided layout is currently supported')
|
||||
|
||||
if tensor_properties.memory_format != torch.contiguous_format:
|
||||
if memory_format != torch.contiguous_format:
|
||||
raise ValueError('Only torch.contiguous_format memory_format is currently supported')
|
||||
|
||||
dims = _flatten_tensor_size(size)
|
||||
|
||||
if not isinstance(sharding_spec, shard_spec.ShardingSpec):
|
||||
raise ValueError(f'Expecting ShardingSpec but got: {type(sharding_spec)}')
|
||||
|
||||
self._sharding_spec = sharding_spec
|
||||
|
||||
sharded_tensor_metadata = sharding_spec.build_metadata(
|
||||
dims, tensor_properties=tensor_properties)
|
||||
self._metadata.tensor_properties.memory_format = memory_format
|
||||
|
||||
current_rank = dist.get_rank(self._process_group)
|
||||
|
||||
for shard_metadata in sharded_tensor_metadata.shards_metadata:
|
||||
for shard_metadata in self._metadata.shards_metadata:
|
||||
rank, device = _parse_and_validate_remote_device(self._process_group, shard_metadata.placement)
|
||||
if rank == current_rank:
|
||||
local_tensor = _create_tensor_from_params(
|
||||
shard_metadata.shard_sizes,
|
||||
local_device=device,
|
||||
tensor_properties=sharded_tensor_metadata.tensor_properties
|
||||
tensor_properties=self._metadata.tensor_properties
|
||||
)
|
||||
self._local_shards.append(Shard(local_tensor, shard_metadata))
|
||||
self._metadata = sharded_tensor_metadata
|
||||
|
||||
# do post initialization (i.e. register sharded_tensor_id, initialize_rpc)
|
||||
self._post_init()
|
||||
@ -266,7 +284,7 @@ class ShardedTensor(object):
|
||||
return torch.device(torch.cuda.current_device())
|
||||
return torch.device("cpu")
|
||||
|
||||
def gather(
|
||||
def gather( # type: ignore[override]
|
||||
self,
|
||||
dst: int = 0,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
@ -407,6 +425,148 @@ class ShardedTensor(object):
|
||||
)
|
||||
return st_cpu
|
||||
|
||||
def cuda(
|
||||
self,
|
||||
device=None,
|
||||
non_blocking=False,
|
||||
memory_format=torch.preserve_format,
|
||||
process_group=None
|
||||
) -> ShardedTensor:
|
||||
"""
|
||||
Returns a copy of this object in CUDA memory, if the original ShardedTensor
|
||||
is on CPU, we will move the local shard to the current GPU device of each
|
||||
process in a SPMD fashion.
|
||||
If this ShardedTensor is already on CUDA memory and local shards on each rank are
|
||||
already on current device, we still returns a new ShardedTensor object with new
|
||||
metadata, but no underlying data movements are performed.
|
||||
.. note:: When moving a ShardedTensor from CPU to GPU, the ShardedTensor might
|
||||
need to be managed by a different type of ProcessGroup(i.e. ProcessGroupNCCL),
|
||||
it is the user's responsiblity to explicitly pass in a new process_group that
|
||||
is compatible with GPU.
|
||||
"""
|
||||
if memory_format != torch.preserve_format and \
|
||||
memory_format != torch.contiguous_format:
|
||||
raise RuntimeError("Only `torch.contiguous_format` or "
|
||||
"`torch.preserve_format` is supported!")
|
||||
|
||||
if device is not None:
|
||||
device = torch.device(device) if isinstance(device, str) else device
|
||||
assert isinstance(device, torch.device) and device.index == torch.cuda.current_device(), \
|
||||
'''Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!'''
|
||||
|
||||
current_device = torch.device(torch.cuda.current_device())
|
||||
# returns a copy of ShardedTensor on CUDA current device
|
||||
list_shards: List[Shard] = []
|
||||
# move all local shards to current device, and change metadata
|
||||
# if local shards already on the current device, there's no
|
||||
# real data movement, only the metadata are copied.
|
||||
for shard in self._local_shards:
|
||||
cuda_tensor = shard.tensor.cuda(
|
||||
device=current_device,
|
||||
non_blocking=non_blocking,
|
||||
memory_format=memory_format
|
||||
) # type: ignore[call-arg]
|
||||
metadata = copy.deepcopy(shard.metadata)
|
||||
metadata.placement._device = current_device # type: ignore[union-attr]
|
||||
|
||||
list_shards.append(
|
||||
Shard(cuda_tensor, metadata)
|
||||
)
|
||||
|
||||
st_meta = copy.deepcopy(self.metadata())
|
||||
for meta in st_meta.shards_metadata:
|
||||
if meta.placement.device().type != "cuda": # type: ignore[union-attr]
|
||||
meta.placement._device = current_device # type: ignore[union-attr]
|
||||
|
||||
pg = self._process_group if process_group is None else process_group
|
||||
# we need to use `init_from_local_shards` to communicate between ranks
|
||||
# and update the sharding spec/shards metadata.
|
||||
st_cuda = ShardedTensor._init_from_local_shards_and_global_metadata(
|
||||
list_shards,
|
||||
sharded_tensor_metadata=st_meta,
|
||||
process_group=pg,
|
||||
init_rrefs=self._init_rrefs
|
||||
)
|
||||
return st_cuda
|
||||
|
||||
def to(self, *args, **kwargs) -> ShardedTensor:
|
||||
current_device = self._local_shards[0].tensor.device
|
||||
current_dtype = self.dtype
|
||||
device_to = current_device
|
||||
dtype_to = current_dtype
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], torch.dtype):
|
||||
dtype_to = args[0]
|
||||
elif isinstance(args[0], torch.device):
|
||||
device_to = args[0]
|
||||
elif isinstance(args[0], (str, int)):
|
||||
device_to = torch.device(args[0])
|
||||
elif isinstance(args[0], torch.Tensor):
|
||||
dtype_to = args[0].dtype
|
||||
device_to = args[0].device
|
||||
else:
|
||||
raise RuntimeError(f"ShardedTensor.to() have wrong arguments: {args}")
|
||||
elif len(args) == 2:
|
||||
device_to, dtype_to = args
|
||||
else:
|
||||
dtype_to = kwargs.get("dtype", current_dtype)
|
||||
device_to = kwargs.get("device", current_device)
|
||||
|
||||
device_to = torch.device(device_to) if isinstance(device_to, (str, int)) else device_to
|
||||
|
||||
if device_to.type == "cuda":
|
||||
# if device_to set to cuda, set to current device even
|
||||
# if user specify the device index.
|
||||
current_idx = torch.cuda.current_device()
|
||||
if device_to.index != current_idx:
|
||||
import warnings
|
||||
warnings.warn("ShardedTensor.to only move tensor to its current device"
|
||||
"If you want to put to different device, use `reshard` instead.")
|
||||
device_to = torch.device(current_idx)
|
||||
|
||||
copy_tensor = kwargs.get("copy", False)
|
||||
non_blocking = kwargs.get("non_blocking", False)
|
||||
memory_format = kwargs.get("memory_format", torch.preserve_format)
|
||||
process_group = kwargs.get("process_group", None)
|
||||
|
||||
if not copy_tensor and dtype_to == current_dtype and device_to == current_device:
|
||||
# already have correct dtype and device, return itself
|
||||
return self
|
||||
|
||||
# returns a copy of ShardedTensor on CUDA current device
|
||||
list_shards: List[Shard] = []
|
||||
|
||||
for shard in self._local_shards:
|
||||
new_tensor = shard.tensor.to( # type: ignore[call-overload]
|
||||
device=device_to,
|
||||
dtype=dtype_to,
|
||||
non_blocking=non_blocking,
|
||||
copy=copy_tensor,
|
||||
memory_format=memory_format
|
||||
)
|
||||
metadata = copy.deepcopy(shard.metadata)
|
||||
if metadata.placement is not None:
|
||||
metadata.placement._device = device_to
|
||||
list_shards.append(Shard(new_tensor, metadata))
|
||||
|
||||
# update metadata
|
||||
st_meta = copy.deepcopy(self.metadata())
|
||||
st_meta.tensor_properties.dtype = dtype_to
|
||||
for meta in st_meta.shards_metadata:
|
||||
meta.placement._device = device_to # type: ignore[union-attr]
|
||||
|
||||
pg = self._process_group if process_group is None else process_group
|
||||
# we need to use `init_from_local_shards` to communicate between ranks
|
||||
# and update the sharding spec/shards metadata.
|
||||
st_to = ShardedTensor._init_from_local_shards_and_global_metadata(
|
||||
list_shards,
|
||||
sharded_tensor_metadata=st_meta,
|
||||
process_group=pg,
|
||||
init_rrefs=self._init_rrefs
|
||||
)
|
||||
return st_to
|
||||
|
||||
|
||||
@classmethod
|
||||
def _init_from_local_shards(
|
||||
cls,
|
||||
@ -446,18 +606,24 @@ class ShardedTensor(object):
|
||||
gathered_metadatas = [local_sharded_tensor_metadata]
|
||||
|
||||
global_sharded_tensor_metadata = build_global_metadata(gathered_metadatas)
|
||||
tensor_properties = global_sharded_tensor_metadata.tensor_properties
|
||||
|
||||
# STEP 3: Validation done, create the actual ShardedTensor and populate fields
|
||||
# prepare initialization
|
||||
sharded_tensor = cls.__new__(cls)
|
||||
sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
|
||||
|
||||
# add to metadata and local_shards
|
||||
sharded_tensor._metadata = global_sharded_tensor_metadata
|
||||
sharded_tensor._local_shards = local_shards
|
||||
sharded_tensor._sharding_spec = shard_spec._infer_sharding_spec_from_shards_metadata(
|
||||
spec = shard_spec._infer_sharding_spec_from_shards_metadata(
|
||||
global_sharded_tensor_metadata.shards_metadata
|
||||
)
|
||||
sharded_tensor = cls.__new__(cls,
|
||||
spec,
|
||||
global_sharded_tensor_metadata.size,
|
||||
dtype=tensor_properties.dtype,
|
||||
layout=tensor_properties.layout,
|
||||
pin_memory=tensor_properties.pin_memory,
|
||||
requires_grad=tensor_properties.requires_grad)
|
||||
sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
|
||||
|
||||
# attach local_shards to the ShardedTensor created
|
||||
sharded_tensor._local_shards = local_shards
|
||||
|
||||
# run post initialization, i.e. map registration, rpc initialization
|
||||
sharded_tensor._post_init()
|
||||
@ -598,10 +764,19 @@ class ShardedTensor(object):
|
||||
if tensor_properties.layout != torch.strided:
|
||||
raise ValueError('Only torch.strided layout is currently supported')
|
||||
|
||||
sharded_tensor = cls.__new__(cls)
|
||||
sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
|
||||
if sharding_spec is None:
|
||||
spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata)
|
||||
else:
|
||||
spec = sharding_spec
|
||||
|
||||
sharded_tensor._metadata = sharded_tensor_metadata
|
||||
sharded_tensor = cls.__new__(cls,
|
||||
spec,
|
||||
sharded_tensor_metadata.size,
|
||||
dtype=tensor_properties.dtype,
|
||||
layout=tensor_properties.layout,
|
||||
pin_memory=tensor_properties.pin_memory,
|
||||
requires_grad=tensor_properties.requires_grad)
|
||||
sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
|
||||
|
||||
local_shard_metadatas = []
|
||||
|
||||
@ -655,11 +830,6 @@ class ShardedTensor(object):
|
||||
|
||||
# done validation, add local_shards
|
||||
sharded_tensor._local_shards = local_shards
|
||||
if sharding_spec is None:
|
||||
sharded_tensor._sharding_spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata)
|
||||
else:
|
||||
sharded_tensor._sharding_spec = sharding_spec
|
||||
|
||||
# run post initialization, i.e. map registration, rpc initialization
|
||||
sharded_tensor._post_init()
|
||||
return sharded_tensor
|
||||
@ -825,6 +995,11 @@ class ShardedTensor(object):
|
||||
f"torch function '{func.__name__}', with args: {args} and "
|
||||
f"kwargs: {kwargs} not supported for ShardedTensor!")
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
raise RuntimeError(f"A {cls.__name__} object is being used from c++ while calling {func.__module__}.{func.__name__} "
|
||||
"but the there is no custom __torch_dispatch__ implementation for it.")
|
||||
|
||||
def metadata(self) -> ShardedTensorMetadata:
|
||||
"""
|
||||
Returns a :class:`ShardedTensorMetadata` object corresponding to the
|
||||
@ -840,186 +1015,12 @@ class ShardedTensor(object):
|
||||
"""
|
||||
return self._local_shards
|
||||
|
||||
def size(self, dim: int = None) -> Union[torch.Size, int]:
|
||||
"""
|
||||
Returns a :Union:`[torch.Size, int]` which represents the size of the tensor.
|
||||
The dimension can be specified.
|
||||
|
||||
Args:
|
||||
dim (int, optional): the dimension over which the size represents.
|
||||
If specified, it returns the size of the given dimension.
|
||||
If not, it returns a subclass of tuple.
|
||||
Default: ``None``
|
||||
|
||||
Returns:
|
||||
A :Union:`[torch.Size, int]` represents the size of the tensor.
|
||||
"""
|
||||
size = self._metadata.size
|
||||
if dim is None:
|
||||
return size
|
||||
if dim < -len(size) or dim >= len(size):
|
||||
raise ValueError(
|
||||
"Argument ``dim`` must be within the range of tensor "
|
||||
f"dimensions [-{len(size)}, {len(size)})"
|
||||
)
|
||||
return size[dim]
|
||||
|
||||
|
||||
def is_pinned(self) -> bool:
|
||||
def is_pinned(self) -> bool: # type: ignore[override]
|
||||
"""
|
||||
Returns True if the sharded tensor (each local shard) resides in pinned memory.
|
||||
"""
|
||||
return self._metadata.tensor_properties.pin_memory
|
||||
|
||||
def is_contiguous(self) -> bool:
|
||||
"""
|
||||
Returns True if the sharded tensor (each local shard) is contiguous in memory
|
||||
in the order specified by memory format.
|
||||
"""
|
||||
return self._metadata.tensor_properties.memory_format == torch.contiguous_format
|
||||
|
||||
def dim(self) -> int:
|
||||
"""
|
||||
Returns a `int` which represents the dimension of the tensor.
|
||||
|
||||
Returns:
|
||||
A `int` represents the dimension of the tensor.
|
||||
"""
|
||||
return len(self._metadata.size)
|
||||
|
||||
# TODO: This op needs further definition of what exactly its behavior will be.
|
||||
def contiguous(self) -> ShardedTensor:
|
||||
"""
|
||||
Returns a new sharded tensor with the local tensor is made to contiguous.
|
||||
"""
|
||||
if self.is_contiguous():
|
||||
return self
|
||||
local_shards = []
|
||||
for shard in self.local_shards():
|
||||
local_shards.append(
|
||||
Shard(shard.tensor.contiguous(), shard.metadata)
|
||||
)
|
||||
return ShardedTensor._init_from_local_shards_and_global_metadata(
|
||||
local_shards,
|
||||
self._metadata,
|
||||
process_group=self._process_group,
|
||||
init_rrefs=self._init_rrefs,
|
||||
)
|
||||
|
||||
def masked_fill(self, mask, value) -> ShardedTensor:
|
||||
"""
|
||||
Returns a new sharded tensor with each shard has been filled elements
|
||||
with value where mask is True. The shape of mask must be broadcastable
|
||||
with the shape of the underlying tensor.
|
||||
|
||||
Args:
|
||||
mask (BoolTensor): the boolean mask.
|
||||
value (float): the value to fill in with.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object whose shards have been applied masked_fill.
|
||||
"""
|
||||
return handle_torch_function(
|
||||
torch.Tensor.masked_fill, (self, mask, value), self, mask, value
|
||||
)
|
||||
|
||||
def type_as(self, tensor) -> ShardedTensor:
|
||||
"""
|
||||
Returns a new sharded tensor with each shard has been
|
||||
cast to the type of the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (Tensor): the tensor which has the desired type.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object whose shards have been applied type_as.
|
||||
"""
|
||||
return handle_torch_function(torch.Tensor.type_as, (self, tensor), self, tensor)
|
||||
|
||||
def view(self, *shape) -> ShardedTensor:
|
||||
"""
|
||||
Returns a new sharded tensor with the same data as the
|
||||
self tensor but of a different shape for its local tensor.
|
||||
|
||||
For now, we only support to pass through the view op to the local
|
||||
tensor.
|
||||
|
||||
Args:
|
||||
shape (torch.Size or int...) – the desired size.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object whose shards have been applied
|
||||
with view to its local tensor.
|
||||
"""
|
||||
return handle_torch_function(torch.Tensor.view, (self, *shape), self, *shape)
|
||||
|
||||
def transpose(self, dim0, dim1) -> ShardedTensor:
|
||||
"""
|
||||
Returns a new sharded tensor with the given dimensions transposed.
|
||||
During the transpose, we keep the original shading dim, e.g., if the
|
||||
tensor is sharded by dim 0 and if we call transpose(1, 0). The returned
|
||||
tensor will be sharded by dim 1.
|
||||
|
||||
Args:
|
||||
dim0 (int): the first dimension to be transposed.
|
||||
dim1 (int): the second dimension to be transposed.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object whose dims have been transposed
|
||||
specified in the input.
|
||||
"""
|
||||
return handle_torch_function(torch.Tensor.transpose, (self, dim0, dim1), self, dim0, dim1)
|
||||
|
||||
def bmm(self, st2, *, out=None) -> ShardedTensor:
|
||||
"""
|
||||
Performs a batch matrix-matrix product of matrices stored in self and st2.
|
||||
|
||||
Warning: For now we only supports the case when both tensors are sharded
|
||||
by dim 0 so that no communication is needed.
|
||||
|
||||
Args:
|
||||
st2 (ShardedTensor) – the second batch of sharded matrices to be multiplied.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object which is the result of the batch multiplication.
|
||||
"""
|
||||
return handle_torch_function(torch.Tensor.bmm, (self, st2, out), self, st2, out=out)
|
||||
|
||||
def chunk(self, chunks, dim=0) -> List[ShardedTensor]:
|
||||
"""
|
||||
Attempts to split a tensor into the specified number of chunks.
|
||||
Each chunk is a view of the input tensor.
|
||||
|
||||
Warnings: Chunk by the sharding dim is not supported.
|
||||
|
||||
Args:
|
||||
chunks (int) – number of chunks to return
|
||||
dim (int) – dimension along which to split the tensor
|
||||
|
||||
Returns:
|
||||
A List of :class:`ShardedTensor` object chunked on dims.
|
||||
"""
|
||||
return handle_torch_function(torch.Tensor.chunk, (self, chunks, dim), self, chunks, dim=dim)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._metadata.size
|
||||
|
||||
@property
|
||||
def requires_grad(self):
|
||||
return self._metadata.tensor_properties.requires_grad
|
||||
|
||||
def requires_grad_(self, requires_grad=True):
|
||||
return handle_torch_function(torch.Tensor.requires_grad_, (self, requires_grad), self, requires_grad)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._metadata.tensor_properties.dtype
|
||||
|
||||
@property
|
||||
def layout(self):
|
||||
return self._metadata.tensor_properties.layout
|
||||
|
||||
def _register_remote_shards(self, remote_shards: List[rpc.RRef[Shard]], rpc_rank: int):
|
||||
self._remote_shards[rpc_rank] = remote_shards
|
||||
|
||||
@ -1043,45 +1044,6 @@ class ShardedTensor(object):
|
||||
def __repr__(self):
|
||||
return f'ShardedTensor({self._metadata})'
|
||||
|
||||
def __add__(self, other):
|
||||
return handle_torch_function(torch.Tensor.__add__, (self, other), self, other)
|
||||
|
||||
def __radd__(self, other):
|
||||
return handle_torch_function(torch.Tensor.__radd__, (self, other), self, other)
|
||||
|
||||
def __sub__(self, other):
|
||||
return handle_torch_function(torch.Tensor.__sub__, (self, other), self, other)
|
||||
|
||||
def __rsub__(self, other):
|
||||
return handle_torch_function(torch.Tensor.__rsub__, (self, other), self, other)
|
||||
|
||||
def __mul__(self, other):
|
||||
return handle_torch_function(torch.Tensor.__mul__, (self, other), self, other)
|
||||
|
||||
def __rmul__(self, other):
|
||||
return handle_torch_function(torch.Tensor.__rmul__, (self, other), self, other)
|
||||
|
||||
def __truediv__(self, other):
|
||||
return handle_torch_function(torch.Tensor.__div__, (self, other), self, other)
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return handle_torch_function(torch.Tensor.__rdiv__, (self, other), self, other)
|
||||
|
||||
def tanh(self):
|
||||
return handle_torch_function(torch.Tensor.tanh, (self,), self)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return handle_torch_function(torch.Tensor.__getitem__, (self, key), self, key)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
return handle_torch_function(torch.Tensor.__deepcopy__, (self, memo), self, memo)
|
||||
|
||||
def clone(self, *, memory_format=torch.preserve_format):
|
||||
return handle_torch_function(torch.Tensor.clone, (self,), self, memory_format=memory_format)
|
||||
|
||||
def detach(self):
|
||||
return handle_torch_function(torch.Tensor.detach, (self,), self)
|
||||
|
||||
@dataclass
|
||||
class ProcessGroupState:
|
||||
"""
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
# coding=utf-8
|
||||
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from ._common import (
|
||||
@ -158,7 +156,7 @@ def _validate_embedding_param(args, kwargs):
|
||||
raise TypeError("input need to be torch.Tensor")
|
||||
if not isinstance(weight, ShardedTensor):
|
||||
raise TypeError("weight needs to be ShardedTensor")
|
||||
weight_size = cast(torch.Size, weight.size())
|
||||
weight_size = weight.size()
|
||||
if len(weight_size) != 2:
|
||||
raise ValueError("Weight needs to have exactly 2 dims")
|
||||
if int(torch.min(input).item()) < 0:
|
||||
|
||||
@ -204,7 +204,7 @@ def _validate_embedding_bag_param(args, kwargs):
|
||||
raise TypeError("weight needs to be ShardedTensor")
|
||||
if len(input.size()) > 2:
|
||||
raise ValueError("Input more than 2 dims not supported")
|
||||
weight_size = cast(torch.Size, weight.size())
|
||||
weight_size = weight.size()
|
||||
if len(weight_size) != 2:
|
||||
raise ValueError("Weight needs to have exactly 2 dims")
|
||||
if int(torch.min(input).item()) < 0:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, cast
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -105,14 +105,14 @@ def sharded_linear(types, args, kwargs, pg):
|
||||
world_size = dist.get_world_size(pg)
|
||||
rank = dist.get_rank(pg)
|
||||
|
||||
if sharding_dim == 1 and isinstance(input, torch.Tensor):
|
||||
return _handle_row_wise_sharding_tensor(
|
||||
input, world_size, weight, rank, local_shard_t, bias, pg
|
||||
)
|
||||
elif sharding_dim == 1 and isinstance(input, ShardedTensor):
|
||||
if sharding_dim == 1 and isinstance(input, ShardedTensor):
|
||||
return _handle_row_wise_sharding_sharded_tensor(
|
||||
input, world_size, weight, local_shard_t, bias, pg
|
||||
)
|
||||
elif sharding_dim == 1 and isinstance(input, torch.Tensor):
|
||||
return _handle_row_wise_sharding_tensor(
|
||||
input, world_size, weight, rank, local_shard_t, bias, pg
|
||||
)
|
||||
elif sharding_dim == 0:
|
||||
return _handle_col_wise_sharding(
|
||||
input, world_size, weight, rank, local_shard_t, bias, pg
|
||||
@ -125,7 +125,7 @@ def sharded_linear(types, args, kwargs, pg):
|
||||
|
||||
def _validate_linear_op_param(args, kwargs):
|
||||
"""
|
||||
Validate input params of sharded embedding op.
|
||||
Validate input params of sharded linear op.
|
||||
|
||||
Args:
|
||||
input: input of the linear layer.
|
||||
@ -141,13 +141,13 @@ def _validate_linear_op_param(args, kwargs):
|
||||
# Validate types
|
||||
if not isinstance(input, torch.Tensor) and not isinstance(input, ShardedTensor):
|
||||
raise TypeError("input needs to be either torch.Tensor or ShardedTensor")
|
||||
if not isinstance(bias, torch.Tensor):
|
||||
if type(bias) != torch.Tensor and type(bias) != torch.nn.Parameter:
|
||||
raise TypeError("bias needs to be torch.Tensor")
|
||||
if not isinstance(weight, ShardedTensor):
|
||||
raise TypeError("weight needs to be ShardedTensor")
|
||||
if len(input.size()) < 1: # type: ignore[arg-type]
|
||||
raise ValueError("Input needs to have at least 1 dim")
|
||||
weight_size = cast(torch.Size, weight.size())
|
||||
weight_size = weight.size()
|
||||
if len(weight_size) != 2:
|
||||
raise ValueError("Weight needs to have exactly 2 dims")
|
||||
if len(bias.size()) != 1:
|
||||
|
||||
Reference in New Issue
Block a user