[SobolEngine] Fix edge case of dtype of first sample (#51578)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51578

https://github.com/pytorch/pytorch/pull/49710 introduced an edge case in which
drawing a single sample resulted in ignoring the `dtype` arg to `draw`. This
fixes this and adds a unit test to cover this behavior.

Test Plan: Unit tests

Reviewed By: danielrjiang

Differential Revision: D26204393

fbshipit-source-id: 441a44dc035002e7bbe6b662bf6d1af0e2cd88f4
This commit is contained in:
Max Balandat
2021-02-02 14:22:35 -08:00
committed by Facebook GitHub Bot
parent 4746b3d1fb
commit a990ff7001
2 changed files with 13 additions and 1 deletions

View File

@ -83,7 +83,7 @@ class SobolEngine(object):
"""
if self.num_generated == 0:
if n == 1:
result = self._first_point
result = self._first_point.to(dtype)
else:
result, self.quasi = torch._sobol_engine_draw(
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype,