39 Commits

Author SHA1 Message Date
b13cd141b3 Add pyrefly suppressions (#164748)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the `project-excludes` field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:

0 errors (4,263 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164748
Approved by: https://github.com/oulgen
2025-10-07 17:31:18 +00:00
6c38b9be73 [typing] Add type hints to __init__ methods in torch.distributions. (#144197)
Fixes #144196
Extends #144106 and #144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-04-06 17:50:35 +00:00
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
355b0bc7e3 [typing] Add type hints to @property and @lazy_property in torch.distributions. (#144110)
Fixes #76772, #144196
Extends #144106

- added type annotations to `lazy_property`.
- added type annotation to all `@property` and `@lazy_property` inside `torch.distributions` module.
- added simply type-check unit test to ensure type inference is working.
- replaced deprecated annotations like `typing.List` with the corresponding counterpart.
- simplified `torch.Tensor` hints with plain `Tensor`, otherwise signatures can become very verbose.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144110
Approved by: https://github.com/Skylion007
2025-01-07 19:27:36 +00:00
b25ef91bf1 [BE][Easy][18/19] enforce style for empty lines in import segments in torch/d*/ (#129770)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129770
Approved by: https://github.com/wconstab
2024-08-01 04:22:50 +00:00
7c12cc7ce4 Flip default value for mypy disallow_untyped_defs [6/11] (#127843)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843
Approved by: https://github.com/oulgen
ghstack dependencies: #127842
2024-06-08 18:49:29 +00:00
e08577aec5 Spelling fix (#108490)
Fixes spelling mistake: non-deterinistic -> non-deterministic
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108490
Approved by: https://github.com/ezyang
2023-09-04 16:59:35 +00:00
3bf922a6ce Apply UFMT to low traffic torch modules (#106249)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249
Approved by: https://github.com/Skylion007
2023-07-29 23:37:30 +00:00
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00
3bcc19b29a Add __all__ to various submodules in torch.fx, distributions, distributed, package (#80367)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80367
Approved by: https://github.com/albanD
2022-06-27 21:27:30 +00:00
40576bceaf Add mode property to distributions. (#76690)
This PR fixes #69466 and introduces some other minor changes. Tests are somewhat more involved because a reference implementation in `scipy` is not available; tests proceed differently for discrete and continuous distributions.

For continuous distributions, we evaluate the gradient of the `log_prob` at the mode. Tests pass if the gradient is zero OR (the mode is at the boundary of the support of the distribution AND the `log_prob` decreases as we move away from the boundary to the interior of the support).

For discrete distributions, the notion of a gradient is not well defined. We thus "look" ahead and behind one step (e.g. if the mode of a Poisson distribution is 9, we consider 8 and 10). If the step ahead/behind is still within the support of the distribution, we assert that the `log_prob` is smaller than at the mode.

For one-hot encoded distributions (currently just `OneHotCategorical`), we evaluate the underlying mode (i.e. encoded as an integral tensor), "advance" by one label to get another sample that should have lower probability using `other = (mode + 1) % event_size` and re-encode as one-hot. The resultant `other` sample should have lower probability than the mode.

Furthermore, Gamma, half Cauchy, and half normal distributions have their support changed from positive to nonnegative. This change is necessary because the mode of the "half" distributions is zero, and the mode of the gamma distribution is zero for `concentration <= 1`.

cc @fritzo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76690
Approved by: https://github.com/neerajprad
2022-05-11 18:26:56 +00:00
35f1617001 Implement Entropy methods for Binomial and Multinomial distributions (#67609)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/60866.

Because it seems https://github.com/pytorch/pytorch/pull/61719 shows no response for a long time, I made this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67609

Reviewed By: malfet

Differential Revision: D32310866

Pulled By: mruberry

fbshipit-source-id: b3a8dde452f448e5981f5405f5f925f860b0d84f
2021-11-11 09:16:28 -08:00
a347c747df Fix TransformedDistribution shaping logic (#50581)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/50496
Fixes https://github.com/pytorch/pytorch/issues/34859
Fixes https://github.com/pytorch/pytorch/issues/21596

This fixes many bugs involving `TransformedDistribution` and `ComposeTransform` when the component transforms changed their event shapes. Part of the fix is to introduce an `IndependentTransform` analogous to `distributions.Independent` and `constraints.independent`, and to introduce methods `Transform.forward_shape()` and `.inverse_shape()`. I have followed fehiepsi's suggestion and replaced `.input_event_dim` -> `.domain.event_dim` and `.output_event_dim` -> `.codomain.event_dim`. This allows us to deprecate `.event_dim` as an attribute.

## Summary of changes

- Fixes `TransformDistribution` and `ComposeTransform` shape errors.
- Fixes a behavior bug in `LogisticNormal`.
- Fixes `kl_divergence(TransformedDistribution, TransformedDistribution)`
- Adds methods `Transform.forward_shape()`, `.inverse_shape()` which are required for correct shape computations in `TransformedDistribution` and `ComposeTransform`.
- Adds an `IndependentTransform`.
- Adds a `ReshapeTransform` which is invaluable in testing shape logic in `ComposeTransform` and `TransformedDistribution` and which will be used by stefanwebb flowtorch.
- Fixes incorrect default values in `constraints.dependent.event_dim`.
- Documents the `.event_dim` and `.is_discrete` attributes.

## Changes planned for follow-up PRs

- Memoize `constraints.dependent_property` as we do with `lazy_property`, since we now consult those properties much more often.

## Tested
- [x] added a test for `Dist.support` vs `Dist(**params).support` to ensure static and dynamic attributes agree.
- [x] refactoring is covered by existing tests
- [x] add test cases for `ReshapedTransform`
- [x] add a test for `TransformedDistribution` on a wide grid of input shapes
- [x] added a regression test for https://github.com/pytorch/pytorch/issues/34859

cc fehiepsi feynmanliang stefanwebb

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50581

Reviewed By: ezyang, glaringlee, jpchen

Differential Revision: D26024247

Pulled By: neerajprad

fbshipit-source-id: f0b9a296f780ff49659b132409e11a29985dde9b
2021-01-25 16:34:12 -08:00
21c2542b6a Independent constraint (#50547)
Summary:
Addresses https://github.com/pytorch/pytorch/issues/50496

This fixes a number of inconsistencies in torch.distributions.constraints as used for parameters and supports of probability distributions.
- Adds a `constraints.independent` and replaces `real_vector` with `independent(real, 1)`. (this pattern has long been used in Pyro)
- Adds an `.event_dim` attribute to all constraints.
- Tests that `constraint.check(data)` has the correct shape. (Previously the shapes were incorrect).
- Adds machinery to set static `.is_discrete` and `.event_dim` for `constraints.dependent`.
- Fixes constraints for a number of distributions.

## Tested
- added a new check to the constraints tests
- added a new check for `.event_dim`

cc fehiepsi feynmanliang stefanwebb

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50547

Reviewed By: VitalyFedyunin

Differential Revision: D25918330

Pulled By: neerajprad

fbshipit-source-id: a648c3de3e8704f70f445c0f1c39f2593c8c74db
2021-01-21 18:42:45 -08:00
9202c44379 Fix error in Binomial to retain lazy logit initialization (#46055)
Summary:
Some internal tests were sporadically failing for https://github.com/pytorch/pytorch/pull/45648. The cause of this is a bug in `Binomial.__init__` that references the lazy `logits` attribute and sets it when not needed. This cleans up the `is_scalar` logic too which isn't needed given that `broadcast_all` will convert `Number` to a `tensor`.

The reason for the flakiness is the mutation of the params dict by the first test, which is fixed by doing a shallow copy. It will be better to convert this into a pytest parameterized test once https://github.com/pytorch/pytorch/pull/45648 is merged.

cc. fritzo, ezyang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/46055

Reviewed By: agolynski

Differential Revision: D24221151

Pulled By: neerajprad

fbshipit-source-id: 15aae90a692ee6aed729c9f1d2d1b1388170a3c0
2020-10-12 10:56:06 -07:00
6fcabf619d [takeover] BTRS algorithm for fast/efficient binomial sampling (#36858)
Summary:
The original PR is https://github.com/pytorch/pytorch/pull/31278.

CC: ezyang jamestwebber fritzo zasdfgbnm

 ---

<!-- # This PR - CPU
In [1]: import torch; import torch.distributions as dist

In [2]: counts = torch.randint(10, 1000, [1000,1000])
   ...: p = 0.5 * torch.ones(1000, 1000)

In [3]: %timeit dist.binomial.Binomial(total_count=counts, probs=p).sample()
94.8 ms ± 911 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
-->
```
# This PR - GPU
In [1]: import torch; import torch.distributions as dist

In [2]: counts = torch.randint(10, 1000, [1000,1000]).cuda(); p = 0.5 * torch.ones(1000, 1000).cuda()

In [3]:  %timeit dist.binomial.Binomial(total_count=counts, probs=p).sample()
737 µs ± 216 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

# master (commit: 806f22b167c74897cf67c0828b528fa3e4e6d6de) - GPU
In [5]: counts = torch.randint(10, 1000, [1000,1000]).cuda(); p = 0.5 * torch.ones(1000, 1000).cuda()

In [6]: %timeit dist.binomial.Binomial(total_count=counts, probs=p).sample()
46.3 ms ± 76.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36858

Differential Revision: D21178367

Pulled By: ezyang

fbshipit-source-id: 7e7d6f463e35b07156d69bd7452040b2f9c2eb7a
2020-04-22 15:53:41 -07:00
071971476d Fix Binomimal overflow when logits is large (#20679)
Summary:
This PR fixes  #17843. In addition (test locally), this still maintains the continuity of log_prob which is addressed in https://github.com/pytorch/pytorch/pull/15962

cc neerajprad
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20679

Differential Revision: D15413311

Pulled By: ezyang

fbshipit-source-id: 4fc0ca755ae6a85aa7deb2206dab675f82f9aa25
2019-05-20 06:52:29 -07:00
8e1e29124d Fix precision issue with expansion that prefers 'probs' over 'logits' (#18614)
Summary:
I have experienced that sometimes both were in `__dict__`, but it chose to copy `probs` which loses precision over `logits`. This is especially important when training (bayesian) neural networks or doing other type of optimization, since the loss is heavily affected.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18614

Differential Revision: D14793486

Pulled By: ezyang

fbshipit-source-id: d4ff5e34fbb4021ea9de9f58af09a7de00d80a63
2019-04-05 13:07:01 -07:00
67f2039f4c Fix numerical stability in binomial.log_prob (#15962)
Summary:
This issue was discovered by fehiepsi in https://github.com/uber/pyro/issues/1706 with the `log_prob` computation for Binomial, ~and can be seen with `torch.float32` when we have a combination of low probability value and high `total_count` - a test is added to capture this (since scipy only uses float64, the comparison is done using relative tolerance).~

The problem is in the code that tries to pull out the minimum values amongst the logits (written by me earlier, presumably to avoid numerical instability issues), but it is not needed.

EDIT: After a few attempts, I have been unable to reliably show that the change is more numerically stable, and have removed my previous test which fails on linux. The reason is that the issue manifests itself when `total_count` is high and `probs` is very low. However, the precision of `lgamma` when `total_count` is high is bad enough to wash away any benefits. The justification for this still stands though - (a) simplifies code (removes the unnecessary bit), (b) is no worse than the previous implementation, (c) has better continuity behavior as observed by fehiepsi in the issue above.

cc. fehiepsi, alicanb, fritzo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15962

Differential Revision: D13709541

Pulled By: ezyang

fbshipit-source-id: 596c6853b6e4d5fba42336afa168a665ab6fbde2
2019-01-17 10:18:37 -08:00
2431eac7c0 Ensure most Distribution methods are jittable (#11560)
Summary:
This adds tests in tests/test_distributions.py to ensure that all methods of `Distribution` objects are jittable.

I've replaced a few samplers with jittable versions:
- `.uniform_()` -> `torch.rand()`
- `.exponential_()` -> `-(-torch.rand()).log1p()`
- `.normal_()` -> `torch.normal(torch.zeros(...), torch.ones(...), ...)`

Some jit failures remain, and are marked in test_distributions.py
- `Cauchy` and `HalfCauchy` do not support sampling due to missing `.cauchy_()`
- `Binomial` does not support `.enumerate_support()` due to `arange` ignoring its first arg.
- `MultivariateNormal`, `LowRankMultivariateNormal` do not support `.mean`, `.entropy`

- [x] Currently some tests fail (I've skipped those) due to unavailability of `aten::uniform` and `aten::cauchy` in the jit. Can someone suggest how to add these? I tried to add declarations to `torch/csrc/ir.cpp` and `torch/csrc/passes/shape_analysis.cpp`, but that resulted in "Couldn't find operator" errors.
- [x] There are still lots of `TracerWarning`s that something doesn't match something. I'm not sure whether these are real.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11560

Differential Revision: D9816327

Pulled By: apaszke

fbshipit-source-id: 72ec998ea13fc4c76d1ed003d9502e0fbaf728b8
2018-09-13 19:55:01 -07:00
bbf54ea37c Ensure .enumerate_support() methods are jittable (#11542)
Summary:
This works around #11535 by avoiding `arange(n, out=x)` and `eye(n, out=x)` in `torch.distributions`. I've confirmed that the `.enumerate_support()` methods are now jittable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11542

Differential Revision: D9777805

Pulled By: apaszke

fbshipit-source-id: fa38f2f1acfc0a289f725fd8c92478573cfdbefb
2018-09-11 18:26:09 -07:00
80fa8e1007 Add .expand() method to distribution classes (#11341)
Summary:
This adds a `.expand` method for distributions that is akin to the `torch.Tensor.expand` method for tensors. It returns a new distribution instance with batch dimensions expanded to the desired `batch_shape`. Since this calls `torch.Tensor.expand` on the distribution's parameters, it does not allocate new memory for the expanded distribution instance's parameters.

e.g.
```python
>>> d = dist.Normal(torch.zeros(100, 1), torch.ones(100, 1))
>>> d.sample().shape
  torch.Size([100, 1])
>>> d.expand([100, 10]).sample().shape
  torch.Size([100, 10])
```

We have already been using the `.expand` method in Pyro in our [patch](https://github.com/uber/pyro/blob/dev/pyro/distributions/torch.py#L10) of `torch.distributions`. We use this in our models to enable dynamic broadcasting. This has also been requested by a few users on the distributions slack, and we believe will be useful to the larger community.

Note that currently, there is no convenient and efficient way to expand distribution instances:
 - Many distributions use `TransformedDistribution` (or wrap over another distribution instance. e.g. `OneHotCategorical` uses a `Categorical` instance) under the hood, or have lazy parameters. This makes it difficult to collect all the relevant parameters, broadcast them and construct new instances.
 - In the few cases where this is even possible, the resulting implementation would be inefficient since we will go through a lot of broadcasting and args validation logic in `__init__.py` that can be avoided.

The `.expand` method allows for a safe and efficient way to expand distribution instances. Additionally, this bypasses `__init__.py` (using `__new__` and populating relevant attributes) since we do not need to do any broadcasting or args validation (which was already done when the instance was first created). This can result in significant savings as compared to constructing new instances via `__init__` (that said, the `sample` and `log_prob` methods will probably be the rate determining steps in many applications).

e.g.
```python
>>> a = dist.Bernoulli(torch.ones([10000, 1]), validate_args=True)

>>> %timeit a.expand([10000, 100])
15.2 µs ± 224 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

>>> %timeit dist.Bernoulli(torch.ones([10000, 100]), validate_args=True)
11.8 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```

cc. fritzo, apaszke, vishwakftw, alicanb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11341

Differential Revision: D9728485

Pulled By: soumith

fbshipit-source-id: 3b94c23bc6a43ee704389e6287aa83d1e278d52f
2018-09-11 06:56:18 -07:00
b3b1e7624d Optional expand=True kwarg in distribution.enumerate_support (#11231)
Summary:
This adds an optional `expand=True` kwarg to the `distribution.expand_support()` method, to get a distribution's support without expanding the values over the distribution's `batch_shape`.
 - The default `expand=True` preserves the current behavior, whereas `expand=False` collapses the batch dimensions.

e.g.
```python
In [47]: d = dist.OneHotCategorical(torch.ones(3, 5) * 0.5)

In [48]: d.batch_shape
Out[48]: torch.Size([3])

In [49]: d.enumerate_support()
Out[49]:
tensor([[[1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.]],

        [[0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 0.]],

        [[0., 0., 0., 1., 0.],
         [0., 0., 0., 1., 0.],
         [0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.]]])

In [50]: d.enumerate_support().shape
Out[50]: torch.Size([5, 3, 5])

In [51]: d.enumerate_support(expand=False)
Out[51]:
tensor([[[1., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0.]],

        [[0., 0., 1., 0., 0.]],

        [[0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 1.]]])

In [52]: d.enumerate_support(expand=False).shape
Out[52]: torch.Size([5, 1, 5])
```

**Motivation:**
 - Currently `enumerate_support` builds up tensors of size `support + batch_shape + event_shape`, but the values are *repeated* over the `batch_shape` (adding little in the way of information). This can lead to expensive matrix operations over large tensors when `batch_shape` is large (see, example above), often leading to OOM issues. We use `expand=False` in Pyro for message passing inference. e.g. when enumerating over the state space in a Hidden Markov Model. This creates sparse tensors that capture the markov dependence, and allows for the possibility of using optimized matrix operations over these sparse tensors. `expand=True`, on the other hand, will create tensors that scale exponentially in size with the length of the Markov chain.
 - We have been using this in our [patch](https://github.com/uber/pyro/blob/dev/pyro/distributions/torch.py) of `torch.distributions` in Pyro. The interface has been stable, and it is already being used in a few Pyro algorithms. We think that this is more broadly applicable and will be of interest to the larger distributions community.

cc. apaszke, fritzo, alicanb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11231

Differential Revision: D9696290

Pulled By: soumith

fbshipit-source-id: c556f8ff374092e8366897ebe3f3b349538d9318
2018-09-06 21:39:42 -07:00
434e943b08 Fix to distribution.__repr__ with lazy attributes (#11263)
Summary:
`__repr__` currently fails for distributions with lazy attributes in PyTorch master, throwing a `KeyError`. This fixes the issue.

**Additionally:**
 - Added `logits` to `arg_constraints` for distributions that accept either `probs` or `logits`. This is both to have `__repr__` display the `logits` param when available, and to be able to do validation checks (e.g. NaN checks) when the logit parametrization is used. fritzo, alicanb - I think there were reasons why we had not done so in the first place, but I am unable to recall now. It passes all the tests, but let me know if there is something that I am missing at the moment.
 - There are certain distributions, e.g. `OneHotCategorical` which won't show any parameters because it uses a `categorical` instance under the hood and neither `logits` / `probs` in `arg_constraints` are present in the instance's `__dict__`. This isn't addressed in this PR.

cc. vishwakftw, fritzo, nadavbh12, apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11263

Differential Revision: D9654959

Pulled By: apaszke

fbshipit-source-id: 16f5b20243fe8e2c13e9c528050d4df0b8ea6e45
2018-09-05 09:55:51 -07:00
f940af6293 Bag of Distributions doc fixes (#10894)
Summary:
- Added `__repr__` for Constraints and Transforms.
- Arguments passed to the constructor are now rendered with :attr:

Closes https://github.com/pytorch/pytorch/issues/10884
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10894

Differential Revision: D9514161

Pulled By: apaszke

fbshipit-source-id: 4abf60335d876449f2b6477eb9655afed9d5b80b
2018-08-27 09:55:27 -07:00
3cbaa6b785 [ready] Clean up torch.distributions (#8046) 2018-06-02 16:54:53 +02:00
3964253f94 Allowing for vectorized counts in Binomial Distribution (#6720) 2018-04-26 15:53:01 +02:00
d564ecb4a5 Update docs with new tensor repr (#6454)
* Update docs with new tensor repr

* remove cuda in dtype

* remove changes to gloo submodule

* [docs] document tensor.new_* ctor

* [docs] Add docs for tensor.to(), tensor.float(), etc

* [docs] Moar examples for docs.

* [docs] Warning for tensor ctor copy behavior

* Quick fix

* [docs] Document requires_grad_()

* [docs] Add example for requires_grad_()

* update slogdet and *fft

* update tensor rst

* small fixes

* update some docs

* additional doc changes

* update torch and tensor docs

* finish changing tensor docs

* fix flake8

* slogdet with negative det

* Update functional.py tensor ctors

* Fix nll_loss docs

* reorder to move device up

* torch.LongTensor -> torch.tensor or torch.empty in docs

* update tensor constructors in docs

* change tensor constructors

* change constructors

* change more Tensor() to tensor()

* Show requires_grads_ docs

* Fix set_default_dtype docs

* Update docs with new tensor repr

* remove cuda in dtype

* remove changes to gloo submodule

* [docs] document tensor.new_* ctor

* [docs] Add docs for tensor.to(), tensor.float(), etc

* [docs] Moar examples for docs.

* [docs] Warning for tensor ctor copy behavior

* Quick fix

* [docs] Document requires_grad_()

* [docs] Add example for requires_grad_()

* update slogdet and *fft

* update tensor rst

* small fixes

* update some docs

* additional doc changes

* update torch and tensor docs

* finish changing tensor docs

* fix flake8

* slogdet with negative det

* Update functional.py tensor ctors

* Fix nll_loss docs

* reorder to move device up

* torch.LongTensor -> torch.tensor or torch.empty in docs

* update tensor constructors in docs

* change tensor constructors

* change constructors

* change more Tensor() to tensor()

* Show requires_grads_ docs

* Fix set_default_dtype docs

* Link to torch.no_grad, etc, from torch doc

* Add dtype aliases to table

* regen docs again

* Tensor attributes stub page

* link to inplace sampling

* Link torch.dtype, device, and layout

* fix dots after nonfinite floats

* better layout docs
2018-04-21 07:35:37 -04:00
1c01eabd3c Codemod to update our codebase to 0.4 standard (#6641)
* Codemod to update our codebase to 0.4 standard

* Update some of the test scri[ts

* remove Variable in test_clip_grad_value

* fix _symbolic_override_wrapper_maker
2018-04-17 22:06:54 -04:00
b2da9fd220 [distributions] Rename .params to .arg_constraints, fix logic (#5989) 2018-03-25 15:24:32 +02:00
7f864bbe52 Fixed distribution constraints and added some test cases for distributions parameter check (#5358) 2018-03-15 23:11:20 +01:00
54b4cdeffa Replace all uses of 'Tensor or Variable' with 'Tensor' (#5508)
Replace all uses of 'Tensor or Variable'  and 'Variable or Tensor' with 'Tensor'
2018-03-02 14:26:11 -05:00
a4d0a74cee Ensure Distribution.sample() result is detached (#5086) 2018-02-14 01:32:11 +01:00
85a7e0fc41 Addition of ExponentialFamily (#4876) 2018-02-04 12:18:28 +01:00
20fbdb9a8b Adding mean, variance, stddev to distributions (#4923) 2018-01-31 00:26:32 +01:00
691c38d670 Remove windows linebreaks in various distributions files. (#4817) 2018-01-23 17:15:59 -05:00
b37aa2bf0e Ensure lazy evaluation for probs and logits (#4691) 2018-01-17 22:36:40 +01:00
3254eca8c8 Implement binomial distribution (#4658) 2018-01-16 21:39:05 +01:00