mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
torch.distributions
: replace numbers.Number
with torch.types.Number
. (#145086)
Fixes #144788 (partial) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145086 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
2f8ad8f4b9
commit
64cd81712d
@ -1,12 +1,10 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from numbers import Number
|
||||
|
||||
import torch
|
||||
from torch import nan, Tensor
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.distribution import Distribution
|
||||
from torch.distributions.utils import broadcast_all
|
||||
from torch.types import _size
|
||||
from torch.types import _Number, _size
|
||||
|
||||
|
||||
__all__ = ["Uniform"]
|
||||
@ -54,7 +52,7 @@ class Uniform(Distribution):
|
||||
def __init__(self, low, high, validate_args=None):
|
||||
self.low, self.high = broadcast_all(low, high)
|
||||
|
||||
if isinstance(low, Number) and isinstance(high, Number):
|
||||
if isinstance(low, _Number) and isinstance(high, _Number):
|
||||
batch_shape = torch.Size()
|
||||
else:
|
||||
batch_shape = self.low.size()
|
||||
|
Reference in New Issue
Block a user