mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytorch] Make behavior of SobolEngine consistent w/ other RNG functions (#36427)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36427 Addresses https://github.com/pytorch/pytorch/issues/36341 Test Plan: unit tests Reviewed By: ldworkin Differential Revision: D20952703 fbshipit-source-id: 28055f4c4c0f8012c2d96e473b822fa455dd833c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
d2e0c628e9
commit
379e4d9cad
@ -3618,6 +3618,18 @@ class _TestTorchMixin(object):
|
||||
engine_3d.fast_forward(2)
|
||||
self.assertEqual(engine_3d.draw(5), actual_3d[5:])
|
||||
|
||||
def test_sobolengine_scrambled_lowdim_default_rng(self):
|
||||
expected_1d = [0.039826, 0.484409, 0.953192, 0.799275, 0.267996]
|
||||
torch.manual_seed(123456)
|
||||
engine_1d = torch.quasirandom.SobolEngine(1, scramble=True)
|
||||
actual_1d = engine_1d.draw(5)
|
||||
self.assertEqual(actual_1d[:, 0], expected_1d)
|
||||
torch.manual_seed(123456)
|
||||
expected_3d = [0.133490, 0.480183, 0.855304, 0.970967, 0.345844]
|
||||
engine_3d = torch.quasirandom.SobolEngine(3, scramble=True)
|
||||
actual_3d = engine_3d.draw(5)
|
||||
self.assertEqual(actual_3d[:, 0], expected_3d)
|
||||
|
||||
def test_sobolengine_scrambled_highdim(self):
|
||||
engine = torch.quasirandom.SobolEngine(1111, scramble=True)
|
||||
draws = engine.draw(1000)
|
||||
|
@ -57,11 +57,11 @@ class SobolEngine(object):
|
||||
torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
|
||||
|
||||
if self.scramble:
|
||||
g = torch.Generator()
|
||||
if self.seed is not None:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed)
|
||||
else:
|
||||
g.seed()
|
||||
g = None
|
||||
|
||||
shift_ints = torch.randint(2, (self.dimension, self.MAXBIT), device=cpu, generator=g)
|
||||
self.shift = torch.mv(shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu)))
|
||||
|
Reference in New Issue
Block a user