mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix _MaskPartial when multiple embeddings coexist (#131264)
Previously, using _MaskPartial when multiple embeddings have the following issues: 1. Suppose an `nn.Embedding` has shape `[vocab_size, emb_size]`. When there are more than one embeddings, sharing the same `vocab_size` but with different `emb_size`s. Then they would not share `OpStrategy` since each, when involved in computation, would have different `OpSchema`; however, there would be cache hit for redistribute (specifically `_gen_transform_infos` in `torch/distributed/_tensor/_redistribute.py` when doing `Replicate` -> `_MaskPartial`) as the `_MaskPartial` only has `vocab_size` as `logical_dim_size` but not `emb_size` as attribute. This cache hit is undesirable and would cause trouble when doing all-reduce/reduce-scatter on the new `_MaskPartial` in a separate `OpStrategy`. The error was reported in #130725. In this PR, we introduce `offset_shape` to represent the embedding's full shape to avoid cache hit from embeddings of different shapes. 2. The second issue is when we have two `nn.Embedding`s `emb1` and `emb2` with the same shape. There will be cache hit not only in `_gen_transform_infos`, but also in `OpStrategy` generation. Previously, if we sequentially do `Replicate` -> `_MaskPartial` for both `emb1` `emb2` and then sequentially do reduction on the `_MaskPartial` of `emb1`, it would destroy the `MaskBuffer` and `emb2` would hit error. This PR adds a `refcount` for the `MaskBuffer` so that it can be properly shared by multiple `nn.Embedding`s. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131264 Approved by: https://github.com/wanchaol
This commit is contained in:
committed by
PyTorch MergeBot
parent
0ab6551bcb
commit
7b0e10f0e5
@ -184,6 +184,49 @@ class TestEmbeddingOp(DTensorTestBase):
|
||||
self.assertEqual(comm_mode.get_total_counts(), 1)
|
||||
self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1)
|
||||
|
||||
@with_comms
|
||||
def test_multiple_embeddings_rowwise(self):
|
||||
mesh = self.build_device_mesh()
|
||||
|
||||
inp = torch.randint(0, 10, (4, 4), device=self.device_type)
|
||||
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
|
||||
|
||||
from torch.distributed._tensor.ops._embedding_ops import _MaskPartial
|
||||
|
||||
# case 1: two embeddings with the same shape, thus sharing the underying _MaskPartial
|
||||
# and MaskBuffer, because of cache hit from sharding propagation
|
||||
|
||||
emb1 = torch.nn.Embedding(10, 23, device=self.device_type)
|
||||
sharded_emb1 = self._apply_sharding(emb1, 0, mesh)
|
||||
output1 = sharded_emb1(replicated_inp)
|
||||
|
||||
emb2 = torch.nn.Embedding(10, 29, device=self.device_type)
|
||||
sharded_emb2 = self._apply_sharding(emb2, 0, mesh)
|
||||
output2 = sharded_emb2(replicated_inp)
|
||||
|
||||
partial_placement1 = output1.placements[0]
|
||||
self.assertIsInstance(partial_placement1, _MaskPartial)
|
||||
output1.full_tensor()
|
||||
|
||||
partial_placement2 = output2.placements[0]
|
||||
self.assertIsInstance(partial_placement2, _MaskPartial)
|
||||
output2.full_tensor()
|
||||
|
||||
self.assertTrue(id(partial_placement1), id(partial_placement2))
|
||||
|
||||
# case 2: two embeddings with the same logical_dim_size, but different logical_shape
|
||||
# thus they will have different _MaskPartial placements (with no cache hit)
|
||||
|
||||
emb3 = torch.nn.Embedding(10, 29, device=self.device_type)
|
||||
sharded_emb3 = self._apply_sharding(emb3, 0, mesh)
|
||||
output3 = sharded_emb3(replicated_inp)
|
||||
partial_placement3 = output3.placements[0]
|
||||
self.assertIsInstance(partial_placement3, _MaskPartial)
|
||||
output2.full_tensor()
|
||||
|
||||
# not equal because of different logical_shape, despite of same logical_dim_size
|
||||
self.assertNotEqual(partial_placement1, partial_placement3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -32,21 +32,29 @@ aten = torch.ops.aten
|
||||
@dataclass
|
||||
class MaskBuffer:
|
||||
data: Optional[torch.Tensor] = None
|
||||
# refcount allows shared usage of the MaskBuffer, as long as all users have the same data
|
||||
refcount: int = 0
|
||||
|
||||
def materialize_mask(self, mask):
|
||||
if self.data is not None:
|
||||
raise RuntimeError("MaskBuffer has already been materialized")
|
||||
self.data = mask
|
||||
if self.refcount == 0:
|
||||
self.data = mask
|
||||
else:
|
||||
assert self.data is not None
|
||||
if not torch.equal(self.data, mask):
|
||||
raise RuntimeError(
|
||||
"MaskBuffer has been materialized with conflicting data"
|
||||
)
|
||||
self.refcount += 1
|
||||
|
||||
def release_mask(self):
|
||||
# TODO: evaluate if we need to release the mask buffer or the buffer
|
||||
# can just have the same lifetime as the Partial placement
|
||||
if self.data is None:
|
||||
if self.refcount == 0 or self.data is None:
|
||||
raise RuntimeError("MaskBuffer has not been materialized")
|
||||
self.data = None
|
||||
self.refcount -= 1
|
||||
if self.refcount == 0:
|
||||
self.data = None
|
||||
|
||||
def apply_mask(self, tensor):
|
||||
if self.data is None:
|
||||
if self.refcount == 0 or self.data is None:
|
||||
raise RuntimeError("MaskBuffer has not been materialized")
|
||||
|
||||
# NOTE: _MaskPartial is being used by the embedding op and the gather op.
|
||||
@ -70,17 +78,23 @@ class _MaskPartial(Partial):
|
||||
lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor.
|
||||
"""
|
||||
|
||||
logical_dim_size: int = -1
|
||||
mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
|
||||
|
||||
# required fields for computing the local offset and deriving the mask
|
||||
offset_shape: Optional[torch.Size] = None
|
||||
offset_dim: int = 0
|
||||
|
||||
def _partition_value(
|
||||
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
|
||||
) -> torch.Tensor:
|
||||
# override parent logic to perform partial mask for embedding
|
||||
num_chunks = mesh.size(mesh_dim)
|
||||
# get local shard size and offset on the embedding_dim
|
||||
assert (
|
||||
self.offset_shape is not None
|
||||
), "offset_shape needs to be set for _MaskPartial"
|
||||
local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim(
|
||||
self.logical_dim_size,
|
||||
self.offset_shape[self.offset_dim],
|
||||
num_chunks,
|
||||
mesh.get_local_rank(mesh_dim),
|
||||
return_offset=True,
|
||||
@ -146,19 +160,24 @@ class _MaskPartial(Partial):
|
||||
|
||||
return (
|
||||
self.reduce_op == other.reduce_op
|
||||
and self.logical_dim_size == other.logical_dim_size
|
||||
and self.offset_shape == other.offset_shape
|
||||
and self.offset_dim == other.offset_dim
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return 1 + hash(
|
||||
(self.logical_dim_size, id(self.mask_buffer.data), self.reduce_op)
|
||||
(
|
||||
self.reduce_op,
|
||||
self.offset_shape,
|
||||
self.offset_dim,
|
||||
)
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
machine readable representation of the MaskPartial placement
|
||||
"""
|
||||
return f"_MaskPartial(logical_dim_size={self.logical_dim_size})"
|
||||
return f"_MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
@ -192,7 +211,7 @@ def embedding_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
|
||||
single_mesh_dim_strategies.append(colwise_sharding)
|
||||
|
||||
# rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial
|
||||
embedding_partial_placement = _MaskPartial(logical_dim_size=weight_shape[0])
|
||||
embedding_partial_placement = _MaskPartial(offset_shape=weight_shape, offset_dim=0)
|
||||
|
||||
# NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates
|
||||
# from the input indices and use it for output reduction
|
||||
|
||||
@ -408,7 +408,7 @@ def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
|
||||
# this only works when the input is sharded on the gather dimension, and
|
||||
# index has size 1 on the gather dimension
|
||||
if index_shape[dim] == 1:
|
||||
index_partial_placement = _MaskPartial(logical_dim_size=input_shape[dim])
|
||||
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
|
||||
input_sharding: PlacementList = [
|
||||
index_partial_placement,
|
||||
Shard(dim),
|
||||
|
||||
@ -200,7 +200,8 @@ def _nll_loss_forward(
|
||||
local_weight: Optional[Tensor],
|
||||
reduction: int,
|
||||
ignore_index: int,
|
||||
channel_dim_size: int,
|
||||
input_shape: torch.Size,
|
||||
channel_dim: int,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
@ -230,7 +231,7 @@ def _nll_loss_forward(
|
||||
|
||||
# The following code block is a distributed version of
|
||||
# result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
|
||||
partial_placement = _MaskPartial(logical_dim_size=channel_dim_size)
|
||||
partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
|
||||
safe_target_partial_ = partial_placement._partition_value(
|
||||
safe_target_, mesh, mesh_dim
|
||||
)
|
||||
@ -317,7 +318,8 @@ def _nll_loss_forward_handler(
|
||||
local_weight,
|
||||
reduction,
|
||||
ignore_index,
|
||||
channel_dim_size,
|
||||
x.shape,
|
||||
channel_dim,
|
||||
spec.mesh,
|
||||
mesh_dim,
|
||||
)
|
||||
@ -348,7 +350,8 @@ def _nll_loss_and_log_softmax_backward(
|
||||
reduction: int,
|
||||
ignore_index: int,
|
||||
total_weight: Tensor,
|
||||
channel_dim_size: int,
|
||||
input_shape: torch.Size,
|
||||
channel_dim: int,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
) -> Tensor:
|
||||
@ -362,7 +365,7 @@ def _nll_loss_and_log_softmax_backward(
|
||||
|
||||
# The following code block is a distributed version of
|
||||
# grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
|
||||
partial_placement = _MaskPartial(logical_dim_size=channel_dim_size)
|
||||
partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
|
||||
safe_target = safe_target.squeeze(channel_dim).flatten()
|
||||
masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim)
|
||||
# only update grad_input to -1 if not masked
|
||||
@ -422,7 +425,6 @@ def _nll_loss_backward_handler(
|
||||
total_weight = cast(Tensor, args[6])
|
||||
|
||||
channel_dim = 1 if x.dim() >= 2 else 0
|
||||
channel_dim_size = x.shape[channel_dim]
|
||||
spec = x._spec
|
||||
mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)
|
||||
|
||||
@ -449,7 +451,8 @@ def _nll_loss_backward_handler(
|
||||
reduction,
|
||||
ignore_index,
|
||||
total_weight,
|
||||
channel_dim_size,
|
||||
x.shape,
|
||||
channel_dim,
|
||||
spec.mesh,
|
||||
mesh_dim,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user