[quant] Do not decompose choose_qparams_per_token_asymmetric (#124178)

Summary: https://github.com/pytorch/pytorch/pull/123452 added
backward support to this op by turning it into
CompositeImplicitAutograd, which meant it gets decomposed during
export/compile. However, this is not desirable behavior for the
PTQ case when we try to lower the model. This commit enables
QAT without breaking PTQ by refactoring the impl into a separate
op that does have backward support.

Test Plan:
python test/test_quantization.py -k test_decomposed_choose_qparams_per_token_asymmetric_backward

Reviewers: jerryzh168, digantdesai, zou3519

Subscribers: jerryzh168, digantdesai, zou3519, supriyar

Differential Revision: [D56192116](https://our.internmc.facebook.com/intern/diff/D56192116)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124178
Approved by: https://github.com/digantdesai
This commit is contained in:
andrewor14
2024-04-16 08:10:52 -07:00
committed by PyTorch MergeBot
parent 3e90e93a78
commit 3eea300680
2 changed files with 36 additions and 4 deletions

View File

@ -1606,7 +1606,7 @@ class TestQuantizedTensor(TestCase):
# register the ops
import torch.ao.quantization.fx._decomposed
x = torch.randn(2, 3).requires_grad_()
(s, zp) = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(x, torch.int8)
(s, zp) = torch.ops.quantized_decomposed._choose_qparams_per_token_asymmetric_impl(x, torch.int8)
out = x.div(s).add(zp).round()
out.sum().backward()

View File

@ -639,16 +639,16 @@ def choose_qparams_per_token_meta(
quantized_decomposed_lib.define(
"choose_qparams_per_token_asymmetric(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
"_choose_qparams_per_token_asymmetric_impl(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
)
@impl(
quantized_decomposed_lib,
"choose_qparams_per_token_asymmetric",
"_choose_qparams_per_token_asymmetric_impl",
"CompositeImplicitAutograd",
)
def choose_qparams_per_token_asymmetric(
def _choose_qparams_per_token_asymmetric_impl(
input: torch.Tensor,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -691,6 +691,38 @@ def choose_qparams_per_token_asymmetric(
return scale.to(torch.float32), zero_point.to(torch.float32)
quantized_decomposed_lib.define(
"choose_qparams_per_token_asymmetric(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
)
@impl(
quantized_decomposed_lib,
"choose_qparams_per_token_asymmetric",
"CompositeExplicitAutograd",
)
def choose_qparams_per_token_asymmetric(
input: torch.Tensor,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
return _choose_qparams_per_token_asymmetric_impl(input, dtype)
@impl(
quantized_decomposed_lib,
"choose_qparams_per_token_asymmetric",
"Meta",
)
def choose_qparams_per_token_asymmetric_meta(
input: torch.Tensor,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
size = (1, input.size(-1))
return torch.empty(size, dtype=torch.double, device=input.device), torch.empty(
size, dtype=torch.int64, device=input.device
)
def _per_token_quant_qparam_dim_check(input, scales, zero_points):
num_tokens = math.prod(list(input.size())[:-1])
assert (