31 Commits

Author SHA1 Message Date
b13cd141b3 Add pyrefly suppressions (#164748)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the `project-excludes` field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:

0 errors (4,263 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164748
Approved by: https://github.com/oulgen
2025-10-07 17:31:18 +00:00
e2c9d8d641 Fix non-bitwise type annotations for Tensor operators (see #145838) (#146845)
Fix https://github.com/pytorch/pytorch/issues/145838

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146845
Approved by: https://github.com/Skylion007
2025-06-24 15:41:34 +00:00
6c38b9be73 [typing] Add type hints to __init__ methods in torch.distributions. (#144197)
Fixes #144196
Extends #144106 and #144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-04-06 17:50:35 +00:00
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
355b0bc7e3 [typing] Add type hints to @property and @lazy_property in torch.distributions. (#144110)
Fixes #76772, #144196
Extends #144106

- added type annotations to `lazy_property`.
- added type annotation to all `@property` and `@lazy_property` inside `torch.distributions` module.
- added simply type-check unit test to ensure type inference is working.
- replaced deprecated annotations like `typing.List` with the corresponding counterpart.
- simplified `torch.Tensor` hints with plain `Tensor`, otherwise signatures can become very verbose.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144110
Approved by: https://github.com/Skylion007
2025-01-07 19:27:36 +00:00
e72e924eb5 Add correct typing annotations to rsample() for all distributions (#133516)
Fixes #133514
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133516
Approved by: https://github.com/Skylion007
2024-08-18 20:31:54 +00:00
b25ef91bf1 [BE][Easy][18/19] enforce style for empty lines in import segments in torch/d*/ (#129770)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129770
Approved by: https://github.com/wconstab
2024-08-01 04:22:50 +00:00
7c12cc7ce4 Flip default value for mypy disallow_untyped_defs [6/11] (#127843)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843
Approved by: https://github.com/oulgen
ghstack dependencies: #127842
2024-06-08 18:49:29 +00:00
5a1216bb2e [BE]: Update ruff to 0.4.1 (#124549)
Update ruff to 0.4.1 .
This version fixes a lot false negatives/false positives, is 20-40% faster, and has various other bug fixes.

Below is a before and after table showing the execution time of ruff lint and ruff format in milliseconds courtesy of https://astral.sh/blog/ruff-v0.4.0

| Repository                                         | Linter (v0.3) | Linter (v0.4) | Formatter (v0.3) | Formatter (v0.4) |
|----------------------------------------------------|---------------|---------------|------------------|------------------|
| [pytorch/pytorch](https://github.com/pytorch/pytorch) | 328.7         | 251.8         | 351.1            | 274.9            |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124549
Approved by: https://github.com/ezyang
2024-04-21 14:06:23 +00:00
3bf922a6ce Apply UFMT to low traffic torch modules (#106249)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249
Approved by: https://github.com/Skylion007
2023-07-29 23:37:30 +00:00
3721fa5612 [BE] Enable ruff's UP rules and autoformat optim/ (#105426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105426
Approved by: https://github.com/malfet, https://github.com/albanD, https://github.com/aaronenyeshi, https://github.com/janeyx99
2023-07-18 21:07:43 +00:00
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
8ebbd5a89a Easier to understand event_dim computation (#81396)
Fixes #81254
Only easier to understand, not a real fix.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81396
Approved by: https://github.com/fritzo, https://github.com/kit1980
2022-11-16 04:38:32 +00:00
3bcc19b29a Add __all__ to various submodules in torch.fx, distributions, distributed, package (#80367)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80367
Approved by: https://github.com/albanD
2022-06-27 21:27:30 +00:00
8548657ddb TransformedDistribution.icdf: Fix erroneous icdf ValueError (#71393)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/66946

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71393

Reviewed By: albanD

Differential Revision: D33810839

Pulled By: neerajprad

fbshipit-source-id: a86d8ce3b196b6b06e1466e4030fc549bc07d332
(cherry picked from commit 88743f53b2877a3ab43365c6cb5771856bf24967)
2022-01-28 00:34:08 +00:00
76b58dd9ae Fix distributions which don't properly honor validate_args=False (#53600)
Summary:
A number of derived distributions use base distributions in their
implementation.

We add what we hope is a comprehensive test whether all distributions
actually honor skipping validation of arguments in log_prob and then
fix the bugs we found. These bugs are particularly cumbersome in
PyTorch 1.8 and master when validate_args is turned on by default
In addition one might argue that validate_args is not performing
as well as it should when the default is not to validate but the
validation is turned on in instantiation.

Arguably, there is another set of bugs or at least inconsistencies
when validation of inputs does not prevent invalid indices in
sample validation (when with validation an IndexError is raised
in the test). We would encourage the implementors to be more
ambitious when validation is turned on and amend sample validation
to throw a ValueError for consistency.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53600

Reviewed By: mrshenli

Differential Revision: D26928088

Pulled By: neerajprad

fbshipit-source-id: 52784a754da2faee1a922976e2142957c6c02e28
2021-03-10 13:16:32 -08:00
a347c747df Fix TransformedDistribution shaping logic (#50581)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/50496
Fixes https://github.com/pytorch/pytorch/issues/34859
Fixes https://github.com/pytorch/pytorch/issues/21596

This fixes many bugs involving `TransformedDistribution` and `ComposeTransform` when the component transforms changed their event shapes. Part of the fix is to introduce an `IndependentTransform` analogous to `distributions.Independent` and `constraints.independent`, and to introduce methods `Transform.forward_shape()` and `.inverse_shape()`. I have followed fehiepsi's suggestion and replaced `.input_event_dim` -> `.domain.event_dim` and `.output_event_dim` -> `.codomain.event_dim`. This allows us to deprecate `.event_dim` as an attribute.

## Summary of changes

- Fixes `TransformDistribution` and `ComposeTransform` shape errors.
- Fixes a behavior bug in `LogisticNormal`.
- Fixes `kl_divergence(TransformedDistribution, TransformedDistribution)`
- Adds methods `Transform.forward_shape()`, `.inverse_shape()` which are required for correct shape computations in `TransformedDistribution` and `ComposeTransform`.
- Adds an `IndependentTransform`.
- Adds a `ReshapeTransform` which is invaluable in testing shape logic in `ComposeTransform` and `TransformedDistribution` and which will be used by stefanwebb flowtorch.
- Fixes incorrect default values in `constraints.dependent.event_dim`.
- Documents the `.event_dim` and `.is_discrete` attributes.

## Changes planned for follow-up PRs

- Memoize `constraints.dependent_property` as we do with `lazy_property`, since we now consult those properties much more often.

## Tested
- [x] added a test for `Dist.support` vs `Dist(**params).support` to ensure static and dynamic attributes agree.
- [x] refactoring is covered by existing tests
- [x] add test cases for `ReshapedTransform`
- [x] add a test for `TransformedDistribution` on a wide grid of input shapes
- [x] added a regression test for https://github.com/pytorch/pytorch/issues/34859

cc fehiepsi feynmanliang stefanwebb

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50581

Reviewed By: ezyang, glaringlee, jpchen

Differential Revision: D26024247

Pulled By: neerajprad

fbshipit-source-id: f0b9a296f780ff49659b132409e11a29985dde9b
2021-01-25 16:34:12 -08:00
146721f1df Fix typing errors in the torch.distributions module (#45689)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/42979.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45689

Reviewed By: agolynski

Differential Revision: D24229870

Pulled By: xuzhao9

fbshipit-source-id: 5fc87cc428170139962ab65b71cacba494d46130
2020-10-12 10:29:45 -07:00
b1b00f329e Fix the flake8 linter
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16549

Reviewed By: bddppq

Differential Revision: D13877435

Pulled By: houseroad

fbshipit-source-id: dbe575ba3f6dd30d27ac6aa5eec2eea025063540
2019-01-30 09:36:00 -08:00
c391c20063 Adding .expand method for TransformedDistribution (#11607)
Summary:
This PR:
 - adds a `.expand` method for `TransformedDistribution` along the lines of #11341.
 - uses this method to simplify `.expand` in distribution classes that subclass off of `TransformedDistribution`.
 - restores testing of `TransformedDistribution` fixtures.
 - fixes some bugs wherein we were not setting certain attributes in the expanded instances, and adds tests for `.mean` and `.variance` which use these attributes.

There are many cases where users directly use `TransformedDistribution` rather than subclassing off it. In such cases, it seems rather inconvenient to have to write a separate class just to define a `.expand` method. The default implementation should suffice in these cases.

cc. fritzo, vishwakftw, alicanb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11607

Differential Revision: D9818225

Pulled By: soumith

fbshipit-source-id: 2c4b3812b9a03e6985278cfce0f9a127ce536f23
2018-09-14 07:55:33 -07:00
bca10ad706 Implementation of Weibull distribution (#9454)
Summary:
This implements the two-parameter Weibull distribution, with scale $\lambda$ and shape $k$ parameters as described on [Wikipedia](https://en.wikipedia.org/wiki/Weibull_distribution).

**Details**
- We implement as a transformed exponential distribution, as described [here](https://en.wikipedia.org/wiki/Weibull_distribution#Related_distributions).
- The `weibull_min` variance function in scipy does not yet support a vector of distributions, so our unit test uses a scalar distribution instead of a vector.

Example of the bug:

```
>>> sp.stats.expon(np.array([0.5, 1, 2])).var() # fine
array([1., 1., 1.])
>>> sp.stats.weibull_min(c=np.array([0.5, 1, 2])).var() # buggy
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.7/site-packages/scipy/stats/_distn_infrastructure.py", line 490, in var
    return self.dist.var(*self.args, **self.kwds)
  File "/usr/local/lib/python3.7/site-packages/scipy/stats/_distn_infrastructure.py", line 1242, in var
    res = self.stats(*args, **kwds)
  File "/usr/local/lib/python3.7/site-packages/scipy/stats/_distn_infrastructure.py", line 1038, in stats
    if np.isinf(mu):
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9454

Differential Revision: D8863574

Pulled By: SsnL

fbshipit-source-id: 1ad3e175b469eee2b6af98e7b379ea170d3d9787
2018-07-24 20:40:15 -07:00
9d88ff7d0d Add half cauchy, half normal distributions (#8411) 2018-06-14 10:28:42 +02:00
52368f25cc Example for Transformed Distribution (#8011) 2018-06-01 16:23:57 +02:00
98c24fae6b Fix broadcasting error in LogNormal and TransformedDistribution (#7269) 2018-05-03 23:03:51 -04:00
b2da9fd220 [distributions] Rename .params to .arg_constraints, fix logic (#5989) 2018-03-25 15:24:32 +02:00
7f864bbe52 Fixed distribution constraints and added some test cases for distributions parameter check (#5358) 2018-03-15 23:11:20 +01:00
a4d0a74cee Ensure Distribution.sample() result is detached (#5086) 2018-02-14 01:32:11 +01:00
b608ea9178 Fix sign error in TransformedDistribution.cdf() and .icdf() (#5172) 2018-02-12 11:45:22 +01:00
011941087a Implementation of the cumulative distribution function and its inverse (#5079) 2018-02-07 16:10:19 +01:00
ca5071d072 Support multivariate TransformedDistributions (#4937) 2018-01-31 18:32:24 +01:00
967bceb16b Implement Transforms (#4771) 2018-01-28 21:17:16 +01:00