Commit Graph

721 Commits

Author SHA1 Message Date
c855f8632e Pyrefly suppressions 7/n (#164913)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Almost there!

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
 INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164913
Approved by: https://github.com/oulgen
2025-10-08 07:27:17 +00:00
35c4130fd1 [2/N] Fix ruff warnings (#164460)
Apply ruff `SIM` rules.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164460
Approved by: https://github.com/ezyang
2025-10-04 03:40:32 +00:00
f7ab8a2710 [1/N] Fix ruff warnings (#164333)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164333
Approved by: https://github.com/albanD
2025-10-01 16:48:32 +00:00
d2c5f231f6 Fix the shape check inside gnll loss (#147522)
Fixes #147521
This modification allow user to put any size of var in GaussianNLLLoss if the var is broadcastable (to input/target's size)

Therefore, the demo code in #147521 will result in expected behaviour and correct output.

This allow all input size that match:
`input.size = (..., n, ...), var.size = (..., 1, ...)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147522
Approved by: https://github.com/mikaylagawarecki
2025-09-30 18:40:15 +00:00
5c020beba4 Update LPPool docs to clarify ceil_mode padding semantics when ceil_mode=True (#163186)
# Summary

- Add a note to each `nn.LPPool*d` docstring explaining how `ceil_mode=True` interacts with right padding.
- Mirror the same clarification in the `torch.nn.functional.lp_pool*` docstrings so the rendered functional docs stay in sync.

# Motivation

The current PyTorch spec for **LPPool** does not fully match runtime behavior, which has led to downstream confusion in other specs (e.g., ONNX) and runtimes (e.g., [onnxruntime issue #25848](https://github.com/microsoft/onnxruntime/issues/25848)). A corresponding clarification was also made in the ONNX spec: [onnx/onnx#5741](https://github.com/onnx/onnx/pull/5741).

PyTorch’s **LPPool** implementation calls into **AvgPool**, which enforces the rule that windows starting entirely in the right padded region are ignored when `ceil_mode=True`. As a result, **LPPool** inherits the same behavior.

This is an edge case where the output size formula shown in the LPPool docs/spec is not sufficient on its own. Without the added caveat, the documentation is technically incorrect. This PR brings the LPPool docs in line with actual behavior.

Note that this is a trivial fix to the spec as all major implementers of the spec adhere to this caveat.

For comparison, both **MaxPool** and **AvgPool** already include this clarification in their spec. Their docstrings explicitly state:

> *When `ceil_mode=True`, sliding windows are allowed to go off-bounds if they start within the left padding or the input. Sliding windows that would start in the right padded region are ignored.*

Adding the same note to LPPool ensures consistency across all pooling operators.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163186
Approved by: https://github.com/mikaylagawarecki
2025-09-30 15:22:46 +00:00
3cda34ebde [2/N] Apply ruff UP035 check in torch files (#164054)
This is the result of applying the ruff `UP035` check.
`Callable` is imported from `collections.abc` instead of `typing`.
`TypeAlias` is also imported from `typing`.
This PR is the follow-up of #163947.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164054
Approved by: https://github.com/ezyang, https://github.com/Skylion007
2025-09-29 03:35:32 +00:00
1621b5494c Removed redundant dtype conversion in scaled_dot_product_attention docstring example (#161613)
Suggested changes done for Fixes #161611.

Removed the line attn_bias.to(query.dtype) entirely

Fixes #161611
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161613
Approved by: https://github.com/mikaylagawarecki
2025-08-28 19:58:07 +00:00
641ee74781 Revert "Add label_smoothing param in nn.BCELoss and nn.BCEWithLogitsLoss (#150282)"
This reverts commit f990490a23815ea6ee27e487c70ba2cf513ba43d.

Reverted https://github.com/pytorch/pytorch/pull/150282 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/150282#issuecomment-3182844949))
2025-08-13 09:01:52 +00:00
f990490a23 Add label_smoothing param in nn.BCELoss and nn.BCEWithLogitsLoss (#150282)
Fixes #91545

## Changes

- Add `label_smoothing` param and docs
- Add test case for `label_smoothing`
- Remove duplicate description in `nn.BCELoss` and `nn.BCEWithLogitsLoss`

##  Test Result

```bash
pytest -s test/test_nn.py -k test_bce
```

![image](https://github.com/user-attachments/assets/30c0b7fe-fe49-4aa0-9b05-4d70403a7b05)

![image](https://github.com/user-attachments/assets/4fe3fd1c-54b8-4012-afd9-133ce9fb4964)

![image](https://github.com/user-attachments/assets/5cad019a-3a4c-475a-9fde-9c1acad5792d)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150282
Approved by: https://github.com/cyyever, https://github.com/mikaylagawarecki
2025-08-12 09:37:03 +00:00
d7a5ec9355 Fix the Doc of padding in avg_poolnd (#159142)
Fixes #159141

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159142
Approved by: https://github.com/mikaylagawarecki
2025-07-31 02:02:48 +00:00
db259bd6b8 [BE][12/16] fix typos in torch/ (#156602)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156602
Approved by: https://github.com/justinchuby, https://github.com/albanD
ghstack dependencies: #156318, #156320
2025-07-02 22:55:29 +00:00
c808af514d Support deterministic upsample trilinear backward (#154239)
Fixes https://github.com/pytorch/pytorch/issues/154183
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154239
Approved by: https://github.com/eellison, https://github.com/albanD
2025-06-26 15:02:27 +00:00
596b418391 [BE][PYFMT] migrate PYFMT for {torch,test}/{nn,optim}/** to ruff format (#144548)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548
Approved by: https://github.com/ezyang
2025-06-14 11:27:04 +00:00
671553bd23 Update documentation wording for transformer-related layers (#155123)
<img width="947" alt="Screenshot 2025-06-04 at 1 33 53 PM" src="https://github.com/user-attachments/assets/4dbb66b3-43f4-4d04-afb5-dc80cec0f2cd" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155123
Approved by: https://github.com/albanD, https://github.com/jbschlosser
2025-06-04 22:20:32 +00:00
31d12b3955 Fix avg_pool2d param kernel_size descripthon (#154353)
Fixes part of #153149

## Test Result

![image](https://github.com/user-attachments/assets/216ffd2b-dd2b-4cf6-9fca-aeed075be5e7)

![image](https://github.com/user-attachments/assets/820cd184-1f8e-4a7a-b64e-15dfb9c7dad2)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154353
Approved by: https://github.com/colesbury
2025-06-04 11:55:01 +00:00
a69da90a9f Add pad limit of avg_poolnd and AvgPoolnd (#152680)
Fixes #152156

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152680
Approved by: https://github.com/mikaylagawarecki
2025-05-04 17:25:22 +00:00
7e2081fa93 Optimize interpolate saturate description (#151304)
Fixes #108225

## Test Result

### Before

![image](https://github.com/user-attachments/assets/bdbf8a5c-d5a4-44a5-b81e-2cbb5b8bfd02)

### After

![image](https://github.com/user-attachments/assets/1c21a27d-1700-4661-9988-dbb1cdc81fa2)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151304
Approved by: https://github.com/albanD

Co-authored-by: albanD <desmaison.alban@gmail.com>
2025-04-17 18:34:29 +00:00
0a6e1d6b9b Expand docs for nn.functional, and make the wording consistent (#148436)
Expands the docs for the loss functions, and makes the wording consistent.

Fixes #148353

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148436
Approved by: https://github.com/albanD
2025-04-14 19:37:12 +00:00
3e9f4f3f78 docs: allow empty targets tensor in ctc_loss (#151080)
docs: allow empty targets tensor in ctc_losswhen target_lengths are zero, as described in issue

Fixes #150995

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151080
Approved by: https://github.com/albanD
2025-04-12 05:26:54 +00:00
4a545eb85d Fix torch.nn.functional.one_hot param num_classes optional description (#146470)
`torch.nn.functional.one_hot` [document](https://pytorch.org/docs/stable/generated/torch.nn.functional.one_hot.html) describe param `num_classes` not optional, but user can call method without pass it.

![image](https://github.com/user-attachments/assets/4e6d4feb-691f-451f-95b5-4ac11bac7bc2)

```python
>>> import torch
>>> a = torch.arange(0, 5) % 3  # [0,1,2,0,1]
>>> torch.nn.functional.one_hot(a)
tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 0],
        [0, 1, 0]])

```

`num_classes` has default value -1

93d98aca31/aten/src/ATen/native/native_functions.yaml (L6154-L6157)

## Test Result

![image](https://github.com/user-attachments/assets/2c7203b7-6226-4ebc-84c8-cbf912fc48e2)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146470
Approved by: https://github.com/albanD
2025-02-06 07:48:05 +00:00
0afd335174 PEP585 update - torch/nn torch/optim torch/package torch/profiler torch/serialization torch/sparse torch/xpu (#145175)
See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145175
Approved by: https://github.com/bobrenjc93
2025-01-21 16:57:27 +00:00
5fd881a5b6 Revert "PEP585 update - torch/nn torch/optim torch/package torch/profiler torch/serialization torch/sparse torch/xpu (#145175)"
This reverts commit 54a00af2c6026a830f40d9e6a659ff81d51f9bc6.

Reverted https://github.com/pytorch/pytorch/pull/145175 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to break some trunk tests ([comment](https://github.com/pytorch/pytorch/pull/145175#issuecomment-2603418267))
2025-01-21 00:49:55 +00:00
54a00af2c6 PEP585 update - torch/nn torch/optim torch/package torch/profiler torch/serialization torch/sparse torch/xpu (#145175)
See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145175
Approved by: https://github.com/bobrenjc93
2025-01-20 22:32:59 +00:00
cyy
d87aad6877 [5/N] Apply Ruff fixes and pyupgrade to Python 3.9 (#144205)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144205
Approved by: https://github.com/albanD
2025-01-15 04:00:47 +00:00
b8f383107e Link to transformer tutorial in transformer docs (#144425)
<img width="1045" alt="Screenshot 2025-01-08 at 4 50 20 PM" src="https://github.com/user-attachments/assets/05adfecb-8a23-4c48-9a2c-50c5b3f886b0" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144425
Approved by: https://github.com/albanD
2025-01-09 17:42:09 +00:00
a9d84875a9 Fix mha torch._check in jit tracing (#142059)
Test Plan: `buck2 run @//mode/dev-nosan //mobile-vision/d2go/projects_oss/detr:tests -- -r test_detr_fbnet_export`

Differential Revision: D66769339

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142059
Approved by: https://github.com/ezyang
2024-12-05 18:38:17 +00:00
80705d3abf Convert assert to torch._check in MHA (#141918)
Fixes https://github.com/pytorch/pytorch/issues/139610
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141918
Approved by: https://github.com/ezyang
2024-12-03 21:58:02 +00:00
723498aab8 Gaussian nll loss scalar variance support (#138931)
Fixes #138747

Adds support for `variance` being a Tensor or a float in `gaussian_nll_loss` to avoid a cpu-gpu sync point in the loss function, when the variance is a static tensor like `<scalar>*torch.ones_like(input)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138931
Approved by: https://github.com/mikaylagawarecki
2024-11-21 18:20:09 +00:00
c1e7d85ce6 Add Weighted Loss Functions to PyTorch : WMSE, WMAE, and Weighted Huber Loss (#132049)
#### Summary
This pull request introduces new weighted loss functions to the PyTorch library: `weighted_huber_loss`, `wmse_loss`, and `wmae_loss`. These functions allow for precise control over the influence of each sample during training, important for imbalanced data or when certain samples are more significant than others.

#### Changes
- **`weighted_huber_loss`**: Huber loss modified to incorporate weights, providing a balance between L1 and L2 loss based on the `delta` parameter.
- **`wmse_loss`** (Weighted Mean Squared Error): Applies weights to the standard MSE loss, useful for emphasizing certain samples in regression tasks.
- **`wmae_loss`** (Weighted Mean Absolute Error): Adjusts MAE loss calculation by including weights, ideal for datasets with outliers.

#### Code Details
- **Input Validation**: Ensures `input`, `target`, and `weights` tensors match in size to prevent broadcasting errors.
- **Reduction Options**: Supports `none`, `mean`, and `sum` reductions to suit various computational needs.
- **Backward Compatibility**: Maintains support for deprecated arguments `size_average` and `reduce`, while encouraging use of the `reduction` argument.

#### Usage Example
```python
import torch
input = torch.tensor([0.5, 2.5, 2.0], dtype=torch.float32)
target = torch.tensor([0.0, 2.0, 1.5], dtype=torch.float32)
weights = torch.tensor([1.0, 0.5, 1.5], dtype=torch.float32)

loss = weighted_huber_loss(input, target, weights, delta=1.0)
print(loss)
```
---

Feedback on these implementations is welcome; please let me know if further modifications are required.

Resolves #132465

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132049
Approved by: https://github.com/mikaylagawarecki

Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
2024-10-31 21:59:43 +00:00
c0582fd0f8 Remove unused Python variables in torch/[b-z]* (#136963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963
Approved by: https://github.com/ezyang
2024-10-19 16:45:22 +00:00
83a3ee0699 Support embedding_bag() with NJT input (#135888)
Fixes #93843

`EmbeddingBag()` / `embedding_bag()` support 1D inputs with offsets to handle raggedness. NJT is a natural fit here as it already maintains offsets of the same form. This PR updates the python-side to support NJT and adds corresponding OpInfo-based NJT tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135888
Approved by: https://github.com/cpuhrsch
2024-09-23 17:35:19 +00:00
e6c3f58584 Fix example: Address broadcasting error in the addition of `attn_bias… (#135427)
…` and `attn_mask`, and correct device assignment for newly created variables in the method.

Fix example: Address broadcasting error in the addition of `attn_bias` and `attn_mask`, and correct device assignment for newly created variables in the method.

1. Adding `attn_bias += attn_mask` results in a broadcasting error. The expected shape of `attn_bias` is (L, S), so the output should also have the shape (L, S). However, when the input shape is (N, num_heads, L, S), broadcasting occurs, leading to an output shape of (N, num_heads, L, S), which is not desired.
2. `attn_bias` is a newly created variable within the method, but it is not assigned to the correct device.

**This is my retry of PR #130209 . The PR has been merged into commit `d4a79d4a7c746068d25fe5cf9333495561f4ce1f`, but the modifications were overwritten by subsequent commits.**

Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
@mikaylagawarecki  provided a more elegant implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135427
Approved by: https://github.com/ezyang
2024-09-09 03:47:34 +00:00
85fa019697 [Docs] Fix call to deprecated function (#135037)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135037
Approved by: https://github.com/janeyx99, https://github.com/jbschlosser
2024-09-03 20:57:11 +00:00
8bc5ef563e Grouped Query Attention (#132689)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Differential Revision: D60772086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132689
Approved by: https://github.com/drisspg
2024-08-07 05:35:36 +00:00
c7cfa51721 Always use high precision for SDPA math backend (#128922)
Summary:
feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts.

Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16.

Differential Revision: D58710805

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922
Approved by: https://github.com/xw285cornell, https://github.com/drisspg
2024-08-04 23:58:14 +00:00
bcb4f7c172 Revert "Grouped Query Attention (#128898)"
This reverts commit 6b28af1b79eaa63e2f423d925bbd42330582983f.

Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/ZainRizvi due to Sorry, this broke a bunch of tests internally. See D60638265 ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2265961038))
2024-08-02 18:58:46 +00:00
59b73079a0 Revert "Always use high precision for SDPA math backend (#128922)"
This reverts commit fbf3bc0a602b4ec1eab169202d5b1158fe2c1def.

Reverted https://github.com/pytorch/pytorch/pull/128922 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR has a dependency on another PR (https://github.com/pytorch/pytorch/pull/128898) that has to be reverted ([comment](https://github.com/pytorch/pytorch/pull/128922#issuecomment-2265949958))
2024-08-02 18:46:50 +00:00
fbf3bc0a60 Always use high precision for SDPA math backend (#128922)
Summary:
feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts.

Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16.

Differential Revision: D58710805

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922
Approved by: https://github.com/xw285cornell, https://github.com/drisspg
2024-08-01 18:55:48 +00:00
6b28af1b79 Grouped Query Attention (#128898)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
2024-07-31 22:58:51 +00:00
499ead96ff Revert "Grouped Query Attention (#128898)"
This reverts commit d039b14207fe659d664c590efc06cc0a2abc96c0.

Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/albanD due to Broken test on main ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2258314481))
2024-07-30 13:11:24 +00:00
d039b14207 Grouped Query Attention (#128898)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
2024-07-29 21:49:06 +00:00
d4a79d4a7c Fix an example: Resolve broadcasting error in attn_bias and attn_mask… (#130209)
… addition, fix device assignment for newly created variables in method

Fix an example: Resolve broadcasting error in attn_bias and attn_mask addition, fix device assignment for newly created variables in method

1. `attn_bias += attn_mask` would cause a broadcasting error. Because the shape of `attn_bias` is (L, S), the shape of the output would be expected as (L, S) too. When the shape of input is (N, num_heads, L, S), a broadcasting should be triggered. Then, the shape of the output would be (N, num_heads, L, S), which is unexpected.
2. `attn_bias` is a newly created variables in method, which is not assigned device.

**This is my retry of #130200 .** I used a wrong account in that pr.

Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130209
Approved by: https://github.com/mikaylagawarecki
2024-07-19 15:23:22 +00:00
52cb9abb1d Add deterministic support in nn.functional.interpolate for XPU (#129864)
Both for CUDA and XPU, there are no deterministic implementation at native in `aten::upsample_bilinear` and `aten::replication_pad`. CUDA leverage operator decomposition path in frontend hook `nn.functional.interpolate` as its deterministic implentation. XPU backend uses the same solution in this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129864
Approved by: https://github.com/dvrogozh, https://github.com/albanD, https://github.com/EikanWang
2024-07-19 02:15:42 +00:00
662e9e1076 [BE] enable UFMT for torch/nn/functional.py (#128592)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128592
Approved by: https://github.com/mikaylagawarecki
2024-06-24 06:24:12 +00:00
cc8193c707 Revert "[BE] enable UFMT for torch/nn/functional.py (#128592)"
This reverts commit f6e6e55fa7d883a89ba99584f8632c260519ba73.

Reverted https://github.com/pytorch/pytorch/pull/128592 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/128592#issuecomment-2181783936))
2024-06-21 00:44:16 +00:00
f6e6e55fa7 [BE] enable UFMT for torch/nn/functional.py (#128592)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128592
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #128596, #128594
2024-06-17 16:29:29 +00:00
67ef2683d9 [BE] wrap deprecated function/class with typing_extensions.deprecated (#127689)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

Resolves #126888

- #126888

This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
2024-06-02 12:30:43 +00:00
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0a8325cbad4734a563aa459ca611991.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00
ff65b18fcf Update the is_causal explaination in the SDPA doc (#127209)
Fixes #126873

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127209
Approved by: https://github.com/drisspg
2024-05-29 18:53:17 +00:00
749a132fb0 [BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.

Resolves #126888

- #126888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00