Compare commits

..

35 Commits

Author SHA1 Message Date
2913cdf29d Update 2025-11-05 07:45:47 -08:00
0661a232a5 Update 2025-11-04 20:06:01 -08:00
5db844dafa Move back 2025-11-04 19:00:42 -08:00
73efad99d7 Update 2025-11-04 18:14:52 -08:00
df1268c311 Make the printed report clearer 2025-11-04 15:28:16 -08:00
84f9f1541d Test that make coverage works 2025-11-04 11:24:38 -08:00
27c0c126bf Update 2025-11-04 10:34:47 -08:00
670873155a Update 2025-11-04 10:34:47 -08:00
923737c510 Update 2025-11-04 10:34:47 -08:00
13d5b14a73 Update 2025-11-04 10:34:47 -08:00
a35a42b21c Update 2025-11-04 10:34:47 -08:00
15956bc1e8 Update 2025-11-04 10:34:47 -08:00
b319ea1111 Change python doc push script to print the undocumented modules 2025-11-04 10:34:47 -08:00
ce4c68a5f6 Update 2025-11-04 10:34:47 -08:00
c6da4a59a3 Test 2025-11-04 10:34:47 -08:00
53f75cd5ba Fixed some syntax errors in SECURITY.md file. (#166718)
Fixed some syntax errors in SECURITY.md file including PyTorch's capitalization problems, some grammatical inconsistencies, etc
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166718
Approved by: https://github.com/mikaylagawarecki
2025-11-04 18:18:38 +00:00
527b1109a8 Delete deprecated fp32 precision warnings (#166956)
The deprecation warning led to warning spamming in PyTorch APIs, like
torch.compile. This is not how a deprecation warning should go: if we
add a deprecation warning, we'd better update our built-in APIs to
prevent warning spam.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166956
Approved by: https://github.com/albanD
2025-11-04 17:50:04 +00:00
clr
3144713325 subproc_pool: Add support for enabling quiesce via a timer (#166467)
This adds the capability to subproc pool to enable quiesce via a timer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166467
Approved by: https://github.com/masnesral
2025-11-04 17:37:41 +00:00
eefa16342c [Inductor] addmm with bias -> unfuse bias if there is a pointwise/reduction consumer (#166165)
Prefer unfused addmm when there is at least a single elemwise/reduction consumer..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166165
Approved by: https://github.com/eellison
2025-11-04 17:23:04 +00:00
d02f68f484 [BE] Use [[maybe_unused]] (#166865)
Instead of `(void) foo; // Unused parameter` trick, as this is a C++17 standard feature

Will replace further repetitions of the same pattern soon after
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166865
Approved by: https://github.com/mikaylagawarecki, https://github.com/Skylion007, https://github.com/janeyx99
2025-11-04 17:08:28 +00:00
68eb55c4b2 Add model code stack trace to cuda.memory._snapshot (#166676)
We store a mapping between generated fx graph code and original model code stack trace in `fx.traceback._FX_METADATA_REGISTRY`. And we do a post-processing on the memory snapshot to append the original model stack trace information.

To achieve this, the biggest change we had to do in `aot_eager` mode is to give each generated fx graph a unique stack trace, i.e. it cannot just be `<eval_with_key>`. We set co_filename to **pretend** that the code is from `co_filename` file. Now instead of `<eval_with_key>` in stack trace, we get something like `fx_generated_3a4b5c6d7e8f9a0.py`.

`augment_with_fx_traces` arg is added to `torch.cuda.memory._snapshot` and `_dump_snapshot`. When the arg is set to True, a post-processing will run to populate the original model stack trace to the snapshot frames.

The new behavior of GraphModule can be controlled by `TORCH_ENRICH_RPOFILER_STACK_TRACE` or `_dynamo.config.enrich_profiler_metadata=True`.

Alternative:

Instead of setting co_filename, we can also do it like below:
Note that if we do it this way, we will need to dump the file to make the graph module torch-scriptable. TorchScript requires source access in order to carry out compilation, so we need to make sure original .py files are available.
```
        key = filename
        globals_copy = globals.copy()
        globals_copy["__file__"] = key
        globals_copy["__name__"] = key
        linecache.lazycache(key, globals_copy)
        exec(compile(src, key, "exec"), globals)
````

Other changes:

- Update `MemoryViz.js` to display fx node information and original model code if exist

```
python test/test_fx.py -k test_lineno_map
python test/test_fx.py -k test_custom_traceback_raised
python test/test_public_bindings.py
python test/test_cuda.py -k test_fx_memory
python test/test_fx.py -k test_informative_co_filename
python test/test_fx.py -k test_autowrap_functions
python test/dynamo/test_utils.py -k test_inductor_provenance
```

```python
# Profile with memory snapshot
torch.cuda.memory._record_memory_history()

with  torch._dynamo.config.patch("enrich_profiler_stack_trace", True):
    compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
    result = compiled(torch.randn(10, 10, device="cuda:0"))

torch.cuda.memory._dump_snapshot("memory_snapshot.pickle", augment_with_fx_traces=True)
torch.cuda.memory._record_memory_history(enabled=None)
```

<img width="913" height="711" alt="Screenshot 2025-10-30 at 10 40 44 AM" src="https://github.com/user-attachments/assets/8d7a1833-f98d-4756-b666-1d63ab57b27b" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166676
Approved by: https://github.com/albanD, https://github.com/ezyang
2025-11-04 17:01:02 +00:00
8d4b8ab430 [ez] Print some more test timing info in the logs (#166447)
You can just subtract timestamps, but this makes it easier
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166447
Approved by: https://github.com/Skylion007
2025-11-04 16:45:22 +00:00
afd50bdd29 [CI] Use smaller amx + avx2 runners for inductor test? (#164989)
Results from CI:
No failures but generally takes longer, maybe ~20% increase in time?
But the smaller runner is ~25% of the cost of the current runner, so in terms of cost this is a decrease

If the 20% is too much, we can try the 4x larger runners, which are about half the cost of the current runner, so it would probably still result in cost savings with hopefully less impact to time

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164989
Approved by: https://github.com/BoyuanFeng, https://github.com/huydhn
2025-11-04 16:43:06 +00:00
56dfd4c74b Add CUDA MXFP4 scaled mm support via. FBGEMM (#166526)
Summary:

* Pull in `f4f4bf16` from FBGemm to provide MXFP4 support for CUDA
* Add testing

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166526
Approved by: https://github.com/drisspg, https://github.com/ngimel
2025-11-04 15:53:16 +00:00
24db5c4451 [inductor] do not hard fail on FakePG with nccl estimator (#166869)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166869
Approved by: https://github.com/eellison
ghstack dependencies: #166521
2025-11-04 15:22:38 +00:00
cc8bfd1206 Docker release build: Use 13.0.0 nvidia docker (#166904)
Forward fix for failing Docker release builds
Related to: https://github.com/pytorch/pytorch/issues/166897

Nightly Docker build failure https://github.com/pytorch/pytorch/actions/runs/18900508440/job/53946606434
Due to missing base image:
```
ERROR: failed to build: failed to solve: docker.io/nvidia/cuda:13.0.2-devel-ubuntu22.04: not found
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166904
Approved by: https://github.com/tinglvv, https://github.com/malfet
2025-11-04 13:58:10 +00:00
c45b156605 Fix DeepSeek scaling tensor handling (#166752)
Summary:

cuBlasLt enforces size/stride requirements for 1x128 and 128x128 blockwise scaling
kernels, some of which weren't being handled, causing silent incorrect
answers especially for 128x128 scaling cases.

cuBlasLt enforces ([docs](https://docs.nvidia.com/cuda/cublas/#scaling-factors-layouts)) for deepseek-style
scaling, for `A: MxN`, `B: KxN` you have the following:

```Py
L = K // 128
L4 = round_up(L, 4)

1x128 x 128x128:
* A_scale: [M, K // 128], stride: [1, M]
* B_scale: [L4, N // 128], stride: [1, L4]

128x128 x 1x128:
* A_scale: [L4, M // 128], stride: [1, L4]
* B_scale: [N, K // 128], stride: [1, N]

1x128 x 1x128:
* A_scale: [M, K // 128], stride: [1, M]
* B_scale: [N, K // 128], stride: [1, N]
```

Notable here is the `L4` term, which means that we must round up to the nearest multiple of 4 blocks
in the `K` dimension. This wasn't enforced previously, and caused silent wrong answers
where `(K // 128) % 4 != 0`.

Test Plan:

Reviewers:

Subscribers:

@vkuzo

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166752
Approved by: https://github.com/drisspg, https://github.com/vkuzo
2025-11-04 13:32:24 +00:00
8fff7e36b4 [xpu][test] Add UT for expandable segments (#166495)
# Motivation
This PR aims to reuse some UT to validate the expandable segment feature.

# Additional Context
Currently, the failure is related to the internal track `GSD-11403`, we could get the fix when upgrading the driver to `ci-neo-master-034630` or greater
TODO: add test conv and gemm into this test case when upgrading the driver.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166495
Approved by: https://github.com/albanD, https://github.com/EikanWang, https://github.com/gujinghui
ghstack dependencies: #166299, #166292, #166424
2025-11-04 08:01:35 +00:00
82fa2aa269 DTensor: Fix trivial as_strided case, add alias support (#166867)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166867
Approved by: https://github.com/albanD
ghstack dependencies: #166868
2025-11-04 07:18:32 +00:00
09e0285608 [xpu][feature][inductor] Enable decompose_mm_pass and UT on Intel GPU (#166613)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166613
Approved by: https://github.com/hl475
2025-11-04 06:58:05 +00:00
d980d8dc79 [dynamo] Implement __sym_float__ for SymBool to fix multiplication TypeError (#165264)
Fixes #164684

### Description

Symbolic tracing fails during multiplication between a `SymBool` and a `Tensor`. This scenario is triggered when `.item()` is called on a 0-dim boolean tensor within a `torch.compile` region. In compile mode, this yields a `SymBool`, and the subsequent `SymBool * FakeTensor` operation is unsupported, leading to a `TypeError` or a data-dependent `UserError`.

### Solution

This PR addresses the issue at the type-conversion level, as suggested by reviewers.

The root cause of the TypeError is that torch.sym_float() (which is called by _maybe_convert_to_dtype during type promotion for aten.mul) lacks a conversion path for SymBool and incorrectly falls back to builtins.float(SymBool).

This fix addresses this by implementing the __sym_float__(self) method within the SymBool class (defined in torch/__init__.py).

The torch.sym_float(a) utility function is already designed to check for hasattr(a, "__sym_float__") before falling back to builtins.float(). By adding this method, SymBool instances now correctly advertise their ability to be cast to SymFloat. The new method implementation leverages self.node.sym_float() to correctly convert the symbolic boolean value to its symbolic float representation (0.0 or 1.0), resolving the TypeError at its source.

This approach is more fundamental than modifying a specific operation in builtin.py and ensures SymBool can be correctly promoted to SymFloat in any operation, while still preserving its boolean nature for control flow operations like guard_or_false (which is verified by a new test case).

### Verification

1.  **Bug Reproduced**: The initial `UserError: Could not guard on data-dependent expression` was successfully reproduced with the script from the issue. As shown below
<img width="1369" height="945" alt="Screenshot 2025-10-13 at 10 29 05" src="https://github.com/user-attachments/assets/8daa4555-3347-4af5-906a-02150b8df9d1" />

2.  **Fix Validated**: After applying the code changes, the same script now runs to completion, printing ` eager success` and ` compile success`. As shown below
<img width="1228" height="82" alt="Screenshot 2025-10-13 at 10 29 21" src="https://github.com/user-attachments/assets/94c4f143-b898-4dda-9bff-0ad5450a30fa" />

3. Added a new test class DynamoOpPromotionTests to test/dynamo/test_misc.py with three new test cases:
1. test_symbool_tensor_mul_does_not_fail: Verifies that the original bug report code (with .item() + *) no longer raises an error when compiled.
2. test_symbool_guard_or_false: Verifies that this fix does not cause a regression for guard_or_false(SymBool) (the concern raised by reviewers).
3. test_symbool_tensor_mul: Verifies the behavior of Tensor(bool) * Tensor(float) (without .item()) for completeness.
All new tests were added and pass locally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165264
Approved by: https://github.com/laithsakka, https://github.com/Lucaskabela
2025-11-04 06:33:20 +00:00
c7d00de115 [xpu][fix] Fix XPU oneDNN memory query bug: pointer to array (#166830)
# Motivation

I believe this is a bug - here's why:
In [dnnl_common_types.h](98132c4908/include/oneapi/dnnl/dnnl_common_types.h (L116-L125)) is defined as a pointer to an `int64_t[12]` array;
We can confirm this from the implementation in [memory_desc.cpp](98132c4908/src/common/memory_desc.cpp (L746-L748)) where the member indeed points to an internal array.

# Solution

Therefore, when accessing `md_padded_dims`, we should first dereference the pointer and then use it with an index - directly using it without dereferencing would corrupt memory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166830
Approved by: https://github.com/EikanWang
2025-11-04 06:12:40 +00:00
d3cf90ada5 Revert "[inductor] require shape in TritonCSEVariable (#162275)"
This reverts commit c21868b4359586550b12e1d9102283c792f45dff.

Reverted https://github.com/pytorch/pytorch/pull/162275 on behalf of https://github.com/izaitsevfb due to breaking test_rms_norm_bwd_float32_split_reductions_True_shape2 ([comment](https://github.com/pytorch/pytorch/pull/162275#issuecomment-3484049109))
2025-11-04 06:06:18 +00:00
0e1a88904f [Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#165036)
Make sure you're on cutlass 4.2.0+

Test Plan:
Tritonbench(oss):
`clear; CUDA_VISIBLE_DEVICES=2 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 python run.py --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_triton_grouped_mm --precision bf16  --num-inputs 1 --metrics tflops,accuracy`

Unit Tests(oss):
`clear; python test/inductor/test_cutedsl_grouped_mm.py`

Differential Revision: D82010227

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165036
Approved by: https://github.com/alexsamardzic, https://github.com/drisspg, https://github.com/mlazos
2025-11-04 05:58:58 +00:00
3232caa078 [XPU][Fix] Register convolution_overrideable for flops count (#166839)
Fixes #166838
1. Register `convolution_overrideable` key for flop_counter. CUDA relies on keys with `cudnn_convolution`. For devices like `XPU`, it falls to `convolution_overrideable`. Without the correct registration, the flop_couter will silently return 0 for XPU in line:
e1d011d6eb/torch/_inductor/analysis/profile_analysis.py (L178-L179)

2. Enable the tests when enabling the XPU on `test_analysis.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166839
Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/jansel
2025-11-04 05:56:29 +00:00
57 changed files with 2353 additions and 680 deletions

View File

@ -1,15 +1,11 @@
sphinx==5.3.0
sphinx==7.2.6
#Description: This is used to generate PyTorch docs
#Pinned versions: 5.3.0
#Pinned versions: 7.2.6
standard-imghdr==3.13.0; python_version >= "3.13"
#Description: This is needed by Sphinx, so it needs to be added here.
# The reasons are as follows:
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
pytorch_sphinx_theme2==0.2.0
#Description: This is needed to generate PyTorch docs
#Pinned versions: 0.2.0
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
# something related to Docker setup. We can investigate this later.
@ -36,17 +32,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
#Description: This is used to generate PyTorch docs
#Pinned versions: 2.13.0
breathe==4.34.0
breathe==4.36.0
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 4.34.0
#Pinned versions: 4.36.0
exhale==0.2.3
exhale==0.3.7
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.2.3
#Pinned versions: 0.3.7
docutils==0.16
docutils==0.20
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.16
#Pinned versions: 0.20
bs4==0.0.1
#Description: This is used to generate PyTorch C++ docs
@ -56,13 +52,13 @@ IPython==8.12.0
#Description: This is used to generate PyTorch functorch docs
#Pinned versions: 8.12.0
myst-nb==0.17.2
myst-nb==1.3.0
#Description: This is used to generate PyTorch functorch and torch.compile docs.
#Pinned versions: 0.17.2
#Pinned versions: 1.3.0
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
python-etcd==0.4.5
sphinx-copybutton==0.5.0
sphinx-design==0.4.0
sphinx-design==0.6.1
sphinxcontrib-mermaid==1.0.0
myst-parser==0.18.1
myst-parser==4.0.1

View File

@ -89,23 +89,41 @@ if [ "$is_main_doc" = true ]; then
make coverage
# Now we have the coverage report, we need to make sure it is empty.
# Count the number of lines in the file and turn that number into a variable
# $lines. The `cut -f1 ...` is to only parse the number, not the filename
# Skip the report header by subtracting 2: the header will be output even if
# there are no undocumented items.
# Sphinx 7.2.6+ format: python.txt contains a statistics table with a TOTAL row
# showing the undocumented count in the third column.
# Example: | TOTAL | 99.83% | 2 |
#
# Also: see docs/source/conf.py for "coverage_ignore*" items, which should
# be documented then removed from there.
lines=$(wc -l build/coverage/python.txt 2>/dev/null |cut -f1 -d' ')
undocumented=$((lines - 2))
if [ $undocumented -lt 0 ]; then
# Extract undocumented count from TOTAL row in Sphinx 7.2.6 statistics table
# The table format is: | Module | Coverage | Undocumented |
# Extract the third column (undocumented count) from the TOTAL row
undocumented=$(grep "| TOTAL" build/coverage/python.txt | awk -F'|' '{print $4}' | tr -d ' ')
if [ -z "$undocumented" ] || ! [[ "$undocumented" =~ ^[0-9]+$ ]]; then
echo coverage output not found
exit 1
elif [ $undocumented -gt 0 ]; then
echo undocumented objects found:
cat build/coverage/python.txt
elif [ "$undocumented" -gt 0 ]; then
set +x # Disable command echoing for cleaner output
echo ""
echo "====================="
echo "UNDOCUMENTED OBJECTS:"
echo "====================="
echo ""
# Find the line number of the TOTAL row and print only what comes after it
total_line=$(grep -n "| TOTAL" build/coverage/python.txt | cut -d: -f1)
if [ -n "$total_line" ]; then
# Print only the detailed list (skip the statistics table)
tail -n +$((total_line + 2)) build/coverage/python.txt
else
# Fallback to showing entire file if TOTAL line not found
cat build/coverage/python.txt
fi
echo ""
echo "Make sure you've updated relevant .rsts in docs/source!"
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
echo "You can reproduce locally by running 'cd docs && make coverage && tail -n +\$((grep -n \"| TOTAL\" build/coverage/python.txt | cut -d: -f1) + 2)) build/coverage/python.txt'"
set -x # Re-enable command echoing
exit 1
fi
else

View File

@ -337,7 +337,7 @@ test_python() {
test_python_smoke() {
# Smoke tests for H100/B200
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
assert_git_not_dirty
}

View File

@ -28,7 +28,7 @@ CUDA_ARCHES_FULL_VERSION = {
"12.6": "12.6.3",
"12.8": "12.8.1",
"12.9": "12.9.1",
"13.0": "13.0.2",
"13.0": "13.0.0",
}
CUDA_ARCHES_CUDNN_VERSION = {
"12.6": "9",

View File

@ -8,6 +8,7 @@ on:
- docker.Makefile
- .github/workflows/docker-release.yml
- .github/scripts/generate_docker_release_matrix.py
- .github/scripts/generate_binary_build_matrix.py
push:
branches:
- nightly

View File

@ -115,10 +115,10 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: |
{ include: [
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
]}
secrets: inherit

View File

@ -84,13 +84,13 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: |
{ include: [
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" },
]}
build-additional-packages: "vision audio torchao"

1
.gitignore vendored
View File

@ -127,6 +127,7 @@ torch/test/
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
torch/version.py
torch/_inductor/kernel/vendored_templates/*
minifier_launcher.py
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*

View File

@ -1,7 +1,7 @@
# Security Policy
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
- [**Using Pytorch Securely**](#using-pytorch-securely)
- [**Using PyTorch Securely**](#using-pytorch-securely)
- [Untrusted models](#untrusted-models)
- [TorchScript models](#torchscript-models)
- [Untrusted inputs](#untrusted-inputs)
@ -10,28 +10,28 @@
- [**CI/CD security principles**](#cicd-security-principles)
## Reporting Security Issues
Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch.
Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch.
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
https://www.facebook.com/whitehat
## Using Pytorch Securely
**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
## Using PyTorch Securely
**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
### Untrusted models
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
### TorchScript models
TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
### Untrusted inputs during training and prediction
@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
### Data privacy
**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits).
**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits).
### Using distributed features

View File

@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
if(USE_CUDA)
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*")
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")

View File

@ -23,8 +23,6 @@ C10_DIAGNOSTIC_POP()
#endif
namespace at {
namespace {
/*
These const variables defined the fp32 precisions for different backend
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
@ -41,16 +39,6 @@ namespace {
->rnn
*/
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
TORCH_WARN_ONCE(
"Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' "
"or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, "
"torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see "
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
);
}
} // namespace
Float32Backend str2backend(const std::string& name) {
if (name == "generic")
return Float32Backend::GENERIC;
@ -206,7 +194,6 @@ bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
} else {
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
}
warn_deprecated_fp32_precision_api();
return allow_tf32_cudnn;
}
@ -214,7 +201,6 @@ void Context::setAllowTF32CuDNN(bool b) {
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
allow_tf32_cudnn = b;
warn_deprecated_fp32_precision_api();
}
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
@ -325,7 +311,6 @@ bool Context::allowTF32CuBLAS() const {
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
"We suggest only using the new API to set the TF32 flag. See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
warn_deprecated_fp32_precision_api();
return allow_tf32_new;
}
@ -349,7 +334,6 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
"We suggest only using the new API for matmul precision. See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
warn_deprecated_fp32_precision_api();
return float32_matmul_precision;
}
@ -377,7 +361,6 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op)
void Context::setFloat32MatmulPrecision(const std::string &s) {
auto match = [this](const std::string & s_) {
warn_deprecated_fp32_precision_api();
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
if (s_ == "highest") {
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;

View File

@ -59,6 +59,24 @@
// forward declare
class cublasCommonArgs;
#ifndef _WIN32
namespace fbgemm_gpu {
// NOTE(slayton58): FBGemm_GPU kernels come from <fbgemm_gpu/torch_ops.h> within the FBGemm repo.
// To update supported ops means a submodule bump, which is.. painful. Instead, we
// can simply forward-declare the methods we want to use.. Works at least as a short-term
// thing, but should still be fixed somewhere/somehow.
at::Tensor f4f4bf16(
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
std::optional<at::Tensor>,
bool use_mx);
} // namespace fbgemm_gpu
#endif
using at::blas::ScalingType;
using at::blas::SwizzleType;
@ -767,33 +785,6 @@ _scaled_rowwise_rowwise(
return out;
}
// Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling.
// Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1,
// and strides become somewhat meaningless
void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) {
if (scale_type == ScalingType::BlockWise1x128) {
TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1),
"at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
auto expected_size = ceil_div<int64_t>(t.size(1), 128);
TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)),
"at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
} else if (scale_type == ScalingType::BlockWise128x128) {
TORCH_CHECK_VALUE(check_size_stride(
scale,
0,
ceil_div<int64_t>(t.size(0), 128),
ceil_div<int64_t>(t.size(1), 128)),
"at dim=0 scale should have ", ceil_div<int64_t>(t.size(0), 128), "elements and stride(0) ", ceil_div<int64_t>(t.size(1), 128), "if ", ceil_div<int64_t>(t.size(0), 128), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
TORCH_CHECK(check_size_stride(
scale, 1, ceil_div<int64_t>(t.size(1), 128), 1),
"at dim=1 scale should have ", ceil_div<int64_t>(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div<int64_t>(t.size(1), 128), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
}
}
void
_check_deepseek_support() {
#ifndef USE_ROCM
@ -806,7 +797,7 @@ _check_deepseek_support() {
}
// Only in cublasLt >= 12.9
TORCH_CHECK_NOT_IMPLEMENTED(
CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900,
CUBLAS_VERSION >= 120900 && cublasLtGetVersion() >= 120900,
"DeepSeek style (1x128, 128x128) scaling requires cublasLt >= 12.9"
);
#endif
@ -823,23 +814,61 @@ _scaled_block1x128_block1x128(
#ifndef USE_ROCM
// Restrictions:
// A, B are FP8, scales are fp32, shape K//128
// CUDA: Only Hopper GPUs
// As: [M x K // 128], stride: [1, M]
// Bs: [N x K // 128], stride: [1, N]
_check_deepseek_support();
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
// check types
TORCH_CHECK_VALUE(
isFloat8Type(mat_a.scalar_type()) &&
isFloat8Type(mat_b.scalar_type()),
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
);
const int64_t M = mat_a.sizes()[0];
const int64_t K = mat_a.sizes()[1];
const int64_t N = mat_b.sizes()[1];
// scale_a shape
TORCH_CHECK_VALUE(
scale_a.size(0) == M &&
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
scale_a.scalar_type() == kFloat,
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
);
// scale_a stride
TORCH_CHECK_VALUE(
scale_a.stride(0) == 1 &&
(
scale_a.stride(1) == M ||
(scale_a.size(1) == 1 && scale_b.stride(1) == 1)
),
"scale_a strides must be (", 1, ", ", M, "); got: ", scale_a.strides()
);
// scale_b shape
TORCH_CHECK_VALUE(
scale_b.size(0) == N &&
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
scale_b.scalar_type() == kFloat,
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
);
// scale_b stride
TORCH_CHECK_VALUE(
scale_b.stride(0) == 1 &&
(
scale_b.stride(1) == N ||
(
scale_b.size(1) == 1 &&
scale_b.stride(1) == 1
)
),
"scale_b strides must be (", 1, ", ", N, "); got: ", scale_a.strides()
);
auto scaling_choice_a = ScalingType::BlockWise1x128;
auto scaling_choice_b = ScalingType::BlockWise1x128;
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
@ -861,24 +890,65 @@ _scaled_block128x128_block1x128(
Tensor& out) {
#ifndef USE_ROCM
// Restrictions:
// A, B are FP8, scales are fp32, shape K//128
// CUDA: Only Hopper GPUs
_check_deepseek_support();
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
"scale_a must have shape ", ceil_div<int64_t>(mat_a.sizes()[0], 128), " x ", ceil_div<int64_t>(mat_a.sizes()[1], 128), " Float elements, got ", scale_a.sizes())
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
// A: [M, K], B: [K, N] are FP8, scales are fp32
// As: [round_up(K // 128, 4), M // 128], stride: [M // 128, 1]
// Bs: [N x K // 128], stride: [1, N]
TORCH_CHECK_VALUE(
isFloat8Type(mat_a.scalar_type()) &&
isFloat8Type(mat_b.scalar_type()),
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
);
const int64_t M = mat_a.sizes()[0];
const int64_t K = mat_a.sizes()[1];
const int64_t N = mat_b.sizes()[1];
// scale_a shape
TORCH_CHECK_VALUE(
scale_a.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
scale_a.size(1) == ceil_div<int64_t>(M, 128) &&
scale_a.scalar_type() == kFloat,
"scale_a must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ",
ceil_div<int64_t>(M, 128), " Float elements, got ", scale_a.sizes()
);
// scale_a stride
TORCH_CHECK_VALUE(
scale_a.stride(0) == 1 &&
(
scale_a.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
(
scale_a.size(1) == 1 &&
scale_a.stride(1) == 1
)
),
"scale_a must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
);
// scale_b shape
TORCH_CHECK_VALUE(
scale_b.size(0) == N &&
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
scale_b.scalar_type() == kFloat,
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
);
// scale_b stride
TORCH_CHECK_VALUE(
scale_b.stride(0) == 1 &&
(
scale_b.stride(1) == N ||
(
scale_b.size(1) == 1 &&
scale_b.stride(1) == 1
)
),
"scale_b must have strides (1, ", N, "); got ", scale_b.strides()
);
auto scaling_choice_a = ScalingType::BlockWise128x128;
auto scaling_choice_b = ScalingType::BlockWise1x128;
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
@ -900,24 +970,62 @@ _scaled_block1x128_block128x128(
Tensor& out) {
#ifndef USE_ROCM
// Restrictions:
// A, B are FP8, scales are fp32, A: shape K//128, B: K//128, N//128
// CUDA: Only Hopper GPUs
_check_deepseek_support();
// A: [M, K], B: [K, N] are FP8, scales are fp32
// As: [M x K // 128], stride: [1, M]
// Bs: [round_up(K // 128, 4) x N // 128], stride: [1, N // 128]
TORCH_CHECK_VALUE(
isFloat8Type(mat_a.scalar_type()) &&
isFloat8Type(mat_b.scalar_type()),
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
);
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes())
int64_t M = mat_a.size(0);
int64_t K = mat_a.size(1);
int64_t N = mat_b.size(1);
// scale_a shape
TORCH_CHECK_VALUE(
scale_a.size(0) == M &&
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
scale_a.scalar_type() == kFloat,
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
);
// scale_a stride
TORCH_CHECK_VALUE(
scale_a.stride(0) == 1 &&
(
scale_a.stride(1) == M ||
(
scale_a.size(1) == 1 &&
scale_a.stride(1) == 1
)
),
"scale_a must have strides (1, ", M, "); got ", scale_b.strides()
);
// scale_b shape
TORCH_CHECK_VALUE(
scale_b.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
scale_b.size(1) == ceil_div<int64_t>(N, 128) &&
scale_b.scalar_type() == kFloat,
"scale_b must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ", ceil_div<int64_t>(N, 128), " Float elements, got ", scale_b.sizes()
);
// scale_b stride
TORCH_CHECK_VALUE(
scale_b.stride(0) == 1 &&
(
scale_b.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
(
scale_b.size(1) == 1 &&
scale_b.stride(1) == 1
)
),
"scale_b must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
);
auto scaling_choice_a = ScalingType::BlockWise1x128;
auto scaling_choice_b = ScalingType::BlockWise128x128;
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
@ -997,26 +1105,47 @@ _scaled_mxfp4_mxfp4(
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
Tensor& out) {
#ifndef USE_ROCM
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
#endif
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
#else
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
// Packed FP4 format means actual-K = 2 * reported-K -- adjust
auto K_multiplier = 2;
#ifdef USE_ROCM
// AMD
auto scale_a_elems = ceil_div<int64_t>(K_multiplier * mat_a.size(0), 32) * mat_a.size(1);
auto scale_b_elems = ceil_div<int64_t>(K_multiplier * mat_b.size(1), 32) * mat_b.size(0);
#else
// NVIDIA
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_a.size(1), 32), 4);
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_b.size(0), 32), 4);
#endif
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
#ifdef USE_ROCM
// AMD
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)");
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)");
#else
// NVIDIA
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
#endif
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
"For Blockwise scaling both scales should be contiguous");
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
#ifdef USE_ROCM
// AMD
auto scaling_choice_a = ScalingType::BlockWise1x32;
auto scaling_choice_b = ScalingType::BlockWise1x32;
@ -1031,11 +1160,30 @@ _scaled_mxfp4_mxfp4(
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
out.scalar_type() == ScalarType::Half,
"Block-wise scaling only supports BFloat16 or Half output types");
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
#endif
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
#else
// NVIDIA
// NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor,
// but we have one we need to use. Two clear options are to copy into
// our output (slow), or use a move-assignment-operator (faster).
// However, the compiler can complain about the explicit move preventing
// copy elision because the return from f4f4bf16 is a temporary object.
// So we don't explicitly move, and trust the compiler here...
// In the longer term this should be fixed on the FBGemm side.
out = fbgemm_gpu::f4f4bf16(
mat_a,
mat_b.transpose(-2, -1),
scale_a,
scale_b,
std::nullopt, /* global_scale */
true /* use_mx */
);
return out;
#endif
#endif
}
Tensor&
@ -1160,17 +1308,20 @@ _scaled_mm_cuda_v2_out(
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")");
}
// Handle fp4 packed-K dimension
int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1;
TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1],
" but got ", bias->numel());
TORCH_CHECK_VALUE(
mat_a.sizes()[1] % 16 == 0,
K_multiplier * mat_a.sizes()[1] % 16 == 0,
"Expected trailing dimension of mat1 to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes()[0],
"x",
mat_a.sizes()[1],
K_multiplier * mat_a.sizes()[1],
").");
TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
mat_b.sizes()[1], ") must be divisible by 16");
// TODO(slayton): Existing checks, not sure if they should really be here.

View File

@ -157,10 +157,10 @@ bool onednn_strides_check(const Tensor& src) {
return true;
dnnl_dims_t blocks = {0};
int perm[DNNL_MAX_NDIMS] = {0};
std::array<int, DNNL_MAX_NDIMS> perm = {0};
for (int d = 0; d < md_ndims; ++d) {
// no strides check needed for empty tensor
if (md_padded_dims[d] == nullptr)
if ((*md_padded_dims)[d] == 0)
return true;
// no strides verification for runtime dims
@ -178,14 +178,15 @@ bool onednn_strides_check(const Tensor& src) {
// A custom comparator to yield linear order on perm
auto idx_sorter = [&](const int a, const int b) -> bool {
if (strides[a] == strides[b] && md_padded_dims[a] == md_padded_dims[b])
if (strides[a] == strides[b] &&
(*md_padded_dims)[a] == (*md_padded_dims)[b])
return a < b;
else if (strides[a] == strides[b])
return md_padded_dims[a] < md_padded_dims[b];
return (*md_padded_dims)[a] < (*md_padded_dims)[b];
else
return strides[a] < strides[b];
};
std::sort(perm, perm + md_ndims, idx_sorter);
std::sort(perm.begin(), perm.begin() + md_ndims, idx_sorter);
auto min_stride = block_size;
for (int idx = 0; idx < md_ndims; ++idx) {
@ -199,9 +200,10 @@ bool onednn_strides_check(const Tensor& src) {
return false;
// update min_stride for next iteration
const auto padded_dim = *md_padded_dims[d];
const auto padded_dim = (*md_padded_dims)[d];
min_stride = block_size * strides[d] * (padded_dim / blocks[d]);
}
return true;
}

View File

@ -206,6 +206,41 @@ templates_path = [
os.path.join(os.path.dirname(pytorch_sphinx_theme2.__file__), "templates"),
]
# TODO: document these and remove them from here.
# Fixes the duplicated
autosummary_filename_map = {
"torch.nn.utils.prune.identity": "torch.nn.utils.prune.identity_function",
"torch.nn.utils.prune.Identity": "torch.nn.utils.prune.Identity_class",
"torch.optim.adamw.adamw": "torch.optim.adamw.adamw_function",
"torch.optim.adamw.AdamW": "torch.optim.adamw.AdamW_class",
"torch.optim.asgd.asgd": "torch.optim.asgd.asgd_function",
"torch.optim.asgd.ASGD": "torch.optim.asgd.ASGD_class",
"torch.optim.nadam.nadam": "torch.optim.nadam.nadam_function",
"torch.optim.nadam.NAdam": "torch.optim.nadam.NAdam_class",
"torch.optim.radam.radam": "torch.optim.radam.radam_function",
"torch.optim.radam.RAdam": "torch.optim.radam.RAdam_class",
"torch.optim.rmsprop.rmsprop": "torch.optim.rmsprop.rmsprop_function",
"torch.optim.rmsprop.RMSprop": "torch.optim.rmsprop.RMSprop_class",
"torch.optim.rprop.rprop": "torch.optim.rprop.rprop_function",
"torch.optim.rprop.Rprop": "torch.optim.rprop.Rprop_class",
"torch.optim.sgd.sgd": "torch.optim.sgd.sgd_function",
"torch.optim.sgd.SGD": "torch.optim.sgd.SGD_class",
"torch.optim.adadelta.adadelta": "torch.optim.adadelta.adadelta_function",
"torch.optim.adadelta.Adadelta": "torch.optim.adadelta.Adadelta_class",
"torch.optim.adagrad.adagrad": "torch.optim.adagrad.adagrad_function",
"torch.optim.adagrad.Adagrad": "torch.optim.adagrad.Adagrad_class",
"torch.optim.adam.adam": "torch.optim.adam.adam_function",
"torch.optim.adam.Adam": "torch.optim.adam.Adam_class",
"torch.optim.adamax.adamax": "torch.optim.adamax.adamax_function",
"torch.optim.adamax.Adamax": "torch.optim.adamax.Adamax_class",
"torch.mtia.stream": "torch.mtia.stream_function",
"torch.mtia.Stream": "torch.mtia.Stream_class",
"torch.cpu.stream": "torch.cpu.stream_function",
"torch.cpu.Stream": "torch.cpu.Stream_class",
"torch.cuda.stream": "torch.cuda.stream_function",
"torch.cuda.Stream": "torch.cuda.Stream_class",
"torch.xpu.stream": "torch.xpu.stream_function",
"torch.xpu.Stream": "torch.xpu.Stream_class",
}
coverage_ignore_functions = [
# torch
@ -3195,6 +3230,11 @@ autodoc_type_aliases = {
# Enable overriding of function signatures in the first line of the docstring.
autodoc_docstring_signature = True
# Exclude inherited IntEnum methods that have RST formatting issues in their docstrings
autodoc_default_options = {
"exclude-members": "from_bytes, to_bytes",
}
# -- katex javascript in header
#
# def setup(app):

View File

@ -253,7 +253,6 @@ regular full-precision tensor.
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
view
as_strided

View File

@ -630,6 +630,37 @@ def mirror_files_into_torchgen() -> None:
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
def mirror_inductor_external_kernels() -> None:
"""
Copy external kernels into Inductor so they are importable.
"""
paths = [
(
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
CWD
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
),
]
for new_path, orig_path in paths:
# Create the dirs involved in new_path if they don't exist
if not new_path.exists():
new_path.parent.mkdir(parents=True, exist_ok=True)
# Copy the files from the orig location to the new location
if orig_path.is_file():
shutil.copyfile(orig_path, new_path)
continue
if orig_path.is_dir():
if new_path.exists():
# copytree fails if the tree exists already, so remove it.
shutil.rmtree(new_path)
shutil.copytree(orig_path, new_path)
continue
raise RuntimeError(
"Check the file paths in `mirror_inductor_external_kernels()`"
)
# ATTENTION: THIS IS AI SLOP
def extract_variant_from_version(version: str) -> str:
"""Extract variant from version string, defaulting to 'cpu'."""
@ -1616,6 +1647,8 @@ def main() -> None:
if RUN_BUILD_DEPS:
build_deps()
mirror_inductor_external_kernels()
(
ext_modules,
cmdclass,
@ -1649,6 +1682,7 @@ def main() -> None:
"_inductor/codegen/aoti_runtime/*.cpp",
"_inductor/script.ld",
"_inductor/kernel/flex/templates/*.jinja",
"_inductor/kernel/templates/*.jinja",
"_export/serde/*.yaml",
"_export/serde/*.thrift",
"share/cmake/ATen/*.cmake",

View File

@ -1019,6 +1019,28 @@ class DTensorMeshTest(DTensorTestBase):
except ValueError:
self.fail("Unexpected ValueError raised with run_check=False")
@with_comms
def test_as_strided_identity(self):
# Test calling as_strided with the same size/stride/offset as input tensor
# This should be a no-op but currently fails
device_mesh = self.build_device_mesh()
placements = [Shard(0)]
local_tensor = torch.randn(3, 4, device=self.device_type)
dtensor = DTensor.from_local(local_tensor, device_mesh, placements)
# Get the current size, stride, and storage_offset
size = dtensor.size()
stride = dtensor.stride()
storage_offset = dtensor.storage_offset()
# Call as_strided with the exact same parameters
result = dtensor.as_strided(size, stride, storage_offset)
# The result should be identical to the input
self.assertEqual(result.size(), dtensor.size())
self.assertEqual(result.stride(), dtensor.stride())
self.assertEqual(result.to_local(), dtensor.to_local())
DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
DTensorMeshTest,

View File

@ -8,11 +8,21 @@ from torch._dynamo.graph_deduplication import apply_graph_deduplication
from torch._dynamo.graph_utils import _detect_cycles
from torch._dynamo.output_graph import FakeRootModule
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import extract_graph, extract_graph_and_tracker, normalize_gm
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
extract_graph_and_tracker,
normalize_gm,
)
from torch.compiler import allow_in_graph
from torch.utils._ordered_set import OrderedSet
def extract_graph(fn, *args, **kwargs):
backend = AotEagerAndRecordGraphs()
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
return result, backend.graphs, backend.fw_graphs
def graph_str(gm):
return normalize_gm(gm.print_readable(print_output=False))
@ -30,7 +40,7 @@ class GraphDededuplicationTests(TestCase):
super().tearDown()
def run_and_return_graphs(self, fn, *args, **kwargs):
return extract_graph(fn, *args, **kwargs)[0:3]
return extract_graph(fn, *args, **kwargs)
def run_and_get_simple_graph(self):
def fn(x, y):

View File

@ -69,6 +69,7 @@ from torch.fx.experimental.symbolic_shapes import (
constrain_unify,
ConstraintViolationError,
expect_true,
guard_or_false,
guard_size_oblivious,
ShapeEnv,
)
@ -100,7 +101,6 @@ from torch.testing._internal.common_utils import (
wrapDeterministicFlagAPITest,
)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.logging_utils import logs_to_string
pytree_modules = {
@ -13636,6 +13636,74 @@ instantiate_device_type_tests(
)
class DynamoOpPromotionTests(torch._dynamo.test_case.TestCase):
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
def test_symbool_tensor_mul(self):
def symbool_mul_fn(x_bool, sentinel):
result = x_bool * sentinel
return result
x_true = torch.tensor([True], device="cuda")
x_false = torch.tensor([False], device="cuda")
sentinel = torch.tensor(2.0, requires_grad=True, device="cuda")
eager_result_true = symbool_mul_fn(x_true, sentinel)
eager_result_false = symbool_mul_fn(x_false, sentinel)
compiled_fn = torch.compile(symbool_mul_fn, fullgraph=True, dynamic=True)
compiled_result_true = compiled_fn(x_true, sentinel)
compiled_result_false = compiled_fn(x_false, sentinel)
self.assertEqual(eager_result_true, compiled_result_true)
self.assertEqual(eager_result_false, compiled_result_false)
self.assertEqual(compiled_result_true.item(), 2.0)
self.assertEqual(compiled_result_false.item(), 0.0)
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
def test_symbool_guard_or_false(self):
def symbool_guard_fn(a_bool_tensor, b):
u0 = a_bool_tensor.item()
# Make sure guard_or_false still handles SymBool produced by .item()
if guard_or_false(u0):
return b * 10
else:
return b * 100
compiled_guard_fn = torch.compile(
symbool_guard_fn, backend="eager", dynamic=True
)
a_true = torch.tensor(True, device="cuda")
a_false = torch.tensor(False, device="cuda")
b = torch.randn(6, device="cuda")
eager_res_true = symbool_guard_fn(a_true, b)
compiled_res_true = compiled_guard_fn(a_true, b)
self.assertEqual(eager_res_true, compiled_res_true)
eager_res_false = symbool_guard_fn(a_false, b)
compiled_res_false = compiled_guard_fn(a_false, b)
self.assertEqual(eager_res_false, compiled_res_false)
self.assertEqual(compiled_res_true, b * 10)
self.assertEqual(compiled_res_false, b * 100)
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
def test_symbool_tensor_mul_does_not_fail(self):
def fuzzed_program(arg_0, sentinel):
var_node_2 = arg_0
var_node_1 = torch.squeeze(var_node_2)
var_node_0 = var_node_1.item()
result = var_node_0 * sentinel
if result.is_complex():
result = result.real
return result
sentinel = torch.tensor(1.0, requires_grad=True, device="cuda")
arg_0 = torch.tensor([True], dtype=torch.bool, device="cuda")
args = (arg_0,) + (sentinel,)
try:
compiled_program = torch.compile(
fuzzed_program, fullgraph=True, dynamic=True
)
compiled_program(*args)
except Exception as e:
self.fail(f"torch.compile failed with error: {e}")
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1,13 +1,11 @@
# Owner(s): ["module: dynamo"]
import functools
import re
import unittest
import weakref
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import extract_graph, remove_trailing_space
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import requires_cuda
@ -17,14 +15,6 @@ requires_multigpu = functools.partial(
)
def remove_file_comment(gm_str: str) -> str:
return remove_trailing_space(re.sub(r"File.*\n", "\n", gm_str))
def print_graph(graph: torch.fx.GraphModule) -> str:
return remove_file_comment(graph.print_readable())
class TestStreams(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
@ -46,7 +36,9 @@ class TestStreams(torch._dynamo.test_case.TestCase):
@requires_cuda
def test_stream_enter_exit(self):
def fn(x, y, s1, s2):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
@ -55,36 +47,13 @@ class TestStreams(torch._dynamo.test_case.TestCase):
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(), torch.Stream())
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
fn_opt = torch.compile(fn, fullgraph=True)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
""",
)
@requires_cuda
@unittest.skip("Needs graph break support with annotation context")
def test_stream_context_graph_break(self):
def fn(x, y):
s2 = torch.Stream()
@ -101,16 +70,9 @@ class <lambda>(torch.nn.Module):
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
fn_opt = torch.compile(fn)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
self.assertEqual(len(fw_graphs), 2)
self.assertExpectedInline(print_graph(fw_graphs[0]), """""")
self.assertExpectedInline(print_graph(fw_graphs[1]), """""")
@requires_cuda
def test_stream_input(self):
@ -193,248 +155,22 @@ class <lambda>(torch.nn.Module):
self.assertEqual(s_act, s_exp)
def test_nested_stream_enter_exit(self):
def fn(x, y, s0, s1, s2):
with s1:
with s2:
z1 = torch.add(x, y)
with s0:
z0 = torch.add(x, y)
with s2:
y = 2 + z1
pass
return z0, y
inp = (
torch.ones(2, 2) + 1,
torch.ones(2, 2),
torch.Stream(),
torch.Stream(),
torch.Stream(),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
)
@unittest.skip("Needs graph break support with annotation context")
def test_stream_enter_exit_graph_break(self):
pass
@unittest.skip("Needs graph break support with annotation context")
def test_nested_stream_enter_exit_graph_break(self):
pass
def test_local_stream_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
z = torch.add(x, y)
y = z + 2 + z1
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 1}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': 0}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
""",
)
pass
def test_local_stream_nested_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
s0 = torch.Stream()
with s1:
with s2:
z1 = torch.add(x, y)
with s0:
z0 = torch.add(x, y)
with s2:
y = 2 + z1
return z0, y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': 0}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
)
pass
def test_stream_with_mutation(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
s0 = torch.Stream()
with s1:
with s2:
x.add_(y)
with s0:
z1 = torch.add(y, y)
z0 = torch.add(z1, y)
with s2:
y = 2 + z1
return z0, y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1)
# Annotation: {'stream': 2}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, arg1_1); arg1_1 = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
#
copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
return (add_2, add_3)
""",
)
def test_stream_backward(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
with s0:
y0 = 2 * x + y
with s2:
z = 2 * x + y
return y0, z
inp = (
torch.ones(2, 2, requires_grad=True) + 1,
torch.ones(2, 2, requires_grad=True),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
bw_graphs,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
# Annotation: {'stream': 1}
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
return (add, add_1)
""",
)
actual[1].sum().backward()
self.assertExpectedInline(
print_graph(bw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
# Annotation: {'stream': 0}
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
#
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
#
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)
""",
)
pass
@requires_cuda
def test_run_opcheck(self):

View File

@ -20,8 +20,14 @@ from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
skipIf,
skipXPUIf,
)
from torch.testing._internal.common_utils import (
parametrize,
run_tests,
TEST_WITH_SLOW,
TestCase,
)
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
from torch.testing._internal.inductor_utils import IS_BIG_GPU
@ -382,7 +388,11 @@ class TestAnalysis(TestCase):
verify_triton(comp_omni)
@skipIf(not SM80OrLater, "Requires SM80")
@skipIf(
(not torch.xpu.is_available()) and (not SM80OrLater),
"Requires XPU or CUDA SM80",
)
@skipXPUIf(TEST_WITH_SLOW, "Skip because test too slow on XPU")
@dtypes(torch.float, torch.float16)
@parametrize(
"maxat",
@ -467,6 +477,7 @@ class TestAnalysis(TestCase):
"aten::cudnn_convolution",
"aten::convolution",
"aten::_convolution",
"aten::convolution_overrideable",
)
)
or "conv" in name

View File

@ -4,6 +4,7 @@ import os
import tempfile
from threading import Event
import torch._inductor.config as config
from torch._inductor.compile_worker.subproc_pool import (
raise_testexc,
SubprocException,
@ -16,9 +17,12 @@ from torch.testing._internal.inductor_utils import HAS_CPU
class TestCompileWorker(TestCase):
def make_pool(self, size):
return SubprocPool(size)
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_basic_jobs(self):
pool = SubprocPool(2)
pool = self.make_pool(2)
try:
a = pool.submit(operator.add, 100, 1)
b = pool.submit(operator.sub, 100, 1)
@ -29,7 +33,7 @@ class TestCompileWorker(TestCase):
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_exception(self):
pool = SubprocPool(2)
pool = self.make_pool(2)
try:
a = pool.submit(raise_testexc)
with self.assertRaisesRegex(
@ -42,7 +46,7 @@ class TestCompileWorker(TestCase):
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_crash(self):
pool = SubprocPool(2)
pool = self.make_pool(2)
try:
with self.assertRaises(Exception):
a = pool.submit(os._exit, 1)
@ -58,7 +62,7 @@ class TestCompileWorker(TestCase):
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_quiesce(self):
pool = SubprocPool(2)
pool = self.make_pool(2)
try:
a = pool.submit(operator.add, 100, 1)
pool.quiesce()
@ -75,7 +79,7 @@ class TestCompileWorker(TestCase):
os.environ["ROLE_RANK"] = "0"
with tempfile.NamedTemporaryFile(delete=True) as temp_log:
os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name
pool = SubprocPool(2)
pool = self.make_pool(2)
try:
pool.submit(operator.add, 100, 1)
self.assertEqual(os.path.exists(temp_log.name), True)
@ -83,6 +87,12 @@ class TestCompileWorker(TestCase):
pool.shutdown()
@config.patch("quiesce_async_compile_time", 0.1)
class TestCompileWorkerWithTimer(TestCompileWorker):
def make_pool(self, size):
return SubprocPool(size, quiesce=True)
class TestTimer(TestCase):
def test_basics(self):
done = Event()

View File

@ -0,0 +1,154 @@
# Owner(s): ["module: inductor"]
import unittest
import torch
from torch import Tensor
from torch._inductor import config
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
from torch._inductor.utils import ensure_cute_available
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
@unittest.skipIf(
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
"CuTeDSL library or Blackwell device not available",
)
@instantiate_parametrized_tests
class TestCuTeDSLGroupedGemm(InductorTestCase):
def _get_inputs(
self,
group_size: int,
M_hint: int,
K: int,
N: int,
device: str,
dtype: torch.dtype,
alignment: int = 16,
) -> tuple[Tensor, Tensor, Tensor]:
# --- Random, tile-aligned M sizes ---
M_sizes = (
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
* alignment
)
M_total = torch.sum(M_sizes).item()
# --- Construct input tensors ---
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
# --- Build offsets (no leading zero, strictly increasing) ---
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
return (A, B, offsets)
@parametrize("group_size", (2, 8))
@parametrize("M_hint", (256, 1024))
@parametrize("K", (64, 128))
@parametrize("N", (128, 256))
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
device = "cuda"
dtype = torch.bfloat16
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
def grouped_gemm_fn(A_packed, B_batched, offs):
return torch._grouped_mm(A_packed, B_batched, offs=offs)
# Eager execution
c_eager = grouped_gemm_fn(A, B, offsets)
# Test with Cute backend
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTEDSL",
"test_configs.autotune_choice_name_regex": "cutedsl",
"autotune_fallback_to_aten": False,
}
):
grouped_gemm_compiled = torch.compile(
grouped_gemm_fn, backend="inductor", dynamic=False
)
c_compiled = grouped_gemm_compiled(A, B, offsets)
self.assertEqual(c_eager.dtype, dtype)
self.assertEqual(c_compiled.dtype, dtype)
torch.testing.assert_close(c_eager, c_compiled)
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
@parametrize("layout_B", ("contiguous", "broadcasted"))
def test_grouped_gemm_assorted_layouts(
self,
layout_A: str,
layout_B: str,
):
device = "cuda"
dtype = torch.bfloat16
G, K, N = 8, 64, 128
M_sizes = [128] * G
sum_M = sum(M_sizes)
offsets = torch.tensor(
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
)
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
A = A_base
if layout_A == "offset":
# allocate bigger buffer than needed, use nonzero storage offset
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
offset = 128 # skip first 128 elements
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
elif layout_A == "padded":
# simulate row pitch > K (row_stride = K + pad)
row_pitch = K + 8
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
elif layout_A == "view":
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
A = A_storage.view(sum_M, K)
assert A._base is not None
assert A.shape == (sum_M, K)
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
if layout_B == "broadcasted":
# Broadcast B across groups (zero stride along G)
B = B[0].expand(G, K, N)
assert B.stride(0) == 0
def grouped_gemm_fn(A_packed, B_batched, offs):
return torch._grouped_mm(A_packed, B_batched, offs=offs)
# --- eager ---
c_eager = grouped_gemm_fn(A, B, offsets)
# --- compiled (CUTE backend) ---
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTEDSL",
"test_configs.autotune_choice_name_regex": "cutedsl",
"autotune_fallback_to_aten": False,
}
):
grouped_gemm_compiled = torch.compile(
grouped_gemm_fn, backend="inductor", dynamic=False
)
c_compiled = grouped_gemm_compiled(A, B, offsets)
self.assertEqual(c_eager.dtype, dtype)
self.assertEqual(c_compiled.dtype, dtype)
torch.testing.assert_close(c_eager, c_compiled)
if __name__ == "__main__":
run_tests()

View File

@ -15,9 +15,8 @@ from torch.testing._internal.common_utils import (
is_navi3_arch,
parametrize,
patch_test_members,
TEST_XPU,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA_AND_TRITON
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU_AND_TRITON
from torch.testing._internal.triton_utils import requires_gpu
@ -61,11 +60,6 @@ class TestDecomposeAddMM(torch.nn.Module):
@requires_gpu
@unittest.skipIf(
TEST_XPU,
"Intel GPU has not enabled decompose_mem_bound_mm PASS in "
"torch/_inductor/fx_passes/decompose_mem_bound_mm.py",
)
@torch._inductor.config.patch(
post_grad_fusion_options={
"decompose_mm_pass": {},
@ -144,7 +138,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_bmm"],
expected_val,
@ -155,7 +149,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_parameters(module, traced)
self.compare_gradients(module, traced)
expected_val = 3 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 3 if should_decompose and HAS_GPU_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_bmm"],
expected_val,
@ -204,7 +198,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
if has_bias:
self.assertEqual(
counters["inductor"]["decompose_addmm"],
@ -259,7 +253,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
if has_bias:
self.assertEqual(
counters["inductor"]["decompose_addmm"],
@ -304,7 +298,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_mm"],
expected_val,
@ -316,7 +310,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_parameters(module, traced)
self.compare_gradients(module, traced)
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
expected_val,
@ -374,7 +368,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_mm"],
expected_val,
@ -386,7 +380,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_parameters(module, traced)
self.compare_gradients(module, traced)
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
expected_val,
@ -410,7 +404,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
if has_bias:
self.assertEqual(
counters["inductor"]["decompose_addmm"],
@ -424,7 +418,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_gradients(module, traced)
expected_val = 0
if HAS_CUDA_AND_TRITON:
if HAS_GPU_AND_TRITON:
expected_val = 1 if has_bias else 2
self.assertEqual(
@ -447,12 +441,8 @@ class TestDecomposeMemMM(TestCase):
_, code = run_and_get_code(foo, input1, input2)
if GPU_TYPE == "xpu":
# only 1 kernel generated on the XPU stack
FileCheck().check_count(".run(", 1, exactly=True).run(code[0])
else:
# two kernels generated
FileCheck().check_count(".run(", 2, exactly=True).run(code[0])
# two kernels generated
FileCheck().check_count(".run(", 2, exactly=True).run(code[0])
def test_check_device(self):
m = 5
@ -462,7 +452,7 @@ class TestDecomposeMemMM(TestCase):
input1 = torch.randn(m, k, device=GPU_TYPE)
input2 = torch.randn(k, n, device=GPU_TYPE)
self.assertTrue(check_device(input1, input2))
self.assertTrue(check_device(input1, input2, device=GPU_TYPE))
self.assertFalse(check_device(input1, input2, device="cpu"))
input1 = torch.randn(m, k)

View File

@ -500,8 +500,13 @@ class PaddingTest(TestCaseBase):
forward_wrapper = wrapper_codes[0]
# make sure the load for softmax is aligned
if bias:
# addmm -> mm + bias and bias is fused with softmax
softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)"
else:
softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)"
self.assertTrue(
"tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper,
softmax_load_str in forward_wrapper,
f"forward_wrapper: {forward_wrapper}",
)

View File

@ -15280,7 +15280,7 @@ if RUN_GPU:
),
(
fn3,
"triton_poi_fused_native_layer_norm_relu",
"triton_poi_fused_addmm_native_layer_norm",
(torch.randn(4, 4, device=GPU_TYPE),),
),
]
@ -15293,7 +15293,7 @@ if RUN_GPU:
),
(
fn3,
"triton_poi_fused_LayerNorm_ReLU",
"triton_poi_fused_LayerNorm_Linear_ReLU",
(torch.randn(4, 4, device=GPU_TYPE),),
),
]

View File

@ -1826,9 +1826,14 @@ def run_test_module(
test_name = test.name
# Printing the date here can help diagnose which tests are slow
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]")
start = time.perf_counter()
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}][{start}]")
handler = CUSTOM_HANDLERS.get(test_name, run_test)
return_code = handler(test, test_directory, options)
end = time.perf_counter()
print_to_stderr(
f"Finished {str(test)} ... [{datetime.now()}][{end}], took {(end - start) / 60:.2f}min"
)
assert isinstance(return_code, int) and not isinstance(return_code, bool), (
f"While running {str(test)} got non integer return code {return_code}"
)

View File

@ -7413,6 +7413,140 @@ class TestCudaDeviceParametrized(TestCase):
)
class TestFXMemoryProfiler(TestCase):
"""Tests for memory profiler augmentation with original stack traces."""
def collect_frames(
self, augmented_snapshot, collect_device_traces=True, collect_segments=True
):
"""Collects all frames that has node metadata from a memory snapshot."""
# Collect all frames with FX metadata
fx_frames = []
# Check device traces for FX debug fields
if collect_device_traces and "device_traces" in augmented_snapshot:
for trace_list in augmented_snapshot["device_traces"]:
for trace_entry in trace_list:
if isinstance(trace_entry, dict) and "frames" in trace_entry:
for frame in trace_entry["frames"]:
if isinstance(frame, dict):
# Check for FX debug fields
if "fx_node_op" in frame or "fx_node_name" in frame:
fx_frames.append(frame)
# Check segments/blocks for FX debug fields
if collect_segments and "segments" in augmented_snapshot:
for segment in augmented_snapshot["segments"]:
if "blocks" in segment:
for block in segment["blocks"]:
if "frames" in block:
for frame in block["frames"]:
if isinstance(frame, dict):
if "fx_node_op" in frame or "fx_node_name" in frame:
fx_frames.append(frame)
return fx_frames
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_fx_memory_profiler_augmentation(self):
"""Test that memory snapshots are augmented with FX debug information."""
# Create a simple model
class MLPModule(nn.Module):
def __init__(self, device):
super().__init__()
torch.manual_seed(5)
self.net1 = nn.Linear(10, 16, bias=True, device=device)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 10, bias=True, device=device)
def forward(self, x):
a = self.net1(x)
b = self.relu(a)
c = self.net2(b)
return c
device = "cuda"
mod = MLPModule(device)
with tempfile.TemporaryDirectory() as tmpdir:
torch.cuda.memory._record_memory_history()
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
result = compiled(torch.randn(10, 10, device=device))
augmented_snapshot = torch.cuda.memory._snapshot(
augment_with_fx_traces=True
)
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
torch.cuda.empty_cache()
fx_frames = self.collect_frames(augmented_snapshot)
if TEST_WITH_ROCM:
self.assertGreater(len(fx_frames), 0)
else:
self.assertEqual(len(fx_frames), 12)
for frame in fx_frames:
# Every FX frame should have both node_op and node_name
self.assertIn("fx_node_op", frame)
self.assertIn("fx_node_name", frame)
self.assertIn("fx_node_target", frame)
self.assertIn("fx_original_trace", frame)
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
fx_node_name = frame["fx_node_name"]
if fx_node_name == "addmm":
self.assertIn("a = self.net1(x)", frame["fx_original_trace"])
elif fx_node_name == "addmm_1":
self.assertIn("c = self.net2(b)", frame["fx_original_trace"])
elif fx_node_name == "relu":
self.assertIn("b = self.relu(a)", frame["fx_original_trace"])
# Test that when we have two graphs with the same src_code, they're not hashed
# to the same metadata
class MLPModule2(nn.Module):
def __init__(self, device):
super().__init__()
torch.manual_seed(5)
self.net1 = nn.Linear(10, 16, bias=True, device=device)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 10, bias=True, device=device)
def forward(self, x):
d = self.net1(x)
e = self.relu(d)
f = self.net2(e)
return f
mod = MLPModule2(device)
with tempfile.TemporaryDirectory() as tmpdir:
torch.cuda.memory._record_memory_history()
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
result = compiled(torch.randn(10, 10, device=device))
augmented_snapshot = torch.cuda.memory._snapshot(
augment_with_fx_traces=True
)
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
# avoid collecting segments from previous run for unit test purpose
fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False)
self.assertGreater(len(fx_frames), 0)
for frame in fx_frames:
# Every FX frame should have both node_op and node_name
self.assertIn("fx_node_op", frame)
self.assertIn("fx_node_name", frame)
self.assertIn("fx_node_target", frame)
self.assertIn("fx_original_trace", frame)
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
fx_node_name = frame["fx_node_name"]
if fx_node_name == "addmm":
self.assertIn("d = self.net1(x)", frame["fx_original_trace"])
elif fx_node_name == "addmm_1":
self.assertIn("f = self.net2(e)", frame["fx_original_trace"])
elif fx_node_name == "relu":
self.assertIn("e = self.relu(d)", frame["fx_original_trace"])
instantiate_parametrized_tests(TestCuda)
instantiate_parametrized_tests(TestCudaMallocAsync)
instantiate_parametrized_tests(TestCompileKernel)

View File

@ -771,6 +771,7 @@ class TestFX(JitTestCase):
gm = GraphModule(tracer.root, graph)
expected = {1: 2, 2: 3, 3: 4, 4: 5}
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
self.assertEqual(gm._prologue_start, 4)
# test custom codegen
def transform_code(code):
@ -780,6 +781,7 @@ class TestFX(JitTestCase):
gm.recompile()
expected = {2: 2, 3: 3, 4: 4, 5: 5}
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
self.assertEqual(gm._prologue_start, 4)
def test_graph_unique_names_manual(self):
graph: torch.fx.Graph = torch.fx.Graph()

View File

@ -11,7 +11,7 @@ from typing import Optional
import torch
from torch.nn.functional import scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType
from torch.nn.functional import pad, scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType
from torch.testing._internal.common_cuda import (
IS_SM90,
_get_torch_cuda_version,
@ -107,11 +107,76 @@ def tensor_to_scale_block(
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
scale = torch.finfo(float8_dtype).max / amax
# if amax == 0, entire block = 0, set scale = 0 to ensure elements are
# zero'd out correctly (and remove bad effects of / 0)
scale[amax == 0] = 0
# Scale x, noting that blocks where amax == 0 are explicitly 0 now.
x = x.mul(scale).to(float8_dtype)
# if amax == 0, all values in the block are 0, scale=0
# but we need scale.reciprocal later, which breaks when scale=0...
# So. Replace 0 -> 1 in the scale so we don't break things later.
# Elements are already zeroed, so don't actually care what the scale
# is, as long as it's not inf/nan.
scale[scale == 0] = 1.
x = x.flatten(2, 3).flatten(0, 1)
scale = scale.flatten(2, 3).flatten(0, 1)
return x, scale
def hp_from_128x128(x_lp, x_scale):
orig_shape = x_lp.shape
M, K = orig_shape
x_lp = x_lp.view(M // 128, 128, K // 128, 128)
x_scale = x_scale.unsqueeze(1).unsqueeze(-1)
x_hp = x_lp.to(torch.float32)
x_hp = x_hp / x_scale
return x_hp.reshape(orig_shape).to(torch.bfloat16)
def hp_to_128x128(x, x_scale):
orig_shape = x.shape
M, K = orig_shape
x = x.view(M // 128, 128, K // 128, 128)
x_scale = x_scale.unsqueeze(1).unsqueeze(-1)
x_lp = x * x_scale
return x_lp.reshape(orig_shape).to(torch.float8_e4m3fn)
def hp_from_1x128(x_lp, x_scale):
orig_shape = x_lp.shape
x_lp = x_lp.reshape(x_lp.shape[0], x_lp.shape[-1] // 128, 128)
x_hp = x_lp.to(torch.float32)
x_hp = x_hp / x_scale.unsqueeze(-1)
return x_hp.reshape(orig_shape).to(torch.bfloat16)
def hp_to_1x128(x, x_scale):
orig_shape = x.shape
x = x.reshape(x.shape[0], x.shape[-1] // 128, 128)
x_lp = x * x_scale.unsqueeze(-1)
return x_lp.reshape(orig_shape).to(torch.float8_e4m3fn)
# cublas requires specific padding for 128x128 scales, see:
# https://docs.nvidia.com/cuda/cublas/#element-1d-and-128x128-2d-block-scaling-for-fp8-data-types
# Notably L = ceil_div(K, 128),
# L4 = round_up(L, 4),
# and then for A/B the shape must be
# scale: [L4, ceil_div({M,N}, 128) and K/L/L4-major in memory.
#
# This routine pads L -> L4
def _pad_128x128_scales(scale: torch.Tensor) -> (torch.Tensor, int):
# scale is either [L4, ceil_div(M, 128)] or [L4, ceil_div(N, 128)], stride: [1, L4]
# However, we get passed it as [ceil_div(M, 128), L] or [ceil_div(N, 128), L]
# so check inner dim % 4, and pad if necessary
pad_amount = scale.shape[-1] % 4
if pad_amount == 0:
return scale, 0
else:
pad_amount = 4 - pad_amount
return pad(scale, (0, pad_amount), "constant", 0), pad_amount
def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
@ -144,42 +209,36 @@ def infer_scale_swizzle(mat, scale):
] == math.ceil(mat.shape[1] // 128):
return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE
# if we're checking for nvfp4, need to adjust for packed-K
K_multiplier = 2 if mat.dtype == torch.float4_e2m1fn_x2 else 1
# NVFP4
if (
(scale.numel()
== round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4)
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 16), 4)
or scale.numel()
== round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4))
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 16), 4))
and mat.dtype == torch.float4_e2m1fn_x2
and scale.dtype == torch.float8_e4m3fn
):
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
# MXFP4 w/o swizzle
if (
(scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1]
or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0])
and mat.dtype == torch.float4_e2m1fn_x2
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
# MX formats
if not torch.version.hip:
# MXFP8 w/ swizzle
# MX w/swizzle (NVIDIA)
if (
(scale.numel()
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 32), 4)
or scale.numel()
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4))
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 32), 4))
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
else:
# MXFP8 w/o swizzle
# MX w/o swizzle (AMD)
if (
(scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0])
(scale.numel() == math.ceil(mat.shape[0] // 32) * K_multiplier * mat.shape[1]
or scale.numel() == math.ceil(K_multiplier * mat.shape[1] // 32) * mat.shape[0])
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
@ -1252,7 +1311,6 @@ class TestFP8Matmul(TestCase):
else:
test()
# Note: Removed parameterization over M,N,K from #163829 as it failed tests as-is
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")
@unittest.skipIf(
@ -1261,59 +1319,224 @@ class TestFP8Matmul(TestCase):
)
@parametrize("output_dtype", [torch.bfloat16, torch.float32])
@parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
@parametrize("M,N,K", [(256, 768, 512)])
@with_tf32_off
def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block, M, N, K):
@parametrize("M,N,K", [
# Nice size
(256, 768, 512),
# Requires padding for 128x128 scale
(384, 128, 1280),
# M=N=K for eyes test
(512, 512, 512),
])
@parametrize("test_case", [
"x_eye_b_eye",
"x_ones_y_ones_calc_scales",
"x_ones_y_ones_set_scales",
"x_ones_y_ones_modify_scales",
"data_random_scales_one",
"data_random_calc_scales",
])
def test_scaled_mm_block_wise_numerics(self, output_dtype, lhs_block, rhs_block, M, N, K, test_case):
"""
subsume test_scaled_mm_vs_emulated_block_wise for random inputs, random scales,
do some other functional tests as well.
# Inputs (as generated are):
# A: [M, K]
# B: [N, K]
# then scales are, for the 3 combinations:
# 1x128 x 1x128:
# As: [M, K // 128], stride: [1, M] -> scale.t().contiguous().t()
# Bs: [N, K // 128], stride: [1, N] -> scale.t().contiguous().t()
# 1x128 x 128x128
# L4 = round_up(K // 128, 4)
# As: [M, K // 128], stride: [1, M] -> scale.t().contiguous().t()
# Bs: [L4, N // 128], stride: [1, L4] -> scale.t()
# 128x128 x 1x128
# L4 = round_up(K // 128, 4)
# As: [L4, M // 128], stride: [1, L4]
# Bs: [N, K // 128], stride: [1, N]
"""
torch.manual_seed(42)
x = torch.randn(M, K, device="cuda", dtype=output_dtype).pow(3)
y = torch.randn(N, K, device="cuda", dtype=output_dtype).pow(3)
def _adjust_lhs_scale(x_fp8, x_scales, lhs_block):
M, K = x_fp8.shape
x_scales_original = x_scales.clone()
# 1x128 blocks need scales to be outer-dim-major
if lhs_block == 1:
x_scales = x_scales.t().contiguous().t()
lhs_recipe = ScalingType.BlockWise1x128
assert (x_scales.shape[0] == M and x_scales.shape[1] == K // 128), f"{x_scales.shape=}"
assert (x_scales.stride(0) == 1 and x_scales.stride(1) in [1, M]), f"{x_scales.stride=}"
x_hp = hp_from_1x128(x_fp8, x_scales_original)
else:
lhs_recipe = ScalingType.BlockWise128x128
x_scales, pad_amount = _pad_128x128_scales(x_scales)
# scales in [M // 128, L4] -> [L4, M // 128]
x_scales = x_scales.t()
x_hp = hp_from_128x128(x_fp8, x_scales_original)
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
return x_hp, lhs_recipe, x_scales, x_scales_original
# 1x128 blocks need scales to be outer-dim-major
if lhs_block == 1:
x_scales = x_scales.t().contiguous().t()
lhs_recipe = ScalingType.BlockWise1x128
def _adjust_rhs_scale(y_fp8, y_scales, rhs_block):
N, K = y_fp8.shape
y_scales_original = y_scales.clone()
if rhs_block == 1:
y_scales = y_scales.t().contiguous().t()
rhs_recipe = ScalingType.BlockWise1x128
assert (y_scales.shape[0] == N and y_scales.shape[1] == K // 128), f"{y_scales.shape=}"
assert (y_scales.stride(0) == 1 and y_scales.stride(1) in [1, N]), f"{y_scales.stride=}"
y_hp = hp_from_1x128(y_fp8, y_scales_original)
else:
rhs_recipe = ScalingType.BlockWise128x128
y_scales, pad_amount = _pad_128x128_scales(y_scales)
# Scale in [N // 128, L4] -> [L4, N // 128]
y_scales = y_scales.t()
y_hp = hp_from_128x128(y_fp8, y_scales_original)
return y_hp, rhs_recipe, y_scales, y_scales_original
def _build_lhs(x, lhs_block):
M, K = x.shape
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
x_scales_original = x_scales
x_hp, x_recipe, x_scales, x_scales_original = _adjust_lhs_scale(x_fp8, x_scales, lhs_block)
return x_hp, x_recipe, x_fp8, x_scales, x_scales_original
def _build_rhs(y, rhs_block):
N, K = y.shape
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
y_hp, y_recipe, y_scales, y_scales_original = _adjust_rhs_scale(y_fp8, y_scales, rhs_block)
return y_hp, y_recipe, y_fp8, y_scales, y_scales_original
def _run_test(x_hp, x_recipe, x_fp8, x_scales, x_scales_original,
y_hp, y_recipe, y_fp8, y_scales, y_scales_original):
# Calculate actual F8 mm
out_scaled_mm = scaled_mm_wrap(
x_fp8,
y_fp8.t(),
scale_a=x_scales.reciprocal(),
scale_recipe_a=x_recipe,
# Note: No more .t() on scale_b, not necessary.
scale_b=y_scales.reciprocal(),
scale_recipe_b=y_recipe,
out_dtype=output_dtype,
)
# Calculate emulated F8 mm
out_emulated = mm_float8_emulated_block(
x_fp8,
x_scales_original,
y_fp8.t(),
y_scales_original.t(),
output_dtype
)
cosine_sim = torch.nn.functional.cosine_similarity(
out_emulated.flatten().float(), (x @ y.t()).flatten().float(), dim=0
)
self.assertGreaterEqual(float(cosine_sim), 0.999)
cosine_sim = torch.nn.functional.cosine_similarity(
out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0
)
self.assertGreaterEqual(float(cosine_sim), 0.999)
if output_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 6e-1, 7e-2
else:
atol, rtol = 7e-1, 2e-3
self.assertEqual(out_scaled_mm, out_emulated.to(output_dtype), atol=atol, rtol=rtol)
# One last check against the full-precision reference, to ensure we
# didn't mess up the scaling itself and made the test trivial.
cosine_sim = torch.nn.functional.cosine_similarity(
out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0
)
self.assertGreaterEqual(float(cosine_sim), 0.999)
def _build_constant_scale(t, block, val):
M, K = t.shape
if block == 1:
scale_shape = M, K // 128
else:
scale_shape = M // 128, K // 128
scale = torch.full(scale_shape, val, device='cuda')
return scale
def hp_to_scaled(t, scale, block):
if block == 1:
return hp_to_1x128(t, scale)
else:
return hp_to_128x128(t, scale)
e4m3_type = torch.float8_e4m3fn
if test_case == "x_eye_b_eye":
if M != K or M != N:
return unittest.skip("a_eye_b_eye only defined for M = N = K")
x = torch.eye(M, device='cuda')
y = torch.eye(M, device='cuda')
x_hp, x_recipe, x_fp8, x_scales, x_scales_original = _build_lhs(x, lhs_block)
y_hp, y_recipe, y_fp8, y_scales, y_scales_original = _build_lhs(y, rhs_block)
elif test_case == "x_ones_y_ones_calc_scales":
x = torch.full((M, K), 1.0, device='cuda')
y = torch.full((N, K), 1.0, device='cuda')
x_hp, x_recipe, x_fp8, x_scales, x_scales_original = _build_lhs(x, lhs_block)
y_hp, y_recipe, y_fp8, y_scales, y_scales_original = _build_lhs(y, rhs_block)
elif test_case in ["x_ones_y_ones_set_scales", "x_ones_y_ones_modify_scales"]:
x = torch.full((M, K), 1.0, device='cuda')
y = torch.full((N, K), 1.0, device='cuda')
x_scales = _build_constant_scale(x, lhs_block, 1.)
y_scales = _build_constant_scale(y, rhs_block, 1.)
if "modify" in test_case:
x_scales[0, 0] = 4.
y_scales[-1, -1] = 4.
x_fp8 = hp_to_scaled(x, x_scales, lhs_block)
y_fp8 = hp_to_scaled(y, y_scales, rhs_block)
x_hp, x_recipe, x_scales, x_scales_original = _adjust_lhs_scale(x_fp8, x_scales, lhs_block)
y_hp, y_recipe, y_scales, y_scales_original = _adjust_rhs_scale(y_fp8, y_scales, rhs_block)
elif test_case == "data_random_scales_one":
x = torch.randint(0, 255, (M, K), device='cuda', dtype=torch.uint8).to(torch.bfloat16)
y = torch.randint(0, 255, (N, K), device='cuda', dtype=torch.uint8).to(torch.bfloat16)
x_scales = _build_constant_scale(x, lhs_block, 1.)
y_scales = _build_constant_scale(y, rhs_block, 1.)
x_fp8 = hp_to_scaled(x, x_scales, lhs_block)
y_fp8 = hp_to_scaled(y, y_scales, rhs_block)
x_hp, x_recipe, x_scales, x_scales_original = _adjust_lhs_scale(x_fp8, x_scales, lhs_block)
y_hp, y_recipe, y_scales, y_scales_original = _adjust_rhs_scale(y_fp8, y_scales, rhs_block)
elif test_case == "data_random_calc_scales":
# Note: Old test_scaled_mm_vs_emulated_block_wise test case
x = torch.randn(M, K, device="cuda", dtype=output_dtype)
y = torch.randn(N, K, device="cuda", dtype=output_dtype) * 1e-3
x_hp, x_recipe, x_fp8, x_scales, x_scales_original = _build_lhs(x, lhs_block)
y_hp, y_recipe, y_fp8, y_scales, y_scales_original = _build_lhs(y, rhs_block)
else:
lhs_recipe = ScalingType.BlockWise128x128
if rhs_block == 1:
y_scales = y_scales.t().contiguous().t()
rhs_recipe = ScalingType.BlockWise1x128
else:
rhs_recipe = ScalingType.BlockWise128x128
raise ValueError("Unknown test-case passed")
_run_test(x_hp, x_recipe, x_fp8, x_scales, x_scales_original,
y_hp, y_recipe, y_fp8, y_scales, y_scales_original)
# Calculate actual F8 mm
out_scaled_mm = scaled_mm_wrap(
x_fp8, y_fp8.t(), scale_a=x_scales.reciprocal(), scale_b=y_scales.reciprocal().t(), out_dtype=output_dtype,
scale_recipe_a=lhs_recipe, scale_recipe_b=rhs_recipe
)
# Calculate emulated F8 mm
out_emulated = mm_float8_emulated_block(
x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype
)
cosine_sim = torch.nn.functional.cosine_similarity(
out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0
)
self.assertGreaterEqual(float(cosine_sim), 0.999)
if output_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 6e-1, 7e-2
else:
atol, rtol = 7e-1, 2e-3
self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
# One last check against the full-precision reference, to ensure we
# didn't mess up the scaling itself and made the test trivial.
cosine_sim = torch.nn.functional.cosine_similarity(
out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0
)
self.assertGreaterEqual(float(cosine_sim), 0.999)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")
@ -1335,18 +1558,30 @@ class TestFP8Matmul(TestCase):
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
x_scales_original = x_scales
y_scales_original = y_scales
# 1x128 blocks need scales to be outer-dim-major
if lhs_block == 1:
x_scales = x_scales.t().contiguous().t()
lhs_recipe = ScalingType.BlockWise1x128
assert (x_scales.shape[0] == M and x_scales.shape[1] == K // 128), f"{x_scales.shape=}"
assert (x_scales.stride(0) == 1 and x_scales.stride(1) in [1, M]), f"{x_scales.stride=}"
else:
lhs_recipe = ScalingType.BlockWise128x128
x_scales, pad_amount = _pad_128x128_scales(x_scales)
# scales in [M // 128, L4] -> [L4, M // 128]
x_scales = x_scales.t()
if rhs_block == 1:
y_scales = y_scales.t().contiguous().t()
rhs_recipe = ScalingType.BlockWise1x128
assert (y_scales.shape[0] == N and y_scales.shape[1] == K // 128), f"{y_scales.shape=}"
assert (y_scales.stride(0) == 1 and y_scales.stride(1) in [1, N]), f"{y_scales.stride=}"
else:
rhs_recipe = ScalingType.BlockWise128x128
y_scales, pad_amount = _pad_128x128_scales(y_scales)
# Scale in [N // 128, L4] -> [L4, N // 128]
y_scales = y_scales.t()
# Verify that actual F8 mm doesn't error
scaled_mm_wrap(
@ -1354,13 +1589,20 @@ class TestFP8Matmul(TestCase):
y_fp8.t(),
scale_a=x_scales,
scale_recipe_a=lhs_recipe,
scale_b=y_scales.t(),
# Note: No more .t() on scale_b, not necessary.
scale_b=y_scales,
scale_recipe_b=rhs_recipe,
out_dtype=output_dtype,
)
# Verify that emulated F8 mm doesn't error
mm_float8_emulated_block(x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype)
mm_float8_emulated_block(
x_fp8,
x_scales_original,
y_fp8.t(),
y_scales_original.t(),
output_dtype
)
@skipIfRocm
@onlyCUDA
@ -1620,7 +1862,7 @@ class TestFP8Matmul(TestCase):
(127, 96, 1024),
(1025, 128, 96)
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
@ -1634,8 +1876,12 @@ class TestFP8Matmul(TestCase):
if not (M % 16 == 0 and K % 128 == 0 and N % 16 == 0):
raise unittest.SkipTest("M and N must be multiples of 16 and K must be multiple of 128 on ROCm, skipping")
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
fp4_scaling_dtype = torch.float8_e8m0fnu if recipe == "mxfp4" else torch.float8_e4m3fn
BLOCK_SIZE = 16 if recipe == "nvfp4" else 32
if K % BLOCK_SIZE != 0:
raise unittest.SkipTest(f"K ({K}) must be divisible by BLOCK_SIZE ({BLOCK_SIZE}), skipping")
require_exact_match = True
approx_match_sqnr_target = 22.0
@ -1813,7 +2059,7 @@ class TestFP8Matmul(TestCase):
B = B.clamp(min=min_val, max=max_val)
B = _bfloat16_to_float4_e2m1fn_x2(B)
approx_match_sqnr_target = 15 if torch.version.hip else 15.8
approx_match_sqnr_target = 15 if recipe == "mxfp4" else 15.8
C_ref = A_ref @ B_ref.t()

View File

@ -14,10 +14,8 @@ from torch.testing import make_tensor
from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyXPU,
OpDTypes,
ops,
skipXPUIf,
)
from torch.testing._internal.common_methods_invocations import ops_and_refs
from torch.testing._internal.common_utils import (
@ -74,6 +72,8 @@ _xpu_computation_ops = [
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
class TestXpu(TestCase):
expandable_segments = False
def test_device_behavior(self):
current_device = torch.xpu.current_device()
torch.xpu.set_device(current_device)
@ -385,56 +385,6 @@ if __name__ == "__main__":
torch.xpu.set_rng_state(g_state0)
self.assertEqual(2024, torch.xpu.initial_seed())
@onlyXPU
@suppress_warnings
@ops(_xpu_computation_ops, dtypes=any_common_cpu_xpu_one)
def test_compare_cpu(self, device, dtype, op):
def to_cpu(arg):
if isinstance(arg, torch.Tensor):
return arg.to(device="cpu")
return arg
samples = op.reference_inputs(device, dtype)
for sample in samples:
cpu_sample = sample.transform(to_cpu)
xpu_results = op(sample.input, *sample.args, **sample.kwargs)
cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
xpu_results = sample.output_process_fn_grad(xpu_results)
cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
# Lower tolerance because we are running this as a `@slowTest`
# Don't want the periodic tests to fail frequently
self.assertEqual(xpu_results, cpu_results, atol=1e-4, rtol=1e-4)
@onlyXPU
@ops(_xpu_computation_ops, allowed_dtypes=(torch.bool,))
def test_non_standard_bool_values(self, device, dtype, op):
# Test boolean values other than 0x00 and 0x01 (gh-54789)
def convert_boolean_tensors(x):
if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
return x
# Map False -> 0 and True -> Random value in [2, 255]
true_vals = torch.randint(
2, 255, x.shape, dtype=torch.uint8, device=x.device
)
false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
x_int = torch.where(x, true_vals, false_vals)
ret = x_int.view(torch.bool)
self.assertEqual(ret, x)
return ret
for sample in op.sample_inputs(device, dtype):
expect = op(sample.input, *sample.args, **sample.kwargs)
transformed = sample.transform(convert_boolean_tensors)
actual = op(transformed.input, *transformed.args, **transformed.kwargs)
self.assertEqual(expect, actual)
def test_serialization_array_with_storage(self):
x = torch.randn(5, 5).xpu()
y = torch.zeros(2, 5, dtype=torch.int, device="xpu")
@ -470,6 +420,8 @@ if __name__ == "__main__":
self.assertEqual(copy.get_device(), original.get_device())
def test_out_of_memory(self):
if self.expandable_segments:
self.skipTest("Skipping OOM test for expandable segments allocator.")
tensor = torch.zeros(1024, device="xpu") # noqa: F841
with self.assertRaisesRegex(RuntimeError, "Tried to allocate 800000000.00 GiB"):
@ -479,6 +431,8 @@ if __name__ == "__main__":
torch.empty(1024 * 1024 * 1024 * 8000000000, dtype=torch.int8, device="xpu")
def test_raises_oom(self):
if self.expandable_segments:
self.skipTest("Skipping OOM test for expandable segments allocator.")
torch.xpu.memory.empty_cache()
with self.assertRaises(torch.OutOfMemoryError):
torch.empty(1024 * 1024 * 1024 * 1024, device="xpu")
@ -591,7 +545,7 @@ if __name__ == "__main__":
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
@skipXPUIf(
@unittest.skipIf(
int(torch.version.xpu) < 20250000,
"Test requires SYCL compiler version 2025.0.0 or newer.",
)
@ -639,6 +593,8 @@ if __name__ == "__main__":
self.assertTrue(b"libsycl.so" in result)
def test_dlpack_conversion(self):
if self.expandable_segments:
self.skipTest("Skipping DLPack test for expandable segments allocator.")
x = make_tensor((5,), dtype=torch.float32, device="xpu")
if IS_WINDOWS and int(torch.version.xpu) < 20250000:
with self.assertRaisesRegex(
@ -652,7 +608,58 @@ if __name__ == "__main__":
self.assertEqual(z, x)
instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True)
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
class TestXpuOps(TestCase):
@suppress_warnings
@ops(_xpu_computation_ops, dtypes=any_common_cpu_xpu_one)
def test_compare_cpu(self, device, dtype, op):
def to_cpu(arg):
if isinstance(arg, torch.Tensor):
return arg.to(device="cpu")
return arg
samples = op.reference_inputs(device, dtype)
for sample in samples:
cpu_sample = sample.transform(to_cpu)
xpu_results = op(sample.input, *sample.args, **sample.kwargs)
cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
xpu_results = sample.output_process_fn_grad(xpu_results)
cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
# Lower tolerance because we are running this as a `@slowTest`
# Don't want the periodic tests to fail frequently
self.assertEqual(xpu_results, cpu_results, atol=1e-4, rtol=1e-4)
@ops(_xpu_computation_ops, allowed_dtypes=(torch.bool,))
def test_non_standard_bool_values(self, device, dtype, op):
# Test boolean values other than 0x00 and 0x01 (gh-54789)
def convert_boolean_tensors(x):
if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
return x
# Map False -> 0 and True -> Random value in [2, 255]
true_vals = torch.randint(
2, 255, x.shape, dtype=torch.uint8, device=x.device
)
false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
x_int = torch.where(x, true_vals, false_vals)
ret = x_int.view(torch.bool)
self.assertEqual(ret, x)
return ret
for sample in op.sample_inputs(device, dtype):
expect = op(sample.input, *sample.args, **sample.kwargs)
transformed = sample.transform(convert_boolean_tensors)
actual = op(transformed.input, *transformed.args, **transformed.kwargs)
self.assertEqual(expect, actual)
instantiate_device_type_tests(TestXpuOps, globals(), only_for="xpu", allow_xpu=True)
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")

View File

@ -0,0 +1,26 @@
# Owner(s): ["module: intel"]
import pathlib
import sys
from test_xpu import TestXpu, TestXpuOpsXPU # noqa: F401
import torch
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))
from tools.stats.import_test_stats import get_disabled_tests
sys.path.remove(str(REPO_ROOT))
if __name__ == "__main__":
if torch.xpu.is_available() and not IS_WINDOWS:
get_disabled_tests(".")
torch._C._accelerator_setAllocatorSettings("expandable_segments:True")
TestXpu.expandable_segments = True
run_tests()

View File

@ -814,6 +814,15 @@ class SymBool:
# Force specialization
return hash(builtins.bool(self))
def __sym_float__(self):
"""
Provides a SymFloat representation (0.0 or 1.0) for this SymBool.
Called by torch.sym_float() when casting SymBool to float.
"""
from torch.fx.experimental.sym_node import wrap_node
return wrap_node(self.node.sym_float())
def sym_not(a):
r"""SymInt-aware utility for logical negation.

View File

@ -739,6 +739,12 @@ enable_aot_compile = False
# HACK: this is for testing custom ops profiling only
_custom_ops_profile: Optional[Any] = None
# Experimental: If True, graph module will register fx metadata during recompile()
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
default=False,
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
)
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -87,12 +87,6 @@ def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-d
return gm.graph, region_tracker # type: ignore[union-attr]
def extract_graph(fn, *args, **kwargs): # type: ignore[no-untyped-def]
backend = AotEagerAndRecordGraphs()
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
return result, backend.graphs, backend.fw_graphs, backend.bw_graphs
def collect_results(
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
) -> list[Any]:

View File

@ -76,7 +76,9 @@ def _slow_conv2d_adapter(
return conv_adapter(tuple(tmp), tuple(tmp2))
@register_adapter(["convolution", "_convolution", "cudnn_convolution"])
@register_adapter(
["convolution", "_convolution", "cudnn_convolution", "convolution_overrideable"]
)
def conv_adapter(
shapes: tuple[Any, ...], concrete: tuple[Any, ...]
) -> tuple[tuple[Any], dict[Any, Any]]:

View File

@ -969,7 +969,8 @@ class TritonCSEVariable(CSEVariable):
# We'll use this to track which masks the variable needs when used for indirect indexing
self.mask_vars: OrderedSet[str] = OrderedSet()
assert dtype is not None, "TritonCSEVariable must have dtype"
assert shape is not None, "TritonCSEVariable must have shape"
# TODO: uncomment this and fix the few failures left
# assert shape is not None, "TritonCSEVariable must have shape"
def update_on_args(self, name, args, kwargs):
for arg in args:

View File

@ -423,6 +423,10 @@ def estimate_nccl_collective_runtime_from_fx_node(
from torch.distributed.distributed_c10d import _resolve_process_group
pg = _resolve_process_group(group_name)
if torch.distributed.distributed_c10d.get_backend(pg) == "fake":
# nccl estimator requires real process group
return None
fn = fx_node.target
assert isinstance(fn, torch._ops.OpOverload)
with torch.distributed._time_estimator(group=pg) as time_estimator:

View File

@ -24,6 +24,7 @@ from typing_extensions import Never, ParamSpec
import torch._thread_safe_fork # noqa: F401
from torch._inductor import config
from torch._inductor.codecache import torch_key
from torch._inductor.compile_worker.timer import Timer
from torch._inductor.compile_worker.tracked_process_pool import (
TrackedProcessPoolExecutor,
)
@ -132,6 +133,7 @@ class SubprocPool:
nprocs: int,
pickler: Optional[SubprocPickler] = None,
kind: SubprocKind = SubprocKind.FORK,
quiesce: bool = False,
) -> None:
entry = os.path.join(os.path.dirname(__file__), "__main__.py")
self.pickler = pickler or SubprocPickler()
@ -216,6 +218,13 @@ class SubprocPool:
"pytorch.wait_counter.subproc_pool.first_job"
).guard()
if quiesce:
self.timer: Optional[Timer] = Timer(
config.quiesce_async_compile_time, self.quiesce
)
else:
self.timer = None
# Start thread last to ensure all member variables are initialized
# before any access.
self.read_thread.start()
@ -288,6 +297,8 @@ class SubprocPool:
with self.futures_lock:
if not self.running:
return
if self.timer:
self.timer.record_call()
if isinstance(result, _SubprocExceptionInfo):
# An exception occurred in the submitted job
self.pending_futures[job_id].set_exception(
@ -322,6 +333,8 @@ class SubprocPool:
with self.write_lock:
if not self.running:
return
if self.timer:
self.timer.quit()
self.running = False
self.running_waitcounter.__exit__()
_send_msg(self.write_pipe, MsgHeader.SHUTDOWN)

View File

@ -17,7 +17,7 @@ class Timer:
self.background_thread: Optional[Thread] = None
self.last_called: Optional[float] = None
self.duration = duration
self.sleep_time = 60
self.sleep_time = duration / 2
self.call = call
self.exit = False

View File

@ -546,6 +546,10 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
).upper() # type: ignore[assignment]
cutedsl_enable_autotuning: bool = (
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
)
# DEPRECATED. This setting is ignored.
autotune_fallback_to_aten = False
@ -960,6 +964,11 @@ quiesce_async_compile_pool: bool = Config(
default=False,
)
# Time in seconds to wait before quiescing
quiesce_async_compile_time: int = Config(
default=60,
)
# Whether or not to enable statically launching CUDA kernels
# compiled by triton (instead of using triton's own launcher)
use_static_cuda_launcher: bool = static_cuda_launcher_default()

View File

@ -66,7 +66,9 @@ def should_decompose_bmm(mat1, mat2) -> bool:
return False
if len(mat1.shape) != 3 or len(mat2.shape) != 3:
return False
if check_device(mat1, mat2, device="cuda"):
if check_device(mat1, mat2, device="cuda") or check_device(
mat1, mat2, device="xpu"
):
if mat1.shape[0] < min_first_dimension_decomposition:
return False
# 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION
@ -130,7 +132,10 @@ def should_decompose_mm(mat1, mat2) -> bool:
"skip_dynamic_shape_dim_check", False
):
return (
check_device(mat1, mat2, device="cuda")
(
check_device(mat1, mat2, device="cuda")
or check_device(mat1, mat2, device="xpu")
)
and statically_known_true(
mat1.shape[0] >= min_first_dimension_decomposition
)
@ -151,7 +156,10 @@ def should_decompose_mm(mat1, mat2) -> bool:
# case 2: we decompose mm if the input is dynamic shape
else:
return (
check_device(mat1, mat2, device="cuda")
(
check_device(mat1, mat2, device="cuda")
or check_device(mat1, mat2, device="xpu")
)
and (
statically_known_true(
mat1.shape[0] >= min_first_dimension_decomposition

View File

@ -51,8 +51,8 @@ from ..utils import (
decode_device,
get_all_devices,
get_gpu_type,
has_uses_tagged_as,
is_gpu,
is_pointwise_use,
OPTIMUS_EXCLUDE_POST_GRAD,
)
from ..virtualized import V
@ -1510,8 +1510,10 @@ def should_prefer_unfused_addmm(match):
if not is_gpu(inp.meta["val"].device.type):
return False
output = match.output_node()
return all(is_pointwise_use(use) for use in output.users)
return has_uses_tagged_as(
match.output_node(),
(torch.Tag.pointwise, torch.Tag.reduction),
)
@register_graph_pattern(

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs
import logging
from collections.abc import Sequence
from functools import partial
from pathlib import Path
from typing import Any
import torch
@ -12,6 +14,7 @@ from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
from .. import config
from ..codegen.wrapper import PythonWrapperCodegen
from ..ir import _IntLike, Layout, TensorBox
from ..utils import load_template
log = logging.getLogger(__name__)
@ -254,3 +257,7 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool:
return False
return True
_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates"
load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR)

View File

@ -1,10 +1,11 @@
# mypy: allow-untyped-defs
import logging
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import Any, Optional
import torch
from torch._dynamo.utils import counters
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
from torch._inductor.runtime.triton_compat import tl
from torch._inductor.virtualized import V
from torch.utils._triton import has_triton
@ -18,19 +19,25 @@ from ..select_algorithm import (
TritonTemplate,
)
from ..utils import (
ensure_cute_available,
get_gpu_shared_memory,
get_num_sms,
has_free_symbols,
use_aten_gemm_kernels,
use_blackwell_cutedsl_grouped_mm,
use_triton_template,
)
from .mm_common import (
_is_static_problem,
check_supported_striding,
load_kernel_template,
persistent_grouped_mm_grid,
)
if ensure_cute_available():
from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs
log = logging.getLogger(__name__)
aten = torch.ops.aten
@ -513,6 +520,11 @@ triton_scaled_grouped_mm_template = TritonTemplate(
source=triton_grouped_mm_source,
)
cutedsl_grouped_mm_template = CuteDSLTemplate(
name="grouped_gemm_cutedsl",
source=load_kernel_template("cutedsl_mm_grouped"),
)
def grouped_mm_args(
mat1: TensorBox,
@ -714,43 +726,44 @@ def _tuned_grouped_mm_common(
# Checking only for the equality of corresponding dims of
# multiplicands here, relying on meta function checks for
# everything else.
if len(m1_size) == 2:
if len(m2_size) == 2:
m, k1 = m1_size
k2, _ = m2_size
# pyrefly: ignore [missing-attribute]
g = offs.get_size()[0]
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, True
else:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, False
else:
if len(m2_size) == 2:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
g2, m, k1 = m1_size
k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, True
else:
g1, m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, False
if (
is_nonzero
and use_triton_template(layout)
and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
):
scaled = scale_a is not None
if len(m1_size) == 2:
if len(m2_size) == 2:
m, k1 = m1_size
k2, _ = m2_size
# pyrefly: ignore [missing-attribute]
g = offs.get_size()[0]
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, True
else:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, False
else:
if len(m2_size) == 2:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
g2, m, k1 = m1_size
k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, True
else:
g1, m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, False
a_is_k_major = mat_a.get_stride()[-1] == 1
b_is_k_major = mat_b.get_stride()[-2] == 1
@ -788,6 +801,22 @@ def _tuned_grouped_mm_common(
**config.kwargs,
)
if use_blackwell_cutedsl_grouped_mm(
mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result
):
for config in get_groupgemm_configs():
kwargs = dict(
ACC_DTYPE="cutlass.Float32",
)
cutedsl_grouped_mm_template.maybe_append_choice(
choices,
input_nodes=input_nodes,
layout=layout,
**kwargs,
**asdict(config),
)
input_gen_fns = {
4: lambda x: create_offsets(
x, m1_size, m2_size, offs.get_size() if offs is not None else None

View File

@ -0,0 +1,333 @@
import functools
from torch._inductor.runtime.runtime_utils import ceildiv
from cutlass.utils import TensorMapUpdateMode
{{gen_defines()}}
# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ----
from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import (
GroupedGemmKernel,
)
# Note about caching:
# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor
# maintains its own local caching system. At this stage, all compile-time
# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel
# name itself ({{kernel_name}}) are permanently baked into the file, so they
# do not need to be included in any cache key.
#
# The caching mechanism is split into two levels:
#
# 1. prep_cache
# Caches the compiled executor for build_group_ptrs_from_bases(). This
# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C,
# and can therefore be safely reused across runs with different group
# partitioning (`offs`).
#
# 2. gemm_cache
# Caches the compiled Grouped GEMM executor. Its key extends the prep
# cache key with hardware- and grid-specific parameters:
# (prep_cache_key, max_active_clusters, total_num_clusters).
# This is necessary because different `offs` tensors can change the
# per-group problem sizes and thus alter `total_num_clusters`, which in
# turn changes the grid shape and persistent scheduler configuration.
# Kernels compiled for one grid cannot be safely reused for another.
#
#
# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically,
# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead,
# despite depending only on the GPU type. We cache this function to mitigate
# redundant recompiles even when shape/stride/dtype cache misses force kernel
# regeneration. A follow-up study will investigate the root cause.
prep_cache = {}
gemm_cache = {}
@functools.lru_cache
def get_hardware_info():
hw = cutlass.utils.HardwareInfo()
sm_count = hw.get_max_active_clusters(1)
max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N)
return (sm_count, max_active_clusters)
def get_prep_cache_key(input_a, input_b, output):
"""
Returns a tuple key for caching the preprocessing kernel executor based on kernel name,
shapes, strides, and dtypes of input/output tensors.
"""
return (
tuple(input_a.shape),
tuple(input_a.stride()),
input_a.dtype,
tuple(input_b.shape),
tuple(input_b.stride()),
input_b.dtype,
tuple(output.shape),
tuple(output.stride()),
output.dtype,
)
def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters):
"""
Returns a tuple key for caching the gemm kernel executor by extending the
prep cache key with hardware- and grid-specific parameters.
"""
return (
prep_cache_key,
max_active_clusters,
total_num_clusters,
)
@cute.kernel
def build_group_ptrs_from_bases_kernel(
base_A_u64: cutlass.Int64, # device addr of input_a (bytes)
base_B_u64: cutlass.Int64, # device addr of input_b (bytes)
base_C_u64: cutlass.Int64, # device addr of Output (bytes)
offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative
K: cutlass.Constexpr,
N: cutlass.Constexpr,
sizeof_element: cutlass.Int32, # bytes
# -------- STRIDES (in ELEMENTS) --------
stride_A_m_elems: cutlass.Constexpr, # A.stride(0)
stride_A_k_elems: cutlass.Constexpr, # A.stride(1)
stride_B0_elems: cutlass.Constexpr, # B.stride(0)
stride_Bk_elems: cutlass.Constexpr, # B.stride(1)
stride_Bn_elems: cutlass.Constexpr, # B.stride(2)
stride_C_m_elems: cutlass.Constexpr, # C.stride(0)
stride_C_n_elems: cutlass.Constexpr, # C.stride(1)
# -------- OUTPUTS --------
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr)
out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1)
out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]]
):
tidx, _, _ = cute.arch.thread_idx()
g = tidx
m_beg_i32 = 0
if g > 0:
m_beg_i32 = offs[g - 1]
m_end_i32 = offs[g]
m_g_i32 = m_end_i32 - m_beg_i32
a_byte_off = (
cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element)
)
c_byte_off = (
cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element)
)
b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element)
# ---- pointers ----
out_ptrs[g, 0] = base_A_u64 + a_byte_off
out_ptrs[g, 1] = base_B_u64 + b_byte_off
out_ptrs[g, 2] = base_C_u64 + c_byte_off
# ---- (m, n, k, 1) ----
out_problem[g, 0] = m_g_i32
out_problem[g, 1] = N
out_problem[g, 2] = K
out_problem[g, 3] = cutlass.Int32(1)
# ---- strides ----
out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems)
out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems)
out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems)
out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems)
out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems)
out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems)
@cute.jit
def launch_build_group_ptrs_from_bases(
base_A_u64: cutlass.Int64,
base_B_u64: cutlass.Int64,
base_C_u64: cutlass.Int64,
offs: cute.Tensor,
G: cutlass.Constexpr,
K: cutlass.Constexpr,
N: cutlass.Constexpr,
sizeof_element: cutlass.Constexpr,
stride_A_m_elems: cutlass.Constexpr,
stride_A_k_elems: cutlass.Constexpr,
stride_B0_elems: cutlass.Constexpr,
stride_Bk_elems: cutlass.Constexpr,
stride_Bn_elems: cutlass.Constexpr,
stride_C_m_elems: cutlass.Constexpr,
stride_C_n_elems: cutlass.Constexpr,
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64
out_problem: cute.Tensor, # [G,4] cutlass.Int32
out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32
stream: cuda.CUstream,
):
build_group_ptrs_from_bases_kernel(
base_A_u64,
base_B_u64,
base_C_u64,
offs,
K,
N,
sizeof_element,
stride_A_m_elems,
stride_A_k_elems,
stride_B0_elems,
stride_Bk_elems,
stride_Bn_elems,
stride_C_m_elems,
stride_C_n_elems,
out_ptrs,
out_problem,
out_strides_abc,
).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream)
{{def_kernel("input_a", "input_b", "input_a_offs")}}
stream = cuda.CUstream(stream)
input_b = input_b.transpose(1, 2)
sumM, K = input_a.shape
G, N, Kb = input_b.shape
dev = input_a.device
base_A_u64 = int(input_a.data_ptr())
base_B_u64 = int(input_b.data_ptr())
base_C_u64 = int({{get_output()}}.data_ptr())
ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64)
probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32)
strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32)
ptrs = from_dlpack(ptrs_t)
probs = from_dlpack(probs_t)
strides = from_dlpack(strides_t)
prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}})
prep_executor = prep_cache.get(prep_cache_key)
if prep_executor is None:
sizeof_element = int(input_a.element_size())
sA_m, sA_k = map(int, input_a.stride())
sB_0, sB_n, sB_k = map(int, input_b.stride())
sC_m, sC_n = map(int, {{get_output()}}.stride())
prep_executor = cute.compile(
launch_build_group_ptrs_from_bases,
base_A_u64=base_A_u64,
base_B_u64=base_B_u64,
base_C_u64=base_C_u64,
offs=from_dlpack(input_a_offs),
G=int(G),
K=int(K),
N=int(N),
sizeof_element=sizeof_element,
stride_A_m_elems=sA_m,
stride_A_k_elems=sA_k,
stride_B0_elems=sB_0,
stride_Bk_elems=sB_k,
stride_Bn_elems=sB_n,
stride_C_m_elems=sC_m,
stride_C_n_elems=sC_n,
out_ptrs=ptrs,
out_problem=probs,
out_strides_abc=strides,
stream=stream,
)
prep_cache[prep_cache_key] = prep_executor
prep_executor(
base_A_u64=base_A_u64,
base_B_u64=base_B_u64,
base_C_u64=base_C_u64,
offs=from_dlpack(input_a_offs),
out_ptrs=ptrs,
out_problem=probs,
out_strides_abc=strides,
stream=stream,
)
# --- Tensormap workspace per SM ---
num_tensormap_buffers, max_active_clusters = get_hardware_info()
tensormap_shape = (
num_tensormap_buffers,
GroupedGemmKernel.num_tensormaps,
GroupedGemmKernel.bytes_per_tensormap // 8,
)
tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64)
tensormap_workspace = from_dlpack(tensormap_workspace_t)
# --- Total clusters ---
def compute_total_num_clusters(
problem_sizes_mnkl,
cluster_tile_shape_mn,
):
total_num_clusters = 0
for m, n, _, _ in problem_sizes_mnkl:
num_clusters_mn = tuple(
ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn)
)
total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
return total_num_clusters
# Compute cluster tile shape
def compute_cluster_tile_shape(
mma_tiler_mn,
cluster_shape_mn,
use_2cta_instrs,
):
cta_tile_shape_mn = list(mma_tiler_mn)
if use_2cta_instrs:
cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2
return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))
cluster_tile_shape_mn = compute_cluster_tile_shape(
(TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA)
)
total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn))
gemm_cache_key = get_gemm_cache_key(
prep_cache_key, max_active_clusters, total_num_clusters
)
gemm_executor = gemm_cache.get(gemm_cache_key)
if gemm_executor is None:
grouped_gemm = GroupedGemmKernel(
acc_dtype=ACC_DTYPE,
use_2cta_instrs=USE_2_CTA,
mma_tiler_mn=(TILE_M, TILE_N),
cluster_shape_mn=(CLUSTER_M, CLUSTER_N),
tensormap_update_mode=TENSORMAP_UPDATE_MODE,
)
gemm_executor = cute.compile(
grouped_gemm,
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
G,
probs,
strides,
ptrs,
total_num_clusters,
tensormap_workspace,
max_active_clusters,
stream,
)
gemm_cache[gemm_cache_key] = gemm_executor
gemm_executor(
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
probs,
strides,
ptrs,
tensormap_workspace,
stream,
)

View File

@ -0,0 +1,141 @@
from dataclasses import dataclass
from enum import auto, Enum
from itertools import product
import torch._inductor.config as config
class TensorMapUpdateMode(Enum):
"""Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency."""
SMEM = auto()
GMEM = auto()
@dataclass(frozen=True)
class CuTeGemmConfig:
TILE_M: int = 128
TILE_N: int = 192
CLUSTER_M: int = 2
CLUSTER_N: int = 1
USE_2_CTA: bool = False
TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM
def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
For information regarding valid config sets, see:
https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py
"""
# Tile_n is always the same regardless of 2cta
tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256]
# Valid clusters
clusters_no_2cta = [
(1, 1),
(1, 2),
(1, 4),
(1, 8),
(1, 16),
(2, 1),
(2, 2),
(2, 4),
(2, 8),
(4, 1),
(4, 2),
(4, 4),
(8, 1),
(8, 2),
(16, 1),
]
clusters_2cta = [
(2, 1),
(2, 2),
(2, 4),
(2, 8),
(4, 1),
(4, 2),
(4, 4),
(8, 1),
(8, 2),
(16, 1),
]
configs: list[CuTeGemmConfig] = []
for use_2cta, cluster_set, tile_m_range in [
(False, clusters_no_2cta, [64, 128]),
(True, clusters_2cta, [128, 256]),
]:
for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product(
[TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM],
tile_m_range,
tile_n_vals,
cluster_set,
):
configs.append(
CuTeGemmConfig(
tile_m,
tile_n,
cluster_m,
cluster_n,
USE_2_CTA=use_2cta,
TENSORMAP_UPDATE_MODE=tensormap_update_mode,
)
)
return configs
def get_default_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
"""
config_tuples = [
(128, 256, 2, 1, False, TensorMapUpdateMode.SMEM),
(256, 160, 2, 1, True, TensorMapUpdateMode.GMEM),
(256, 256, 2, 1, True, TensorMapUpdateMode.GMEM),
(64, 32, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
(128, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
(256, 256, 2, 2, True, TensorMapUpdateMode.GMEM),
(128, 256, 1, 2, False, TensorMapUpdateMode.GMEM),
(64, 32, 1, 1, False, TensorMapUpdateMode.SMEM),
(256, 256, 2, 1, True, TensorMapUpdateMode.SMEM),
(128, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
(256, 256, 8, 1, True, TensorMapUpdateMode.GMEM),
(64, 32, 1, 2, False, TensorMapUpdateMode.SMEM),
(256, 192, 2, 1, True, TensorMapUpdateMode.GMEM),
(256, 256, 2, 2, True, TensorMapUpdateMode.SMEM),
(128, 96, 1, 2, False, TensorMapUpdateMode.SMEM),
(64, 192, 1, 1, False, TensorMapUpdateMode.SMEM),
(64, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 192, 1, 1, False, TensorMapUpdateMode.GMEM),
(128, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 160, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
]
return [CuTeGemmConfig(*args) for args in config_tuples]
def get_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures
or unstable results. By default, autotuning is disabled and we return only
a single baseline config.
"""
if (
config.cutedsl_enable_autotuning
and config.max_autotune_gemm_search_space == "EXHAUSTIVE"
):
return get_exhaustive_groupgemm_configs()
elif config.cutedsl_enable_autotuning:
return get_default_groupgemm_configs()
else:
return [get_default_groupgemm_configs()[0]]

View File

@ -549,6 +549,70 @@ def is_pointwise_use(
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
class LogicalConnective(enum.Enum):
OR = enum.auto()
AND = enum.auto()
def has_uses(
target: Node,
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
) -> bool:
"""
Given a target, explore the uses of `target` by applying `use_selector_fn`
on them, and then aggregate these booleans with the `use_aggregate_type`
logical connective.
Uses in view ops will follow the views uses.
"""
def get_use_aggregate_fn(
use_aggregate_type: LogicalConnective,
) -> Callable[[Iterator[Any]], bool]:
match use_aggregate_type:
case LogicalConnective.AND:
return all
case LogicalConnective.OR:
return any
case _:
return any
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
def has_uses_impl(use: Node) -> bool:
if use.op != "call_function":
return False
if not (
isinstance(use.target, torch._ops.OpOverload)
or use.target is operator.getitem
):
return False
target = cast(torch._ops.OpOverload, use.target)
# Process getitem and view
if target is operator.getitem or is_view(target):
return use_aggregate_fn(has_uses_impl(user) for user in use.users)
return use_selector_fn(target)
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
def has_uses_tagged_as(
target: Node,
use_tags: Collection[torch.Tag],
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
) -> bool:
"""
Is there a use with given tags?
"""
return has_uses(
target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type
)
def gen_gm_and_inputs(
target: Any, args: list[Any], kwargs: dict[str, Any]
) -> tuple[GraphModule, list[torch.Tensor]]:
@ -1911,6 +1975,77 @@ def use_triton_blackwell_tma_template(
return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
@functools.lru_cache(maxsize=1)
def ensure_cute_available() -> bool:
"""Check if CuTeDSL is importable; cache the result for reuse.
Call ensure_cute_available.cache_clear() after installing CuTeDSL
in the same interpreter to retry the import.
"""
try:
return importlib.util.find_spec("cutlass.cute") is not None
except ImportError:
return False
def use_blackwell_cutedsl_grouped_mm(
mat_a: Any,
mat_b: Any,
layout: Layout,
a_is_2d: bool,
b_is_2d: bool,
offs: Optional[Any],
bias: Optional[Any],
scale_result: Optional[Any],
) -> bool:
"""
Returns True if we can use the blackwell kernel for grouped mm.
Required conditions:
1. CuTeDSL is available
2. We are on a blackwell arch
3. The dtype is bf16
4. Max autotune or max autotune gemm is enabled
6. A, B, and the output are 16B aligned
7. We are not using dynamic shapes
8. A is 2d
9. B is 3d
10. Offsets are provided
11. Bias and Scale are not provided
"""
if not ensure_cute_available():
return False
from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
if not is_gpu(layout.device.type) and is_datacenter_blackwell_arch():
return False
layout_dtypes = [torch.bfloat16]
if not _use_template_for_gpu(layout, layout_dtypes):
return False
if not (config.max_autotune or config.max_autotune_gemm):
return False
# Checks for 16B ptr and stride alignment
if not can_use_tma(mat_a, mat_b, output_layout=layout):
return False
if any(is_dynamic(x) for x in [mat_a, mat_b]):
return False
if not a_is_2d or b_is_2d:
return False
if offs is None:
return False
if bias is not None or scale_result is not None:
return False
return True
def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
from .virtualized import V

View File

@ -31,10 +31,8 @@ template <typename T>
struct FromImpl {
static StableIValue call(
T val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
static_assert(
sizeof(T) <= sizeof(StableIValue),
"StableLibrary stack does not support parameter types larger than 64 bits.");
@ -75,10 +73,8 @@ template <>
struct FromImpl<ScalarType> {
static StableIValue call(
ScalarType val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
switch (val) {
case ScalarType::Byte:
return from(aoti_torch_dtype_uint8());
@ -133,10 +129,8 @@ template <>
struct FromImpl<std::nullopt_t> {
static StableIValue call(
std::nullopt_t val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
return from(nullptr);
}
};
@ -190,10 +184,8 @@ template <>
struct FromImpl<torch::stable::Tensor> {
static StableIValue call(
const torch::stable::Tensor& val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
AtenTensorHandle new_ath;
TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath));
return from(new_ath);
@ -209,10 +201,8 @@ template <typename T>
struct ToImpl {
static T call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
static_assert(std::is_trivially_copyable_v<T>);
// T may not have a default constructor. (For example, it might be
// c10::Device.) However, std::memcpy implicitly creates a T at the
@ -249,10 +239,8 @@ template <>
struct ToImpl<ScalarType> {
static ScalarType call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
int32_t shim_scalartype = to<int32_t>(val);
if (shim_scalartype == aoti_torch_dtype_uint8()) {
return ScalarType::Byte;
@ -309,10 +297,8 @@ template <>
struct ToImpl<std::nullopt_t> {
static std::nullopt_t call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
// val should be equivalent to from(nullptr)
return std::nullopt;
}
@ -350,10 +336,8 @@ template <>
struct ToImpl<torch::stable::Tensor> {
static torch::stable::Tensor call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
return torch::stable::Tensor(to<AtenTensorHandle>(val));
}
};

View File

@ -4,12 +4,14 @@ r"""This package adds support for device memory management implemented in CUDA."
import collections
import contextlib
import ctypes
import os
import pickle
import re
import sys
import warnings
from inspect import signature
from typing import Any, Literal, Optional, TYPE_CHECKING
from typing_extensions import deprecated
from typing import Any, cast, Literal, Optional, TYPE_CHECKING, TypedDict
from typing_extensions import deprecated, NotRequired
import torch
from torch import _C
@ -29,6 +31,60 @@ if TYPE_CHECKING:
from torch.types import Device
# Type definitions for memory profiler
class _Frame(TypedDict):
"""Frame information from memory profiler snapshots."""
filename: str
line: int
name: str
# Fields added by FX augmentation (optional)
fx_node_op: NotRequired[str]
fx_node_name: NotRequired[str]
fx_node_target: NotRequired[str]
fx_original_trace: NotRequired[str]
class _Block(TypedDict):
"""Memory block information."""
size: int
requested_size: int
address: int
state: str
frames: list[_Frame]
class _Segment(TypedDict):
"""Memory segment information."""
address: int
total_size: int
stream: int
segment_type: str
allocated_size: int
active_size: int
blocks: list[_Block]
class _TraceEntry(TypedDict):
"""Memory trace entry information."""
action: str
addr: NotRequired[int]
frames: list[_Frame]
size: int
stream: int
device_free: NotRequired[int]
class _Snapshot(TypedDict):
"""Memory snapshot structure."""
segments: list[_Segment]
device_traces: NotRequired[list[list[_TraceEntry]]]
__all__ = [
"caching_allocator_alloc",
"caching_allocator_delete",
@ -964,7 +1020,120 @@ def _record_memory_history_impl(
_record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined]
def _snapshot(device: "Device" = None):
def _augment_frames(frames: list[_Frame]) -> int:
"""
Augment a list of frames with FX debug information.
Args:
frames: List of frame dictionaries to augment
Returns:
The count of frames that were augmented.
"""
from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX
# Regex pattern to match FX generated files
_FX_GENERATED_PATTERN = re.compile(
rf"{re.escape(FX_GRAPH_MODULE_FILE_PREFIX)}.*\.py$"
)
count = 0
if not frames:
return count
for frame in frames:
if "filename" in frame and "line" in frame:
filename = frame["filename"]
lineno = frame["line"]
# Check if this looks like an FX generated file
if not _FX_GENERATED_PATTERN.search(os.path.basename(filename)):
continue
# Look up metadata from the global registry
from torch.fx.traceback import _FX_METADATA_REGISTRY
metadata = _FX_METADATA_REGISTRY.get(filename)
if metadata is None:
continue
lineno_map = metadata.get("lineno_map", {})
node_metadata = metadata.get("node_metadata", {})
prologue_start = metadata.get("prologue_start", 0)
# Get the node index for this line
node_idx = lineno_map.get(lineno - prologue_start)
if node_idx is not None and node_idx in node_metadata:
node_info = node_metadata[node_idx]
original_trace = node_info.get("stack_trace")
node_op = node_info.get("op")
node_name = node_info.get("name")
node_target = node_info.get("target")
# Always add node metadata
frame["fx_node_op"] = node_op
frame["fx_node_name"] = node_name
frame["fx_node_target"] = str(node_target)
# Add original trace if available
if original_trace:
frame["fx_original_trace"] = original_trace
count += 1
return count
def _augment_memory_snapshot_stack_traces(
snapshot: str | _Snapshot,
) -> _Snapshot:
"""
Augment a memory snapshot with original source stack traces from FX metadata.
IMPORTANT: This function reads from a global in-memory registry (_FX_METADATA_REGISTRY)
that is populated during graph module compilation. It must be called in the same
Python process where the FX graphs were compiled. It cannot be used to augment
snapshots loaded from disk in a different process.
Args:
snapshot: Either a memory snapshot dict or path to a snapshot pickle file
Returns:
The augmented snapshot dictionary with fx_node_op, fx_node_name,
fx_original_trace, and fx_node_info fields added to frames
"""
snapshot_dict: _Snapshot
if isinstance(snapshot, str):
# Load the memory snapshot
with open(snapshot, "rb") as f:
snapshot_dict = cast(_Snapshot, pickle.load(f))
else:
snapshot_dict = snapshot
# Process stack traces in the snapshot
augmented_count = 0
# Process blocks in segments (for regular allocations)
if "segments" in snapshot_dict:
for segment in snapshot_dict["segments"]:
if "blocks" in segment:
for block in segment["blocks"]:
if "frames" in block:
augmented_count += _augment_frames(block["frames"])
# Process device traces (for memory history)
if "device_traces" in snapshot_dict:
for trace_list in snapshot_dict["device_traces"]:
for trace_entry in trace_list:
if isinstance(trace_entry, dict) and "frames" in trace_entry:
augmented_count += _augment_frames(trace_entry["frames"])
return snapshot_dict
def _snapshot(device: "Device" = None, augment_with_fx_traces=False):
"""Save a snapshot of CUDA memory state at the time it was called.
The state is represented as a dictionary with the following structure.
@ -1012,6 +1181,11 @@ def _snapshot(device: "Device" = None):
filename: str
line: int
name: str
# Optional FX debug fields (present when augment_with_fx_traces=True
# and the frame corresponds to FX-generated code)
fx_node_op: str # FX node operation type (e.g., 'call_function', 'output')
fx_node_name: str # FX node name (e.g., 'linear', 'relu_1')
fx_original_trace: str # Original model source code stack trace
class TraceEntry(TypedDict):
@ -1041,13 +1215,23 @@ def _snapshot(device: "Device" = None):
device_free: int # only present for OOM, the amount of
# memory cuda still reports to be free
Args:
device: Device to capture snapshot for. If None, captures for current device.
augment_with_fx_traces: If True, augment stack trace frames with FX debug information
that maps generated FX code back to original model source code.
This adds fx_node_op, fx_node_name, fx_original_trace, and
fx_node_info fields to Frame objects. Default: False.
Returns:
The Snapshot dictionary object
"""
return _C._cuda_memorySnapshot(None)
s = _C._cuda_memorySnapshot(None)
if augment_with_fx_traces:
s = _augment_memory_snapshot_stack_traces(s) # type: ignore[assignment, arg-type]
return s
def _dump_snapshot(filename="dump_snapshot.pickle"):
def _dump_snapshot(filename="dump_snapshot.pickle", augment_with_fx_traces=False):
"""
Save a pickled version of the `torch.memory._snapshot()` dictionary to a file.
@ -1059,8 +1243,14 @@ def _dump_snapshot(filename="dump_snapshot.pickle"):
Args:
filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle".
augment_with_fx_traces (bool, optional): If True, augment the snapshot with FX debug information
before dumping. This maps generated FX code stack traces
back to original model source code. Defaults to False.
verbose (bool, optional): If True and augment_with_fx_traces is True, print verbose debug output
during augmentation. Defaults to False.
"""
s = _snapshot()
s = _snapshot(augment_with_fx_traces=augment_with_fx_traces)
with open(filename, "wb") as f:
pickle.dump(s, f)

View File

@ -9,6 +9,7 @@ import torch
import torch.distributed as dist
import torch.distributed.tensor._api as dtensor
import torch.distributed.tensor._random as random
from torch._library.utils import fill_defaults
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import OpInfo, OpSchema, OutputSpecType
@ -34,6 +35,23 @@ aten = torch.ops.aten
logger = logging.getLogger(__name__)
def as_strided_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: dict[str, object],
):
args, kwargs = fill_defaults(op_call._schema, args, kwargs)
assert not kwargs
tensor, size, stride, storage_offset = args
if (
tensor.size() == tuple(size)
and tensor.stride() == tuple(stride)
and (storage_offset is None or tensor.storage_offset() == storage_offset)
):
return torch.ops.aten.alias.default(tensor)
raise RuntimeError("as_strided not supported with DTensor")
def is_same_size_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
@ -121,6 +139,7 @@ class OpDispatcher:
aten.convolution.default: convolution_handler,
aten.convolution_backward.default: convolution_backward_handler,
aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler,
aten.as_strided.default: as_strided_handler,
}
# This flag is used internally to control whether we treat the torch.Tensor(non-DTensor)

View File

@ -84,6 +84,7 @@ register_op_strategy(
aten.clone.default,
aten.contiguous.default,
aten.detach.default,
aten.alias.default,
aten.fill_.Scalar,
aten.view.dtype,
aten.zero_.default,

View File

@ -226,8 +226,10 @@ class PythonCode:
# Values in global scope during execution of `src_def`.
globals: dict[str, Any]
# Optional mapping from the forward function's line number to
# node index.
# node index. Line number starts at the prologue (i.e. forward()).
_lineno_map: Optional[dict[int, Optional[int]]]
# The line number of prologue in fn_code
_prologue_start: int = 0
def _format_target(base: str, target: str) -> str:
@ -854,7 +856,14 @@ class CodeGen:
{prologue}
{code}"""
return PythonCode(fn_code, globals_, _lineno_map=lineno_map)
# The +4 accounts for the empty lines before prologue in fn_code
prologue_start = wrap_stmts.count("\n") + 4
return PythonCode(
fn_code,
globals_,
_lineno_map=lineno_map,
_prologue_start=prologue_start,
)
# Ideally, we'd like to refactor all of the pytree logic into this codegen

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs
import base64
import contextlib
import copy
import hashlib
import itertools
import linecache
import os
@ -36,6 +38,7 @@ __all__ = [
]
_USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes"
FX_GRAPH_MODULE_FILE_PREFIX = "fx_generated_"
# Normal exec loses the source code, however we can work with
@ -61,7 +64,13 @@ class _EvalCacheLoader:
key = self._get_key()
if co_fields:
key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
if "co_filename" in co_fields:
# If only co_filename is provided, use it directly as the key
if "co_firstlineno" not in co_fields or "co_name" not in co_fields:
key = co_fields["co_filename"]
else:
# Full co_fields with all three components
key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
self.eval_cache[key] = src
# Don't mutate globals so that this loader is only used
@ -353,6 +362,36 @@ def _print_readable(
return output
def _metadata_hash(code: str, node_metadata: dict) -> str:
"""
Create a content-addressed hash from code and metadata.
Args:
code: The source code string
lineno_map: Mapping from line numbers to node indices
node_metadata: Metadata for each node
Returns:
A 51-character base32-encoded hash
"""
import json
# Create a deterministic string representation of all components
# We use JSON to ensure consistent serialization
hash_data = {
"code": code,
"node_metadata": node_metadata,
}
hashing_str = json.dumps(hash_data).encode("utf-8")
# [:51] to strip off the "Q====" suffix common to every hash value.
return (
base64.b32encode(hashlib.sha256(hashing_str).digest())[:51]
.decode("utf-8")
.lower()
)
class _WrappedCall:
def __init__(self, cls, cls_call):
self.cls = cls
@ -825,9 +864,47 @@ class {module_name}(torch.nn.Module):
python_code = self._graph.python_code(root_module="self")
self._code = python_code.src
self._lineno_map = python_code._lineno_map
self._prologue_start = python_code._prologue_start
cls = type(self)
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
from torch._dynamo import config as dynamo_config
if dynamo_config.enrich_profiler_metadata:
# Generate metadata and register for profiler augmentation
node_metadata: dict[int, dict[str, Any]] = {}
for i, node in enumerate(self._graph.nodes):
node_metadata[i] = {
"name": node.name,
"op": node.op,
"target": str(node.target),
"stack_trace": node.meta.get("stack_trace", None),
}
# Generate a content-addressed filename based on hash of code and metadata
# This ensures the same code+metadata always generates the same filename
hash_value = _metadata_hash(self._code, node_metadata)
file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
filename = f"{file_stem}.py"
# Only include co_filename to use it directly as the cache key
co_fields = {
"co_filename": filename,
}
# Store metadata in global in-memory registry
metadata = {
"lineno_map": python_code._lineno_map,
"prologue_start": python_code._prologue_start,
"node_metadata": node_metadata,
}
# Register metadata in the global registry
from torch.fx.traceback import _register_fx_metadata
_register_fx_metadata(filename, metadata)
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
# Determine whether this class explicitly defines a __call__ implementation

View File

@ -38,6 +38,28 @@ current_meta: dict[str, Any] = {}
current_replay_node: Optional[Node] = None
should_preserve_node_meta = False
# =============================================================================
# FX Metadata Registry for Memory Profiler
# =============================================================================
# Global in-memory registry for FX metadata
# Maps module_name -> metadata dict containing lineno_map and node_metadata
_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {}
def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
"""
Register FX metadata in the global in-memory registry.
This is called automatically during graph module compilation to store metadata
for later use by memory profiler augmentation.
Args:
module_name: The module identifier (content-addressed filename)
metadata: Metadata dict containing lineno_map, node_metadata, and source_code
"""
# TODO: add logging to tlparse
_FX_METADATA_REGISTRY[module_name] = metadata
@compatibility(is_backward_compatible=False)
class NodeSourceAction(Enum):

View File

@ -149,7 +149,11 @@ def conv_flop_count(
flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2
return flop
@register_flop_formula([aten.convolution, aten._convolution, aten.cudnn_convolution, aten._slow_conv2d_forward])
@register_flop_formula([aten.convolution,
aten._convolution,
aten.cudnn_convolution,
aten._slow_conv2d_forward,
aten.convolution_overrideable])
def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
"""Count flops for convolution."""
# pyrefly: ignore [bad-argument-type]
@ -582,6 +586,7 @@ flop_registry = {
aten.convolution: conv_flop,
aten._convolution: conv_flop,
aten.cudnn_convolution: conv_flop,
aten.convolution_overrideable: conv_flop,
aten._slow_conv2d_forward: conv_flop,
aten.convolution_backward: conv_backward_flop,
aten._scaled_dot_product_efficient_attention: sdpa_flop,

View File

@ -806,7 +806,29 @@ function format_frames(frames) {
}
const frame_strings = frames
.filter(frameFilter)
.map(f => `${f.filename}:${f.line}:${f.name}`);
.map(f => {
let frame_str = `${f.filename}:${f.line}:${f.name}`;
// Add FX debug information if available
if (f.fx_node_op || f.fx_node_name || f.fx_node_target) {
const fx_parts = [];
if (f.fx_node_name) fx_parts.push(`node=${f.fx_node_name}`);
if (f.fx_node_op) fx_parts.push(`op=${f.fx_node_op}`);
if (f.fx_node_target) fx_parts.push(`target=${f.fx_node_target}`);
frame_str += `\n >> FX: ${fx_parts.join(', ')}`;
}
if (f.fx_original_trace) {
frame_str += `\n >> Original Model Code:`;
const original_lines = f.fx_original_trace.trim().split('\n');
// Show all lines of the original trace
for (const line of original_lines) {
frame_str += `\n ${line}`;
}
}
return frame_str;
});
return elideRepeats(frame_strings).join('\n');
}