fix _MaskPartial when multiple embeddings coexist (#131264)

Previously, using _MaskPartial when multiple embeddings have the following issues:
1. Suppose an `nn.Embedding` has shape `[vocab_size, emb_size]`. When there are more than one embeddings, sharing the same `vocab_size` but with different `emb_size`s. Then they would not share `OpStrategy` since each, when involved in computation, would have different `OpSchema`; however, there would be cache hit for redistribute (specifically `_gen_transform_infos` in `torch/distributed/_tensor/_redistribute.py` when doing `Replicate` -> `_MaskPartial`) as the `_MaskPartial` only has `vocab_size` as `logical_dim_size` but not `emb_size` as attribute. This cache hit is undesirable and would cause trouble when doing all-reduce/reduce-scatter on the new `_MaskPartial` in a separate `OpStrategy`. The error was reported in #130725. In this PR, we introduce `offset_shape` to represent the embedding's full shape to avoid cache hit from embeddings of different shapes.
2. The second issue is when we have two `nn.Embedding`s `emb1` and `emb2` with the same shape. There will be cache hit not only in `_gen_transform_infos`, but also in `OpStrategy` generation. Previously, if we sequentially do `Replicate` -> `_MaskPartial` for both `emb1` `emb2` and then sequentially do reduction on the `_MaskPartial` of `emb1`, it would destroy the `MaskBuffer` and `emb2` would hit error. This PR adds a `refcount` for the `MaskBuffer` so that it can be properly shared by multiple `nn.Embedding`s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131264
Approved by: https://github.com/wanchaol
This commit is contained in:
Tianyu Liu
2024-07-26 15:13:23 -07:00
committed by PyTorch MergeBot
parent 0ab6551bcb
commit 7b0e10f0e5
4 changed files with 87 additions and 22 deletions

View File

@ -184,6 +184,49 @@ class TestEmbeddingOp(DTensorTestBase):
self.assertEqual(comm_mode.get_total_counts(), 1)
self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1)
@with_comms
def test_multiple_embeddings_rowwise(self):
mesh = self.build_device_mesh()
inp = torch.randint(0, 10, (4, 4), device=self.device_type)
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
from torch.distributed._tensor.ops._embedding_ops import _MaskPartial
# case 1: two embeddings with the same shape, thus sharing the underying _MaskPartial
# and MaskBuffer, because of cache hit from sharding propagation
emb1 = torch.nn.Embedding(10, 23, device=self.device_type)
sharded_emb1 = self._apply_sharding(emb1, 0, mesh)
output1 = sharded_emb1(replicated_inp)
emb2 = torch.nn.Embedding(10, 29, device=self.device_type)
sharded_emb2 = self._apply_sharding(emb2, 0, mesh)
output2 = sharded_emb2(replicated_inp)
partial_placement1 = output1.placements[0]
self.assertIsInstance(partial_placement1, _MaskPartial)
output1.full_tensor()
partial_placement2 = output2.placements[0]
self.assertIsInstance(partial_placement2, _MaskPartial)
output2.full_tensor()
self.assertTrue(id(partial_placement1), id(partial_placement2))
# case 2: two embeddings with the same logical_dim_size, but different logical_shape
# thus they will have different _MaskPartial placements (with no cache hit)
emb3 = torch.nn.Embedding(10, 29, device=self.device_type)
sharded_emb3 = self._apply_sharding(emb3, 0, mesh)
output3 = sharded_emb3(replicated_inp)
partial_placement3 = output3.placements[0]
self.assertIsInstance(partial_placement3, _MaskPartial)
output2.full_tensor()
# not equal because of different logical_shape, despite of same logical_dim_size
self.assertNotEqual(partial_placement1, partial_placement3)
if __name__ == "__main__":
run_tests()

View File

