Add wrapper for fbgemm quantization operations (#122763)

Summary:
We add wrappers for fbgemm's packing so we can pass it through PT2 to
lowering phase of AOTInductor.

Test Plan:
Included in commit.
test_quantized_ops::test_wrapped_fbgemm_linear_fp16

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D55433204](https://our.internmc.facebook.com/intern/diff/D55433204)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122763
Approved by: https://github.com/jerryzh168
ghstack dependencies: #122762
This commit is contained in:
Mu-Chu Lee
2024-03-27 21:19:46 -07:00
committed by PyTorch MergeBot
parent e296722e0e
commit 966ae943df
3 changed files with 122 additions and 1 deletions

View File

@ -4,9 +4,10 @@
import copy
import itertools
import numpy as np
import unittest
import operator
import random
import sys
import unittest
from typing import NamedTuple, List
import torch
@ -3377,6 +3378,55 @@ class TestDynamicQuantizedOps(TestCase):
opcheck(qlinear_dynamic, (x, w, bias))
@skipIfNoFBGEMM
def test_wrapped_fbgemm_linear_fp16(self):
options = itertools.product(
(2, 4), # batch_size
(4, 5), # input_channels
(4, 7), # output_channels
)
for batch_size, input_channels, output_channels in options:
pack_op = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16
linear_op = torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight
x = torch.randn(batch_size, input_channels)
w = torch.randn(output_channels, input_channels)
bias = torch.randn(output_channels)
w_packed = pack_op(w)
out = linear_op(x, w_packed, bias, output_channels)
w_fp16 = w.to(torch.float16).to(torch.float32)
ref = F.linear(x, w_fp16, bias)
self.assertEqual(out, ref)
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@skipIfNoFBGEMM
def test_wrapped_fbgemm_pack_gemm_matrix_fp16_pt2_compliant(self):
# We are not using opcheck over here because the output for the op we're testing
# (_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16) is not deterministic
# due to the C-struct it's procuding. This would fail the check when we're trying
# to match the result between compiled and eager version.
#
# This is only a temporary solution, long term, we should be able to support PT2
# with torchbind natively.
def func(X, W, B):
packed_W = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(W)
return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(X, packed_W, B, W.size(0))
x = torch.randn(1, 4, device="cpu")
w = torch.randn(4, 4, device="cpu")
b = torch.zeros(4, device="cpu")
ref_out = func(x, w, b)
compiled = torch.compile(func)
compiled_out = compiled(x, w, b)
self.assertEqual(ref_out, compiled_out)
"""Tests the correctness of the dynamic quantized lstm/gru."""