mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	This reverts commit 321e6026925f6b6e8a36e3a8b7c0295cd7541911. Reverted https://github.com/pytorch/pytorch/pull/164645 on behalf of https://github.com/izaitsevfb due to causes lint failures ([comment](https://github.com/pytorch/pytorch/pull/164645#issuecomment-3369274351))
		
			
				
	
	
		
			85 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			85 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["module: unknown"]
 | |
| 
 | |
| import hypothesis.strategies as st
 | |
| from hypothesis import given
 | |
| import numpy as np
 | |
| import torch
 | |
| from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
 | |
| import torch.testing._internal.hypothesis_utils as hu
 | |
| hu.assert_deadline_disabled()
 | |
| 
 | |
| 
 | |
| class PruningOpTest(TestCase):
 | |
| 
 | |
|     # Generate rowwise mask vector based on indicator and threshold value.
 | |
|     # indicator is a vector that contains one value per weight row and it
 | |
|     # represents the importance of a row.
 | |
|     # We mask a row if its indicator value is less than the threshold.
 | |
|     def _generate_rowwise_mask(self, embedding_rows):
 | |
|         indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32))
 | |
|         threshold = float(np.random.random_sample())
 | |
|         mask = torch.BoolTensor([True if val >= threshold else False for val in indicator])
 | |
|         return mask
 | |
| 
 | |
|     def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype):
 | |
|         embedding_weights = None
 | |
|         if weights_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
 | |
|             embedding_weights = torch.randint(0, 100, (embedding_rows, embedding_dims), dtype=weights_dtype)
 | |
|         else:
 | |
|             embedding_weights = torch.rand((embedding_rows, embedding_dims), dtype=weights_dtype)
 | |
|         mask = self._generate_rowwise_mask(embedding_rows)
 | |
| 
 | |
|         def get_pt_result(embedding_weights, mask, indices_type):
 | |
|             return torch._rowwise_prune(embedding_weights, mask, indices_type)
 | |
| 
 | |
|         # Reference implementation.
 | |
|         def get_reference_result(embedding_weights, mask, indices_type):
 | |
|             num_embeddings = mask.size()[0]
 | |
|             compressed_idx_out = torch.zeros(num_embeddings, dtype=indices_type)
 | |
|             pruned_weights_out = embedding_weights[mask[:]]
 | |
|             idx = 0
 | |
|             for i in range(mask.size()[0]):
 | |
|                 if mask[i]:
 | |
|                     compressed_idx_out[i] = idx
 | |
|                     idx = idx + 1
 | |
|                 else:
 | |
|                     compressed_idx_out[i] = -1
 | |
|             return (pruned_weights_out, compressed_idx_out)
 | |
| 
 | |
|         pt_pruned_weights, pt_compressed_indices_map = get_pt_result(
 | |
|             embedding_weights, mask, indices_type)
 | |
|         ref_pruned_weights, ref_compressed_indices_map = get_reference_result(
 | |
|             embedding_weights, mask, indices_type)
 | |
| 
 | |
|         torch.testing.assert_close(pt_pruned_weights, ref_pruned_weights)
 | |
|         self.assertEqual(pt_compressed_indices_map, ref_compressed_indices_map)
 | |
|         self.assertEqual(pt_compressed_indices_map.dtype, indices_type)
 | |
| 
 | |
| 
 | |
|     @skipIfTorchDynamo()
 | |
|     @given(
 | |
|         embedding_rows=st.integers(1, 100),
 | |
|         embedding_dims=st.integers(1, 100),
 | |
|         weights_dtype=st.sampled_from([torch.float64, torch.float32,
 | |
|                                        torch.float16, torch.int8,
 | |
|                                        torch.int16, torch.int32, torch.int64])
 | |
|     )
 | |
|     def test_rowwise_prune_op_32bit_indices(self, embedding_rows, embedding_dims, weights_dtype):
 | |
|         self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype)
 | |
| 
 | |
| 
 | |
|     @skipIfTorchDynamo()
 | |
|     @given(
 | |
|         embedding_rows=st.integers(1, 100),
 | |
|         embedding_dims=st.integers(1, 100),
 | |
|         weights_dtype=st.sampled_from([torch.float64, torch.float32,
 | |
|                                        torch.float16, torch.int8,
 | |
|                                        torch.int16, torch.int32, torch.int64])
 | |
|     )
 | |
|     def test_rowwise_prune_op_64bit_indices(self, embedding_rows, embedding_dims, weights_dtype):
 | |
|         self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int64, weights_dtype)
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     run_tests()
 |