Summary: As above, also changes a bunch of the build files to be better
Test Plan:
internal and external CI
did run buck2 build fbcode//caffe2:torch and it succeeded
Rollback Plan:
Reviewed By: swolchok
Differential Revision: D78016591
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158035
Approved by: https://github.com/swolchok
The general context for the upcoming stack of commits is I am attempting
to "pipeline" AOTAutograd. Instead of having function f call function g
which is the next "stage" of compilation, instead f should return with
its outputs, which are then piped to g for the next stage. This will
make it easier to implement early exit / resume pipeline without forcing
callback structure, which is good for export-style use cases. It also
reduces the size of our stack traces, which makes tools like Perfetto
happy.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158149
Approved by: https://github.com/jamesjwu
This PR disables `strict-aliasing` GCC C++ optimization flag on all AArch64 cpus for GCC versions 12 and above.
Pull Request #152825 upgraded gcc version from 11 to 13 in manywheel which caused several segmentation faults in unit tests ( not visible in CI workflows because the jammy gcc version has not been updated yet ).
We Identified the problem also exists in GCC12 hence the ` __GNUC__ >= 12`
Fixes#157626
fixes these tests failures when pytorch is built in GCC12 and above
```
test_ops.py::TestCommonCPU::test_noncontiguous_samples_grid_sampler_2d_cpu_float32 Fatal Python error: Segmentation fault
test_ops.py::TestCommonCPU::test_dtypes_grid_sampler_2d_cpu Fatal Python error: Segmentation fault
test_ops.py::TestMathBitsCPU::test_neg_view_nn_functional_grid_sample_cpu_float64 free(): invalid next size (fast)
test_ops.py::TestCompositeComplianceCPU::test_backward_grid_sampler_2d_cpu_float32 Fatal Python error: Segmentation fault
test_ops.py::TestCommonCPU::test_dtypes_nn_functional_grid_sample_cpu Fatal Python error: Segmentation fault
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158117
Approved by: https://github.com/malfet
Fixes#124435
This updates the torch.histogramdd documentation to correctly state that bins are inclusive of their left edges, not exclusive as currently written. There was a previous PR addressing this but it was closed due to inactivity. This picks that up and applies the fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158275
Approved by: https://github.com/albanD
Summary: Add flag TORCHINDUCTOR_CPP_FORCE_INLINE_KERNEL to force inline the kernel function when TORCHINDUCTOR_CPP_FORCE_INLINE_KERNEL=1. It's disabled by default because force inlining may increase the build time.
Differential Revision: D77915987
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157949
Approved by: https://github.com/desertfire
# Motivation
Refactor `CUDAAllocatorConfig` to reuse `AcceleratorAllocatorConfig` and `ConfigTokenizer`. We would deprecate those option that overleap with `AcceleratorAllocatorConfig` in the following PR and keep them only for BC.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150312
Approved by: https://github.com/albanD
----
# Refactor and Improve the OpenReg Module
## Background
Since PrivateUse1 has become the main path for integrating new devices with PyTorch, there have been some feature requests related to PrivateUse1 regarding interfaces, documentation, reference examples, etc., such as the following:
- https://github.com/pytorch/pytorch/issues/155864
- https://github.com/pytorch/pytorch/issues/144955
- https://github.com/pytorch/pytorch/issues/144845
Taking these requests into consideration and combining them with the position of OpenReg, which is currently used as the test backend for PrivateUse1, I'm planning to make the following optimizations:
- Optimize the implementation of OpenReg to make it align with the standard specifications for real backend (C++) access, serving as a reference for new device integration code.
- Add comprehensive documentation to the [developer notes](https://docs.pytorch.org/docs/main/notes.html) to guide new accelerator integration, functioning as a reference manual.
## Design Principles:
- Minimization Principle: Keep the code small and clear; only implement the minimum set of code required for verification and as an integration reference.
- Authenticity Principle: Integrate OpenReg in the same way that real accelerators access PyTorch.
## More Infos:
Pleaes refer to [this](6b8020f1ab/test/cpp_extensions/open_registration_extension/torch_openreg/README.md) for more information about `OpenReg`.
## Current Progress:
- Refer to the implementation of [torch_xla](https://github.com/pytorch/xla) to refactor all of OpenReg's code, making it easier to understand.
- Ensure all tests in [test/test_openreg.py](https://github.com/FFFrog/pytorch/blob/openreg/test/test_openreg.py) pass after refactoring.
## Next Steps:
- Add more features to cover all integration points.
- Gradually add user guides and documentation to the [developer notes](https://docs.pytorch.org/docs/main/notes.html).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158090
Approved by: https://github.com/seemethere, https://github.com/albanD
The `test_triton_wait_until` test was hanging due to an NCCL synchronization issue stemming from mismatched NVSHMEM operations. Specifically, the flag variable was updated using `nvshmemx_signal_op` (a signaling operation), but waited on with `nvshmem_wait_until` (intended for put/get updates). Per NVSHMEM documentation (see documentation reference section below), signal-updated variables require `nvshmem_signal_wait_until` for proper completion guarantees, so the mismatch caused a deadlock and NCCL hang.
**Fix:**
- A simple fix was to replace the flag update with a regular `nvshmem_putmem_block` (via `put_kernel`) to match `nvshmem_wait_until`. I also added a fence (`nvshmem_fence`) between data and flag puts on the sender (Rank 1) for ordered delivery.
- In a follow-up PR I will add a kernel/test to demonstrate usage of `nvshmemx_signal_op`
**Testing:**
- I ran `python test/distributed/test_nvshmem_triton.py` and `python test/distributed/test_nvshmem_triton.py -k test_triton_wait_until`
- I also verified with debug prints (Sender completes puts/fence before receiver's wait returns, and assertions confirm correct state). Multiple runs show no hangs or failures.
**Documentation Referenced:**
- [NVSHMEM Point-To-Point Synchronization](https://docs.nvidia.com/nvshmem/api/gen/api/sync.html) explicitly states: *"the sig_addr object at the calling PE is expected only to be updated as a signal, through the signaling operations available in Section NVSHMEM_PUT_SIGNAL and Section NVSHMEM_PUT_SIGNAL_NBI"*
- [NVIDIA's Official Ring Broadcast Example](https://docs.nvidia.com/nvshmem/api/examples.html) demonstrates the correct pairing: `nvshmemx_signal_op` with `nvshmem_signal_wait_until` (not `nvshmem_wait_until`)
- [NVSHMEM Signaling Operations](https://docs.nvidia.com/nvshmem/api/gen/api/signal.html) documents that signal operations work on special "signal data objects" with specific atomicity guarantees distinct from regular RMA operations
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158167
Approved by: https://github.com/Skylion007, https://github.com/fduwjj
Beginning of process for 3.14 bringup.
State of things from this PR:
- Nothing too scary looking from the Dynamo CPython side, nothing we heavily rely on seems to be missing @williamwen42
- The existing check that makes torch.compile() nicely fail is working as expected. So all these empty functions shouldn't cause any weirdness.
- The `__module__` update changes look suspicious, we should investigate what is the reason and impact of that, in particular for our public API checking @jbschlosser
- Leaving the weakref.py thread safety change as a follow up to keep this a bit simpler. I vendored the whole struct in the meantime FYI @ezyang
EDIT: The `__module__` change is even more cursed than I though due to changes to Union and Optional type where the `__module__` field cannot be changed anymore. See https://github.com/python/cpython/issues/132139 for details.
For now, I'm just skipping the `__module__` setting for 3.14 which will trip the public API checks. Will revisit once I have a final answer on the cpython issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158184
Approved by: https://github.com/msaroufim
**Summary**
`split_strategy` used `TupleStrategy` as return type because DTensor sharding
propagation's `OpStrategy` support on multi-returns only applies to `Tuple`.
However, `TupleStrategy`'s not a good fit for `split` op. `TupleStrategy` was
initially introduced to handle the sharding strategy of `foreach_*` ops where
the input args can be split into independent subsets regarding sharding decisions,
so are the outputs.
To address the misuse, this PR adds `OpStrategy` propagation for `List[Tensor]`
(note that this support is INCOMPLETE because it only checks the return type
to be `torch.ListType`). Nevertheless, the logic for `Tuple` returns also made
similar assumption so I think it's fine to unblock in such a way.
Besides adding `OpStrategy` support to ops having `List[Tensor]` return type,
this PR also changes `split_strategy`'s return from `TupleStrategy` to `OpStrategy`.
**Test**
`pytest test/distributed/tensor/test_tensor_ops.py -s -k test_split_on_partial`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158051
Approved by: https://github.com/wconstab, https://github.com/zpcore
local_tensor input to grouped_mm has a stride requirement.
(see `_meta_grouped_mm_common` in meta_registrations.py or
`check_valid_strides_and_return_transposed` in native/cuda/Blas.cpp)
Don't allow sharding a tensor if its shape would result in an
incompatible local_tensor stride.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158245
Approved by: https://github.com/zpcore, https://github.com/XilunWu
This PR allows for symints in `gen_slice_strategy` which is the strategy for `aten.slice.Tensor`. Previously, using dynamic shapes with slicing would result in
```
File ".../pytorch/torch/distributed/tensor/_ops/_tensor_ops.py", line 348, in gen_slice_strategy
assert isinstance(end, int)
torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <built-in function getitem>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(s3, 2)), device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)), slice(None, (s77//2), None)), **{}): got AssertionError()
```
Questions before merge:
1. `dim` is still asserted to be int. Is this fine, or is this potentially dynamic as well?
2. I'm using argtype ignore for `normalize_dim`. Should I instead change types for `normalize_dim` and further dependency to be `IntLike` as well?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157953
Approved by: https://github.com/wconstab
When loading a package and calling package.install(backends), we create a new frame and compile id for each package load, so that tlparse and chromium events still show compile times on warm start.
There is an argument for not doing this in AOT precompile, as no "compile" occurs. So for now, we put it in `package.install`, which hopefully won't be a thing for AOT precompile.
## Recompiles
Recompiles get saved to the same frame and code entry, so on warm start, each recompile will get collapsed into the same entry. Therefore, dynamo compiles that have recompiles on cold start (0/0, 0/1, 0/2, etc) will all get collapsed into a single compile id (0/0), as warm start will load all of the entries properly.
## Graph breaks
Graph breaks get their own compile id, and therefore their own code entry. These are replicated on warm start, so if cold start you had 4 different graphs (and therefore 4 compile ids), you'll have 4 compile ids on warm start as well.
## Test plan
Added a frame counter check to existing unit tests for automatic dynamic, showing that old and new frame counter between old and new load is the same.
This is the chromium event for test_automatic_dynamo_graph_breaks_device_cuda:
```
python test/dynamo/test_package.py -k test_automatic_dynamo_graph_breaks_device_cuda
```
<img width="2216" height="508" alt="image" src="https://github.com/user-attachments/assets/f604ed33-5c31-464b-9320-d67b2e6f57a1" />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158028
Approved by: https://github.com/oulgen
This is intended to make it easier to have backend specific "hints" that can be provided by the user to hint about certain options.
```py
import torch.distributed._dist2 as dist2
pg = dist2.new_group(backend="my_custom_backend", device=..., timeout=..., foo=1234, bar="1234")
pg.allreduce(...)
```
Test plan:
```
pytest test/distributed/test_dist2.py
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158147
Approved by: https://github.com/fduwjj