mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/97302 Approved by: https://github.com/zou3519
		
			
				
	
	
		
			83 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			83 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Owner(s): ["module: unknown"]
 | 
						|
 | 
						|
import hypothesis.strategies as st
 | 
						|
from hypothesis import given
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
from torch.testing._internal.common_utils import TestCase, run_tests
 | 
						|
import torch.testing._internal.hypothesis_utils as hu
 | 
						|
hu.assert_deadline_disabled()
 | 
						|
 | 
						|
 | 
						|
class PruningOpTest(TestCase):
 | 
						|
 | 
						|
    # Generate rowwise mask vector based on indicator and threshold value.
 | 
						|
    # indicator is a vector that contains one value per weight row and it
 | 
						|
    # represents the importance of a row.
 | 
						|
    # We mask a row if its indicator value is less than the threshold.
 | 
						|
    def _generate_rowwise_mask(self, embedding_rows):
 | 
						|
        indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32))
 | 
						|
        threshold = np.random.random_sample()
 | 
						|
        mask = torch.BoolTensor([True if val >= threshold else False for val in indicator])
 | 
						|
        return mask
 | 
						|
 | 
						|
    def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype):
 | 
						|
        embedding_weights = None
 | 
						|
        if weights_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
 | 
						|
            embedding_weights = torch.randint(0, 100, (embedding_rows, embedding_dims), dtype=weights_dtype)
 | 
						|
        else:
 | 
						|
            embedding_weights = torch.rand((embedding_rows, embedding_dims), dtype=weights_dtype)
 | 
						|
        mask = self._generate_rowwise_mask(embedding_rows)
 | 
						|
 | 
						|
        def get_pt_result(embedding_weights, mask, indices_type):
 | 
						|
            return torch._rowwise_prune(embedding_weights, mask, indices_type)
 | 
						|
 | 
						|
        # Reference implementation.
 | 
						|
        def get_reference_result(embedding_weights, mask, indices_type):
 | 
						|
            num_embeddings = mask.size()[0]
 | 
						|
            compressed_idx_out = torch.zeros(num_embeddings, dtype=indices_type)
 | 
						|
            pruned_weights_out = embedding_weights[mask[:]]
 | 
						|
            idx = 0
 | 
						|
            for i in range(mask.size()[0]):
 | 
						|
                if mask[i]:
 | 
						|
                    compressed_idx_out[i] = idx
 | 
						|
                    idx = idx + 1
 | 
						|
                else:
 | 
						|
                    compressed_idx_out[i] = -1
 | 
						|
            return (pruned_weights_out, compressed_idx_out)
 | 
						|
 | 
						|
        pt_pruned_weights, pt_compressed_indices_map = get_pt_result(
 | 
						|
            embedding_weights, mask, indices_type)
 | 
						|
        ref_pruned_weights, ref_compressed_indices_map = get_reference_result(
 | 
						|
            embedding_weights, mask, indices_type)
 | 
						|
 | 
						|
        torch.testing.assert_close(pt_pruned_weights, ref_pruned_weights)
 | 
						|
        self.assertEqual(pt_compressed_indices_map, ref_compressed_indices_map)
 | 
						|
        self.assertEqual(pt_compressed_indices_map.dtype, indices_type)
 | 
						|
 | 
						|
 | 
						|
    @given(
 | 
						|
        embedding_rows=st.integers(1, 100),
 | 
						|
        embedding_dims=st.integers(1, 100),
 | 
						|
        weights_dtype=st.sampled_from([torch.float64, torch.float32,
 | 
						|
                                       torch.float16, torch.int8,
 | 
						|
                                       torch.int16, torch.int32, torch.int64])
 | 
						|
    )
 | 
						|
    def test_rowwise_prune_op_32bit_indices(self, embedding_rows, embedding_dims, weights_dtype):
 | 
						|
        self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype)
 | 
						|
 | 
						|
 | 
						|
    @given(
 | 
						|
        embedding_rows=st.integers(1, 100),
 | 
						|
        embedding_dims=st.integers(1, 100),
 | 
						|
        weights_dtype=st.sampled_from([torch.float64, torch.float32,
 | 
						|
                                       torch.float16, torch.int8,
 | 
						|
                                       torch.int16, torch.int32, torch.int64])
 | 
						|
    )
 | 
						|
    def test_rowwise_prune_op_64bit_indices(self, embedding_rows, embedding_dims, weights_dtype):
 | 
						|
        self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int64, weights_dtype)
 | 
						|
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    run_tests()
 |