Compare commits

...

3 Commits

Author SHA1 Message Date
9e3d367b21 Update
[ghstack-poisoned]
2025-11-12 07:53:09 +00:00
66713ca92a Update
[ghstack-poisoned]
2025-11-12 03:34:29 +00:00
77d3cc45fb Update (base update)
[ghstack-poisoned]
2025-11-12 03:34:29 +00:00
2 changed files with 14 additions and 7 deletions

View File

@ -1426,6 +1426,9 @@ static at::Tensor _fp8_convolution_onednn_ref(
w_scales_new_shape[0] = -1;
auto dqw = weight.to(at::kFloat) * weight_scales.reshape(w_scales_new_shape);
auto output_padding = std::vector<int64_t>(kSpatialDim, 0);
if (bias.has_value()){
bias = bias.value().to(at::kFloat);
}
auto y_f32 = at::convolution(
dqx, dqw, bias, stride.vec(), padding.vec(), dilation.vec(), /* transposed */false, output_padding, groups
);

View File

@ -7849,7 +7849,7 @@ class TestQuantizedConv(TestCase):
def _make_qconv_tensors_fp8(
self, batch_size, input_channels_per_group, input_feature_map_shape,
output_channels_per_group, groups, kernels, strides, pads, dilations,
use_bias, use_channelwise, use_transpose,
use_bias, use_channelwise, use_transpose, bfloat16_output,
device=torch.device("cpu"),
):
assert not (use_channelwise and use_transpose), \
@ -7879,9 +7879,11 @@ class TestQuantizedConv(TestCase):
X_q, X_scale = _quantize_fp8e4m3(X, channelwise=False)
W = torch.randn(output_shape + kernels, device=device) * 0.1
W_q, W_scale = _quantize_fp8e4m3(W, channelwise=use_channelwise)
bias_float = torch.randn((output_channels,), device=device) if use_bias else None
bias = torch.randn((output_channels,), device=device) if use_bias else None
if use_bias and bfloat16_output:
bias = bias.bfloat16()
return X, W, X_q, W_q, X_scale, W_scale, bias_float
return X, W, X_q, W_q, X_scale, W_scale, bias
def _test_qconv_impl_cpu_tensor_fp8(
self,
@ -7913,7 +7915,7 @@ class TestQuantizedConv(TestCase):
batch_size = 3
device = torch.device("cpu")
use_transpose = False
X, W, X_q, W_q, X_scale, W_scale, bias_float = self._make_qconv_tensors_fp8(
X, W, X_q, W_q, X_scale, W_scale, bias = self._make_qconv_tensors_fp8(
batch_size,
input_channels_per_group,
input_feature_map_shape,
@ -7926,11 +7928,13 @@ class TestQuantizedConv(TestCase):
use_bias,
use_channelwise,
use_transpose,
bfloat16_output,
device=device,
)
# Assign weights
dqW = _dequantize_fp8e4m3(W_q, W_scale)
dqX = _dequantize_fp8e4m3(X_q, X_scale)
bias_float = bias.float() if use_bias else bias
conv_op.weight = torch.nn.Parameter(dqW, requires_grad=False)
conv_op.bias = (
torch.nn.Parameter(bias_float, requires_grad=False) if use_bias else None
@ -8011,7 +8015,7 @@ class TestQuantizedConv(TestCase):
W_scale,
torch.zeros([], dtype=torch.int8), # W_zero_point
accum,
bias_float,
bias,
strides,
pads,
dilations,
@ -8035,7 +8039,7 @@ class TestQuantizedConv(TestCase):
packed_weight,
W_scale,
torch.zeros([], dtype=torch.int8), # W_zero_point
bias_float,
bias,
strides,
pads,
dilations,
@ -8050,7 +8054,7 @@ class TestQuantizedConv(TestCase):
if fp32_output or bfloat16_output:
self.assertTrue(result.dtype == qconv_output_dtype)
self.assertEqual(result.float(), result_ref.float(), atol=1e-6, rtol=1e-5)
self.assertEqual(result.float(), result_ref.float(), atol=1e-2, rtol=1e-2)
assert not torch.isnan(result).any()
def _test_qconv_fp8_helper(self, nd, pointwise_post_op):