mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19984 Add qint8 for QTensor, with underlying type of int8_t Reviewed By: jianyuh Differential Revision: D15150715 fbshipit-source-id: 57580f599d46f9323af5ce462dbbc464b25e40d7
		
			
				
	
	
		
			27 lines
		
	
	
		
			791 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			27 lines
		
	
	
		
			791 B
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						|
import torch.nn.quantized.functional as F
 | 
						|
 | 
						|
import numpy as np
 | 
						|
from common_utils import TestCase, run_tests
 | 
						|
 | 
						|
def _quantize(x, scale, zero_point, qmin=0, qmax=255):
 | 
						|
    """Quantizes a numpy array."""
 | 
						|
    qx = np.round(x / scale + zero_point)
 | 
						|
    qx = np.clip(qx, qmin, qmax).astype(np.uint8)
 | 
						|
    return qx
 | 
						|
 | 
						|
class FunctionalAPITest(TestCase):
 | 
						|
    def test_functional_api(self):
 | 
						|
        X = torch.arange(-5, 5, dtype=torch.float)
 | 
						|
        scale = 2.0
 | 
						|
        zero_point = 1
 | 
						|
        Y = X.numpy().copy()
 | 
						|
        Y[Y < 0] = 0
 | 
						|
        qY = _quantize(Y, scale, zero_point)
 | 
						|
        qX = X.quantize_linear(scale=scale, zero_point=zero_point, dtype=torch.quint8)
 | 
						|
        qY_hat = F.relu(qX)
 | 
						|
        np.testing.assert_equal(qY, qY_hat.int_repr())
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    run_tests()
 |