Compare commits

...

34 Commits

Author SHA1 Message Date
89fb2567e7 Add annotation to assertion nodes 2025-11-05 17:53:13 -08:00
fd5edda1ed Reland "Add model code stack trace to torch.profile (#166677)" (#167110)
```python
python test/test_fx.py -k profiler
```

Insert `torch._C._profiler._RecordFunctionFast` to fx graph codegen.

We post-process the profiler dump using `map_recorded_events_to_aten_ops_with_stack_trace` to add the stack trace to the dump'd trace.

`map_recorded_events_to_aten_ops_with_stack_trace` queries `fx.traceback._FX_METADATA_REGISTRY` for node metadata. Each graph module has a hash'd fake file name (e.g. `fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py`), which is the key to the registry.

One can do `fx_g.enrich_profiler_metadata()` to add debugging info. Or `fx_g.enrich_profiler_metadata(enable=False)` to remove.

`aot_eager` makes calls `fx_g.enrich_profiler_metadata()` if TORCH_ENRICH_RPOFILER_STACK_TRACE is set or _dynamo.config.enrich_profiler_metadata=True.

<img width="1188" height="565" alt="Screenshot 2025-10-31 at 4 40 52 PM" src="https://github.com/user-attachments/assets/41e8113f-3e6d-439b-bffd-cfbf0c03a47a" />

Example code gen'd.
```
def forward(self, args_list):
    args_iter = iter(args_list)
    arg0_1 = next(args_iter)
    arg1_1 = next(args_iter)
    args_list.clear()
    _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py ##'); _rf.__enter__()
    repeated_subgraph0 = self.repeated_subgraph0
    _rf_invoke_subgraph = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_invoke_subgraph.__enter__()
    invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1);  repeated_subgraph0 = arg0_1 = arg1_1 = None
    _rf_invoke_subgraph.__exit__(None, None, None)
    _rf_getitem = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_getitem.__enter__()
    getitem = invoke_subgraph[0];  invoke_subgraph = None
    _rf_getitem.__exit__(None, None, None)
    return (getitem,)
    _rf.__exit__(None, None, None)

def forward(self, arg0_1, arg1_1):
    _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__ozpadpj5cxoalxeyopej33g2vvtvhxg4xsk7bhx7ldmcibtybyn.py ##'); _rf.__enter__()
    _rf_mul = torch._C._profiler._RecordFunctionFast('## 2 ##'); _rf_mul.__enter__()
    mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
    _rf_mul.__exit__(None, None, None)
    _rf_sin = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_sin.__enter__()
    sin = torch.ops.aten.sin.default(mul);  mul = None
    _rf_sin.__exit__(None, None, None)
    _rf_add = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_add.__enter__()
    add = torch.ops.aten.add.Tensor(sin, 5);  sin = None
    _rf_add.__exit__(None, None, None)
    return (add,)
    _rf.__exit__(None, None, None)

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167110
Approved by: https://github.com/pianpwk
2025-11-06 01:14:27 +00:00
872d1daec2 Avoid DDE in narrow with unbacked start (#166361)
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice.
The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate,
for that case we shall pass dim_size instead of start+length

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166361
Approved by: https://github.com/aorenste
2025-11-06 01:04:19 +00:00
eqy
6cd57e6fc2 [cuBLAS] Force tensor-core-no-reduction algo in cuBLASLt for n=1 cases (#166735)
Ostensibly useful for batch-invariance purposes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166735
Approved by: https://github.com/ngimel
2025-11-06 00:50:42 +00:00
d29efba8fa Move almalinux docker image to DEVTOOLSET 13 (#167018)
1. Update general Almalinux image to Devtoolset 13.
2. Fix ROCm images, missing devtoolset-13
This image used by Linux Job in test-infra
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167018
Approved by: https://github.com/sudharssun, https://github.com/d4l3k
2025-11-06 00:34:40 +00:00
a344069f2a Add missing skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION) to test/test_transformers.py (#166969)
This PR adds missing skips for efficient attention tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166969
Approved by: https://github.com/jeffdaily
2025-11-05 23:16:51 +00:00
af829c0dad [ROCm] Skip nvfp4 tests on ROCm (#167066)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167066
Approved by: https://github.com/jeffdaily, https://github.com/slayton58
2025-11-05 23:15:17 +00:00
3869aa115b fix fr reset api (#166970)
Summary:
- there are various places that access fr's `entries_` field
- if we empty the entries_ on reset, the accesses can result in an error
- so we only perform a soft delete instead of clearing out the entries copletely
  - only reset id_ on the reset
  - keep track of a reset_epoch which increments everytime reset is called
  - dump_entries only returns entries from the latest epoch
  - api's that access entries also check if the reset epoch matches
- make the `next_` always track the index in the circular buffer - this change was needed to make the soft delete's implementation easier

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/166970).
* #166972
* #166971
* __->__ #166970

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166970
Approved by: https://github.com/fduwjj
2025-11-05 23:06:00 +00:00
47eb34b7ac [ATEN][CUDA] Reduce register pressure in radix_sort_pairs to improve torch.sort performance (#167094)
# Summary
This PR improves `torch.sort` and `torch.unique` performance by **15% to 50%** on NVIDIA GPUs by optimizing CUDA register allocation in radix sort operations.

The key change: specialize `OpaqueType<N>` to use native integer types (uint8_t, uint16_t, uint32_t, uint64_t) for common sizes (1, 2, 4, 8 bytes) instead of `char data[N]`. This enables more efficient register allocation while preserving the template deduplication strategy.

The following table shows the speedup on various input shapes and GPUs. Sorting is performed on the last dimension, and baseline torch version is 2.9.0.

| GPU  | input shape | input dtype | **Before** **(ms)** | After (ms) | Speedup |
| ---- | ----------- | ----------- | ------------------- | ---------- | ------- |
| H100 | (16, 1e6)   | int32       | 1.61                | 1.37       | 1.18×   |
| H100 | (1, 1e8)    | int32       | 6.6                 | 5.0        | 1.3×    |
| H20  | (16, 1e6)   | int64       | 3.57                | 3.03       | 1.18×   |
| H20  | (1, 1e8)    | int64       | 19.3                | 13.0       | 1.48×   |

# Analysis

`torch.sort` and `torch.unique` use `radix_sort_pairs`, which internally calls `cub::DeviceRadixSort::SortPairs`. Since values are only copied (never compared), we cast them to `OpaqueType<sizeof(value_t)>` to minimize template instantiations. For example, both `int32` and `float32` values map to the same `OpaqueType<4>.`

## The Problem

The previous `char data[N]` implementation causes inefficient register allocation. Here is one reason I find from SASS code. For 8-byte types:

- `char data[8]:` Compiler may allocate 8 registers (one per byte)

- `uint64_t data`: Compiler allocates 2 registers (standard 64-bit handling)

This happens because the compiler doesn't recognize char[8] as a cohesive 64-bit value, treating each byte independently, which increases register pressure and reduces GPU occupancy.

From Nsight Compute, when using `char data[8]`, the registers per thread is 166, and corresponding theoretical occupancy is 18.75%. When using native `uint64_t`, the registers per thread is 80, and corresponding theoretical occupancy is 37.5%.

## The Solution

Specialize `OpaqueType<N>` for common sizes using native integer types:

```
// Before
template <int N> struct alignas(N) OpaqueType { char data[N]; };

// After
template <int N> struct alignas(N) OpaqueType { char data[N]; }; // fallback
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
```

This preserves the template deduplication strategy (all 8-byte types still use the same `OpaqueType<8>` instantiation) while enabling better register allocation.

# Testing & Compatibility
## Testing:
 Correctness tests pass for various input types (bfloat16, int32, float32, int64), shapes, and dimensions (1, 2, 3)
 Register usage reduction verified with NSight Compute
 Linter passes
## Compatibility:
 No API/ABI changes
 Template instantiation count unchanged

# Reference
For detailed analysis, please refere to my previous blog: [Performance Optimization of torch.sort on GPU](https://yywangcs.notion.site/Performance-Optimization-of-torch-sort-on-GPU-192fc9f5d8058018a1bec1efa35da3f9)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167094
Approved by: https://github.com/ngimel, https://github.com/Skylion007
2025-11-05 22:34:19 +00:00
08200280ce [CP][BE][3/N] Add _templated_ring_attention to the backward compatility stub (#166991)
While `_templated_ring_attention` is a private API, it is unfortunatelly used by some packages.
Add it to __all__ so that people can still use it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166991
Approved by: https://github.com/XilunWu
ghstack dependencies: #166456, #166501
2025-11-05 22:22:55 +00:00
ad7a57262c [12/N] Apply ruff UP035 rule (#166929)
This PR continues to apply ruff UP035 rule to test code and some remaining torch files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166929
Approved by: https://github.com/Lucaskabela
2025-11-05 22:06:19 +00:00
711a775878 fix nccl estimations (#167093)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167093
Approved by: https://github.com/kwen2501, https://github.com/eellison
2025-11-05 22:01:49 +00:00
e9a688f02e [DebugMode] output, tensor id annotations for DebugMode (#165076)
Adds optional "node" id for tensors, output info annotations to DebugMode, with `DebugMode(record_output=True, record_ids=True)`

Example output for `test_debug_mode_mm`, with both enabled:
```
  torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0))  ->  dt$12: f32[8, 32]| S(0)
    aten::mm(dt$2: f32[8, 8]| S(0), dt$3: f32[8, 32]| S(0))
      redistribute_input(1, S(0) -> R)
        redistribute_input(t$4: f32[1, 32], trace: S(0)->R)
          _c10d_functional::all_gather_into_tensor(t$5: f32[1, 32], 8, 0)  ->  t$6: f32[8, 32]
          _c10d_functional::wait_tensor(t$7: f32[8, 32])  ->  t$8: f32[8, 32]
      aten::mm(t$9: f32[1, 8], t$10: f32[8, 32])  ->  t$11: f32[1, 32]
  <method 'sum' of 'torch._C.TensorBase' objects>(dt$13: f32[8, 32]| S(0))  ->  dt$17: f32[]| P
    aten::sum(dt$14: f32[8, 32]| S(0))
      aten::sum(t$15: f32[1, 32])  ->  t$16: f32[]"""
```

Sadly the only way to get DTensor op outputs is to set `record_torchfunction=True`, as dispatch calls just defer to DTensor's dispatch logic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165076
Approved by: https://github.com/zpcore
2025-11-05 22:00:11 +00:00
e69aaaf45a [user-streams] Add backward test (#167021)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167021
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167019
2025-11-05 21:24:44 +00:00
fd8f368d31 [user-streams] Add graph annotation checks (#167019)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167019
Approved by: https://github.com/Lucaskabela
2025-11-05 21:24:44 +00:00
13d2cc7bd2 Remove python workaround for ContextDecorator (#167049)
This PR removes the import workaround for ContextDecorator because the import always succeeds in Py 3.10+.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167049
Approved by: https://github.com/Skylion007
2025-11-05 20:56:04 +00:00
c6c913d18e Add torch::stable::Tensor sizes and strides (#165153)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165153
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #164991, #165152
2025-11-05 20:55:34 +00:00
ef3f953966 Revert "[DebugMode] output, tensor id annotations for DebugMode (#165076)"
This reverts commit a64c7d740428010d700b4bcd395af8a7b2d5c21f.

Reverted https://github.com/pytorch/pytorch/pull/165076 on behalf of https://github.com/wdvr due to Sorry but this is breaking internally. See diff [D86245252](https://l.workplace.com/l.php?u=https%3A%2F%2Fwww.internalfb.com%2Fdiff%2FD86245252&h=AT1oPbS1XTv6HjYeYdxmDMW1-jlT0pS8yBO2iSfbPfUB9ydsEjFXBNT56QhV1v5TKc4_QaQNxykNowSKmb4fgenjOyCv20NuL7oV_Id5fhh32hhv1IpjgsDJYK-PBFfSfv_miLIWfNgj902KcgXojbBgDcDzQeS9lNt0GQ) for details. To validate your fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/165076#issuecomment-3493358159))
2025-11-05 20:52:43 +00:00
ea44f12bce [13/N] Apply ruff UP035 rule (#167048)
This PR continues to apply ruff UP035 rule to test code and some remaining torch files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167048
Approved by: https://github.com/Skylion007
2025-11-05 20:51:53 +00:00
a74fe75c45 Don't hardcode double argument for reduction base (#166951)
Fixes https://github.com/pytorch/pytorch/issues/43254

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166951
Approved by: https://github.com/ngimel, https://github.com/Skylion007
ghstack dependencies: #166813
2025-11-05 20:34:15 +00:00
6d30666bc1 Revert "[12/N] Apply ruff UP035 rule (#166929)"
This reverts commit 5863ba1b2e4de9ea0ae16a663465ec5d3d6f9f52.

Reverted https://github.com/pytorch/pytorch/pull/166929 on behalf of https://github.com/donigian due to Temporarily need to revert this to continue a revert for #165076. @cyyever Please re-merge after revert of #165076. ([comment](https://github.com/pytorch/pytorch/pull/166929#issuecomment-3493090596))
2025-11-05 20:02:47 +00:00
8e8cbb85ee Revert "[Inductor] Fix unbacked float symbol handling in kernel codegen (#166890)"
This reverts commit 0c7a4a6b48d49306eae8d0a9ee8d32b1899e5e23.

Reverted https://github.com/pytorch/pytorch/pull/166890 on behalf of https://github.com/malfet due to Looks like it broke torchfuzz tests, see fbd70fb84e/1 and same test on slow ([comment](https://github.com/pytorch/pytorch/pull/166890#issuecomment-3493011038))
2025-11-05 19:42:39 +00:00
fbd70fb84e Update typing docs to reference pyrefly (#166883)
Replacing mypy codumentation in the CONTRIBUTING.MD file with pyrefly references. I have made initial changes to https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch documentation, and will replace the script at the bottom with one tailored to the pyrefly tool as a follow-up.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166883
Approved by: https://github.com/malfet
2025-11-05 19:35:38 +00:00
6c5db82584 [Inductor] Naive foreach autotune support (#162053)
Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code.

Before:
triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 |

After:
triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 |

num_warps=8 default due to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton_combo_kernel.py#L374

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162053
Approved by: https://github.com/mlazos, https://github.com/naromero77amd, https://github.com/jeffdaily

Co-authored-by: Nichols A. Romero <nick.romero@amd.com>
2025-11-05 19:27:23 +00:00
6052a01b71 [BE][Typing][Dynamo] Type torch/_dynamo/variables/dicts.py (#167022)
Provides type coverage to torch/_dynamo/variables/dicts.py

Coverage report:
`mypy torch/_dynamo/variables/dicts.py --linecount-report /tmp/coverage_log`

Compare before to after - we go from 0 lines and 0 funcs covered to 1547 lines and 89 funcs covered

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167022
Approved by: https://github.com/Skylion007
2025-11-05 19:18:35 +00:00
14b153bcf2 include DTensor metadata when pretty-printing fx.Graphs (#166750)
Example below. You need to trace your function with DTensor inputs in order for the graph proxies to run on DTensor (and not the inner local tensor). You also need to run with `tracing_mode="fake"`, or with your own `FakeTensorMode`, to see the nice DTensor printing. If this doesn't feel very ergonomic then maybe we can find some better UX for printing a graph with DTensor in it:

<img width="1446" height="582" alt="image" src="https://github.com/user-attachments/assets/99ea5ce6-1008-4ba5-b58e-542cd34a340b" />

```
import torch
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.distributed.tensor import distribute_tensor, Shard, Replicate
from torch.utils._debug_mode import DebugMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils import _pytree as pytree

world_size = 8
device_type = "cpu"
fake_store = FakeStore()
torch.distributed.init_process_group("fake", store=fake_store, rank=0, world_size=world_size)
device_mesh = torch.distributed.init_device_mesh(device_type, (world_size,))
dim = 128

A = torch.randn(8, dim)
B = torch.randn(dim, dim)
dA = distribute_tensor(A, device_mesh, [Shard(0)]).requires_grad_()
dB = distribute_tensor(B, device_mesh, [Replicate()]).requires_grad_()

def f(dA, dB):
    dy = dA @ dB
    loss = dy.sum()
    loss.backward()
    return dA.grad, dB.grad

# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
gm = make_fx(f, tracing_mode='fake')(dA, dB)
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
gm.graph.eliminate_dead_code()
gm.recompile()
gm.print_readable(colored=True)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166750
Approved by: https://github.com/ezyang, https://github.com/wconstab, https://github.com/Skylion007
2025-11-05 18:58:54 +00:00
641de23c96 ci: Add aarch64 docker builds for modern clang (#166416)
Should enable us to build using some arm optimizations that are only
available on the newest versions of clang.

Signed-off-by: Eli Uriegas <eliuriegas@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166416
Approved by: https://github.com/malfet
2025-11-05 18:55:56 +00:00
89165c0a2b Update triton to 3.5.1 release (#166968)
This includes sm103 https://github.com/triton-lang/triton/pull/8485 fix

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166968
Approved by: https://github.com/Lucaskabela, https://github.com/njriasan
2025-11-05 18:26:34 +00:00
dcc2ba4ca4 Add some code for exploring the space of accessible size/stride configs via plain views (#167076)
We are working on a translation from as_strided to view operations, but
only when the as_strided is representable as a plain view.  A useful
testing utility in this situation is the ability to enumerate all valid
views on an original tensor.  So we have a small test here that shows
it is possible.

To avoid an explosion of states, we don't handle permutes and size=1,
which are degenerate cases (you can always do a single permute and
a series of unsqueezes to get to the final desired state.)

Authored with claude code assistance.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167076
Approved by: https://github.com/albanD
ghstack dependencies: #166868, #166867
2025-11-05 18:25:19 +00:00
ad5c7c20e0 Revert "[cuDNN] Smoke-test runtime cuDNN version matches compile time version in CI (#165922)"
This reverts commit 1d3f5e19da068ec1340db041b7105b287a513578.

Reverted https://github.com/pytorch/pytorch/pull/165922 on behalf of https://github.com/atalman due to Introduces Segfault in linux-jammy-cuda12.8-py3.10-gcc11 ([comment](https://github.com/pytorch/pytorch/pull/165922#issuecomment-3492667312))
2025-11-05 18:13:57 +00:00
c86540f120 Revert "Add model code stack trace to torch.profile (#166677)"
This reverts commit c00696144dae1f02e04ce345480b55e46c7d32a8.

Reverted https://github.com/pytorch/pytorch/pull/166677 on behalf of https://github.com/jeffdaily due to broke rocm ([comment](https://github.com/pytorch/pytorch/pull/166677#issuecomment-3492658160))
2025-11-05 18:11:11 +00:00
c17aa0f113 [ROCm] Enable group gemm through CK (#166334)
Fixes #161366
All the 4 types of dimension matrix are supported.
2d-2d, 2d-3d, 3d-3d, 3d-2d. The corresponding test cases in test_matmul_cuda are working
for both forward and backward pass.
The CK path is enabled for gfx942, gfx950.
ToDo: Need to enable support on gfx90a since the ck kernel used in this commit produces gpu error,
might require a different CK kernel config, based on the profiler result on gfx90a.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166334
Approved by: https://github.com/atalman
2025-11-05 18:03:59 +00:00
4ff068c33a [Code Clean] Replace assert with if statement and raise AssertionError (#166935)
Including:
- `torch/profiler/profiler.py`

Fixes part of #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166935
Approved by: https://github.com/fffrog, https://github.com/albanD
2025-11-05 17:59:16 +00:00
0c7a4a6b48 [Inductor] Fix unbacked float symbol handling in kernel codegen (#166890)
When a fn compiled with `torch.compile` calls `.item()` on a float tensor arg (e.g., for thresholds in `torch.clamp`), the generated triton kernel references an unbacked float symbol (e.g., `zuf0`) that was never added to the kernel's parameter list, causing a compilation error.

Fixes: #166888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166890
Approved by: https://github.com/eellison
2025-11-05 17:50:08 +00:00
63 changed files with 2079 additions and 396 deletions

View File

@ -7,13 +7,13 @@ ENV LC_ALL en_US.UTF-8
ENV LANG en_US.UTF-8
ENV LANGUAGE en_US.UTF-8
ARG DEVTOOLSET_VERSION=11
ARG DEVTOOLSET_VERSION=13
RUN yum -y update
RUN yum -y install epel-release
# install glibc-langpack-en make sure en_US.UTF-8 locale is available
RUN yum -y install glibc-langpack-en
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-toolchain
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
# Just add everything as a safe.directory for git since these will be used in multiple places with git
RUN git config --global --add safe.directory '*'
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
@ -41,6 +41,7 @@ RUN bash ./install_conda.sh && rm install_conda.sh
# Install CUDA
FROM base as cuda
ARG CUDA_VERSION=12.6
ARG DEVTOOLSET_VERSION=13
RUN rm -rf /usr/local/cuda-*
ADD ./common/install_cuda.sh install_cuda.sh
COPY ./common/install_nccl.sh install_nccl.sh
@ -50,7 +51,8 @@ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
# Preserve CUDA_VERSION for the builds
ENV CUDA_VERSION=${CUDA_VERSION}
# Make things in our path by default
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
FROM cuda as cuda12.6
RUN bash ./install_cuda.sh 12.6
@ -68,8 +70,22 @@ FROM cuda as cuda13.0
RUN bash ./install_cuda.sh 13.0
ENV DESIRED_CUDA=13.0
FROM ${ROCM_IMAGE} as rocm
FROM ${ROCM_IMAGE} as rocm_base
ARG DEVTOOLSET_VERSION=13
ENV LC_ALL en_US.UTF-8
ENV LANG en_US.UTF-8
ENV LANGUAGE en_US.UTF-8
# Install devtoolset on ROCm base image
RUN yum -y update && \
yum -y install epel-release && \
yum -y install glibc-langpack-en && \
yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
RUN git config --global --add safe.directory '*'
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
FROM rocm_base as rocm
ARG PYTORCH_ROCM_ARCH
ARG DEVTOOLSET_VERSION=13
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
ADD ./common/install_mkl.sh install_mkl.sh
RUN bash ./install_mkl.sh && rm install_mkl.sh
@ -88,6 +104,7 @@ COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0
# Final step
FROM ${BASE_TARGET} as final
ARG DEVTOOLSET_VERSION=13
COPY --from=openssl /opt/openssl /opt/openssl
COPY --from=patchelf /patchelf /usr/local/bin/patchelf
COPY --from=conda /opt/conda /opt/conda

View File

@ -63,7 +63,7 @@ docker build \
--target final \
--progress plain \
--build-arg "BASE_TARGET=${BASE_TARGET}" \
--build-arg "DEVTOOLSET_VERSION=11" \
--build-arg "DEVTOOLSET_VERSION=13" \
${EXTRA_BUILD_ARGS} \
-t ${tmp_tag} \
$@ \

View File

@ -271,6 +271,16 @@ case "$tag" in
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
;;
pytorch-linux-jammy-aarch64-py3.10-clang21)
ANACONDA_PYTHON_VERSION=3.10
CLANG_VERSION=21
ACL=yes
VISION=yes
OPENBLAS=yes
# snadampal: skipping llvm src build install because the current version
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
;;
pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=11

View File

@ -1 +1 @@
7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd
bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7

View File

@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then
# work around ubuntu apt-get conflicts
sudo apt-get -y -f install
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
if [[ $CLANG_VERSION == 18 ]]; then
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main"
if [[ $CLANG_VERSION -ge 18 ]]; then
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main"
fi
fi

View File

@ -129,7 +129,7 @@ function install_129 {
}
function install_128 {
CUDNN_VERSION=9.10.2.21
CUDNN_VERSION=9.8.0.87
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
# install CUDA 12.8.1 in the same container
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux

View File

@ -10,6 +10,7 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" -
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
OPENBLAS_BUILD_FLAGS="
CC=gcc
NUM_THREADS=128
USE_OPENMP=1
NO_SHARED=0

View File

@ -1 +1 @@
3.5.0
3.5.1

View File

@ -272,18 +272,6 @@ def smoke_test_cuda(
torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version())
print(f"Torch cuDNN version: {torch_cudnn_version}")
torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion()
print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}")
torch_cudnn_runtime_version = tuple(
[int(x) for x in torch_cudnn_version.split(".")]
)
if torch_cudnn_runtime_version != torch_cudnn_compile_version:
raise RuntimeError(
"cuDNN runtime version doesn't match comple version. "
f"Loaded: {torch_cudnn_runtime_version} "
f"Expected: {torch_cudnn_compile_version}"
)
if sys.platform in ["linux", "linux2"]:
torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version())
print(f"Torch nccl; version: {torch_nccl_version}")

View File

@ -79,6 +79,8 @@ jobs:
include:
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
runner: linux.arm64.m7g.4xlarge
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21
runner: linux.arm64.m7g.4xlarge
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
runner: linux.arm64.m7g.4xlarge
timeout-minutes: 600

View File

@ -18,7 +18,7 @@ aspects of contributing to PyTorch.
- [Python Unit Testing](#python-unit-testing)
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
- [Local linting](#local-linting)
- [Running `mypy`](#running-mypy)
- [Running `pyrefly`](#running-pyrefly)
- [C++ Unit Testing](#c-unit-testing)
- [Run Specific CI Jobs](#run-specific-ci-jobs)
- [Merging your Change](#merging-your-change)
@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory.
**Prerequisites**:
The following packages should be installed with `pip`:
- `expecttest` and `hypothesis` - required to run tests
- `mypy` - recommended for linting
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/)
- `pytest` - recommended to run tests more selectively
Running
```
@ -350,15 +350,32 @@ make lint
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
#### Running `mypy`
#### Running `pyrefly`
`mypy` is an optional static type checker for Python. We have multiple `mypy`
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback.
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
**Getting Started with Pyrefly:**
To run type checking on the PyTorch codebase:
```bash
pyrefly check
```
For more detailed error information with summaries:
```bash
pyrefly check --summarize-errors
```
**Learn More:**
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
See [Guide for adding type annotations to
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
for more information on how to set up `mypy` and tackle type annotation
tasks.
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase.
### C++ Unit Testing

View File

@ -388,6 +388,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
#ifndef USE_ROCM
at::Half halpha;
at::Half hbeta;
uint32_t mask = -1;
#endif
void * alpha_ptr = &alpha;
void * beta_ptr = &beta;
@ -427,7 +428,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
if (fp16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
uint32_t mask =
mask =
fp16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
@ -444,7 +445,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
if (bf16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
uint32_t mask =
mask =
bf16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
@ -511,17 +512,41 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
// on Blackwell+, we fake a n > 1 matmul when querying heuristics
// to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance
#ifndef USE_ROCM
const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10;
#else
const bool lie_to_cublaslt = false;
#endif
if (lie_to_cublaslt) {
CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T);
CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc);
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
FakeBdesc.descriptor(),
FakeCdesc.descriptor(),
FakeCdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
} else {
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
}
if (returnedResult == 0) {
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
}

View File

@ -24,7 +24,13 @@ namespace detail {
// radix_sort_pairs doesn't interact with value_t other than to copy
// the data, so we can save template instantiations by reinterpreting
// it as an opaque type.
// We use native integer types for 1/2/4/8-byte values to reduce
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
template <int N> struct alignas(N) OpaqueType { char data[N]; };
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
template<typename key_t, int value_size>
void radix_sort_pairs_impl(

View File

@ -1,5 +1,6 @@
#include <ATen/core/ATen_fwd.h>
#include <c10/core/ScalarType.h>
#include <c10/core/SymInt.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
@ -1710,11 +1711,37 @@ Tensor narrow_symint(
"], but got ",
start,
")")
if (start < 0) {
start = start + cur_size;
auto cond1 = TORCH_GUARD_OR_FALSE(start.sym_lt(0));
auto cond2 = TORCH_GUARD_OR_FALSE(start.sym_ge(0));
if (cond1 || cond2) {
if (cond1) {
start = start + cur_size;
}
TORCH_SYM_CHECK(
start.sym_le(cur_size - length),
"start (",
start,
") + length (",
length,
") exceeds dimension size (",
cur_size,
").");
return at::slice_symint(self, dim, start, start + length, 1);
}
// Unbacked start handling!
// Bounds check without converting start:
// - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start +
// length <= 0
// - If start >= 0: need start + length <= cur_size
auto end = start + length;
TORCH_SYM_CHECK(
start.sym_le(cur_size - length),
(start.sym_lt(0).sym_and((end).sym_le(0)))
.sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))),
"start (",
start,
") + length (",
@ -1722,7 +1749,28 @@ Tensor narrow_symint(
") exceeds dimension size (",
cur_size,
").");
return at::slice_symint(self, dim, start, start + length, 1);
if (TORCH_GUARD_OR_FALSE(end.sym_ne(0))) {
return at::slice_symint(self, dim, start, end, 1);
} else {
// Cannot statically determine the condition due to unbacked.
// This is an interesting situation; when start is negative and
// start + length == 0, slice and narrow do different things.
// i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to
// pass curr_size instead of 0. Otherwise, they would do the same thing.
// This says at runtime: if start < 0 and end == 0, then pass curr_size
// instead of 0.
auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt();
auto result =
at::slice_symint(self, dim, start, end + use_different * cur_size, 1);
// Ensure slice allocated unbacked size is specialized to length.
SymInt new_size = result.sym_size(dim);
TORCH_SYM_CHECK(new_size.sym_eq(length), "")
return result;
}
}
// This overload exists purely for XLA, because they wanted to pass in

View File

@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
});
}
template <typename func_t, typename vec_func_t>
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
template <typename func_t, typename vec_func_t, typename ident_t = double>
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast<ident_t>(0)) {
using traits = binary_function_traits<func_t>;
static_assert(
all_same<

View File

@ -339,33 +339,13 @@ void or_kernel_impl(TensorIterator& iter) {
}
}
template<typename scalar_t>
struct MinValuesOps: public at::native::MinOps<scalar_t> {
using arg_t = typename MinOps<scalar_t>::arg_t;
static scalar_t project(arg_t arg) {
return arg.first;
}
};
void min_values_kernel_impl(TensorIterator& iter) {
// This case is special because of Vectorized<int64_t> does not
// handle upper_bound<int64_t>().
// See: https://github.com/pytorch/pytorch/issues/43254
if (iter.dtype() == kLong || iter.dtype() == kUInt64) {
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
binary_kernel_reduce(
iter,
MinValuesOps<scalar_t>{},
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
}), kLong, kUInt64);
return;
}
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
binary_kernel_reduce_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
static_cast<double>(upper_bound<scalar_t>()));
upper_bound<scalar_t>());
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}

View File

@ -22,6 +22,9 @@
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#ifdef USE_ROCM
#include <ATen/native/hip/ck_group_gemm.h>
#endif
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
@ -666,12 +669,19 @@ std::optional<c10::ScalarType> out_dtype) {
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
use_fast_path = true;
}
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
if (use_fast_path) {
// fast path, no d2h sync needed
#ifndef USE_ROCM
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
#else
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#endif
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
}

View File

@ -0,0 +1,19 @@
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/ScalarType.h>
#include <optional>
namespace at {
namespace hip {
namespace detail {
void group_gemm_ck(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
at::Tensor& out);
} // namespace detail
} // namespace hip
} // namespace at

View File

@ -0,0 +1,462 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/hip/HIPContext.h>
#include <ATen/Tensor.h>
#include <ATen/TensorAccessor.h>
#include <c10/hip/HIPStream.h>
#include <iostream>
#include <vector>
#include <optional>
#include <type_traits>
#include <ck/ck.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/utility/tuple.hpp>
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
namespace at {
namespace hip {
namespace detail {
namespace CkTypes {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
}
template <typename ALayout, typename BLayout, typename DataType>
using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage<
ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor,
DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType,
CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough,
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2,
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
3, 8, 8, 1,
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
3, 8, 8, 1,
1, 1,
S<1,32,1,8>, 4
>;
template <typename ALayout, typename BLayout, typename DataType>
void launch_grouped_bgemm_ck_impl_dispatch(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
at::Tensor& out)
{
using DeviceOp = GroupedGemmKernel<ALayout, BLayout, DataType>;
using PassThrough = CkTypes::PassThrough;
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a_ptrs, p_b_ptrs;
std::vector<void*> p_e_ptrs;
// Note: d_ptrs will be resized after we populate the other vectors
const int mat_a_dim = mat_a.dim();
const int mat_b_dim = mat_b.dim();
const char* a_ptr_base = reinterpret_cast<const char*>(mat_a.data_ptr());
const char* b_ptr_base = reinterpret_cast<const char*>(mat_b.data_ptr());
char* out_ptr_base = reinterpret_cast<char*>(out.data_ptr());
const size_t a_element_size = mat_a.element_size();
const size_t b_element_size = mat_b.element_size();
const size_t out_element_size = out.element_size();
// for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses.
if (mat_a_dim == 2 && mat_b_dim == 2) {
// 2D*2D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
const int M = mat_a.size(0); // number of rows in A
const int N = mat_b.size(1); // number of columns in B
const int K = mat_a.size(1); // columns in A == rows in B
// for 2d*2d input, output is 3d.
// for each group, A columns (K) are sliced. M and N dimensions are not sliced.
for (int i = 0; i < num_groups; ++i) {
int start_k = (i == 0) ? 0 : offs_accessor[i-1];
int end_k = offs_accessor[i];
int k = end_k - start_k;
//K dimension are sliced, hence select stride(1) always.
//K dimension is always dimension 1, regardless of memory layout (row/column major)
const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size;
const void* group_b_ptr;
int ldb;
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset
group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size;
// Leading dimension = distance between rows = stride(0)
ldb = mat_b.stride(0);
} else {
// Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset
group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size;
// Leading dimension = distance between columns = stride(1)
ldb = mat_b.stride(1);
}
// Calculate output pointer for group i in 3D tensor [num_groups, M, N]
// stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
int lda, ldc;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A [M,K]: leading dimension = distance between rows = stride(0)
lda = mat_a.stride(0);
} else {
// Column-major A [M,K]: leading dimension = distance between columns = stride(1)
lda = mat_a.stride(1);
}
// Output is always row-major in 3D tensor [num_groups, M, N]
// Leading dimension for each group's [M,N] slice = stride(1) = N
ldc = out.stride(1);
size_t output_group_bytes = M * N * out_element_size;
void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes;
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(k),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 2 && mat_b_dim == 3) {
// 2D*3D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
// 2d*3d input, output is 2d.
// A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n]
// Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B
const int K = mat_a.size(1); // columns in A
// For 2D-3D case: The output determines N (result width)
const int N = out.size(1); // N is the width of the output tensor
for (int i = 0; i < num_groups; ++i) {
int start_m = (i == 0) ? 0 : offs_accessor[i - 1];
int end_m = offs_accessor[i];
int m = end_m - start_m;
// Skip zero-sized groups but continue processing subsequent groups
if (m <= 0) {
continue;
}
// Select A rows for group i: skip start_m rows
const void* group_a_ptr;
int lda;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
lda = mat_a.stride(0); // distance between rows
} else {
// Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows)
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
// Detect stride pattern for A tensor to determine appropriate lda calculation
bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0));
if (a_is_strided_tensor) {
// For strided A tensors: stride(0) gives the actual leading dimension
lda = mat_a.stride(0);
} else {
// For non-strided A tensors: use the M dimension (total rows)
lda = mat_a.size(0); // Total M dimension for column-major layout
}
}
// Select B batch for group i: B[i, :, :]
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
int ldb;
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed
ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N]
} else {
// Detect stride pattern to determine appropriate ldb calculation
bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2));
if (is_strided_tensor) {
// For strided tensors: stride(2) gives the actual leading dimension
ldb = mat_b.stride(2);
} else {
// For non-strided tensors: use the N dimension
ldb = mat_b.size(1);
}
}
// Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N]
void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size;
int ldc = out.stride(0); // distance between rows in output (should be N for 2D case)
gemm_descs.push_back({
static_cast<ck::index_t>(m),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 3 && mat_b_dim == 3) {
// 3d*3d input, output is 3d - batched matrix multiplication
// A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n]
// Each batch is processed as a separate GEMM operation
const int batch_size = mat_a.size(0);
const int M = mat_a.size(1); // rows in each A matrix
const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed)
// Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout
int N;
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
N = mat_b.size(2);
} else if (mat_b.size(2) == K) {
// B is [batch, n, k] - transposed layout
N = mat_b.size(1);
} else {
TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[",
batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]");
}
for (int i = 0; i < batch_size; ++i) {
// Select A batch for group i: A[i, :, :]
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
// Select B batch for group i: B[i, :, :]
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
// Select output batch for group i: Output[i, :, :]
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
int lda, ldb, ldc;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A: leading dimension = distance between rows = stride(1)
lda = mat_a.stride(1);
} else {
// Column-major A: leading dimension = distance between columns = stride(2)
lda = mat_a.stride(2);
}
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B: leading dimension = distance between rows
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
ldb = mat_b.stride(1); // stride between K rows
} else {
// B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM
ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n])
}
} else {
// Column-major B: leading dimension = distance between columns
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
ldb = mat_b.stride(2); // stride between N columns
} else {
// B is [batch, n, k] - transposed layout
ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n])
}
}
// Output is typically row-major: leading dimension = distance between rows = stride(1)
ldc = out.stride(1);
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 3 && mat_b_dim == 2) {
// 3D*2D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
// 3d*2d input, output is 3d.
// A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both)
// Offset divides N dimension of B, each group gets different slice of B and different batch of A
const int batch_size = mat_a.size(0); // n_groups
const int M = mat_a.size(1); // rows in each A matrix
const int K = mat_a.size(2); // columns in A
// For row-major A and B case: B should be [K, total_N]
const int total_N = mat_b.size(1); // B is [K, total_N] for row-major
for (int i = 0; i < num_groups; ++i) {
int start_n = (i == 0) ? 0 : offs_accessor[i - 1];
int end_n = offs_accessor[i];
int n = end_n - start_n;
// Skip zero-sized groups but continue processing subsequent groups
if (n <= 0) {
continue;
}
// Select A batch for group i: A[i, :, :]
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
// Select B slice for group i: B[:, start_n:end_n] (B[K, total_N])
const void* group_b_ptr;
int ldb;
// Check if B is row-major or column-major
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B [K, total_N]: slice columns [start_n:end_n]
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
ldb = mat_b.stride(0); // distance between rows (should be total_N)
} else {
// Column-major B [K, total_N]: slice columns [start_n:end_n]
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
ldb = mat_b.stride(1); // distance between columns (should be K)
}
// Select output slice for group i: Output[:, start_n:end_n]
void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size;
int lda, ldc;
// Row-major A: leading dimension = distance between rows = stride(1)
lda = mat_a.stride(1);
// Output is row-major: leading dimension = distance between rows = stride(0)
ldc = out.stride(0);
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(n),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc),
{} // --> stride_Ds_
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else {
TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim);
}
TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups");
// Initialize d_ptrs with the correct size
std::vector<std::array<const void*, 0>> d_ptrs(p_a_ptrs.size());
static DeviceOp gemm_instance;
auto argument = gemm_instance.MakeArgument(
p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs,
gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}
);
TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument),
"CK Group GEMM: argument unsupported (shape/strides/type config)");
size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument);
size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument);
void* gemm_arg_buf = nullptr;
void* ws_buf = nullptr;
hipMalloc(&gemm_arg_buf, arg_buf_size);
hipMalloc(&ws_buf, ws_size);
gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf);
gemm_instance.SetWorkSpacePointer(&argument, ws_buf);
auto invoker = gemm_instance.MakeInvoker();
hipStream_t stream = c10::hip::getCurrentHIPStream();
invoker.Run(argument, {stream});
hipFree(gemm_arg_buf);
hipFree(ws_buf);
}
void group_gemm_ck(
const at::Tensor& input_a,
const at::Tensor& input_b_colmajor,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& /*bias*/,
at::Tensor& out)
{
// Detect if input_a is row-major based on stride pattern
bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1);
bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1);
// Ensure tensor A is row-major and contiguous if not already
at::Tensor mat_a = input_a;
if (!a_row_major) {
// If A is not row-major, make it contiguous (row-major)
mat_a = input_a.contiguous();
}
// Force tensor B to be column-major using double transpose trick
// This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape
at::Tensor mat_b = input_b_colmajor;
if (!b_col_major) {
mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1);
}
// For 3D tensors, check the last dimension stride for row-major detection
a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1);
bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1);
if (mat_a.dtype() == at::kBFloat16) {
// bf16 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
}
} else if (mat_a.dtype() == at::kHalf) {
// fp16 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
}
} else if (mat_a.dtype() == at::kFloat) {
// fp32 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
}
} else {
TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype");
}
}
} // namespace detail
} // namespace hip
} // namespace at

View File

@ -1,4 +1,5 @@
#include <c10/core/SymBool.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymNodeImpl.h>
namespace c10 {
@ -111,4 +112,17 @@ bool SymBool::has_hint() const {
return toSymNodeImpl()->has_hint();
}
SymInt SymBool::toSymInt() const {
// If concrete bool, return concrete SymInt
if (auto ma = maybe_as_bool()) {
return SymInt(*ma ? 1 : 0);
}
// Symbolic case: use sym_ite to convert bool to int (0 or 1)
auto node = toSymNodeImpl();
auto one_node = node->wrap_int(1);
auto zero_node = node->wrap_int(0);
return SymInt(node->sym_ite(one_node, zero_node));
}
} // namespace c10

View File

@ -12,6 +12,8 @@
namespace c10 {
class SymInt;
class C10_API SymBool {
public:
/*implicit*/ SymBool(bool b) : data_(b) {}
@ -80,6 +82,10 @@ class C10_API SymBool {
return toSymNodeImplUnowned()->constant_bool();
}
// Convert SymBool to SymInt (0 or 1)
// This is the C++ equivalent of Python's cast_symbool_to_symint_guardless
SymInt toSymInt() const;
bool is_heap_allocated() const {
return ptr_;
}

View File

@ -47,20 +47,10 @@ Tensor sgd_out_of_place(
STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1");
STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1");
int64_t *param_sizes;
int64_t *param_strides;
aoti_torch_get_sizes(param.get(), &param_sizes);
aoti_torch_get_strides(param.get(), &param_strides);
// testing Tensor strides + stride
STD_TORCH_CHECK(param.strides()[0] == param.stride(0));
int32_t param_dtype;
aoti_torch_get_dtype(param.get(), &param_dtype);
int32_t param_device_type;
aoti_torch_get_device_type(param.get(), &param_device_type);
AtenTensorHandle out_ath;
aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath);
auto out = Tensor(out_ath);
auto out = new_empty(param, param.sizes());
sgd_math(
reinterpret_cast<float*>(param.data_ptr()),
@ -344,6 +334,8 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
// Still using a std::vector below even though people can just pass in an
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
// directly.
// This is to test that passing in a std::vector works for BC. (It gets
// implicitly converted to HeaderOnlyArrayRef too!)
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
return new_empty(t, sizes, dtype);

View File

@ -5,8 +5,16 @@ import contextlib
import torch
import torch.distributed as dist
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -426,6 +434,31 @@ class TestDTensorDebugMode(TestCase):
][-1]
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
def test_pretty_print_dtensor_make_fx(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
A = torch.randn(8, 32)
B = torch.randn(32, 32)
dA = distribute_tensor(A, mesh, [Shard(0)]).requires_grad_()
dB = distribute_tensor(B, mesh, [Replicate()]).requires_grad_()
def f(dA, dB):
dy = dA @ dB
loss = dy.sum()
loss.backward()
return dA.grad, dB.grad
# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
gm = make_fx(f, tracing_mode="fake")(dA, dB)
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
gm.graph.eliminate_dead_code()
gm.recompile()
# Colored is nice for actual viewing, not using in this test though
gm_str = gm.print_readable(colored=False, print_output=False)
self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str)
instantiate_parametrized_tests(TestDTensorDebugMode)

View File

@ -5789,6 +5789,229 @@ class NCCLTraceTest(NCCLTraceTestBase):
else:
self.assertTrue("duration_ms" not in t["entries"][0])
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_circular_buffer_full(self, timing_enabled):
"""
Test that when the circular buffer in entries_ is full and we call reset,
then fill the buffer with new entries, dump_entries returns only the new
entries and not the old ones.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill the buffer completely with 10 entries
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify buffer is full with 10 entries
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 10)
# Now reset the flight recorder
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Add new entries after reset - fill the buffer completely again
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify we get exactly 10 new entries, not 20
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 10)
# Verify all entries have the expected properties (from after reset)
# After reset, record IDs should start from 0 again
for i, entry in enumerate(t["entries"]):
self.assertIn("profiling_name", entry)
self.assertEqual(entry["profiling_name"], "nccl:all_reduce")
self.assertIn("record_id", entry)
# Record IDs should be sequential starting from 0 after reset
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_partial_overwrite(self, timing_enabled):
"""
Test that when the circular buffer is full, we reset, and then add fewer
entries than the buffer size, we only get the new entries.
This tests that old entries at the end of the circular buffer are properly
filtered out based on reset_epoch.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill the buffer completely
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Reset the flight recorder
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Add only 3 new entries (much less than buffer size)
for _ in range(3):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify we only get the 3 new entries, not 10
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 3)
# Verify record IDs start from 0 after reset
for i, entry in enumerate(t["entries"]):
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_wraparound(self, timing_enabled):
"""
Test that when we reset in the middle of the circular buffer and then
wrap around, dump_entries correctly returns only entries from the current
epoch in the correct order.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill half the buffer
for _ in range(5):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Reset at this point (reset happens at index 5)
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Now add 8 entries, which will wrap around
# (5->9 fills rest of buffer, then 0->2 wraps around)
for _ in range(8):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Should get exactly 8 entries, properly ordered
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 8)
# Entries should be in chronological order
# The dump_entries() method returns entries from next_ to end, then 0 to next_
# After filtering old entries, we should have 8 entries in order
# Verify record IDs start from 0 after reset (id_ is reset in reset_all())
for i, entry in enumerate(t["entries"]):
self.assertIn("profiling_name", entry)
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_multiple_resets(self, timing_enabled):
"""
Test multiple consecutive resets to ensure each reset properly increments
the epoch and filters out entries from previous epochs.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# First batch: 2 entries
for _ in range(2):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# First reset
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Second batch: 3 entries
for _ in range(3):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Second reset
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Third batch: 4 entries
for _ in range(4):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Should only see the last 4 entries
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 4)
# Verify record IDs start from 0 after the last reset
for i, entry in enumerate(t["entries"]):
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
def check_if_test_is_skipped(fn):
def wrapper(self, *args, **kwargs):

View File

@ -8,21 +8,11 @@ from torch._dynamo.graph_deduplication import apply_graph_deduplication
from torch._dynamo.graph_utils import _detect_cycles
from torch._dynamo.output_graph import FakeRootModule
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
extract_graph_and_tracker,
normalize_gm,
)
from torch._dynamo.testing import extract_graph, extract_graph_and_tracker, normalize_gm
from torch.compiler import allow_in_graph
from torch.utils._ordered_set import OrderedSet
def extract_graph(fn, *args, **kwargs):
backend = AotEagerAndRecordGraphs()
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
return result, backend.graphs, backend.fw_graphs
def graph_str(gm):
return normalize_gm(gm.print_readable(print_output=False))
@ -40,7 +30,7 @@ class GraphDededuplicationTests(TestCase):
super().tearDown()
def run_and_return_graphs(self, fn, *args, **kwargs):
return extract_graph(fn, *args, **kwargs)
return extract_graph(fn, *args, **kwargs)[0:3]
def run_and_get_simple_graph(self):
def fn(x, y):

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: dynamo"]
import unittest
from collections.abc import Sequence
from typing import Any, Callable, Union
from collections.abc import Callable, Sequence
from typing import Any, Union
import torch
import torch._dynamo

View File

@ -1,5 +1,5 @@
# Owner(s): ["module: dynamo"]
from typing import Callable, NamedTuple, Optional
from typing import NamedTuple, Optional, TYPE_CHECKING
import torch
import torch._dynamo
@ -7,6 +7,10 @@ from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter, same
if TYPE_CHECKING:
from collections.abc import Callable
"""
This is an example of a pure-python version of autograd implemented by
@zdevito. It represents a rather challenging test case for TorchDynamo

