[quant] Enable backward for choose_qparams_per_token_asymmetric (#123452)

Summary: When running the backward for this op, we get the error:
```
RuntimeError: derivative for aten::aminmax is not implemented
```
This commit replaces this call with separate amin and amax
calls instead, which do have implemented derivatives.

Test Plan:
python test/test_quantization.py -k test_decomposed_choose_qparams_per_token_asymmetric_backward

Reviewers: jerryzh168, digantdesai

Subscribers: jerryzh168, digantdesai, supriyar

Differential Revision: [D55805170](https://our.internmc.facebook.com/intern/diff/D55805170)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123452
Approved by: https://github.com/digantdesai, https://github.com/jerryzh168, https://github.com/zou3519
This commit is contained in:
andrewor14
2024-04-12 08:52:55 -07:00
committed by PyTorch MergeBot
parent 3346ec8263
commit 762e19606e
2 changed files with 11 additions and 18 deletions

View File

@ -1602,6 +1602,14 @@ class TestQuantizedTensor(TestCase):
self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X)
self.assertEqual(dequantized_X, dequantized_decomposed_X)
def test_decomposed_choose_qparams_per_token_asymmetric_backward(self):
# register the ops
import torch.ao.quantization.fx._decomposed
x = torch.randn(2, 3).requires_grad_()
(s, zp) = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(x, torch.int8)
out = x.div(s).add(zp).round()
out.sum().backward()
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_quantization.py TESTNAME\n\n"