344 Commits

Author SHA1 Message Date
70925bdf82 [1/N] Use "is" in python type comparison (#165037)
It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165037
Approved by: https://github.com/mlazos
2025-10-10 12:36:50 +00:00
b13cd141b3 Add pyrefly suppressions (#164748)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the `project-excludes` field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:

0 errors (4,263 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164748
Approved by: https://github.com/oulgen
2025-10-07 17:31:18 +00:00
a43c4c3972 [5/N] Apply ruff UP035 rule (#164423)
Continued code migration to enable ruff `UP035`. Most changes are about moving `Callable` from `typing` to `from collections.abc`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164423
Approved by: https://github.com/ezyang
2025-10-02 07:31:11 +00:00
4cc8b60d1b [BE][1/16] fix typos in torch/ (#156311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156311
Approved by: https://github.com/albanD
2025-07-09 11:02:22 +00:00
e2c9d8d641 Fix non-bitwise type annotations for Tensor operators (see #145838) (#146845)
Fix https://github.com/pytorch/pytorch/issues/145838

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146845
Approved by: https://github.com/Skylion007
2025-06-24 15:41:34 +00:00
ba3f91af97 Type hints for distributions/utils (#154712)
Fixes #144196
Part of #144219

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154712
Approved by: https://github.com/Skylion007
2025-05-30 15:50:31 +00:00
839c9c6156 Use property instead of ClassVar for Uniform.arg_constraints and Wishart.arg_constraints (#154361)
Fixes #154355

For these two distributions, the constraints depend on the actual values, and so `arg_constraints` cannot be a `ClassVar`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154361
Approved by: https://github.com/Skylion007
2025-05-26 17:48:28 +00:00
a54bf43baa Fix support of MixtureSameFamily [bugfix]. (#151317)
Fixes https://github.com/pyro-ppl/pyro/issues/3419 which is actually a `torch` bug that can be replicated by the below code:

```
from torch import rand
from torch.distributions import MixtureSameFamily, Categorical, Binomial

max_count = 20
probs = rand(10, 5)
binom_probs = rand(10, 5)

d = MixtureSameFamily(Categorical(probs=probs), Binomial(max_count, binom_probs))
d.log_prob(d.sample())
```

which results in:

```
Traceback (most recent call last):
  File "test.py", line 11, in <module>
    d.log_prob(d.sample())
  File "pytorch\torch\distributions\mixture_same_family.py", line 168, in log_prob
    self._validate_sample(x)
  File "pytorch\torch\distributions\distribution.py", line 315, in _validate_sample
    valid = support.check(value)
            ^^^^^^^^^^^^^^^^^^^^
  File "pytorch\torch\distributions\constraints.py", line 307, in check
    (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
                                                      ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (10) must match the size of tensor b (5) at non-singleton dimension 1
```

### Fix explanation (only for cases when the component distribution contains parameters with batch dimenisons)

- The failure is due to sample validation taking place before padding in `MixtureSameFamily.log_prob`, and hence the fix is to pad before doing sample validation.
- The fix itself does not alter the calculations at all. It only affects the sample validation process.
- The failure does not occur with the component distribution set to the `Normal` distribution, as its validation is not defined elementwise (the validation itself is elementwise).
- I've split the `test_mixture_same_family_log_prob` test into two tests based on the `Normal` and `Binomial` distributions.
- Initially, the `Binomial` version of the test did not fail, but this was due to the component distribution having equal batch dimensions of (5, 5) so I changed it to (10, 5).

### Updated fix explanation (for all cases)

- The previous fix caused a bug in sample shape validation (which is done correctly) due to the padding taking place before the sample validation.
- The updated fix corrects the support to reflect the fact that the support of `MixtureSameFamily` is equal to the support of its components distribution with the first event dimension removed.
- This issue was already anticipated in the [code](331423e5c2/torch/distributions/mixture_same_family.py (L127)).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151317
Approved by: https://github.com/albanD, https://github.com/fritzo
2025-05-14 19:24:36 +00:00
b027cb8f9e [Docs] Add Description of validate_args for torch.distributions (#152173)
Fixes #152165

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152173
Approved by: https://github.com/soulitzer
2025-04-30 18:01:20 +00:00
2ed2cb5805 add generalized pareto distribution (GPD) (#135968)
Add the GPD as a distribution class

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135968
Approved by: https://github.com/albanD

Co-authored-by: Alexander März <statmixedmlgit@gmail.com>
2025-04-17 18:51:02 +00:00
6c38b9be73 [typing] Add type hints to __init__ methods in torch.distributions. (#144197)
Fixes #144196
Extends #144106 and #144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-04-06 17:50:35 +00:00
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
302f56a1f2 Revert "Fix non-bitwise type annotations for Tensor operators (see #145838) (#146845)"
This reverts commit 59b7e52ad8f6146b4364515a7f3e54d6f3edd6da.

Reverted https://github.com/pytorch/pytorch/pull/146845 on behalf of https://github.com/jeanschmidt due to Seems to break a few code dependencies in multiple places ([comment](https://github.com/pytorch/pytorch/pull/146845#issuecomment-2666656834))
2025-02-18 19:01:27 +00:00
59b7e52ad8 Fix non-bitwise type annotations for Tensor operators (see #145838) (#146845)
Fix https://github.com/pytorch/pytorch/issues/145838

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146845
Approved by: https://github.com/Skylion007
2025-02-17 22:42:16 +00:00
64cd81712d torch.distributions: replace numbers.Number with torch.types.Number. (#145086)
Fixes #144788 (partial)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145086
Approved by: https://github.com/malfet
2025-01-27 20:24:55 +00:00
5e4cf3e6ad Moved .all() checks for distributions to _is_all_true (#145029)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145029
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2025-01-18 07:55:48 +00:00
983bf604e5 ReshapeTransform: added missing argument in docstring (#144401)
See https://github.com/pytorch/pytorch/pull/144197#discussion_r1907336339

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144401
Approved by: https://github.com/janeyx99, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
2025-01-13 17:59:59 +00:00
ad221269b0 remove allow-untyped-defs from torch/distributions/pareto.py (#144624)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144624
Approved by: https://github.com/Skylion007
2025-01-12 00:06:56 +00:00
dcc3cf7066 [BE] fix ruff rule E226: add missing whitespace around operator in f-strings (#144415)
The fixes are generated by:

```bash
ruff check --fix --preview --unsafe-fixes --select=E226 .
lintrunner -a --take "RUFF,PYFMT" --all-files
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144415
Approved by: https://github.com/huydhn, https://github.com/Skylion007
2025-01-08 21:55:00 +00:00
628acc4ace Dirichlet.mode: use dim= instead of axis= (#144402)
`axis=` is undocumented and will raise typing errors when #144197 is merged.

See: https://github.com/pytorch/pytorch/pull/144197#pullrequestreview-2537398866

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144402
Approved by: https://github.com/Skylion007
2025-01-08 20:14:01 +00:00
768d73f692 use torch.special.xlogy to implement x_log_x (#144220)
Fixes #144279

Using `x* x.log()` does not produce the correct value when `x=0`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144220
Approved by: https://github.com/Skylion007
2025-01-08 17:41:55 +00:00
355b0bc7e3 [typing] Add type hints to @property and @lazy_property in torch.distributions. (#144110)
Fixes #76772, #144196
Extends #144106

- added type annotations to `lazy_property`.
- added type annotation to all `@property` and `@lazy_property` inside `torch.distributions` module.
- added simply type-check unit test to ensure type inference is working.
- replaced deprecated annotations like `typing.List` with the corresponding counterpart.
- simplified `torch.Tensor` hints with plain `Tensor`, otherwise signatures can become very verbose.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144110
Approved by: https://github.com/Skylion007
2025-01-07 19:27:36 +00:00
5e8e1d725a Remove some unused type ignores (round 1) (#142325)
Over time, a large number of the existing type ignores have become irrelevant/unused/dead as a result of improvements in annotations and type checking.

Having these `# type: ignore` linger around is not ideal for two reasons:

- They are verbose/ugly syntatically.
- They could hide genuine bugs in the future, if a refactoring would actually introduce a bug but it gets hidden by the ignore.

I'm counting over 1500 unused ignores already. This is a first PR that removes some of them. Note that I haven't touched type ignores that looked "conditional" like the import challenge mentioned in https://github.com/pytorch/pytorch/pull/60006#issuecomment-2480604728. I will address these at a later point, and eventually would enable `warn_unused_ignores = True` in the mypy configuration as discussed in that comment to prevent accumulating more dead ignores going forward.

This PR should have no effect on runtime at all.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142325
Approved by: https://github.com/Skylion007, https://github.com/janeyx99
2024-12-09 18:23:46 +00:00
f2d388eddd [BE] Use torch.special.expm1 (#141518)
Instead of `torch.exp(x)-1`, as suggested by TorchFix

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141518
Approved by: https://github.com/kit1980
2024-11-26 01:47:11 +00:00
090b778b8a Clarify meaning of rate parameter in Gamma distribution (#134847)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134847
Approved by: https://github.com/fritzo
2024-11-09 00:22:13 +00:00
c0582fd0f8 Remove unused Python variables in torch/[b-z]* (#136963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963
Approved by: https://github.com/ezyang
2024-10-19 16:45:22 +00:00
31715be72a [BE]: Update mypy to 1.11.2 (#133816)
Updates mypy to 1.11.1 to improve type inference

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816
Approved by: https://github.com/ezyang
2024-09-16 19:44:11 +00:00
3117f2cf67 Revert "[BE]: Update mypy to 1.11.2 (#133816)"
This reverts commit 55299cfc223fa838aadd8d6d6fa3ed541fa5acd1.

Reverted https://github.com/pytorch/pytorch/pull/133816 on behalf of https://github.com/jeanschmidt due to seems to have broken https://github.com/pytorch/pytorch/actions/runs/10865710499/job/30155699792 on main ([comment](https://github.com/pytorch/pytorch/pull/133816#issuecomment-2352377684))
2024-09-16 09:11:16 +00:00
55299cfc22 [BE]: Update mypy to 1.11.2 (#133816)
Updates mypy to 1.11.1 to improve type inference

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816
Approved by: https://github.com/ezyang
2024-09-14 21:40:36 +00:00
e72e924eb5 Add correct typing annotations to rsample() for all distributions (#133516)
Fixes #133514
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133516
Approved by: https://github.com/Skylion007
2024-08-18 20:31:54 +00:00
b25ef91bf1 [BE][Easy][18/19] enforce style for empty lines in import segments in torch/d*/ (#129770)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129770
Approved by: https://github.com/wconstab
2024-08-01 04:22:50 +00:00
93a14aba6e [BE]: Update mypy to 1.10.0 (#127717)
Updates mypy to the latest and greatest.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127717
Approved by: https://github.com/ezyang
2024-06-13 15:57:13 +00:00
624e8ae491 Documentation for is_dependent function (#128197)
Docstring for torch.distributions.constraints.is_dependent

Fixes #127900

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128197
Approved by: https://github.com/fritzo, https://github.com/malfet
2024-06-12 17:50:41 +00:00
7c12cc7ce4 Flip default value for mypy disallow_untyped_defs [6/11] (#127843)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843
Approved by: https://github.com/oulgen
ghstack dependencies: #127842
2024-06-08 18:49:29 +00:00
f99409903c Documenting torch.distributions.utils.clamp_probs (#128136)
Fixes https://github.com/pytorch/pytorch/issues/127889

This PR adds docstring to the `torch.distributions.utils.clamp_probs` function.

Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128136
Approved by: https://github.com/janeyx99, https://github.com/svekars, https://github.com/malfet
2024-06-07 00:49:41 +00:00
80d34217c6 Typo fixes: et al. (#127811)
"et al." is short for _et alia_ and should be abbreviated with a period on the second word. Noticed this typo when reading through the SGD docs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127811
Approved by: https://github.com/janeyx99
2024-06-06 01:03:25 +00:00
84776d7597 Revert "[BE]: Update mypy to 1.10.0 (#127717)"
This reverts commit 30213ab0a7b27277e76ea9dd707ce629a63d91ee.

Reverted https://github.com/pytorch/pytorch/pull/127717 on behalf of https://github.com/huydhn due to I am not sure why but the failures look legit and they are showing up in trunk 30213ab0a7 ([comment](https://github.com/pytorch/pytorch/pull/127717#issuecomment-2144183347))
2024-06-03 02:52:47 +00:00
30213ab0a7 [BE]: Update mypy to 1.10.0 (#127717)
Updates mypy to the latest and greatest.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127717
Approved by: https://github.com/ezyang
2024-06-02 21:07:23 +00:00
67ef2683d9 [BE] wrap deprecated function/class with typing_extensions.deprecated (#127689)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

Resolves #126888

- #126888

This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
2024-06-02 12:30:43 +00:00
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0a8325cbad4734a563aa459ca611991.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00
749a132fb0 [BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.

Resolves #126888

- #126888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00
5a1216bb2e [BE]: Update ruff to 0.4.1 (#124549)
Update ruff to 0.4.1 .
This version fixes a lot false negatives/false positives, is 20-40% faster, and has various other bug fixes.

Below is a before and after table showing the execution time of ruff lint and ruff format in milliseconds courtesy of https://astral.sh/blog/ruff-v0.4.0

| Repository                                         | Linter (v0.3) | Linter (v0.4) | Formatter (v0.3) | Formatter (v0.4) |
|----------------------------------------------------|---------------|---------------|------------------|------------------|
| [pytorch/pytorch](https://github.com/pytorch/pytorch) | 328.7         | 251.8         | 351.1            | 274.9            |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124549
Approved by: https://github.com/ezyang
2024-04-21 14:06:23 +00:00
a77be631e0 Bugfix to MixtureSameFamily's _pad_mixture_dimension (#118947)
Fixes Issue #73792

This is a duplicate of pull request. #73864. It's a small bugfix that should have happened a long time ago, but it didn't because I didn't actually follow up with the pull request after originally submitting. That's my bad. Trying to remedy the error.

This contains a fix to _pad_mixture_dimension, which intends to count the number of dimensions in its referent tensors, but accidentally counts the number of elements (and can thus end up creating tensors with potentially thousands of dimensions by mistake). Also contains a single test for the fixed behavior.

Co-authored-by: Jeffrey Wan <soulitzer@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118947
Approved by: https://github.com/soulitzer
2024-02-06 16:24:22 +00:00
46712b019d Enable local_partial_types (#118467)
When using dmypy, this setting is enabled and cannot be turned off. Force it for regular mypy too.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118467
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418, #118432
2024-01-28 13:38:22 +00:00
aa6920c542 Fix hang in VonMises rejection sampling for small values of concentration (#114498)
Fixes #88443

Forces the internal `dtype` of `torch.distributions.von_mises.VonMises` to be `torch.double` and mirrors the numpy implementation of the second order Taylor expansion for `concentration < 1e-5`. Samples and log probs are returned with `dtype` of argument `loc`.

In principle one could also use masking in the rejection sampler to return uniformly distributed numbers for `concentration < 1e-8`, as in numpy. This may be slightly more efficient, but isn't required to solve the hanging issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114498
Approved by: https://github.com/fritzo
2023-12-04 23:07:06 +00:00
6f4409073f [doc] two diff meanings of rv generated by torch.tensor.geometric_ and torch.distributions.geometric.Geometric (#113183)
The meaning of random variables generated by `torch.tensor.geometric_` and `torch.distributions.geometric.Geometric` are different, and they are defined by two different PMFs.
Inform the user, so the user can choose their desired one.

Background: https://github.com/pytorch/pytorch/pull/37984#issuecomment-630336511

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113183
Approved by: https://github.com/albanD
2023-11-15 03:49:04 +00:00
5296c14094 Add inverse gamma distribution and fix sign bug in PowerTransform. (#104501)
This PR comprises a few small contributions:

1. `PowerTransform` returned a sign of `+1` irrespective of exponent. However, it should return the sign of the exponent because the gradient has the same sign as the exponent. That issue has been fixed.
2. Added tests to catch errors akin to 1. in the future.
3. Added an `InverseGamma` distribution as a `TransformedDistribution` with `PowerTransform(-1)` and `Gamma` base distribution. The `InverseGamma` is often used as a prior for the length scale of Gaussian processes to aggressively suppress short length scales (see [here](https://betanalpha.github.io/assets/case_studies/gaussian_processes.html#323_Informative_Prior_Model) for a discussion).

Note: I added a `positive` constraint for the support of the inverse gamma distribution because the `PowerTransform(-1)` can fail for `nonnegative` constraints if the random variable is zero.

```python
>>> torch.distributions.InverseGamma(0.5, 1.0).log_prob(torch.zeros(1))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-8-758aa22deacd> in <module>
----> 1 torch.distributions.InverseGamma(0.5, 1.0).log_prob(torch.zeros(1))

~/git/pytorch/torch/distributions/transformed_distribution.py in log_prob(self, value)
    140         """
    141         if self._validate_args:
--> 142             self._validate_sample(value)
    143         event_dim = len(self.event_shape)
    144         log_prob = 0.0

~/git/pytorch/torch/distributions/distribution.py in _validate_sample(self, value)
    298         valid = support.check(value)
    299         if not valid.all():
--> 300             raise ValueError(
    301                 "Expected value argument "
    302                 f"({type(value).__name__} of shape {tuple(value.shape)}) "

ValueError: Expected value argument (Tensor of shape (1,)) to be within the support (GreaterThan(lower_bound=0.0)) of the distribution InverseGamma(), but found invalid values:
tensor([0.])
```

This differs from the scipy implementation.

```python
>>> scipy.stats.invgamma(0.5).pdf(0)
0.0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104501
Approved by: https://github.com/fritzo, https://github.com/ezyang
2023-11-01 02:26:25 +00:00
e08577aec5 Spelling fix (#108490)
Fixes spelling mistake: non-deterinistic -> non-deterministic
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108490
Approved by: https://github.com/ezyang
2023-09-04 16:59:35 +00:00
660e8060ad [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-22 23:16:38 +00:00
d59a6864fb Revert "[BE]: Update ruff to 0.285 (#107519)"
This reverts commit 88ab3e43228b7440a33bf534cde493446a31538c.

Reverted https://github.com/pytorch/pytorch/pull/107519 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR breaks internal tests. @ezyang, can you please hep them get unblocked? It seems like one of the strings was prob accidentally modified ([comment](https://github.com/pytorch/pytorch/pull/107519#issuecomment-1688833480))
2023-08-22 19:53:32 +00:00