[pytorch] Make behavior of SobolEngine consistent w/ other RNG functions (#36427)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36427

Addresses https://github.com/pytorch/pytorch/issues/36341

Test Plan: unit tests

Reviewed By: ldworkin

Differential Revision: D20952703

fbshipit-source-id: 28055f4c4c0f8012c2d96e473b822fa455dd833c
This commit is contained in:
Max Balandat
2020-04-13 07:48:05 -07:00
committed by Facebook GitHub Bot
parent d2e0c628e9
commit 379e4d9cad
2 changed files with 14 additions and 2 deletions

View File

@ -3618,6 +3618,18 @@ class _TestTorchMixin(object):
engine_3d.fast_forward(2)
self.assertEqual(engine_3d.draw(5), actual_3d[5:])
def test_sobolengine_scrambled_lowdim_default_rng(self):
expected_1d = [0.039826, 0.484409, 0.953192, 0.799275, 0.267996]
torch.manual_seed(123456)
engine_1d = torch.quasirandom.SobolEngine(1, scramble=True)
actual_1d = engine_1d.draw(5)
self.assertEqual(actual_1d[:, 0], expected_1d)
torch.manual_seed(123456)
expected_3d = [0.133490, 0.480183, 0.855304, 0.970967, 0.345844]
engine_3d = torch.quasirandom.SobolEngine(3, scramble=True)
actual_3d = engine_3d.draw(5)
self.assertEqual(actual_3d[:, 0], expected_3d)
def test_sobolengine_scrambled_highdim(self):
engine = torch.quasirandom.SobolEngine(1111, scramble=True)
draws = engine.draw(1000)

View File

@ -57,11 +57,11 @@ class SobolEngine(object):
torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
if self.scramble:
g = torch.Generator()
if self.seed is not None:
g = torch.Generator()
g.manual_seed(self.seed)
else:
g.seed()
g = None
shift_ints = torch.randint(2, (self.dimension, self.MAXBIT), device=cpu, generator=g)
self.shift = torch.mv(shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu)))