diff --git a/aten/src/ATen/native/quantized/FakeQuantPerChannelAffine.cpp b/aten/src/ATen/native/quantized/FakeQuantPerChannelAffine.cpp index 811830dd1a98..86601e346731 100644 --- a/aten/src/ATen/native/quantized/FakeQuantPerChannelAffine.cpp +++ b/aten/src/ATen/native/quantized/FakeQuantPerChannelAffine.cpp @@ -48,8 +48,8 @@ std::tuple fake_quantize_per_channel_affine_cachemask( int64_t axis, int64_t quant_min, int64_t quant_max) { - TORCH_CHECK(scale.scalar_type() == ScalarType::Float, - "Scale must be Float, found ", scale.scalar_type()); + TORCH_CHECK(scale.scalar_type() == ScalarType::Float || scale.scalar_type() == at::kBFloat16, + "Scale must be Float or BFloat16, found ", scale.scalar_type()); TORCH_CHECK(zero_point.scalar_type() == ScalarType::Int || zero_point.scalar_type() == ScalarType::Float || zero_point.scalar_type() == ScalarType::Half, "Zero-point must be Int32, Float or Half, found ", zero_point.scalar_type()); TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor"); diff --git a/test/quantization/core/test_workflow_ops.py b/test/quantization/core/test_workflow_ops.py index b107349678c6..d4ae27677dd7 100644 --- a/test/quantization/core/test_workflow_ops.py +++ b/test/quantization/core/test_workflow_ops.py @@ -823,21 +823,20 @@ class TestFakeQuantizeOps(TestCase): self._test_learnable_forward_per_channel( X_base, 'cpu', scale_base, zero_point_base, axis) - @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,), - qparams=hu.qparams(dtypes=torch.quint8))) @unittest.skipIf(not TEST_CUDA, "No gpu is not available.") - @unittest.skip( - "this is broken without changes to any relevant code, " - "we need to remove hypothesis testing in CI") - def test_learnable_forward_per_channel_cuda(self, X): + def test_learnable_forward_per_channel_cuda(self): torch.random.manual_seed(NP_RANDOM_SEED) - X, (_, _, axis, _) = X - X_base = torch.tensor(X).to('cuda') - channel_size = X_base.size(axis) - scale_base = torch.normal(mean=0, std=1, size=(channel_size,)).clamp(1e-4, 100) - zero_point_base = torch.normal(mean=0, std=128, size=(channel_size,)) - self._test_learnable_forward_per_channel( - X_base, 'cuda', scale_base, zero_point_base, axis) + shape = (2, 1, 2, 10) + axis = 1 + + for dtype in [torch.float32, torch.bfloat16]: + X_base = torch.randn(shape, device="cuda").to(dtype) + channel_size = X_base.size(axis) + scale_base = torch.normal(mean=0, std=1, size=(channel_size,)).clamp(1e-4, 100).to(dtype) + zero_point_base = torch.normal(mean=0, std=128, size=(channel_size,)).to(dtype) + + self._test_learnable_forward_per_channel( + X_base, 'cuda', scale_base, zero_point_base, axis) @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),