@ -32,21 +32,29 @@ aten = torch.ops.aten
@dataclass
class MaskBuffer:
data: Optional[torch.Tensor] = None
# refcount allows shared usage of the MaskBuffer, as long as all users have the same data
refcount: int = 0
def materialize_mask(self, mask):
if self.data is not None:
raise RuntimeError("MaskBuffer has already been materialized")
self.data = mask
if self.refcount == 0:
self.data = mask
else:
assert self.data is not None
if not torch.equal(self.data, mask):
raise RuntimeError(
"MaskBuffer has been materialized with conflicting data"
)
self.refcount += 1
def release_mask(self):
# TODO: evaluate if we need to release the mask buffer or the buffer
# can just have the same lifetime as the Partial placement
if self.data is None:
if self.refcount == 0 or self.data is None:
raise RuntimeError("MaskBuffer has not been materialized")
self.data = None
self.refcount -= 1
if self.refcount == 0:
self.data = None
def apply_mask(self, tensor):
if self.data is None:
if self.refcount == 0 or self.data is None:
raise RuntimeError("MaskBuffer has not been materialized")
# NOTE: _MaskPartial is being used by the embedding op and the gather op.
@ -70,17 +78,23 @@ class _MaskPartial(Partial):
lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor.
"""
logical_dim_size: int = -1
mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
# required fields for computing the local offset and deriving the mask
offset_shape: Optional[torch.Size] = None
offset_dim: int = 0
def _partition_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
# override parent logic to perform partial mask for embedding
num_chunks = mesh.size(mesh_dim)
# get local shard size and offset on the embedding_dim
assert (
self.offset_shape is not None
), "offset_shape needs to be set for _MaskPartial"
local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim(
self.logical_dim_size,
self.offset_shape[self.offset_dim],
num_chunks,
mesh.get_local_rank(mesh_dim),
return_offset=True,
@ -146,19 +160,24 @@ class _MaskPartial(Partial):
return (
self.reduce_op == other.reduce_op
and self.logical_dim_size == other.logical_dim_size
and self.offset_shape == other.offset_shape
and self.offset_dim == other.offset_dim
)
def __hash__(self) -> int:
return 1 + hash(
(self.logical_dim_size, id(self.mask_buffer.data), self.reduce_op)
(
self.reduce_op,
self.offset_shape,
self.offset_dim,
)
)
def __repr__(self) -> str:
"""
machine readable representation of the MaskPartial placement
"""
return f"_MaskPartial(logical_dim_size={self.logical_dim_size})"
return f"_MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})"
def __str__(self) -> str:
"""
@ -192,7 +211,7 @@ def embedding_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
single_mesh_dim_strategies.append(colwise_sharding)
# rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial
embedding_partial_placement = _MaskPartial(logical_dim_size=weight_shape[0])
embedding_partial_placement = _MaskPartial(offset_shape=weight_shape, offset_dim=0)
# NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates
# from the input indices and use it for output reduction

View File

@ -408,7 +408,7 @@ def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
# this only works when the input is sharded on the gather dimension, and
# index has size 1 on the gather dimension
if index_shape[dim] == 1:
index_partial_placement = _MaskPartial(logical_dim_size=input_shape[dim])
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
input_sharding: PlacementList = [
index_partial_placement,
Shard(dim),

View File

@ -200,7 +200,8 @@ def _nll_loss_forward(
local_weight: Optional[Tensor],
reduction: int,
ignore_index: int,
channel_dim_size: int,
input_shape: torch.Size,
channel_dim: int,
mesh: DeviceMesh,
mesh_dim: int,
) -> Tuple[Tensor, Tensor]:
@ -230,7 +231,7 @@ def _nll_loss_forward(
# The following code block is a distributed version of
# result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
partial_placement = _MaskPartial(logical_dim_size=channel_dim_size)
partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
safe_target_partial_ = partial_placement._partition_value(
safe_target_, mesh, mesh_dim
)
@ -317,7 +318,8 @@ def _nll_loss_forward_handler(
local_weight,
reduction,
ignore_index,
channel_dim_size,
x.shape,
channel_dim,
spec.mesh,
mesh_dim,
)
@ -348,7 +350,8 @@ def _nll_loss_and_log_softmax_backward(
reduction: int,
ignore_index: int,
total_weight: Tensor,
channel_dim_size: int,
input_shape: torch.Size,
channel_dim: int,
mesh: DeviceMesh,
mesh_dim: int,
) -> Tensor:
@ -362,7 +365,7 @@ def _nll_loss_and_log_softmax_backward(
# The following code block is a distributed version of
# grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
partial_placement = _MaskPartial(logical_dim_size=channel_dim_size)
partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
safe_target = safe_target.squeeze(channel_dim).flatten()
masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim)
# only update grad_input to -1 if not masked
@ -422,7 +425,6 @@ def _nll_loss_backward_handler(
total_weight = cast(Tensor, args[6])
channel_dim = 1 if x.dim() >= 2 else 0
channel_dim_size = x.shape[channel_dim]
spec = x._spec
mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)
@ -449,7 +451,8 @@ def _nll_loss_backward_handler(
reduction,
ignore_index,
total_weight,
channel_dim_size,
x.shape,
channel_dim,
spec.mesh,
mesh_dim,
)