mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Quant][Inductor] Enable lowering of dynamic qlinear for X86Inductor (#120605)
**description** Enable lowering of dynamic qlinear for X86Inductor. The pattern is `choose_qparams -> getitem -> q -> dq -> linear`. We only fuse `dq -> linear` and get `choose_qparams -> getitem -> q -> onednn.qlinear_pointwise`. So, we treat it as dynamic quantization of activation + static quantized linear. The previous implementation of `onednn.qlinear_pointwise` is for the case where `x_scale` and `x_zp` are scalars. Since `choose_qparams` returns tensors, we added a variation `onednn.qlinear_pointwise.tensor` to support the case. This feature is targeting PyTorch 2.3 release. **Test plan** ``` python inductor/test_mkldnn_pattern_matcher.py -k test_dynamic_qlinear_cpu python inductor/test_mkldnn_pattern_matcher.py -k test_dynamic_qlinear_qat_cpu python inductor/test_cpu_cpp_wrapper.py -k test_dynamic_qlinear ``` **Performance before and after lowering `choose_qparam` to Inductor** Before - latency for shape (32, 32) = 0.151 ms latency for shape (128, 128) = 0.153 ms latency for shape (1024, 1024) = 0.247 ms After - latency for shape (32, 32) = 0.049 ms - latency for shape (128, 128) = 0.052 ms - latency for shape (1024, 1024) = 0.133 ms Test method: A module with a single Linear layer, dynamic-quantize, lower to X86Inductor Test env & config: Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz, single instance, single core, using Intel OpenMP and Tcmalloc Pull Request resolved: https://github.com/pytorch/pytorch/pull/120605 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
af5376c444
commit
83d848e1c7
@ -1158,6 +1158,33 @@ class QLinearOnednn final {
|
||||
bias, output_scale, output_zero_point, output_dtype,
|
||||
post_op_name, post_op_args, post_op_algorithm
|
||||
);
|
||||
#endif
|
||||
TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)");
|
||||
}
|
||||
|
||||
static Tensor run_pointwise_tensor(
|
||||
Tensor act, // int8 CPU tensor, not QTensor
|
||||
Tensor act_scale,
|
||||
Tensor act_zero_point,
|
||||
Tensor onednn_weight, // int8 tensor from MkldnnCPU
|
||||
Tensor weight_scales,
|
||||
Tensor weight_zero_points,
|
||||
c10::optional<Tensor> bias,
|
||||
double output_scale,
|
||||
int64_t output_zero_point,
|
||||
c10::optional<c10::ScalarType> output_dtype,
|
||||
std::string post_op_name,
|
||||
torch::List<c10::optional<at::Scalar>> post_op_args,
|
||||
std::string post_op_algorithm) {
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
|
||||
"onednn int8 linear: act scale/zp size should be 1");
|
||||
return linear_int8_with_onednn_weight(
|
||||
act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
|
||||
onednn_weight, weight_scales, weight_zero_points,
|
||||
bias, output_scale, output_zero_point, output_dtype,
|
||||
post_op_name, post_op_args, post_op_algorithm
|
||||
);
|
||||
#endif
|
||||
TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)");
|
||||
}
|
||||
@ -1185,6 +1212,8 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) {
|
||||
TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) {
|
||||
m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise"),
|
||||
TORCH_FN(QLinearOnednn::run_pointwise));
|
||||
m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.tensor"),
|
||||
TORCH_FN(QLinearOnednn::run_pointwise_tensor));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -266,4 +266,5 @@ TORCH_LIBRARY(onednn, m) {
|
||||
|
||||
// Linear with unary postop
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor"));
|
||||
}
|
||||
|
@ -257,6 +257,18 @@ if RUN_CPU:
|
||||
test_mkldnn_pattern_matcher.TestPatternMatcher(),
|
||||
condition=torch.backends.mkldnn.is_available(),
|
||||
),
|
||||
BaseTest(
|
||||
"test_dynamic_qlinear",
|
||||
"cpu",
|
||||
test_mkldnn_pattern_matcher.TestPatternMatcher(),
|
||||
condition=torch.backends.mkldnn.is_available(),
|
||||
),
|
||||
BaseTest(
|
||||
"test_dynamic_qlinear_qat",
|
||||
"cpu",
|
||||
test_mkldnn_pattern_matcher.TestPatternMatcher(),
|
||||
condition=torch.backends.mkldnn.is_available(),
|
||||
),
|
||||
BaseTest("test_randint"),
|
||||
BaseTest("test_randn_with_dtype_and_device"),
|
||||
BaseTest("test_reduction1"), # Reduction
|
||||
|
@ -89,7 +89,9 @@ class TestPatternMatcherBase(TestCase):
|
||||
|
||||
return tuple(clone(x) for x in inputs)
|
||||
|
||||
def _generate_qdq_quantized_model(self, mod, inputs, is_qat=False):
|
||||
def _generate_qdq_quantized_model(
|
||||
self, mod, inputs, is_qat=False, is_dynamic=False
|
||||
):
|
||||
maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad()
|
||||
with maybe_no_grad:
|
||||
export_model = capture_pre_autograd_graph(
|
||||
@ -98,7 +100,9 @@ class TestPatternMatcherBase(TestCase):
|
||||
)
|
||||
quantizer = X86InductorQuantizer()
|
||||
quantizer.set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config(is_qat=is_qat)
|
||||
xiq.get_default_x86_inductor_quantization_config(
|
||||
is_qat=is_qat, is_dynamic=is_dynamic
|
||||
)
|
||||
)
|
||||
prepare_model = (
|
||||
prepare_qat_pt2e(export_model, quantizer)
|
||||
@ -123,6 +127,7 @@ class TestPatternMatcherBase(TestCase):
|
||||
is_qat=False,
|
||||
matcher_check_fn=None,
|
||||
dtype=None,
|
||||
is_dynamic=False,
|
||||
):
|
||||
counters.clear()
|
||||
torch._dynamo.reset()
|
||||
@ -146,7 +151,9 @@ class TestPatternMatcherBase(TestCase):
|
||||
maybe_autocast = contextlib.nullcontext()
|
||||
|
||||
if check_quantization:
|
||||
convert_model = self._generate_qdq_quantized_model(mod, inputs, is_qat)
|
||||
convert_model = self._generate_qdq_quantized_model(
|
||||
mod, inputs, is_qat, is_dynamic
|
||||
)
|
||||
with torch.no_grad(), maybe_autocast:
|
||||
_ = torch.compile(convert_model)(*inputs)
|
||||
if matcher_count is not None:
|
||||
@ -1195,6 +1202,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
do_permute=False,
|
||||
matcher_check_fn=None,
|
||||
bias=True,
|
||||
is_dynamic=False,
|
||||
is_qat=False,
|
||||
):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, use_bias, do_permute=False):
|
||||
@ -1223,6 +1232,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
matcher_check_fn=matcher_check_fn
|
||||
if matcher_check_fn is not None
|
||||
else _default_matcher_check_fn,
|
||||
is_qat=is_qat,
|
||||
is_dynamic=is_dynamic,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -1235,6 +1246,30 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
for bias in [True, False]:
|
||||
self._qlinear_cpu_test_helper((torch.randn((2, 4)),), bias=bias)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
@skipIfRocm
|
||||
def test_dynamic_qlinear_cpu(self):
|
||||
r"""
|
||||
This testcase will quantize a single Linear Moduel.
|
||||
"""
|
||||
for bias in [True, False]:
|
||||
self._qlinear_cpu_test_helper(
|
||||
(torch.randn((2, 4)),), bias=bias, is_dynamic=True
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
@skipIfRocm
|
||||
def test_dynamic_qlinear_qat_cpu(self):
|
||||
r"""
|
||||
This testcase will quantize a single Linear Moduel.
|
||||
"""
|
||||
for bias in [True, False]:
|
||||
self._qlinear_cpu_test_helper(
|
||||
(torch.randn((2, 4)),), bias=bias, is_dynamic=True, is_qat=True
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNNBF16
|
||||
@skipIfNoONEDNN
|
||||
|
@ -521,7 +521,9 @@ def dequantize_per_tensor_tensor_decomp_impl(
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return (input.to(torch.float32) - zero_point) * scale
|
||||
return (input.to(torch.float32) - zero_point.to(torch.int32)) * scale.to(
|
||||
torch.float32
|
||||
)
|
||||
|
||||
|
||||
@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
|
||||
@ -646,3 +648,15 @@ def masked_scatter(self, mask, source):
|
||||
source_idx = mask.reshape(-1).cumsum(0) - 1
|
||||
return inductor_prims.masked_scatter_with_index(self, mask, source_idx, source)
|
||||
return NotImplemented
|
||||
|
||||
|
||||
@register_decomposition(quantized_decomposed.choose_qparams.tensor)
|
||||
def choose_qparams_tensor(
|
||||
input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype
|
||||
):
|
||||
min_val, max_val = torch.aminmax(input)
|
||||
scale = (max_val - min_val) / float(quant_max - quant_min)
|
||||
scale = torch.max(scale, torch.Tensor([eps]))
|
||||
zero_point = quant_min - torch.round(min_val / scale).to(torch.int)
|
||||
zero_point = torch.clamp(zero_point, quant_min, quant_max)
|
||||
return scale.to(torch.float64), zero_point.to(torch.int64)
|
||||
|
@ -150,22 +150,29 @@ def get_dequantize_qconv_pt2e_pattern(users=1):
|
||||
)
|
||||
|
||||
|
||||
qlinear_pt2e_pattern = CallFunction(
|
||||
torch.ops.onednn.qlinear_pointwise.default,
|
||||
KeywordArg("x"),
|
||||
KeywordArg("x_scale"),
|
||||
KeywordArg("x_zp"),
|
||||
KeywordArg("packed_weight"),
|
||||
KeywordArg("w_scale"),
|
||||
KeywordArg("w_zp"),
|
||||
KeywordArg("b"),
|
||||
KeywordArg("output_scale"),
|
||||
KeywordArg("output_zero_point"),
|
||||
KeywordArg("output_dtype"),
|
||||
KeywordArg("postop_name"),
|
||||
KeywordArg("postop_args"),
|
||||
KeywordArg("postop_algorithm"),
|
||||
)
|
||||
def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors):
|
||||
qlinear_op = (
|
||||
torch.ops.onednn.qlinear_pointwise.tensor
|
||||
if x_scale_zp_are_tensors
|
||||
else torch.ops.onednn.qlinear_pointwise.default
|
||||
)
|
||||
return CallFunction(
|
||||
qlinear_op,
|
||||
KeywordArg("x"),
|
||||
KeywordArg("x_scale"),
|
||||
KeywordArg("x_zp"),
|
||||
KeywordArg("packed_weight"),
|
||||
KeywordArg("w_scale"),
|
||||
KeywordArg("w_zp"),
|
||||
KeywordArg("b"),
|
||||
KeywordArg("output_scale"),
|
||||
KeywordArg("output_zero_point"),
|
||||
KeywordArg("output_dtype"),
|
||||
KeywordArg("postop_name"),
|
||||
KeywordArg("postop_args"),
|
||||
KeywordArg("postop_algorithm"),
|
||||
)
|
||||
|
||||
|
||||
dequantize_accum_pattern = CallFunction(
|
||||
aten.mul.Tensor,
|
||||
@ -670,44 +677,46 @@ def _register_quantization_unary_fusion():
|
||||
)
|
||||
|
||||
# QLinear
|
||||
# Priority 1 to match: QLinear Unary pattern with int8 output
|
||||
linear_unary_replace_patterns = {
|
||||
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
|
||||
qlinear_pt2e_pattern,
|
||||
dtype=original_pattern_output_dtype,
|
||||
),
|
||||
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
|
||||
generate_pattern_with_unary(qlinear_pt2e_pattern, aten.relu.default),
|
||||
dtype=original_pattern_output_dtype,
|
||||
),
|
||||
}
|
||||
for x_scale_zp_are_tensors in (False, True):
|
||||
qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
|
||||
# Priority 1 to match: QLinear Unary pattern with int8 output
|
||||
linear_unary_replace_patterns = {
|
||||
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
|
||||
qlinear_pattern,
|
||||
dtype=original_pattern_output_dtype,
|
||||
),
|
||||
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
|
||||
generate_pattern_with_unary(qlinear_pattern, aten.relu.default),
|
||||
dtype=original_pattern_output_dtype,
|
||||
),
|
||||
}
|
||||
|
||||
for unary_attr, patterns in linear_unary_replace_patterns.items():
|
||||
_register_quantized_linear_lowering(
|
||||
patterns,
|
||||
1, # pass_number
|
||||
torch.ops.onednn.qlinear_pointwise, # computation_op
|
||||
None, # output_dtype
|
||||
unary_attr, # unary_attr
|
||||
original_pattern_output_dtype=original_pattern_output_dtype,
|
||||
)
|
||||
for unary_attr, patterns in linear_unary_replace_patterns.items():
|
||||
_register_quantized_linear_lowering(
|
||||
patterns,
|
||||
1, # pass_number
|
||||
torch.ops.onednn.qlinear_pointwise, # computation_op
|
||||
None, # output_dtype
|
||||
unary_attr, # unary_attr
|
||||
original_pattern_output_dtype=original_pattern_output_dtype,
|
||||
)
|
||||
|
||||
# Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
|
||||
linear_unary_replace_float_out_patterns = {
|
||||
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
|
||||
qlinear_pt2e_pattern, aten.relu.default
|
||||
),
|
||||
}
|
||||
# Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
|
||||
linear_unary_replace_float_out_patterns = {
|
||||
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
|
||||
qlinear_pattern, aten.relu.default
|
||||
),
|
||||
}
|
||||
|
||||
for unary_attr, patterns in linear_unary_replace_float_out_patterns.items():
|
||||
_register_quantized_linear_lowering(
|
||||
patterns,
|
||||
2, # pass_number
|
||||
torch.ops.onednn.qlinear_pointwise, # computation_op
|
||||
original_pattern_output_dtype, # output_dtype
|
||||
unary_attr, # unary_attr
|
||||
original_pattern_output_dtype=original_pattern_output_dtype,
|
||||
)
|
||||
for unary_attr, patterns in linear_unary_replace_float_out_patterns.items():
|
||||
_register_quantized_linear_lowering(
|
||||
patterns,
|
||||
2, # pass_number
|
||||
torch.ops.onednn.qlinear_pointwise, # computation_op
|
||||
original_pattern_output_dtype, # output_dtype
|
||||
unary_attr, # unary_attr
|
||||
original_pattern_output_dtype=original_pattern_output_dtype,
|
||||
)
|
||||
|
||||
|
||||
def _register_quantization_binary_fusion():
|
||||
@ -1672,9 +1681,15 @@ def _register_qlinear_weight_prepack_pass(
|
||||
[], # post op args
|
||||
"", # post op algorithm
|
||||
)
|
||||
new_linear_node = graph.call_function(
|
||||
torch.ops.onednn.qlinear_pointwise.default, args=new_args
|
||||
)
|
||||
Node = torch.fx.node.Node
|
||||
if isinstance(x_scale, Node) and isinstance(x_zp, Node):
|
||||
new_linear_node = graph.call_function(
|
||||
torch.ops.onednn.qlinear_pointwise.tensor, args=new_args
|
||||
)
|
||||
else:
|
||||
new_linear_node = graph.call_function(
|
||||
torch.ops.onednn.qlinear_pointwise.default, args=new_args
|
||||
)
|
||||
if input_dim_exceeds_two:
|
||||
if input_contiguous:
|
||||
output_reshape_node.replace_all_uses_with(new_linear_node)
|
||||
|
@ -1034,6 +1034,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
torch.ops.onednn.qconv2d_pointwise.default,
|
||||
torch.ops.onednn.qconv2d_pointwise.binary,
|
||||
torch.ops.onednn.qlinear_pointwise.default,
|
||||
torch.ops.onednn.qlinear_pointwise.tensor,
|
||||
]
|
||||
if torch._C.has_mkl:
|
||||
need_fixed_layout += [torch.ops.mkl._mkl_linear.default]
|
||||
|
@ -6478,6 +6478,8 @@ class QLinearPointwisePT2E(ExternKernelAlloc):
|
||||
layout,
|
||||
inputs,
|
||||
constant_args=(),
|
||||
has_bias=True,
|
||||
x_scale_zp_are_tensors=False,
|
||||
):
|
||||
"""
|
||||
if bias is not None
|
||||
@ -6489,21 +6491,32 @@ class QLinearPointwisePT2E(ExternKernelAlloc):
|
||||
- const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp,
|
||||
fp32_output, unary_attr, unary_scalars, unary_algorithm]
|
||||
"""
|
||||
self.has_bias = len(inputs) == 5
|
||||
self.has_bias = has_bias
|
||||
self.x_scale_zp_are_tensors = x_scale_zp_are_tensors
|
||||
super().__init__(
|
||||
layout,
|
||||
inputs,
|
||||
constant_args,
|
||||
None,
|
||||
python_kernel_name="torch.ops.onednn.qlinear_pointwise",
|
||||
python_kernel_name=(
|
||||
"torch.ops.onednn.qlinear_pointwise.tensor"
|
||||
if x_scale_zp_are_tensors
|
||||
else "torch.ops.onednn.qlinear_pointwise.default"
|
||||
),
|
||||
cpp_kernel_name="onednn::qlinear_pointwise",
|
||||
)
|
||||
self.cpp_kernel_overload_name = "tensor" if x_scale_zp_are_tensors else ""
|
||||
self.cpp_kernel_key = "qlinear_pointwise"
|
||||
self.cpp_op_schema = """
|
||||
x_scale_type_str, x_zp_type_str = (
|
||||
("at::Tensor", "at::Tensor")
|
||||
if x_scale_zp_are_tensors
|
||||
else ("double", "int64_t")
|
||||
)
|
||||
self.cpp_op_schema = f"""
|
||||
at::Tensor(
|
||||
at::Tensor act,
|
||||
double act_scale,
|
||||
int64_t act_zero_point,
|
||||
{x_scale_type_str} act_scale,
|
||||
{x_zp_type_str} act_zero_point,
|
||||
at::Tensor weight,
|
||||
at::Tensor weight_scales,
|
||||
at::Tensor weight_zero_points,
|
||||
@ -6525,16 +6538,29 @@ class QLinearPointwisePT2E(ExternKernelAlloc):
|
||||
packed_weight = args[1]
|
||||
bias = args[2] if self.has_bias else const_args[0]
|
||||
w_scale, w_zp = args[-2], args[-1]
|
||||
(
|
||||
x_scale,
|
||||
x_zp,
|
||||
o_inv_scale,
|
||||
o_zp,
|
||||
output_dtype,
|
||||
unary_attr,
|
||||
unary_scalars,
|
||||
unary_algorithm,
|
||||
) = const_args[-8:]
|
||||
if self.x_scale_zp_are_tensors:
|
||||
assert len(args) >= 4
|
||||
x_scale, x_zp = args[-4], args[-3]
|
||||
(
|
||||
o_inv_scale,
|
||||
o_zp,
|
||||
output_dtype,
|
||||
unary_attr,
|
||||
unary_scalars,
|
||||
unary_algorithm,
|
||||
) = const_args[-6:]
|
||||
else:
|
||||
assert len(const_args) >= 8
|
||||
(
|
||||
x_scale,
|
||||
x_zp,
|
||||
o_inv_scale,
|
||||
o_zp,
|
||||
output_dtype,
|
||||
unary_attr,
|
||||
unary_scalars,
|
||||
unary_algorithm,
|
||||
) = const_args[-8:]
|
||||
|
||||
codegen_args = (
|
||||
x,
|
||||
@ -6557,6 +6583,7 @@ class QLinearPointwisePT2E(ExternKernelAlloc):
|
||||
codegen_args,
|
||||
self.cpp_op_schema,
|
||||
self.cpp_kernel_key,
|
||||
self.cpp_kernel_overload_name,
|
||||
)
|
||||
if isinstance(self.layout, Layout):
|
||||
self.codegen_size_asserts(wrapper)
|
||||
@ -6585,12 +6612,19 @@ class QLinearPointwisePT2E(ExternKernelAlloc):
|
||||
bias,
|
||||
)
|
||||
|
||||
if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox):
|
||||
x_scale.realize()
|
||||
x_zp.realize()
|
||||
inputs = inputs + [x_scale, x_zp]
|
||||
x_scale_zp_are_tensors = True
|
||||
else:
|
||||
assert isinstance(x_scale, float) and isinstance(x_zp, int)
|
||||
constant_args = constant_args + [x_scale, x_zp]
|
||||
x_scale_zp_are_tensors = False
|
||||
w_scale.realize()
|
||||
w_zp.realize()
|
||||
inputs = inputs + [w_scale, w_zp]
|
||||
constant_args = constant_args + [
|
||||
x_scale,
|
||||
x_zp,
|
||||
o_inv_scale,
|
||||
output_zero_point,
|
||||
output_dtype,
|
||||
@ -6609,6 +6643,8 @@ class QLinearPointwisePT2E(ExternKernelAlloc):
|
||||
layout=kernel_layout,
|
||||
inputs=inputs,
|
||||
constant_args=constant_args,
|
||||
has_bias=(bias is not None),
|
||||
x_scale_zp_are_tensors=x_scale_zp_are_tensors,
|
||||
)
|
||||
|
||||
|
||||
|
@ -2249,6 +2249,7 @@ if torch._C._has_mkldnn:
|
||||
return out
|
||||
|
||||
@register_meta(torch.ops.onednn.qlinear_pointwise.default)
|
||||
@register_meta(torch.ops.onednn.qlinear_pointwise.tensor)
|
||||
def meta_qlinear_pointwise(
|
||||
x,
|
||||
x_scale,
|
||||
|
Reference in New Issue
Block a user