mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[SobolEngine] Fix edge case of dtype of first sample (#51578)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51578 https://github.com/pytorch/pytorch/pull/49710 introduced an edge case in which drawing a single sample resulted in ignoring the `dtype` arg to `draw`. This fixes this and adds a unit test to cover this behavior. Test Plan: Unit tests Reviewed By: danielrjiang Differential Revision: D26204393 fbshipit-source-id: 441a44dc035002e7bbe6b662bf6d1af0e2cd88f4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4746b3d1fb
commit
a990ff7001
@ -83,7 +83,7 @@ class SobolEngine(object):
|
||||
"""
|
||||
if self.num_generated == 0:
|
||||
if n == 1:
|
||||
result = self._first_point
|
||||
result = self._first_point.to(dtype)
|
||||
else:
|
||||
result, self.quasi = torch._sobol_engine_draw(
|
||||
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype,
|
||||
|
||||
Reference in New Issue
Block a user