Compare commits

..

39 Commits

Author SHA1 Message Date
c263bd43e8 [inductor] use triu ref instead of lowering (#96040) (#96462)
Fixes #95958
Generated code is functionally identical with ref and lowering, only minor differences

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96040
Approved by: https://github.com/jansel

Co-authored-by: Natalia Gimelshein <ngimel@fb.com>
2023-03-09 17:42:00 -05:00
c9913cf66f Add jinja2 as mandatory dependency (#95691) (#96450)
Should fix #95671  for nightly wheels issue. v2.0.0 RC does not need this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95691
Approved by: https://github.com/malfet

Co-authored-by: Wei Wang <weiwangmeta@meta.com>
2023-03-09 17:31:12 -05:00
2f7d8bbf17 Fix expired deprecation of comparison dtype for NumPy 1.24+ (#91517) (#96452)
> The `dtype=` argument to comparison ufuncs is now applied correctly. That
> means that only `bool` and `object` are valid values and `dtype=object` is
> enforced.

Source: https://numpy.org/doc/stable/release/1.24.0-notes.html#expired-deprecations

Fixes #91516

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91517
Approved by: https://github.com/zou3519, https://github.com/huydhn

Co-authored-by: Johnson <j3.soon@msa.hinet.net>
2023-03-09 14:30:00 -08:00
ca0cdf52ca dl_open_guard should restore flag even after exception (#96231) (#96457)
I.e. follow pattern outlined in https://docs.python.org/3.8/library/contextlib.html#contextlib.contextmanager

Also, return early on non-unix platforms (when `sys.getdlopenflags` is not defined)

Fixes https://github.com/pytorch/pytorch/issues/96159

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96231
Approved by: https://github.com/atalman

(cherry picked from commit 941ff109d32d51d6e93a2c2f4a028ff3826ece31)
2023-03-09 14:29:17 -08:00
9cfa076da8 [Release/2.0] Use Triton from PYPI (#96010)
* [Release/2.0] Use Triton from PYPI

Remove `[dynamo]` extras from setup.py

Build torchtriton conda wheels as 2.0.0

* Also, upload triton conda packages to test channel
2023-03-03 20:15:48 -05:00
8e05e41dbc [Release/2.0] Use builder release branch for tests 2023-03-03 16:22:04 -08:00
d8ffc60bc1 Remove mention of dynamo.optimize() in docs (#95802) (#96007)
This should be self containable to merge but other stuff that's been bugging me is
* Instructions on debugging IMA issues
* Dynamic shape instructions
* Explaining config options better

Will look at adding a config options doc

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95802
Approved by: https://github.com/svekars
2023-03-03 17:43:31 -05:00
1483723037 [MPS] Disallow reshape in slice (#95905) (#95978)
Disallow reshapes for arrayViews.
Current code allows a base shape of `[2, 4, 256]` to be sliced into `[4, 1, 256]` (view's shape) - which is not possible. Slicing a smaller dimension into a bigger one will always error out.

Fixes https://github.com/pytorch/pytorch/issues/95883
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95905
Approved by: https://github.com/razarmehr, https://github.com/kulinseth

Co-authored-by: Denis Vieriu <dvieriu@apple.com>
2023-03-03 10:15:10 -08:00
c4572aa1b7 [MPS] Add fixes for div with floor (#95869)
* [MPS] Add fixes for div with floor and raise error for div_trunc (#95769)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95769
Approved by: https://github.com/DenisVieriu97

* Add back the unittest skip for MacOS 12.
2023-03-02 12:36:02 -08:00
82b078ba64 [MPS] Fix views with 3 or more sliced dimensions (#95762) (#95871)
Fixes https://github.com/pytorch/pytorch/issues/95482
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95762
Approved by: https://github.com/razarmehr

Co-authored-by: Denis Vieriu <dvieriu@apple.com>
2023-03-02 12:27:46 -08:00
77f7bc5f9d Remove torch._inductor.config.triton.convolution (#95840) 2023-03-02 13:49:20 -05:00
0865964576 [optim] _actually_ default to foreach (#95862)
* [optim] include nn.Parameter as foreach supported (#95811)

This PR is a result of a realization that models are NOT subscribed to the foreach defaulting as have been claimed on our documentation for months now. BIG OOPS.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95811
Approved by: https://github.com/albanD

* [optim] Widen the cases for defaulting to foreach (#95820)

Big OOP correction continued. Also added a test this time to verify the defaulting was as expected.

The key here is realizing that the grouping for foreach already assumes that the non-param tensorlists follow suit in dtype and device, so it is too narrow to check that _all_ tensors were on CUDA. The main leeway this allowed was state_steps, which are sometimes cpu tensors. Since foreach _can_ handle cpu tensors, this should not introduce breakage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95820
Approved by: https://github.com/albanD
2023-03-02 13:33:57 -05:00
f18ac1b386 Release version of fixed nll decomp (#95853)
* fix nll loss decomposition to properly ignore ignore_index

* remove branch
2023-03-02 13:26:45 -05:00
c04134cdb1 [ROCM] Restrict pytorch rocm to only use triton 2.0.x (#95793) (#95834)
To align with upstream, we are requiring triton dependency to be between 2.0.0 and 2.1.  This will allow PyTorch 2.0 on ROCM to stay flexible enough to pick up any performance/stability improvements from Triton, without needing to cut a separate PyTorch version.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95793
Approved by: https://github.com/huydhn
2023-03-01 19:10:00 -05:00
72d0863ab2 [BE] Fix TORCH_WARN_ONCE (#95559) (#95822)
It does not take a condition as first argument, unlike `TORCH_CHECK`
Test plan, run: ` python3 -c "import torch;print(torch.arange(1., 10.,device='mps').view(3, 3).trace())"` and observe no warning

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95559
Approved by: https://github.com/Skylion007

(cherry picked from commit 9bca9df42b5898e45e2a80e03a4a4ba9a6fe654a)
2023-03-01 19:03:41 -05:00
1bd334dc25 Update copyright (#95652) (#95700)
Updating the copyright to reflect on the website.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95652
Approved by: https://github.com/atalman

Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
2023-02-28 11:00:18 -05:00
93e13cd429 [MPS] Remove FFT from the fallback as its causing crashes in test_ops and TestConsistency tests. (#95625) 2023-02-27 17:26:32 -08:00
4e4d4b0afe [MPS] Add TORCH_CHECK for Convolution (#95495)
* Raise errors for Conv and remove FFTs from Fallback list.

* Move the FFT to a separate commit.
2023-02-27 17:25:14 -08:00
Wei
c4fa850827 Reserve the tensorrt backend name for torch-tensorrt (#95627) 2023-02-27 17:17:47 -08:00
36ead09873 Add float to list of allowed ops (#94910) (#95661)
By adding `BINFLOAT` op support

Fixes https://github.com/pytorch/pytorch/issues/94670
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94910
Approved by: https://github.com/albanD

Co-authored-by: Nikita Shulga <nshulga@meta.com>
2023-02-27 15:09:00 -08:00
66d23dbad7 fix spurious aot autograd warning (#95521) (#95614)
The _make_boxed logic probably needs a cleanup, but this fixes a spurious warning that we should get in before the release.

Confirmed that this used to emit a warning and no longer does:
```
import torch

lin = torch.nn.Linear(100, 10)
def f(x):
    return lin(x)

opt_f = torch.compile(f)
opt_f(torch.randn(10, 100, requires_grad=False))
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95521
Approved by: https://github.com/ngimel
2023-02-27 14:31:02 -05:00
e2fff58844 [CUDA][CUBLAS] Explicitly link against cuBLASLt (#95094) (#95615)
An issue surfaced recently that revealed that we were never explicitly linking against `cuBLASLt`, this fixes it by linking explicitly rather than depending on linker magic.

CC @ptrblck @ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95094
Approved by: https://github.com/malfet, https://github.com/ngimel, https://github.com/atalman

Co-authored-by: eqy <eddiey@nvidia.com>
2023-02-27 14:27:06 -05:00
735333a7ff Update triton hash (#95540) (#95577)
Fixes #95523

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95540
Approved by: https://github.com/ngimel
2023-02-27 09:03:50 -08:00
6017488801 [MPS] LSTM fixes (#95388)
* [MPS] Fix LSTM backward and forward pass (#95137)

Fixes #91694
Fixes #92615

Several transpositions were missing for backward graph in case of `batch_first=True`. The #91694 is not reproduced with `batch_first=False`.

After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to #92615.

After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded

Funny enough, backward tests were completely disabled before and were not passing:
```python
    @unittest.skipIf(True, "Backward of lstm returns wrong result")
    def test_lstm_2(self, device="mps", dtype=torch.float32):
```

UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95137
Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer

* Update the allowlist for lstm_mps_backward

* More update to the BC allowlist

---------

Co-authored-by: alexdremov <dremov.me@gmail.com>
Co-authored-by: albanD <desmaison.alban@gmail.com>
2023-02-25 14:04:15 -05:00
e51e5e721c [optim] Add general documentation on our algorithm defaults (#95391) (#95516)
I added a section + table under Algorithms
https://docs-preview.pytorch.org/95391/optim.html?highlight=optim#module-torch.optim
<img width="725" alt="image" src="https://user-images.githubusercontent.com/31798555/221246256-99325a27-9016-407b-a9fe-404d61e41a82.png">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95391
Approved by: https://github.com/albanD
2023-02-25 14:02:57 -05:00
91739a0279 hotfix for memory leak in aot autograd induced by saving tensors for backward (#95101) (#95477)
Workaround fix in AOTAutograd for https://github.com/pytorch/pytorch/issues/94990 (see the comments for more details / discussion)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95101
Approved by: https://github.com/albanD
2023-02-24 16:40:30 -05:00
531f097b6f inductor: fix complier error when trying to vectorize logit_and and logit_or (#95361) (#95439)
Currently, `operator&& `  and `operator|| ` don't have vectorization implementation, disable them now for a quick fix for 2.0 release.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95361
Approved by: https://github.com/ngimel, https://github.com/EikanWang
2023-02-24 09:23:29 -05:00
00eb7b0d78 [optim] Set defaults to foreach, NOT fused (#95241) (#95415)
Rolling back the default change for Adam and rectifying the docs to reflect that AdamW never defaulted to fused.

Since our fused implementations are relatively newer, let's give them a longer bake-in time before flipping the switch for every user.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95241
Approved by: https://github.com/ngimel
2023-02-24 09:19:40 -05:00
2180f342c4 [SDPA] Fix bug in parsing scaled_dot_product_attention arguments (#95311) (#95397)
Fixes #95266

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95311
Approved by: https://github.com/cpuhrsch
2023-02-24 09:18:19 -05:00
a90b4f09ac use 4 warps for small block config in mm (#95383)
* use 4 warps for small block config in mm

* Update test/inductor/test_select_algorithm.py

* Update test/inductor/test_select_algorithm.py
2023-02-24 09:12:36 -05:00
1211ceeaa4 [MPS] Fix issues with max_pool2d (#95325)
* [MPS] Fix upsample for NHWC output  (#94963)

Fixes https://github.com/huggingface/diffusers/issues/941

**Before**:
<img width="1144" alt="Screenshot 2023-02-15 at 8 11 53 PM" src="https://user-images.githubusercontent.com/104024078/219266709-6a77636a-2fc0-4802-b130-85069b95953f.png">

**After**:
<img width="1144" alt="Screenshot 2023-02-15 at 8 12 02 PM" src="https://user-images.githubusercontent.com/104024078/219266694-ea743c02-fb55-44f1-b7d6-5946106527c3.png">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94963
Approved by: https://github.com/razarmehr

* [MPS] Move max_pool2d to mps dispatch key (#90772)

Related issue: #77394

This PR also modifies some assertions in the codegen, an explanatory comment for it has been added.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90772
Approved by: https://github.com/albanD

* [MPS] Convert output back to ChannelsLast for MaxPool2D (#94877)

Since we re-stride the indices and output in MPS pooling from ChannelsLast to Contiguous, we need to convert the results back to ChannelsLast.
This will fix the failure with test_memory_format with MaxPool2D in test_modules.py.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94877
Approved by: https://github.com/kulinseth, https://github.com/DenisVieriu97

---------

Co-authored-by: Denis Vieriu <104024078+DenisVieriu97@users.noreply.github.com>
Co-authored-by: Li-Huai (Allan) Lin <qqaatw@gmail.com>
Co-authored-by: Ramin Azarmehr <razarmehr@apple.com>
2023-02-24 09:10:49 -05:00
beaa5c5908 [MPS] View fixes (#95323)
* [MPS] Fix the uint8 type issue with View ops kernels (#95145)

This should fix the problem in Resnet model with image artifacts due to saturation on int8 type and also the incorrect class recognition reported in #86954.

Fixes #86954

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95145
Approved by: https://github.com/kulinseth, https://github.com/DenisVieriu97

* [MPS] Fix tensor with non-zero storage offset graph gathering (#91071)

Previously, the "can slice" flag in Placeholder constructor in `OperationUtils.mm` is conditioned on whether the numbers of dimensions of base shape and view shape are the same. This doesn't consider the situation that a view tensor could be the base tensor's sliced and then unsqueezed version, resulting in different num of dims.

For example, if we want to stack `y_mps` and `x_mps` on the last dim:
```
t_mps = torch.tensor([1, 2, 3, 4], device="mps")
x_mps = t_mps[2:]  # [3, 4]
y_mps = t_mps[:2]  # [1, 2]

res_mps = torch.stack((y_mps, x_mps), dim=-1)
```

the kernel will unsqueeze both of them on the last dim and then concatenate them, which is equivalent to:

```
res_mps = torch.cat((y_mps.unsqueeze(-1), x_mps.unsqueeze(-1)), dim=-1)
```

`x_mps.unsqueeze(-1)` is an unsqueezed and contiguous tensor with a storage offset, this kind of tensors should be sliceable without cloning its storage.

Fixes #87856
Fixes #91065

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91071
Approved by: https://github.com/kulinseth

* [MPS] Fix fill_ where input tensor has a storage offset (#95113)

Fixes #94390

Apart from fixing the issue above, this PR also fixes a bug that when an input tensor can be sliced, a sliced array view is created. This array view seems to be not writable or have a different storage from the original tensor, causing incorrect results with the in-place `fill`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95113
Approved by: https://github.com/kulinseth

* [MPS] Fix view op slicing for 2nd dim in case of 0 offset (#95381)

* Fix view op slicing for 2nd dim in case of 0 offset

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95381
Approved by: https://github.com/razarmehr

---------

Co-authored-by: Ramin Azarmehr <razarmehr@apple.com>
Co-authored-by: Li-Huai (Allan) Lin <qqaatw@gmail.com>
Co-authored-by: Denis Vieriu <104024078+DenisVieriu97@users.noreply.github.com>
2023-02-24 09:09:49 -05:00
4bd5c1e4f4 Fix warning if backend registers timer (#91702) (#95363)
currently logger timer is registered default for
cpu/cuda. for other backends, it may or may not
registers this timer. It reports warning for other backends and return which is not expected.
The above may fail, if the backends has have registered this timer. For example, HPU(habana) backend registers this timer. so, in this case it reports a warning and return which is incorrect.

Other case is where lazy backend timer is never registered. so, this returns a warning, and this is the reason the check was added, but it fails for other cases.

Add a generic check if the timer is registered, then don’t report warning.

Signed-off-by: Jeeja <jeejakp@habana.ai>

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91702
Approved by: https://github.com/kit1980
2023-02-23 18:57:09 -05:00
f3c97a4e43 Raise error on 3.11 dynamo export (#95088) (#95396)
For https://github.com/pytorch/pytorch/issues/94914. Realized that `dynamo.export` doesn't immediately raise an error when dynamo is trying to run on 3.11/windows.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95088
Approved by: https://github.com/weiwangmeta
2023-02-23 18:55:32 -05:00
30cf0e70f7 [MPS] Copy fixes for MPS backend (#95321)
* [MPS] Handle broadcasting by expanding src tensor in Copy.mm (#95272)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95272
Approved by: https://github.com/DenisVieriu97

* [MPS] Fix copy_cast_mps() on tensors with storage offset (#95093)

- The copy_cast path requires storage_offset to be applied before casting
- This should fix some correctness issues in transformer models

Fixes #94980

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95093
Approved by: https://github.com/kulinseth

---------

Co-authored-by: Ramin Azarmehr <razarmehr@apple.com>
2023-02-23 18:17:20 -05:00
96f627dcde [MPS] Fixes in backward functions of the MPS ops (#95327)
* [MPS] Fix bilinear backward pass (#94892)

Fixes backward pass for bilinear.

Summary of changes:
- bilinear op is able to produce **contiguous, non-view** tensors with a storage offset, such as: shape=`[1, 1, 1, 1]`, `storage_offset=12`. This seems a weird case, but it is valid, and for these type of tensors we wouldn't be able to gather/scatter since we look at the view flag (which is not set here). This change looks into `storage_offset` only rather than the is_view flag which is not being set
- **reduction sum** must return a zeroed out output if passing an input with 0 elements (e.g a shape of (0, 5)).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94892
Approved by: https://github.com/kulinseth

* [MPS] Fix the crash in elu_backward() (#94923)

Fixes a crash where the inputTensor could go null and cause a crash.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94923
Approved by: https://github.com/DenisVieriu97, https://github.com/kulinseth

* [MPS] Fix prelu backward pass (#94933)

Allocate the correct shape for the weights gradient
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94933
Approved by: https://github.com/razarmehr

* [MPS] Fix embedding_backward() issue with Float16 (#94950)

- Casting the float16 input tensor to float32 and cast back the output tensor

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94950
Approved by: https://github.com/DenisVieriu97

---------

Co-authored-by: Denis Vieriu <dvieriu@apple.com>
Co-authored-by: Ramin Azarmehr <razarmehr@apple.com>
2023-02-23 16:28:33 -05:00
6f11e6d6a1 [MPS] Convolution fixes (#95318)
* [MPS] Convolution cleanup; remove unnecessary contiguous calls (#95078)

- Fixes convolution crashes in backward with weights
- Removes unnecessary contiguous calls
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95078
Approved by: https://github.com/kulinseth

* [MPS] Fix nn.functional.conv_transpose2d grad (#94871)

- add _mps_convolution_impl that takes optional shape
- for conv_tranpose2d grad, use the shape from forward pass directly
- for conv, calculate the shape from input
- remove nn.functional.conv_transpose2d grad from blocklist

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94871
Approved by: https://github.com/kulinseth

---------

Co-authored-by: Denis Vieriu <104024078+DenisVieriu97@users.noreply.github.com>
Co-authored-by: Denis Vieriu <dvieriu@apple.com>
2023-02-23 12:31:30 -05:00
fcec27f7d5 [MPS] Numerical stability and reduction fixes (#95317)
* [MPS] Fixes for LSTM. (#94889)

- Backward pass has to give explicit bias tensor of zeros if none is passed to the op or the bias gradient will not be calculated.
- Fixed bias tensor mistakenly getting overwritten to zeros
- Fixes crash when lstm op called with has_biases set to false. Change takes into account the changed shape of the input params TensorList depending on the bias flag.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94889
Approved by: https://github.com/DenisVieriu97

* [MPS] LogSoftmax numerical stability (#95091)

Fixes #94043

Calculations are now consistent with numericaly stable formula and CPU:

$LogSoftmax(X, \dim) = X - \max(X, \dim) - \log(sum(X - \max(X, \dim), \dim))$

@malfet

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95091
Approved by: https://github.com/malfet, https://github.com/kulinseth

* [MPS] Cast int64 to int32 for reduction ops (#95231)

- give warnings of converting int64 for reduction ops
- use cast tensor for reduction sum on trace
- unblock trace from running
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95231
Approved by: https://github.com/razarmehr

* [MPS] Fix Float16 issue with Reduction ops for macOS 12 (#94952)

This would fix the issue with `__rdiv__` with float16
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94952
Approved by: https://github.com/kulinseth

---------

Co-authored-by: alexdremov <dremov.me@gmail.com>
Co-authored-by: Denis Vieriu <dvieriu@apple.com>
Co-authored-by: Ramin Azarmehr <razarmehr@apple.com>
2023-02-23 12:27:40 -05:00
cddcb1e526 Raise error if torch.compile is called from windows or py 3.11 (#94940) (#95329)
For https://github.com/pytorch/pytorch/issues/94914

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94940
Approved by: https://github.com/albanD

Co-authored-by: William Wen <williamwen@fb.com>
2023-02-23 07:56:44 -05:00
83 changed files with 1336 additions and 587 deletions

View File

@ -1 +1 @@
d54c04abe2c3e67b2139c68cdbda87b59e8dd01b
b8b470bc597c1c5bd03682c09fe3e6b7c53787fd

View File

@ -1 +1 @@
pytorch-triton-rocm>=2.0.0.dev
pytorch-triton-rocm>=2.0.0,<2.1

View File

@ -38,7 +38,7 @@ def build_triton(commit_hash: str, build_conda: bool = False, py_version : Optio
check_call(["git", "checkout", commit_hash], cwd=triton_basedir)
if build_conda:
with open(triton_basedir / "meta.yaml", "w") as meta:
print(f"package:\n name: torchtriton\n version: 2.0.0+{commit_hash[:10]}\n", file=meta)
print("package:\n name: torchtriton\n version: 2.0.0\n", file=meta)
print("source:\n path: .\n", file=meta)
print("build:\n string: py{{py}}\n number: 1\n script: cd python; "
"python setup.py install --single-version-externally-managed --record=record.txt\n", file=meta)

View File

@ -226,7 +226,8 @@ def generate_wheels_matrix(os: str,
"nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64'",
"nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"triton==2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
"build_name":
f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-with-pypi-cudnn"
.replace(

View File

@ -153,7 +153,7 @@ jobs:
- name: Checkout pytorch/builder to builder dir
uses: malfet/checkout@silent-checkout
with:
ref: main
ref: release/2.0
submodules: recursive
repository: pytorch/builder
path: builder

View File

@ -137,7 +137,7 @@ jobs:
run: |
set -ex
pip install -q awscli
s3_dir="${UPLOAD_BUCKET}/whl/nightly/"
s3_dir="${UPLOAD_BUCKET}/whl/test/"
for pkg in "${PKG_DIR}/"*.whl; do
aws s3 cp --no-progress --acl public-read "${pkg}" "${s3_dir}"
done
@ -193,7 +193,7 @@ jobs:
if: ${{ github.event_name == 'push' && (github.event.ref == 'refs/heads/master' || github.event.ref == 'refs/heads/main') }}
run: |
container_name=$(docker container ps --format '{{.ID}}')
docker exec -t "${container_name}" sh -c "anaconda upload /artifacts/torch*.tar.bz2 -u pytorch-nightly --label main --no-progress --force"
docker exec -t "${container_name}" sh -c "anaconda upload /artifacts/torch*.tar.bz2 -u pytorch-test --label main --no-progress --force"
- name: Chown artifacts
run: |

View File

@ -47,7 +47,7 @@ jobs:
DESIRED_PYTHON: "3.8"
build_name: manywheel-py3_8-cuda11_7-with-pypi-cudnn
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64' | triton==2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -169,7 +169,7 @@ jobs:
DESIRED_PYTHON: "3.8"
build_name: manywheel-py3_8-cuda11_7-with-pypi-cudnn
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64' | triton==2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -667,7 +667,7 @@ jobs:
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cuda11_7-with-pypi-cudnn
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64' | triton==2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -1165,7 +1165,7 @@ jobs:
DESIRED_PYTHON: "3.10"
build_name: manywheel-py3_10-cuda11_7-with-pypi-cudnn
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64' | triton==2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -1663,7 +1663,7 @@ jobs:
DESIRED_PYTHON: "3.11"
build_name: manywheel-py3_11-cuda11_7-with-pypi-cudnn
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64' | triton==2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -54,8 +54,6 @@ TORCH_LIBRARY_IMPL(aten, MPS, m) {
m.impl("embedding_renorm_", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("linalg_svd", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("linalg_svd.U", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("_fft_c2c", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("_fft_r2c", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("im2col", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); // Used in preprocessing by nn.Unfold
m.impl("col2im", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("linalg_vector_norm", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());

View File

@ -9,7 +9,6 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_mps_max_pool2d.h>
#include <ATen/ops/adaptive_avg_pool1d_native.h>
#include <ATen/ops/adaptive_avg_pool2d.h>
#include <ATen/ops/adaptive_max_pool1d_native.h>
@ -141,12 +140,6 @@ Tensor max_pool2d(
return at::mkldnn_max_pool2d(
self, kernel_size, stride, padding, dilation, ceil_mode);
}
#ifdef USE_MPS
if (self.is_mps()) {
return at::_mps_max_pool2d(
self, kernel_size, stride, padding, dilation, ceil_mode);
}
#endif
#if defined(C10_MOBILE)
if(xnnpack::use_max_pool2d(self, kernel_size, padding, stride,
dilation, ceil_mode)) {

View File

@ -1428,7 +1428,7 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
}
#ifdef USE_MPS
if (_input.is_mps() && !bidirectional) {
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> output = at::_lstm_mps(_input, hx, _params, has_biases,
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> output = at::_lstm_mps(_input, hx, _params, has_biases,
num_layers, dropout_p, train, bidirectional, batch_first);
std::tuple<Tensor, Tensor, Tensor> return_values = std::make_tuple(std::get<0>(output), std::get<1>(output), std::get<2>(output));
return return_values;

View File

@ -138,4 +138,7 @@ typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode)
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
constantValue:(double) constantValue
name:(NSString * _Nullable) name;
- (MPSGraphTensor * _Nonnull) truncateWithTensor:(MPSGraphTensor * _Nonnull) tensor
name:(NSString * _Nullable) name;
@end

View File

@ -265,7 +265,7 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSS
id<MTLBuffer> srcBuf = getMTLBufferStorage(src);
bool sliceViewTensor = canSliceViewTensor(src, mpsShape);
// a view tensor could be contiguous (e.g., slice ops) or non-contiguous (e.g., transpose())
if ((!src.is_contiguous() || (src.is_view() && src.storage_offset() && !sliceViewTensor)) && gatherTensorData) {
if ((!src.is_contiguous() || (src.storage_offset() && !sliceViewTensor)) && gatherTensorData) {
Tensor emptyShell = Tensor();
// use "_tensor" from Placeholder to retain view's output during its usage in other ops
_tensor = gatherViewTensor(src, emptyShell);
@ -289,7 +289,7 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSS
} else {
if (!mpsShape) {
mpsShape = getMPSShape(_tensor);
}
}
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf
shape:mpsShape

View File

@ -311,11 +311,25 @@ TORCH_IMPL_FUNC(log_softmax_mps_out) (
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* softmaxTensor = [mpsGraph softMaxWithTensor:inputTensor
axis:dim
name:nil];
MPSGraphTensor* outputTensor = [mpsGraph logarithmWithTensor:softmaxTensor
name:nil];
MPSGraphTensor* maximumsTensor = [mpsGraph reductionMaximumWithTensor:inputTensor
axis:dim
name:nil];
MPSGraphTensor* inputTensorSubMax = [mpsGraph subtractionWithPrimaryTensor:inputTensor
secondaryTensor:maximumsTensor
name:nil];
MPSGraphTensor* exponentTensor = [mpsGraph exponentWithTensor:inputTensorSubMax
name:nil];
MPSGraphTensor* exponentTensorReduced = [mpsGraph reductionSumWithTensor:exponentTensor
axis:dim
name:nil];
MPSGraphTensor* logSumExpTensor = [mpsGraph logarithmWithTensor:exponentTensorReduced
name:nil];
MPSGraphTensor* outputTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensorSubMax
secondaryTensor:logSumExpTensor
name:nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
@ -1208,8 +1222,7 @@ TORCH_IMPL_FUNC(elu_backward_out_mps) (
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *gradOutputTensor_ = nil;
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *resultTensor_ = nil;
MPSGraphTensor *selfOrResultTensor_ = nil;
MPSGraphTensor *gradInputTensor_ = nil;
};
@ -1218,7 +1231,7 @@ TORCH_IMPL_FUNC(elu_backward_out_mps) (
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" +
string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output, self_or_result}) + ":" +
to_string(alpha.to<double>()) + ":" +
to_string(scale.to<double>()) + ":" +
to_string(input_scale.to<double>()) + ":" +
@ -1235,18 +1248,14 @@ TORCH_IMPL_FUNC(elu_backward_out_mps) (
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* inputTensor = nil;
MPSGraphTensor* resultTensor = nil;
MPSGraphTensor* selfOrResultTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_or_result);
MPSGraphTensor* lessThanZeroGradTensor = nil;
if(is_result) {
resultTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_or_result);
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.to<double>()
shape:@[@1]
dataType:getMPSDataType(grad_output.scalar_type())];
MPSGraphTensor* resultPlusAlphaTensor = [mpsGraph additionWithPrimaryTensor:resultTensor
MPSGraphTensor* resultPlusAlphaTensor = [mpsGraph additionWithPrimaryTensor:selfOrResultTensor
secondaryTensor:alphaTensor
name:nil];
auto constMul = scale.to<double>() * input_scale.to<double>();
@ -1258,11 +1267,10 @@ TORCH_IMPL_FUNC(elu_backward_out_mps) (
name:nil];
}
else {
inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_or_result);
MPSGraphTensor* inputScaleTensor = [mpsGraph constantWithScalar:input_scale.to<double>()
shape:@[@1]
dataType:getMPSDataType(grad_output.scalar_type())];
MPSGraphTensor* scaledInputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor
MPSGraphTensor* scaledInputTensor = [mpsGraph multiplicationWithPrimaryTensor:selfOrResultTensor
secondaryTensor:inputScaleTensor
name:nil];
MPSGraphTensor* expTensor = [mpsGraph exponentWithTensor:scaledInputTensor
@ -1282,7 +1290,7 @@ TORCH_IMPL_FUNC(elu_backward_out_mps) (
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f
shape:@[@1]
dataType:getMPSDataType(grad_output.scalar_type())];
MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor
MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:selfOrResultTensor
secondaryTensor:zeroTensor
name:nil];
MPSGraphTensor* gradTensor = [mpsGraph selectWithPredicateTensor:predicateTensor
@ -1294,8 +1302,7 @@ TORCH_IMPL_FUNC(elu_backward_out_mps) (
name:nil];
newCachedGraph->gradOutputTensor_ = gradOutputTensor;
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->resultTensor_ = resultTensor;
newCachedGraph->selfOrResultTensor_ = selfOrResultTensor;
newCachedGraph->gradInputTensor_ = gradInputTensor;
}
return newCachedGraph;
@ -1304,28 +1311,14 @@ TORCH_IMPL_FUNC(elu_backward_out_mps) (
}
Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output, nil, executeGatherOp);
Placeholder selfPlaceholder = Placeholder();
Placeholder resultPlaceholder = Placeholder();
if(is_result)
resultPlaceholder = Placeholder(cachedGraph->resultTensor_, self_or_result, nil, executeGatherOp);
else
selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self_or_result, nil, executeGatherOp);
Placeholder selfOrResultPlaceholder = Placeholder(cachedGraph->selfOrResultTensor_, self_or_result, nil, executeGatherOp);
Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, out.has_storage() ? out : grad_input, nil, false);
// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
if(is_result)
feeds = @{
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
resultPlaceholder.getMPSGraphTensor() : resultPlaceholder.getMPSGraphTensorData()
};
else
feeds = @{
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
selfOrResultPlaceholder.getMPSGraphTensor() : selfOrResultPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()
};
@ -1840,7 +1833,7 @@ std::tuple<Tensor, Tensor> prelu_backward_mps(const Tensor& grad_output, const T
using namespace mps;
Tensor grad_input = at::empty_like(self, self.suggest_memory_format());
Tensor weight_grad = at::empty_like(weight_, at::MemoryFormat::Contiguous);
Tensor weight_grad = at::empty_like(self, at::MemoryFormat::Contiguous);
if (grad_output.numel() == 0) {
return std::tuple<Tensor, Tensor>{grad_input, weight_grad};
}

View File

@ -177,10 +177,6 @@ void div_mode_template(const Tensor& self, const Tensor& other,
c10::optional<c10::string_view> rounding_mode,
const Tensor& output, const string op_name)
{
if(rounding_mode.has_value() && *rounding_mode == "floor"){
TORCH_CHECK(self.scalar_type() != ScalarType::Long,
"MPS: does not support floor_divide op with int64 input");
}
BinaryOpBlock div_mode_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
bool isFloatInput = ([primaryCastTensor dataType] & MPSDataTypeFloatBit) != 0;

View File

@ -12,7 +12,7 @@ Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
}
Tensor output = self;
bool needsCopyToOutput = false;
if (!self.is_contiguous()) {
if (!self.is_contiguous() || self.storage_offset()) {
output = empty_mps(self.sizes(), self.scalar_type(), c10::nullopt, kMPS);
needsCopyToOutput = true;
}
@ -89,7 +89,7 @@ bool fill_mps_tensor_(Tensor& self, uint8_t value) {
if (self.is_contiguous()) {
MPSStream* stream = getCurrentMPSStream();
auto storage_byte_offset = self.storage_offset() * self.itemsize();
stream->fill(mps::getMTLBufferStorage(self), 0, self.nbytes(), storage_byte_offset);
stream->fill(mps::getMTLBufferStorage(self), 0, self.storage().nbytes(), storage_byte_offset);
return true;
}
return false;

View File

@ -56,15 +56,17 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,
descriptor_.groups = groups;
}
Tensor _mps_convolution(
Tensor _mps_convolution_impl(
const Tensor& input_t,
const Tensor& weight_t,
const c10::optional<Tensor>& bias_opt,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
int64_t groups,
c10::optional<IntArrayRef> input_shape) {
TORCH_CHECK(input_t.dim() < 5, "Conv3D is not supported on MPS");
TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types");
namespace native_mps = at::native::mps;
CheckedFrom c = "mps_convolution";
@ -83,6 +85,8 @@ Tensor _mps_convolution(
auto memory_format = input_t.suggest_memory_format();
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
auto output_t = at::empty(
input_shape.has_value() ?
input_shape.value() :
conv_output_size(input->sizes(), weight->sizes(),
padding, stride, dilation),
input->scalar_type(),
@ -237,21 +241,30 @@ Tensor _mps_convolution(
return *output;
}
Tensor _mps_convolution(
const Tensor& input_t,
const Tensor& weight_t,
const c10::optional<Tensor>& bias_opt,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
return _mps_convolution_impl(input_t, weight_t, bias_opt, padding, stride, dilation, groups, c10::nullopt);
}
Tensor mps_convolution_backward_input(
IntArrayRef input_size, const Tensor& grad_output_, const Tensor& weight_,
IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
namespace native_mps = at::native::mps;
using namespace mps;
TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types");
CheckedFrom c = "mps_convolution_backward_input";
TensorArg grad_output{ grad_output_, "grad_output", 1 },
weight{ weight_, "weight", 2 };
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
weight{ weight_t, "weight", 2 };
checkAllSameType(c, {grad_output, weight});
checkAllSameGPU(c, {grad_output, weight});
auto memory_format = grad_output_.suggest_memory_format();
auto memory_format = grad_output_t.suggest_memory_format();
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
Tensor grad_output_t = grad_output_.contiguous(memory_format);
Tensor weight_t = weight_.contiguous(memory_format);
MPSShape* weightShape = getMPSShape(weight_);
auto grad_input_t = at::empty( input_size, grad_output_t.options(), c10::nullopt);
// Avoid "grad_input" when this is being used as transposed convolution
@ -327,10 +340,10 @@ Tensor mps_convolution_backward_input(
}
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(weight_t.scalar_type()), weightShape);
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor;
if (is_channels_last && grad_output_t.is_contiguous() && !grad_output_t.is_view()) {
if (is_channels_last) {
gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose);
}
MPSGraphTensor* gradInputTensor;
@ -359,7 +372,7 @@ Tensor mps_convolution_backward_input(
}
auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape);
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t, weightShape);
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t);
auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, *grad_input);
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
@ -377,17 +390,15 @@ Tensor mps_convolution_backward_input(
}
Tensor mps_convolution_backward_weights(
IntArrayRef weight_size, const Tensor& grad_output_, const Tensor& input_,
IntArrayRef weight_size, const Tensor& grad_output_t, const Tensor& input_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
namespace native_mps = at::native::mps;
using namespace mps;
TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types");
CheckedFrom c = "mps_convolution_backward_weights";
auto memory_format = input_.suggest_memory_format();
auto memory_format = grad_output_t.suggest_memory_format();
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
auto grad_output_t = grad_output_.to(memory_format);
auto input_t = input_.to(memory_format);
MPSShape* gradOutputShape = mps::getMPSShape(grad_output_t, memory_format);
// For uniformity with everything else, although it seems grad_weight
@ -475,7 +486,7 @@ Tensor mps_convolution_backward_weights(
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor;
if (is_channels_last && grad_output_t.is_contiguous() && !grad_output_t.is_view()) {
if (is_channels_last) {
gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose);
}
@ -525,12 +536,9 @@ Tensor mps_convolution_backward_weights(
}
std::tuple<at::Tensor,at::Tensor,at::Tensor> mps_convolution_backward(
const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
std::array<bool,3> output_mask) {
Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());
Tensor grad_input, grad_weight, grad_bias;
if (input.numel() == 0) {
if (output_mask[0]) {
@ -576,10 +584,10 @@ Tensor _mps_convolution_transpose(
Tensor mps_convolution_transpose_backward_input(
const Tensor& grad_output_t, const Tensor& weight_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups)
int64_t groups, IntArrayRef input_shape)
{
return at::_mps_convolution(
grad_output_t, weight_t, c10::nullopt, padding, stride, dilation, groups);
return _mps_convolution_impl(
grad_output_t, weight_t, c10::nullopt, padding, stride, dilation, groups, input_shape);
}
Tensor mps_convolution_transpose_backward_weight(
@ -595,15 +603,12 @@ Tensor mps_convolution_transpose_backward_weight(
std::tuple<Tensor,Tensor> mps_convolution_transpose_backward(
const Tensor& input, const Tensor& grad_output_t, const Tensor& weight,
const Tensor& input, const Tensor& grad_output, const Tensor& weight,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
std::array<bool,2> output_mask) {
Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());
Tensor grad_input, grad_weight;
if (output_mask[0]) {
grad_input = mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups);
grad_input = mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, input.sizes());
}
if (output_mask[1]) {
grad_weight = mps_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups);

View File

@ -251,8 +251,11 @@ static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, boo
bool returnGatherOutput = dst_.is_contiguous();
Tensor src;
auto sameMemFormat = src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
const bool sameDataType = src_.dtype() == dst_.dtype();
if (!src_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) {
if ((!src_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) ||
// the copy_cast path requires storage_offset to be applied before casting
(src_.storage_offset() && !sameDataType)) {
Tensor emptyShell = Tensor();
src = gatherViewTensor(src_, returnGatherOutput ? dst_ : emptyShell);
@ -282,7 +285,7 @@ static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, boo
src._set_neg(src_.is_neg());
const size_t src_size = src.nbytes();
if (src.dtype() == dst_.dtype()) {
if (sameDataType) {
MPSStream* stream = getCurrentMPSStream();
// for GPU to GPU copies we only encode to stream's command buffer (no flushing)
stream->copy(sourceBuffer, destBuffer, src_size, src_byte_offset, dst_byte_offset);
@ -297,22 +300,27 @@ at::Tensor& mps_copy_(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
TORCH_CHECK(dst.defined(), "dst is undefined");
TORCH_CHECK(src.defined(), "src is undefined");
bool needs_broadcasting = false;
if (src.numel() == 0 || dst.is_same(src)) {
return dst;
}
if (dst.numel() == 0) {
dst.resize_as_(src);
}
if (dst.dim() > src.dim()) {
needs_broadcasting = true;
}
if (src.device().type() == at::kMPS && dst.device().type() == at::kCPU) {
return copy_from_mps_(dst, src, non_blocking);
return copy_from_mps_(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
}
if (src.device().type() == at::kCPU && dst.device().type() == at::kMPS) {
return copy_to_mps_(dst, src, non_blocking);
return copy_to_mps_(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
}
if (src.device().type() == at::kMPS && dst.device().type() == at::kMPS) {
return copy_kernel_mps(dst, src, non_blocking);
return copy_kernel_mps(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
}
TORCH_INTERNAL_ASSERT(
src.device().type() == DeviceType::MPS,

View File

@ -886,19 +886,31 @@ Tensor embedding_dense_backward_mps(
MPSGraphTensor* reshapedIndicesTensor = indicesTensor;
MPSGraphTensor* castGradTensor = incomingGradTensor;
MPSDataType dataType = mps::getMPSDataType(grad_.scalar_type());
// issue 105486100, scatterNDWithUpdatesTensor produces wrong result for float16
if (dataType == MPSDataTypeFloat16) {
castGradTensor = [mpsGraph castTensor: incomingGradTensor
toType: MPSDataTypeFloat32
name: @"castGradTensor"];
}
if (num_indices_dims != 0) {
reshapedIndicesTensor = [mpsGraph expandDimsOfTensor: indicesTensor
axes: @[@-1]
name: nil];
}
auto outgoingGradTensor = [mpsGraph scatterNDWithUpdatesTensor: incomingGradTensor
auto outgoingGradTensor = [mpsGraph scatterNDWithUpdatesTensor: castGradTensor
indicesTensor: reshapedIndicesTensor
shape: native_mps::getMPSShape(IntArrayRef(outgoing_gradient_shape))
batchDimensions: 0
mode: MPSGraphScatterModeAdd
name: @"edb"];
if (dataType == MPSDataTypeFloat16) {
outgoingGradTensor = [mpsGraph castTensor: outgoingGradTensor
toType: MPSDataTypeFloat16
name: @"castGradTensor"];
}
newCachedGraph->incomingGradTensor_ = incomingGradTensor;
newCachedGraph->indicesTensor_ = indicesTensor;
newCachedGraph->outgoingGradTensor_ = outgoingGradTensor;

View File

@ -83,6 +83,7 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
pool2d_shape_check(input, kH, kW, dH, dW, padH, padW, dilationH, dilationW,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
auto output_memory_format = output.suggest_memory_format();
// the output and indices are 'empty', so we could avoid unnecessary gatherView on empty tensors
// by simply restriding them (instead of calling the costly Contiguous()).
if (indices.suggest_memory_format() == MemoryFormat::ChannelsLast) {
@ -94,8 +95,9 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
outputSizes.insert(outputSizes.begin(), nbatch);
}
output.resize_(outputSizes);
} else if (output.suggest_memory_format() == MemoryFormat::ChannelsLast) {
} else if (output_memory_format == MemoryFormat::ChannelsLast) {
output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
output_memory_format = MemoryFormat::Contiguous;
}
if (output.numel() == 0 || (is_backward_pass && grad_output.numel() == 0)) {
@ -196,6 +198,10 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
}
runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results);
if (output_memory_format != suggested_memory_format) {
const_cast<Tensor&>(output) = output.to(suggested_memory_format);
}
}
}
@ -302,7 +308,7 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
} // namespace mps
Tensor _mps_max_pool2d(
Tensor mps_max_pool2d(
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
@ -356,6 +362,8 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)(
const Tensor& output,
const Tensor& indices) {
auto indices_memory_format = indices.suggest_memory_format();
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
MPSGraph* mpsGraph = cachedGraph.graph();
NSArray<MPSGraphTensor*>* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor: cachedGraph.inputTensor
@ -366,6 +374,10 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)(
};
mps::pool2d_template(input, output, indices, c10::nullopt, kernel_size, stride,
padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d_indices");
if (indices_memory_format == MemoryFormat::ChannelsLast) {
const_cast<Tensor&>(indices) = indices.to(MemoryFormat::ChannelsLast);
}
}
TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)(

View File

@ -139,6 +139,10 @@ void reduction_out_mps(
MPSReductionType reduction_type,
const std::string& func_name) {
// issue 103641234, reduction ops does not have int64 support
if (input_t.scalar_type() == ScalarType::Long) {
TORCH_WARN_ONCE("MPS: no support for int64 reduction ops, casting it to int32");
}
IntArrayRef input_shape = input_t.sizes();
if (opt_dim.has_value()) {
@ -163,6 +167,9 @@ void reduction_out_mps(
if (reduction_type == MPSReductionType::PROD) {
output_t.fill_(1);
}
else if (reduction_type == MPSReductionType::SUM) {
output_t.zero_();
}
return;
}
@ -197,7 +204,10 @@ void reduction_out_mps(
(dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt)) {
inputCastDtype = getMPSDataType(dtype.value());
} else if (input_type != MPSDataTypeInt32 &&
input_type != MPSDataTypeFloat32) {
input_type != MPSDataTypeFloat32 &&
input_type != MPSDataTypeFloat16) {
inputCastDtype = MPSDataTypeFloat32;
} else if (!is_macos_13_or_newer() && input_type == MPSDataTypeFloat16) {
inputCastDtype = MPSDataTypeFloat32;
}
@ -241,7 +251,7 @@ void reduction_out_mps(
axes:wrappedAxes
name:nil];
} else if (reduction_type == MPSReductionType::TRACE) {
MPSGraphTensor *bandPartWithTensor = [mpsGraph bandPartWithTensor:inputTensor
MPSGraphTensor *bandPartWithTensor = [mpsGraph bandPartWithTensor:castInputTensor
numLower:0
numUpper:0
name:nil];
@ -1257,7 +1267,9 @@ Tensor min_max_mps
(const Tensor& input_t,
MPSReductionType reduction_type,
const std::string& func_name) {
TORCH_WARN_ONCE(input_t.scalar_type() != ScalarType::Long, "MPS: no support for int64 min/max ops, casting it to int32");
if (input_t.scalar_type() == ScalarType::Long) {
TORCH_WARN_ONCE("MPS: no support for int64 min/max ops, casting it to int32");
}
using CachedGraph = MPSUnaryCachedGraph;

View File

@ -233,7 +233,7 @@ Tensor repeat_interleave_mps(const Tensor& repeat_, c10::optional<int64_t> outpu
if (repeat.scalar_type() == kLong) {
// #103810551: `repeat_interleave_common` uses cumsum to calculate the final shape of output,
// which currently doesn't support int64_t as input. Casting internally the indices to int32_t.
TORCH_WARN_ONCE(false, "MPS: no support for int64 repeats mask, casting it to int32");
TORCH_WARN_ONCE("MPS: no support for int64 repeats mask, casting it to int32");
repeat = repeat.to(kInt);
}
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
@ -243,4 +243,4 @@ Tensor repeat_interleave_mps(const Tensor& repeat_, c10::optional<int64_t> outpu
return output;
}
} // namespace at::native
} // namespace at::native

View File

@ -23,17 +23,31 @@ std::vector<long long> getTensorShape(MPSGraphTensor* mpsTensor) {
return output_dimensions;
}
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
using namespace mps;
//Projections are not currently supported, raise an error if needed
bool has_projections = (hx[0].size(2) != hx[1].size(2));
if(has_projections) {
AT_ERROR("LSTM with projections is not currently supported with MPS.");
}
TORCH_CHECK(!(!is_macos_13_or_newer() && num_layers > 1), "Multi-layer LSTM support in MPS available only on MacOS 13 onwards");
std::vector<Tensor> kernel_weights;
std::vector<Tensor> recurrent_kernel_weights;
std::vector<Tensor> biases;
std::vector<Tensor> recurrent_biases;
for (size_t i = 0; i < num_layers; i+=1) {
kernel_weights.push_back(params[i*4]);
recurrent_kernel_weights.push_back(params[i*4+1]);
biases.push_back(params[i*4+2]);
recurrent_biases.push_back(params[i*4+3]);
if (has_biases) {
kernel_weights.push_back(params[i*4]);
recurrent_kernel_weights.push_back(params[i*4+1]);
biases.push_back(params[i*4+2]);
recurrent_biases.push_back(params[i*4+3]);
} else {
kernel_weights.push_back(params[i*2]);
recurrent_kernel_weights.push_back(params[i*2+1]);
}
}
struct CachedGraph : public MPSCachedGraph {
@ -44,8 +58,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList_ = nil;
NSMutableArray<MPSGraphTensor*> *biasList_ = nil;
NSMutableArray<MPSGraphTensor*> *recurrentBiasList_ = nil;
std::vector<MPSGraphTensor*> outputCellStateFwdVector_;
std::vector<MPSGraphTensor*> outputZStateVector_;
};
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@ -67,12 +79,15 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList = [[NSMutableArray alloc] initWithCapacity:params.size()];
NSMutableArray<MPSGraphTensor*> *kernelBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
NSMutableArray<MPSGraphTensor*> *recurrentBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
NSMutableArray<MPSGraphTensor*> *layersOutputsList = [[NSMutableArray alloc] initWithCapacity:num_layers];
for (size_t i = 0; i < num_layers; i += 1) {
[kernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(kernel_weights[i]))];
[recurrentKernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_kernel_weights[i]))];
[kernelBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(biases[i]))];
[recurrentBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_biases[i]))];
if(has_biases) {
[kernelBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(biases[i]))];
[recurrentBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_biases[i]))];
}
}
MPSGraphLSTMDescriptor * opDesc = [MPSGraphLSTMDescriptor descriptor];
@ -93,25 +108,28 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
}
MPSGraphTensor* inputTensor_ = inputTensor;
MPSGraphTensor* stateTensor_ = [mpsGraph sliceTensor:stateTensor
dimension:0
start:0
length:1
name:nil];
MPSGraphTensor* cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
dimension:0
start:0
length:1
name:nil];
NSArray<MPSGraphTensor*>* outputs = nil;
NSMutableArray<MPSGraphTensor*>* outputStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
NSMutableArray<MPSGraphTensor*>* outputCellStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
NSMutableArray<MPSGraphTensor*>* outputZStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
NSMutableArray<MPSGraphTensor*>* outputCellStateFwdArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
for(int i = 0; i < num_layers; i++) {
MPSGraphTensor* biasTensor = [mpsGraph additionWithPrimaryTensor:kernelBiasList[i]
secondaryTensor:recurrentBiasList[i]
name:nil];
MPSGraphTensor* biasTensor = nil;
if(has_biases) {
biasTensor = [mpsGraph additionWithPrimaryTensor:kernelBiasList[i]
secondaryTensor:recurrentBiasList[i]
name:nil];
}
MPSGraphTensor* stateTensor_ = [mpsGraph sliceTensor:stateTensor
dimension:0
start:i
length:1
name:nil];
MPSGraphTensor* cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
dimension:0
start:i
length:1
name:nil];
outputs = [mpsGraph LSTMWithSourceTensor:inputTensor_
recurrentWeight:recurrentKernelWeightsList[i]
inputWeight:kernelWeightsList[i]
@ -121,18 +139,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
descriptor:opDesc
name:nil];
stateTensor_ = [mpsGraph sliceTensor:stateTensor
dimension:0
start:i
length:1
name:nil];
cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
dimension:0
start:i
length:1
name:nil];
inputTensor_ = [outputs objectAtIndex:0];
// no need to keep a final layer output copy as it is
// returned anyway and not used in backprop
if(i != num_layers - 1) {
[layersOutputsList addObject:[mpsGraph expandDimsOfTensor:inputTensor_
axis:0
name:nil]];
}
if(dropout_p>0.0 && train && (i!=num_layers-1)) {
inputTensor_ = [mpsGraph dropoutTensor:inputTensor_
rate:dropout_p
@ -150,7 +164,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
name:nil]];
}
MPSGraphTensor* outputTensor = [outputs objectAtIndex:0];
MPSGraphTensor* outputTensor = inputTensor_;
if (batch_first) {
outputTensor = [mpsGraph transposeTensor:outputTensor
dimension:0
@ -169,8 +183,11 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
MPSGraphTensor* outputCellStatesFwd = [mpsGraph concatTensors:outputCellStateFwdArray
dimension:0
name:nil];
MPSGraphTensor* layersOutputs = (num_layers > 1)
? [mpsGraph concatTensors:layersOutputsList dimension:0 name:nil]
: nil;
std::vector<MPSGraphTensor*> outputTensors = {outputTensor, outputStates, outputCellStates, outputZStates, outputCellStatesFwd};
std::vector<MPSGraphTensor*> outputTensors = {outputTensor, outputStates, outputCellStates, outputZStates, outputCellStatesFwd, layersOutputs};
newCachedGraph->inputTensors_ = inputTensors;
newCachedGraph->outputTensors_ = outputTensors;
newCachedGraph->kernelWeightsList_ = kernelWeightsList;
@ -188,20 +205,20 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
NSMutableArray<MPSGraphTensor*> *biasList = cachedGraph->biasList_;
NSMutableArray<MPSGraphTensor*> *recurrentBiasList = cachedGraph->recurrentBiasList_;
Placeholder kernelWeight;
Placeholder recurrentKernelWeight;
Placeholder bias;
Placeholder recurrentBias;
Placeholder kernelWeight, recurrentKernelWeight, bias, recurrentBias;
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*> *feeds = [[[NSMutableDictionary alloc] init] autorelease];
for (size_t i = 0; i < num_layers; i+=1) {
kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]);
recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]);
bias = Placeholder([biasList objectAtIndex:i], biases[i]);
recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]);
[feeds setObject:kernelWeight.getMPSGraphTensorData() forKey:kernelWeight.getMPSGraphTensor()];
[feeds setObject:recurrentKernelWeight.getMPSGraphTensorData() forKey:recurrentKernelWeight.getMPSGraphTensor()];
[feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()];
[feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()];
if(has_biases) {
bias = Placeholder([biasList objectAtIndex:i], biases[i]);
recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]);
[feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()];
[feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()];
}
}
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensors_[0], input);
@ -218,6 +235,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
Tensor cy = at::empty_like(hx[1], input.options());
Tensor zState = at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[3])), input.options());
Tensor cellStateFwd = at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[4])), input.options());
Tensor layerOutputs = (num_layers > 1)
? at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[5])), input.options())
: at::empty({ 1 }, input.options()); // not used if num_layers == 1
Placeholder outputPlaceholder0 = Placeholder(cachedGraph->outputTensors_[0], output);
Placeholder outputPlaceholder1 = Placeholder(cachedGraph->outputTensors_[1], hy);
@ -225,20 +245,25 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
Placeholder outputPlaceholder3 = Placeholder(cachedGraph->outputTensors_[3], zState);
Placeholder outputPlaceholder4 = Placeholder(cachedGraph->outputTensors_[4], cellStateFwd);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = [@{
outputPlaceholder0.getMPSGraphTensor() : outputPlaceholder0.getMPSGraphTensorData(),
outputPlaceholder1.getMPSGraphTensor() : outputPlaceholder1.getMPSGraphTensorData(),
outputPlaceholder2.getMPSGraphTensor() : outputPlaceholder2.getMPSGraphTensorData(),
outputPlaceholder3.getMPSGraphTensor() : outputPlaceholder3.getMPSGraphTensorData(),
outputPlaceholder4.getMPSGraphTensor() : outputPlaceholder4.getMPSGraphTensorData()
};
outputPlaceholder4.getMPSGraphTensor() : outputPlaceholder4.getMPSGraphTensorData(),
} mutableCopy];
if (num_layers > 1) {
Placeholder outputPlaceholder5 = Placeholder(cachedGraph->outputTensors_[5], layerOutputs);
[results setObject:outputPlaceholder5.getMPSGraphTensorData() forKey: outputPlaceholder5.getMPSGraphTensor()];
}
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
return std::make_tuple(output, hy, cy, zState, cellStateFwd);
return std::make_tuple(output, hy, cy, zState, cellStateFwd, layerOutputs);
}
}
std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(const Tensor& grad_y, const c10::optional<Tensor>& grad_hy_opt, const c10::optional<Tensor>& grad_cy_opt, const Tensor& z_state, const Tensor& cell_state_fwd, const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(const Tensor& grad_y, const c10::optional<Tensor>& grad_hy_opt, const c10::optional<Tensor>& grad_cy_opt, const Tensor& z_state, const Tensor& cell_state_fwd, const Tensor& input, const Tensor& layersOutputs, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
using namespace mps;
const Tensor& grad_hy_r = c10::value_or_else(grad_hy_opt, [] {return Tensor();});
const Tensor& grad_cy_r = c10::value_or_else(grad_cy_opt, [] {return Tensor();});
@ -250,10 +275,15 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
std::vector<Tensor> biases;
std::vector<Tensor> recurrent_biases;
for (size_t i = 0; i < num_layers; i+=1) {
kernel_weights.push_back(params[i*4]);
recurrent_kernel_weights.push_back(params[i*4+1]);
biases.push_back(params[i*4+2]);
recurrent_biases.push_back(params[i*4+3]);
if(has_biases) {
kernel_weights.push_back(params[i*4]);
recurrent_kernel_weights.push_back(params[i*4+1]);
biases.push_back(params[i*4+2]);
recurrent_biases.push_back(params[i*4+3]);
} else {
kernel_weights.push_back(params[i*2]);
recurrent_kernel_weights.push_back(params[i*2+1]);
}
}
struct CachedGraph : public MPSCachedGraph {
@ -264,12 +294,12 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList_ = nil;
NSMutableArray<MPSGraphTensor*> *biasList_ = nil;
NSMutableArray<MPSGraphTensor*> *recurrentBiasList_ = nil;
NSMutableArray<MPSGraphTensor*> *gradOutput_ = nil;
NSMutableArray<MPSGraphTensor*> *gradRecWeights_ = nil;
NSMutableArray<MPSGraphTensor*> *gradWeights_ = nil;
NSMutableArray<MPSGraphTensor*> *gradBias_ = nil;
NSMutableArray<MPSGraphTensor*> *gradState_ = nil;
NSMutableArray<MPSGraphTensor*> *gradCellState_ = nil;
MPSGraphTensor* gradOutput_ = nil;
MPSGraphTensor* gradState_ = nil;
MPSGraphTensor* gradCellState_ = nil;
};
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@ -296,8 +326,10 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
for (size_t i = 0; i < num_layers; i += 1) {
[kernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(kernel_weights[i]))];
[recurrentKernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_kernel_weights[i]))];
[kernelBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(biases[i]))];
[recurrentBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_biases[i]))];
if(has_biases) {
[kernelBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(biases[i]))];
[recurrentBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_biases[i]))];
}
}
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(input));
@ -308,8 +340,22 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
MPSGraphTensor* gradientCyTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_cy.scalar_type()), getMPSShape(grad_cy));
MPSGraphTensor* gradientHyTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_hy.scalar_type()), getMPSShape(grad_hy));
MPSGraphTensor* cellStateFwdTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(cell_state_fwd.scalar_type()), getMPSShape(cell_state_fwd));
MPSGraphTensor* layersOutputsTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(layersOutputs.scalar_type()), getMPSShape(layersOutputs));
std::vector<MPSGraphTensor*> inputs = {inputTensor, stateTensor, cellStateTensor, gradientTensor, zStateTensor, cellStateFwdTensor, gradientHyTensor, gradientCyTensor, layersOutputsTensor};
if (batch_first) {
inputTensor = [mpsGraph transposeTensor: inputTensor
dimension: 0
withDimension: 1
name: nil];
gradientTensor = [mpsGraph transposeTensor: gradientTensor
dimension: 0
withDimension: 1
name: nil];
}
std::vector<MPSGraphTensor*> inputs = {inputTensor, stateTensor, cellStateTensor, gradientTensor, zStateTensor, cellStateFwdTensor, gradientHyTensor, gradientCyTensor};
newCachedGraph->recurrentKernelWeightsList_ = recurrentKernelWeightsList;
newCachedGraph->kernelWeightsList_ = kernelWeightsList;
newCachedGraph->biasList_ = kernelBiasList;
@ -325,7 +371,6 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
NSArray<MPSGraphTensor*>* outputs = nil;
NSMutableArray<MPSGraphTensor*>* gradOutputArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
NSMutableArray<MPSGraphTensor*>* gradRecWeightsArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
NSMutableArray<MPSGraphTensor*>* gradWeightsArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
NSMutableArray<MPSGraphTensor*>* gradBiasArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
@ -349,9 +394,15 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
cellStateFwd = [mpsGraph squeezeTensor:cellStateFwd
axis:0
name:nil];
MPSGraphTensor* biasTensor = [mpsGraph additionWithPrimaryTensor:kernelBiasList[i]
secondaryTensor:recurrentBiasList[i]
name:nil];
MPSGraphTensor* biasTensor = nil;
if(has_biases) {
biasTensor = [mpsGraph additionWithPrimaryTensor:kernelBiasList[i]
secondaryTensor:recurrentBiasList[i]
name:nil];
} else {
biasTensor = [mpsGraph constantWithScalar:0.0
dataType:inputTensor.dataType];
}
MPSGraphTensor* stateTensor_ = [mpsGraph sliceTensor:stateTensor
dimension:0
@ -375,7 +426,23 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
length:1
name:nil];
outputs = [mpsGraph LSTMGradientsWithSourceTensor: inputTensor
MPSGraphTensor* iterationInputTensor_ = nil;
if (i == 0) {
iterationInputTensor_ = inputTensor;
} else {
iterationInputTensor_ = [mpsGraph sliceTensor:layersOutputsTensor
dimension: 0
// last element in layersOutputsTensor contains
// **inputs** for the last layer
start: i - num_layers
length: 1
name: nil];
iterationInputTensor_ = [mpsGraph squeezeTensor:iterationInputTensor_
axis:0
name: nil];
}
outputs = [mpsGraph LSTMGradientsWithSourceTensor: iterationInputTensor_
recurrentWeight: recurrentKernelWeightsList[i]
sourceGradient: gradientTensor_
zState: zState
@ -391,24 +458,31 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
descriptor: opDesc
name: nil];
gradientTensor_ = [outputs objectAtIndex:0];
[gradOutputArray addObject:[outputs objectAtIndex:0]];
[gradRecWeightsArray addObject:[outputs objectAtIndex:1]];
[gradWeightsArray addObject:[outputs objectAtIndex:2]];
[gradBiasArray addObject:[outputs objectAtIndex:3]];
[gradStateArray addObject:[outputs objectAtIndex:4]];
[gradCellStateArray addObject:[outputs objectAtIndex:5]];
[gradRecWeightsArray insertObject:[outputs objectAtIndex:1] atIndex:0];
[gradWeightsArray insertObject:[outputs objectAtIndex:2] atIndex:0];
[gradBiasArray insertObject: [outputs objectAtIndex:3] atIndex:0];
[gradStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:4] axis:0 name:nil] atIndex:0];
[gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:5] axis:0 name:nil] atIndex:0];
}
std::vector<MPSGraphTensor*> outputTensors = {[outputs objectAtIndex:0],[outputs objectAtIndex:1],[outputs objectAtIndex:2],[outputs objectAtIndex:3], [outputs objectAtIndex:4], [outputs objectAtIndex:5]};
if (batch_first) {
MPSGraphTensor* gradientTensorTransposed = [mpsGraph transposeTensor:gradientTensor_
dimension: 0
withDimension: 1
name:nil];
newCachedGraph->gradOutput_ = gradientTensorTransposed;
} else {
newCachedGraph->gradOutput_ = gradientTensor_;
}
newCachedGraph->outputTensors_ = outputTensors;
newCachedGraph->gradOutput_ = gradOutputArray;
newCachedGraph->gradRecWeights_ = gradRecWeightsArray;
newCachedGraph->gradWeights_ = gradWeightsArray;
newCachedGraph->gradBias_ = gradBiasArray;
newCachedGraph->gradState_ = gradStateArray;
newCachedGraph->gradCellState_ = gradCellStateArray;
newCachedGraph->gradState_ = [mpsGraph concatTensors:gradStateArray dimension: 0 name: nil];
newCachedGraph->gradCellState_ = [mpsGraph concatTensors:gradCellStateArray dimension: 0 name: nil];
}
return newCachedGraph;
});
@ -423,6 +497,7 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
Placeholder cellStateFwdPlaceholder = Placeholder(cachedGraph->inputTensors_[5], cell_state_fwd);
Placeholder gradientHyPlaceholder = Placeholder(cachedGraph->inputTensors_[6], grad_hy);
Placeholder gradientCyPlaceholder = Placeholder(cachedGraph->inputTensors_[7], grad_cy);
Placeholder layersOutputsPlaceholder = Placeholder(cachedGraph->inputTensors_[8], layersOutputs);
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*> *feeds = [[[NSMutableDictionary alloc] init] autorelease];
[feeds setObject:gradientPlaceholder.getMPSGraphTensorData() forKey:gradientPlaceholder.getMPSGraphTensor()];
@ -433,6 +508,7 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
[feeds setObject:cellStatePlaceholder.getMPSGraphTensorData() forKey:cellStatePlaceholder.getMPSGraphTensor()];
[feeds setObject:zStatePlaceholder.getMPSGraphTensorData() forKey:zStatePlaceholder.getMPSGraphTensor()];
[feeds setObject:cellStateFwdPlaceholder.getMPSGraphTensorData() forKey:cellStateFwdPlaceholder.getMPSGraphTensor()];
[feeds setObject:layersOutputsPlaceholder.getMPSGraphTensorData() forKey:layersOutputsPlaceholder.getMPSGraphTensor()];
NSMutableArray<MPSGraphTensor*> *kernelWeightsList = cachedGraph->kernelWeightsList_;
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList = cachedGraph->recurrentKernelWeightsList_;
@ -445,68 +521,65 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
for (size_t i = 0; i < num_layers; i+=1) {
kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]);
recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]);
bias = Placeholder([biasList objectAtIndex:i], biases[i]);
recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]);
[feeds setObject:kernelWeight.getMPSGraphTensorData() forKey:kernelWeight.getMPSGraphTensor()];
[feeds setObject:recurrentKernelWeight.getMPSGraphTensorData() forKey:recurrentKernelWeight.getMPSGraphTensor()];
[feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()];
[feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()];
if(has_biases) {
bias = Placeholder([biasList objectAtIndex:i], biases[i]);
recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]);
[feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()];
[feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()];
}
}
Tensor output = at::empty_like(input);
Tensor grad_rec_weights = at::empty_like(recurrent_kernel_weights[0]);
Tensor grad_weights = at::empty_like(kernel_weights[0]);
Tensor grad_bias = at::empty_like(biases[0]);
Tensor grad_state = at::empty_like(hx[0]);
Tensor grad_cell_state = at::empty_like(hx[1]);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensors_[0], output);
Placeholder gradRecWeightsPlaceholder = Placeholder(cachedGraph->outputTensors_[1], grad_rec_weights);
Placeholder gradWeightsPlaceholder = Placeholder(cachedGraph->outputTensors_[2], grad_weights);
Placeholder gradBiasPlaceholder = Placeholder(cachedGraph->outputTensors_[3], grad_bias);
Placeholder gradStatePlaceholder = Placeholder(cachedGraph->outputTensors_[4], grad_state);
Placeholder gradCellStatePlaceholder = Placeholder(cachedGraph->outputTensors_[5], grad_cell_state);
Tensor output_out = at::empty_like(input);
Tensor grad_state_out = at::empty_like(hx[0]);
Tensor grad_cell_state_out = at::empty_like(hx[1]);
std::vector<Tensor> grad_hx = {grad_state, grad_cell_state};
std::vector<Tensor> grad_hx = {grad_state_out, grad_cell_state_out};
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*> *results = [[[NSMutableDictionary alloc] init] autorelease];
NSMutableArray<MPSGraphTensor*> *gradOutputArray = cachedGraph->gradOutput_;
NSMutableArray<MPSGraphTensor*> *gradRecWeightsArray = cachedGraph->gradRecWeights_;
NSMutableArray<MPSGraphTensor*> *gradWeightsArray = cachedGraph->gradWeights_;
NSMutableArray<MPSGraphTensor*> *gradBiasArray = cachedGraph->gradBias_;
NSMutableArray<MPSGraphTensor*> *gradStateArray = cachedGraph->gradState_;
NSMutableArray<MPSGraphTensor*> *gradCellStateArray = cachedGraph->gradCellState_;
Placeholder gradOutPlaceholder;
MPSGraphTensor* gradOutput = cachedGraph->gradOutput_;
MPSGraphTensor* gradState = cachedGraph->gradState_;
MPSGraphTensor* gradCellState = cachedGraph->gradCellState_;
Placeholder gradStatePlaceholder = Placeholder(gradState, grad_state_out);
Placeholder gradCellStatePlaceholder = Placeholder(gradCellState, grad_cell_state_out);
Placeholder outputPlaceholder = Placeholder(gradOutput, output_out);
[results setObject:gradStatePlaceholder.getMPSGraphTensorData() forKey:gradStatePlaceholder.getMPSGraphTensor()];
[results setObject:gradCellStatePlaceholder.getMPSGraphTensorData() forKey:gradCellStatePlaceholder.getMPSGraphTensor()];
[results setObject:outputPlaceholder.getMPSGraphTensorData() forKey:outputPlaceholder.getMPSGraphTensor()];
Placeholder gradRecWeightsPlaceholder, gradWeightsPlaceholder, gradBiasPlaceholder;
std::vector<Tensor> weights;
for (int i = 0; i < num_layers; i++) {
Tensor output = at::empty_like(input);
Tensor grad_rec_weights = at::empty_like(recurrent_kernel_weights[i]);
Tensor grad_weights = at::empty_like(kernel_weights[i]);
Tensor grad_bias = at::empty_like(biases[i]);
Tensor grad_state = at::empty_like(hx[0]);
Tensor grad_cell_state = at::empty_like(hx[1]);
Tensor grad_bias = at::empty((kernel_weights[i].size(0)), kernel_weights[i].options());
weights.push_back(grad_weights);
weights.push_back(grad_rec_weights);
weights.push_back(grad_bias);
weights.push_back(grad_bias);
gradOutPlaceholder = Placeholder([gradOutputArray objectAtIndex:i], output);
gradRecWeightsPlaceholder = Placeholder([gradRecWeightsArray objectAtIndex:i], grad_rec_weights);
gradWeightsPlaceholder = Placeholder([gradWeightsArray objectAtIndex:i], grad_weights);
gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex:i], grad_bias);
gradStatePlaceholder = Placeholder([gradStateArray objectAtIndex:i], grad_state);
gradCellStatePlaceholder = Placeholder([gradCellStateArray objectAtIndex:i], grad_cell_state);
[results setObject:gradOutPlaceholder.getMPSGraphTensorData() forKey:gradOutPlaceholder.getMPSGraphTensor()];
[results setObject:gradRecWeightsPlaceholder.getMPSGraphTensorData() forKey:gradRecWeightsPlaceholder.getMPSGraphTensor()];
if(has_biases) {
weights.push_back(grad_bias);
weights.push_back(grad_bias);
}
gradRecWeightsPlaceholder = Placeholder([gradRecWeightsArray objectAtIndex: i], grad_rec_weights);
gradWeightsPlaceholder = Placeholder([gradWeightsArray objectAtIndex: i], grad_weights);
gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex: i], grad_bias);
[results setObject:gradBiasPlaceholder.getMPSGraphTensorData() forKey:gradBiasPlaceholder.getMPSGraphTensor()];
[results setObject:gradStatePlaceholder.getMPSGraphTensorData() forKey:gradStatePlaceholder.getMPSGraphTensor()];
[results setObject:gradCellStatePlaceholder.getMPSGraphTensorData() forKey:gradCellStatePlaceholder.getMPSGraphTensor()];
[results setObject:gradRecWeightsPlaceholder.getMPSGraphTensorData() forKey:gradRecWeightsPlaceholder.getMPSGraphTensor()];
[results setObject:gradWeightsPlaceholder.getMPSGraphTensorData() forKey:gradWeightsPlaceholder.getMPSGraphTensor()];
}
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
return std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> (output, grad_hx, weights);
return std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> (output_out, grad_hx, weights);
}
}

View File

@ -35,7 +35,9 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
indices.copy_(cpu_indices);
return;
}
TORCH_WARN_ONCE(self.scalar_type() != ScalarType::Long, "MPS: no support for int64 min/max ops, casting it to int32");
if (self.scalar_type() == ScalarType::Long) {
TORCH_WARN_ONCE("MPS: no support for int64 min/max ops, casting it to int32");
}
MPSStream* stream = getCurrentMPSStream();
struct CachedGraph : public MPSCachedGraph {

View File

@ -75,15 +75,20 @@ MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor)
return inputTensor;
}
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
dataType:inputTensor.dataType];
MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
secondaryTensor:zeroTensor
name:nil];
return [mpsGraph selectWithPredicateTensor:predicateTensor
truePredicateTensor:[mpsGraph ceilWithTensor :inputTensor name:nil]
falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil]
name:nil];
if(!is_macos_13_or_newer()) {
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
dataType:inputTensor.dataType];
MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
secondaryTensor:zeroTensor
name:nil];
return [mpsGraph selectWithPredicateTensor:predicateTensor
truePredicateTensor:[mpsGraph ceilWithTensor :inputTensor name:nil]
falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil]
name:nil];
} else {
return [mpsGraph truncateWithTensor:inputTensor
name:nil];
}
};
} // namespace mps

View File

@ -26,6 +26,11 @@ void upsample_out_template(const Tensor& input,
} else {
native::upsample_2d_common_check(input.sizes(), output_size);
}
Tensor out;
if (!output.is_contiguous()) {
out = at::empty_like(output, MemoryFormat::Contiguous);
}
bool centerResults = false;
MPSGraphResizeMode resizeMode = MPSGraphResizeNearest;
MPSGraphResizeNearestRoundingMode nearestRoundingMode = MPSGraphResizeNearestRoundingModeFloor;
@ -199,7 +204,7 @@ void upsample_out_template(const Tensor& input,
MPSGraphTensorData* sizeTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: sizeNDArray] autorelease];
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, out.has_storage() ? out : output, nil, false);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
@ -209,6 +214,10 @@ void upsample_out_template(const Tensor& input,
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
if (out.has_storage()) {
output.copy_(out);
}
}
}

View File

@ -424,22 +424,54 @@ MPSGraphTensor* asStridedLayer_pattern(MPSGraph *graph, MPSGraphTensor *inputTen
}
static
std::vector<int64_t> getViewShape(const Tensor& src, MPSShape *mpsShape) {
std::vector<int64_t> getViewShape(const Tensor& src, MPSShape *mpsShape, const bool squeeze) {
bool hasMPSShape = (mpsShape != nil);
std::vector<int64_t> src_view_shape;
if (hasMPSShape) {
int src_ndim_view = [mpsShape count];
src_view_shape.resize(src_ndim_view);
for (const auto i : c10::irange(src_ndim_view)) {
src_view_shape[i] = [mpsShape[i] intValue];
if (squeeze) {
for (const auto i : c10::irange(src_ndim_view)) {
if ([mpsShape[i] intValue] == 1)
continue;
src_view_shape.emplace_back([mpsShape[i] intValue]);
}
} else {
src_view_shape.resize(src_ndim_view);
for (const auto i : c10::irange(src_ndim_view)) {
src_view_shape[i] = [mpsShape[i] intValue];
}
}
} else {
src_view_shape = src.sizes().vec();
if (squeeze) {
IntArrayRef src_shape = src.sizes();
size_t src_ndim_view = src_shape.size();
for (const auto i : c10::irange(src_ndim_view)) {
if (src_shape[i] == 1)
continue;
src_view_shape.emplace_back(src_shape[i]);
}
} else {
src_view_shape = src.sizes().vec();
}
}
return src_view_shape;
}
std::vector<int64_t> getSqueezedBaseShape(const Tensor& src, IntArrayRef shape) {
std::vector<int64_t> src_base_shape;
for (const auto i : c10::irange(shape.size())) {
if (shape[i] == 1)
continue;
src_base_shape.emplace_back(shape[i]);
}
return src_base_shape;
}
bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
if (!src.is_contiguous()) {
return false;
@ -447,57 +479,79 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
IntArrayRef src_base_shape = getIMPSAllocator()->getBufferShape(src.storage().data());
size_t src_ndim_base = src_base_shape.size();
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape);
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape, false);
size_t src_ndim_view = src_view_shape.size();
if (src_ndim_base != src_ndim_view) {
return false;
}
for (const auto i: c10::irange(src_ndim_base)) {
if (src_view_shape[i] > src_base_shape[i]) {
return false;
}
}
if (src_view_shape[i] > src_base_shape[i]) {
return false;
}
}
return true;
}
MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType) {
IntArrayRef src_base_shape = getIMPSAllocator()->getBufferShape(src.storage().data());
int src_ndim_base = src_base_shape.size();
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape);
int src_ndim_view = src_view_shape.size();
TORCH_CHECK(src_ndim_base == src_ndim_view);
size_t src_ndim_base = src_base_shape.size();
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape, false);
size_t src_ndim_view = src_view_shape.size();
MPSNDArray *srcTensorNDArrayView = nil;
MPSNDArrayDescriptor *srcTensorNDArrayDesc = nil;
MPSNDArray *srcTensorNDArray = nil;
id<MTLCommandBuffer> commandBuffer = getCurrentMPSStream()->commandBuffer();
int64_t base_idx = 0;
std::vector<int64_t> src_base_shape_vec;
if (src_ndim_view != src_ndim_base) {
src_base_shape_vec.reserve(src_ndim_view);
for (const auto i : c10::irange(src_ndim_view)) {
if (src_view_shape[i] == 1 && src_base_shape[base_idx] != 1) {
src_base_shape_vec.emplace_back(1);
} else {
src_base_shape_vec.emplace_back(src_base_shape[base_idx]);
if (base_idx < src_ndim_base - 1)
base_idx += 1;
}
}
src_base_shape = IntArrayRef(src_base_shape_vec);
src_ndim_base = src_base_shape.size();
}
srcTensorNDArray = ndArrayFromTensor(src, getMPSShape(src_base_shape), mpsDataType);
srcTensorNDArrayDesc = srcTensorNDArray.descriptor;
int firstDimToSlice = 0;
size_t firstDimToSlice = 0;
while (src_base_shape[firstDimToSlice] == src_view_shape[firstDimToSlice]) {
firstDimToSlice++;
}
int view_numel = 1;
int64_t view_numel = 1;
for (const auto i : c10::irange(firstDimToSlice + 1, src_base_shape.size())) {
view_numel *= src_base_shape[i];
}
int sliceOffset = src.storage_offset() / view_numel;
// There are cases where both dimensions of a view can shrink
// E.g: x = torch.randn((3,6))[1, 1:3]
int nextSliceOffset = src.storage_offset() % view_numel;
int64_t sliceOffset = src.storage_offset() / view_numel;
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - firstDimToSlice
withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice])}];
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - firstDimToSlice withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice])}];
if (nextSliceOffset) {
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 2 - firstDimToSlice withSubrange:{static_cast<NSUInteger>(nextSliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice+1])}];
// Slice any remaining dimensions
for (const auto crtSliceOffset: c10::irange(firstDimToSlice + 1, src_base_shape.size())) {
if (src_view_shape[crtSliceOffset] != src_base_shape[crtSliceOffset]) {
if (crtSliceOffset == src_base_shape.size() - 1) {
sliceOffset = src.storage_offset() % src_base_shape[src_base_shape.size() - 1];
} else {
sliceOffset = (src.storage_offset() % view_numel) / (view_numel / src_base_shape[crtSliceOffset]);
}
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - crtSliceOffset
withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[crtSliceOffset])}];
}
}
srcTensorNDArrayView = [srcTensorNDArray arrayViewWithCommandBuffer:commandBuffer
descriptor:srcTensorNDArrayDesc
aliasing:MPSAliasingStrategyShallAlias];
@ -696,7 +750,7 @@ const std::string& getGatherScatterScalarType(const Tensor& t) {
{c10::ScalarType::Int, "int"},
{c10::ScalarType::Short, "short"},
{c10::ScalarType::Char, "char"},
{c10::ScalarType::Byte, "char"},
{c10::ScalarType::Byte, "uchar"},
{c10::ScalarType::Bool, "bool"},
};

View File

@ -3567,19 +3567,14 @@
- func: max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor
- func: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
# TODO: Add this function to MPS dispatch key so that we avoid declaring it in
# native_functions.yaml
# https://github.com/pytorch/pytorch/issues/77394
- func: _mps_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
dispatch:
MPS: _mps_max_pool2d
autogen: _mps_max_pool2d.out
CompositeImplicitAutograd: max_pool2d
MPS: mps_max_pool2d
- func: mps_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
- func: max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
dispatch:
MPS: mps_max_pool2d_backward
autogen: mps_max_pool2d_backward.out
autogen: max_pool2d_backward.out
- func: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
dispatch:
@ -7188,12 +7183,12 @@
# MPS LSTM implementation
- func: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
- func: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
dispatch:
MPS: _lstm_mps
autogen: _lstm_mps.out
- func: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
- func: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
dispatch:
MPS: lstm_mps_backward
autogen: lstm_mps_backward.out

View File

@ -630,6 +630,7 @@ macro(cuda_unset_include_and_libraries)
unset(CUDA_cublas_LIBRARY CACHE)
unset(CUDA_cublas_device_LIBRARY CACHE)
unset(CUDA_cublasemu_LIBRARY CACHE)
unset(CUDA_cublasLt_LIBRARY CACHE)
unset(CUDA_cufft_LIBRARY CACHE)
unset(CUDA_cufftemu_LIBRARY CACHE)
unset(CUDA_cupti_LIBRARY CACHE)
@ -963,6 +964,7 @@ endif()
find_cuda_helper_libs(cufft)
find_cuda_helper_libs(cublas)
find_cuda_helper_libs(cublasLt)
# cusparse showed up in version 3.2
find_cuda_helper_libs(cusparse)
find_cuda_helper_libs(curand)
@ -993,7 +995,7 @@ if (CUDA_BUILD_EMULATION)
set(CUDA_CUBLAS_LIBRARIES ${CUDA_cublasemu_LIBRARY})
else()
set(CUDA_CUFFT_LIBRARIES ${CUDA_cufft_LIBRARY})
set(CUDA_CUBLAS_LIBRARIES ${CUDA_cublas_LIBRARY} ${CUDA_cublas_device_LIBRARY})
set(CUDA_CUBLAS_LIBRARIES ${CUDA_cublas_LIBRARY} ${CUDA_cublas_device_LIBRARY} ${CUDA_cublasLt_LIBRARY})
endif()
########################
@ -1962,7 +1964,7 @@ macro(CUDA_ADD_CUBLAS_TO_TARGET target)
if (CUDA_BUILD_EMULATION)
target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cublasemu_LIBRARY})
else()
target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cublas_LIBRARY} ${CUDA_cublas_device_LIBRARY})
target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cublas_LIBRARY} ${CUDA_cublas_device_LIBRARY} ${CUDA_cublasLt_LIBRARY})
endif()
endmacro()

View File

@ -351,7 +351,7 @@ master_doc = 'index'
# General information about the project.
project = 'PyTorch'
copyright = '2022, PyTorch Contributors'
copyright = '2023, PyTorch Contributors'
author = 'PyTorch Contributors'
torch_version = str(torch.__version__)

View File

@ -6,13 +6,12 @@ significant speedups the newer your GPU is.
.. code:: python
from torch._dynamo import optimize
import torch
def fn(x, y):
a = torch.cos(x).cuda()
b = torch.sin(y).cuda()
return a + b
new_fn = optimize("inductor")(fn)
new_fn = torch.compile(fn, backend="inductor")
input_tensor = torch.randn(10000).to(device="cuda:0")
a = new_fn(input_tensor, input_tensor)
@ -54,7 +53,7 @@ with the actual generated kernel being
tmp2 = tl.sin(tmp1)
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask)
And you can verify that fusing the two ``sins`` did actually occur
And you can verify that fusing the two ``sin`` did actually occur
because the two ``sin`` operations occur within a single Triton kernel
and the temporary variables are held in registers with very fast access.
@ -69,13 +68,12 @@ hub.
.. code-block:: python
import torch
import torch._dynamo as dynamo
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
opt_model = dynamo.optimize("inductor")(model)
opt_model = torch.compile(model, backend="inductor")
model(torch.randn(1,3,64,64))
And that is not the only available backend, you can run in a REPL
``dynamo.list_backends()`` to see all the available backends. Try out the
``torch._dynamo.list_backends()`` to see all the available backends. Try out the
``cudagraphs`` or ``nvfuser`` next as inspiration.
Lets do something a bit more interesting now, our community frequently
@ -92,11 +90,10 @@ HuggingFace hub and optimize it:
import torch
from transformers import BertTokenizer, BertModel
import torch._dynamo as dynamo
# Copy pasted from here https://huggingface.co/bert-base-uncased
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0")
model = dynamo.optimize("inductor")(model) # This is the only line of code that we changed
model = torch.compile(model, backend="inductor") # This is the only line of code that we changed
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt').to(device="cuda:0")
output = model(**encoded_input)
@ -116,7 +113,7 @@ Similarly lets try out a TIMM example
import torch._dynamo as dynamo
import torch
model = timm.create_model('resnext101_32x8d', pretrained=True, num_classes=2)
opt_model = dynamo.optimize("inductor")(model)
opt_model = torch.compile(model, backend="inductor")
opt_model(torch.randn(64,3,7,7))
Our goal with Dynamo and inductor is to build the highest coverage ML compiler
@ -132,16 +129,16 @@ or ``torch._dynamo.list_backends()`` each of which with its optional dependencie
Some of the most commonly used backends include:
**Training & inference backends**:
* ``dynamo.optimize("inductor")`` - Uses ``TorchInductor`` backend. `Read more <https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747>`__
* ``dynamo.optimize("aot_ts_nvfuser")`` - nvFuser with AotAutograd/TorchScript. `Read more <https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593>`__
* ``dynamo.optimize("nvprims_nvfuser")`` - nvFuser with PrimTorch. `Read more <https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593>`__
* ``dynamo.optimize("cudagraphs")`` - cudagraphs with AotAutograd. `Read more <https://github.com/pytorch/torchdynamo/pull/757>`__
* ``torch.compile(m, backend="inductor")`` - Uses ``TorchInductor`` backend. `Read more <https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747>`__
* ``torch.compile(m, backend="aot_ts_nvfuser")`` - nvFuser with AotAutograd/TorchScript. `Read more <https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593>`__
* ``torch.compile(m, backend=""nvprims_nvfuser")`` - nvFuser with PrimTorch. `Read more <https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593>`__
* ``torch.compile(m, backend="cudagraphs")`` - cudagraphs with AotAutograd. `Read more <https://github.com/pytorch/torchdynamo/pull/757>`__
**Inference-only backends**:
* ``dynamo.optimize("onnxrt")`` - Uses ONNXRT for inference on CPU/GPU. `Read more <https://onnxruntime.ai/>`__
* ``dynamo.optimize("tensorrt")`` - Uses ONNXRT to run TensorRT for inference optimizations. `Read more <https://github.com/onnx/onnx-tensorrt>`__
* ``dynamo.optimize("ipex")`` - Uses IPEX for inference on CPU. `Read more <https://github.com/intel/intel-extension-for-pytorch>`__
* ``dynamo.optimize("tvm")`` - Uses Apach TVM for inference optimizations. `Read more <https://tvm.apache.org/>`__
* ``torch.compile(m, backend="onnxrt")`` - Uses ONNXRT for inference on CPU/GPU. `Read more <https://onnxruntime.ai/>`__
* ``torch.compile(m, backend="tensorrt")`` - Uses ONNXRT to run TensorRT for inference optimizations. `Read more <https://github.com/onnx/onnx-tensorrt>`__
* ``torch.compile(m, backend="ipex")`` - Uses IPEX for inference on CPU. `Read more <https://github.com/intel/intel-extension-for-pytorch>`__
* ``torch.compile(m, backend="tvm")`` - Uses Apach TVM for inference optimizations. `Read more <https://tvm.apache.org/>`__
Why do you need another way of optimizing PyTorch code?
-------------------------------------------------------

View File

@ -15,7 +15,7 @@ Where a complete example looks like this:
from typing import List
import torch
import torchdynamo
from torch import _dynamo as torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()

View File

@ -14,7 +14,7 @@ worlds — usability and performance.
TorchDynamo makes it easy to experiment with different compiler
backends to make PyTorch code faster with a single line decorator
``torch._dynamo.optimize()``
``torch._dynamo.optimize()`` which is wrapped for convenience by ``torch.compile()``
.. image:: ../_static/img/dynamo/TorchDynamo.png

View File

@ -27,7 +27,7 @@ TorchDynamo dependencies (for CUDA 11.7):
.. code-block:: shell
pip3 install numpy --pre torch[dynamo] --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117
pip3 install numpy --pre torch --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117
CPU requirements
~~~~~~~~~~~~~~~~
@ -41,16 +41,6 @@ To install, run the following command:
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu
Install from Local Source
~~~~~~~~~~~~~~~~~~~~~~~~~
Alternatively, you can build PyTorch from `source
<https://github.com/pytorch/pytorch#from-source>`__, which has TorchDynamo
included.
To install GPU TorchDynamo dependencies, run ``make triton`` in the
PyTorch repo root directory.
Verify Installation
~~~~~~~~~~~~~~~~~~~

View File

@ -129,6 +129,49 @@ Algorithms
Rprop
SGD
Many of our algorithms have various implementations optimized for performance,
readability and/or generality, so we attempt to default to the generally fastest
implementation for the current device if no particular implementation has been
specified by the user.
We have 3 major categories of implementations: for-loop, foreach (multi-tensor), and
fused. The most straightforward implementations are for-loops over the parameters with
big chunks of computation. For-looping is usually slower than our foreach
implementations, which combine parameters into a multi-tensor and run the big chunks
of computation all at once, thereby saving many sequential kernel calls. A few of our
optimizers have even faster fused implementations, which fuse the big chunks of
computation into one kernel. We can think of foreach implementations as fusing
horizontally and fused implementations as fusing vertically on top of that.
In general, the performance ordering of the 3 implementations is fused > foreach > for-loop.
So when applicable, we default to foreach over for-loop. Applicable means the foreach
implementation is available, the user has not specified any implementation-specific kwargs
(e.g., fused, foreach, differentiable), and all tensors are native and on CUDA. Note that
while fused should be even faster than foreach, the implementations are newer and we would
like to give them more bake-in time before flipping the switch everywhere. You are welcome
to try them out though!
Below is a table showing the available and default implementations of each algorithm:
.. csv-table::
:header: "Algorithm", "Default", "Has foreach?", "Has fused?"
:widths: 25, 25, 25, 25
:delim: ;
:class:`Adadelta`;foreach;yes;no
:class:`Adagrad`;foreach;yes;no
:class:`Adam`;foreach;yes;yes
:class:`AdamW`;foreach;yes;yes
:class:`SparseAdam`;for-loop;no;no
:class:`Adamax`;foreach;yes;no
:class:`ASGD`;foreach;yes;no
:class:`LBFGS`;for-loop;no;no
:class:`NAdam`;foreach;yes;no
:class:`RAdam`;foreach;yes;no
:class:`RMSprop`;foreach;yes;no
:class:`Rprop`;foreach;yes;no
:class:`SGD`;foreach;yes;no
How to adjust learning rate
---------------------------

View File

@ -1024,17 +1024,12 @@ def main():
'typing-extensions',
'sympy',
'networkx',
'jinja2',
]
extras_require = {
'opt-einsum': ['opt-einsum>=3.3']
}
if platform.system() == 'Linux':
triton_pin_file = os.path.join(cwd, ".github", "ci_commit_pins", "triton.txt")
if os.path.exists(triton_pin_file):
with open(triton_pin_file) as f:
triton_pin = f.read().strip()
extras_require['dynamo'] = ['pytorch-triton==2.0.0+' + triton_pin[:10], 'jinja2']
# Parse the command line and check the arguments before we proceed with
# building deps and setup. We need to set values so `--help` works.

View File

@ -504,7 +504,7 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
fsdp_kwargs=fsdp_kwargs,
deterministic=True,
)
optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR)
optim = torch.optim.Adam(fsdp_model.parameters(), foreach=False, lr=LR)
fsdp_kwargs["use_orig_params"] = True
fsdp_model_orig_params = TransformerWithSharedParams.init(
self.process_group,
@ -513,7 +513,9 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
fsdp_kwargs=fsdp_kwargs,
deterministic=True,
)
optim_orig_params = torch.optim.Adam(fsdp_model_orig_params.parameters(), lr=LR)
optim_orig_params = torch.optim.Adam(
fsdp_model_orig_params.parameters(), foreach=False, lr=LR
)
return fsdp_model, optim, fsdp_model_orig_params, optim_orig_params
def _check_fsdp_parameter_parity(self, fsdp1: FSDP, fsdp2: FSDP) -> None:

View File

@ -60,6 +60,11 @@ unittest.expectedFailure(
# Cannot call sizes() on tensor with symbolic sizes/strides
)
unittest.expectedFailure(
DynamicShapesMiscTests.test_parsing_sdpa_dynamic_shapes
# Cannot call sizes() on tensor with symbolic sizes/strides
)
# DynamicShapesSubGraphTests
unittest.expectedFailure(

View File

@ -3145,6 +3145,53 @@ class MiscTests(torch._dynamo.test_case.TestCase):
self.assertEqual(compiled.device.index, 0)
self.assertEqual(compiled.dtype, torch.float16)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater,
"Can't run fused SDPA on this platform",
)
def test_parsing_sdpa(self):
class MyModule(torch.nn.Module):
def forward(self, query, key, value):
out = F.scaled_dot_product_attention(query, key, value, None, 0, True)
out = F.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=None,
dropout_p=0,
is_causal=True,
)
out = F.scaled_dot_product_attention(
query,
key=key,
value=value,
attn_mask=None,
dropout_p=0,
is_causal=True,
)
out = F.scaled_dot_product_attention(
query, key, value, None, dropout_p=0, is_causal=True
)
return out
device = "cuda"
dtype = torch.float16
seq_len_q = 1
seq_len_k = 1
head_dim = 8
query = torch.ones(
1, 8, seq_len_q, head_dim, device=device, dtype=dtype, requires_grad=True
)
key = torch.ones(
1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True
)
value = torch.ones(
1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True
)
module = MyModule()
opt_mod = torch._dynamo.optimize("inductor")(module)
opt_mod(query, key, value)
def test_autocast_cpu(self):
class MyModule(torch.nn.Module):
def forward(self, x):

View File

@ -377,8 +377,6 @@ aten::_mps_convolution
aten::_mps_convolution.out
aten::_mps_convolution_transpose
aten::_mps_convolution_transpose.out
aten::_mps_max_pool2d
aten::_mps_max_pool2d.out
aten::_native_batch_norm_legit.no_stats_out
aten::_native_batch_norm_legit.out
aten::_native_decoder_only_multi_head_attention
@ -857,6 +855,8 @@ aten::max
aten::max.dim
aten::max.dim_max
aten::max.unary_out
aten::max_pool2d_backward
aten::max_pool2d_backward.out
aten::max_pool2d_with_indices
aten::max_pool2d_with_indices.out
aten::max_pool2d_with_indices_backward
@ -930,8 +930,6 @@ aten::mps_convolution_backward
aten::mps_convolution_backward.out
aten::mps_convolution_transpose_backward
aten::mps_convolution_transpose_backward.out
aten::mps_max_pool2d_backward
aten::mps_max_pool2d_backward.out
aten::multi_margin_loss
aten::multi_margin_loss.out
aten::multi_margin_loss_backward

View File

@ -150,6 +150,10 @@ ALLOW_LIST = [
("aten::sum.SymInt", datetime.date(2022, 11, 30)),
("aten::mps_linear", datetime.date(9999, 1, 1)),
("aten::_mps_linear", datetime.date(9999, 1, 1)),
("aten::_mps_max_pool2d", datetime.date(9999, 1, 1)),
("aten::_mps_max_pool2d.out", datetime.date(9999, 1, 1)),
("aten::mps_max_pool2d_backward", datetime.date(9999, 1, 1)),
("aten::mps_max_pool2d_backward.out", datetime.date(9999, 1, 1)),
("aten::view_copy.SymInt", datetime.date(2022, 11, 30)),
("aten::view_copy.SymInt_out", datetime.date(2022, 11, 30)),
("aten::expand_copy.SymInt", datetime.date(2022, 11, 30)),
@ -269,7 +273,10 @@ ALLOW_LIST = [
("aten::dsplit.int", datetime.date(2022, 9, 1)),
("aten::hsplit.array", datetime.date(2022, 9, 1)),
("aten::hsplit.int", datetime.date(2022, 9, 1)),
("aten::lstm_mps_backward.out", datetime.date(2022, 9, 1)),
("aten::lstm_mps_backward.out", datetime.date(2023, 9, 1)),
("aten::lstm_mps_backward", datetime.date(2023, 9, 1)),
("aten::_lstm_mps.out", datetime.date(2023, 9, 1)),
("aten::_lstm_mps", datetime.date(2023, 9, 1)),
("aten::miopen_rnn_backward.out", datetime.date(2022, 9, 1)),
("aten::quantize_per_tensor.tensors_out", datetime.date(2022, 9, 1)),
("aten::split", datetime.date(2022, 9, 1)),

View File

@ -989,6 +989,34 @@ def forward(self, primals_1, primals_2):
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True)
self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_mem_leak_from_save_for_bw(self):
# See a full diagnosis at this issue: https://github.com/pytorch/pytorch/issues/94990
# Note [Detaching saved tensors in AOTAutograd]
# This program creates a ref-cycle. Long term, we should fix this ref cycle
# (since it can arise, naturally albeit rarely, from uses of autograd.Function).
# But AOTAutograd makes it more likely to show up from tracing user programs,
# so we deal with it by manually detaching the tensors that we save for backward.
# This is completely wrong and would give wrong results if we were to do double backward.
# Fortunately today, double backward is explicitly banned in AOTAutograd.
def f(a, b):
add = a + a
split = torch.functional.split(add, [4, 4], dim=1)
getitem_2 = split[1]
unsqueeze = getitem_2.unsqueeze(-1)
mul = unsqueeze * b
return (getitem_2, mul)
f_compiled = aot_function(f, nop)
inps = [
torch.ones(8, 8, device='cuda', requires_grad=True),
torch.ones(1, 4, 1, device='cuda', requires_grad=True),
]
mem_before = torch.cuda.memory_allocated()
f_compiled(*inps)
mem_after = torch.cuda.memory_allocated()
self.assertTrue(mem_after == mem_before)
@patch("functorch.compile.config.use_fake_tensor", True)
def test_output_aliases_multiple_inputs_get_correct_one(self):
# a and b are aliased, but have different shapes

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: inductor"]
import contextlib
import sys
from unittest.mock import patch
import functorch
@ -10,6 +11,7 @@ from torch._dynamo.backends.registry import register_backend
from torch._inductor import metrics
from torch._inductor.compile_fx import compile_fx, count_bytes_inner
from torch.testing._internal.common_utils import (
IS_WINDOWS,
TEST_WITH_ROCM,
TestCase as TorchTestCase,
)
@ -23,9 +25,17 @@ def count_bytes_inductor(gm, example_inputs):
return compile_fx(gm, example_inputs, inner_compile=count_bytes_inner)
@torch._dynamo.optimize("count_bytes_inductor")
def f(x):
return torch.cat([x, x.cos()])
# TODO remove version check once dynamo supports 3.11
if sys.version_info < (3, 11) and not IS_WINDOWS:
@torch._dynamo.optimize("count_bytes_inductor")
def f(x):
return torch.cat([x, x.cos()])
else:
def f(x):
return torch.cat([x, x.cos()])
def count_numel(f, *args):

View File

@ -62,11 +62,30 @@ class TestSelectAlgorithm(TestCase):
def foo(input, weight, bias):
return torch.addmm(bias, input, weight)
foo(
inps = (
torch.randn(20, 33, device="cuda"),
torch.randn(33, 16, device="cuda"),
torch.randn(20, 16, device="cuda"),
)
foo(*inps)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patch.object(select_algorithm, "VERIFY", dict(atol=5e-2, rtol=5e-2))
@patches
def test_addmm_fp16(self):
@torch.compile
def foo(input, weight, bias):
return torch.addmm(bias, input, weight)
inps = (
torch.randn(2, 320, device="cuda", dtype=torch.half),
torch.randn(320, 320, device="cuda", dtype=torch.half).t(),
torch.empty(320, device="cuda", dtype=torch.half),
)
foo(*inps)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

View File

@ -4127,18 +4127,38 @@ class CommonTemplate:
self.common(fn, (torch.zeros([4, 256, 296, 304]), torch.zeros([2292, 5])))
@requires_decomp(aten.nll_loss_forward)
def test_nll_loss_forward(self):
def fn(a, b):
return aten.nll_loss_forward(a, b, None, 1, -100)
self.common(
fn,
(
torch.randn([5, 5]),
torch.zeros([5], dtype=torch.int64),
),
labels = (
torch.zeros([5], dtype=torch.int64),
torch.tensor([-100, -100, 3, -100, -100], dtype=torch.int64),
)
inps = (torch.randn(5, 5), torch.randn(5, 5))
for a, b in zip(inps, labels):
self.common(
fn,
(a, b),
)
def test_nll_loss_backward(self):
def fn(a, b, c):
return aten.nll_loss_backward(
a, b, c, None, 1, -100, torch.tensor(1.0, device=self.device)
)
labels = (
torch.zeros([5], dtype=torch.int64),
torch.tensor([-100, -100, 3, -100, -100], dtype=torch.int64),
)
inps = (torch.randn(5, 5), torch.randn(5, 5))
grad_outs = (torch.randn(()), torch.randn(()))
for a, b, c in zip(grad_outs, inps, labels):
self.common(
fn,
(a, b, c),
)
def test_isinf(self):
def fn(x):
@ -5613,6 +5633,22 @@ class CommonTemplate:
eager_out = eager_mod(*eager_args)
self.assertEqual(inductor_out, eager_out)
def test_where_with_logical_op(self):
def fn_and(x, y):
return torch.where(torch.logical_and(x, y), 1.0, 0.0)
def fn_or(x, y):
return torch.where(torch.logical_or(x, y), 1.0, 0.0)
self.common(
fn_and,
(torch.randn(32), torch.randn(32)),
)
self.common(
fn_or,
(torch.randn(32), torch.randn(32)),
)
test_skips = {
"test_alexnet_prefix_dynamic_shapes": ("cuda",),
@ -5956,6 +5992,8 @@ if HAS_CPU:
"randn",
"isnan",
"rand",
"logical_and",
"logical_or",
]
union = {*cpp_vec_op_list, *diff}
self.assertTrue(set(cpp_op_list).issubset(union))

View File

@ -448,6 +448,7 @@ inductor_all_samples = {
"mT",
"mH",
"rsub",
"triu",
}

View File

@ -435,6 +435,27 @@ class TestMPS(TestCaseMPS):
helper(0, [1024])
helper(0.2, [2, 3])
def test_fill_storage_offset(self):
shape = [2, 10]
val = 0.2
tensor = torch.ones(shape, device="mps")
tensor_mps = tensor[:][1].fill_(val)
tensor_0 = torch.ones(shape, device="cpu")
tensor_cpu = tensor_0[:][1].fill_(val)
self.assertEqual(tensor_mps, tensor_cpu)
shape = [1, 10]
val = 0.0
tensor = torch.ones(shape, device="mps")
val_tensor_mps = torch.tensor(val, device="mps")
tensor_mps = tensor[:, 9].fill_(val_tensor_mps)
tensor_0 = torch.ones(shape, device="cpu")
val_tensor_cpu = torch.tensor(val, device="cpu")
tensor_cpu = tensor_0[:, 9].fill_(val_tensor_cpu)
self.assertEqual(tensor_mps, tensor_cpu)
def test_cdist_large(self, device="mps"):
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(100, 10, device=device)
@ -1786,6 +1807,87 @@ class TestMPS(TestCaseMPS):
x_cpu = x_cpu + 2
self.assertEqual(x, x_cpu)
def test_reshape_storage_offset(self):
# https://github.com/pytorch/pytorch/issues/95883
B = 4
T = 1
lin_cpu = nn.Linear(10, 256)
lin_mps = nn.Linear(10, 256, device="mps")
# Use the same weights and bias as the ones from the cpu
lin_mps.weight.data = lin_cpu.weight.data.detach().clone().to("mps").requires_grad_()
lin_mps.bias.data = lin_cpu.bias.data.detach().clone().to("mps").requires_grad_()
x_mps = torch.rand([B, T, 10], device="mps", requires_grad=True)
x_cpu = x_mps.detach().clone().cpu().requires_grad_()
x_mps = lin_mps(x_mps)
x_cpu = lin_cpu(x_cpu)
self.assertEqual(x_mps.shape, (B, T, 256))
self.assertEqual(x_cpu.shape, (B, T, 256))
cls_token_mps = torch.rand([1, 256], device="mps", requires_grad=True).repeat(B, 1, 1)
cls_token_cpu = cls_token_mps.detach().clone().cpu()
x_mps = torch.cat([cls_token_mps, x_mps], dim=1)
x_cpu = torch.cat([cls_token_cpu, x_cpu], dim=1)
x_mps = x_mps.transpose(0, 1)
x_cpu = x_cpu.transpose(0, 1)
target_mps = torch.rand_like(x_mps)
target_cpu = target_mps.detach().clone().cpu()
loss_mps = F.mse_loss(x_mps, target_mps)
loss_cpu = F.mse_loss(x_cpu, target_cpu)
self.assertEqual(loss_mps, loss_cpu)
loss_mps.backward()
loss_cpu.backward()
self.assertEqual(x_mps.grad, x_cpu.grad)
def test_stack(self):
# https://github.com/pytorch/pytorch/issues/87856
x_cpu = torch.tensor([[1, 2]])
x_mps = x_cpu.detach().clone().to("mps")
y_cpu = torch.stack((x_cpu[:, :1], x_cpu[:, -1:]), dim=-1)
y_mps = torch.stack((x_mps[:, :1], x_mps[:, -1:]), dim=-1)
self.assertEqual(y_cpu, y_mps)
t_mps = torch.tensor([1, 2, 3, 4], device="mps")
t_cpu = t_mps.detach().cpu().detach()
x_mps = t_mps[2:]
y_mps = t_mps[:2]
x_cpu = t_cpu[2:]
y_cpu = t_cpu[:2]
res_mps = torch.stack((y_mps, x_mps), dim=-1)
res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
self.assertEqual(res_mps, res_cpu)
def test_unsafe_chunk(self):
# https://github.com/pytorch/pytorch/issues/91065
a = torch.rand(5, dtype=torch.float32, device="cpu")
ret = a.unsafe_chunk(4, 0)
y = ret[0] * ret[2]
a_mps = a.to("mps")
ret_mps = a_mps.unsafe_chunk(4, 0)
y_mps = ret_mps[0] * ret_mps[2]
self.assertEqual(y, y_mps)
def test_slice_casting(self):
# generate random binary numbers
cpu_in = torch.bernoulli(torch.empty(1, 1, 128, 128).uniform_(0, 1)).to(torch.uint8)
mps_in = cpu_in.detach().clone().to("mps")
# check copy_cast(unit8 -> bool) on tensors with storage offset
cpu_out = cpu_in[:, :, 11 : 12, :12].to(torch.bool)
mps_out = mps_in[:, :, 11 : 12, :12].to(torch.bool)
self.assertEqual(cpu_out, mps_out)
def test_slice_reshape_contg_view(self):
import torch
@ -1797,6 +1899,72 @@ class TestMPS(TestCaseMPS):
self.assertEqual(r_mps, r_cpu)
def test_contiguous_slice_2d(self):
def helper(shape):
for i in range(0, shape[0]):
for j in range(0, shape[1]):
t_mps = torch.randn(shape, device="mps")
t_cpu = t_mps.detach().clone().cpu()
y_mps = t_mps[i:, :j]
y_cpu = t_cpu[i:, :j]
self.assertEqual(y_mps + 1, y_cpu + 1)
y_mps = t_mps[i:, j]
y_cpu = t_cpu[i:, j]
self.assertEqual(y_mps + 1, y_cpu + 1)
y_mps = t_mps[i, :j]
y_cpu = t_cpu[i, :j]
self.assertEqual(y_mps + 1, y_cpu + 1)
y_mps = t_mps[:i, :j]
y_cpu = t_cpu[:i, :j]
self.assertEqual(y_mps + 1, y_cpu + 1)
y_mps = t_mps[:i, j]
y_cpu = t_cpu[:i, j]
self.assertEqual(y_mps + 1, y_cpu + 1)
y_mps = t_mps[:i, j:]
y_cpu = t_cpu[:i, j:]
self.assertEqual(y_mps + 1, y_cpu + 1)
l = []
for N in range(1, 3):
l.append(N)
for C in range(1, 3):
l.append(C)
helper(l)
for D in range(1, 3):
l.append(D)
helper(l)
for H in range(1, 3):
l.append(H)
helper(l)
for W in range(1, 3):
l.append(W)
helper(l)
l.pop()
l.pop()
l.pop()
l.pop()
l.pop()
helper([9, 15, 4])
helper([9, 3, 2])
helper([3, 4, 18, 22])
helper([3, 4, 18, 22, 150])
def test_contiguous_slice_3d(self):
x = torch.randn(2, 3, 3, device="mps")
x_cpu = x.detach().clone().cpu()
x = x[:1]
x_cpu = x_cpu[:1]
out = x[:, 0:1, 0:1] * x[:, 1:2, 1:2]
out_cpu = x_cpu[:, 0:1, 0:1] * x_cpu[:, 1:2, 1:2]
self.assertEqual(out, out_cpu)
def test_view_slice(self):
# https://github.com/pytorch/pytorch/issues/83995
NUM_SAMPLES = 60
@ -1890,25 +2058,28 @@ class TestMPS(TestCaseMPS):
if operator == "<=":
res_mps = x_mps <= y_mps
res_cpu = x_cpu <= y_cpu
if operator == "<":
elif operator == "<":
res_mps = x_mps < y_mps
res_cpu = x_cpu < y_cpu
if operator == ">=":
elif operator == ">=":
res_mps = x_mps >= y_mps
res_cpu = x_cpu >= y_cpu
if operator == ">":
elif operator == ">":
res_mps = x_mps >= y_mps
res_cpu = x_cpu >= y_cpu
if operator == "==":
elif operator == "==":
res_mps = x_mps == y_mps
res_cpu = x_cpu == y_cpu
if operator == "!=":
elif operator == "!=":
res_mps = x_mps != y_mps
res_cpu = x_cpu != y_cpu
elif operator == "stack":
res_mps = torch.stack((y_mps, x_mps), dim=-1)
res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
self.assertEqual(res_mps, res_cpu)
for op in ["<=", "<", ">=", ">", "==", "!="]:
for op in ["<=", "<", ">=", ">", "==", "!=", "stack"]:
helper(op)
def test_slice_of_slice(self):
@ -2245,10 +2416,9 @@ class TestMPS(TestCaseMPS):
# See https://github.com/pytorch/pytorch/issues/84995
def test_div_bugs(self):
for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']):
if dtype != torch.int64:
x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
y = torch.div(x, 101, rounding_mode=mode)
self.assertEqual(y.sum(), 0)
x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
y = torch.div(x, 101, rounding_mode=mode)
self.assertEqual(y.sum(), 0)
# See https://github.com/pytorch/pytorch/issues/82663
def test_bool_expand(self):
@ -3358,6 +3528,26 @@ class TestNLLLoss(TestCaseMPS):
self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
def test_log_softmax_large_numbers(self):
values = [
[10.0, 100.0, 1000.0, 10000.0, 100000.0, 1000000.0],
[-10.0, -100.0, -1000.0, -10000.0, -100000.0, -1000000.0]
]
cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
mps_x = torch.tensor(values, device='mps', requires_grad=True)
cpu_log_softmax = F.log_softmax(cpu_x, dim=-1)
mps_log_softmax = F.log_softmax(mps_x, dim=-1)
self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
cpu_grad = torch.ones_like(cpu_log_softmax)
mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
cpu_log_softmax.backward(gradient=cpu_grad)
mps_log_softmax.backward(gradient=mps_grad)
self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
def test_eq(self):
values1 = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
values2 = [[[1.0, 2.0, 15.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [0.0, 11.0, 12.0]]]
@ -4545,9 +4735,9 @@ class TestNLLLoss(TestCaseMPS):
)
def test_upsample_nearest2d(self):
def helper(N, C, H, W):
def helper(N, C, H, W, memory_format):
inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
requires_grad=True).reshape(N, C, H, W)
requires_grad=True).reshape(N, C, H, W).to(memory_format=memory_format)
inputCPU.retain_grad()
inputMPS = inputCPU.detach().to('mps').requires_grad_()
@ -4573,8 +4763,9 @@ class TestNLLLoss(TestCaseMPS):
self.assertEqual(inputCPU.grad, inputMPS.grad)
helper(1, 1, 4, 4)
helper(7, 5, 3, 2)
for memory_format in [torch.channels_last, torch.contiguous_format]:
helper(1, 1, 4, 4, memory_format=memory_format)
helper(7, 5, 3, 2, memory_format=memory_format)
def test_upsample_bilinear2d(self):
def helper(N, C, H, W):
@ -7716,7 +7907,8 @@ class TestConvolutionMPS(TestCaseMPS):
def test_conv_backward_1d_channels_last(self):
def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1):
# https://github.com/pytorch/pytorch/issues/84511
conv_cpu = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups)
conv_cpu = torch.nn.Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).requires_grad_()
conv_mps = torch.nn.Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps")
conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True)
@ -7756,15 +7948,89 @@ class TestConvolutionMPS(TestCaseMPS):
def test_conv2d_all_strides_paddings(self):
# https://github.com/pytorch/pytorch/issues/83180
y_cpu = torch.randn(2, 2, 3, 6)
y_gpu = y_cpu.to(device='mps')
for strideX in range(1, 4):
for strideY in range(1, 4):
conv_cpu = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=(strideX, strideY))
conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
x_cpu = conv_cpu(y_cpu)
x_gpu = conv_gpu(y_gpu)
self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
def helper(N, C, H, W, groups, input_mem_format, weight_mem_format, permute_data):
x_cpu = torch.randn(N, C, H, W).to(memory_format=input_mem_format).requires_grad_()
x_mps = x_cpu.detach().clone().to(device='mps').requires_grad_()
if permute_data:
x_cpu.permute(0, 2, 3, 1)
x_mps.permute(0, 2, 3, 1)
for strideX in range(1, 4):
for strideY in range(1, 4):
conv_cpu = torch.nn.Conv2d(
in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY)).requires_grad_()
conv_cpu.weight.data = conv_cpu.weight.to(memory_format=weight_mem_format).requires_grad_()
conv_mps = torch.nn.Conv2d(
in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY), device="mps")
conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_()
conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_()
res_cpu = conv_cpu(x_cpu)
res_mps = conv_mps(x_mps)
self.assertEqual(res_cpu, res_mps.cpu(), rtol=1e-03, atol=1e-05)
res_cpu = res_cpu.sum().backward()
res_mps = res_mps.sum().backward()
self.assertEqual(res_cpu, res_mps, rtol=2.6e-05, atol=2e-04)
self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
self.assertEqual(conv_cpu.bias.grad, conv_mps.bias.grad)
self.assertEqual(x_cpu.grad, x_mps.grad)
for mem_format_input in [torch.contiguous_format, torch.channels_last]:
for mem_format_weight in [torch.contiguous_format, torch.channels_last]:
for permute_data in [True, False]:
helper(2, 2, 3, 6, 1, mem_format_input, mem_format_weight, permute_data)
helper(10, 10, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
helper(32, 32, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
def test_conv_transpose_2d_strided(self):
def helper(m_cpu, memory_format):
m_mps = copy.deepcopy(m_cpu).requires_grad_()
m_mps.weight.data = m_cpu.weight.data.detach().clone().to("mps").requires_grad_()
m_mps.bias.data = m_cpu.bias.data.detach().clone().to("mps").requires_grad_()
input_cpu = torch.randn(20, 16, 50, 100).to(memory_format=memory_format).requires_grad_()
input_mps = input_cpu.detach().clone().to("mps")
output_cpu = m_cpu(input_cpu)
output_mps = m_mps(input_mps)
self.assertEqual(output_cpu, output_mps)
for mem_format_input in [torch.contiguous_format, torch.channels_last]:
# With square kernels and equal stride
helper(nn.ConvTranspose2d(16, 33, 3, stride=2).requires_grad_(), mem_format_input)
# non-square kernels and unequal stride and with padding
helper(nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)).requires_grad_(), mem_format_input)
def test_conv_transpose_2d_specified_output(self):
input_cpu = torch.randn(1, 16, 12, 12)
input_mps = input_cpu.detach().clone().to("mps")
downsample_cpu = nn.Conv2d(16, 16, 3, stride=2, padding=1)
downsample_mps = nn.Conv2d(16, 16, 3, stride=2, padding=1, device="mps")
downsample_mps.weight.data = downsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
downsample_mps.bias.data = downsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
upsample_cpu = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
upsample_mps = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, device="mps")
upsample_mps.weight.data = upsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
upsample_mps.bias.data = upsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
h_cpu = downsample_cpu(input_cpu)
h_mps = downsample_mps(input_mps)
self.assertEqual(h_cpu, h_mps)
size_cpu = h_cpu.size()
size_mps = h_mps.size()
self.assertEqual(size_cpu, size_mps)
output_cpu = upsample_cpu(h_cpu, output_size=input_cpu.size())
output_mps = upsample_mps(h_mps, output_size=input_mps.size())
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())
def test_conv2d_single_stride(self):
y_cpu = torch.randn(2, 2, 3, 6)
@ -8822,64 +9088,148 @@ class TestAdvancedIndexing(TestCaseMPS):
class TestRNNMPS(TestCaseMPS):
def test_lstm_1(self, device="mps", dtype=torch.float32):
for layers in [1] if product_version < 13.0 else [1, 2, 5]:
torch.random.manual_seed(42)
rnn = nn.LSTM(7, 4, layers, device="cpu")
input = torch.randn(2, 3, 7, device="cpu")
hx = torch.randn(layers, 3, 4, device="cpu")
cx = torch.randn(layers, 3, 4, device="cpu")
rnn = nn.LSTM(1, 4, 2, device="cpu")
input = torch.randn(2, 3, 1, device="cpu")
hx = torch.zeros(2, 3, 4, device="cpu")
cx = torch.zeros(2, 3, 4, device="cpu")
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
rnn = rnn.to(device)
input = input.to(device)
hx = hx.to(device)
cx = cx.to(device)
output, (hn, cn) = rnn(input, (hx, cx))
rnn = rnn.to(device)
input = input.to(device)
hx = hx.to(device)
cx = cx.to(device)
output, (hn, cn) = rnn(input, (hx, cx))
self.assertEqual(cpu_output, output)
self.assertEqual(cpu_hn, hn)
self.assertEqual(cpu_cn, cn)
self.assertEqual(cpu_output, output)
self.assertEqual(cpu_hn, hn)
self.assertEqual(cpu_cn, cn)
# test batch_first
rnn = nn.LSTM(7, 4, layers, device="cpu", batch_first=True)
input = torch.randn(3, 2, 7, device="cpu")
hx = torch.randn(layers, 3, 4, device="cpu")
cx = torch.randn(layers, 3, 4, device="cpu")
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
# test batch_first
rnn = nn.LSTM(1, 4, 2, device="cpu", batch_first=True)
input = torch.randn(3, 2, 1, device="cpu")
hx = torch.zeros(2, 3, 4, device="cpu")
cx = torch.zeros(2, 3, 4, device="cpu")
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
rnn = rnn.to(device)
input = input.to(device)
hx = hx.to(device)
cx = cx.to(device)
output, (hn, cn) = rnn(input, (hx, cx))
rnn = rnn.to(device)
input = input.to(device)
hx = hx.to(device)
cx = cx.to(device)
output, (hn, cn) = rnn(input, (hx, cx))
self.assertEqual(cpu_output, output)
self.assertEqual(cpu_hn, hn)
self.assertEqual(cpu_cn, cn)
self.assertEqual(cpu_output, output)
self.assertEqual(cpu_hn, hn)
self.assertEqual(cpu_cn, cn)
def test_lstm_backward(self, device="mps", dtype=torch.float32):
for layers in [1] if product_version < 13.0 else [1, 2, 5]:
lstm = nn.LSTM(2, 4, layers) # initialized globally for consistent parameters init
lstm.train()
@unittest.skipIf(True, "Backward of lstm returns wrong result")
def test_lstm_2(self, device="mps", dtype=torch.float32):
def get_results(device):
rnn = nn.LSTM(1, 4, 1, device=device)
inp = torch.randn(2, 3, 1, device=device, requires_grad=True)
hx = torch.zeros(1, 3, 4, device=device)
cx = torch.zeros(1, 3, 4, device=device)
def get_results(device, inp, hx, cx):
rnn = lstm.to(device)
inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
output, _ = rnn(inp, (hx, cx))
output.sum().backward()
output, _ = rnn(inp, (hx, cx))
f = output.sum()
weight_grad = rnn.weight_ih_l0.grad.clone()
input_grad = inp.grad.clone()
param_names, params = zip(*rnn.named_parameters())
param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
return output, weight_grad, input_grad
input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx])
return output, param_grads, input_grad, hx_grad, cx_grad
inp = torch.randn((5, 3, 2), requires_grad=True, dtype=dtype, device=device)
hx = torch.randn((layers, 3, 4), requires_grad=True, dtype=dtype, device=device)
cx = torch.randn((layers, 3, 4), requires_grad=True, dtype=dtype, device=device)
cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad = get_results("cpu", inp, hx, cx)
mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad = get_results(device, inp, hx, cx)
self.assertEqual(cpu_hx_grad, mps_hx_grad)
self.assertEqual(cpu_cx_grad, mps_cx_grad)
self.assertEqual(cpu_output, mps_output)
self.assertEqual(cpu_input_grad, mps_input_grad)
for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
self.assertEqual(cpu_weight_grad, mps_weight_grad, f"mismatch in cpu:{cpu_name} vs mps:{mps_name}")
# test batch_first backward
lstm = nn.LSTM(2, 4, layers, batch_first=True)
lstm.train()
hx = torch.randn((layers, 5, 4), requires_grad=True, dtype=dtype, device=device)
cx = torch.randn((layers, 5, 4), requires_grad=True, dtype=dtype, device=device)
cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad = get_results("cpu", inp, hx, cx)
mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad = get_results(device, inp, hx, cx)
self.assertEqual(cpu_hx_grad, mps_hx_grad)
self.assertEqual(cpu_cx_grad, mps_cx_grad)
self.assertEqual(cpu_output, mps_output)
self.assertEqual(cpu_input_grad, mps_input_grad)
for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
self.assertEqual(cpu_weight_grad, mps_weight_grad, f"mismatch in cpu:{cpu_name} vs mps:{mps_name}")
cpu_output, cpu_weight_grad, cpu_input_grad = get_results("cpu")
mps_output, mps_weight_grad, mps_input_grad = get_results("mps")
def test_RNN_cell_no_broadcasting(self):
def test(cell_module, input, hx, input_size, hidden_size):
cell = cell_module(input_size, hidden_size, device='mps')
self.assertRaises(RuntimeError, lambda: cell(input, hx))
def test_all(hidden_size, bad_hx, good_hx, input_size, input):
test(nn.RNNCell, input, bad_hx, input_size, hidden_size)
test(nn.GRUCell, input, bad_hx, input_size, hidden_size)
test(nn.LSTMCell, input, (bad_hx, good_hx), input_size, hidden_size)
test(nn.LSTMCell, input, (good_hx, bad_hx), input_size, hidden_size)
hidden_size = 20
input_size = 10
input = torch.randn(3, input_size, device='mps')
bad_hx = torch.randn(1, hidden_size, device='mps')
good_hx = torch.randn(3, hidden_size, device='mps')
# Test hidden/input batch size broadcasting
test_all(hidden_size, bad_hx, good_hx, input_size, input)
# Test hx's hidden_size vs module's hidden_size broadcasting
bad_hx = torch.randn(3, 1)
test_all(hidden_size, bad_hx, good_hx, input_size, input)
# Test input's input_size vs module's input_size broadcasting
bad_input = torch.randn(3, 1)
test_all(hidden_size, good_hx, good_hx, input_size, bad_input)
def test_LSTM_cell(self):
# this is just a smoke test; these modules are implemented through
# autograd so no Jacobian test is needed
for bias in (True, False):
input = torch.randn(3, 10, device='mps')
hx = torch.randn(3, 20, device='mps')
cx = torch.randn(3, 20, device='mps')
lstm = nn.LSTMCell(10, 20, bias=bias, device='mps')
for _ in range(6):
hx, cx = lstm(input, (hx, cx))
(hx + cx).sum().backward()
def test_LSTM_cell_forward_input_size(self):
input = torch.randn(3, 11, device='mps')
hx = torch.randn(3, 20, device='mps')
cx = torch.randn(3, 20, device='mps')
lstm = nn.LSTMCell(10, 20, device='mps')
self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
def test_LSTM_cell_forward_hidden_size(self):
input = torch.randn(3, 10, device='mps')
hx = torch.randn(3, 21, device='mps')
cx = torch.randn(3, 20, device='mps')
lstm = nn.LSTMCell(10, 20, device='mps')
self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
self.assertRaises(Exception, lambda: lstm(input, (cx, hx)))
self.assertEqual(cpu_output, mps_output)
self.assertEqual(cpu_input_grad, mps_input_grad)
self.assertEqual(cpu_weight_grad, mps_weight_grad)
class TestFallbackWarning(TestCase):
# TODO: Remove once test_testing.py is running on MPS devices
@ -9152,6 +9502,7 @@ class TestConsistency(TestCaseMPS):
'isreal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'kron': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'linalg.matrix_norm': ['f16'],
'linalg.matrix_power': ['f32'],
'linalg.svd': ['f32'],
'linalg.vector_norm': ['f16', 'f32'],
'linspace': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@ -9315,6 +9666,7 @@ class TestConsistency(TestCaseMPS):
'nn.functional.bilinear': ['f32'],
'linalg.solve_triangular': ['f32'],
'triangular_solve': ['f32'],
'trace': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'_native_batch_norm_legit': ['f32'],
'native_batch_norm': ['f32'],
'minreduction_with_dim': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@ -9511,6 +9863,8 @@ class TestConsistency(TestCaseMPS):
'native_batch_norm': ['f32'],
'native_layer_norm': ['f32'],
'nn.functional.gelu': ['f32'],
'nn.functional.bilinear': ['f32'],
'nn.functional.prelu': ['f32'],
}
# These ops that are problematic. So never run them even when
@ -9530,7 +9884,6 @@ class TestConsistency(TestCaseMPS):
'stft': [torch.float32], 'var': [torch.float16],
# + forward when requires_grad=True or running backward
'nn.functional.embedding': [torch.float32, torch.float16],
'__rpow__': [torch.int64],
'as_strided_scatter': [torch.uint8],
'atan2': [torch.int64],

View File

@ -46,9 +46,10 @@ from torch.testing._internal.common_utils import (
skipIfRocm,
skipIfTorchDynamo
)
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
from typing import Dict, Any, Tuple
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
from unittest.mock import patch
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
@ -252,21 +253,26 @@ class TestOptim(TestCase):
)
# Make sure that optimizers that support maximize can load older models
state_dict = optimizer.state_dict()
if "maximize" in state_dict["param_groups"][0]:
for group in state_dict["param_groups"]:
old_state_dict = deepcopy(optimizer.state_dict())
state_dict_no_maximize = deepcopy(optimizer.state_dict())
if "maximize" in state_dict_no_maximize["param_groups"][0]:
for group in state_dict_no_maximize["param_groups"]:
del group["maximize"]
optimizer.load_state_dict(state_dict)
optimizer.load_state_dict(state_dict_no_maximize)
# Make sure we can still step
optimizer.step()
# Undo these changes before proceeding!
optimizer.load_state_dict(old_state_dict)
# Make sure that optimizers that support foreach can load older models
state_dict = optimizer.state_dict()
if "foreach" in state_dict["param_groups"][0]:
for group in state_dict["param_groups"]:
state_dict_no_foreach = deepcopy(optimizer.state_dict())
if "foreach" in state_dict_no_foreach["param_groups"][0]:
for group in state_dict_no_foreach["param_groups"]:
del group["foreach"]
optimizer.load_state_dict(state_dict)
optimizer.load_state_dict(state_dict_no_foreach)
# Make sure we can still step
optimizer.step()
# Undo these changes before proceeding!
optimizer.load_state_dict(old_state_dict)
# Make sure that loading optimizers with step not wrapped in tensor can work
state_dict = optimizer.state_dict()
@ -4535,5 +4541,39 @@ class TestDifferentiableOptimizer(TestCase):
)
@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
def test_defaults_changed_to_foreach(self):
from torch.optim import (adam, adamw, nadam, sgd, radam, rmsprop, rprop,
asgd, adamax, adadelta, adagrad)
multi_optims = ((optim.Adam, adam, "_multi_tensor_adam"),
(optim.AdamW, adamw, "_multi_tensor_adamw"),
(optim.NAdam, nadam, "_multi_tensor_nadam"),
(optim.SGD, sgd, "_multi_tensor_sgd"),
(optim.RAdam, radam, "_multi_tensor_radam"),
(optim.RMSprop, rmsprop, "_multi_tensor_rmsprop"),
(optim.Rprop, rprop, "_multi_tensor_rprop"),
(optim.ASGD, asgd, "_multi_tensor_asgd"),
(optim.Adamax, adamax, "_multi_tensor_adamax"),
(optim.Adadelta, adadelta, "_multi_tensor_adadelta"),
(optim.Adagrad, adagrad, "_multi_tensor_adagrad"),)
model = torch.nn.Linear(5, 5)
model.to(dtype=torch.float64, device="cuda")
input = torch.rand(2, 5, dtype=torch.float64, device="cuda")
for opt, mod, func in multi_optims:
defaults = {}
if opt == optim.SGD:
defaults["lr"] = 1e-2
optimizer = opt(model.parameters(), **defaults)
optimizer.zero_grad()
output = model(input)
loss = output.sum()
loss.backward()
with patch.object(mod, func) as mocked_foreach_impl:
optimizer.step()
self.assertTrue(mocked_foreach_impl.called)
if __name__ == "__main__":
run_tests()

View File

@ -736,6 +736,15 @@ class SerializationMixin:
with self.assertRaisesRegex(RuntimeError, error_msg):
torch.save([a.storage(), s_bytes], f)
def test_safe_load_basic_types(self):
with tempfile.NamedTemporaryFile() as f:
data = {"int": 123, "str": "world", "float": 3.14, "bool": False}
torch.save(data, f)
f.seek(0)
loaded_data = torch.load(f, weights_only=True)
self.assertEqual(data, loaded_data)
class serialization_method:
def __init__(self, use_zip):
self.use_zip = use_zip

View File

@ -2170,8 +2170,8 @@
input, weight, bias: linear_backward(input, grad, weight, grad_input_mask)
#mps
- name: _mps_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
self: mps_max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode)
- name: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
self: max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode)
- name: _mps_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor
self, weight, bias: "grad.defined() ? mps_convolution_backward(self, grad, weight, padding, stride, dilation, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
@ -2564,11 +2564,11 @@
input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, std::vector<int64_t>(padding.size(), 1), false, std::vector<c10::SymInt>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
#LSTM MPS
- name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
output_differentiability: [True, True, True, False, False]
input, hx, params: "lstm_mps_backward(grads[0], grads[1], grads[2], result3, result4, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)"
- name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
output_differentiability: [True, True, True, False, False, False]
input, hx, params: "lstm_mps_backward(grads[0], grads[1], grads[2], result3, result4, input, result5, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)"
- name: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
- name: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])

View File

@ -215,6 +215,10 @@ def main():
f"ROCM version: {rocm_ver}\n"
)
for args in _SANITY_CHECK_ARGS:
# TODO remove check when 3.11 is supported
if sys.version_info >= (3, 11):
warnings.warn("Dynamo not yet supported in Python 3.11. Skipping check.")
continue
check_dynamo(*args)
print("All required checks passed")

View File

@ -308,6 +308,7 @@ def core_aten_decompositions() -> Dict[OpOverload, Callable]:
aten.trace,
aten.transpose.int,
aten.tril.default,
aten.triu.default,
aten.unfold,
aten.unfold_backward,
aten.upsample_bilinear2d,

View File

@ -405,8 +405,9 @@ def _nll_loss_backward(
grad_output = grad_output / total_weight
target = target.unsqueeze(channel_dim)
safe_target = torch.where(target != ignore_index, target, 0)
grad_input = torch.zeros_like(self)
grad_input = torch.scatter(grad_input, channel_dim, target, -1.0)
grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
if grad_input.dim() > grad_output.dim() > 0:
grad_output = grad_output.unsqueeze(channel_dim)
@ -417,9 +418,7 @@ def _nll_loss_backward(
weight = weight.reshape(new_shape)
grad_output = grad_output * weight
has_ignore_index = ignore_index >= 0
if has_ignore_index:
grad_output = torch.where(target != ignore_index, grad_output, 0)
grad_output = torch.where(target != ignore_index, grad_output, 0)
return grad_input * grad_output
@ -2798,14 +2797,13 @@ def nll_loss_forward(
if weight is not None:
w = weight.unsqueeze(0) if n_dims > 1 else weight
self = self * w
target_ = target.unsqueeze(channel_dim)
safe_target = torch.where(target != ignore_index, target, 0)
safe_target_ = safe_target.unsqueeze(channel_dim)
# target can be [N, 1] or [1]
result = -torch.gather(self, channel_dim, target_).squeeze(channel_dim)
result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
if ignore_index >= 0:
result = torch.where(target != ignore_index, result, 0)
result = torch.where(target != ignore_index, result, 0)
if reduction == Reduction.NONE.value and n_dims > 1:
total_weight = self.new_full((), 0.0)
@ -2813,22 +2811,16 @@ def nll_loss_forward(
if weight is not None:
w = weight.unsqueeze(0).expand(self.shape) if n_dims > 1 else weight
wsum = torch.gather(w, channel_dim, target_).squeeze(channel_dim)
if ignore_index >= 0:
wsum = torch.where(target != ignore_index, wsum, 0)
wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
wsum = torch.where(target != ignore_index, wsum, 0)
total_weight = wsum.sum()
elif ignore_index >= 0:
total_weight = (target != ignore_index).sum().to(self)
else:
total_weight = self.new_full((), 1.0 * result.numel())
total_weight = (target != ignore_index).sum().to(self)
if reduction == Reduction.SUM.value:
result = result.sum()
elif reduction == Reduction.MEAN.value:
if weight is None:
result = result.sum() / total_weight if ignore_index >= 0 else result.mean()
else:
result = result.sum() / total_weight
result = result.sum() / total_weight
return result, total_weight

View File

@ -116,8 +116,3 @@ def onnxrt(gm, example_inputs, *, filename=None, provider=None):
return outputs
return _call
@register_backend
def tensorrt(gm, example_inputs):
return onnxrt(gm, example_inputs, provider="TensorrtExecutionProvider")

View File

@ -0,0 +1,12 @@
# import torch # type: ignore[import]
# from .common import device_from_inputs, fake_tensor_unsupported # type: ignore[import]
# from .registry import register_backend # type: ignore[import]
"""
Placeholder for TensorRT backend for dynamo via torch-tensorrt
"""
# @register_backend
# def tensorrt(gm, example_inputs):
# import torch_tensorrt # type: ignore[import]
# pass

View File

@ -370,6 +370,13 @@ class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
return fn
def check_if_dynamo_supported():
if sys.platform == "win32":
raise RuntimeError("Windows not yet supported for torch.compile")
if sys.version_info >= (3, 11):
raise RuntimeError("Python 3.11+ not yet supported for torch.compile")
def optimize(
backend="inductor",
*,
@ -403,6 +410,7 @@ def optimize(
def toy_example(a, b):
...
"""
check_if_dynamo_supported()
# Note: The hooks object could be global instead of passed around, *however* that would make
# for a confusing API usage and plumbing story wherein we nest multiple .optimize calls.
# There is some prior art around this, w/r/t nesting backend calls are enforced to be the same
@ -412,14 +420,6 @@ def optimize(
torch._C._log_api_usage_once("torch._dynamo.optimize")
if disable or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1":
return _NullDecorator()
if sys.platform == "win32":
warnings.warn(
"Windows is not currently supported, torch.compile() will do nothing"
)
return _NullDecorator()
if sys.version_info >= (3, 11):
warnings.warn("Python 3.11+ not yet supported, torch.compile() will do nothing")
return _NullDecorator()
backend = get_compiler_fn(backend)
@ -521,6 +521,7 @@ def explain(f, *args, **kwargs):
def export(
f, *args, aten_graph=False, decomposition_table=None, tracing_mode="real", **kwargs
):
check_if_dynamo_supported()
torch._C._log_api_usage_once("torch._dynamo.export")
if decomposition_table is not None or tracing_mode != "real":
assert (

View File

@ -481,9 +481,34 @@ For now, dynamo will explicitly graph break when it encounters user code with th
if self.value == torch._C._nn.scaled_dot_product_attention:
# See:[Note] SDPA_flash's meta function returns incorrect Philox seed and offset
# in pytorch/torch/_meta_registrations.py
fake_query = args[0].as_proxy().node.meta["example_value"]
fake_key = args[1].as_proxy().node.meta["example_value"]
fake_value = args[2].as_proxy().node.meta["example_value"]
all_kwargs = kwargs.copy()
all_kwargs.update(
dict(
zip(
(
"query",
"key",
"value",
"attn_mask",
"dropout_p",
"is_causal",
),
args,
)
)
)
fake_query = all_kwargs["query"].as_proxy().node.meta["example_value"]
fake_key = all_kwargs["key"].as_proxy().node.meta["example_value"]
fake_value = all_kwargs["value"].as_proxy().node.meta["example_value"]
fake_mask = all_kwargs.get("attn_mask")
if isinstance(fake_mask, TensorVariable):
fake_mask = fake_mask.as_proxy().node.meta["example_value"]
else:
fake_mask = None
dropout_p = kwargs.get("dropout_p")
dropout_p = dropout_p.value if dropout_p is not None else 0.0
is_causal = kwargs.get("is_causal")
is_causal = is_causal.value if is_causal is not None else False
# We look through the stack to find a cuda autocast context
# If we do we will convert the fake tensors to torch.float16
is_cuda_autocast_context = False
@ -502,15 +527,10 @@ For now, dynamo will explicitly graph break when it encounters user code with th
fake_value = fake_value.clone().to(amp_dtype)
backend_choice = torch._fused_sdp_choice(
fake_query, fake_key, fake_value
fake_query, fake_key, fake_value, fake_mask, dropout_p, is_causal
)
if backend_choice == torch.backends.cuda.SDPBackend.FLASH_ATTENTION:
dropout_p = kwargs.get("dropout_p")
# Lets see if they passed it in as not an arg
if len(args) >= 5:
dropout_p = args[4]
if dropout_p is not None and dropout_p.value != 0.0:
if dropout_p is not None and dropout_p != 0.0:
unimplemented(
"FlashAttention with dropout is not supported in cuda graphs"
)

View File

@ -1870,6 +1870,9 @@ def create_runtime_wrapper(
trace_joint: bool,
keep_input_mutations: bool,
):
if not hasattr(compiled_fn, "_boxed_call"):
compiled_fn = make_boxed_func(compiled_fn)
def runtime_wrapper(*args):
# Step 2: remove aliased inputs that are mutated, replace with synthetic bases
# Only happens if our graph mutates an input that aliases another input.
@ -2180,7 +2183,8 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig):
assert all(
[isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards]
)
ctx.save_for_backward(*tensors_saved_for_backwards)
# See Note [Detaching saved tensors in AOTAutograd]
ctx.save_for_backward(*map(lambda x: x.detach() if x._is_view() else x, tensors_saved_for_backwards))
symint_outs = fw_outs[-num_symints_saved_for_bw:]
assert all(
[
@ -2190,7 +2194,9 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig):
)
ctx.symints = symint_outs
else:
ctx.save_for_backward(*fw_outs[num_forward_returns:])
tensors_saved_for_backwards = fw_outs[num_forward_returns:]
# See Note [Detaching saved tensors in AOTAutograd]
ctx.save_for_backward(*map(lambda x: x.detach() if x._is_view() else x, tensors_saved_for_backwards))
ctx.symints = []
raw_returns = fw_outs[0:num_forward_returns]
@ -2299,6 +2305,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig):
contiguous_args = [
t.contiguous() if torch.is_tensor(t) else t for t in flat_bw_args
]
all_args = (
list(ctx.symints) + list(ctx.saved_tensors) + list(contiguous_args)
)

View File

@ -361,6 +361,8 @@ class CppVecOverrides(OpOverrides):
def lgamma(x):
return f"{x}.lgamma()"
"""
#TODO: support logical_and and logical_or vectorization
@staticmethod
def logical_and(a, b):
return f"{a} && {b}"
@ -368,6 +370,7 @@ class CppVecOverrides(OpOverrides):
@staticmethod
def logical_or(a, b):
return f"{a} || {b}"
"""
@staticmethod
def tan(a):

View File

@ -294,15 +294,6 @@ class WrapperCodeGen(CodeGen):
"""
)
if config.triton.convolution != "aten":
self.header.splice(
"""
from torch._inductor.triton_ops.conv_perf_model import early_config_prune
from torch._inductor.triton_ops.conv_perf_model import estimate_conv_time
from torch._inductor.triton_ops.autotune import conv_heuristics
"""
)
self.write_prefix()
for name, value in V.graph.constants.items():

View File

@ -153,9 +153,6 @@ class triton:
# Synchronize after every kernel launch, to help pinpoint bugs
debug_sync_kernel = False
# choose conv backend, "aten" or "triton"
convolution = "aten"
# Always load full blocks (rather than broadcasting inside the block)
dense_indexing = False

View File

@ -2184,20 +2184,10 @@ class ComputedBuffer(Buffer):
for reads_name in body.reads_name2expr.keys()
]
priority_idx = []
if config.triton.convolution == "aten":
memory_addrs = [
*body.reads_name2expr.values(),
*body.writes_name2expr.values(),
]
else:
# prioritize reads layout/loop_ordering over writes
if len(body.reads_name2expr.values()) > 0:
memory_addrs = [*body.reads_name2expr.values()]
else:
memory_addrs = [*body.writes_name2expr.values()]
for i, reads_buf in enumerate(reads_bufs):
if isinstance(reads_buf, Convolution):
priority_idx.append(i)
memory_addrs = [
*body.reads_name2expr.values(),
*body.writes_name2expr.values(),
]
index_vars = []
reduce_vars = []
index_size = []
@ -3140,12 +3130,8 @@ class Convolution(ExternKernelAlloc):
)
req_stride_order = get_stride_order(output.stride())
if config.triton.convolution == "aten":
weight = cls.require_stride_order(weight, req_stride_order)
x = cls.require_stride_order(x, req_stride_order)
else:
x = cls.require_stride1(cls.realize_input(x))
weight = cls.require_stride1(cls.realize_input(weight))
weight = cls.require_stride_order(weight, req_stride_order)
x = cls.require_stride_order(x, req_stride_order)
stride = tuple(stride_)
padding = tuple(padding_)
@ -3163,7 +3149,7 @@ class Convolution(ExternKernelAlloc):
_, _, *kernel_size = weight_shape
# choose runtime kernel
config_conv = config.triton.convolution
config_conv = "aten"
if (
config_conv == "aten"
or len(kernel_size) != 2 # triton conv only supports conv2d
@ -3196,7 +3182,7 @@ class Convolution(ExternKernelAlloc):
)
# for conv2d or conv3d, prefer channels last format
transform_x_layout = config.triton.convolution != "aten"
transform_x_layout = False
if kernel == "triton_ops.conv":
output_layout_str = "torch.channels_last"
else:

View File

@ -45,7 +45,7 @@ def mm_configs():
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=2, num_warps=8
{"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=2, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, num_stages=2, num_warps=4

View File

@ -1505,30 +1505,6 @@ def iota(
)
@register_lowering(aten.triu)
def triu(x, diagonal=0):
x_loader = x.make_loader()
dtype = x.get_dtype()
def inner_fn(index):
*_, i, j = index
return ops.where(
ops.ge(
ops.index_expr(j - i - diagonal, torch.int32),
ops.constant(0, torch.int32),
),
x_loader(index),
ops.constant(0, dtype),
)
return Pointwise.create(
device=x.get_device(),
dtype=dtype,
inner_fn=inner_fn,
ranges=list(x.get_size()),
)
@register_lowering(aten.select_scatter, type_promotion_kind=None)
def select_scatter(x, src, dim: int, index: int):
assert x.get_dtype() == src.get_dtype()

View File

@ -22,11 +22,14 @@ def dl_open_guard():
Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
shared library to load custom operators.
"""
if _SET_GLOBAL_FLAGS:
old_flags = sys.getdlopenflags()
sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
yield
if _SET_GLOBAL_FLAGS:
if not _SET_GLOBAL_FLAGS:
yield
return
old_flags = sys.getdlopenflags()
sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
try:
yield
finally:
sys.setdlopenflags(old_flags)

View File

@ -21,6 +21,7 @@ from collections import OrderedDict
from pickle import (
APPEND,
APPENDS,
BINFLOAT,
BINGET,
BININT,
BININT1,
@ -226,6 +227,8 @@ class Unpickler:
self.append(self.read(1)[0])
elif key[0] == BININT2[0]:
self.append(unpack("<H", read(2))[0])
elif key[0] == BINFLOAT[0]:
self.append(unpack(">d", self.read(8))[0])
elif key[0] == BINUNICODE[0]:
strlen = unpack("<I", read(4))[0]
if strlen > maxsize:

View File

@ -320,7 +320,9 @@ void Logger::set_runtime_stats_and_log() {
"Cuda time stats are not collected for multi-device modules.");
return;
}
if (!reducer_->params_[0].is_cuda() && !reducer_->params_[0].is_cpu()) {
if (!reducer_->timer_ &&
(!reducer_->params_[0].is_cuda() && !reducer_->params_[0].is_cpu())) {
TORCH_WARN_ONCE(
"Time stats are currently only collected for CPU and CUDA devices. "
"Please refer to CpuTimer or CudaTimer for how to register timer "

View File

@ -193,8 +193,7 @@ def adadelta(
# We still respect when the user inputs False for foreach.
if foreach is None:
_, foreach = _default_to_fused_or_foreach([params, grads, square_avgs, acc_deltas],
differentiable, has_fused=False)
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

View File

@ -210,8 +210,7 @@ def adagrad(
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach([params, grads, state_sums, state_steps],
differentiable, has_fused=False)
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

View File

@ -4,7 +4,7 @@ import torch
from torch import Tensor
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling,
_dispatch_sqrt, _default_to_fused_or_foreach, _capturable_doc,
_differentiable_doc, _foreach_doc, _maximize_doc)
_differentiable_doc, _foreach_doc, _fused_doc, _maximize_doc)
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
__all__ = ['Adam', 'adam']
@ -218,28 +218,14 @@ Adam.__doc__ = r"""Implements Adam algorithm.
{maximize}
{capturable}
{differentiable}
fused (bool, optional): whether the fused implementation (CUDA only) is used.
Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
are supported. Since the fused implementation is usually significantly faster than
the for-loop implementation, we try to use it whenever possible (all parameters
are on CUDA and are of a supported type). Else, we attempt to use the foreach
implementation and lastly fall back to the for-loop implementation. (default: None)
.. note:: The foreach and fused implementations are typically faster than the for-loop,
single-tensor implementation, so we will try to default to them IF the user has
not specified either flag (i.e., when foreach = fused = None). For example, if
the user specifies True for foreach but nothing for fused, we will run the foreach
implementation. If the user specifies False for fused but nothing for foreach, we will
run the for-loop implementation. If the user specifies True for both foreach and
fused, we will prioritize fused over foreach. We attempt to use the fastest, so the
hierarchy goes fused -> foreach -> for-loop.
{fused}
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
""".format(foreach=_foreach_doc, maximize=_maximize_doc, capturable=_capturable_doc,
differentiable=_differentiable_doc)
differentiable=_differentiable_doc, fused=_fused_doc)
def adam(params: List[Tensor],
@ -268,10 +254,12 @@ def adam(params: List[Tensor],
See :class:`~torch.optim.Adam` for details.
"""
# Respect when the user inputs False/True for foreach or fused. We only want to change
# the default when neither have been user-specified. Note that we default to foreach
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
# bake-in time before making it the default, even if it is typically faster.
if fused is None and foreach is None:
fused, foreach = _default_to_fused_or_foreach(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps],
differentiable, has_fused=True)
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
if fused is None:
fused = False
if foreach is None:

View File

@ -206,8 +206,7 @@ def adamax(
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach([params, grads, exp_avgs, exp_infs, state_steps],
differentiable, has_fused=False)
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

View File

@ -2,7 +2,7 @@ import torch
from torch import Tensor
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt,
_stack_if_compiling, _capturable_doc, _differentiable_doc, _foreach_doc,
_maximize_doc, _default_to_fused_or_foreach)
_fused_doc, _maximize_doc, _default_to_fused_or_foreach)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
@ -248,13 +248,7 @@ AdamW.__doc__ = r"""Implements AdamW algorithm.
{foreach}
{capturable}
{differentiable}
fused (bool, optional): whether the fused implementation (CUDA only) is used.
Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
are supported. Since the fused implementation is usually significantly faster than
the for-loop implementation, we try to use it whenever possible (all parameters
are on CUDA and are of a supported type). Else, we continue with the for-loop
implementation. (default: None)
{fused}
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
@ -262,6 +256,7 @@ AdamW.__doc__ = r"""Implements AdamW algorithm.
""".format(maximize=_maximize_doc,
foreach=_foreach_doc,
fused=_fused_doc,
capturable=_capturable_doc,
differentiable=_differentiable_doc)
@ -300,11 +295,12 @@ def adamw(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
# Respect when the user inputs False/True for foreach.
# Respect when the user inputs False/True for foreach or fused. We only want to change
# the default when neither have been user-specified. Note that we default to foreach
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
# bake-in time before making it the default, even if it is typically faster.
if fused is None and foreach is None:
fused, foreach = _default_to_fused_or_foreach(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps],
differentiable, has_fused=False)
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
if fused is None:
fused = False
if foreach is None:

View File

@ -185,8 +185,7 @@ def asgd(
"""
if foreach is None:
_, foreach = _default_to_fused_or_foreach([params, grads, axs, mus, etas, state_steps],
differentiable, has_fused=False)
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

View File

@ -187,8 +187,7 @@ def nadam(params: List[Tensor],
raise RuntimeError("API has changed, `mu_products` argument must contain a list of singleton tensors")
if foreach is None:
_, foreach = _default_to_fused_or_foreach([params, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps],
differentiable, has_fused=False)
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
if foreach and torch.jit.is_scripting():
raise RuntimeError('torch.jit.script not supported with foreach optimizers')

View File

@ -15,6 +15,7 @@ from torch._utils import is_compiling
__all__ = ['Optimizer', 'register_optimizer_step_pre_hook', 'register_optimizer_step_post_hook']
_global_optimizer_pre_hooks: Dict[int, Callable] = OrderedDict()
_global_optimizer_post_hooks: Dict[int, Callable] = OrderedDict()
_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter]
class _RequiredParameter:
"""Singleton class representing a required parameter for an Optimizer."""
@ -55,24 +56,21 @@ def _dispatch_sqrt(x: float): # float annotation is needed because of torchscri
return math.sqrt(x)
# For any optimizer with a faster implementation, we attempt to default to the
# fastest whenever possible. For foreach, the requirements are to have native
# tensors all on CUDA. For fused, there's currently the additional requirement
# fastest + stablest whenever possible. For foreach, the requirements are to have
# native params all on CUDA. For fused, there's currently the additional requirement
# that the tensors' dtypes must be floating point. Neither alternative supports
# torch.jit.script nor differentiable, so we fall back to the single tensor
# implementation in those cases.
def _default_to_fused_or_foreach(tensorlists: List[List[torch.Tensor]],
def _default_to_fused_or_foreach(params: List[torch.Tensor],
differentiable: bool,
has_fused: bool = False) -> Tuple[bool, bool]:
use_fused: bool = False) -> Tuple[bool, bool]:
if torch.jit.is_scripting() or differentiable:
return False, False
all_tensors = []
for tensorlist in tensorlists:
all_tensors.extend(tensorlist)
fused = has_fused and all(
p is None or (type(p) == torch.Tensor and p.is_cuda and torch.is_floating_point(p)) for p in all_tensors
fused = use_fused and all(
p is None or (type(p) in _foreach_supported_types and p.is_cuda and torch.is_floating_point(p)) for p in params
)
foreach = not fused and all(
p is None or (type(p) == torch.Tensor and p.is_cuda) for p in all_tensors
p is None or (type(p) in _foreach_supported_types and p.is_cuda) for p in params
)
return fused, foreach
@ -83,6 +81,23 @@ _foreach_doc = r"""foreach (bool, optional): whether foreach implementation of o
foreach over the for-loop implementation on CUDA, since it is usually
significantly more performant. (default: None)"""
_fused_doc = r"""fused (bool, optional): whether the fused implementation (CUDA only) is used.
Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
are supported. (default: None)
.. note:: The foreach and fused implementations are typically faster than the for-loop,
single-tensor implementation. Thus, if the user has not specified BOTH flags
(i.e., when foreach = fused = None), we will attempt defaulting to the foreach
implementation when the tensors are all on CUDA. For example, if the user specifies
True for fused but nothing for foreach, we will run the fused implementation. If
the user specifies False for foreach but nothing for fused (or False for fused but
nothing for foreach), we will run the for-loop implementation. If the user specifies
True for both foreach and fused, we will prioritize fused over foreach, as it is
typically faster. We attempt to use the fastest, so the hierarchy goes fused ->
foreach -> for-loop. HOWEVER, since the fused implementation is relatively new,
we want to give it sufficient bake-in time, so we default to foreach and NOT
fused when the user has not specified either flag."""
_capturable_doc = r"""capturable (bool, optional): whether this instance is safe to
capture in a CUDA graph. Passing True can impair ungraphed performance,
so if you don't intend to graph capture this instance, leave it False

View File

@ -209,8 +209,7 @@ def radam(
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach([params, grads, exp_avgs, exp_avg_sqs, state_steps],
differentiable, has_fused=False)
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

View File

@ -220,8 +220,7 @@ def rmsprop(
"""
if foreach is None:
_, foreach = _default_to_fused_or_foreach([params, grads, square_avgs, grad_avgs, momentum_buffer_list],
differentiable, has_fused=False)
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

View File

@ -192,8 +192,7 @@ def rprop(
"""
if foreach is None:
_, foreach = _default_to_fused_or_foreach([params, grads, prevs, step_sizes],
differentiable, has_fused=False)
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

View File

@ -207,8 +207,7 @@ def sgd(params: List[Tensor],
# why must we be explicit about an if statement for torch.jit.is_scripting here?
# because JIT can't handle Optionals nor fancy conditionals when scripting
if not torch.jit.is_scripting():
_, foreach = _default_to_fused_or_foreach([params, d_p_list, momentum_buffer_list],
differentiable=False, has_fused=False)
_, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False)
else:
foreach = False

View File

@ -2130,7 +2130,8 @@ class TestCase(expecttest.TestCase):
errors_before = 0 if result is None else len(result.errors)
skipped_before = 0 if result is None else len(result.skipped)
if TEST_WITH_TORCHDYNAMO:
# TODO remove version check once dynamo supports 3.11
if TEST_WITH_TORCHDYNAMO and sys.version_info < (3, 11):
# TorchDynamo optimize annotation
if TEST_WITH_TORCHINDUCTOR:
super_run = torch._dynamo.optimize("inductor")(super().run)

View File

@ -380,7 +380,7 @@ def make_histogram(values, bins, max_bins=None):
limits = new_limits
# Find the first and the last bin defining the support of the histogram:
cum_counts = np.cumsum(np.greater(counts, 0, dtype=np.int32))
cum_counts = np.cumsum(np.greater(counts, 0))
start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
start = int(start)
end = int(end) + 1

View File

@ -1109,6 +1109,7 @@ SUPPORTED_RETURN_TYPES = {
"::std::tuple<at::Tensor,at::Tensor,at::Tensor>",
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>",
"::std::tuple<at::Tensor,at::Tensor,double,int64_t>",
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>",

View File

@ -638,6 +638,7 @@ class NativeFunction:
raw_dispatch = e.pop("dispatch", None)
assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
dispatch: Dict[DispatchKey, BackendMetadata] = {}
num_dispatch_keys: int = 0
if raw_dispatch is not None:
assert not manual_kernel_registration, (
"cannot specify both manual_kernel_registration and dispatch; with "
@ -650,6 +651,8 @@ class NativeFunction:
assert isinstance(ks, str), e
for k in ks.split(","):
dispatch_key = DispatchKey.parse(k.strip())
num_dispatch_keys += 1
if ignore_keys and dispatch_key in ignore_keys:
continue
assert dispatch_key in dispatch_keys, (
@ -677,7 +680,12 @@ class NativeFunction:
):
redundant_composite_implicit_autograd = True
assert not (len(dispatch) == 1 and redundant_composite_implicit_autograd), (
# We count the number of dispatch keys which have not been ignored to prevent a dispatch table
# in which all backend keys are ignored but necessarily kept, remaining compositeimplicit,
# from being treated as redundant.
assert not (
num_dispatch_keys == 1 and redundant_composite_implicit_autograd
), (
"unnecessary dispatch table for this function; just delete the dispatch "
"key entirely"
)
@ -687,6 +695,7 @@ class NativeFunction:
structured_delegate
or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint()
or num_dispatch_keys != 1
), (
f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} "
f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}. Rename your implementation to the expected "