Compare commits

...

5 Commits

Author SHA1 Message Date
30d470d92a Update
[ghstack-poisoned]
2025-11-19 14:55:17 +00:00
f232cd9dc3 Update (base update)
[ghstack-poisoned]
2025-11-19 14:55:17 +00:00
59d601043a Update
[ghstack-poisoned]
2025-11-17 15:12:41 +00:00
e03c076156 Update (base update)
[ghstack-poisoned]
2025-11-17 15:00:23 +00:00
54f2693e45 Update
[ghstack-poisoned]
2025-11-17 15:00:23 +00:00
21 changed files with 50 additions and 55 deletions

View File

@ -63,9 +63,9 @@ def type_casts(
):
@functools.wraps(f)
def inner(*args, **kwargs):
allowed_types = (
(Tensor, torch.types._Number) if include_non_tensor_args else (Tensor,)
) # type: ignore[arg-type]
allowed_types: tuple[type, ...] = (
(Tensor, torch.types.Number) if include_non_tensor_args else (Tensor,)
)
flat_args = [
x
for x in pytree.arg_tree_leaves(*args, **kwargs)

View File

@ -12,7 +12,7 @@ from torch.distributions.utils import (
probs_to_logits,
)
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.types import _Number, Number
from torch.types import Number
__all__ = ["Bernoulli"]
@ -56,12 +56,12 @@ class Bernoulli(ExponentialFamily):
"Either `probs` or `logits` must be specified, but not both."
)
if probs is not None:
is_scalar = isinstance(probs, _Number)
is_scalar = isinstance(probs, Number)
# pyrefly: ignore [read-only]
(self.probs,) = broadcast_all(probs)
else:
assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number)
is_scalar = isinstance(logits, Number)
# pyrefly: ignore [read-only]
(self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits

View File

@ -7,7 +7,7 @@ from torch.distributions import constraints
from torch.distributions.dirichlet import Dirichlet
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import broadcast_all
from torch.types import _Number, _size
from torch.types import _size, Number
__all__ = ["Beta"]
@ -45,7 +45,7 @@ class Beta(ExponentialFamily):
concentration0: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
if isinstance(concentration1, _Number) and isinstance(concentration0, _Number):
if isinstance(concentration1, Number) and isinstance(concentration0, Number):
concentration1_concentration0 = torch.tensor(
[float(concentration1), float(concentration0)]
)
@ -97,7 +97,7 @@ class Beta(ExponentialFamily):
@property
def concentration1(self) -> Tensor:
result = self._dirichlet.concentration[..., 0]
if isinstance(result, _Number):
if isinstance(result, Number):
return torch.tensor([result])
else:
return result
@ -105,7 +105,7 @@ class Beta(ExponentialFamily):
@property
def concentration0(self) -> Tensor:
result = self._dirichlet.concentration[..., 1]
if isinstance(result, _Number):
if isinstance(result, Number):
return torch.tensor([result])
else:
return result

View File

@ -7,7 +7,7 @@ from torch import inf, nan, Tensor
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
from torch.types import _Number, _size
from torch.types import _size, Number
__all__ = ["Cauchy"]
@ -43,7 +43,7 @@ class Cauchy(Distribution):
validate_args: Optional[bool] = None,
) -> None:
self.loc, self.scale = broadcast_all(loc, scale)
if isinstance(loc, _Number) and isinstance(scale, _Number):
if isinstance(loc, Number) and isinstance(scale, Number):
batch_shape = torch.Size()
else:
batch_shape = self.loc.size()

View File