View File

@ -1,11 +1,13 @@
# Owner(s): ["module: dynamo"]
import functools
import re
import unittest
import weakref
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import extract_graph, remove_trailing_space
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import requires_cuda
@ -15,6 +17,14 @@ requires_multigpu = functools.partial(
)
def remove_file_comment(gm_str: str) -> str:
return remove_trailing_space(re.sub(r"File.*\n", "\n", gm_str))
def print_graph(graph: torch.fx.GraphModule) -> str:
return remove_file_comment(graph.print_readable())
class TestStreams(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
@ -36,9 +46,7 @@ class TestStreams(torch._dynamo.test_case.TestCase):
@requires_cuda
def test_stream_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
def fn(x, y, s1, s2):
with s1:
z1 = torch.add(x, y)
with s2:
@ -47,13 +55,36 @@ class TestStreams(torch._dynamo.test_case.TestCase):
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(), torch.Stream())
expected = fn(*inp)
fn_opt = torch.compile(fn, fullgraph=True)
actual = fn_opt(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
""",
)
@requires_cuda
@unittest.skip("Needs graph break support with annotation context")
def test_stream_context_graph_break(self):
def fn(x, y):
s2 = torch.Stream()
@ -70,9 +101,16 @@ class TestStreams(torch._dynamo.test_case.TestCase):
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
fn_opt = torch.compile(fn)
actual = fn_opt(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(expected, actual)
self.assertEqual(len(fw_graphs), 2)
self.assertExpectedInline(print_graph(fw_graphs[0]), """""")
self.assertExpectedInline(print_graph(fw_graphs[1]), """""")
@requires_cuda
def test_stream_input(self):
@ -155,22 +193,248 @@ class TestStreams(torch._dynamo.test_case.TestCase):
self.assertEqual(s_act, s_exp)
def test_nested_stream_enter_exit(self):
pass
def fn(x, y, s0, s1, s2):
with s1:
with s2:
z1 = torch.add(x, y)
with s0:
z0 = torch.add(x, y)
with s2:
y = 2 + z1
return z0, y
inp = (
torch.ones(2, 2) + 1,
torch.ones(2, 2),
torch.Stream(),
torch.Stream(),
torch.Stream(),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
)
@unittest.skip("Needs graph break support with annotation context")
def test_stream_enter_exit_graph_break(self):
pass
@unittest.skip("Needs graph break support with annotation context")
def test_nested_stream_enter_exit_graph_break(self):
pass
def test_local_stream_enter_exit(self):
pass
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
z = torch.add(x, y)
y = z + 2 + z1
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 1}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': 0}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
""",
)
def test_local_stream_nested_enter_exit(self):
pass
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
s0 = torch.Stream()
with s1:
with s2:
z1 = torch.add(x, y)
with s0:
z0 = torch.add(x, y)
with s2:
y = 2 + z1
return z0, y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': 0}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
)
def test_stream_with_mutation(self):
pass
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
s0 = torch.Stream()
with s1:
with s2:
x.add_(y)
with s0:
z1 = torch.add(y, y)
z0 = torch.add(z1, y)
with s2:
y = 2 + z1
return z0, y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1)
# Annotation: {'stream': 2}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, arg1_1); arg1_1 = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
#
copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
return (add_2, add_3)
""",
)
def test_stream_backward(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
with s0:
y0 = 2 * x + y
with s2:
z = 2 * x + y
return y0, z
inp = (
torch.ones(2, 2, requires_grad=True) + 1,
torch.ones(2, 2, requires_grad=True),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
bw_graphs,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
# Annotation: {'stream': 1}
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
return (add, add_1)
""",
)
actual[1].sum().backward()
self.assertExpectedInline(
print_graph(bw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
# Annotation: {'stream': 0}
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
#
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
#
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)
""",
)
@requires_cuda
def test_run_opcheck(self):

View File

@ -721,6 +721,34 @@ class TestExport(TestCase):
)
self.assertEqual(node.meta["from_node"][-1].graph_id, graph_id)
def test_annotate_on_assert(self):
# nodes added in `apply_runtime_assertion_pass` will be annotated
class M(torch.nn.Module):
def forward(self, x, y):
with torch.fx.traceback.annotate({"moo": 0}):
x = torch.cat([x, x])
b = y.item()
torch._check(b >= x.shape[0])
return x * b
with torch.fx.traceback.preserve_node_meta():
ep = torch.export.export(
M(),
(torch.randn(3), torch.tensor(6)),
dynamic_shapes={"x": {0: Dim("b")}, "y": None},
)
custom_metadata = torch.fx.traceback._get_custom_metadata(ep.module())
self.assertExpectedInline(
str(custom_metadata),
"""\
('call_function', 'cat', {'moo': 0})
('call_function', 'item', {'moo': 0})
('call_function', 'ge_1', {'moo': 0})
('call_function', '_assert_scalar_default', {'moo': 0})
('call_function', 'mul', {'moo': 0})""",
)
@requires_gpu
def test_flex_attention_export(self):
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
@ -6093,26 +6121,19 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
retry_export(
cf_implicitsize(),
(torch.tensor(2), torch.randn(10)),
fixes=[
# Could not guard on data-dependent expression u0 < 0
"torch._check(i >= 0)",
],
fixes=[],
)
class cf_stacklist(torch.nn.Module):
def forward(self, xs, y, fixes):
i = y.item()
eval(fixes)
# instead of xs[i]
return torch.stack(xs, 0).narrow(0, i, 1).squeeze()
retry_export(
cf_stacklist(),
([torch.ones(5) * i for i in range(10)], torch.tensor(2)),
fixes=[
# Could not guard on data-dependent expression u0 < 0
"torch._check(i >= 0)",
],
fixes=[],
)
class cf_tensorsplit(torch.nn.Module):
@ -6166,7 +6187,12 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
class cf_stacklist(torch.nn.Module):
def forward(self, xs, y):
# y.item() is not a local, so we can't suggest a fix
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
if y.item() < 0:
return (
torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze()
)
else:
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
with self.assertRaisesRegex(
error_type,
@ -6196,7 +6222,18 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
def forward(self, xs, y):
box = Box(y.item())
# box.content is not a local, so we can't suggest a fix
return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze()
if box.content < 0:
return (
torch.stack(xs, 0)
.narrow(0, box.content + xs.size(), 1)
.squeeze()
)
else:
return (
torch.stack(xs, 0)
.narrow(0, box.content + xs.size(), 1)
.squeeze()
)
with self.assertRaisesRegex(
error_type,

176
test/test_as_strided.py Normal file
View File

@ -0,0 +1,176 @@
# Owner(s): ["oncall: pt2"]
from collections import deque
from typing import Optional
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
def get_state(t: torch.Tensor) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""Extract (sizes, strides) tuple from a tensor."""
return (tuple(t.size()), tuple(t.stride()))
def enumerate_reachable_states(
initial_size: int,
) -> set[tuple[tuple[int, ...], tuple[int, ...]]]:
"""
Use BFS with DP to enumerate all reachable (size, stride) states from
a 1D contiguous tensor via valid view operations.
We only explore states with offset=0 (you can retroactively change the offset).
We reject states with size=0 or size=1 dimensions as they are degenerate.
"""
# Create initial 1D contiguous tensor
initial_tensor = torch.arange(initial_size)
initial_state = get_state(initial_tensor)
# Map from state to tensor for that state
state_to_tensor: dict[tuple[tuple[int, ...], tuple[int, ...]], torch.Tensor] = {
initial_state: initial_tensor
}
visited: set[tuple[tuple[int, ...], tuple[int, ...]]] = {initial_state}
queue: deque[tuple[tuple[int, ...], tuple[int, ...]]] = deque([initial_state])
while queue:
state = queue.popleft()
t = state_to_tensor[state]
sizes, strides = state
ndim = len(sizes)
def add_state(new_t: torch.Tensor) -> None:
new_state = get_state(new_t)
sizes, strides = new_state
# Skip if has size-0 or size-1 dimensions
if any(s == 0 or s == 1 for s in sizes):
return
# Only accept states where strides are in descending order
if list(strides) != sorted(strides, reverse=True):
return
if new_state not in visited:
visited.add(new_state)
queue.append(new_state)
state_to_tensor[new_state] = new_t
# 1. Unflatten: try factoring each dimension
for dim in range(ndim):
size = sizes[dim]
assert size > 1
# Try all factorizations x * y = size where both x, y >= 2
# We only need to check x up to size // 2 since when x > size // 2,
# y = size // x < 2, which we reject
for x in range(2, size // 2 + 1):
if size % x == 0:
y = size // x
add_state(t.unflatten(dim, (x, y)))
# 2. Slice: exhaustively check all possible slicing parameters
for dim in range(ndim):
size = sizes[dim]
for start in range(size):
for stop in range(start + 1, size + 1):
for step in range(1, size + 1):
slices = [slice(None)] * ndim
slices[dim] = slice(start, stop, step)
add_state(t[tuple(slices)])
# 3. Flatten: merge adjacent dimensions
for dim in range(ndim - 1):
add_state(t.flatten(dim, dim + 1))
return visited
class TestAsStrided(TestCase):
def test_size_10_exhaustive(self) -> None:
"""Test that size 10 produces exactly the expected 54 states."""
expected_states = {
((2,), (1,)),
((2,), (2,)),
((2,), (3,)),
((2,), (4,)),
((2,), (5,)),
((2,), (6,)),
((2,), (7,)),
((2,), (8,)),
((2,), (9,)),
((2, 2), (2, 1)),
((2, 2), (3, 1)),
((2, 2), (3, 2)),
((2, 2), (4, 1)),
((2, 2), (4, 2)),
((2, 2), (4, 3)),
((2, 2), (5, 1)),
((2, 2), (5, 2)),
((2, 2), (5, 3)),
((2, 2), (5, 4)),
((2, 2), (6, 1)),
((2, 2), (6, 2)),
((2, 2), (6, 3)),
((2, 2), (8, 1)),
((2, 2, 2), (4, 2, 1)),
((2, 2, 2), (5, 2, 1)),
((2, 3), (3, 1)),
((2, 3), (4, 1)),
((2, 3), (5, 1)),
((2, 3), (5, 2)),
((2, 3), (6, 1)),
((2, 4), (4, 1)),
((2, 4), (5, 1)),
((2, 5), (5, 1)),
((3,), (1,)),
((3,), (2,)),
((3,), (3,)),
((3,), (4,)),
((3, 2), (2, 1)),
((3, 2), (3, 1)),
((3, 2), (3, 2)),
((3, 2), (4, 1)),
((3, 3), (3, 1)),
((4,), (1,)),
((4,), (2,)),
((4,), (3,)),
((4, 2), (2, 1)),
((5,), (1,)),
((5,), (2,)),
((5, 2), (2, 1)),
((6,), (1,)),
((7,), (1,)),
((8,), (1,)),
((9,), (1,)),
((10,), (1,)),
}
actual_states = enumerate_reachable_states(10)
self.assertEqual(len(actual_states), 54)
self.assertEqual(actual_states, expected_states)
def test_subset_property(self) -> None:
"""
Test that for sizes 2..10, each smaller tensor results in a strict
subset of possible states compared to the next one.
"""
prev_states: Optional[set[tuple[tuple[int, ...], tuple[int, ...]]]] = None
for size in range(2, 11):
current_states = enumerate_reachable_states(size)
if prev_states is not None:
# Check that prev_states is a strict subset of current_states
self.assertTrue(
prev_states.issubset(current_states),
f"States from size {size - 1} are not a subset of size {size}",
)
# Check that it's a strict subset (not equal)
self.assertTrue(
len(prev_states) < len(current_states),
f"States from size {size - 1} should be strictly fewer than size {size}",
)
prev_states = current_states
if __name__ == "__main__":
run_tests()

View File

@ -4401,6 +4401,57 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
self.assertEqual(compiled(a, b), func(a, b))
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_narrow_unbacked_start(self):
def func(x, start, length):
# unbacked start
u0 = start.item()
return torch.narrow(x, 0, u0, length)
compiled_func = torch.compile(func, fullgraph=True, backend="inductor")
x = torch.tensor([1, 2, 3, 4, 5, 6])
# Test cases: (start, length)
test_cases = [
# Negative starts
(-2, 2), # Start from second-to-last element
(-1, 1), # Start from last element
(-3, 3), # Start from third-to-last element
(-6, 2), # Start from beginning (negative)
(-4, 1), # Start from fourth-to-last element
# Positive starts
(0, 2), # Start from beginning
(1, 3), # Start from second element
(2, 2), # Start from third element
(4, 2), # Start near end
# Edge cases
(0, 6), # Full tensor
(0, 1), # Single element from start
(5, 1), # Single element from end
]
for start_val, length in test_cases:
with self.subTest(start=start_val, length=length):
start = torch.tensor([start_val])
# Test with compiled function
result_compiled = compiled_func(x, start, length)
# Test with eager function (expected behavior)
result_eager = func(x, start, length)
# Compare results
self.assertEqual(result_compiled, result_eager)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_narrow_unbacked_start_cpp_wrapper(self):
"""Test narrow with unbacked start with cpp_wrapper"""
self.test_narrow_unbacked_start()
instantiate_parametrized_tests(TestUnbacked)

View File

@ -72,6 +72,7 @@ from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
skipIfTorchDynamo,
skipIfRocm,
)
from torch.testing._internal.jit_utils import JitTestCase
@ -4249,6 +4250,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_stack_trace_augmentation(self):
"""
@ -4304,6 +4306,7 @@ event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_multiple_modules(self):
"""
@ -4347,6 +4350,7 @@ event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_nested_graph_modules(self):
"""

View File

@ -359,6 +359,29 @@ class TestMatmulCuda(InductorTestCase):
self.assertEqual(agrad, a.grad)
self.assertEqual(bgrad, b.grad)
@onlyCUDA
@skipIfRocm
@dtypes(torch.half, torch.bfloat16)
@unittest.skipIf(not SM100OrLater, "cuBLAS integration for batch invariance is only on Blackwell")
@serialTest()
def test_cublas_batch_invariance_blackwell(self, device, dtype):
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (False, False)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (False, False)
with blas_library_context('cublaslt'):
N = 2048
K = 6144
M_max = 32
x = torch.randn(M_max, K, device="cuda", dtype=torch.bfloat16)
w = torch.randn(N, K, device="cuda", dtype=torch.bfloat16).t()
full = x @ w
xx = x[:1]
out = xx @ w
self.assertEqual(full[:1], out, atol=0., rtol=0.)
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
@parametrize("strided", [False, True])
@parametrize("a_row_major", [False, True])
@ -490,8 +513,6 @@ class TestMatmulCuda(InductorTestCase):
@parametrize("b_row_major", [False, True])
@dtypes(torch.bfloat16, torch.float32, torch.float16)
def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype):
if TEST_WITH_ROCM and a_row_major and b_row_major and dtype in [torch.bfloat16, torch.float16]:
self.skipTest("failed using hipblaslt on rocm 6.4.2")
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 64, 4

View File

@ -1864,6 +1864,8 @@ class TestFP8Matmul(TestCase):
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
if torch.version.hip and recipe == "nvfp4":
raise unittest.SkipTest("nvfp4 not supported on ROCm, skipping")
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")

View File

@ -257,34 +257,6 @@ class TestFuzzerCompileIssues(TestCase):
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #163971")
def test_fuzzer_issue_163971(self):
torch.manual_seed(0)
def foo(arg0):
t0 = arg0 # size=(), stride=(), dtype=bfloat16, device=cuda
t1 = torch.softmax(
t0, dim=0
) # size=(), stride=(), dtype=bfloat16, device=cuda
t2 = torch.nn.functional.gelu(
t1
) # size=(), stride=(), dtype=bfloat16, device=cuda
t3 = torch.softmax(
t2, dim=0
) # size=(), stride=(), dtype=bfloat16, device=cuda
output = t3
return output
arg0 = torch.rand([], dtype=torch.bfloat16, device="cuda", requires_grad=True)
out_eager = foo(arg0)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0)
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #164059")
def test_fuzzer_issue_164059(self):
torch.manual_seed(0)

View File

@ -1914,6 +1914,7 @@ class TestSDPAFailureModes(NNTestCase):
q, k, v, None, 0.0, is_causal=True))
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
batch_size = 2**16
query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
@ -1935,6 +1936,7 @@ class TestSDPAFailureModes(NNTestCase):
self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4)
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
@ -1948,6 +1950,7 @@ class TestSDPAFailureModes(NNTestCase):
@largeTensorTest("15GB", "cuda")
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_large_seq_len_uniform_attention(self):
device = torch.device("cuda")
dtype = torch.bfloat16

View File

@ -1,5 +1,5 @@
from typing import Union
from typing_extensions import assert_type, TypeAlias
from typing import TypeAlias, Union
from typing_extensions import assert_type
from torch import randn, Tensor

View File

@ -1,8 +1,9 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code="type-arg"
from collections.abc import Callable
from datetime import timedelta
from enum import Enum
from typing import Any, Callable, Optional, overload, Union
from typing import Any, Optional, overload, Union
import torch
from torch import Tensor

View File

@ -3320,7 +3320,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, SetVariable)
assert obj.is_mutable()
obj.call_method(self, "add", [v], {})
obj.call_method(self, "add", [v], {}) # type: ignore[arg-type]
def SET_UPDATE(self, inst: Instruction) -> None:
v = self.pop()
@ -3329,7 +3329,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, SetVariable)
assert obj.is_mutable()
obj.call_method(self, "update", [v], {})
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
def LIST_APPEND(self, inst: Instruction) -> None:
v = self.pop()
@ -3637,7 +3637,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg].realize()
assert isinstance(obj, ConstDictVariable)
assert obj.is_mutable()
obj.call_method(self, "update", [v], {})
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
DICT_UPDATE = DICT_MERGE

View File

@ -87,6 +87,12 @@ def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-d
return gm.graph, region_tracker # type: ignore[union-attr]
def extract_graph(fn, *args, **kwargs): # type: ignore[no-untyped-def]
backend = AotEagerAndRecordGraphs()
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
return result, backend.graphs, backend.fw_graphs, backend.bw_graphs
def collect_results(
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
) -> list[Any]:

View File

@ -21,9 +21,9 @@ restoring state changes.
import inspect
import sys
import warnings
from collections.abc import Callable, Sequence
from collections.abc import Callable, Sequence, Sized
from contextlib import ExitStack
from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union
from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union
import torch._C
from torch._guards import Guard

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
Dictionary-related variable tracking classes for PyTorch Dynamo.
@ -26,7 +24,7 @@ import inspect
import operator
import types
from collections.abc import Hashable as py_Hashable
from typing import Optional, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING, Union
from torch._subclasses.fake_tensor import is_fake
@ -59,11 +57,13 @@ if TYPE_CHECKING:
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
def was_instancecheck_override(obj):
def was_instancecheck_override(obj: Any) -> bool:
return type(obj).__dict__.get("__instancecheck__", False)
def raise_unhashable(arg, tx=None):
def raise_unhashable(
arg: VariableTracker, tx: Optional["InstructionTranslator"] = None
) -> None:
if tx is None:
from torch._dynamo.symbolic_convert import InstructionTranslator
@ -75,7 +75,7 @@ def raise_unhashable(arg, tx=None):
)
def is_hashable(x):
def is_hashable(x: VariableTracker) -> bool:
# NB - performing isinstance check on a LazVT realizes the VT, accidentally
# inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at
# the underlying value without realizing the VT. Consider updating the
@ -143,7 +143,7 @@ class ConstDictVariable(VariableTracker):
Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
"""
def __init__(self, vt) -> None:
def __init__(self, vt: VariableTracker) -> None:
# We specialize SymNodes
vt = specialize_symnode(vt)
# TODO Temporarily remove to figure out what keys are we breaking on
@ -153,7 +153,7 @@ class ConstDictVariable(VariableTracker):
self.vt = vt
@property
def underlying_value(self):
def underlying_value(self) -> Any:
if (
isinstance(self.vt, variables.LazyVariableTracker)
and not self.vt.is_realized()
@ -178,7 +178,8 @@ class ConstDictVariable(VariableTracker):
elif isinstance(self.vt, variables.FrozenDataClassVariable):
Hashable = ConstDictVariable._HashableTracker
fields_values = {
k: Hashable(v).underlying_value for k, v in self.vt.fields.items()
k: Hashable(v).underlying_value
for k, v in self.vt.fields.items() # type: ignore[attr-defined]
}
return variables.FrozenDataClassVariable.HashWrapper(
self.vt.python_type(), fields_values
@ -187,16 +188,16 @@ class ConstDictVariable(VariableTracker):
# The re module in Python 3.13+ has a dictionary (_cache2) with
# an object as key (`class _ZeroSentinel(int): ...`):
# python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual
return self.vt.value
return self.vt.value # type: ignore[attr-defined,union-attr]
else:
x = self.vt.as_python_constant()
return x
def __hash__(self):
def __hash__(self) -> int:
return hash(self.underlying_value)
@staticmethod
def _eq_impl(a, b):
def _eq_impl(a: Any, b: Any) -> bool:
# TODO: Put this in utils and share it between variables/builtin.py and here
type_a, type_b = type(a), type(b)
if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)):
@ -212,7 +213,7 @@ class ConstDictVariable(VariableTracker):
else:
return a == b
def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
def __eq__(self, other: object) -> bool:
Hashable = ConstDictVariable._HashableTracker
assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), (
type(other)
@ -226,8 +227,8 @@ class ConstDictVariable(VariableTracker):
def __init__(
self,
items: dict[VariableTracker, VariableTracker],
user_cls=dict,
**kwargs,
user_cls: type = dict,
**kwargs: Any,
) -> None:
# .clone() pass these arguments in kwargs but they're recreated a few
# lines below
@ -247,18 +248,22 @@ class ConstDictVariable(VariableTracker):
for x, v in items.items()
)
def make_hashable(key):
def make_hashable(
key: Union[VariableTracker, "ConstDictVariable._HashableTracker"],
) -> "ConstDictVariable._HashableTracker":
return key if isinstance(key, Hashable) else Hashable(key)
dict_cls = self._get_dict_cls_from_user_cls(user_cls)
self.items = dict_cls({make_hashable(x): v for x, v in items.items()})
# need to reconstruct everything if the dictionary is an intermediate value
# or if a pop/delitem was executed
self.should_reconstruct_all = not is_from_local_source(self.source)
self.should_reconstruct_all = (
not is_from_local_source(self.source) if self.source else True
)
self.original_items = items.copy()
self.user_cls = user_cls
def _get_dict_cls_from_user_cls(self, user_cls):
def _get_dict_cls_from_user_cls(self, user_cls: type) -> type:
accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict)
# avoid executing user code if user_cls is a dict subclass
@ -277,10 +282,10 @@ class ConstDictVariable(VariableTracker):
dict_cls = dict
return dict_cls
def as_proxy(self):
def as_proxy(self) -> dict[Any, Any]:
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
def debug_repr(self):
def debug_repr(self) -> str:
return (
"{"
+ ", ".join(
@ -289,20 +294,20 @@ class ConstDictVariable(VariableTracker):
+ "}"
)
def as_python_constant(self):
def as_python_constant(self) -> dict[Any, Any]:
return {
k.vt.as_python_constant(): v.as_python_constant()
for k, v in self.items.items()
}
def keys_as_python_constant(self):
def keys_as_python_constant(self) -> dict[Any, VariableTracker]:
self.install_dict_keys_match_guard()
return {k.vt.as_python_constant(): v for k, v in self.items.items()}
def python_type(self):
def python_type(self) -> type:
return self.user_cls
def __contains__(self, vt) -> bool:
def __contains__(self, vt: VariableTracker) -> bool:
assert isinstance(vt, VariableTracker)
Hashable = ConstDictVariable._HashableTracker
return (
@ -322,13 +327,15 @@ class ConstDictVariable(VariableTracker):
for key, value in self.items.items()
)
def is_new_item(self, value, other):
def is_new_item(
self, value: Optional[VariableTracker], other: VariableTracker
) -> bool:
# compare the id of the realized values if both values are not lazy VTs
if value and value.is_realized() and other.is_realized():
return id(value.realize()) != id(other.realize())
return id(value) != id(other)
def reconstruct_kvs_into_new_dict(self, codegen):
def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None:
# Build a dictionary that contains the keys and values.
num_args = 0
for key, value in self.items.items():
@ -340,7 +347,7 @@ class ConstDictVariable(VariableTracker):
num_args += 1
codegen.append_output(create_instruction("BUILD_MAP", arg=num_args))
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
if self.user_cls is collections.OrderedDict:
# emit `OrderedDict(constructed_dict)`
codegen.add_push_null(
@ -358,19 +365,21 @@ class ConstDictVariable(VariableTracker):
def getitem_const_raise_exception_if_absent(
self, tx: "InstructionTranslator", arg: VariableTracker
):
) -> VariableTracker:
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
raise_observed_exception(KeyError, tx)
return self.items[key]
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
msg = f"Dictionary key {arg.value} not found during tracing"
msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined]
unimplemented_v2(
gb_type="key not found in dict",
context=f"Key {arg.value}",
context=f"Key {arg.value}", # type: ignore[attr-defined]
explanation=msg,
hints=[
"Check if the key exists in the dictionary before accessing it.",
@ -379,13 +388,13 @@ class ConstDictVariable(VariableTracker):
)
return self.items[key]
def maybe_getitem_const(self, arg: VariableTracker):
def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]:
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
return None
return self.items[key]
def realize_key_vt(self, arg: VariableTracker):
def realize_key_vt(self, arg: VariableTracker) -> None:
# Realize the LazyVT on a particular index
assert arg in self
key = ConstDictVariable._HashableTracker(arg)
@ -394,11 +403,13 @@ class ConstDictVariable(VariableTracker):
if isinstance(original_key_vt, variables.LazyVariableTracker):
original_key_vt.realize()
def install_dict_keys_match_guard(self):
def install_dict_keys_match_guard(self) -> None:
if self.source:
install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH))
def install_dict_contains_guard(self, tx, args):
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
# Key guarding - These are the cases to consider
# 1) The dict has been mutated. In this case, we would have already
# inserted a DICT_KEYS_MATCH guard, so we can skip.
@ -439,11 +450,11 @@ class ConstDictVariable(VariableTracker):
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
# NB - Both key and value are LazyVariableTrackers in the beginning. So,
# we have to insert guards when a dict method is accessed. For this to
# be simple, we are conservative and overguard. We skip guard only for
@ -462,7 +473,7 @@ class ConstDictVariable(VariableTracker):
tx, *args, **kwargs
)
tx.output.side_effects.mutation(self)
self.items.update(temp_dict_vt.items)
self.items.update(temp_dict_vt.items) # type: ignore[attr-defined]
return ConstantVariable.create(None)
elif name == "__getitem__":
# Key guarding - Nothing to do. LazyVT for value will take care.
@ -526,7 +537,7 @@ class ConstDictVariable(VariableTracker):
return ConstantVariable.create(len(self.items))
elif name == "__setitem__" and self.is_mutable():
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
self.install_dict_keys_match_guard()
if kwargs or len(args) != 2:
@ -550,7 +561,7 @@ class ConstDictVariable(VariableTracker):
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
if args[0] not in self:
self.install_dict_contains_guard(tx, args)
@ -565,7 +576,7 @@ class ConstDictVariable(VariableTracker):
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
if args[0] not in self:
# missing item, return the default value. Install no DICT_CONTAINS guard.
@ -599,7 +610,7 @@ class ConstDictVariable(VariableTracker):
last = v.value
else:
raise_args_mismatch(tx, name)
k, v = self.items.popitem(last=last)
k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined]
else:
k, v = self.items.popitem()
@ -632,17 +643,17 @@ class ConstDictVariable(VariableTracker):
# NB - Guard on all the keys of the other dict to ensure
# correctness.
args[0].install_dict_keys_match_guard()
dict_vt = args[0]
dict_vt: ConstDictVariable = args[0]
else:
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
self.items.update(dict_vt.items)
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment]
self.items.update(dict_vt.items) # type: ignore[attr-defined]
if has_kwargs:
# Handle kwargs
kwargs = {
kwargs_hashable = {
Hashable(ConstantVariable.create(k)): v
for k, v in kwargs.items()
}
self.items.update(kwargs)
self.items.update(kwargs_hashable)
return ConstantVariable.create(None)
else:
return super().call_method(tx, name, args, kwargs)
@ -656,7 +667,7 @@ class ConstDictVariable(VariableTracker):
)
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
self.install_dict_contains_guard(tx, args)
contains = args[0] in self
@ -671,7 +682,7 @@ class ConstDictVariable(VariableTracker):
)
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
self.install_dict_keys_match_guard()
if kwargs or len(args) > 2:
@ -707,7 +718,7 @@ class ConstDictVariable(VariableTracker):
and "last" in kwargs
and isinstance(kwargs["last"], ConstantVariable)
):
last = kwargs.get("last").value
last = kwargs.get("last").value # type: ignore[union-attr]
key = Hashable(args[0])
self.items.move_to_end(key, last=last)
@ -723,7 +734,7 @@ class ConstDictVariable(VariableTracker):
)
elif name == "__ne__":
return ConstantVariable.create(
not self.call_method(tx, "__eq__", args, kwargs).value
not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined]
)
elif name == "__or__":
if len(args) != 1:
@ -750,14 +761,14 @@ class ConstDictVariable(VariableTracker):
if not istype(
other, (ConstDictVariable, variables.UserDefinedDictVariable)
):
msg = (
err_msg = (
f"unsupported operand type(s) for |: '{self.python_type().__name__}'"
f"and '{other.python_type().__name__}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
raise_observed_exception(TypeError, tx, args=[err_msg])
# OrderedDict overloads __ror__
ts = {self.user_cls, other.user_cls}
ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined]
user_cls = (
collections.OrderedDict
if any(issubclass(t, collections.OrderedDict) for t in ts)
@ -774,8 +785,8 @@ class ConstDictVariable(VariableTracker):
# NB - Guard on all the keys of the other dict to ensure
# correctness.
args[0].install_dict_keys_match_guard()
new_dict_vt.items.update(args[0].items)
args[0].install_dict_keys_match_guard() # type: ignore[attr-defined]
new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined]
return new_dict_vt
elif name == "__ior__":
self.call_method(tx, "update", args, kwargs)
@ -789,11 +800,13 @@ class ConstDictVariable(VariableTracker):
else:
return super().call_method(tx, name, args, kwargs)
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
self.install_dict_keys_match_guard()
return [x.vt for x in self.items.keys()]
def call_obj_hasattr(self, tx, name):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
# dict not allow setting arbitrary attributes. OrderedDict and
# defaultdict allow arbitrary setattr, but not deletion of default attrs
if any(
@ -816,25 +829,25 @@ class ConstDictVariable(VariableTracker):
],
)
def clone(self, **kwargs):
def clone(self, **kwargs: Any) -> VariableTracker:
self.install_dict_keys_match_guard()
return super().clone(**kwargs)
class MappingProxyVariable(VariableTracker):
# proxies to the original dict_vt
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
super().__init__(**kwargs)
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
def python_type(self):
def python_type(self) -> type:
return types.MappingProxyType
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
return self.dv_dict.unpack_var_sequence(tx)
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
# load types.MappingProxyType
if self.source:
msg = (
@ -863,11 +876,11 @@ class MappingProxyVariable(VariableTracker):
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if self.source and tx.output.side_effects.has_existing_dict_mutation():
msg = (
"A dict has been modified while we have an existing mappingproxy object. "
@ -892,7 +905,7 @@ class MappingProxyVariable(VariableTracker):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
if self.python_type() is types.MappingProxyType:
return ConstantVariable.create(name in types.MappingProxyType.__dict__)
return super().call_obj_hasattr(tx, name)
@ -900,35 +913,44 @@ class MappingProxyVariable(VariableTracker):
class NNModuleHooksDictVariable(ConstDictVariable):
# Special class to avoid adding any guards on the nn module hook ids.
def install_dict_keys_match_guard(self):
def install_dict_keys_match_guard(self) -> None:
pass
def install_dict_contains_guard(self, tx, args):
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
pass
class DefaultDictVariable(ConstDictVariable):
def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
def __init__(
self,
items: dict[VariableTracker, VariableTracker],
user_cls: type,
default_factory: Optional[VariableTracker] = None,
**kwargs: Any,
) -> None:
super().__init__(items, user_cls, **kwargs)
assert user_cls is collections.defaultdict
if default_factory is None:
default_factory = ConstantVariable.create(None)
self.default_factory = default_factory
def is_python_constant(self):
def is_python_constant(self) -> bool:
# Return false for unsupported defaults. This ensures that a bad handler
# path is not taken in BuiltinVariable for getitem.
if self.default_factory not in [list, tuple, dict] and not self.items:
return False
return super().is_python_constant()
def debug_repr(self):
def debug_repr(self) -> str:
assert self.default_factory is not None
return (
f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})"
)
@staticmethod
def is_supported_arg(arg):
def is_supported_arg(arg: VariableTracker) -> bool:
if isinstance(arg, variables.BuiltinVariable):
return arg.fn in (list, tuple, dict, set)
else:
@ -942,11 +964,11 @@ class DefaultDictVariable(ConstDictVariable):
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__getitem__":
if len(args) != 1:
raise_args_mismatch(tx, name, "1 args", f"{len(args)} args")
@ -962,13 +984,13 @@ class DefaultDictVariable(ConstDictVariable):
else:
default_var = self.default_factory.call_function(tx, [], {})
super().call_method(
tx, "__setitem__", (args[0], default_var), kwargs
tx, "__setitem__", [args[0], default_var], kwargs
)
return default_var
else:
return super().call_method(tx, name, args, kwargs)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen") -> None:
# emit `defaultdict(default_factory, new_dict)`
codegen.add_push_null(
lambda: codegen.extend_output(
@ -994,40 +1016,48 @@ class SetVariable(ConstDictVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
**kwargs: Any,
) -> None:
# pyrefly: ignore[bad-assignment]
items = dict.fromkeys(items, SetVariable._default_value())
# pyrefly: ignore[bad-argument-type]
super().__init__(items, **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
if not self.items:
return "set()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
@property
def set_items(self):
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
return set(self.items.keys())
@staticmethod
def _default_value():
def _default_value() -> VariableTracker:
# Variable to fill in he keys of the dictionary
return ConstantVariable.create(None)
def as_proxy(self):
def as_proxy(self) -> Any:
return {k.vt.as_proxy() for k in self.set_items}
def python_type(self):
def python_type(self) -> type:
return set
def as_python_constant(self):
def as_python_constant(self) -> Any:
return {k.vt.as_python_constant() for k in self.set_items}
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.foreach([x.vt for x in self.set_items])
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
def _fast_set_method(self, tx, fn, args, kwargs):
def _fast_set_method(
self,
tx: "InstructionTranslator",
fn: Any,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
try:
res = fn(
*[x.as_python_constant() for x in [self, *args]],
@ -1037,15 +1067,16 @@ class SetVariable(ConstDictVariable):
raise_observed_exception(
type(exc), tx, args=list(map(ConstantVariable.create, exc.args))
)
# pyrefly: ignore[unbound-name]
return VariableTracker.build(tx, res)
def call_method(
self,
tx,
name,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
) -> VariableTracker:
# We forward the calls to the dictionary model
from ..utils import check_constant_args
@ -1065,10 +1096,10 @@ class SetVariable(ConstDictVariable):
return self._fast_set_method(tx, getattr(py_type, name), args, kwargs)
if name == "__init__":
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs)
tx.output.side_effects.mutation(self)
self.items.clear()
self.items.update(temp_set_vt.items)
self.items.update(temp_set_vt.items) # type: ignore[attr-defined]
return ConstantVariable.create(None)
elif name == "add":
if kwargs or len(args) != 1:
@ -1079,7 +1110,7 @@ class SetVariable(ConstDictVariable):
f"{len(args)} args and {len(kwargs)} kwargs",
)
name = "__setitem__"
args = (args[0], SetVariable._default_value())
args = [args[0], SetVariable._default_value()]
elif name == "pop":
if kwargs or args:
raise_args_mismatch(
@ -1090,12 +1121,14 @@ class SetVariable(ConstDictVariable):
)
# Choose an item at random and pop it via the Dict.pop method
try:
result = self.set_items.pop().vt
result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment]
except KeyError as e:
raise_observed_exception(
KeyError, tx, args=list(map(ConstantVariable.create, e.args))
)
super().call_method(tx, name, (result,), kwargs)
# pyrefly: ignore[unbound-name]
super().call_method(tx, name, [result], kwargs)
# pyrefly: ignore[unbound-name]
return result
elif name == "isdisjoint":
if kwargs or len(args) != 1:
@ -1217,6 +1250,7 @@ class SetVariable(ConstDictVariable):
f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
assert m is not None
return self.call_method(tx, m, args, kwargs)
elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"):
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
@ -1230,29 +1264,34 @@ class SetVariable(ConstDictVariable):
"__ixor__": "symmetric_difference_update",
"__isub__": "difference_update",
}.get(name)
assert m is not None
self.call_method(tx, m, args, kwargs)
return self
elif name == "__eq__":
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
return ConstantVariable.create(False)
r = self.call_method(tx, "symmetric_difference", args, kwargs)
return ConstantVariable.create(len(r.set_items) == 0)
return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined]
elif name in cmp_name_to_op_mapping:
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
return ConstantVariable.create(NotImplemented)
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
)
return super().call_method(tx, name, args, kwargs)
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
raise RuntimeError("Illegal to getitem on a set")
def install_dict_keys_match_guard(self):
def install_dict_keys_match_guard(self) -> None:
# Already EQUALS_MATCH guarded
pass
def install_dict_contains_guard(self, tx, args):
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
super().install_dict_contains_guard(tx, args)
@ -1260,27 +1299,27 @@ class FrozensetVariable(SetVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
**kwargs: Any,
) -> None:
super().__init__(items, **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
if not self.items:
return "frozenset()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
@property
def set_items(self):
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
return self.items.keys()
def python_type(self):
def python_type(self) -> type:
return frozenset
def as_python_constant(self):
def as_python_constant(self) -> Any:
return frozenset({k.vt.as_python_constant() for k in self.set_items})
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.foreach([x.vt for x in self.set_items])
codegen.add_push_null(
lambda: codegen.extend_output(
@ -1293,11 +1332,11 @@ class FrozensetVariable(SetVariable):
def call_method(
self,
tx,
name,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
) -> VariableTracker:
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a frozenset")
elif name == "__init__":
@ -1316,7 +1355,7 @@ class FrozensetVariable(SetVariable):
"symmetric_difference",
):
r = super().call_method(tx, name, args, kwargs)
return FrozensetVariable(r.items)
return FrozensetVariable(r.items) # type: ignore[attr-defined]
return super().call_method(tx, name, args, kwargs)
@ -1324,11 +1363,11 @@ class DictKeySetVariable(SetVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
**kwargs: Any,
) -> None:
super().__init__(items, **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
if not self.items:
return "dict_keys([])"
else:
@ -1338,33 +1377,35 @@ class DictKeySetVariable(SetVariable):
+ "])"
)
def install_dict_keys_match_guard(self):
def install_dict_keys_match_guard(self) -> None:
# Already EQUALS_MATCH guarded
pass
def install_dict_contains_guard(self, tx, args):
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
# Already EQUALS_MATCH guarded
pass
@property
def set_items(self):
def set_items(self) -> Any:
return self.items
def python_type(self):
def python_type(self) -> type:
return dict_keys
def as_python_constant(self):
def as_python_constant(self) -> Any:
return dict.fromkeys(
{k.vt.as_python_constant() for k in self.set_items}, None
).keys()
def call_method(
self,
tx,
name,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
) -> VariableTracker:
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a dict_keys")
return super().call_method(tx, name, args, kwargs)
@ -1379,42 +1420,47 @@ class DictViewVariable(VariableTracker):
kv: Optional[str] = None
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
super().__init__(**kwargs)
assert self.kv in ("keys", "values", "items")
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
@property
def view_items(self):
def view_items(self) -> Any:
assert self.kv is not None
return getattr(self.dv_dict.items, self.kv)()
@property
def view_items_vt(self):
def view_items_vt(self) -> list[VariableTracker]:
# Returns an iterable of the unpacked items
# Implement in the subclasses
raise NotImplementedError
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
return self.view_items_vt
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
assert self.kv is not None
codegen(self.dv_dict)
codegen.load_method(self.kv)
codegen.call_method(0)
def call_obj_hasattr(self, tx, name):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
assert self.kv is not None
if name in self.python_type().__dict__:
return ConstantVariable.create(True)
return ConstantVariable.create(False)
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__len__":
return self.dv_dict.call_method(tx, name, args, kwargs)
elif name == "__iter__":
@ -1428,24 +1474,24 @@ class DictKeysVariable(DictViewVariable):
kv = "keys"
@property
def set_items(self):
def set_items(self) -> set[VariableTracker]:
return set(self.view_items)
@property
def view_items_vt(self):
def view_items_vt(self) -> list[VariableTracker]:
# Returns an iterable of the unpacked items
return [x.vt for x in self.view_items]
def python_type(self):
def python_type(self) -> type:
return dict_keys
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__contains__":
return self.dv_dict.call_method(tx, name, args, kwargs)
elif name in (
@ -1460,13 +1506,13 @@ class DictKeysVariable(DictViewVariable):
):
# These methods always returns a set
m = getattr(self.set_items, name)
r = m(args[0].set_items)
r = m(args[0].set_items) # type: ignore[attr-defined]
return SetVariable(r)
if name in cmp_name_to_op_mapping:
if not isinstance(args[0], (SetVariable, DictKeysVariable)):
return ConstantVariable.create(NotImplemented)
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
)
return super().call_method(tx, name, args, kwargs)
@ -1476,10 +1522,10 @@ class DictValuesVariable(DictViewVariable):
kv = "values"
@property
def view_items_vt(self):
def view_items_vt(self) -> list[VariableTracker]:
return list(self.view_items)
def python_type(self):
def python_type(self) -> type:
return dict_values
@ -1487,14 +1533,20 @@ class DictItemsVariable(DictViewVariable):
kv = "items"
@property
def view_items_vt(self):
def view_items_vt(self) -> list[VariableTracker]:
# Returns an iterable of the unpacked items
return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items]
def python_type(self):
def python_type(self) -> type:
return dict_items
def call_method(self, tx, name, args, kwargs):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
# TODO(guilhermeleobas): This should actually check if args[0]
# implements the mapping protocol.
if name == "__eq__":

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import hashlib
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
import sympy # noqa: TC002
@ -17,6 +17,8 @@ from .simd import SIMDKernel, SIMDScheduling
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ..ir import IRNode
from ..scheduler import BaseSchedulerNode

