Commit Graph

57 Commits

Author SHA1 Message Date
1cf62e86a4 skip various unit tests for Jetson (#122531)
skip multiprocessing, cuda expandable segments, mem eff and flash attention tests on Jetson due to hanging / sigkill issues from nvidia internal testing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122531
Approved by: https://github.com/eqy, https://github.com/malfet
2024-04-16 01:26:26 +00:00
f4e2a226aa ScoreMod API (#121845)
# Summary

This PR adds a new higher-order_op: `templated_attention`.  This op is designed to extend the functionality of torch.nn.fucntional.scaled_dot_product_attention.  PyTorch has efficient pre-written fused-attention kernels. However, users want to modify how scores are computed (a substep inside attention) -- this traditionally requires the user to write their own attention kernel. One such modification to attention scores that is not currently supported by the top level SDPA op is:[ Attention with Linear Biases (ALiBi](https://arxiv.org/abs/2108.12409)).

This higher-order op will instead accept a callable( 'score_mod') function that is through torch.compile will be used to create an efficient attention kernel instantiation.

### Details

This HOP utilizes the existing fx and HOP infra to capture and convert the User `score-mod` function and convert to an FX graph module. Inductor then consumes this HOP that has a `ir.Subgraph` input. It will inline this lowered subgraph into a triton kernel which performs fused attention with the modification to the scores matrix inlined.

### API

The API for a score_mod function should be as follows:

```Python
def score_mod(score: torch.Tensor, batch: torch.Tensor, head: torch.Tensor, token_1: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```

This function receives five parameters:

- `score`: A scalar tensor representing the attention score, with the same data type and device as the query, key, and value tensors.
- `batch`, `head`, `seq_len_q`, `seq_len_kv`: Scalar tensors indicating the batch index, head index, query index, and key/value index, respectively, with torch.int data type and located on the same device as the score tensor.

Consider inputs query, key, value of shapes (2, 4, 16, 8), leading to an intermediate attention score matrix of shape (2, 4, 16, 16)

The score_mod function will be vectorized over each element of this matrix. For instance, modifying the score at the position corresponding to the 0th batch, 2nd head, between the 8th query and the 9th key element, would be invoked as:

```Python
score_mod(score[0,2,8,9], torch.tensor(0), torch.tensor(2), torch.tensor(8), torch.tensor(9))
```

### Examples
```Python
import torch
from torch.nn.attention.templated_attention import templated_attention

torch.manual_seed(0)

# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim)
query = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)
key = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)
value = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)

# Lets create a fun new score_modification! I will call this
# Checkerboard. It will reduce the score for neighboring tokens (1 step apart)
# in the sequence. And increase the score for tokens 2 steps apart. For everything
# else, the score will remain the same.

def checkerboard(score, batch, head, token_q, token_kv):
    score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
    score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
    return score

# Lets call templated_attention with this new score modification
output = templated_attention(query, key, value, score_mod=checkerboard)

compiled_templated_attention = torch.compile(templated_attention)
out_compiled = compiled_templated_attention(query, key, value, score_mod=checkerboard)

torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```

### Future Work
- This PR is currently only forward only. However the triton kernel for backwards where score_modifications to not rely on external buffers has been explored here: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/flash/flash_attention.py
- Kernel Improvements; There are has been some larger updates to the fused attention implementation that Triton uses in its tutorials. The implementation of this kernel is based on a prior version and should be updated.
- We may want to unify this API under the top level SDPA API and leave that as a follow up once this is more stable
- Should we error on CPU?
- There are some issues with dynamic shapes
- Capturing of free variables and lifting to inputs to the subgraph is not working correctly today

### Performance
Comparisons generated by this benchmark:

| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------|
| Average |     5.412 |              |             |             |             |            |               |                |
| Max     |     8.882 |           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |
| Min     |     3.645 |            8 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |
| Min     |     0.345 |            1 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |

For reference

| Configuration                                 | Forward Time (µ seconds) | Backend          | Speedup |
|-----------------------------------------------|--------------------------|------------------|---------|
| Fastest Config in Sweep (`8 16 4096 4096 64 relative_bias torch.bfloat16`) | 3608                   | Templated Attention                | 1.0  |
| Compiled SDPA (No Mask)                       | 9928                   | Math             | 2.75x   |
| Compiled SDPA (With Mask)                     | 11898                    | Math             | 3.29x   |
| Compiled SDPA (With Mask) | 8704                      | Memory Efficient Attention | 2.42x   |
| Compiled SDPA (No Mask) | 2548                     | FlashAttention2 | 0.706x   |

The speedups are measuring compiled templated attention speed versus different calls to torch.nn.functional.sdpa

<details>

<summary> FULL PERFORMANCE SWEEP NUMBERS </summary>

|   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |   eager_time |   compiled_time |   speedup |
|--------------|-------------|-------------|-------------|------------|---------------|----------------|--------------|-----------------|-----------|
|            1 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      331.444 |          67.221 |     4.931 |
|            1 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |      335.300 |          64.187 |     5.224 |
|            1 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      352.039 |          63.806 |     5.517 |
|            1 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |      371.699 |         711.349 |     0.523 |
|            1 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |      333.488 |          86.455 |     3.857 |
|            1 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |      322.363 |          82.469 |     3.909 |
|            1 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |      349.967 |          82.233 |     4.256 |
|            1 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |      486.359 |        1412.453 |     0.344 |
|            1 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |     2794.597 |         551.188 |     5.070 |
|            1 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |     3965.150 |         513.101 |     7.728 |
|            1 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |     2408.013 |         504.759 |     4.771 |
|            1 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |     6850.531 |       16733.675 |     0.409 |
|            8 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      441.939 |         123.576 |     3.576 |
|            8 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |      560.379 |         116.710 |     4.801 |
|            8 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      421.172 |         115.825 |     3.636 |
|            8 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |      994.492 |        2132.806 |     0.466 |
|            8 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |     1436.430 |         309.495 |     4.641 |
|            8 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |     1892.216 |         290.186 |     6.521 |
|            8 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |     1360.665 |         282.956 |     4.809 |
|            8 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |     3525.532 |        8359.702 |     0.422 |
|            8 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |    22026.839 |        3864.604 |     5.700 |
|            8 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |    31262.746 |        3609.551 |     8.661 |
|            8 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |    20219.079 |        3480.402 |     5.809 |
|            8 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |    54654.647 |      116652.357 |     0.469 |
|           16 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      820.606 |         188.683 |     4.349 |
|           16 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |     1058.362 |         179.295 |     5.903 |
|           16 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      784.372 |         175.714 |     4.464 |
|           16 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |     1890.792 |        4212.877 |     0.449 |
|           16 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |     2781.830 |         557.017 |     4.994 |
|           16 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |     3694.050 |         525.249 |     7.033 |
|           16 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |     2634.164 |         507.613 |     5.189 |
|           16 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |     6959.917 |       15331.116 |     0.454 |
|           16 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |    43889.096 |        7582.018 |     5.789 |
|           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |    62784.293 |        7075.846 |     8.873 |
|           16 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |    40308.606 |        6829.587 |     5.902 |
|           16 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |   108892.137 |      233090.953 |     0.467 |
</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121845
Approved by: https://github.com/Chillee, https://github.com/zou3519
2024-04-06 01:10:44 +00:00
12116aee68 Add Flash Attention support on ROCM (#121561)
This patch addresses the major limitations in our previous [PR #115981](https://github.com/pytorch/pytorch/pull/115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton)

- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
    * MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
    * Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
    * varlen API will be supported in future release of AOTriton
- [x] Only support head dimension 16,32,64,128.
    * Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
    * Kernel is selected according to autotune information from Triton.

Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API

This is a more extensive fix to #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561
Approved by: https://github.com/huydhn
2024-03-28 00:27:38 +00:00
764eae9c4e Revert "Add Flash Attention support on ROCM (#121561)"
This reverts commit a37e22de7059d06b75e4602f0568c3154076718a.

Reverted https://github.com/pytorch/pytorch/pull/121561 on behalf of https://github.com/huydhn due to Sorry for reverting your change but this needs more work to be able to land in fbcode because https://github.com/ROCm/aotriton is not available there atm.  We are working to reland this change before 2.3 release ([comment](https://github.com/pytorch/pytorch/pull/121561#issuecomment-2007717091))
2024-03-19 17:14:28 +00:00
a37e22de70 Add Flash Attention support on ROCM (#121561)
This patch addresses the major limitations in our previous [PR #115981](https://github.com/pytorch/pytorch/pull/115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton)

- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
    * MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
    * Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
    * varlen API will be supported in the next release of AOTriton
- [x] Only support head dimension 16,32,64,128.
    * Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
    * Kernel is selected according to autotune information from Triton.

Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API

This is a more extensive fix to #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561
Approved by: https://github.com/malfet, https://github.com/atalman
2024-03-12 01:16:53 +00:00
cd380c794f [CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)
#113713

Going to clean up some of the checks and will remove draft status after.
Can be tested on SM80+ with `TORCH_CUDNN_MHA_ENABLED=1`.

CC @drisspg @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115663
Approved by: https://github.com/drisspg
2024-02-14 22:02:06 +00:00
113138aa55 add test cases for GradScaler on CPU (#109994)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109994
Approved by: https://github.com/jgong5, https://github.com/ezyang
2024-02-02 21:49:07 +00:00
9bce208dfb Replace follow_imports = silent with normal (#118414)
This is a lot of files changed! Don't panic! Here's how it works:

* Previously, we set `follow_imports = silent` for our mypy.ini configuration. Per https://mypy.readthedocs.io/en/stable/running_mypy.html#follow-imports, what this does is whenever we have an import to a module which is not listed as a file to be typechecked in mypy, we typecheck it as normal but suppress all errors that occurred in that file.
* When mypy is run inside lintrunner, the list of files is precisely the files covered by the glob in lintrunner.toml, but with files in excludes excluded.
* The top-level directive `# mypy: ignore-errors` instructs mypy to typecheck the file as normal, but ignore all errors.
* Therefore, it should be equivalent to set `follow_imports = normal`, if we put `# mypy: ignore-errors` on all files that were previously excluded from the file list.
* Having done this, we can remove the exclude list from .lintrunner.toml, since excluding a file from typechecking is baked into the files themselves.
* torch/_dynamo and torch/_inductor were previously in the exclude list, because they were covered by MYPYINDUCTOR. It is not OK to mark these as `# mypy: ignore-errors` as this will impede typechecking on the alternate configuration. So they are temporarily being checked twice, but I am suppressing the errors in these files as the configurations are not quite the same. I plan to unify the configurations so this is only a temporary state.
* There were some straggler type errors after these changes somehow, so I fixed them as needed. There weren't that many.

In the future, to start type checking a file, just remove the ignore-errors directive from the top of the file.

The codemod was done with this script authored by GPT-4:

```
import glob

exclude_patterns = [
    ...
]

for pattern in exclude_patterns:
    for filepath in glob.glob(pattern, recursive=True):
        if filepath.endswith('.py'):
            with open(filepath, 'r+') as f:
                content = f.read()
                f.seek(0, 0)
                f.write('# mypy: ignore-errors\n\n' + content)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118414
Approved by: https://github.com/thiagocrepaldi, https://github.com/albanD
2024-01-27 02:44:11 +00:00
2f84a9d37c Revert "[CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)"
This reverts commit 5aa92b5090e3db4a053548a3f360dd06c16df2f7.

Reverted https://github.com/pytorch/pytorch/pull/115663 on behalf of https://github.com/PaliC due to Unfortunately, this pr breaks cuda builds internally ([comment](https://github.com/pytorch/pytorch/pull/115663#issuecomment-1899388813))
2024-01-18 23:40:30 +00:00
5aa92b5090 [CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)
#113713

Going to clean up some of the checks and will remove draft status after.
Can be tested on SM80+ with `TORCH_CUDNN_MHA_ENABLED=1`.

CC @drisspg @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115663
Approved by: https://github.com/drisspg
2024-01-18 01:20:36 +00:00
e3ca7346ce Re-add initial Flash Attention support on ROCM (#115981)
Note about the Updates:

This PR:
1. skips more flash attention related UTs on MI200
2. Fix additional ATen compiling errors after hipification
3. Fix the author "root" of a specific commit
4. Includes the patch from Nikita in favor of block level static initialization.

CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge.

Original PR (https://github.com/pytorch/pytorch/pull/114309) Note:

This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- Only supports power of two sequence lengths.
- No support for varlen APIs.
- Only support head dimension 16,32,64,128.
- Performance is still being optimized.

Fixes #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115981
Approved by: https://github.com/malfet
2024-01-04 22:21:31 +00:00
e3aefe2970 Revert "Initial Flash Attention support on ROCM (#114309)" (#115975)
This reverts commit 5bddbed399a89bf2875a38bb84cb869f382f1809.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115975
Approved by: https://github.com/atalman, https://github.com/malfet
2023-12-16 03:40:14 +00:00
5bddbed399 Initial Flash Attention support on ROCM (#114309)
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- [ ] Only supports power of two sequence lengths.
- [ ] No support for varlen APIs.
- [ ] Only support head dimension 16,32,64,128.
- [ ] Performance is still being optimized.

Fixes https://github.com/pytorch/pytorch/issues/112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114309

Approved by: https://github.com/jeffdaily, https://github.com/malfet

---------

Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
2023-12-14 08:52:57 -08:00
e686341f64 Consider that ops can be fused into cat in the min-cut partitioner (#110501)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110501
Approved by: https://github.com/eellison
2023-10-05 01:34:57 +00:00
a2d5f13310 [Inductor CUTLASS backend] Step 5: Gemm CUTLASS templates (#108015)
This is the step 5 to add cutlass as an alternative inductor backend.

Feature request: https://github.com/pytorch/pytorch/issues/106991.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108015
Approved by: https://github.com/kadeng, https://github.com/jansel, https://github.com/aakhundov
ghstack dependencies: #107802, #107847, #107901, #107931
2023-09-12 17:44:38 +00:00
a9c663c269 Revert "Flash Attention v2 (#105602)" (#108827)
This reverts commit add45aea1cc8048fd0b43445b28fec7d93281f00.

There are some conflicts on some benchmark csv file https://github.com/pytorch/pytorch/pull/105602#issuecomment-1710988951 so I need to revert this manually.

The diff has been reverted internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108827
Approved by: https://github.com/kit1980
2023-09-08 07:43:04 +00:00
e45b290127 Revert "Revert "Flash Attention v2 (#105602)" (#108827)"
This reverts commit 24e9bbe22af296048f8242c6112d13cff726c588.

Reverted https://github.com/pytorch/pytorch/pull/108827 on behalf of https://github.com/huydhn due to I need to land this revert properly as there are new failures showing up on trunk ([comment](https://github.com/pytorch/pytorch/pull/108827#issuecomment-1711020924))
2023-09-08 03:25:45 +00:00
24e9bbe22a Revert "Flash Attention v2 (#105602)" (#108827)
This reverts commit add45aea1cc8048fd0b43445b28fec7d93281f00.

There are some conflicts on some benchmark csv file https://github.com/pytorch/pytorch/pull/105602#issuecomment-1710988951 so I need to revert this manually.

The diff has been reverted internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108827
Approved by: https://github.com/kit1980
2023-09-08 02:54:20 +00:00
add45aea1c Flash Attention v2 (#105602)
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985

### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao

### Changes Made
The majority of the changes in this pull request involve:

- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.

### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)

There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.

### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108

### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes

### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.
![flashv2_tflops_vs_seq_len](https://github.com/pytorch/pytorch/assets/32754868/152de46d-8fa6-42f0-9a9c-ef1eb7ae29e7)

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
2023-09-01 22:14:44 +00:00
bb2fcc7659 unify TEST_CUDA (#106685)
Fixes #ISSUE_NUMBER
as title, unify TEST_CUDA
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106685
Approved by: https://github.com/zou3519
2023-08-10 09:01:36 +00:00
3c7331742a test_fused_sdp_choice in test_transformers.py fix (#106587)
sdp dispatcher prioritizes flash attention over efficient attention: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp#L684-L687, and flash attention is enabled for sm75+: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp#L625. Thus, the unit test `test_fused_sdp_choice` from `test_transformers.py` which is failing on T4 (sm75) should have this `SM80OrLater` check changed to `SM75OrLater`: https://github.com/pytorch/pytorch/blob/main/test/test_transformers.py#L1914-L1917.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106587
Approved by: https://github.com/drisspg
2023-08-04 03:43:56 +00:00
1cebfef8a4 sm90 efficient attention test fixes (#105978)
Fixes the following two test cases involving efficient attention on sm90:

Explanations:

functorch/test_ops.py: test_vjp_nn_functional_scaled_dot_product_attention_cuda_float32
* originally the test had xfail for all sm
* in https://github.com/pytorch/pytorch/issues/102029, we found that it was unexpectedly passing on sm90
* I made https://github.com/pytorch/pytorch/pull/102131 to update the test to let it pass
* @drisspg seems to have made changes to the behavior such that the original xfail was getting triggered (https://github.com/pytorch/pytorch/issues/102029#issuecomment-1560071148)
* the CI began complaining about the failure again: https://github.com/pytorch/pytorch/issues/102663
* I'm now reverting https://github.com/pytorch/pytorch/pull/102131 to bring back the original xfail now that the behavior has been fixed by @drisspg to trigger the xfail in sm90 similar to all other sm

test_transformers.py: test_mem_efficient_fail_sm90_cuda
* the test as it's currently written seems to expect the sdp dispatcher to fail for mem efficient attention on sm90; however, testing this on H100, it actually succeeds, so I'm disabling the test for now as the current expected result may be outdated

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105978
Approved by: https://github.com/eqy, https://github.com/kshitij12345, https://github.com/zou3519
2023-07-31 17:59:40 +00:00
be03a56955 [BE] Enable ruff's UP rules and autoformat testing/ (#105425)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105425
Approved by: https://github.com/malfet
2023-07-18 21:04:39 +00:00
c3e4a67905 Refactor multigpu tests to test_cuda_multigpu (#104059)
Mostly refactor, that moves all the tests from `test_cuda` that benefit from multiGPU environment into its own file.

- Add `TestCudaMallocAsync` class for Async tests ( to separate them from `TestCudaComm`)
- Move individual tests from `TestCuda` to `TestCudaMultiGPU`
- Move `_create_scaling_models_optimizers` and `_create_scaling_case` to `torch.testing._internal.common_cuda`
- Add newly created `test_cuda_multigpu` to the multigpu periodic test

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at f4d46fa</samp>

This pull request fixes a flaky test and improves the testing of gradient scaling on multiple GPUs. It adds verbose output for two CUDA tests, and refactors some common code into helper functions in `torch/testing/_internal/common_cuda.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104059
Approved by: https://github.com/huydhn
2023-06-27 05:32:05 +00:00
cd05c3b98c [BE] Use TEST_MULTIGPU from common_cuda.py (#103982)
Comment about `TEST_CUDNN` called over and over has long been alleviated by wrapping the check with `LazyVal`, that caches the results.
Also, delete unused `TEST_MAGMA`.

Prep change for https://github.com/pytorch/pytorch/issues/100006

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at e3a5b39</samp>

> _`common_cuda.py`_
> _Refactored for dynamo tests_
> _Winter code cleanup_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103982
Approved by: https://github.com/atalman, https://github.com/janeyx99
2023-06-22 00:07:44 +00:00
5b01c8dc6a fix functorch/test_ops.py test_vjp flash attention unexpected success (#102131)
add isSm90 check for expected failure in nn.functional.scaled_dot_product_attention in functorch/test_ops.py

Fixes #102029

Uses solution https://github.com/pytorch/pytorch/issues/102029#issuecomment-1560052965 which was verified by
https://github.com/pytorch/pytorch/issues/102029#issuecomment-1560071148

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102131
Approved by: https://github.com/zou3519
2023-05-25 22:17:25 +00:00
3a5427baf4 Add torch.utils._content_store (#99809)
Implements a simple content-addressable store for storages (with tensors implemented as cheap references on top), enabling incremental serialization of tensors to disk, which I intend to use in the accuracy repro extractor.  Check the comment at the top of torch/utils/_content_store.py for more details on the intended use case.

One major piece of this PR is implementing the content hash for tensors.  For our prospective use case, we may need to repeatedly hash up to 80 GB of tensor data every time we snapshot (and we may snapshot multiple times).  Using a conventional cryptographic hash and hashing each snapshot would likely take on order of minutes, which seemed too slow to me.  So instead, I implemented a crappy hash function that can be run on GPU.  It is at least somewhat theoretically grounded: using random parameters generated by Philox, we use the standard shift-multiply and xor sum universal hash family.  The hash function is a bit dorky though; instead of properly doing 160-bit math, it just runs 32-bit hash five times and cats them together.  By the way, this sets the first precedent for kernel in PyTorch library which MUST be torch.compile'd to be run (in fact, this kernel does not run in eager mode because of the use of xor_sum, which doesn't actually exist in ATen.)

I had to add a few more primitives to inductor, namely randint (over the entire int range) and xor_sum.  Fortunately, these primitives are natively supported by Triton/C++, and so they were very easy to plumb through.  xor_sum is exposed as a prim, while randint special cases on when low/high span the entire 32-bit signed integer range.

Thanks to Jeff Johnson for letting me bounce ideas of him on a Saturday morning lol.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99809
Approved by: https://github.com/voznesenskym
2023-04-26 18:02:59 +00:00
cf354a0491 Don't eagerly initialize CUDA when importing common_cuda (#99536)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99536
Approved by: https://github.com/Chillee, https://github.com/bertmaher, https://github.com/albanD
2023-04-19 22:12:10 +00:00
eqy
2fddcf0fc0 [CUDA][CUDA 11] Remove more CUDA 11 version checks (#92934)
Working on removing stragglers missed in previous CUDA version < 11.0 cleanup PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92934
Approved by: https://github.com/ngimel
2023-03-30 19:49:52 +00:00
4610ce49f6 Fix typo under torch/testing directory (#97254)
This PR fixes typo in comments and messages under `torch/testing` directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97254
Approved by: https://github.com/kit1980, https://github.com/malfet
2023-03-23 01:46:17 +00:00
653dc73df0 [SDPA] Wire up FlashAttention's backward (#92917)
# Summary
This PR creates _flash_attention_backward and _scaled_dot_product_flash_attention_backward native functions and registers them to the respective derivatives.yaml.

The goal is to replicate the torch.autograd.Function defined in the FlashAttention repo [here](33e0860c9c/flash_attn/flash_attn_interface.py (L126)) natively in PyTorch.  One thing that we don't have access to is ctx.save_for_backward in native PyTorch so in order to save these variables I extended the returned objects from the forward functions.

### MetaFunctions
I also updated the FlashAttention meta functions to mirror the real outputs now. As well I added a meta registration for backwards. I have an XLMR training script and while eager training now works with FlashAttention compiling this module fails with the inductor error down below.

### Questions?
Performance issues vs mem efficient when using torch.nn.mha_forward

TorchCompile -> See purposed solution below.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92917
Approved by: https://github.com/cpuhrsch
2023-02-02 04:02:30 +00:00
0bf7506051 [CUDA] Drop CUDA < 11.0 test flags (#92605)
Follow-up of #89582 to drop flags like `CUDA11OrLater` in tests. Note that in some places it appears that `TEST_WITH_ROCM` is _implicitly_ guarded against via the `CUDA11OrLater` version check, based on my best-guess of how `torch.version.cuda` would behave in ROCM builds, so I've added `not TEST_WITH_ROCM` in cases where ROCM wasn't previously explicitly allowed.

CC @ptrblck @malfet @ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92605
Approved by: https://github.com/ngimel
2023-01-24 04:34:06 +00:00
38dd4cbdf1 ROCm enable sparse_sampled_addmm (#86401)
Enables:
test_comprehensive_sparse_sampled_addmm_cuda_complex128
test_comprehensive_sparse_sampled_addmm_cuda_complex64
test_comprehensive_sparse_sampled_addmm_cuda_float32
test_comprehensive_sparse_sampled_addmm_cuda_float64
test_dispatch_meta_sparse_sampled_addmm_cuda_complex128
test_dispatch_meta_sparse_sampled_addmm_cuda_complex64
test_dispatch_meta_sparse_sampled_addmm_cuda_float32
test_dispatch_meta_sparse_sampled_addmm_cuda_float64
test_meta_sparse_sampled_addmm_cuda_complex128
test_meta_sparse_sampled_addmm_cuda_complex64
test_meta_sparse_sampled_addmm_cuda_float32
test_meta_sparse_sampled_addmm_cuda_float64

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86401
Approved by: https://github.com/ngimel
2022-10-26 19:39:24 +00:00
247468baf0 [ROCm] More Sparse UTs enablement and more hipification mappings. (#78939)
Enables:

 test_bmm_cuda_float64
 test_bmm_deterministic_cuda_float64
 test_csr_matvec_cuda_complex128
 test_csr_matvec_cuda_complex64
 test_csr_matvec_cuda_float32
 test_csr_matvec_cuda_float64

To enable the above tests had to add some more hip mappings for the hipification process.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78939
Approved by: https://github.com/pruthvistony, https://github.com/malfet
2022-08-23 13:54:09 +00:00
eqy
ad1bff1bff [TF32] Fix typo in tf32 wrapper function (#78438)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78438
Approved by: https://github.com/ngimel
2022-06-03 01:03:43 +00:00
8bb7203049 Add torch.linalg.ldl_factor_ex and torch.linalg.ldl_solve
This PR adds a function for computing the LDL decomposition and a function that can solve systems of linear equations using this decomposition. The result of `torch.linalg.ldl_factor_ex` is in a compact form and it's required to use it only through `torch.linalg.ldl_solve`. In the future, we could provide `ldl_unpack` function that transforms the compact representation into explicit matrices.

Fixes https://github.com/pytorch/pytorch/issues/54847.

cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @Lezcano
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69828
Approved by: https://github.com/Lezcano, https://github.com/mruberry, https://github.com/albanD
2022-04-28 19:23:37 +00:00
d71b8e1a8d More distutils.version.LooseVersion changes (#69947)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69947

Reviewed By: seemethere

Differential Revision: D33111996

Pulled By: malfet

fbshipit-source-id: e7d2cc4ed3e39452e809965e360b05f0b409ec0d
2021-12-15 08:07:36 -08:00
541eb1db63 Add cuSPARSE descriptors and update CSR addmm (#60838)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60838

Rewrote `addmm_out_sparse_csr_dense_cuda` implementation using new cusparse descriptors.

`addmm` now works without conversions with both 32-bit and 64-bit indices.
The dense tensors can have a row- or column-major layout. If the dense tensors are a contiguous slice of a larger tensor, the storage is used directly without temporary copies.

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D30643191

Pulled By: cpuhrsch

fbshipit-source-id: 5555f5b59b288daa3a3987d322a93dada63b46c8
2021-09-30 11:32:51 -07:00
1fec9cd76b [Fixed] Enable Half, BFloat16, and Complex dtypes for coo-coo sparse matmul [CUDA] (#59980)
Summary:
This PR enables Half, BFloat16, ComplexFloat, and ComplexDouble support for matrix-matrix multiplication of COO sparse matrices.
The change is applied only to CUDA 11+ builds.

`cusparseSpGEMM` also supports `CUDA_C_16F` (complex float16) and `CUDA_C_16BF` (complex bfloat16). PyTorch also supports the complex float16 dtype (`ScalarType::ComplexHalf`), but there is no convenient dispatch, so this dtype is omitted in this PR.

cc nikitaved pearu cpuhrsch IvanYashchuk ezyang anjali411 dylanbespalko mruberry Lezcano

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59980

Reviewed By: ngimel

Differential Revision: D30994115

Pulled By: cpuhrsch

fbshipit-source-id: 4f55b99e8e25079d6273b4edf95ad6fa85aeaf24
2021-09-21 13:03:40 -07:00
92b31b59af Revert D29699456: [pytorch][PR] Enable Half, BFloat16, and Complex dtypes for coo-coo sparse matmul [CUDA]
Test Plan: revert-hammer

Differential Revision:
D29699456 (ad4848565e)

Original commit changeset: 407ae53392ac

fbshipit-source-id: b6c70ba8bb28c0c38de47857030b69792a8470de
2021-09-01 07:32:24 -07:00
ad4848565e Enable Half, BFloat16, and Complex dtypes for coo-coo sparse matmul [CUDA] (#59980)
Summary:
This PR enables Half, BFloat16, ComplexFloat, and ComplexDouble support for matrix-matrix multiplication of COO sparse matrices.
The change is applied only to CUDA 11+ builds.

`cusparseSpGEMM` also supports `CUDA_C_16F` (complex float16) and `CUDA_C_16BF` (complex bfloat16). PyTorch also supports the complex float16 dtype (`ScalarType::ComplexHalf`), but there is no convenient dispatch, so this dtype is omitted in this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59980

Reviewed By: ngimel

Differential Revision: D29699456

Pulled By: cpuhrsch

fbshipit-source-id: 407ae53392acb2f92396a62a57cbaeb0fe6e950b
2021-08-30 15:06:25 -07:00
c966ce6933 Fix several test_ops cuda dtypes tests (#60922)
Summary:
Close https://github.com/pytorch/pytorch/issues/60443

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60922

Reviewed By: jdonald, iramazanli

Differential Revision: D29630122

Pulled By: mruberry

fbshipit-source-id: 441f79828860282e5849a2565facf9e7f72912e8
2021-07-09 09:29:13 -07:00
d99a8a31b1 Fix version comparison for defining CUDA11OrLater (#60010)
Summary:
Before this PR `CUDA11OrLater` was incorrectly set to `False` when `torch.version.cuda == "11.0"`.
`torch.version.cuda` returns major and minor CUDA versions, it doesn't return patch info.
LooseVersion comparison was calling `[11, 0] >= [11, 0, 0]` which evaluates to `False`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60010

Reviewed By: mruberry

Differential Revision: D29147107

Pulled By: ezyang

fbshipit-source-id: bd9ed076337b4d32bf1c3376b8f7ae15dbc4d08d
2021-06-16 18:04:29 -07:00
4b96fc060b Remove distutils (#57040)
Summary:
[distutils](https://docs.python.org/3/library/distutils.html) is on its way out and will be deprecated-on-import for Python 3.10+ and removed in Python 3.12 (see [PEP 632](https://www.python.org/dev/peps/pep-0632/)). There's no reason for us to keep it around since all the functionality we want from it can be found in `setuptools` / `sysconfig`. `setuptools` includes a copy of most of `distutils` (which is fine to use according to the PEP), that it uses under the hood, so this PR also uses that in some places.

Fixes #56527
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57040

Pulled By: driazati

Reviewed By: nikithamalgifb

Differential Revision: D28051356

fbshipit-source-id: 1ca312219032540e755593e50da0c9e23c62d720
2021-04-29 12:10:11 -07:00
9f336bdf10 Fixes new tf32 failures in test_nn.py (#52871)
Summary:
Also modify the `tf32_on_and_off` decorator to make it support function without `device` argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/52871

Reviewed By: ngimel

Differential Revision: D27286674

Pulled By: mruberry

fbshipit-source-id: 14f6d558271bd6a1d0bc40691c170d47e81de1ff
2021-03-24 21:53:33 -07:00
8c798e0622 Forbid trailing whitespace (#53406)
Summary:
Context: https://github.com/pytorch/pytorch/pull/53299#discussion_r587882857

These are the only hand-written parts of this diff:
- the addition to `.github/workflows/lint.yml`
- the file endings changed in these four files (to appease FB-internal land-blocking lints):
  - `GLOSSARY.md`
  - `aten/src/ATen/core/op_registration/README.md`
  - `scripts/README.md`
  - `torch/csrc/jit/codegen/fuser/README.md`

The rest was generated by running this command (on macOS):
```
git grep -I -l ' $' -- . ':(exclude)**/contrib/**' ':(exclude)third_party' | xargs gsed -i 's/ *$//'
```

I looked over the auto-generated changes and didn't see anything that looked problematic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53406

Test Plan:
This run (after adding the lint but before removing existing trailing spaces) failed:
- https://github.com/pytorch/pytorch/runs/2043032377

This run (on the tip of this PR) succeeded:
- https://github.com/pytorch/pytorch/runs/2043296348

Reviewed By: walterddr, seemethere

Differential Revision: D26856620

Pulled By: samestep

fbshipit-source-id: 3f0de7f7c2e4b0f1c089eac9b5085a58dd7e0d97
2021-03-05 17:22:55 -08:00
b52e2e6045 [BE] _get_torch_cuda_version should return tuple (#52409)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52409

Reviewed By: jbschlosser, glaringlee

Differential Revision: D26513924

Pulled By: walterddr

fbshipit-source-id: ee18ef357c326c5ad344d80c59821cc2b8814734
2021-02-18 09:28:38 -08:00
b822aba8ec Enable BFloat support for gemms on arch other than ampere (#50442)
Summary:
Fixes #{issue number}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50442

Reviewed By: bdhirsh

Differential Revision: D26044981

Pulled By: mruberry

fbshipit-source-id: 65c42f2c1de8d24e4852a1b5bd8f4b1735b2230e
2021-01-26 11:07:07 -08:00
3f5eee666c Adjust TF32 tests (#44240)
Summary:
- The thresholds of some tests are bumped up. Depending on the random generator, sometimes these tests fail with things like 0.0059 is not smaller than 0.005. I ran `test_nn.py` and `test_torch.py` for 10+ times to check these are no longer flaky.
- Add `tf32_on_and_off` to new `matrix_exp` tests.
- Disable TF32 on test suites other than `test_nn.py` and `test_torch.py`

cc: ptrblck

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44240

Reviewed By: mruberry

Differential Revision: D23882498

Pulled By: ngimel

fbshipit-source-id: 44a9ec08802c93a2efaf4e01d7487222478b6df8
2020-09-24 10:25:58 -07:00
d75c402755 Add cusolver to build, rewrite MAGMA inverse with cusolver (#42403)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/42265

This PR adds cusolver to the pytorch build, and enables the use of cusolver/cublas library functions on GPU `torch.inverse` on certain tensor shapes.

Specifically, when

* the tensor is two dimensional (single batch), or
* has >2 dimensions (multiple batches) and `batch_size <= 2`, or
* magma is not linked,

cusolver/cublas will be used. In other conditions, the current implementation of MAGMA will still be used.

8c0949ae45/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu (L742-L752)

The reason for this is that for tensors with large batch_size, `cublasXgetrfBatched` and `cublasXgetriBatched` doesn't perform very well. For `batch_size > 1`, we launch cusolver functions in multiple streams. This lets cusolver functions run in parallel, and can greatly increase the performance. When `batch_size > 2`, the parallel launched cusolver functions are slightly slower than the current magma implementation, so we still use the current magma impl.

On CUDA 9.2, there were some numerical issues detected, so cusolver impl will not be used. The cusolver impl will also not be used on platforms other than Nvidia CUDA.

060769feaf/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h (L10-L13)

Note that there is a new heuristic used before cusolver/cublas calls here:

8c0949ae45/aten/src/ATen/native/cuda/MiscUtils.h (L113-L121)

where `use_loop_launch = true` means launch single batch cusolver functions in parallel, and `use_loop_launch = false` means use cublas_X_batched functions. When magma is enabled (only `batch_size <= 2` will be dispatched to cusolver/cublas), the heuristic will always return `true` and the cusolver calls are faster than small batch_size magma calls. When magma is disabled, this adds the functionality of `torch.inverse`, which was disabled before for all shapes (though large batch_size cublas performance may not be as well as magma).

Checklist:
- [X] Add benchmark, cpu, gpu-before (magma), gpu-after (cusolver)
- [X] Rewrite single inverse (ndim == 2) with cusolver
- [X] Rewrite batched inverse (ndim > 2) with cublas
- [X] Add cusolver to build
- [x] Clean up functions related to `USE_MAGMA` define guard
- [x] Workaround for non-cuda platform
- [x] Workaround for cuda 9.2
- [x] Add zero size check
- [x] Add tests

Next step:

If cusolver doesn't cause any problem in pytorch build, and there are no major performance regressions reported after this PR being merged, I will start porting other cusolver/cublas functions for linear algebra to improve the performance.

<details>
<summary> benchmark 73499c6 </summary>

benchmark code: https://github.com/xwang233/code-snippet/blob/master/torch.inverse/inverse-cusolver.ipynb

shape meaning:

* `[] 2 torch.float32 -> torch.randn(2, 2, dtype=torch.float32)`
* `[2] 4 torch.float32 -> torch.randn(2, 4, 4, dtype=torch.float32)`

| shape | cpu_time (ms) | gpu_time_before (magma) (ms) | gpu_time_after (ms) |
| --- | --- | --- | --- |
| [] 2 torch.float32 |  0.095 |  7.534 |  0.129  |
| [] 4 torch.float32 |  0.009 |  7.522 |  0.129  |
| [] 8 torch.float32 |  0.011 |  7.647 |  0.138  |
| [] 16 torch.float32 |  0.075 |  7.582 |  0.135  |
| [] 32 torch.float32 |  0.073 |  7.573 |  0.191  |
| [] 64 torch.float32 |  0.134 |  7.694 |  0.288  |
| [] 128 torch.float32 |  0.398 |  8.073 |  0.491  |
| [] 256 torch.float32 |  1.054 |  11.860 |  1.074  |
| [] 512 torch.float32 |  5.218 |  14.130 |  2.582  |
| [] 1024 torch.float32 |  19.010 |  18.780 |  6.936  |
| [1] 2 torch.float32 |  0.009 |  0.113 |  0.128 ***regressed |
| [1] 4 torch.float32 |  0.009 |  0.113 |  0.131 ***regressed |
| [1] 8 torch.float32 |  0.011 |  0.116 |  0.129 ***regressed |
| [1] 16 torch.float32 |  0.015 |  0.122 |  0.135 ***regressed |
| [1] 32 torch.float32 |  0.032 |  0.177 |  0.178 ***regressed |
| [1] 64 torch.float32 |  0.070 |  0.420 |  0.281  |
| [1] 128 torch.float32 |  0.328 |  0.816 |  0.490  |
| [1] 256 torch.float32 |  1.125 |  1.690 |  1.084  |
| [1] 512 torch.float32 |  4.344 |  4.305 |  2.576  |
| [1] 1024 torch.float32 |  16.510 |  16.340 |  6.928  |
| [2] 2 torch.float32 |  0.009 |  0.113 |  0.186 ***regressed |
| [2] 4 torch.float32 |  0.011 |  0.115 |  0.184 ***regressed |
| [2] 8 torch.float32 |  0.012 |  0.114 |  0.184 ***regressed |
| [2] 16 torch.float32 |  0.019 |  0.119 |  0.173 ***regressed |
| [2] 32 torch.float32 |  0.050 |  0.170 |  0.240 ***regressed |
| [2] 64 torch.float32 |  0.120 |  0.429 |  0.375  |
| [2] 128 torch.float32 |  0.576 |  0.830 |  0.675  |
| [2] 256 torch.float32 |  2.021 |  1.748 |  1.451  |
| [2] 512 torch.float32 |  9.070 |  4.749 |  3.539  |
| [2] 1024 torch.float32 |  33.655 |  18.240 |  12.220  |
| [4] 2 torch.float32 |  0.009 |  0.112 |  0.318 ***regressed |
| [4] 4 torch.float32 |  0.010 |  0.115 |  0.319 ***regressed |
| [4] 8 torch.float32 |  0.013 |  0.115 |  0.320 ***regressed |
| [4] 16 torch.float32 |  0.027 |  0.120 |  0.331 ***regressed |
| [4] 32 torch.float32 |  0.085 |  0.173 |  0.385 ***regressed |
| [4] 64 torch.float32 |  0.221 |  0.431 |  0.646 ***regressed |
| [4] 128 torch.float32 |  1.102 |  0.834 |  1.055 ***regressed |
| [4] 256 torch.float32 |  4.042 |  1.811 |  2.054 ***regressed |
| [4] 512 torch.float32 |  18.390 |  4.884 |  5.087 ***regressed |
| [4] 1024 torch.float32 |  69.025 |  19.840 |  20.000 ***regressed |

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/42403

Reviewed By: ailzhang, mruberry

Differential Revision: D23717984

Pulled By: ngimel

fbshipit-source-id: 54cbd9ea72a97989cff4127089938e8a8e29a72b
2020-09-18 20:43:29 -07:00