[BE] enable UFMT for top-level files torch/*.py (#127707)

Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127707
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan
2024-06-12 20:10:47 +08:00
committed by PyTorch MergeBot
parent cc231a8e2b
commit dd143d44cc
15 changed files with 1548 additions and 875 deletions

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
import torch
from typing import Optional
import torch
class SobolEngine:
r"""
@ -48,8 +49,10 @@ class SobolEngine:
def __init__(self, dimension, scramble=False, seed=None):
if dimension > self.MAXDIM or dimension < 1:
raise ValueError("Supported range of dimensionality "
f"for SobolEngine is [1, {self.MAXDIM}]")
raise ValueError(
"Supported range of dimensionality "
f"for SobolEngine is [1, {self.MAXDIM}]"
)
self.seed = seed
self.scramble = scramble
@ -57,7 +60,9 @@ class SobolEngine:
cpu = torch.device("cpu")
self.sobolstate = torch.zeros(dimension, self.MAXBIT, device=cpu, dtype=torch.long)
self.sobolstate = torch.zeros(
dimension, self.MAXBIT, device=cpu, dtype=torch.long
)
torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
if not self.scramble:
@ -66,11 +71,15 @@ class SobolEngine:
self._scramble()
self.quasi = self.shift.clone(memory_format=torch.contiguous_format)
self._first_point = (self.quasi / 2 ** self.MAXBIT).reshape(1, -1)
self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1)
self.num_generated = 0
def draw(self, n: int = 1, out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
def draw(
self,
n: int = 1,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
r"""
Function to draw a sequence of :attr:`n` points from a Sobol sequence.
Note that the samples are dependent on the previous samples. The size
@ -92,12 +101,22 @@ class SobolEngine:
result = self._first_point.to(dtype)
else:
result, self.quasi = torch._sobol_engine_draw(
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype,
self.quasi,
n - 1,
self.sobolstate,
self.dimension,
self.num_generated,
dtype=dtype,
)
result = torch.cat((self._first_point.to(dtype), result), dim=-2)
else:
result, self.quasi = torch._sobol_engine_draw(
self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1, dtype=dtype,
self.quasi,
n,
self.sobolstate,
self.dimension,
self.num_generated - 1,
dtype=dtype,
)
self.num_generated += n
@ -108,8 +127,12 @@ class SobolEngine:
return result
def draw_base2(self, m: int, out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
def draw_base2(
self,
m: int,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
r"""
Function to draw a sequence of :attr:`2**m` points from a Sobol sequence.
Note that the samples are dependent on the previous samples. The size
@ -122,15 +145,16 @@ class SobolEngine:
returned tensor.
Default: ``None``
"""
n = 2 ** m
n = 2**m
total_n = self.num_generated + n
if not (total_n & (total_n - 1) == 0):
raise ValueError("The balance properties of Sobol' points require "
f"n to be a power of 2. {self.num_generated} points have been "
f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
"If you still want to do this, please use "
"'SobolEngine.draw()' instead."
)
raise ValueError(
"The balance properties of Sobol' points require "
f"n to be a power of 2. {self.num_generated} points have been "
f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
"If you still want to do this, please use "
"'SobolEngine.draw()' instead."
)
return self.draw(n=n, out=out, dtype=dtype)
def reset(self):
@ -151,9 +175,13 @@ class SobolEngine:
n (Int): The number of steps to fast-forward by.
"""
if self.num_generated == 0:
torch._sobol_engine_ff_(self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated)
torch._sobol_engine_ff_(
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated
)
else:
torch._sobol_engine_ff_(self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1)
torch._sobol_engine_ff_(
self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1
)
self.num_generated += n
return self
@ -166,8 +194,12 @@ class SobolEngine:
cpu = torch.device("cpu")
# Generate shift vector
shift_ints = torch.randint(2, (self.dimension, self.MAXBIT), device=cpu, generator=g)
self.shift = torch.mv(shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu)))
shift_ints = torch.randint(
2, (self.dimension, self.MAXBIT), device=cpu, generator=g
)
self.shift = torch.mv(
shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))
)
# Generate lower triangular matrices (stacked across dimensions)
ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT)
@ -176,9 +208,9 @@ class SobolEngine:
torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension)
def __repr__(self):
fmt_string = [f'dimension={self.dimension}']
fmt_string = [f"dimension={self.dimension}"]
if self.scramble:
fmt_string += ['scramble=True']
fmt_string += ["scramble=True"]
if self.seed is not None:
fmt_string += [f'seed={self.seed}']
return self.__class__.__name__ + '(' + ', '.join(fmt_string) + ')'
fmt_string += [f"seed={self.seed}"]
return self.__class__.__name__ + "(" + ", ".join(fmt_string) + ")"