mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support boolean tensor for torch.fused_moving_avg_obs_fake_quant on CUDA (#153699)
Fixes #153310 As the title **Test plan** ``` pytest test/quantization/core/test_workflow_ops.py -k test_fused_obs_fake_quant_moving_avg ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153699 Approved by: https://github.com/mingfeima, https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
156b28e62a
commit
d9799a2ee7
@ -1056,9 +1056,9 @@ class TestFakeQuantizeOps(TestCase):
|
||||
|
||||
class TestFusedObsFakeQuant(TestCase):
|
||||
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
||||
symmetric_quant=st.booleans())
|
||||
symmetric_quant=st.booleans(), use_bool=st.booleans())
|
||||
@settings(deadline=None)
|
||||
def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant) -> None:
|
||||
def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant, use_bool) -> None:
|
||||
"""
|
||||
Tests the case where we call the fused_obs_fake_quant op multiple times
|
||||
and update the running_min and max of the activation tensors.
|
||||
@ -1070,15 +1070,15 @@ class TestFusedObsFakeQuant(TestCase):
|
||||
avg_const = 0.01
|
||||
scale = torch.tensor([1.0], device=device)
|
||||
zero_point = torch.tensor([0], dtype=torch.int, device=device)
|
||||
observer_on = fake_quant_on = 0
|
||||
observer_on = fake_quant_on = False if use_bool else 0
|
||||
|
||||
pt_op = torch.fused_moving_avg_obs_fake_quant
|
||||
# enable observer after 2 iterations and fake_quant after 4 iterations
|
||||
for i in range(10):
|
||||
if i > 2:
|
||||
observer_on = 1
|
||||
observer_on = True if use_bool else 1
|
||||
if i > 4:
|
||||
fake_quant_on = 1
|
||||
fake_quant_on = True if use_bool else 1
|
||||
|
||||
x = torch.randn(5, 5, device=device)
|
||||
out = pt_op(
|
||||
@ -1147,9 +1147,9 @@ class TestFusedObsFakeQuant(TestCase):
|
||||
self.assertEqual(out.shape, output_shape)
|
||||
|
||||
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
||||
symmetric_quant=st.booleans())
|
||||
symmetric_quant=st.booleans(), use_bool=st.booleans())
|
||||
@settings(deadline=None)
|
||||
def test_fused_obs_fake_quant_moving_avg_per_channel(self, device, symmetric_quant) -> None:
|
||||
def test_fused_obs_fake_quant_moving_avg_per_channel(self, device, symmetric_quant, use_bool) -> None:
|
||||
"""
|
||||
Tests the case where we call the fused_obs_fake_quant op multiple times
|
||||
and update the running_min and max of the activation tensors.
|
||||
@ -1166,15 +1166,15 @@ class TestFusedObsFakeQuant(TestCase):
|
||||
scale = torch.empty(m, device=device).fill_(0.1)
|
||||
zero_point = torch.empty(m, dtype=torch.int, device=device).fill_(0)
|
||||
|
||||
observer_on = fake_quant_on = 0
|
||||
observer_on = fake_quant_on = False if use_bool else 0
|
||||
|
||||
pt_op = torch.fused_moving_avg_obs_fake_quant
|
||||
# enable observer after 2 iterations and fake_quant after 4 iterations
|
||||
for i in range(10):
|
||||
if i > 2:
|
||||
observer_on = 1
|
||||
observer_on = True if use_bool else 1
|
||||
if i > 4:
|
||||
fake_quant_on = 1
|
||||
fake_quant_on = True if use_bool else 1
|
||||
|
||||
x = torch.randn(size, device=device)
|
||||
out = pt_op(
|
||||
|
Reference in New Issue
Block a user