Moved .all() checks for distributions to _is_all_true (#145029)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145029
Approved by: https://github.com/Skylion007, https://github.com/zou3519
This commit is contained in:
chilli
2025-01-16 17:04:42 -08:00
committed by PyTorch MergeBot
parent 2bf772d1ba
commit 5e4cf3e6ad

View File

@ -68,7 +68,7 @@ class Distribution:
continue # skip checking lazily-constructed args
value = getattr(self, param)
valid = constraint.check(value)
if not valid.all():
if not torch._is_all_true(valid):
raise ValueError(
f"Expected parameter {param} "
f"({type(value).__name__} of shape {tuple(value.shape)}) "
@ -313,7 +313,7 @@ class Distribution:
return
assert support is not None
valid = support.check(value)
if not valid.all():
if not torch._is_all_true(valid):
raise ValueError(
"Expected value argument "
f"({type(value).__name__} of shape {tuple(value.shape)}) "