59 Commits

Author SHA1 Message Date
70925bdf82 [1/N] Use "is" in python type comparison (#165037)
It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165037
Approved by: https://github.com/mlazos
2025-10-10 12:36:50 +00:00
a43c4c3972 [5/N] Apply ruff UP035 rule (#164423)
Continued code migration to enable ruff `UP035`. Most changes are about moving `Callable` from `typing` to `from collections.abc`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164423
Approved by: https://github.com/ezyang
2025-10-02 07:31:11 +00:00
768d73f692 use torch.special.xlogy to implement x_log_x (#144220)
Fixes #144279

Using `x* x.log()` does not produce the correct value when `x=0`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144220
Approved by: https://github.com/Skylion007
2025-01-08 17:41:55 +00:00
355b0bc7e3 [typing] Add type hints to @property and @lazy_property in torch.distributions. (#144110)
Fixes #76772, #144196
Extends #144106

- added type annotations to `lazy_property`.
- added type annotation to all `@property` and `@lazy_property` inside `torch.distributions` module.
- added simply type-check unit test to ensure type inference is working.
- replaced deprecated annotations like `typing.List` with the corresponding counterpart.
- simplified `torch.Tensor` hints with plain `Tensor`, otherwise signatures can become very verbose.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144110
Approved by: https://github.com/Skylion007
2025-01-07 19:27:36 +00:00
5e8e1d725a Remove some unused type ignores (round 1) (#142325)
Over time, a large number of the existing type ignores have become irrelevant/unused/dead as a result of improvements in annotations and type checking.

Having these `# type: ignore` linger around is not ideal for two reasons:

- They are verbose/ugly syntatically.
- They could hide genuine bugs in the future, if a refactoring would actually introduce a bug but it gets hidden by the ignore.

I'm counting over 1500 unused ignores already. This is a first PR that removes some of them. Note that I haven't touched type ignores that looked "conditional" like the import challenge mentioned in https://github.com/pytorch/pytorch/pull/60006#issuecomment-2480604728. I will address these at a later point, and eventually would enable `warn_unused_ignores = True` in the mypy configuration as discussed in that comment to prevent accumulating more dead ignores going forward.

This PR should have no effect on runtime at all.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142325
Approved by: https://github.com/Skylion007, https://github.com/janeyx99
2024-12-09 18:23:46 +00:00
b25ef91bf1 [BE][Easy][18/19] enforce style for empty lines in import segments in torch/d*/ (#129770)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129770
Approved by: https://github.com/wconstab
2024-08-01 04:22:50 +00:00
7c12cc7ce4 Flip default value for mypy disallow_untyped_defs [6/11] (#127843)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843
Approved by: https://github.com/oulgen
ghstack dependencies: #127842
2024-06-08 18:49:29 +00:00
5a1216bb2e [BE]: Update ruff to 0.4.1 (#124549)
Update ruff to 0.4.1 .
This version fixes a lot false negatives/false positives, is 20-40% faster, and has various other bug fixes.

Below is a before and after table showing the execution time of ruff lint and ruff format in milliseconds courtesy of https://astral.sh/blog/ruff-v0.4.0

| Repository                                         | Linter (v0.3) | Linter (v0.4) | Formatter (v0.3) | Formatter (v0.4) |
|----------------------------------------------------|---------------|---------------|------------------|------------------|
| [pytorch/pytorch](https://github.com/pytorch/pytorch) | 328.7         | 251.8         | 351.1            | 274.9            |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124549
Approved by: https://github.com/ezyang
2024-04-21 14:06:23 +00:00
46712b019d Enable local_partial_types (#118467)
When using dmypy, this setting is enabled and cannot be turned off. Force it for regular mypy too.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118467
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418, #118432
2024-01-28 13:38:22 +00:00
3bf922a6ce Apply UFMT to low traffic torch modules (#106249)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249
Approved by: https://github.com/Skylion007
2023-07-29 23:37:30 +00:00
4cc1745b13 [BE] f-stringify torch/ and scripts (#105538)
This PR is a follow up on the pyupgrade series to convert more strings to use f-strings using `flynt`.

- https://docs.python.org/3/reference/lexical_analysis.html#f-strings
- https://pypi.org/project/flynt/

Command used:

```
flynt torch/ -ll 120
flynt scripts/ -ll 120
flynt tools/ -ll 120
```

and excluded `collect_env.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105538
Approved by: https://github.com/ezyang, https://github.com/malfet
2023-07-21 19:35:24 +00:00
3721fa5612 [BE] Enable ruff's UP rules and autoformat optim/ (#105426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105426
Approved by: https://github.com/malfet, https://github.com/albanD, https://github.com/aaronenyeshi, https://github.com/janeyx99
2023-07-18 21:07:43 +00:00
b005ec62b9 [BE] Remove dependency on six and future (#94709)
Remove the Python 2 and 3 compatibility library [six](https://pypi.org/project/six) and [future](https://pypi.org/project/future) and `torch._six`. We only support Python 3.8+ now. It's time to retire them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94709
Approved by: https://github.com/malfet, https://github.com/Skylion007
2023-02-14 09:14:14 +00:00
8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00
cf2f552cd8 Add __all__ to torch.{fx, distributed, backends} submodules (#85079)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85079
Approved by: https://github.com/rohan-varma
2022-09-20 12:51:08 +00:00
27fc9fcd13 More stable computation of KL between two Bernoulli distributions (#79944)
Fixes #20164

@neerajprad here the new PR with the updated master

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79944
Approved by: https://github.com/neerajprad
2022-06-27 21:31:45 +00:00
dd620c4575 add type annotation to distributions.kl_divergence (#78432)
Fixes #78431

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78432
Approved by: https://github.com/fritzo, https://github.com/ejguan
2022-06-10 13:39:20 +00:00
f92cddd890 Removed direct doc formatting
Fixes #76034

This does not make python remove all `__doc__` because in some places `__doc__` is assigned to a string.

Example:
04b3313379/torch/nn/modules/conv.py (L174-L233)

Since there are quite a few of these, I will add all of them together in this PR later. (Basically still a lot of docstring will persist even with `-OO` enabled.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76619
Approved by: https://github.com/albanD
2022-05-02 14:14:33 +00:00
c837caf5c5 Adding details to kl.py (#72845)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/72765.

- [x] Improved `NotImplementedError` verbosity.
- [x] Automate the docstring generation process

## Improved `NotImplementedError` verbosity
### Code
```python
import torch

dist = torch.distributions

torch_normal = dist.Normal(loc=0.0, scale=1.0)
torch_mixture = dist.MixtureSameFamily(
    dist.Categorical(torch.ones(5,)
    ),
    dist.Normal(torch.randn(5,), torch.rand(5,)),
)

dist.kl_divergence(torch_normal, torch_mixture)
```
#### Output before this PR
```python
NotImplementedError:
```
#### Output after this PR
```python
NotImplementedError: No KL(p || q) is implemented for p type Normal and q type MixtureSameFamily
```

## Automate the docstring generation process
### Docstring before this PR
```python
Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.

    .. math::

        KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx

    Args:
        p (Distribution): A :class:`~torch.distributions.Distribution` object.
        q (Distribution): A :class:`~torch.distributions.Distribution` object.

    Returns:
        Tensor: A batch of KL divergences of shape `batch_shape`.

    Raises:
        NotImplementedError: If the distribution types have not been registered via
            :meth:`register_kl`.
```
### Docstring after this PR
```python
Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.

    .. math::

        KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx

    Args:
        p (Distribution): A :class:`~torch.distributions.Distribution` object.
        q (Distribution): A :class:`~torch.distributions.Distribution` object.

    Returns:
        Tensor: A batch of KL divergences of shape `batch_shape`.

    Raises:
        NotImplementedError: If the distribution types have not been registered via
            :meth:`register_kl`.
    KL divergence is currently implemented for the following distribution pairs:
        * :class:`~torch.distributions.Bernoulli` and :class:`~torch.distributions.Bernoulli`
        * :class:`~torch.distributions.Bernoulli` and :class:`~torch.distributions.Poisson`
        * :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Beta`
        * :class:`~torch.distributions.Beta` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Gamma`
        * :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Uniform`
        * :class:`~torch.distributions.Binomial` and :class:`~torch.distributions.Binomial`
        * :class:`~torch.distributions.Categorical` and :class:`~torch.distributions.Categorical`
        * :class:`~torch.distributions.Cauchy` and :class:`~torch.distributions.Cauchy`
        * :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.Uniform`
        * :class:`~torch.distributions.Dirichlet` and :class:`~torch.distributions.Dirichlet`
        * :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Beta`
        * :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Gamma`
        * :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Gumbel`
        * :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Uniform`
        * :class:`~torch.distributions.ExponentialFamily` and :class:`~torch.distributions.ExponentialFamily`
        * :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Beta`
        * :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Gamma`
        * :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Gumbel`
        * :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Uniform`
        * :class:`~torch.distributions.Geometric` and :class:`~torch.distributions.Geometric`
        * :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Beta`
        * :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Gamma`
        * :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Gumbel`
        * :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Uniform`
        * :class:`~torch.distributions.HalfNormal` and :class:`~torch.distributions.HalfNormal`
        * :class:`~torch.distributions.Independent` and :class:`~torch.distributions.Independent`
        * :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Beta`
        * :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Gamma`
        * :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Laplace`
        * :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Uniform`
        * :class:`~torch.distributions.LowRankMultivariateNormal` and :class:`~torch.distributions.LowRankMultivariateNormal`
        * :class:`~torch.distributions.LowRankMultivariateNormal` and :class:`~torch.distributions.MultivariateNormal`
        * :class:`~torch.distributions.MultivariateNormal` and :class:`~torch.distributions.LowRankMultivariateNormal`
        * :class:`~torch.distributions.MultivariateNormal` and :class:`~torch.distributions.MultivariateNormal`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Beta`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Gamma`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Gumbel`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Laplace`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Uniform`
        * :class:`~torch.distributions.OneHotCategorical` and :class:`~torch.distributions.OneHotCategorical`
        * :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Beta`
        * :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Gamma`
        * :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Uniform`
        * :class:`~torch.distributions.Poisson` and :class:`~torch.distributions.Bernoulli`
        * :class:`~torch.distributions.Poisson` and :class:`~torch.distributions.Binomial`
        * :class:`~torch.distributions.Poisson` and :class:`~torch.distributions.Poisson`
        * :class:`~torch.distributions.TransformedDistribution` and :class:`~torch.distributions.TransformedDistribution`
        * :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Beta`
        * :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Gamma`
        * :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Gumbel`
        * :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Uniform`
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72845

Reviewed By: mikaylagawarecki

Differential Revision: D34344551

Pulled By: soulitzer

fbshipit-source-id: 7a603613a2f56f71138d56399c7c521e2238e8c5
(cherry picked from commit 6b2a51c796cd8a16551d629ca368360eec34faef)
2022-02-19 06:33:08 +00:00
cafcf599d0 Deprecate torch.triangular_solve (#63570)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63570

There is a use of `at::triangular_solve_out` in the file
`torch/csrc/jit/tensorexpr/external_functions.cpp` that I have not dared
to move to `at::linalg_solve_triangular_out`.

**Deprecation note:**

This PR deprecates the `torch.triangular_solve` function in favor of
`torch.linalg.solve_triangular`. An upgrade guide is added to the
documentation for `torch.triangular_solve`.

Note that it DOES NOT remove `torch.triangular_solve`, but
`torch.triangular_solve` will be removed in a future PyTorch release.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D32618035

Pulled By: anjali411

fbshipit-source-id: 0bfb48eeb6d96eff3e96e8a14818268cceb93c83
2021-12-02 13:24:55 -08:00
deaf745aee Add kl divergence between normal and laplace distribution. (#68807)
Summary:
Fixes [https://github.com/pytorch/pytorch/issues/68746]
![KL_normal_laplace](https://user-images.githubusercontent.com/35850237/143008244-f304cee1-9583-4de1-b0d0-5751ebdb8188.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/68807

Reviewed By: H-Huang

Differential Revision: D32750391

Pulled By: neerajprad

fbshipit-source-id: 129e6ef60d6e244d0d6b02b3944bfd5d8b06edcb
2021-12-01 10:22:08 -08:00
0974215c4d Prefer mT and mH over transpose(-2, -1) and transpose(-2, -1).conj() (#64181)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64181

This PR replaces all the calls to:
- `transpose(-2, -1)` or `transpose(-1, -2)` by `mT()` in C++ and `mT` in Python
- `conj().transpose(-2, -1)` or `transpose(-2, -1).conj()` or `conj().transpose(-1, -2)` or `transpose(-1, -2).conj()` by `mH()` in C++ and `mH` in Python.

It also simplifies two pieces of code, and fixes one bug where a pair
of parentheses were missing in the function `make_symmetric_matrices`.

Test Plan: Imported from OSS

Reviewed By: H-Huang

Differential Revision: D31692896

Pulled By: anjali411

fbshipit-source-id: e9112c42343663d442dc5bd53ff2b492094b434a
2021-10-18 13:02:25 -07:00
75024e228c Add lint for unqualified type: ignore (#56290)
Summary:
The other half of https://github.com/pytorch/pytorch/issues/56272.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56290

Test Plan:
CI should pass on the tip of this PR, and we know that the lint works because the following CI runs (before this PR was finished) failed:

- https://github.com/pytorch/pytorch/runs/2384511062
- https://github.com/pytorch/pytorch/actions/runs/765036024

Reviewed By: seemethere

Differential Revision: D27867219

Pulled By: samestep

fbshipit-source-id: e648f07b6822867e70833e23ddafe7fb7eaca235
2021-04-21 08:07:23 -07:00
a347c747df Fix TransformedDistribution shaping logic (#50581)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/50496
Fixes https://github.com/pytorch/pytorch/issues/34859
Fixes https://github.com/pytorch/pytorch/issues/21596

This fixes many bugs involving `TransformedDistribution` and `ComposeTransform` when the component transforms changed their event shapes. Part of the fix is to introduce an `IndependentTransform` analogous to `distributions.Independent` and `constraints.independent`, and to introduce methods `Transform.forward_shape()` and `.inverse_shape()`. I have followed fehiepsi's suggestion and replaced `.input_event_dim` -> `.domain.event_dim` and `.output_event_dim` -> `.codomain.event_dim`. This allows us to deprecate `.event_dim` as an attribute.

## Summary of changes

- Fixes `TransformDistribution` and `ComposeTransform` shape errors.
- Fixes a behavior bug in `LogisticNormal`.
- Fixes `kl_divergence(TransformedDistribution, TransformedDistribution)`
- Adds methods `Transform.forward_shape()`, `.inverse_shape()` which are required for correct shape computations in `TransformedDistribution` and `ComposeTransform`.
- Adds an `IndependentTransform`.
- Adds a `ReshapeTransform` which is invaluable in testing shape logic in `ComposeTransform` and `TransformedDistribution` and which will be used by stefanwebb flowtorch.
- Fixes incorrect default values in `constraints.dependent.event_dim`.
- Documents the `.event_dim` and `.is_discrete` attributes.

## Changes planned for follow-up PRs

- Memoize `constraints.dependent_property` as we do with `lazy_property`, since we now consult those properties much more often.

## Tested
- [x] added a test for `Dist.support` vs `Dist(**params).support` to ensure static and dynamic attributes agree.
- [x] refactoring is covered by existing tests
- [x] add test cases for `ReshapedTransform`
- [x] add a test for `TransformedDistribution` on a wide grid of input shapes
- [x] added a regression test for https://github.com/pytorch/pytorch/issues/34859

cc fehiepsi feynmanliang stefanwebb

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50581

Reviewed By: ezyang, glaringlee, jpchen

Differential Revision: D26024247

Pulled By: neerajprad

fbshipit-source-id: f0b9a296f780ff49659b132409e11a29985dde9b
2021-01-25 16:34:12 -08:00
47db191f0c Implement Kumaraswamy Distribution (#48285)
Summary:
This PR implements the Kumaraswamy distribution.

cc: fritzo alicanb sdaulton

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48285

Reviewed By: ejguan

Differential Revision: D25221015

Pulled By: ezyang

fbshipit-source-id: e621b25a9c75671bdfc94af145a4d9de2f07231e
2020-12-02 07:46:45 -08:00
789e935304 Annotate torch.nn.cpp (#46490)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/46489

Pull Request resolved: https://github.com/pytorch/pytorch/pull/46490

Reviewed By: zhangguanheng66

Differential Revision: D24509519

Pulled By: ezyang

fbshipit-source-id: edffd32ab2ac17ae4bbd44826b71f5cb9f1da1c5
2020-10-23 17:40:32 -07:00
146721f1df Fix typing errors in the torch.distributions module (#45689)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/42979.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45689

Reviewed By: agolynski

Differential Revision: D24229870

Pulled By: xuzhao9

fbshipit-source-id: 5fc87cc428170139962ab65b71cacba494d46130
2020-10-12 10:29:45 -07:00
4e365b9cd1 [Distribution] Implement kl divergence for Cauchy distribution (#36477)
Summary:
Implement closed-form kl divergence between cauchy distribution

### Reference:
https://arxiv.org/pdf/1905.10965.pdf
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36477

Differential Revision: D21134487

Pulled By: ezyang

fbshipit-source-id: 69d2cc2237aa931f224c3807baee7c63f91583fc
2020-04-20 13:27:11 -07:00
a74fbea345 Continuous bernoulli distribution (take 2) (#34619)
Summary:
We recently had a NeurIPS paper (https://arxiv.org/abs/1907.06845 and https://papers.nips.cc/paper/9484-the-continuous-bernoulli-fixing-a-pervasive-error-in-variational-autoencoders) where we introduce a new [0,1]-supported distribution: the continuous Bernoulli. This pull request implements this distribution in pytorch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34619

Differential Revision: D20403123

Pulled By: ngimel

fbshipit-source-id: d807c7d0d372c6daf6cb6ef09df178bc7491abb2
2020-03-12 11:53:18 -07:00
75309b45f3 explicitly provide memory format when calling to clone() at Indexing.cpp
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28660

Test Plan: Imported from OSS

Differential Revision: D18333346

Pulled By: ifedan

fbshipit-source-id: 06590205d883a5096388a4ae318389244130972d
2019-11-07 05:38:32 -08:00
fb40e58f24 Remove deprecated tensor constructors in torch.distributions (#19979)
Summary:
This removes the deprecated `tensor.new_*` constructors (see #16770) from `torch.distributions` module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19979

Differential Revision: D15195618

Pulled By: soumith

fbshipit-source-id: 46b519bfd32017265e90bd5c53f12cfe4a138021
2019-05-02 20:45:02 -07:00
173f224570 Turn on F401: Unused import warning. (#18598)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598
ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18598 Turn on F401: Unused import warning.**

This was requested by someone at Facebook; this lint is turned
on for Facebook by default.  "Sure, why not."

I had to noqa a number of imports in __init__.  Hypothetically
we're supposed to use __all__ in this case, but I was too lazy
to fix it.  Left for future work.

Be careful!  flake8-2 and flake8-3 behave differently with
respect to import resolution for # type: comments.  flake8-3 will
report an import unused; flake8-2 will not.  For now, I just
noqa'd all these sites.

All the changes were done by hand.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: D14687478

fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
2019-03-30 09:01:17 -07:00
291746f110 Rename trtrs to triangular_solve (#18213)
Summary:
Changelog:
- Renames `trtrs` to `triangular_solve` to remain consistent with `cholesky_solve` and `solve`.
- Rename all tests, fix callsites
- Create a tentative alias for `triangular_solve` under the name `trtrs`, and add a deprecation warning to not promote usage.
- Move `isnan` to _torch_docs.py
- Remove unnecessary imports
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18213

Differential Revision: D14566902

Pulled By: ezyang

fbshipit-source-id: 544f57c29477df391bacd5de700bed1add456d3f
2019-03-21 14:27:21 -07:00
a519217ee7 Add batched version of trtrs (#18025)
Summary:
- Remove single batch TH/THC implementations
- Remove `_batch_trtrs_lower` from `multivariate_normal`
- Add tests for batched behavior
- Modify trtrs_backward to accommodate for batched case
- Modify docs

In a future PR, this will be renamed to `triangular_solve`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18025

Differential Revision: D14523004

Pulled By: ifedan

fbshipit-source-id: 11c6a967d107f969b60e5a5c73ce6bb8099ebbe1
2019-03-20 11:11:32 -07:00
8045b3eb14 Registering of kl-divergence for independent distribution (#17681)
Summary:
This address issue https://github.com/pytorch/pytorch/issues/13545 and implements the proposed fix together with a single test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17681

Differential Revision: D14360161

Pulled By: ezyang

fbshipit-source-id: 427afc88e9054b5b0dc39ebbab1087b990695ea5
2019-03-11 08:10:16 -07:00
2681af1c8a Remove redundant wrappers in torch.distributions (#16807)
Summary:
Changelog:
- Remove torch.distributions.multivariate_normal._batch_diag : same functionality is provided by torch.diagonal
- Remove torch.distributions.lowrank_multivariate_normal._batch_vector_diag : same functionality is provided by torch.diag_embed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16807

Differential Revision: D13985550

Pulled By: soumith

fbshipit-source-id: 25c7d00c52ff7f85e431134e9ce0d5dda453667b
2019-02-07 01:13:55 -08:00
67308a9323 Fix expanded mvn and lowrankmvn (#14557)
Summary:
This PR fixes an issue of the slowness expanded MVN.

A notebook to show the problem is [here](https://gist.github.com/fehiepsi/b15ac2978f1045d6d96b1d35b640d742). Basically, mvn's sample and log_prob have expensive computations based on `cholesky` and `trtrs`. We can save a lot of computation based on caching the unbroadcasted version of `scale_tril` (or `cov_diag`, `cov_factor` in lowrank mvn).
When expanding, this cached tensor should not be expanded together with other arguments.

Ref: https://github.com/uber/pyro/issues/1586

cc neerajprad fritzo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14557

Differential Revision: D13277408

Pulled By: soumith

fbshipit-source-id: a6b16f999b008d5da148ccf519b7f32d9c6a5351
2018-11-30 10:49:13 -08:00
9646d68962 support broadcasting in _kl_categorical_categorical (#10533)
Summary:
Support broadcasting in _kl_categorical_categorical

this makes it possible to do:
```
import torch.distributions as dist
import torch
p_dist = dist.Categorical(torch.ones(1,10))
q_dist = dist.Categorical(torch.ones(100,10))
dist.kl_divergence(p_dist, q_dist)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10533

Differential Revision: D9341252

Pulled By: soumith

fbshipit-source-id: 34575b30160b43b6c9e4c3070dd7ef07c00ff5d7
2018-08-15 12:40:17 -07:00
9525925119 Low rank multivariate normal (#8635)
Summary:
This pull request implements low rank multivariate normal distribution where the covariance matrix has the from `W @ W.T + D`. Here D is a diagonal matrix, W has shape n x m where m << n. It used "matrix determinant lemma" and "Woodbury matrix identity" to save computational cost.

During the way, I also revise MultivariateNormal distribution a bit. Here are other changes:
+ `torch.trtrs` works with cuda tensor. So I tried to use it instead of `torch.inverse`.
+ Use `torch.matmul` instead of `torch.bmm` in `_batch_mv`. The former is faster and simpler.
+ Use `torch.diagonal` for `_batch_diag`
+ Reimplement `_batch_mahalanobis` based on `_batch_trtrs_lower`.
+ Use trtrs to compute term2 of KL.
+ `variance` relies on `scale_tril` instead of `covariance_matrix`

TODO:
- [x] Resolve the fail at `_gradcheck_log_prob`
- [x] Add test for KL

cc fritzo stepelu apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/8635

Differential Revision: D8951893

Pulled By: ezyang

fbshipit-source-id: 488ee3db6071150c33a1fb6624f3cfd9b52760c3
2018-07-23 10:10:53 -07:00
27455e9c78 Use _six for inf and nan (#9500)
Summary:
Things like `float('inf')` are actually quite expensive.
```py
In [1]: import math

In [2]: %timeit -n 200 math.inf
49.3 ns ± 1.42 ns per loop (mean ± std. dev. of 7 runs, 200 loops each)

In [3]: %timeit -n 200 float('inf')
194 ns ± 39.1 ns per loop (mean ± std. dev. of 7 runs, 200 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9500

Reviewed By: soumith

Differential Revision: D8876229

Pulled By: SsnL

fbshipit-source-id: 78602b76bb53d5588910b58270930c0bd413d2d7
2018-07-18 10:40:29 -07:00
9d88ff7d0d Add half cauchy, half normal distributions (#8411) 2018-06-14 10:28:42 +02:00
3cbaa6b785 [ready] Clean up torch.distributions (#8046) 2018-06-02 16:54:53 +02:00
3964253f94 Allowing for vectorized counts in Binomial Distribution (#6720) 2018-04-26 15:53:01 +02:00
1c01eabd3c Codemod to update our codebase to 0.4 standard (#6641)
* Codemod to update our codebase to 0.4 standard

* Update some of the test scri[ts

* remove Variable in test_clip_grad_value

* fix _symbolic_override_wrapper_maker
2018-04-17 22:06:54 -04:00
3497f0207c [distributions] KL-Divergence for Multivariate Normal (#6172) 2018-04-04 13:19:47 +02:00
7cbbc0bc74 Implementation of the logistic-normal distribution (#5547) 2018-03-22 00:32:14 +01:00
54b4cdeffa Replace all uses of 'Tensor or Variable' with 'Tensor' (#5508)
Replace all uses of 'Tensor or Variable'  and 'Variable or Tensor' with 'Tensor'
2018-03-02 14:26:11 -05:00
3b63e552f9 Fix test_distributions when WITH_SCALARS. (#5121)
* Fix test_distributions when WITH_SCALARS.

* Use SCALAR_SHAPE in test, use self.scale in AffineTransform.

* Handle device correctly for scalars.

* Fix one hot categorical.

* Fix relaxed categorical.

* Add a new_tensor instance method to Variable that takes only data.

This is to work around the legacy problems of new, where e.g.
new(5) will give you an unfilled tensor rather than a scalar.

* Fix cuda scalar code path.

* Remove double return.

* Work around lack of WITH_SCALARS.

* Use tensor_new.
2018-02-09 11:01:13 -05:00
85a7e0fc41 Addition of ExponentialFamily (#4876) 2018-02-04 12:18:28 +01:00
423677bacc Add KL-divergence for Categorical and OneHotCategorical and stronger tests (#4961) 2018-02-03 12:47:13 +01:00