mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix randint distribution for large max (#143787)
Fixes #ISSUE_NUMBER Similar to #143682, for large maximum values we were sampling integers via % and it doesn't provide uniform distribution. Here we limit the max skew to approx 1% (random32 is used for max values `<= 2**32 / 128`) This comes with significant perf penalty, especially for cuda, but it's a pretty bad bug, so we'll have to figure out what can be done to improve it. `torch.compile` has always been producing correct results for this, and it's performance is also significantly better than current eager (eager is ~660 GB/s on H100, torch.compile 1200 GB/s), so we have to figure out why torch.compile is better. `__launch_bounds__` slightly regress perf, so perhaps we can figure out how to specify them better, but it's only 20-30 GB/s, so the big difference is still unexplained. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143787 Approved by: https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
0e1675a89b
commit
ab1f627aa4
@ -3499,6 +3499,24 @@ class TestRandomTensorCreation(TestCase):
|
||||
self.assertTrue((res1 < 6).all().item())
|
||||
self.assertTrue((res1 >= 0).all().item())
|
||||
|
||||
|
||||
def test_randint_distribution(self, device):
|
||||
size = 1_000_000
|
||||
n_max = int(0.75 * 2 ** 32)
|
||||
n_bins = 8
|
||||
|
||||
def bin(index, max_size):
|
||||
return index // (max_size // n_bins)
|
||||
res = torch.randint(n_max, (size,), device=device)
|
||||
# histogram implemented for float only
|
||||
bins = bin(res, n_max).float().cpu()
|
||||
hist, _ = bins.histogram(8, range=(0, n_bins))
|
||||
expected_bin = res.shape[0] / 8
|
||||
expected_error = math.sqrt(expected_bin) / expected_bin * 3
|
||||
error = (hist - expected_bin).abs().max() / expected_bin
|
||||
self.assertTrue(error < expected_error)
|
||||
|
||||
|
||||
@dtypes(torch.half, torch.float, torch.bfloat16, torch.double,
|
||||
torch.complex32, torch.complex64, torch.complex128)
|
||||
def test_randn(self, device, dtype):
|
||||
|
Reference in New Issue
Block a user