mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Support boolean tensor for torch.fused_moving_avg_obs_fake_quant on CUDA (#153699)
Fixes #153310 As the title **Test plan** ``` pytest test/quantization/core/test_workflow_ops.py -k test_fused_obs_fake_quant_moving_avg ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153699 Approved by: https://github.com/mingfeima, https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
156b28e62a
commit
d9799a2ee7
@ -273,7 +273,7 @@ std::tuple<at::Tensor, at::Tensor> fused_moving_avg_obs_fake_quant_cuda(
|
|||||||
}
|
}
|
||||||
_calculate_moving_average(
|
_calculate_moving_average(
|
||||||
y,
|
y,
|
||||||
observer_on,
|
observer_on.to(at::kLong),
|
||||||
running_min,
|
running_min,
|
||||||
running_max,
|
running_max,
|
||||||
averaging_const,
|
averaging_const,
|
||||||
@ -282,7 +282,7 @@ std::tuple<at::Tensor, at::Tensor> fused_moving_avg_obs_fake_quant_cuda(
|
|||||||
} else {
|
} else {
|
||||||
_calculate_moving_average(
|
_calculate_moving_average(
|
||||||
x_contig,
|
x_contig,
|
||||||
observer_on,
|
observer_on.to(at::kLong),
|
||||||
running_min,
|
running_min,
|
||||||
running_max,
|
running_max,
|
||||||
averaging_const,
|
averaging_const,
|
||||||
@ -295,7 +295,7 @@ std::tuple<at::Tensor, at::Tensor> fused_moving_avg_obs_fake_quant_cuda(
|
|||||||
|
|
||||||
_calc_moving_avg_qparams_helper(
|
_calc_moving_avg_qparams_helper(
|
||||||
x_contig,
|
x_contig,
|
||||||
fake_quant_on,
|
fake_quant_on.to(at::kLong),
|
||||||
running_min,
|
running_min,
|
||||||
running_max,
|
running_max,
|
||||||
scale_ptr,
|
scale_ptr,
|
||||||
@ -316,7 +316,7 @@ std::tuple<at::Tensor, at::Tensor> fused_moving_avg_obs_fake_quant_cuda(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return at::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams(
|
return at::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams(
|
||||||
x, scale, zero_point, fake_quant_on, qmin, qmax);
|
x, scale, zero_point, fake_quant_on.to(at::kLong), qmin, qmax);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
@ -1056,9 +1056,9 @@ class TestFakeQuantizeOps(TestCase):
|
|||||||
|
|
||||||
class TestFusedObsFakeQuant(TestCase):
|
class TestFusedObsFakeQuant(TestCase):
|
||||||
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
||||||
symmetric_quant=st.booleans())
|
symmetric_quant=st.booleans(), use_bool=st.booleans())
|
||||||
@settings(deadline=None)
|
@settings(deadline=None)
|
||||||
def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant) -> None:
|
def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant, use_bool) -> None:
|
||||||
"""
|
"""
|
||||||
Tests the case where we call the fused_obs_fake_quant op multiple times
|
Tests the case where we call the fused_obs_fake_quant op multiple times
|
||||||
and update the running_min and max of the activation tensors.
|
and update the running_min and max of the activation tensors.
|
||||||
@ -1070,15 +1070,15 @@ class TestFusedObsFakeQuant(TestCase):
|
|||||||
avg_const = 0.01
|
avg_const = 0.01
|
||||||
scale = torch.tensor([1.0], device=device)
|
scale = torch.tensor([1.0], device=device)
|
||||||
zero_point = torch.tensor([0], dtype=torch.int, device=device)
|
zero_point = torch.tensor([0], dtype=torch.int, device=device)
|
||||||
observer_on = fake_quant_on = 0
|
observer_on = fake_quant_on = False if use_bool else 0
|
||||||
|
|
||||||
pt_op = torch.fused_moving_avg_obs_fake_quant
|
pt_op = torch.fused_moving_avg_obs_fake_quant
|
||||||
# enable observer after 2 iterations and fake_quant after 4 iterations
|
# enable observer after 2 iterations and fake_quant after 4 iterations
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
if i > 2:
|
if i > 2:
|
||||||
observer_on = 1
|
observer_on = True if use_bool else 1
|
||||||
if i > 4:
|
if i > 4:
|
||||||
fake_quant_on = 1
|
fake_quant_on = True if use_bool else 1
|
||||||
|
|
||||||
x = torch.randn(5, 5, device=device)
|
x = torch.randn(5, 5, device=device)
|
||||||
out = pt_op(
|
out = pt_op(
|
||||||
@ -1147,9 +1147,9 @@ class TestFusedObsFakeQuant(TestCase):
|
|||||||
self.assertEqual(out.shape, output_shape)
|
self.assertEqual(out.shape, output_shape)
|
||||||
|
|
||||||
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
|
||||||
symmetric_quant=st.booleans())
|
symmetric_quant=st.booleans(), use_bool=st.booleans())
|
||||||
@settings(deadline=None)
|
@settings(deadline=None)
|
||||||
def test_fused_obs_fake_quant_moving_avg_per_channel(self, device, symmetric_quant) -> None:
|
def test_fused_obs_fake_quant_moving_avg_per_channel(self, device, symmetric_quant, use_bool) -> None:
|
||||||
"""
|
"""
|
||||||
Tests the case where we call the fused_obs_fake_quant op multiple times
|
Tests the case where we call the fused_obs_fake_quant op multiple times
|
||||||
and update the running_min and max of the activation tensors.
|
and update the running_min and max of the activation tensors.
|
||||||
@ -1166,15 +1166,15 @@ class TestFusedObsFakeQuant(TestCase):
|
|||||||
scale = torch.empty(m, device=device).fill_(0.1)
|
scale = torch.empty(m, device=device).fill_(0.1)
|
||||||
zero_point = torch.empty(m, dtype=torch.int, device=device).fill_(0)
|
zero_point = torch.empty(m, dtype=torch.int, device=device).fill_(0)
|
||||||
|
|
||||||
observer_on = fake_quant_on = 0
|
observer_on = fake_quant_on = False if use_bool else 0
|
||||||
|
|
||||||
pt_op = torch.fused_moving_avg_obs_fake_quant
|
pt_op = torch.fused_moving_avg_obs_fake_quant
|
||||||
# enable observer after 2 iterations and fake_quant after 4 iterations
|
# enable observer after 2 iterations and fake_quant after 4 iterations
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
if i > 2:
|
if i > 2:
|
||||||
observer_on = 1
|
observer_on = True if use_bool else 1
|
||||||
if i > 4:
|
if i > 4:
|
||||||
fake_quant_on = 1
|
fake_quant_on = True if use_bool else 1
|
||||||
|
|
||||||
x = torch.randn(size, device=device)
|
x = torch.randn(size, device=device)
|
||||||
out = pt_op(
|
out = pt_op(
|
||||||
|
Reference in New Issue
Block a user