View File

@ -627,7 +627,7 @@ class ComboKernel(Kernel):
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
num_warps={self.num_warps},
filename=__file__,
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)

View File

@ -2063,7 +2063,8 @@ class PythonWrapperCodegen(CodeGen):
neg = self.codegen_sizevar(
sympy.Max(0, sympy.Min(x + node.size, node.size))
)
return f"{pos} if {x} >= 0 else {neg}"
x_cond = self.codegen_sizevar(x)
return f"{pos} if {x_cond} >= 0 else {neg}"
def codegen_with_step(start_var, end_var, step):
if step == 1:

View File

@ -360,7 +360,7 @@ def estimate_nccl_collective_runtime_from_fx_node(
fx_node: torch.fx.Node,
override_size: Optional[int] = None,
# TODO(ivankobzarev): NCCL estimator sometimes fail unexpectedly, enable back after fix.
use_nccl_estimator: bool = False,
use_nccl_estimator: bool = True,
) -> float:
"""
Returns estimated NCCL collective runtime in nanoseconds (ns).

View File

@ -1,6 +1,6 @@
import os
from collections.abc import Callable
from functools import cache, partial
from typing import Callable
import torch
from torch._environment import is_fbcode

View File

@ -3586,13 +3586,24 @@ def user_autotune(
)
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
def foreach(triton_meta, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
configs = []
# Naive autotuning path for num_warps
if not (
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
):
configs.append(triton.Config({}, num_stages=1, num_warps=8))
else:
for warps in [1, 2, 4, 8]:
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
return cached_autotune(
None,
[triton.Config({}, num_stages=1, num_warps=num_warps)],
configs,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,

View File

@ -52,26 +52,7 @@ __all__ = [
"MemRecordsAcc",
]
try:
# Available in Python >= 3.2
from contextlib import ContextDecorator as _ContextDecorator
except ImportError:
import functools
class _ContextDecorator: # type: ignore[no-redef]
def __enter__(self):
raise NotImplementedError
def __exit__(self, exc_type, exc_val, exc_tb):
raise NotImplementedError
def __call__(self, func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
with self:
return func(*args, **kwargs)
return wrapped
from contextlib import ContextDecorator
# global python state - whether profiler is currently enabled
@ -744,8 +725,7 @@ class profile:
return all_function_events
# pyrefly: ignore [invalid-inheritance]
class record_function(_ContextDecorator):
class record_function(ContextDecorator):
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
Label will only appear if CPU activity tracing is enabled.

View File

@ -108,12 +108,14 @@ struct FlightRecorder {
capture_cpp_stack_ = getCvarBool(
{"TORCH_FR_CPP_STACK", "TORCH_NCCL_TRACE_CPP_STACK"}, false);
enabled_ = max_entries_ > 0;
reset_epoch_start_idx_[0] = 0;
}
struct Entry {
size_t id_; // incremented id in the trace buffer
// used to figure out where in the circular entries
// buffer this entry will be located to
// update state information
size_t reset_epoch_; // epoch when this entry was created
size_t pg_id_;
std::tuple<std::string, std::string> pg_name_; // <group_name, group_desc>
@ -183,11 +185,34 @@ struct FlightRecorder {
size_t max_entries_ = 0;
size_t next_ = 0;
size_t id_ = 0;
size_t reset_epoch_ = 0;
std::unordered_map<size_t, size_t>
reset_epoch_start_idx_; // maps reset_epoch to the idx where it starts
std::map<size_t, std::shared_ptr<ProcessGroupStatus>> all_pg_status_;
std::map<std::tuple<std::string, std::string>, std::vector<uint64_t>>
pg_name_to_ranks_;
std::string comm_lib_version_;
struct TraceIdentifier {
std::optional<size_t> id;
std::optional<size_t> reset_epoch;
};
TraceIdentifier recordWithResetEnabled(
size_t pg_id,
const std::tuple<std::string, std::string>& pg_name,
size_t collective_seq_id,
size_t p2p_seq_id,
size_t op_id,
std::string profiling_name,
const std::vector<at::Tensor>& inputs,
const std::vector<at::Tensor>& outputs,
EventType* start,
EventType* end,
std::chrono::milliseconds timeout_ms,
std::shared_ptr<ProcessGroupStatus> pg_status,
bool isP2P);
std::optional<size_t> record(
size_t pg_id,
const std::tuple<std::string, std::string>& pg_name,
@ -213,8 +238,16 @@ struct FlightRecorder {
std::vector<Entry> dump_entries();
// Returns the entry with the given id, if it exists. Otherwise, returns
// std::nullopt.
// Returns the index in entries_ for the given id and reset_epoch.
// Caller must hold mutex_lock before calling this method.
size_t getIdxFromId(size_t id, size_t reset_epoch) const;
// Returns the entry with the given id and reset_epoch, if it exists.
// Otherwise, returns std::nullopt.
TORCH_API std::optional<Entry> getEntry(
std::optional<size_t> id,
std::optional<size_t> reset_epoch);
TORCH_API std::optional<Entry> getEntry(std::optional<size_t> id);
/*
@ -227,6 +260,11 @@ struct FlightRecorder {
never hang. (timing must also be enabled for compute_duration - see
TORCH_NCCL_ENABLE_TIMING).
*/
TORCH_API void retire_id(
std::optional<size_t> id,
std::optional<size_t> reset_epoch,
bool compute_duration = true);
TORCH_API void retire_id(
std::optional<size_t> id,
bool compute_duration = true);

View File

@ -53,8 +53,41 @@ std::optional<size_t> FlightRecorder<EventType>::record(
std::chrono::milliseconds timeout_ms,
std::shared_ptr<ProcessGroupStatus> pg_status,
bool isP2P) {
auto result = recordWithResetEnabled(
pg_id,
pg_name,
collective_seq_id,
p2p_seq_id,
op_id,
std::move(profiling_name),
inputs,
outputs,
start,
end,
timeout_ms,
std::move(pg_status),
isP2P);
return result.id;
}
template <typename EventType>
typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
recordWithResetEnabled(
size_t pg_id,
const std::tuple<std::string, std::string>& pg_name,
size_t collective_seq_id,
size_t p2p_seq_id,
size_t op_id,
std::string profiling_name,
const std::vector<at::Tensor>& inputs,
const std::vector<at::Tensor>& outputs,
EventType* start,
EventType* end,
std::chrono::milliseconds timeout_ms,
std::shared_ptr<ProcessGroupStatus> pg_status,
bool isP2P) {
if (!enabled_) {
return std::nullopt;
return TraceIdentifier{std::nullopt, std::nullopt};
}
if (all_pg_status_.find(pg_id) == all_pg_status_.end()) {
// Current pg_status is not in FR.
@ -64,8 +97,13 @@ std::optional<size_t> FlightRecorder<EventType>::record(
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
std::lock_guard<std::mutex> guard(mutex_);
TORCH_CHECK(
reset_epoch_start_idx_.find(reset_epoch_) !=
reset_epoch_start_idx_.end());
auto te = Entry{
id_,
reset_epoch_,
pg_id,
pg_name,
collective_seq_id,
@ -104,15 +142,20 @@ std::optional<size_t> FlightRecorder<EventType>::record(
te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end());
}
const auto next = next_++;
if (entries_.size() < max_entries_) {
entries_.emplace_back(std::move(te));
} else {
entries_[next_++] = std::move(te);
if (next_ == max_entries_) {
next_ = 0;
}
entries_[next] = std::move(te);
}
return id_++;
if (next_ == max_entries_) {
next_ = 0;
}
const auto id = id_++;
return TraceIdentifier{id, reset_epoch_};
}
template <typename EventType>
@ -163,15 +206,20 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
std::vector<Entry> result;
{
std::lock_guard<std::mutex> guard(mutex_);
result.reserve(entries_.size());
result.insert(
result.end(),
// Filter entries during insertion - only keep entries from current epoch
auto filter = [this](const Entry& e) {
return e.reset_epoch_ == reset_epoch_;
};
std::copy_if(
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
entries_.end());
result.insert(
result.end(),
entries_.end(),
std::back_inserter(result),
filter);
std::copy_if(
entries_.begin(),
entries_.begin() + static_cast<std::ptrdiff_t>(next_));
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
std::back_inserter(result),
filter);
}
// query any remaining events
for (auto& r : result) {
@ -182,28 +230,47 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
}
template <typename EventType>
// Returns the entry with the given id, if it exists. Otherwise, returns
// std::nullopt.
// Returns the index in entries_ for the given id and reset_epoch.
// Caller must hold mutex_lock before calling this method.
size_t FlightRecorder<EventType>::getIdxFromId(size_t id, size_t reset_epoch)
const {
// Look up the starting idx for the given reset epoch
auto it = reset_epoch_start_idx_.find(reset_epoch);
TORCH_CHECK(it != reset_epoch_start_idx_.end());
// Calculate idx based on where the epoch started
return (it->second + id) % max_entries_;
}
template <typename EventType>
// Returns the entry with the given id and reset_epoch, if it exists. Otherwise,
// returns std::nullopt.
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
EventType>::getEntry(std::optional<size_t> id) {
if (!enabled_ || !id) {
EventType>::
getEntry(std::optional<size_t> id, std::optional<size_t> reset_epoch) {
if (!enabled_ || !id || !reset_epoch) {
return std::nullopt;
}
std::unique_lock<std::mutex> guard(mutex_);
Entry entry = entries_.at(*id % max_entries_);
if (entry.id_ == *id) {
Entry entry = entries_.at(getIdxFromId(*id, *reset_epoch));
if (entry.id_ == *id && entry.reset_epoch_ == *reset_epoch) {
return entry;
} else {
return std::nullopt;
}
return std::nullopt;
}
template <typename EventType>
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
EventType>::getEntry(std::optional<size_t> id) {
return getEntry(id, 0);
}
template <typename EventType>
void FlightRecorder<EventType>::retire_id(
std::optional<size_t> id,
std::optional<size_t> reset_epoch,
bool compute_duration) {
if (!enabled_ || !id) {
if (!enabled_ || !id || !reset_epoch) {
return;
}
@ -214,8 +281,8 @@ void FlightRecorder<EventType>::retire_id(
std::unique_lock<std::mutex> guard(mutex_);
Entry* entry = &entries_.at(*id % max_entries_);
if (entry->id_ == *id) {
Entry* entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
if (entry->id_ == *id && entry->reset_epoch_ == *reset_epoch) {
update_state(*entry);
if (compute_duration) {
@ -237,8 +304,8 @@ void FlightRecorder<EventType>::retire_id(
guard.lock();
// Refresh the entry pointer, see if the entry has been overwritten
entry = &entries_.at(*id % max_entries_);
if (entry->id_ != *id) {
entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
if (!(entry->id_ == *id && entry->reset_epoch_ == *reset_epoch)) {
LOG(INFO) << "retire_id abandoned for id " << *id
<< ", event was overwritten while waiting to compute duration.";
return;
@ -249,12 +316,23 @@ void FlightRecorder<EventType>::retire_id(
}
}
template <typename EventType>
void FlightRecorder<EventType>::retire_id(
std::optional<size_t> id,
bool compute_duration) {
retire_id(id, 0, compute_duration);
}
template <typename EventType>
void FlightRecorder<EventType>::reset_all() {
std::lock_guard<std::mutex> guard(mutex_);
next_ = 0;
id_ = 0;
entries_.clear();
if (!entries_.empty()) {
// Soft delete: increment epoch to mark all existing entries as old
// Store where the new epoch starts in the circular buffer
reset_epoch_++;
reset_epoch_start_idx_[reset_epoch_] = next_;
id_ = 0;
}
}
template <typename EventType>

View File

@ -708,7 +708,8 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
// TODO: We need to have numel of tensors for gloo as well.
pgStatus_->lastCompletedNumelIn = 0;
pgStatus_->lastCompletedNumelOut = 0;
FlightRecorder<c10::Event>::get()->retire_id(work->trace_id_, false);
FlightRecorder<c10::Event>::get()->retire_id(
work->trace_id_, work->trace_reset_epoch_, false);
lock.lock();
workInProgress_[workerIndex].reset();
}
@ -780,7 +781,7 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
pgStatus_->lastEnqueuedNumelOut = 0;
// using c10d::FlightRecorder;
// TODO: We need to have a way to use c10::Event inside gloo as well.
work->trace_id_ = FlightRecorder<c10::Event>::get()->record(
auto traceId = FlightRecorder<c10::Event>::get()->recordWithResetEnabled(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
collectiveCounter_,
@ -795,6 +796,8 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
work->getTimeout(),
pgStatus_,
false);
work->trace_id_ = traceId.id;
work->trace_reset_epoch_ = traceId.reset_epoch;
workQueue_.push_back(std::move(work));
lock.unlock();

View File

@ -99,6 +99,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
// unique id used to tell the trace buffer that this
// work has completed
std::optional<uint64_t> trace_id_;
std::optional<uint64_t> trace_reset_epoch_;
std::shared_ptr<gloo::Context> context_;
const std::chrono::milliseconds timeout_;

View File

@ -575,6 +575,7 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
futureWorkResult_(w.futureWorkResult_),
timingEnabled_(w.timingEnabled_),
trace_id_(w.trace_id_),
trace_reset_epoch_(w.trace_reset_epoch_),
distDebugLevel_(w.distDebugLevel_) {
exception_ = w.exception_;
}
@ -704,9 +705,9 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout(
// Print the traceback of the collective at call time
std::string ProcessGroupNCCL::WorkNCCL::getTraceback() const {
// First step we get the corresponding record entry from FR, based on work's
// trace_id_
// trace_id_ and trace_reset_epoch_
std::optional<FlightRecorderCUDA::Entry> entry =
FlightRecorderCUDA::get()->getEntry(trace_id_);
FlightRecorderCUDA::get()->getEntry(trace_id_, trace_reset_epoch_);
if (entry.has_value()) {
auto entryVal = entry.value();
// Get stack trace from FR entry, in string format
@ -2394,7 +2395,8 @@ void ProcessGroupNCCL::Watchdog::runLoop() {
pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_);
pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_;
pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_;
FlightRecorderCUDA::get()->retire_id(work.trace_id_, true);
FlightRecorderCUDA::get()->retire_id(
work.trace_id_, work.trace_reset_epoch_, true);
if (pg_->onCompletionHook_) {
// Move Work object to completedWorkList_ to be consumed by the hook
// thread
@ -3360,7 +3362,7 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
// these objects to the Work because it has implications for keeping those
// tensors alive longer and adds overhead when copying Work objects
// between threads
r->trace_id_ = FlightRecorderCUDA::get()->record(
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
seqCollective_,
@ -3374,6 +3376,8 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
options_->timeout,
pgStatus_,
isP2P);
r->trace_id_ = traceId.id;
r->trace_reset_epoch_ = traceId.reset_epoch;
}
return r;
}
@ -3593,6 +3597,7 @@ float ProcessGroupNCCL::endTimeEstimate() {
#ifdef NCCL_SIM_INFO_INITIALIZER
ncclSimInfo_t simInfo = NCCL_SIM_INFO_INITIALIZER;
C10D_NCCL_CHECK(ncclGroupSimulateEnd(&simInfo), std::nullopt);
--ncclActiveGroupCounter_;
return simInfo.estimatedTime;
#else
TORCH_CHECK(
@ -3676,7 +3681,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// later in endCoalescing we record a 'coalesced' Work which has
// timing/state updates via watchdog thread, but lacks op metadata such as
// input/output sizes and profilingTitle per-op in the group.
FlightRecorderCUDA::get()->record(
FlightRecorderCUDA::get()->recordWithResetEnabled(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
seqCollective_,
@ -4168,7 +4173,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
// TODO(whc) because we don't pass output {tensor} to initWork, we tell
// initWork to not record, and then we manually call record passing all the
// information it wants.
work->trace_id_ = FlightRecorderCUDA::get()->record(
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
seqCollective_,
@ -4182,6 +4187,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
options_->timeout,
pgStatus_,
/*isP2P=*/true);
work->trace_id_ = traceId.id;
work->trace_reset_epoch_ = traceId.reset_epoch;
}
// Only check for NaN for send ops, for recv ops `tensor` can be a random

View File

@ -505,6 +505,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// unique id used to tell the trace buffer that this
// work has completed
std::optional<uint64_t> trace_id_;
std::optional<uint64_t> trace_reset_epoch_;
DebugLevel distDebugLevel_;
friend class ProcessGroupNCCL;
};

View File

@ -4,6 +4,7 @@
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/macros/Macros.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
#include <torch/headeronly/util/shim_utils.h>
#include <climits>
#include <memory>
@ -13,6 +14,7 @@
HIDDEN_NAMESPACE_BEGIN(torch, stable)
using accelerator::DeviceIndex;
using torch::headeronly::IntHeaderOnlyArrayRef;
using torch::headeronly::ScalarType;
// The torch::stable::Tensor class is a highlevel C++ wrapper around
@ -93,6 +95,32 @@ class Tensor {
return numel;
}
// note: this API is, for all intents and purposes, the same as the one in
// TensorBase.h: it returns a borrowed reference of the dimension sizes of
// a Tensor.
//
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
// which has slightly less functionality than a regular IntArrayRef. See
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
IntHeaderOnlyArrayRef sizes() const {
int64_t* sizes;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(ath_.get(), &sizes));
return IntHeaderOnlyArrayRef(sizes, dim());
}
// note: this API is, for all intents and purposes, the same as the one in
// TensorBase.h: it returns a borrowed reference of the strides of a
// Tensor.
//
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
// which has slightly less functionality than a regular IntArrayRef. See
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
IntHeaderOnlyArrayRef strides() const {
int64_t* strides;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(ath_.get(), &strides));
return IntHeaderOnlyArrayRef(strides, dim());
}
// note: this is a subset of the original TensorBase API. It takes no
// arguments whereas the original API takes in a kwarg of memory format.
// Here, we assume the default contiguous memory format.

View File

@ -1,9 +1,8 @@
import functools
import math
import operator
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from datetime import timedelta
from typing import Callable
import torch
from torch._C import ScriptObject

View File

@ -10,6 +10,7 @@ from ._context_parallel._attention import (
_enable_context_parallel_dispatcher,
_is_causal_behavior,
_RotateMethod,
_templated_ring_attention,
context_parallel,
context_parallel_unshard,
set_rotate_method,
@ -22,6 +23,7 @@ from ._context_parallel._load_balancer import (
)
# TODO(fegin): add deprecation message once the final interfaces are concluded.
__all__ = [
"_CausalBehavior",
"_context_parallel_shard",
@ -31,6 +33,7 @@ __all__ = [
"_enable_context_parallel_dispatcher",
"_is_causal_behavior",
"_RotateMethod",
"_templated_ring_attention",
"context_parallel",
"context_parallel_unshard",
"set_rotate_method",

View File

@ -547,6 +547,7 @@ def rebind_unbacked(
assert shape_env is not None
for raw_u0, path in bindings.items():
u1 = pytree.key_get(result, path)
# Sometimes, things were previously unbacked bindings become constants.
# There are two situations this can happen.
#
@ -602,7 +603,23 @@ def rebind_unbacked(
if u1.node.hint is not None:
continue
raw_u1 = u1.node.expr
# unbacked symbols bindings might be replaced to other backed or
# unbacked replacements.
#
# Example:
# u = x.item()
# torch._check(u == 5)
#
# The safest approach is to retrieve raw_u1 from u1.node._expr
# and perform the rebinding on the original unbacked symbol,
# even if its no longer directly referenced.
#
# In other words, we should always rebind the original symbol
# before any replacements are applied.
# u0 -> u0 == s1
raw_u1 = u1.node._expr
# TODO Do we still need this logic below?
# Simplify SymBool binding
if (
isinstance(raw_u1, sympy.Piecewise)

View File

@ -648,6 +648,15 @@ class CodeGen:
if verbose:
# override annotation with more detailed information
try:
from torch.distributed.tensor._api import DTensor, DTensorSpec
dtensorspec_format_shard_order_str = (
DTensorSpec.format_shard_order_str
)
except ModuleNotFoundError:
DTensor = None # type: ignore[assignment,misc]
dtensorspec_format_shard_order_str = None
from torch.fx.experimental.proxy_tensor import py_sym_types
from torch.fx.passes.shape_prop import TensorMetadata
@ -678,6 +687,16 @@ class CodeGen:
core = _tensor_annotation(meta_val)
if is_plain:
maybe_type_annotation = f': "{core}"'
elif type(meta_val) is DTensor:
assert dtensorspec_format_shard_order_str is not None
dtensor_meta = dtensorspec_format_shard_order_str(
meta_val._spec.placements, # type: ignore[attr-defined]
meta_val._spec.shard_order, # type: ignore[attr-defined]
)
cls = meta_val.__class__.__name__
maybe_type_annotation = (
f': "{cls}({core}, {dim_green(dtensor_meta)})"'
)
else:
cls = meta_val.__class__.__name__
maybe_type_annotation = f': "{cls}({core})"'

View File

@ -165,6 +165,7 @@ def insert_deferred_runtime_asserts(
node: torch.fx.Node,
stack_trace: Optional[str] = None,
nn_module_stack: Optional[dict[str, Any]] = None,
custom: Optional[dict[str, Any]] = None,
) -> None:
fake_args = pytree.tree_map(
lambda arg: (
@ -188,6 +189,8 @@ def insert_deferred_runtime_asserts(
node.meta["stack_trace"] = stack_trace
if nn_module_stack is not None:
node.meta["nn_module_stack"] = nn_module_stack
if custom is not None:
node.meta["custom"] = custom
# Track asserts/checks we've added
added_asserts: set[sympy.Expr] = set()
@ -617,6 +620,9 @@ def insert_deferred_runtime_asserts(
_node_metadata_hook,
stack_trace=node.meta.get("stack_trace"),
nn_module_stack=node.meta.get("nn_module_stack"),
# nodes added in `apply_runtime_assertion_pass` will have the same annotation
# as the input node to the assertion
custom=node.meta.get("custom"),
),
):
if (min_val := convert(vr.lower)) is not None:

View File

@ -210,7 +210,8 @@ class _KinetoProfile:
def start_trace(self) -> None:
if self.execution_trace_observer:
self.execution_trace_observer.start()
assert self.profiler is not None
if self.profiler is None:
raise AssertionError("Profiler must be initialized before starting trace")
self.profiler._start_trace()
if self.profile_memory:
@ -256,7 +257,8 @@ class _KinetoProfile:
def stop_trace(self) -> None:
if self.execution_trace_observer:
self.execution_trace_observer.stop()
assert self.profiler is not None
if self.profiler is None:
raise AssertionError("Profiler must be initialized before stopping trace")
self.profiler.__exit__(None, None, None)
def export_chrome_trace(self, path: str):
@ -264,7 +266,10 @@ class _KinetoProfile:
Exports the collected trace in Chrome JSON format. If kineto is enabled, only
last cycle in schedule is exported.
"""
assert self.profiler
if self.profiler is None:
raise AssertionError(
"Profiler must be initialized before exporting chrome trace"
)
if path.endswith(".gz"):
fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False)
fp.close()
@ -284,7 +289,8 @@ class _KinetoProfile:
path (str): save stacks file to this location;
metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total"
"""
assert self.profiler
if self.profiler is None:
raise AssertionError("Profiler must be initialized before exporting stacks")
return self.profiler.export_stacks(path, metric)
def toggle_collection_dynamic(
@ -316,7 +322,7 @@ class _KinetoProfile:
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
"""
if not self.profiler:
if self.profiler is None:
return
self.profiler.toggle_collection_dynamic(enable, activities)
@ -333,7 +339,10 @@ class _KinetoProfile:
To use shape/stack functionality make sure to set record_shapes/with_stack
when creating profiler context manager.
"""
assert self.profiler
if self.profiler is None:
raise AssertionError(
"Profiler must be initialized before getting key averages"
)
return self.profiler.key_averages(
group_by_input_shape, group_by_stack_n, group_by_overload_name
)
@ -343,7 +352,8 @@ class _KinetoProfile:
Returns the list of unaggregated profiler events,
to be used in the trace callback or after the profiling is finished
"""
assert self.profiler
if self.profiler is None:
raise AssertionError("Profiler must be initialized before accessing events")
return self.profiler.function_events
def add_metadata(self, key: str, value: str) -> None:
@ -395,7 +405,10 @@ class _KinetoProfile:
if missing:
raise ValueError(f"{', '.join(missing)} required for memory profiling.")
assert self.profiler is not None and self.profiler.kineto_results is not None
if self.profiler is None or self.profiler.kineto_results is None:
raise AssertionError(
"Profiler and kineto_results must be initialized for memory profiling"
)
return MemoryProfile(self.profiler.kineto_results)
def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
@ -485,7 +498,8 @@ def schedule(
"""
def schedule_fn(step: int) -> ProfilerAction:
assert step >= 0
if step < 0:
raise AssertionError(f"Step must be non-negative. Got {step}.")
if step < skip_first:
return ProfilerAction.NONE
else:
@ -508,9 +522,11 @@ def schedule(
else ProfilerAction.RECORD_AND_SAVE
)
assert (
wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0
), "Invalid profiler schedule arguments"
if wait < 0 or warmup < 0 or active <= 0 or repeat < 0 or skip_first < 0:
raise AssertionError(
f"Invalid profiler schedule arguments. Got wait={wait} (need >= 0), warmup={warmup} (need >= 0), "
f"active={active} (need > 0), repeat={repeat} (need >= 0), skip_first={skip_first} (need >= 0)."
)
if warmup == 0:
warn(
"Profiler won't be using warmup, this can skew profiler results",
@ -717,7 +733,8 @@ class profile(_KinetoProfile):
activities_set.add(ProfilerActivity.CUDA)
elif ProfilerActivity.CUDA in activities_set:
activities_set.remove(ProfilerActivity.CUDA)
assert len(activities_set) > 0, "No valid profiler activities found"
if len(activities_set) == 0:
raise AssertionError("No valid profiler activities found")
super().__init__(
activities=activities,

View File

@ -306,6 +306,24 @@ class PythonPrinter(ExprPrinter):
raise TypeError("ndigits must be an instance of sympy.Integer")
return f"round({self._print(number)}, {ndigits})"
def _print_Piecewise(self, expr: sympy.Expr) -> str:
# Convert Piecewise(expr_cond_pairs) to nested ternary expressions
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
# becomes: e1 if c1 else (e2 if c2 else (... else eN))
result: Optional[str] = None
for expr_i, cond_i in reversed(expr.args):
expr_str = self._print(expr_i)
if cond_i == True: # noqa: E712
# This is the default case
result = expr_str
else:
cond_str = self._print(cond_i)
if result is None:
result = expr_str
else:
result = f"({expr_str} if {cond_str} else {result})"
return result if result else "0"
class CppPrinter(ExprPrinter):
def _print_Integer(self, expr: sympy.Expr) -> str:
@ -327,6 +345,24 @@ class CppPrinter(ExprPrinter):
)
return f"{c} ? {p} : {q}"
def _print_Piecewise(self, expr: sympy.Expr) -> str:
# Convert Piecewise(expr_cond_pairs) to nested ternary operators
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
# becomes: c1 ? e1 : (c2 ? e2 : (... : eN))
result: Optional[str] = None
for expr_i, cond_i in reversed(expr.args):
expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5)
if cond_i == True: # noqa: E712
# This is the default case
result = expr_str
else:
cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5)
if result is None:
result = expr_str
else:
result = f"{cond_str} ? {expr_str} : {result}"
return f"({result})" if result else "0"
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
x, div, mod = expr.args
x = self.doprint(x)