39 Commits

Author SHA1 Message Date
b13cd141b3 Add pyrefly suppressions (#164748)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the `project-excludes` field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:

0 errors (4,263 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164748
Approved by: https://github.com/oulgen
2025-10-07 17:31:18 +00:00
6c38b9be73 [typing] Add type hints to __init__ methods in torch.distributions. (#144197)
Fixes #144196
Extends #144106 and #144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-04-06 17:50:35 +00:00
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
628acc4ace Dirichlet.mode: use dim= instead of axis= (#144402)
`axis=` is undocumented and will raise typing errors when #144197 is merged.

See: https://github.com/pytorch/pytorch/pull/144197#pullrequestreview-2537398866

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144402
Approved by: https://github.com/Skylion007
2025-01-08 20:14:01 +00:00
355b0bc7e3 [typing] Add type hints to @property and @lazy_property in torch.distributions. (#144110)
Fixes #76772, #144196
Extends #144106

- added type annotations to `lazy_property`.
- added type annotation to all `@property` and `@lazy_property` inside `torch.distributions` module.
- added simply type-check unit test to ensure type inference is working.
- replaced deprecated annotations like `typing.List` with the corresponding counterpart.
- simplified `torch.Tensor` hints with plain `Tensor`, otherwise signatures can become very verbose.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144110
Approved by: https://github.com/Skylion007
2025-01-07 19:27:36 +00:00
e72e924eb5 Add correct typing annotations to rsample() for all distributions (#133516)
Fixes #133514
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133516
Approved by: https://github.com/Skylion007
2024-08-18 20:31:54 +00:00
b25ef91bf1 [BE][Easy][18/19] enforce style for empty lines in import segments in torch/d*/ (#129770)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129770
Approved by: https://github.com/wconstab
2024-08-01 04:22:50 +00:00
7c12cc7ce4 Flip default value for mypy disallow_untyped_defs [6/11] (#127843)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843
Approved by: https://github.com/oulgen
ghstack dependencies: #127842
2024-06-08 18:49:29 +00:00
e08577aec5 Spelling fix (#108490)
Fixes spelling mistake: non-deterinistic -> non-deterministic
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108490
Approved by: https://github.com/ezyang
2023-09-04 16:59:35 +00:00
3bf922a6ce Apply UFMT to low traffic torch modules (#106249)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249
Approved by: https://github.com/Skylion007
2023-07-29 23:37:30 +00:00
e75f7994e1 Fix Dirichlet.log_prob() when x=0 and alpha=1 (#103605)
`Dirichlet.log_prob()` incorrectly returns NaN in the case where $x_i=0$ and $\alpha_i=1$.  The Dirichlet PDF is given by:
$$\frac{1}{B(\alpha)} \prod_{i=1}^{K} x_i^{\alpha_i - 1}$$
So this corresponds to the case where one of the terms has the form $0^0=1$. The logarithm of such a term should be 0, but you get NaN if you try to calculate it as `0 * log(0)`.

This PR implements the same algorithm that `scipy.stats.dirichlet` uses to avoid this behavior, namely `xlogy(alpha - 1, x)` instead of `(alpha - 1) * log(x)`.  It also adds a test case comparing the pytorch and scipy implementations for this specific case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103605
Approved by: https://github.com/albanD
2023-06-15 16:16:50 +00:00
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00
40feeea500 Fix typo in dirichlet.py example (#82062)
### Description
<!-- What did you change and why was it needed? -->

### Issue
<!-- Link to Issue ticket or RFP -->

### Testing
<!-- How did you test your change? -->

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82062
Approved by: https://github.com/kit1980
2022-07-23 22:30:12 +00:00
3bcc19b29a Add __all__ to various submodules in torch.fx, distributions, distributed, package (#80367)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80367
Approved by: https://github.com/albanD
2022-06-27 21:27:30 +00:00
40576bceaf Add mode property to distributions. (#76690)
This PR fixes #69466 and introduces some other minor changes. Tests are somewhat more involved because a reference implementation in `scipy` is not available; tests proceed differently for discrete and continuous distributions.

For continuous distributions, we evaluate the gradient of the `log_prob` at the mode. Tests pass if the gradient is zero OR (the mode is at the boundary of the support of the distribution AND the `log_prob` decreases as we move away from the boundary to the interior of the support).

For discrete distributions, the notion of a gradient is not well defined. We thus "look" ahead and behind one step (e.g. if the mode of a Poisson distribution is 9, we consider 8 and 10). If the step ahead/behind is still within the support of the distribution, we assert that the `log_prob` is smaller than at the mode.

For one-hot encoded distributions (currently just `OneHotCategorical`), we evaluate the underlying mode (i.e. encoded as an integral tensor), "advance" by one label to get another sample that should have lower probability using `other = (mode + 1) % event_size` and re-encode as one-hot. The resultant `other` sample should have lower probability than the mode.

Furthermore, Gamma, half Cauchy, and half normal distributions have their support changed from positive to nonnegative. This change is necessary because the mode of the "half" distributions is zero, and the mode of the gamma distribution is zero for `concentration <= 1`.

cc @fritzo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76690
Approved by: https://github.com/neerajprad
2022-05-11 18:26:56 +00:00
6465793011 Fix Dirichlet.arg_constraints event_dim (#51369)
Summary:
This fix ensures
```py
Dirichlet.arg_constraints["concentration"].event_dim == 1
```
which was missed in https://github.com/pytorch/pytorch/issues/50547

## Tested
- [x] added a regression test, covering all distributions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/51369

Reviewed By: H-Huang

Differential Revision: D26160644

Pulled By: neerajprad

fbshipit-source-id: 1bb44c79480a1f0052b0ef9d4605e750ab07bea1
2021-02-02 10:26:45 -08:00
9a153412fd Fix underflow issue with dirichlet sample (#17488)
Summary:
Addresses #15738, using fritzo's suggestion. This adds a `torch._sample_dirichlet` method in `Distributions.cpp` and `Distributions.cu`.
 - For CPU, this leads to no perf hit since all we do is to promote the `alpha` to double when getting the gamma samples (the gamma sampler anyways uses `accscalar_t`(double for CPU)) and cast it back to float32 on return.
 - I have added an analogous method for CUDA as well, but the default sampler for CUDA uses scalar_t for efficiency, so I have kept it as that. With this, I do not see the bias towards 1 as reported in #15738 with `float32`, but there is a spurious mode at 0.5, as would be expected. Users would need to explicitly use `float64` for GPU to not see the spurious mode at 0.5. (EDIT: see note below, it appears that the bias issue is still there for certain builds).

Added some tests and checked that there is no perf regression. My experience with C++ is very limited, so apologies in advance if I missed something basic. cc. ailzhang, fritzo, fmassa
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17488

Differential Revision: D14410301

Pulled By: ezyang

fbshipit-source-id: 62b2f694b4642685eab06db96d74ce28e05c3992
2019-03-19 10:34:13 -07:00
6911ce19d7 Remove _finfo; replace _finfo usage with torch.finfo (#15165)
Summary:
This PR removes the usage of _finfo defined in torch.distributions.utils and changes the call sites
to use torch.finfo instead

Differential Revision: D13451936

Pulled By: soumith

fbshipit-source-id: 6dbda3a6179d9407bc3396bf1a2baf3e85bc4cf2
2018-12-13 14:30:27 -08:00
cda71e2600 Disallow scalar parameters in Dirichlet and Categorical (#11589)
Summary:
This adds a small check in `Dirichlet` and `Categorical` `__init__` methods to ensure that scalar parameters are not admissible.

**Motivation**
Currently, `Dirichlet` throws no error when provided with a scalar parameter, but if we `expand` a scalar instance, it inherits the empty event shape from the original instance and gives unexpected results.

The alternative to this check is to promote `event_shape` to be `torch.Size((1,))` if the original instance was a scalar, but that seems to add a bit more complexity (and changes the behavior of `expand` in that it would affect the `event_shape` as well as the `batch_shape` now). Does this seem reasonable? cc. alicanb, fritzo.

```python
In [4]: d = dist.Dirichlet(torch.tensor(1.))

In [5]: d.sample()
Out[5]: tensor(1.0000)

In [6]: d.log_prob(d.sample())
Out[6]: tensor(0.)

In [7]: e = d.expand([3])

In [8]: e.sample()
Out[8]: tensor([0.3953, 0.1797, 0.4250])  # interpreted as events

In [9]: e.log_prob(e.sample())
Out[9]: tensor(0.6931)  # wrongly summed out

In [10]: e.batch_shape
Out[10]: torch.Size([3])

In [11]: e.event_shape
Out[11]: torch.Size([])  # cannot be empty
```

Additionally, based on review comments, this removes `real_vector` constraint. This was only being used in `MultivariateNormal`, but I am happy to revert this if we want to keep it around for backwards compatibility.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11589

Differential Revision: D9818271

Pulled By: soumith

fbshipit-source-id: f9bbba90ed6f04e0b5bdfa169e70ca20b280fc74
2018-09-14 07:55:35 -07:00
80fa8e1007 Add .expand() method to distribution classes (#11341)
Summary:
This adds a `.expand` method for distributions that is akin to the `torch.Tensor.expand` method for tensors. It returns a new distribution instance with batch dimensions expanded to the desired `batch_shape`. Since this calls `torch.Tensor.expand` on the distribution's parameters, it does not allocate new memory for the expanded distribution instance's parameters.

e.g.
```python
>>> d = dist.Normal(torch.zeros(100, 1), torch.ones(100, 1))
>>> d.sample().shape
  torch.Size([100, 1])
>>> d.expand([100, 10]).sample().shape
  torch.Size([100, 10])
```

We have already been using the `.expand` method in Pyro in our [patch](https://github.com/uber/pyro/blob/dev/pyro/distributions/torch.py#L10) of `torch.distributions`. We use this in our models to enable dynamic broadcasting. This has also been requested by a few users on the distributions slack, and we believe will be useful to the larger community.

Note that currently, there is no convenient and efficient way to expand distribution instances:
 - Many distributions use `TransformedDistribution` (or wrap over another distribution instance. e.g. `OneHotCategorical` uses a `Categorical` instance) under the hood, or have lazy parameters. This makes it difficult to collect all the relevant parameters, broadcast them and construct new instances.
 - In the few cases where this is even possible, the resulting implementation would be inefficient since we will go through a lot of broadcasting and args validation logic in `__init__.py` that can be avoided.

The `.expand` method allows for a safe and efficient way to expand distribution instances. Additionally, this bypasses `__init__.py` (using `__new__` and populating relevant attributes) since we do not need to do any broadcasting or args validation (which was already done when the instance was first created). This can result in significant savings as compared to constructing new instances via `__init__` (that said, the `sample` and `log_prob` methods will probably be the rate determining steps in many applications).

e.g.
```python
>>> a = dist.Bernoulli(torch.ones([10000, 1]), validate_args=True)

>>> %timeit a.expand([10000, 100])
15.2 µs ± 224 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

>>> %timeit dist.Bernoulli(torch.ones([10000, 100]), validate_args=True)
11.8 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```

cc. fritzo, apaszke, vishwakftw, alicanb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11341

Differential Revision: D9728485

Pulled By: soumith

fbshipit-source-id: 3b94c23bc6a43ee704389e6287aa83d1e278d52f
2018-09-11 06:56:18 -07:00
f940af6293 Bag of Distributions doc fixes (#10894)
Summary:
- Added `__repr__` for Constraints and Transforms.
- Arguments passed to the constructor are now rendered with :attr:

Closes https://github.com/pytorch/pytorch/issues/10884
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10894

Differential Revision: D9514161

Pulled By: apaszke

fbshipit-source-id: 4abf60335d876449f2b6477eb9655afed9d5b80b
2018-08-27 09:55:27 -07:00
3cbaa6b785 [ready] Clean up torch.distributions (#8046) 2018-06-02 16:54:53 +02:00
d564ecb4a5 Update docs with new tensor repr (#6454)
* Update docs with new tensor repr

* remove cuda in dtype

* remove changes to gloo submodule

* [docs] document tensor.new_* ctor

* [docs] Add docs for tensor.to(), tensor.float(), etc

* [docs] Moar examples for docs.

* [docs] Warning for tensor ctor copy behavior

* Quick fix

* [docs] Document requires_grad_()

* [docs] Add example for requires_grad_()

* update slogdet and *fft

* update tensor rst

* small fixes

* update some docs

* additional doc changes

* update torch and tensor docs

* finish changing tensor docs

* fix flake8

* slogdet with negative det

* Update functional.py tensor ctors

* Fix nll_loss docs

* reorder to move device up

* torch.LongTensor -> torch.tensor or torch.empty in docs

* update tensor constructors in docs

* change tensor constructors

* change constructors

* change more Tensor() to tensor()

* Show requires_grads_ docs

* Fix set_default_dtype docs

* Update docs with new tensor repr

* remove cuda in dtype

* remove changes to gloo submodule

* [docs] document tensor.new_* ctor

* [docs] Add docs for tensor.to(), tensor.float(), etc

* [docs] Moar examples for docs.

* [docs] Warning for tensor ctor copy behavior

* Quick fix

* [docs] Document requires_grad_()

* [docs] Add example for requires_grad_()

* update slogdet and *fft

* update tensor rst

* small fixes

* update some docs

* additional doc changes

* update torch and tensor docs

* finish changing tensor docs

* fix flake8

* slogdet with negative det

* Update functional.py tensor ctors

* Fix nll_loss docs

* reorder to move device up

* torch.LongTensor -> torch.tensor or torch.empty in docs

* update tensor constructors in docs

* change tensor constructors

* change constructors

* change more Tensor() to tensor()

* Show requires_grads_ docs

* Fix set_default_dtype docs

* Link to torch.no_grad, etc, from torch doc

* Add dtype aliases to table

* regen docs again

* Tensor attributes stub page

* link to inplace sampling

* Link torch.dtype, device, and layout

* fix dots after nonfinite floats

* better layout docs
2018-04-21 07:35:37 -04:00
1d51dd8665 [distributions] Fix Independent.rsample() and add more tests (#6806) 2018-04-20 21:55:39 +02:00
1c01eabd3c Codemod to update our codebase to 0.4 standard (#6641)
* Codemod to update our codebase to 0.4 standard

* Update some of the test scri[ts

* remove Variable in test_clip_grad_value

* fix _symbolic_override_wrapper_maker
2018-04-17 22:06:54 -04:00
b2da9fd220 [distributions] Rename .params to .arg_constraints, fix logic (#5989) 2018-03-25 15:24:32 +02:00
7f864bbe52 Fixed distribution constraints and added some test cases for distributions parameter check (#5358) 2018-03-15 23:11:20 +01:00
54b4cdeffa Replace all uses of 'Tensor or Variable' with 'Tensor' (#5508)
Replace all uses of 'Tensor or Variable'  and 'Variable or Tensor' with 'Tensor'
2018-03-02 14:26:11 -05:00
48a3349c29 Delete dead Tensor code paths (#5417)
This deletes most of the dead Tensor code paths, including the TensorMethods cwrap and generic/Tensor.cpp.

This also moves the THNN.cwrap/.cpp generation to generate_code which can use ninja if installed.
2018-02-27 17:58:09 -05:00
85a7e0fc41 Addition of ExponentialFamily (#4876) 2018-02-04 12:18:28 +01:00
20fbdb9a8b Adding mean, variance, stddev to distributions (#4923) 2018-01-31 00:26:32 +01:00
8c2d35c754 Refactor distributions (#4688) 2018-01-17 11:58:08 +01:00
71b1120ba8 Fix bug in Dirichlet.rsample(); add tests (#4602)
* Fix bug in Dirichlet.rsample(); add tests

* Address review comments
2018-01-11 12:29:10 -05:00
8cff8e93d2 Add torch.distributions.utils._finfo for numerical stability (#4572)
* Add torch.distributions.utils.finfo

* Make _finfo private

* Address review comments

* Simplify _finfo() to key on Storage type
2018-01-10 21:42:47 -05:00
a3e91515de Declare constraints for distribution parameters and support (#4450) 2018-01-04 23:58:26 +01:00
35abc4efa2 Add low-precision digamma() and polygamma() functions (#4399) 2018-01-02 11:53:23 +01:00
0bc1505f34 Implement .entropy() methods for all distributions (#4268) 2017-12-20 14:06:01 +01:00
ee98e7a82e Implement Dirichlet and Beta distributions (#4117) 2017-12-18 19:11:37 +01:00