[Quant][X86] add ops to compute uint8 pointwise add/add_relu (#152411)

**Summary**
This PR adds two new ops, `onednn.qadd.tensor` and `onednn.qadd_relu.tensor`, for int8 elementwise add, which accepts inputs on CPU device (instead of QuantizedCPU).
The new ops are implemented with AVX512 instructions and it provides similar or better performance, depending on shape, than its counterpart for QuantizedCPU device `quantized.add` and `quantized.add_relu`.
The new op supports output dtypes other than uint8 (fp32, fp16 and bf16 are supported).

**Test plan**
```
pytest test/quantization/core/test_quantized_op.py -k test_int8_add_onednn
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152411
Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168
This commit is contained in:
Xia, Weiwen
2025-05-14 19:13:11 -07:00
committed by PyTorch MergeBot
parent a762dd1f67
commit 55784be01b
5 changed files with 244 additions and 60 deletions

View File

@ -535,6 +535,48 @@ Tensor qadd_scalar_tensor_out(Tensor qa, Tensor b, Tensor out) {
return qadd_scalar_out(std::move(qa), b.item(), std::move(out)); return qadd_scalar_out(std::move(qa), b.item(), std::move(out));
} }
DEFINE_DISPATCH(qadd_tensor_cpu_stub);
DEFINE_DISPATCH(qadd_relu_tensor_cpu_stub);
template <bool ReLUFused = false>
Tensor int8_add_tensor_onednn(
const Tensor& self, double self_scale, int64_t self_zero_point,
const Tensor& other, double other_scale, int64_t other_zero_point,
double output_scale, int64_t output_zero_point, c10::ScalarType output_dtype) {
// Both inputs should have the same shape and both in uint8 dtype.
// If output_dtype is uint8, output is requantized with output scale/zero point.
// Otherwise, output scale should be 1 and zero point 0.
TORCH_CHECK(self.sizes() == other.sizes(),
"Quantized add operands should have the same size.");
TORCH_CHECK(self.scalar_type() == at::kByte && other.scalar_type() == at::kByte,
"Quantized add operands should be of type uint8, but got ",
self.scalar_type(), " and ", other.scalar_type());
TORCH_CHECK(output_dtype == at::kByte || output_dtype == at::kFloat || output_dtype == at::kBFloat16 || output_dtype == at::kHalf,
"Quantized add output should be of type uint8, float, bfloat16 or float16, but got ",
output_dtype);
if (output_dtype != at::kByte) {
TORCH_CHECK(output_scale == 1.0 && output_zero_point == 0,
"Quantized add output scale and zero point should be 1 and 0 for "
"output_dtype ", output_dtype, ", but got scale = ",
output_scale, " and zero point = ", output_zero_point);
}
at::Tensor out = at::empty_like(self, self.options().dtype(output_dtype));
if constexpr (ReLUFused) {
qadd_relu_tensor_cpu_stub(
self.device().type(), out, self, self_scale, self_zero_point,
other, other_scale, other_zero_point,
output_scale, output_zero_point);
} else {
qadd_tensor_cpu_stub(
self.device().type(), out, self, self_scale, self_zero_point,
other, other_scale, other_zero_point,
output_scale, output_zero_point);
}
return out;
}
TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::add"), TORCH_FN(qadd</*ReLUFused=*/false>)); m.impl(TORCH_SELECTIVE_NAME("quantized::add"), TORCH_FN(qadd</*ReLUFused=*/false>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add.out"), TORCH_FN(qadd_out</*ReLUFused=*/false>)); m.impl(TORCH_SELECTIVE_NAME("quantized::add.out"), TORCH_FN(qadd_out</*ReLUFused=*/false>));
@ -563,6 +605,11 @@ TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("_quantized::add"), TORCH_FN(qadd</*ReLUFused=*/false>)); m.impl(TORCH_SELECTIVE_NAME("_quantized::add"), TORCH_FN(qadd</*ReLUFused=*/false>));
} }
TORCH_LIBRARY_IMPL(onednn, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("onednn::qadd.tensor"), TORCH_FN(int8_add_tensor_onednn<false>));
m.impl(TORCH_SELECTIVE_NAME("onednn::qadd_relu.tensor"), TORCH_FN(int8_add_tensor_onednn<true>));
}
} // namespace } // namespace
Tensor quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point){ Tensor quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point){

View File

@ -216,7 +216,7 @@ using qnormalize_nhwc_fn = void (*)(
using qprelu_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/, using qprelu_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/,
const Tensor& /*qw*/); const Tensor& /*qw*/);
using qmul_tensor_cpu_fn = void (*)( using qbinary_eltwise_cpu_fn = void (*)(
Tensor& /*out*/, Tensor& /*out*/,
const Tensor& /*qx*/, const Tensor& /*qx*/,
double /*qx_scale*/, double /*qx_scale*/,
@ -263,6 +263,8 @@ DECLARE_DISPATCH(qupsample_bilinear2d_fn, qupsample_bilinear2d_nhwc_stub)
DECLARE_DISPATCH(qmean_inner_dim_fn, qmean_inner_dim_stub) DECLARE_DISPATCH(qmean_inner_dim_fn, qmean_inner_dim_stub)
DECLARE_DISPATCH(qstd_inner_dim_fn, qstd_inner_dim_stub) DECLARE_DISPATCH(qstd_inner_dim_fn, qstd_inner_dim_stub)
DECLARE_DISPATCH(qprelu_fn, qprelu_stub) DECLARE_DISPATCH(qprelu_fn, qprelu_stub)
DECLARE_DISPATCH(qmul_tensor_cpu_fn, qmul_tensor_cpu_stub) DECLARE_DISPATCH(qbinary_eltwise_cpu_fn, qmul_tensor_cpu_stub)
DECLARE_DISPATCH(qbinary_eltwise_cpu_fn, qadd_tensor_cpu_stub)
DECLARE_DISPATCH(qbinary_eltwise_cpu_fn, qadd_relu_tensor_cpu_stub)
} // namespace at::native } // namespace at::native

View File

@ -4260,6 +4260,26 @@ void _qmul_tensor_cpu_impl(
double output_scale, double output_scale,
int64_t output_zero_point) { int64_t output_zero_point) {
float multiplier = x_scale * y_scale / output_scale; float multiplier = x_scale * y_scale / output_scale;
auto compute_with_scalar = [&](int idx) {
uint8_t x_data = *(x_ptr + idx);
uint8_t y_data = *(y_ptr + idx);
int32_t x_val = static_cast<int32_t>(x_data) - x_zero_point;
int32_t y_val = static_cast<int32_t>(y_data) - y_zero_point;
int32_t out_val = static_cast<int32_t>(x_val * y_val);
float out_val_f = (float)out_val * multiplier;
if constexpr (std::is_same<T, float>::value) {
*(out_ptr + idx) = out_val_f;
} else if constexpr (std::is_same<T, at::BFloat16>::value) {
*(out_ptr + idx) = at::BFloat16(out_val_f);
} else if constexpr (std::is_same<T, at::Half>::value) {
*(out_ptr + idx) = at::Half(out_val_f);
} else { // T == uint8, requantization needed
out_val_f = std::round(out_val_f);
int32_t out_val_i32 = (int32_t)out_val_f + output_zero_point;
out_val_i32 = std::min(255, std::max(0, out_val_i32));
*(out_ptr + idx) = static_cast<uint8_t>(out_val_i32);
}
};
#if defined(CPU_CAPABILITY_AVX512) #if defined(CPU_CAPABILITY_AVX512)
int64_t size_rem = size % 16; int64_t size_rem = size % 16;
int64_t size_com = size - size_rem; int64_t size_com = size - size_rem;
@ -4304,47 +4324,13 @@ void _qmul_tensor_cpu_impl(
}); });
if (size_rem > 0) { if (size_rem > 0) {
for (const auto d : c10::irange(size_rem)) { for (const auto d : c10::irange(size_rem)) {
uint8_t x_data = *(x_ptr + size_com + d); compute_with_scalar(size_com + d);
uint8_t y_data = *(y_ptr + size_com + d);
int32_t x_val = static_cast<int32_t>(x_data) - x_zero_point;
int32_t y_val = static_cast<int32_t>(y_data) - y_zero_point;
int32_t out_val = static_cast<int32_t>(x_val * y_val);
float out_val_f = (float)out_val * multiplier;
if constexpr (std::is_same<T, float>::value) {
*(out_ptr + size_com + d) = out_val_f;
} else if constexpr (std::is_same<T, at::BFloat16>::value) {
*(out_ptr + size_com + d) = at::BFloat16(out_val_f);
} else if constexpr (std::is_same<T, at::Half>::value) {
*(out_ptr + size_com + d) = at::Half(out_val_f);
} else { // T == uint8, requantization needed
out_val_f = std::round(out_val_f);
int32_t out_val_i32 = (int32_t)out_val_f + output_zero_point;
out_val_i32 = std::min(255, std::max(0, out_val_i32));
*(out_ptr + size_com + d) = static_cast<uint8_t>(out_val_i32);
}
} }
} }
#else #else
at::parallel_for(0, size, 1, [&](int64_t start, int64_t end) { at::parallel_for(0, size, 1, [&](int64_t start, int64_t end) {
for (const auto d : c10::irange(start, end)) { for (const auto d : c10::irange(start, end)) {
uint8_t x_data = *(x_ptr + d); compute_with_scalar(d);
uint8_t y_data = *(y_ptr + d);
int32_t x_val = static_cast<int32_t>(x_data) - x_zero_point;
int32_t y_val = static_cast<int32_t>(y_data) - y_zero_point;
int32_t out_val = static_cast<int32_t>(x_val * y_val);
float out_val_f = (float)out_val * multiplier;
if constexpr (std::is_same<T, float>::value) {
*(out_ptr + d) = out_val_f;
} else if constexpr (std::is_same<T, at::BFloat16>::value) {
*(out_ptr + d) = at::BFloat16(out_val_f);
} else if constexpr (std::is_same<T, at::Half>::value) {
*(out_ptr + d) = at::Half(out_val_f);
} else { // T == uint8, requantization needed
out_val_f = std::round(out_val_f);
int32_t out_val_i32 = (int32_t)out_val_f + output_zero_point;
out_val_i32 = std::min(255, std::max(0, out_val_i32));
*(out_ptr + d) = static_cast<uint8_t>(out_val_i32);
}
} }
}); });
#endif #endif
@ -4366,29 +4352,139 @@ void qmul_tensor_cpu_kernel(
TORCH_CHECK( TORCH_CHECK(
size == qy.numel() && size == out.numel(), size == qy.numel() && size == out.numel(),
"qmul_cpu: Expect qx, qy and out to have the same number of elements"); "qmul_cpu: Expect qx, qy and out to have the same number of elements");
if (out.scalar_type() == c10::ScalarType::Float) { AT_DISPATCH_FLOATING_TYPES_AND3(
auto out_ptr = out.data_ptr<float>(); at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Byte, out.scalar_type(), "int8_mul_cpu", [&] {
_qmul_tensor_cpu_impl<float>( auto out_ptr = out.data_ptr<scalar_t>();
out_ptr, size, qx_ptr, qx_scale, qx_zero_point, qy_ptr, qy_scale, qy_zero_point, output_scale, output_zero_point _qmul_tensor_cpu_impl<scalar_t>(
); out_ptr, size, qx_ptr, qx_scale, qx_zero_point, qy_ptr, qy_scale, qy_zero_point, output_scale, output_zero_point);
} else if (out.scalar_type() == c10::ScalarType::BFloat16) { });
auto out_ptr = out.data_ptr<at::BFloat16>(); }
_qmul_tensor_cpu_impl<at::BFloat16>(
out_ptr, size, qx_ptr, qx_scale, qx_zero_point, qy_ptr, qy_scale, qy_zero_point, output_scale, output_zero_point template<typename T, bool ReLUFused>
); void _qadd_tensor_cpu_impl(
} else if (out.scalar_type() == c10::ScalarType::Half) { T* out_ptr,
auto out_ptr = out.data_ptr<at::Half>(); int64_t size,
_qmul_tensor_cpu_impl<at::Half>( const uint8_t* x_ptr,
out_ptr, size, qx_ptr, qx_scale, qx_zero_point, qy_ptr, qy_scale, qy_zero_point, output_scale, output_zero_point double x_scale,
); int64_t x_zero_point,
} else { const uint8_t* y_ptr,
TORCH_CHECK(out.scalar_type() == c10::ScalarType::Byte, double y_scale,
"qmul_cpu: Unsupported output dtype: ", out.scalar_type()); int64_t y_zero_point,
auto out_ptr = out.data_ptr<uint8_t>(); double output_scale,
_qmul_tensor_cpu_impl<uint8_t>( int64_t output_zero_point) {
out_ptr, size, qx_ptr, qx_scale, qx_zero_point, qy_ptr, qy_scale, qy_zero_point, output_scale, output_zero_point float inv_output_scale = 1.0 / output_scale;
); auto compute_with_scalar = [&](int idx) {
uint8_t x_data = *(x_ptr + idx);
uint8_t y_data = *(y_ptr + idx);
int32_t x_val = static_cast<int32_t>(x_data) - x_zero_point;
int32_t y_val = static_cast<int32_t>(y_data) - y_zero_point;
float x_val_f = static_cast<float>(x_val) * x_scale;
float y_val_f = static_cast<float>(y_val) * y_scale;
float out_val_f = x_val_f + y_val_f;
if constexpr (ReLUFused) {
out_val_f = std::max(out_val_f, 0.f);
}
if constexpr (std::is_same<T, float>::value) {
*(out_ptr + idx) = out_val_f;
} else if constexpr (std::is_same<T, at::BFloat16>::value) {
*(out_ptr + idx) = at::BFloat16(out_val_f);
} else if constexpr (std::is_same<T, at::Half>::value) {
*(out_ptr + idx) = at::Half(out_val_f);
} else { // T == uint8, requantization needed
out_val_f = std::round(out_val_f * inv_output_scale);
int32_t out_val_i32 = (int32_t)out_val_f + output_zero_point;
out_val_i32 = std::min(255, std::max(0, out_val_i32));
*(out_ptr + idx) = static_cast<uint8_t>(out_val_i32);
}
};
#if defined(CPU_CAPABILITY_AVX512)
int64_t size_rem = size % 16;
int64_t size_com = size - size_rem;
int64_t steps = size_com / 16;
__m512 vsa = _mm512_set1_ps(x_scale);
__m512 vsb = _mm512_set1_ps(y_scale);
__m512 vsc = _mm512_set1_ps(inv_output_scale);
__m512i vza = _mm512_set1_epi32(x_zero_point);
__m512i vzb = _mm512_set1_epi32(y_zero_point);
__m512i vzc = _mm512_set1_epi32(output_zero_point);
__m512i v255 = _mm512_set1_epi32(255);
__m512i v0 = _mm512_set1_epi32(0);
__m512 v0f = _mm512_set1_ps(0);
at::parallel_for(0, steps, 1, [&](int64_t start, int64_t end) {
for (const auto d : c10::irange(start, end)) {
auto x_data = x_ptr + d * 16;
auto y_data = y_ptr + d * 16;
auto out_data = out_ptr + d * 16;
__m128i va = _mm_loadu_si128((__m128i*)x_data);
__m128i vb = _mm_loadu_si128((__m128i*)y_data);
__m512i va_i32 = _mm512_cvtepi8_epi32(va);
__m512i vb_i32 = _mm512_cvtepi8_epi32(vb);
va_i32 = _mm512_sub_epi32(va_i32, vza);
vb_i32 = _mm512_sub_epi32(vb_i32, vzb);
__m512 va_f = _mm512_cvtepi32_ps(va_i32);
__m512 vb_f = _mm512_cvtepi32_ps(vb_i32);
va_f = _mm512_mul_ps(va_f, vsa);
vb_f = _mm512_mul_ps(vb_f, vsb);
__m512 vc_f = _mm512_add_ps(va_f, vb_f);
if constexpr (ReLUFused) {
vc_f = _mm512_max_ps(vc_f, v0f);
}
if constexpr (std::is_same<T, float>::value) {
_mm512_storeu_ps(out_data, vc_f);
} else if constexpr (std::is_same<T, at::BFloat16>::value) {
__m256i vc_bf16 = cvtfp32_bf16(vc_f);
_mm256_storeu_si256((__m256i*)out_data, vc_bf16);
} else if constexpr (std::is_same<T, at::Half>::value) {
__m256i vc_f16 = cvtfp32_fp16(vc_f);
_mm256_storeu_si256((__m256i*)out_data, vc_f16);
} else { // T == uint8, requantization needed
vc_f = _mm512_mul_ps(vc_f, vsc);
__m512i vc_i32 = _mm512_cvtps_epi32(vc_f);
vc_i32 = _mm512_add_epi32(vc_i32, vzc);
vc_i32 = _mm512_min_epi32(vc_i32, v255);
vc_i32 = _mm512_max_epi32(vc_i32, v0);
__m128i vc_i8 = _mm512_cvtepi32_epi8(vc_i32);
_mm_storeu_si128((__m128i*)out_data, vc_i8);
}
}
});
if (size_rem > 0) {
for (const auto d : c10::irange(size_rem)) {
compute_with_scalar(size_com + d);
}
} }
#else
at::parallel_for(0, size, 1, [&](int64_t start, int64_t end) {
for (const auto d : c10::irange(start, end)) {
compute_with_scalar(d);
}
});
#endif
}
template <bool ReLUFused>
void qadd_tensor_cpu_kernel(
Tensor& out,
const Tensor& qx,
double qx_scale,
int64_t qx_zero_point,
const Tensor& qy,
double qy_scale,
int64_t qy_zero_point,
double output_scale,
int64_t output_zero_point) {
auto qx_ptr = qx.const_data_ptr<uint8_t>();
auto qy_ptr = qy.const_data_ptr<uint8_t>();
int64_t size = qx.numel();
TORCH_CHECK(
size == qy.numel() && size == out.numel(),
"qadd_cpu: Expect qx, qy and out to have the same number of elements");
AT_DISPATCH_FLOATING_TYPES_AND3(
at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Byte, out.scalar_type(), "int8_add_cpu", [&] {
auto out_ptr = out.data_ptr<scalar_t>();
_qadd_tensor_cpu_impl<scalar_t, ReLUFused>(
out_ptr, size, qx_ptr, qx_scale, qx_zero_point, qy_ptr, qy_scale, qy_zero_point, output_scale, output_zero_point);
});
} }
} // anonymous namespace } // anonymous namespace
@ -4489,5 +4585,7 @@ REGISTER_DISPATCH(
REGISTER_DISPATCH(qmean_inner_dim_stub, &qmean_inner_dim_kernel) REGISTER_DISPATCH(qmean_inner_dim_stub, &qmean_inner_dim_kernel)
REGISTER_DISPATCH(qstd_inner_dim_stub, &qstd_inner_dim_kernel) REGISTER_DISPATCH(qstd_inner_dim_stub, &qstd_inner_dim_kernel)
ALSO_REGISTER_AVX512_DISPATCH(qmul_tensor_cpu_stub, &qmul_tensor_cpu_kernel) ALSO_REGISTER_AVX512_DISPATCH(qmul_tensor_cpu_stub, &qmul_tensor_cpu_kernel)
ALSO_REGISTER_AVX512_DISPATCH(qadd_tensor_cpu_stub, &qadd_tensor_cpu_kernel<false>)
ALSO_REGISTER_AVX512_DISPATCH(qadd_relu_tensor_cpu_stub, &qadd_tensor_cpu_kernel<true>)
} // namespace at::native } // namespace at::native
// NOLINTEND(*-c-arrays) // NOLINTEND(*-c-arrays)

View File

@ -280,4 +280,7 @@ TORCH_LIBRARY(onednn, m) {
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.binary_tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? other, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, float other_scale, int other_zp, str binary_post_op, float binary_alpha, str unary_post_op, Scalar?[] unary_post_op_args, str unary_post_op_algorithm) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.binary_tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? other, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, float other_scale, int other_zp, str binary_post_op, float binary_alpha, str unary_post_op, Scalar?[] unary_post_op_args, str unary_post_op_algorithm) -> Tensor"));
// int8 mul // int8 mul
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qmul.tensor(Tensor self, float self_scale, int self_zero_point, Tensor other, float other_scale, int other_zero_point, float output_scale, int output_zero_point, ScalarType output_dtype) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("onednn::qmul.tensor(Tensor self, float self_scale, int self_zero_point, Tensor other, float other_scale, int other_zero_point, float output_scale, int output_zero_point, ScalarType output_dtype) -> Tensor"));
// int8 add
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qadd.tensor(Tensor self, float self_scale, int self_zero_point, Tensor other, float other_scale, int other_zero_point, float output_scale, int output_zero_point, ScalarType output_dtype) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qadd_relu.tensor(Tensor self, float self_scale, int self_zero_point, Tensor other, float other_scale, int other_zero_point, float output_scale, int output_zero_point, ScalarType output_dtype) -> Tensor"));
} }

View File

@ -3167,6 +3167,40 @@ class TestQuantizedOps(TestCase):
c = torch.ops.onednn.qmul.tensor(a_int8, s_a, z_a, b_int8, s_b, z_b, s_c, z_c, output_dtype) c = torch.ops.onednn.qmul.tensor(a_int8, s_a, z_a, b_int8, s_b, z_b, s_c, z_c, output_dtype)
self.assertEqual(c, c_ref) self.assertEqual(c, c_ref)
@skipIfNoONEDNN
@given(relu_fused=st.booleans())
def test_int8_add_onednn(self, relu_fused):
output_dtype_list = [torch.uint8, torch.float, torch.bfloat16, torch.half]
shape_list = [(16, 64), (15, 63)]
cases = itertools.product(shape_list, output_dtype_list)
for shape, output_dtype in cases:
a = torch.randn(shape)
b = torch.randn(shape)
s_a, z_a = 0.1, 1
s_b, z_b = 0.2, 2
if output_dtype == torch.uint8:
s_c, z_c = 0.3, 3
else:
s_c, z_c = 1, 0
qa = torch.quantize_per_tensor(a, s_a, z_a, torch.quint8)
qb = torch.quantize_per_tensor(b, s_b, z_b, torch.quint8)
dqa = qa.dequantize()
dqb = qb.dequantize()
c_ref = dqa + dqb
if relu_fused:
c_ref = torch.nn.functional.relu(c_ref)
if output_dtype == torch.uint8:
c_ref = torch.ops.quantized_decomposed.quantize_per_tensor.default(c_ref, s_c, z_c, 0, 255, torch.uint8)
c_ref = c_ref.to(output_dtype)
a_int8 = qa.int_repr()
b_int8 = qb.int_repr()
if relu_fused:
c = torch.ops.onednn.qadd_relu.tensor(a_int8, s_a, z_a, b_int8, s_b, z_b, s_c, z_c, output_dtype)
else:
c = torch.ops.onednn.qadd.tensor(a_int8, s_a, z_a, b_int8, s_b, z_b, s_c, z_c, output_dtype)
self.assertEqual(c, c_ref)
class TestDynamicQuantizedOps(TestCase): class TestDynamicQuantizedOps(TestCase):
"""Tests the correctness of the dynamic quantized linear and linear_relu op.""" """Tests the correctness of the dynamic quantized linear and linear_relu op."""