mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
bf16 support for fake_quantize_learnable_per_channel_affine (#165098)
Adding bf16 support for `torch._fake_quantize_learnable_per_channel_affine()` op by relaxing the type check on scale TODO: need to add bf16 support to `per_tensor_affine_` as `torch._fake_quantize_learnable_per_tensor_affine_backward` gets called in the backward pass **Test** Modified unit test in `test_workflow_ops.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165098 Approved by: https://github.com/jerryzh168, https://github.com/andrewor14
This commit is contained in:
committed by
PyTorch MergeBot
parent
abb2f7179e
commit
253fd765bd
@ -48,8 +48,8 @@ std::tuple<Tensor, Tensor> fake_quantize_per_channel_affine_cachemask(
|
||||
int64_t axis,
|
||||
int64_t quant_min,
|
||||
int64_t quant_max) {
|
||||
TORCH_CHECK(scale.scalar_type() == ScalarType::Float,
|
||||
"Scale must be Float, found ", scale.scalar_type());
|
||||
TORCH_CHECK(scale.scalar_type() == ScalarType::Float || scale.scalar_type() == at::kBFloat16,
|
||||
"Scale must be Float or BFloat16, found ", scale.scalar_type());
|
||||
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Int || zero_point.scalar_type() == ScalarType::Float || zero_point.scalar_type() == ScalarType::Half,
|
||||
"Zero-point must be Int32, Float or Half, found ", zero_point.scalar_type());
|
||||
TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
|
||||
|
@ -823,21 +823,20 @@ class TestFakeQuantizeOps(TestCase):
|
||||
self._test_learnable_forward_per_channel(
|
||||
X_base, 'cpu', scale_base, zero_point_base, axis)
|
||||
|
||||
@given(X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
|
||||
qparams=hu.qparams(dtypes=torch.quint8)))
|
||||
@unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
|
||||
@unittest.skip(
|
||||
"this is broken without changes to any relevant code, "
|
||||
"we need to remove hypothesis testing in CI")
|
||||
def test_learnable_forward_per_channel_cuda(self, X):
|
||||
def test_learnable_forward_per_channel_cuda(self):
|
||||
torch.random.manual_seed(NP_RANDOM_SEED)
|
||||
X, (_, _, axis, _) = X
|
||||
X_base = torch.tensor(X).to('cuda')
|
||||
channel_size = X_base.size(axis)
|
||||
scale_base = torch.normal(mean=0, std=1, size=(channel_size,)).clamp(1e-4, 100)
|
||||
zero_point_base = torch.normal(mean=0, std=128, size=(channel_size,))
|
||||
self._test_learnable_forward_per_channel(
|
||||
X_base, 'cuda', scale_base, zero_point_base, axis)
|
||||
shape = (2, 1, 2, 10)
|
||||
axis = 1
|
||||
|
||||
for dtype in [torch.float32, torch.bfloat16]:
|
||||
X_base = torch.randn(shape, device="cuda").to(dtype)
|
||||
channel_size = X_base.size(axis)
|
||||
scale_base = torch.normal(mean=0, std=1, size=(channel_size,)).clamp(1e-4, 100).to(dtype)
|
||||
zero_point_base = torch.normal(mean=0, std=128, size=(channel_size,)).to(dtype)
|
||||
|
||||
self._test_learnable_forward_per_channel(
|
||||
X_base, 'cuda', scale_base, zero_point_base, axis)
|
||||
|
||||
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
||||
X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
|
||||
|
Reference in New Issue
Block a user