diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index f6d02f2737c6..75ea50d24860 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -68,7 +68,7 @@ class Distribution: continue # skip checking lazily-constructed args value = getattr(self, param) valid = constraint.check(value) - if not valid.all(): + if not torch._is_all_true(valid): raise ValueError( f"Expected parameter {param} " f"({type(value).__name__} of shape {tuple(value.shape)}) " @@ -313,7 +313,7 @@ class Distribution: return assert support is not None valid = support.check(value) - if not valid.all(): + if not torch._is_all_true(valid): raise ValueError( "Expected value argument " f"({type(value).__name__} of shape {tuple(value.shape)}) "