mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[quant] Enable backward for choose_qparams_per_token_asymmetric (#123452)
Summary: When running the backward for this op, we get the error: ``` RuntimeError: derivative for aten::aminmax is not implemented ``` This commit replaces this call with separate amin and amax calls instead, which do have implemented derivatives. Test Plan: python test/test_quantization.py -k test_decomposed_choose_qparams_per_token_asymmetric_backward Reviewers: jerryzh168, digantdesai Subscribers: jerryzh168, digantdesai, supriyar Differential Revision: [D55805170](https://our.internmc.facebook.com/intern/diff/D55805170) Pull Request resolved: https://github.com/pytorch/pytorch/pull/123452 Approved by: https://github.com/digantdesai, https://github.com/jerryzh168, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
3346ec8263
commit
762e19606e
@ -1602,6 +1602,14 @@ class TestQuantizedTensor(TestCase):
|
||||
self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X)
|
||||
self.assertEqual(dequantized_X, dequantized_decomposed_X)
|
||||
|
||||
def test_decomposed_choose_qparams_per_token_asymmetric_backward(self):
|
||||
# register the ops
|
||||
import torch.ao.quantization.fx._decomposed
|
||||
x = torch.randn(2, 3).requires_grad_()
|
||||
(s, zp) = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(x, torch.int8)
|
||||
out = x.div(s).add(zp).round()
|
||||
out.sum().backward()
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_quantization.py TESTNAME\n\n"
|
||||
|
||||
Reference in New Issue
Block a user