Support boolean tensor for torch.fused_moving_avg_obs_fake_quant on CUDA (#153699)

Fixes #153310

As the title

**Test plan**
```
pytest test/quantization/core/test_workflow_ops.py -k test_fused_obs_fake_quant_moving_avg
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153699
Approved by: https://github.com/mingfeima, https://github.com/jerryzh168
This commit is contained in:
Xia, Weiwen
2025-06-16 07:10:06 +00:00
committed by PyTorch MergeBot
parent 156b28e62a
commit d9799a2ee7
2 changed files with 14 additions and 14 deletions

View File

@ -1056,9 +1056,9 @@ class TestFakeQuantizeOps(TestCase):
class TestFusedObsFakeQuant(TestCase):
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
symmetric_quant=st.booleans())
symmetric_quant=st.booleans(), use_bool=st.booleans())
@settings(deadline=None)
def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant) -> None:
def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant, use_bool) -> None:
"""
Tests the case where we call the fused_obs_fake_quant op multiple times
and update the running_min and max of the activation tensors.
@ -1070,15 +1070,15 @@ class TestFusedObsFakeQuant(TestCase):
avg_const = 0.01
scale = torch.tensor([1.0], device=device)
zero_point = torch.tensor([0], dtype=torch.int, device=device)
observer_on = fake_quant_on = 0
observer_on = fake_quant_on = False if use_bool else 0
pt_op = torch.fused_moving_avg_obs_fake_quant
# enable observer after 2 iterations and fake_quant after 4 iterations
for i in range(10):
if i > 2:
observer_on = 1
observer_on = True if use_bool else 1
if i > 4:
fake_quant_on = 1
fake_quant_on = True if use_bool else 1
x = torch.randn(5, 5, device=device)
out = pt_op(
@ -1147,9 +1147,9 @@ class TestFusedObsFakeQuant(TestCase):
self.assertEqual(out.shape, output_shape)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
symmetric_quant=st.booleans())
symmetric_quant=st.booleans(), use_bool=st.booleans())
@settings(deadline=None)
def test_fused_obs_fake_quant_moving_avg_per_channel(self, device, symmetric_quant) -> None:
def test_fused_obs_fake_quant_moving_avg_per_channel(self, device, symmetric_quant, use_bool) -> None:
"""
Tests the case where we call the fused_obs_fake_quant op multiple times
and update the running_min and max of the activation tensors.
@ -1166,15 +1166,15 @@ class TestFusedObsFakeQuant(TestCase):
scale = torch.empty(m, device=device).fill_(0.1)
zero_point = torch.empty(m, dtype=torch.int, device=device).fill_(0)
observer_on = fake_quant_on = 0
observer_on = fake_quant_on = False if use_bool else 0
pt_op = torch.fused_moving_avg_obs_fake_quant
# enable observer after 2 iterations and fake_quant after 4 iterations
for i in range(10):
if i > 2:
observer_on = 1
observer_on = True if use_bool else 1
if i > 4:
fake_quant_on = 1
fake_quant_on = True if use_bool else 1
x = torch.randn(size, device=device)
out = pt_op(