e3d00beddd
Fix triu_/tril_ overlap handling
2025-10-21 07:54:24 -07:00
21131a2444
Revert "[ROCm][CI] Update rocm.yml workflow to use 1 GPU ARC runners ( #165481 )"
...
This reverts commit ffa90d46e61650834d5f926008f48f50c6a7e87a.
Reverted https://github.com/pytorch/pytorch/pull/165481 on behalf of https://github.com/jeffdaily due to timeouts after merge ([comment](https://github.com/pytorch/pytorch/pull/165481#issuecomment-3426898171 ))
2025-10-21 14:15:55 +00:00
1009790ad8
[pytree][dynamo] trace on native optree functions for community pytree support ( #165860 )
...
Resolves #164972
- #164972
All `torch.utils._cxx_pytree` functions are based on `optree` functions with hardcoded `none_is_leaf=True` and `namespace="torch"`. This PR changes the polyfills to generic `optree` functions with those arguments unhardcoded. This means `torch.utils._cxx_pytree` functions are still traceable while the community `optree` usages can get dynamo support additionally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165860
Approved by: https://github.com/Lucaskabela
2025-10-21 14:13:08 +00:00
410e6a4321
Better error handling in torch/csrc/jit/frontend/* ( #165213 )
...
Refactor error handling by using TORCH_CHECK for improved clarity in constants and scope management in some files in torch/csrc/jit/frontend/*
Fixes some parts of ISSUE https://github.com/pytorch/pytorch/issues/148114
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165213
Approved by: https://github.com/FFFrog , https://github.com/albanD
2025-10-21 13:54:59 +00:00
23c55c5b66
[Code Clean]Replace assert statements with explicit if/raise patterns ( #165735 )
...
Fix part of #164878
Replace 75 assert statements with explicit if/raise patterns in `torch/ao/ns` , include:
- `torch/ao/ns/_numeric_suite_fx.py` - 5 asserts
- `torch/ao/ns/fx/graph_matcher.py` - 6 asserts
- `torch/ao/ns/fx/graph_passes.py` -12 asserts
- `torch/ao/ns/fx/n_shadows_utils.py` - 20 asserts
- `torch/ao/ns/fx/pattern_utils.py` - 2 asserts
- `torch/ao/ns/fx/utils.py` - 21 asserts
- `torch/ao/ns/fx/weight_utils.py` - 19 asserts
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165735
Approved by: https://github.com/albanD
2025-10-21 11:21:57 +00:00
1290b077f2
[dynamo][misc] Replace UserFunctionVariable with VariableTracker build ( #165707 )
...
Audit: To prevent future issues with functools.partial or callable
objects.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165707
Approved by: https://github.com/Lucaskabela
2025-10-21 09:27:41 +00:00
9f9ab881b2
[ROCm][inductor] heuristic improvements for reduction kernels ( #161280 )
...
Improvements to reduction kernel heuristics for MI350.
Contributions from several members of the AMD Inductor and Triton teams: @jataylo @iupaikov-amd @AmdSampsa @xiaohuguo2023
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161280
Approved by: https://github.com/jansel , https://github.com/PaulZhang12 , https://github.com/eellison , https://github.com/jeffdaily
2025-10-21 07:48:54 +00:00
f2bb22ff84
[Inductor-FX] Support Tensor.item ( #165599 )
...
# Feature
This PR supports compiling `Tensor.item` with Inductor's FX backend. This maps to a custom WrapperCodeGen method called `codegen_dynamic_scalar`.
# Implementation
The implementation is fairly mechanical, following the usual flow for these types of PRs.
1. Introduce a new Wrapper IR line for this, called `DynamicScalarLine`.
2. Split `PythonWrapperCodegen.codegen_dynamic_scalar` into 2 parts: a public method which generates the Wrapper IR line, and a private one generating Python from Wrapper IR.
3. Implement an FX codegen method for the wrapper IR line. This one calls `aten.where.Scalar` to handle code like `1 if x.item() else 0`, which is a bit tricky. It also calls `aten.item.default` to convert tensors to scalars.
# Test plan
Added CI tests mirroring the AOTI ones. They test float, int and bool types, the latter taking a distinct codegen path.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165599
Approved by: https://github.com/angelayi , https://github.com/jansel
2025-10-21 07:09:56 +00:00
03f3f7899c
[ATen] Add reduction tag to reduction operators ( #165155 )
...
Add a new 'reduction' tag to tags.yaml and apply it to 98 reduction
operator variants across 21 operator families (sum, mean, min, max,
argmin, argmax, amin, amax, aminmax, prod, all, any, norm, var, std,
std_mean, var_mean, nansum, logsumexp, count_nonzero, linalg_vector_norm).
This tag categorizes operators that perform reduction operations,
computing aggregate values across one or more dimensions of input
tensor(s).
Based on PR #153342 - co-written with @AlonSardas.
Just as we have pointwise tag - this can be useful for compiler passes, or for opting into sharding rules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165155
Approved by: https://github.com/ezyang , https://github.com/zou3519 , https://github.com/mlazos
2025-10-21 04:35:03 +00:00
771170807b
[dynamo][nn_module] Replace UserFunctionVariable with VariableTracker build ( #165708 )
...
Audit: To prevent future issues with functools.partial or callable objects.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165708
Approved by: https://github.com/Lucaskabela
2025-10-21 04:13:12 +00:00
ffa90d46e6
[ROCm][CI] Update rocm.yml workflow to use 1 GPU ARC runners ( #165481 )
...
* Moving rocm.yml from using persistent non-ARC runners from the combined MI2xx (MI210 + MI250) cluster to the ARC runners from the MI250 cluster. This halves the number of nodes, but provides access to approximately 4 times the runners, since every 8-GPU MI250 node now provides 8 1-GPU runners. This should help with concurrent capacity and queueing on the MI2xx jobs.
Tested here successfully: https://github.com/pytorch/pytorch/actions/runs/18620814622/job/53092469720
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165481
Approved by: https://github.com/jeffdaily
Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com >
2025-10-21 04:02:04 +00:00
0e083942cc
Enable PLW0127 in ruff ( #165851 )
...
This PR enables `PLW0127` in ruff, which checks self-assignment of variables with the form `var=var`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165851
Approved by: https://github.com/Lucaskabela
2025-10-21 03:30:57 +00:00
ce1fcff03e
[ROCm] Keep amdgpu-coerce-illegal-types flag if rocm version is less than 7.2 ( #165789 )
...
The `-amdgpu-coerce-illegal-types=1` flag is for LLVM that is in ROCm 6.3, 6.4, 7.0, and 7.1. It will not be in ROCm7.2. It was added to enable performance improvements for composable kernel. ROCm7.2 and newer changed the compiler so that the flag isn't needed to achieve those performance improvements. Keeping the flag with ROCm 7.2 breaks the PyTorch build.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165789
Approved by: https://github.com/jithunnair-amd , https://github.com/jeffdaily
2025-10-21 03:17:33 +00:00
a238a9a100
Add clang-tidy misc-definitions-in-headers check ( #164959 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164959
Approved by: https://github.com/Skylion007 , https://github.com/mikaylagawarecki
ghstack dependencies: #164882 , #164956
2025-10-21 02:59:46 +00:00
fe69a2bbbd
Move from/to to torch::stable::detail ( #164956 )
...
To not pollute the global namespace, we should move the `from`/`to` APIs into torch::stable::detail. We are also following our normal deprecation cycle and choosing to continue exposing the global `from`/`to` for the time being as people who onboard their extensions onto 2.9 would not be able to build with 2.10 otherwise.
Note that this means that within libtorch, we do not get the luxury of tacking on a `using torch::stable::detail::from` because then it leads to build time ambiguous calls --> both the global and namespace APIs are exposed, which one do I want? So that is why you see every local site is updated.
Note that the update is _not_ necessary from a custom op writer point of view. FA3 can continue to build on torch nightlies without changing any code. (Since this is a header change, this PR has no implication on runtime, a previously built FA3 ABI stable wheel will continue to work fine with newer torch versions after this PR.)
Once TORCH_BOX lands, we would be free to remove these global APIs when the deprecation cycle is up (April 2026) and encourage people to use TORCH_BOX and avoid from/to entirely.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164956
Approved by: https://github.com/malfet
ghstack dependencies: #164882
2025-10-21 02:59:46 +00:00
0be0de4ffa
Add type suppressions to _inductor/runtime ( #165918 )
...
Original PR that did this was reverted due to merge conflicts.
Trying it again
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165918
Approved by: https://github.com/oulgen
2025-10-21 02:54:22 +00:00
7406d2e665
[DeviceMesh] Clean up the call into mesh_resouces to get root mesh ( #165787 )
...
We moved the method to get root mesh into class in https://github.com/pytorch/pytorch/pull/164510 . This is to further clean code up.
Differential Revision: [D85090191](https://our.internmc.facebook.com/intern/diff/D85090191 )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165787
Approved by: https://github.com/fegin
2025-10-21 02:54:04 +00:00
303c9cf048
Save Python refcount bump on each arg in maybe_handle_torch_function ( #164625 )
...
Pybind's API entails a small unnecessary overhead when working with args. (Similarly, we should probably be using vectorcall, but that's a bigger change for both us and pybind11.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164625
Approved by: https://github.com/albanD
ghstack dependencies: #164624
2025-10-21 02:40:12 +00:00
d7d4bb7c51
Add XPU part for persons_of_interest ( #165920 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165920
Approved by: https://github.com/albanD
2025-10-21 01:57:17 +00:00
0b1c462979
Making Numpy depedency in Local Tensor optional to fix broken Torchao CI ( #165938 )
...
In recent change LocalTensor introduced dependency on Numpy and has broken Torchao CI.
This dependency cna be made optional and required only when Local Tensor is used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165938
Approved by: https://github.com/atalman
2025-10-21 01:46:53 +00:00
4a6cf0a93e
Fix dynamo stack trace ( #165930 )
...
Fixes #165911
- Add message to Attribute error so we see ` Developer debug context: raised exception AttributeError(["'Linear' object has no attribute 'w'"])` instead of just `Developer debug context: raised exception AttributeError([])`
- Add stack trace in `ObservedException` so we display the inner most error stack trace back to user code
Output:
```
/data/users/shangdiy/pytorch/torch/__init__.py:2641: UserWarning: You are calling torch.compile inside torch.export region. To capture an useful graph, we will implicitly switch to torch.compile(backend=eager)
warnings.warn(
Traceback (most recent call last):
File "/data/users/shangdiy/pytorch/torch/_dynamo/variables/user_defined.py", line 1385, in var_getattr
subobj = self._getattr_static(name)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/shangdiy/pytorch/torch/_dynamo/variables/user_defined.py", line 1256, in _getattr_static
subobj = type(self.value).__getattribute__(self.value, name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Linear' object has no attribute 'w'
During handling of the above exception, another exception occurred:
torch._dynamo.exc.ObservedAttributeError: 'Linear' object has no attribute 'w'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/data/users/shangdiy/pytorch/test.py", line 34, in <module>
mod = torch._dynamo.functional_export._dynamo_graph_capture_for_export(Model())(x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/shangdiy/pytorch/torch/_dynamo/functional_export.py", line 481, in inner
out = fullgraph_capture(
^^^^^^^^^^^^^^^^^^
File "/data/users/shangdiy/pytorch/torch/_dynamo/convert_frame.py", line 1053, in fullgraph_capture
return _fullgraph_capture_frame(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/shangdiy/pytorch/torch/_dynamo/convert_frame.py", line 1115, in _fullgraph_capture_frame
raise e.with_traceback(None) from e.__cause__ # User compiler error
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.Unsupported: Observed exception
Explanation: Dynamo found no exception handler at the top-level compiled function when encountering an exception. Exception will propagate outside the compiled region.
Hint: Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled.
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
Developer debug context: raised exception AttributeError(["'Linear' object has no attribute 'w'"])
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0088.html
from user code:
File "/data/users/shangdiy/pytorch/torch/_dynamo/functional_export.py", line 171, in forward
res = self._export_root(*args, **kwargs)
File "/data/users/shangdiy/pytorch/test.py", line 31, in forward
weight = self.linear.w
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165930
Approved by: https://github.com/anijain2305
2025-10-21 01:32:23 +00:00
4c963a68d7
Use inline instead of anon namespace for stableivalue from/to ( #164882 )
...
Fixes https://github.com/pytorch/pytorch/issues/163343 .
After some consideration, I propose we remove the anonymous namespace around from/to in favor of:
1. Adding inline to the function implementations, assuming that they will not change in the near future
2. If we decide to change them, we will wrap the code in inline versioned namespaces such that the implementations within any versioned namespace will be guaranteed identical.
Note that:
- We eventually intend to abstract away usage of `from`/`to` (related: @lw's TORCH_BOX work)
- The from/to implementations are now powered through class template specializations, where adding a specialization does not change the from/to signatures.
I do plan to deprecate top-level from/to in favor of torch::stable::details::from/to consequently. This way we can stop polluting the global namespace.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164882
Approved by: https://github.com/lw , https://github.com/albanD
2025-10-21 00:12:15 +00:00
b20deec3d1
[PP] Add optional argument to not save outputs ( #165822 )
...
Fix https://github.com/pytorch/pytorch/issues/159251
Add an optional argument `return_outputs` to the schedule `step`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165822
Approved by: https://github.com/wconstab
2025-10-21 00:09:31 +00:00
51d0d8ee67
[ATen] Fix CUDA reduction warp shuffle order ( #164790 )
...
Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166 " />
which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8 " />
Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/
Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:
```
Tensor Shape Operation New all dims (ms) New dim=0 (ms) New dim=1 (ms) Old all dims (ms) Old dim=0 (ms) Old dim=1 (ms)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024) mean 0.015817 0.016259 0.013642 0.015990 0.016258 0.013631
(1024, 1024) sum 0.015917 0.015906 0.013359 0.015707 0.016266 0.013226
(1024, 1024) min 0.016021 0.024625 0.015631 0.015761 0.024485 0.015317
(1024, 1024) max 0.016349 0.024971 0.015972 0.015771 0.025001 0.015314
(1024, 1024) argmin 0.018070 0.024448 0.015578 0.018135 0.025370 0.015322
(1024, 1024) argmax 0.018427 0.024859 0.015932 0.018164 0.024452 0.015639
(1024, 1024) var 0.020078 0.026413 0.020295 0.020199 0.026381 0.020214
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048) mean 0.023826 0.023726 0.022273 0.023236 0.023776 0.022248
(2048, 2048) sum 0.023840 0.023355 0.021974 0.023294 0.023354 0.021884
(2048, 2048) min 0.024519 0.041263 0.024620 0.023292 0.041491 0.024358
(2048, 2048) max 0.024509 0.041670 0.024277 0.023334 0.041231 0.024395
(2048, 2048) argmin 0.026125 0.041282 0.024567 0.026772 0.041773 0.024296
(2048, 2048) argmax 0.026117 0.041487 0.024572 0.026412 0.041477 0.024273
(2048, 2048) var 0.026603 0.048581 0.031308 0.027587 0.048603 0.030860
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096) mean 0.053927 0.057070 0.054073 0.053028 0.057544 0.053935
(4096, 4096) sum 0.053604 0.057410 0.054451 0.053076 0.057033 0.054266
(4096, 4096) min 0.054293 0.109122 0.058363 0.053821 0.108689 0.058382
(4096, 4096) max 0.054258 0.108035 0.058703 0.053492 0.110552 0.058376
(4096, 4096) argmin 0.056805 0.111167 0.058301 0.056836 0.112325 0.058292
(4096, 4096) argmax 0.056488 0.110958 0.058636 0.056844 0.111000 0.057928
(4096, 4096) var 0.058936 0.141755 0.068693 0.059735 0.141284 0.068500
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192) mean 0.145552 0.148082 0.138647 0.145364 0.147818 0.138207
(8192, 8192) sum 0.145985 0.147900 0.138714 0.145755 0.148031 0.138616
(8192, 8192) min 0.146566 0.205359 0.192739 0.145611 0.205237 0.182335
(8192, 8192) max 0.146526 0.204844 0.193050 0.146073 0.205457 0.182697
(8192, 8192) argmin 0.150190 0.206605 0.192543 0.150654 0.206847 0.182007
(8192, 8192) argmax 0.150481 0.206368 0.192535 0.150845 0.206430 0.182022
(8192, 8192) var 0.150884 0.184546 0.203900 0.151594 0.184172 0.197983
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1, 1024, 128) mean 0.014293 0.008119 0.014533 0.013861 0.008022 0.014449
(1, 1024, 128) sum 0.014039 0.007877 0.014111 0.014219 0.008227 0.014045
(1, 1024, 128) min 0.014159 0.011354 0.023493 0.014271 0.010862 0.023644
(1, 1024, 128) max 0.014154 0.011027 0.023368 0.014259 0.011234 0.023692
(1, 1024, 128) argmin 0.016403 0.005677 0.023328 0.016273 0.005683 0.024073
(1, 1024, 128) argmax 0.016734 0.005675 0.023437 0.016580 0.005318 0.023331
(1, 1024, 128) var 0.018338 0.009549 0.025538 0.018528 0.009391 0.024777
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(5, 1024, 128) mean 0.014873 0.010131 0.015546 0.015123 0.010131 0.015481
(5, 1024, 128) sum 0.015334 0.009673 0.015824 0.014736 0.009671 0.015438
(5, 1024, 128) min 0.015047 0.013252 0.024573 0.014803 0.013163 0.024551
(5, 1024, 128) max 0.015050 0.013339 0.024197 0.014810 0.013525 0.024230
(5, 1024, 128) argmin 0.017341 0.012737 0.024306 0.017471 0.012379 0.024991
(5, 1024, 128) argmax 0.017345 0.012411 0.024421 0.017422 0.012471 0.024237
(5, 1024, 128) var 0.019973 0.011453 0.026188 0.020050 0.011438 0.026282
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10, 1024, 128) mean 0.016976 0.011575 0.016831 0.016722 0.011927 0.017173
(10, 1024, 128) sum 0.017039 0.011841 0.017159 0.016385 0.011860 0.016753
(10, 1024, 128) min 0.017036 0.015331 0.026770 0.016944 0.015205 0.027166
(10, 1024, 128) max 0.017369 0.015348 0.027077 0.016531 0.015716 0.026819
(10, 1024, 128) argmin 0.019203 0.014447 0.026813 0.018994 0.014497 0.027313
(10, 1024, 128) argmax 0.019563 0.014795 0.027140 0.019460 0.014912 0.026733
(10, 1024, 128) var 0.020529 0.014316 0.030405 0.020719 0.013960 0.029964
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100, 1024, 128) mean 0.045046 0.039168 0.046082 0.044839 0.039217 0.045782
(100, 1024, 128) sum 0.045094 0.039150 0.045777 0.044496 0.039542 0.046083
(100, 1024, 128) min 0.045768 0.054466 0.076244 0.044915 0.053943 0.076599
(100, 1024, 128) max 0.045748 0.054459 0.076188 0.044931 0.053949 0.076856
(100, 1024, 128) argmin 0.048275 0.054046 0.076647 0.048694 0.054105 0.077004
(100, 1024, 128) argmax 0.048267 0.054395 0.077401 0.048691 0.054131 0.076751
(100, 1024, 128) var 0.049710 0.043254 0.083077 0.050971 0.043251 0.082378
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000, 100) mean 0.202312 0.196723 0.197765 0.201774 0.196641 0.197459
(1000, 1000, 100) sum 0.202651 0.196682 0.197736 0.202175 0.196313 0.197523
(1000, 1000, 100) min 0.203022 0.264762 0.269200 0.202729 0.264129 0.268694
(1000, 1000, 100) max 0.202864 0.264396 0.269388 0.202486 0.263896 0.268720
(1000, 1000, 100) argmin 0.226727 0.263781 0.268651 0.226597 0.264676 0.268983
(1000, 1000, 100) argmax 0.226412 0.264469 0.269090 0.226570 0.264595 0.269178
(1000, 1000, 100) var 0.243223 0.204079 0.216096 0.241942 0.204079 0.215925
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10000, 100) mean 0.016193 0.020277 0.014316 0.016152 0.020324 0.013712
(10000, 100) sum 0.016289 0.020237 0.014034 0.016168 0.020265 0.013708
(10000, 100) min 0.016046 0.030872 0.019609 0.016208 0.030867 0.018627
(10000, 100) max 0.016369 0.030835 0.019257 0.016218 0.030861 0.018209
(10000, 100) argmin 0.017957 0.031171 0.019517 0.018050 0.031556 0.018077
(10000, 100) argmax 0.017961 0.031658 0.019521 0.018060 0.031564 0.018087
(10000, 100) var 0.020393 0.035652 0.019339 0.020144 0.035987 0.019171
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100000, 10) mean 0.015718 0.016576 0.016555 0.015999 0.016246 0.014869
(100000, 10) sum 0.015833 0.016247 0.016572 0.016007 0.016627 0.014872
(100000, 10) min 0.015888 0.020510 0.023920 0.015671 0.020821 0.021417
(100000, 10) max 0.015889 0.020479 0.023918 0.016077 0.020386 0.021421
(100000, 10) argmin 0.018233 0.020863 0.023647 0.017574 0.020864 0.021103
(100000, 10) argmax 0.017896 0.020527 0.023296 0.017569 0.020447 0.021098
(100000, 10) var 0.020005 0.024198 0.024372 0.020075 0.024167 0.022415
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 1023) mean 1.874816 1.963506 1.903909 1.873279 1.963859 1.903230
(1023, 1023, 1023) sum 1.875030 1.965716 1.902458 1.873566 1.960730 1.901642
(1023, 1023, 1023) min 1.878563 2.473455 2.179092 1.875174 2.482086 2.183027
(1023, 1023, 1023) max 1.879128 2.474803 2.178895 1.874831 2.482253 2.183884
(1023, 1023, 1023) argmin 1.921800 2.476629 2.174831 1.923987 2.472641 2.170453
(1023, 1023, 1023) argmax 1.922605 2.476688 2.177927 1.923366 2.472808 2.172979
(1023, 1023, 1023) var 1.972606 3.088695 2.758797 1.978679 3.095658 2.762243
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 255) mean 0.489984 0.500954 0.492957 0.489891 0.500654 0.491971
(1023, 1023, 255) sum 0.490228 0.500764 0.492289 0.489624 0.501089 0.492824
(1023, 1023, 255) min 0.491457 0.563560 0.553334 0.490355 0.564709 0.554754
(1023, 1023, 255) max 0.491396 0.563628 0.553345 0.490017 0.565004 0.554947
(1023, 1023, 255) argmin 0.503666 0.561512 0.551831 0.503845 0.560972 0.551017
(1023, 1023, 255) argmax 0.503602 0.561185 0.551407 0.504328 0.561267 0.551448
(1023, 1023, 255) var 0.510844 0.709452 0.701630 0.512693 0.710365 0.701965
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 377) mean 0.707439 0.727646 0.712019 0.706769 0.727101 0.711632
(1023, 1023, 377) sum 0.707780 0.727453 0.711554 0.706807 0.726656 0.711729
(1023, 1023, 377) min 0.709423 0.819809 0.794379 0.707847 0.822086 0.796664
(1023, 1023, 377) max 0.709297 0.819780 0.794308 0.707566 0.821913 0.796690
(1023, 1023, 377) argmin 0.725028 0.817088 0.791695 0.726039 0.816445 0.790828
(1023, 1023, 377) argmax 0.725301 0.817011 0.791420 0.726040 0.816917 0.791143
(1023, 1023, 377) var 0.740859 1.034165 1.006712 0.743413 1.035506 1.007638
```
Differential Revision: [D85022826](https://our.internmc.facebook.com/intern/diff/D85022826 )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164790
Approved by: https://github.com/ngimel , https://github.com/eqy
2025-10-21 00:09:13 +00:00
70592c6819
[ROCm][CI] Move gfx1100 workflows to own yaml file ( #165699 )
...
This should allow us to move gfx1100 workflow to a lower frequency and also allow it to be triggered on PRs via a dedicated label, for any PRs that target Navi fixes such as [this](https://github.com/pytorch/pytorch/pull/165630 ) or [this](https://github.com/pytorch/pytorch/pull/165625 ).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165699
Approved by: https://github.com/jeffdaily
2025-10-20 23:52:48 +00:00
259cb945f5
[stage 2c] make autograd and inference functions ( #165668 )
...
Add final stage of aot_stage2_compile for autograd and inference.
Differential Revision: D84844699
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165668
Approved by: https://github.com/zhxchen17 , https://github.com/tugsbayasgalan
2025-10-20 23:50:31 +00:00
e20c9bf288
[torch/utils][Code Clean] Clean asserts in torch/utils/*.py ( #165410 )
...
Including:
- `torch/utils/*.py`
Fixes part of #164878
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165410
Approved by: https://github.com/albanD
2025-10-20 23:29:17 +00:00
99c8640b5d
[1/N] Change C-style casts to static_cast or reinterpret_cast ( #165750 )
...
This series of changes try to cover C style casts into C++ alternatives.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165750
Approved by: https://github.com/Skylion007
2025-10-20 23:27:13 +00:00
96b0e7aaa6
[Code Clean] Clean asserts in torch/ao/quantization/experimental/* and torch/ao/quantization/pt2e/* ( #165317 )
...
Replace assert statements with explicit if/raise patterns in:
- torch/ao/quantization/experimental/* (11 errors)
- torch/ao/quantization/pt2e/* (68 errors)
fix partialy #164878
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165317
Approved by: https://github.com/albanD
2025-10-20 23:07:11 +00:00
850ba8c96d
[Code Clean] Clean asserts in torch/autograd. ( #165627 )
...
Replaces 78 assert statements across 10 files in torch.autograd with explicit if-checks raising AssertionError to prevent assertions from being disabled with Python -O flag. This ensures error checking remains active in optimized builds.
fix partially #164878
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165627
Approved by: https://github.com/albanD
2025-10-20 23:03:47 +00:00
1bcd736f91
fix bad merge duplicate pre pass ( #165917 )
...
fix for https://github.com/pytorch/pytorch/issues/165624 - we were applying pre pass multiple times.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165917
Approved by: https://github.com/bdhirsh
2025-10-20 22:54:36 +00:00
df64c0c464
[Code Clean] Clean asserts in torch/ao/quantization (root, quantizer, backend_config) ( #165433 )
...
Replace assert statements with explicit if/raise patterns in:
- torch/ao/quantization/~
- torch/ao/quantization/quantizer/
- torch/ao/quantization/backend_config/
fix partialy #164878
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165433
Approved by: https://github.com/albanD
2025-10-20 22:42:51 +00:00
1891239a1d
[Graph Partition] fix graph partition input signature for fallback kernels ( #165815 )
...
Scheduler relies on node.last_usage to free buffers. `last_usage` may contain a buffer that is allocated in previous graph partition AND not directly accessed in the current graph partition.
## Example
```python
def f(x):
y = x + 1
z = torch.ops.aten.view.dtype(y, torch.float8_e4m3fn)
z_cpu = z.cpu()
u_cuda = z_cpu.cuda()
return u_cuda
```
In the generated code, we have
```
def partition_0(args):
...
# Topologically Sorted Source Nodes: [y, z], Original ATen: [aten.add, aten.view]
buf1 = torch.ops.aten.view.dtype(buf0, torch.float8_e4m3fn) # < ------ buf1 is a view of buf0
buf2 = buf1 # <------- buf2 is buf1
assert_size_stride(buf2, (8, ), (1, ), 'torch.ops.aten.view.dtype')
assert_alignment(buf2, 16, 'torch.ops.aten.view.dtype')
return (buf2, )
def call(self, args):
...
(buf2,) = self.partitions[0](partition0_args)
...
buf3.copy_(buf2, False)
del buf0
del buf1
del buf2 # <---- `del buf2` leads to `del buf0`. BUT `buf0` is not returned from partition_0.
...
```
Note: view is treated as a fallback kernel due to its special dtype.
de09bab4b6/torch/_inductor/lowering.py (L841-L843)
## Fix
This PR fixes the issue by also returning these buffers to be freed later.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165815
Approved by: https://github.com/eellison
2025-10-20 22:23:29 +00:00
cf280ca1e8
Revert "[Inductor] Naive foreach autotune support ( #162053 )"
...
This reverts commit 779296a3fce5db0829377c792f13a8eafe537b30.
Reverted https://github.com/pytorch/pytorch/pull/162053 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/162053#issuecomment-3423808492 ))
2025-10-20 21:36:44 +00:00
efc277cac7
[annotation] add logging for debugging annotation ( #165797 )
...
Add logging for debugging annotation bugs. Log will show with `TORCH_LOGS="+annotation" `
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165797
Approved by: https://github.com/ezyang , https://github.com/Skylion007 , https://github.com/SherlockNoMad
2025-10-20 21:27:38 +00:00
4f7f43253d
Revert "[ROCm][CI] Update rocm.yml workflow to use 1 GPU ARC runners ( #165481 )"
...
This reverts commit 8700d68fef855850e2e0aa65056a77b8f80adbdb.
Reverted https://github.com/pytorch/pytorch/pull/165481 on behalf of https://github.com/malfet due to Broke lint somehow, see 8f06a1308f/1 ([comment](https://github.com/pytorch/pytorch/pull/165481#issuecomment-3423642456 ))
2025-10-20 20:39:56 +00:00
779296a3fc
[Inductor] Naive foreach autotune support ( #162053 )
...
Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code.
Before:
triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 |
After:
triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162053
Approved by: https://github.com/mlazos , https://github.com/naromero77amd
2025-10-20 20:39:04 +00:00
8f06a1308f
[MPS] slightly faster cholesky ( #165867 )
...
Slightly faster cholesky, removed one redundant simdgroup_multiply
<img width="721" height="593" alt="Screenshot 2025-10-19 at 22 00 19" src="https://github.com/user-attachments/assets/e3a9005b-9347-4e62-a24d-16ba5e28849a " />
Generate benchmarks with(measured on M1 Pro):
```
import torch
import numpy as np
import time
import csv
matrix_sizes = [512, 1024, 2048, 4096]
batch_sizes = [1, 2, 4, 8, 16]
num_runs = 10
warmup_runs = 3
def create_spd_matrix(n, batch_size):
torch.manual_seed(42)
A = torch.randn(batch_size, n, n, dtype=torch.float32)
return A @ A.transpose(-2, -1) + n * torch.eye(n).expand(batch_size, -1, -1)
def run_cholesky_mps(A):
torch.mps.synchronize()
start = time.perf_counter()
b = torch.linalg.cholesky(A, upper=False)
torch.mps.synchronize()
end = time.perf_counter()
return b, end - start
results = {
'N': [],
'batch_size': [],
'mean_time': [],
'std_time': []
}
for n in matrix_sizes:
for batch_size in batch_sizes:
print(f"\nBenchmarking N={n}, batch_size={batch_size}")
try:
A_cpu = create_spd_matrix(n, batch_size)
A_mps = A_cpu.to("mps")
for _ in range(warmup_runs):
_, _ = run_cholesky_mps(A_mps)
times = []
for _ in range(num_runs):
_, t = run_cholesky_mps(A_mps)
times.append(t)
mean_time = np.mean(times)
std_time = np.std(times)
results['N'].append(n)
results['batch_size'].append(batch_size)
results['mean_time'].append(mean_time)
results['std_time'].append(std_time)
print(f"Mean time: {mean_time:.4f}s ± {std_time:.4f}s")
except RuntimeError as e:
print(f"Error for N={n}, batch_size={batch_size}: {e}")
continue
with open('cholesky_benchmark_times.csv', 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['N', 'batch_size', 'mean_time', 'std_time'])
for i in range(len(results['N'])):
writer.writerow([
results['N'][i],
results['batch_size'][i],
results['mean_time'][i],
results['std_time'][i]
])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165867
Approved by: https://github.com/malfet
2025-10-20 18:56:17 +00:00
240c13394e
Revert "[inductor] require shape in TritonCSEVariable ( #162275 )"
...
This reverts commit 3af2f0c12accc6bd10ef2b76fb5c51aa0f6b73a3.
Reverted https://github.com/pytorch/pytorch/pull/162275 on behalf of https://github.com/clee2000 due to still failing due to the above D84932446 ([comment](https://github.com/pytorch/pytorch/pull/162275#issuecomment-3423153819 ))
2025-10-20 17:55:54 +00:00
150682ba7f
Revert "Remove workaround to old CUDA bug ( #164354 )"
...
This reverts commit 26f38034332a99f2bdcc67ce1f4ba9403d420e52.
Reverted https://github.com/pytorch/pytorch/pull/164354 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/164354#issuecomment-3423132083 ))
2025-10-20 17:48:08 +00:00
ca7360e996
Revert "Move toString(ScalarType) and ScalarType ostream operator to headeronly ( #164405 )"
...
This reverts commit ca8bd5dbedb5b46f78026e0378b0f47500ddba38.
Reverted https://github.com/pytorch/pytorch/pull/164405 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/164354#issuecomment-3423132083 ))
2025-10-20 17:48:08 +00:00
0bf604320f
Revert "[dynamo][user_defined] Replace UserFunctionVariable with VariableTracker build ( #165706 )"
...
This reverts commit 1dc9a05d0323ee3c7a20945c62463959d40f1a51.
Reverted https://github.com/pytorch/pytorch/pull/165706 on behalf of https://github.com/clee2000 due to breaking internal tests D84961097 ([comment](https://github.com/pytorch/pytorch/pull/165706#issuecomment-3423059867 ))
2025-10-20 17:28:58 +00:00
9875e70da8
Revert "[dynamo][misc] Replace UserFunctionVariable with VariableTracker build ( #165707 )"
...
This reverts commit 630520b346b8883db7821562e589ccde7d12687a.
Reverted https://github.com/pytorch/pytorch/pull/165707 on behalf of https://github.com/clee2000 due to breaking internal tests D84961097 ([comment](https://github.com/pytorch/pytorch/pull/165706#issuecomment-3423059867 ))
2025-10-20 17:28:58 +00:00
69a4bfe8bb
Revert "Refactor out headeronly ArrayRef ( #164991 )"
...
This reverts commit 3806e9767b03d06edc317cb90a3a996abdf192a0.
Reverted https://github.com/pytorch/pytorch/pull/164991 on behalf of https://github.com/clee2000 due to breaking internal tests D84961075 ([comment](https://github.com/pytorch/pytorch/pull/164991#issuecomment-3423058017 ))
2025-10-20 17:26:42 +00:00
62a263b8d4
Revert "Widen ops support to take in IntHOArrayRef vs only std::vec ( #165152 )"
...
This reverts commit e4454947e2c692db1a249591121f8583fefe7df1.
Reverted https://github.com/pytorch/pytorch/pull/165152 on behalf of https://github.com/clee2000 due to breaking internal tests D84961075 ([comment](https://github.com/pytorch/pytorch/pull/164991#issuecomment-3423058017 ))
2025-10-20 17:26:42 +00:00
0da1f911dc
Revert "[Submodule] Bump FBGEMM to latest ( #165544 )"
...
This reverts commit 23417ae50f5d9bc02e988d916c103ff3a03c5903.
Reverted https://github.com/pytorch/pytorch/pull/165544 on behalf of https://github.com/clee2000 due to failing in internal D84996252, probably needs some sort of update to fbgemm internally? ([comment](https://github.com/pytorch/pytorch/pull/165544#issuecomment-3422993703 ))
2025-10-20 17:06:07 +00:00
8700d68fef
[ROCm][CI] Update rocm.yml workflow to use 1 GPU ARC runners ( #165481 )
...
* Moving rocm.yml from using persistent non-ARC runners from the combined MI2xx (MI210 + MI250) cluster to the ARC runners from the MI250 cluster. This halves the number of nodes, but provides access to approximately 4 times the runners, since every 8-GPU MI250 node now provides 8 1-GPU runners. This should help with concurrent capacity and queueing on the MI2xx jobs.
Tested here successfully: https://github.com/pytorch/pytorch/actions/runs/18620814622/job/53092469720
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165481
Approved by: https://github.com/jeffdaily , https://github.com/pruthvistony , https://github.com/albanD
Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com >
2025-10-20 16:06:37 +00:00
ab82456c16
Revert "[1/N] Change C-style casts to static_cast or reinterpret_cast ( #165750 )"
...
This reverts commit e1e8491b316df810388d9fa24f135cdba27ab40e.
Reverted https://github.com/pytorch/pytorch/pull/165750 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165750#issuecomment-3422413890 ))
2025-10-20 14:51:58 +00:00
b23f4687fd
[Inductor][CuTeDSL] Move load_template up two directories ( #165868 )
...
Summary:
This is a reland of https://github.com/pytorch/pytorch/pull/165347
Moves the function used to load CuTeDSL Jinja templates up one level out of the flex attention folder. This way it can be used for more generate Inductor templates in the future.
Test Plan: test/inductor/test_flex_flash
Differential Revision: D85013024
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165868
Approved by: https://github.com/jananisriram
2025-10-20 12:14:38 +00:00
2705937080
[CI] Add rocm CI back to trunk for pre-submit/PR jobs ( #165674 )
...
Only adding single-GPU shards for now, to observe how current capacity handles it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165674
Approved by: https://github.com/jeffdaily
2025-10-20 12:14:06 +00:00
c1eda348be
[cuda] fix triu/tril int32 overflow for large matrices ( #164705 )
...
Fixes #136611
Cast blockIdx.x to int64_t before multiplication to prevent overflow when computing linear_idx for matrices larger than 2^31 elements.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164705
Approved by: https://github.com/eqy , https://github.com/ngimel
2025-10-20 07:17:41 +00:00
ba93d5636e
[cuda] fix nll_loss2d backward bounds check with reduction=none ( #165247 )
...
Fixes #49882
Add missing bounds check in nll_loss2d backward kernel with reduction=none. Forward kernel already had CUDA_KERNEL_ASSERT for target bounds, now backward kernel matches.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165247
Approved by: https://github.com/ngimel
2025-10-20 06:25:11 +00:00
722b2b86c9
[dynamo] Remove duplicated guards ( #165806 )
...
This is by looking at a tlparse of an internal job. We will need deeper audit.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165806
Approved by: https://github.com/jansel
2025-10-20 05:50:33 +00:00
e1e8491b31
[1/N] Change C-style casts to static_cast or reinterpret_cast ( #165750 )
...
This series of changes try to cover C style casts into C++ alternatives.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165750
Approved by: https://github.com/Skylion007
2025-10-20 04:36:19 +00:00
767199fd9b
[flex_attention] replace sliced BlockMask noop with helpful error ( #164702 )
...
Fixes part of #163314
After slicing BlockMask with `[]`, mask_mod was silently replaced with noop_mask. This caused silent incorrect results when users applied transformations to `sliced_mask.mask_mod`.
Replace noop with `_sliced_mask_mod_error` that raises RuntimeError with guidance to use `base_mask.mask_mod` instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164702
Approved by: https://github.com/drisspg , https://github.com/BoyuanFeng
2025-10-20 03:46:16 +00:00
602ace5eb4
Revert "[ATen] Fix CUDA reduction warp shuffle order ( #164790 )"
...
This reverts commit 36371b8ec7a1baed255c18451b2c716386a54c95.
Reverted https://github.com/pytorch/pytorch/pull/164790 on behalf of https://github.com/clee2000 due to was reverted due to failing internal tests after merge D84992607 ([comment](https://github.com/pytorch/pytorch/pull/164790#issuecomment-3420373755 ))
2025-10-20 03:06:52 +00:00
47804ce467
Revert "12/n : Remove fbandroid_compiler_flags ( #165558 )"
...
This reverts commit aead9270f56ebc7302c7f5fa7e5dff959f26608e.
Reverted https://github.com/pytorch/pytorch/pull/165558 on behalf of https://github.com/clee2000 due to Diff was actually reverted internally D84832629 ([comment](https://github.com/pytorch/pytorch/pull/165558#issuecomment-3420367955 ))
2025-10-20 03:03:13 +00:00
e8cb34dd52
[Inductor] support masked vectorization for the tail_loop for fp8 datatype ( #163324 )
...
**Summary:**
Support masked vectorization for the tail_loop for fp8 datatype.
**Example:**
```
import torch
def fn(
x,
scale,
zero_point,
quant_min,
quant_max,
dtype,
):
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x,
scale,
zero_point,
quant_min,
quant_max,
dtype,
)
x = torch.relu(x)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, scale, zero_point, quant_min, quant_max, dtype
)
return x
quant_min = -128
quant_max = 127
dtype = torch.float8_e4m3fn
x = torch.clamp(torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, quant_min, quant_max).to(dtype)
zero_point = 100
scale = 0.01
with torch.no_grad():
compiled_fn = torch.compile(fn)
compiled_fn(x, scale, zero_point, quant_min, quant_max, dtype)
```
**Generated code:**
- Before
```
cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0 = async_compile.cpp_pybinding(['const at::Float8_e4m3fn*', 'at::Float8_e4m3fn*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C" void kernel(const at::Float8_e4m3fn* in_ptr0,
at::Float8_e4m3fn* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(441L); x0+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(432L)))
{
auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(100.0);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 - tmp3;
auto tmp5 = static_cast<float>(0.01);
auto tmp6 = at::vec::Vectorized<float>(tmp5);
auto tmp7 = tmp4 * tmp6;
auto tmp8 = (tmp7);
auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
auto tmp10 = tmp9 * tmp3;
auto tmp11 = tmp10.round();
auto tmp12 = tmp11 + tmp3;
auto tmp13 = static_cast<float>(-128.0);
auto tmp14 = at::vec::Vectorized<float>(tmp13);
auto tmp15 = at::vec::maximum(tmp12, tmp14);
auto tmp16 = static_cast<float>(127.0);
auto tmp17 = at::vec::Vectorized<float>(tmp16);
auto tmp18 = at::vec::minimum(tmp15, tmp17);
auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
}
if(C10_UNLIKELY(x0 >= static_cast<int64_t>(432L) && x0 < static_cast<int64_t>(441L)))
{
for (int64_t x0_tail = static_cast<int64_t>(432L);x0_tail < static_cast<int64_t>(441L); x0_tail++)
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail)];
auto tmp1 = c10::convert<float>(tmp0);
auto tmp2 = static_cast<float>(100.0);
auto tmp3 = float(tmp1 - tmp2);
auto tmp4 = static_cast<float>(0.01);
auto tmp5 = float(tmp3 * tmp4);
auto tmp6 = c10::convert<float>(tmp5);
auto tmp7 = std::max(tmp6, decltype(tmp6)(0));
auto tmp8 = float(tmp7 * tmp2);
auto tmp9 = std::nearbyint(tmp8);
auto tmp10 = float(tmp9 + tmp2);
auto tmp11 = static_cast<float>(-128.0);
auto tmp12 = max_propagate_nan(tmp10, tmp11);
auto tmp13 = static_cast<float>(127.0);
auto tmp14 = min_propagate_nan(tmp12, tmp13);
auto tmp15 = c10::convert<at::Float8_e4m3fn>(tmp14);
out_ptr0[static_cast<int64_t>(x0_tail)] = tmp15;
}
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
class Runner:
def __init__(self, partitions):
self.partitions = partitions
def recursively_apply_fns(self, fns):
new_callables = []
for fn, c in zip(fns, self.partitions):
new_callables.append(fn(c))
self.partitions = new_callables
def call(self, args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (1, 7, 7, 9), (441, 63, 9, 1))
buf0 = empty_strided_cpu((1, 7, 7, 9), (441, 63, 9, 1), torch.float8_e4m3fn)
# [Provenance debug handles] cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0:1
cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0(arg0_1, buf0)
del arg0_1
return (buf0, )
```
- After
```
cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0 = async_compile.cpp_pybinding(['const at::Float8_e4m3fn*', 'at::Float8_e4m3fn*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C" void kernel(const at::Float8_e4m3fn* in_ptr0,
at::Float8_e4m3fn* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(441L); x0+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(432L)))
{
auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(100.0);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 - tmp3;
auto tmp5 = static_cast<float>(0.01);
auto tmp6 = at::vec::Vectorized<float>(tmp5);
auto tmp7 = tmp4 * tmp6;
auto tmp8 = (tmp7);
auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
auto tmp10 = tmp9 * tmp3;
auto tmp11 = tmp10.round();
auto tmp12 = tmp11 + tmp3;
auto tmp13 = static_cast<float>(-128.0);
auto tmp14 = at::vec::Vectorized<float>(tmp13);
auto tmp15 = at::vec::maximum(tmp12, tmp14);
auto tmp16 = static_cast<float>(127.0);
auto tmp17 = at::vec::Vectorized<float>(tmp16);
auto tmp18 = at::vec::minimum(tmp15, tmp17);
auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
}
if(C10_UNLIKELY(x0 >= static_cast<int64_t>(432L) && x0 < static_cast<int64_t>(441L)))
{
auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(9L));
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(100.0);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 - tmp3;
auto tmp5 = static_cast<float>(0.01);
auto tmp6 = at::vec::Vectorized<float>(tmp5);
auto tmp7 = tmp4 * tmp6;
auto tmp8 = (tmp7);
auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
auto tmp10 = tmp9 * tmp3;
auto tmp11 = tmp10.round();
auto tmp12 = tmp11 + tmp3;
auto tmp13 = static_cast<float>(-128.0);
auto tmp14 = at::vec::Vectorized<float>(tmp13);
auto tmp15 = at::vec::maximum(tmp12, tmp14);
auto tmp16 = static_cast<float>(127.0);
auto tmp17 = at::vec::Vectorized<float>(tmp16);
auto tmp18 = at::vec::minimum(tmp15, tmp17);
auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(9L));
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
class Runner:
def __init__(self, partitions):
self.partitions = partitions
def recursively_apply_fns(self, fns):
new_callables = []
for fn, c in zip(fns, self.partitions):
new_callables.append(fn(c))
self.partitions = new_callables
def call(self, args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (1, 7, 7, 9), (441, 63, 9, 1))
buf0 = empty_strided_cpu((1, 7, 7, 9), (441, 63, 9, 1), torch.float8_e4m3fn)
# [Provenance debug handles] cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0:1
cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0(arg0_1, buf0)
del arg0_1
return (buf0, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163324
Approved by: https://github.com/Xia-Weiwen , https://github.com/mingfeima , https://github.com/jansel
ghstack dependencies: #163316
2025-10-20 01:56:00 +00:00
e9d8973427
[Inductor] support masked vectorization for the tail_loop for float64 datatype ( #163316 )
...
**Summary:**
Support masked vectorization for the tail_loop for float64 datatype.
**Example:**
```
import torch
def fn(x):
return x * x
x = torch.randn((22, 22), dtype=torch.double)
with torch.no_grad():
compiled_fn = torch.compile(fn)
compiled_fn(x)
```
**Generated code:**
- Before
```
cpp_fused_mul_0 = async_compile.cpp_pybinding(['const double*', 'double*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C" void kernel(const double* in_ptr0,
double* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(484L); x0+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(480L)))
{
auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
auto tmp1 = tmp0 * tmp0;
tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
}
if(C10_UNLIKELY(x0 >= static_cast<int64_t>(480L) && x0 < static_cast<int64_t>(484L)))
{
for (int64_t x0_tail = static_cast<int64_t>(480L);x0_tail < static_cast<int64_t>(484L); x0_tail++)
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail)];
auto tmp1 = double(tmp0 * tmp0);
out_ptr0[static_cast<int64_t>(x0_tail)] = tmp1;
}
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
class Runner:
def __init__(self, partitions):
self.partitions = partitions
def recursively_apply_fns(self, fns):
new_callables = []
for fn, c in zip(fns, self.partitions):
new_callables.append(fn(c))
self.partitions = new_callables
def call(self, args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (22, 22), (22, 1))
buf0 = empty_strided_cpu((22, 22), (22, 1), torch.float64)
# [Provenance debug handles] cpp_fused_mul_0:1
cpp_fused_mul_0(arg0_1, buf0)
del arg0_1
return (buf0, )
```
- After
```
cpp_fused_mul_0 = async_compile.cpp_pybinding(['const double*', 'double*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C" void kernel(const double* in_ptr0,
double* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(484L); x0+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(480L)))
{
auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
auto tmp1 = tmp0 * tmp0;
tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
}
if(C10_UNLIKELY(x0 >= static_cast<int64_t>(480L) && x0 < static_cast<int64_t>(484L)))
{
auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(4L));
auto tmp1 = tmp0 * tmp0;
tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(4L));
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
class Runner:
def __init__(self, partitions):
self.partitions = partitions
def recursively_apply_fns(self, fns):
new_callables = []
for fn, c in zip(fns, self.partitions):
new_callables.append(fn(c))
self.partitions = new_callables
def call(self, args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (22, 22), (22, 1))
buf0 = empty_strided_cpu((22, 22), (22, 1), torch.float64)
# [Provenance debug handles] cpp_fused_mul_0:1
cpp_fused_mul_0(arg0_1, buf0)
del arg0_1
return (buf0, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163316
Approved by: https://github.com/mingfeima , https://github.com/jansel
2025-10-20 01:41:38 +00:00
61d9a5180e
[Fix XPU CI] [Inductor UT] Fix test cases broken by community. ( #165714 )
...
Fixes #165719 , Fixes #165771
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165714
Approved by: https://github.com/jansel
2025-10-19 23:59:04 +00:00
8a8329b51f
[ATen] Switch order of blocked reduce when vectorize loads ( #165178 )
...
Performance benchmarking, perf neutral:
```
================================================================================================================================================================================================================================================
Tensor Shape Operation Full reduce (ms) Non-Contig dim (ms) Contig dim (ms) Full reduce (ms) Non-Contig dim (ms) Contig dim (ms) Full diff % Non-Contig diff % Contig diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256) mean 0.015684 0.017056 0.008287 0.016015 0.016929 0.008170 -2.07% +0.75% +1.43%
(256, 256) sum 0.015774 0.016638 0.007926 0.015811 0.016935 0.008330 -0.23% -1.75% -4.85%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512) mean 0.013385 0.025742 0.008629 0.013046 0.026005 0.008924 +2.60% -1.01% -3.31%
(512, 512) sum 0.013390 0.026059 0.009116 0.013054 0.025696 0.008952 +2.57% +1.41% +1.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024) mean 0.014213 0.015467 0.010334 0.013862 0.015082 0.010318 +2.53% +2.55% +0.16%
(1024, 1024) sum 0.014179 0.015446 0.010774 0.014132 0.015073 0.010350 +0.33% +2.47% +4.10%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048) mean 0.018234 0.019487 0.014812 0.018482 0.019397 0.014802 -1.34% +0.46% +0.07%
(2048, 2048) sum 0.018202 0.019529 0.015195 0.018122 0.019485 0.015129 +0.44% +0.23% +0.44%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096) mean 0.033582 0.039378 0.030751 0.033810 0.039673 0.031019 -0.67% -0.74% -0.86%
(4096, 4096) sum 0.033604 0.039777 0.030809 0.033530 0.039386 0.031113 +0.22% +0.99% -0.98%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192) mean 0.085824 0.091133 0.084200 0.085431 0.091364 0.084303 +0.46% -0.25% -0.12%
(8192, 8192) sum 0.085763 0.091442 0.084180 0.085508 0.091419 0.084595 +0.30% +0.03% -0.49%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384) mean 0.146480 0.147666 0.138807 0.146515 0.147987 0.138930 -0.02% -0.22% -0.09%
(8192, 16384) sum 0.146446 0.147593 0.138559 0.146151 0.147982 0.139120 +0.20% -0.26% -0.40%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768) mean 0.266047 0.265386 0.253837 0.265648 0.265885 0.253652 +0.15% -0.19% +0.07%
(8192, 32768) sum 0.266093 0.265421 0.253890 0.265458 0.265591 0.253567 +0.24% -0.06% +0.13%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536) mean 0.498632 0.508976 0.481865 0.498237 0.508777 0.481476 +0.08% +0.04% +0.08%
(8192, 65536) sum 0.498917 0.508202 0.481883 0.498104 0.508016 0.481972 +0.16% +0.04% -0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072) mean 0.957633 0.968519 0.938172 0.956766 0.968267 0.938196 +0.09% +0.03% -0.00%
(8192, 131072) sum 0.956972 0.968140 0.937741 0.957365 0.968404 0.938056 -0.04% -0.03% -0.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144) mean 1.906661 1.928377 1.861846 1.907327 1.928811 1.862083 -0.03% -0.02% -0.01%
(8192, 262144) sum 1.905976 1.928362 1.862399 1.907098 1.928844 1.861782 -0.06% -0.02% +0.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144) mean 0.956852 0.970101 0.936524 0.957263 0.969809 0.936965 -0.04% +0.03% -0.05%
(4096, 262144) sum 0.957117 0.969933 0.936247 0.956675 0.969451 0.936395 +0.05% +0.05% -0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144) mean 0.498813 0.511299 0.483415 0.498567 0.511482 0.483376 +0.05% -0.04% +0.01%
(2048, 262144) sum 0.498813 0.510834 0.483641 0.498875 0.511036 0.483338 -0.01% -0.04% +0.06%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144) mean 0.266157 0.276751 0.255192 0.265966 0.276808 0.255544 +0.07% -0.02% -0.14%
(1024, 262144) sum 0.266133 0.276709 0.255528 0.265658 0.276685 0.255287 +0.18% +0.01% +0.09%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072) mean 0.085941 0.081184 0.087931 0.085591 0.080832 0.088008 +0.41% +0.44% -0.09%
(512, 131072) sum 0.085962 0.081107 0.088045 0.085882 0.081160 0.088024 +0.09% -0.07% +0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000) mean 0.014203 0.045859 0.010310 0.013885 0.046132 0.010621 +2.29% -0.59% -2.93%
(1000, 1000) sum 0.014180 0.046165 0.010756 0.013893 0.046109 0.010338 +2.07% +0.12% +4.04%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129) mean 0.012953 0.016751 0.008536 0.012977 0.016714 0.008916 -0.18% +0.22% -4.26%
(1024, 129) sum 0.013356 0.016806 0.008722 0.013003 0.017071 0.008611 +2.71% -1.55% +1.29%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257) mean 0.013075 0.016787 0.009102 0.013116 0.016769 0.008679 -0.31% +0.11% +4.87%
(1024, 257) sum 0.013092 0.016842 0.008786 0.013126 0.017128 0.008771 -0.26% -1.67% +0.17%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587) mean 0.013662 0.017412 0.010055 0.013659 0.017019 0.010033 +0.02% +2.31% +0.22%
(1024, 587) sum 0.013636 0.017473 0.010163 0.013642 0.017363 0.010101 -0.04% +0.63% +0.61%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977) mean 0.015276 0.027873 0.012531 0.015241 0.027783 0.012467 +0.23% +0.32% +0.51%
(2048, 977) sum 0.015345 0.027949 0.012192 0.015255 0.027839 0.012485 +0.59% +0.40% -2.35%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128) mean 0.012806 0.014020 0.008291 0.013137 0.014309 0.007908 -2.52% -2.02% +4.84%
(1024, 128) sum 0.012769 0.014308 0.007924 0.012788 0.014236 0.008038 -0.15% +0.51% -1.42%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128) mean 0.014145 0.023049 0.009143 0.014104 0.023298 0.009501 +0.29% -1.07% -3.77%
(8192, 128) sum 0.014132 0.023082 0.009638 0.014107 0.023331 0.009244 +0.18% -1.07% +4.26%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130) mean 0.013420 0.025834 0.008949 0.013368 0.025724 0.008918 +0.39% +0.43% +0.35%
(1024, 130) sum 0.013300 0.025940 0.009113 0.013266 0.025419 0.008922 +0.26% +2.05% +2.14%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130) mean 0.013993 0.017883 0.009661 0.014275 0.018220 0.009596 -1.98% -1.85% +0.68%
(8192, 130) sum 0.014026 0.018297 0.010066 0.014326 0.018257 0.009659 -2.09% +0.22% +4.21%
================================================================================================================================================================================================================================================
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165178
Approved by: https://github.com/ngimel
ghstack dependencies: #165494 , #164790 , #165055
2025-10-19 23:39:05 +00:00
6b80c94901
[FlexAttention] Fix dynamic shaped heads flex_flash check ( #165866 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165866
Approved by: https://github.com/BoyuanFeng
ghstack dependencies: #165729
2025-10-19 23:10:16 +00:00
8951df03de
test_scaled_matmul_cuda: fix infer_scale_swizzle ( #165788 )
...
Extend #165747 fix to other cases.
Add parentheses to clarify operator precedence.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165788
Approved by: https://github.com/jeffdaily , https://github.com/slayton58
2025-10-19 21:42:01 +00:00
8139f33fa5
[dynamo] Add recompile reason for set_stance fail_on_recompile ( #165445 )
...
Fixes #163500
### Summary:
For `set_stance("fail_on_recompile")` failures will provide the reason why the recompilation occurred
### Impacts:
module: dynamo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165445
Approved by: https://github.com/williamwen42
2025-10-19 21:12:19 +00:00
a88587348b
[dynamo] Clean up assert in dynamo [1/N] ( #165430 )
...
Fixes some part of #162852 and #164878 . These two issues have some relationship though.
* __->__ #165430
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165430
Approved by: https://github.com/Lucaskabela , https://github.com/williamwen42
Co-authored-by: Lucas Kabela <lucasakabela@gmail.com >
2025-10-19 21:00:05 +00:00
633a3b7f67
Revert "shrink_group implementation to expose ncclCommShrink API ( #164518 )"
...
This reverts commit fa0db212e717b6cb225159cb32ea3d83baa52381.
Reverted https://github.com/pytorch/pytorch/pull/164518 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164518#issuecomment-3419893217 ))
2025-10-19 19:20:45 +00:00
fa0db212e7
shrink_group implementation to expose ncclCommShrink API ( #164518 )
...
Closes #164529
To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink ) API to PyTorch.
This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization.
For more info: [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518
Approved by: https://github.com/kwen2501
2025-10-19 18:00:08 +00:00
15ff1cd28b
Remove E721 suppression in flake8 ( #165855 )
...
Currently all files pass the E721 check.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165855
Approved by: https://github.com/albanD
2025-10-19 17:51:12 +00:00
c73f5080de
Migrating some more callsites ( #163580 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163580
Approved by: https://github.com/avikchaudhuri
ghstack dependencies: #165582
2025-10-19 15:52:17 +00:00
22ae059d32
AOTI util deprecated flow using the new tracer ( #165582 )
...
Reapply of https://github.com/pytorch/pytorch/pull/163260
AOTI utils expect free function sometimes so adjust export API to handle that, haven't seen any methods getting exported. Some AOTI flows also require we populate dynamo_flat_name_to_original_fqn so i just copy how it is done in eval_frame.py. I also cleaned up how we get rid of export_root and fixed some overcomplicated nn_module_stack handling in export code. The logic is simpler now thanks to @anijain2305 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165582
Approved by: https://github.com/anijain2305
2025-10-19 15:52:16 +00:00
1b121d636e
Fix AllocatorConfig parse roundup division bug ( #165304 )
...
* #165288
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165304
Approved by: https://github.com/albanD
ghstack dependencies: #165288 , #165289 , #165291 , #165298
2025-10-19 15:34:44 +00:00
1ba808dd97
Refine CUDA BackendStaticInitializer for allocator select ( #165298 )
...
* #165288
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165298
Approved by: https://github.com/albanD
ghstack dependencies: #165288 , #165289 , #165291
2025-10-19 15:34:44 +00:00
b2f5c25b27
Introduce a generic API torch._C._accelerator_setAllocatorSettings ( #165291 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165291
Approved by: https://github.com/albanD
ghstack dependencies: #165288 , #165289
2025-10-19 15:34:36 +00:00
a1114beed2
Deprecate overlapped functions in CUDAAllocatorConfig ( #165289 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165289
Approved by: https://github.com/albanD
ghstack dependencies: #165288
2025-10-19 15:34:26 +00:00
4888ed440e
Refine Allocator Config error message friendly ( #165288 )
...
* __->__ #165288
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165288
Approved by: https://github.com/albanD
2025-10-19 15:34:17 +00:00
5d62b63a76
[BE] Use Python-3.14 GE build ( #165804 )
...
3.14 reached general availability on Oct 7th 2025, so we can remove all pre-release workarounds
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165804
Approved by: https://github.com/yangw-dev , https://github.com/Skylion007 , https://github.com/cyyever
2025-10-19 11:45:10 +00:00