Compare commits

..

105 Commits

Author SHA1 Message Date
d0e00d5448 remove resize and del 2024-06-20 17:07:47 -07:00
25229787d6 [Traceable FSDP2] Add unit tests for simple MLP and transformer model
ghstack-source-id: a23c48d5d56d2633f7ec7efb7b1567340d4feeb1
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129157
2024-06-20 14:45:17 -07:00
e84cf805d2 Revert "Modularize aten parameter parser and checker (#125308)"
This reverts commit 60bbdc0b40656cf70b2b098c7d715e19f031fb0d.

Reverted https://github.com/pytorch/pytorch/pull/125308 on behalf of https://github.com/fbgheith due to test failures when run by meta ([comment](https://github.com/pytorch/pytorch/pull/125308#issuecomment-2181327211))
2024-06-20 18:52:05 +00:00
254487f288 Revert "Separate AOTI Eager utils as a single file (#125819)"
This reverts commit 18634048a1f939a961b7c96b0acfe78b474c821e.

Reverted https://github.com/pytorch/pytorch/pull/125819 on behalf of https://github.com/fbgheith due to test failures when run by meta ([comment](https://github.com/pytorch/pytorch/pull/125819#issuecomment-2181317332))
2024-06-20 18:49:08 +00:00
73340f0909 Revert "[3/N] Non-Tensor: Support string parameter for aten operations (#125831)"
This reverts commit a52c8ace98afe76dc9e2c330b415972fd1529077.

