mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add wrapper for fbgemm quantization operations (#122763)
Summary: We add wrappers for fbgemm's packing so we can pass it through PT2 to lowering phase of AOTInductor. Test Plan: Included in commit. test_quantized_ops::test_wrapped_fbgemm_linear_fp16 Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D55433204](https://our.internmc.facebook.com/intern/diff/D55433204) Pull Request resolved: https://github.com/pytorch/pytorch/pull/122763 Approved by: https://github.com/jerryzh168 ghstack dependencies: #122762
This commit is contained in:
committed by
PyTorch MergeBot
parent
e296722e0e
commit
966ae943df
@ -4,9 +4,10 @@
|
||||
import copy
|
||||
import itertools
|
||||
import numpy as np
|
||||
import unittest
|
||||
import operator
|
||||
import random
|
||||
import sys
|
||||
import unittest
|
||||
from typing import NamedTuple, List
|
||||
|
||||
import torch
|
||||
@ -3377,6 +3378,55 @@ class TestDynamicQuantizedOps(TestCase):
|
||||
|
||||
opcheck(qlinear_dynamic, (x, w, bias))
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_wrapped_fbgemm_linear_fp16(self):
|
||||
options = itertools.product(
|
||||
(2, 4), # batch_size
|
||||
(4, 5), # input_channels
|
||||
(4, 7), # output_channels
|
||||
)
|
||||
for batch_size, input_channels, output_channels in options:
|
||||
pack_op = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16
|
||||
linear_op = torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight
|
||||
|
||||
x = torch.randn(batch_size, input_channels)
|
||||
w = torch.randn(output_channels, input_channels)
|
||||
bias = torch.randn(output_channels)
|
||||
|
||||
w_packed = pack_op(w)
|
||||
out = linear_op(x, w_packed, bias, output_channels)
|
||||
|
||||
w_fp16 = w.to(torch.float16).to(torch.float32)
|
||||
ref = F.linear(x, w_fp16, bias)
|
||||
|
||||
self.assertEqual(out, ref)
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
|
||||
)
|
||||
@skipIfNoFBGEMM
|
||||
def test_wrapped_fbgemm_pack_gemm_matrix_fp16_pt2_compliant(self):
|
||||
# We are not using opcheck over here because the output for the op we're testing
|
||||
# (_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16) is not deterministic
|
||||
# due to the C-struct it's procuding. This would fail the check when we're trying
|
||||
# to match the result between compiled and eager version.
|
||||
#
|
||||
# This is only a temporary solution, long term, we should be able to support PT2
|
||||
# with torchbind natively.
|
||||
def func(X, W, B):
|
||||
packed_W = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(W)
|
||||
return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(X, packed_W, B, W.size(0))
|
||||
|
||||
x = torch.randn(1, 4, device="cpu")
|
||||
w = torch.randn(4, 4, device="cpu")
|
||||
b = torch.zeros(4, device="cpu")
|
||||
|
||||
ref_out = func(x, w, b)
|
||||
|
||||
compiled = torch.compile(func)
|
||||
compiled_out = compiled(x, w, b)
|
||||
|
||||
self.assertEqual(ref_out, compiled_out)
|
||||
|
||||
"""Tests the correctness of the dynamic quantized lstm/gru."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user