mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Description: 1. Quantize Linear Layer Weights to 4-bits: Quantize the weights of the Linear layer to 4 bits, using symmetric quantization. Pack two 4-bit weights into one uint8 container. Choose a quantization scheme (channel-wise or group-wise), with the group size being a multiple of 32. 2. Prepare Quantized Weights, Scales, and Optional Bias: After quantizing, obtain the quantized_weights, scales, and groupsize. If the original Linear layer has a bias, prepare it as well. 3. Pack the Weights Efficiently: Use torch.ops.aten._dyn_quant_pack_4bit_weight to optimally pack the weights, scales, and optional bias. ```python packed_weights = torch.ops.aten._dyn_quant_pack_4bit_weight(weight, scales_and_zeros, bias, groupsize, in_features, out_features) ``` Input parameters should include: in_features and out_features (the same as the Linear layer’s corresponding parameters). 4. Perform Dynamic Quantized Matrix Multiplication: Use torch.ops.aten._dyn_quant_matmul_4bit to perform matrix multiplication with quantized weights. ```python output = torch.ops.aten._dyn_quant_matmul_4bit(input, packed_weights, groupsize, in_features, out_features) ``` Inputs required include: The input tensor, packed_weights , groupsize, and the in_features and out_features. API Usage: https://github.com/pytorch/pytorch/issues/143289 Model Perf : 7B Transformer model: Prefill : 340 t/s Decode : 40 t/s 2B Transformer model Prefill : 747 t/s Decode : 80 t/s Tests: python test/test_linalg.py -k test__dyn_quant_pack_4bit_weight Ran 1 test in 0.016s OK python test/test_linalg.py -k test__dyn_quant_matmul_4bit Ran 8 tests in 0.077s OK python test/test_linalg.py -k test_compile_dyn_quant_matmul_4bit Ran 8 tests in 11.454s Change-Id: Ia1672bad5e6ec94e64d8bb1971395d60f4b3a452 Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134124 Approved by: https://github.com/digantdesai, https://github.com/malfet
74 lines
1.7 KiB
Python
74 lines
1.7 KiB
Python
# mypy: allow-untyped-defs
|
|
import types
|
|
from contextlib import contextmanager
|
|
|
|
|
|
# The idea for this parameter is that we forbid bare assignment
|
|
# to torch.backends.<cudnn|mkldnn>.enabled and friends when running our
|
|
# test suite, where it's very easy to forget to undo the change
|
|
# later.
|
|
__allow_nonbracketed_mutation_flag = True
|
|
|
|
|
|
def disable_global_flags():
|
|
global __allow_nonbracketed_mutation_flag
|
|
__allow_nonbracketed_mutation_flag = False
|
|
|
|
|
|
def flags_frozen():
|
|
return not __allow_nonbracketed_mutation_flag
|
|
|
|
|
|
@contextmanager
|
|
def __allow_nonbracketed_mutation():
|
|
global __allow_nonbracketed_mutation_flag
|
|
old = __allow_nonbracketed_mutation_flag
|
|
__allow_nonbracketed_mutation_flag = True
|
|
try:
|
|
yield
|
|
finally:
|
|
__allow_nonbracketed_mutation_flag = old
|
|
|
|
|
|
class ContextProp:
|
|
def __init__(self, getter, setter):
|
|
self.getter = getter
|
|
self.setter = setter
|
|
|
|
def __get__(self, obj, objtype):
|
|
return self.getter()
|
|
|
|
def __set__(self, obj, val):
|
|
if not flags_frozen():
|
|
self.setter(val)
|
|
else:
|
|
raise RuntimeError(
|
|
f"not allowed to set {obj.__name__} flags "
|
|
"after disable_global_flags; please use flags() context manager instead"
|
|
)
|
|
|
|
|
|
class PropModule(types.ModuleType):
|
|
def __init__(self, m, name):
|
|
super().__init__(name)
|
|
self.m = m
|
|
|
|
def __getattr__(self, attr):
|
|
return self.m.__getattribute__(attr)
|
|
|
|
|
|
from torch.backends import (
|
|
cpu as cpu,
|
|
cuda as cuda,
|
|
cudnn as cudnn,
|
|
cusparselt as cusparselt,
|
|
kleidiai as kleidiai,
|
|
mha as mha,
|
|
mkl as mkl,
|
|
mkldnn as mkldnn,
|
|
mps as mps,
|
|
nnpack as nnpack,
|
|
openmp as openmp,
|
|
quantized as quantized,
|
|
)
|