[AOTI][CPU] Consider bias=None case for fbgemm_linear_fp16_weight (#158535)

Test Plan:

Rollback Plan:

Differential Revision: D78458214

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158535
Approved by: https://github.com/houseroad, https://github.com/henryoier, https://github.com/jingsh
This commit is contained in:
Huamin Li
2025-07-21 23:42:40 +00:00
committed by PyTorch MergeBot
parent 08540b13c6
commit 2c37acfd89
10 changed files with 53 additions and 17 deletions

View File

@ -409,7 +409,7 @@ Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) {
Tensor fbgemm_linear_fp16_weight_fp32_activation(
const Tensor& input,
const Tensor& packed_weight,
const Tensor& bias) {
const std::optional<Tensor>& bias) {
TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated "
"and will be removed in a future PyTorch release.")
@ -430,7 +430,6 @@ Tensor fbgemm_linear_fp16_weight_fp32_activation(
TORCH_CHECK(input.size(input.dim() - 1) == packed_weight_fp16.numRows())
TORCH_CHECK(input.dim() >= 2);
TORCH_CHECK(bias.dim() == 1);
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
const int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
@ -449,7 +448,12 @@ Tensor fbgemm_linear_fp16_weight_fp32_activation(
output.data_ptr<float>());
// Add bias term
output.add_(bias);
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias);
const Tensor& bias_ = *bias_maybe_owned;
if (bias_.defined()) {
TORCH_CHECK(bias_.dim() == 1);
output.add_(bias_);
}
return output;
}
@ -551,7 +555,7 @@ Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) {
Tensor fbgemm_linear_fp16_weight_fp32_activation(
const Tensor& input,
const Tensor& packed_weight,
const Tensor& bias) {
const std::optional<Tensor>& bias) {
TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated "
"and will be removed in a future PyTorch release.")

View File

@ -3432,7 +3432,7 @@
- func: _wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor
- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor
- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor? bias) -> Tensor
- func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor

View File

