Compare commits

...

22 Commits

Author SHA1 Message Date
f36f372acc bwd pass (#164504)
**Summary**
This implements the backward pass for the Varlen API and registers `_varlen_attn()` as a custom op.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings:

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.8189142608642578 ms       | 3.263883056640625 ms  |
| TFLOPs | 268.652       | 158.731  |

We can see that runtime for Varlen is >3x faster

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen gradients vs SDPA.

For custom op testing, `test_custom_op_registration` uses logging mode to verify that `_varlen_attn()` was called and tests with `torch.compile`. `test_custom_op_compliances` uses `torch.library.opcheck()` to verify.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164504
Approved by: https://github.com/drisspg
2025-10-28 22:35:11 +00:00
d9483d4c8d [dynamo] Clean up assert in dynamo [3/N] (#165903)
Some previous PRs have been merged. This PR aims for some **assert** that the users can trigger, and it may be better to turn them into a graph break. Correct me if there are any problems.

* ->#165903(Clean up for graph break)
* #165745
* #165430

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165903
Approved by: https://github.com/williamwen42

Co-authored-by: William Wen <william.wen42@gmail.com>
2025-10-28 22:29:35 +00:00
fea819ed08 added type annotation to _NoParamDecoratorContextManager.__new__ (#166414)
Fixes #166413

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166414
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-10-28 21:59:20 +00:00
84a2715d34 [dynamo] Revert C++-fying of symbolic shape guards (#166427)
Moving symbolic shape guards to C++ causes compile time issues. This basically boils down to a tradeoff question.

For models that have large amount of dynamic shape guards, this flag will help reduce guard latency. But for most of the models, that have a very few dynamic shape guards, the guard lantecy is anyways small. These models will still see a high compile time hit because of calling gcc during the compile.

So a good default value seems to be False. We can write a doc to give guidance on reducing guard latency.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166427
Approved by: https://github.com/zou3519
2025-10-28 21:57:31 +00:00
572cc12b42 Move MaskPartial to placement_types to improve discoverability (#164414)
Had trouble finding this one myself in #163030.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164414
Approved by: https://github.com/ezyang
2025-10-28 21:56:02 +00:00
1fdef664a5 Revert "[Pytorch] Update Kineto Submodule (#166317)"
This reverts commit be283297100ab86123e74b7a8372995d32b140c8.

Reverted https://github.com/pytorch/pytorch/pull/166317 on behalf of https://github.com/jeffdaily due to ROCm CI was clean, but post-merge ROCm failures showed up ([comment](https://github.com/pytorch/pytorch/pull/166317#issuecomment-3458665809))
2025-10-28 21:55:38 +00:00
08ae55021e support batch size=0 for flash attention (#166318)
Fixes #165944

**Summary**

Today, if we attempt to run flash attention with batch_size 0, we get error `Runtime Error: batch size must be positive`. This PR fixes this by returning early with empty tensors in the fwd and bwd.

**Test plan**
`python test/test_transformers.py -k test_scaled_dot_product_attention` - added case for batch_size=0
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166318
Approved by: https://github.com/drisspg
2025-10-28 21:53:48 +00:00
551921d484 Change t.is_cuda to t.device.type == 'cuda' in torch/utils/viz (#156418)
Fixes #156417

Unlike `.is_cuda` the property `.device` is supported by `ShardedTensor`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156418
Approved by: https://github.com/mikaylagawarecki

Co-authored-by: Alexander Zhipa <azzhipa@amazon.com>
2025-10-28 20:34:14 +00:00
b5189e269e NVFP4 grouped gemm support via. FBGEMM kernels (#166308)
Summary:

* Add NVFP4 (1x16 block e4m3, tensor-wise fp32) scaled grouped gemm
* Extend testing to add nvfp4 support

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166308
Approved by: https://github.com/ngimel
2025-10-28 20:32:53 +00:00
3895ce093f [inductor] add in-kernel nan-check (#166008)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166008
Approved by: https://github.com/eellison
2025-10-28 20:19:10 +00:00
8aa087a29d [ez] Fix print for failing test when entire file fails (#166420)
Was previously printing "FAILED CONSISTENTLY: ul" since it was null,
This changes it so it prints the test_file by moving some logic for checking this to be earlier
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166420
Approved by: https://github.com/Skylion007
2025-10-28 20:13:58 +00:00
7379972cc0 Revert "[Inductor] Naive foreach autotune support (#162053)"
This reverts commit cdb60e44eb528bf02c6bb2d7e384298283e755ca.

Reverted https://github.com/pytorch/pytorch/pull/162053 on behalf of https://github.com/xmfan due to Compile time regression ([comment](https://github.com/pytorch/pytorch/pull/162053#issuecomment-3458252331))
2025-10-28 20:01:54 +00:00
b903018c26 [CD] Windows builds migrate python 3.14rc1->3.14.0 (#166408)
Python 3.14 was released, hence we can use official release version now
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166408
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-10-28 19:52:38 +00:00
21b48f8dfa Fixes torch.compile(nn.ModuleList()) changes bool() behavior (#159208)
Fixes #159139
## The Cause

The bug occurs because the OptimizedModule wrapper in torch._dynamo.eval_frame doesn't call the len method. This causes Python's bool() check to fall back to the default object truthiness (always True) instead of correctly evaluating containers with len() == 0 as False.
## The Fix

A very easy fix . I just added the len method to OptimizedModule in torch._dynamo.eval_frame class to delegate the call to the original module
```python
def __len__(self):
    """
    Proxy the len() call to the original module to fix truthiness checks.
    """
    return len(self._orig_mod)
```
This successfully fixes the issue . The script now works as expected.
## Reproduction Script
```python
import torch
import torch.nn as nn

# Create an empty nn.ModuleList
original = nn.ModuleList()

# Compile it using torch.compile
compiled = torch.compile(original)

# Compare their boolean evaluations
print(f"bool(original): {bool(original)}")
print(f"bool(compiled): {bool(compiled)}")

# Trigger failure if they differ
assert bool(original) == bool(compiled), "BUG: truthiness behavior mismatch after compilation"
```
## Output

bool(original): False
bool(compiled): False

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159208
Approved by: https://github.com/andrewboldi, https://github.com/Lucaskabela

Co-authored-by: pushkar-hue <pushkarsharma.rtm@gmail.com>
Co-authored-by: Lucas Kabela <lucasakabela@gmail.com>
2025-10-28 19:21:24 +00:00
009ea77234 Remove not needed code path. (#166278)
I accepted a PR that added this code, but re-examining it now, I'm questioning the approach. It seems like we're working around an issue with the inductor generating incorrect sizes. A comment suggests it might be related to unsqueezed u0 values. Removing this code didn't cause any failures, so I'll take it out and address the root issue if it arises.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166278
Approved by: https://github.com/Lucaskabela
2025-10-28 19:03:22 +00:00
0e46a10aa7 [ONNX] Warn when it's training (#166412)
Fixes #166163

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166412
Approved by: https://github.com/justinchuby
2025-10-28 19:01:05 +00:00
a25818cf7e Fix image display on pypi project description section (#166404)
Fixes https://github.com/pytorch/pytorch/issues/165559

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166404
Approved by: https://github.com/malfet, https://github.com/Skylion007, https://github.com/Camyll
2025-10-28 18:58:24 +00:00
e3e93c7107 [MPS] Fix random in-place ops on non-contiguous tensors (#165267)
Random in-place operations (normal_, uniform_, exponential_, bernoulli_, random_) were silently failing on non-contiguous tensors on macOS < 15.0.

* Added needsGather check and scatter-back logic to handle non-contiguous output tensors, following the pattern used in PointwiseOps.

* Adds test to confirm these now work
* Remove pre-macOS15 xfail for test_Dropout

Fixes #165257 and #124029

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165267
Approved by: https://github.com/kulinseth, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-10-28 18:43:22 +00:00
1abfa5f70b [EZ][MPS] Improve distribution error checking (#166425)
Essentially not allow ops on self-overlapping outputs, by adding
`at::assert_no_internal_overlap(self);` check that already used in CPU
and CUDA builds, see
895795f07c/aten/src/ATen/native/DistributionTemplates.h (L366)

This fixes `test_error_inputs_bernoulli_mps`

Should be landed ahead of https://github.com/pytorch/pytorch/pull/165267
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166425
Approved by: https://github.com/Skylion007, https://github.com/seemethere
2025-10-28 18:42:12 +00:00
687c15c0b3 [AOTI][BE] Change test_aoti_inference to one-pass build (#164277)
Summary: To fix https://github.com/pytorch/pytorch/issues/159400. Currently, test_aoti_abi_check and test_aoti_inference need to be built in two passes, first build pytorch using the regular `pythonsetup.py develop` and then build with `CMAKE_FRESH=1 BUILD_AOT_INDUCTOR_TEST=1 python setup.py devleop`. This is cumbersome. Fix by rewriting CMakeLists.txt for test_aoti_inference to one-pass build which runs AOTI to compile models at the test time. Also update CI test script to get rid of two-pass build. For test_aoti_abi_check, it is not AOTI specific, so we make it not guarded by BUILD_AOT_INDUCTOR_TEST.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164277
Approved by: https://github.com/janeyx99
2025-10-28 17:43:22 +00:00
895795f07c [ROCm][CI] forward fix kineto submodule bump (#166421)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166421
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-28 17:40:23 +00:00
2dc56456cb refactor: pull _replace_node common functionality out of Scheduler.finalize_multi_template_buffers (#163368)
Pull replace_node function out of Scheduler.finalize_multi_template_buffers(). This is needed by the next PR (#163369). As part of this also pull the _replace_operation_buffer() up to top-level since it needed no self references.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163368
Approved by: https://github.com/PaulZhang12
2025-10-28 17:21:52 +00:00
51 changed files with 1486 additions and 583 deletions

View File

@ -100,6 +100,8 @@ COPY ./common/common_utils.sh common_utils.sh
COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt
COPY ci_commit_pins/timm.txt timm.txt
COPY ci_commit_pins/torchbench.txt torchbench.txt
# Only build aoti cpp tests when INDUCTOR_BENCHMARKS is set to True
ENV BUILD_AOT_INDUCTOR_TEST ${INDUCTOR_BENCHMARKS}
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt

View File

@ -460,28 +460,18 @@ test_inductor_shard() {
--verbose
}
test_inductor_aoti() {
# docker build uses bdist_wheel which does not work with test_aot_inductor
# TODO: need a faster way to build
test_inductor_aoti_cpp() {
if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
# We need to hipify before building again
python3 tools/amd_build/build_amd.py
fi
if [[ "$BUILD_ENVIRONMENT" == *sm86* ]]; then
BUILD_COMMAND=(TORCH_CUDA_ARCH_LIST=8.6 USE_FLASH_ATTENTION=OFF python -m pip install --no-build-isolation -v -e .)
# TODO: Replace me completely, as one should not use conda libstdc++, nor need special path to TORCH_LIB
TEST_ENVS=(CPP_TESTS_DIR="${BUILD_BIN_DIR}" LD_LIBRARY_PATH="/opt/conda/envs/py_3.10/lib:${TORCH_LIB_DIR}:${LD_LIBRARY_PATH}")
else
BUILD_COMMAND=(python -m pip install --no-build-isolation -v -e .)
TEST_ENVS=(CPP_TESTS_DIR="${BUILD_BIN_DIR}" LD_LIBRARY_PATH="${TORCH_LIB_DIR}")
fi
# aoti cmake custom command requires `torch` to be installed
# initialize the cmake build cache and install torch
/usr/bin/env "${BUILD_COMMAND[@]}"
# rebuild with the build cache with `BUILD_AOT_INDUCTOR_TEST` enabled
/usr/bin/env CMAKE_FRESH=1 BUILD_AOT_INDUCTOR_TEST=1 "${BUILD_COMMAND[@]}"
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
}
@ -1776,7 +1766,7 @@ elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then
install_torchvision
PYTHONPATH=/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER"
if [[ "$SHARD_NUMBER" -eq "1" ]]; then
test_inductor_aoti
test_inductor_aoti_cpp
fi
elif [[ "${TEST_CONFIG}" == *inductor* ]]; then
install_torchvision

View File

@ -7,12 +7,9 @@ if "%DESIRED_PYTHON%" == "3.13t" (
set "PYTHON_INSTALLER_URL=https://www.python.org/ftp/python/3.13.0/python-3.13.0-amd64.exe"
set ADDITIONAL_OPTIONS="Include_freethreaded=1"
set PYTHON_EXEC="python3.13t"
) else if "%DESIRED_PYTHON%"=="3.14" (
echo Python version is set to 3.14 or 3.14t
set "PYTHON_INSTALLER_URL=https://www.python.org/ftp/python/3.14.0/python-3.14.0rc1-amd64.exe"
) else if "%DESIRED_PYTHON%"=="3.14t" (
echo Python version is set to 3.14 or 3.14t
set "PYTHON_INSTALLER_URL=https://www.python.org/ftp/python/3.14.0/python-3.14.0rc1-amd64.exe"
set "PYTHON_INSTALLER_URL=https://www.python.org/ftp/python/3.14.0/python-3.14.0-amd64.exe"
set ADDITIONAL_OPTIONS="Include_freethreaded=1"
set PYTHON_EXEC="python3.14t"
) else (

View File

@ -1,4 +1,4 @@
![PyTorch Logo](https://github.com/pytorch/pytorch/blob/9708fcf92db88b80b9010c68662d634434da3106/docs/source/_static/img/pytorch-logo-dark.png)
![PyTorch Logo](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/pytorch-logo-dark.png)
--------------------------------------------------------------------------------
@ -72,7 +72,7 @@ Elaborating Further:
If you use NumPy, then you have used Tensors (a.k.a. ndarray).
![Tensor illustration](https://github.com/pytorch/pytorch/blob/9708fcf92db88b80b9010c68662d634434da3106/docs/source/_static/img/tensor_illustration.png)
![Tensor illustration](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/tensor_illustration.png)
PyTorch provides Tensors that can live either on the CPU or the GPU and accelerates the
computation by a huge amount.
@ -99,7 +99,7 @@ from several research papers on this topic, as well as current and past work suc
While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date.
You get the best of speed and flexibility for your crazy research.
![Dynamic graph](https://github.com/pytorch/pytorch/blob/9708fcf92db88b80b9010c68662d634434da3106/docs/source/_static/img/dynamic_graph.gif)
![Dynamic graph](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/dynamic_graph.gif)
### Python First

View File

@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
if(USE_CUDA)
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*mx8mx8bf16_grouped.*")
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
@ -291,6 +291,7 @@ IF(USE_FBGEMM_GENAI)
set(fbgemm_genai_cuh
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/f4f4bf16_grouped/"
"${FBGEMM_GENAI_SRCS}/"
)

View File

@ -208,6 +208,48 @@ _f8_f8_bf16_rowwise_grouped_mm(
#endif
}
Tensor&
_f4_f4_bf16_grouped_mm_fbgemm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& global_scale_a,
const Tensor& scale_b,
const Tensor& global_scale_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
Tensor& out) {
#if !defined(USE_ROCM) && defined(USE_FBGEMM_GENAI)
// Typing checks
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2,
"mat_a must be Float4_e2n1fn_2, got: ", mat_a.scalar_type());
TORCH_CHECK_VALUE(mat_b.scalar_type() == at::kFloat4_e2m1fn_x2,
"mat_b must be Float4_e2n1fn_2, got: ", mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e4m3fn,
"scale_a must be Float8_e4m3fn, got: ", scale_a.scalar_type());
TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e4m3fn,
"scale_b must be Float8_e4m3fn, got: ", scale_b.scalar_type());
TORCH_CHECK_VALUE(global_scale_a.scalar_type() == at::kFloat,
"global_scale_a must be Float, got: ", global_scale_a.scalar_type());
TORCH_CHECK_VALUE(global_scale_b.scalar_type() == at::kFloat,
"global_scale_b must be Float, got: ", global_scale_b.scalar_type());
auto o = fbgemm_gpu::f4f4bf16_grouped_mm(
mat_a,
mat_b,
scale_a,
scale_b,
offs.value(),
out,
global_scale_a.mul(global_scale_b)
);
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "nvfp4 grouped gemm is not supported without USE_FBGEMM_GENAI, and only for CUDA")
#endif
return out;
}
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
// Checks scales for 2d or 3d target tensors (`mat`).
if (mat.dim() == 2) {
@ -245,7 +287,15 @@ void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int
}
}
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
void _check_scales_blocked(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
// if {mx,nv}fp4, will need to modify K later
bool is_fp4 = (mat.scalar_type() == kFloat4_e2m1fn_x2);
int blocksize = 32;
// check for nvfp4 vs. mxfp4 to fix blocksize
if (is_fp4 && scale.scalar_type() == kFloat8_e4m3fn) {
blocksize = 16;
}
// Checks scales for 2d or 3d target tensors (`mat`).
if (mat.dim() == 2) {
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
@ -253,17 +303,19 @@ void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim,
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
TORCH_CHECK(
scale.dim() == mat.dim(),
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
"for block-scaled, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(),
" and scale.dim() = ", scale.dim(), " for arg ", arg_idx
);
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/blocksize, 4))
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/blocksize, 4))
// * weight is transposed prior to the call, scale stays non-transposed.
bool LHS = arg_idx == 0;
int scale_dim_to_check = 0;
int mat_dim_to_check = LHS ? 0 : 1;
TORCH_CHECK(
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
"for block-scaled, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
} else {
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
@ -273,32 +325,40 @@ void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim,
};
// TODO: this is for 3d tensor in 2d-3d case specifically.
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
// We'll need to support 3d-3d and 3d-2d cases once mxfp8/nvfp4 grouped gemm supports them.
int64_t G = mat.size(0);
int64_t K = mat.size(1);
if (is_fp4) {
// FP4 packs 2 values into a single 8b word - the "real" K is 2x the
// reported K. Reverse that adjustment.
const int fp4_elems_per_byte = 2;
K *= fp4_elems_per_byte;
}
int64_t N = mat.size(2);
int64_t blocked_scale_K = round_up(K/32, 4);
int64_t blocked_scale_K = round_up(K/blocksize, 4);
int64_t blocked_scale_N = round_up(N, 128);
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
TORCH_CHECK(
scale.dim() == mat.dim() - 1,
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
"for block-scaled 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N),",
"but scale is ", scale.dim(), "D for arg ", arg_idx
);
TORCH_CHECK(
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
"for block-scaled grouped GEMM, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ")",
" for arg ", arg_idx, ", got: ", scale.size(0), ", ", scale.size(1)
);
}
}
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
bool using_mx = scale.scalar_type() == at::kFloat8_e8m0fnu;
if (using_fp8_rowwise) {
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
} else if (using_mxfp8) {
_check_scales_mxfp8(mat, scale, dim, arg_idx);
} else if (using_mx) {
_check_scales_blocked(mat, scale, dim, arg_idx);
} else {
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
}
@ -411,9 +471,10 @@ namespace {
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 3> scale_grouped_kernel_dispatch = {{
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
{ "nvfp4_nvfp4", scaled_blas::check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4}}};
} // anonymous namespace
@ -525,8 +586,9 @@ _scaled_grouped_mm_cuda_v2(
out);
}
case ScaledGemmImplementation::MXFP8_MXFP8: {
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
// scale shape checks
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
return _mx8_mx8_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
@ -537,6 +599,21 @@ _scaled_grouped_mm_cuda_v2(
offs.value(),
out);
}
case ScaledGemmImplementation::NVFP4_NVFP4: {
// scale shape checks
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
return _f4_f4_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
scale_a[0], /* block-scale A */
scale_a[1], /* global-scale A */
scale_b[0], /* block-scale B */
scale_b[1], /* global-scale B */
offs.value(),
std::nullopt, /* bias */
out);
}
default:
TORCH_CHECK_NOT_IMPLEMENTED(false,
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");

View File

@ -57,6 +57,7 @@ Tensor& random_mps_impl(Tensor& self,
if (self.numel() == 0) {
return self;
}
at::assert_no_internal_overlap(self);
// MPS random is broken for 5D+ tensors, see https://github.com/pytorch/pytorch/issues/147624
const auto need_reshape = self.ndimension() > 4;
auto mps_gen = get_generator_or_default<MPSGeneratorImpl>(gen, at::mps::detail::getDefaultMPSGenerator());
@ -153,8 +154,16 @@ Tensor& random_mps_impl(Tensor& self,
feeds[meanPlaceholder.getMPSGraphTensor()] = meanPlaceholder.getMPSGraphTensorData();
}
Placeholder outputPlaceholder = Placeholder(cachedGraph->resultTensor, self);
// Handle non-contiguous output tensors by creating a contiguous temporary
const auto needs_gather = needsGather(self);
Tensor self_ = needs_gather ? at::empty_like(self, MemoryFormat::Contiguous) : self;
Placeholder outputPlaceholder = Placeholder(cachedGraph->resultTensor, self_);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
// Copy results back to original non-contiguous output
if (needs_gather) {
self.copy_(self_);
}
}
return self;

View File

@ -22,6 +22,7 @@
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/zeros_like.h>
#include <ATen/ops/reshape.h>
#include <ATen/ops/scalar_tensor.h>
#include <ATen/ops/sum.h>
@ -42,7 +43,6 @@ C10_DIAGNOSTIC_POP()
#include <static_switch.h>
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
#include <c10/util/Exception.h>
namespace FLASH_NAMESPACE {
@ -417,6 +417,26 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
const int head_size_og = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
if (batch_size == 0) {
auto opts = q.options();
at::Tensor out = at::empty({0, seqlen_q, num_heads, head_size_og}, opts);
at::Tensor q_padded = at::empty({0, seqlen_q, num_heads, head_size_og}, opts);
at::Tensor k_padded = at::empty({0, seqlen_k, num_heads_k, head_size_og}, opts);
at::Tensor v_padded = at::empty({0, seqlen_k, num_heads_k, head_size_og}, opts);
at::Tensor softmax_lse = at::empty({0, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor rng_state = at::empty({2}, at::dtype(c10::kUInt64).device(at::kCUDA));
at::Tensor _unused = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA));
at::Tensor p = at::empty({0}, opts);
if (return_softmax) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
p = at::empty({0, num_heads, seqlen_q_rounded, seqlen_k_rounded}, opts);
}
return {std::move(out), std::move(q_padded), std::move(k_padded), std::move(v_padded), std::move(softmax_lse), std::move(rng_state), _unused, std::move(p)};
}
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
@ -547,7 +567,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
}
return {out, q_padded, k_padded, v_padded, softmax_lse, rng_state, _unused, p};
return {std::move(out), std::move(q_padded), std::move(k_padded), std::move(v_padded), std::move(softmax_lse), std::move(rng_state), std::move(_unused), std::move(p)};
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@ -852,7 +872,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
const auto sizes = q.sizes();
@ -863,6 +882,20 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const int head_size = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
if (batch_size == 0) {
auto opts = q.options();
at::Tensor dq = at::empty_like(q);
at::Tensor dk = at::empty_like(k);
at::Tensor dv = at::empty_like(v);
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
at::Tensor softmax_d = at::empty({0, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
return {dq, dk, dv, softmax_d};
}
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");

View File

@ -1358,9 +1358,15 @@ if(BUILD_TEST)
)
else()
add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit)
add_subdirectory(${TORCH_ROOT}/test/cpp/lazy ${CMAKE_BINARY_DIR}/test_lazy)
# NativeRT is disabled
# add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert)
add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor)
add_subdirectory(${TORCH_ROOT}/test/cpp/aoti_abi_check ${CMAKE_BINARY_DIR}/test_aoti_abi_check)
if(BUILD_AOT_INDUCTOR_TEST)
add_subdirectory(${TORCH_ROOT}/test/cpp/aoti_inference ${CMAKE_BINARY_DIR}/test_aoti_inference)
endif()
if(USE_DISTRIBUTED)
add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d)
if(NOT WIN32)
@ -1378,16 +1384,6 @@ if(BUILD_TEST)
${CMAKE_BINARY_DIR}/test_mobile_nnc
)
endif()
add_subdirectory(${TORCH_ROOT}/test/cpp/lazy
${CMAKE_BINARY_DIR}/test_lazy)
endif()
if(BUILD_AOT_INDUCTOR_TEST)
add_subdirectory(
${TORCH_ROOT}/test/cpp/aoti_abi_check
${CMAKE_BINARY_DIR}/test_aoti_abi_check)
add_subdirectory(
${TORCH_ROOT}/test/cpp/aoti_inference
${CMAKE_BINARY_DIR}/test_aoti_inference)
endif()
endif()

View File

@ -99,6 +99,12 @@ DTensor supports the following types of {class}`Placement` on each {class}`Devic
:undoc-members:
```
```{eval-rst}
.. autoclass:: MaskPartial
:members:
:undoc-members:
```
```{eval-rst}
.. autoclass:: Placement
:members:

View File

@ -1,3 +1,8 @@
# Skip on windows
if(WIN32)
return()
endif()
set(AOTI_ABI_CHECK_TEST_ROOT ${TORCH_ROOT}/test/cpp/aoti_abi_check)
# Build the cpp gtest binary containing the cpp-only tests.
@ -30,8 +35,15 @@ target_compile_definitions(test_aoti_abi_check PRIVATE USE_GTEST)
# WARNING: DO NOT LINK torch!!!
# The purpose is to check if the used aten/c10 headers are written in a header-only way
target_link_libraries(test_aoti_abi_check PRIVATE gtest_main)
target_link_libraries(test_aoti_abi_check PRIVATE gtest_main sleef)
target_include_directories(test_aoti_abi_check PRIVATE ${ATen_CPU_INCLUDE})
if(NOT USE_SYSTEM_SLEEF)
target_include_directories(test_aoti_abi_check PRIVATE ${CMAKE_BINARY_DIR}/include)
endif()
# Disable unused-variable warnings for variables that are only used to test compilation
target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-variable)
target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-but-set-variable)
foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS})
foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES})
@ -41,12 +53,17 @@ foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS})
separate_arguments(FLAGS UNIX_COMMAND "${FLAGS}")
add_executable(${test_name}_${CPU_CAPABILITY} "${test_src}")
target_link_libraries(${test_name}_${CPU_CAPABILITY} PRIVATE gtest_main)
target_link_libraries(${test_name}_${CPU_CAPABILITY} PRIVATE gtest_main sleef)
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE ${ATen_CPU_INCLUDE})
if(NOT USE_SYSTEM_SLEEF)
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE ${CMAKE_BINARY_DIR}/include)
endif()
# Define CPU_CAPABILITY and CPU_CAPABILITY_XXX macros for conditional compilation
target_compile_definitions(${test_name}_${CPU_CAPABILITY} PRIVATE CPU_CAPABILITY=${CPU_CAPABILITY} CPU_CAPABILITY_${CPU_CAPABILITY})
target_compile_options(${test_name}_${CPU_CAPABILITY} PRIVATE ${FLAGS})
target_compile_options_if_supported(${test_name}_${CPU_CAPABILITY} -Wno-unused-variable)
target_compile_options_if_supported(${test_name}_${CPU_CAPABILITY} -Wno-unused-but-set-variable)
endforeach()
endforeach()

View File

@ -2,10 +2,27 @@
#include <ATen/cpu/vec/vec.h>
#include <iostream>
namespace torch {
namespace aot_inductor {
template <typename T>
void ExpectVecEqual(
const at::vec::Vectorized<T>& expected,
const at::vec::Vectorized<T>& actual) {
using Vec = at::vec::Vectorized<T>;
// Have to use std::vector for comparison because at::vec::Vectorized doesn't
// support operator[] on aarch64
std::vector<T> expected_data(Vec::size());
std::vector<T> actual_data(Vec::size());
expected.store(expected_data.data());
actual.store(actual_data.data());
for (int i = 0; i < Vec::size(); i++) {
EXPECT_EQ(expected_data[i], actual_data[i]);
}
}
TEST(TestVec, TestAdd) {
using Vec = at::vec::Vectorized<int>;
std::vector<int> a(1024, 1);
@ -16,9 +33,7 @@ TEST(TestVec, TestAdd) {
std::vector<int> expected(1024, 3);
Vec expected_vec = Vec::loadu(expected.data());
for (int i = 0; i < Vec::size(); i++) {
EXPECT_EQ(expected_vec[i], actual_vec[i]);
}
ExpectVecEqual(expected_vec, actual_vec);
}
TEST(TestVec, TestMax) {
@ -30,9 +45,7 @@ TEST(TestVec, TestMax) {
Vec actual_vec = at::vec::maximum(a_vec, b_vec);
Vec expected_vec = b_vec;
for (int i = 0; i < Vec::size(); i++) {
EXPECT_EQ(expected_vec[i], actual_vec[i]);
}
ExpectVecEqual(expected_vec, actual_vec);
}
TEST(TestVec, TestMin) {
@ -44,9 +57,7 @@ TEST(TestVec, TestMin) {
Vec actual_vec = at::vec::minimum(a_vec, b_vec);
Vec expected_vec = a_vec;
for (int i = 0; i < Vec::size(); i++) {
EXPECT_EQ(expected_vec[i], actual_vec[i]);
}
ExpectVecEqual(expected_vec, actual_vec);
}
TEST(TestVec, TestConvert) {
@ -58,9 +69,7 @@ TEST(TestVec, TestConvert) {
auto actual_vec = at::vec::convert<float>(a_vec);
auto expected_vec = b_vec;
for (int i = 0; i < at::vec::Vectorized<int>::size(); i++) {
EXPECT_EQ(expected_vec[i], actual_vec[i]);
}
ExpectVecEqual(expected_vec, actual_vec);
}
TEST(TestVec, TestClampMin) {
@ -72,9 +81,7 @@ TEST(TestVec, TestClampMin) {
Vec actual_vec = at::vec::clamp_min(a_vec, min_vec);
Vec expected_vec = min_vec;
for (int i = 0; i < Vec::size(); i++) {
EXPECT_EQ(expected_vec[i], actual_vec[i]);
}
ExpectVecEqual(expected_vec, actual_vec);
}
} // namespace aot_inductor

View File

@ -1,4 +1,3 @@
set(AOT_INDUCTOR_TEST_ROOT ${TORCH_ROOT}/test/cpp/aoti_inference)
# Build custom TorchScript op for AOTInductor
@ -8,27 +7,12 @@ set_target_properties(aoti_custom_class PROPERTIES
if(USE_CUDA)
target_compile_definitions(aoti_custom_class PRIVATE USE_CUDA)
elseif(USE_ROCM)
target_compile_definitions(aoti_custom_class PRIVATE USE_ROCM)
target_compile_definitions(aoti_custom_class PRIVATE USE_ROCM)
endif()
# Link against LibTorch
target_link_libraries(aoti_custom_class torch)
# the custom command that generates the TorchScript module
add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/script_data.pt
${CMAKE_CURRENT_BINARY_DIR}/script_model_cpu.pt
${CMAKE_CURRENT_BINARY_DIR}/script_model_cuda.pt
# This script requires the torch package to be installed.
COMMAND python ${AOT_INDUCTOR_TEST_ROOT}/compile_model.py
DEPENDS torch torch_python aoti_custom_class ${AOT_INDUCTOR_TEST_ROOT}/compile_model.py
)
add_custom_target(aoti_script_model ALL
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/script_data.pt
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/script_model_cpu.pt
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/script_model_cuda.pt
)
add_dependencies(aoti_script_model aoti_custom_class)
# Build the cpp gtest binary containing the cpp-only tests.
set(INDUCTOR_TEST_SRCS
${AOT_INDUCTOR_TEST_ROOT}/test.cpp
@ -37,23 +21,12 @@ set(INDUCTOR_TEST_SRCS
add_executable(test_aoti_inference
${TORCH_ROOT}/test/cpp/common/main.cpp
${INDUCTOR_TEST_SRCS}
data.pt
script_data.pt
script_model_cpu.pt
script_model_cuda.pt
)
add_dependencies(test_aoti_inference aoti_custom_class aoti_script_model)
add_dependencies(test_aoti_inference aoti_custom_class)
# TODO temporary until we can delete the old gtest polyfills.
target_compile_definitions(test_aoti_inference PRIVATE USE_GTEST)
# Define a custom command to generate the library
add_custom_command(
OUTPUT data.pt
COMMAND python ${AOT_INDUCTOR_TEST_ROOT}/test.py
DEPENDS ${AOT_INDUCTOR_TEST_ROOT}/test.py
)
target_link_libraries(test_aoti_inference PRIVATE
torch
gtest_main
@ -71,6 +44,10 @@ target_compile_definitions(test_aoti_inference PRIVATE
CMAKE_CURRENT_BINARY_DIR=${CMAKE_CURRENT_BINARY_DIR}
)
target_compile_options_if_supported(test_aoti_inference -Wno-unused-variable)
target_compile_options_if_supported(test_aoti_inference -Wno-unused-but-set-variable)
target_compile_options_if_supported(test_aoti_inference -Wno-unused-function)
if(INSTALL_TEST)
install(TARGETS test_aoti_inference DESTINATION bin)
# Install PDB files for MSVC builds

View File

@ -2,7 +2,9 @@
#include <gtest/gtest.h>
#include <atomic>
#include <condition_variable>
#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <functional>
#include <mutex>
#include <queue>
@ -28,6 +30,64 @@
namespace {
// Function to check if test data files exist and are valid
bool testDataFilesExist() {
std::string bindir = STRINGIZE(CMAKE_CURRENT_BINARY_DIR);
std::array<std::string, 4> required_files = {
"data.pt",
"script_data.pt",
"script_model_cpu.pt",
"script_model_cuda.pt"};
for (const auto& filename : required_files) {
std::string filepath = bindir + "/" + filename;
std::ifstream file(filepath);
if (!file.good()) {
return false;
}
}
return true;
}
// Function to ensure test data files are generated at runtime
void ensureTestDataGenerated() {
static std::once_flag generated_flag;
std::call_once(generated_flag, []() {
// Only generate if files don't exist or are placeholders
if (testDataFilesExist()) {
return;
}
std::string bindir = STRINGIZE(CMAKE_CURRENT_BINARY_DIR);
// Calculate path to source directory: build/test_aoti_inference -> build ->
// pytorch
std::string pytorch_root = bindir.substr(0, bindir.find_last_of("/"));
pytorch_root = pytorch_root.substr(0, pytorch_root.find_last_of("/"));
std::string source_dir = pytorch_root + "/test/cpp/aoti_inference";
// Generate test data files (data.pt, etc.) by running test.py directly
std::string test_script = source_dir + "/test.py";
std::string test_data_cmd = "cd " + bindir + " && python " + test_script;
std::cout << "Generating test data: " << test_data_cmd << std::endl;
int result1 = std::system(test_data_cmd.c_str());
if (result1 != 0) {
std::cerr << "Warning: Test data generation failed with code " << result1
<< std::endl;
}
// Generate model files (script_*.pt) by running compile_model.py directly
std::string compile_script = source_dir + "/compile_model.py";
std::string models_cmd = "cd " + bindir + " && python " + compile_script;
std::cout << "Generating model files: " << models_cmd << std::endl;
int result2 = std::system(models_cmd.c_str());
if (result2 != 0) {
std::cerr << "Warning: Model generation failed with code " << result2
<< std::endl;
}
});
}
const std::unordered_map<std::string, at::Tensor> derefTensorConstantMap(
torch::inductor::TensorConstantMap tensor_constant_map) {
std::unordered_map<std::string, at::Tensor> ret;
@ -855,7 +915,6 @@ void test_aoti_free_buffer(bool use_runtime_constant_folding) {
}
}
#if defined(USE_CUDA) || defined(USE_ROCM)
void test_cuda_alloc_test() {
torch::NoGradGuard no_grad;
@ -895,8 +954,8 @@ void test_cuda_alloc_test() {
runner->run(data_loader.attr(inputs_attr.c_str()).toTensorList().vec());
ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
}
#endif
#ifdef USE_CUDA
class ThreadPool {
private:
struct Task {
@ -1037,86 +1096,96 @@ void test_multi_cuda_streams(const std::string& device) {
ASSERT_TRUE(torch::allclose(ref_output_tensors[0], all_outputs[i][0]));
}
}
#endif
#endif // USE_CUDA
#endif // USE_CUDA || USE_ROCM
} // namespace
namespace torch::aot_inductor {
TEST(AotInductorTest, BasicTestCpu) {
// Test fixture that ensures test data is generated once for all tests
class AotInductorTest : public ::testing::Test {
public:
// This runs once before all tests in this test suite
static void SetUpTestSuite() {
ensureTestDataGenerated();
}
};
TEST_F(AotInductorTest, BasicTestCpu) {
test_aoti("cpu", false);
}
TEST(AotInductorTest, BasicScriptTestCpu) {
TEST_F(AotInductorTest, BasicScriptTestCpu) {
test_aoti_script("cpu");
}
TEST(AotInductorTest, BasicPackageLoaderTestCpu) {
TEST_F(AotInductorTest, BasicPackageLoaderTestCpu) {
test_aoti_package_loader("cpu", false);
}
TEST(AotInductorTest, ExtractConstantsMapCpu) {
TEST_F(AotInductorTest, ExtractConstantsMapCpu) {
test_aoti_extract_constants_map("cpu");
}
#ifdef USE_CUDA
TEST(AotInductorTest, BasicTestCuda) {
TEST_F(AotInductorTest, BasicTestCuda) {
test_aoti("cuda", true);
test_aoti("cuda", false);
}
TEST(AotInductorTest, BasicScriptTestCuda) {
TEST_F(AotInductorTest, BasicScriptTestCuda) {
test_aoti_script("cuda");
}
TEST(AotInductorTest, BasicPackageLoaderTestCuda) {
TEST_F(AotInductorTest, BasicPackageLoaderTestCuda) {
test_aoti_package_loader("cuda", false);
}
TEST(AotInductorTest, BasicPackageLoaderTestMultiGpuCuda) {
TEST_F(AotInductorTest, BasicPackageLoaderTestMultiGpuCuda) {
test_aoti_package_loader_multi_gpu("cuda", false);
}
TEST(AotInductorTest, UpdateUserManagedConstantsCuda) {
TEST_F(AotInductorTest, UpdateUserManagedConstantsCuda) {
test_aoti_user_managed_buffer();
}
TEST(AotInductorTest, RuntimeUpdateConstantsCuda) {
TEST_F(AotInductorTest, RuntimeUpdateConstantsCuda) {
test_aoti_constants_update("cuda", true);
}
TEST(AotInductorTest, UpdateConstantsCuda) {
TEST_F(AotInductorTest, UpdateConstantsCuda) {
test_aoti_constants_update("cuda", false);
}
TEST(AotInductorTest, ExtractConstantsMapCuda) {
TEST_F(AotInductorTest, ExtractConstantsMapCuda) {
test_aoti_extract_constants_map("cuda");
}
TEST(AotInductorTest, RuntimeUpdateInactiveConstantsCuda) {
TEST_F(AotInductorTest, RuntimeUpdateInactiveConstantsCuda) {
test_aoti_double_buffering("cuda", true);
}
TEST(AotInductorTest, UpdateInactiveConstantsCuda) {
TEST_F(AotInductorTest, UpdateInactiveConstantsCuda) {
test_aoti_double_buffering("cuda", false);
}
TEST(AotInductorTest, UpdateInactiveConstantsWithTensorConstantsCuda) {
TEST_F(AotInductorTest, UpdateInactiveConstantsWithTensorConstantsCuda) {
test_aoti_double_buffering_with_tensor_constants();
}
TEST(AotInductorTest, FreeInactiveConstantBufferCuda) {
TEST_F(AotInductorTest, FreeInactiveConstantBufferCuda) {
test_aoti_free_buffer(false);
}
TEST(AotInductorTest, FreeInactiveConstantBufferRuntimeConstantFoldingCuda) {
TEST_F(AotInductorTest, FreeInactiveConstantBufferRuntimeConstantFoldingCuda) {
test_aoti_free_buffer(true);
}
TEST(AotInductorTest, MultiStreamTestCuda) {
TEST_F(AotInductorTest, MultiStreamTestCuda) {
test_multi_cuda_streams("cuda");
}
TEST(AotInductorTest, CudaAllocTestCuda) {
TEST_F(AotInductorTest, CudaAllocTestCuda) {
test_cuda_alloc_test();
}
#endif

View File

@ -168,7 +168,7 @@ class TestEmbeddingOp(DTensorTestBase):
self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22)
self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10)
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor.placement_types import MaskPartial
# test collectives
embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type)
@ -176,7 +176,7 @@ class TestEmbeddingOp(DTensorTestBase):
inp = torch.randint(0, 10, (8, 8), device=self.device_type)
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
output = sharded_embedding(replicated_inp)
self.assertIsInstance(output.placements[0], _MaskPartial)
self.assertIsInstance(output.placements[0], MaskPartial)
comm_mode = CommDebugMode()
@ -192,9 +192,9 @@ class TestEmbeddingOp(DTensorTestBase):
inp = torch.randint(0, 10, (4, 4), device=self.device_type)
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor.placement_types import MaskPartial
# case 1: two embeddings with the same shape, thus sharing the underlying _MaskPartial
# case 1: two embeddings with the same shape, thus sharing the underlying MaskPartial
# and MaskBuffer, because of cache hit from sharding propagation
emb1 = torch.nn.Embedding(10, 23, device=self.device_type)
@ -206,23 +206,23 @@ class TestEmbeddingOp(DTensorTestBase):
output2 = sharded_emb2(replicated_inp)
partial_placement1 = output1.placements[0]
self.assertIsInstance(partial_placement1, _MaskPartial)
self.assertIsInstance(partial_placement1, MaskPartial)
output1.full_tensor()
partial_placement2 = output2.placements[0]
self.assertIsInstance(partial_placement2, _MaskPartial)
self.assertIsInstance(partial_placement2, MaskPartial)
output2.full_tensor()
self.assertTrue(id(partial_placement1), id(partial_placement2))
# case 2: two embeddings with the same logical_dim_size, but different logical_shape
# thus they will have different _MaskPartial placements (with no cache hit)
# thus they will have different MaskPartial placements (with no cache hit)
emb3 = torch.nn.Embedding(10, 29, device=self.device_type)
sharded_emb3 = self._apply_sharding(emb3, 0, mesh)
output3 = sharded_emb3(replicated_inp)
partial_placement3 = output3.placements[0]
self.assertIsInstance(partial_placement3, _MaskPartial)
self.assertIsInstance(partial_placement3, MaskPartial)
output2.full_tensor()
# not equal because of different logical_shape, despite of same logical_dim_size

View File

@ -511,7 +511,7 @@ class DistTensorOpsTest(DTensorTestBase):
# case 2 input sharding: input sharded, index replicated, output mask partial
# only works when index has size 1 on the gather dimension and
# input is sharded on the gather dimension
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor.placement_types import MaskPartial
gather_dim = 1
global_input = torch.randn(12, 8, 16)
@ -522,7 +522,7 @@ class DistTensorOpsTest(DTensorTestBase):
with comm_mode:
output_dt = torch.gather(input_dt, gather_dim, index_dt)
self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertIsInstance(output_dt.placements[0], _MaskPartial)
self.assertIsInstance(output_dt.placements[0], MaskPartial)
self.assertEqual(output_dt.full_tensor(), global_output)
# case 3 index sharding: input replicated, index sharded, output sharded

View File

@ -892,10 +892,16 @@ fn(torch.randn(5))
os.remove(
file_path
) # Delete temp file manually, due to setup NamedTemporaryFile as delete=False.
self.assertEqual( # process wrap difference: /r/n on Windows, /n on posix.
empty_line_normalizer(lines),
empty_line_normalizer(stderr.decode("utf-8")),
)
orig_maxDiff = unittest.TestCase.maxDiff
unittest.TestCase.maxDiff = None
try:
self.assertEqual( # process wrap difference: /r/n on Windows, /n on posix.
empty_line_normalizer(lines),
empty_line_normalizer(stderr.decode("utf-8")),
)
except Exception:
unittest.TestCase.maxDiff = orig_maxDiff
raise
@make_settings_test("torch._dynamo.eval_frame")
def test_log_traced_frames(self, records):

View File

@ -1000,6 +1000,18 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.exit_stack.close()
super().tearDown()
def test_compiled_module_truthiness(self):
# Test with empty ModuleList
original_empty = nn.ModuleList()
compiled_empty = torch.compile(original_empty)
self.assertEqual(bool(original_empty), bool(compiled_empty))
self.assertFalse(bool(compiled_empty))
# Test with non-empty ModuleList
original_filled = nn.ModuleList([nn.Linear(10, 5)])
compiled_filled = torch.compile(original_filled)
self.assertEqual(bool(original_filled), bool(compiled_filled))
self.assertTrue(bool(compiled_filled))
def guard_manager_clone_hook_fn(self, guard_manager_wrapper, f_locals, builder):
root = guard_manager_wrapper.root
cloned_root = root.clone_manager(lambda x: True)

View File

@ -14269,6 +14269,22 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.assertTrue("'enable_fp_fusion': False" in code)
torch.testing.assert_close(out, fn(a, b), atol=0, rtol=0)
@requires_cuda_and_triton
@config.patch(runtime_triton_nan_asserts=True)
def test_nan_assert_inside_triton_kernel(self):
def fn(x):
x = x - 1
# Uncomment the following line can trigger the failure of
# the device size assertion
# x = torch.log(x)
return torch.where(x.isnan(), 3.14, x)
compiled = torch.compile(fn)
x = torch.randn(4096, device=GPU_TYPE)
out, (code,) = run_and_get_code(compiled, x)
self.assertTrue("'NaN or Inf found'" in code)
torch.testing.assert_close(out, fn(x))
@skip_if_cpp_wrapper("skip cpp wrapper")
@requires_cuda_and_triton
def test_repeat_interleave_decomposition_has_clamp(self):

View File

@ -12,7 +12,6 @@ from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfMPS,
expectedFailureMPS,
expectedFailureMPSPre15,
expectedFailureXLA,
instantiate_device_type_tests,
)
@ -173,7 +172,6 @@ class TestDropoutNNDeviceType(NNTestCase):
else:
self.assertNotEqual(permuted_inp, out)
@expectedFailureMPSPre15
def test_Dropout(self, device):
input = torch.empty(1000)
self._test_dropout(nn.Dropout, device, input)

View File

@ -529,7 +529,7 @@ class TestProfiler(TestCase):
found_mm = True
if "gemm" in e.name.lower() or "Cijk" in e.name:
found_gemm = True
if "memcpy" in e.name.lower():
if "memcpy" in e.name.lower() or "__amd_rocclr_copyBuffer" in e.name:
found_memcpy = True
if use_cuda:
self.assertTrue(found_gemm)

View File

@ -755,6 +755,8 @@ def run_test_retries(
REPO_ROOT / ".pytest_cache/v/cache/stepcurrent" / stepcurrent_key
) as f:
current_failure = f.read()
if current_failure == "null":
current_failure = f"'{test_file}'"
except FileNotFoundError:
print_to_file(
"No stepcurrent file found. Either pytest didn't get to run (e.g. import error)"
@ -791,8 +793,6 @@ def run_test_retries(
print_to_file("Retrying single test...")
print_items = [] # do not continue printing them, massive waste of space
if "null" in num_failures:
num_failures[f"'{test_file}'"] = num_failures.pop("null")
consistent_failures = [x[1:-1] for x in num_failures.keys() if num_failures[x] >= 3]
flaky_failures = [x[1:-1] for x in num_failures.keys() if 0 < num_failures[x] < 3]
if len(flaky_failures) > 0:

View File

@ -7846,6 +7846,45 @@ class TestMPS(TestCaseMPS):
y = torch.normal(torch.zeros(shape, device="mps"), torch.ones(shape, device="mps"))
self.assertNotEqual(y[0], y[1])
def test_random_ops_noncontiguous(self):
"""Test random in-place operations on non-contiguous tensors.
All random in-place operations should work on non-contiguous tensors.
See issues #165257 and #124029.
"""
# Test each random in-place operation
ops = [
("normal_", lambda t: t.normal_(0, 1)),
("uniform_", lambda t: t.uniform_(0, 1)),
("exponential_", lambda t: t.exponential_(1.0)),
("bernoulli_", lambda t: t.bernoulli_(0.5)),
("random_", lambda t: t.random_()),
("random_with_to", lambda t: t.random_(10)),
("random_with_range", lambda t: t.random_(0, 10)),
]
for name, op_func in ops:
with self.subTest(operation=name):
# Create non-contiguous tensor via transpose
t_mps = torch.zeros(50, 50, device='mps').T.clone()
self.assertFalse(t_mps.is_contiguous(),
f"{name}: tensor should be non-contiguous")
# Apply operation
op_func(t_mps)
# Verify tensor was modified (not all zeros)
max_val = t_mps.max().item()
self.assertNotEqual(max_val, 0.0,
f"{name}: operation failed to modify non-contiguous tensor")
# Test rand_like specifically (issue #124029)
t = torch.ones((3, 2, 2), device='mps').permute(2, 0, 1)
self.assertFalse(t.is_contiguous(), "rand_like input should be non-contiguous")
result = torch.rand_like(t)
self.assertFalse(result.is_contiguous(), "rand_like result should be non-contiguous")
self.assertNotEqual(result.max().item(), 0.0, "rand_like should generate non-zero values")
# Test exponential
@unittest.skip("This does not test anything")
def test_exponential(self):

View File

@ -46,6 +46,7 @@ from torch.testing._internal.common_quantized import (
_floatx_unpacked_to_f32,
ceil_div, to_blocked,
to_mxfp8,
from_blocked_format,
generate_jagged_offs,
)
@ -462,6 +463,24 @@ def pack_uint4(uint8_data) -> torch.Tensor:
uint8_data = uint8_data.contiguous().view(-1)
return (uint8_data[1::2] << 4 | uint8_data[::2]).view(down_size(shape))
def unpack_uint4(uint8_data) -> torch.Tensor:
# Take a packed uint8 tensor (i.e. nvfp4) and unpack into
# a tensor twice as wide. Useful for dequant operations.
shape = list(uint8_data.shape)
# 2x packed elements -> single non-packed => adjust shape
shape[-1] *= 2
out = torch.empty(
*shape,
device=uint8_data.device,
dtype=torch.uint8
).view(-1)
uint8_data_as_uint8 = uint8_data.view(torch.uint8).view(-1)
out[1::2] = uint8_data_as_uint8[:] >> 4
out[::2] = uint8_data_as_uint8 & 15
return out.view(shape)
def _bfloat16_to_float4_e2m1fn_x2(x):
assert x.dtype == torch.bfloat16
@ -470,6 +489,119 @@ def _bfloat16_to_float4_e2m1fn_x2(x):
x = x.view(torch.float4_e2m1fn_x2)
return x
def _convert_to_nvfp4_with_hp_ref(t):
# Convert a tensor to nvfp4, returning:
# t_hp : reconstructed bf16 version of t_lp
# t_lp : nvfp4 tensor (2x elements packed into uint8)
# t_scale: e4m3 block-wise scaling factors (non-swizzled)
# t_global_scale: fp32 tensor-wise global scaling factor
t_lp, t_scale, t_global_scale = data_to_nvfp4_with_global_scale(
t,
16,
)
t_hp = from_blocked_format(
_floatx_unpacked_to_f32(
unpack_uint4(t_lp),
FP4_EBITS,
FP4_MBITS),
t_scale,
blocksize=16) * t_global_scale
return t_hp, t_lp, t_scale, t_global_scale
def _convert_to_mxfp8_with_hp_ref(t):
# Convert a tensor to mxfp8, returning:
# t_hp : reconstructed bf16 version of t_lp
# t_lp : fp8_e4m3 tensor
# t_scale: fp8_e8m0 block-wise scaling factors (non-swizzled)
t_scale, t_lp = to_mxfp8(t)
t_hp = from_blocked_format(t_lp, t_scale, blocksize=32)
return t_hp, t_lp, t_scale
def _2d_grouped_tensor_to_mxfp8_blocked_scaled(t, MN, G, offs, format='mxfp8'):
# Convert scales to blocked format. either mxfp8 or nvfp4
th_list = []
t_list = []
t_blocked_scale_list = []
t_global_scale_list = []
def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
for group_idx in range(G):
# to_mxfp8 per group
prev_group_end_offset = (
0 if group_idx == 0 else offs[group_idx - 1]
)
curr_group_end_offset = offs[group_idx]
group_size = curr_group_end_offset - prev_group_end_offset
if group_size > 0:
t_slice = t[
:, prev_group_end_offset:curr_group_end_offset
].contiguous() # (M, K_group)
if format == 'mxfp8':
th_slice, tq_slice, t_scale_slice = _convert_to_mxfp8_with_hp_ref(t_slice)
elif format == 'nvfp4':
th_slice, tq_slice, t_scale_slice, tq_global = _convert_to_nvfp4_with_hp_ref(
t_slice,
)
t_global_scale_list.append(tq_global)
else:
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
t_list.append(tq_slice)
th_list.append(th_slice)
# Convert scales to blocked format.
t_scale_slice_blocked = to_blocked(
t_scale_slice
) # (round_up(M, 128), round_up(K_group//32, 4))
t_blocked_scale_list.append(t_scale_slice_blocked)
# Assemble the full XQ and WQ
tq = torch.cat(t_list, dim=1).contiguous()
th = torch.cat(th_list, dim=1).contiguous()
# Combine all XQ groups blocked scales into one tensor.
t_blocked_scales = torch.cat(t_blocked_scale_list, dim=0)
MN_rounded = round_up(MN, 128)
t_blocked_scales = t_blocked_scales.reshape(MN_rounded, -1)
# Global scales only exist for nvfp4
t_global_scales = None
if len(t_global_scale_list) > 0:
t_global_scales = torch.stack(t_global_scale_list)
return th, tq, t_blocked_scales, t_global_scales
def _build_scaled_grouped_mm_kwargs(scale_a, scale_b, offs, format):
# Build some standard args that are wordy
# Note: if/when ROCm support added, need to change swizzle handling
kwargs = {
'mxfp8': {
'scale_a': scale_a,
'scale_b': scale_b,
'scale_recipe_a': ScalingType.BlockWise1x32,
'scale_recipe_b': ScalingType.BlockWise1x32,
'swizzle_a': SwizzleType.SWIZZLE_32_4_4,
'swizzle_b': SwizzleType.SWIZZLE_32_4_4,
'offs': offs, # (G,)
'out_dtype': torch.bfloat16,
'wrap_v2': True,
},
'nvfp4': {
'scale_a': scale_a,
'scale_b': scale_b,
'scale_recipe_a': [ScalingType.BlockWise1x16, ScalingType.TensorWise],
'scale_recipe_b': [ScalingType.BlockWise1x16, ScalingType.TensorWise],
'swizzle_a': SwizzleType.SWIZZLE_32_4_4,
'swizzle_b': SwizzleType.SWIZZLE_32_4_4,
'offs': offs, # (G,)
'out_dtype': torch.bfloat16,
'wrap_v2': True,
},
}
return kwargs[format]
class TestFP8Matmul(TestCase):
@ -526,13 +658,15 @@ class TestFP8Matmul(TestCase):
out_fp8_s = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b)
self.assertEqual(out_fp8, out_fp8_s)
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
@parametrize("G", [1, 4, 16])
@parametrize("M", [2048, 2049])
@parametrize("N", [8192])
@parametrize("K", [16640])
@parametrize("wrap_v2", [True, False])
def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K, wrap_v2):
@parametrize("format", ["mxfp8"] + (["nvfp4"] if torch.version.cuda else []))
def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format):
torch.manual_seed(42)
total_K = K # Alias for clarity, communicating this consists of several groups along this dim
input_group_end_offsets = generate_jagged_offs(
@ -541,95 +675,61 @@ class TestFP8Matmul(TestCase):
X = torch.randn((M, total_K), dtype=torch.bfloat16, device="cuda") * 0.1
W = torch.randn((N, total_K), dtype=torch.bfloat16, device="cuda") * 0.01
# Convert scales to blocked format.
x_list = []
w_list = []
x_blocked_scale_list = []
w_blocked_scale_list = []
xh, xq, x_blocked_scales, x_global_scales = _2d_grouped_tensor_to_mxfp8_blocked_scaled(
X, M, G, input_group_end_offsets, format=format
)
wh, wq, w_blocked_scales, w_global_scales = _2d_grouped_tensor_to_mxfp8_blocked_scaled(
W, N, G, input_group_end_offsets, format=format
)
def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
for group_idx in range(G):
# to_mxfp8 per group
prev_group_end_offset = (
0 if group_idx == 0 else input_group_end_offsets[group_idx - 1]
if format == "mxfp8":
kwargs = _build_scaled_grouped_mm_kwargs(
x_blocked_scales,
w_blocked_scales,
input_group_end_offsets,
format,
)
curr_group_end_offset = input_group_end_offsets[group_idx]
group_size = curr_group_end_offset - prev_group_end_offset
if group_size > 0:
x_slice = X[
:, prev_group_end_offset:curr_group_end_offset
].contiguous() # (M, K_group)
w_slice = W[
:, prev_group_end_offset:curr_group_end_offset
].contiguous() # (N, K_group)
x_scale_slice, xq_slice = to_mxfp8(
x_slice
) # scale shape -> (M, K_group // 32)
w_scale_slice, wq_slice = to_mxfp8(
w_slice
) # scale shape -> (N, K_group // 32)
x_list.append(xq_slice)
w_list.append(wq_slice)
elif format == "nvfp4":
kwargs = _build_scaled_grouped_mm_kwargs(
[x_blocked_scales, x_global_scales],
[w_blocked_scales, w_global_scales],
input_group_end_offsets,
format,
)
else:
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
# Convert scales to blocked format.
x_scale_slice_blocked = to_blocked(
x_scale_slice
) # (round_up(M, 128), round_up(K_group//32, 4))
w_scale_slice_blocked = to_blocked(
w_scale_slice
) # (round_up(N, 128), round_up(K_group//32, 4))
x_blocked_scale_list.append(x_scale_slice_blocked)
w_blocked_scale_list.append(w_scale_slice_blocked)
# Assemble the full XQ and WQ
xq = torch.cat(x_list, dim=1).contiguous()
wq = torch.cat(w_list, dim=1).contiguous()
# Combine all XQ groups blocked scales into one tensor.
x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0)
M_rounded = round_up(M, 128)
x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1)
# Combine all WQ groups blocked scales into one tensor.
w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0)
N_rounded = round_up(N, 128)
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
if format == 'nvfp4':
assert x_global_scales.numel() == w_global_scales.numel()
assert x_global_scales.numel() == G
# Compute mxfp8 grouped mm output
y_mxfp8 = scaled_grouped_mm_wrap(
xq, # (M, total_K)
wq.transpose(-2, -1), # (total_K, N)
x_blocked_scales, # to_blocked_per_group(M, total_K//32)
w_blocked_scales, # to_blocked_per_group(N, total_K//32)
scale_recipe_a=ScalingType.BlockWise1x32,
scale_recipe_b=ScalingType.BlockWise1x32,
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
offs=input_group_end_offsets, # (G,)
out_dtype=torch.bfloat16,
wrap_v2=wrap_v2
y_lp = scaled_grouped_mm_wrap(
xq,
wq.transpose(-2, -1),
**kwargs,
)
# bf16 reference output
y_bf16 = torch._grouped_mm(
X, W.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16
# Note: Reference result should be on reconstructed, not original values.
# as-in float(fp4(t)) not t itself.
xh, wh.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16
)
# Assert no NaNs
assert not y_mxfp8.isnan().any(), "mxfp8 output contains NaN"
assert not y_lp.isnan().any(), "mxfp8 output contains NaN"
# Assert outputs are close
torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2)
torch.testing.assert_close(y_lp, y_bf16, atol=8.0e-2, rtol=8.0e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
@parametrize("G", [1, 4, 16])
@parametrize("M", [16640])
@parametrize("N", [8192])
@parametrize("K", [4096])
@parametrize("wrap_v2", [True, False])
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, wrap_v2):
@parametrize("format", ["mxfp8"] + (["nvfp4"] if torch.version.cuda else []))
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, format):
torch.manual_seed(42)
# Simulate 2d-3d grouped gemm `out = input @ weight.t()`
# 2D inputs with groups along M, 3D weights.
@ -643,60 +743,120 @@ class TestFP8Matmul(TestCase):
# For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
# as they each used for independent gemm in the grouped gemm.
wq_list = []
w_scale_list = []
for i in range(G):
w_scale, wq = to_mxfp8(W[i])
w_scale = to_blocked(w_scale)
wq_list.append(wq)
w_scale_list.append(w_scale)
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
def _3d_to_blocked_scaled(W, G, format):
wh_list = []
wq_list = []
w_scale_list = []
w_global_scale_list = []
for i in range(G):
if format == "mxfp8":
wh, wq, w_scale = _convert_to_mxfp8_with_hp_ref(W[i])
elif format == "nvfp4":
w_scale, wq = to_mxfp8(W[i])
wh, wq, w_scale, w_global_scale = _convert_to_nvfp4_with_hp_ref(W[i])
w_global_scale_list.append(w_global_scale)
else:
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
# Swizzle scaled
# TODO(slayton): gate on cuda/hip
w_scale = to_blocked(w_scale)
wh_list.append(wh)
wq_list.append(wq)
w_scale_list.append(w_scale)
wh = torch.stack(wh_list, dim=0).contiguous()
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
# Global scales only exist for nvfp4
if len(w_global_scale_list) > 0:
w_global_scales = torch.stack(w_global_scale_list)
else:
w_global_scales = None
return wh, wq, w_scale, w_global_scales
wh, wq, w_blocked_scales, w_global_scales = _3d_to_blocked_scaled(W, G, format)
# For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
# as they each used for independent gemm in the grouped gemm.
xq_list = []
x_scale_list = []
for i in range(G):
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
curr_group_end = input_group_end_offsets[i]
group_size = curr_group_end - prev_group_end
if group_size > 0:
x_slice = X[prev_group_end:curr_group_end, :]
x_scale, xq = to_mxfp8(x_slice)
x_scale = to_blocked(x_scale)
xq_list.append(xq)
x_scale_list.append(x_scale)
xq = torch.cat(xq_list, dim=0).contiguous()
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
x_scale = x_scale.reshape(-1, K // block_size)
xq = xq.view(-1, xq.shape[-1])
def _2d_to_blocked_scaled(X, K, G, offs, format):
xh_list = []
xq_list = []
x_scale_list = []
x_global_scale_list = []
for i in range(G):
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
curr_group_end = input_group_end_offsets[i]
group_size = curr_group_end - prev_group_end
if group_size > 0:
x_slice = X[prev_group_end:curr_group_end, :]
if format == "mxfp8":
xh, xq, x_scale = _convert_to_mxfp8_with_hp_ref(x_slice)
elif format == "nvfp4":
xh, xq, x_scale, x_global_scale = _convert_to_nvfp4_with_hp_ref(x_slice)
x_global_scale_list.append(x_global_scale)
else:
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
# Compute mxfp8 grouped gemm.
y_mxfp8 = scaled_grouped_mm_wrap(
x_scale = to_blocked(x_scale)
xh_list.append(xh)
xq_list.append(xq)
x_scale_list.append(x_scale)
xh = torch.cat(xh_list, dim=0).contiguous()
xq = torch.cat(xq_list, dim=0).contiguous()
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
x_scale = x_scale.reshape(-1, K // block_size)
xq = xq.view(-1, xq.shape[-1])
xh = xh.view(-1, xh.shape[-1])
x_global_scales = None
if len(x_global_scale_list) > 0:
x_global_scales = torch.stack(x_global_scale_list)
return xh, xq, x_scale, x_global_scales
xh, xq, x_blocked_scales, x_global_scales = _2d_to_blocked_scaled(X, K, G, input_group_end_offsets, format)
if format == "mxfp8":
kwargs = _build_scaled_grouped_mm_kwargs(
x_blocked_scales,
w_blocked_scales,
input_group_end_offsets,
format,
)
elif format == "nvfp4":
kwargs = _build_scaled_grouped_mm_kwargs(
[x_blocked_scales, x_global_scales],
[w_blocked_scales, w_global_scales],
input_group_end_offsets,
format,
)
else:
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
if format == 'nvfp4':
assert x_global_scales.numel() == w_global_scales.numel()
assert x_global_scales.numel() == G
# Compute low-precision grouped gemm.
y_lp = scaled_grouped_mm_wrap(
xq,
wq.transpose(-2, -1),
x_scale,
w_scale,
offs=input_group_end_offsets,
out_dtype=torch.bfloat16,
scale_recipe_a=ScalingType.BlockWise1x32,
scale_recipe_b=ScalingType.BlockWise1x32,
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
wrap_v2=wrap_v2)
**kwargs
)
# Compute reference bf16 grouped gemm.
# Note: Reference result should be on reconstructed, not original values.
# as-in float(fp4(t)) not t itself.
y_bf16 = torch._grouped_mm(
X,
W.transpose(-2, -1),
xh,
wh.transpose(-2, -1),
offs=input_group_end_offsets,
out_dtype=torch.bfloat16,
)
# Assert outputs are close.
torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2)
torch.testing.assert_close(y_lp, y_bf16, atol=8.0e-2, rtol=8.0e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@ -1704,6 +1864,7 @@ class TestFP8Matmul(TestCase):
@parametrize("fast_accum", [False, True])
# AMD does not support non-contiguous inputs yet
@parametrize("strided", [False] + ([True] if torch.version.cuda else []))
# AMD does not support NVFP4
@parametrize("wrap_v2", [True, False])
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, wrap_v2):
device = "cuda"

View File

@ -1107,6 +1107,7 @@ class TestTransformers(NNTestCase):
)[0]
@tf32_on_and_off(0.003)
@parametrize("batch_size", [0, 5])
@parametrize("input_dim,attn_mask_dim,is_causal",
[(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True),
(4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)],
@ -1116,7 +1117,7 @@ class TestTransformers(NNTestCase):
if attn_dim is not None else "no_attn_mask")))
@parametrize("dropout_p", [0.0, 0.2, 0.5])
@sdpa_kernel(backends=[SDPBackend.MATH])
def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p):
def test_scaled_dot_product_attention(self, device, batch_size, input_dim, attn_mask_dim, is_causal, dropout_p):
def sdp_ref(
q,
k,
@ -1140,12 +1141,13 @@ class TestTransformers(NNTestCase):
# TODO: Support cross-device / dtype testing properly when instantiate_device_type_tests() is used.
dtypes = [torch.double, torch.float]
for dtype in dtypes:
N = batch_size
def rand_tensor(*shape):
return torch.randn(shape, device=device, dtype=dtype)
# This test compares python and C++ implementations of SDP.
N, N_prime, L, S, E = 5, 2, 4, 3, 6
N_prime, L, S, E = 2, 4, 3, 6
if input_dim == 3:
query = rand_tensor(N, L, E)
key = rand_tensor(N, S, E)

View File

@ -5,11 +5,12 @@ from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import varlen_attn
from torch.nn.attention.varlen import varlen_attn
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import parametrize, run_tests
from torch.utils._python_dispatch import TorchDispatchMode
VarlenShape = namedtuple(
@ -23,6 +24,18 @@ default_tolerances = {
}
class OpLoggingMode(TorchDispatchMode):
"""Logging mode that captures all dispatched operations"""
def __init__(self):
self.called_ops = []
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
op_name = str(func)
self.called_ops.append(op_name)
return func(*args, **(kwargs or {}))
class AttentionBlock(nn.Module):
def __init__(
self, embed_dim: int, num_heads: int, device: torch.device, dtype: torch.dtype
@ -39,12 +52,9 @@ class AttentionBlock(nn.Module):
embed_dim, embed_dim, bias=False, device=device, dtype=dtype
)
def forward_varlen(
def get_varlen_qkv(
self,
x_packed: torch.Tensor,
cu_seq: torch.Tensor,
max_len: int,
is_causal: bool = False,
):
qkv = self.qkv_proj(x_packed)
q, k, v = qkv.chunk(3, dim=-1)
@ -53,24 +63,51 @@ class AttentionBlock(nn.Module):
k = k.view(-1, self.num_heads, self.head_dim)
v = v.view(-1, self.num_heads, self.head_dim)
attn_out = varlen_attn(
q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal=is_causal
)
return q, k, v
def forward_varlen(
self,
x_packed: torch.Tensor,
cu_seq: torch.Tensor,
max_len: int,
is_causal: bool = False,
):
q, k, v = self.get_varlen_qkv(x_packed)
attn_out = varlen_attn(q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal)
attn_out = attn_out.view(-1, self.embed_dim)
return self.out_proj(attn_out)
def forward_sdpa(self, x_padded: torch.Tensor, is_causal: bool = False):
def forward_sdpa(
self,
x_padded: torch.Tensor,
seq_lengths: torch.Tensor,
dtype: torch.dtype,
is_causal: bool = False,
):
batch_size, seq_len, _ = x_padded.shape
qkv = self.qkv_proj(x_padded)
q, k, v = qkv.chunk(3, dim=-1)
mask = (
torch.arange(seq_len, device=x_padded.device)[None, :]
< seq_lengths[:, None]
)
attn_mask = mask[:, None, None, :].expand(
batch_size, self.num_heads, seq_len, seq_len
)
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
attn_out = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=is_causal
)
attn_out = (
attn_out.transpose(1, 2)
.contiguous()
@ -91,7 +128,9 @@ def create_variable_length_batch(
seq_lengths = torch.tensor(seq_lengths, device=device)
total_tokens = seq_lengths.sum().item()
x_packed = torch.randn(total_tokens, shape.embed_dim, device=device, dtype=dtype)
x_packed = torch.randn(
total_tokens, shape.embed_dim, device=device, dtype=dtype, requires_grad=True
)
cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32)
cu_seq[1:] = seq_lengths.cumsum(0)
@ -106,6 +145,7 @@ def create_variable_length_batch(
end_idx = start_idx + seq_len
x_padded[i, :seq_len] = x_packed[start_idx:end_idx]
start_idx = end_idx
x_padded = x_padded.clone().detach().requires_grad_()
return {
"seq_lengths": seq_lengths,
@ -133,7 +173,11 @@ class TestVarlenAttention(NNTestCase):
total_tokens = shape.batch_size * shape.max_seq_len
x_packed = torch.randn(
total_tokens, shape.embed_dim, device=device, dtype=dtype
total_tokens,
shape.embed_dim,
device=device,
dtype=dtype,
requires_grad=True,
)
cu_seq = torch.tensor(
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
@ -147,6 +191,128 @@ class TestVarlenAttention(NNTestCase):
self.assertEqual(output.device, torch.device(device))
self.assertEqual(output.dtype, dtype)
varlen_grad_out = torch.ones_like(output)
varlen_grad = torch.autograd.grad(
outputs=output,
inputs=x_packed,
grad_outputs=varlen_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
self.assertIsNotNone(varlen_grad)
self.assertEqual(varlen_grad.shape, x_packed.shape)
self.assertEqual(varlen_grad.dtype, x_packed.dtype)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@parametrize("dtype", [torch.bfloat16, torch.float16])
def test_custom_op_compliance(self, device, dtype):
torch.manual_seed(42)
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
attention_block = AttentionBlock(
shape.embed_dim, shape.num_heads, device, dtype
)
total_tokens = shape.batch_size * shape.max_seq_len
x_packed = torch.randn(
total_tokens,
shape.embed_dim,
device=device,
dtype=dtype,
)
cu_seq = torch.tensor(
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
)
q, k, v = attention_block.get_varlen_qkv(x_packed)
torch.library.opcheck(
torch.ops.torch_attn._varlen_attn,
(q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False),
)
out, lse, rng_state = torch.ops.torch_attn._varlen_attn(
q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False
)
grad_out = torch.randn_like(out)
# we don't support double backward
# skipping test_autograd_registration, test_aot_dispatch_dynamic, test_aot_dispatch_static
torch.library.opcheck(
torch.ops.torch_attn._varlen_attn_backward,
(
grad_out,
q,
k,
v,
out,
lse,
cu_seq,
cu_seq,
shape.max_seq_len,
shape.max_seq_len,
False,
rng_state,
),
test_utils=["test_schema", "test_faketensor"],
)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@parametrize("dtype", [torch.bfloat16, torch.float16])
def test_custom_op_registration(self, device, dtype):
torch.manual_seed(42)
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
attention_block = AttentionBlock(
shape.embed_dim, shape.num_heads, device, dtype
)
total_tokens = shape.batch_size * shape.max_seq_len
x_packed = torch.randn(
total_tokens,
shape.embed_dim,
device=device,
dtype=dtype,
requires_grad=True,
)
cu_seq = torch.tensor(
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
)
compiled_forward = torch.compile(
attention_block.forward_varlen, backend="eager", fullgraph=True
)
with OpLoggingMode() as mode:
output = compiled_forward(
x_packed, cu_seq, shape.max_seq_len, is_causal=False
)
varlen_grad_out = torch.ones_like(output)
_ = torch.autograd.grad(
outputs=output,
inputs=x_packed,
grad_outputs=varlen_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
called_ops = mode.called_ops
custom_ops_called = any(
"torch_attn._varlen_attn" in op for op in called_ops
) and any("torch_attn._varlen_attn_backward" in op for op in called_ops)
assert custom_ops_called
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@ -172,7 +338,10 @@ class TestVarlenAttention(NNTestCase):
is_causal=is_causal,
)
sdpa_output = attention_block.forward_sdpa(
variable_length_batch_data["x_padded"], is_causal=is_causal
variable_length_batch_data["x_padded"],
variable_length_batch_data["seq_lengths"],
dtype=dtype,
is_causal=is_causal,
)
tolerances = default_tolerances[dtype]
@ -186,6 +355,44 @@ class TestVarlenAttention(NNTestCase):
torch.testing.assert_close(varlen_seq, sdpa_seq, **tolerances)
start_idx = end_idx
varlen_grad_out = torch.ones_like(varlen_output)
sdpa_grad_out = torch.zeros_like(sdpa_output)
start_idx = 0
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
end_idx = start_idx + seq_len
sdpa_grad_out[i, :seq_len] = varlen_grad_out[start_idx:end_idx]
start_idx = end_idx
varlen_grad = torch.autograd.grad(
outputs=varlen_output,
inputs=variable_length_batch_data["x_packed"],
grad_outputs=varlen_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
sdpa_grad = torch.autograd.grad(
outputs=sdpa_output,
inputs=variable_length_batch_data["x_padded"],
grad_outputs=sdpa_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
start_idx = 0
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
end_idx = start_idx + seq_len
varlen_grad_seq = varlen_grad[start_idx:end_idx]
sdpa_grad_seq = sdpa_grad[i, :seq_len]
torch.testing.assert_close(varlen_grad_seq, sdpa_grad_seq, **tolerances)
start_idx = end_idx
device_types = ("cuda",)

View File

@ -445,7 +445,7 @@ use_numpy_random_stream = False
enable_cpp_guard_manager = True
# Use C++ guard manager for symbolic shapes
enable_cpp_symbolic_shape_guards = not is_fbcode()
enable_cpp_symbolic_shape_guards = False
# Enable tracing through contextlib.contextmanager
enable_trace_contextlib = True

View File

@ -42,7 +42,7 @@ import weakref
from dataclasses import dataclass
from enum import Enum
from os.path import dirname, join
from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union
from typing import Any, NamedTuple, Optional, Sized, TYPE_CHECKING, Union
from unittest.mock import patch
import sympy
@ -395,6 +395,13 @@ class OptimizedModule(torch.nn.Module):
self._initialize()
self.training = self._orig_mod.training
def __len__(self) -> int:
# Proxy the len call to the original module
if isinstance(self._orig_mod, Sized):
return len(self._orig_mod)
# Mimic python's default behavior for objects without a length
raise TypeError(f"{type(self._orig_mod).__name__} does not support len()")
def _initialize(self) -> None:
# Do this stuff in constructor to lower overhead slightly
if isinstance(self.dynamo_ctx, DisableContext):

View File

@ -2820,5 +2820,36 @@
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0280": [
{
"Gb_type": "1-arg super not implemented",
"Context": "",
"Explanation": "Dynamo failed to trace attribute `{name}` accessed via `super()` (for type `{self.typevar}` and object `{self.objvar}`) because one-argument of super() is not supported.",
"Hints": [
"Use two-argument super(type, object_or_type)."
]
}
],
"GB0281": [
{
"Gb_type": "Invalid or non-const argument in nn.Module __getitem__",
"Context": "call_method: {self} {name} {args} {kwargs}",
"Explanation": "Dynamo does not support calling method `{name}` of ``nn.Module`` {module} with a non-constant or non-(str, int) key.",
"Hints": [
"Use constant arguments of type str or int for __getitem__"
]
}
],
"GB0282": [
{
"Gb_type": "Placement with custom __getattr__ not supported",
"Context": "{value_type.__name__} with custom __getattr__",
"Explanation": "Dynamo does not support Placement types with custom __getattr__ methods",
"Hints": [
"Use Placement types without custom __getattr__ methods",
"Move the Placement usage outside the compiled region"
]
}
]
}

View File

@ -210,9 +210,16 @@ class PlacementVariable(DistributedVariable):
if name in constant_fold_functions:
try:
value_type = type(self.value)
assert (
inspect.getattr_static(value_type, "__getattr__", None) is None
), "no custom getattr allowed!"
if inspect.getattr_static(value_type, "__getattr__", None) is not None:
unimplemented_v2(
gb_type="Placement with custom __getattr__ not supported",
context=f"{value_type.__name__} with custom __getattr__",
explanation="Dynamo does not support Placement types with custom __getattr__ methods",
hints=[
"Use Placement types without custom __getattr__ methods",
"Move the Placement usage outside the compiled region",
],
)
method = inspect.getattr_static(value_type, name)
except AttributeError:
method = None

View File

@ -103,7 +103,17 @@ class SuperVariable(VariableTracker):
codegen.extend_output(create_call_function(1, False))
def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name):
assert self.objvar, "1-arg super not implemented"
if not self.objvar:
unimplemented_v2(
gb_type="1-arg super not implemented",
context="",
explanation=f"Dynamo failed to trace attribute `{name}` accessed "
f"via `super()` (for type `{self.typevar}` and object `{self.objvar}`) "
"because one-argument of super() is not supported.",
hints=[
"Use two-argument super(type, object_or_type).",
],
)
search_type = self.typevar.as_python_constant()
# The rest of this function does two things:

View File

@ -822,9 +822,19 @@ class NNModuleVariable(VariableTracker):
)
if type(module).__getitem__ not in builtin_supported:
assert isinstance(args[0], variables.ConstantVariable), typestr(args[0])
key = args[0].as_python_constant()
assert isinstance(key, (str, int))
if not (
isinstance(args[0], variables.ConstantVariable)
and isinstance(args[0].as_python_constant(), (str, int))
):
unimplemented_v2(
gb_type="Invalid or non-const argument in nn.Module __getitem__",
context=f"call_method: {self} {name} {args} {kwargs}",
explanation="Dynamo does not support calling "
f"method `{name}` of ``nn.Module`` {module} with a non-constant or non-(str, int) key.",
hints=[
"Use constant arguments of type str or int for __getitem__"
],
)
fn = getattr(module, name).__func__
assert isinstance(fn, types.FunctionType)

View File

@ -1793,14 +1793,6 @@ def _aot_stage2b_bw_compile(
# tensor which is wrong.
ph_size = ph_arg.size()
# pyrefly: ignore # bad-argument-type
if len(ph_size) == 0 and len(real_stride) > 0:
# Fix for 0-dimensional tensors: When a tensor becomes 0-d
# (e.g., via squeeze), its stride should be () not (1,).
# This mismatch can occur when dynamic shape operations produce
# tensors that are later squeezed to 0-d. The stride metadata
# may get preserved causing a dimension mismatch (#164814)
real_stride = ()
# pyrefly: ignore # bad-argument-type
placeholder_list[i] = ph_arg.as_strided(ph_size, real_stride)

View File

@ -720,13 +720,22 @@ def check_shape(
) -> None:
backend = get_current_backend()
assert shape is not None
if config.test_configs.runtime_triton_dtype_assert and backend == "triton":
if config.test_configs.runtime_triton_shape_assert and backend == "triton":
shape_str = (
", ".join(str(d) for d in shape) if len(shape) != 1 else f"{shape[0]},"
)
buffer.writeline(f"tl.static_assert({var}.shape == ({shape_str}))")
def check_nan(buffer: IndentedBuffer, var: CSEVariableType) -> None:
backend = get_current_backend()
if backend == "triton":
msg = "NaN or Inf found"
buffer.writeline(
f"tl.device_assert(({var} == {var}) & ({var} != float('inf')) & ({var} != float('-inf')), '{msg}')"
)
class DataTypePropagation:
def __init__(self, body: LoopBody) -> None:
self.body = body
@ -2623,6 +2632,9 @@ class CSEProxy(DefaultHandler):
assert output_shape is not None
check_shape(V.kernel.compute, csevar, output_shape)
if config.runtime_triton_nan_asserts:
check_nan(V.kernel.compute, csevar)
return csevar
return pytree.tree_map(do_cse, value)

View File

@ -626,7 +626,7 @@ class ComboKernel(Kernel):
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
filename=__file__,
num_warps={self.num_warps},
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)

View File

@ -206,6 +206,9 @@ static_weight_shapes = True
# put correctness assertions in generated code
size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1"
runtime_triton_nan_asserts = (
os.environ.get("TORCHINDUCTOR_RUNTIME_TRITON_NAN_ASSERTS") == "1"
)
scalar_asserts = os.environ.get("TORCHINDUCTOR_SCALAR_ASSERTS", "1") == "1"
# Disable by default in fbcode

View File

@ -3550,24 +3550,13 @@ def user_autotune(
)
def foreach(triton_meta, filename=None, inductor_meta=None):
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
configs = []
# Naive autotuning path for num_warps
if not inductor_meta.get("autotune_pointwise", True) and not (
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
):
configs.append(triton.Config({}, num_stages=1, num_warps=8))
else:
for warps in [1, 2, 4, 8]:
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
return cached_autotune(
None,
configs,
[triton.Config({}, num_stages=1, num_warps=num_warps)],
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,

View File

@ -409,9 +409,10 @@ class SchedulerDonatedBuffer(SchedulerBuffer):
class BaseSchedulerNode:
ancestors: OrderedSet[str]
debug_device_str: Callable[[BaseSchedulerNode], list[str]]
group: tuple[torch.device, tuple[tuple[sympy.Expr, ...], ...]]
read_writes: dependencies.ReadWrites
unmet_dependencies: OrderedSet[Dep]
last_usage: OrderedSet[str]
# .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
# e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node
# in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3.
@ -420,22 +421,24 @@ class BaseSchedulerNode:
min_order: int
max_order: int
mpi_node: MemoryPlanningInfoForNode
mutation_renames: dict[str, str]
node: Optional[ir.Operation]
outputs: list[SchedulerBuffer]
outputs_by_name: dict[str, SchedulerBuffer]
override_estimated_runtime: Optional[float] = None
read_writes: dependencies.ReadWrites
unmet_dependencies: OrderedSet[Dep]
def __init__(self, scheduler: Scheduler) -> None:
self.scheduler: Scheduler = scheduler
self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = (
lambda *args, **kwargs: []
)
self.scheduler = scheduler
self.debug_device_str = lambda *args, **kwargs: []
def _init_from_node(self, node: ir.Operation) -> None:
self.node: Optional[ir.Operation] = node
self.ancestors: OrderedSet[str] = OrderedSet()
self.last_usage = OrderedSet[
str
]() # buffers that won't be used after this kernel
self.node = node
self.ancestors = OrderedSet()
self.last_usage = OrderedSet() # buffers that won't be used after this kernel
self.written = False
self.outputs: list[SchedulerBuffer] = [
self.outputs = [
SchedulerBuffer(
scheduler=self.scheduler,
node=output,
@ -443,16 +446,14 @@ class BaseSchedulerNode:
)
for output in node.get_outputs()
]
self.outputs_by_name: dict[str, SchedulerBuffer] = {
buf.get_name(): buf for buf in self.outputs
}
self.outputs_by_name = {buf.get_name(): buf for buf in self.outputs}
# mutation_renames for the current node. Due to potential
# more mutations happening later, this can be different
# to Scheduler.mutation_renames. Also this dict should be small
# since only mutation information relevant to the deps for this
# node is stored here.
self.mutation_renames: dict[str, str] = {}
self.mutation_renames = {}
def __repr__(self) -> str:
return f"{type(self).__name__}(name={self.get_name()!r})"
@ -2435,6 +2436,34 @@ def pick_loop_order(
return order
def _replace_operation_buffer(
orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
) -> None:
replaced_buf_name = new_node.get_name()
orig_buf_name = orig_node.get_name()
assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
replaced_op_name = new_node.get_operation_name()
orig_op_name = orig_node.get_operation_name()
assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
del V.graph.name_to_buffer[replaced_buf_name]
new_node.name = orig_buf_name
del V.graph.name_to_op[replaced_op_name]
new_node.operation_name = orig_op_name
orig = V.graph.buffers.index(orig_node)
V.graph.buffers.remove(new_node)
V.graph.buffers[orig] = new_node
V.graph.name_to_buffer[orig_buf_name] = new_node
orig = V.graph.operations.index(orig_node)
V.graph.operations.remove(new_node)
V.graph.operations[orig] = new_node
V.graph.name_to_op[orig_op_name] = new_node
@dataclasses.dataclass
class NodeUser:
node: Union[BaseSchedulerNode, OutputNode]
@ -3336,33 +3365,6 @@ class Scheduler:
will force completion of compilation and benchmarking.
"""
def replace_operation_buffer(
orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
) -> None:
replaced_buf_name = new_node.get_name()
orig_buf_name = orig_node.get_name()
assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
replaced_op_name = new_node.get_operation_name()
orig_op_name = orig_node.get_operation_name()
assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
del V.graph.name_to_buffer[replaced_buf_name]
new_node.name = orig_buf_name
del V.graph.name_to_op[replaced_op_name]
new_node.operation_name = orig_op_name
orig = V.graph.buffers.index(orig_node)
V.graph.buffers.remove(new_node)
V.graph.buffers[orig] = new_node
V.graph.name_to_buffer[orig_buf_name] = new_node
orig = V.graph.operations.index(orig_node)
V.graph.operations.remove(new_node)
V.graph.operations[orig] = new_node
V.graph.name_to_op[orig_op_name] = new_node
for i, node in enumerate(self.nodes):
if isinstance(node, SchedulerNode) and isinstance(
node.node, ir.MultiTemplateBuffer
@ -3416,40 +3418,47 @@ class Scheduler:
assign_origin_node(out_tensorbox, multi_node.origin_node)
out_buffer.layout = multi_node.layout
replace_operation_buffer(multi_node, out_buffer)
new_scheduler_node = self.create_scheduler_node(out_buffer)
self._replace_node(out_buffer, multi_node, i, node)
self.nodes[i] = new_scheduler_node
self.name_to_node[node.get_name()] = new_scheduler_node
self.name_to_fused_node[node.get_name()] = new_scheduler_node
def _replace_node(
self,
out_buffer: ir.OperationBuffer,
multi_node: ir.MultiTemplateBuffer,
i: int,
node: SchedulerNode,
) -> None:
_replace_operation_buffer(multi_node, out_buffer)
new_scheduler_node = self.create_scheduler_node(out_buffer)
# We need to reflect the mutation renames that were recorded in the original node
mutation_renames = {}
for dep in itertools.chain(
node.read_writes.reads, node.unmet_dependencies
):
if real_name := self.mutation_real_name.get(dep.name, None):
mutation_renames[real_name] = dep.name
self.nodes[i] = new_scheduler_node
self.name_to_node[node.get_name()] = new_scheduler_node
self.name_to_fused_node[node.get_name()] = new_scheduler_node
def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
return OrderedSet(dep.rename(mutation_renames) for dep in deps)
# We need to reflect the mutation renames that were recorded in the original node
mutation_renames = {}
for dep in itertools.chain(node.read_writes.reads, node.unmet_dependencies):
if real_name := self.mutation_real_name.get(dep.name, None):
mutation_renames[real_name] = dep.name
new_scheduler_node.unmet_dependencies = rename_deps(
new_scheduler_node.unmet_dependencies
)
new_scheduler_node.read_writes.reads = rename_deps(
new_scheduler_node.read_writes.reads
)
def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
return OrderedSet(dep.rename(mutation_renames) for dep in deps)
for new_out, old_out in zip(
new_scheduler_node.get_outputs(), node.get_outputs()
):
self.name_to_buf[old_out.get_name()] = new_out
new_out.users = old_out.users
new_scheduler_node.unmet_dependencies = rename_deps(
new_scheduler_node.unmet_dependencies
)
new_scheduler_node.read_writes.reads = rename_deps(
new_scheduler_node.read_writes.reads
)
new_scheduler_node.min_order = node.min_order
new_scheduler_node.max_order = node.max_order
new_scheduler_node.last_usage = node.last_usage
for new_out, old_out in zip(
new_scheduler_node.get_outputs(), node.get_outputs()
):
self.name_to_buf[old_out.get_name()] = new_out
new_out.users = old_out.users
new_scheduler_node.min_order = node.min_order
new_scheduler_node.max_order = node.max_order
new_scheduler_node.last_usage = node.last_usage
def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool:
return any(

View File

@ -1,13 +1,9 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
from dataclasses import dataclass, field
from typing import cast, Optional
from typing import cast
import torch
import torch.distributed._functional_collectives as funcol
from torch.distributed._local_tensor import maybe_run_for_local_tensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._op_schema import (
OpSchema,
OpStrategy,
@ -19,8 +15,8 @@ from torch.distributed.tensor._ops.utils import (
register_op_strategy,
)
from torch.distributed.tensor.placement_types import (
MaskPartial,
Partial,
Placement,
Replicate,
Shard,
)
@ -29,190 +25,6 @@ from torch.distributed.tensor.placement_types import (
aten = torch.ops.aten
@dataclass
class MaskBuffer:
data: Optional[torch.Tensor] = None
# refcount allows shared usage of the MaskBuffer, as long as all users have the same data
refcount: int = 0
def materialize_mask(self, mask):
if self.refcount == 0:
self.data = mask
else:
assert self.data is not None
if not torch.equal(self.data, mask):
raise RuntimeError(
"MaskBuffer has been materialized with conflicting data"
)
self.refcount += 1
def release_mask(self):
if self.refcount == 0 or self.data is None:
raise RuntimeError("MaskBuffer has not been materialized")
self.refcount -= 1
if self.refcount == 0:
self.data = None
def apply_mask(self, tensor):
if self.refcount == 0 or self.data is None:
raise RuntimeError("MaskBuffer has not been materialized")
# NOTE: _MaskPartial is being used by the embedding op and the gather op.
# For gather, the mask has the same dimension as the output tensor, whereas
# the output of the embedding op has an additional dimension compare to the input,
# hence the output masking logic below having two different cases.
if tensor.ndim == self.data.ndim:
tensor[self.data] = 0.0
else:
tensor[self.data, :] = 0.0
@dataclass(frozen=True)
class _MaskPartial(Partial):
"""
A partial mask placement devised for rowwise sharded embedding op, where we need
to mask and adjust the indices to the local embedding shard, embedding masking
is a special type of the Partial placement
NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor
lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor.
"""
mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
# required fields for computing the local offset and deriving the mask
offset_shape: Optional[torch.Size] = None
offset_dim: int = 0
def __init__(
self,
reduce_op=None,
mask_buffer=None,
offset_shape=None,
offset_dim=0,
*args,
**kwargs,
):
super().__init__(reduce_op)
if mask_buffer is None:
mask_buffer = MaskBuffer()
object.__setattr__(self, "mask_buffer", mask_buffer)
object.__setattr__(self, "offset_shape", offset_shape)
object.__setattr__(self, "offset_dim", offset_dim)
@staticmethod
@maybe_run_for_local_tensor
def _mask_tensor(
tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int
) -> tuple[torch.Tensor, torch.Tensor]:
# Build the input mask and save it for the current partial placement
# this is so that the output of embedding op can reuse the same partial
# placement saved mask to perform mask + reduction
mask = (tensor < local_offset_on_dim) | (
tensor >= local_offset_on_dim + local_shard_size
)
# mask the input tensor
masked_tensor = tensor.clone() - local_offset_on_dim
masked_tensor[mask] = 0
return mask, masked_tensor
def _partition_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
my_coordinate = mesh.get_coordinate()
assert my_coordinate is not None, "my_coordinate should not be None"
# override parent logic to perform partial mask for embedding
num_chunks = mesh.size(mesh_dim)
# get local shard size and offset on the embedding_dim
assert self.offset_shape is not None, (
"offset_shape needs to be set for _MaskPartial"
)
local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset(
self.offset_shape[self.offset_dim],
num_chunks,
my_coordinate[mesh_dim],
)
mask, masked_tensor = _MaskPartial._mask_tensor(
tensor, local_offset_on_dim, local_shard_size
)
# materialize the mask buffer to be used for reduction
self.mask_buffer.materialize_mask(mask)
return masked_tensor
def _reduce_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
# by the time we need reduction, we should have already saved the mask
assert self.mask_buffer.data is not None
# apply the mask to the tensor that pending reduction
self.mask_buffer.apply_mask(tensor)
# clear the mask buffer
self.mask_buffer.release_mask()
# perform sum reduction
return funcol.all_reduce(
tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim)
)
def _reduce_shard_value(
self,
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
shard_spec: Placement,
) -> torch.Tensor:
# by the time we need reduction, we should have already saved the mask
assert self.mask_buffer.data is not None
# apply the mask to the tensor that pending reduction
self.mask_buffer.apply_mask(tensor)
# clear the mask buffer
self.mask_buffer.release_mask()
# call reduce_shard_tensor of the shard_spec.
shard_spec = cast(Shard, shard_spec)
return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim)
def __eq__(self, other: object) -> bool:
if not isinstance(other, _MaskPartial):
return False
# if either data is not None, we invalidate the sharding cache, as this indicates
# the current MaskPartial placement is still in use and should not be used for cache hit.
if self.mask_buffer.data is not None or other.mask_buffer.data is not None:
return False
return (
self.reduce_op == other.reduce_op
and self.offset_shape == other.offset_shape
and self.offset_dim == other.offset_dim
)
def __hash__(self) -> int:
return 1 + hash(
(
self.reduce_op,
self.offset_shape,
self.offset_dim,
)
)
def __repr__(self) -> str:
"""
machine readable representation of the MaskPartial placement
"""
return f"_MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})"
def __str__(self) -> str:
"""
human readable representation of the MaskPartial placement
"""
return "MaskP"
@register_op_strategy(aten.embedding.default)
def embedding_strategy(op_schema: OpSchema) -> StrategyType:
"""
@ -239,7 +51,7 @@ def embedding_strategy(op_schema: OpSchema) -> StrategyType:
single_mesh_dim_strategies.append(colwise_sharding)
# rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial
embedding_partial_placement = _MaskPartial(offset_shape=weight_shape, offset_dim=0)
embedding_partial_placement = MaskPartial(offset_shape=weight_shape, offset_dim=0)
# NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates
# from the input indices and use it for output reduction

View File

@ -0,0 +1,44 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from dataclasses import dataclass
from typing import Optional
import torch
@dataclass
class MaskBuffer:
data: Optional[torch.Tensor] = None
# refcount allows shared usage of the MaskBuffer, as long as all users have the same data
refcount: int = 0
def materialize_mask(self, mask):
if self.refcount == 0:
self.data = mask
else:
assert self.data is not None
if not torch.equal(self.data, mask):
raise RuntimeError(
"MaskBuffer has been materialized with conflicting data"
)
self.refcount += 1
def release_mask(self):
if self.refcount == 0 or self.data is None:
raise RuntimeError("MaskBuffer has not been materialized")
self.refcount -= 1
if self.refcount == 0:
self.data = None
def apply_mask(self, tensor):
if self.refcount == 0 or self.data is None:
raise RuntimeError("MaskBuffer has not been materialized")
# NOTE: MaskPartial is being used by the embedding op and the gather op.
# For gather, the mask has the same dimension as the output tensor, whereas
# the output of the embedding op has an additional dimension compare to the input,
# hence the output masking logic below having two different cases.
if tensor.ndim == self.data.ndim:
tensor[self.data] = 0.0
else:
tensor[self.data, :] = 0.0

View File

@ -17,7 +17,7 @@ from torch.distributed.tensor._op_schema import (
TupleStrategy,
)
from torch.distributed.tensor._ops._common_rules import pointwise_rule
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor._ops._embedding_ops import MaskPartial
from torch.distributed.tensor._ops.utils import (
expand_to_full_mesh_op_strategy,
generate_redistribute_costs,
@ -646,7 +646,7 @@ def gather_strategy(op_schema: OpSchema) -> StrategyType:
# this only works when the input is sharded on the gather dimension, and
# index has size 1 on the gather dimension
if dim < len(index_shape) and index_shape[dim] == 1:
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
index_partial_placement = MaskPartial(offset_shape=input_shape, offset_dim=dim)
input_sharding: PlacementList = [
index_partial_placement,
Shard(dim),

View File

@ -11,7 +11,7 @@ from torch import Tensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor._ops._embedding_ops import MaskPartial
from torch.distributed.tensor._ops._math_ops import (
_skip_dim,
Reduction,
@ -236,7 +236,7 @@ def _nll_loss_forward(
# The following code block is a distributed version of
# result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
partial_placement = MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
safe_target_partial_ = partial_placement._partition_value(
safe_target_, mesh, mesh_dim
)
@ -375,7 +375,7 @@ def _nll_loss_and_log_softmax_backward(
# The following code block is a distributed version of
# grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
partial_placement = MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
safe_target = safe_target.squeeze(channel_dim).flatten()
masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim)
# only update grad_input to -1 if not masked

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from dataclasses import dataclass, field
from typing import cast, Optional
import torch
@ -17,9 +18,10 @@ from torch.distributed.tensor._collective_utils import (
shard_dim_alltoall,
unpad_tensor,
)
from torch.distributed.tensor._ops._mask_buffer import MaskBuffer
__all__ = ["Placement", "Shard", "Replicate", "Partial"]
__all__ = ["Placement", "Shard", "Replicate", "Partial", "MaskPartial"]
# Appease TestPublicBindings.test_correct_module_names
@ -841,3 +843,149 @@ class Partial(torch._C._distributed.Partial):
# We keep the old _Partial name for a while for BC reason
_Partial = Partial
@dataclass(frozen=True)
class MaskPartial(Partial):
"""
A partial mask placement devised for rowwise sharded embedding op, where we need
to mask and adjust the indices to the local embedding shard, embedding masking
is a special type of the Partial placement
NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor
lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor.
"""
mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
# required fields for computing the local offset and deriving the mask
offset_shape: Optional[torch.Size] = None
offset_dim: int = 0
def __init__(
self,
reduce_op=None,
mask_buffer=None,
offset_shape=None,
offset_dim=0,
*args,
**kwargs,
):
super().__init__(reduce_op)
if mask_buffer is None:
mask_buffer = MaskBuffer()
object.__setattr__(self, "mask_buffer", mask_buffer)
object.__setattr__(self, "offset_shape", offset_shape)
object.__setattr__(self, "offset_dim", offset_dim)
@staticmethod
@maybe_run_for_local_tensor
def _mask_tensor(
tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int
) -> tuple[torch.Tensor, torch.Tensor]:
# Build the input mask and save it for the current partial placement
# this is so that the output of embedding op can reuse the same partial
# placement saved mask to perform mask + reduction
mask = (tensor < local_offset_on_dim) | (
tensor >= local_offset_on_dim + local_shard_size
)
# mask the input tensor
masked_tensor = tensor.clone() - local_offset_on_dim
masked_tensor[mask] = 0
return mask, masked_tensor
def _partition_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
my_coordinate = mesh.get_coordinate()
assert my_coordinate is not None, "my_coordinate should not be None"
# override parent logic to perform partial mask for embedding
num_chunks = mesh.size(mesh_dim)
# get local shard size and offset on the embedding_dim
assert self.offset_shape is not None, (
"offset_shape needs to be set for MaskPartial"
)
local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset(
self.offset_shape[self.offset_dim],
num_chunks,
my_coordinate[mesh_dim],
)
mask, masked_tensor = MaskPartial._mask_tensor(
tensor, local_offset_on_dim, local_shard_size
)
# materialize the mask buffer to be used for reduction
self.mask_buffer.materialize_mask(mask)
return masked_tensor
def _reduce_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
# by the time we need reduction, we should have already saved the mask
assert self.mask_buffer.data is not None
# apply the mask to the tensor that pending reduction
self.mask_buffer.apply_mask(tensor)
# clear the mask buffer
self.mask_buffer.release_mask()
# perform sum reduction
return funcol.all_reduce(
tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim)
)
def _reduce_shard_value(
self,
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
shard_spec: Placement,
) -> torch.Tensor:
# by the time we need reduction, we should have already saved the mask
assert self.mask_buffer.data is not None
# apply the mask to the tensor that pending reduction
self.mask_buffer.apply_mask(tensor)
# clear the mask buffer
self.mask_buffer.release_mask()
# call reduce_shard_tensor of the shard_spec.
shard_spec = cast(Shard, shard_spec)
return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim)
def __eq__(self, other: object) -> bool:
if not isinstance(other, MaskPartial):
return False
# if either data is not None, we invalidate the sharding cache, as this indicates
# the current MaskPartial placement is still in use and should not be used for cache hit.
if self.mask_buffer.data is not None or other.mask_buffer.data is not None:
return False
return (
self.reduce_op == other.reduce_op
and self.offset_shape == other.offset_shape
and self.offset_dim == other.offset_dim
)
def __hash__(self) -> int:
return 1 + hash(
(
self.reduce_op,
self.offset_shape,
self.offset_dim,
)
)
def __repr__(self) -> str:
"""
machine readable representation of the MaskPartial placement
"""
return f"MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})"
def __str__(self) -> str:
"""
human readable representation of the MaskPartial placement
"""
return "MaskP"

View File

@ -14,14 +14,11 @@ from torch.backends.cuda import (
SDPAParams,
)
from .varlen import varlen_attn
__all__: list[str] = [
"SDPBackend",
"sdpa_kernel",
"WARN_FOR_UNFUSED_KERNELS",
"varlen_attn",
]
# Note: [SDPA warnings]

View File

@ -7,7 +7,7 @@ that calls into the optimized Flash Attention kernels.
import logging
from functools import lru_cache
from typing import NamedTuple, Optional, Union
from typing import Any, NamedTuple, Optional, Union
import torch
@ -33,8 +33,7 @@ class AuxRequest(NamedTuple):
lse: bool = False
# import failures when I try to register as custom op
# @torch.library.custom_op("torch_nn_attention::_varlen_attn", mutates_args={})
@torch.library.custom_op("torch_attn::_varlen_attn", mutates_args={})
def _varlen_attn(
query: torch.Tensor,
key: torch.Tensor,
@ -44,7 +43,7 @@ def _varlen_attn(
max_q: int,
max_k: int,
is_causal: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Private custom op for variable-length attention.
@ -70,7 +69,7 @@ def _varlen_attn(
False, # return_debug_mask
)
# cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask)
output, softmax_lse = result[0], result[1]
output, softmax_lse, rng_state = result[0], result[1], result[6]
else:
log.info("Using Flash Attention backend for varlen_attn")
output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward(
@ -86,10 +85,13 @@ def _varlen_attn(
return_debug_mask=False,
)
return output, softmax_lse
rng_state_ = torch.zeros(
(2,), dtype=torch.uint64, device=query.device
) # hardcoded since dropout is hardcoded to 0
return output, softmax_lse, rng_state_
# @_varlen_attn.register_fake
@_varlen_attn.register_fake
def _varlen_attn_fake(
query: torch.Tensor,
key: torch.Tensor,
@ -99,7 +101,7 @@ def _varlen_attn_fake(
max_q: int,
max_k: int,
is_causal: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Fake implementation for meta tensor computation and tracing.
@ -117,7 +119,9 @@ def _varlen_attn_fake(
(num_heads, total_q), dtype=torch.float, device=query.device
)
return output, logsumexp
rng_state = torch.empty((2,), dtype=torch.uint64, device=query.device)
return output, logsumexp, rng_state
def varlen_attn(
@ -191,9 +195,145 @@ def varlen_attn(
... query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False
... )
"""
out, lse = _varlen_attn(
out, lse, _ = torch.ops.torch_attn._varlen_attn(
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal
)
if return_aux is not None and return_aux.lse:
return out, lse
return out
def _setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None:
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal = inputs
out, lse, rng_state = output
ctx.query = query
ctx.key = key
ctx.value = value
ctx.cu_seq_q = cu_seq_q
ctx.cu_seq_k = cu_seq_k
ctx.max_q = max_q
ctx.max_k = max_k
ctx.is_causal = is_causal
ctx.output = out
ctx.lse = lse
ctx.rng_state = rng_state
@torch.library.custom_op("torch_attn::_varlen_attn_backward", mutates_args={})
def _varlen_attn_backward(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
lse: torch.Tensor,
cu_seq_q: torch.Tensor,
cu_seq_k: torch.Tensor,
max_q: int,
max_k: int,
is_causal: bool,
rng_state: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
unused = torch.empty(0, device=query.device)
use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
if use_cudnn:
log.info("Using cuDNN backend for varlen_attn")
dq, dk, dv = torch.ops.aten._cudnn_attention_backward(
grad_out,
query,
key,
value,
out,
lse,
cu_seq_q,
cu_seq_k,
max_q,
max_k,
0.0,
is_causal,
rng_state,
unused,
)
else:
log.info("Using Flash Attention backend for varlen_attn")
dq, dk, dv = torch.ops.aten._flash_attention_backward(
grad_out,
query,
key,
value,
out,
lse,
cu_seq_q,
cu_seq_k,
max_q,
max_k,
0.0,
is_causal,
rng_state,
unused,
)
return dq, dk, dv
@_varlen_attn_backward.register_fake
def _varlen_attn_backward_fake(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
lse: torch.Tensor,
cu_seq_q: torch.Tensor,
cu_seq_k: torch.Tensor,
max_q: int,
max_k: int,
is_causal: bool,
rng_state: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Fake implementation for meta tensor computation and tracing.
"""
grad_query = torch.empty_like(query)
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)
return grad_query, grad_key, grad_value
def _backward(
ctx: Any, grad_out: torch.Tensor, grad_lse: torch.Tensor, grad_rng: torch.Tensor
) -> tuple[Optional[torch.Tensor], ...]:
query = ctx.query
key = ctx.key
value = ctx.value
cu_seq_q = ctx.cu_seq_q
cu_seq_k = ctx.cu_seq_k
max_q = ctx.max_q
max_k = ctx.max_k
is_causal = ctx.is_causal
out = ctx.output
lse = ctx.lse
rng_state = ctx.rng_state
# rng_state = torch.empty(2, device=query.device)
dq, dk, dv = torch.ops.torch_attn._varlen_attn_backward(
grad_out,
query,
key,
value,
out,
lse,
cu_seq_q,
cu_seq_k,
max_q,
max_k,
is_causal,
rng_state,
)
return dq, dk, dv, None, None, None, None, None, None
_varlen_attn.register_autograd(_backward, setup_context=_setup_context)

View File

@ -74,6 +74,17 @@ def export_compat(
if opset_version is None:
opset_version = onnx_constants.ONNX_DEFAULT_OPSET
if isinstance(model, torch.nn.Module):
if model.training:
warnings.warn(
"Exporting a model while it is in training mode. "
"Please ensure that this is intended, as it may lead to "
"different behavior during inference. "
"Calling model.eval() before export is recommended.",
UserWarning,
stacklevel=2,
)
if isinstance(model, torch.export.ExportedProgram):
# We know the model is already exported program, so the args, kwargs, and dynamic_shapes
# are not used

View File

@ -812,7 +812,6 @@ if torch.backends.mps.is_available():
"__rmod__",
"__rsub__",
"__rpow__",
"bernoulli",
"clamp_max",
"clamp_min",
"masked_scatter",

View File

@ -447,6 +447,56 @@ def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
def ceil_div(a, b):
return (a + b - 1) // b
# NVIDIA Blackwell HW requires scales for MX/NV blocked formats to be in a 128x4 tile layout,
# with a weird 32x4x4 internal layout of that tile. If we want to take swizzled scales and use them
# for non-gemm purposes (like testing), we need to de-swizzle them, then they can be applied much
# more naturally.
def from_blocked(input, input_scales, blocksize) -> torch.Tensor:
# Matrix is in a 128x4 pattern, internally blocked as 32x4x4 nonsense.
# Output should be [input.size(0, input.size(1) // blocksize] scales
output_scales = torch.zeros(
(input.size(0), input.size(1) // blocksize),
device=input.device,
dtype=input_scales.dtype,
)
# Swizzled scales are padded to tiles of 128x4, we need to replicate how that padding
# happened for offset purposes.
# There are K//blocksize scales, padded to groups of 4.
num_col_tiles = ceil_div(ceil_div(input.size(1), blocksize), 4)
# (Very) slow reference implementation using horrifying loops.
for i in range(input.size(0)):
for j in range(input.size(1) // blocksize):
# which 128x4 tile of scaling factors am I in
scale_tile_h = i // 128
scale_tile_w = j // 4
# There are (padded) input_scales.size(1) // 4 tiles along the w dim.
# So offset is 512 * (h_tile * tiles_per_row + tile_in_row)
tile_offset = 512 * (scale_tile_h * num_col_tiles + scale_tile_w)
# indices within the tile - use nomenclature directly from cublas docs
outer = i % 128 # "outer" in cublas docs
inner = j % 4 # "inner" in cublas docs
# Note: "offset" is given in terms of bytes, in cublas docs, but our scales are e8m0,
# anyway, and so 1B == 1 value => use offset directly.
# Formula directly from cublas docs in 3.1.4.3.2
offset = tile_offset + (outer % 32) * 16 + (outer // 32) * 4 + inner
output_scales[i, j] = input_scales[offset]
return output_scales
def from_blocked_format(x_mxfp8, scales_unswizzled, blocksize=32):
# expand scales
scales = torch.repeat_interleave(scales_unswizzled, blocksize, dim=1)
# de-scale and convert
x_f32 = x_mxfp8.to(torch.float) * scales.to(torch.float)
return x_f32.to(torch.bfloat16)
def to_blocked(input_matrix) -> torch.Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.

View File

@ -7,7 +7,8 @@ import inspect
import sys
import warnings
from collections.abc import Callable
from typing import Any, cast, TypeVar
from typing import Any, cast, overload, TypeVar
from typing_extensions import Self
# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
@ -158,7 +159,12 @@ class _DecoratorContextManager:
class _NoParamDecoratorContextManager(_DecoratorContextManager):
"""Allow a context manager to be used as a decorator without parentheses."""
def __new__(cls, orig_func=None):
@overload
def __new__(cls, orig_func: F) -> F: ... # type: ignore[misc]
@overload
def __new__(cls, orig_func: None = None) -> Self: ...
def __new__(cls, orig_func: F | None = None) -> Self | F: # type: ignore[misc]
if orig_func is None:
return super().__new__(cls)
return cls()(orig_func)

View File

@ -311,7 +311,11 @@ def escape(n):
def is_cuda_tensor(obj):
return isinstance(obj, torch.Tensor) and obj.is_cuda and not isinstance(obj, torch._subclasses.FakeTensor)
return (
isinstance(obj, torch.Tensor) and
obj.device.type == "cuda" and
not isinstance(obj, torch._subclasses.FakeTensor)
)
def cuda_allocation_context():
snapshot = torch.cuda.memory._snapshot()