torch.distributions: replace numbers.Number with torch.types.Number. (#145086)

Fixes #144788 (partial)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145086
Approved by: https://github.com/malfet
This commit is contained in:
Randolf Scholz
2025-01-27 20:24:52 +00:00
committed by PyTorch MergeBot
parent 2f8ad8f4b9
commit 64cd81712d
19 changed files with 55 additions and 77 deletions

View File

@ -1,12 +1,10 @@
# mypy: allow-untyped-defs
from numbers import Number
import torch
from torch import nan, Tensor
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
from torch.types import _size
from torch.types import _Number, _size
__all__ = ["Uniform"]
@ -54,7 +52,7 @@ class Uniform(Distribution):
def __init__(self, low, high, validate_args=None):
self.low, self.high = broadcast_all(low, high)
if isinstance(low, Number) and isinstance(high, Number):
if isinstance(low, _Number) and isinstance(high, _Number):
batch_shape = torch.Size()
else:
batch_shape = self.low.size()