Reverted https://github.com/pytorch/pytorch/pull/125831 on behalf of https://github.com/fbgheith due to test failures when run by meta ([comment](https://github.com/pytorch/pytorch/pull/125831#issuecomment-2181313892))
2024-06-20 18:45:41 +00:00
8c2542623b [Traceable FSDP2] [Dynamo] Add tracing support for out-variant custom ops that return None (#129078)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129078
Approved by: https://github.com/yanboliang
2024-06-20 17:46:13 +00:00
734891ac22 Fix export log script (#128967)
Summary: Title

Test Plan: CI

Differential Revision: D58699557

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128967
Approved by: https://github.com/jiashenC
2024-06-20 17:01:00 +00:00
ddb95dbb0d Fixing equalize with three things and improving functionality (#124632)
Summary:
(1) Make code work when a first layer does not have a bias.
(2) Make it possible to provide both modules and module names as input
(3) Allow sequences of contiguous layers as input, that then get split into pairs
(4) fix documentation to be more clear on inputs to be provided

Test Plan:
Run this new version of the algorithm on a network and see if it throws errors.

There's also this notebook to run and test N5199827

It you tell me where I can find the tests for this code, I can add some simple unit tests as well.

Differential Revision: D55895862

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124632
Approved by: https://github.com/jerryzh168
2024-06-20 16:55:56 +00:00
832fc35211 Revert "Improved flexattention bwd perf + added configurations for benchmarks (#129013)"
This reverts commit 6d2b3c90f144d7b77d51da27e6696192b2b97ebd.

Reverted https://github.com/pytorch/pytorch/pull/129013 on behalf of https://github.com/ZainRizvi due to Sorry but this is causing a flexattention test to fail on ROCm. Can you please fix that test before remerging this in? See 6d2b3c90f1 for details ([comment](https://github.com/pytorch/pytorch/pull/129013#issuecomment-2181133070))
2024-06-20 16:51:41 +00:00
65286883d4 [export] reland "experimental joint graph API." (#129081)
Summary: previous diff got reverted despite CI was green.

Test Plan: CI

Differential Revision: D58790048

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129081
Approved by: https://github.com/tugsbayasgalan
2024-06-20 16:50:53 +00:00
fc5b0ff2d7 [BE][Hackaday] deprecate legacy cuda docker image (#128859)
Fixes https://github.com/pytorch/builder/issues/1795 from the pytorch side specifically for the cuda image

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128859
Approved by: https://github.com/atalman
2024-06-20 16:30:49 +00:00
b2a9b8d485 [CpuInductor] Enable NEON ISA detection on Linux ARM (#129075)
Also, cleanup code a bit to use `x in [y, z]` instead of `x == y or x == z`

And do not redefine `at_align`, but instead use `alignas(64)` as was suggested in https://github.com/pytorch/pytorch/pull/128686/files#r1639365978

Test plan: `python3 -c "import torch._inductor.codecache as cc; isa = cc.valid_vec_isa_list()[0];print(str(isa), bool(isa))"`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129075
Approved by: https://github.com/jansel
2024-06-20 16:22:57 +00:00
e0aa992d73 Fix inductor and deploy jobs timing out (#129108)
Some trunk and periodic jobs are timing out at the moment, including:

* `deploy`.  This is because https://github.com/pytorch/pytorch/pull/127952 has removed `deploy` config, but there is one left over in periodic.
    * [periodic / linux-focal-cuda12.4-py3.10-gcc9 / test (deploy, 1, 1, linux.4xlarge.nvidia.gpu](https://github.com/pytorch/pytorch/actions/runs/9525590191/job/26260620457).
* `inductor`, including `py3.10`, `py3.12`, and `cuda12.1`, `cuda12.4`.  The increase comes from this change https://github.com/pytorch/pytorch/pull/128343, so I add another GPU shard.
    * [inductor / cuda12.1-py3.12-gcc9-sm86 / test (inductor, 1, 1, linux.g5.4xlarge.nvidia.gpu)](https://github.com/pytorch/pytorch/actions/runs/9522817887/job/26255069269)
    * [inductor / cuda12.1-py3.10-gcc9-sm86 / test (inductor, 1, 1, linux.g5.4xlarge.nvidia.gpu)](https://github.com/pytorch/pytorch/actions/runs/9524651902/job/26260009757)
    * [inductor-cu124 / cuda12.4-py3.10-gcc9-sm86 / test (inductor, 1, 1, linux.g5.4xlarge.nvidia.gpu)](https://github.com/pytorch/pytorch/actions/runs/9587982228/job/26440205869)
    * [inductor-cu124 / cuda12.4-py3.12-gcc9-sm86 / test (inductor, 1, 1, linux.g5.4xlarge.nvidia.gpu)](https://github.com/pytorch/pytorch/actions/runs/9587982228/job/26440634200)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129108
Approved by: https://github.com/malfet
2024-06-20 16:03:11 +00:00
2bb8ee602b Fix DEBUG=1 asserts with NJT ops (#129014)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129014
Approved by: https://github.com/YuqingJ, https://github.com/soulitzer
2024-06-20 15:15:28 +00:00
7178b4e987 [Dynamo x torch_function] fix incorrect source (#128980)
Fixes https://github.com/pytorch/pytorch/issues/128964

The problem was that we were installing the source for a type
incorrectly.

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128980
Approved by: https://github.com/mlazos
2024-06-20 14:54:00 +00:00
ea47d542ca [dynamo][guards] Remove BOOL_FALSE - not needed after C++ guards (#129098)
PyDict_Size is very fast ... earlier with Python guards, Cpython will go through layers of fluff to finally call the PyDict_Size. With C++ guards, its not needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129098
Approved by: https://github.com/jansel
2024-06-20 14:40:27 +00:00
54b0006cb2 Evaluate symexprs on load path of cache not write (#128997)
When caching is enabled, an internal model fails with
```
assert_size_stride(bmm_9, (17, s0, 512), (54784, 512, 1))
AssertionError: expected size 17==17, stride 57344==54784 at dim=0
```
looking at this model, the exact problem is when the cache is hit on the forward graph, the generated code for backward fails since the strides of the outputs of forward, passed to backward as inputs, are not what we expected.

This PR changes the evaluation logic so that we defer evaluation of output stride exprs to load path as opposed to eagerly doing it on save path.

I have not been able to come up with a unit test repro for this problem.

Differential Revision: [D58796503](https://our.internmc.facebook.com/intern/diff/D58796503)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128997
Approved by: https://github.com/ezyang
2024-06-20 08:55:12 +00:00
799acd31b4 [MPS] Add lu_factor (#99269)
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at d75cde1</samp>

Added MPS support and autograd formulas for LU factorization of tensors. Implemented the `linalg_lu_factor` and `linalg_lu_factor.out` functions for the MPS backend in `LinearAlgebra.mm` and added tests in `test_mps.py`. Added the corresponding dispatch entries in `native_functions.yaml` and the backward and forward formulas in `derivatives.yaml`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99269
Approved by: https://github.com/kulinseth, https://github.com/lezcano
2024-06-20 07:35:29 +00:00
0d25f096c1 [CppInductor] Fix erfinv codegen when non-vectorized isa (#129090)
Fix erfinv codegen when ISA could not be detected

Manual test plan (on MacOS):
 - Modify `valid_vec_isa_list` to return empty list
 - Run `python3 inductor/test_torchinductor_opinfo.py -v -k test_comprehensive_erfinv_cpu_bool`

Before this change, abovementioned test will fail with
```
Output:
/var/folders/rk/fxg20zvx6vvb5bk7cplq4xrc0000gn/T/tmpgic60b6c/ns/cnsp7snp7fyclkm5lsfiyiv3m6c3svevkbhcb3v7pijdfjwlyaij.cpp:11:25: error: use of undeclared identifier 'calc_erfinv'
            auto tmp2 = calc_erfinv(tmp1);
                        ^
1 error generated.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129090
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-06-20 06:09:48 +00:00
6d2b3c90f1 Improved flexattention bwd perf + added configurations for benchmarks (#129013)
Before:
<img width="519" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/6f4a9b37-4aff-48d3-aaba-7e8e5a5bf0fb">

After:
<img width="541" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/423f179e-76f5-457b-8064-ee8a70247534">

After fixing strides:
![image](https://github.com/pytorch/pytorch/assets/6355099/58471587-404b-4bfc-b9b2-7546bdf53f54)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129013
Approved by: https://github.com/drisspg, https://github.com/yanboliang
ghstack dependencies: #128938
2024-06-20 05:15:48 +00:00
ad2593cb86 [Animesh's PR #125340] [dynamo][fsdp] Track FSDPNNModuleVariable for mutations (#129045)
This is a copy of Animesh's work in https://github.com/pytorch/pytorch/pull/125340, with very small changes to the unit test. It's needed sooner for the Traceable FSDP2 work, so I copy it here and will work through landing it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129045
Approved by: https://github.com/anijain2305
2024-06-20 04:02:36 +00:00
19f3abcde4 [Docs][MPS] Add mps environment variable table (#129008)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129008
Approved by: https://github.com/malfet
ghstack dependencies: #129006
2024-06-20 03:30:35 +00:00
609ffaf717 Add more shards for slow CPU and ROCm jobs (#128873)
As they start to timeout in trunk fc2913fb80/1.  Adding one more shard for slow CPU job is trivial.  ROCm runners is harder to find, but I assume that this is ok because slow jobs only run periodically.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128873
Approved by: https://github.com/PaliC
2024-06-20 03:13:19 +00:00
d8db074988 [Traceable FSDP2] [Dynamo] Fix OptimizedModule._initialize to allow tracing into FSDP2 module hooks for module from user-defined module class (#129046)
This is a workaround to allow inplace fully-sharded module to still go into this branch:
3a185778ed/torch/_dynamo/eval_frame.py (L163)
instead of the second branch:
3a185778ed/torch/_dynamo/eval_frame.py (L166)

If we don't do this, `torch.compile(fully_shard(module_from_user_defined_module_class))` will ignore all module hooks which will break FSDP tracing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129046
Approved by: https://github.com/anijain2305
2024-06-20 00:15:55 +00:00
859fa183fe BE: Use future annotations in inductor scheduler and ir (#128892)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128892
Approved by: https://github.com/lezcano
2024-06-20 00:10:43 +00:00
a2b1673dfb [Horace's PR #126446] Prevent partitioner from ever saving views (#129039)
Most work is done by Horace in https://github.com/pytorch/pytorch/issues/126446, this PR just additionally adds the config for it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129039
Approved by: https://github.com/Chillee
2024-06-19 23:21:16 +00:00
9d06e3783d [Inductor][CPP] Fix the symbolic size cast issue in GEMM Benchmark (#128824)
**Summary**
The symbolic size generated from size hint (python int) is different with c type `long` of kernel args which may cause the benchmark failing to run.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128824
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-06-19 23:11:53 +00:00
a6ac6447b5 Re-enable py3.12 nightly wheel builds and add triton dependency for ROCm (#128525)
The llnl-hatchet developers have published the py3.12 binaries on [PyPI](https://pypi.org/project/llnl-hatchet/#files). In fact, looking [here](https://download.pytorch.org/whl/nightly/llnl-hatchet), it seems we already have the py3.12 wheels mirrored. This should allow us to re-enable py3.12 binaries for ROCm.

This PR reverts commit 9d849d4312cd1e62d97b9e9d58979ec78d36c95f.

It also adds the pytorch-triton-rocm dependency for torch wheels on ROCm since pytorch-triton-rocm py3.12 wheels are available now

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128525
Approved by: https://github.com/malfet
2024-06-19 21:56:54 +00:00
571a0db132 [inductor] Fix logging for run_and_get_cpp_code (#128794)
Summary: Found during testing with remote caching: Use the same output logger object between graph.py and codecache.py since it's patched in `run_and_get_cpp_code`. That allows us to capture any logging produced from the codecache path when using `run_and_get_cpp_code`. I'm also fixing a few tests that were passing mistakenly because logging was missing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128794
Approved by: https://github.com/oulgen, https://github.com/leslie-fang-intel
2024-06-19 21:32:34 +00:00
cyy
277f2914a5 [9/N] Remove unused functions (#128704)
MKL can not be enabled on aarch64, and as CI compiles code with `-Werror=unused-function` it will fail to compile with
```
/usr/bin/c++ -DAT_PER_OPERATOR_HEADERS -DBUILD_ONEDNN_GRAPH -DCAFFE2_BUILD_MAIN_LIB -DCPUINFO_SUPPORTED_PLATFORM=1 -DFLASHATTENTION_DISABLE_ALIBI -DFMT_HEADER_ONLY=1 -DFXDIV_USE_INLINE_ASSEMBLY=0 -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DNNP_CONVOLUTION_ONLY=0 -DNNP_INFERENCE_ONLY=0 -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DUSE_C10D_GLOO -DUSE_DISTRIBUTED -DUSE_EXTERNAL_MZCRC -DUSE_RPC -DUSE_TENSORPIPE -D_FILE_OFFSET_BITS=64 -Dtorch_cpu_EXPORTS -I/var/lib/jenkins/workspace/build/aten/src -I/var/lib/jenkins/workspace/aten/src -I/var/lib/jenkins/workspace/build -I/var/lib/jenkins/workspace -I/var/lib/jenkins/workspace/cmake/../third_party/benchmark/include -I/var/lib/jenkins/workspace/third_party/onnx -I/var/lib/jenkins/workspace/build/third_party/onnx -I/var/lib/jenkins/workspace/third_party/foxi -I/var/lib/jenkins/workspace/build/third_party/foxi -I/var/lib/jenkins/workspace/torch/csrc/api -I/var/lib/jenkins/workspace/torch/csrc/api/include -I/var/lib/jenkins/workspace/caffe2/aten/src/TH -I/var/lib/jenkins/workspace/build/caffe2/aten/src/TH -I/var/lib/jenkins/workspace/build/caffe2/aten/src -I/var/lib/jenkins/workspace/build/caffe2/../aten/src -I/var/lib/jenkins/workspace/torch/csrc -I/var/lib/jenkins/workspace/third_party/miniz-2.1.0 -I/var/lib/jenkins/workspace/third_party/kineto/libkineto/include -I/var/lib/jenkins/workspace/third_party/kineto/libkineto/src -I/var/lib/jenkins/workspace/third_party/cpp-httplib -I/var/lib/jenkins/workspace/aten/src/ATen/.. -I/var/lib/jenkins/workspace/third_party/FXdiv/include -I/var/lib/jenkins/workspace/c10/.. -I/var/lib/jenkins/workspace/third_party/pthreadpool/include -I/var/lib/jenkins/workspace/third_party/cpuinfo/include -I/var/lib/jenkins/workspace/aten/src/ATen/native/quantized/cpu/qnnpack/include -I/var/lib/jenkins/workspace/aten/src/ATen/native/quantized/cpu/qnnpack/src -I/var/lib/jenkins/workspace/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/include -I/var/lib/jenkins/workspace/third_party/NNPACK/include -I/var/lib/jenkins/workspace/third_party/FP16/include -I/var/lib/jenkins/workspace/third_party/tensorpipe -I/var/lib/jenkins/workspace/build/third_party/tensorpipe -I/var/lib/jenkins/workspace/third_party/tensorpipe/third_party/libnop/include -I/var/lib/jenkins/workspace/third_party/fmt/include -I/var/lib/jenkins/workspace/build/third_party/ideep/mkl-dnn/include -I/var/lib/jenkins/workspace/third_party/ideep/mkl-dnn/src/../include -I/var/lib/jenkins/workspace/third_party/flatbuffers/include -isystem /var/lib/jenkins/workspace/build/third_party/gloo -isystem /var/lib/jenkins/workspace/cmake/../third_party/gloo -isystem /var/lib/jenkins/workspace/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /var/lib/jenkins/workspace/cmake/../third_party/googletest/googlemock/include -isystem /var/lib/jenkins/workspace/cmake/../third_party/googletest/googletest/include -isystem /var/lib/jenkins/workspace/third_party/protobuf/src -isystem /var/lib/jenkins/workspace/third_party/XNNPACK/include -isystem /var/lib/jenkins/workspace/cmake/../third_party/eigen -isystem /var/lib/jenkins/workspace/third_party/ideep/mkl-dnn/include/oneapi/dnnl -isystem /var/lib/jenkins/workspace/third_party/ideep/include -isystem /var/lib/jenkins/workspace/build/include -D_GLIBCXX_USE_CXX11_ABI=1 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Werror -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow -O3 -DNDEBUG -DNDEBUG -std=gnu++17 -fPIC -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -D__NEON__ -Wall -Wextra -Wdeprecated -Wno-unused-parameter -Wno-missing-field-initializers -Wno-unknown-pragmas -Wno-type-limits -Wno-array-bounds -Wno-strict-overflow -Wno-strict-aliasing -Wunused-function -Wno-maybe-uninitialized -fvisibility=hidden -O2 -pthread -fopenmp -MD -MT caffe2/CMakeFiles/torch_cpu.dir/__/aten/src/ATen/native/mkldnn/Linear.cpp.o -MF caffe2/CMakeFiles/torch_cpu.dir/__/aten/src/ATen/native/mkldnn/Linear.cpp.o.d -o caffe2/CMakeFiles/torch_cpu.dir/__/aten/src/ATen/native/mkldnn/Linear.cpp.o -c /var/lib/jenkins/workspace/aten/src/ATen/native/mkldnn/Linear.cpp
/var/lib/jenkins/workspace/aten/src/ATen/native/mkldnn/Linear.cpp:426:15: error: ‘at::Tensor at::native::mkl_linear(const at::Tensor&, const at::Tensor&, const at::Tensor&, const std::optional<at::Tensor>&, int64_t)’ defined but not used [-Werror=unused-function]
  426 | static Tensor mkl_linear(
      |               ^~~~~~~~~~
```

Follows #128499

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128704
Approved by: https://github.com/malfet
2024-06-19 20:46:45 +00:00
fca408fa29 s390x vectorization: rework operators (#129066)
Move operators from member functions to free functions. This is needed to fix torch inductor on s390x.

This change fixes tests like
DynamicShapesMiscTests::test_numpy_min_dynamic_shapes from test/dynamo/test_dynamic_shapes.py

This change also fixes recently intorduced build failure on s390x.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129066
Approved by: https://github.com/malfet
2024-06-19 20:12:41 +00:00
73f5d2b787 Run ET unit tests on PT CI (#128560)
This is the first PR to add all existing ET unit tests into PT CI.  The goal is to improve the coverage there to avoid breaking change from PT that could break ET.  With this, any future unit tests on ET will automatically be run on PT CI.  The duration of the job is now 40+ minutes, not too bad.

This also fixed the failed ET build in https://github.com/pytorch/pytorch/pull/123043.

Adding model coverage is a bit more evolved and requires adding new shards, so I will follow up on that in separate PRs.

[T192117506](https://www.internalfb.com/intern/tasks/?t=192117506), with the failed diffs D58295865 and D58394154

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128560
Approved by: https://github.com/guangy10, https://github.com/digantdesai
2024-06-19 20:08:58 +00:00
df94d57c0a Revert "[export] experimental joint graph API. (#128847)"
This reverts commit 0707811286d1846209676435f4f86f2b4b3d1a17.

Reverted https://github.com/pytorch/pytorch/pull/128847 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/128847#issuecomment-2179326891))
2024-06-19 19:04:36 +00:00
b5d541609d [Memory Snapshot] Add recordAnnotations to capture record_function annotations (#129072)
Summary:
Add new traceEvents into Memory Snapshot for record_function annotations. These will capture both the profiler's step annotation as well as user annotations.

Test Plan:
CI

Pulled By:
aaronenyeshi

Differential Revision: D55941362

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129072
Approved by: https://github.com/zdevito
2024-06-19 18:05:41 +00:00
bafd68b4fc [inductor] fix windows python module ext and func export declaration (#129059)
I have run the first inductor case on Windows base on the exploration code: https://github.com/pytorch/pytorch/pull/128330
Due to some fundamental PR still need pass `fb_code`: https://github.com/pytorch/pytorch/pull/128303
This PR would land some part of exploration code:
1. Fix Windows python module ext type: pyd.
2. Add function export declaration for Windows.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129059
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-06-19 17:51:32 +00:00
0707811286 [export] experimental joint graph API. (#128847)
Summary:
WARNING: This API is highly unstable and will be subject to change in the future.

Add a protoype to "decompose" an ExportedProgram into a joint graph form, so that we can compute the gradients on this graph.

Test Plan: buck test mode/opt caffe2/torch/fb/export:test_experimental

Differential Revision: D55657917

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128847
Approved by: https://github.com/tugsbayasgalan
2024-06-19 16:45:27 +00:00
0fc603ece4 [optim] Fused implementation stability table (#129006)
I'd like to discuss the criteria that we regard an implementation as stable. If there is no existing standard, my initial proposal would be a 6 month period after the commit to regard it as stable. As a result, now Adam and AdamW on CUDA would be considered as stable, while the rest are of beta.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129006
Approved by: https://github.com/malfet
2024-06-19 16:29:49 +00:00
1b92bdd0ea [ALI] [Reland] Use LF runners for Lint (#129071)
Quick experiment with using LF runners for lint jobs.

Picking a set of jobs where infra failures would be obvious to most people (lint)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129071
Approved by: https://github.com/malfet
2024-06-19 16:10:51 +00:00
236fbcbdf4 [Split Build] Test split build in pull CI workflow (#126813)
This PR builds the split build in the pull workflow and runs the appropriate tests against them. A single linux cpu and single gpu build were chosen arbitrarily to not add too many tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126813
Approved by: https://github.com/atalman
ghstack dependencies: #127934
2024-06-19 15:57:21 +00:00
7d33ff59ba [Split Build]Use same package (#127934)
This PR removes the second separate package we were using for the libtorch wheel.
In terms of testing that this works we will look use the PRs above this in the stack.

As for sanity checking these are the wheels that are produced by running
```
python setup.py clean && BUILD_LIBTORCH_WHL=1 with-proxy python setup.py bdist_whee
l && BUILD_PYTHON_ONLY=1 with-proxy python setup.py bdist_wheel --cmake
```

```
sahanp@devgpu086 ~/pytorch ((5f15e171…))> ls -al dist/                                                        (pytorch-3.10)
total 677236
drwxr-xr-x 1 sahanp users       188 Jun  4 12:19 ./
drwxr-xr-x 1 sahanp users      1696 Jun  4 12:59 ../
-rw-r--r-- 1 sahanp users  81405742 Jun  4 12:19 torch-2.4.0a0+gitca0a73c-cp310-cp310-linux_x86_64.whl
-rw-r--r-- 1 sahanp users 612076919 Jun  4 12:19 libtorch-2.4.0a0+gitca0a73c-py3-none-any.whl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127934
Approved by: https://github.com/atalman
2024-06-19 15:57:21 +00:00
lyb
ffb50fb691 [ONNX] Add onnx::Gelu support for version 20 (#128773)
Fixes https://github.com/pytorch/pytorch/issues/128772
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128773
Approved by: https://github.com/justinchuby
2024-06-19 15:39:02 +00:00
3397d5ef90 Revert "[ALI] Use lf runners for Lint" (#129070)
Reverts pytorch/pytorch#128978
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129070
Approved by: https://github.com/atalman
2024-06-19 14:48:16 +00:00
118f9ceb7c [inductor][ci] Fix torchbench dependency issue with numpy (#128968)
For some reason, pip will always upgrade the numpy version even when an older version has been installed.
We have to lock numpy version to the old version to make this constraint explicit.

Torchbench commit: 23512dbebd

Second attempt to fix #128845

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128968
Approved by: https://github.com/eellison
2024-06-19 12:10:50 +00:00
e49525275d Make TraceUtils.h to be device-agnostic (#126969)
Some features of third-party devices depend on TraceUtils.h, so some of the CUDA code was removed and split into NCCLUtils files.

In addition, some common functions still remain in TraceUtils.h since I'm not sure if other devices will use them later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126969
Approved by: https://github.com/c-p-i-o
2024-06-19 09:06:49 +00:00
7fac03aee9 [ALI] Use lf runners for Lint (#128978) 2024-06-19 10:59:07 +02:00
50567f7081 Pass device to is_pinned call inside TensorProperties.create_from_tensor (#128896)
Summary:
The default input device for is_pinned function is Cuda. This can unnecessarily create Cuda context for CPU tensors when just generating TensorProperties, bloating memory usage. Passing the device to the is_pinned call site inside def create_from_tensor solves this issue.

This also fixes Model Store test
https://www.internalfb.com/intern/test/844425019931542?ref_report_id=0
which is currently broken on memory usage assertions.

Test Plan: UT

Differential Revision: D58695006

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128896
Approved by: https://github.com/fegin
2024-06-19 08:50:46 +00:00
d3e8b8bf47 Remove cuda check in the CUDAGraph destructor (#127382)
Fixes #125804

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127382
Approved by: https://github.com/eqy, https://github.com/eellison
2024-06-19 08:09:31 +00:00
ba92f5277f [inductor][refactor] Unify the use of generate_kernel_call (#128467)
Summary: Refactor TritonTemplateKernel.call_kernel and ForeachKernel.call_kernel to use wrapper.generate_kernel_call to generate kernel calls instead of explicitly composing the kernel call string. This consolidates the entry point of generate_kernel_call and similifies later changes in this PR stack.

Differential Revision: [D58733631](https://our.internmc.facebook.com/intern/diff/D58733631)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128467
Approved by: https://github.com/shunting314
2024-06-19 07:47:25 +00:00
3a185778ed [aotinductor] Add torch.polar fallback op for shim v2 (#128722)
Compilation error:
```
$ TORCHINDUCTOR_C_SHIM_VERSION=2 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_LOGS_FORMAT="%(pathname)s:%(lineno)s: %(message)s" TORCH_LOGS="+output_code" python test/inductor/test_cpu_cpp_wrapper.py -k test_polar

/tmp/tmp2sp128xj/dy/cdypvu3hvgg3mwxydwbiuddsnmuoi37it3mrpjktcnu6vt4hr3ki.cpp:59:33: error: ‘aoti_torch_cpu_polar’ was not declared in this scope; did you mean ‘aoti_torch_cpu_topk’?
```

Steps:
1. Add aten.polar
2. run `python torchgen/gen.py --update-aoti-c-shim`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128722
Approved by: https://github.com/chenyang78, https://github.com/desertfire
2024-06-19 05:06:58 +00:00
a584b2a389 Revert "Add test to xfail_list only for abi_compatible (#128506)"
This reverts commit df85f34a14dd30f784418624b05bd52b12ab8b0b.

Reverted https://github.com/pytorch/pytorch/pull/128506 on behalf of https://github.com/huydhn due to The failure shows up in trunk df85f34a14 ([comment](https://github.com/pytorch/pytorch/pull/128506#issuecomment-2177744578))
2024-06-19 04:59:10 +00:00
fcf2a1378b Enable fp8 rowwise scaling kernel on cuda, TAKE 2: #125204 (#128989)
# Summary
First PR got reverted and needed a redo

This pull request introduces an fp8 row-scaling kernel as an optional implementation for `scaled_mm`. The kernel selection is based on the scaling tensors of the inputs. For inputs `x` and `y` of shape `[M, K]` and `[K, N]` respectively, the following conditions must be met:
- `x`'s scale should be a 1-dimensional tensor of length `M`.
- `y`'s scale should be a 1-dimensional tensor of length `N`.

It's important to note that this kernel is not called "rowwise, columnwise" scaling because, although the scales for `y` are semantically along its columns, this implementation only supports the TN format. This means the scaling is along the faster-moving dimension, or the "row".

The following two PRs were required to enable local builds:
- [PR #126185](https://github.com/pytorch/pytorch/pull/126185)
- [PR #125523](https://github.com/pytorch/pytorch/pull/125523)

### Todo
We still do not build our Python wheels with this architecture.

@ptrblck @malfet, should we replace `sm_90` with `sm_90a`?

The NVRTC TMA shadowing feels wrong, but I a not sure the right way to spoof the symbol for this compilation unit:
https://github.com/pytorch/pytorch/pull/125204/files#r1586986954

#### ifdef

I tried to use : `#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 && \
    defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 900` to gate the building of the kernel. I was having a hell of a time with this.. so I am not really sure the right way to do this

Kernel Credit:
@jwfromm

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128989
Approved by: https://github.com/yangsiyu007, https://github.com/vkuzo
2024-06-19 04:49:39 +00:00
2f88597aad [inductor] For internal, allow multiple workers if the method is "subprocess" (#129002)
Summary: This does not change the current default behavior in fbcode ("fork" if unspecified and no worker processes if unspecified). But it allows us to more easily test the subprocess-based parallel if we override the start method to subprocess.

Test Plan: Set `TORCHINDUCTOR_WORKER_START=subprocess` and locally ran all torchbench models listed [here](https://www.internalfb.com/intern/wiki/PyTorch/Teams/PyTorch_Perf_Infra/TorchBench/#torchbench-internal-mode)

Differential Revision: D58755021

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129002
Approved by: https://github.com/eellison
2024-06-19 04:28:27 +00:00
1f0a68b572 [ROCm] Fix fp32 atomicAdd for non-MI100 GPUs (#128750)
Current implementation is very specific to MI100.
This is causing performance degradation for other GPUs.

Fixes #128631

Benchmarking on MI300X:
```
Before:  1918.5126953125 ms
After: 0.8285150527954102 ms
```

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128750
Approved by: https://github.com/xw285cornell
2024-06-19 03:56:20 +00:00
acefc5c016 [torch.compile] Enable bwd compilation metrics (#128973)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128973
Approved by: https://github.com/dshi7
2024-06-19 03:45:41 +00:00
eb9f4da11e Modified template indexing to broadcast indices to out instead of mask and some other flexattention micro-opts (#128938)
For headdim=64 and headdim=128

Old:
<img width="656" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/2c5d1613-96dc-4300-8dc0-dccaef59e73c">

New:
<img width="644" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/730004a8-6d5f-46a5-82a0-2594feb5e192">

Note, this does regress headdim=256. We can unregress it by special casing `headdim=256`, but ehh.... we can do it later

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128938
Approved by: https://github.com/drisspg
2024-06-19 03:41:22 +00:00
8771e3429c Introduce a prototype for SymmetricMemory (#128582)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

This PR introduces a prototype for `SymmetricMemory` (including a CUDA implementation) - a remote-memory access-based communication primitive. It allows for user-defined communication patterns/kernels and is designed to be torch.compile-friendly. It addresses the major limitations of `IntraNodeComm` and `ProcessGroupCudaP2p` and serves as a replacement for them.

### SymmetricMemory

`SymmetricMemory` represents symmetric allocations across a group of devices. The allocations represented by a `SymmetricMemory` object are accessible by all devices in the group. The class can be used for **op-level custom communication patterns** (via the get_buffer APIs and the synchronization primitives), as well as **custom communication kernels** (via the buffer and signal_pad device pointers).

### Python API Example

```python
from torch._C.distributed_c10d import _SymmetricMemory

# Set a store for rendezvousing symmetric allocations on a group of devices
# identified by group_name. The concept of groups is logical; users can
# utilize predefined groups (e.g., a group of device identified by a
# ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator
# backends might employ a more efficient communication channel for the actual
# rendezvous process and only use the store for bootstrapping purposes.
_SymmetricMemory.set_group_info(group_name, rank, world_size, store)

# Identical to empty_strided, but allows symmetric memory access to be
# established for the allocated tensor via _SymmetricMemory.rendezvous().
# This function itself is not a collective operation.
t = _SymmetricMemory.empty_strided_p2p((64, 64), (64, 1), torch.float32, group_name)

# Users can write Python custom ops that leverages the symmetric memory access.
# Below are examples of things users can do (assuming the group's world_size is 2).

# Establishes symmetric memory access on tensors allocated via
# _SymmetricMemory.empty_strided_p2p(). rendezvous() is a one-time process,
# and the mapping between a local memory region and the associated SymmetricMemory
# object is unique. Subsequent calls to rendezvous() with the same tensor will receive
# the cached SymmetricMemory object.
#
# The function has a collective semantic and must be invoked simultaneously
# from all rendezvous participants.
symm_mem = _SymmetricMemory.rendezvous(t)

# This represents the allocation on rank 0 and is accessible from all devices.
buf = symm_mem.get_buffer(0, (64, 64), torch.float32)

if symm_mem.rank == 0:
    symm_mem.wait_signal(src_rank=1)
    assert buf.eq(42).all()
else:
    # The remote buffer can be used as a regular tensor
    buf.fill_(42)
    symm_mem.put_signal(dst_rank=0)

symm_mem.barrier()

if symm_mem.rank == 0:
    symm_mem.barrier()
    assert buf.eq(43).all()
else:
    new_val = torch.empty_like(buf)
    new_val.fill_(43)
    # Contiguous copies to/from a remote buffer utilize copy engines
    # which bypasses SMs (i.e. no need to load the data into registers)
    buf.copy_(new_val)
    symm_mem.barrier()
```

### Custom CUDA Comm Kernels

Given a tensor, users can access the associated `SymmetricMemory` which provides pointer to remote buffers/signal_pads needed for custom communication kernels.

```cpp
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
    const at::Tensor& tensor);

class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
 public:
  ...
  virtual std::vector<void*> get_buffer_ptrs() = 0;
  virtual std::vector<void*> get_signal_pad_ptrs() = 0;
  virtual void** get_buffer_ptrs_dev() = 0;
  virtual void** get_signal_pad_ptrs_dev() = 0;
  virtual size_t get_buffer_size() = 0;
  virtual size_t get_signal_pad_size() = 0;
  virtual int get_rank() = 0;
  virtual int get_world_size() = 0;
  ...
};
```

### Limitations of IntraNodeComm and ProcessGroupCudaP2p
Both `IntraNodeComm` (used by `ProcessGroupCudaP2p`) manages a single fixed-size workspace. This approach:
- Leads to awkward UX in which the required workspace needs to be specified upfront.
- Can not avoid extra copies for some algorithms in eager mode (e.g., custom/multimem all-reduce, reduce-scatter, all-gather).
- Prevents torch.compile from eliminating all copies.

In addition, they only offer out-of-the-box communication kernels and don't expose required pointers for user-defined, custom CUDA comm kernels.

* __->__ #128582

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128582
Approved by: https://github.com/wanchaol
2024-06-19 03:38:58 +00:00
ed5b8432cd Enable mixed_mm only if casting from lower-bitwidth type to a higher one (#128899)
This PR changes the behavior of `cuda_and_enabled_mixed_mm` such that mixed_mm is only enabled if we are casting from a lower-bitwidth type to a higher one.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128899
Approved by: https://github.com/eellison
2024-06-19 03:12:18 +00:00
df85f34a14 Add test to xfail_list only for abi_compatible (#128506)
https://github.com/pytorch/pytorch/pull/126717 will skip the tests in both ABI compatible and non-ABI compatible mode.
It's not expected to skip them in non-ABI compatible mode since they can actually run successfully in such mode but only have issues in ABI compatible mode.

We leverage the existing `xfail_list` for those that will only fail in ABI compatible mode.

- `test_qlinear_add` is already in the `xfail_list`.
- `test_linear_packed` doesn't fail either in my local run (running with `TORCHINDUCTOR_ABI_COMPATIBLE=1`) or in the CI of this PR so I didn't add it into `xfail_list`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128506
Approved by: https://github.com/jgong5, https://github.com/desertfire
2024-06-19 01:18:37 +00:00
4bc90185fb fix: Print statements causing parse error (#128969)
The print statements for the get_workflow_type script is problematic because the shell script calling this script is expecting the output to only be JSON. This PR resolves this by removing all print statements to covert them to a message field in the JSON return output so that the output can continue to expect to be JSON while giving us the debug data we are looking for.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128969
Approved by: https://github.com/tylertitsworth, https://github.com/ZainRizvi
2024-06-19 01:17:08 +00:00
eda375a490 [Inductor] Remove min/max from inductor opinfo test (#128925)
**Summary**
Remove `max.binary, min.binary, maximum, minimum` from `inductor_one_sample` op list as we fix the bool vectorization issue in https://github.com/pytorch/pytorch/pull/126841.

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_maximum
python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_minimum
python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_min_binary
python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_max_binary
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128925
Approved by: https://github.com/isuruf, https://github.com/jgong5, https://github.com/peterbell10
2024-06-19 01:14:27 +00:00
2458f79f83 [Inductor UT][Intel GPU] Skip newly added test case test_torchinductor_strided_blocks:test_reduction for Intel GPU (#128881)
Skip newly added test case test_torchinductor_strided_blocks:test_reduction for Intel GPU because
it have not implemented reduction kernel split.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128881
Approved by: https://github.com/blaine-rister, https://github.com/EikanWang, https://github.com/malfet
2024-06-19 00:44:57 +00:00
b0d2fe6299 Revert "Short-term fix to preserve NJT metadata cache in torch.compile (#122836)"
This reverts commit 2a41fc03903de63270d325bd1886a50faf32d7e4.

Reverted https://github.com/pytorch/pytorch/pull/122836 on behalf of https://github.com/jbschlosser due to internal test failures with DEBUG=1 asserts ([comment](https://github.com/pytorch/pytorch/pull/122836#issuecomment-2177298245))
2024-06-19 00:28:53 +00:00
5ffb032be6 Revert "Backward support for unbind() with NJT (#128032)"
This reverts commit 5dc4f652bc5c068ef15130c955e3f2ffe11f4b74.

Reverted https://github.com/pytorch/pytorch/pull/128032 on behalf of https://github.com/jbschlosser due to reverting to revert parent PR ([comment](https://github.com/pytorch/pytorch/pull/128032#issuecomment-2177296325))
2024-06-19 00:26:40 +00:00
35c78668b4 Improve the debugging message for when foreach mta_called (#128991)
The hope that lives in this PR: I am currently trying to debug why the foreach tests are so flaky. It looks like every flaky test falls under this pattern:
- a test is flaky due to the mta_called assertion, which gathers data from the profiler regarding whether the multi_tensor_apply_kernel has been called.
- then, a later test fails deterministically, usually failing to compare two results.

```
================== 1 failed, 241 deselected, 2 rerun in 1.76s ==================
Got exit code 1
Stopping at first consistent failure
The following tests failed and then succeeded when run in a new process ['test/test_foreach.py::TestForeachCUDA::test_binary_op_float_inf_nan__foreach_add_cuda_bfloat16']
The following tests failed consistently: ['test/test_foreach.py::TestForeachCUDA::test_binary_op_list_error_cases__foreach_add_cuda_bfloat16']
```

So my suspicion is that the first causes the second, but what causes the first? Idk! So it would be nice to have the error message tell us what the profiler actually saw in case it's getting muddled. This change would help mostly because I have not been able to repro this flakiness locally.

Also undo the useless changes in #128220 which are actually redundant as Joel and I realized that we set the seed during the setUp of every test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128991
Approved by: https://github.com/clee2000
2024-06-19 00:25:09 +00:00
99f042d336 Revert "Forward fix to skip ROCm tests for #122836 (#128891)"
This reverts commit 4061b3b8225f522ae0ed6db00111441e7d3cc3d5.

Reverted https://github.com/pytorch/pytorch/pull/128891 on behalf of https://github.com/jbschlosser due to reverting to revert parent PR ([comment](https://github.com/pytorch/pytorch/pull/128891#issuecomment-2177291249))
2024-06-19 00:21:21 +00:00
670b94c9c8 [inductor][mkldnn] Use floats instead of ints for pattern matcher test (#128484)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128484
Approved by: https://github.com/mlazos
ghstack dependencies: #128428
2024-06-19 00:06:46 +00:00
c5e0b84484 [dynamo][trace_rules] Remove incorrectly classified Ingraph functions (#128428)
Co-authored-by: Laith Sakka <lsakka@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128428
Approved by: https://github.com/yanboliang, https://github.com/mlazos
2024-06-19 00:06:46 +00:00
cyy
cb5e9183c6 [Caffe2] [2/N] Remove Caffe2 from tests (#128911)
Follows #128675

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128911
Approved by: https://github.com/titaiwangms, https://github.com/r-barnes
2024-06-19 00:05:50 +00:00
ac5f565fa7 [FSDP2] Added set_post_optim_event (#128975)
This PR adds `set_post_optim_event` that allows power users to provide their own CUDA event that is recorded after the optimizer step for the FSDP root module to wait the all-gather streams on.
```
def set_post_optim_event(self, event: torch.cuda.Event) -> None:
```
By default, the root would have the all-gather streams wait on the current stream (`wait_stream`), which may introduce false dependencies if there is unrelated computation after the optimizer step and before the wait. For example, this pattern can appear in recommendation models.

To avoid those false dependencies while preserving the correctness guarantee, we provide this API so that the user can provide their own CUDA event to wait the all-gather streams on.

We include both correctness test (`test_fully_shard_training.py`) and overlap test (`test_fully_shard_overlap.py`).

---

One possible way to use the API is to register a post-step hook on the optimizer. For example:
12e8d1399b/test/distributed/_composable/fsdp/test_fully_shard_training.py (L546-L552)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128975
Approved by: https://github.com/sanketpurandare, https://github.com/weifengpy
ghstack dependencies: #128884
2024-06-18 22:26:14 +00:00
d9c294c672 [Inductor] Fix arguments passed to triton kernel launch hooks (#128732)
`binary.launch_enter_hook` is treated as an instance method and will add a `self` argument to the hooks.
`CompiledKernel.launch_enter_hook` is a static method, which matches the hook calling convention of profilers (i.e., a single `LazyDict` argument only).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128732
Approved by: https://github.com/shunting314, https://github.com/bertmaher
2024-06-18 22:06:55 +00:00
a0e1e20c41 [BE][Easy] enable UFMT for torch/distributed/ (#128870)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870
Approved by: https://github.com/fegin
ghstack dependencies: #128868, #128869
2024-06-18 21:49:08 +00:00
3b798df853 [BE][Easy] enable UFMT for torch/distributed/{fsdp,optim,rpc}/ (#128869)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128869
Approved by: https://github.com/fegin
ghstack dependencies: #128868
2024-06-18 21:49:08 +00:00
cec31050b4 [BE][Easy] enable UFMT for torch/distributed/{tensor,_tensor}/ (#128868)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128868
Approved by: https://github.com/fegin
2024-06-18 21:49:02 +00:00
e47603a549 Fix weight_norm decomposition behavior (#128956)
By upcasting norm to float32 to align with CUDA and CPU behaviors
e6d4451ae8/aten/src/ATen/native/WeightNorm.cpp (L56-L59)

Discovered this when started running OpInfo tests, see https://github.com/pytorch/pytorch/actions/runs/9552858711/job/26332062502#step:20:1060
```
  File "/var/lib/jenkins/workspace/test/test_decomp.py", line 185, in op_assert_ref
    assert orig.dtype == decomp.dtype, f"{i} Operation:  {op}"
AssertionError: 1 Operation:  aten._weight_norm_interface.default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128956
Approved by: https://github.com/albanD
ghstack dependencies: #128955
2024-06-18 21:24:12 +00:00
2227da4431 [Profiler] Clean up use_mtia to follow standard use_device instead (#126284)
Summary:
use_mtia should instead set use_device='mtia' similar to cuda, xpu, and privateuseone. Avoid an ever-growing list of use_* arguments.

Since use_mtia is specific to FBCode, we don't need a deprecation warning.

Test Plan: CI.

Differential Revision: D57338005

Pulled By: aaronenyeshi

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126284
Approved by: https://github.com/fenypatel99
2024-06-18 21:01:03 +00:00
4cc3fb5ee2 Bump urllib3 from 2.2.1 to 2.2.2 in /tools/build/bazel (#128908)
Bumps [urllib3](https://github.com/urllib3/urllib3) from 2.2.1 to 2.2.2.
- [Release notes](https://github.com/urllib3/urllib3/releases)
- [Changelog](https://github.com/urllib3/urllib3/blob/main/CHANGES.rst)
- [Commits](https://github.com/urllib3/urllib3/compare/2.2.1...2.2.2)

---
updated-dependencies:
- dependency-name: urllib3
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-06-18 13:38:22 -07:00
5dc4f652bc Backward support for unbind() with NJT (#128032)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128032
Approved by: https://github.com/soulitzer
2024-06-18 20:29:00 +00:00
44722c6b10 Revert "[dynamo][fsdp] Dont take unspecializedNNModuleVariable path for FSDP modules (#128453)"
This reverts commit 2b28b107dbafeec18d1095a2002e79511aa241df.

Reverted https://github.com/pytorch/pytorch/pull/128453 on behalf of https://github.com/anijain2305 due to luca saw bad compile time ([comment](https://github.com/pytorch/pytorch/pull/128453#issuecomment-2176877667))
2024-06-18 20:09:00 +00:00
1babeddbbf Revert "[inductor][mkldnn] Use floats instead of ints for pattern matcher test (#128484)"
This reverts commit 1f6e84fa6852805e15ddc9583c5f36c3a7f93df8.

Reverted https://github.com/pytorch/pytorch/pull/128484 on behalf of https://github.com/anijain2305 due to luca saw bad compile time ([comment](https://github.com/pytorch/pytorch/pull/128453#issuecomment-2176877667))
2024-06-18 20:09:00 +00:00
5bc9835d64 Revert "[dynamo][trace_rules] Remove incorrectly classified Ingraph functions (#128428)"
This reverts commit c52eda896eb3ec7f8d04b6321861f4c5614a40bb.

Reverted https://github.com/pytorch/pytorch/pull/128428 on behalf of https://github.com/anijain2305 due to luca saw bad compile time ([comment](https://github.com/pytorch/pytorch/pull/128453#issuecomment-2176877667))
2024-06-18 20:09:00 +00:00
9a7e2519d3 [MPS] Fused Adam & AdamW (#127242)
Summary:

This PR adds fused Adam and AdamW implementations.

Benchmark on Macbook Pro with M1 Max chip and 64GB unified memory:
**Fast math enabled:**
```
[---------------------------------------------- Fused Adam ----------------------------------------------]
                                                                           |  Fused: True  |  Fused: False
1 threads: -----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100        |       10      |       100
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100       |        9      |        89
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100       |        9      |        90
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100      |        9      |        83
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100       |       12      |        94
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100      |       11      |        88
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100      |       12      |        90
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100     |       11      |       100
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100     |       27      |       100
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100    |       23      |       100
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100    |       27      |       100
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100   |       23      |        98
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500        |       82      |       480
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500       |       72      |       450
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500       |       82      |       450
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500      |       73      |       420
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500       |       91      |       500
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500      |       83      |       400
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500      |       94      |       500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500     |       78      |       400
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500     |      170      |       500
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500    |      140      |       600
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500    |      170      |       600
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500   |      140      |       500
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000       |      250      |       890
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000      |      220      |       850
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000      |      250      |       830
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000     |      220      |       770
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000      |      270      |       870
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000     |      230      |       840
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000     |      270      |       810
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000    |      240      |       800
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000    |      400      |      1000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000   |      360      |      2000
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000   |      430      |      2000
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000  |      360      |      1300

Times are in milliseconds (ms).
```

**Fast math disabled:**
```
[---------------------------------------------- Fused Adam ----------------------------------------------]
                                                                           |  Fused: True  |  Fused: False
1 threads: -----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100        |       10      |       100
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100       |        9      |        84
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100       |        9      |        84
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100      |        9      |        79
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100       |       11      |        93
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100      |       10      |        90
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100      |       11      |        91
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100     |       11      |        81
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100     |       34      |       100
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100    |       31      |       100
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100    |       34      |        95
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100   |       31      |       100
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500        |       94      |       500
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500       |       82      |       430
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500       |       92      |       430
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500      |       81      |       390
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500       |       98      |       500
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500      |       88      |       430
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500      |      100      |       500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500     |       88      |       400
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500     |      210      |       500
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500    |      190      |       610
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500    |      210      |       510
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500   |      190      |       500
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000       |      300      |       900
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000      |      260      |       850
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000      |      295      |       900
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000     |      260      |       800
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000      |      320      |       910
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000     |      280      |       900
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000     |      320      |       900
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000    |      300      |       900
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000    |      500      |      2000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000   |      480      |      2000
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000   |      540      |      1500
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000  |      480      |      1200

Times are in milliseconds (ms).
```

```python
def profile_fused_adam():
    from torch.optim import adam, adamw
    import torch.utils.benchmark as benchmark

    import itertools

    def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused):
        fn(
            params,
            grads,
            exp_avgs,
            exp_avg_sqs,
            max_exp_avg_sqs,
            state_steps,
            foreach=False,
            capturable=False,
            fused=fused,
            amsgrad=amsgrad,
            beta1=0.9,
            beta2=0.99,
            lr=1e-3,
            weight_decay=.0,
            eps=1e-5,
            maximize=False,
            grad_scale=None,
            found_inf=None,
        )
        torch.mps.synchronize()

    device = "mps"

    results = []

    for num_tensors, numel, adamWflag, amsgrad in itertools.product([100, 500, 1000], [1024, 65536, 1048576], [True, False], [True, False]):
        print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}")
        params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=torch.float32, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)]
        max_exp_avg_sqs = [torch.arange(numel, dtype=torch.float32, device=device) for _ in range(num_tensors)] if amsgrad else []
        state_steps = [torch.tensor([5], dtype=torch.float32, device=device) for _ in range(num_tensors)]
        if adamWflag:
            fn = adamw.adamw
        else:
            fn = adam.adam

        for fused in [True, False]:

            t = benchmark.Timer(
                    stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)',
                    label='Fused Adam',
                    sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}",
                    globals=locals(),
                    description= f"Fused: {fused}",
                ).blocked_autorange(min_run_time=5)
            results.append(t)

    compare = benchmark.Compare(results)
    compare.trim_significant_figures()
    compare.colorize(rowwise=True)
    compare.print()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127242
Approved by: https://github.com/kulinseth, https://github.com/janeyx99
2024-06-18 19:59:50 +00:00
fe8558b7aa [DSD] Add unittest to verify HSDP1 + broadcast_from_rank0 (#128755)
HSDP1 + broadcast_from_rank0 actually behaves differently from FSDP1 + broadcast_from_rank0. So we need an unittest to cover this use case.

This test relies on the fix from https://github.com/pytorch/pytorch/pull/128446.

Differential Revision: [D58621436](https://our.internmc.facebook.com/intern/diff/D58621436/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128755
Approved by: https://github.com/Skylion007, https://github.com/wz337
ghstack dependencies: #128685
2024-06-18 19:42:51 +00:00
abde6cab4c Remove compile_threads=1 in test_inductor_collectives.py (#128580)
Summary: I believe https://github.com/pytorch/pytorch/issues/125235 should be fixed after switching to subprocess-based parallel compile.

Test Plan: Ran locally with python-3.9

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128580
Approved by: https://github.com/eellison
2024-06-18 19:31:13 +00:00
04a5d3228e [ts migration] Support prim::tolist and aten::len (#128894)
Support prim::tolist and aten::len. Add unit tests for prim::min.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128894
Approved by: https://github.com/angelayi
2024-06-18 19:11:07 +00:00
44483972bd [EZ] Keep weight_norm var name aligned (#128955)
To keep it aligned with
e6d4451ae8/aten/src/ATen/native/native_functions.yaml (L6484)
I.e.  `x`->`v`, `y`->`g`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128955
Approved by: https://github.com/albanD, https://github.com/Skylion007
2024-06-18 18:40:59 +00:00
bdffd9f0c6 [export] Graph break on nn.Parameter construction (#128935)
Fixes https://github.com/pytorch/pytorch/issues/126109

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128935
Approved by: https://github.com/angelayi
2024-06-18 18:37:44 +00:00
1a527915a6 [DSD] Correctly handle shared parameters for optimizer state_dict (#128685)
*
Fixes https://github.com/pytorch/pytorch/issues/128011

See the discussion in https://github.com/pytorch/pytorch/pull/128076

Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue.

Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128685
Approved by: https://github.com/LucasLLC
2024-06-18 18:34:32 +00:00
d77a1aaa86 DOC: add note about same sized tensors to dist.gather() (#128676)
Fixes #103305

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128676
Approved by: https://github.com/wconstab
2024-06-18 18:26:07 +00:00
1877b7896c [checkpoint] Clean up selective activation checkpoint and make public (#125795)
### bc-breaking for existing users of the private API:
- Existing policy functions must now change their return value to be [CheckpointPolicy](c0b40ab42e/torch/utils/checkpoint.py (L1204-L1230))  Enum instead of bool.
   - To restore previous behavior, return `PREFER_RECOMPUTE` instead of `False` and `{PREFER,MUST}_SAVE` instead of `True` depending whether you prefer the compiler to override your policy.
- Policy function now accepts a `ctx` object instead of `mode` for its first argument.
   - To restore previous behavior, `mode = "recompute" if ctx.is_recompute else "forward"`.
- Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `create_selective_checkpoint_contexts `. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint).

Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit

Memory considerations:
- As with the existing SAC, cached values are cleared upon first use.
- We error if the user wishes to backward a second time on a region forwarded with SAC enabled.

In-place:
- We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed.
- `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place)

Randomness, views
- Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors)

Tensor object preservation
- ~We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object.~ UPDATE: We guarantee that if a tensor is of non-differentiable dtype AND it is not a view, and it is saved, then what you get out is the same tensor object. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor.

Policy function
- Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error.
- The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3).
- The number of times we call the policy_fn is something that should be documented as part of public API. We call the policy function for all ops except ~~detach~~ UPDATE :  metadata ops listed in `torch.utils.checkpoint.SAC_IGNORED_OPS`) because these ops may be called a different number of times by AC itself between forward and recompute.
- The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below).
Tensors guaranteed to be the same tensor as-is
- Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795
Approved by: https://github.com/Chillee, https://github.com/fmassa
2024-06-18 18:18:50 +00:00
77830d509f Revert "Introduce a prototype for SymmetricMemory (#128582)"
This reverts commit 7a39755da28d5a109bf0c37f72b364d3a83137b1.

Reverted https://github.com/pytorch/pytorch/pull/128582 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/128582#issuecomment-2176685232))
2024-06-18 18:11:43 +00:00
84c86e56bd Update tracker issues after successfully cherry-picking a PR (#128924)
This extends the capacity of the cherry-pick bot to automatically update the tracker issue with the information.  For this to work, the tracker issue needs to be an open one with a `release tracker` label, i.e. https://github.com/pytorch/pytorch/issues/128436.  The version from the release branch, i.e. `release/2.4`, will be match with the title of the tracker issue, i.e. `[v.2.4.0] Release Tracker` or `[v.2.4.1] Release Tracker`

### Testing

`python cherry_pick.py --onto-branch release/2.4 --classification release --fixes "DEBUG DEBUG" --github-actor huydhn 128718`

* On the PR https://github.com/pytorch/pytorch/pull/128718#issuecomment-2174846771
* On the tracker issue https://github.com/pytorch/pytorch/issues/128436#issuecomment-2174846757

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128924
Approved by: https://github.com/atalman
2024-06-18 17:48:47 +00:00
eqy
4e03263224 [CUDA][Convolution] Add missing launch bounds to vol2col_kernel (#128740)
Fix "too many resources requested" that can happen with recent toolkits on V100.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128740
Approved by: https://github.com/mikaylagawarecki
2024-06-18 17:26:23 +00:00
26e374e3ca [EZ] Fix typos in RELEASE.md (#128769)
This PR fixes typo in `RELEASE.md`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128769
Approved by: https://github.com/yumium, https://github.com/mikaylagawarecki
2024-06-18 17:15:05 +00:00
9818283da1 re-enable jacrev/jacfwd/hessian after #128028 landed (#128622)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128622
Approved by: https://github.com/zou3519
2024-06-18 17:08:58 +00:00
eqy
ec616da518 RNN API cleanup for cuDNN 9.1 (#122011)
Can potentially avoid a bit of boilerplate if we move directly to cuDNN 9.1's RNN API...

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122011
Approved by: https://github.com/Skylion007
2024-06-18 16:16:38 +00:00
108318ad10 [BE][JIT] Handle case where codegen object can be unset (#128951)
Summary:
Unblocks a test that's failing.

`codegen` can be unset until `compile` is called. If `codegen` is not set, then just use the kernel name directly.

Test Plan:
```
buck2 run //caffe2/test:tensorexpr -- --regex test_simple_add
```

Differential Revision: D58727391

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128951
Approved by: https://github.com/aaronenyeshi
2024-06-18 15:40:45 +00:00
4817180601 make fallback for aten.argsort.stable (#128907)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128907
Approved by: https://github.com/lezcano
ghstack dependencies: #128343
2024-06-18 14:56:35 +00:00
22d258427b [BE][Easy] enable UFMT for torch/distributed/_shard/ (#128867)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128867
Approved by: https://github.com/fegin
ghstack dependencies: #128866
2024-06-18 14:39:25 +00:00
e6d4451ae8 [BE][Easy] enable UFMT for torch/distributed/{algorithms,autograd,benchmarks,checkpoint,elastic}/ (#128866)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128866
Approved by: https://github.com/fegin
2024-06-18 13:51:53 +00:00
f2805a0408 [FSDP2] Added APIs for explicit fwd/bwd prefetching (#128884)
This PR adds two APIs `set_modules_to_forward_prefetch` and `set_modules_to_backward_prefetch` to enable explicit forward/backward all-gather prefetching, respectively.

```
def set_modules_to_forward_prefetch(self, modules: List[FSDPModule]): -> None
def set_modules_to_backward_prefetch(self, modules: List[FSDPModule]): -> None
```

**Motivation**
FSDP2 implements _reasonable defaults_ for forward and backward prefetching. In forward, it uses implicit prefetching and allows two all-gather output tensors to be alive at once (so that the current all-gather copy-out can overlap with the next all-gather). In backward, it uses explicit prefetching based on the reverse post-forward order.

However, there may be cases where with expert knowledge, we can reduce communication bubbles by moving all-gathers manually. One way to expose such behavior is to expose _prefetching limits_, i.e. integers that configure how many outstanding all-gathers/all-gather output tensors can be alive at once. IMIHO, this leans toward _easy_, not _simple_ (see [PyTorch design principles](https://pytorch.org/docs/stable/community/design.html#principle-2-simple-over-easy)).

The crux of the problem is that there may be special cases where manual intervention can give better performance. Exposing a prefetching limit and allowing users to pass a value >1 just smooths over the problem since such a limit would generally apply over the entire model even though it possibly should not. Then, expert users will see a specific all-gather that they want to deviate from this limit, and there is little we can do.

Thus, we instead choose to expose the most primitive extension point: namely, every `FSDPModule` gives an opportunity to prefetch other all-gathers in forward and in backward. How to leverage this extension point is fully up to the user. Implementing the prefetch limit can be done using this extension point (e.g. record the post-forward order yourself using forward hooks, iterate over that order, and call the `set_modules_to_forward_prefetch` / `set_modules_to_backward_prefetch` APIs).

Differential Revision: [D58700346](https://our.internmc.facebook.com/intern/diff/D58700346)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128884
Approved by: https://github.com/ckluk2, https://github.com/weifengpy
2024-06-18 13:32:57 +00:00
3dd5f0ecbb Remove circular import (#128875)
Summary: A spurious import is causing circular dependency errors

Test Plan: phabricator signals

Differential Revision: D58685676

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128875
Approved by: https://github.com/kit1980
2024-06-18 12:30:13 +00:00
304c934572 Move MKLDNN Specific IR to Separate File (#126504)
**Summary**
Following the discussion in https://github.com/pytorch/pytorch/pull/122593#discussion_r1604144782, Move Inductor MKLDNN specific IRs to a separate file.

Co-authored-by: Isuru Fernando <ifernando@quansight.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126504
Approved by: https://github.com/desertfire, https://github.com/jgong5
ghstack dependencies: #126841, #126940
2024-06-18 09:29:13 +00:00
6e43897912 [BE][ptd_fb_test][3/N] Enable TestSlide for MultiThreadedTestCase (#128843)
Enabling testslide for MultiThreadedTestCase, similar to https://github.com/pytorch/pytorch/pull/127512.

Differential Revision: [D58677457](https://our.internmc.facebook.com/intern/diff/D58677457/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128843
Approved by: https://github.com/wz337
2024-06-18 07:05:31 +00:00
60baeee59f [BE] Skip the test if CUDA is not available (#128885)
As title

Differential Revision: [D58690210](https://our.internmc.facebook.com/intern/diff/D58690210/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128885
Approved by: https://github.com/wz337
2024-06-18 07:02:44 +00:00
e3a39d49a0 [Traceable FSDP][Compiled Autograd] Add queue_callback() support (#126366)
Adds support for `Variable._execution_engine.queue_callback()`, which is used in FSDP2.

Important tests:
- `pytest -rA test/inductor/test_compiled_autograd.py::TestCompiledAutograd::test_callback_graph_break_throws_error`
- `pytest -rA test/inductor/test_compiled_autograd.py::TestAutogradWithCompiledAutograd::test_callback_adds_callback`
- `PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py -k TestAutograd.test_callback_adds_callback`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126366
Approved by: https://github.com/xmfan
2024-06-18 06:22:14 +00:00
436 changed files with 10975 additions and 6691 deletions

View File

@ -1 +1 @@
d4b3e5cc607e97afdba79dc90f8ef968142f347c
172574a6be5910a4609e4ed1bef2b6b8475ddb3d

View File

@ -37,6 +37,9 @@ install_conda_dependencies() {
install_pip_dependencies() {
pushd executorch/.ci/docker
# Install PyTorch CPU build beforehand to avoid installing the much bigger CUDA
# binaries later, ExecuTorch only needs CPU
pip_install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
# Install all Python dependencies
pip_install -r requirements-ci.txt
popd
@ -44,13 +47,14 @@ install_pip_dependencies() {
setup_executorch() {
pushd executorch
source .ci/scripts/utils.sh
# Setup swiftshader and Vulkan SDK which are required to build the Vulkan delegate
as_jenkins bash .ci/scripts/setup-vulkan-linux-deps.sh
install_flatc_from_source
pip_install .
export PYTHON_EXECUTABLE=python
export EXECUTORCH_BUILD_PYBIND=ON
export CMAKE_ARGS="-DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON"
# Make sure that all the newly generate files are owned by Jenkins
chown -R jenkins .
as_jenkins .ci/scripts/setup-linux.sh cmake
popd
}

View File

@ -284,12 +284,26 @@ else
# Which should be backward compatible with Numpy-1.X
python -mpip install --pre numpy==2.0.0rc1
fi
WERROR=1 python setup.py bdist_wheel
WERROR=1 python setup.py clean
if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
BUILD_LIBTORCH_WHL=1 BUILD_PYTHON_ONLY=0 python setup.py bdist_wheel
BUILD_LIBTORCH_WHL=0 BUILD_PYTHON_ONLY=1 python setup.py bdist_wheel --cmake
else
WERROR=1 python setup.py bdist_wheel
fi
else
python setup.py clean
if [[ "$BUILD_ENVIRONMENT" == *xla* ]]; then
source .ci/pytorch/install_cache_xla.sh
fi
python setup.py bdist_wheel
if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
echo "USE_SPLIT_BUILD cannot be used with xla or rocm"
exit 1
else
python setup.py bdist_wheel
fi
fi
pip_install_whl "$(echo dist/*.whl)"
@ -328,9 +342,10 @@ else
CUSTOM_OP_TEST="$PWD/test/custom_operator"
python --version
SITE_PACKAGES="$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')"
mkdir -p "$CUSTOM_OP_BUILD"
pushd "$CUSTOM_OP_BUILD"
cmake "$CUSTOM_OP_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch" -DPython_EXECUTABLE="$(which python)" \
cmake "$CUSTOM_OP_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch;$SITE_PACKAGES" -DPython_EXECUTABLE="$(which python)" \
-DCMAKE_MODULE_PATH="$CUSTOM_TEST_MODULE_PATH" -DUSE_ROCM="$CUSTOM_TEST_USE_ROCM"
make VERBOSE=1
popd
@ -343,7 +358,7 @@ else
SITE_PACKAGES="$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')"
mkdir -p "$JIT_HOOK_BUILD"
pushd "$JIT_HOOK_BUILD"
cmake "$JIT_HOOK_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch" -DPython_EXECUTABLE="$(which python)" \
cmake "$JIT_HOOK_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch;$SITE_PACKAGES" -DPython_EXECUTABLE="$(which python)" \
-DCMAKE_MODULE_PATH="$CUSTOM_TEST_MODULE_PATH" -DUSE_ROCM="$CUSTOM_TEST_USE_ROCM"
make VERBOSE=1
popd
@ -355,7 +370,7 @@ else
python --version
mkdir -p "$CUSTOM_BACKEND_BUILD"
pushd "$CUSTOM_BACKEND_BUILD"
cmake "$CUSTOM_BACKEND_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch" -DPython_EXECUTABLE="$(which python)" \
cmake "$CUSTOM_BACKEND_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch;$SITE_PACKAGES" -DPython_EXECUTABLE="$(which python)" \
-DCMAKE_MODULE_PATH="$CUSTOM_TEST_MODULE_PATH" -DUSE_ROCM="$CUSTOM_TEST_USE_ROCM"
make VERBOSE=1
popd

View File

@ -56,9 +56,29 @@ function assert_git_not_dirty() {
function pip_install_whl() {
# This is used to install PyTorch and other build artifacts wheel locally
# without using any network connection
python3 -mpip install --no-index --no-deps "$@"
# Convert the input arguments into an array
local args=("$@")
# Check if the first argument contains multiple paths separated by spaces
if [[ "${args[0]}" == *" "* ]]; then
# Split the string by spaces into an array
IFS=' ' read -r -a paths <<< "${args[0]}"
# Loop through each path and install individually
for path in "${paths[@]}"; do
echo "Installing $path"
python3 -mpip install --no-index --no-deps "$path"
done
else
# Loop through each argument and install individually
for path in "${args[@]}"; do
echo "Installing $path"
python3 -mpip install --no-index --no-deps "$path"
done
fi
}
function pip_install() {
# retry 3 times
# old versions of pip don't have the "--progress-bar" flag

View File

@ -289,6 +289,9 @@ test_python_shard() {
# Bare --include flag is not supported and quoting for lint ends up with flag not being interpreted correctly
# shellcheck disable=SC2086
# modify LD_LIBRARY_PATH to ensure it has the conda env.
# This set of tests has been shown to be buggy without it for the split-build
time python test/run_test.py --exclude-jit-executor --exclude-distributed-tests $INCLUDE_CLAUSE --shard "$1" "$NUM_TEST_SHARDS" --verbose $PYTHON_TEST_EXTRA_OPTION
assert_git_not_dirty
@ -1174,15 +1177,21 @@ test_executorch() {
pushd /executorch
# NB: We need to build ExecuTorch runner here and not inside the Docker image
# because it depends on PyTorch
export PYTHON_EXECUTABLE=python
export EXECUTORCH_BUILD_PYBIND=ON
export CMAKE_ARGS="-DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON"
# NB: We need to rebuild ExecuTorch runner here because it depends on PyTorch
# from the PR
# shellcheck disable=SC1091
source .ci/scripts/utils.sh
build_executorch_runner "cmake"
source .ci/scripts/setup-linux.sh cmake
echo "Run ExecuTorch unit tests"
pytest -v -n auto
# shellcheck disable=SC1091
LLVM_PROFDATA=llvm-profdata-12 LLVM_COV=llvm-cov-12 bash test/run_oss_cpp_tests.sh
echo "Run ExecuTorch regression tests for some models"
# NB: This is a sample model, more can be added here
export PYTHON_EXECUTABLE=python
# TODO(huydhn): Add more coverage here using ExecuTorch's gather models script
# shellcheck disable=SC1091
source .ci/scripts/test.sh mv3 cmake xnnpack-quantization-delegation ''

View File

@ -33,9 +33,9 @@ if [[ -z "$DOCKER_IMAGE" ]]; then
if [[ "$PACKAGE_TYPE" == conda ]]; then
export DOCKER_IMAGE="pytorch/conda-cuda"
elif [[ "$DESIRED_CUDA" == cpu ]]; then
export DOCKER_IMAGE="pytorch/manylinux-cpu"
export DOCKER_IMAGE="pytorch/manylinux:cpu"
else
export DOCKER_IMAGE="pytorch/manylinux-cuda${DESIRED_CUDA:2}"
export DOCKER_IMAGE="pytorch/manylinux-builder:${DESIRED_CUDA:2}"
fi
fi
@ -75,9 +75,9 @@ export PYTORCH_BUILD_NUMBER=1
TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt)
# Here PYTORCH_EXTRA_INSTALL_REQUIREMENTS is already set for the all the wheel builds hence append TRITON_CONSTRAINT
TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64' and python_version < '3.13'"
if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]]; then
# Only linux Python < 3.13 are supported wheels for triton
TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64' and python_version < '3.13'"
TRITON_REQUIREMENT="triton==${TRITON_VERSION}; ${TRITON_CONSTRAINT}"
if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then
TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt)
@ -87,11 +87,11 @@ if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:
fi
# Set triton via PYTORCH_EXTRA_INSTALL_REQUIREMENTS for triton rocm package
if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*rocm.* && $(uname) == "Linux" && "$DESIRED_PYTHON" != "3.12" ]]; then
TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}"
if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*rocm.* && $(uname) == "Linux" ]]; then
TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}; ${TRITON_CONSTRAINT}"
if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then
TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton-rocm.txt)
TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}+${TRITON_SHORTHASH}"
TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}+${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}"
fi
if [[ -z "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]]; then
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${TRITON_REQUIREMENT}"

View File

@ -52,6 +52,13 @@ inputs:
description: Hugging Face Hub token
required: false
default: ""
use_split_build:
description: |
[Experimental] Build a libtorch only wheel and build pytorch such that
are built from the libtorch wheel.
required: false
type: boolean
default: false
outputs:
docker-image:
value: ${{ steps.calculate-docker-image.outputs.docker-image }}
@ -144,6 +151,7 @@ runs:
DEBUG: ${{ inputs.build-with-debug == 'true' && '1' || '0' }}
OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }}
HUGGING_FACE_HUB_TOKEN: ${{ inputs.HUGGING_FACE_HUB_TOKEN }}
USE_SPLIT_BUILD: ${{ inputs.use_split_build }}
shell: bash
run: |
# detached container should get cleaned up by teardown_ec2_linux
@ -163,6 +171,7 @@ runs:
-e PR_LABELS \
-e OUR_GITHUB_JOB_ID \
-e HUGGING_FACE_HUB_TOKEN \
-e USE_SPLIT_BUILD \
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
--security-opt seccomp=unconfined \
--cap-add=SYS_PTRACE \
@ -183,7 +192,7 @@ runs:
- name: Store PyTorch Build Artifacts on S3
uses: seemethere/upload-artifact-s3@v5
if: inputs.build-generates-artifacts == 'true' && steps.build.outcome != 'skipped'
if: inputs.build-generates-artifacts == 'true' && steps.build.outcome != 'skipped' && inputs.use_split_build != 'true'
with:
name: ${{ inputs.build-environment }}
retention-days: 14
@ -191,6 +200,16 @@ runs:
path: artifacts.zip
s3-bucket: ${{ inputs.s3-bucket }}
- name: Store PyTorch Build Artifacts on S3 for split build
uses: seemethere/upload-artifact-s3@v5
if: inputs.build-generates-artifacts == 'true' && steps.build.outcome != 'skipped' && inputs.use_split_build == 'true'
with:
name: ${{ inputs.build-environment }}-experimental-split-build
retention-days: 14
if-no-files-found: error
path: artifacts.zip
s3-bucket: ${{ inputs.s3-bucket }}
- name: Upload sccache stats
if: steps.build.outcome != 'skipped'
uses: seemethere/upload-artifact-s3@v5

View File

@ -1 +1 @@
0dab1dd97709096e8129f8a08115ee83f64f2194
23512dbebd44a11eb84afbf53c3c071dd105297e

View File

@ -3,11 +3,11 @@
import json
import os
import re
from typing import Any, Optional
from typing import Any, cast, Dict, List, Optional
from urllib.error import HTTPError
from github_utils import gh_fetch_url, gh_post_pr_comment
from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
from trymerge import get_pr_commit_sha, GitHubPR
@ -19,6 +19,7 @@ REQUIRES_ISSUE = {
"critical",
"fixnewfeature",
}
RELEASE_BRANCH_REGEX = re.compile(r"release/(?P<version>.+)")
def parse_args() -> Any:
@ -58,6 +59,33 @@ def get_merge_commit_sha(repo: GitRepo, pr: GitHubPR) -> Optional[str]:
return commit_sha if pr.is_closed() else None
def get_release_version(onto_branch: str) -> Optional[str]:
"""
Return the release version if the target branch is a release branch
"""
m = re.match(RELEASE_BRANCH_REGEX, onto_branch)
return m.group("version") if m else ""
def get_tracker_issues(
org: str, project: str, onto_branch: str
) -> List[Dict[str, Any]]:
"""
Find the tracker issue from the repo. The tracker issue needs to have the title
like [VERSION] Release Tracker following the convention on PyTorch
"""
version = get_release_version(onto_branch)
if not version:
return []
tracker_issues = gh_query_issues_by_labels(org, project, labels=["release tracker"])
if not tracker_issues:
return []
# Figure out the tracker issue from the list by looking at the title
return [issue for issue in tracker_issues if version in issue.get("title", "")]
def cherry_pick(
github_actor: str,
repo: GitRepo,
@ -77,17 +105,49 @@ def cherry_pick(
)
try:
org, project = repo.gh_owner_and_name()
cherry_pick_pr = ""
if not dry_run:
org, project = repo.gh_owner_and_name()
cherry_pick_pr = submit_pr(repo, pr, cherry_pick_branch, onto_branch)
msg = f"The cherry pick PR is at {cherry_pick_pr}"
if fixes:
msg += f" and it is linked with issue {fixes}"
elif classification in REQUIRES_ISSUE:
msg += f" and it is recommended to link a {classification} cherry pick PR with an issue"
tracker_issues_comments = []
tracker_issues = get_tracker_issues(org, project, onto_branch)
for issue in tracker_issues:
issue_number = int(str(issue.get("number", "0")))
if not issue_number:
continue
post_comment(org, project, pr.pr_num, msg)
res = cast(
Dict[str, Any],
post_tracker_issue_comment(
org,
project,
issue_number,
pr.pr_num,
cherry_pick_pr,
classification,
fixes,
dry_run,
),
)
comment_url = res.get("html_url", "")
if comment_url:
tracker_issues_comments.append(comment_url)
msg = f"The cherry pick PR is at {cherry_pick_pr}"
if fixes:
msg += f" and it is linked with issue {fixes}."
elif classification in REQUIRES_ISSUE:
msg += f" and it is recommended to link a {classification} cherry pick PR with an issue."
if tracker_issues_comments:
msg += " The following tracker issues are updated:\n"
for tracker_issues_comment in tracker_issues_comments:
msg += f"* {tracker_issues_comment}\n"
post_pr_comment(org, project, pr.pr_num, msg, dry_run)
finally:
if current_branch:
@ -159,7 +219,9 @@ def submit_pr(
raise RuntimeError(msg) from error
def post_comment(org: str, project: str, pr_num: int, msg: str) -> None:
def post_pr_comment(
org: str, project: str, pr_num: int, msg: str, dry_run: bool = False
) -> List[Dict[str, Any]]:
"""
Post a comment on the PR itself to point to the cherry picking PR when success
or print the error when failure
@ -182,7 +244,35 @@ def post_comment(org: str, project: str, pr_num: int, msg: str) -> None:
comment = "\n".join(
(f"### Cherry picking #{pr_num}", f"{msg}", "", f"{internal_debugging}")
)
gh_post_pr_comment(org, project, pr_num, comment)
return gh_post_pr_comment(org, project, pr_num, comment, dry_run)
def post_tracker_issue_comment(
org: str,
project: str,
issue_num: int,
pr_num: int,
cherry_pick_pr: str,
classification: str,
fixes: str,
dry_run: bool = False,
) -> List[Dict[str, Any]]:
"""
Post a comment on the tracker issue (if any) to record the cherry pick
"""
comment = "\n".join(
(
"Link to landed trunk PR (if applicable):",
f"* https://github.com/{org}/{project}/pull/{pr_num}",
"",
"Link to release branch PR:",
f"* {cherry_pick_pr}",
"",
"Criteria Category:",
" - ".join((classification.capitalize(), fixes.capitalize())),
)
)
return gh_post_pr_comment(org, project, issue_num, comment, dry_run)
def main() -> None:
@ -214,7 +304,7 @@ def main() -> None:
except RuntimeError as error:
if not args.dry_run:
post_comment(org, project, pr_num, str(error))
post_pr_comment(org, project, pr_num, str(error))
else:
raise error

View File

@ -347,10 +347,6 @@ def generate_wheels_matrix(
for python_version in python_versions:
for arch_version in arches:
gpu_arch_type = arch_type(arch_version)
# Disable py3.12 builds for ROCm because of triton dependency
# on llnl-hatchet, which doesn't have py3.12 wheels available
if gpu_arch_type == "rocm" and python_version == "3.12":
continue
gpu_arch_version = (
""
if arch_version == "cpu"

View File

@ -1,6 +1,6 @@
import json
from argparse import ArgumentParser
from typing import Any
from typing import Any, Tuple
from github import Auth, Github
from github.Issue import Issue
@ -9,6 +9,8 @@ from github.Issue import Issue
WORKFLOW_LABEL_META = "" # use meta runners
WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation
LABEL_TYPE_KEY = "label_type"
MESSAGE_KEY = "message"
MESSAGE = "" # Debug message to return to the caller
def parse_args() -> Any:
@ -48,45 +50,50 @@ def is_exception_branch(branch: str) -> bool:
return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"}
def get_workflow_type(issue: Issue, username: str) -> str:
def get_workflow_type(issue: Issue, username: str) -> Tuple[str, str]:
try:
user_list = issue.get_comments()[0].body.split()
if user_list[0] == "!":
print("LF Workflows are disabled for everyone. Using meta runners.")
return WORKFLOW_LABEL_META
MESSAGE = "LF Workflows are disabled for everyone. Using meta runners."
return WORKFLOW_LABEL_META, MESSAGE
elif user_list[0] == "*":
print("LF Workflows are enabled for everyone. Using LF runners.")
return WORKFLOW_LABEL_LF
MESSAGE = "LF Workflows are enabled for everyone. Using LF runners."
return WORKFLOW_LABEL_LF, MESSAGE
elif username in user_list:
print(f"LF Workflows are enabled for {username}. Using LF runners.")
return WORKFLOW_LABEL_LF
MESSAGE = f"LF Workflows are enabled for {username}. Using LF runners."
return WORKFLOW_LABEL_LF, MESSAGE
else:
print(f"LF Workflows are disabled for {username}. Using meta runners.")
return WORKFLOW_LABEL_META
MESSAGE = f"LF Workflows are disabled for {username}. Using meta runners."
return WORKFLOW_LABEL_META, MESSAGE
except Exception as e:
print(
f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}"
)
return WORKFLOW_LABEL_META
MESSAGE = f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}"
return WORKFLOW_LABEL_META, MESSAGE
def main() -> None:
args = parse_args()
if is_exception_branch(args.github_branch):
print(f"Exception branch: '{args.github_branch}', using meta runners")
output = {LABEL_TYPE_KEY: WORKFLOW_LABEL_META}
output = {
LABEL_TYPE_KEY: WORKFLOW_LABEL_META,
MESSAGE_KEY: f"Exception branch: '{args.github_branch}', using meta runners",
}
else:
try:
gh = get_gh_client(args.github_token)
# The default issue we use - https://github.com/pytorch/test-infra/issues/5132
issue = get_issue(gh, args.github_repo, args.github_issue)
output = {LABEL_TYPE_KEY: get_workflow_type(issue, args.github_user)}
label_type, message = get_workflow_type(issue, args.github_user)
output = {
LABEL_TYPE_KEY: label_type,
MESSAGE_KEY: message,
}
except Exception as e:
print(f"Failed to get issue. Falling back to meta runners. Exception: {e}")
output = {LABEL_TYPE_KEY: WORKFLOW_LABEL_META}
output = {
LABEL_TYPE_KEY: WORKFLOW_LABEL_META,
MESSAGE_KEY: f"Failed to get issue. Falling back to meta runners. Exception: {e}",
}
json_output = json.dumps(output)
print(json_output)

View File

@ -202,3 +202,12 @@ def gh_update_pr_state(org: str, repo: str, pr_num: int, state: str = "open") ->
)
else:
raise
def gh_query_issues_by_labels(
org: str, repo: str, labels: List[str], state: str = "open"
) -> List[Dict[str, Any]]:
url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues"
return gh_fetch_json(
url, method="GET", params={"labels": ",".join(labels), "state": state}
)

View File

@ -56,6 +56,13 @@ on:
required: false
type: string
default: ""
use_split_build:
description: |
[Experimental] Build a libtorch only wheel and build pytorch such that
are built from the libtorch wheel.
required: false
type: boolean
default: false
secrets:
HUGGING_FACE_HUB_TOKEN:
required: false
@ -107,3 +114,4 @@ jobs:
aws-role-to-assume: ${{ inputs.aws-role-to-assume }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
use_split_build: ${{ inputs.use_split_build }}

View File

@ -64,6 +64,14 @@ on:
required: false
type: string
default: ""
use_split_build:
description: |
[Experimental] Build a libtorch only wheel and build pytorch such that
are built from the libtorch wheel.
required: false
type: boolean
default: false
secrets:
HUGGING_FACE_HUB_TOKEN:
required: false
@ -181,6 +189,7 @@ jobs:
DEBUG: ${{ inputs.build-with-debug && '1' || '0' }}
OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }}
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
USE_SPLIT_BUILD: ${{ inputs.use_split_build }}
run: |
# detached container should get cleaned up by teardown_ec2_linux
container_name=$(docker run \
@ -199,6 +208,7 @@ jobs:
-e PR_LABELS \
-e OUR_GITHUB_JOB_ID \
-e HUGGING_FACE_HUB_TOKEN \
-e USE_SPLIT_BUILD \
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
--security-opt seccomp=unconfined \
--cap-add=SYS_PTRACE \

View File

@ -2410,3 +2410,209 @@ jobs:
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
manywheel-py3_12-rocm6_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.0
GPU_ARCH_VERSION: 6.0
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-rocm6_0
build_environment: linux-binary-manywheel
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_12-rocm6_0-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs: manywheel-py3_12-rocm6_0-build
runs-on: linux.rocm.gpu
timeout-minutes: 240
env:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.0
GPU_ARCH_VERSION: 6.0
GPU_ARCH_TYPE: rocm
SKIP_ALL_TESTS: 1
DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main
DESIRED_PYTHON: "3.12"
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
- uses: actions/download-artifact@v3
name: Download Build Artifacts
with:
name: manywheel-py3_12-rocm6_0
path: "${{ runner.temp }}/artifacts/"
- name: Checkout PyTorch
uses: malfet/checkout@silent-checkout
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
quiet-checkout: true
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Checkout pytorch/builder
uses: malfet/checkout@silent-checkout
with:
ref: main
submodules: recursive
repository: pytorch/builder
path: builder
quiet-checkout: true
- name: Clean pytorch/builder checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: builder
- name: ROCm set GPU_FLAG
run: |
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
- name: Pull Docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: pytorch/manylinux-builder:rocm6.0-main
- name: Test Pytorch binary
uses: ./pytorch/.github/actions/test-pytorch-binary
- name: Teardown ROCm
uses: ./.github/actions/teardown-rocm
manywheel-py3_12-rocm6_0-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: manywheel-py3_12-rocm6_0-test
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.0
GPU_ARCH_VERSION: 6.0
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-rocm6_0
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
manywheel-py3_12-rocm6_1-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.1
GPU_ARCH_VERSION: 6.1
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-rocm6_1
build_environment: linux-binary-manywheel
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_12-rocm6_1-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs: manywheel-py3_12-rocm6_1-build
runs-on: linux.rocm.gpu
timeout-minutes: 240
env:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.1
GPU_ARCH_VERSION: 6.1
GPU_ARCH_TYPE: rocm
SKIP_ALL_TESTS: 1
DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main
DESIRED_PYTHON: "3.12"
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
- uses: actions/download-artifact@v3
name: Download Build Artifacts
with:
name: manywheel-py3_12-rocm6_1
path: "${{ runner.temp }}/artifacts/"
- name: Checkout PyTorch
uses: malfet/checkout@silent-checkout
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
quiet-checkout: true
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Checkout pytorch/builder
uses: malfet/checkout@silent-checkout
with:
ref: main
submodules: recursive
repository: pytorch/builder
path: builder
quiet-checkout: true
- name: Clean pytorch/builder checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: builder
- name: ROCm set GPU_FLAG
run: |
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
- name: Pull Docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: pytorch/manylinux-builder:rocm6.1-main
- name: Test Pytorch binary
uses: ./pytorch/.github/actions/test-pytorch-binary
- name: Teardown ROCm
uses: ./.github/actions/teardown-rocm
manywheel-py3_12-rocm6_1-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: manywheel-py3_12-rocm6_1-test
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.1
GPU_ARCH_VERSION: 6.1
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-rocm6_1
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml

View File

@ -28,7 +28,8 @@ jobs:
cuda-arch-list: '8.6'
test-matrix: |
{ include: [
{ config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" },
{ config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
@ -95,7 +96,8 @@ jobs:
cuda-arch-list: '8.6'
test-matrix: |
{ include: [
{ config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
]}
linux-focal-cuda12_4-py3_12-gcc9-inductor-test:

View File

@ -48,7 +48,8 @@ jobs:
cuda-arch-list: '8.6'
test-matrix: |
{ include: [
{ config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" },
{ config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
@ -90,7 +91,8 @@ jobs:
cuda-arch-list: '8.6'
test-matrix: |
{ include: [
{ config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
]}
linux-focal-cuda12_1-py3_12-gcc9-inductor-test:

View File

@ -19,10 +19,10 @@ jobs:
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
timeout: 120
runner: linux.2xlarge
runner: lf.linux.2xlarge
docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter
# NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout
# to run git rev-parse HEAD~:.ci/docker when a new image is needed
# to run git rev-parse HEAD~:.ci/docker when a new image is needed.
fetch-depth: 0
submodules: true
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
@ -35,7 +35,7 @@ jobs:
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
timeout: 120
runner: linux.2xlarge
runner: lf.linux.2xlarge
docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter
# NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout
# to run git rev-parse HEAD~:.ci/docker when a new image is needed
@ -49,7 +49,7 @@ jobs:
quick-checks:
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
runner: linux.2xlarge
runner: lf.linux.2xlarge
docker-image: pytorch-linux-focal-linter
fetch-depth: 0
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}

View File

@ -73,7 +73,6 @@ jobs:
{ config: "default", shard: 3, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" },
{ config: "deploy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" },
{ config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" },
{ config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "linux.2xlarge" },
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" },

View File

@ -487,3 +487,31 @@ jobs:
build-environment: linux-jammy-py3-clang12-executorch
docker-image: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.test-matrix }}
linux-focal-cuda12_1-py3_10-gcc9-experimental-split-build:
name: linux-focal-cuda12.1-py3.10-gcc9-experimental-split-build
uses: ./.github/workflows/_linux-build-label.yml
with:
use_split_build: true
build-environment: linux-focal-cuda12.1-py3.10-gcc9
docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 3, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" },
]}
linux-focal-cuda12_4-py3_10-gcc9-experimental-split-build-test:
name: linux-focal-cuda12.1-py3.10-gcc9-experimental-split-build
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-focal-cuda12_1-py3_10-gcc9-experimental-split-build
- target-determination
with:
timeout-minutes: 360
build-environment: linux-focal-cuda12.1-py3.10-gcc9-experimental-split-build
docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-experimental-split-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-experimental-split-build.outputs.test-matrix }}

View File

@ -97,7 +97,8 @@ jobs:
docker-image-name: pytorch-linux-focal-py3.8-clang10
test-matrix: |
{ include: [
{ config: "slow", shard: 1, num_shards: 1, runner: "linux.2xlarge" },
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.2xlarge" },
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.2xlarge" },
]}
linux-focal-py3_8-clang10-test:
@ -119,7 +120,8 @@ jobs:
docker-image-name: pytorch-linux-focal-rocm-n-py3
test-matrix: |
{ include: [
{ config: "slow", shard: 1, num_shards: 1, runner: "linux.rocm.gpu" },
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.rocm.gpu" },
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu" },
]}
linux-focal-rocm6_1-py3_8-test:

View File

@ -1390,169 +1390,6 @@ exclude_patterns = [
'torch/contrib/_tensorboard_vis.py',
"torch/cuda/_gpu_trace.py",
'torch/cuda/_memory_viz.py', # mypy: Value of type "object" is not indexable
'torch/distributed/__init__.py',
'torch/distributed/_composable_state.py',
'torch/distributed/_shard/__init__.py',
'torch/distributed/_shard/_utils.py',
'torch/distributed/_shard/api.py',
'torch/distributed/_shard/checkpoint/__init__.py',
'torch/distributed/_shard/common_op_utils.py',
'torch/distributed/_shard/metadata.py',
'torch/distributed/_shard/op_registry_utils.py',
'torch/distributed/_shard/sharded_optim/__init__.py',
'torch/distributed/_shard/sharded_optim/api.py',
'torch/distributed/_shard/sharded_tensor/__init__.py',
'torch/distributed/_shard/sharded_tensor/_ops/__init__.py',
'torch/distributed/_shard/sharded_tensor/_ops/_common.py',
'torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py',
'torch/distributed/_shard/sharded_tensor/_ops/init.py',
'torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py',
'torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py',
'torch/distributed/_shard/sharded_tensor/api.py',
'torch/distributed/_shard/sharded_tensor/logger.py',
'torch/distributed/_shard/sharded_tensor/logging_handlers.py',
'torch/distributed/_shard/sharded_tensor/metadata.py',
'torch/distributed/_shard/sharded_tensor/reshard.py',
'torch/distributed/_shard/sharded_tensor/shard.py',
'torch/distributed/_shard/sharded_tensor/utils.py',
'torch/distributed/_shard/sharder.py',
'torch/distributed/_shard/sharding_plan/__init__.py',
'torch/distributed/_shard/sharding_plan/api.py',
'torch/distributed/_shard/sharding_spec/__init__.py',
'torch/distributed/_shard/sharding_spec/_internals.py',
'torch/distributed/_shard/sharding_spec/api.py',
'torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py',
'torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py',
'torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py',
'torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py',
'torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py',
'torch/distributed/_sharded_tensor/__init__.py',
'torch/distributed/_sharding_spec/__init__.py',
'torch/distributed/_tools/__init__.py',
'torch/distributed/_tools/memory_tracker.py',
'torch/distributed/algorithms/__init__.py',
'torch/distributed/algorithms/_checkpoint/__init__.py',
'torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py',
'torch/distributed/algorithms/_comm_hooks/__init__.py',
'torch/distributed/algorithms/_comm_hooks/default_hooks.py',
'torch/distributed/algorithms/_optimizer_overlap/__init__.py',
'torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py',
'torch/distributed/algorithms/_quantization/__init__.py',
'torch/distributed/algorithms/_quantization/quantization.py',
'torch/distributed/algorithms/ddp_comm_hooks/__init__.py',
'torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py',
'torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py',
'torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py',
'torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py',
'torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py',
'torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py',
'torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py',
'torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py',
'torch/distributed/algorithms/join.py',
'torch/distributed/algorithms/model_averaging/__init__.py',
'torch/distributed/algorithms/model_averaging/averagers.py',
'torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py',
'torch/distributed/algorithms/model_averaging/utils.py',
'torch/distributed/argparse_util.py',
'torch/distributed/autograd/__init__.py',
'torch/distributed/benchmarks/benchmark_ddp_rpc.py',
'torch/distributed/c10d_logger.py',
'torch/distributed/collective_utils.py',
'torch/distributed/constants.py',
'torch/distributed/distributed_c10d.py',
'torch/distributed/elastic/__init__.py',
'torch/distributed/elastic/agent/__init__.py',
'torch/distributed/elastic/agent/server/__init__.py',
'torch/distributed/elastic/agent/server/api.py',
'torch/distributed/elastic/agent/server/local_elastic_agent.py',
'torch/distributed/elastic/events/__init__.py',
'torch/distributed/elastic/events/api.py',
'torch/distributed/elastic/events/handlers.py',
'torch/distributed/elastic/metrics/__init__.py',
'torch/distributed/elastic/metrics/api.py',
'torch/distributed/elastic/multiprocessing/__init__.py',
'torch/distributed/elastic/multiprocessing/api.py',
'torch/distributed/elastic/multiprocessing/errors/__init__.py',
'torch/distributed/elastic/multiprocessing/errors/error_handler.py',
'torch/distributed/elastic/multiprocessing/errors/handlers.py',
'torch/distributed/elastic/multiprocessing/redirects.py',
'torch/distributed/elastic/multiprocessing/tail_log.py',
'torch/distributed/elastic/rendezvous/__init__.py',
'torch/distributed/elastic/rendezvous/api.py',
'torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py',
'torch/distributed/elastic/rendezvous/dynamic_rendezvous.py',
'torch/distributed/elastic/rendezvous/etcd_rendezvous.py',
'torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py',
'torch/distributed/elastic/rendezvous/etcd_server.py',
'torch/distributed/elastic/rendezvous/etcd_store.py',
'torch/distributed/elastic/rendezvous/registry.py',
'torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py',
'torch/distributed/elastic/rendezvous/utils.py',
'torch/distributed/elastic/timer/__init__.py',
'torch/distributed/elastic/timer/api.py',
'torch/distributed/elastic/timer/file_based_local_timer.py',
'torch/distributed/elastic/timer/local_timer.py',
'torch/distributed/elastic/utils/__init__.py',
'torch/distributed/elastic/utils/api.py',
'torch/distributed/elastic/utils/data/__init__.py',
'torch/distributed/elastic/utils/data/cycling_iterator.py',
'torch/distributed/elastic/utils/data/elastic_distributed_sampler.py',
'torch/distributed/elastic/utils/distributed.py',
'torch/distributed/elastic/utils/log_level.py',
'torch/distributed/elastic/utils/logging.py',
'torch/distributed/elastic/utils/store.py',
'torch/distributed/examples/memory_tracker_example.py',
'torch/distributed/launch.py',
'torch/distributed/launcher/__init__.py',
'torch/distributed/launcher/api.py',
'torch/distributed/logging_handlers.py',
'torch/distributed/nn/__init__.py',
'torch/distributed/nn/api/__init__.py',
'torch/distributed/nn/api/remote_module.py',
'torch/distributed/nn/functional.py',
'torch/distributed/nn/jit/__init__.py',
'torch/distributed/nn/jit/instantiator.py',
'torch/distributed/nn/jit/templates/__init__.py',
'torch/distributed/nn/jit/templates/remote_module_template.py',
'torch/distributed/optim/__init__.py',
'torch/distributed/optim/apply_optimizer_in_backward.py',
'torch/distributed/optim/functional_adadelta.py',
'torch/distributed/optim/functional_adagrad.py',
'torch/distributed/optim/functional_adam.py',
'torch/distributed/optim/functional_adamax.py',
'torch/distributed/optim/functional_adamw.py',
'torch/distributed/optim/functional_rmsprop.py',
'torch/distributed/optim/functional_rprop.py',
'torch/distributed/optim/functional_sgd.py',
'torch/distributed/optim/named_optimizer.py',
'torch/distributed/optim/optimizer.py',
'torch/distributed/optim/post_localSGD_optimizer.py',
'torch/distributed/optim/utils.py',
'torch/distributed/optim/zero_redundancy_optimizer.py',
'torch/distributed/remote_device.py',
'torch/distributed/rendezvous.py',
'torch/distributed/rpc/__init__.py',
'torch/distributed/rpc/_testing/__init__.py',
'torch/distributed/rpc/_testing/faulty_agent_backend_registry.py',
'torch/distributed/rpc/_utils.py',
'torch/distributed/rpc/api.py',
'torch/distributed/rpc/backend_registry.py',
'torch/distributed/rpc/constants.py',
'torch/distributed/rpc/functions.py',
'torch/distributed/rpc/internal.py',
'torch/distributed/rpc/options.py',
'torch/distributed/rpc/rref_proxy.py',
'torch/distributed/rpc/server_process_global_profiler.py',
'torch/distributed/run.py',
'torch/distributed/tensor/__init__.py',
'torch/distributed/tensor/parallel/__init__.py',
'torch/distributed/tensor/parallel/_utils.py',
'torch/distributed/tensor/parallel/_view_with_dim_change.py',
'torch/distributed/tensor/parallel/api.py',
'torch/distributed/tensor/parallel/fsdp.py',
'torch/distributed/tensor/parallel/input_reshard.py',
'torch/distributed/tensor/parallel/multihead_attention_tp.py',
'torch/distributed/tensor/parallel/style.py',
'torch/fft/__init__.py',
'torch/func/__init__.py',
'torch/futures/__init__.py',

View File

@ -290,7 +290,7 @@ After the final RC is created. The following tasks should be performed :
* Create validation issue for the release, see for example [Validations for 2.1.2 release](https://github.com/pytorch/pytorch/issues/114904) and perform required validations.
* Run performance tests in [benchmark repository](https://github.com/pytorch/benchmark). Make sure there are no prerformance regressions.
* Run performance tests in [benchmark repository](https://github.com/pytorch/benchmark). Make sure there are no performance regressions.
* Prepare and stage PyPI binaries for promotion. This is done with this script:
[`pytorch/builder:release/pypi/promote_pypi_to_staging.sh`](https://github.com/pytorch/builder/blob/main/release/pypi/promote_pypi_to_staging.sh)
@ -429,12 +429,12 @@ need to support these particular versions of software.
## Operating Systems
Supported OS flavors are summarized in the table below:
| Operating System family | Architectrue | Notes |
| Operating System family | Architecture | Notes |
| --- | --- | --- |
| Linux | aarch64, x86_64 | Wheels are manylinux2014 compatible, i.e. they should be runnable on any Linux system with glibc-2.17 or above. |
| MacOS | arm64 | Builds should be compatible with MacOS 11 (Big Sur) or newer, but are actively tested against MacOS 14 (Sonoma). |
| MacOS | x86_64 | Requires MacOS Catalina or above, not supported after 2.2, see https://github.com/pytorch/pytorch/issues/114602 |
| Windows | x86_64 | Buils are compatible with Windows-10 or newer. |
| Windows | x86_64 | Builds are compatible with Windows-10 or newer. |
# Submitting Tutorials

View File

@ -473,6 +473,7 @@ endif()
if(USE_CUDA AND NOT USE_ROCM)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
if($ENV{ATEN_STATIC_CUDA})
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
${CUDA_LIBRARIES}

View File

@ -303,7 +303,7 @@ Tensor FunctionalInverses::_nested_view_from_buffer_inverse(const Tensor& base,
return Tensor();
}
Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional<Tensor>& lengths, int64_t ragged_idx, const c10::optional<Tensor>& min_seqlen, const c10::optional<Tensor>& max_seqlen) {
Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional<Tensor>& lengths, int64_t ragged_idx) {
auto values = at::_nested_get_values(mutated_view);
if (inverse_return_mode != InverseReturnMode::NeverView) {
return values;
@ -317,12 +317,7 @@ Tensor FunctionalInverses::_nested_get_values_inverse(const Tensor& base, const
auto lengths = at::_nested_get_lengths(base);
auto ragged_idx = at::_nested_get_ragged_idx(base);
auto dummy = at::_nested_get_jagged_dummy(base);
auto min_seqlen = at::_nested_get_min_seqlen(base);
auto max_seqlen = at::_nested_get_max_seqlen(base);
auto nt = at::_nested_view_from_jagged(
mutated_view, offsets, dummy, lengths, ragged_idx,
(min_seqlen.defined() ? c10::optional<Tensor>(min_seqlen) : c10::nullopt),
(max_seqlen.defined() ? c10::optional<Tensor>(max_seqlen) : c10::nullopt));
auto nt = at::_nested_view_from_jagged(mutated_view, offsets, dummy, lengths, ragged_idx);
if (inverse_return_mode != InverseReturnMode::NeverView) {
return nt;

View File

@ -765,115 +765,10 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented<T>()>> {
const ElementType& operator[](int idx) const = delete;
ElementType& operator[](int idx) = delete;
Vectorized<T> C10_ALWAYS_INLINE operator+(const Vectorized<T>& other) const {
return Vectorized<T>{_vec0 + other._vec0, _vec1 + other._vec1};
}
Vectorized<T> C10_ALWAYS_INLINE operator-(const Vectorized<T>& other) const {
return Vectorized<T>{_vec0 - other._vec0, _vec1 - other._vec1};
}
Vectorized<T> C10_ALWAYS_INLINE operator*(const Vectorized<T>& other) const {
return Vectorized<T>{_vec0 * other._vec0, _vec1 * other._vec1};
}
Vectorized<T> C10_ALWAYS_INLINE operator/(const Vectorized<T>& other) const {
return Vectorized<T>{_vec0 / other._vec0, _vec1 / other._vec1};
}
Vectorized<T> C10_ALWAYS_INLINE operator&(const Vectorized<T>& other) const {
return Vectorized<T>{
(vtype)(vecb0() & other.vecb0()), (vtype)(vecb1() & other.vecb1())};
}
Vectorized<T> C10_ALWAYS_INLINE operator|(const Vectorized<T>& other) const {
return Vectorized<T>{
(vtype)(vecb0() | other.vecb0()), (vtype)(vecb1() | other.vecb1())};
}
Vectorized<T> C10_ALWAYS_INLINE operator^(const Vectorized<T>& other) const {
return Vectorized<T>{
(vtype)(vecb0() ^ other.vecb0()), (vtype)(vecb1() ^ other.vecb1())};
}
Vectorized<T> C10_ALWAYS_INLINE operator<<(const Vectorized<T> &other) const {
constexpr ElementType max_shift = sizeof(ElementType) * CHAR_BIT;
ElementType a_array[Vectorized<T>::size()];
ElementType b_array[Vectorized<T>::size()];
ElementType c_array[Vectorized<T>::size()];
store(a_array);
other.store(b_array);
for (int i = 0; i != Vectorized<T>::size(); i++) {
T shift = b_array[i];
if ((static_cast<std::make_signed_t<T>>(shift) < 0) || (shift >= max_shift)) {
c_array[i] = 0;
} else {
c_array[i] = static_cast<std::make_unsigned_t<T>>(a_array[i]) << shift;
}
}
return loadu(c_array);
}
Vectorized<T> C10_ALWAYS_INLINE operator>>(const Vectorized<T> &other) const {
// right shift value to retain sign bit for signed and no bits for unsigned
constexpr ElementType max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v<T>;
ElementType a_array[Vectorized<T>::size()];
ElementType b_array[Vectorized<T>::size()];
ElementType c_array[Vectorized<T>::size()];
store(a_array);
other.store(b_array);
for (int i = 0; i != Vectorized<T>::size(); i++) {
T shift = b_array[i];
if ((static_cast<std::make_signed_t<T>>(shift) < 0) || (shift >= max_shift)) {
c_array[i] = a_array[i] >> max_shift;
} else {
c_array[i] = a_array[i] >> shift;
}
}
return loadu(c_array);
}
Vectorized<T> _not() const {
return {(vtype)vec_nor(vecb0(), vecb0()), (vtype)vec_nor(vecb1(), vecb1())};
}
Vectorized<T> C10_ALWAYS_INLINE operator==(const Vectorized<T>& other) const {
return Vectorized<T>{
vec_cmpeq(_vec0, other._vec0), vec_cmpeq(_vec1, other._vec1)};
}
Vectorized<T> C10_ALWAYS_INLINE operator!=(const Vectorized<T>& other) const {
return Vectorized<T>{
vec_cmpeq(_vec0, other._vec0), vec_cmpeq(_vec1, other._vec1)}
._not();
}
Vectorized<T> C10_ALWAYS_INLINE operator>(const Vectorized<T>& other) const {
return Vectorized<T>{
vec_cmpgt(_vec0, other._vec0), vec_cmpgt(_vec1, other._vec1)};
}
Vectorized<T> C10_ALWAYS_INLINE operator>=(const Vectorized<T>& other) const {
return Vectorized<T>{
vec_cmpge(_vec0, other._vec0), vec_cmpge(_vec1, other._vec1)};
}
Vectorized<T> C10_ALWAYS_INLINE operator<(const Vectorized<T>& other) const {
return Vectorized<T>{
vec_cmplt(_vec0, other._vec0), vec_cmplt(_vec1, other._vec1)};
}
Vectorized<T> C10_ALWAYS_INLINE operator<=(const Vectorized<T>& other) const {
return Vectorized<T>{
vec_cmple(_vec0, other._vec0), vec_cmple(_vec1, other._vec1)};
}
Vectorized<T> C10_ALWAYS_INLINE eq(const Vectorized<T>& other) const {
return (*this == other) & Vectorized<T>((T)1.0);
}
@ -1410,30 +1305,153 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented<T>()>> {
}
};
template <>
inline Vectorized<int64_t> operator~(const Vectorized<int64_t>& a) {
return a._not();
}
#define ZVECTOR_OPERATORS(typex) \
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator+(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec0() + b.vec0(), a.vec1() + b.vec1()}; \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator-(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec0() - b.vec0(), a.vec1() - b.vec1()}; \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator*(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec0() * b.vec0(), a.vec1() * b.vec1()}; \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator/(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator&(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{ \
(Vectorized<typex>::vtype)(a.vecb0() & b.vecb0()), \
(Vectorized<typex>::vtype)(a.vecb1() & b.vecb1())}; \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator|(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{ \
(Vectorized<typex>::vtype)(a.vecb0() | b.vecb0()), \
(Vectorized<typex>::vtype)(a.vecb1() | b.vecb1())}; \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator^(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{ \
(Vectorized<typex>::vtype)(a.vecb0() ^ b.vecb0()), \
(Vectorized<typex>::vtype)(a.vecb1() ^ b.vecb1())}; \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator==(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{ \
vec_cmpeq(a.vec0(), b.vec0()), vec_cmpeq(a.vec1(), b.vec1())}; \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator!=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{ \
vec_cmpeq(a.vec0(), b.vec0()), vec_cmpeq(a.vec1(), b.vec1())} \
._not(); \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator>(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{ \
vec_cmpgt(a.vec0(), b.vec0()), vec_cmpgt(a.vec1(), b.vec1())}; \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator>=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{ \
vec_cmpge(a.vec0(), b.vec0()), vec_cmpge(a.vec1(), b.vec1())}; \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator<(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{ \
vec_cmplt(a.vec0(), b.vec0()), vec_cmplt(a.vec1(), b.vec1())}; \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator<=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{ \
vec_cmple(a.vec0(), b.vec0()), vec_cmple(a.vec1(), b.vec1())}; \
}
template <>
inline Vectorized<int32_t> operator~(const Vectorized<int32_t>& a) {
return a._not();
}
ZVECTOR_OPERATORS(float)
ZVECTOR_OPERATORS(double)
ZVECTOR_OPERATORS(int8_t)
ZVECTOR_OPERATORS(uint8_t)
ZVECTOR_OPERATORS(uint16_t)
ZVECTOR_OPERATORS(int16_t)
ZVECTOR_OPERATORS(int32_t)
ZVECTOR_OPERATORS(int64_t)
template <>
inline Vectorized<int16_t> operator~(const Vectorized<int16_t>& a) {
return a._not();
}
#undef ZVECTOR_OPERATORS
template <>
inline Vectorized<int8_t> operator~(const Vectorized<int8_t>& a) {
return a._not();
}
#define ZVECTOR_OPERATORS(typex) \
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator<<(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
constexpr Vectorized<typex>::ElementType max_shift \
= sizeof(Vectorized<typex>::ElementType) * CHAR_BIT; \
\
Vectorized<typex>::ElementType a_array[Vectorized<typex>::size()]; \
Vectorized<typex>::ElementType b_array[Vectorized<typex>::size()]; \
Vectorized<typex>::ElementType c_array[Vectorized<typex>::size()]; \
\
a.store(a_array); \
b.store(b_array); \
\
for (int i = 0; i != Vectorized<typex>::size(); i++) { \
typex shift = b_array[i]; \
if ((static_cast<std::make_signed_t<typex>>(shift) < 0) || (shift >= max_shift)) { \
c_array[i] = 0; \
} else { \
c_array[i] = static_cast<std::make_unsigned_t<typex>>(a_array[i]) << shift; \
} \
} \
\
return Vectorized<typex>::loadu(c_array); \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator>>(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
/* right shift value to retain sign bit for signed and no bits for unsigned */ \
constexpr Vectorized<typex>::ElementType max_shift \
= sizeof(typex) * CHAR_BIT - std::is_signed_v<typex>; \
\
Vectorized<typex>::ElementType a_array[Vectorized<typex>::size()]; \
Vectorized<typex>::ElementType b_array[Vectorized<typex>::size()]; \
Vectorized<typex>::ElementType c_array[Vectorized<typex>::size()]; \
\
a.store(a_array); \
b.store(b_array); \
\
for (int i = 0; i != Vectorized<typex>::size(); i++) { \
typex shift = b_array[i]; \
if ((static_cast<std::make_signed_t<typex>>(shift) < 0) || (shift >= max_shift)) { \
c_array[i] = a_array[i] >> max_shift; \
} else { \
c_array[i] = a_array[i] >> shift; \
} \
} \
\
return Vectorized<typex>::loadu(c_array); \
} \
\
template <> \
inline Vectorized<typex> operator~(const Vectorized<typex>& a) { \
return a._not(); \
}
template <>
inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) {
return a._not();
}
ZVECTOR_OPERATORS(int8_t)
ZVECTOR_OPERATORS(uint8_t)
ZVECTOR_OPERATORS(uint16_t)
ZVECTOR_OPERATORS(int16_t)
ZVECTOR_OPERATORS(int32_t)
ZVECTOR_OPERATORS(int64_t)
#undef ZVECTOR_OPERATORS
#define DEFINE_MAXMIN_FUNCS(operand_type) \
template <> \
@ -1976,55 +1994,6 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_quant<T>()>> {
return Vectorized<U>{ret};
}
Vectorized<T> C10_ALWAYS_INLINE operator+(const Vectorized<T>& other) const {
return Vectorized<T>{_vec + other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator-(const Vectorized<T>& other) const {
return Vectorized<T>{_vec - other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator*(const Vectorized<T>& other) const {
return Vectorized<T>{_vec * other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator/(const Vectorized<T>& other) const {
return Vectorized<T>{_vec / other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator&(const Vectorized<T>& other) const {
return Vectorized<T>{_vec & other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator|(const Vectorized<T>& other) const {
return Vectorized<T>{_vec | other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator^(const Vectorized<T>& other) const {
return Vectorized<T>{_vec ^ other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator==(const Vectorized<T>& other) const {
return Vectorized<T>{_vec == other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator!=(const Vectorized<T>& other) const {
return Vectorized<T>{_vec != other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator>(const Vectorized<T>& other) const {
return Vectorized<T>{_vec > other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator>=(const Vectorized<T>& other) const {
return Vectorized<T>{_vec >= other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator<(const Vectorized<T>& other) const {
return Vectorized<T>{_vec < other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator<=(const Vectorized<T>& other) const {
return Vectorized<T>{_vec <= other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE eq(const Vectorized<T>& other) const {
return Vectorized<T>{_vec.eq(other._vec)};
}
@ -2061,6 +2030,72 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_quant<T>()>> {
}
};
#define ZVECTOR_OPERATORS(typex) \
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator+(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() + b.vec()}; \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator-(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() - b.vec()}; \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator*(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() * b.vec()}; \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator/(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() / b.vec()}; \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator&(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() & b.vec()}; \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator|(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() | b.vec()}; \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator^(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() ^ b.vec()}; \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator==(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() == b.vec()}; \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator!=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() != b.vec()}; \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator>(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() > b.vec()}; \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator>=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() >= b.vec()}; \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator<(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() < b.vec()}; \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator<=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() <= b.vec()}; \
}
ZVECTOR_OPERATORS(c10::qint32)
ZVECTOR_OPERATORS(c10::qint8)
ZVECTOR_OPERATORS(c10::quint8)
#undef ZVECTOR_OPERATORS
DEFINE_CLAMP_MAXMIN_FUNCS(c10::quint8)
DEFINE_CLAMP_MAXMIN_FUNCS(c10::qint8)
DEFINE_CLAMP_MAXMIN_FUNCS(c10::qint32)
@ -2364,35 +2399,6 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
return Vectorized<T>{a00, a01};
}
Vectorized<T> C10_ALWAYS_INLINE operator+(const Vectorized<T>& other) const {
return Vectorized<T>{_vec + other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator-(const Vectorized<T>& other) const {
return Vectorized<T>{_vec - other._vec};
}
Vectorized<T> inline operator*(const Vectorized<T>& b) const {
//(a + bi) * (c + di) = (ac - bd) + (ad + bc)i
vinner_type bv = b.vec();
#if !defined(ZVECTOR_SIMULATE_X86_MULT)
// this is more z arch friendly than simulating horizontal from x86
vinner_type vi = bv.mergeo();
vinner_type vr = bv.mergee();
vi = vi ^ rsign_mask<underline_type>();
vinner_type ret = _vec * vr;
vinner_type vx_swapped = _vec.swapped();
ret = fmadd(vx_swapped, vi, ret);
#else
vinner_type ac_bd = _vec * b;
vinner_type d_c = bv.swapped();
d_c = d_c ^ isign_mask<underline_type>();
vinner_type ad_bc = _vec * d_c;
vinner_type ret = vinner_type::horizontal_sub_perm(ac_bd, ad_bc);
#endif
return Vectorized<T>{ret};
}
template <
typename U = T,
std::enable_if_t<std::is_same<U, c10::complex<float>>::value, int> = 0>
@ -2418,29 +2424,6 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
return { v0, v1 };
}
Vectorized<T> inline operator/(const Vectorized<T>& b) const {
// Unfortunately, this breaks some tests
// Implement it like it's done for avx2
auto fabs_cd = b.vec().abs(); // |c| |d|
auto fabs_dc = fabs_cd.swapped(); // |d| |c|
auto scale = vinner_type {1.0} / maximum(fabs_cd, fabs_dc); // 1/sc 1/sc
auto a2 = vec() * scale; // a/sc b/sc
auto b2 = b.vec() * scale; // c/sc d/sc
auto acbd2 = a2 * b2; // ac/sc^2 bd/sc^2
auto dc2 = b2.swapped(); // d/sc c/sc
dc2 = Vectorized<T>::real_neg(dc2); // -d/|c,d| c/sc
auto adbc2 = a2 * dc2; // -ad/sc^2 bc/sc^2
auto sum1 = acbd2 + acbd2.swapped(); // (ac+bd)/sc^2 (ac+bd)/sc^2
auto sum2 = adbc2 + adbc2.swapped(); // (bc-ad)/sc^2 (bc-ad)/sc^2
auto res2 = vinner_type::mergee(sum1, sum2); // (ac+bd)/sc^2 (bc-ad)/sc^2
// get the denominator
auto denom2 = Vectorized<T>{b2}.abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
res2 = res2 / denom2;
return Vectorized<T>{ res2 };
}
Vectorized<T> angle2_() const {
auto b_a = _vec.swapped(); // b a
return Vectorized<T>{_vec.atan2(b_a).swapped()};
@ -2528,25 +2511,6 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
return Vectorized<T>{_vec.trunc()};
}
Vectorized<T> C10_ALWAYS_INLINE operator&(const Vectorized<T>& other) const {
return Vectorized<T>{_vec & other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator|(const Vectorized<T>& other) const {
return Vectorized<T>{_vec | other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator^(const Vectorized<T>& other) const {
return Vectorized<T>{_vec ^ other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator==(const Vectorized<T>& other) const {
return Vectorized<T>{_vec == other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE operator!=(const Vectorized<T>& other) const {
return Vectorized<T>{_vec != other._vec};
}
Vectorized<T> C10_ALWAYS_INLINE eq(const Vectorized<T>& other) const {
auto eq = _vec.eq(other._vec); // compares real and imag individually
// If both real numbers and imag numbers are equal, then the complex numbers are equal
@ -2648,22 +2612,6 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
return sqrt().reciprocal();
}
Vectorized<T> operator<(const Vectorized<T>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<T> operator<=(const Vectorized<T>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<T> operator>(const Vectorized<T>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<T> operator>=(const Vectorized<T>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<T> lt(const Vectorized<T>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
@ -2681,6 +2629,101 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
}
};
#define ZVECTOR_OPERATORS(typex) \
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator+(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() + b.vec()}; \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator-(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() - b.vec()}; \
} \
\
template <> \
Vectorized<typex> inline operator*(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
/* (a + bi) * (c + di) = (ac - bd) + (ad + bc)i */ \
Vectorized<typex>::vinner_type bv = b.vec(); \
\
/* this is more z arch friendly than simulating horizontal from x86 */ \
Vectorized<typex>::vinner_type vi = bv.mergeo(); \
Vectorized<typex>::vinner_type vr = bv.mergee(); \
vi = vi ^ Vectorized<typex>::vinner_type(rsign_mask<Vectorized<typex>::underline_type>()); \
Vectorized<typex>::vinner_type ret = a.vec() * vr; \
Vectorized<typex>::vinner_type vx_swapped = a.vec().swapped(); \
ret = fmadd(vx_swapped, vi, ret); \
\
return Vectorized<typex>{ret}; \
} \
\
template <> \
Vectorized<typex> inline operator/(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
/* Unfortunately, this breaks some tests */ \
/* Implement it like it's done for avx2 */ \
auto fabs_cd = b.vec().abs(); /* |c| |d| */ \
auto fabs_dc = fabs_cd.swapped(); /* |d| |c| */ \
auto scale = Vectorized<typex>::vinner_type {1.0} / maximum(fabs_cd, fabs_dc); /* 1/sc 1/sc */ \
auto a2 = a.vec() * scale; /* a/sc b/sc */ \
auto b2 = b.vec() * scale; /* c/sc d/sc */ \
auto acbd2 = a2 * b2; /* ac/sc^2 bd/sc^2 */ \
\
auto dc2 = b2.swapped(); /* d/sc c/sc */ \
dc2 = Vectorized<typex>::real_neg(dc2); /* -d/|c,d| c/sc */ \
auto adbc2 = a2 * dc2; /* -ad/sc^2 bc/sc^2 */ \
auto sum1 = acbd2 + acbd2.swapped(); /* (ac+bd)/sc^2 (ac+bd)/sc^2 */ \
auto sum2 = adbc2 + adbc2.swapped(); /* (bc-ad)/sc^2 (bc-ad)/sc^2 */ \
auto res2 = Vectorized<typex>::vinner_type::mergee(sum1, sum2); /* (ac+bd)/sc^2 (bc-ad)/sc^2 */ \
\
/* get the denominator */ \
Vectorized<typex>::vinner_type denom2 = Vectorized<typex>{b2}.abs_2_(); /* (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 */ \
res2 = res2 / denom2; \
return Vectorized<typex>{ res2 }; \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator&(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() & b.vec()}; \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator|(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() | b.vec()}; \
} \
\
template <> \
Vectorized<typex> C10_ALWAYS_INLINE operator^(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() ^ b.vec()}; \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator==(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() == b.vec()}; \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator!=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
return Vectorized<typex>{a.vec() != b.vec()}; \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator<(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
TORCH_CHECK(false, "not supported for complex numbers"); \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator<=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
TORCH_CHECK(false, "not supported for complex numbers"); \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator>(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
TORCH_CHECK(false, "not supported for complex numbers"); \
} \
\
Vectorized<typex> C10_ALWAYS_INLINE operator>=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
TORCH_CHECK(false, "not supported for complex numbers"); \
}
ZVECTOR_OPERATORS(c10::complex<float>)
ZVECTOR_OPERATORS(c10::complex<double>)
#undef ZVECTOR_OPERATORS
template <typename T, std::enable_if_t<(sizeof(T) == 8), int> = 0>
std::pair<Vectorized<T>, Vectorized<T>> inline inner_interleave2(
const Vectorized<T>& a,

View File

@ -334,7 +334,13 @@ static inline __device__ void gpuAtomicAddNoReturn(double *address, double val)
/* Special case fp32 atomic. */
#if defined(USE_ROCM)
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { atomicAddNoRet(address, val); }
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) {
#if defined(__gfx908__)
atomicAddNoRet(address, val);
#else
(void)unsafeAtomicAdd(address, val);
#endif
}
#else
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); }
#endif

View File

@ -152,9 +152,6 @@ void CUDAGeneratorState::register_graph(cuda::CUDAGraph* graph) {
* Unregisters a CUDA graph from the RNG state.
*/
void CUDAGeneratorState::unregister_graph(cuda::CUDAGraph* graph) {
// Ensures that the RNG state is not currently being captured.
at::cuda::assertNotCapturing(
"Cannot unregister the state during capturing stage.");
// Verify the graph was previously registered.
TORCH_CHECK(
registered_graphs_.find(graph) != registered_graphs_.end(),

View File

@ -170,6 +170,43 @@ CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *);
CUDA_STUB3(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int);
CUDA_STUB3(cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction);
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
CUresult CUDAAPI
cuTensorMapEncodeTiled(
CUtensorMap* tensorMap,
CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank,
void* globalAddress,
const cuuint64_t* globalDim,
const cuuint64_t* globalStrides,
const cuuint32_t* boxDim,
const cuuint32_t* elementStrides,
CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle,
CUtensorMapL2promotion l2Promotion,
CUtensorMapFloatOOBfill oobFill) {
auto fn = reinterpret_cast<decltype(&cuTensorMapEncodeTiled)>(
getCUDALibrary().sym(__func__));
if (!fn)
throw std::runtime_error("Can't get cuTensorMapEncodeTiled");
lazyNVRTC.cuTensorMapEncodeTiled = fn;
return fn(
tensorMap,
tensorDataType,
tensorRank,
globalAddress,
globalDim,
globalStrides,
boxDim,
elementStrides,
interleave,
swizzle,
l2Promotion,
oobFill);
}
#endif
// Irregularly shaped functions
CUresult CUDAAPI cuLaunchKernel(CUfunction f,
unsigned int gridDimX,

View File

@ -59,16 +59,25 @@ namespace at { namespace cuda {
_(cuLinkAddData) \
_(cuLinkComplete) \
_(cuFuncSetAttribute) \
_(cuFuncGetAttribute)
_(cuFuncGetAttribute) \
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
#define AT_FORALL_NVRTC_EXTENDED(_) \
AT_FORALL_NVRTC_BASE(_) \
_(cuTensorMapEncodeTiled)
#else
#define AT_FORALL_NVRTC_EXTENDED(_) \
AT_FORALL_NVRTC_BASE(_)
#endif
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
#define AT_FORALL_NVRTC(_) \
AT_FORALL_NVRTC_BASE(_) \
AT_FORALL_NVRTC_EXTENDED(_) \
_(nvrtcGetCUBINSize) \
_(nvrtcGetCUBIN)
#else
#define AT_FORALL_NVRTC(_) \
AT_FORALL_NVRTC_BASE(_)
AT_FORALL_NVRTC_EXTENDED(_)
#endif
#else

View File

@ -1,3 +1,7 @@
#include <cstdint>
#include <c10/util/Exception.h>
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/core/NamedTensor.h>
@ -10,6 +14,7 @@
#include <ATen/cuda/tunable/TunableGemm.h>
#include <ATen/native/Resize.h>
#include <c10/util/MaybeOwned.h>
#include <ATen/native/cuda/RowwiseScaledMM.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -819,24 +824,106 @@ static bool _scaled_mm_allowed_device() {
#endif
}
namespace{
enum class ScalingType {
TensorWise,
RowWise,
Error
};
/*
* Scaling Type Determination:
* ---------------------------
* Conditions and corresponding Scaling Types:
*
* - If scale_a.numel() == 1 && scale_b.numel() == 1:
* - Returns TensorWise.
*
* - Else if scale_a.dim() == 1 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
* - Returns RowWise.
*
* - Otherwise:
* - Returns Error.
*/
// Validates the scale tensors to scaled_mm
// And returns the type of scaling/which kernel to use
ScalingType get_scaling_type(
const at::Tensor& scale_a,
const at::Tensor& scale_b,
int64_t dim_m,
int64_t dim_n) {
// Both Per-Tensor and Row-wise scaling expect fp32 tensors
TORCH_CHECK(
scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
"Both scale_a and scale_b must be float (fp32) tensors.");
// Check the singluar scale case for per-tensor scaling
if (scale_a.numel() == 1 && scale_b.numel() == 1) {
return ScalingType::TensorWise;
} else if (scale_a.dim() == 1 && scale_a.size(0) == dim_m) {
// Check the per-row scaling case
#if !defined(USE_ROCM) && !defined(_MSC_VER) || \
(defined(USE_ROCM) && ROCM_VERSION >= 60000)
TORCH_CHECK(
scale_a.dim() == 1 && scale_b.dim() == 1,
"Both scale_a and scale_b must be 1-dimensional tensors");
TORCH_CHECK(
scale_b.size(0) == dim_n,
"For row-wise scaling, scale_b must have size ",
dim_n,
" but got ",
scale_b.size(0),
".");
TORCH_CHECK(
scale_a.is_contiguous() && scale_b.is_contiguous(),
"Both scale_a and scale_b must be contiguous.");
return ScalingType::RowWise;
#else
TORCH_CHECK(false, "Per-row scaling is not supported for this platform!");
return ScalingType::Error;
#endif // !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) &&
// ROCM_VERSION >= 60000)
} else {
// Prettier Error Case messaging
TORCH_CHECK(
false,
"For row-wise scaling, scale_a must be size ",
dim_m,
" but got ",
scale_a.numel(),
" and scale_b must be size ",
dim_n,
" but got ",
scale_b.numel(),
".");
// Unreachable
return ScalingType::RowWise;
}
return ScalingType::Error;
}
} // namespace
// Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax
// Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default.
// If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed.
// Known limitations:
// - Only works if mat1 is row-major and mat2 is column-major
// - Only works if matrices sizes are divisible by 32
//
// - If 1-dimensional tensors are used then scale_a should be size = mat1.size(0)
// and scale_b should have size = to mat2.size(1)
// Arguments:
// - `mat1`: the first operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
// - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type
// - `scale_a`: a scalar tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type
// - `scale_b`: a scalar tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type
// - `scale_result`: a scalar tensor with the scale of the output, only set if the output is a float8 type
// - `scale_a`: a scalar or 1-dimensional tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type
// - `scale_b`: a scalar or 1-dimensional tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type
// - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type
// - `use_fast_accum`: if true, enables fast float8 accumulation
// - `out`: a reference to the output tensor
// - `amax`: a reference to the amax tensor of the output, only needed if the output is a float8 type and will be updated inplace
Tensor&
_scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
@ -855,10 +942,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
TORCH_CHECK((scale_a.numel() == 1 && scale_a.scalar_type() == kFloat),
"scale_a must be float scalar");
TORCH_CHECK((scale_b.numel() == 1 && scale_b.scalar_type() == kFloat),
"scale_b must be a float scalar");
// Check what type of scaling we are doing based on inputs
ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1));
TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported");
TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
"scale_result must be a float scalar");
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
@ -899,11 +987,25 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
{scale_result_, "scale_result", 6}};
checkAllSameGPU(__func__, targs);
}
// Validation checks have passed lets resize the output to actual size
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
// We are doing row-wise scaling
if (scaling_choice == ScalingType::RowWise) {
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precsion output types are supported for row-wise scaling.");
at::cuda::detail::f8f8bf16_rowwise(
mat1,
mat2,
scale_a,
scale_b,
bias,
use_fast_accum,
out);
return out;
}
cublasCommonArgs args(mat1, mat2, out);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");

View File

@ -0,0 +1,536 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
// Determine if the architecture supports rowwise scaled mm
// Currenlty failing on windows with: https://github.com/NVIDIA/cutlass/issues/1571
#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000
#define BUILD_ROWWISE_FP8_KERNEL
#endif
#if defined(BUILD_ROWWISE_FP8_KERNEL)
// We are going to override the cuTensorMapEncodeTiled driver api with our lazy loader
static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
CUtensorMap* tensorMap,
CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank,
void* globalAddress,
const cuuint64_t* globalDim,
const cuuint64_t* globalStrides,
const cuuint32_t* boxDim,
const cuuint32_t* elementStrides,
CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle,
CUtensorMapL2promotion l2Promotion,
CUtensorMapFloatOOBfill oobFill) {
return at::globalContext().getNVRTC().cuTensorMapEncodeTiled(
tensorMap,
tensorDataType,
tensorRank,
globalAddress,
globalDim,
globalStrides,
boxDim,
elementStrides,
interleave,
swizzle,
l2Promotion,
oobFill);
}
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/half.h>
#include <cutlass/numeric_types.h>
#include <cutlass/trace.h>
#include <cutlass/util/host_tensor.h>
// Rename the global function symbol
#define cuTensorMapEncodeTiled nvrtc_cuTensorMapEncodeTiled
#include <cute/tensor.hpp>
#undef cuTensorMapEncodeTiled
// Set everything back to normal
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
#include <cutlass/epilogue/collective/collective_builder.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cutlass/gemm/dispatch_policy.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
namespace {
// Cutlass rowwise kernel
template <
int TB_M,
int TB_N,
int TB_K,
int TBS_M,
int TBS_N,
int TBS_K,
bool PONG,
bool FAST_ACCUM,
bool USE_BIAS,
typename INPUT_DTYPE,
typename BIAS_DTYPE>
void f8f8bf16_rowwise_impl(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor x_scale,
at::Tensor w_scale,
c10::optional<at::Tensor> bias,
at::Tensor out) {
int M = XQ.size(0);
int N = WQ.size(1);
int K = XQ.size(1);
TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
TORCH_CHECK(
WQ.is_cuda() && WQ.ndimension() == 2 && WQ.stride(1) == WQ.size(0) &&
WQ.stride(0) == 1);
// auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));
using ElementInputA = INPUT_DTYPE;
using LayoutInputA = cutlass::layout::RowMajor;
constexpr int AlignmentInputA = 16 / sizeof(ElementInputA);
using ElementInputB = cutlass::float_e4m3_t;
using LayoutInputB = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputB = 16 / sizeof(ElementInputB);
using ElementBias = BIAS_DTYPE;
using ElementOutput = cutlass::bfloat16_t;
using LayoutOutput = cutlass::layout::RowMajor;
constexpr int AlignmentOutput = 16 / sizeof(ElementOutput);
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
// supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp;
using TileShape = cute::Shape<
cute::Int<TB_M>,
cute::Int<TB_N>,
cute::Int<TB_K>>; // Threadblock-level
// tile size
using ClusterShape = cute::Shape<
cute::Int<TBS_M>,
cute::Int<TBS_N>,
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using KernelSchedule = cutlass::gemm::collective::
KernelScheduleAuto; // Kernel to launch based on the default setting in
// the Collective Builder
// Implement rowwise scaling epilogue.
using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<
0,
TileShape,
ElementComputeEpilogue,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
PONG ? 2 : 1,
TileShape,
ElementComputeEpilogue,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
PONG ? 2 : 1,
TileShape,
ElementBias,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies,
ElementComputeEpilogue, // First stage output type.
ElementComputeEpilogue, // First stage input types.
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies,
cute::conditional_t< // Second stage output type.
USE_BIAS,
ElementBias,
ElementOutput>,
ElementComputeEpilogue, // Second stage input types.
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 =
cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
using ComputeBias = cutlass::epilogue::fusion::Sm90Compute<
cutlass::plus,
ElementOutput, // Final (optional) stage output type.
ElementBias, // Final stage input types.
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeBias =
cutlass::epilogue::fusion::Sm90EVT<ComputeBias, Bias, EVTCompute1>;
using EpilogueEVT =
cute::conditional_t<USE_BIAS, EVTComputeBias, EVTCompute1>;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementComputeEpilogue,
ElementOutput,
LayoutOutput,
AlignmentOutput,
ElementOutput,
LayoutOutput,
AlignmentOutput,
cutlass::epilogue::TmaWarpSpecialized,
EpilogueEVT>::CollectiveOp;
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using FastDefaultSchedule =
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using FastPongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using SlowAccum = cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
using FastAccum =
cute::conditional_t<PONG, FastPongSchedule, FastDefaultSchedule>;
using MainLoopSchedule =
cute::conditional_t<FAST_ACCUM, FastAccum, SlowAccum>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementInputA,
LayoutInputA,
AlignmentInputA,
ElementInputB,
LayoutInputB,
AlignmentInputB,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainLoopSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideOutput = typename Gemm::GemmKernel::StrideC;
StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(M, K, 1));
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(N, K, 1));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(M, N, 1));
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K},
{reinterpret_cast<ElementInputA*>(XQ.data_ptr()),
stride_a,
reinterpret_cast<ElementInputB*>(WQ.data_ptr()),
stride_b},
{{}, // Epilogue thread we populate below.
(ElementOutput*)out.data_ptr<at::BFloat16>(),
stride_output,
(ElementOutput*)out.data_ptr<at::BFloat16>(),
stride_output}};
if constexpr (USE_BIAS) {
arguments.epilogue.thread = {
{reinterpret_cast<ElementBias*>(bias.value().data_ptr())}, // bias
// compute_1
{
{reinterpret_cast<ElementComputeEpilogue*>(
x_scale.data_ptr())}, // x_scale
// compute_0
{
{reinterpret_cast<ElementComputeEpilogue*>(
w_scale.data_ptr())}, // w_scale
{}, // Accumulator
{} // Multiplies
},
{}, // Multiplies
},
{}, // Plus
};
} else {
arguments.epilogue.thread = {
{reinterpret_cast<ElementComputeEpilogue*>(
x_scale.data_ptr())}, // x_scale
// compute_0
{
{reinterpret_cast<ElementComputeEpilogue*>(
w_scale.data_ptr())}, // w_scale
{}, // Accumulator
{} // Multiplies
},
{}, // Multiplies
};
}
Gemm gemm;
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = gemm(at::cuda::getCurrentCUDAStream());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
std::string("cutlass cannot run") +
cutlass::cutlassGetStatusString(status));
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
// FP8 Rowwise Cutlass kernel dispatch.
enum class KernelMode { Small, Large, Default };
KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) {
auto M = XQ.size(0);
auto K = XQ.size(1);
auto N = WQ.size(0);
// Use a large kernel if at least two shapes are large....
bool use_large_kernel =
((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) ||
(K >= 2048 && N >= 2048));
if (M <= 128 || N <= 128) {
return KernelMode::Small;
} else if (use_large_kernel) {
return KernelMode::Large;
} else {
return KernelMode::Default;
}
}
template <typename InputDType, bool FastAccum, bool UseBias, typename BiasDType>
void dispatch_fp8_rowwise_kernel(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
c10::optional<at::Tensor> bias,
at::Tensor out) {
KernelMode kernel = get_kernel_mode(XQ, WQ);
if (kernel == KernelMode::Small) {
return f8f8bf16_rowwise_impl<
64,
128,
128,
2,
1,
1,
false,
FastAccum,
UseBias,
InputDType,
BiasDType>(XQ, WQ, x_scale, w_scale, bias, out);
} else if (kernel == KernelMode::Large) {
return f8f8bf16_rowwise_impl<
128,
128,
128,
2,
1,
1,
true,
FastAccum,
UseBias,
InputDType,
BiasDType>(XQ, WQ, x_scale, w_scale, bias, out);
} else {
return f8f8bf16_rowwise_impl<
128,
128,
128,
1,
2,
1,
false,
FastAccum,
UseBias,
InputDType,
BiasDType>(XQ, WQ, x_scale, w_scale, bias, out);
}
}
} // namespace
#endif // !defined(USE_ROCM)
namespace at::cuda::detail {
void f8f8bf16_rowwise(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor x_scale, // FP32
at::Tensor w_scale, // FP32
c10::optional<at::Tensor> bias, // BF16
bool use_fast_accum,
at::Tensor& out) {
#if defined(BUILD_ROWWISE_FP8_KERNEL)
// Check datatypes.
TORCH_CHECK(
x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat,
"Scale tensors must be float32.");
if (bias.has_value()) {
TORCH_CHECK(
bias.value().dtype() == at::kFloat ||
bias.value().dtype() == at::kBFloat16,
"Bias type must be bfloat16 or float32 if provided.");
}
// Extract problem size.
int M = XQ.size(0);
int N = WQ.size(1);
int K = XQ.size(1);
bool use_bias = bias.has_value();
bool bf16_bias = use_bias && bias.value().dtype() == at::kBFloat16;
// Templatize based on input dtype.
bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2;
TORCH_CHECK(WQ.dtype() == at::kFloat8_e4m3fn, "For row-wise scaling the second input is required to be a float8_e4m3fn dtype.");
if (use_bias) {
if (bf16_bias) {
if (use_fast_accum) {
if (use_e5m2) {
return dispatch_fp8_rowwise_kernel<
cutlass::float_e5m2_t,
true,
true,
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out);
} else {
return dispatch_fp8_rowwise_kernel<
cutlass::float_e4m3_t,
true,
true,
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out);
}
} else {
if (use_e5m2) {
return dispatch_fp8_rowwise_kernel<
cutlass::float_e5m2_t,
false,
true,
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out);
} else {
return dispatch_fp8_rowwise_kernel<
cutlass::float_e4m3_t,
false,
true,
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out);
}
}
} else {
if (use_fast_accum) {
if (use_e5m2) {
return dispatch_fp8_rowwise_kernel<
cutlass::float_e5m2_t,
true,
true,
float>(XQ, WQ, x_scale, w_scale, bias, out);
} else {
return dispatch_fp8_rowwise_kernel<
cutlass::float_e4m3_t,
true,
true,
float>(XQ, WQ, x_scale, w_scale, bias, out);
}
} else {
if (use_e5m2) {
return dispatch_fp8_rowwise_kernel<
cutlass::float_e5m2_t,
false,
true,
float>(XQ, WQ, x_scale, w_scale, bias, out);
} else {
return dispatch_fp8_rowwise_kernel<
cutlass::float_e4m3_t,
false,
true,
float>(XQ, WQ, x_scale, w_scale, bias, out);
}
}
}
} else {
if (use_fast_accum) {
if (use_e5m2) {
return dispatch_fp8_rowwise_kernel<
cutlass::float_e5m2_t,
true,
false,
float>(XQ, WQ, x_scale, w_scale, bias, out);
} else {
return dispatch_fp8_rowwise_kernel<
cutlass::float_e4m3_t,
true,
false,
float>(XQ, WQ, x_scale, w_scale, bias, out);
}
} else {
if (use_e5m2) {
return dispatch_fp8_rowwise_kernel<
cutlass::float_e5m2_t,
false,
false,
float>(XQ, WQ, x_scale, w_scale, bias, out);
} else {
return dispatch_fp8_rowwise_kernel<
cutlass::float_e4m3_t,
false,
false,
float>(XQ, WQ, x_scale, w_scale, bias, out);
}
}
}
#else // BUILD_ROWWISE_FP8_KERNEL
TORCH_CHECK(false, "Rowwise scaling is not currenlty supported on your device");
#endif
}
} // namespace at::cuda::detail

View File

@ -0,0 +1,15 @@
#pragma once
#include <ATen/core/TensorBase.h>
#include <c10/util/Optional.h>
namespace at::cuda::detail {
TORCH_API void f8f8bf16_rowwise(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor x_scale, // FP32
at::Tensor w_scale, // FP32
c10::optional<at::Tensor> bias, // BF16
bool use_fast_accum,
at::Tensor& out);
} // at::cuda::detail

View File

@ -14,6 +14,7 @@ using namespace at::cuda::detail;
// Kernel for fast unfold+copy on volumes
template <typename T>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void vol2col_kernel(
const int64_t n,
const T* data_vol,

View File

@ -614,8 +614,6 @@ void add_projection_weights(
/*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(),
/*linLayerMat=*/&matrix_pointer));
#else
void* unused_pointer;
TensorDescriptor unused_desc;
TensorDescriptor lin_layer_mat_desc;
AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
/*handle=*/handle,
@ -626,8 +624,8 @@ void add_projection_weights(
/*linLayerID=*/linear_id,
/*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(),
/*linLayerMat=*/&matrix_pointer,
unused_desc.mut_desc(),
&unused_pointer));
nullptr,
nullptr));
#endif
cudnnDataType_t data_type;
@ -735,8 +733,6 @@ get_parameters(
lin_layer_mat_desc.mut_desc(),
&matrix_pointer));
#else
void* unused_pointer = nullptr;
TensorDescriptor unused_desc;
TensorDescriptor lin_layer_mat_desc;
for (int stateless = 0; stateless < 100; stateless++) {
if (cudnn_method) { // matrix
@ -749,8 +745,8 @@ get_parameters(
linear_id,
lin_layer_mat_desc.mut_desc(),
&matrix_pointer,
unused_desc.mut_desc(),
&unused_pointer));
nullptr,
nullptr));
} else { // bias
AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
handle,
@ -759,8 +755,8 @@ get_parameters(
weight_buf.numel() * weight_buf.element_size(),
weight_buf.data_ptr(),
linear_id,
unused_desc.mut_desc(),
&unused_pointer,
nullptr,
nullptr,
lin_layer_mat_desc.mut_desc(),
&matrix_pointer));
}
@ -922,8 +918,6 @@ std::vector<void*> get_expected_data_ptrs(
lin_layer_mat_desc.mut_desc(),
&matrix_pointer));
#else
void* unused_pointer = nullptr;
TensorDescriptor unused_desc;
TensorDescriptor lin_layer_mat_desc;
if (cudnn_method) { // matrix
AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
@ -935,8 +929,8 @@ std::vector<void*> get_expected_data_ptrs(
linear_id,
lin_layer_mat_desc.mut_desc(),
&matrix_pointer,
unused_desc.mut_desc(),
&unused_pointer));
nullptr,
nullptr));
} else { // bias
AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
handle,
@ -945,8 +939,8 @@ std::vector<void*> get_expected_data_ptrs(
weight_buf.numel() * weight_buf.element_size(),
weight_buf.data_ptr(),
linear_id,
unused_desc.mut_desc(),
&unused_pointer,
nullptr,
nullptr,
lin_layer_mat_desc.mut_desc(),
&matrix_pointer));
}
@ -972,8 +966,6 @@ std::vector<void*> get_expected_data_ptrs(
lin_layer_mat_desc.mut_desc(),
&matrix_pointer));
#else
void* unused_pointer;
TensorDescriptor unused_desc;
TensorDescriptor lin_layer_mat_desc;
AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
@ -985,8 +977,8 @@ std::vector<void*> get_expected_data_ptrs(
linear_id,
lin_layer_mat_desc.mut_desc(),
&matrix_pointer,
unused_desc.mut_desc(),
&unused_pointer));
nullptr,
nullptr));
#endif
data_ptrs.push_back(matrix_pointer);
}

View File

@ -421,17 +421,6 @@ TORCH_LIBRARY_IMPL(mkl, MkldnnCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("mkl::_mkl_linear"), TORCH_FN(mkl_linear));
}
#else // AT_MKL_ENABLED
static Tensor mkl_linear(
const Tensor& self,
const Tensor& mkl_weight_t,
const Tensor& origin_weight_t,
const std::optional<Tensor>& bias_opt,
const int64_t prepack_batch_size) {
TORCH_CHECK(false, "mkl_linear: ATen not compiled with MKL support");
}
#endif// AT_MKL_ENABLED
TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {

View File

@ -336,25 +336,34 @@ inline bool is_dense_in_storage(const at::Tensor& t) {
class MetalShaderLibrary {
public:
MetalShaderLibrary(const std::string& src, unsigned nparams_ = 0): shaderSource(src), nparams(nparams_) {}
MetalShaderLibrary(const std::string& src): shaderSource(src), nparams(0), compile_options(nullptr){}
MetalShaderLibrary(const std::string& src, unsigned nparams_): shaderSource(src), nparams(nparams_), compile_options(nullptr){}
MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_): shaderSource(src), nparams(nparams_), compile_options(compile_options_) {}
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
inline id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname) {
return getLibraryPipelineState(getLibrary(), fname);
return getLibraryPipelineState(getLibrary(), fname).first;
}
id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname, const std::initializer_list<std::string>& params) {
return getLibraryPipelineState(getLibrary(params), fname);
return getLibraryPipelineState(getLibrary(params), fname).first;
}
inline id<MTLFunction> getMTLFunction(const std::string& fname) {
return getLibraryPipelineState(getLibrary(), fname).second;
}
id<MTLFunction> getMTLFunction(const std::string& fname, const std::initializer_list<std::string>& params) {
return getLibraryPipelineState(getLibrary(params), fname).second;
}
private:
id<MTLComputePipelineState> getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);
std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);
id<MTLLibrary> getLibrary();
id<MTLLibrary> getLibrary(const std::initializer_list<std::string>& params);
id<MTLLibrary> compileLibrary(const std::string& src);
std::string shaderSource;
unsigned nparams;
MTLCompileOptions* compile_options;
id<MTLLibrary> library = nil;
std::unordered_map<std::string, id<MTLLibrary>> libMap;
std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap;
std::unordered_map<std::string, std::pair<id<MTLComputePipelineState>, id<MTLFunction>>> cplMap;
};
static inline void mtl_setBuffer(id<MTLComputeCommandEncoder> encoder, const Tensor& t, unsigned idx) {

View File

@ -656,31 +656,38 @@ id<MTLLibrary> MetalShaderLibrary::getLibrary(const std::initializer_list<std::s
id<MTLLibrary> MetalShaderLibrary::compileLibrary(const std::string& src) {
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1
: MTLLanguageVersion2_3];
// [options setFastMathEnabled: NO];
auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding];
MTLCompileOptions* options = compile_options;
if (!options) {
options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1
: MTLLanguageVersion2_3];
[options setFastMathEnabled:NO];
}
const auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding];
auto device = MPSDevice::getInstance()->device();
library = [device newLibraryWithSource:str options:options error:&error];
TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]);
return library;
}
id<MTLComputePipelineState> MetalShaderLibrary::getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname) {
auto key = fmt::format("{}:{}", reinterpret_cast<void*>(lib), fname);
auto cpl = cplMap[key];
if (cpl) {
return cpl;
std::pair<id<MTLComputePipelineState>, id<MTLFunction>> MetalShaderLibrary::getLibraryPipelineState(
id<MTLLibrary> lib,
const std::string& fname) {
const auto key = fmt::format("{}:{}", reinterpret_cast<void*>(lib), fname);
auto found_cpl = cplMap.find(key);
if (found_cpl != cplMap.end()) {
return found_cpl->second;
}
NSError* error = nil;
id<MTLFunction> func = [lib newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]];
TORCH_CHECK(func, "Failed to create function state object for: ", fname);
cpl = [[lib device] newComputePipelineStateWithFunction:func error:&error];
auto cpl = [[lib device] newComputePipelineStateWithFunction:func error:&error];
TORCH_CHECK(cpl, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
return cplMap[key] = cpl;
cplMap[key] = std::make_pair(cpl, func);
return cplMap[key];
}
} // namespace at::native::mps

View File

@ -0,0 +1,24 @@
#pragma once
#include <ATen/core/Tensor.h>
namespace at::native {
namespace mps {
void _fused_adam_amsgrad_mps_impl_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf
);
} //namespace mps
}// namespace at::native

View File

@ -0,0 +1,37 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h>
#include <ATen/Dispatch.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/mps/operations/FusedOptimizerOps.h>
#include <ATen/native/mps/operations/MultiTensorApply.h>
#include <vector>
namespace at::native {
namespace mps {
void _fused_adam_amsgrad_mps_impl_(at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf) {
std::vector<std::vector<at::Tensor>> tensor_lists{
params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec(), max_exp_avg_sqs.vec()};
const std::string kernel_name = "fused_adam_amsgrad_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" +
scalarToMetalTypeString(state_steps[0].scalar_type());
multi_tensor_apply_for_fused_adam<5, 512>(
kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize);
}
} // namespace mps
} // namespace at::native

View File

@ -0,0 +1,69 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/TypeDefault.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h>
#include <ATen/native/mps/operations/FusedAdamKernelImpl.h>
#include <c10/util/Exception.h>
#include <iostream>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_fused_adam_native.h>
#endif
namespace at::native {
void _fused_adam_kernel_mps_(at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf) {
if (amsgrad) {
TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
"params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
mps::_fused_adam_amsgrad_mps_impl_(params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale,
found_inf);
} else {
TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}),
"params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
mps::_fused_adam_mps_impl_(params,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale,
found_inf);
}
}
} // namespace at::native

View File

@ -0,0 +1,23 @@
#pragma once
#include <ATen/core/Tensor.h>
namespace at::native {
namespace mps {
void _fused_adam_mps_impl_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf
);
} //namespace mps
}// namespace at::native

View File

@ -0,0 +1,35 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/mps/operations/FusedAdamKernelImpl.h>
#include <ATen/Dispatch.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/mps/operations/FusedOptimizerOps.h>
#include <ATen/native/mps/operations/MultiTensorApply.h>
#include <vector>
namespace at::native {
namespace mps {
void _fused_adam_mps_impl_(at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf) {
std::vector<std::vector<at::Tensor>> tensor_lists{params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()};
const std::string kernel_name = "fused_adam_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" +
scalarToMetalTypeString(state_steps[0].scalar_type());
multi_tensor_apply_for_fused_adam<4, 512>(
kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize);
}
} // namespace mps
} // namespace at::native

View File

@ -0,0 +1,24 @@
#pragma once
#include <ATen/core/Tensor.h>
namespace at::native {
namespace mps {
void _fused_adamw_amsgrad_mps_impl_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf
);
} //namespace mps
}// namespace at::native

View File

@ -0,0 +1,37 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h>
#include <ATen/Dispatch.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/mps/operations/FusedOptimizerOps.h>
#include <ATen/native/mps/operations/MultiTensorApply.h>
#include <vector>
namespace at::native {
namespace mps {
void _fused_adamw_amsgrad_mps_impl_(at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf) {
std::vector<std::vector<at::Tensor>> tensor_lists{
params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec(), max_exp_avg_sqs.vec()};
const std::string kernel_name = "fused_adamw_amsgrad_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" +
scalarToMetalTypeString(state_steps[0].scalar_type());
multi_tensor_apply_for_fused_adam<5, 512>(
kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize);
}
} // namespace mps
} // namespace at::native

View File

@ -0,0 +1,68 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/TypeDefault.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h>
#include <ATen/native/mps/operations/FusedAdamWKernelImpl.h>
#include <c10/util/Exception.h>
#include <iostream>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_fused_adamw_native.h>
#endif
namespace at::native {
void _fused_adamw_kernel_mps_(at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf) {
if (amsgrad) {
TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
"params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
mps::_fused_adamw_amsgrad_mps_impl_(params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale,
found_inf);
} else {
TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}),
"params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
mps::_fused_adamw_mps_impl_(params,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale,
found_inf);
}
}
} // namespace at::native

View File

@ -0,0 +1,23 @@
#pragma once
#include <ATen/core/Tensor.h>
namespace at::native {
namespace mps {
void _fused_adamw_mps_impl_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf
);
} //namespace mps
}// namespace at::native

View File

@ -0,0 +1,35 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/mps/operations/FusedAdamWKernelImpl.h>
#include <ATen/Dispatch.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/mps/operations/FusedOptimizerOps.h>
#include <ATen/native/mps/operations/MultiTensorApply.h>
#include <vector>
namespace at::native {
namespace mps {
void _fused_adamw_mps_impl_(at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf) {
std::vector<std::vector<at::Tensor>> tensor_lists{params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()};
const std::string kernel_name = "fused_adamw_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" +
scalarToMetalTypeString(state_steps[0].scalar_type());
multi_tensor_apply_for_fused_adam<4, 512>(
kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize);
}
} // namespace mps
} // namespace at::native

View File

@ -0,0 +1,274 @@
#pragma once
#include <ATen/native/mps/OperationUtils.h>
namespace at::native {
namespace mps {
static const char* FUSED_ADAM_OPS = R"METAL(
#include <metal_stdlib>
#define kmaxThreadGroups 32
#define kmaxTensors 32
#define chunk_size 65536
constexpr constant uint kParamIdx = 0;
constexpr constant uint kGradIdx = kParamIdx + kmaxTensors;
constexpr constant uint kExpAvgIdx = kGradIdx + kmaxTensors;
constexpr constant uint kExpAvgSqIdx = kExpAvgIdx + kmaxTensors;
constexpr constant uint kMaxExpAvgSqIdx = kExpAvgSqIdx + kmaxTensors;
constexpr constant uint kStateStepsIdx = kExpAvgSqIdx + kmaxTensors;
constexpr constant uint kStateStepsIdxForAmsgrad = kMaxExpAvgSqIdx + kmaxTensors;
template<typename T, typename state_steps_t>
struct AdamArguments {
metal::array<device T *, kmaxTensors> params [[ id(kParamIdx) ]];
metal::array<device T *, kmaxTensors> grads [[ id(kGradIdx) ]];
metal::array<device T *, kmaxTensors> exp_avgs [[ id(kExpAvgIdx) ]];
metal::array<device T *, kmaxTensors> exp_avg_sqs [[ id(kExpAvgSqIdx) ]];
metal::array<device state_steps_t *, kmaxTensors> state_steps [[ id(kStateStepsIdx) ]];
};
template<typename T, typename state_steps_t>
struct AdamAmsgradArguments {
metal::array<device T *, kmaxTensors> params [[ id(kParamIdx) ]];
metal::array<device T *, kmaxTensors> grads [[ id(kGradIdx) ]];
metal::array<device T *, kmaxTensors> exp_avgs [[ id(kExpAvgIdx) ]];
metal::array<device T *, kmaxTensors> exp_avg_sqs [[ id(kExpAvgSqIdx) ]];
metal::array<device T *, kmaxTensors> max_exp_avg_sqs [[ id(kMaxExpAvgSqIdx) ]];
metal::array<device state_steps_t *, kmaxTensors> state_steps [[ id(kStateStepsIdxForAmsgrad) ]];
};
struct MetadataArguments {
uint32_t numels[kmaxTensors];
uint32_t threadgroup_to_tensor[kmaxThreadGroups];
uint32_t threadgroup_to_chunk[kmaxThreadGroups];
};
enum ADAM_MODE : uint8_t {
ORIGINAL = 0,
ADAMW = 1
};
template <typename T, typename state_steps_t, ADAM_MODE adam_mode>
inline void adam_math_amsgrad(
device T & param,
device T & grad,
device T & exp_avg,
device T & exp_avg_sq,
device T & max_exp_avg_sq,
device state_steps_t & state_steps,
const float lr,
const float beta1,
const float beta2,
const float weight_decay,
const float eps,
const uint8_t maximize
) {
T grad_ = grad;
if (maximize) {
grad = -grad;
}
// Update param, grad, 1st and 2nd order momentum.
if (weight_decay != 0) {
switch (adam_mode) {
case ADAM_MODE::ORIGINAL:
grad += param * weight_decay;
break;
case ADAM_MODE::ADAMW:
param -= lr * weight_decay * param;
break;
}
}
exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
const float casted_state_steps = static_cast<float>(state_steps);
const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps);
const T step_size = lr / bias_correction1;
const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps);
const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
max_exp_avg_sq = metal::max(max_exp_avg_sq, exp_avg_sq);
const T denom = (metal::precise::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps;
param -= step_size * exp_avg / denom;
grad = grad_;
}
template <typename T, typename state_steps_t, ADAM_MODE adam_mode>
inline void adam_math(
device T & param,
device T & grad,
device T & exp_avg,
device T & exp_avg_sq,
device state_steps_t & state_steps,
const float lr,
const float beta1,
const float beta2,
const float weight_decay,
const float eps,
const uint8_t maximize
) {
T grad_ = grad;
if (maximize) {
grad = -grad;
}
// Update param, grad, 1st and 2nd order momentum.
if (weight_decay != 0) {
switch (adam_mode) {
case ADAM_MODE::ORIGINAL:
grad += param * weight_decay;
break;
case ADAM_MODE::ADAMW:
param -= lr * weight_decay * param;
break;
}
}
exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
const float casted_state_steps = static_cast<float>(state_steps);
const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps);
const T step_size = lr / bias_correction1;
const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps);
const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
const T denom = (metal::precise::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps;
param -= step_size * exp_avg / denom;
grad = grad_;
}
template <typename T, typename state_steps_t, ADAM_MODE adam_mode>
kernel void fused_adam_amsgrad(
device AdamAmsgradArguments<T, state_steps_t> & args [[buffer(0)]],
constant MetadataArguments & metadata_args [[buffer(1)]],
constant float & lr [[buffer(2)]],
constant float & beta1 [[buffer(3)]],
constant float & beta2 [[buffer(4)]],
constant float & weight_decay [[buffer(5)]],
constant float & eps [[buffer(6)]],
constant uint8_t & maximize [[buffer(7)]],
uint tid [[thread_position_in_threadgroup]],
uint tgid [[threadgroup_position_in_grid]],
uint tptg [[threads_per_threadgroup]]) {
const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid];
const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid];
const uint32_t chunk_offset = chunk_idx * chunk_size;
const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset;
const auto step_count = args.state_steps[tensor_loc];
// each chunk is a threadgroup
auto param = args.params[tensor_loc] + chunk_offset;
auto grad = args.grads[tensor_loc] + chunk_offset;
auto exp_avg = args.exp_avgs[tensor_loc] + chunk_offset;
auto exp_avg_sq = args.exp_avg_sqs[tensor_loc] + chunk_offset;
auto max_exp_avg_sq = args.max_exp_avg_sqs[tensor_loc] + chunk_offset;
for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) {
adam_math_amsgrad<T, state_steps_t, adam_mode>(
*(param + i_start),
*(grad + i_start),
*(exp_avg + i_start),
*(exp_avg_sq + i_start),
*(max_exp_avg_sq + i_start),
*step_count,
lr,
beta1,
beta2,
weight_decay,
eps,
maximize
);
}
}
template <typename T, typename state_steps_t, ADAM_MODE adam_mode>
kernel void fused_adam(
device AdamArguments<T, state_steps_t> & args [[buffer(0)]],
constant MetadataArguments & metadata_args [[buffer(1)]],
constant float & lr [[buffer(2)]],
constant float & beta1 [[buffer(3)]],
constant float & beta2 [[buffer(4)]],
constant float & weight_decay [[buffer(5)]],
constant float & eps [[buffer(6)]],
constant uint8_t & maximize [[buffer(7)]],
uint tid [[thread_position_in_threadgroup]],
uint tgid [[threadgroup_position_in_grid]],
uint tptg [[threads_per_threadgroup]]) {
const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid];
const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid];
const uint32_t chunk_offset = chunk_idx * chunk_size;
const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset;
const auto step_count = args.state_steps[tensor_loc];
// each chunk is a threadgroup
auto param = args.params[tensor_loc] + chunk_offset;
auto grad = args.grads[tensor_loc] + chunk_offset;
auto exp_avg = args.exp_avgs[tensor_loc] + chunk_offset;
auto exp_avg_sq = args.exp_avg_sqs[tensor_loc] + chunk_offset;
for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) {
adam_math<T, state_steps_t, adam_mode>(
*(param + i_start),
*(grad + i_start),
*(exp_avg + i_start),
*(exp_avg_sq + i_start),
*step_count,
lr,
beta1,
beta2,
weight_decay,
eps,
maximize
);
}
}
#define REGISTER_FUSED_ADAM_OP(DTYPE, STATE_STEPS_DTYPE, ADAM_MODE_DTYPE, HOST_NAME, KERNEL_NAME, ARGUMENTS_STRUCT) \
template \
[[host_name(#HOST_NAME "_" #DTYPE "_" #STATE_STEPS_DTYPE)]] \
kernel void KERNEL_NAME<DTYPE, STATE_STEPS_DTYPE, ADAM_MODE_DTYPE>( \
device ARGUMENTS_STRUCT<DTYPE, STATE_STEPS_DTYPE> & args [[buffer(0)]],\
constant MetadataArguments & metadata_args [[buffer(1)]],\
constant float & lr [[buffer(2)]],\
constant float & beta1 [[buffer(3)]],\
constant float & beta2 [[buffer(4)]],\
constant float & weight_decay [[buffer(5)]],\
constant float & eps [[buffer(6)]],\
constant uint8_t & maximize [[buffer(7)]],\
uint tid [[thread_position_in_threadgroup]],\
uint tgid [[threadgroup_position_in_grid]],\
uint tptg [[threads_per_threadgroup]])
REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments);
REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments);
REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments);
REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments);
REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments);
REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments);
REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments);
REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments);
REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
)METAL";
static std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getCPLState(const std::string& fname) {
static MetalShaderLibrary lib(FUSED_ADAM_OPS, 0);
return std::make_pair(lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname));
}
} //namespace mps
} // namespace at::native

View File

@ -17,11 +17,15 @@
#include <ATen/ops/addr_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/linalg_lu_factor_native.h>
#include <ATen/ops/linalg_solve_triangular_native.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/stack.h>
#include <ATen/ops/triangular_solve_native.h>
#endif
#include <algorithm>
namespace at::native {
namespace mps {
namespace {
@ -127,6 +131,116 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output)
} // anonymous namespace
static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
using namespace mps;
TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()),
"linalg.lu_factor(): MPS doesn't support complex types.");
TORCH_CHECK(pivot, "linalg.lu_factor(): MPS doesn't allow pivot == False.");
Tensor A_t = A;
uint64_t aRows = A_t.size(-2);
uint64_t aCols = A_t.size(-1);
uint64_t aElemSize = A_t.element_size();
uint64_t numPivots = std::min(aRows, aCols);
std::vector<int64_t> pivot_sizes(A_t.sizes().begin(), A_t.sizes().end() - 2);
pivot_sizes.push_back(numPivots);
resize_output(pivots, pivot_sizes);
if (A_t.numel() == 0) {
return;
}
Tensor A_ = A_t.dim() > 3 ? A_t.flatten(0, -3) : A_t;
uint64_t batchSize = A_.dim() > 2 ? A_.size(0) : 1;
std::vector<Tensor> status_tensors;
std::vector<Tensor> pivots_list;
status_tensors.reserve(batchSize);
pivots_list.reserve(batchSize);
for (C10_UNUSED const auto i : c10::irange(batchSize)) {
status_tensors.push_back(at::zeros(1, kInt, c10::nullopt, kMPS, c10::nullopt));
pivots_list.push_back(at::zeros(numPivots, kInt, c10::nullopt, kMPS, c10::nullopt));
}
// Since the MPSMatrixDecompositionLU functions in-place if the result matrix completely aliases the source matrix,
// We copy LU from A as the new A.
resize_output(LU, A_.sizes());
if (!LU.is_same(A_)) {
A_ = LU.copy_(A_);
} else {
A_ = LU;
}
TORCH_INTERNAL_ASSERT(A_.is_contiguous())
id<MTLBuffer> aBuffer = getMTLBufferStorage(A_);
MPSStream* mpsStream = getCurrentMPSStream();
id<MTLDevice> device = MPSDevice::getInstance()->device();
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
MPSMatrixDecompositionLU* filter = [[[MPSMatrixDecompositionLU alloc] initWithDevice:device
rows:aRows
columns:aCols] autorelease];
MPSMatrixDescriptor* sourceMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:aRows
columns:aCols
matrices:batchSize
rowBytes:aCols * aElemSize
matrixBytes:aRows * aCols * aElemSize
dataType:getMPSDataType(A_)];
MPSMatrixDescriptor* pivotsMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:1
columns:numPivots
matrices:1
rowBytes:numPivots * sizeof(uint32_t)
matrixBytes:numPivots * sizeof(uint32_t)
dataType:MPSDataTypeUInt32];
for (const auto i : c10::irange(batchSize)) {
const uint64_t aBatchOffset = i * aRows * aCols;
MPSMatrix* sourceMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer
offset:(A_.storage_offset() + aBatchOffset) * aElemSize
descriptor:sourceMatrixDesc] autorelease];
MPSMatrix* pivotIndices = [[[MPSMatrix alloc] initWithBuffer:getMTLBufferStorage(pivots_list[i])
offset:0
descriptor:pivotsMatrixDesc] autorelease];
MPSMatrix* solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer
offset:(A_.storage_offset() + aBatchOffset) * aElemSize
descriptor:sourceMatrixDesc] autorelease];
id<MTLBuffer> statusBuffer = getMTLBufferStorage(status_tensors[i]);
[filter encodeToCommandBuffer:commandBuffer
sourceMatrix:sourceMatrix
resultMatrix:solutionMatrix
pivotIndices:pivotIndices
status:statusBuffer];
}
}
});
auto stacked_pivots = A_.dim() > 2 ? at::stack(pivots_list) : pivots_list[0];
if (A_t.dim() > 3) {
resize_output(LU, A_t.sizes());
pivots.copy_(stacked_pivots.view(pivot_sizes));
} else {
pivots.copy_(stacked_pivots);
}
pivots += 1; // PyTorch's `pivots` is 1-index.
for (const auto i : c10::irange(status_tensors.size())) {
int status = status_tensors[i].item<int>();
TORCH_CHECK(
status == 0,
"lu_factor(): LU factorization failure at the ",
i + 1,
" sample with status: ",
status,
". See https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus for details.");
}
}
static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
using namespace mps;
using CachedGraph = MPSBinaryCachedGraph;
@ -753,4 +867,16 @@ TORCH_IMPL_FUNC(triangular_solve_mps_out)
result.copy_(out);
}
std::tuple<Tensor&, Tensor&> linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
mps::linalg_lu_factor_out_mps_impl(A, pivot, LU, pivots);
return std::tie(LU, pivots);
}
std::tuple<Tensor, Tensor> linalg_lu_factor_mps(const Tensor& A, bool pivot) {
Tensor LU = at::empty({0}, A.options());
Tensor pivots = at::empty({0}, A.options().dtype(kInt));
mps::linalg_lu_factor_out_mps_impl(A, pivot, LU, pivots);
return std::make_tuple(std::move(LU), std::move(pivots));
}
} // namespace at::native

View File

@ -0,0 +1,190 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/mps/MPSProfiler.h>
#include <Aten/native/mps/operations/FusedOptimizerOps.h>
namespace at::native {
namespace mps {
static constexpr int64_t kChunkSize = 65536;
static constexpr int64_t kmaxThreadGroups = 32;
static constexpr int64_t kmaxTensors = 32;
struct MetadataArguments { // the size of this struct must be less than 4 bytes
uint numels[kmaxTensors];
uint threadgroup_to_tensor[kmaxThreadGroups];
uint threadgroup_to_chunk[kmaxThreadGroups];
};
template <int depth, uint32_t kThreadGroupSize>
static void multi_tensor_apply_for_fused_adam(
const std::string& kernel_name,
std::vector<std::vector<at::Tensor>>& tensor_lists,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool maximize
) {
const auto num_tensors = tensor_lists[0].size();
if (num_tensors == 0) {
return;
}
TORCH_CHECK(
tensor_lists.size() == depth,
"Number of tensor lists has to match the depth");
for (const auto& d : c10::irange(depth)) {
TORCH_CHECK(
tensor_lists[d][0].scalar_type() == at::ScalarType::Float || tensor_lists[d][0].scalar_type() == at::ScalarType::Half, "Only float and half are supported");
}
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
float lr_lv = lr;
float beta1_lv = beta1;
float beta2_lv = beta2;
float weight_decay_lv = weight_decay;
float eps_lv = eps;
uint8_t maximize_lv = maximize;
// Remove comment for debugging
/*
mpsStream->addCompletedHandler(^(id<MTLCommandBuffer> cb) {
[cb.logs enumerateObjectsUsingBlock:^(NSString* log, NSUInteger idx, BOOL* stop) {
NSLog(@"MPSStream: %@", log);
}
];
});
*/
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
auto [fusedOptimizerPSO, fusedOptimizerFunc] = getCPLState(kernel_name);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(fusedOptimizerPSO, kernel_name, {tensor_lists[0]});
[computeEncoder setComputePipelineState:fusedOptimizerPSO];
// BufferIndex is the index in the kernel function
auto tensorArgumentEncoder = [[fusedOptimizerFunc newArgumentEncoderWithBufferIndex:0] autorelease];
id<MTLBuffer> tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
int64_t tensor_loc = 0;
int64_t threadgroup_loc = 0;
MetadataArguments metadata_arguments;
for (const auto tensor_index : c10::irange(num_tensors)) {
// short-circuit to avoid adding empty tensors to tensorListMeta
if (tensor_lists[0][tensor_index].numel() == 0) {
continue;
}
for (const auto& d : c10::irange(depth)) {
[tensorArgumentEncoder setBuffer:getMTLBufferStorage(tensor_lists[d][tensor_index])
offset:tensor_lists[d][tensor_index].storage_offset() * tensor_lists[d][tensor_index].element_size()
atIndex:d * kmaxTensors + tensor_loc];
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageRead | MTLResourceUsageWrite];
}
[tensorArgumentEncoder setBuffer:getMTLBufferStorage(state_steps[tensor_index])
offset:state_steps[tensor_index].storage_offset() * state_steps[tensor_index].element_size()
atIndex:depth * kmaxTensors + tensor_loc];
[computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
metadata_arguments.numels[tensor_loc] = tensor_lists[0][tensor_index].numel();
tensor_loc++;
const auto numel = tensor_lists[0][tensor_index].numel();
const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
TORCH_CHECK(chunks > -1);
for (const auto& chunk : c10::irange(chunks)) {
metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = tensor_loc - 1;
metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk;
threadgroup_loc++;
const auto tensor_full = tensor_loc == kmaxTensors && chunk == chunks - 1;
// Reach the maximum threadgroups per dispatch
const auto blocks_full = threadgroup_loc == kmaxThreadGroups;
if (tensor_full || blocks_full){
[computeEncoder setBuffer:tensorArgumentBuffer
offset:0
atIndex:0];
[computeEncoder setBytes:&metadata_arguments
length:sizeof(MetadataArguments)
atIndex:1];
[computeEncoder setBytes:&lr_lv length:sizeof(float) atIndex:2];
[computeEncoder setBytes:&beta1_lv length:sizeof(float) atIndex:3];
[computeEncoder setBytes:&beta2_lv length:sizeof(float) atIndex:4];
[computeEncoder setBytes:&weight_decay_lv length:sizeof(float) atIndex:5];
[computeEncoder setBytes:&eps_lv length:sizeof(float) atIndex:6];
[computeEncoder setBytes:&maximize_lv length:sizeof(uint8_t) atIndex:7];
MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1);
uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup];
MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1);
[computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];
// Reset
threadgroup_loc = 0;
if (chunk == chunks - 1) {
// last chunk
tensor_loc = 0;
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
} else {
// reuse the current tensor since the current one isn't done.
metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1];
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
for (const auto& d : c10::irange(depth)) {
[tensorArgumentEncoder setBuffer:getMTLBufferStorage(tensor_lists[d][tensor_index])
offset:tensor_lists[d][tensor_index].storage_offset() * tensor_lists[d][tensor_index].element_size()
atIndex:d * kmaxTensors + 0];
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageWrite | MTLResourceUsageRead];
}
[tensorArgumentEncoder setBuffer:getMTLBufferStorage(state_steps[tensor_index])
offset:state_steps[tensor_index].storage_offset() * state_steps[tensor_index].element_size()
atIndex:depth * kmaxTensors + 0];
[computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
tensor_loc = 1;
}
}
}
}
if (threadgroup_loc != 0) {
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
[computeEncoder setBytes:&lr_lv length:sizeof(float) atIndex:2];
[computeEncoder setBytes:&beta1_lv length:sizeof(float) atIndex:3];
[computeEncoder setBytes:&beta2_lv length:sizeof(float) atIndex:4];
[computeEncoder setBytes:&weight_decay_lv length:sizeof(float) atIndex:5];
[computeEncoder setBytes:&eps_lv length:sizeof(float) atIndex:6];
[computeEncoder setBytes:&maximize_lv length:sizeof(uint8_t) atIndex:7];
MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1);
uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup];
MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1);
[computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];
}
getMPSProfiler().endProfileKernel(fusedOptimizerPSO);
}
});
}
} // namespace mps
} // namespace at::native

View File

@ -6185,12 +6185,12 @@
CompositeExplicitAutogradNonFunctional: _nested_view_from_buffer_copy
autogen: _nested_view_from_buffer_copy.out
- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a)
- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a)
variants: function
device_check: NoCheck
dispatch: {}
- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor
- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor
variants: function
device_check: NoCheck
tags: view_copy
@ -6227,16 +6227,6 @@
device_check: NoCheck
dispatch: {}
- func: _nested_get_min_seqlen(Tensor self) -> Tensor
variants: function
device_check: NoCheck
dispatch: {}
- func: _nested_get_max_seqlen(Tensor self) -> Tensor
variants: function
device_check: NoCheck
dispatch: {}
- func: _nested_get_jagged_dummy(Tensor any) -> Tensor
category_override: dummy
dispatch: {}
@ -13797,10 +13787,16 @@
- func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots)
python_module: linalg
variants: function
dispatch:
CompositeImplicitAutograd: linalg_lu_factor
MPS: linalg_lu_factor_mps
- func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots)
python_module: linalg
variants: function
dispatch:
CompositeImplicitAutograd: linalg_lu_factor_out
MPS: linalg_lu_factor_out_mps
- func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info)
python_module: linalg
@ -15575,6 +15571,7 @@
dispatch:
CPU: _fused_adam_kernel_cpu_
CUDA: _fused_adam_kernel_cuda_
MPS: _fused_adam_kernel_mps_
autogen: _fused_adam, _fused_adam.out
- func: _fused_adam_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
@ -15593,6 +15590,7 @@
dispatch:
CPU: _fused_adamw_kernel_cpu_
CUDA: _fused_adamw_kernel_cuda_
MPS: _fused_adamw_kernel_mps_
autogen: _fused_adamw, _fused_adamw.out
- func: _fused_adamw_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()

View File

@ -264,7 +264,7 @@ def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]:
batch_sizes = [2, 8, 16]
num_heads = [16]
q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)]
head_dims = [64, 128, 256]
head_dims = [64, 128]
dtypes = [
torch.bfloat16,
]
@ -302,8 +302,6 @@ def main(dynamic: bool, calculate_bwd: bool):
results.append(
Experiment(config, run_single_experiment(config, dynamic=dynamic))
)
for config in tqdm(generate_experiment_configs(calculate_bwd)):
results.append(Experiment(config, run_single_experiment(config)))
print_results(results)

View File

@ -4,6 +4,7 @@
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <c10/core/Device.h>

View File

@ -958,6 +958,10 @@ class DeviceCachingAllocator {
}
}
void recordAnnotation(const std::shared_ptr<GatheredContext>& name) {
record_trace(TraceEntry::USER_DEFINED, 0, 0, nullptr, 0, name);
}
bool isHistoryEnabled() {
return record_history;
}
@ -3026,6 +3030,12 @@ class NativeCachingAllocator : public CUDAAllocator {
}
}
void recordAnnotation(const std::shared_ptr<GatheredContext>& name) override {
for (auto& allocator : device_allocator) {
allocator->recordAnnotation(name);
}
}
bool isHistoryEnabled() override {
c10::DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));

View File

@ -170,8 +170,9 @@ struct TraceEntry {
SEGMENT_UNMAP, // unmap part of a segment (used with expandable segments)
SNAPSHOT, // a call to snapshot, used to correlate memory snapshots to trace
// events
OOM // the allocator threw an OutOfMemoryError (addr_ is the amount of free
// bytes reported by cuda)
OOM, // the allocator threw an OutOfMemoryError (addr_ is the amount of free
// bytes reported by cuda)
USER_DEFINED // a call made from user defined API such as record_function
};
TraceEntry(
Action action,
@ -289,6 +290,7 @@ class CUDAAllocator : public Allocator {
CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
RecordContext when) = 0;
virtual void recordAnnotation(const std::shared_ptr<GatheredContext>& name){};
virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0;
// Attached AllocatorTraceTracker callbacks will be called while the
@ -428,6 +430,10 @@ inline void recordHistory(
enabled, context_recorder, alloc_trace_max_entries, when);
}
inline void recordAnnotation(const std::shared_ptr<GatheredContext>& name) {
return get()->recordAnnotation(name);
}
inline bool isHistoryEnabled() {
return get()->isHistoryEnabled();
}

View File

@ -750,6 +750,9 @@ if(BUILD_LIBTORCHLESS)
find_library(TORCH_XPU_LIB torch_xpu PATHS $ENV{LIBTORCH_LIB_PATH} NO_DEFAULT_PATH)
endif()
add_subdirectory(../torch torch)
# ---[ Torch python bindings build
set(TORCH_PYTHON_COMPILE_OPTIONS ${TORCH_PYTHON_COMPILE_OPTIONS} PARENT_SCOPE)
set(TORCH_PYTHON_LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS} PARENT_SCOPE)
else()
set(TORCH_LIB torch)
set(TORCH_CPU_LIB torch_cpu)
@ -1270,12 +1273,10 @@ install(FILES
${PROJECT_BINARY_DIR}/TorchConfig.cmake
DESTINATION share/cmake/Torch)
# ---[ Torch python bindings build
add_subdirectory(../torch torch)
set(TORCH_PYTHON_COMPILE_OPTIONS ${TORCH_PYTHON_COMPILE_OPTIONS} PARENT_SCOPE)
set(TORCH_PYTHON_LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS} PARENT_SCOPE)
# ==========================================================
# END formerly-libtorch flags
# ==========================================================

View File

@ -35,3 +35,6 @@ torch.utils.checkpoint
.. autofunction:: checkpoint
.. autofunction:: checkpoint_sequential
.. autofunction:: set_checkpoint_debug_enabled
.. autoclass:: CheckpointPolicy
.. autoclass:: SelectiveCheckpointContext
.. autofunction:: create_selective_checkpoint_contexts

View File

@ -2713,6 +2713,7 @@ coverage_ignore_classes = [
"GuardOnDataDependentSymNode",
"PendingUnbackedSymbolNotFound",
"LoggingShapeGuardPrinter",
"SymExprPrinter",
"RelaxedUnspecConstraint",
"RuntimeAssert",
"ShapeGuardPrinter",

View File

@ -719,3 +719,5 @@ API Reference
:members:
.. automodule:: torch.export.custom_obj
.. automodule:: torch.export.experimental

View File

@ -0,0 +1,39 @@
.. _mps_environment_variables:
MPS Environment Variables
==========================
**PyTorch Environment Variables**
.. list-table::
:header-rows: 1
* - Variable
- Description
* - ``PYTORCH_DEBUG_MPS_ALLOCATOR``
- If set to ``1``, set allocator logging level to verbose.
* - ``PYTORCH_MPS_HIGH_WATERMARK_RATIO``
- High watermark ratio for MPS allocator. By default, it is set to 1.7.
* - ``PYTORCH_MPS_LOW_WATERMARK_RATIO``
- Low watermark ratio for MPS allocator. By default, it is set to 1.4 if the memory is unified and set to 1.0 if the memory is discrete.
* - ``PYTORCH_MPS_PREFER_METAL``
- If set to ``1``, force using metal kernels instead of using MPS Graph APIs. For now this is only used for matmul op.
* - ``PYTORCH_ENABLE_MPS_FALLBACK``
- If set to ``1``, full back operations to CPU when MPS does not support them.
.. note::
**high watermark ratio** is a hard limit for the total allowed allocations
- `0.0` : disables high watermark limit (may cause system failure if system-wide OOM occurs)
- `1.0` : recommended maximum allocation size (i.e., device.recommendedMaxWorkingSetSize)
- `>1.0`: allows limits beyond the device.recommendedMaxWorkingSetSize
e.g., value 0.95 means we allocate up to 95% of recommended maximum
allocation size; beyond that, the allocations would fail with OOM error.
**low watermark ratio** is a soft limit to attempt limiting memory allocations up to the lower watermark
level by garbage collection or committing command buffers more frequently (a.k.a, adaptive commit).
Value between 0 to m_high_watermark_ratio (setting 0.0 disables adaptive commit and garbage collection)
e.g., value 0.9 means we 'attempt' to limit allocations up to 90% of recommended maximum
allocation size.

View File

@ -164,10 +164,10 @@ horizontally and fused implementations as fusing vertically on top of that.
In general, the performance ordering of the 3 implementations is fused > foreach > for-loop.
So when applicable, we default to foreach over for-loop. Applicable means the foreach
implementation is available, the user has not specified any implementation-specific kwargs
(e.g., fused, foreach, differentiable), and all tensors are native and on CUDA. Note that
while fused should be even faster than foreach, the implementations are newer and we would
like to give them more bake-in time before flipping the switch everywhere. You are welcome
to try them out though!
(e.g., fused, foreach, differentiable), and all tensors are native. Note that while fused
should be even faster than foreach, the implementations are newer and we would like to give
them more bake-in time before flipping the switch everywhere. We summarize the stability status
for each implementation on the second table below, you are welcome to try them out though!
Below is a table showing the available and default implementations of each algorithm:
@ -177,7 +177,7 @@ Below is a table showing the available and default implementations of each algor
:delim: ;
:class:`Adadelta`;foreach;yes;no
:class:`Adagrad`;foreach;yes;no
:class:`Adagrad`;foreach;yes;yes (cpu only)
:class:`Adam`;foreach;yes;yes
:class:`AdamW`;foreach;yes;yes
:class:`SparseAdam`;for-loop;no;no
@ -188,7 +188,28 @@ Below is a table showing the available and default implementations of each algor
:class:`RAdam`;foreach;yes;no
:class:`RMSprop`;foreach;yes;no
:class:`Rprop`;foreach;yes;no
:class:`SGD`;foreach;yes;no
:class:`SGD`;foreach;yes;yes (CPU and CUDA only)
Below table is showing the stability status for fused implementations:
.. csv-table::
:header: "Algorithm", "CPU", "CUDA", "MPS"
:widths: 25, 25, 25, 25
:delim: ;
:class:`Adadelta`;unsupported;unsupported;unsupported
:class:`Adagrad`;beta;unsupported;unsupported
:class:`Adam`;beta;stable;beta
:class:`AdamW`;beta;stable;beta
:class:`SparseAdam`;unsupported;unsupported;unsupported
:class:`Adamax`;unsupported;unsupported;unsupported
:class:`ASGD`;unsupported;unsupported;unsupported
:class:`LBFGS`;unsupported;unsupported;unsupported
:class:`NAdam`;unsupported;unsupported;unsupported
:class:`RAdam`;unsupported;unsupported;unsupported
:class:`RMSprop`;unsupported;unsupported;unsupported
:class:`Rprop`;unsupported;unsupported;unsupported
:class:`SGD`;beta;beta;unsupported
How to adjust learning rate
---------------------------

View File

@ -21,6 +21,7 @@ If you find anything in this documentation that is missing, incorrect, or could
threading_environment_variables
cuda_environment_variables
mps_environment_variables
debugging_environment_variables
miscellaneous_environment_variables
logging

View File

@ -199,7 +199,6 @@
# Builds pytorch as a wheel using libtorch.so from a seperate wheel
import os
import pkgutil
import sys
if sys.platform == "win32" and sys.maxsize.bit_length() == 31:
@ -210,19 +209,6 @@ if sys.platform == "win32" and sys.maxsize.bit_length() == 31:
import platform
def _get_package_path(package_name):
loader = pkgutil.find_loader(package_name)
if loader:
# The package might be a namespace package, so get_data may fail
try:
file_path = loader.get_filename()
return os.path.dirname(file_path)
except AttributeError:
pass
return None
BUILD_LIBTORCH_WHL = os.getenv("BUILD_LIBTORCH_WHL", "0") == "1"
BUILD_PYTHON_ONLY = os.getenv("BUILD_PYTHON_ONLY", "0") == "1"
@ -237,6 +223,7 @@ if sys.version_info < python_min_version:
import filecmp
import glob
import importlib
import importlib.util
import json
import shutil
import subprocess
@ -253,15 +240,24 @@ from setuptools.dist import Distribution
from tools.build_pytorch_libs import build_caffe2
from tools.generate_torch_version import get_torch_version
from tools.setup_helpers.cmake import CMake
from tools.setup_helpers.env import (
build_type,
IS_DARWIN,
IS_LINUX,
IS_WINDOWS,
LIBTORCH_PKG_NAME,
)
from tools.setup_helpers.env import build_type, IS_DARWIN, IS_LINUX, IS_WINDOWS
from tools.setup_helpers.generate_linker_script import gen_linker_script
def _get_package_path(package_name):
spec = importlib.util.find_spec(package_name)
if spec:
# The package might be a namespace package, so get_data may fail
try:
loader = spec.loader
if loader is not None:
file_path = loader.get_filename() # type: ignore[attr-defined]
return os.path.dirname(file_path)
except AttributeError:
pass
return None
# set up appropriate env variables
if BUILD_LIBTORCH_WHL:
# Set up environment variables for ONLY building libtorch.so and not libtorch_python.so
@ -271,7 +267,7 @@ if BUILD_LIBTORCH_WHL:
if BUILD_PYTHON_ONLY:
os.environ["BUILD_LIBTORCHLESS"] = "ON"
os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path(LIBTORCH_PKG_NAME)}/lib"
os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path('torch')}/lib"
################################################################################
# Parameters parsed from environment
@ -347,9 +343,12 @@ cmake_python_include_dir = sysconfig.get_path("include")
# Version, create_version_file, and package_name
################################################################################
DEFAULT_PACKAGE_NAME = LIBTORCH_PKG_NAME if BUILD_LIBTORCH_WHL else "torch"
package_name = os.getenv("TORCH_PACKAGE_NAME", "torch")
LIBTORCH_PKG_NAME = os.getenv("LIBTORCH_PACKAGE_NAME", "libtorch")
if BUILD_LIBTORCH_WHL:
package_name = LIBTORCH_PKG_NAME
package_name = os.getenv("TORCH_PACKAGE_NAME", DEFAULT_PACKAGE_NAME)
package_type = os.getenv("PACKAGE_TYPE", "wheel")
version = get_torch_version()
report(f"Building wheel {package_name}-{version}")
@ -472,7 +471,6 @@ def build_deps():
check_submodules()
check_pydep("yaml", "pyyaml")
build_python = not BUILD_LIBTORCH_WHL
build_caffe2(
version=version,
cmake_python_library=cmake_python_library,
@ -1125,8 +1123,6 @@ def main():
raise RuntimeError(
"Conflict: 'BUILD_LIBTORCH_WHL' and 'BUILD_PYTHON_ONLY' can't both be 1. Set one to 0 and rerun."
)
# the list of runtime dependencies required by this built package
install_requires = [
"filelock",
"typing-extensions>=4.8.0",
@ -1141,7 +1137,7 @@ def main():
install_requires.append("setuptools")
if BUILD_PYTHON_ONLY:
install_requires.append(LIBTORCH_PKG_NAME)
install_requires.append(f"{LIBTORCH_PKG_NAME}=={get_torch_version()}")
use_prioritized_text = str(os.getenv("USE_PRIORITIZED_TEXT_FOR_LD", ""))
if (
@ -1190,7 +1186,6 @@ def main():
entry_points,
extra_install_requires,
) = configure_extension_build()
install_requires += extra_install_requires
extras_require = {
@ -1219,6 +1214,7 @@ def main():
"utils/data/*.pyi",
"utils/data/datapipes/*.pyi",
"lib/*.pdb",
"lib/*shm*",
"lib/torch_shm_manager",
"lib/*.h",
"include/*.h",
@ -1383,15 +1379,15 @@ def main():
"utils/model_dump/*.mjs",
]
if BUILD_PYTHON_ONLY:
if not BUILD_LIBTORCH_WHL:
torch_package_data.extend(
[
"lib/libtorch_python*",
"lib/*shm*",
"lib/libtorch_global_deps*",
"lib/libtorch_python.so",
"lib/libtorch_python.dylib",
"lib/libtorch_python.dll",
]
)
else:
if not BUILD_PYTHON_ONLY:
torch_package_data.extend(
[
"lib/*.so*",
@ -1442,28 +1438,18 @@ def main():
"packaged/autograd/*",
"packaged/autograd/templates/*",
]
package_data = {
"torch": torch_package_data,
}
if BUILD_LIBTORCH_WHL:
modified_packages = []
for package in packages:
parts = package.split(".")
if parts[0] == "torch":
modified_packages.append(DEFAULT_PACKAGE_NAME + package[len("torch") :])
packages = modified_packages
package_dir = {LIBTORCH_PKG_NAME: "torch"}
torch_package_dir_name = LIBTORCH_PKG_NAME
package_data = {LIBTORCH_PKG_NAME: torch_package_data}
extensions = []
if not BUILD_LIBTORCH_WHL:
package_data["torchgen"] = torchgen_package_data
package_data["caffe2"] = [
"python/serialized_test/data/operator_test/*.zip",
]
else:
torch_package_dir_name = "torch"
package_dir = {}
package_data = {
"torch": torch_package_data,
"torchgen": torchgen_package_data,
"caffe2": [
"python/serialized_test/data/operator_test/*.zip",
],
}
# no extensions in BUILD_LIBTORCH_WHL mode
extensions = []
setup(
name=package_name,
@ -1481,7 +1467,6 @@ def main():
install_requires=install_requires,
extras_require=extras_require,
package_data=package_data,
package_dir=package_dir,
url="https://pytorch.org/",
download_url="https://github.com/pytorch/pytorch/tags",
author="PyTorch Team",

View File

@ -1970,6 +1970,7 @@
"EqualityConstraint",
"GuardOnDataDependentSymNode",
"LoggingShapeGuardPrinter",
"SymExprPrinter",
"RelaxedUnspecConstraint",
"RuntimeAssert",
"ShapeGuardPrinter",

View File

@ -43,6 +43,7 @@ from torch.testing._internal.common_fsdp import (
FSDPTestMultiThread,
MLP,
patch_post_backward,
patch_reshard,
patch_unshard,
)
from torch.testing._internal.common_utils import run_tests
@ -372,7 +373,7 @@ class TestFullyShardCommunication(FSDPTest):
)
class TestFullyShardBackwardPrefetch(FSDPTest):
class TestFullyShardPrefetch(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
@ -578,6 +579,193 @@ class TestFullyShardBackwardPrefetch(FSDPTest):
self.assertEqual(events, expected_events)
events.clear()
@skip_if_lt_x_gpu(2)
def test_set_modules_to_forward_prefetch(self):
n_layers = 4
reshard_after_forward = True
checkpoint_impl = "utils"
model, _, inp = self._init_transformer(
n_layers, reshard_after_forward, checkpoint_impl
)
def set_forward_prefetch(model: Transformer, num_to_prefetch: int) -> None:
# Use model-specific knowledge to configure forward prefetching:
# each transformer block (layer) prefetches for the next few
for i, layer in enumerate(model.layers):
if i >= len(model.layers) - num_to_prefetch:
break
layers_to_prefetch = [
model.layers[i + j] for j in range(1, num_to_prefetch + 1)
]
layer.set_modules_to_forward_prefetch(layers_to_prefetch)
events: List[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
reshard_with_record = self._get_reshard_with_record(
FSDPParamGroup.reshard, events
)
post_backward_with_record = self._get_post_backward_with_record(
FSDPParamGroup.post_backward, events
)
expected_backward_events = [
# Default backward prefetching
("unshard", "layers.3", TrainingState.PRE_BACKWARD),
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
("reshard", "layers.3", TrainingState.POST_BACKWARD),
("post_backward", "layers.3", TrainingState.POST_BACKWARD),
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
("reshard", "layers.2", TrainingState.POST_BACKWARD),
("post_backward", "layers.2", TrainingState.POST_BACKWARD),
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
("reshard", "layers.1", TrainingState.POST_BACKWARD),
("post_backward", "layers.1", TrainingState.POST_BACKWARD),
("reshard", "layers.0", TrainingState.POST_BACKWARD),
("post_backward", "layers.0", TrainingState.POST_BACKWARD),
("reshard", "", TrainingState.POST_BACKWARD),
("post_backward", "", TrainingState.POST_BACKWARD),
]
with patch_unshard(unshard_with_record), patch_reshard(
reshard_with_record
), patch_post_backward(post_backward_with_record):
set_forward_prefetch(model, num_to_prefetch=1)
loss = model(inp)
expected_forward_events = [
("unshard", "", TrainingState.FORWARD),
# `layers.i` prefetches `layers.i+1`
("unshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.1", TrainingState.FORWARD),
("reshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.1", TrainingState.FORWARD),
("unshard", "layers.3", TrainingState.FORWARD),
("reshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.3", TrainingState.FORWARD),
]
self.assertEqual(events, expected_forward_events)
events.clear()
loss.sum().backward()
self.assertEqual(events, expected_backward_events)
events.clear()
set_forward_prefetch(model, num_to_prefetch=2)
loss = model(inp)
expected_forward_events = [
("unshard", "", TrainingState.FORWARD),
# `layers.i` prefetches `layers.i+1` and `layers.i+2`
("unshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.1", TrainingState.FORWARD),
("unshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.3", TrainingState.FORWARD),
("reshard", "layers.1", TrainingState.FORWARD),
("reshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.3", TrainingState.FORWARD),
]
self.assertEqual(events, expected_forward_events)
events.clear()
loss.sum().backward()
self.assertEqual(events, expected_backward_events)
events.clear()
@skip_if_lt_x_gpu(2)
def test_set_modules_to_backward_prefetch(self):
n_layers = 4
reshard_after_forward = True
checkpoint_impl = "utils"
model, _, inp = self._init_transformer(
n_layers, reshard_after_forward, checkpoint_impl
)
def set_backward_prefetch(model: Transformer, num_to_prefetch: int) -> None:
# Use model-specific knowledge to configure backward prefetching:
# each transformer block (layer) prefetches for the previous few
for i, layer in enumerate(model.layers):
if i < num_to_prefetch:
continue
layers_to_prefetch = [
model.layers[i - j] for j in range(1, num_to_prefetch + 1)
]
layer.set_modules_to_backward_prefetch(layers_to_prefetch)
events: List[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
reshard_with_record = self._get_reshard_with_record(
FSDPParamGroup.reshard, events
)
post_backward_with_record = self._get_post_backward_with_record(
FSDPParamGroup.post_backward, events
)
expected_forward_events = [
# Default forward prefetching
("unshard", "", TrainingState.FORWARD), # root
("unshard", "layers.0", TrainingState.FORWARD),
("reshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.1", TrainingState.FORWARD),
("reshard", "layers.1", TrainingState.FORWARD),
("unshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.2", TrainingState.FORWARD),
("unshard", "layers.3", TrainingState.FORWARD),
("reshard", "layers.3", TrainingState.FORWARD),
]
with patch_unshard(unshard_with_record), patch_reshard(
reshard_with_record
), patch_post_backward(post_backward_with_record):
set_backward_prefetch(model, num_to_prefetch=1)
loss = model(inp)
self.assertEqual(events, expected_forward_events)
events.clear()
loss.sum().backward()
expected_backward_events = [
# Root prefetches `layers.3` per default
("unshard", "layers.3", TrainingState.PRE_BACKWARD),
# `layers.i` prefetches for `layers.i-1` (same as default)
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
("reshard", "layers.3", TrainingState.POST_BACKWARD),
("post_backward", "layers.3", TrainingState.POST_BACKWARD),
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
("reshard", "layers.2", TrainingState.POST_BACKWARD),
("post_backward", "layers.2", TrainingState.POST_BACKWARD),
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
("reshard", "layers.1", TrainingState.POST_BACKWARD),
("post_backward", "layers.1", TrainingState.POST_BACKWARD),
("reshard", "layers.0", TrainingState.POST_BACKWARD),
("post_backward", "layers.0", TrainingState.POST_BACKWARD),
("reshard", "", TrainingState.POST_BACKWARD),
("post_backward", "", TrainingState.POST_BACKWARD),
]
self.assertEqual(events, expected_backward_events)
events.clear()
set_backward_prefetch(model, num_to_prefetch=2)
loss = model(inp)
self.assertEqual(events, expected_forward_events)
events.clear()
loss.sum().backward()
expected_backward_events = [
# Root prefetches `layers.3` per default
("unshard", "layers.3", TrainingState.PRE_BACKWARD),
# `layers.i` prefetches for `layers.i-1` and `layers.i-2`
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
("reshard", "layers.3", TrainingState.POST_BACKWARD),
("post_backward", "layers.3", TrainingState.POST_BACKWARD),
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
("reshard", "layers.2", TrainingState.POST_BACKWARD),
("post_backward", "layers.2", TrainingState.POST_BACKWARD),
("reshard", "layers.1", TrainingState.POST_BACKWARD),
("post_backward", "layers.1", TrainingState.POST_BACKWARD),
("reshard", "layers.0", TrainingState.POST_BACKWARD),
("post_backward", "layers.0", TrainingState.POST_BACKWARD),
("reshard", "", TrainingState.POST_BACKWARD),
("post_backward", "", TrainingState.POST_BACKWARD),
]
self.assertEqual(events, expected_backward_events)
events.clear()
def _init_transformer(
self,
n_layers: int,
@ -614,6 +802,21 @@ class TestFullyShardBackwardPrefetch(FSDPTest):
return unshard_with_record
def _get_reshard_with_record(
self, orig_reshard: Callable, events: List[EventType]
) -> Callable:
def reshard_with_record(self, *args, **kwargs):
nonlocal events
if (
self._training_state == TrainingState.FORWARD
and not self._reshard_after_forward
): # skip no-ops
return
events.append(("reshard", self._module_fqn, self._training_state))
return orig_reshard(self, *args, **kwargs)
return reshard_with_record
def _get_post_backward_with_record(
self, orig_post_backward: Callable, events: List[EventType]
) -> Callable:

View File

@ -1,16 +1,30 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import itertools
import unittest
import torch
import torch._dynamo.testing
from torch import nn
from torch._dynamo import compiled_autograd
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._composable.fsdp._fsdp_common import TrainingState
from torch.distributed._composable.fsdp._fsdp_init import (
_get_managed_modules,
_get_managed_states,
)
from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
from torch.distributed._tensor import init_device_mesh
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, MLP
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
)
from torch.utils._triton import has_triton
@ -64,6 +78,10 @@ class TestFullyShardCompileCompute(FSDPTest):
class TestFullyShardCompile(FSDPTest):
@property
def world_size(self) -> int:
return min(2, torch.cuda.device_count())
def test_dynamo_trace_use_training_state(self):
torch._dynamo.reset()
# Construct a dummy FSDPParamGroup, since we just want to test the `use_training_state` ctx manager.
@ -100,6 +118,174 @@ class TestFullyShardCompile(FSDPTest):
self.assertEqual(cnt.op_count, 1)
self.assertEqual(len(cnt.graphs), 1)
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
@torch._functorch.config.patch(recompute_views=True)
def _test_traceable_fsdp(
self, model_init_fn, input_creation_fn, backend, fullgraph
):
n_iter = 10
def compiler_fn(compiled_autograd_backend):
def _fn(gm):
# fullgraph=True because graph-break in Compiled Autograd BWD graph is not supported by Traceable FSDP2 yet
# (main difficulty comes from queue_callback not working well when BWD has graph break).
return torch.compile(
gm, backend=compiled_autograd_backend, fullgraph=True
)
return _fn
def run_all_iters(model, optim, compiled_autograd_backend=None):
torch.manual_seed(42)
losses = []
for i in range(n_iter):
optim.zero_grad(set_to_none=True)
inp = input_creation_fn()
if compiled_autograd_backend is not None:
maybe_compiled_autograd_ctx = compiled_autograd.enable(
compiler_fn(compiled_autograd_backend)
)
else:
maybe_compiled_autograd_ctx = contextlib.nullcontext()
with maybe_compiled_autograd_ctx:
out = model(inp)
loss = out.sum()
losses.append(loss.item())
loss.backward()
optim.step()
torch.cuda.synchronize()
return losses
def test_compiled():
model, optim = model_init_fn()
# FSDP2 does lazy init using 1st run, so run it once to init using eager mode
run_all_iters(model, optim, 1)
model_compiled = torch.compile(model, backend=backend, fullgraph=True)
res = run_all_iters(
model_compiled, optim, compiled_autograd_backend=backend
)
optim.zero_grad(set_to_none=True)
return res
def test_eager():
model, optim = model_init_fn()
# FSDP2 does lazy init using 1st run, so run it once to init using eager mode
run_all_iters(model, optim, 1)
res = run_all_iters(model, optim)
optim.zero_grad(set_to_none=True)
return res
losses_compiled = test_compiled()
losses_eager = test_eager()
for loss_compiled, loss_eager in zip(losses_compiled, losses_eager):
self.assertTrue(
torch.allclose(
torch.tensor(loss_compiled), torch.tensor(loss_eager), rtol=1e-3
),
f"{loss_compiled} vs {loss_eager}",
)
def _create_simple_mlp_factory_fns(self):
hidden_dim = 16
def model_init_fn():
torch.manual_seed(0)
fsdp_config = {}
model = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
)
fully_shard(model, reshard_after_forward=True, **fsdp_config)
optim = torch.optim.SGD(model.parameters(), lr=1e-6)
return model, optim
def input_creation_fn():
torch.manual_seed(0)
inp = torch.randn((2, hidden_dim), device="cuda", requires_grad=False)
return inp
return model_init_fn, input_creation_fn
@skip_if_lt_x_gpu(2)
def test_simple_mlp_fullgraph_backend_eager(self):
self._test_traceable_fsdp(
*self._create_simple_mlp_factory_fns(), "eager", fullgraph=True
)
@skip_if_lt_x_gpu(2)
def test_simple_mlp_fullgraph_backend_aot_eager(self):
self._test_traceable_fsdp(
*self._create_simple_mlp_factory_fns(), "aot_eager", fullgraph=True
)
@unittest.expectedFailure
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_simple_mlp_fullgraph_backend_inductor(self):
self._test_traceable_fsdp(
*self._create_simple_mlp_factory_fns(), "inductor", fullgraph=True
)
def _create_transformer_factory_fns(self):
hidden_dim = 16
def model_init_fn():
torch.manual_seed(0)
fsdp_config = {}
mesh = init_device_mesh("cuda", (self.world_size,))
model_args = ModelArgs(
dim=hidden_dim,
n_layers=2,
n_heads=1,
vocab_size=1024,
)
model = Transformer(model_args)
for layer_id, mod in enumerate(model.layers):
fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
model.layers[layer_id] = mod
model = fully_shard(
model, mesh=mesh, reshard_after_forward=True, **fsdp_config
)
optim = torch.optim.SGD(model.parameters(), lr=1e-6)
return model, optim
def input_creation_fn():
torch.manual_seed(0)
inp = torch.zeros(
(2, hidden_dim),
device="cuda",
requires_grad=False,
dtype=torch.long,
)
return inp
return model_init_fn, input_creation_fn
@skip_if_lt_x_gpu(2)
def test_transformer_fullgraph_backend_eager(self):
self._test_traceable_fsdp(
*self._create_transformer_factory_fns(), "eager", fullgraph=True
)
@skip_if_lt_x_gpu(2)
def test_transformer_fullgraph_backend_aot_eager(self):
self._test_traceable_fsdp(
*self._create_transformer_factory_fns(), "aot_eager", fullgraph=True
)
@unittest.expectedFailure
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_transformer_fullgraph_backend_inductor(self):
self._test_traceable_fsdp(
*self._create_transformer_factory_fns(), "inductor", fullgraph=True
)
if __name__ == "__main__":
run_tests()

View File

@ -1,5 +1,6 @@
# Owner(s): ["oncall: distributed"]
import functools
from typing import Callable
import torch
@ -7,6 +8,7 @@ import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._tensor.experimental import implicit_replication
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
FSDPTest,
@ -23,15 +25,6 @@ class TestFullyShardOverlap(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_fully_shard_training_overlap(self):
class LinearWithSleep(nn.Module):
def __init__(self, dim: int, sleep_ms: int):
super().__init__()
self.weight = nn.Parameter(torch.randn((dim, dim)))
self.sleep_ms = sleep_ms
def forward(self, x: torch.Tensor) -> torch.Tensor:
return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms))
torch.manual_seed(42)
# Use non-trivial comm. time but still shorter than compute time
@ -44,7 +37,7 @@ class TestFullyShardOverlap(FSDPTest):
fully_shard(model, reshard_after_forward=True)
orig_all_gather_into_tensor = dist.all_gather_into_tensor
orig_reduce_scatter = dist.reduce_scatter_tensor
orig_reduce_scatter_tensor = dist.reduce_scatter_tensor
comm_stream = torch.cuda.Stream()
def delay_collective():
@ -61,7 +54,7 @@ class TestFullyShardOverlap(FSDPTest):
def delayed_reduce_scatter(*args, **kwargs):
delay_collective()
return orig_reduce_scatter(*args, **kwargs)
return orig_reduce_scatter_tensor(*args, **kwargs)
inp = torch.randn((2, dim), device="cuda")
loss = model(inp).sum() # warmup CUDA and allocator
@ -92,6 +85,63 @@ class TestFullyShardOverlap(FSDPTest):
)
self.assertLessEqual(fwd_bwd_time, expected_fwd_time + expected_bwd_time)
@skip_if_lt_x_gpu(2)
def test_fully_shard_post_optim_event_overlap(self):
torch.manual_seed(42)
# Use non-trivial comm. time but still shorter than compute time
dim, compute_sleep_ms, comm_sleep_ms = (4, 25, 10)
# Define the model to have a high-compute linear followed by a
# low-compute linear, where only the low-compute linear uses FSDP
model = nn.Sequential(
LinearWithSleep(dim, compute_sleep_ms), nn.Linear(dim, dim)
).cuda()
fully_shard(model[1], reshard_after_forward=False)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
orig_all_gather_into_tensor = dist.all_gather_into_tensor
def delayed_all_gather(*args, **kwargs):
torch.cuda._sleep(int(comm_sleep_ms * get_cycles_per_ms()))
return orig_all_gather_into_tensor(*args, **kwargs)
inp = torch.randn((2, dim), device="cuda")
def run_train_steps(num_iters: int, use_post_optim_event: bool):
for _ in range(num_iters):
optim.zero_grad()
with patch_all_gather(delayed_all_gather):
loss = model(inp).sum()
loss.backward()
with implicit_replication():
optim.step()
if use_post_optim_event:
post_optim_event = torch.cuda.current_stream().record_event()
model[1].set_post_optim_event(post_optim_event)
run_train_steps(1, False) # warmup CUDA and allocator
num_iters = 5
baseline_time = self._time_fn(
functools.partial(run_train_steps, num_iters, False)
)
test_time = self._time_fn(functools.partial(run_train_steps, num_iters, True))
buffer_ms = 4 # CPU delays and copies
# Baseline: FSDP all-gather is exposed since the FSDP module waits for
# the current stream and hence the high-compute linear
self.assertLessEqual(
baseline_time,
num_iters * (3 * compute_sleep_ms + comm_sleep_ms + buffer_ms),
)
# Test: FSDP all-gather is overlapped with the high-compute linear
# since the FSDP module only waits for the post-optim event (except on
# the 1st iteration when no event has been recorded)
expected_test_time = (
num_iters * (3 * compute_sleep_ms + buffer_ms) + comm_sleep_ms
)
self.assertLessEqual(test_time, expected_test_time)
self.assertGreater(baseline_time, expected_test_time)
def _time_fn(self, fn: Callable):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
@ -123,5 +173,15 @@ class Matmul(torch.autograd.Function):
return grad_input, grad_weight, None
class LinearWithSleep(nn.Module):
def __init__(self, dim: int, sleep_ms: int):
super().__init__()
self.weight = nn.Parameter(torch.randn((dim, dim)))
self.sleep_ms = sleep_ms
def forward(self, x: torch.Tensor) -> torch.Tensor:
return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms))
if __name__ == "__main__":
run_tests()

View File

@ -3,6 +3,7 @@
import contextlib
import copy
import functools
import itertools
import unittest
from typing import Iterable, List, Tuple, Type, Union
@ -337,7 +338,6 @@ class TestFullyShard1DTrainingCore(FSDPTest):
return
assert device_type in ("cuda", "cpu"), f"{device_type}"
torch.manual_seed(42)
lin_dim = 32
vocab_size = 1024
model_args = ModelArgs(
n_layers=3,
@ -494,6 +494,85 @@ class TestFullyShard1DTrainingCore(FSDPTest):
_optim.step()
self.assertEqual(losses[0], losses[1])
@skip_if_lt_x_gpu(2)
def test_explicit_prefetching(self):
torch.manual_seed(42)
model_args = ModelArgs(n_layers=8, dropout_p=0.0)
model = Transformer(model_args)
ref_model = replicate(copy.deepcopy(model).cuda())
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
for layer in itertools.chain(model.layers, [model]):
fully_shard(layer)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
num_to_forward_prefetch = num_to_backward_prefetch = 2
for i, layer in enumerate(model.layers):
if i >= len(model.layers) - num_to_forward_prefetch:
break
layers_to_prefetch = [
model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1)
]
layer.set_modules_to_forward_prefetch(layers_to_prefetch)
for i, layer in enumerate(model.layers):
if i < num_to_backward_prefetch:
continue
layers_to_prefetch = [
model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1)
]
layer.set_modules_to_backward_prefetch(layers_to_prefetch)
torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
for iter_idx in range(10):
losses: List[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad()
losses.append(_model(inp).sum())
losses[-1].backward()
_optim.step()
self.assertEqual(losses[0], losses[1])
@skip_if_lt_x_gpu(2)
def test_post_optim_event(self):
torch.manual_seed(42)
model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args)
ref_model = replicate(copy.deepcopy(model).cuda())
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
for layer in itertools.chain(model.layers, [model]):
fully_shard(layer)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
def step_post_hook(
fsdp_module: FSDPModule, opt: torch.optim.Optimizer, args, kwargs
) -> None:
post_optim_event = torch.cuda.current_stream().record_event()
fsdp_module.set_post_optim_event(post_optim_event)
optim.register_step_post_hook(functools.partial(step_post_hook, model))
torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
# Track all losses and check for equality at the end to avoid a CPU
# sync point after each iteration
ref_losses: List[torch.Tensor] = []
losses: List[torch.Tensor] = []
for iter_idx in range(10):
ref_optim.zero_grad()
ref_losses.append(ref_model(inp).sum())
ref_losses[-1].backward()
ref_optim.step()
for iter_idx in range(10):
optim.zero_grad()
losses.append(model(inp).sum())
losses[-1].backward()
optim.step()
# Sleep after the optimizer step to allow CPU to run ahead into the
# next iteration's forward, exercising the post-optim stream sync
torch.cuda._sleep(int(25 * get_cycles_per_ms()))
for ref_loss, loss in zip(ref_losses, losses):
self.assertEqual(ref_loss, loss)
class TestFullyShard1DTrainingCompose(FSDPTest):
@property

View File

@ -279,12 +279,16 @@ class ReplicateTest(MultiProcessTestCase):
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
return code
def test_bucketing_coalesced_op(self):
torch._inductor.config._fuse_ddp_communication_passes = [
@torch._inductor.config.patch(
_fuse_ddp_communication_passes=[
"fuse_ddp_with_coalesced_op",
"schedule_comm_wait",
]
)
# todo: This pass mucks things up since Inductor thinks its inference
# and can apply this. Should turn off these passes in compiled autograd
@torch._inductor.config.patch(reorder_for_locality=False)
def test_bucketing_coalesced_op(self):
# Gradient is None
code = self._test_bucketing()
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
@ -311,12 +315,16 @@ class ReplicateTest(MultiProcessTestCase):
fc.run(code)
def test_bucketing_concat_op(self):
torch._inductor.config._fuse_ddp_communication_passes = [
@torch._inductor.config.patch(
_fuse_ddp_communication_passes=[
"fuse_ddp_with_concat_op",
"schedule_comm_wait",
]
)
# todo: This pass mucks things up since Inductor thinks its inference
# and can apply this. Should turn off these passes in compiled autograd
@torch._inductor.config.patch(reorder_for_locality=False)
def test_bucketing_concat_op(self):
# Gradient is None
code = self._test_bucketing()
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)

View File

@ -116,6 +116,9 @@ class TestCommMode(TestCase):
@requires_nccl()
def test_comm_mode_with_c10d(self):
if not torch.cuda.is_available():
return
world_pg = self.world_pg
inp = torch.rand(2, 8, 16).cuda()

View File

@ -33,7 +33,11 @@ from torch.distributed.checkpoint.state_dict import (
set_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
StateDictType,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.optim import _apply_optimizer_in_backward
from torch.nn.parallel import DistributedDataParallel as DDP
@ -70,7 +74,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
@property
def world_size(self) -> int:
return 2
return min(4, torch.cuda.device_count())
def _test_save_load(
self,
@ -567,55 +571,71 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
set_model_state_dict(ddp_model, get_model_state_dict(ddp_model))
self.assertEqual(model.state_dict(), get_model_state_dict(ddp_model))
@with_comms
@skip_if_lt_x_gpu(2)
def test_broadcast_from_rank0(self) -> None:
def inner_test(wrapper):
model = CompositeParamModel(device=torch.device("cuda"))
optim = torch.optim.Adam(model.parameters())
fsdp_model = wrapper(copy.deepcopy(model))
fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
def _test_broadcast_from_rank0(self, wrapper) -> None:
model = CompositeParamModel(device=torch.device("cuda"))
optim = torch.optim.Adam(model.parameters())
fsdp_model = wrapper(copy.deepcopy(model))
fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
batch = torch.rand(8, 100, device="cuda")
model(batch).sum().backward()
optim.step()
states, optim_states = get_state_dict(model, optim)
batch = torch.rand(8, 100, device="cuda")
model(batch).sum().backward()
optim.step()
states, optim_states = get_state_dict(model, optim)
fsdp_model(batch).sum().backward()
fsdp_optim.step()
fsdp_model(batch).sum().backward()
fsdp_optim.step()
def check(equal):
fsdp_states = get_model_state_dict(
fsdp_model,
options=StateDictOptions(full_state_dict=True),
)
fsdp_optim_states = get_optimizer_state_dict(
fsdp_model,
fsdp_optim,
options=StateDictOptions(full_state_dict=True),
)
if equal:
self.assertEqual(states, fsdp_states)
self.assertEqual(optim_states, fsdp_optim_states)
else:
self.assertNotEqual(states, fsdp_states)
self.assertNotEqual(optim_states, fsdp_optim_states)
check(equal=True)
fsdp_model(batch).sum().backward()
fsdp_optim.step()
check(equal=False)
# Drop the states to simulate loading from rank0
if dist.get_rank() > 0:
load_states = {}
load_states2 = {}
load_optim_states = {}
def check(equal):
fsdp_states = get_model_state_dict(
fsdp_model,
options=StateDictOptions(full_state_dict=True),
)
fsdp_optim_states = get_optimizer_state_dict(
fsdp_model,
fsdp_optim,
options=StateDictOptions(full_state_dict=True),
)
if equal:
self.assertEqual(states, fsdp_states)
self.assertEqual(optim_states, fsdp_optim_states)
else:
load_states = copy.deepcopy(states)
load_states2 = copy.deepcopy(states)
load_optim_states = copy.deepcopy(optim_states)
self.assertNotEqual(states, fsdp_states)
self.assertNotEqual(optim_states, fsdp_optim_states)
check(equal=True)
fsdp_model(batch).sum().backward()
fsdp_optim.step()
check(equal=False)
# Drop the states to simulate loading from rank0
if dist.get_rank() > 0:
load_states = {}
load_states2 = {}
load_optim_states = {}
else:
load_states = copy.deepcopy(states)
load_states2 = copy.deepcopy(states)
load_optim_states = copy.deepcopy(optim_states)
set_model_state_dict(
fsdp_model,
model_state_dict=load_states,
options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True),
)
set_optimizer_state_dict(
fsdp_model,
fsdp_optim,
optim_state_dict=load_optim_states,
options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True),
)
check(equal=True)
# Verify the `strict` flag.
load_states = load_states2
if load_states:
key = next(iter(load_states.keys()))
load_states.pop(key)
with self.assertRaisesRegex(RuntimeError, "Missing key"):
set_model_state_dict(
fsdp_model,
model_state_dict=load_states,
@ -623,30 +643,10 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
broadcast_from_rank0=True, full_state_dict=True
),
)
set_optimizer_state_dict(
fsdp_model,
fsdp_optim,
optim_state_dict=load_optim_states,
options=StateDictOptions(
broadcast_from_rank0=True, full_state_dict=True
),
)
check(equal=True)
# Verify the `strict` flag.
load_states = load_states2
if load_states:
key = next(iter(load_states.keys()))
load_states.pop(key)
with self.assertRaisesRegex(RuntimeError, "Missing key"):
set_model_state_dict(
fsdp_model,
model_state_dict=load_states,
options=StateDictOptions(
broadcast_from_rank0=True, full_state_dict=True
),
)
@with_comms
@skip_if_lt_x_gpu(2)
def test_broadcast_from_rank0(self) -> None:
device_mesh = init_device_mesh("cuda", (self.world_size,))
self.run_subtests(
{
@ -655,7 +655,24 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
functools.partial(FSDP, device_mesh=device_mesh),
]
},
inner_test,
self._test_broadcast_from_rank0,
)
@with_comms
@skip_if_lt_x_gpu(4)
def test_broadcast_from_rank0_hsdp(self) -> None:
device_mesh = init_device_mesh("cuda", (2, self.world_size // 2))
self.run_subtests(
{
"wrapper": [
functools.partial(
FSDP,
device_mesh=device_mesh,
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
),
]
},
self._test_broadcast_from_rank0,
)
@with_comms
@ -851,6 +868,33 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
):
get_model_state_dict(model)
@with_comms
@skip_if_lt_x_gpu(2)
def test_shared_weight(self):
class TiedEmbeddingModel(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.decoder = nn.Linear(embedding_dim, vocab_size)
self.decoder.weight = self.embedding.weight # Tying weights
def forward(self, input):
input = (input * 10).to(torch.int)
embedded = self.embedding(input)
output = self.decoder(embedded)
return output
def init_model_optim():
device_mesh = init_device_mesh("cuda", (self.world_size,))
orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda"))
orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh)
dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-3)
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
self._test_save_load(init_model_optim)
class TestNoComm(MultiProcessTestCase):
def setUp(self) -> None:

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]
import sys
from unittest.mock import MagicMock
import torch
@ -123,5 +124,13 @@ class TestMedatadaIndex(TestCase):
find_state_dict_object(state_dict, MetadataIndex("st", [1]))
class TestTensorProperties(TestCase):
def test_create_from_tensor_correct_device(self):
t = torch.randn([10, 2], device="cpu")
t.is_pinned = MagicMock(return_value=True)
TensorProperties.create_from_tensor(t)
t.is_pinned.assert_called_with(device=torch.device("cpu"))
if __name__ == "__main__":
run_tests()

View File

@ -92,6 +92,38 @@ def get_model(
return m, inputs, outputs
class MutatingModel(nn.Module):
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
super().__init__()
self.ctx_manager = ctx_manager
self.net = nn.Sequential(
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
)
self.state = 1
def forward(self, inputs):
self.state = 2
return self.net(inputs) * self.state
def get_mutating_model(
device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None
):
m = MutatingModel(
in_feat=in_feat,
hidden_feat=hidden_feat,
out_feat=out_feat,
ctx_manager=ctx_manager,
).to(device)
m.apply(init_weights)
inputs = torch.rand(bsz, in_feat).to(device)
outputs = m(inputs)
return m, inputs, outputs
class ToyInnerModel(nn.Module):
def __init__(self):
super().__init__()
@ -484,6 +516,26 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@skip_if_lt_x_gpu(1)
def test_fsdp_setattr(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
m, inputs, correct_outputs = get_mutating_model(f"cuda:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
prof = torch._dynamo.utils.CompileProfiler()
fsdp_m = torch.compile(fsdp_m, backend=prof, fullgraph=False)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
FileCheck().check("Torchdynamo Profiler Report").check(
"Graph Breaks"
).check_not(
"setattr(FSDPManagedNNModuleVariable(MutatingModel), state, ...)"
).check_not(
"setattr(FSDPManagedNNModuleVariable(FullyShardedDataParallel), _is_root, ...)"
).run(
prof.report()
)
@skip_if_lt_x_gpu(1)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_fsdp_inductor(self):

View File

@ -60,8 +60,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
def test_broadcast_inductor(self):
"""
Testing if broadcast works correctly when using inductor
@ -94,8 +92,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
def test_allreduce_inductor(self):
"""
This is matmul/cat/allreduce is a pattern we aim to optimize.
@ -129,8 +125,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
def test_allreduce_inductor_cudagraph_trees(self):
"""
Tests whether cudagraph trees support all_reduce from nccl
@ -177,8 +171,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
def test_eager_allreduce_inductor_wait(self):
def eager_func(a, b, c, d, *, tag, ranks, group_size):
x = torch.matmul(a, b)
@ -218,8 +210,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
def test_inductor_allreduce_eager_wait(self):
def inductor_func(a, b, c, d, *, tag, ranks, group_size):
x = torch.matmul(a, b)
@ -256,8 +246,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
def test_allreduce_input_buffer_reuse(self):
def func(a, *, tag, ranks, group_size):
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
@ -275,8 +263,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
def test_permute_tensor(self):
def func(tensor, src_dst_pairs, *, tag, ranks, group_size):
return _functional_collectives.permute_tensor(
@ -304,8 +290,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
def test_allgather_output_buffer_reuse(self):
class Model(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
@ -329,8 +313,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
def test_allgather_contiguous_input(self):
class Model(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
@ -355,8 +337,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
def test_allgather_into_tensor_inductor(self):
"""
This is matmul/cat/allreduce is a pattern we aim to optimize.
@ -388,8 +368,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
def test_reduce_scatter_tensor_inductor(self):
def example(a, b, *, tag, ranks, group_size):
c = torch.matmul(a, b)
@ -418,8 +396,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
def test_all_to_all_single_inductor(self):
def example(
inp,
@ -488,8 +464,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
def test_all_to_all_single_inductor_split_sizes_none(self):
def example(inp, *, tag, ranks, group_size):
a2a = torch.ops.c10d_functional.all_to_all_single(

View File

@ -19,7 +19,11 @@ from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfRocm
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint
from torch.utils.checkpoint import (
checkpoint,
CheckpointPolicy,
create_selective_checkpoint_contexts,
)
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
requires_distributed = functools.partial(
@ -105,8 +109,11 @@ def op_count(gm):
def _get_custom_policy(no_recompute_list=None):
def _custom_policy(mode, func, *args, **kwargs):
return func in no_recompute_list
def _custom_policy(ctx, func, *args, **kwargs):
if func in no_recompute_list:
return CheckpointPolicy.MUST_SAVE
else:
return CheckpointPolicy.PREFER_RECOMPUTE
return _custom_policy
@ -530,7 +537,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
no_recompute_list = [
torch.ops.aten.mm.default,
]
return _pt2_selective_checkpoint_context_fn_gen(
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
@ -580,7 +587,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
no_recompute_list = [
torch.ops.aten.mm.default,
]
return _pt2_selective_checkpoint_context_fn_gen(
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
@ -650,7 +657,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
def selective_checkpointing_context_fn():
meta = {}
return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta))
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
def gn(x, y):
return torch.sigmoid(
@ -698,7 +705,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
)
def test_compile_selective_checkpoint_partial_ctx_fn(self):
def selective_checkpointing_context_fn(no_recompute_list):
return _pt2_selective_checkpoint_context_fn_gen(
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
@ -751,7 +758,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
torch.ops.aten.mm.default,
torch.ops.aten.sigmoid.default,
]
return _pt2_selective_checkpoint_context_fn_gen(
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list),
)
@ -803,7 +810,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
torch.ops.aten.mm.default,
torch.ops.aten.sigmoid.default,
]
return _pt2_selective_checkpoint_context_fn_gen(
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
@ -854,7 +861,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
no_recompute_list = [
torch.ops.aten.sigmoid.default,
]
return _pt2_selective_checkpoint_context_fn_gen(
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)

View File

@ -2746,26 +2746,6 @@ class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
wrapped_gm = backend.graphs[graph_idx]
return wrapped_gm
def test_hessian_graph_break(self):
counters.clear()
def wrapper_fn(x):
return torch.func.hessian(torch.sin)(x)
x = torch.randn(4, 3)
expected = wrapper_fn(x)
got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
self.assertEqual(expected, got)
self.assertEqual(len(counters["graph_break"]), 2)
self.assertEqual(
{
"'skip function disable in file _dynamo/decorators.py'": 1,
"call torch._dynamo.disable() wrapped function <function jacfwd.<locals>.wrapper_fn at 0xN>": 1,
},
{munge_exc(k): v for k, v in counters["graph_break"].items()},
)
@unittest.expectedFailure
def test_hessian(self):
counters.clear()
@ -2900,7 +2880,6 @@ class GraphModule(torch.nn.Module):
""",
)
@unittest.expectedFailure
def test_hessian_argnums(self):
counters.clear()
@ -3046,7 +3025,6 @@ class GraphModule(torch.nn.Module):
""" return (unflatten,)""",
)
@unittest.expectedFailure
def test_hessian_disable_capture(self):
counters.clear()
@ -3073,26 +3051,6 @@ class GraphModule(torch.nn.Module):
)
self.assertEqual(actual, expected)
def test_jacrev_graph_break(self):
counters.clear()
def wrapper_fn(x):
return torch.func.jacrev(torch.sin)(x)
x = torch.randn(4, 3)
expected = wrapper_fn(x)
got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
self.assertEqual(expected, got)
self.assertEqual(len(counters["graph_break"]), 2)
self.assertEqual(
{
"'skip function disable in file _dynamo/decorators.py'": 1,
"call torch._dynamo.disable() wrapped function <function jacrev.<locals>.wrapper_fn at 0xN>": 1,
},
{munge_exc(k): v for k, v in counters["graph_break"].items()},
)
@unittest.expectedFailure
def test_jacrev(self):
counters.clear()
@ -3169,7 +3127,6 @@ class GraphModule(torch.nn.Module):
""",
)
@unittest.expectedFailure
def test_jacrev_two_tensors_argnums(self):
counters.clear()
@ -3252,7 +3209,6 @@ class GraphModule(torch.nn.Module):
""",
)
@unittest.expectedFailure
def test_jacrev_has_aux(self):
counters.clear()
@ -3337,7 +3293,6 @@ class GraphModule(torch.nn.Module):
""",
)
@unittest.expectedFailure
def test_jacrev_disable_capture(self):
counters.clear()
@ -4284,26 +4239,6 @@ class GraphModule(torch.nn.Module):
self.assertEqual(len(counters["graph_break"]), 0)
self.assertEqual(actual, expected)
def test_jacfwd_graph_break(self):
counters.clear()
def wrapper_fn(x):
return torch.func.jacfwd(torch.sin)(x)
x = torch.randn(4, 3)
expected = wrapper_fn(x)
got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
self.assertEqual(expected, got)
self.assertEqual(len(counters["graph_break"]), 2)
self.assertEqual(
{
"'skip function disable in file _dynamo/decorators.py'": 1,
"call torch._dynamo.disable() wrapped function <function jacfwd.<locals>.wrapper_fn at 0xN>": 1,
},
{munge_exc(k): v for k, v in counters["graph_break"].items()},
)
@unittest.expectedFailure
def test_jacfwd(self):
counters.clear()
@ -4387,7 +4322,6 @@ class GraphModule(torch.nn.Module):
""",
)
@unittest.expectedFailure
def test_jacfwd_two_tensors_argnums(self):
counters.clear()
@ -4477,7 +4411,6 @@ class GraphModule(torch.nn.Module):
""",
)
@unittest.expectedFailure
def test_jacfwd_has_aux(self):
counters.clear()
@ -4572,7 +4505,6 @@ class GraphModule(torch.nn.Module):
""",
)
@unittest.expectedFailure
def test_jacfwd_randomness(self):
counters.clear()
@ -4676,7 +4608,6 @@ class GraphModule(torch.nn.Module):
""",
)
@unittest.expectedFailure
def test_jacfwd_disable_capture(self):
counters.clear()

View File

@ -47,7 +47,6 @@ from torch._dynamo.testing import (
same,
skipIfNotPy311,
unsupported,
xfailIfPy312,
)
from torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault
from torch._inductor.utils import run_and_get_code
@ -8289,6 +8288,72 @@ def ___make_guard_fn():
x = torch.zeros(100, dtype=torch.int64)
f(x)
def test_out_variant_custom_op(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
lib.define(
"split_with_sizes_copy(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()"
)
@torch.library.impl(lib, "split_with_sizes_copy", "Meta")
@torch.library.impl(lib, "split_with_sizes_copy", "CPU")
def split_with_sizes_copy(
all_gather_output: torch.Tensor,
all_gather_input_split_sizes: typing.List[int],
dim: int,
out: typing.List[torch.Tensor],
) -> None:
torch.split_with_sizes_copy(
all_gather_output, all_gather_input_split_sizes, dim=dim, out=out
)
@torch.compile(backend="eager", fullgraph=True)
def f1(all_gather_output, all_gather_input_split_sizes, dim, out):
return torch.ops.mylib.split_with_sizes_copy(
all_gather_output, all_gather_input_split_sizes, dim, out=out
)
all_gather_output = torch.randn(2, 272)
all_gather_input_split_sizes = [128, 8, 128, 8]
dim = 1
out = [
torch.empty(2, 128),
torch.empty(2, 8),
torch.empty(2, 128),
torch.empty(2, 8),
]
f1(all_gather_output, all_gather_input_split_sizes, dim, out)
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
lib.define(
"chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()"
)
@torch.library.impl(lib, "chunk_cat", "Meta")
@torch.library.impl(lib, "chunk_cat", "CPU")
def chunk_cat(
tensors: typing.List[torch.Tensor],
dim: int,
num_chunks: int,
out: torch.Tensor,
) -> None:
torch._chunk_cat(tensors, dim, num_chunks, out=out)
@torch.compile(backend="eager", fullgraph=True)
def f2(tensors, dim, num_chunks, out):
return torch.ops.mylib.chunk_cat(tensors, dim, num_chunks, out=out)
x = torch.zeros(100, dtype=torch.int64)
tensors = [
torch.randn(16, 16),
torch.randn(16),
torch.randn(16, 16),
torch.randn(16),
]
dim = 0
num_chunks = 2
out = torch.empty(2, 272)
f2(tensors, dim, num_chunks, out)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_runtime_assert_replacement(self):
@torch.compile(backend="aot_eager")
@ -9946,10 +10011,6 @@ fn
lambda mod: mod,
)
# The following 2 tests fail due to https://github.com/python/cpython/issues/118013.
# Tracked by https://github.com/pytorch/pytorch/issues/124302.
# The xfails can be removed once Python 3.12 is updated on CI.
@xfailIfPy312
def test_outside_linear_module_free(self):
# Compared to test_linear_module_free, the linear
# layer is not the code object that is directly compiled.
@ -9984,7 +10045,6 @@ fn
gc.collect()
self.assertTrue(cleared)
@xfailIfPy312
def test_parameter_free(self):
def model_inp_ctr():
param = torch.nn.Parameter(torch.randn(100, 100))

View File

@ -4781,6 +4781,9 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
res = opt_fn(x_weak, y)
self.assertEqual(ref, res)
@torch._functorch.config.patch(
recompute_views=True,
)
def test_storage_resize_forward_full_graph(self):
class TestModule(torch.nn.Module):
def __init__(self):
@ -4839,8 +4842,7 @@ def forward(self, primals_1, primals_2):
_foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]); primals_1 = primals_2 = None
getitem = _foreach_copy[0]; _foreach_copy = None
mm = torch.ops.aten.mm.default(getitem, getitem)
t_1 = torch.ops.aten.t.default(getitem); getitem = None
return [mm, t_1]""",
return [mm, getitem]""",
)
self.assertEqual(out_ref, out_test)

View File

@ -334,6 +334,41 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
res = fn(input)
self.assertIsInstance(res, BadNewTorchFunction)
def test_no_torch_function_recompiles(self):
class NJT:
def __repr__(self):
return f"NJT(shape={self.shape})"
def __init__(self, values, offsets):
self._values = values
self._offsets = offsets
def sin(self):
return torch.sin(self)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func == torch.sin:
self = args[0]
return NJT(func(self._values), self._offsets)
raise AssertionError("should not get here")
values1 = torch.randn(10, 3, 4, requires_grad=True)
values2 = torch.randn(10, 3, 4, requires_grad=True)
offsets = torch.tensor([0, 3, 10])
njt1 = NJT(values1, offsets)
njt2 = NJT(values2, offsets)
@torch.compile(backend="eager", fullgraph=True)
def f(x):
return torch.sin(x)
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
f(njt1)
f(njt2)
def test_base_torch_function_tracing(self):
def fn(x):
return torch.add(x, 1)
@ -1616,15 +1651,15 @@ Eq(s10, s8)""",
guard_str,
"""\
Eq(s3 - 1, s0)
Eq(zf1, zf6)""",
Eq(zf1, zf4)""",
)
else:
self.assertExpectedInline(
guard_str,
"""\
Eq(s4 - 1, s1)
Eq(s12 - 1, s7)
Eq(s11, s9)""",
Eq(s10 - 1, s5)
Eq(s9, s7)""",
)
return gm

View File

@ -446,8 +446,6 @@ aten::_nested_from_padded_and_nested_example
aten::_nested_from_padded_and_nested_example.out
aten::_nested_get_jagged_dummy
aten::_nested_get_lengths
aten::_nested_get_max_seqlen
aten::_nested_get_min_seqlen
aten::_nested_get_offsets
aten::_nested_get_ragged_idx
aten::_nested_get_values

View File

@ -111,13 +111,102 @@ class TestConverter(TestCase):
def test_aten_len(self):
class Module(torch.nn.Module):
def forward(self, x):
def forward(self, x: torch.Tensor):
length = len(x)
return torch.ones(length)
# aten::len.Tensor
inp = (torch.ones(2, 3),)
self._check_equal_ts_ep_converter(Module(), inp)
class Module(torch.nn.Module):
def forward(self, x: List[int]):
length = len(x)
return torch.ones(length)
# aten::len.t
inp = ([1, 2, 3],)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: Dict[int, str]):
length = len(x)
return torch.ones(length)
# aten::len.Dict_int
inp = ({1: "a", 2: "b", 3: "c"},)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: Dict[bool, str]):
length = len(x)
return torch.ones(length)
# aten::len.Dict_bool
inp = ({True: "a", False: "b"},)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: Dict[float, str]):
length = len(x)
return torch.ones(length)
# aten::len.Dict_float
inp = ({1.2: "a", 3.4: "b"},)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: Dict[torch.Tensor, str]):
length = len(x)
return torch.ones(length)
# aten::len.Dict_Tensor
inp = ({torch.zeros(2, 3): "a", torch.ones(2, 3): "b"},)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
# aten::len.str and aten::len.Dict_str are not supported
# since torch._C._jit_flatten does not support str
# inp = ("abcdefg",)
# self._check_equal_ts_ep_converter(Module(), inp)
# inp = ({"a": 1, "b": 2},)
# self._check_equal_ts_ep_converter(Module(), inp)
def test_prim_min(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x_len = len(x)
y_len = len(y)
# prim::min.int
len_int = min(x_len, y_len)
# prim::min.float
len_float = int(min(x_len * 2.0, y_len * 2.0))
# prim::min.self_int
len_self_int = min([x_len, y_len])
# prim::min.self_float
len_self_float = int(min([x_len * 2.0, y_len * 2.0]))
# prim::min.float_int
len_float_int = int(min(x_len * 2.0, y_len))
# prim::min.int_float
len_int_float = int(min(x_len, y_len * 2.0))
return torch.ones(
len_int
+ len_float
+ len_self_int
+ len_self_float
+ len_float_int
+ len_int_float
)
inp = (torch.randn(10, 2), torch.randn(5))
self._check_equal_ts_ep_converter(Module(), inp)
def test_aten___getitem___list(self):
class Module(torch.nn.Module):
def forward(self, x):
@ -659,6 +748,21 @@ class TestConverter(TestCase):
# inp = (torch.randn([2, 3, 4]),)
# self._check_equal_ts_ep_converter(func6, inp)
def test_prim_tolist(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor) -> List[int]:
return x.tolist()
inp = (torch.tensor([1, 2, 3]),)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor) -> List[List[int]]:
return x.tolist()
inp = (torch.tensor([[1, 2, 3], [4, 5, 6]]),)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
if __name__ == "__main__":
run_tests()

View File

@ -11,6 +11,7 @@ from torch._export.wrappers import _mark_strict_experimental
from torch._functorch.aot_autograd import aot_export_module
from torch.export._trace import _convert_ts_to_export_experimental
from torch.export.experimental import _export_forward_backward
from torch.testing import FileCheck
@ -194,6 +195,76 @@ def forward(self, arg0_1, arg1_1):
MDict, ({"0": torch.randn(4), "1": torch.randn(4)},)
)
def test_joint_basic(self) -> None:
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
self.loss = torch.nn.CrossEntropyLoss()
def forward(self, x):
return self.loss(
self.linear(x).softmax(dim=0), torch.tensor([1.0, 0.0, 0.0])
)
m = Module()
example_inputs = (torch.randn(3),)
m(*example_inputs)
ep = torch.export._trace._export(m, example_inputs, pre_dispatch=True)
joint_ep = _export_forward_backward(ep)
print(joint_ep)
"""
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]"):
# No stacktrace found for following nodes
view: "f32[1, 3]" = torch.ops.aten.view.default(arg3_1, [1, 3]); arg3_1 = None
t: "f32[3, 3]" = torch.ops.aten.t.default(arg0_1); arg0_1 = None
addmm: "f32[1, 3]" = torch.ops.aten.addmm.default(arg1_1, view, t); arg1_1 = t = None
view_1: "f32[3]" = torch.ops.aten.view.default(addmm, [3]); addmm = None
_softmax: "f32[3]" = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
detach_1: "f32[3]" = torch.ops.aten.detach.default(_softmax)
clone: "f32[3]" = torch.ops.aten.clone.default(arg2_1); arg2_1 = None
detach_5: "f32[3]" = torch.ops.aten.detach.default(clone); clone = None
_log_softmax: "f32[3]" = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
detach_12: "f32[3]" = torch.ops.aten.detach.default(_log_softmax)
mul: "f32[3]" = torch.ops.aten.mul.Tensor(_log_softmax, detach_5); _log_softmax = None
sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None
neg: "f32[]" = torch.ops.aten.neg.default(sum_1); sum_1 = None
div: "f32[]" = torch.ops.aten.div.Scalar(neg, 1); neg = None
ones_like: "f32[]" = torch.ops.aten.ones_like.default(div, pin_memory = False, memory_format = torch.preserve_format)
div_1: "f32[]" = torch.ops.aten.div.Scalar(ones_like, 1); ones_like = None
neg_1: "f32[]" = torch.ops.aten.neg.default(div_1); div_1 = None
expand: "f32[3]" = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
mul_1: "f32[3]" = torch.ops.aten.mul.Tensor(expand, detach_5); expand = detach_5 = None
_log_softmax_backward_data: "f32[3]" = torch.ops.aten._log_softmax_backward_data.default(mul_1, detach_12, 0, torch.float32); mul_1 = detach_12 = None
_softmax_backward_data: "f32[3]" = torch.ops.aten._softmax_backward_data.default(_log_softmax_backward_data, detach_1, 0, torch.float32); _log_softmax_backward_data = detach_1 = None
view_2: "f32[1, 3]" = torch.ops.aten.view.default(_softmax_backward_data, [1, 3]); _softmax_backward_data = None
t_1: "f32[3, 1]" = torch.ops.aten.t.default(view_2)
mm: "f32[3, 3]" = torch.ops.aten.mm.default(t_1, view); t_1 = view = None
t_2: "f32[3, 3]" = torch.ops.aten.t.default(mm); mm = None
sum_2: "f32[1, 3]" = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None
view_3: "f32[3]" = torch.ops.aten.view.default(sum_2, [3]); sum_2 = None
t_3: "f32[3, 3]" = torch.ops.aten.t.default(t_2); t_2 = None
return (div, t_3, view_3)
Graph signature: ExportGraphSignature(
input_specs=[
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='linear.weight', persistent=None),
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg1_1'), target='linear.bias', persistent=None),
InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='arg2_1'), target='lifted_tensor_0', persistent=None),
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None, persistent=None)
],
output_specs=[
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='div'), target=None),
OutputSpec(kind=<OutputKind.GRADIENT_TO_PARAMETER: 4>, arg=TensorArgument(name='t_3'), target='linear.weight'),
OutputSpec(kind=<OutputKind.GRADIENT_TO_PARAMETER: 4>, arg=TensorArgument(name='view_3'), target='linear.bias')
]
)
Range constraints: {}
"""
if __name__ == "__main__":
run_tests()

View File

@ -77,6 +77,7 @@ from torch.testing._internal.common_utils import (
subtest,
TEST_WITH_TORCHDYNAMO,
TestCase,
xfailIfTorchDynamo,
)
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
@ -2341,8 +2342,7 @@ class TestJac(VmapTearDownMixin, TestCase):
self.assertEqual(actual, expected)
# https://github.com/pytorch/pytorch/issues/127036
# it won't fail as jacrev/jacfwd were not inlined (see #128255)
# @xfailIfTorchDynamo
@xfailIfTorchDynamo
@parametrize("_preallocate_and_copy", (True, False))
def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy):
# With chunk_size=1, we shouldn't `vmap` and hence not be limited

View File

@ -1767,6 +1767,33 @@ TORCH_LIBRARY(test_autograd_cpp_node_data_dependent, m) {
out = compiled_fn(activations)
self.assertTrue(len(activations) == 0)
def test_callback_graph_break_throws_error(self):
called = [0]
def callback_final():
called[0] += 1
class MyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad):
torch.autograd.Variable._execution_engine.queue_callback(callback_final)
torch._dynamo.graph_break()
return grad
a = torch.rand((3, 3), requires_grad=True)
with self.assertRaisesRegex(
AssertionError,
"only supported when Compiled Autograd is enabled with fullgraph=True",
):
with compiled_autograd.enable(make_compiler_fn(fullgraph=False)):
b = MyFunc.apply(a)
b.sum().backward()
@unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_cudagraphs_cpu_division(self):
from torch._dynamo.testing import reduce_to_scalar_loss
@ -2177,7 +2204,6 @@ known_failing_tests = {
"test_autograd_multiple_views_python", # torch._dynamo.exc.Unsupported: call_function args: TensorVariable(
"test_autograd_node_isinstance", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsInstance
"test_autograd_simple_views_python", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function
"test_callback_adds_callback", # torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable
"test_callback_propagates_errors_from_device_thread", # AssertionError: "blah" does not match "call_method
"test_custom_autograd_no_early_free", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
"test_custom_function_cycle", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {}

View File

@ -235,6 +235,7 @@ if RUN_CPU:
BaseTest("test_int_div", "", test_cpu_repro.CPUReproTests()),
BaseTest("test_linear1"),
BaseTest("test_linear2"),
BaseTest("test_polar"),
BaseTest(
"test_linear_binary",
"",
@ -255,7 +256,8 @@ if RUN_CPU:
BaseTest("test_multihead_attention", "cpu", test_cpu_repro.CPUReproTests()),
BaseTest(
"test_multi_threading",
code_string_count={"py::gil_scoped_release release;": 1},
# Two threads compile, so we expect the output code to be printed twice.
code_string_count={"py::gil_scoped_release release;": 2},
),
BaseTest("test_profiler_mark_wrapper_call"),
BaseTest(

View File

@ -1920,6 +1920,8 @@ class CPUReproTests(TestCase):
FileCheck().check(_target_code_check).run(code)
if _target_code_check_not:
FileCheck().check_not(_target_code_check_not).run(code)
# Verify that the output isn't empty
FileCheck().check("Output code:").run(code)
self.assertEqual(
_fn(*_inps),
@ -1933,10 +1935,16 @@ class CPUReproTests(TestCase):
_internal_check(fn, inps, "aten.scatter_reduce_")
if "ATen parallel backend: OpenMP" in torch.__config__.parallel_info():
# Fix https://github.com/pytorch/pytorch/issues/118518
# which fails to change thread number with native thread pool
with set_num_threads(1):
_internal_check(fn, inps, _target_code_check_not="aten.scatter_reduce_")
# When running with a single thread, we expect the aten.scatter will go
# into the cpp backend codegen instead of a fallback to aten.scatter_reduce_.
# Avoid the inductor cache so we don't serve an entry compiled above.
with config.patch(
{"fx_graph_cache": False, "fx_graph_remote_cache": False}
):
_internal_check(
fn, inps, _target_code_check_not="aten.scatter_reduce_"
)
with config.patch({"cpp.dynamic_threads": True}), set_num_threads(1):
_internal_check(fn, inps, "aten.scatter_reduce_")

View File

@ -442,7 +442,15 @@ class TestPatternMatcher(TestCase):
.sub(8),
)
args_list = [
def check_uint4x2_mixed_mm(args, expect_mixed_mm):
torch._dynamo.reset()
counters.clear()
ref = fn(*args)
test, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(ref, test)
self.assertEqual("uint4x2_mixed_mm" in code, expect_mixed_mm)
args_expect_mixed_mm = [
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"),
@ -454,6 +462,13 @@ class TestPatternMatcher(TestCase):
.contiguous()
.t(),
),
]
for args in args_expect_mixed_mm:
check_uint4x2_mixed_mm(args, True)
# mixed mm is only enabled when casting from a lower-bitwidth dtype to a higher one
args_expect_no_mixed_mm = [
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.int32, device="cuda"),
@ -464,13 +479,8 @@ class TestPatternMatcher(TestCase):
),
]
for args in args_list:
torch._dynamo.reset()
counters.clear()
ref = fn(*args)
test, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(ref, test)
self.assertTrue("uint4x2_mixed_mm" in code)
for args in args_expect_no_mixed_mm:
check_uint4x2_mixed_mm(args, False)
@unittest.skipIf(not SM80OrLater, "need sm_80")
@inductor_config.patch(use_mixed_mm=True)

View File

@ -158,10 +158,10 @@ class DynamoProfilerTests(torch._inductor.test_case.TestCase):
hooks_called = {"enter": False, "exit": False}
def launch_enter_hook(*args):
def launch_enter_hook(lazy_dict):
hooks_called["enter"] = True
def launch_exit_hook(*args):
def launch_exit_hook(lazy_dict):
hooks_called["exit"] = True
CompiledKernel.launch_enter_hook = launch_enter_hook

View File

@ -28,7 +28,6 @@ import numpy as np
import torch
import torch._dynamo.config as dynamo_config
import torch._inductor.aoti_eager
import torch.nn as nn
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.debug_utils import aot_graph_input_parser
@ -40,16 +39,14 @@ from torch._dynamo.testing import (
skipIfPy312,
)
from torch._dynamo.utils import ifdynstaticdefault
from torch._inductor.aoti_eager import (
aoti_compile_with_persistent_cache,
aoti_eager_cache_dir,
load_aoti_eager_cache,
)
from torch._inductor.codegen.common import DataTypePropagation, OptimizationContext
from torch._inductor.fx_passes import pad_mm
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import (
add_scheduler_init_hook,
aoti_compile_with_persistent_cache,
aoti_eager_cache_dir,
load_aoti_eager_cache,
run_and_get_code,
run_and_get_cpp_code,
run_and_get_triton_code,
@ -772,7 +769,7 @@ class CommonTemplate:
)
@skipCUDAIf(not SM80OrLater, "Requires sm80")
def test_aoti_eager_support_out(self):
def test_eager_aoti_support_out(self):
ns = "aten"
op_name = "clamp"
dispatch_key = "CPU"
@ -824,44 +821,7 @@ class CommonTemplate:
self.assertEqual(ref_out_tensor1, res_out_tensor1)
@skipCUDAIf(not SM80OrLater, "Requires sm80")
def test_aoti_eager_support_str(self):
ns = "aten"
op_name = "div"
dispatch_key = "CPU"
device = "cpu"
if self.device.lower() == "cuda":
dispatch_key = "CUDA"
device = "cuda"
a = torch.randn(128, dtype=torch.float, device=device)
b = torch.randn(128, dtype=torch.float, device=device)
rounding_mode_list = ["trunc", "floor"]
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
# Get ref result from eager
ref_value_list = []
for rounding_mode in rounding_mode_list:
ref_value = getattr(torch.ops.aten, op_name)(
a, b, rounding_mode=rounding_mode
)
ref_value_list.append(ref_value)
register_ops_with_aoti_compile(
ns, [op_name], dispatch_key, torch_compile_op_lib_impl
)
# Invoke the pre-compiled kernel and get result.
res_value_list = []
for rounding_mode in rounding_mode_list:
res_value = getattr(torch.ops.aten, op_name)(
a, b, rounding_mode=rounding_mode
)
res_value_list.append(res_value)
for ref_value, res_value in zip(ref_value_list, res_value_list):
self.assertEqual(ref_value, res_value)
@skipCUDAIf(not SM80OrLater, "Requires sm80")
def test_aoti_eager_cache_hit(self):
def test_eager_aoti_cache_hit(self):
ns = "aten"
op_name = "abs"
dispatch_key = "CPU"
@ -886,7 +846,7 @@ class CommonTemplate:
# Patch the aoti_compile_with_persistent_cache as None to ensure no new kernel is generated
with mock.patch(
"torch._inductor.aoti_eager.aoti_compile_with_persistent_cache", None
"torch._inductor.utils.aoti_compile_with_persistent_cache", None
):
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
# Get ref result from eager
@ -902,7 +862,7 @@ class CommonTemplate:
self.assertEqual(ref_value, res_value)
@skipCUDAIf(not SM80OrLater, "Requires sm80")
def test_aoti_eager_with_persistent_cache(self):
def test_eager_aoti_with_persistent_cache(self):
def fn(a):
return torch.abs(a)
@ -946,7 +906,7 @@ class CommonTemplate:
self.assertTrue(kernel_lib_path in kernel_libs_abs_path)
@skipCUDAIf(not SM80OrLater, "Requires sm80")
def test_aoti_eager_with_scalar(self):
def test_eager_aoti_with_scalar(self):
namespace_name = "aten"
op_name = "add"
op_overload_name = "Tensor"
@ -982,18 +942,18 @@ class CommonTemplate:
self.assertTrue(isinstance(op_info, dict))
self.assertTrue("meta_info" in op_info)
self.assertTrue(len(op_info["meta_info"]) == 3)
# Scalar Tensor
self.assertTrue("scalar_value" not in op_info["meta_info"][0])
self.assertTrue(op_info["meta_info"][0]["sizes"] == [])
self.assertTrue(op_info["meta_info"][0]["strides"] == [])
# Scalar Tensor
self.assertTrue("scalar_value" not in op_info["meta_info"][1])
self.assertTrue("scalar_value" not in op_info["meta_info"][0])
self.assertTrue(op_info["meta_info"][1]["sizes"] == [])
self.assertTrue(op_info["meta_info"][1]["strides"] == [])
# Scalar Tensor
self.assertTrue("scalar_value" not in op_info["meta_info"][1])
self.assertTrue(op_info["meta_info"][2]["sizes"] == [])
self.assertTrue(op_info["meta_info"][2]["strides"] == [])
# Scalar
self.assertTrue("scalar_value" in op_info["meta_info"][2])
self.assertTrue("sizes" not in op_info["meta_info"][2])
self.assertTrue("strides" not in op_info["meta_info"][2])
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
a = torch.randn(128, device=device)
@ -1016,7 +976,7 @@ class CommonTemplate:
self.assertEqual(ref_values, res_values)
@skipCUDAIf(not SM80OrLater, "Requires sm80")
def test_aoti_eager_override_registration(self):
def test_eager_aoti_override_registration(self):
namespace_name = "aten"
dispatch_key = "CPU"
device = torch.device("cpu")
@ -4697,6 +4657,16 @@ class CommonTemplate:
self.common(fn, (x,))
def test_polar(self):
def fn(dist, angle):
return torch.polar(dist, angle)
inp = (
torch.tensor([1, 2], dtype=torch.float64),
torch.tensor([np.pi / 2, 5 * np.pi / 4], dtype=torch.float64),
)
self.common(fn, (*inp,))
def test_cauchy(self):
def fn(x, y):
return torch.sum(1 / (torch.unsqueeze(x, -1) - y))
@ -10167,7 +10137,8 @@ class CommonTemplate:
self.assertEqual(rot.grad, rot_e.grad)
self.assertEqual(trans.grad, trans_e.grad)
@config.patch({"fx_graph_cache": False})
# If we serve from the cache, the init hook isn't called
@config.patch({"fx_graph_cache": False, "fx_graph_remote_cache": False})
def test_inner_fn_str_and_stride(self):
def f(x):
x = x + 1

View File

@ -237,6 +237,7 @@ test_failures = {
"test_pointwise_hermite_polynomial_he_dynamic_shapes": TestFailure(("cuda", "xpu")),
"test_pointwise_laguerre_polynomial_l_dynamic_shapes": TestFailure(("cuda", "xpu")),
"test_pointwise_legendre_polynomial_p_dynamic_shapes": TestFailure(("cuda", "xpu")),
"test_polar_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True),
"test_randn_generator_dynamic_shapes": TestFailure(("cpu",)),
"test_randn_like_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_single_elem_dynamic_shapes": TestFailure(("cpu",)),

View File

@ -411,7 +411,6 @@ inductor_one_sample = {
"_segment_reduce.lengths": {f16},
"_segment_reduce.offsets": {f16},
"addmv": {f16},
"argsort": {b8, f16, f32, f64, i32, i64},
"as_strided.partial_views": {f16},
"corrcoef": {f16},
"diff": {f16},
@ -426,11 +425,7 @@ inductor_one_sample = {
"logspace": {f16},
"logspace.tensor_overload": {f16, f32, f64, i32, i64},
"masked_logsumexp": {i64},
"max.binary": {b8},
"max_pool2d_with_indices_backward": {f16, f32, f64},
"maximum": {b8},
"min.binary": {b8},
"minimum": {b8},
"new_empty_strided": {f16},
"nn.functional.adaptive_avg_pool3d": {f16},
"nn.functional.adaptive_max_pool1d": {f16, f32},

View File

@ -14,6 +14,7 @@ from torch._inductor.utils import run_and_get_code
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skipIfXpu,
)
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
@ -214,6 +215,7 @@ class TritonBlockPointerTest(InductorTestCase):
# Expect 3 block pointers: 2 inputs one output
self.run_and_compare(foo, x, y, expected_num_block_pointers=3)
@skipIfXpu
@parametrize(
"view_size,num_block_pointers,num_triton_kernels",
[

View File

@ -911,51 +911,6 @@ class TestTracer(JitTestCase):
self.assertEqual(len(list(g.inputs())), 2)
FileCheck().check("mul").check("add").run(str(g))
def test_trace_c10_ops(self):
try:
_ = torch.ops._caffe2.GenerateProposals
except AttributeError:
self.skipTest("Skip the test since c2 ops are not registered.")
class MyModel(torch.nn.Module):
def forward(self, scores, bbox_deltas, im_info, anchors):
a, b = torch.ops._caffe2.GenerateProposals(
(scores),
(bbox_deltas),
(im_info),
(anchors),
2.0,
6000,
300,
0.7,
16,
True,
-90,
90,
1.0,
True,
)
return a, b
model = MyModel()
A = 4
H = 10
W = 8
img_count = 3
scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
bbox_deltas = torch.linspace(
0, 10, steps=img_count * 4 * A * H * W, dtype=torch.float32
)
bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
im_info = torch.ones(img_count, 3, dtype=torch.float32)
anchors = torch.ones(A, 4, dtype=torch.float32)
inputs = (scores, bbox_deltas, im_info, anchors)
traced_model = torch.jit.trace(model, inputs)
self.assertEqual(traced_model(*inputs), model(*inputs))
self.assertExportImportModule(
traced_model, (scores, bbox_deltas, im_info, anchors)
)
def run_ge_tests(self, optimize, use_cuda):
with enable_profiling_mode_for_profiling_tests():
with torch.jit.optimized_execution(optimize):

View File

@ -340,8 +340,8 @@ def xfail(error_message: str, reason: Optional[str] = None):
# skips tests for opset_versions listed in unsupported_opset_versions.
# if the caffe2 test cannot be run for a specific version, add this wrapper
# (for example, an op was modified but the change is not supported in caffe2)
# if the PyTorch test cannot be run for a specific version, add this wrapper
# (for example, an op was modified but the change is not supported in PyTorch)
def skipIfUnsupportedOpsetVersion(unsupported_opset_versions):
def skip_dec(func):
@functools.wraps(func)

View File

@ -873,33 +873,6 @@ class TestOperators(common_utils.TestCase):
x = torch.randn(2, 3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11)
# Github Issue: https://github.com/pytorch/pytorch/issues/71095
# def test_c2_op(self):
# class MyModel(torch.nn.Module):
# def __init__(self):
# super().__init__()
#
# def forward(self, scores, bbox_deltas, im_info, anchors):
# a, b = torch.ops._caffe2.GenerateProposals(
# (scores), (bbox_deltas), (im_info), (anchors),
# 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, True,
# )
# return a, b
#
# model = MyModel()
# A = 4
# H = 10
# W = 8
# img_count = 3
# scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
# bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
# dtype=torch.float32)
# bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
# im_info = torch.ones(img_count, 3, dtype=torch.float32)
# anchors = torch.ones(A, 4, dtype=torch.float32)
# inputs = (scores, bbox_deltas, im_info, anchors)
# self.assertONNX(model, inputs, custom_opsets={"org.pytorch._caffe2": 0})
def test_dict(self):
class MyModel(torch.nn.Module):
def forward(self, x_in):

View File

@ -1358,6 +1358,8 @@ class TestUtilityFuns(_BaseTestCase):
iter = graph.nodes()
self.assertEqual(next(iter).kind(), "custom_namespace::custom_op")
# gelu is exported as onnx::Gelu for opset >= 20
@skipIfUnsupportedMaxOpsetVersion(19)
def test_custom_opsets_gelu(self):
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::gelu", 9)
@ -1382,6 +1384,8 @@ class TestUtilityFuns(_BaseTestCase):
self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
self.assertEqual(graph.opset_import[1].version, 1)
# gelu is exported as onnx::Gelu for opset >= 20
@skipIfUnsupportedMaxOpsetVersion(19)
def test_register_aten_custom_op_symbolic(self):
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "aten::gelu", 9)

Some files were not shown because too many files have changed in this diff Show More