@ -888,7 +888,7 @@ class QLinearUnpackedDynamicFp16 final {
static at::Tensor run(
at::Tensor input,
const at::Tensor& weight,
const at::Tensor& bias) {
const std::optional<at::Tensor>& bias) {
// We make a strong guarantee that models using these operators will have
// the same numerics across different machines. Therefore, we do not provide
// a fallback path and rather fail loudly if we cannot run FBGEMM.
@ -908,7 +908,7 @@ class QLinearUnpackedDynamicFp16 final {
static at::Tensor meta(
at::Tensor input,
const at::Tensor& weight,
const at::Tensor& bias) {
const std::optional<at::Tensor>& bias) {
// We make a strong guarantee that models using these operators will have
// the same numerics across different machines. Therefore, we do not provide
// a fallback path and rather fail loudly if we cannot run FBGEMM.
@ -929,7 +929,7 @@ class QLinearUnpackedDynamicFp16 final {
static at::Tensor run(
at::Tensor /* input */,
const at::Tensor& weight,
const at::Tensor& bias) {
const std::optional<at::Tensor>& bias) {
// We make a strong guarantee that models using these operators will have
// the same numerics across different machines. Therefore, we do not provide
// a fallback path and rather fail loudly if we cannot run FBGEMM.
@ -940,7 +940,7 @@ class QLinearUnpackedDynamicFp16 final {
static at::Tensor meta(
at::Tensor /* input */,
const at::Tensor& weight,
const at::Tensor& bias) {
const std::optional<at::Tensor>& bias) {
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}

View File

@ -142,7 +142,7 @@ TORCH_LIBRARY(quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16_unpacked_weight(Tensor X, Tensor weight, Tensor bias) -> Tensor Y"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16_unpacked_weight(Tensor X, Tensor weight, Tensor? bias) -> Tensor Y"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_leaky_relu(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i, float negative_slope) -> Tensor Y"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_tanh(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y"), {at::Tag::pt2_compliant_tag});

View File

@ -1455,6 +1455,22 @@ class AOTInductorTestsTemplate:
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
self.check_model(Model(self.device), example_inputs)
@skipIfNoFBGEMM
def test_quantized_linear_bias_none(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(10, 10, device=device)
def forward(self, x):
return torch.ops.quantized.linear_dynamic_fp16_unpacked_weight(
x, self.weight, None
)
example_inputs = (torch.randn(10, 10, device=self.device),)
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
self.check_model(Model(self.device), example_inputs)
@skipIfNoFBGEMM
def test_quanatized_int8_linear(self):
class Model(torch.nn.Module):
@ -6714,6 +6730,7 @@ GPU_TEST_FAILURES = {
# quantized unsupported for GPU
"test_quantized_linear": fail_gpu(("cuda", "xpu")),
"test_quanatized_int8_linear": fail_gpu(("cuda", "xpu")),
"test_quantized_linear_bias_none": fail_gpu(("cuda", "xpu")),
# No scaled_dot_product_efficient_attention implementation for XPU yet.
"test_scaled_dot_product_efficient_attention": fail_gpu(("xpu",)),
# No fft implementation for XPU yet.

View File

@ -512,6 +512,7 @@ CUDA_TEST_FAILURES = {
# quantized unsupported for GPU
"test_quantized_linear": fail_cuda(),
"test_quanatized_int8_linear": fail_cuda(),
"test_quantized_linear_bias_none": fail_cuda(),
}

View File

@ -3550,14 +3550,15 @@ class TestDynamicQuantizedOps(TestCase):
(2, 4), # batch_size
(4, 5), # input_channels
(4, 7), # output_channels
(True, False), # bias None or not
)
for batch_size, input_channels, output_channels in options:
for batch_size, input_channels, output_channels, bias_is_none in options:
pack_op = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16
linear_op = torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight
x = torch.randn(batch_size, input_channels)
w = torch.randn(output_channels, input_channels)
bias = torch.randn(output_channels)
bias = torch.randn(output_channels) if not bias_is_none else None
w_packed = pack_op(w)
out = linear_op(x, w_packed, bias, output_channels)
@ -3591,6 +3592,18 @@ class TestDynamicQuantizedOps(TestCase):
self.assertEqual(ref_out, compiled_out)
def func(X, W):
packed_W = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(W)
return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(X, packed_W, None, W.size(0))
ref_out = func(x, w)
compiled = torch.compile(func)
compiled_out = compiled(x, w)
self.assertEqual(ref_out, compiled_out)
"""Tests the correctness of the dynamic quantized lstm/gru."""
def _get_rnn_inputs(self, seq_len, num_batches, input_size, hidden_size, num_directions, reduce_range):

View File

@ -701,7 +701,7 @@ def randint(
def linear_dynamic_fp16_unpacked_weight(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight)
return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(

View File

@ -365,7 +365,7 @@ AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight(
AtenTensorHandle input,
AtenTensorHandle weight,
AtenTensorHandle bias,
AtenTensorHandle bias, // optional argument
int64_t out_channel,
AtenTensorHandle* out);

View File

@ -981,16 +981,17 @@ AOTITorchError aoti_torch_cpu__wrapped_linear_prepack(
AOTITorchError aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight(
AtenTensorHandle input,
AtenTensorHandle weight,
AtenTensorHandle bias,
AtenTensorHandle bias, // optional argument
int64_t out_channel,
AtenTensorHandle* out) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* input_tensor = tensor_handle_to_tensor_pointer(input);
at::Tensor* weight_tensor = tensor_handle_to_tensor_pointer(weight);
at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias);
auto optional_bias_tensor =
pointer_to_optional(tensor_handle_to_tensor_pointer(bias));
*out = new_tensor_handle(at::fbgemm_linear_fp16_weight_fp32_activation(
*input_tensor, *weight_tensor, *bias_tensor));
*input_tensor, *weight_tensor, optional_bias_tensor));
});
}