mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[AOTI][CPU] Consider bias=None case for fbgemm_linear_fp16_weight (#158535)
Test Plan: Rollback Plan: Differential Revision: D78458214 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158535 Approved by: https://github.com/houseroad, https://github.com/henryoier, https://github.com/jingsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
08540b13c6
commit
2c37acfd89
@ -409,7 +409,7 @@ Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) {
|
||||
Tensor fbgemm_linear_fp16_weight_fp32_activation(
|
||||
const Tensor& input,
|
||||
const Tensor& packed_weight,
|
||||
const Tensor& bias) {
|
||||
const std::optional<Tensor>& bias) {
|
||||
TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated "
|
||||
"and will be removed in a future PyTorch release.")
|
||||
|
||||
@ -430,7 +430,6 @@ Tensor fbgemm_linear_fp16_weight_fp32_activation(
|
||||
|
||||
TORCH_CHECK(input.size(input.dim() - 1) == packed_weight_fp16.numRows())
|
||||
TORCH_CHECK(input.dim() >= 2);
|
||||
TORCH_CHECK(bias.dim() == 1);
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||||
const int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
|
||||
@ -449,7 +448,12 @@ Tensor fbgemm_linear_fp16_weight_fp32_activation(
|
||||
output.data_ptr<float>());
|
||||
|
||||
// Add bias term
|
||||
output.add_(bias);
|
||||
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias);
|
||||
const Tensor& bias_ = *bias_maybe_owned;
|
||||
if (bias_.defined()) {
|
||||
TORCH_CHECK(bias_.dim() == 1);
|
||||
output.add_(bias_);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
@ -551,7 +555,7 @@ Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) {
|
||||
Tensor fbgemm_linear_fp16_weight_fp32_activation(
|
||||
const Tensor& input,
|
||||
const Tensor& packed_weight,
|
||||
const Tensor& bias) {
|
||||
const std::optional<Tensor>& bias) {
|
||||
TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated "
|
||||
"and will be removed in a future PyTorch release.")
|
||||
|
||||
|
@ -3432,7 +3432,7 @@
|
||||
|
||||
- func: _wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor
|
||||
|
||||
- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor
|
||||
- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor? bias) -> Tensor
|
||||
|
||||
- func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor
|
||||
|
||||
|
@ -888,7 +888,7 @@ class QLinearUnpackedDynamicFp16 final {
|
||||
static at::Tensor run(
|
||||
at::Tensor input,
|
||||
const at::Tensor& weight,
|
||||
const at::Tensor& bias) {
|
||||
const std::optional<at::Tensor>& bias) {
|
||||
// We make a strong guarantee that models using these operators will have
|
||||
// the same numerics across different machines. Therefore, we do not provide
|
||||
// a fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||
@ -908,7 +908,7 @@ class QLinearUnpackedDynamicFp16 final {
|
||||
static at::Tensor meta(
|
||||
at::Tensor input,
|
||||
const at::Tensor& weight,
|
||||
const at::Tensor& bias) {
|
||||
const std::optional<at::Tensor>& bias) {
|
||||
// We make a strong guarantee that models using these operators will have
|
||||
// the same numerics across different machines. Therefore, we do not provide
|
||||
// a fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||
@ -929,7 +929,7 @@ class QLinearUnpackedDynamicFp16 final {
|
||||
static at::Tensor run(
|
||||
at::Tensor /* input */,
|
||||
const at::Tensor& weight,
|
||||
const at::Tensor& bias) {
|
||||
const std::optional<at::Tensor>& bias) {
|
||||
// We make a strong guarantee that models using these operators will have
|
||||
// the same numerics across different machines. Therefore, we do not provide
|
||||
// a fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||
@ -940,7 +940,7 @@ class QLinearUnpackedDynamicFp16 final {
|
||||
static at::Tensor meta(
|
||||
at::Tensor /* input */,
|
||||
const at::Tensor& weight,
|
||||
const at::Tensor& bias) {
|
||||
const std::optional<at::Tensor>& bias) {
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
}
|
||||
|
@ -142,7 +142,7 @@ TORCH_LIBRARY(quantized, m) {
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"), {at::Tag::pt2_compliant_tag});
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"), {at::Tag::pt2_compliant_tag});
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"), {at::Tag::pt2_compliant_tag});
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16_unpacked_weight(Tensor X, Tensor weight, Tensor bias) -> Tensor Y"), {at::Tag::pt2_compliant_tag});
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16_unpacked_weight(Tensor X, Tensor weight, Tensor? bias) -> Tensor Y"), {at::Tag::pt2_compliant_tag});
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"), {at::Tag::pt2_compliant_tag});
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_leaky_relu(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i, float negative_slope) -> Tensor Y"), {at::Tag::pt2_compliant_tag});
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_tanh(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y"), {at::Tag::pt2_compliant_tag});
|
||||
|
@ -1455,6 +1455,22 @@ class AOTInductorTestsTemplate:
|
||||
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
|
||||
self.check_model(Model(self.device), example_inputs)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_quantized_linear_bias_none(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.weight = torch.randn(10, 10, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.ops.quantized.linear_dynamic_fp16_unpacked_weight(
|
||||
x, self.weight, None
|
||||
)
|
||||
|
||||
example_inputs = (torch.randn(10, 10, device=self.device),)
|
||||
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
|
||||
self.check_model(Model(self.device), example_inputs)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_quanatized_int8_linear(self):
|
||||
class Model(torch.nn.Module):
|
||||
@ -6714,6 +6730,7 @@ GPU_TEST_FAILURES = {
|
||||
# quantized unsupported for GPU
|
||||
"test_quantized_linear": fail_gpu(("cuda", "xpu")),
|
||||
"test_quanatized_int8_linear": fail_gpu(("cuda", "xpu")),
|
||||
"test_quantized_linear_bias_none": fail_gpu(("cuda", "xpu")),
|
||||
# No scaled_dot_product_efficient_attention implementation for XPU yet.
|
||||
"test_scaled_dot_product_efficient_attention": fail_gpu(("xpu",)),
|
||||
# No fft implementation for XPU yet.
|
||||
|
@ -512,6 +512,7 @@ CUDA_TEST_FAILURES = {
|
||||
# quantized unsupported for GPU
|
||||
"test_quantized_linear": fail_cuda(),
|
||||
"test_quanatized_int8_linear": fail_cuda(),
|
||||
"test_quantized_linear_bias_none": fail_cuda(),
|
||||
}
|
||||
|
||||
|
||||
|
@ -3550,14 +3550,15 @@ class TestDynamicQuantizedOps(TestCase):
|
||||
(2, 4), # batch_size
|
||||
(4, 5), # input_channels
|
||||
(4, 7), # output_channels
|
||||
(True, False), # bias None or not
|
||||
)
|
||||
for batch_size, input_channels, output_channels in options:
|
||||
for batch_size, input_channels, output_channels, bias_is_none in options:
|
||||
pack_op = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16
|
||||
linear_op = torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight
|
||||
|
||||
x = torch.randn(batch_size, input_channels)
|
||||
w = torch.randn(output_channels, input_channels)
|
||||
bias = torch.randn(output_channels)
|
||||
bias = torch.randn(output_channels) if not bias_is_none else None
|
||||
|
||||
w_packed = pack_op(w)
|
||||
out = linear_op(x, w_packed, bias, output_channels)
|
||||
@ -3591,6 +3592,18 @@ class TestDynamicQuantizedOps(TestCase):
|
||||
|
||||
self.assertEqual(ref_out, compiled_out)
|
||||
|
||||
def func(X, W):
|
||||
packed_W = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(W)
|
||||
return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(X, packed_W, None, W.size(0))
|
||||
|
||||
ref_out = func(x, w)
|
||||
|
||||
compiled = torch.compile(func)
|
||||
compiled_out = compiled(x, w)
|
||||
|
||||
self.assertEqual(ref_out, compiled_out)
|
||||
|
||||
|
||||
"""Tests the correctness of the dynamic quantized lstm/gru."""
|
||||
|
||||
def _get_rnn_inputs(self, seq_len, num_batches, input_size, hidden_size, num_directions, reduce_range):
|
||||
|
@ -701,7 +701,7 @@ def randint(
|
||||
def linear_dynamic_fp16_unpacked_weight(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight)
|
||||
return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(
|
||||
|
@ -365,7 +365,7 @@ AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight(
|
||||
AtenTensorHandle input,
|
||||
AtenTensorHandle weight,
|
||||
AtenTensorHandle bias,
|
||||
AtenTensorHandle bias, // optional argument
|
||||
int64_t out_channel,
|
||||
AtenTensorHandle* out);
|
||||
|
||||
|
@ -981,16 +981,17 @@ AOTITorchError aoti_torch_cpu__wrapped_linear_prepack(
|
||||
AOTITorchError aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight(
|
||||
AtenTensorHandle input,
|
||||
AtenTensorHandle weight,
|
||||
AtenTensorHandle bias,
|
||||
AtenTensorHandle bias, // optional argument
|
||||
int64_t out_channel,
|
||||
AtenTensorHandle* out) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
at::Tensor* input_tensor = tensor_handle_to_tensor_pointer(input);
|
||||
at::Tensor* weight_tensor = tensor_handle_to_tensor_pointer(weight);
|
||||
at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias);
|
||||
auto optional_bias_tensor =
|
||||
pointer_to_optional(tensor_handle_to_tensor_pointer(bias));
|
||||
|
||||
*out = new_tensor_handle(at::fbgemm_linear_fp16_weight_fp32_activation(
|
||||
*input_tensor, *weight_tensor, *bias_tensor));
|
||||
*input_tensor, *weight_tensor, optional_bias_tensor));
|
||||
});
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user