@ -67,7 +67,7 @@ object.
"""
from torch.distributions import constraints, transforms
from torch.types import _Number
from torch.types import Number
__all__ = [
@ -220,10 +220,10 @@ def _transform_to_less_than(constraint):
def _transform_to_interval(constraint):
# Handle the special case of the unit interval.
lower_is_0 = (
isinstance(constraint.lower_bound, _Number) and constraint.lower_bound == 0
isinstance(constraint.lower_bound, Number) and constraint.lower_bound == 0
)
upper_is_1 = (
isinstance(constraint.upper_bound, _Number) and constraint.upper_bound == 1
isinstance(constraint.upper_bound, Number) and constraint.upper_bound == 1
)
if lower_is_0 and upper_is_1:
return transforms.SigmoidTransform()

View File

@ -14,7 +14,7 @@ from torch.distributions.utils import (
probs_to_logits,
)
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.types import _Number, _size, Number
from torch.types import _size, Number
__all__ = ["ContinuousBernoulli"]
@ -65,7 +65,7 @@ class ContinuousBernoulli(ExponentialFamily):
"Either `probs` or `logits` must be specified, but not both."
)
if probs is not None:
is_scalar = isinstance(probs, _Number)
is_scalar = isinstance(probs, Number)
# pyrefly: ignore [read-only]
(self.probs,) = broadcast_all(probs)
# validate 'probs' here if necessary as it is later clamped for numerical stability
@ -77,7 +77,7 @@ class ContinuousBernoulli(ExponentialFamily):
self.probs = clamp_probs(self.probs)
else:
assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number)
is_scalar = isinstance(logits, Number)
# pyrefly: ignore [read-only]
(self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits

View File

@ -6,7 +6,7 @@ from torch import Tensor
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import broadcast_all
from torch.types import _Number, _size
from torch.types import _size, Number
__all__ = ["Exponential"]
@ -55,7 +55,7 @@ class Exponential(ExponentialFamily):
validate_args: Optional[bool] = None,
) -> None:
(self.rate,) = broadcast_all(rate)
batch_shape = torch.Size() if isinstance(rate, _Number) else self.rate.size()
batch_shape = torch.Size() if isinstance(rate, Number) else self.rate.size()
super().__init__(batch_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):

View File

@ -7,7 +7,7 @@ from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.gamma import Gamma
from torch.distributions.utils import broadcast_all
from torch.types import _Number, _size
from torch.types import _size, Number
__all__ = ["FisherSnedecor"]
@ -44,7 +44,7 @@ class FisherSnedecor(Distribution):
self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
self._gamma2 = Gamma(self.df2 * 0.5, self.df2)
if isinstance(df1, _Number) and isinstance(df2, _Number):
if isinstance(df1, Number) and isinstance(df2, Number):
batch_shape = torch.Size()
else:
batch_shape = self.df1.size()

View File

@ -6,7 +6,7 @@ from torch import Tensor
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import broadcast_all
from torch.types import _Number, _size
from torch.types import _size, Number
__all__ = ["Gamma"]
@ -62,7 +62,7 @@ class Gamma(ExponentialFamily):
validate_args: Optional[bool] = None,
) -> None:
self.concentration, self.rate = broadcast_all(concentration, rate)
if isinstance(concentration, _Number) and isinstance(rate, _Number):
if isinstance(concentration, Number) and isinstance(rate, Number):
batch_shape = torch.Size()
else:
batch_shape = self.concentration.size()

View File

@ -12,7 +12,7 @@ from torch.distributions.utils import (
probs_to_logits,
)
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.types import _Number, Number
from torch.types import Number
__all__ = ["Geometric"]
@ -66,7 +66,7 @@ class Geometric(Distribution):
# pyrefly: ignore [read-only]
(self.logits,) = broadcast_all(logits)
probs_or_logits = probs if probs is not None else logits
if isinstance(probs_or_logits, _Number):
if isinstance(probs_or_logits, Number):
batch_shape = torch.Size()
else:
assert probs_or_logits is not None # helps mypy

View File

@ -9,7 +9,7 @@ from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import AffineTransform, ExpTransform
from torch.distributions.uniform import Uniform
from torch.distributions.utils import broadcast_all, euler_constant
from torch.types import _Number
from torch.types import Number
__all__ = ["Gumbel"]
@ -43,7 +43,7 @@ class Gumbel(TransformedDistribution):
) -> None:
self.loc, self.scale = broadcast_all(loc, scale)
finfo = torch.finfo(self.loc.dtype)
if isinstance(loc, _Number) and isinstance(scale, _Number):
if isinstance(loc, Number) and isinstance(scale, Number):
base_dist = Uniform(finfo.tiny, 1 - finfo.eps, validate_args=validate_args)
else:
base_dist = Uniform(

View File

@ -6,7 +6,7 @@ from torch import Tensor
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
from torch.types import _Number, _size
from torch.types import _size, Number
__all__ = ["Laplace"]
@ -56,7 +56,7 @@ class Laplace(Distribution):
validate_args: Optional[bool] = None,
) -> None:
self.loc, self.scale = broadcast_all(loc, scale)
if isinstance(loc, _Number) and isinstance(scale, _Number):
if isinstance(loc, Number) and isinstance(scale, Number):
batch_shape = torch.Size()
else:
batch_shape = self.loc.size()

View File

@ -7,7 +7,7 @@ from torch import Tensor
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import _standard_normal, broadcast_all
from torch.types import _Number, _size
from torch.types import _size, Number
__all__ = ["Normal"]
@ -60,7 +60,7 @@ class Normal(ExponentialFamily):
validate_args: Optional[bool] = None,
) -> None:
self.loc, self.scale = broadcast_all(loc, scale)
if isinstance(loc, _Number) and isinstance(scale, _Number):
if isinstance(loc, Number) and isinstance(scale, Number):
batch_shape = torch.Size()
else:
batch_shape = self.loc.size()
@ -92,9 +92,7 @@ class Normal(ExponentialFamily):
# pyrefly: ignore [unsupported-operation]
var = self.scale**2
log_scale = (
math.log(self.scale)
if isinstance(self.scale, _Number)
else self.scale.log()
math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log()
)
return (
-((value - self.loc) ** 2) / (2 * var)

View File

@ -6,7 +6,7 @@ from torch import Tensor
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import broadcast_all
from torch.types import _Number, Number
from torch.types import Number
__all__ = ["Poisson"]
@ -54,7 +54,7 @@ class Poisson(ExponentialFamily):
validate_args: Optional[bool] = None,
) -> None:
(self.rate,) = broadcast_all(rate)
if isinstance(rate, _Number):
if isinstance(rate, Number):
batch_shape = torch.Size()
else:
batch_shape = self.rate.size()

View File

@ -14,7 +14,7 @@ from torch.distributions.utils import (
logits_to_probs,
probs_to_logits,
)
from torch.types import _Number, _size, Number
from torch.types import _size, Number
__all__ = ["LogitRelaxedBernoulli", "RelaxedBernoulli"]
@ -57,12 +57,12 @@ class LogitRelaxedBernoulli(Distribution):
"Either `probs` or `logits` must be specified, but not both."
)
if probs is not None:
is_scalar = isinstance(probs, _Number)
is_scalar = isinstance(probs, Number)
# pyrefly: ignore [read-only]
(self.probs,) = broadcast_all(probs)
else:
assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number)
is_scalar = isinstance(logits, Number)
# pyrefly: ignore [read-only]
(self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits

View File

@ -19,7 +19,7 @@ from torch.distributions.utils import (
vec_to_tril_matrix,
)
from torch.nn.functional import pad, softplus
from torch.types import _Number
from torch.types import Number
__all__ = [
@ -796,14 +796,14 @@ class AffineTransform(Transform):
if not isinstance(other, AffineTransform):
return False
if isinstance(self.loc, _Number) and isinstance(other.loc, _Number):
if isinstance(self.loc, Number) and isinstance(other.loc, Number):
if self.loc != other.loc:
return False
else:
if not (self.loc == other.loc).all().item(): # type: ignore[union-attr]
return False
if isinstance(self.scale, _Number) and isinstance(other.scale, _Number):
if isinstance(self.scale, Number) and isinstance(other.scale, Number):
if self.scale != other.scale:
return False
else:
@ -814,7 +814,7 @@ class AffineTransform(Transform):
@property
def sign(self) -> Union[Tensor, int]: # type: ignore[override]
if isinstance(self.scale, _Number):
if isinstance(self.scale, Number):
return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0
return self.scale.sign()
@ -827,7 +827,7 @@ class AffineTransform(Transform):
def log_abs_det_jacobian(self, x, y):
shape = x.shape
scale = self.scale
if isinstance(scale, _Number):
if isinstance(scale, Number):
result = torch.full_like(x, math.log(abs(scale)))
else:
result = torch.abs(scale).log()

View File

@ -6,7 +6,7 @@ from torch import nan, Tensor
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
from torch.types import _Number, _size
from torch.types import _size, Number
__all__ = ["Uniform"]
@ -63,7 +63,7 @@ class Uniform(Distribution):
) -> None:
self.low, self.high = broadcast_all(low, high)
if isinstance(low, _Number) and isinstance(high, _Number):
if isinstance(low, Number) and isinstance(high, Number):
batch_shape = torch.Size()
else:
batch_shape = self.low.size()

View File

@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F
from torch import SymInt, Tensor
from torch.overrides import is_tensor_like
from torch.types import _dtype, _Number, Device, Number
from torch.types import _dtype, Device, Number
euler_constant: Final[float] = 0.57721566490153286060 # Euler Mascheroni Constant
@ -40,7 +40,7 @@ def broadcast_all(*values: Union[Tensor, Number]) -> tuple[Tensor, ...]:
ValueError: if any of the values is not a `Number` instance,
a `torch.*Tensor` instance, or an instance implementing __torch_function__
"""
if not all(is_tensor_like(v) or isinstance(v, _Number) for v in values):
if not all(is_tensor_like(v) or isinstance(v, Number) for v in values):
raise ValueError(
"Input arguments must all be instances of Number, "
"torch.Tensor or objects implementing __torch_function__."

View File

@ -9,7 +9,7 @@ from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.multivariate_normal import _precision_to_scale_tril
from torch.distributions.utils import lazy_property
from torch.types import _Number, _size, Number
from torch.types import _size, Number
__all__ = ["Wishart"]
@ -102,7 +102,7 @@ class Wishart(ExponentialFamily):
"scale_tril must be at least two-dimensional, with optional leading batch dimensions"
)
if isinstance(df, _Number):
if isinstance(df, Number):
batch_shape = torch.Size(param.shape[:-2])
self.df = torch.tensor(df, dtype=param.dtype, device=param.device)
else:

View File

@ -62,9 +62,6 @@ PySymType: TypeAlias = Union[SymInt, SymFloat, SymBool]
# Meta-type for "numeric" things; matches our docs
Number: TypeAlias = Union[int, float, bool]
# tuple for isinstance(x, Number) checks.
# FIXME: refactor once python 3.9 support is dropped.
_Number = (int, float, bool)
FileLike: TypeAlias = Union[str, os.PathLike[str], IO[bytes]]