Summary:
Moves the function used to load CuTeDSL Jinja templates up one level out of the flex attention folder. This way it can be used for more generate Inductor templates in the future.
Test Plan: `INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:cutedsl_grouped_mm -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8`
Reviewed By: drisspg
Differential Revision: D84527470
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165576
Approved by: https://github.com/jananisriram
Changed the implementation from an output-based approach to an input-based one to remove `atomicAdd` operations, and it appears to deliver at least a 20× speedup.
The changes are from Yu-Yun <YuYun.Chang@amd.com>.
# Summary: Refactor of the implementation of the `upsample_bilinear2d_backward` opertion on MI300X/MI325X
- The original "scatter-add" approach
- Each thread, representing an output pixel, scattered gradient contributions to four input pixels, using costly atomic operations on MI300X/MI325X GPUs.
- The new "gather-sum" approach
- Each thread is responsible for a single input pixel and gathers all relevant gradient contributions from a small, calculated region of the output tensor (done by the `compute_output_range` device function).
# Breakdown of the code changes
- Inversion of the parallelization strategy of the kernel function `upsample_bilinear2d_backward_out_frame`
- Originally, the main kernel loop was parallelized over the number of elements in the output gradient tensor (`const size_t o_numel = nc * width2 * height2;`).
- Each thread processed one output pixel.
- The new loop is parallelized over the number of elements in the input gradient tensor (`const size_t i_numel = nc * height1 * width1;`).
- Each thread is responsible for calculating the final gradient for a single input pixel.
- The kernel launch changes accordingly in the function `upsample_bilinear2d_backward_out_cuda_template`.
- Added a device function for calculating the range of output pixels that could have possibly used that the input pixel (`input_pos`) during the forward pass interpolation
- This is essentially the mathematical inverse of the forward pass.
- This function tries to prune a thread's search space so that it only needs to inspect a small, local window of the output tensor.
- Gradient calculation approach switching from "scatter-add" to "gather-sum"
- Scatter-add
- For each output pixel, the thread calculated 4 gradient contributions and use `fastAtomicAdd` 4 times to add these values to 4 different (and potentially highly contended) memory locations in the input gradient tensor.
- Gather-sum
- A thread responsible for one input pixel calls `compute_output_range` to determine the small rectangular region of output pixels that influence the input's final gradient value.
- The thread iterates through this region, and for each output pixel in the regionre, it re-calculates the interpolation weights to determine the exact contribution to its specific input pixel.
- All these contributions are accumulated into a private, per-thread register variable (`accscalar_t grad_sum = 0;`).
- W/o any gloabl memory access, this accumulation is extremely fast.
- When the loops are done, the thread performs a single, direct write (non-atomic) of the final summed gradient to its designated location in global memory (`idata[index] = static_cast<scalar_t>(grad_sum);`).
# Why performance gets boosted
- Analysis of the root cause of performance drop
- Ref. (internal only) - https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1140493327/PyTorch__upsample_bilinear2d_backward
- First and foremost, elimination of the contention of atomic operations
- Many parallel threads called `atomicAdd` frequently attempting to update the exact same memory location in the input gradient tensor at the same time.
- The GPU's memory controler has to serialize these operations, effectively nullifying the benefit of parallel capability at those contention points.
- MI300X/MI325X chiplet-based CDNA 3 architeture amplified the issue.
- When contending threads reside on different XCDs, resolving the atomic operation requires high-latency coherence traffic across the Infinity Fabric interconnect.
- The implementation change eliminates hardware-level serialization and cross-chiplet coherence traffic caused by many `atomicAdd`.
- Improved memory access pattern and locality
- Write coalescing
- The regular sum writes `idata[index] = static_cast<scalar_t>(grad_sum);` can be perfectly coalesced by GPUs.
- Read locality
- Even though there are many (potentially repeated) reads from the output tensor (`static_cast<accscalar_t>(odata[output_idx])`), these are highly cache-friendly, meaning the data for one thread is likely to be in the L1 or L2 cache already due to an access from a neighboring thread.
- Trade-off: computation for memory synchronization
- The recalculation of interpolation weights fits well on high-computational-throughput modern GPUs like MI300X/MI325X.
- Removal of atomic operations avoids expensive memory synchronization.
---
Optimizations of `grid_sampler_2d_backward` will be addressed in a separate PR.
Doc for reference: (internal only) https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1162750701/PyTorch__grid_sampler_2d_backward
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164572
Approved by: https://github.com/jeffdaily
Bucketing a number of smallish improvements:
- Account for bucketing in overlap calculation: if an in-flight collective exists with the same bucket key, reduce new collectives estimated time by its latency time
- Update compute domination so we are ordering based on compute idx, as opposed to compute depth, so we never reorder compute. this makes it a bit easier to reason about memory, and pre-fetching, although we can exploring reordering in the future.
- When we wait on a collective, force all collectives on the same process group as it that were enqueued prior to the collective to wait as well.
Better Memory Handling:
- Pre-fetch limiting - when scheduling collectives for overlap, only pre-fetch up to a certain distance, then schedule off-path collectives (which are typically memory reducing).
- When we are above peak memory, schedule waits.
TODO:
- for each compute node, we know its original memory in the graph. we could limit pre-fetching that goes across peak memory
- By scheduling off-path collectives for overlap, we reduce memory, but if there weren't enough compute for overlap, we need to proactively schedule them. not an issue yet on examples.
- config some hard coded constants, clean up enablement (can do in subsequent pr)
On small llama 2d backward :
578 of 618 potentially hideable collectives hidden
original mem 14.4GB, rescheduled mem, 15.9GB
on forward:
254/256 potentially hideable collectives hidden
original mem 5.8 gb, reshceduled mem 5.8GB
WIP: adding tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165318
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783, #164944, #164945, #165059
For `test_graph_partition_with_memory_plan_reuse`, before this PR, when using graph partition, it would error ([P1992728479](https://www.internalfb.com/phabricator/paste/view/P1992728479)):
```
def partition_0(args):
...
del buf0
return (buf3, buf4, buf5, buf2, primals_4, )
...
File "/tmp/torchinductor_boyuan/ww/cwwc7ukfqscg2vy6ankby2fizdb377tvgyx3fwdgddrxe3g47jg6.py", line 132, in partition_0
return (buf3, buf4, buf5, buf2, primals_4, )
^^^^
NameError: name 'buf2' is not defined. Did you mean: 'buf0'?
```
When not using graph partition, it would work and give the following code ([P1992997521](https://www.internalfb.com/phabricator/paste/view/P1992997521)):
```
def call(self, args):
...
buf2 = buf0; del buf0 # reuse
...
```
Note that the issue is buf0 is not reused for buf2 when using graph partition.
Why? Because the codegen runs `run_wrapper_ir_passes` and `memory_plan_reuse`, which pops tailing `MemoryPlanningLine` unless it is in graph output by checking `V.graph.get_output_names()`. However, for graph partition, we should check the output of the current partition instead of the graph before partition.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165514
Approved by: https://github.com/ProExpertProg, https://github.com/eellison
https://github.com/pytorch/pytorch/pull/164820 introduced a bug that `_StridedShard` will call parent class `Shard`'s `split_tensor` method, thus results in incorrect data locality. (I think @ezyang spotted this issue, but we have no test to capture this)
Meanwhile, I notice another bug that when we normalize a `_StridedShard`'s placement, it will also trigger parent class `Shard`'s `split_tensor` method because it will create a Shard class [here](0c14f55de6/torch/distributed/tensor/_api.py (L783)). I think we never test `distribute_tensor` for `_StridedShard` before. So I added a test here to compare against ordered shard.
Using classmethod because the _split_tensor logic is different between `Shard` and `_StridedShard`. Basically I want to shard on local tensors without initializing the Shard object:
```
local_tensor = _StridedShard._make_shard_tensor(dim, tensor, mesh, mesh_dim, split_factor=split_factor)
local_tensor = Shard._make_shard_tensor(dim, tensor, mesh, mesh_dim)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165533
Approved by: https://github.com/XilunWu
**Summary**
Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA.
This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.
**Benchmarking**
To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.
Settings:
- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs
| | Variable Length API | SDPA |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms | 0.43171775817871094 ms |
| TFLOPs | 231.812 | 320.840 |
The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.
**Testing**
Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.
**Next steps**
Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics.
(This stack builds on top of #162326)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164502
Approved by: https://github.com/v0i0, https://github.com/drisspg
These happen when building with CMAKE_BUILD_TYPE=RelWithAssert
This should fix two types of failures that started with https://github.com/pytorch/pytorch/pull/163665
Disclaimer that I used a lot of AI since I don't how pybind works or what refcounts and pointers are, so idk if this is a good solution, or even a solution at all (fwiw the tests pass now)
The first one type is
Truncated:
```
default_pg, _ = _new_process_group_helper(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2096, in _new_process_group_helper
backend_class = creator_fn(dist_backend_opts, backend_options)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/distributed/fake_pg.py", line 25, in _create_fake_pg
return FakeProcessGroup._create_internal(
RuntimeError: new_refcount != 1 INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/c10/util/intrusive_ptr.h":319, please report a bug to PyTorch. intrusive_ptr: Cannot increase refcount after it reached zero.
Exception raised from retain_ at /var/lib/jenkins/workspace/c10/util/intrusive_ptr.h:319 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
#6 c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) from ??:0
#7 c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, char const*) from ??:0
#8 void pybind11::class_<c10d::FakeProcessGroup, (anonymous namespace)::IntrusivePtrNoGilDestructor<c10d::FakeProcessGroup> >::init_instance<(anonymous namespace)::IntrusivePtrNoGilDestructor<c10d::FakeProcessGroup>, 0>(pybind11::detail::instance*, void const*) from init.cpp:0
#9 pybind11::detail::type_caster_generic::cast(void const*, pybind11::return_value_policy, pybind11::handle, pybind11::detail::type_info const*, void* (*)(void const*), void* (*)(void const*), void const*) from :0
#10 pybind11::cpp_function::initialize<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >)#127}, c10::intrusive_ptr<c10d::FakeProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup> >, int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >, pybind11::name, pybind11::scope, pybind11::sibling, pybind11::arg, pybind11::arg, pybind11::arg_v>(torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >)#127}&&, c10::intrusive_ptr<c10d::FakeProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup> > (*)(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) from init.cpp:0
```
and I fix it here by getting rid of `DontIncreaseRefcount` and using make_intrusive to do the ref count handling instead. However, I also had to move the constructor to be public, which I think is not good, based on the reasoning of the original PR
The other one type is
```
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/test/test_testing.py", line 2415, in test_no_warning_on_import
self.assertEqual(out, "")
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 4233, in assertEqual
raise error_metas.pop()[0].to_error( # type: ignore[index]
AssertionError: String comparison failed: "/opt/conda/envs/py_3.10/lib/python3.10/s[352 chars]):\n" != ''
- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/__init__.py:29: FutureWarning: pybind11-bound class 'torch._C._distributed_c10d.FakeProcessGroup' is using an old-style placement-new '__init__' which has been deprecated. See the upgrade guide in pybind11's docs. This message is only visible when compiled in debug mode.
- if is_available() and not torch._C._c10d_init():
To execute this test, run the following from the base repo dir:
python test/test_testing.py TestImports.test_no_warning_on_import
```
which I fix by getting rid of the `__init__` which I think is ok since it'll just error if you try to make one?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165479
Approved by: https://github.com/ezyang
Summary:
* Add `torch._scaled_grouped_mm_v2` with more functionality and
extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality
Test Plan:
```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```
Reviewers:
Subscribers:
Tasks:
Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165154
Approved by: https://github.com/drisspg, https://github.com/danielvegamyhre
Summary: Moves the function used to load CuTeDSL Jinja templates up one level out of the flex attention folder. This way it can be used for more generate Inductor templates in the future.
Test Plan: `INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:flex_flash -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8`
Differential Revision: D84527470
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165347
Approved by: https://github.com/drisspg
The `_flatten_mapping` field was defined as a class attribute with a mutable default value {}:
```
_flatten_mapping: dict[str, "DeviceMesh"] = {}
```
This caused all DeviceMesh instances to share the same dictionary object. When multiple test instances tried to create flattened meshes with the same name (like "dp"), they would conflict because they were all using the same shared dictionary, resulting in the error: "Flatten mesh with mesh_dim_name dp has been created before, Please specify another valid mesh_dim_name."
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165521
Approved by: https://github.com/fegin, https://github.com/lw
Fixes#165110
The `PUBLIC` scope causes CUTLASS of the FBGEMM being included in for all PyTorch targets, including special matmuls (RowwiseScaledMM, ScaledGroupMM and GroupMM). Due to version mismatch between FBGEMM/CUTLASS and PyTorch/CUTLASS it is unacceptable to use FBGEMM/CUTLASS in PyTorch targets. This PR limits the scope of FBGEMM/CUTLASS to `fbgemm_genai` target only.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165424
Approved by: https://github.com/cthi, https://github.com/eqy, https://github.com/danielvegamyhre
https://github.com/pytorch/pytorch/pull/164790 modifies aten to perform a different reduction order intra warp. However, this change exposed a large difference in a sum for complex32. Namely the case:
```
import torch
a = torch.tensor([[ 4.82031250+7.34765625j,
-3.37109375-1.9501953125j],
[ 3.7832031250-2.43359375j,
-6.07812500+5.32812500j]], dtype=torch.complex32, device='cuda:0')
sum_out = torch.sum(a)
nansum_out = torch.nansum(a)
torch.testing.assert_close(
sum_out,
nansum_out,
rtol=0,
atol=0,
)
```
Here, the result of `sum` and `nansum` differed significantly by 1e-2. Further investigation showed that the explicit casting of b back to `arg_t` from `scalar_t` was the root cause. `arg_t` is the dtype of the accumulator, ComplexFloat, and `scalar_t` of the input dtype, ComplexHalf. When we cast in the reduction to the accumulator order, that means the input is still of ComplexHalf, which loses precision as it can store intermediate values.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165494
Approved by: https://github.com/ngimel
Fixes#161943
## The Fix
I implemented a recursive unwrapping helper function in the `tensor_to_list.cpp` file that looks for wrapped tensors and unwraps them. The recursive implementation was needed for multi-level gradTrackingTensors.
Let me know if there is any more suggestions on fixing this issue!
@guilhermeleobas @KimbingNg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165184
Approved by: https://github.com/zou3519
reproducer
```
import torch
# does not crash
a = torch.rand((0), device="cpu")
b = torch.rand((0), device="cpu")
a.dot(b)
# crashes due to internal assert
a = torch.rand((0), device="mps")
b = torch.rand((0), device="mps")
a.dot(b)
```
Discovered when implementing an op for SparseMPS backend
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165237
Approved by: https://github.com/malfet
Fixes#160752
# Background:
`torch.func.jacfwd` is implemented as vmap over forward-mode JVP. With torch.compile(dynamic=True), FakeTensor + SymInt shape reasoning is used while tracing through the transform. The old vmap rule for one_hot decomposed into “zeros_symint + scatter,” which interacted poorly with the transform stack and dynamic shapes, leading to failures mid-trace. Using a functional equality construction makes one_hot composable with vmap/JVP and friendly to dynamic shape tracing.
# Changes:
- functorch vmap batching rule for `aten::one_hot` now uses a purely functional formulation:
- Replace “zeros + scatter” with eq(self.unsqueeze(-1), arange(num_classes)).to(kLong) under FuncTorchBatched.
- one_hot native path remains unchanged for regular eager; vmap transform no longer relies on scatter, which was fragile under dynamic shape tracing.
The minimal repro from the issue is now fixed:
```python
import torch
import torch.nn.functional as F
MAX, BATCH = 3, 37
def func(x, idxs):
return x.square() * F.one_hot(idxs, MAX)
def jacfunc(x, idxs):
return torch.func.jacfwd(func, argnums=0)(x, idxs)
idxs = torch.randint(MAX, (BATCH,), dtype=torch.int64)
x = torch.rand((BATCH, MAX), dtype=torch.float64)
# eager
out_eager = jacfunc(x, idxs)
# compiled dynamic
jacfunc_c = torch.compile(jacfunc, dynamic=True)
out_comp = jacfunc_c(x, idxs)
torch.testing.assert_close(out_eager, out_comp)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160837
Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519
Summary:
The `nDims` variable is mutated inside the loop but never restored to its original value.
This affects subsequent iterations of the outer loop.
Each batch iteration may get incorrect `nDims` after the first batch.
Test Plan: CI
Reviewed By: ngimel
Differential Revision: D84612194
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165446
Approved by: https://github.com/ngimel
Summary: If a function is wrapped with functools, we should not look at the wrapped function signature but rather the wrapper, since we need to construct the frame for the top level function here.
Test Plan: test_decorated_function_with_functools_wrap_aot
Differential Revision: D84626752
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165454
Approved by: https://github.com/yiming0416