[Docs] Add Description of validate_args for torch.distributions (#152173)

Fixes #152165

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152173
Approved by: https://github.com/soulitzer
This commit is contained in:
Yuanhao Ji
2025-04-30 18:01:20 +00:00
committed by PyTorch MergeBot
parent 256c96332c
commit b027cb8f9e
3 changed files with 7 additions and 0 deletions

View File

@ -36,6 +36,7 @@ class Bernoulli(ExponentialFamily):
Args:
probs (Number, Tensor): the probability of sampling `1`
logits (Number, Tensor): the log-odds of sampling `1`
validate_args (bool, optional): whether to validate arguments, None by default
"""
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}

View File

@ -16,6 +16,11 @@ __all__ = ["Distribution"]
class Distribution:
r"""
Distribution is the abstract base class for probability distributions.
Args:
batch_shape (torch.Size): The shape over which parameters are batched.
event_shape (torch.Size): The shape of a single sample (without batching).
validate_args (bool, optional): Whether to validate arguments. Default: None.
"""
has_rsample = False

View File

@ -28,6 +28,7 @@ class Weibull(TransformedDistribution):
Args:
scale (float or Tensor): Scale parameter of distribution (lambda).
concentration (float or Tensor): Concentration parameter of distribution (k/shape).
validate_args (bool, optional): Whether to validate arguments. Default: None.
"""
arg_constraints = {