Support boolean tensor for torch.fused_moving_avg_obs_fake_quant on CUDA (#153699)

Fixes #153310

As the title

**Test plan**
```
pytest test/quantization/core/test_workflow_ops.py -k test_fused_obs_fake_quant_moving_avg
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153699
Approved by: https://github.com/mingfeima, https://github.com/jerryzh168
This commit is contained in:
Xia, Weiwen
2025-06-16 07:10:06 +00:00
committed by PyTorch MergeBot
parent 156b28e62a
commit d9799a2ee7
2 changed files with 14 additions and 14 deletions

View File

@ -273,7 +273,7 @@ std::tuple<at::Tensor, at::Tensor> fused_moving_avg_obs_fake_quant_cuda(
} }
_calculate_moving_average( _calculate_moving_average(
y, y,
observer_on, observer_on.to(at::kLong),
running_min, running_min,
running_max, running_max,
averaging_const, averaging_const,
@ -282,7 +282,7 @@ std::tuple<at::Tensor, at::Tensor> fused_moving_avg_obs_fake_quant_cuda(
} else { } else {
_calculate_moving_average( _calculate_moving_average(
x_contig, x_contig,
observer_on, observer_on.to(at::kLong),
running_min, running_min,
running_max, running_max,
averaging_const, averaging_const,
@ -295,7 +295,7 @@ std::tuple<at::Tensor, at::Tensor> fused_moving_avg_obs_fake_quant_cuda(
_calc_moving_avg_qparams_helper( _calc_moving_avg_qparams_helper(
x_contig, x_contig,
fake_quant_on, fake_quant_on.to(at::kLong),
running_min, running_min,
running_max, running_max,
scale_ptr, scale_ptr,
@ -316,7 +316,7 @@ std::tuple<at::Tensor, at::Tensor> fused_moving_avg_obs_fake_quant_cuda(
} }
} else { } else {
return at::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams( return at::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams(
x, scale, zero_point, fake_quant_on, qmin, qmax); x, scale, zero_point, fake_quant_on.to(at::kLong), qmin, qmax);
} }
} }
} // namespace at::native } // namespace at::native

View File

@ -1056,9 +1056,9 @@ class TestFakeQuantizeOps(TestCase):
class TestFusedObsFakeQuant(TestCase): class TestFusedObsFakeQuant(TestCase):
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
symmetric_quant=st.booleans()) symmetric_quant=st.booleans(), use_bool=st.booleans())
@settings(deadline=None) @settings(deadline=None)
def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant) -> None: def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant, use_bool) -> None:
""" """
Tests the case where we call the fused_obs_fake_quant op multiple times Tests the case where we call the fused_obs_fake_quant op multiple times
and update the running_min and max of the activation tensors. and update the running_min and max of the activation tensors.
@ -1070,15 +1070,15 @@ class TestFusedObsFakeQuant(TestCase):
avg_const = 0.01 avg_const = 0.01
scale = torch.tensor([1.0], device=device) scale = torch.tensor([1.0], device=device)
zero_point = torch.tensor([0], dtype=torch.int, device=device) zero_point = torch.tensor([0], dtype=torch.int, device=device)
observer_on = fake_quant_on = 0 observer_on = fake_quant_on = False if use_bool else 0
pt_op = torch.fused_moving_avg_obs_fake_quant pt_op = torch.fused_moving_avg_obs_fake_quant
# enable observer after 2 iterations and fake_quant after 4 iterations # enable observer after 2 iterations and fake_quant after 4 iterations
for i in range(10): for i in range(10):
if i > 2: if i > 2:
observer_on = 1 observer_on = True if use_bool else 1
if i > 4: if i > 4:
fake_quant_on = 1 fake_quant_on = True if use_bool else 1
x = torch.randn(5, 5, device=device) x = torch.randn(5, 5, device=device)
out = pt_op( out = pt_op(
@ -1147,9 +1147,9 @@ class TestFusedObsFakeQuant(TestCase):
self.assertEqual(out.shape, output_shape) self.assertEqual(out.shape, output_shape)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
symmetric_quant=st.booleans()) symmetric_quant=st.booleans(), use_bool=st.booleans())
@settings(deadline=None) @settings(deadline=None)
def test_fused_obs_fake_quant_moving_avg_per_channel(self, device, symmetric_quant) -> None: def test_fused_obs_fake_quant_moving_avg_per_channel(self, device, symmetric_quant, use_bool) -> None:
""" """
Tests the case where we call the fused_obs_fake_quant op multiple times Tests the case where we call the fused_obs_fake_quant op multiple times
and update the running_min and max of the activation tensors. and update the running_min and max of the activation tensors.
@ -1166,15 +1166,15 @@ class TestFusedObsFakeQuant(TestCase):
scale = torch.empty(m, device=device).fill_(0.1) scale = torch.empty(m, device=device).fill_(0.1)
zero_point = torch.empty(m, dtype=torch.int, device=device).fill_(0) zero_point = torch.empty(m, dtype=torch.int, device=device).fill_(0)
observer_on = fake_quant_on = 0 observer_on = fake_quant_on = False if use_bool else 0
pt_op = torch.fused_moving_avg_obs_fake_quant pt_op = torch.fused_moving_avg_obs_fake_quant
# enable observer after 2 iterations and fake_quant after 4 iterations # enable observer after 2 iterations and fake_quant after 4 iterations
for i in range(10): for i in range(10):
if i > 2: if i > 2:
observer_on = 1 observer_on = True if use_bool else 1
if i > 4: if i > 4:
fake_quant_on = 1 fake_quant_on = True if use_bool else 1
x = torch.randn(size, device=device) x = torch.randn(size, device=device)
out = pt_op( out = pt_op(