bf16 support for fake_quantize_learnable_per_channel_affine (#165098)

Adding bf16 support for `torch._fake_quantize_learnable_per_channel_affine()` op by relaxing the type check on scale

TODO: need to add bf16 support to `per_tensor_affine_` as `torch._fake_quantize_learnable_per_tensor_affine_backward` gets called in the backward pass

**Test**
Modified unit test in `test_workflow_ops.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165098
Approved by: https://github.com/jerryzh168, https://github.com/andrewor14
This commit is contained in:
Angel Li
2025-10-10 16:24:48 +00:00
committed by PyTorch MergeBot
parent abb2f7179e
commit 253fd765bd
2 changed files with 14 additions and 15 deletions

View File

@ -48,8 +48,8 @@ std::tuple<Tensor, Tensor> fake_quantize_per_channel_affine_cachemask(
int64_t axis,
int64_t quant_min,
int64_t quant_max) {
TORCH_CHECK(scale.scalar_type() == ScalarType::Float,
"Scale must be Float, found ", scale.scalar_type());
TORCH_CHECK(scale.scalar_type() == ScalarType::Float || scale.scalar_type() == at::kBFloat16,
"Scale must be Float or BFloat16, found ", scale.scalar_type());
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Int || zero_point.scalar_type() == ScalarType::Float || zero_point.scalar_type() == ScalarType::Half,
"Zero-point must be Int32, Float or Half, found ", zero_point.scalar_type());
TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");

View File

@ -823,19 +823,18 @@ class TestFakeQuantizeOps(TestCase):
self._test_learnable_forward_per_channel(
X_base, 'cpu', scale_base, zero_point_base, axis)
@given(X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
@unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
@unittest.skip(
"this is broken without changes to any relevant code, "
"we need to remove hypothesis testing in CI")
def test_learnable_forward_per_channel_cuda(self, X):
def test_learnable_forward_per_channel_cuda(self):
torch.random.manual_seed(NP_RANDOM_SEED)
X, (_, _, axis, _) = X
X_base = torch.tensor(X).to('cuda')
shape = (2, 1, 2, 10)
axis = 1
for dtype in [torch.float32, torch.bfloat16]:
X_base = torch.randn(shape, device="cuda").to(dtype)
channel_size = X_base.size(axis)
scale_base = torch.normal(mean=0, std=1, size=(channel_size,)).clamp(1e-4, 100)
zero_point_base = torch.normal(mean=0, std=128, size=(channel_size,))
scale_base = torch.normal(mean=0, std=1, size=(channel_size,)).clamp(1e-4, 100).to(dtype)
zero_point_base = torch.normal(mean=0, std=128, size=(channel_size,)).to(dtype)
self._test_learnable_forward_per_channel(
X_base, 'cuda', scale_base, zero_point_base, axis)