Compare commits

..

34 Commits

Author SHA1 Message Date
5f81117a5d Update on "[dynamo] Replace tx.inline_user_function_return with call_function"
tx.inline_user_function_return only works for UserDefinedFunction, but
as we use more VariableTracker build, we should use call_function.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-21 15:48:42 -07:00
f340bb91f5 Update base for Update on "[dynamo] Replace tx.inline_user_function_return with call_function"
tx.inline_user_function_return only works for UserDefinedFunction, but
as we use more VariableTracker build, we should use call_function.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-21 15:48:42 -07:00
8daef35cf1 Revert "[Code Clean] Clean asserts in torch/ao/quantization (root, quantizer, backend_config) (#165433)"
This reverts commit df64c0c4649984093bd1a46f1e9c658c72018200.

Reverted https://github.com/pytorch/pytorch/pull/165433 on behalf of https://github.com/clee2000 due to I think this broke some quantization tests ([comment](https://github.com/pytorch/pytorch/pull/165433#issuecomment-3429741770))
2025-10-21 22:10:19 +00:00
51319ca090 [Pytorch] Add NEON Vectorized<uint> family of translation layers (#165690)
Summary:
Adding NEON specializations of Vectorized<T> for uint8, uint16, uint32 and uint64.

Correcness has been checked using test_ops.py

operator_benchmark_test.py, which uses the PyTorch API, shows significant enhancements in some operations:

Before:

uint8 mul: 1460.751us
uint8 add: 2359.565us
uint8 lsl: 2151.206us

After:

uint8 mul: 194.792us ---> 650% higher throughput
uint8 add: 195.609us ---> 1100% higher throughput
uint8 lsl: 186.249us ---> 1055% higher throughput

Test Plan:
Correctness:

buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch

Performance:

buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test

Reviewed By: mcfi

Differential Revision: D84770153

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165690
Approved by: https://github.com/malfet
2025-10-21 21:46:55 +00:00
d311a3d1dc A temporary fix to autotune out of range and related IMA (#165943)
Summary:
Autotune issue during lowering w/ AOTI:
```
setStorage: sizes [1536, 32, 8192], strides [8192, 8192, 1], storage offset 0, and itemsize 2 requiring a storage size of 25673728 are out of bounds for storage of size 25362432
```
Need a hack to create new base tensor with sufficient storage

Test Plan: Finally be able to see the e2e test passes on CI. See the detailed Test Plan in D83520844

Differential Revision: D84872792

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165943
Approved by: https://github.com/laithsakka
2025-10-21 21:40:20 +00:00
04adfe5ba9 Make Backend::setGroupUid virtual (#165957)
As titled, so that we may customize this function in custom backends

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165957
Approved by: https://github.com/d4l3k
2025-10-21 21:33:24 +00:00
4be1e3bf92 [AMP][Refactor] Autocast dtype handling to simplify device-specific c… (#165221)
This PR refactors the autocast context manager in autocast_mode.py to simplify and centralize the logic for checking supported dtypes for each device. The previous implementation repeated similar checks for multiple device types. Now, a single mapping device_supported_dtypes is used to associate device types with their supported dtypes, and the validation logic is unified.

**The former PR #163446 was merged but reverted due to failed CI test on `openreg` related tests.**

This RR additionally slightly modified some test assertions for passing the CI tests. CI failed due to assertion for the exactly same error message. For example:
```
File "/var/lib/jenkins/workspace/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py", line 9, in test_autocast_with_unsupported_type
    with self.assertWarnsRegex(
        AssertionError: "In openreg autocast, but the target dtype torch.float32 is not supported." does not match "In openreg autocast, but the target dtype is not supported. Disabling autocast."
```

Sorry for the inconvenience again.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165221
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-21 21:32:12 +00:00
e7592f4005 [CI] Move the periodic debug tests to newer runner (#165158)
Previously g3 = NVIDIA Tesla M60
Now g6 = NVIDIA L4
Also change cuda arch list accordingly

Pros:
More memory, newer GPU

Cons:
That was one of the few remaining tests on g3 runners, so we probably lost coverage?

We can probably run more tests in parallel now but I'm not going to do that here

Disabled a bunch of sparse tests and nestedtensor tests that were previously skipped due to not having sufficient hardware?  They are now failing with
```
Traceback (most recent call last):
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3293, in wrapper
    method(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3292, in wrapper
    with policy():
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2532, in __enter__
    self.beforeStreams[-1].synchronize()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/cuda/streams.py", line 105, in synchronize
    super().synchronize()
torch.AcceleratorError: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from stream_synchronize at /var/lib/jenkins/workspace/c10/cuda/CUDAFunctions.h:120 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
#6 c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, unsigned int, bool) [clone .cold] from CUDAException.cpp:0
#7 THCPStream_synchronize(_object*, _object*) from Stream.cpp:0
#8 cfunction_vectorcall_NOARGS from /usr/local/src/conda/python-3.10.14/Objects/methodobject.c:489
#9 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:114
#10 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.14/Include/internal/pycore_ceval.h:46
#11 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:114
#12 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.14/Include/internal/pycore_ceval.h:46
```
when run with cuda launch blocking I got a ton of stuff like
```

/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [5,3,0], thread: [2,7,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [5,3,0], thread: [3,7,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [0,0,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [1,0,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [2,0,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [3,0,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [0,1,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [1,1,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [3,1,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [0,2,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [2,2,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [3,2,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [0,3,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [1,3,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [1,4,0] Assertion `value < upper_bound` failed.
/var/lib/jenkins/workspace/third_party/cutlass/include/cutlass/integer_subbyte.h:124: cutlass::integer_subbyte<Bits, Signed>::integer_subbyte(unsigned int) [with int Bits = 2; __nv_bool Signed = false]: block: [3,8,0], thread: [3,4,0] Assertion `value < upper_bound` failed.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165158
Approved by: https://github.com/seemethere
2025-10-21 21:28:12 +00:00
d334c3649d [CUDA] fix reflection padding for large batch size (#165942)
Fixes [#165861](https://github.com/pytorch/pytorch/issues/165861)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165942
Approved by: https://github.com/eqy
2025-10-21 21:07:38 +00:00
9f82535c5a [ROCm] [Normalization] Update block size (#165941)
* Seeing upto 6x improvement

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165941
Approved by: https://github.com/jeffdaily
2025-10-21 20:53:05 +00:00
5b35fc8777 Support multiple commits on push events in trunk tagging workflow (#165937)
Context:
* this workflow is used to create tags like `trunk/{sha}` for all `main` commits
* those tags are used by [autorevert](https://github.com/pytorch/test-infra/blob/main/aws/lambda/pytorch-auto-revert/README.md) to rerun selected workflows

Problem: currently the workflow creates only a single tag per push event, while ghstack pushes multiple commits per single push.

This PR supports tag creation for all commits in the push event.

Complimentary autorevert PR: https://github.com/pytorch/test-infra/pull/7291

---

### Testing

I created an identical copy of this workflow in my personal repo: https://github.com/izaitsevfb/pr-head-test/actions/workflows/trunk-tagging.yml

See action runs there.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165937
Approved by: https://github.com/huydhn
2025-10-21 20:52:34 +00:00
2f38eece7c [CUDA][cuBLAS] addmm -- some refactoring for easier navigation between the Lt and non-Lt paths (#163955)
As per title. Additionally, some Lt selection conditions are revisited, and some redundancy removed (especially in the ROCm vs non-ROCm paths).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163955
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-21 20:48:12 +00:00
830e789a55 [dynamo][annotate] Graph break cleanly on fx.traceback.annotate reconstruction (#166006)
This avoids generation of bad bytecode, leading to really confusing
error. I am not sure why we can't reconstruct cleanly, it has to do with
the input being a dict, while other supported ctx managers take bools.

Fixing that is for another day. Lets give a good error message for now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166006
Approved by: https://github.com/yushangdi, https://github.com/SherlockNoMad
2025-10-21 20:48:04 +00:00
ad4dc52bf6 Revert "shrink_group implementation to expose ncclCommShrink API (#164518)"
This reverts commit 4e643422f63a3cdd71bd141615f98de6bb54d15f.

Reverted https://github.com/pytorch/pytorch/pull/164518 on behalf of https://github.com/albanD due to Breaks lint ([comment](https://github.com/pytorch/pytorch/pull/164518#issuecomment-3429426503))
2025-10-21 20:24:14 +00:00
dac9ed9790 Bump uv from 0.8.6 to 0.9.5 in /.ci/lumen_cli (#166017)
Bumps [uv](https://github.com/astral-sh/uv) from 0.8.6 to 0.9.5.
- [Release notes](https://github.com/astral-sh/uv/releases)
- [Changelog](https://github.com/astral-sh/uv/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/uv/compare/0.8.6...0.9.5)

---
updated-dependencies:
- dependency-name: uv
  dependency-version: 0.9.5
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-21 13:16:30 -07:00
1c7fe8f861 [BugFix] chunk_size should always be int64_t (#165971)
aspired by https://github.com/pytorch/pytorch/pull/156872
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165971
Approved by: https://github.com/albanD
2025-10-21 19:52:47 +00:00
4e643422f6 shrink_group implementation to expose ncclCommShrink API (#164518)
Closes #164529

To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch.

This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization.

For more info:  [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518
Approved by: https://github.com/kwen2501
2025-10-21 19:47:33 +00:00
3c3b278872 [reland][fx] Move Node._prepend/Node._remove_from_list to C++ (#165882)
Relands #148261 that was reverted by #150542

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165882
Approved by: https://github.com/ezyang
2025-10-21 19:43:55 +00:00
dd0138db0b [dynamo] Replace tx.inline_user_function_return with call_function
tx.inline_user_function_return only works for UserDefinedFunction, but
as we use more VariableTracker build, we should use call_function.

[ghstack-poisoned]
2025-10-21 11:04:27 -07:00
c25828c0fa Update on "[dynamo][remaining] Replace UserFunctionVariable with VariableTracker build"
Audit: To prevent future issues with functools.partial or callable objects.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-21 08:36:43 -07:00
1383dd025a Update base for Update on "[dynamo][remaining] Replace UserFunctionVariable with VariableTracker build"
Audit: To prevent future issues with functools.partial or callable objects.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-21 08:36:43 -07:00
b75a2bac08 Update on "[dynamo][remaining] Replace UserFunctionVariable with VariableTracker build"
Audit: To prevent future issues with functools.partial or callable objects.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-20 22:51:57 -07:00
fcc5040fbf Update base for Update on "[dynamo][remaining] Replace UserFunctionVariable with VariableTracker build"
Audit: To prevent future issues with functools.partial or callable objects.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-20 22:51:57 -07:00
1476d60587 Update on "[dynamo][remaining] Replace UserFunctionVariable with VariableTracker build"
Audit: To prevent future issues with functools.partial or callable objects.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-20 18:50:38 -07:00
fb6f547840 Update base for Update on "[dynamo][remaining] Replace UserFunctionVariable with VariableTracker build"
Audit: To prevent future issues with functools.partial or callable objects.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-20 18:50:38 -07:00
7ee521258d Update on "[dynamo][remaining] Replace UserFunctionVariable with VariableTracker build"
Audit: To prevent future issues with functools.partial or callable objects.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-20 18:01:55 -07:00
9481632922 Update base for Update on "[dynamo][remaining] Replace UserFunctionVariable with VariableTracker build"
Audit: To prevent future issues with functools.partial or callable objects.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-20 18:01:55 -07:00
03d0cb831d Update on "[dynamo][remaining] Replace UserFunctionVariable with VariableTracker build"
Audit: To prevent future issues with functools.partial or callable objects.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-20 17:54:31 -07:00
b0b0860505 Update base for Update on "[dynamo][remaining] Replace UserFunctionVariable with VariableTracker build"
Audit: To prevent future issues with functools.partial or callable objects.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-20 17:54:31 -07:00
456a3934c1 Update on "[dynamo][remaining] Replace UserFunctionVariable with VariableTracker build"
Audit: To prevent future issues with functools.partial or callable objects.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-20 16:09:32 -07:00
47b6d8f026 Update base for Update on "[dynamo][remaining] Replace UserFunctionVariable with VariableTracker build"
Audit: To prevent future issues with functools.partial or callable objects.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-20 16:09:32 -07:00
b962834f9d Update on "[dynamo][remaining] Replace UserFunctionVariable with VariableTracker build"
Audit: To prevent future issues with functools.partial or callable objects.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-20 15:50:44 -07:00
e2f1ae8b4f Update base for Update on "[dynamo][remaining] Replace UserFunctionVariable with VariableTracker build"
Audit: To prevent future issues with functools.partial or callable objects.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-20 15:50:44 -07:00
60880a6fe2 [dynamo][remaining] Replace UserFunctionVariable with VariableTracker build
Audit: To prevent future issues with functools.partial or callable objects.

[ghstack-poisoned]
2025-10-20 00:42:21 -07:00
59 changed files with 2461 additions and 1641 deletions

View File

@ -1,11 +1,15 @@
sphinx==7.2.6
sphinx==5.3.0
#Description: This is used to generate PyTorch docs
#Pinned versions: 7.2.6
#Pinned versions: 5.3.0
pytorch_sphinx_theme2==0.1.0
#Description: This is needed to generate PyTorch docs
#Pinned versions: 0.1.0
standard-imghdr==3.13.0; python_version >= "3.13"
#Description: This is needed by Sphinx, so it needs to be added here.
# The reasons are as follows:
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
# something related to Docker setup. We can investigate this later.
@ -32,17 +36,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
#Description: This is used to generate PyTorch docs
#Pinned versions: 2.13.0
breathe==4.36.0
breathe==4.34.0
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 4.36.0
#Pinned versions: 4.34.0
exhale==0.3.7
exhale==0.2.3
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.3.7
#Pinned versions: 0.2.3
docutils==0.20
docutils==0.16
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.20
#Pinned versions: 0.16
bs4==0.0.1
#Description: This is used to generate PyTorch C++ docs
@ -52,13 +56,13 @@ IPython==8.12.0
#Description: This is used to generate PyTorch functorch docs
#Pinned versions: 8.12.0
myst-nb==1.3.0
myst-nb==0.17.2
#Description: This is used to generate PyTorch functorch and torch.compile docs.
#Pinned versions: 1.3.0
#Pinned versions: 0.17.2
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
python-etcd==0.4.5
sphinx-copybutton==0.5.0
sphinx-design==0.6.1
sphinx-design==0.4.0
sphinxcontrib-mermaid==1.0.0
myst-parser==4.0.1
myst-parser==0.18.1

View File

@ -6,7 +6,7 @@ dependencies = [
"GitPython==3.1.45",
"docker==7.1.0",
"pytest==7.3.2",
"uv==0.8.6"
"uv==0.9.5"
]
[tool.setuptools]

View File

@ -102,18 +102,8 @@ if [ "$is_main_doc" = true ]; then
echo coverage output not found
exit 1
elif [ $undocumented -gt 0 ]; then
echo "======================================"
echo "ERROR: $undocumented undocumented objects found!"
echo "======================================"
echo ""
echo "Full coverage report:"
echo undocumented objects found:
cat build/coverage/python.txt
echo ""
echo "======================================"
echo "Undocumented modules/objects (lines after TOTAL):"
tail -n +$((lines - undocumented + 1)) build/coverage/python.txt
echo "======================================"
echo ""
echo "Make sure you've updated relevant .rsts in docs/source!"
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
exit 1

View File

@ -147,15 +147,16 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
cuda-arch-list: 8.9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
]}
secrets: inherit

View File

@ -58,8 +58,10 @@ jobs:
else
COMMIT_SHA="${{ github.sha }}"
fi
echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
{
echo "sha=${COMMIT_SHA}"
echo "tag_name=trunk/${COMMIT_SHA}"
} >> "${GITHUB_OUTPUT}"
- name: Validate commit SHA
run: |
@ -87,7 +89,7 @@ jobs:
echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
fi
- name: Create and push tag with retry
- name: Create and push tag(s) with retry
id: check_tag
env:
TAG_NAME: ${{ steps.commit.outputs.tag_name }}
@ -112,14 +114,23 @@ jobs:
return 1
}
# Exit early if tag already exists
if check_tag_exists; then
echo "✅ Tag already exists - no action needed"
echo "exists=true" >> "${GITHUB_OUTPUT}"
exit 0
fi
# Counters for summary reporting
created_count=0
skipped_count=0
failed_count=0
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
# Always write outputs once on exit
finish() {
set +e
if [ -n "${GITHUB_OUTPUT:-}" ]; then
{
echo "created_count=${created_count}"
echo "skipped_count=${skipped_count}"
echo "failed_count=${failed_count}"
} >> "${GITHUB_OUTPUT}"
fi
}
trap finish EXIT
# Retry configuration
MAX_RETRIES=5
@ -194,31 +205,111 @@ jobs:
}
}
# Execute with retry
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
echo "exists=false" >> "${GITHUB_OUTPUT}"
# New behavior for push events: enumerate commits in the push and tag each one.
# For workflow_dispatch, retain existing single-SHA behavior.
# Always fetch tags once up front to improve idempotency in loops
git fetch origin --tags --quiet || true
if [ "${{ github.event_name }}" = "push" ]; then
BEFORE_SHA="${{ github.event.before }}"
AFTER_SHA="${{ github.sha }}" # same as event.after
# List commits introduced by this push (old..new), oldest first for stable ordering
commits_file="$(mktemp)"
git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
if [ ! -s "${commits_file}" ]; then
echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
rm -f "${commits_file}"
exit 0
fi
commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
echo "Found ${commit_count} commit(s) to tag for push:"
while IFS= read -r sha; do
printf ' %s\n' "${sha}"
done < "${commits_file}"
while IFS= read -r sha; do
TAG_NAME="trunk/${sha}"
COMMIT_SHA="${sha}"
# If tag already exists locally or remotely, skip (idempotent)
if check_tag_exists; then
echo "✅ Tag ${TAG_NAME} already exists - skipping"
skipped_count=$((skipped_count + 1))
continue
fi
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
created_count=$((created_count + 1))
else
echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
failed_count=$((failed_count + 1))
fi
done < "${commits_file}"
rm -f "${commits_file}"
if [ "${failed_count}" -gt 0 ]; then
exit 1
fi
exit 0
else
echo "Tag creation failed after all retry attempts"
exit 1
# workflow_dispatch path (single SHA tagging preserved)
# Exit early if tag already exists
if check_tag_exists; then
echo "✅ Tag already exists - no action needed"
skipped_count=1
exit 0
fi
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
created_count=1
exit 0
else
echo "Tag creation failed after all retry attempts"
failed_count=1
exit 1
fi
fi
- name: Tag creation summary
if: always()
run: |
if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
elif [ "${{ job.status }}" = "success" ]; then
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
if [ "${{ github.event_name }}" = "push" ]; then
echo "Trigger: push on main"
echo "Created: ${{ steps.check_tag.outputs.created_count }}"
echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
else
echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
fi
else
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
else
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
else
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
echo ""
echo "Tag details:"
echo " Name: ${{ steps.commit.outputs.tag_name }}"
echo " Commit: ${{ steps.commit.outputs.sha }}"
echo " Trigger: ${{ github.event_name }}"
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
echo ""
echo "Tag details:"
echo " Name: ${{ steps.commit.outputs.tag_name }}"
echo " Commit: ${{ steps.commit.outputs.sha }}"
echo " Trigger: ${{ github.event_name }}"
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
fi
fi

View File

@ -9,6 +9,7 @@
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
#include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h>
#endif
#include <ATen/cpu/vec/vec128/vec128_convert.h>

View File

@ -0,0 +1,378 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/macros/Macros.h>
#include <c10/util/irange.h>
namespace at::vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#define VEC_UINT_NEON_TEMPLATE(vl, bit) \
template <> \
struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \
\
template <> \
class Vectorized<uint##bit##_t> { \
using neon_type = uint##bit##x##vl##_t; \
\
private: \
neon_type values; \
\
public: \
using value_type = uint##bit##_t; \
using size_type = int; \
static constexpr size_type size() { \
return vl; \
} \
Vectorized() { \
values = vdupq_n_u##bit(0); \
} \
Vectorized(neon_type v) : values(v) {} \
Vectorized(uint##bit##_t val); \
template < \
typename... Args, \
typename = std::enable_if_t<(sizeof...(Args) == size())>> \
Vectorized(Args... vals) { \
__at_align__ uint##bit##_t buffer[size()] = {vals...}; \
values = vld1q_u##bit(buffer); \
} \
operator neon_type() const { \
return values; \
} \
static Vectorized<uint##bit##_t> loadu( \
const void* ptr, \
uint64_t count = size()); \
void store(void* ptr, uint64_t count = size()) const; \
template <uint64_t mask> \
static Vectorized<uint##bit##_t> blend( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b); \
static Vectorized<uint##bit##_t> blendv( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b, \
const Vectorized<uint##bit##_t>& mask_) { \
return vbslq_u##bit(mask_.values, b, a); \
} \
template <typename step_t> \
static Vectorized<uint##bit##_t> arange( \
value_type base = 0, \
step_t step = static_cast<step_t>(1)); \
static Vectorized<uint##bit##_t> set( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b, \
uint64_t count = size()); \
const uint##bit##_t& operator[](uint idx) const = delete; \
uint##bit##_t& operator[](uint idx) = delete; \
Vectorized<uint##bit##_t> abs() const { \
return values; \
} \
Vectorized<uint##bit##_t> real() const { \
return values; \
} \
Vectorized<uint##bit##_t> imag() const { \
return vdupq_n_u##bit(0); \
} \
Vectorized<uint##bit##_t> conj() const { \
return values; \
} \
Vectorized<uint##bit##_t> neg() const { \
return vreinterpretq_u##bit##_s##bit( \
vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values))); \
} \
uint##bit##_t reduce_add() const { \
return vaddvq_u##bit(values); \
} \
uint##bit##_t reduce_max() const; \
Vectorized<uint##bit##_t> operator==( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vceqq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator!=( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> operator<( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcltq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator<=( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcleq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator>( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcgtq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator>=( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcgeq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> eq( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> ne( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> gt( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> ge( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> lt( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> le( \
const Vectorized<uint##bit##_t>& other) const; \
}; \
template <> \
Vectorized<uint##bit##_t> inline operator+( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vaddq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator-( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vsubq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator&( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vandq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator|( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vorrq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator^( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return veorq_u##bit(a, b); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this == other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this != other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this > other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this >= other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this < other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this <= other) & Vectorized<uint##bit##_t>(1); \
}
VEC_UINT_NEON_TEMPLATE(16, 8)
inline uint8_t Vectorized<uint8_t>::reduce_max() const {
return vmaxvq_u8(values);
}
template <>
Vectorized<uint8_t> inline operator*(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
return vmulq_u8(a, b);
}
template <>
inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) {
return vmvnq_u8(a);
}
inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=(
const Vectorized<uint8_t>& other) const {
return ~(*this == other);
}
template <>
Vectorized<uint8_t> inline minimum(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
return vminq_u8(a, b);
}
template <>
Vectorized<uint8_t> inline maximum(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
return vmaxq_u8(a, b);
}
template <uint64_t mask>
Vectorized<uint8_t> Vectorized<uint8_t>::blend(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
// Build an array of flags: each bit of element is 1 if the corresponding bit
// in 'mask' is set, 0 otherwise.
uint8x16_t maskArray = {
(mask & 1LL) ? 0xFF : 0,
(mask & 2LL) ? 0xFF : 0,
(mask & 4LL) ? 0xFF : 0,
(mask & 8LL) ? 0xFF : 0,
(mask & 16LL) ? 0xFF : 0,
(mask & 32LL) ? 0xFF : 0,
(mask & 64LL) ? 0xFF : 0,
(mask & 128LL) ? 0xFF : 0,
(mask & 256LL) ? 0xFF : 0,
(mask & 512LL) ? 0xFF : 0,
(mask & 1024LL) ? 0xFF : 0,
(mask & 2048LL) ? 0xFF : 0,
(mask & 4096LL) ? 0xFF : 0,
(mask & 8192LL) ? 0xFF : 0,
(mask & 16384LL) ? 0xFF : 0,
(mask & 32768LL) ? 0xFF : 0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_u8(maskArray, b.values, a.values);
}
#define VEC_UINT_NEON_OPS(vl, bit) \
inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) { \
values = vdupq_n_u##bit(val); \
} \
inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu( \
const void* ptr, uint64_t count) { \
if (count == size()) { \
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr)); \
} else { \
__at_align__ uint##bit##_t tmp_values[size()]; \
for (const auto i : c10::irange(size())) { \
tmp_values[i] = 0; \
} \
std::memcpy( \
tmp_values, \
reinterpret_cast<const uint##bit##_t*>(ptr), \
count * sizeof(uint##bit##_t)); \
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \
} \
} \
inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count) \
const { \
if (count == size()) { \
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values); \
} else { \
uint##bit##_t tmp_values[size()]; \
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values); \
std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t)); \
} \
}
VEC_UINT_NEON_OPS(16, 8)
template <typename step_t>
inline Vectorized<uint8_t> Vectorized<uint8_t>::arange(
uint8_t base,
step_t step) {
const Vectorized<uint8_t> base_vec(base);
const Vectorized<uint8_t> step_vec(step);
const uint8x16_t step_sizes = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
return vmlaq_u8(base_vec, step_sizes, step_vec);
}
template <>
Vectorized<uint8_t> inline operator>>(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
uint8x16_t x = a;
uint8x16_t bound = vdupq_n_u8(8);
uint8x16_t z = vminq_u8(b, bound);
return x >> z;
}
template <>
Vectorized<uint8_t> inline operator<<(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
uint8x16_t bound = vdupq_n_u8(8);
uint8x16_t z = vminq_u8(b, bound);
return vshlq_u8(a, vreinterpretq_s8_u8(z));
}
inline Vectorized<uint8_t> Vectorized<uint8_t>::set(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b,
uint64_t count) {
if (count == 0) {
return a;
} else if (count >= 16) {
return b;
} else {
// Build an array of flags: each bit of element is 1 if the corresponding
// bit in 'mask' is set, 0 otherwise.
uint8x16_t maskArray = {
static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_u8(maskArray, b.values, a.values);
}
}
template <>
Vectorized<uint8_t> inline operator/(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
uint8x16_t x = a;
uint8x16_t y = b;
return x / y;
}
template <>
Vectorized<uint8_t> inline clamp(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& min,
const Vectorized<uint8_t>& max) {
return minimum(max, maximum(min, a));
}
template <>
Vectorized<uint8_t> inline clamp_max(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& max) {
return minimum(max, a);
}
template <>
Vectorized<uint8_t> inline clamp_min(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& min) {
return maximum(min, a);
}
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
at::vec::Vectorized<uint8_t> src) {
auto u8x8 = vld1_u8(src.operator const uint8_t*());
auto u8x8 = vget_low_u8(src);
auto u16x8 = vmovl_u8(u8x8);
auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
@ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float(
Vectorized<float> inline convert_int8_half_register_to_float(
at::vec::Vectorized<uint8_t> src) {
auto u8x8 = vld1_u8(src.operator const uint8_t*());
auto u8x8 = vget_low_u8(src);
auto u16x8 = vmovl_u8(u8x8);
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));

View File

@ -272,28 +272,110 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
}
}
static bool getDisableAddmmCudaLt() {
static const auto env_value = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
if (env_value == "1") {
return true;
}
return false;
/*
* Checks whether DISABLE_ADDMM_CUDA_LT is set.
* Additionally, for ROCM we test whether the architecture supports the Lt.
*/
static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
// When hipBLASLt is not supported on the architecture, return true
#ifdef USE_ROCM
static const std::vector<std::string> archs = {
"gfx90a", "gfx942",
#if ROCM_VERSION >= 60300
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
#endif
#if ROCM_VERSION >= 70000
"gfx950", "gfx1150", "gfx1151"
#endif
};
const auto is_hipblas_lt_arch_supported = at::detail::getCUDAHooks().isGPUArch(archs, device.index());
if (!is_hipblas_lt_arch_supported) {
return true;
}
#endif
// Check whether it is disabled in the env
static const auto is_addmm_cuda_lt_disabled = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
if (is_addmm_cuda_lt_disabled == "1") {
return true;
}
return false;
}
#ifdef USE_ROCM
static bool isSupportedHipLtROCmArch(int index) {
static const std::vector<std::string> archs = {
"gfx90a", "gfx942",
#if ROCM_VERSION >= 60300
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
#endif
#if ROCM_VERSION >= 70000
"gfx950", "gfx1150", "gfx1151"
#endif
};
return at::detail::getCUDAHooks().isGPUArch(archs, index);
/*
* Check whether for the given input we want to enable the Lt interface
*/
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
// Implies 2D bias which we currently not send through Lt.
// TODO: this check is done pre col-major input preparation,
// so, this condition can be ralexed in cases when a col-major
// copy of result is needed.
if (result.is_same(self)) {
return false;
}
#if defined(USE_ROCM) && ROCM_VERSION == 60400
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
const auto args = cublasCommonArgs(mat1, mat2, result);
if (args.transa == 't' && args.transb == 't') {
return false;
}
#endif
const auto mat1_sizes = mat1.sizes();
const auto mat2_sizes = mat2.sizes();
#if defined(CUDA_VERSION) || defined(USE_ROCM)
const auto scalar_type = mat1.scalar_type();
return (beta.toComplexDouble() == 1.0
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
&& result.dim() == 2 && result.is_contiguous()
&& ( // some dtype restrictions
#ifndef USE_ROCM
scalar_type == at::ScalarType::Double ||
#endif
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16
)
&& ( // some shape/stride restrictions
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// NOTE: extension to mat1 because mat1/mat2 can be swapped based off
// their row-/col-majorness.
mat1_sizes[0] > 1 && mat1_sizes[1] > 1 &&
mat2_sizes[0] > 1 && mat2_sizes[1] > 1
// The last conditions is to skip 16b transA and non-trans-B having
// leading dim >> rows when they are sliced from a large tensor
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
// Related to avoiding the leading stride >> leading dim problematic case
// with 16b dtypes described above. For such dtypes we only allow inputs
// which are either row- or col-major (i.e. non-overlapping, compact memory layout).
// In that case the leading stride will be equal to the outer dim len.
// Why do we catch this case here? The following `prepare_matrix_for_cublas` method
// does not modify inputs as long as there is a stride of length 1
// and the leading stride is at least max(1, other dim length), so we might
// end up with contiguous cols but not rows (i.e. holes between different rows)
// and vice versa.
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
&& (
// filter by dtype
(scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) ||
// check mat1/mat2 is row-/col-major
(mat1.is_non_overlapping_and_dense() && mat2.is_non_overlapping_and_dense())
)
#endif
)
);
#endif
// no compliance by default
return false;
}
#endif
template <typename scalar_t>
void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
@ -335,7 +417,70 @@ void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const
}
}
template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmAndBiasCublasLt(
// args contains result which is modified
cublasCommonArgs& args,
const Tensor& self,
const Scalar& alpha,
Activation activation = Activation::None
) {
const auto* self_ptr = self.const_data_ptr<scalar_t>();
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
// TODO: maybe also return some success state?
launchTunableGemmAndBias<scalar_t>(
args, alpha, self_ptr, activation_to_gemm_and_blas_arg(activation)
);
return true;
}
return at::cuda::blas::gemm_and_bias<scalar_t, res_scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self_ptr,
args.result->data_ptr<res_scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
}
template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmCublas(
// args contains result which is modified
cublasCommonArgs& args,
const Scalar& alpha,
const Scalar& beta
) {
at::cuda::blas::gemm<scalar_t, res_scalar_t>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
beta.to<at::opmath_type<scalar_t>>(),
args.result->data_ptr<res_scalar_t>(),
args.result_ld
);
return true; // success!
}
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) {
// Shape checks {
// Make sure to keep addmm_cuda below in sync with this code; it
// preflights a check to try to avoid actually needing to call
// expand().
@ -345,105 +490,62 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
"expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
)
if (result.is_same(self)) {
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self.sizes()[0] == mat1.sizes()[0], "self dim 0 must match mat1 dim 0");
TORCH_CHECK(self.sizes()[1] == mat2.sizes()[1], "self dim 1 must match mat2 dim 1");
}
// } Shape checks
// NOLINTNEXTLINE(*c-array*)
TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
checkAllSameGPU(__func__, targs);
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
IntArrayRef self__sizes;
bool useLtInterface = false;
#if defined(USE_ROCM)
// When hipBLASLt is not supported on the architecture,
// disable_addmm_cuda_lt will always be to set to true
static bool disable_addmm_cuda_lt =
!isSupportedHipLtROCmArch(self.device().index()) || getDisableAddmmCudaLt();
#else
static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt();
#endif
// Handle whether to use the Lt interface {
static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device());
// if lt path fails, we recurse back into this function here and force the lt path to off
// we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent
bool disable_addmm_cuda_lt_final = disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#if defined(USE_ROCM) && ROCM_VERSION == 60400
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
cublasCommonArgs _args(mat1, mat2, result);
if (_args.transa == 't' && _args.transb == 't') {
disable_addmm_cuda_lt_final = true;
}
#endif
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#ifdef USE_ROCM
// Conditioned on the device index, which is not persistent
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
#endif
// Condition on the input
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
// }
at::ScalarType scalar_type = mat1.scalar_type();
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
#if defined(CUDA_VERSION) || defined(USE_ROCM)
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
// for cuda 11.4, cublasLtMatmul is activated
// the last two conditions is to skip 16b transA and non-trans-B having
// leading dim >> rows when they are sliced from a large tensor
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
if (!disable_addmm_cuda_lt_final) {
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
self.is_contiguous() && result.is_contiguous() &&
#ifdef USE_ROCM
(scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16) &&
#else
(scalar_type == at::ScalarType::Double ||
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16) &&
#endif
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
#else
mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
// avoid leading dim >> rows bugs
((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) ||
(mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) ||
(scalar_type != at::ScalarType::Half &&
scalar_type != at::ScalarType::BFloat16)) &&
((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) ||
(mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) ||
(scalar_type != at::ScalarType::Half &&
scalar_type != at::ScalarType::BFloat16));
#endif
}
#endif
if (!useLtInterface) {
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
}
self__sizes = self_->sizes();
} else {
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
self__sizes = self_->sizes();
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0");
TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1");
}
if (&result != &self) {
at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
at::native::copy_(result, *self_);
// Handle result/self shapes
if (!result.is_same(self)) {
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
if (disable_addmm_cuda_lt) {
// When in non-Lt path we do expand self even before
// check for beta != 0.0 to make sure that
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
// runs green.
return expand_size(self, result.sizes(), "addmm");
}
// copy next, should broadcast
return c10::MaybeOwned<Tensor>::borrowed(self);
}();
// We copy bias when in the non-Lt path
if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) {
// NOTE: self should broadcast over result
at::native::copy_(result, *self_maybe_expanded);
}
}
IntArrayRef result_sizes = result.sizes();
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
// Short circuit on empty result
if (result.numel() == 0) {
return result;
}
cublasCommonArgs args(mat1, mat2, result);
if (mat1.numel() == 0) {
// Short circuit if the reduction dim is empty
if (mat1.sizes()[1] == 0) {
// By definition, when beta==0, values in self should be ignored. nans and infs
// should not propagate
if (beta.toComplexDouble() == 0.) {
@ -455,158 +557,64 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
result,
self.expand(result.sizes()),
at::native::scalar_tensor(
beta,
self.scalar_type(),
std::nullopt /* layout */,
at::kCPU,
std::nullopt /* pin_memory */));
beta,
self.scalar_type(),
std::nullopt /* layout */,
at::kCPU,
std::nullopt /* pin_memory */
)
);
}
cublasCommonArgs args(mat1, mat2, result);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
if (useLtInterface) {
#if defined(USE_ROCM)
bool okay = true;
// The Lt path
if (!disable_addmm_cuda_lt) {
bool lt_success = false;
if (is_float_output_with_half_input) {
#ifdef USE_ROCM
TORCH_CHECK(false, "float output with half input is not enabled for ROCm");
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
activation_to_gemm_and_blas_arg(activation));
} else {
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
// This condition is needed for mm case on ROCm for hipblasLt path.
// Passing the bias ptr as null to avoid accuracy issues for mm case.
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
}
});
}
if (!okay) {
// lt path failed; recurse but disable lt path
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
}
#else
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
bool okay = true;
if (is_float_output_with_half_input) {
#else
if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) {
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
}
else {
okay = at::cuda::blas::gemm_and_bias<scalar_t, float>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<float>(),
args.result_ld,
activation_epilogue
);
}});
);
#endif
} else {
// !is_float_output_with_half_input
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
self.const_data_ptr<scalar_t>(),
activation_epilogue);
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
}
else {
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_epilogue
);
}});
}
if (!okay) {
// lt path failed; recurse but disable lt path
);
} // end is_float_output_with_half_input
if (!lt_success) {
// lt path failed; recurse but disable lt path
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
}
#endif
} else
{
// end Lt path
} else {
// No Lt, we use a GEMM instead
if (is_float_output_with_half_input) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda",
[&] {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
float* result_ptr = args.result->mutable_data_ptr<float>();
at::cuda::blas::gemm<scalar_t, float>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha_val,
mat1_ptr,
args.lda,
mat2_ptr,
args.ldb,
beta_val,
result_ptr,
args.result_ld);
});
launchGemmCublas<scalar_t, float>(args, alpha, beta);
}
);
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half,
@ -614,28 +622,12 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type,
"addmm_cuda",
[&] {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>();
at::cuda::blas::gemm<scalar_t>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha_val,
mat1_ptr,
args.lda,
mat2_ptr,
args.ldb,
beta_val,
result_ptr,
args.result_ld);
});
launchGemmCublas<scalar_t>(args, alpha, beta);
}
);
}
// Apply epilogue
switch (activation) {
case Activation::RELU:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
@ -647,14 +639,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
break;
default: break;
}
}
} // end GEMM path
// Preprocessor gate here needs to match the inverse of the check
// gating activation_to_gemm_and_blas_arg above; here we are manually
// performing a post-GELU because we weren't able to use the GELU
// epilogue above.
#if !defined(CUDA_VERSION) && !defined(USE_ROCM)
if (useLtInterface && activation == Activation::GELU) {
if (!disable_addmm_cuda_lt && activation == Activation::GELU) {
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
}
#endif

View File

@ -23,7 +23,7 @@ namespace at::native {
// The maximum number of threads in a block
#if defined(USE_ROCM)
constexpr int MAX_BLOCK_SIZE = 256;
constexpr int MAX_BLOCK_SIZE = 1024;
#else
constexpr int MAX_BLOCK_SIZE = 512;
#endif
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
#if defined(USE_ROCM)
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
#else
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
#endif

View File

@ -92,6 +92,16 @@ inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
output_offset + output_y * output_dim_x + output_x);
}
__device__ __forceinline__ int64_t reflect_index(int64_t x, int64_t len) {
const int64_t two = (len - 1) * 2;
if (two <= 0) {
return 0;
}
int64_t m = x % two;
if (m < 0) m += two;
return (m < len) ? m : (two - m);
}
template<typename scalar_t>
__global__ void reflection_pad1d_out_kernel(
const scalar_t * input, scalar_t * output,
@ -106,6 +116,28 @@ __global__ void reflection_pad1d_out_kernel(
}
}
template <typename scalar_t>
__global__ void reflection_pad1d_flat(
const scalar_t* __restrict__ input,
scalar_t* __restrict__ output,
int64_t input_w, int64_t pad_l, int64_t pad_r,
int64_t out_w, int64_t plane_count) {
const int64_t bx = blockDim.x;
const int64_t tx = threadIdx.x;
const int64_t total = plane_count * out_w;
const int64_t grid_stride = static_cast<int64_t>(bx) * gridDim.x;
int64_t linear = static_cast<int64_t>(blockIdx.x) * bx + tx;
for (; linear < total; linear += grid_stride) {
const int64_t plane = linear / out_w;
const int64_t x = linear - plane * out_w;
const int64_t j = reflect_index(x - pad_l, input_w);
output[plane * out_w + x] = input[plane * input_w + j];
}
}
template <typename scalar_t>
__global__ void reflection_pad1d_backward_out_kernel(
scalar_t * grad_input, const scalar_t * grad_output,
@ -710,25 +742,44 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
int64_t input_w = input_.size(dim_w);
int64_t output_w = input_w + pad_l + pad_r;
dim3 block_size(output_w > 256 ? 256 : output_w);
dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
Tensor input = input_.contiguous();
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
reflection_pad1d_out_kernel<<<
grid_size,
block_size,
0,
at::cuda::getCurrentCUDAStream()>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w,
pad_l,
pad_r);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
const int block_x = static_cast<int>(std::min<int64_t>(256, std::max<int64_t>(1, output_w)));
const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
const int max_x = prop->maxGridSize[0];
const int max_y = prop->maxGridSize[1];
const int max_z = prop->maxGridSize[2];
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out", [&] {
auto stream = at::cuda::getCurrentCUDAStream();
const int64_t gx = at::ceil_div(output_w, static_cast<int64_t>(block_x));
const bool fits3d = (nplane <= max_y) && (nbatch <= max_z) && (gx <= max_x);
if (fits3d) {
dim3 block(block_x, 1, 1);
dim3 grid(gx, static_cast<unsigned>(nplane), static_cast<unsigned>(nbatch));
reflection_pad1d_out_kernel<scalar_t><<<grid, block, 0, stream>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w, pad_l, pad_r);
} else {
dim3 block(block_x, 1, 1);
const int64_t plane_count = nplane * nbatch;
const int64_t total_blocks = at::ceil_div(plane_count * output_w, static_cast<int64_t>(block_x));
const int grid_x = static_cast<int>(std::min<int64_t>(max_x, std::max<int64_t>(1, total_blocks)));
dim3 grid(grid_x, 1, 1);
reflection_pad1d_flat<scalar_t><<<grid, block, 0, stream>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w, pad_l, pad_r, output_w, plane_count);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,

View File

@ -52,7 +52,7 @@ struct FusedAdagradMathFunctor {
using opmath_t = at::opmath_type<scalar_t>;
C10_DEVICE __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
FusedOptimizerTensorListMetadata<3>& tl,
const float* lr_ptr,
const double& lr,
@ -133,4 +133,4 @@ struct FusedAdagradMathFunctor {
} // namespace
} // namespace at::native
} // namespace at::native

View File

@ -1,8 +1,8 @@
add_loop_eager,compile_time_instruction_count,3070000000,0.1
add_loop_eager,compile_time_instruction_count,3184000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
update_hint_regression,compile_time_instruction_count,1719000000,0.1
update_hint_regression,compile_time_instruction_count,1645000000,0.1
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1

1 add_loop_eager compile_time_instruction_count 3070000000 3184000000 0.1
2 add_loop_eager_dynamic compile_time_instruction_count 4432000000 4595000000 0.1
3 add_loop_inductor compile_time_instruction_count 29660000000 29660000000 0.1
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 39910000000 39910000000 0.1
5 add_loop_inductor_gpu compile_time_instruction_count 26800000000 26800000000 0.1
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 1048000000 1096000000 0.1
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 15240000000 15240000000 0.1
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17020000000 17720000000 0.1
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3442000000 3152000000 0.1
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 9239000000 8301000000 0.1
20 mm_loop_inductor_gpu compile_time_instruction_count 4820968837 4958000000 0.1
21 mm_loop_inductor_dynamic_gpu compile_time_instruction_count 9051000000 9051000000 0.1
22 basic_NestedModule_eager compile_time_instruction_count 9554000000 9990000000 0.1
23 basic_InlineMod_eager compile_time_instruction_count 7618000000 8126000000 0.1
24
26
27
28
29
30
31
32
34
35
36
37
38
39
40
41
42
43
44
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
82
83
84
85
86
87
88
89

View File

@ -48,17 +48,89 @@ PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,Fa
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,58.529255,0.000000
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,54.645077,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,4.397014,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.739000,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.786000,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.911000,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,59.243500,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.066000,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.076000,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.225000,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.947691,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.291000,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.224000,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.912000,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.925851,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.0240000,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.069000,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.938000,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.308320,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.091000,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.710000,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.502000,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.787743,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.863000,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.939000,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.603000,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,7.978539,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.741000,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.757000,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,8.774000,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,159.754860,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,165.552000,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,165.755000,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,165.714000,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,165.360235,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,168.376000,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,169.604000,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,168.428000,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,3.928136,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.402000,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.567000,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,4.020000,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,56.413499,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,104.638000,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.335000,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.612000,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.925090,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.110000,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.389000,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.195000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.989000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.999000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.939000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.980000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.408000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.647000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.476000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.784000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.583000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.083000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.663000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.283000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.986000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.676000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.618000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.982000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.698000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.899000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.741000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.182000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.290000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.744000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.820000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.298000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.988000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.689000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.695000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.978000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.934000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.217000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.215000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.115000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.974000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.828000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.879000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.197000,0.000000
PyTorch,logical_and,"logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool",short,False,78.404254,0.000000
PyTorch,logical_and,logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,5.354032,0.000000
PyTorch,logical_and,logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,54.072783,0.000000
@ -71,6 +143,9 @@ PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.float32,short,False,6.631313,
PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16,short,False,6.476986,0.000000
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32,short,False,266.065131,0.000000
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16,short,False,295.503063,0.000000
PyTorch,all,all_M1_N1_K1_cpu,short,False,5.773000,0.000000
PyTorch,all,all_M64_N64_K64_cpu,short,False,89.427000,0.000000
PyTorch,all,all_M64_N64_K128_cpu,short,False,120.119000,0.000000
PyTorch,cat,"cat_sizes(1,1,1)_N2_dim0_cpu",short,False,4.301950,0.000000
PyTorch,cat,"cat_sizes(512,512,2)_N2_dim1_cpu",short,False,99.093415,0.000000
PyTorch,cat,"cat_sizes(128,1024,2)_N2_dim1_cpu",short,False,96.771578,0.000000

1 Benchmarking Framework Benchmarking Module Name Case Name tag run_backward Execution Time Peak Memory (KB)
48 PyTorch div div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32 short False 58.529255 0.000000
49 PyTorch mul mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32 short False 54.645077 0.000000
50 PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 4.397014 0.000000
51 PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.739000 0.000000
52 PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.786000 0.000000
53 PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.911000 0.000000
54 PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 59.243500 0.000000
55 PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.066000 0.000000
56 PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.076000 0.000000
57 PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.225000 0.000000
58 PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 57.947691 0.000000
59 PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 107.291000 0.000000
60 PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 107.224000 0.000000
61 PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.912000 0.000000
62 PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.925851 0.000000
63 PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 8.0240000 0.000000
64 PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 8.069000 0.000000
65 PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.938000 0.000000
66 PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 57.308320 0.000000
67 PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 107.091000 0.000000
68 PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 108.710000 0.000000
69 PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.502000 0.000000
70 PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 57.787743 0.000000
71 PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 108.863000 0.000000
72 PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 108.939000 0.000000
73 PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.603000 0.000000
74 PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 7.978539 0.000000
75 PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 8.741000 0.000000
76 PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 8.757000 0.000000
77 PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 8.774000 0.000000
78 PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 159.754860 0.000000
79 PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 165.552000 0.000000
80 PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 165.755000 0.000000
81 PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 165.714000 0.000000
82 PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 165.360235 0.000000
83 PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 168.376000 0.000000
84 PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 169.604000 0.000000
85 PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 168.428000 0.000000
86 PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 3.928136 0.000000
87 PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.402000 0.000000
88 PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.567000 0.000000
89 PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 4.020000 0.000000
90 PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 56.413499 0.000000
91 PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 104.638000 0.000000
92 PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 104.335000 0.000000
93 PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.612000 0.000000
94 PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.925090 0.000000
95 PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 106.110000 0.000000
96 PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.389000 0.000000
97 PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.195000 0.000000
98 PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.989000 0.000000
99 PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.999000 0.000000
100 PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.939000 0.000000
101 PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.980000 0.000000
102 PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 54.408000 0.000000
103 PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.647000 0.000000
104 PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.476000 0.000000
105 PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.784000 0.000000
106 PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.583000 0.000000
107 PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 108.083000 0.000000
108 PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 107.663000 0.000000
109 PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.283000 0.000000
110 PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.986000 0.000000
111 PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.676000 0.000000
112 PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.618000 0.000000
113 PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.982000 0.000000
114 PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 54.698000 0.000000
115 PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.899000 0.000000
116 PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.741000 0.000000
117 PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 51.182000 0.000000
118 PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.290000 0.000000
119 PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 107.744000 0.000000
120 PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 107.820000 0.000000
121 PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 51.298000 0.000000
122 PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.988000 0.000000
123 PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.689000 0.000000
124 PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.695000 0.000000
125 PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.978000 0.000000
126 PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 54.934000 0.000000
127 PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.217000 0.000000
128 PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 104.215000 0.000000
129 PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.115000 0.000000
130 PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.974000 0.000000
131 PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 106.828000 0.000000
132 PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.879000 0.000000
133 PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.197000 0.000000
134 PyTorch logical_and logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool short False 78.404254 0.000000
135 PyTorch logical_and logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool short False 5.354032 0.000000
136 PyTorch logical_and logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool short False 54.072783 0.000000
143 PyTorch baddbmm baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16 short False 6.476986 0.000000
144 PyTorch baddbmm baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32 short False 266.065131 0.000000
145 PyTorch baddbmm baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16 short False 295.503063 0.000000
146 PyTorch all all_M1_N1_K1_cpu short False 5.773000 0.000000
147 PyTorch all all_M64_N64_K64_cpu short False 89.427000 0.000000
148 PyTorch all all_M64_N64_K128_cpu short False 120.119000 0.000000
149 PyTorch cat cat_sizes(1,1,1)_N2_dim0_cpu short False 4.301950 0.000000
150 PyTorch cat cat_sizes(512,512,2)_N2_dim1_cpu short False 99.093415 0.000000
151 PyTorch cat cat_sizes(128,1024,2)_N2_dim1_cpu short False 96.771578 0.000000

View File

@ -71,8 +71,8 @@ binary_short_configs = op_bench.config_list(
],
cross_product_configs={
"device": ["cpu", "cuda"],
"dtype_one": [torch.int32],
"dtype_two": [torch.int32],
"dtype_one": [torch.int32, torch.uint8],
"dtype_two": [torch.int32, torch.uint8],
},
tags=["short"],
)
@ -82,8 +82,8 @@ binary_long_configs = op_bench.cross_product_configs(
N=[32, 64],
K=[256, 512],
device=["cpu", "cuda"],
dtype_one=[torch.int8, torch.int32],
dtype_two=[torch.int8, torch.int32],
dtype_one=[torch.int8, torch.int32, torch.uint8],
dtype_two=[torch.int8, torch.int32, torch.uint8],
tags=["long"],
)

View File

@ -207,42 +207,6 @@ templates_path = [
]
# TODO: document these and remove them from here.
# Fixes the duplicated
autosummary_filename_map = {
"torch.nn.utils.prune.identity": "torch.nn.utils.prune.identity_function",
"torch.nn.utils.prune.Identity": "torch.nn.utils.prune.Identity_class",
"torch.optim.adamw.adamw": "torch.optim.adamw.adamw_function",
"torch.optim.adamw.AdamW": "torch.optim.adamw.AdamW_class",
"torch.optim.asgd.asgd": "torch.optim.asgd.asgd_function",
"torch.optim.asgd.ASGD": "torch.optim.asgd.ASGD_class",
"torch.optim.nadam.nadam": "torch.optim.nadam.nadam_function",
"torch.optim.nadam.NAdam": "torch.optim.nadam.NAdam_class",
"torch.optim.radam.radam": "torch.optim.radam.radam_function",
"torch.optim.radam.RAdam": "torch.optim.radam.RAdam_class",
"torch.optim.rmsprop.rmsprop": "torch.optim.rmsprop.rmsprop_function",
"torch.optim.rmsprop.RMSprop": "torch.optim.rmsprop.RMSprop_class",
"torch.optim.rprop.rprop": "torch.optim.rprop.rprop_function",
"torch.optim.rprop.Rprop": "torch.optim.rprop.Rprop_class",
"torch.optim.sgd.sgd": "torch.optim.sgd.sgd_function",
"torch.optim.sgd.SGD": "torch.optim.sgd.SGD_class",
"torch.optim.adadelta.adadelta": "torch.optim.adadelta.adadelta_function",
"torch.optim.adadelta.Adadelta": "torch.optim.adadelta.Adadelta_class",
"torch.optim.adagrad.adagrad": "torch.optim.adagrad.adagrad_function",
"torch.optim.adagrad.Adagrad": "torch.optim.adagrad.Adagrad_class",
"torch.optim.adam.adam": "torch.optim.adam.adam_function",
"torch.optim.adam.Adam": "torch.optim.adam.Adam_class",
"torch.optim.adamax.adamax": "torch.optim.adamax.adamax_function",
"torch.optim.adamax.Adamax": "torch.optim.adamax.Adamax_class",
"torch.mtia.stream": "torch.mtia.stream_function",
"torch.mtia.Stream": "torch.mtia.Stream_class",
"torch.cpu.stream": "torch.cpu.stream_function",
"torch.cpu.Stream": "torch.cpu.Stream_class",
"torch.cuda.stream": "torch.cuda.stream_function",
"torch.cuda.Stream": "torch.cuda.Stream_class",
"torch.xpu.stream": "torch.xpu.stream_function",
"torch.xpu.Stream": "torch.xpu.Stream_class",
}
coverage_ignore_functions = [
# torch
"typename",
@ -3229,11 +3193,6 @@ autodoc_type_aliases = {
# Enable overriding of function signatures in the first line of the docstring.
autodoc_docstring_signature = True
# Exclude inherited IntEnum methods that have RST formatting issues in their docstrings
autodoc_default_options = {
"exclude-members": "from_bytes, to_bytes",
}
# -- katex javascript in header
#
# def setup(app):

View File

@ -253,6 +253,7 @@ regular full-precision tensor.
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
view
as_strided

View File

@ -8,7 +8,8 @@ class TestAutocast(TestCase):
def test_autocast_with_unsupported_type(self):
with self.assertWarnsRegex(
UserWarning,
"In openreg autocast, but the target dtype torch.float32 is not supported.",
"In openreg autocast, but the target dtype is not supported."
"openreg Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
):
with torch.autocast(device_type="openreg", dtype=torch.float32):
_ = torch.ones(10)

View File

@ -67,7 +67,21 @@ class TestFullyShardMemory(FSDPTest):
# allocate the cuBLAS workspaces before measuring the memory usage
# since the workspace size can differ between hardwares
lin = torch.nn.Linear(768, 768, device=device_type)
inp = torch.randn(1, 768, device=device_type)
# NOTE: before https://github.com/pytorch/pytorch/pull/163955,
# the input shape was (1, 768), so that the forward gemm used
# cublaslt, and the backward used cublas.
# With the aforementioned PR, and with shape (1, 768),
# the cublas path is used both in forward and in backward,
# altering peak memory usage not accounting for cublaslt.
# Here we change the input shape to (2, 768), and that swaps
# the cublas/cublaslt selection in the forward/backward,
# but that does not affect the peak memory usage stored in `base_mem_mb`.
# Reasons for the flip:
# before PR: no Lt in addmm when mat2 has nrows/ncols <= 1,
# after PR: no Lt in addmm when either mat1 or mat2 have nrows/ncols <= 1,
# since the input preparation can swap matrices based on output
# row-/col-majorness.
inp = torch.randn(2, 768, device=device_type)
lin(inp).sum().backward()
torch.get_device_module(device_type).empty_cache()
base_mem_mb = self._get_peak_active_memory_mb()

View File

@ -288,6 +288,18 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
)
def test_graph_break(self):
def fn(x):
with torch.fx.traceback.annotate({"pp_stage": 0}):
x = torch.sin(x)
torch._dynamo.graph_break()
x = torch.cos(x)
return x
opt_fn = torch.compile(fn, backend="eager")
x = torch.randn(10, requires_grad=True)
self.assertEqual(fn(x), opt_fn(x))
if __name__ == "__main__":
run_tests()

View File

@ -346,7 +346,7 @@ class TestAutocastMPS(TestCase):
def test_mps_autocast_error_message(self):
with self.assertWarnsRegex(
UserWarning,
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.",
"MPS Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
):
with torch.autocast(device_type="mps", dtype=torch.float32):
_ = torch.ones(10)

View File

@ -6,6 +6,7 @@ import builtins
import collections
import contextlib
import copy
import gc
import functools
import inspect
import io
@ -19,6 +20,7 @@ import traceback
import types
import typing
import unittest
import weakref
import warnings
from math import sqrt
from torch.multiprocessing import Process
@ -1624,6 +1626,25 @@ class TestFX(JitTestCase):
self.assertTrue(neg not in relu.users)
@skipIfTorchDynamo("Dynamo does not free right away")
def test_prepend_does_not_leak(self):
g = Graph()
x = g.placeholder("x")
relu = g.call_function(torch.relu, (x,))
neg = g.call_function(torch.neg, (x,))
relu.prepend(neg)
ref = weakref.ref(neg)
g.erase_node(neg)
del g
del x
del relu
del neg
gc.collect()
self.assertIsNone(ref())
def test_remove_uses_with_custom_filter(self):
g: torch.fx.Graph = Graph()
x: torch.fx.Node = g.placeholder("x")

View File

@ -7381,6 +7381,10 @@ torch.cuda.synchronize()
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@parametrize("use_legacy_api", [True, False])
@skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_dummy_mha_with_nt(self, device, use_legacy_api):
bs = 3
d1 = 2

View File

@ -8490,6 +8490,14 @@ class TestNNDeviceType(NNTestCase):
y_cuda_contig = pool(x_cuda.contiguous())
self.assertEqual(y_cuda_ch_last, y_cuda_contig)
@onlyCUDA
def test_large_reflect_pad(self, device):
# https://github.com/pytorch/pytorch/issues/165861
x = torch.rand(2**16, 2, device="cuda")
c = F.pad(x, (1, 1), mode="reflect")
c_cpu = F.pad(x.cpu(), (1, 1), mode="reflect")
self.assertEqual(c, c_cpu)
@onlyCUDA
@largeTensorTest("48GB", "cpu")
@largeTensorTest("48GB", "cuda")

View File

@ -247,6 +247,10 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_compile(self) -> None:
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
@ -576,6 +580,10 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_prune_dense_static_sort(self, dtype) -> None:
# Ideally we would like to clone and compare, but that won't work because the sorting order will be different
# instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
@ -621,6 +629,10 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@parametrize_backends
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
inp = torch.tensor(
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
@ -658,6 +670,10 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@parametrize_backends
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
M, N = 128, 256
# Construct x to make sure we always have exactly 8 elements per 4x4 tile
@ -692,6 +708,10 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_pack_both_ways_id(self, dtype) -> None:
N = 512
torch.manual_seed(0)
@ -729,6 +749,10 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_pack_both_ways_edge_case1(self, dtype) -> None:
# In this case, the heuristic will keep 7 values out of 16
# instead of 8. let's see how the kernel handles this
@ -754,6 +778,10 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_apply(self, dtype) -> None:
M, N = 256, 1024
x = torch.randn([M, N], dtype=dtype, device="cuda")
@ -770,6 +798,10 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_apply_dense(self, dtype) -> None:
M, N = 256, 1024
x = torch.randn([M, N], dtype=dtype, device="cuda")
@ -808,6 +840,10 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_matmuls(self, dtype) -> None:
M, N, K = 64, 256, 1024
a = torch.randn([M, K], device="cuda", dtype=dtype)
@ -843,6 +879,10 @@ class TestSparseSemiStructuredTraining(TestCase):
)
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_matmuls_mat_vec(self) -> None:
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
b = torch.randn([128], device="cuda", dtype=torch.float16)
@ -853,6 +893,10 @@ class TestSparseSemiStructuredTraining(TestCase):
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
@unittest.skipIf(
"RelWithAssert" in torch.__config__.show(),
"failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context",
)
def test_sp24_matmuls_bmm(self) -> None:
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)

View File

@ -2758,6 +2758,12 @@ class _NodeBase:
return_type: Any,
) -> None: ...
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
def _prepend(self, n: FxNode) -> None: ...
def _remove_from_list(self) -> None: ...
def __lt__(self, n: Self) -> _bool: ...
def __gt__(self, n: Self) -> _bool: ...
def __le__(self, n: Self) -> _bool: ...
def __ge__(self, n: Self) -> _bool: ...
class _NodeIter(Iterator[FxNode]):
def __init__(self, root: FxNode, reversed: _bool) -> None: ...

View File

@ -2810,5 +2810,15 @@
"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict."
]
}
],
"GB0279": [
{
"Gb_type": "torch.fx.traceback.annotate escaped from compiled region",
"Context": "str(self)",
"Explanation": "Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
"Hints": [
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
]
}

View File

@ -3502,10 +3502,8 @@ class InstructionTranslatorBase(
if isinstance(excp, Unsupported):
excp.remove_from_stats()
self.push(
self.inline_user_function_return(
VariableTracker.build(self, impl_CONTAINS_OP_fallback),
[left, right],
{},
VariableTracker.build(self, impl_CONTAINS_OP_fallback).call_function(
self, [left, right], {}
)
)
if op == 1:

View File

@ -745,9 +745,9 @@ class BuiltinVariable(VariableTracker):
)
def handler(tx, a, b):
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfill_fn_mapping[op]), [a, b], {}
)
return VariableTracker.build(
tx, polyfill_fn_mapping[op]
).call_function(tx, [a, b], {})
result.append(((VariableTracker, VariableTracker), handler))
return result
@ -1559,19 +1559,18 @@ class BuiltinVariable(VariableTracker):
)
else:
# Overrides for custom str method
# Pass method as function to call tx.inline_user_function_return
bound_method = str_method.__func__ # type: ignore[attr-defined]
try:
# Only supports certain function types
user_func_variable = variables.UserFunctionVariable(bound_method)
user_func_variable = VariableTracker.build(tx, bound_method)
except AssertionError:
# Won't be able to do inline the str method, return to avoid graph break
log.warning("Failed to create UserFunctionVariable", exc_info=True)
return
# Inline the user function
return tx.inline_user_function_return(user_func_variable, [arg], {})
return user_func_variable.call_function(tx, [arg], {})
elif isinstance(arg, (variables.ExceptionVariable,)):
if len(arg.args) == 0:
value = f"{arg.exc_type}"
@ -1925,8 +1924,8 @@ class BuiltinVariable(VariableTracker):
# VT(foo.__dict__). This simplifies the construction of the new
# dict.
args[0] = args[0].get_forwarded_dict(tx)
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.construct_dict),
return VariableTracker.build(tx, polyfills.construct_dict).call_function(
tx,
[VariableTracker.build(tx, user_cls), *args],
kwargs,
)
@ -2022,7 +2021,7 @@ class BuiltinVariable(VariableTracker):
):
iter_fn = arg.var_getattr(tx, "__iter__")
if isinstance(iter_fn, variables.UserMethodVariable):
out = tx.inline_user_function_return(iter_fn, args, kwargs)
out = iter_fn.call_function(tx, list(args), kwargs)
if isinstance(out, SetVariable):
return out
return BuiltinVariable(set).call_set(tx, out)

View File

@ -1295,6 +1295,16 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable):
def fn_name(self):
return "annotate"
def reconstruct_type(self, codegen: "PyCodegen"):
unimplemented_v2(
gb_type="torch.fx.traceback.annotate escaped from compiled region",
context=str(self),
explanation="Dynamo doesn't support graph break on torch.fx.traceback.annotate.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
class DynamoConfigPatchVariable(ContextWrappingVariable):
"""represents torch._dynamo.patch_dynamo_config"""

View File

@ -189,8 +189,8 @@ class ItertoolsVariable(VariableTracker):
*args, mutation_type=ValueMutationNew()
)
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.repeat), args, kwargs
return VariableTracker.build(tx, polyfills.repeat).call_function(
tx, args, kwargs
)
elif self.value is itertools.count:
return variables.CountIteratorVariable(

View File

@ -181,11 +181,9 @@ class BaseListVariable(VariableTracker):
if not len(args):
raise_args_mismatch(tx, name)
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.index),
[self] + list(args),
kwargs,
)
return VariableTracker.build(tx, polyfills.index).call_function(
tx, [self] + list(args), kwargs
)
elif name == "count":
if len(args) != 1:
raise_args_mismatch(tx, name)

View File

@ -542,11 +542,9 @@ class NNModuleVariable(VariableTracker):
args = [self] + args
else:
assert istype(fn, types.FunctionType)
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn, source=fn_source),
args,
kwargs,
)
return variables.UserFunctionVariable(
fn, source=fn_source
).call_function(tx, args, kwargs)
def call_method(
self,
@ -773,8 +771,8 @@ class NNModuleVariable(VariableTracker):
assert isinstance(fn, types.FunctionType)
src = AttrSource(AttrSource(self.source, name), "__func__")
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn, source=src),
return variables.UserFunctionVariable(fn, source=src).call_function(
tx,
[self] + list(args),
kwargs,
)
@ -851,8 +849,8 @@ class NNModuleVariable(VariableTracker):
# Inline the function
fn = getattr(module, name).__func__
fn_source = AttrSource(AttrSource(self.source, name), "__func__")
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn, source=fn_source),
return variables.UserFunctionVariable(fn, source=fn_source).call_function(
tx,
[self] + args,
kwargs,
)
@ -951,13 +949,18 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
# The program can mutate the nn module object but the saved `value`
# will not reflect the mutations. So, trace through the `__iter__`
# function to reflect any tracked mutations.
return tx.inline_user_function_return(
VariableTracker.build(tx, fn),
[
self,
],
{},
).unpack_var_sequence(tx)
return (
VariableTracker.build(tx, fn)
.call_function(
tx,
[
self,
],
{},
)
.unpack_var_sequence(tx)
)
return super().unpack_var_sequence(tx)

View File

@ -1085,8 +1085,8 @@ class TensorVariable(VariableTracker):
if value is not None:
from .. import polyfills
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.addcmul_inplace),
return VariableTracker.build(tx, polyfills.addcmul_inplace).call_function(
tx,
[self, tensor1, tensor2, value],
{},
)

View File

@ -568,16 +568,16 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
@register(torch.ops.inductor.accumulate_grad_.default)
def handle_accumulate_grad_(self, tx: "InstructionTranslator", *args, **kwargs):
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.accumulate_grad), args, kwargs
return VariableTracker.build(tx, polyfills.accumulate_grad).call_function(
tx, args, kwargs
)
@register(math.radians)
def handle_radians(self, tx: "InstructionTranslator", *args, **kwargs):
if not check_unspec_or_constant_args(args, kwargs):
# Use polyfill to convert math.radians(x) into math.pi * x / 180.0
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.radians), args, kwargs
return VariableTracker.build(tx, polyfills.radians).call_function(
tx, args, kwargs
)
@register(torch.is_inference_mode_enabled)
@ -829,8 +829,10 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
_, tx: "InstructionTranslator", *args, **kwargs
):
if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs:
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.foreach_lerp_inplace),
return VariableTracker.build(
tx, polyfills.foreach_lerp_inplace
).call_function(
tx,
args,
kwargs,
)
@ -840,8 +842,10 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
# In eager it's more performant to call item() from within the C op implementation
# in compile, it's more performant to not graph break.
if len(args) == 2 and isinstance(args[0], TensorVariable) and not kwargs:
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.foreach_pow_scalar),
return VariableTracker.build(
tx, polyfills.foreach_pow_scalar
).call_function(
tx,
args,
kwargs,
)
@ -1968,8 +1972,8 @@ class FuncTorchInterpreterVariable(BaseTorchVariable):
if name == "key":
return variables.EnumVariable(self.value.key())
elif name == "process":
return tx.inline_user_function_return(
variables.UserFunctionVariable(self.value.process.__func__),
return VariableTracker.build(tx, self.value.process.__func__).call_function(
tx,
[self] + args,
kwargs,
)

View File

@ -59,7 +59,7 @@ from ..utils import (
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import GenericContextWrappingVariable
from .functions import UserFunctionVariable, UserMethodVariable
from .functions import UserMethodVariable
from .lazy import LazyVariableTracker
from .lists import TupleVariable
from .tensor import TensorSubclassVariable, TensorVariable
@ -620,7 +620,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
elif isinstance(attr, property):
getter_source = AttrSource(attr_source, "fget")
getter = attr.fget
getter_var = UserFunctionVariable(getter, source=getter_source)
getter_var = VariableTracker.build(tx, getter, source=getter_source)
return getter_var.call_function(tx, [self], {})
elif isinstance(attr, classmethod):

View File

@ -490,8 +490,8 @@ class UserDefinedClassVariable(UserDefinedVariable):
return NullContextVariable(*args, **kwargs)
elif self.value is collections.OrderedDict:
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.construct_dict),
return VariableTracker.build(tx, polyfills.construct_dict).call_function(
tx,
[self, *args],
kwargs,
)
@ -823,10 +823,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
return variables.MappingProxyVariable(args[0])
elif SideEffects.cls_supports_mutation_side_effects(self.value) and self.source:
with do_not_convert_to_tracable_parameter():
return tx.inline_user_function_return(
VariableTracker.build(
tx, polyfills.instantiate_user_defined_class_object
),
return VariableTracker.build(
tx, polyfills.instantiate_user_defined_class_object
).call_function(
tx,
[self, *args],
kwargs,
)
@ -1803,8 +1803,8 @@ class SourcelessGraphModuleVariable(UserDefinedObjectVariable):
) -> "VariableTracker":
fn_variable = variables.UserFunctionVariable(self.value.forward.__func__)
args = [self] + args
return tx.inline_user_function_return(
fn_variable,
return fn_variable.call_function(
tx,
args,
kwargs,
)

View File

@ -3147,35 +3147,82 @@ class AlgorithmSelectorCache(PersistentCache):
for i, x in enumerate(input_nodes)
}
example_inputs = list(unique_example_inputs.values())
example_inputs_extern = [
(
unique_example_inputs[input_node.get_name()]
if unique_example_inputs[input_node.get_name()].is_mkldnn
else torch.as_strided(
unique_example_inputs[input_node.get_name()],
V.graph.sizevars.size_hints(
input_node.get_size(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
V.graph.sizevars.size_hints(
input_node.get_stride(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
V.graph.sizevars.size_hint(
input_node.get_layout().offset,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
example_inputs_extern = []
for input_node in input_nodes:
if unique_example_inputs[input_node.get_name()].is_mkldnn:
example_inputs_extern.append(
unique_example_inputs[input_node.get_name()]
)
else:
base = unique_example_inputs[input_node.get_name()]
base = base if base._base is None else base._base
sizes = tuple(
V.graph.sizevars.atomically_apply_size_hint(
size,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
)
for size in input_node.get_size()
)
strides = tuple(
V.graph.sizevars.atomically_apply_size_hint(
stride,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
)
for stride in input_node.get_stride()
)
storage_offset = V.graph.sizevars.atomically_apply_size_hint(
input_node.get_layout().offset,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
)
# Check if the required storage size exceeds the current storage
# to avoid illegal memory access
needed_size = torch._prims_common.compute_required_storage_length(
sizes, strides, storage_offset
)
current_size = base.storage().size()
if needed_size > current_size:
# Create a new base tensor with sufficient storage
new_base = torch.randn(
needed_size,
dtype=base.dtype,
device=base.device,
requires_grad=base.requires_grad,
)
base = new_base.as_strided(
base.size(), base.stride(), base.storage_offset()
)
example_inputs_extern.append(
torch.as_strided(base, sizes, strides, storage_offset)
)
)
for input_node in input_nodes
]
out = cls.benchmark_example_value(layout, hint_override=hint_override)
out_extern = torch.as_strided(
out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
# Also check the output tensor for storage size
out_base = out if out._base is None else out._base
out_offset = V.graph.sizevars.size_hint(layout.offset)
needed_out_size = torch._prims_common.compute_required_storage_length(
out.size(), out.stride(), out_offset
)
current_out_size = out_base.storage().size()
if needed_out_size > current_out_size:
# Create a new base tensor with sufficient storage
new_out_base = torch.randn(
needed_out_size,
dtype=out_base.dtype,
device=out_base.device,
requires_grad=out_base.requires_grad,
)
out_base = new_out_base.as_strided(
out_base.size(), out_base.stride(), out_base.storage_offset()
)
out_extern = torch.as_strided(out_base, out.size(), out.stride(), out_offset)
expected = None
if VERIFY:
choices[0].benchmark(*example_inputs_extern, out=out_extern)
@ -3616,10 +3663,13 @@ class AlgorithmSelectorCache(PersistentCache):
# So we need call as_strided in the end to 'view' the tensor with the correct
# sizes/strides
return AlgorithmSelectorCache.generate_example_value(
V.graph.sizevars.size_hints(
node.get_size(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
tuple(
V.graph.sizevars.atomically_apply_size_hint(
size,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
)
for size in node.get_size()
),
tuple(
V.graph.sizevars.atomically_apply_size_hint(
@ -3632,13 +3682,20 @@ class AlgorithmSelectorCache(PersistentCache):
node.get_device(),
node.get_dtype(),
# pyrefly: ignore # missing-attribute
node.layout.offset,
V.graph.sizevars.size_hints(
# pyrefly: ignore # bad-argument-type
V.graph.get_allocation_size(node),
V.graph.sizevars.atomically_apply_size_hint(
node.layout.offset,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
tuple(
V.graph.sizevars.atomically_apply_size_hint(
size,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
)
# pyrefly: ignore # bad-argument-type
for size in V.graph.get_allocation_size(node)
),
)
@staticmethod

View File

@ -230,9 +230,9 @@ class autocast:
raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
)
if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
self.fast_dtype = dtype
self.fast_dtype = (
torch.get_autocast_dtype(device_type) if dtype is None else dtype
)
if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = device_type
@ -243,6 +243,9 @@ class autocast:
raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'"
)
device_supported_dtypes = [torch.bfloat16, torch.float16]
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
if self.device == self.custom_backend_name:
necessary_funcs = [
@ -259,110 +262,55 @@ class autocast:
assert hasattr(self.custom_device_mod, func), (
message + f"But the func `{func}` is missing. \n"
)
device_supported_dtypes = self.custom_device_mod.get_amp_supported_dtype()
self._cache_enabled = torch.is_autocast_cache_enabled()
if (
enabled
and self.device == "cuda"
and torch.cuda.amp.common.amp_definitely_not_available()
):
warnings.warn(
"User provided device_type of 'cuda', but CUDA is not available. Disabling"
)
enabled = False
if cache_enabled is not None:
self._cache_enabled = cache_enabled
self._cache_enabled = (
torch.is_autocast_cache_enabled()
if cache_enabled is None
else cache_enabled
)
if self.device == "cpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype and enabled:
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "CPU Autocast only supports dtype of "
error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message)
enabled = False
elif self.device == "mtia":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "maia":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "xpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "ipu":
supported_dtypes = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtypes:
error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "hpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == self.custom_backend_name:
supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
if self.fast_dtype not in supported_dtype:
error_message = f"In {self.custom_backend_name} autocast, but the target dtype {self.fast_dtype} is not supported. "
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message)
enabled = False
elif self.device == "cuda":
if (
enabled
and self.fast_dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError(
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
)
elif self.device == "mps":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
device_name = (
self.device
if self.device == self.custom_backend_name
else self.device.upper()
)
if enabled:
# Special case for CUDA AMP and bfloat16 support
if self.device == "cuda":
if torch.cuda.amp.common.amp_definitely_not_available():
warnings.warn(
"CUDA is not available or torch_xla is imported. Disabling autocast."
)
enabled = False
elif (
self.fast_dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError(
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
)
elif self.fast_dtype not in device_supported_dtypes:
error_message = (
"In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently."
f"In {device_name} autocast, but the target dtype is not supported. Disabling autocast.\n"
f"{device_name} Autocast only supports dtypes of "
+ ", ".join(map(str, device_supported_dtypes))
+ " currently."
)
warnings.warn(error_message)
enabled = False
elif self.fast_dtype == torch.bfloat16:
if not torch.backends.mps.is_macos_or_newer(14, 0):
# Special case for MPS bfloat16 support on macOS < 14
if (
self.device == "mps"
and self.fast_dtype == torch.bfloat16
and not torch.backends.mps.is_macos_or_newer(14, 0)
):
error_message = (
"In MPS autocast, but the target dtype torch.bfloat16 is not supported "
"on macOS versions below 14. Disabling autocast."
)
warnings.warn(error_message)
enabled = False
elif self.device == "xla":
supported_dtype = [torch.float16, torch.bfloat16]
if self.fast_dtype not in supported_dtype:
error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += (
"XLA Autocast only supports dtype of torch.bfloat16 currently."
)
warnings.warn(error_message)
enabled = False
self._enabled = enabled
def __enter__(self):

View File

@ -235,10 +235,9 @@ class _DerivedObserverOrFakeQuantize(ObserverBase):
from .utils import is_per_channel
if is_per_channel(self.qscheme):
if self.ch_axis is None:
raise AssertionError(
"Must provide a valid ch_axis if qscheme is per channel"
)
assert self.ch_axis is not None, (
"Must provide a valid ch_axis if qscheme is per channel"
)
def forward(self, x: Tensor) -> Tensor:
return x

View File

@ -92,10 +92,9 @@ def channel_range(input, axis=0):
mins = min_over_ndim(input, axis_list)
maxs = max_over_ndim(input, axis_list)
if mins.size(0) != input.size(axis):
raise AssertionError(
"Dimensions of resultant channel range does not match size of requested axis"
)
assert mins.size(0) == input.size(axis), (
"Dimensions of resultant channel range does not match size of requested axis"
)
return maxs - mins

View File

@ -45,8 +45,7 @@ class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase):
**observer_kwargs,
):
super().__init__()
if quant_min >= quant_max:
raise AssertionError("quant_min must be strictly less than quant_max.")
assert quant_min < quant_max, "quant_min must be strictly less than quant_max."
self.quant_min = quant_min
self.quant_max = quant_max
# also pass quant_min and quant_max to observer
@ -57,16 +56,19 @@ class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase):
self.scale = Parameter(torch.tensor([scale]))
self.zero_point = Parameter(torch.tensor([zero_point]))
else:
if not (isinstance(channel_len, int) and channel_len > 0):
raise AssertionError("Channel size must be a positive integer.")
assert isinstance(channel_len, int) and channel_len > 0, (
"Channel size must be a positive integer."
)
self.scale = Parameter(torch.tensor([scale] * channel_len))
self.zero_point = Parameter(torch.tensor([zero_point] * channel_len))
self.activation_post_process = observer(**observer_kwargs)
if not torch.iinfo(self.activation_post_process.dtype).min > quant_min:
raise AssertionError("quant_min out of bound")
if quant_max > torch.iinfo(self.activation_post_process.dtype).max:
raise AssertionError("quant_max out of bound")
assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, (
"quant_min out of bound"
)
assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, (
"quant_max out of bound"
)
self.dtype = self.activation_post_process.dtype
self.qscheme = self.activation_post_process.qscheme
self.ch_axis = (

View File

@ -88,10 +88,9 @@ def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu):
>>> lr = nn.LeakyReLU(0.01)
>>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr)
"""
if linear.training != bn.training or bn.training != leaky_relu.training:
raise AssertionError(
"Linear, BN and LeakyReLU all must be in the same mode (train or eval)."
)
assert linear.training == bn.training and bn.training == leaky_relu.training, (
"Linear, BN and LeakyReLU all must be in the same mode (train or eval)."
)
if is_qat:
raise NotImplementedError(

View File

@ -164,11 +164,10 @@ def remove_boolean_dispatch_from_name(p) -> Any:
return "torch.nn.functional.adaptive_max_pool2d"
elif p is F.adaptive_max_pool3d:
return "torch.nn.functional.adaptive_max_pool3d"
if "boolean_dispatch" in str(p):
raise AssertionError(
f"{p} does not have a human readable representation in "
+ "quantization documentation"
)
assert "boolean_dispatch" not in str(p), (
f"{p} does not have a human readable representation in "
+ "quantization documentation"
)
return p
@ -301,8 +300,7 @@ def _get_fuser_method_in_reversed_nested_tuple_format(
The first argument of a fuser method is always `is_qat` and is not affected
in the conversion. We currently only support functions with 3 or 4 arguments.
"""
if config.fuser_method is None:
raise AssertionError("config.fuser_method must be provided")
assert config.fuser_method is not None
if config._pattern_complex_format is not None:
return config.fuser_method
if not isinstance(config.pattern, tuple):

View File

@ -175,10 +175,9 @@ class FakeQuantize(FakeQuantizeBase):
super().__init__()
# Populate quant_min/quant_max to observer_kwargs if valid
if quant_min is not None and quant_max is not None:
if quant_min > quant_max:
raise AssertionError(
"quant_min must be less than or equal to quant_max"
)
assert quant_min <= quant_max, (
"quant_min must be less than or equal to quant_max"
)
dtype = observer_kwargs.get("dtype", torch.quint8)
if hasattr(observer, "p"):
# In case observer is _PartialWrapper, dtype can be stored in
@ -187,11 +186,9 @@ class FakeQuantize(FakeQuantizeBase):
"dtype", dtype
)
# pyrefly: ignore # bad-argument-type
if torch.iinfo(dtype).min > quant_min:
raise AssertionError("quant_min out of bound")
assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound"
# pyrefly: ignore # bad-argument-type
if quant_max > torch.iinfo(dtype).max:
raise AssertionError("quant_max out of bound")
assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound"
observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
observer_kwargs["is_dynamic"] = is_dynamic
self.activation_post_process = observer(**observer_kwargs)
@ -213,12 +210,11 @@ class FakeQuantize(FakeQuantizeBase):
if hasattr(self.activation_post_process, "ch_axis")
else -1
)
if not (_is_per_channel(self.qscheme) or _is_per_tensor(self.qscheme)):
raise AssertionError(
"Only per channel and per tensor quantization are supported in fake quantize"
+ " got qscheme: "
+ str(self.qscheme)
)
assert _is_per_channel(self.qscheme) or _is_per_tensor(self.qscheme), (
"Only per channel and per tensor quantization are supported in fake quantize"
+ " got qscheme: "
+ str(self.qscheme)
)
self.is_per_channel = _is_per_channel(self.qscheme)
@torch.jit.export
@ -299,10 +295,7 @@ class FakeQuantize(FakeQuantizeBase):
if name == "scale":
self.scale.resize_(val.shape)
else:
if name != "zero_point":
raise AssertionError(
"Expected 'zero_point' but got different state key"
)
assert name == "zero_point"
self.zero_point.resize_(val.shape)
# For torchscript module we need to update the attributes here since we do not
# call the `_load_from_state_dict` function defined module.py
@ -310,10 +303,7 @@ class FakeQuantize(FakeQuantizeBase):
if name == "scale":
self.scale.copy_(val)
else:
if name != "zero_point":
raise AssertionError(
"Expected 'zero_point' but got different state key"
)
assert name == "zero_point"
self.zero_point.copy_(val)
elif strict:
missing_keys.append(key)
@ -339,19 +329,17 @@ class FixedQParamsFakeQuantize(FakeQuantize):
# TODO: rename observer to observer_ctr
def __init__(self, observer):
super().__init__(observer=observer)
if type(self.activation_post_process) is not FixedQParamsObserver:
raise AssertionError(
f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}"
)
assert type(self.activation_post_process) is FixedQParamsObserver, (
f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}"
)
self._observer_ctr = observer
self.scale = self.activation_post_process.scale
self.zero_point = self.activation_post_process.zero_point
if not _is_per_tensor(self.qscheme):
raise AssertionError(
"Only per tensor quantization is supported"
+ " FixedQParamsFakeQuantize module, got qscheme:"
+ str(self.qscheme)
)
assert _is_per_tensor(self.qscheme), (
"Only per tensor quantization is supported"
+ " FixedQParamsFakeQuantize module, got qscheme:"
+ str(self.qscheme)
)
@torch.jit.export
def calculate_qparams(self): # type: ignore[override]
@ -394,13 +382,12 @@ class FusedMovingAvgObsFakeQuantize(FakeQuantize):
**observer_kwargs: Any,
) -> None:
super().__init__(observer, quant_min, quant_max, **observer_kwargs)
if not isinstance(
assert isinstance(
self.activation_post_process,
(MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver),
):
raise AssertionError(
"Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
)
), (
"Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
)
self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
self.is_symmetric_quant = _is_symmetric_quant(

View File

@ -35,10 +35,9 @@ def fuse_conv_bn(is_qat, conv, bn):
>>> # xdoctest: +SKIP
>>> m2 = fuse_conv_bn(m1, b1)
"""
if conv.training != bn.training:
raise AssertionError(
"Conv and BN both must be in the same mode (train or eval)."
)
assert conv.training == bn.training, (
"Conv and BN both must be in the same mode (train or eval)."
)
fused_module_class_map = {
nn.Conv1d: nni.ConvBn1d,
@ -47,18 +46,13 @@ def fuse_conv_bn(is_qat, conv, bn):
}
if is_qat:
if bn.num_features != conv.out_channels:
raise AssertionError(
"Output channel of Conv2d must match num_features of BatchNorm2d."
)
if not bn.affine:
raise AssertionError(
"Only support fusing BatchNorm2d with affine set to True"
)
if not bn.track_running_stats:
raise AssertionError(
"Only support fusing BatchNorm2d with tracking_running_stats set to True"
)
assert bn.num_features == conv.out_channels, (
"Output channel of Conv2d must match num_features of BatchNorm2d"
)
assert bn.affine, "Only support fusing BatchNorm2d with affine set to True"
assert bn.track_running_stats, (
"Only support fusing BatchNorm2d with tracking_running_stats set to True"
)
fused_module_class = fused_module_class_map.get((type(conv)), None)
if fused_module_class is not None:
return fused_module_class(conv, bn)
@ -87,10 +81,9 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
>>> # xdoctest: +SKIP
>>> m2 = fuse_conv_bn_relu(m1, b1, r1)
"""
if not (conv.training == bn.training == relu.training):
raise AssertionError(
"Conv and BN both must be in the same mode (train or eval)."
)
assert conv.training == bn.training == relu.training, (
"Conv and BN both must be in the same mode (train or eval)."
)
fused_module: Optional[type[nn.Sequential]] = None
if is_qat:
map_to_fused_module_train = {
@ -98,18 +91,13 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
nn.Conv2d: nni.ConvBnReLU2d,
nn.Conv3d: nni.ConvBnReLU3d,
}
if bn.num_features != conv.out_channels:
raise AssertionError(
"Output channel of Conv2d must match num_features of BatchNorm2d"
)
if not bn.affine:
raise AssertionError(
"Only support fusing BatchNorm2d with affine set to True"
)
if not bn.track_running_stats:
raise AssertionError(
"Only support fusing BatchNorm2d with tracking_running_stats set to True"
)
assert bn.num_features == conv.out_channels, (
"Output channel of Conv must match num_features of BatchNorm"
)
assert bn.affine, "Only support fusing BatchNorm with affine set to True"
assert bn.track_running_stats, (
"Only support fusing BatchNorm with tracking_running_stats set to True"
)
fused_module = map_to_fused_module_train.get(type(conv), None)
if fused_module is not None:
return fused_module(conv, bn, relu)
@ -146,24 +134,18 @@ def fuse_linear_bn(is_qat, linear, bn):
>>> # xdoctest: +SKIP
>>> m2 = fuse_linear_bn(m1, b1)
"""
if linear.training != bn.training:
raise AssertionError(
"Linear and BN both must be in the same mode (train or eval)."
)
assert linear.training == bn.training, (
"Linear and BN both must be in the same mode (train or eval)."
)
if is_qat:
if bn.num_features != linear.out_features:
raise AssertionError(
"Output features of Linear must match num_features of BatchNorm1d"
)
if not bn.affine:
raise AssertionError(
"Only support fusing BatchNorm1d with affine set to True"
)
if not bn.track_running_stats:
raise AssertionError(
"Only support fusing BatchNorm1d with tracking_running_stats set to True"
)
assert bn.num_features == linear.out_features, (
"Output features of Linear must match num_features of BatchNorm1d"
)
assert bn.affine, "Only support fusing BatchNorm1d with affine set to True"
assert bn.track_running_stats, (
"Only support fusing BatchNorm1d with tracking_running_stats set to True"
)
return nni.LinearBn1d(linear, bn)
else:
return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
@ -185,10 +167,9 @@ def fuse_convtranspose_bn(is_qat, convt, bn):
>>> # xdoctest: +SKIP
>>> m2 = fuse_convtranspose_bn(m1, b1)
"""
if convt.training != bn.training:
raise AssertionError(
"ConvTranspose and BN both must be in the same mode (train or eval)."
)
assert convt.training == bn.training, (
"ConvTranspose and BN both must be in the same mode (train or eval)."
)
if is_qat:
raise Exception( # noqa: TRY002
@ -243,8 +224,7 @@ def get_fuser_method(op_list, additional_fuser_method_mapping=None):
_DEFAULT_OP_LIST_TO_FUSER_METHOD, additional_fuser_method_mapping
)
fuser_method = all_mappings.get(op_list, None)
if fuser_method is None:
raise AssertionError(f"did not find fuser method for: {op_list} ")
assert fuser_method is not None, f"did not find fuser method for: {op_list} "
return fuser_method
@ -309,6 +289,5 @@ def get_fuser_method_new(
fuser_method = fuser_method_mapping.get(op_pattern)
if fuser_method is not None:
break
if fuser_method is None:
raise AssertionError(f"did not find fuser method for: {op_pattern} ")
assert fuser_method is not None, f"did not find fuser method for: {op_pattern} "
return fuser_method

View File

@ -249,17 +249,17 @@ class UniformQuantizationObserverBase(ObserverBase):
)
self.reduce_range = reduce_range
self.register_buffer("eps", torch.tensor([eps], **factory_kwargs))
if self.qscheme not in (
assert self.qscheme in (
torch.per_tensor_affine,
torch.per_tensor_symmetric,
torch.per_channel_affine,
torch.per_channel_symmetric,
torch.per_channel_affine_float_qparams,
):
raise AssertionError(
"Default Observer only works for per_tensor_affine, per_tensor_symmetric, "
"per_channel_affine, per_channel_symmetric and per_channel_float_qparams quantization scheme"
)
), (
"Default Observer only works for per_tensor_affine, \
per_tensor_symmetric, per_channel_affine, \
per_channel_symmetric and per_channel_float_qparams quantization scheme"
)
_ALLOWED_DTYPES = (
torch.qint8,
@ -275,10 +275,9 @@ class UniformQuantizationObserverBase(ObserverBase):
torch.uint16,
)
if self.dtype not in _ALLOWED_DTYPES:
raise AssertionError(
f"Default Observer only works for {_ALLOWED_DTYPES} data type"
)
assert self.dtype in _ALLOWED_DTYPES, (
f"Default Observer only works for {_ALLOWED_DTYPES} data type"
)
self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
if self.has_customized_qrange:
# pyrefly: ignore # bad-argument-type
@ -337,12 +336,12 @@ class UniformQuantizationObserverBase(ObserverBase):
"""
# The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
# based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
if not quant_min <= 0 <= quant_max:
raise AssertionError("Used-specified quantization range must include 0.")
if quant_min >= quant_max:
raise AssertionError(
"qmin must be strictly less than qmax for user-specified quantization range."
)
assert quant_min <= 0 <= quant_max, (
"Used-specified quantization range must include 0."
)
assert quant_min < quant_max, (
"qmin must be strictly less than qmax for user-specified quantization range."
)
@torch.jit.export
def _calculate_qparams(
@ -1132,8 +1131,7 @@ class HistogramObserver(UniformQuantizationObserverBase):
This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in
caffe2/quantization/server/norm_minimization.cc
"""
if self.histogram.size()[0] != self.bins:
raise AssertionError("bins mismatch")
assert self.histogram.size()[0] == self.bins, "bins mismatch"
bin_width = (self.max_val - self.min_val) / self.bins
# cumulative sum
@ -1254,10 +1252,8 @@ class HistogramObserver(UniformQuantizationObserverBase):
return transformed_orig_hist + update_hist
# We assume the update_hist is already in the target range, we will map the orig_max to it
if update_min > orig_min:
raise AssertionError("update_min must be <= orig_min")
if update_max < orig_max:
raise AssertionError("update_max must be >= orig_max")
assert update_min <= orig_min
assert update_max >= orig_max
# Now we need to turn the old_histogram, into the range of the new histogram
transformed_orig_hist = self._upscale_histogram(
@ -1277,8 +1273,9 @@ class HistogramObserver(UniformQuantizationObserverBase):
self.min_val.copy_(min_val)
self.max_val.resize_(max_val.shape)
self.max_val.copy_(max_val)
if min_val.numel() != 1 or max_val.numel() != 1:
raise AssertionError("histogram min/max values must be scalar.")
assert min_val.numel() == 1 and max_val.numel() == 1, (
"histogram min/max values must be scalar."
)
new_histogram = torch.histc(x, self.bins, min=min_val, max=max_val) # type: ignore[arg-type]
self.histogram.detach_().resize_(new_histogram.shape)
self.histogram.copy_(new_histogram)
@ -1353,11 +1350,10 @@ class HistogramObserver(UniformQuantizationObserverBase):
return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor(
[0], device=self.min_val.device.type
)
if self.bins != len(self.histogram):
raise AssertionError(
"The number of bins in histogram should be equal to the number of bins "
"supplied while making this observer"
)
assert self.bins == len(self.histogram), (
"The number of bins in histogram should be equal to the number of bins "
"supplied while making this observer"
)
new_min, new_max = self._non_linear_param_search()
@ -1789,10 +1785,9 @@ def get_block_size(
input_shape: The input tensor shape possibly more than 2 dimensions
granularity: The granularity type of the quantization
"""
if not isinstance(granularity, Granularity):
raise AssertionError(
"Please provide an instance of Granularity, not subclass of it"
)
assert isinstance(granularity, Granularity), (
"Please provide an instance of Granularity, not subclass of it"
)
if isinstance(granularity, PerTensor):
return input_shape
elif isinstance(granularity, PerAxis):
@ -1802,10 +1797,9 @@ def get_block_size(
elif isinstance(granularity, PerRow):
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
elif isinstance(granularity, PerGroup):
if len(input_shape) != 2:
raise AssertionError(
f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
)
assert len(input_shape) == 2, (
f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
)
return (1, granularity.group_size)
elif isinstance(granularity, PerToken):
block_size = [1] * len(input_shape)
@ -1842,8 +1836,8 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module):
**kwargs,
):
super().__init__()
if granularity is None:
raise AssertionError("granularity is None")
assert granularity is not None, "granularity is None"
self.mapping_type = mapping_type
self.target_dtype = target_dtype
self.granularity = granularity
@ -1881,10 +1875,10 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module):
from torch.ao.quantization.fx.utils import create_getattr_from_value
with model.graph.inserting_before(observer_node):
if self.block_size is None:
raise AssertionError("Expecting block_size to be populated")
if self.original_dtype is None:
raise AssertionError("Expecting original_dtype to be populated")
assert self.block_size is not None, "Expecting block_size to be populated"
assert self.original_dtype is not None, (
"Expecting original_dtype to be populated"
)
if hasattr(self, "is_dynamic") and self.is_dynamic:
choose_qparams_affine = model.graph.call_function(
torch.ops.pt2e_quant.choose_qparams_affine,

View File

@ -565,10 +565,9 @@ def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> N
torch.ao.quantization.MovingAveragePerChannelMinMaxObserver,
),
)
if is_per_channel:
raise AssertionError(
"Per channel weight observer is not supported yet for ConvTranspose{n}d."
)
assert not is_per_channel, (
"Per channel weight observer is not supported yet for ConvTranspose{n}d."
)
if sys.version_info < (3, 12):
@ -600,8 +599,7 @@ def _add_module_to_qconfig_obs_ctr(
return qconfig
def get_factory_kwargs_based_on_module_device():
if not isinstance(module, torch.nn.Module):
raise AssertionError("module must be an instance of torch.nn.Module")
assert isinstance(module, torch.nn.Module)
devices = {p.device for p in module.parameters()} | {
p.device for p in module.buffers()
}
@ -673,10 +671,7 @@ def qconfig_equals(q1: QConfigAny, q2: QConfigAny):
if q1 is None or q2 is None:
return q1 == q2
else:
if q1 is None or q2 is None:
raise AssertionError(
"Both q1 and q2 must be non-None for qconfig comparison"
)
assert q1 is not None and q2 is not None
try:
# Qconfig weight and activation can be either a partial wrapper,
# or an observer class. Special handling is required (above) for

View File

@ -252,11 +252,10 @@ def get_static_quant_module_class(
additional_static_quant_mapping,
)
static_quant_module_class = all_mappings.get(float_module_class, None)
if static_quant_module_class is None:
raise AssertionError(
f"Floating point module class {str(float_module_class)}"
+ " does not have a corresponding quantized module class"
)
assert static_quant_module_class is not None, (
f"Floating point module class {str(float_module_class)}"
+ " does not have a corresponding quantized module class"
)
return copy.deepcopy(static_quant_module_class)
@ -273,11 +272,10 @@ def get_dynamic_quant_module_class(
DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping
)
dynamic_quant_module_class = all_mappings.get(float_module_class, None)
if dynamic_quant_module_class is None:
raise AssertionError(
f"Floating point module class {str(float_module_class)}"
+ " does not have a corresponding quantized module class"
)
assert dynamic_quant_module_class is not None, (
f"Floating point module class {str(float_module_class)}"
+ " does not have a corresponding quantized module class"
)
return copy.deepcopy(dynamic_quant_module_class)
@ -346,10 +344,9 @@ def get_default_float_to_quantized_operator_mappings() -> dict[
def get_quantized_operator(float_op: Union[Callable, str]) -> Callable:
"""Get the quantized operator corresponding to the float operator"""
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op)
if quantized_op is None:
raise AssertionError(
f"Operator {str(float_op)} does not have corresponding quantized op"
)
assert quantized_op is not None, (
f"Operator {str(float_op)} does not have corresponding quantized op"
)
return quantized_op

View File

@ -158,10 +158,9 @@ def _observer_forward_pre_hook(self, input):
def _register_activation_post_process_hook(module, pre_hook=False):
if not hasattr(module, "activation_post_process"):
raise AssertionError(
"Expect activation_post_process attribute already attached to the module"
)
assert hasattr(module, "activation_post_process"), (
"Expect activation_post_process attribute already attached to the module"
)
if pre_hook:
module.register_forward_pre_hook(_observer_forward_pre_hook, prepend=True)
else:
@ -199,10 +198,9 @@ def _add_observer_(
# respect device affinity when adding observers
if device is None:
devices = _get_unique_devices_(module)
if len(devices) > 1:
raise AssertionError(
f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}"
)
assert len(devices) <= 1, (
f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}"
)
device = next(iter(devices)) if len(devices) > 0 else None
def get_activation_post_process(qconfig, device, special_act_post_process=None):
@ -245,10 +243,9 @@ def _add_observer_(
type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional)
):
if needs_observation(child):
if not hasattr(child, "activation_post_process"):
raise AssertionError(
f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`"
)
assert hasattr(child, "activation_post_process"), (
f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`"
)
child.activation_post_process = get_activation_post_process(
child.qconfig, device
)
@ -587,8 +584,7 @@ def prepare_qat(model, mapping=None, inplace=False):
is mutated
"""
torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
if not model.training:
raise AssertionError("prepare_qat only works on models in training mode")
assert model.training, "prepare_qat only works on models in training mode"
if mapping is None:
mapping = get_default_qat_module_mappings()
@ -764,10 +760,7 @@ def swap_module(
elif type_before_parametrizations(mod) in mapping:
qmod = mapping[type_before_parametrizations(mod)]
if hasattr(qmod, "_IS_REFERENCE") and qmod._IS_REFERENCE:
if mod.qconfig is None:
raise AssertionError(
"module qconfig must not be None when swapping to reference module"
)
assert mod.qconfig is not None
weight_post_process = mod.qconfig.weight()
weight_post_process(mod.weight)
weight_qparams = get_qparam_dict(weight_post_process)
@ -794,13 +787,11 @@ def swap_module(
# respect device affinity when swapping modules
devices = _get_unique_devices_(mod)
if not (
len(devices) <= 1
or (len(devices) == 2 and torch.device("meta") in devices)
):
raise AssertionError(
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
)
assert len(devices) <= 1 or (
len(devices) == 2 and torch.device("meta") in devices
), (
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
)
device = next(iter(devices)) if len(devices) > 0 else None
if device:
new_mod.to(device)

View File

@ -157,12 +157,12 @@ def _convert_ondevice_jit(
model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC
):
_check_is_script_module(model)
if quant_type != QuantType.DYNAMIC:
raise AssertionError(
"This API, while should work for static quant, is only tested for dynamic quant."
)
if method_name.startswith("observe_"):
raise AssertionError("Pass in valid method to be quantized, e.g. forward")
assert quant_type == QuantType.DYNAMIC, (
"This API, while should work for static quant, is only tested for dynamic quant."
)
assert not method_name.startswith("observe_"), (
"Pass in valid method to be quantized, e.g. forward"
)
observe_method_name = "observe_" + method_name
quantize_method_name = "quantize_" + method_name
model_c = model._c
@ -230,14 +230,12 @@ def _quantize_jit(
model = prepare_dynamic_jit(model, qconfig_dict, inplace)
model = convert_dynamic_jit(model, True, debug)
else:
if not run_fn:
raise AssertionError(
"Must provide calibration function for post training static quantization"
)
if not run_args:
raise AssertionError(
"Must provide calibration dataset for post training static quantization"
)
assert run_fn, (
"Must provide calibration function for post training static quantization"
)
assert run_args, (
"Must provide calibration dataset for post training static quantization"
)
model = prepare_jit(model, qconfig_dict, inplace)
run_fn(model, *run_args)
model = convert_jit(model, True, debug)

View File

@ -263,10 +263,7 @@ def _is_quantized_op_pt2e(node: torch.fx.Node):
# The node has not been annotated, directly return False
return False
quantization_annotation = node.meta.get(QUANT_ANNOTATION_KEY, None)
if not isinstance(quantization_annotation, _X86InductorQuantizationAnnotation):
raise AssertionError(
"quantization_annotation must be an _X86InductorQuantizationAnnotation"
)
assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation)
return quantization_annotation._is_output_of_quantized_pattern
@ -431,22 +428,20 @@ class X86InductorQuantizer(Quantizer):
if qat_state is None:
qat_state = qconfig.is_qat
else:
if qat_state != qconfig.is_qat:
raise AssertionError(
f"All non-None quantization configs should have the same `is_qat`,"
f"but got {qat_state} and {qconfig.is_qat}."
)
assert qat_state == qconfig.is_qat, (
f"All non-None quantization configs should have the same `is_qat`,"
f"but got {qat_state} and {qconfig.is_qat}."
)
# Query the `is_dynamic` state
input_activation_spec = qconfig.input_activation
if input_activation_spec is not None:
if dynamic_state is None:
dynamic_state = input_activation_spec.is_dynamic
else:
if dynamic_state != input_activation_spec.is_dynamic:
raise AssertionError(
f"All non-None `input_activation_spec` should have the same `is_dynamic`,"
f"but got {dynamic_state} and {input_activation_spec.is_dynamic}."
)
assert dynamic_state == input_activation_spec.is_dynamic, (
f"All non-None `input_activation_spec` should have the same `is_dynamic`,"
f"but got {dynamic_state} and {input_activation_spec.is_dynamic}."
)
return _CurrentQuantizationMode(
qat_state=qat_state, dynamic_state=dynamic_state
)
@ -572,12 +567,10 @@ class X86InductorQuantizer(Quantizer):
return
input_qspec_map = {}
input_node = conv_node.args[0]
if not isinstance(input_node, Node):
raise AssertionError("input_node must be a FX Node")
assert isinstance(input_node, Node)
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
weight_node = conv_node.args[1]
if not isinstance(weight_node, Node):
raise AssertionError("weight_node must be a FX Node")
assert isinstance(weight_node, Node)
input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
bias_node = None if len(conv_node.args) == 2 else conv_node.args[2]
if isinstance(bias_node, Node):
@ -605,23 +598,18 @@ class X86InductorQuantizer(Quantizer):
_annotate_nodes_not_quantize(linear_node)
return
input_qspec_map = {}
if linear_node.target != torch.ops.aten.linear.default:
raise AssertionError(
"linear_node.target must be torch.ops.aten.linear.default"
)
assert linear_node.target == torch.ops.aten.linear.default
has_bias = len(linear_node.args) == 3
input_index = 0
weight_index = 1
bias_index = 2
input_node = linear_node.args[input_index]
if not isinstance(input_node, Node):
raise AssertionError("input_node must be a FX Node")
assert isinstance(input_node, Node)
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
weight_node = linear_node.args[weight_index]
if not isinstance(weight_node, Node):
raise AssertionError("weight_node must be a FX Node")
assert isinstance(weight_node, Node)
input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
bias_node = linear_node.args[bias_index] if has_bias else None
@ -649,8 +637,7 @@ class X86InductorQuantizer(Quantizer):
if len(partition.output_nodes) > 1:
raise ValueError("Input partition has more than one output node")
output_node = partition.output_nodes[0]
if not isinstance(output_node, Node):
raise AssertionError("output_node must be a FX Node")
assert isinstance(output_node, Node)
output_node_list.append(output_node)
if len(output_node_list) != len(partition_list):
raise ValueError(
@ -679,8 +666,7 @@ class X86InductorQuantizer(Quantizer):
conv_gemm_node_idx = 1
extra_input_node_idx = 0
extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index]
if not isinstance(extra_input_node, Node):
raise AssertionError("extra_input_node must be a FX Node")
assert isinstance(extra_input_node, Node)
return conv_gemm_node_idx, extra_input_node_idx
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
@ -1137,8 +1123,7 @@ class X86InductorQuantizer(Quantizer):
if conv_node != binary_node.args[conv_node_idx]:
raise ValueError(f"{conv_node} doesn't match input of binary node")
extra_input_node = binary_node.args[extra_input_node_idx]
if not isinstance(conv_node, Node):
raise AssertionError("conv_node must be a FX Node")
assert isinstance(conv_node, Node)
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.conv2d.default
@ -1252,8 +1237,7 @@ class X86InductorQuantizer(Quantizer):
return
input_node = maxpool_node.args[0]
if not isinstance(input_node, Node):
raise AssertionError("input_node must be a FX Node")
assert isinstance(input_node, Node)
input_qspec_map = {}
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
maxpool_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
@ -1270,14 +1254,11 @@ class X86InductorQuantizer(Quantizer):
return
cat_node = node
input_nodes = cat_node.args[0]
if not isinstance(input_nodes, Sequence):
raise AssertionError("input_nodes must be a Sequence of FX Nodes")
assert isinstance(input_nodes, Sequence)
first_input_node = input_nodes[0]
input_qspec_map = {}
if not isinstance(first_input_node, Node):
raise AssertionError("first_input_node must be a FX Node")
if not isinstance(cat_node, Node):
raise AssertionError("cat_node must be a FX Node")
assert isinstance(first_input_node, Node)
assert isinstance(cat_node, Node)
input_qspec_map[first_input_node] = get_input_act_qspec(quantization_config)
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
(first_input_node, cat_node)
@ -1286,8 +1267,7 @@ class X86InductorQuantizer(Quantizer):
for input_node in input_nodes[1:]:
if input_node not in input_qspec_map:
# There has the case of cat same nodes: torch.cat([input0, input0], 1)
if not isinstance(input_node, Node):
raise AssertionError("input_node must be a FX Node")
assert isinstance(input_node, Node)
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
cat_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
@ -1425,10 +1405,8 @@ class X86InductorQuantizer(Quantizer):
):
# Annotate the output_qspec of getitem_node
input_act = maxpool_node.args[0]
if not isinstance(input_act, Node):
raise AssertionError("input_act must be a FX Node")
if not isinstance(maxpool_node, Node):
raise AssertionError("maxpool_node must be a FX Node")
assert isinstance(input_act, Node)
assert isinstance(maxpool_node, Node)
edge_or_node = (input_act, maxpool_node)
maxpool_node_quantization_annotation.output_qspec = (
SharedQuantizationSpec(edge_or_node)
@ -1556,8 +1534,7 @@ class X86InductorQuantizer(Quantizer):
raise ValueError(
f"{linear_node} doesn't match input of binary node"
)
if not isinstance(linear_node, Node):
raise AssertionError("linear_node must be a FX Node")
assert isinstance(linear_node, Node)
if (
linear_node.op != "call_function"
or linear_node.target != torch.ops.aten.linear.default

View File

@ -347,8 +347,9 @@ class XNNPACKQuantizer(Quantizer):
quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
patterns in the submodule with this module name with the given `quantization_config`
"""
if quantization_config is None:
raise AssertionError("quantization_config == None is not supported yet")
assert quantization_config is not None, (
" quantization_config == None is not supported yet"
)
self.module_name_config[module_name] = quantization_config
return self

View File

@ -121,13 +121,10 @@ def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_config.input_activation is None:
return None
quantization_spec: QuantizationSpec = quantization_config.input_activation
if quantization_spec.qscheme not in [
assert quantization_spec.qscheme in [
torch.per_tensor_affine,
torch.per_tensor_symmetric,
]:
raise AssertionError(
f"Unsupported activation qscheme: {quantization_spec.qscheme}"
)
]
return quantization_spec
@ -137,21 +134,17 @@ def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_config.output_activation is None:
return None
quantization_spec: QuantizationSpec = quantization_config.output_activation
if quantization_spec.qscheme not in [
assert quantization_spec.qscheme in [
torch.per_tensor_affine,
torch.per_tensor_symmetric,
]:
raise AssertionError(
f"Unsupported activation qscheme: {quantization_spec.qscheme}"
)
]
return quantization_spec
def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_config is None:
return None
if quantization_config is None:
raise AssertionError("quantization_config must not be None")
assert quantization_config is not None
if quantization_config.weight is None:
return None
quantization_spec: QuantizationSpec = quantization_config.weight
@ -169,15 +162,13 @@ def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
def get_bias_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_config is None:
return None
if quantization_config is None:
raise AssertionError("quantization_config must not be None")
assert quantization_config is not None
if quantization_config.bias is None:
return None
quantization_spec: QuantizationSpec = quantization_config.bias
if quantization_spec.dtype != torch.float:
raise AssertionError(
"Only float dtype for bias is supported for bias right now"
)
assert quantization_spec.dtype == torch.float, (
"Only float dtype for bias is supported for bias right now"
)
return quantization_spec
@ -262,13 +253,11 @@ def _annotate_linear_relu(
input_qspec_map = {}
input_act = linear_node.args[0]
if not isinstance(input_act, Node):
raise AssertionError("input activation must be a FX Node")
assert isinstance(input_act, Node)
input_qspec_map[input_act] = input_act_qspec
weight = linear_node.args[1]
if not isinstance(weight, Node):
raise AssertionError("weight must be a FX Node")
assert isinstance(weight, Node)
input_qspec_map[weight] = weight_qspec
# adding weight node to the partition as well
@ -314,13 +303,11 @@ def _annotate_conv(
input_qspec_map = {}
input_act = conv_node.args[0]
if not isinstance(input_act, Node):
raise AssertionError("input activation must be a FX Node")
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
weight = conv_node.args[1]
if not isinstance(weight, Node):
raise AssertionError("weight must be a FX Node")
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
# adding weight node to the partition as well
@ -375,13 +362,11 @@ def _do_annotate_conv_relu(
input_qspec_map = {}
input_act = conv_node.args[0]
if not isinstance(input_act, Node):
raise AssertionError("input activation must be a FX Node")
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
weight = conv_node.args[1]
if not isinstance(weight, Node):
raise AssertionError("weight must be a FX Node")
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
# adding weight node to the partition as well
@ -650,10 +635,8 @@ def _annotate_gru_io_only(
# subgraph
input_act = input_nodes[0]
input_act_user = next(iter(input_act.users.keys()))
if not isinstance(input_act, Node):
raise AssertionError("input activation must be a FX Node")
if not isinstance(input_act_user, Node):
raise AssertionError("input activation user must be a FX Node")
assert isinstance(input_act, Node)
assert isinstance(input_act_user, Node)
input_act_user.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
input_act: get_input_act_qspec(quantization_config),
@ -663,10 +646,8 @@ def _annotate_gru_io_only(
hidden_state = input_nodes[1]
hidden_state_user = next(iter(hidden_state.users.keys()))
if not isinstance(hidden_state, Node):
raise AssertionError("hidden state must be a FX Node")
if not isinstance(hidden_state_user, Node):
raise AssertionError("hidden state user must be a FX Node")
assert isinstance(hidden_state, Node)
assert isinstance(hidden_state_user, Node)
hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
hidden_state: get_input_act_qspec(quantization_config),
@ -674,8 +655,7 @@ def _annotate_gru_io_only(
_annotated=True,
)
if len(output_nodes) != 2:
raise AssertionError("expecting GRU to have two outputs")
assert len(output_nodes) == 2, "expecting GRU to have two outputs"
for output in output_nodes:
output.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_output_act_qspec(quantization_config),
@ -711,8 +691,7 @@ def _annotate_adaptive_avg_pool2d(
annotated_partitions.append(partition.nodes)
input_act = pool_node.args[0]
if not isinstance(input_act, Node):
raise AssertionError("input activation must be a FX Node")
assert isinstance(input_act, Node)
# only annotate input output sharing operator
# when the output of the input node is annotated

View File

@ -214,8 +214,7 @@ def to_underlying_dtype(qdtype):
torch.float8_e5m2: torch.float8_e5m2,
torch.float8_e4m3fn: torch.float8_e4m3fn,
}
if qdtype not in DTYPE_MAPPING:
raise AssertionError("Unsupported dtype: " + str(qdtype))
assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype)
return DTYPE_MAPPING[qdtype]
@ -270,24 +269,21 @@ def get_swapped_custom_module_class(
"""
quant_type = get_quant_type(qconfig)
class_mapping = custom_module_class_mapping.get(quant_type, {})
if type(custom_module) not in class_mapping:
raise AssertionError(
"did not find corresponding observed "
f"module class for {type(custom_module)} in mapping: {class_mapping}"
)
assert type(custom_module) in class_mapping, (
"did not find corresponding observed "
f"module class for {type(custom_module)} in mapping: {class_mapping}"
)
return class_mapping[type(custom_module)]
def activation_dtype(qconfig):
if qconfig is None:
raise AssertionError("qconfig must be provided to determine activation dtype")
assert qconfig is not None
activation = qconfig.activation()
return activation.dtype
def weight_dtype(qconfig):
if qconfig is None:
raise AssertionError("qconfig must be provided to determine weight dtype")
assert qconfig is not None
weight = qconfig.weight()
return weight.dtype
@ -381,8 +377,7 @@ def get_qconfig_dtypes(qconfig):
r"""returns the qconfig tuple for qconfig:
(activation_dtype, weight_dtype, activation_is_dynamic)
"""
if qconfig is None:
raise AssertionError("qconfig must be provided to extract dtypes")
assert qconfig is not None
activation = qconfig.activation()
weight = qconfig.weight()
act_is_dynamic = getattr(activation, "is_dynamic", False)
@ -390,8 +385,7 @@ def get_qconfig_dtypes(qconfig):
def get_quant_type(qconfig):
if qconfig is None:
raise AssertionError("qconfig must be provided to determine quant type")
assert qconfig is not None
activation = qconfig.activation()
weight = qconfig.weight()
static_dtypes = [
@ -446,11 +440,11 @@ def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
return False
if min_val > max_val:
raise AssertionError(f"min {min_val} should be less than max {max_val}")
assert min_val <= max_val, f"min {min_val} should be less than max {max_val}"
else:
if torch.any(min_val > max_val):
raise AssertionError(f"min {min_val} should be less than max {max_val}")
assert torch.all(min_val <= max_val), (
f"min {min_val} should be less than max {max_val}"
)
return True
@ -485,15 +479,13 @@ def calculate_qmin_qmax(
qrange_len = initial_quant_max - initial_quant_min + 1
if dtype in [torch.qint8, torch.int8]:
if not (0 < qrange_len <= 256):
raise AssertionError(
"quantization range should be positive and not exceed the maximum bit range (=256)."
)
assert 0 < qrange_len <= 256, (
"quantization range should be positive and not exceed the maximum bit range (=256)."
)
elif dtype in [torch.qint32, torch.int32]:
if not (0 < qrange_len <= 2**32):
raise AssertionError(
"quantization range should be positive and not exceed the maximum bit range (=4294967296)."
)
assert 0 < qrange_len <= 2**32, (
"quantization range should be positive and not exceed the maximum bit range (=4294967296)."
)
if reduce_range:
quant_min, quant_max = quant_min // 2, quant_max // 2
else:
@ -641,12 +633,12 @@ def validate_qmin_qmax(quant_min: int, quant_max: int) -> None:
"""
# The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
# based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
if not (quant_min <= 0 <= quant_max):
raise AssertionError("Used-specified quantization range must include 0.")
if quant_min >= quant_max:
raise AssertionError(
"qmin must be strictly less than qmax for user-specified quantization range."
)
assert quant_min <= 0 <= quant_max, (
"Used-specified quantization range must include 0."
)
assert quant_min < quant_max, (
"qmin must be strictly less than qmax for user-specified quantization range."
)
# Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme
@ -818,11 +810,10 @@ def _assert_and_get_unique_device(module: torch.nn.Module) -> Any:
)
devices = {torch.device("cpu")}
""
if len(devices) > 1:
raise AssertionError(
"prepare only works with cpu or single-device CUDA modules, "
f"but got devices {devices}"
)
assert len(devices) <= 1, (
"prepare only works with cpu or single-device CUDA modules, "
f"but got devices {devices}"
)
device = next(iter(devices)) if len(devices) > 0 else None
return device

View File

@ -419,7 +419,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
}
// Do not call this directly, use ProcessGroup::setGroupName instead.
void setGroupUid(const std::string& pg_uid) {
virtual void setGroupUid(const std::string& pg_uid) {
pg_uid_ = pg_uid;
}

View File

@ -1,11 +1,15 @@
#include <torch/csrc/fx/node.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <structmember.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <algorithm>
namespace {
using NodeSortKey = c10::SmallVector<int64_t, 4>;
struct NodeBase;
// Thrown to exit out of a C++ function and return an error to Python.
@ -163,7 +167,41 @@ struct NodeBase {
PyObject* users;
PyObject* _repr_fn;
PyObject* meta;
PyObject* _sort_key;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)];
inline NodeSortKey& sort_key() {
return *reinterpret_cast<NodeSortKey*>(sort_key_buf);
}
inline void set_prev(NodeBase* value) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
Py_INCREF(reinterpret_cast<PyObject*>(value));
NodeBase* old = _prev;
_prev = value;
Py_DECREF(reinterpret_cast<PyObject*>(old));
}
inline void set_next(NodeBase* value) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
Py_INCREF(reinterpret_cast<PyObject*>(value));
NodeBase* old = _next;
_next = value;
Py_DECREF(reinterpret_cast<PyObject*>(old));
}
// Equivalent to:
// p, n = self._prev, self._next
// p._next, n._prev = n, p
inline void remove_from_list() {
if (this->_prev == this && this->_next == this) {
return;
}
NodeBase* p = this->_prev;
NodeBase* n = this->_next;
p->set_next(n);
n->set_prev(p);
}
};
static PyObject* NodeBase_new(
@ -173,6 +211,8 @@ static PyObject* NodeBase_new(
PyObject* self = type->tp_alloc(type, 0);
if (!self)
return nullptr;
new (reinterpret_cast<NodeBase*>(self)->sort_key_buf)
NodeSortKey(); // placement new does not allocate
return self;
}
@ -201,7 +241,6 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
self->users = PyDict_New();
self->_repr_fn = Py_NewRef(Py_None);
self->meta = PyDict_New();
self->_sort_key = PyTuple_New(0);
return 0;
}
@ -221,7 +260,6 @@ static struct PyMemberDef NodeBase_members[] = {
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
{nullptr} /* Sentinel */
};
@ -239,7 +277,6 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
Py_VISIT(self->users);
Py_VISIT(self->_repr_fn);
Py_VISIT(self->meta);
Py_VISIT(self->_sort_key);
return 0;
}
@ -257,12 +294,12 @@ static int NodeBase_clear(NodeBase* self) {
Py_CLEAR(self->users);
Py_CLEAR(self->_repr_fn);
Py_CLEAR(self->meta);
Py_CLEAR(self->_sort_key);
return 0;
}
static void NodeBase_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self);
reinterpret_cast<NodeBase*>(self)->sort_key().~NodeSortKey();
(void)NodeBase_clear((NodeBase*)self);
Py_TYPE(self)->tp_free(self);
}
@ -321,15 +358,195 @@ static PyObject* NodeBase__update_args_kwargs(
}
}
static PyObject* NodeBase__remove_from_list(
PyObject* self,
PyObject* _ignored) {
reinterpret_cast<NodeBase*>(self)->remove_from_list();
Py_RETURN_NONE;
}
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
if (self_ == arg) {
Py_RETURN_NONE;
}
if (!is_node(arg)) {
PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node");
return nullptr;
}
NodeBase* self = reinterpret_cast<NodeBase*>(self_);
NodeBase* x = reinterpret_cast<NodeBase*>(arg);
if (self->graph != x->graph) {
PyErr_SetString(
PyExc_AssertionError,
"Attempting to move a Node into a different Graph");
return nullptr;
}
x->remove_from_list();
NodeBase* p = self->_prev;
p->set_next(x);
x->set_prev(p);
x->set_next(self);
self->set_prev(x);
// Now compute x.sort_key()
const NodeSortKey& psk = x->_prev->sort_key();
const NodeSortKey& nsk = x->_next->sort_key();
if (psk.size() > nsk.size()) {
// prefix = psk[: len(nsk)+1]
size_t slice_len = nsk.size() + 1;
NodeSortKey prefix(psk.begin(), psk.begin() + slice_len);
// last element is idx => increment by 1
prefix.back()++;
x->sort_key() = std::move(prefix);
} else if (psk.size() < nsk.size()) {
// prefix = nsk[: len(psk)+1]
size_t slice_len = psk.size() + 1;
NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len);
// last element is idx => decrement by 1
prefix.back()--;
x->sort_key() = std::move(prefix);
} else {
// same length => add a 0
x->sort_key() = psk;
x->sort_key().emplace_back(0);
}
Py_RETURN_NONE;
}
// __lt__(self, other): Return self.sort_key < other.sort_key
static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) {
// METH_O => one argument: 'other'
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
bool less = std::lexicographical_compare(
lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
if (less)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
// __gt__(self, other): Return self.sort_key() > other.sort_key
static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) {
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
// "a > b" is equivalent to "b < a"
bool greater = std::lexicographical_compare(
rhs.begin(), rhs.end(), lhs.begin(), lhs.end());
if (greater)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___gt__(self, other);
}
// __le__(self, other): Return not (self > other)
static PyObject* NodeBase___le__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___lt__(self, other);
}
// Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
// Only used by pickle/__getstate__
static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
const NodeSortKey& vec = node->sort_key();
Py_ssize_t n = static_cast<Py_ssize_t>(vec.size());
THPObjectPtr tuple(PyTuple_New(n));
if (!tuple) {
return nullptr; // Out of memory
}
for (Py_ssize_t i = 0; i < n; i++) {
PyObject* value = PyLong_FromSsize_t(vec[i]);
if (!value) {
return nullptr;
}
PyTuple_SET_ITEM(tuple.get(), i, value);
}
return tuple.release();
}
// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
// node._sort_key = (1,2,3) Only used by pickle/__setstate__
static int NodeBase_set_sort_key(
PyObject* self,
PyObject* value,
void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
if (!PyTuple_Check(value)) {
PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints");
return -1;
}
Py_ssize_t size = PyTuple_GET_SIZE(value);
NodeSortKey new_vec;
new_vec.reserve(size);
for (Py_ssize_t i = 0; i < size; i++) {
int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i));
if (val == -1 && PyErr_Occurred()) {
return -1;
}
new_vec.emplace_back(val);
}
node->sort_key() = std::move(new_vec);
return 0;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyMethodDef NodeBase_methods[] = {
{"_update_args_kwargs",
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
METH_FASTCALL,
"Internal method: do not call directly."},
{"_remove_from_list",
(PyCFunction)(void*)(NodeBase__remove_from_list),
METH_NOARGS,
"Internal method: do not call directly."},
{"_prepend",
(PyCFunction)(void*)(NodeBase__prepend),
METH_O,
"Internal method: do not call directly."},
{"__lt__",
(PyCFunction)(void*)NodeBase___lt__,
METH_O,
"Return True if self.sort_key < other.sort_key"},
{"__gt__",
(PyCFunction)(void*)NodeBase___gt__,
METH_O,
"Return True if self.sort_key > other.sort_key"},
{"__ge__",
(PyCFunction)(void*)NodeBase___ge__,
METH_O,
"Return True if self.sort_key >= other.sort_key"},
{"__le__",
(PyCFunction)(void*)NodeBase___le__,
METH_O,
"Return True if self.sort_key <= other.sort_key"},
{nullptr, nullptr, 0, nullptr} // Sentinel
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyGetSetDef NodeBase_getset[] = {
{"_sort_key", // attribute name in Python
(getter)NodeBase_get_sort_key, // C getter function
(setter)NodeBase_set_sort_key, // C setter function
(char*)"The sort key as a tuple of ints", // docstring
nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel
};
PyTypeObject NodeBaseType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._NodeBase", /* tp_name */
@ -361,7 +578,7 @@ PyTypeObject NodeBaseType = {
nullptr, /* tp_iternext */
NodeBase_methods, /* tp_methods */
NodeBase_members, /* tp_members */
nullptr, /* tp_getset */
NodeBase_getset, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */

View File

@ -385,41 +385,7 @@ class Node(_NodeBase):
Args:
x (Node): The node to put before this node. Must be a member of the same graph.
"""
assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
if self == x:
log.debug(
"Trying to prepend a node to itself. This behavior has no effect on the graph."
)
return
x._remove_from_list()
p = self._prev
p._next, x._prev = x, p
x._next, self._prev = self, x
# compute x._sort_key
psk = x._prev._sort_key
nsk = x._next._sort_key
if len(psk) > len(nsk):
idx: int
*prefix, idx = psk[: len(nsk) + 1]
x._sort_key = (*prefix, idx + 1)
elif len(psk) < len(nsk):
*prefix, idx = nsk[: len(psk) + 1]
x._sort_key = (*prefix, idx - 1)
else: # same length, increase length by 1
x._sort_key = (*psk, 0)
def __gt__(self, other: "Node") -> bool:
return self._sort_key > other._sort_key
def __lt__(self, other: "Node") -> bool:
return self._sort_key < other._sort_key
def __ge__(self, other: "Node") -> bool:
return self > other or self == other
def __le__(self, other: "Node") -> bool:
return self < other or self == other
self._prepend(x)
@compatibility(is_backward_compatible=True)
def append(self, x: "Node") -> None:
@ -430,11 +396,7 @@ class Node(_NodeBase):
Args:
x (Node): The node to put after this node. Must be a member of the same graph.
"""
self._next.prepend(x)
def _remove_from_list(self) -> None:
p, n = self._prev, self._next
p._next, n._prev = n, p
self._next._prepend(x)
@property
def args(self) -> tuple[Argument, ...]: