mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] enable UFMT for top-level files torch/*.py
(#127707)
Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127707 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
cc231a8e2b
commit
dd143d44cc
@ -1,7 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class SobolEngine:
|
||||
r"""
|
||||
@ -48,8 +49,10 @@ class SobolEngine:
|
||||
|
||||
def __init__(self, dimension, scramble=False, seed=None):
|
||||
if dimension > self.MAXDIM or dimension < 1:
|
||||
raise ValueError("Supported range of dimensionality "
|
||||
f"for SobolEngine is [1, {self.MAXDIM}]")
|
||||
raise ValueError(
|
||||
"Supported range of dimensionality "
|
||||
f"for SobolEngine is [1, {self.MAXDIM}]"
|
||||
)
|
||||
|
||||
self.seed = seed
|
||||
self.scramble = scramble
|
||||
@ -57,7 +60,9 @@ class SobolEngine:
|
||||
|
||||
cpu = torch.device("cpu")
|
||||
|
||||
self.sobolstate = torch.zeros(dimension, self.MAXBIT, device=cpu, dtype=torch.long)
|
||||
self.sobolstate = torch.zeros(
|
||||
dimension, self.MAXBIT, device=cpu, dtype=torch.long
|
||||
)
|
||||
torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
|
||||
|
||||
if not self.scramble:
|
||||
@ -66,11 +71,15 @@ class SobolEngine:
|
||||
self._scramble()
|
||||
|
||||
self.quasi = self.shift.clone(memory_format=torch.contiguous_format)
|
||||
self._first_point = (self.quasi / 2 ** self.MAXBIT).reshape(1, -1)
|
||||
self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1)
|
||||
self.num_generated = 0
|
||||
|
||||
def draw(self, n: int = 1, out: Optional[torch.Tensor] = None,
|
||||
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
||||
def draw(
|
||||
self,
|
||||
n: int = 1,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Function to draw a sequence of :attr:`n` points from a Sobol sequence.
|
||||
Note that the samples are dependent on the previous samples. The size
|
||||
@ -92,12 +101,22 @@ class SobolEngine:
|
||||
result = self._first_point.to(dtype)
|
||||
else:
|
||||
result, self.quasi = torch._sobol_engine_draw(
|
||||
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype,
|
||||
self.quasi,
|
||||
n - 1,
|
||||
self.sobolstate,
|
||||
self.dimension,
|
||||
self.num_generated,
|
||||
dtype=dtype,
|
||||
)
|
||||
result = torch.cat((self._first_point.to(dtype), result), dim=-2)
|
||||
else:
|
||||
result, self.quasi = torch._sobol_engine_draw(
|
||||
self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1, dtype=dtype,
|
||||
self.quasi,
|
||||
n,
|
||||
self.sobolstate,
|
||||
self.dimension,
|
||||
self.num_generated - 1,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
self.num_generated += n
|
||||
@ -108,8 +127,12 @@ class SobolEngine:
|
||||
|
||||
return result
|
||||
|
||||
def draw_base2(self, m: int, out: Optional[torch.Tensor] = None,
|
||||
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
||||
def draw_base2(
|
||||
self,
|
||||
m: int,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Function to draw a sequence of :attr:`2**m` points from a Sobol sequence.
|
||||
Note that the samples are dependent on the previous samples. The size
|
||||
@ -122,15 +145,16 @@ class SobolEngine:
|
||||
returned tensor.
|
||||
Default: ``None``
|
||||
"""
|
||||
n = 2 ** m
|
||||
n = 2**m
|
||||
total_n = self.num_generated + n
|
||||
if not (total_n & (total_n - 1) == 0):
|
||||
raise ValueError("The balance properties of Sobol' points require "
|
||||
f"n to be a power of 2. {self.num_generated} points have been "
|
||||
f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
|
||||
"If you still want to do this, please use "
|
||||
"'SobolEngine.draw()' instead."
|
||||
)
|
||||
raise ValueError(
|
||||
"The balance properties of Sobol' points require "
|
||||
f"n to be a power of 2. {self.num_generated} points have been "
|
||||
f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
|
||||
"If you still want to do this, please use "
|
||||
"'SobolEngine.draw()' instead."
|
||||
)
|
||||
return self.draw(n=n, out=out, dtype=dtype)
|
||||
|
||||
def reset(self):
|
||||
@ -151,9 +175,13 @@ class SobolEngine:
|
||||
n (Int): The number of steps to fast-forward by.
|
||||
"""
|
||||
if self.num_generated == 0:
|
||||
torch._sobol_engine_ff_(self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated)
|
||||
torch._sobol_engine_ff_(
|
||||
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated
|
||||
)
|
||||
else:
|
||||
torch._sobol_engine_ff_(self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1)
|
||||
torch._sobol_engine_ff_(
|
||||
self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1
|
||||
)
|
||||
self.num_generated += n
|
||||
return self
|
||||
|
||||
@ -166,8 +194,12 @@ class SobolEngine:
|
||||
cpu = torch.device("cpu")
|
||||
|
||||
# Generate shift vector
|
||||
shift_ints = torch.randint(2, (self.dimension, self.MAXBIT), device=cpu, generator=g)
|
||||
self.shift = torch.mv(shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu)))
|
||||
shift_ints = torch.randint(
|
||||
2, (self.dimension, self.MAXBIT), device=cpu, generator=g
|
||||
)
|
||||
self.shift = torch.mv(
|
||||
shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))
|
||||
)
|
||||
|
||||
# Generate lower triangular matrices (stacked across dimensions)
|
||||
ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT)
|
||||
@ -176,9 +208,9 @@ class SobolEngine:
|
||||
torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = [f'dimension={self.dimension}']
|
||||
fmt_string = [f"dimension={self.dimension}"]
|
||||
if self.scramble:
|
||||
fmt_string += ['scramble=True']
|
||||
fmt_string += ["scramble=True"]
|
||||
if self.seed is not None:
|
||||
fmt_string += [f'seed={self.seed}']
|
||||
return self.__class__.__name__ + '(' + ', '.join(fmt_string) + ')'
|
||||
fmt_string += [f"seed={self.seed}"]
|
||||
return self.__class__.__name__ + "(" + ", ".join(fmt_string) + ")"
|
||||
|
Reference in New Issue
Block a user