mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	fix _MaskPartial when multiple embeddings coexist (#131264)
Previously, using _MaskPartial when multiple embeddings have the following issues: 1. Suppose an `nn.Embedding` has shape `[vocab_size, emb_size]`. When there are more than one embeddings, sharing the same `vocab_size` but with different `emb_size`s. Then they would not share `OpStrategy` since each, when involved in computation, would have different `OpSchema`; however, there would be cache hit for redistribute (specifically `_gen_transform_infos` in `torch/distributed/_tensor/_redistribute.py` when doing `Replicate` -> `_MaskPartial`) as the `_MaskPartial` only has `vocab_size` as `logical_dim_size` but not `emb_size` as attribute. This cache hit is undesirable and would cause trouble when doing all-reduce/reduce-scatter on the new `_MaskPartial` in a separate `OpStrategy`. The error was reported in #130725. In this PR, we introduce `offset_shape` to represent the embedding's full shape to avoid cache hit from embeddings of different shapes. 2. The second issue is when we have two `nn.Embedding`s `emb1` and `emb2` with the same shape. There will be cache hit not only in `_gen_transform_infos`, but also in `OpStrategy` generation. Previously, if we sequentially do `Replicate` -> `_MaskPartial` for both `emb1` `emb2` and then sequentially do reduction on the `_MaskPartial` of `emb1`, it would destroy the `MaskBuffer` and `emb2` would hit error. This PR adds a `refcount` for the `MaskBuffer` so that it can be properly shared by multiple `nn.Embedding`s. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131264 Approved by: https://github.com/wanchaol
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							0ab6551bcb
						
					
				
				
					commit
					7b0e10f0e5
				
			| @ -184,6 +184,49 @@ class TestEmbeddingOp(DTensorTestBase): | ||||
|             self.assertEqual(comm_mode.get_total_counts(), 1) | ||||
|             self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1) | ||||
|  | ||||
|     @with_comms | ||||
|     def test_multiple_embeddings_rowwise(self): | ||||
|         mesh = self.build_device_mesh() | ||||
|  | ||||
|         inp = torch.randint(0, 10, (4, 4), device=self.device_type) | ||||
|         replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False) | ||||
|  | ||||
|         from torch.distributed._tensor.ops._embedding_ops import _MaskPartial | ||||
|  | ||||
|         # case 1: two embeddings with the same shape, thus sharing the underying _MaskPartial | ||||
|         # and MaskBuffer, because of cache hit from sharding propagation | ||||
|  | ||||
|         emb1 = torch.nn.Embedding(10, 23, device=self.device_type) | ||||
|         sharded_emb1 = self._apply_sharding(emb1, 0, mesh) | ||||
|         output1 = sharded_emb1(replicated_inp) | ||||
|  | ||||
|         emb2 = torch.nn.Embedding(10, 29, device=self.device_type) | ||||
|         sharded_emb2 = self._apply_sharding(emb2, 0, mesh) | ||||
|         output2 = sharded_emb2(replicated_inp) | ||||
|  | ||||
|         partial_placement1 = output1.placements[0] | ||||
|         self.assertIsInstance(partial_placement1, _MaskPartial) | ||||
|         output1.full_tensor() | ||||
|  | ||||
|         partial_placement2 = output2.placements[0] | ||||
|         self.assertIsInstance(partial_placement2, _MaskPartial) | ||||
|         output2.full_tensor() | ||||
|  | ||||
|         self.assertTrue(id(partial_placement1), id(partial_placement2)) | ||||
|  | ||||
|         # case 2: two embeddings with the same logical_dim_size, but different logical_shape | ||||
|         # thus they will have different _MaskPartial placements (with no cache hit) | ||||
|  | ||||
|         emb3 = torch.nn.Embedding(10, 29, device=self.device_type) | ||||
|         sharded_emb3 = self._apply_sharding(emb3, 0, mesh) | ||||
|         output3 = sharded_emb3(replicated_inp) | ||||
|         partial_placement3 = output3.placements[0] | ||||
|         self.assertIsInstance(partial_placement3, _MaskPartial) | ||||
|         output2.full_tensor() | ||||
|  | ||||
|         # not equal because of different logical_shape, despite of same logical_dim_size | ||||
|         self.assertNotEqual(partial_placement1, partial_placement3) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
|  | ||||
| @ -32,21 +32,29 @@ aten = torch.ops.aten | ||||
| @dataclass | ||||
| class MaskBuffer: | ||||
|     data: Optional[torch.Tensor] = None | ||||
|     # refcount allows shared usage of the MaskBuffer, as long as all users have the same data | ||||
|     refcount: int = 0 | ||||
|  | ||||
|     def materialize_mask(self, mask): | ||||
|         if self.data is not None: | ||||
|             raise RuntimeError("MaskBuffer has already been materialized") | ||||
|         self.data = mask | ||||
|         if self.refcount == 0: | ||||
|             self.data = mask | ||||
|         else: | ||||
|             assert self.data is not None | ||||
|             if not torch.equal(self.data, mask): | ||||
|                 raise RuntimeError( | ||||
|                     "MaskBuffer has been materialized with conflicting data" | ||||
|                 ) | ||||
|         self.refcount += 1 | ||||
|  | ||||
|     def release_mask(self): | ||||
|         # TODO: evaluate if we need to release the mask buffer or the buffer | ||||
|         # can just have the same lifetime as the Partial placement | ||||
|         if self.data is None: | ||||
|         if self.refcount == 0 or self.data is None: | ||||
|             raise RuntimeError("MaskBuffer has not been materialized") | ||||
|         self.data = None | ||||
|         self.refcount -= 1 | ||||
|         if self.refcount == 0: | ||||
|             self.data = None | ||||
|  | ||||
|     def apply_mask(self, tensor): | ||||
|         if self.data is None: | ||||
|         if self.refcount == 0 or self.data is None: | ||||
|             raise RuntimeError("MaskBuffer has not been materialized") | ||||
|  | ||||
|         # NOTE: _MaskPartial is being used by the embedding op and the gather op. | ||||
| @ -70,17 +78,23 @@ class _MaskPartial(Partial): | ||||
|     lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor. | ||||
|     """ | ||||
|  | ||||
|     logical_dim_size: int = -1 | ||||
|     mask_buffer: MaskBuffer = field(default_factory=MaskBuffer) | ||||
|  | ||||
|     # required fields for computing the local offset and deriving the mask | ||||
|     offset_shape: Optional[torch.Size] = None | ||||
|     offset_dim: int = 0 | ||||
|  | ||||
|     def _partition_value( | ||||
|         self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int | ||||
|     ) -> torch.Tensor: | ||||
|         # override parent logic to perform partial mask for embedding | ||||
|         num_chunks = mesh.size(mesh_dim) | ||||
|         # get local shard size and offset on the embedding_dim | ||||
|         assert ( | ||||
|             self.offset_shape is not None | ||||
|         ), "offset_shape needs to be set for _MaskPartial" | ||||
|         local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim( | ||||
|             self.logical_dim_size, | ||||
|             self.offset_shape[self.offset_dim], | ||||
|             num_chunks, | ||||
|             mesh.get_local_rank(mesh_dim), | ||||
|             return_offset=True, | ||||
| @ -146,19 +160,24 @@ class _MaskPartial(Partial): | ||||
|  | ||||
|         return ( | ||||
|             self.reduce_op == other.reduce_op | ||||
|             and self.logical_dim_size == other.logical_dim_size | ||||
|             and self.offset_shape == other.offset_shape | ||||
|             and self.offset_dim == other.offset_dim | ||||
|         ) | ||||
|  | ||||
|     def __hash__(self) -> int: | ||||
|         return 1 + hash( | ||||
|             (self.logical_dim_size, id(self.mask_buffer.data), self.reduce_op) | ||||
|             ( | ||||
|                 self.reduce_op, | ||||
|                 self.offset_shape, | ||||
|                 self.offset_dim, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         """ | ||||
|         machine readable representation of the MaskPartial placement | ||||
|         """ | ||||
|         return f"_MaskPartial(logical_dim_size={self.logical_dim_size})" | ||||
|         return f"_MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})" | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         """ | ||||
| @ -192,7 +211,7 @@ def embedding_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: | ||||
|     single_mesh_dim_strategies.append(colwise_sharding) | ||||
|  | ||||
|     # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial | ||||
|     embedding_partial_placement = _MaskPartial(logical_dim_size=weight_shape[0]) | ||||
|     embedding_partial_placement = _MaskPartial(offset_shape=weight_shape, offset_dim=0) | ||||
|  | ||||
|     # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates | ||||
|     # from the input indices and use it for output reduction | ||||
|  | ||||
| @ -408,7 +408,7 @@ def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: | ||||
|     # this only works when the input is sharded on the gather dimension, and | ||||
|     # index has size 1 on the gather dimension | ||||
|     if index_shape[dim] == 1: | ||||
|         index_partial_placement = _MaskPartial(logical_dim_size=input_shape[dim]) | ||||
|         index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) | ||||
|         input_sharding: PlacementList = [ | ||||
|             index_partial_placement, | ||||
|             Shard(dim), | ||||
|  | ||||
| @ -200,7 +200,8 @@ def _nll_loss_forward( | ||||
|     local_weight: Optional[Tensor], | ||||
|     reduction: int, | ||||
|     ignore_index: int, | ||||
|     channel_dim_size: int, | ||||
|     input_shape: torch.Size, | ||||
|     channel_dim: int, | ||||
|     mesh: DeviceMesh, | ||||
|     mesh_dim: int, | ||||
| ) -> Tuple[Tensor, Tensor]: | ||||
| @ -230,7 +231,7 @@ def _nll_loss_forward( | ||||
|  | ||||
|     # The following code block is a distributed version of | ||||
|     # result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim) | ||||
|     partial_placement = _MaskPartial(logical_dim_size=channel_dim_size) | ||||
|     partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim) | ||||
|     safe_target_partial_ = partial_placement._partition_value( | ||||
|         safe_target_, mesh, mesh_dim | ||||
|     ) | ||||
| @ -317,7 +318,8 @@ def _nll_loss_forward_handler( | ||||
|         local_weight, | ||||
|         reduction, | ||||
|         ignore_index, | ||||
|         channel_dim_size, | ||||
|         x.shape, | ||||
|         channel_dim, | ||||
|         spec.mesh, | ||||
|         mesh_dim, | ||||
|     ) | ||||
| @ -348,7 +350,8 @@ def _nll_loss_and_log_softmax_backward( | ||||
|     reduction: int, | ||||
|     ignore_index: int, | ||||
|     total_weight: Tensor, | ||||
|     channel_dim_size: int, | ||||
|     input_shape: torch.Size, | ||||
|     channel_dim: int, | ||||
|     mesh: DeviceMesh, | ||||
|     mesh_dim: int, | ||||
| ) -> Tensor: | ||||
| @ -362,7 +365,7 @@ def _nll_loss_and_log_softmax_backward( | ||||
|  | ||||
|     # The following code block is a distributed version of | ||||
|     # grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0) | ||||
|     partial_placement = _MaskPartial(logical_dim_size=channel_dim_size) | ||||
|     partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim) | ||||
|     safe_target = safe_target.squeeze(channel_dim).flatten() | ||||
|     masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim) | ||||
|     # only update grad_input to -1 if not masked | ||||
| @ -422,7 +425,6 @@ def _nll_loss_backward_handler( | ||||
|     total_weight = cast(Tensor, args[6]) | ||||
|  | ||||
|     channel_dim = 1 if x.dim() >= 2 else 0 | ||||
|     channel_dim_size = x.shape[channel_dim] | ||||
|     spec = x._spec | ||||
|     mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim) | ||||
|  | ||||
| @ -449,7 +451,8 @@ def _nll_loss_backward_handler( | ||||
|         reduction, | ||||
|         ignore_index, | ||||
|         total_weight, | ||||
|         channel_dim_size, | ||||
|         x.shape, | ||||
|         channel_dim, | ||||
|         spec.mesh, | ||||
|         mesh_dim, | ||||
|     ) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user