mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Moved .all() checks for distributions to _is_all_true (#145029)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145029 Approved by: https://github.com/Skylion007, https://github.com/zou3519
This commit is contained in:
@ -68,7 +68,7 @@ class Distribution:
|
||||
continue # skip checking lazily-constructed args
|
||||
value = getattr(self, param)
|
||||
valid = constraint.check(value)
|
||||
if not valid.all():
|
||||
if not torch._is_all_true(valid):
|
||||
raise ValueError(
|
||||
f"Expected parameter {param} "
|
||||
f"({type(value).__name__} of shape {tuple(value.shape)}) "
|
||||
@ -313,7 +313,7 @@ class Distribution:
|
||||
return
|
||||
assert support is not None
|
||||
valid = support.check(value)
|
||||
if not valid.all():
|
||||
if not torch._is_all_true(valid):
|
||||
raise ValueError(
|
||||
"Expected value argument "
|
||||
f"({type(value).__name__} of shape {tuple(value.shape)}) "
|
||||
|
Reference in New Issue
Block a user