Compare commits

..

44 Commits

Author SHA1 Message Date
b552a4eba1 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 16:43:15 -07:00
b3e120665b Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 16:43:15 -07:00
0d4992c170 [dynamo][easy] Use CONSTANT_MATCH for __code__ guard (#166445)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166445
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #166437, #166444
2025-10-28 23:19:42 +00:00
b060e5c131 [dynamo] Move more FUNCTION_MATCH to CLOSURE_MATCH (#166444)
Closure match is more relaxed than function match which is id match

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166444
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #166437
2025-10-28 23:19:42 +00:00
6d5e651a50 [user-streams] update stream context to use fork/join (#162903)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162903
Approved by: https://github.com/anijain2305
2025-10-28 23:12:05 +00:00
3cc5949dc2 Remove global pytree registration for blockmask (#166434)
The global pytree registration of `BlockMask` was added in https://github.com/pytorch/pytorch/pull/166045

In general ppl assume `BlockMask` is a leaf, so the global registration  could lead to some unexpected failure when calling `tree_map()` on a `BlockMask` since now it will flatten all the way down.

Therefore, we remove the global registration but keep the `_flatten()` and `_unflatten()` classmethod. Users could do a local registration easily when it is needed.

in pytorch
```
python test/distributed/tensor/test_dtensor_export.py -k test_flex_attention_dtensor_export
```

in torchtitan
```
python -m tests.integration_tests.run_tests ./outputs --test_suite features --ngpu 8
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166434
Approved by: https://github.com/wwwjn
2025-10-28 23:11:52 +00:00
f167fd09fa [annotation] Override metadata on regenerated node in functional mode (#166200)
Fixes #165810

If we regenerate a node during functionalization, we override the "stack_trace", "custom", and "seq_nr" metadata of the regenerated node with the node meta of the original node.

```
python test/functorch/test_aot_joint_with_descriptors.py -k test_preserve_annotate_replay_view
python test/functorch/test_aotdispatch.py TestAOTAutogradWithDynamo.test_duplicated_arguments_on_tensor_overlap
 ```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166200
Approved by: https://github.com/bdhirsh
2025-10-28 22:59:39 +00:00
68b3984b77 [xpu][test] Enable skipped SparseAdam UTs (#166375)
With `SparseAdam` now correctly supported on Intel GPU, the previously disabled UTs can be enabled.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166375
Approved by: https://github.com/Skylion007, https://github.com/janeyx99
2025-10-28 22:49:25 +00:00
a1eb6b5538 [dynamo][guards] Do not guard on the queue_callback (#166437)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166437
Approved by: https://github.com/xmfan
2025-10-28 22:37:38 +00:00
f36f372acc bwd pass (#164504)
**Summary**
This implements the backward pass for the Varlen API and registers `_varlen_attn()` as a custom op.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings:

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.8189142608642578 ms       | 3.263883056640625 ms  |
| TFLOPs | 268.652       | 158.731  |

We can see that runtime for Varlen is >3x faster

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen gradients vs SDPA.

For custom op testing, `test_custom_op_registration` uses logging mode to verify that `_varlen_attn()` was called and tests with `torch.compile`. `test_custom_op_compliances` uses `torch.library.opcheck()` to verify.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164504
Approved by: https://github.com/drisspg
2025-10-28 22:35:11 +00:00
d9483d4c8d [dynamo] Clean up assert in dynamo [3/N] (#165903)
Some previous PRs have been merged. This PR aims for some **assert** that the users can trigger, and it may be better to turn them into a graph break. Correct me if there are any problems.

* ->#165903(Clean up for graph break)
* #165745
* #165430

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165903
Approved by: https://github.com/williamwen42

Co-authored-by: William Wen <william.wen42@gmail.com>
2025-10-28 22:29:35 +00:00
fea819ed08 added type annotation to _NoParamDecoratorContextManager.__new__ (#166414)
Fixes #166413

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166414
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-10-28 21:59:20 +00:00
84a2715d34 [dynamo] Revert C++-fying of symbolic shape guards (#166427)
Moving symbolic shape guards to C++ causes compile time issues. This basically boils down to a tradeoff question.

For models that have large amount of dynamic shape guards, this flag will help reduce guard latency. But for most of the models, that have a very few dynamic shape guards, the guard lantecy is anyways small. These models will still see a high compile time hit because of calling gcc during the compile.

So a good default value seems to be False. We can write a doc to give guidance on reducing guard latency.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166427
Approved by: https://github.com/zou3519
2025-10-28 21:57:31 +00:00
572cc12b42 Move MaskPartial to placement_types to improve discoverability (#164414)
Had trouble finding this one myself in #163030.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164414
Approved by: https://github.com/ezyang
2025-10-28 21:56:02 +00:00
1fdef664a5 Revert "[Pytorch] Update Kineto Submodule (#166317)"
This reverts commit be283297100ab86123e74b7a8372995d32b140c8.

Reverted https://github.com/pytorch/pytorch/pull/166317 on behalf of https://github.com/jeffdaily due to ROCm CI was clean, but post-merge ROCm failures showed up ([comment](https://github.com/pytorch/pytorch/pull/166317#issuecomment-3458665809))
2025-10-28 21:55:38 +00:00
08ae55021e support batch size=0 for flash attention (#166318)
Fixes #165944

**Summary**

Today, if we attempt to run flash attention with batch_size 0, we get error `Runtime Error: batch size must be positive`. This PR fixes this by returning early with empty tensors in the fwd and bwd.

**Test plan**
`python test/test_transformers.py -k test_scaled_dot_product_attention` - added case for batch_size=0
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166318
Approved by: https://github.com/drisspg
2025-10-28 21:53:48 +00:00
96a8d1c5e0 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 13:37:38 -07:00
39307c3db2 Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 13:37:38 -07:00
3d6061d56a Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 13:34:54 -07:00
dc55769bb6 Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 13:34:54 -07:00
551921d484 Change t.is_cuda to t.device.type == 'cuda' in torch/utils/viz (#156418)
Fixes #156417

Unlike `.is_cuda` the property `.device` is supported by `ShardedTensor`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156418
Approved by: https://github.com/mikaylagawarecki

Co-authored-by: Alexander Zhipa <azzhipa@amazon.com>
2025-10-28 20:34:14 +00:00
2c74beddf6 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 13:33:32 -07:00
12ff17857e Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 13:33:32 -07:00
b5189e269e NVFP4 grouped gemm support via. FBGEMM kernels (#166308)
Summary:

* Add NVFP4 (1x16 block e4m3, tensor-wise fp32) scaled grouped gemm
* Extend testing to add nvfp4 support

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166308
Approved by: https://github.com/ngimel
2025-10-28 20:32:53 +00:00
3895ce093f [inductor] add in-kernel nan-check (#166008)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166008
Approved by: https://github.com/eellison
2025-10-28 20:19:10 +00:00
8aa087a29d [ez] Fix print for failing test when entire file fails (#166420)
Was previously printing "FAILED CONSISTENTLY: ul" since it was null,
This changes it so it prints the test_file by moving some logic for checking this to be earlier
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166420
Approved by: https://github.com/Skylion007
2025-10-28 20:13:58 +00:00
7379972cc0 Revert "[Inductor] Naive foreach autotune support (#162053)"
This reverts commit cdb60e44eb528bf02c6bb2d7e384298283e755ca.

Reverted https://github.com/pytorch/pytorch/pull/162053 on behalf of https://github.com/xmfan due to Compile time regression ([comment](https://github.com/pytorch/pytorch/pull/162053#issuecomment-3458252331))
2025-10-28 20:01:54 +00:00
4ae3c59ce2 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 12:19:15 -07:00
7a8ad5f874 Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 12:19:15 -07:00
dd09fa089d Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 15:37:16 -07:00
0995593caa Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 15:37:16 -07:00
69a4358a01 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 15:24:52 -07:00
0ab9e050ab Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 15:24:52 -07:00
651e9dbf94 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 14:08:25 -07:00
56bd4c695a Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 14:08:24 -07:00
1cb7be9419 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 13:34:38 -07:00
00f68803d3 Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 13:34:38 -07:00
de1f732075 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 13:31:13 -07:00
cbfee32779 Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 13:31:13 -07:00
0e38867920 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 12:57:54 -07:00
cefd269c35 Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 12:57:54 -07:00
7fcf3a1488 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 12:28:36 -07:00
9a88bd06e1 Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 12:28:36 -07:00
ccc9750df1 [DONT MERGE] Get rid of FUNCTION_MATCH
[ghstack-poisoned]
2025-10-27 11:38:49 -07:00
47 changed files with 1441 additions and 530 deletions

View File

@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
if(USE_CUDA)
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*mx8mx8bf16_grouped.*")
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
@ -291,6 +291,7 @@ IF(USE_FBGEMM_GENAI)
set(fbgemm_genai_cuh
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/f4f4bf16_grouped/"
"${FBGEMM_GENAI_SRCS}/"
)

View File

@ -208,6 +208,48 @@ _f8_f8_bf16_rowwise_grouped_mm(
#endif
}
Tensor&
_f4_f4_bf16_grouped_mm_fbgemm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& global_scale_a,
const Tensor& scale_b,
const Tensor& global_scale_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
Tensor& out) {
#if !defined(USE_ROCM) && defined(USE_FBGEMM_GENAI)
// Typing checks
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2,
"mat_a must be Float4_e2n1fn_2, got: ", mat_a.scalar_type());
TORCH_CHECK_VALUE(mat_b.scalar_type() == at::kFloat4_e2m1fn_x2,
"mat_b must be Float4_e2n1fn_2, got: ", mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e4m3fn,
"scale_a must be Float8_e4m3fn, got: ", scale_a.scalar_type());
TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e4m3fn,
"scale_b must be Float8_e4m3fn, got: ", scale_b.scalar_type());
TORCH_CHECK_VALUE(global_scale_a.scalar_type() == at::kFloat,
"global_scale_a must be Float, got: ", global_scale_a.scalar_type());
TORCH_CHECK_VALUE(global_scale_b.scalar_type() == at::kFloat,
"global_scale_b must be Float, got: ", global_scale_b.scalar_type());
auto o = fbgemm_gpu::f4f4bf16_grouped_mm(
mat_a,
mat_b,
scale_a,
scale_b,
offs.value(),
out,
global_scale_a.mul(global_scale_b)
);
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "nvfp4 grouped gemm is not supported without USE_FBGEMM_GENAI, and only for CUDA")
#endif
return out;
}
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
// Checks scales for 2d or 3d target tensors (`mat`).
if (mat.dim() == 2) {
@ -245,7 +287,15 @@ void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int
}
}
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
void _check_scales_blocked(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
// if {mx,nv}fp4, will need to modify K later
bool is_fp4 = (mat.scalar_type() == kFloat4_e2m1fn_x2);
int blocksize = 32;
// check for nvfp4 vs. mxfp4 to fix blocksize
if (is_fp4 && scale.scalar_type() == kFloat8_e4m3fn) {
blocksize = 16;
}
// Checks scales for 2d or 3d target tensors (`mat`).
if (mat.dim() == 2) {
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
@ -253,17 +303,19 @@ void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim,
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
TORCH_CHECK(
scale.dim() == mat.dim(),
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
"for block-scaled, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(),
" and scale.dim() = ", scale.dim(), " for arg ", arg_idx
);
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/blocksize, 4))
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/blocksize, 4))
// * weight is transposed prior to the call, scale stays non-transposed.
bool LHS = arg_idx == 0;
int scale_dim_to_check = 0;
int mat_dim_to_check = LHS ? 0 : 1;
TORCH_CHECK(
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
"for block-scaled, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
} else {
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
@ -273,32 +325,40 @@ void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim,
};
// TODO: this is for 3d tensor in 2d-3d case specifically.
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
// We'll need to support 3d-3d and 3d-2d cases once mxfp8/nvfp4 grouped gemm supports them.
int64_t G = mat.size(0);
int64_t K = mat.size(1);
if (is_fp4) {
// FP4 packs 2 values into a single 8b word - the "real" K is 2x the
// reported K. Reverse that adjustment.
const int fp4_elems_per_byte = 2;
K *= fp4_elems_per_byte;
}
int64_t N = mat.size(2);
int64_t blocked_scale_K = round_up(K/32, 4);
int64_t blocked_scale_K = round_up(K/blocksize, 4);
int64_t blocked_scale_N = round_up(N, 128);
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
TORCH_CHECK(
scale.dim() == mat.dim() - 1,
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
"for block-scaled 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N),",
"but scale is ", scale.dim(), "D for arg ", arg_idx
);
TORCH_CHECK(
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
"for block-scaled grouped GEMM, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ")",
" for arg ", arg_idx, ", got: ", scale.size(0), ", ", scale.size(1)
);
}
}
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
bool using_mx = scale.scalar_type() == at::kFloat8_e8m0fnu;
if (using_fp8_rowwise) {
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
} else if (using_mxfp8) {
_check_scales_mxfp8(mat, scale, dim, arg_idx);
} else if (using_mx) {
_check_scales_blocked(mat, scale, dim, arg_idx);
} else {
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
}
@ -411,9 +471,10 @@ namespace {
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 3> scale_grouped_kernel_dispatch = {{
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
{ "nvfp4_nvfp4", scaled_blas::check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4}}};
} // anonymous namespace
@ -525,8 +586,9 @@ _scaled_grouped_mm_cuda_v2(
out);
}
case ScaledGemmImplementation::MXFP8_MXFP8: {
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
// scale shape checks
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
return _mx8_mx8_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
@ -537,6 +599,21 @@ _scaled_grouped_mm_cuda_v2(
offs.value(),
out);
}
case ScaledGemmImplementation::NVFP4_NVFP4: {
// scale shape checks
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
return _f4_f4_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
scale_a[0], /* block-scale A */
scale_a[1], /* global-scale A */
scale_b[0], /* block-scale B */
scale_b[1], /* global-scale B */
offs.value(),
std::nullopt, /* bias */
out);
}
default:
TORCH_CHECK_NOT_IMPLEMENTED(false,
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");

View File

@ -22,6 +22,7 @@
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/zeros_like.h>
#include <ATen/ops/reshape.h>
#include <ATen/ops/scalar_tensor.h>
#include <ATen/ops/sum.h>
@ -42,7 +43,6 @@ C10_DIAGNOSTIC_POP()
#include <static_switch.h>
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
#include <c10/util/Exception.h>
namespace FLASH_NAMESPACE {
@ -417,6 +417,26 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
const int head_size_og = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
if (batch_size == 0) {
auto opts = q.options();
at::Tensor out = at::empty({0, seqlen_q, num_heads, head_size_og}, opts);
at::Tensor q_padded = at::empty({0, seqlen_q, num_heads, head_size_og}, opts);
at::Tensor k_padded = at::empty({0, seqlen_k, num_heads_k, head_size_og}, opts);
at::Tensor v_padded = at::empty({0, seqlen_k, num_heads_k, head_size_og}, opts);
at::Tensor softmax_lse = at::empty({0, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor rng_state = at::empty({2}, at::dtype(c10::kUInt64).device(at::kCUDA));
at::Tensor _unused = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA));
at::Tensor p = at::empty({0}, opts);
if (return_softmax) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
p = at::empty({0, num_heads, seqlen_q_rounded, seqlen_k_rounded}, opts);
}
return {std::move(out), std::move(q_padded), std::move(k_padded), std::move(v_padded), std::move(softmax_lse), std::move(rng_state), _unused, std::move(p)};
}
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
@ -547,7 +567,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
}
return {out, q_padded, k_padded, v_padded, softmax_lse, rng_state, _unused, p};
return {std::move(out), std::move(q_padded), std::move(k_padded), std::move(v_padded), std::move(softmax_lse), std::move(rng_state), std::move(_unused), std::move(p)};
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@ -852,7 +872,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
const auto sizes = q.sizes();
@ -863,6 +882,20 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const int head_size = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
if (batch_size == 0) {
auto opts = q.options();
at::Tensor dq = at::empty_like(q);
at::Tensor dk = at::empty_like(k);
at::Tensor dv = at::empty_like(v);
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
at::Tensor softmax_d = at::empty({0, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
return {dq, dk, dv, softmax_d};
}
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");

View File

@ -1066,6 +1066,8 @@ coverage_ignore_functions = [
"set_current_meta",
"set_grad_fn_seq_nr",
"set_stack_trace",
"set_current_replay_node",
"get_current_replay_node",
# torch.jit.annotations
"ann_to_type",
"check_fn",

View File

@ -99,6 +99,12 @@ DTensor supports the following types of {class}`Placement` on each {class}`Devic
:undoc-members:
```
```{eval-rst}
.. autoclass:: MaskPartial
:members:
:undoc-members:
```
```{eval-rst}
.. autoclass:: Placement
:members:

View File

@ -22,7 +22,11 @@ from torch.distributed.tensor.parallel import (
parallelize_module,
RowwiseParallel,
)
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from torch.nn.attention.flex_attention import (
BlockMask,
create_block_mask,
flex_attention,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -32,6 +36,7 @@ from torch.testing._internal.common_utils import (
)
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.utils._pytree import register_pytree_node
class SimpleModel(torch.nn.Module):
@ -176,6 +181,15 @@ def _count_op(gm, target):
return sum(1 for node in gm.graph.nodes if node.target == target)
register_pytree_node(
BlockMask,
BlockMask._flatten,
BlockMask._unflatten,
flatten_with_keys_fn=BlockMask._flatten_with_keys,
serialized_type_name="torch.nn.attention.flex_attention.BlockMask",
)
@requires_cuda
class DTensorExportTest(TestCase):
def tearDown(self):

View File

@ -168,7 +168,7 @@ class TestEmbeddingOp(DTensorTestBase):
self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22)
self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10)
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor.placement_types import MaskPartial
# test collectives
embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type)
@ -176,7 +176,7 @@ class TestEmbeddingOp(DTensorTestBase):
inp = torch.randint(0, 10, (8, 8), device=self.device_type)
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
output = sharded_embedding(replicated_inp)
self.assertIsInstance(output.placements[0], _MaskPartial)
self.assertIsInstance(output.placements[0], MaskPartial)
comm_mode = CommDebugMode()
@ -192,9 +192,9 @@ class TestEmbeddingOp(DTensorTestBase):
inp = torch.randint(0, 10, (4, 4), device=self.device_type)
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor.placement_types import MaskPartial
# case 1: two embeddings with the same shape, thus sharing the underlying _MaskPartial
# case 1: two embeddings with the same shape, thus sharing the underlying MaskPartial
# and MaskBuffer, because of cache hit from sharding propagation
emb1 = torch.nn.Embedding(10, 23, device=self.device_type)
@ -206,23 +206,23 @@ class TestEmbeddingOp(DTensorTestBase):
output2 = sharded_emb2(replicated_inp)
partial_placement1 = output1.placements[0]
self.assertIsInstance(partial_placement1, _MaskPartial)
self.assertIsInstance(partial_placement1, MaskPartial)
output1.full_tensor()
partial_placement2 = output2.placements[0]
self.assertIsInstance(partial_placement2, _MaskPartial)
self.assertIsInstance(partial_placement2, MaskPartial)
output2.full_tensor()
self.assertTrue(id(partial_placement1), id(partial_placement2))
# case 2: two embeddings with the same logical_dim_size, but different logical_shape
# thus they will have different _MaskPartial placements (with no cache hit)
# thus they will have different MaskPartial placements (with no cache hit)
emb3 = torch.nn.Embedding(10, 29, device=self.device_type)
sharded_emb3 = self._apply_sharding(emb3, 0, mesh)
output3 = sharded_emb3(replicated_inp)
partial_placement3 = output3.placements[0]
self.assertIsInstance(partial_placement3, _MaskPartial)
self.assertIsInstance(partial_placement3, MaskPartial)
output2.full_tensor()
# not equal because of different logical_shape, despite of same logical_dim_size

View File

@ -511,7 +511,7 @@ class DistTensorOpsTest(DTensorTestBase):
# case 2 input sharding: input sharded, index replicated, output mask partial
# only works when index has size 1 on the gather dimension and
# input is sharded on the gather dimension
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor.placement_types import MaskPartial
gather_dim = 1
global_input = torch.randn(12, 8, 16)
@ -522,7 +522,7 @@ class DistTensorOpsTest(DTensorTestBase):
with comm_mode:
output_dt = torch.gather(input_dt, gather_dim, index_dt)
self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertIsInstance(output_dt.placements[0], _MaskPartial)
self.assertIsInstance(output_dt.placements[0], MaskPartial)
self.assertEqual(output_dt.full_tensor(), global_output)
# case 3 index sharding: input replicated, index sharded, output sharded

View File

@ -230,7 +230,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 12)
self.assertEqual(cnts.op_count, 20)
@unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/118204
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@ -335,7 +335,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 21)
self.assertEqual(cnts.op_count, 37)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_stream_compared_with_constant(self):
@ -517,7 +517,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
res = opt_fn(x, cur_stream, new_stream)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 19)
self.assertEqual(cnts.op_count, 27)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_event_method(self):
@ -557,7 +557,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 19)
self.assertEqual(cnts.op_count, 27)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_device(self):

View File

@ -1016,6 +1016,59 @@ class inner_f(torch.nn.Module):
self.assertFalse("self._opoverload" in foo_node.meta.get("stack_trace", None))
self.assertFalse("self._opoverload" in gm.print_readable(print_output=False))
def test_preserve_annotate_replay_view(self):
"""Test stack trace and annotation are correct on nodes regenerated in functionalization"""
def _unpermute(out, input_shape, permuted_indices):
"""
Unpermute operation from torchtitan MoE utils.
"""
out_unpermuted = out.new_empty(input_shape)
out_unpermuted[permuted_indices, :] = out
out = out_unpermuted[:-1]
return out
class Module(nn.Module):
def __init__(self):
super().__init__()
self.input_shape = (5, 3)
self.permuted_indices = torch.tensor([2, 0, 3, 1])
def forward(self, x):
with fx_traceback.annotate({"pp_stage": 0}):
routed_output = _unpermute(
x, self.input_shape, self.permuted_indices
)
return routed_output.cos()
inputs = (torch.randn(4, 3, requires_grad=True),)
model = Module()
graph_module = graph_capture(model, inputs, True)
custom_metadata = fx_traceback._get_custom_metadata(graph_module)
slice_nodes = graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten.slice.Tensor
)
self.assertEqual(len(slice_nodes), 1)
slice_backward_nodes = graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten.slice_backward.default
)
self.assertEqual(len(slice_backward_nodes), 1)
slice_node = slice_nodes[0]
slice_backward_node = slice_backward_nodes[0]
self.assertEqual(slice_node.meta["seq_nr"], slice_backward_node.meta["seq_nr"])
self.assertTrue("out = out_unpermuted[:-1]" in slice_node.meta["stack_trace"])
self.assertExpectedInline(
str(custom_metadata),
"""\
('call_function', 'new_empty', {'pp_stage': 0})
('call_function', 'index_put', {'pp_stage': 0})
('call_function', 'slice_2', {'pp_stage': 0})
('call_function', 'slice_backward', {'pp_stage': 0})
('call_function', 'index', {'pp_stage': 0})""",
)
if __name__ == "__main__":
run_tests()

View File

@ -3245,8 +3245,8 @@ def forward(self, primals_1):
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_8, [4]); as_strided_8 = None
as_strided_9 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_9, [4]); as_strided_9 = None
return (as_strided_scatter, view_1)""",
) # noqa: B950
@ -3409,13 +3409,13 @@ def forward(self, primals_1, primals_2, primals_3):
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
unsqueeze_1 = torch.ops.aten.unsqueeze.default(as_strided_5, 0); as_strided_5 = None
add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1); add_1 = None
unsqueeze = torch.ops.aten.unsqueeze.default(as_strided_5, 0); as_strided_5 = None
add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze); add_1 = None
as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_2 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None
return (as_strided_scatter, add_2, view_2, unsqueeze_1)""",
return (as_strided_scatter, add_2, view_2, unsqueeze)""",
) # noqa: B950
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")

View File

@ -14269,6 +14269,22 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.assertTrue("'enable_fp_fusion': False" in code)
torch.testing.assert_close(out, fn(a, b), atol=0, rtol=0)
@requires_cuda_and_triton
@config.patch(runtime_triton_nan_asserts=True)
def test_nan_assert_inside_triton_kernel(self):
def fn(x):
x = x - 1
# Uncomment the following line can trigger the failure of
# the device size assertion
# x = torch.log(x)
return torch.where(x.isnan(), 3.14, x)
compiled = torch.compile(fn)
x = torch.randn(4096, device=GPU_TYPE)
out, (code,) = run_and_get_code(compiled, x)
self.assertTrue("'NaN or Inf found'" in code)
torch.testing.assert_close(out, fn(x))
@skip_if_cpp_wrapper("skip cpp wrapper")
@requires_cuda_and_triton
def test_repeat_interleave_decomposition_has_clamp(self):

View File

@ -27,7 +27,6 @@ import torch
import torch.distributed as dist
from torch.multiprocessing import current_process, get_context
from torch.testing._internal.common_utils import (
get_report_dir,
get_report_path,
IS_CI,
IS_MACOS,
@ -35,6 +34,7 @@ from torch.testing._internal.common_utils import (
set_cwd,
shell,
TEST_CUDA,
TEST_SAVE_XML,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
TEST_WITH_SLOW_GRADCHECK,
@ -529,14 +529,6 @@ def run_test(
replacement = {"-f": "-x", "-dist=loadfile": "--dist=loadfile"}
unittest_args = [replacement.get(arg, arg) for arg in unittest_args]
xml_report_dir = get_report_dir(test_file, None, options.pytest)
if is_cpp_test:
unittest_args.append(
f"--junit-xml-reruns={get_report_path(xml_report_dir, test_file)}"
)
else:
unittest_args.append(f"--save-xml={xml_report_dir}")
if options.showlocals:
if options.pytest:
unittest_args.extend(["--showlocals", "--tb=long", "--color=yes"])
@ -763,6 +755,8 @@ def run_test_retries(
REPO_ROOT / ".pytest_cache/v/cache/stepcurrent" / stepcurrent_key
) as f:
current_failure = f.read()
if current_failure == "null":
current_failure = f"'{test_file}'"
except FileNotFoundError:
print_to_file(
"No stepcurrent file found. Either pytest didn't get to run (e.g. import error)"
@ -799,8 +793,6 @@ def run_test_retries(
print_to_file("Retrying single test...")
print_items = [] # do not continue printing them, massive waste of space
if "null" in num_failures:
num_failures[f"'{test_file}'"] = num_failures.pop("null")
consistent_failures = [x[1:-1] for x in num_failures.keys() if num_failures[x] >= 3]
flaky_failures = [x[1:-1] for x in num_failures.keys() if 0 < num_failures[x] < 3]
if len(flaky_failures) > 0:
@ -1234,6 +1226,12 @@ def get_pytest_args(options, is_cpp_test=False, is_distributed_test=False):
# is much slower than running them directly
pytest_args.extend(["-n", str(NUM_PROCS)])
if TEST_SAVE_XML:
# Add the option to generate XML test report here as C++ tests
# won't go into common_utils
test_report_path = get_report_path(pytest=True)
pytest_args.extend(["--junit-xml-reruns", test_report_path])
if options.pytest_k_expr:
pytest_args.extend(["-k", options.pytest_k_expr])

View File

@ -46,6 +46,7 @@ from torch.testing._internal.common_quantized import (
_floatx_unpacked_to_f32,
ceil_div, to_blocked,
to_mxfp8,
from_blocked_format,
generate_jagged_offs,
)
@ -462,6 +463,24 @@ def pack_uint4(uint8_data) -> torch.Tensor:
uint8_data = uint8_data.contiguous().view(-1)
return (uint8_data[1::2] << 4 | uint8_data[::2]).view(down_size(shape))
def unpack_uint4(uint8_data) -> torch.Tensor:
# Take a packed uint8 tensor (i.e. nvfp4) and unpack into
# a tensor twice as wide. Useful for dequant operations.
shape = list(uint8_data.shape)
# 2x packed elements -> single non-packed => adjust shape
shape[-1] *= 2
out = torch.empty(
*shape,
device=uint8_data.device,
dtype=torch.uint8
).view(-1)
uint8_data_as_uint8 = uint8_data.view(torch.uint8).view(-1)
out[1::2] = uint8_data_as_uint8[:] >> 4
out[::2] = uint8_data_as_uint8 & 15
return out.view(shape)
def _bfloat16_to_float4_e2m1fn_x2(x):
assert x.dtype == torch.bfloat16
@ -470,6 +489,119 @@ def _bfloat16_to_float4_e2m1fn_x2(x):
x = x.view(torch.float4_e2m1fn_x2)
return x
def _convert_to_nvfp4_with_hp_ref(t):
# Convert a tensor to nvfp4, returning:
# t_hp : reconstructed bf16 version of t_lp
# t_lp : nvfp4 tensor (2x elements packed into uint8)
# t_scale: e4m3 block-wise scaling factors (non-swizzled)
# t_global_scale: fp32 tensor-wise global scaling factor
t_lp, t_scale, t_global_scale = data_to_nvfp4_with_global_scale(
t,
16,
)
t_hp = from_blocked_format(
_floatx_unpacked_to_f32(
unpack_uint4(t_lp),
FP4_EBITS,
FP4_MBITS),
t_scale,
blocksize=16) * t_global_scale
return t_hp, t_lp, t_scale, t_global_scale
def _convert_to_mxfp8_with_hp_ref(t):
# Convert a tensor to mxfp8, returning:
# t_hp : reconstructed bf16 version of t_lp
# t_lp : fp8_e4m3 tensor
# t_scale: fp8_e8m0 block-wise scaling factors (non-swizzled)
t_scale, t_lp = to_mxfp8(t)
t_hp = from_blocked_format(t_lp, t_scale, blocksize=32)
return t_hp, t_lp, t_scale
def _2d_grouped_tensor_to_mxfp8_blocked_scaled(t, MN, G, offs, format='mxfp8'):
# Convert scales to blocked format. either mxfp8 or nvfp4
th_list = []
t_list = []
t_blocked_scale_list = []
t_global_scale_list = []
def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
for group_idx in range(G):
# to_mxfp8 per group
prev_group_end_offset = (
0 if group_idx == 0 else offs[group_idx - 1]
)
curr_group_end_offset = offs[group_idx]
group_size = curr_group_end_offset - prev_group_end_offset
if group_size > 0:
t_slice = t[
:, prev_group_end_offset:curr_group_end_offset
].contiguous() # (M, K_group)
if format == 'mxfp8':
th_slice, tq_slice, t_scale_slice = _convert_to_mxfp8_with_hp_ref(t_slice)
elif format == 'nvfp4':
th_slice, tq_slice, t_scale_slice, tq_global = _convert_to_nvfp4_with_hp_ref(
t_slice,
)
t_global_scale_list.append(tq_global)
else:
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
t_list.append(tq_slice)
th_list.append(th_slice)
# Convert scales to blocked format.
t_scale_slice_blocked = to_blocked(
t_scale_slice
) # (round_up(M, 128), round_up(K_group//32, 4))
t_blocked_scale_list.append(t_scale_slice_blocked)
# Assemble the full XQ and WQ
tq = torch.cat(t_list, dim=1).contiguous()
th = torch.cat(th_list, dim=1).contiguous()
# Combine all XQ groups blocked scales into one tensor.
t_blocked_scales = torch.cat(t_blocked_scale_list, dim=0)
MN_rounded = round_up(MN, 128)
t_blocked_scales = t_blocked_scales.reshape(MN_rounded, -1)
# Global scales only exist for nvfp4
t_global_scales = None
if len(t_global_scale_list) > 0:
t_global_scales = torch.stack(t_global_scale_list)
return th, tq, t_blocked_scales, t_global_scales
def _build_scaled_grouped_mm_kwargs(scale_a, scale_b, offs, format):
# Build some standard args that are wordy
# Note: if/when ROCm support added, need to change swizzle handling
kwargs = {
'mxfp8': {
'scale_a': scale_a,
'scale_b': scale_b,
'scale_recipe_a': ScalingType.BlockWise1x32,
'scale_recipe_b': ScalingType.BlockWise1x32,
'swizzle_a': SwizzleType.SWIZZLE_32_4_4,
'swizzle_b': SwizzleType.SWIZZLE_32_4_4,
'offs': offs, # (G,)
'out_dtype': torch.bfloat16,
'wrap_v2': True,
},
'nvfp4': {
'scale_a': scale_a,
'scale_b': scale_b,
'scale_recipe_a': [ScalingType.BlockWise1x16, ScalingType.TensorWise],
'scale_recipe_b': [ScalingType.BlockWise1x16, ScalingType.TensorWise],
'swizzle_a': SwizzleType.SWIZZLE_32_4_4,
'swizzle_b': SwizzleType.SWIZZLE_32_4_4,
'offs': offs, # (G,)
'out_dtype': torch.bfloat16,
'wrap_v2': True,
},
}
return kwargs[format]
class TestFP8Matmul(TestCase):
@ -526,13 +658,15 @@ class TestFP8Matmul(TestCase):
out_fp8_s = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b)
self.assertEqual(out_fp8, out_fp8_s)
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
@parametrize("G", [1, 4, 16])
@parametrize("M", [2048, 2049])
@parametrize("N", [8192])
@parametrize("K", [16640])
@parametrize("wrap_v2", [True, False])
def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K, wrap_v2):
@parametrize("format", ["mxfp8"] + (["nvfp4"] if torch.version.cuda else []))
def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format):
torch.manual_seed(42)
total_K = K # Alias for clarity, communicating this consists of several groups along this dim
input_group_end_offsets = generate_jagged_offs(
@ -541,95 +675,61 @@ class TestFP8Matmul(TestCase):
X = torch.randn((M, total_K), dtype=torch.bfloat16, device="cuda") * 0.1
W = torch.randn((N, total_K), dtype=torch.bfloat16, device="cuda") * 0.01
# Convert scales to blocked format.
x_list = []
w_list = []
x_blocked_scale_list = []
w_blocked_scale_list = []
xh, xq, x_blocked_scales, x_global_scales = _2d_grouped_tensor_to_mxfp8_blocked_scaled(
X, M, G, input_group_end_offsets, format=format
)
wh, wq, w_blocked_scales, w_global_scales = _2d_grouped_tensor_to_mxfp8_blocked_scaled(
W, N, G, input_group_end_offsets, format=format
)
def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
for group_idx in range(G):
# to_mxfp8 per group
prev_group_end_offset = (
0 if group_idx == 0 else input_group_end_offsets[group_idx - 1]
if format == "mxfp8":
kwargs = _build_scaled_grouped_mm_kwargs(
x_blocked_scales,
w_blocked_scales,
input_group_end_offsets,
format,
)
curr_group_end_offset = input_group_end_offsets[group_idx]
group_size = curr_group_end_offset - prev_group_end_offset
if group_size > 0:
x_slice = X[
:, prev_group_end_offset:curr_group_end_offset
].contiguous() # (M, K_group)
w_slice = W[
:, prev_group_end_offset:curr_group_end_offset
].contiguous() # (N, K_group)
x_scale_slice, xq_slice = to_mxfp8(
x_slice
) # scale shape -> (M, K_group // 32)
w_scale_slice, wq_slice = to_mxfp8(
w_slice
) # scale shape -> (N, K_group // 32)
x_list.append(xq_slice)
w_list.append(wq_slice)
elif format == "nvfp4":
kwargs = _build_scaled_grouped_mm_kwargs(
[x_blocked_scales, x_global_scales],
[w_blocked_scales, w_global_scales],
input_group_end_offsets,
format,
)
else:
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
# Convert scales to blocked format.
x_scale_slice_blocked = to_blocked(
x_scale_slice
) # (round_up(M, 128), round_up(K_group//32, 4))
w_scale_slice_blocked = to_blocked(
w_scale_slice
) # (round_up(N, 128), round_up(K_group//32, 4))
x_blocked_scale_list.append(x_scale_slice_blocked)
w_blocked_scale_list.append(w_scale_slice_blocked)
# Assemble the full XQ and WQ
xq = torch.cat(x_list, dim=1).contiguous()
wq = torch.cat(w_list, dim=1).contiguous()
# Combine all XQ groups blocked scales into one tensor.
x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0)
M_rounded = round_up(M, 128)
x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1)
# Combine all WQ groups blocked scales into one tensor.
w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0)
N_rounded = round_up(N, 128)
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
if format == 'nvfp4':
assert x_global_scales.numel() == w_global_scales.numel()
assert x_global_scales.numel() == G
# Compute mxfp8 grouped mm output
y_mxfp8 = scaled_grouped_mm_wrap(
xq, # (M, total_K)
wq.transpose(-2, -1), # (total_K, N)
x_blocked_scales, # to_blocked_per_group(M, total_K//32)
w_blocked_scales, # to_blocked_per_group(N, total_K//32)
scale_recipe_a=ScalingType.BlockWise1x32,
scale_recipe_b=ScalingType.BlockWise1x32,
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
offs=input_group_end_offsets, # (G,)
out_dtype=torch.bfloat16,
wrap_v2=wrap_v2
y_lp = scaled_grouped_mm_wrap(
xq,
wq.transpose(-2, -1),
**kwargs,
)
# bf16 reference output
y_bf16 = torch._grouped_mm(
X, W.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16
# Note: Reference result should be on reconstructed, not original values.
# as-in float(fp4(t)) not t itself.
xh, wh.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16
)
# Assert no NaNs
assert not y_mxfp8.isnan().any(), "mxfp8 output contains NaN"
assert not y_lp.isnan().any(), "mxfp8 output contains NaN"
# Assert outputs are close
torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2)
torch.testing.assert_close(y_lp, y_bf16, atol=8.0e-2, rtol=8.0e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
@parametrize("G", [1, 4, 16])
@parametrize("M", [16640])
@parametrize("N", [8192])
@parametrize("K", [4096])
@parametrize("wrap_v2", [True, False])
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, wrap_v2):
@parametrize("format", ["mxfp8"] + (["nvfp4"] if torch.version.cuda else []))
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, format):
torch.manual_seed(42)
# Simulate 2d-3d grouped gemm `out = input @ weight.t()`
# 2D inputs with groups along M, 3D weights.
@ -643,60 +743,120 @@ class TestFP8Matmul(TestCase):
# For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
# as they each used for independent gemm in the grouped gemm.
wq_list = []
w_scale_list = []
for i in range(G):
w_scale, wq = to_mxfp8(W[i])
w_scale = to_blocked(w_scale)
wq_list.append(wq)
w_scale_list.append(w_scale)
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
def _3d_to_blocked_scaled(W, G, format):
wh_list = []
wq_list = []
w_scale_list = []
w_global_scale_list = []
for i in range(G):
if format == "mxfp8":
wh, wq, w_scale = _convert_to_mxfp8_with_hp_ref(W[i])
elif format == "nvfp4":
w_scale, wq = to_mxfp8(W[i])
wh, wq, w_scale, w_global_scale = _convert_to_nvfp4_with_hp_ref(W[i])
w_global_scale_list.append(w_global_scale)
else:
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
# Swizzle scaled
# TODO(slayton): gate on cuda/hip
w_scale = to_blocked(w_scale)
wh_list.append(wh)
wq_list.append(wq)
w_scale_list.append(w_scale)
wh = torch.stack(wh_list, dim=0).contiguous()
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
# Global scales only exist for nvfp4
if len(w_global_scale_list) > 0:
w_global_scales = torch.stack(w_global_scale_list)
else:
w_global_scales = None
return wh, wq, w_scale, w_global_scales
wh, wq, w_blocked_scales, w_global_scales = _3d_to_blocked_scaled(W, G, format)
# For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
# as they each used for independent gemm in the grouped gemm.
xq_list = []
x_scale_list = []
for i in range(G):
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
curr_group_end = input_group_end_offsets[i]
group_size = curr_group_end - prev_group_end
if group_size > 0:
x_slice = X[prev_group_end:curr_group_end, :]
x_scale, xq = to_mxfp8(x_slice)
x_scale = to_blocked(x_scale)
xq_list.append(xq)
x_scale_list.append(x_scale)
xq = torch.cat(xq_list, dim=0).contiguous()
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
x_scale = x_scale.reshape(-1, K // block_size)
xq = xq.view(-1, xq.shape[-1])
def _2d_to_blocked_scaled(X, K, G, offs, format):
xh_list = []
xq_list = []
x_scale_list = []
x_global_scale_list = []
for i in range(G):
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
curr_group_end = input_group_end_offsets[i]
group_size = curr_group_end - prev_group_end
if group_size > 0:
x_slice = X[prev_group_end:curr_group_end, :]
if format == "mxfp8":
xh, xq, x_scale = _convert_to_mxfp8_with_hp_ref(x_slice)
elif format == "nvfp4":
xh, xq, x_scale, x_global_scale = _convert_to_nvfp4_with_hp_ref(x_slice)
x_global_scale_list.append(x_global_scale)
else:
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
# Compute mxfp8 grouped gemm.
y_mxfp8 = scaled_grouped_mm_wrap(
x_scale = to_blocked(x_scale)
xh_list.append(xh)
xq_list.append(xq)
x_scale_list.append(x_scale)
xh = torch.cat(xh_list, dim=0).contiguous()
xq = torch.cat(xq_list, dim=0).contiguous()
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
x_scale = x_scale.reshape(-1, K // block_size)
xq = xq.view(-1, xq.shape[-1])
xh = xh.view(-1, xh.shape[-1])
x_global_scales = None
if len(x_global_scale_list) > 0:
x_global_scales = torch.stack(x_global_scale_list)
return xh, xq, x_scale, x_global_scales
xh, xq, x_blocked_scales, x_global_scales = _2d_to_blocked_scaled(X, K, G, input_group_end_offsets, format)
if format == "mxfp8":
kwargs = _build_scaled_grouped_mm_kwargs(
x_blocked_scales,
w_blocked_scales,
input_group_end_offsets,
format,
)
elif format == "nvfp4":
kwargs = _build_scaled_grouped_mm_kwargs(
[x_blocked_scales, x_global_scales],
[w_blocked_scales, w_global_scales],
input_group_end_offsets,
format,
)
else:
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
if format == 'nvfp4':
assert x_global_scales.numel() == w_global_scales.numel()
assert x_global_scales.numel() == G
# Compute low-precision grouped gemm.
y_lp = scaled_grouped_mm_wrap(
xq,
wq.transpose(-2, -1),
x_scale,
w_scale,
offs=input_group_end_offsets,
out_dtype=torch.bfloat16,
scale_recipe_a=ScalingType.BlockWise1x32,
scale_recipe_b=ScalingType.BlockWise1x32,
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
wrap_v2=wrap_v2)
**kwargs
)
# Compute reference bf16 grouped gemm.
# Note: Reference result should be on reconstructed, not original values.
# as-in float(fp4(t)) not t itself.
y_bf16 = torch._grouped_mm(
X,
W.transpose(-2, -1),
xh,
wh.transpose(-2, -1),
offs=input_group_end_offsets,
out_dtype=torch.bfloat16,
)
# Assert outputs are close.
torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2)
torch.testing.assert_close(y_lp, y_bf16, atol=8.0e-2, rtol=8.0e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@ -1704,6 +1864,7 @@ class TestFP8Matmul(TestCase):
@parametrize("fast_accum", [False, True])
# AMD does not support non-contiguous inputs yet
@parametrize("strided", [False] + ([True] if torch.version.cuda else []))
# AMD does not support NVFP4
@parametrize("wrap_v2", [True, False])
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, wrap_v2):
device = "cuda"

View File

@ -1107,6 +1107,7 @@ class TestTransformers(NNTestCase):
)[0]
@tf32_on_and_off(0.003)
@parametrize("batch_size", [0, 5])
@parametrize("input_dim,attn_mask_dim,is_causal",
[(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True),
(4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)],
@ -1116,7 +1117,7 @@ class TestTransformers(NNTestCase):
if attn_dim is not None else "no_attn_mask")))
@parametrize("dropout_p", [0.0, 0.2, 0.5])
@sdpa_kernel(backends=[SDPBackend.MATH])
def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p):
def test_scaled_dot_product_attention(self, device, batch_size, input_dim, attn_mask_dim, is_causal, dropout_p):
def sdp_ref(
q,
k,
@ -1140,12 +1141,13 @@ class TestTransformers(NNTestCase):
# TODO: Support cross-device / dtype testing properly when instantiate_device_type_tests() is used.
dtypes = [torch.double, torch.float]
for dtype in dtypes:
N = batch_size
def rand_tensor(*shape):
return torch.randn(shape, device=device, dtype=dtype)
# This test compares python and C++ implementations of SDP.
N, N_prime, L, S, E = 5, 2, 4, 3, 6
N_prime, L, S, E = 2, 4, 3, 6
if input_dim == 3:
query = rand_tensor(N, L, E)
key = rand_tensor(N, S, E)

View File

@ -5,11 +5,12 @@ from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import varlen_attn
from torch.nn.attention.varlen import varlen_attn
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import parametrize, run_tests
from torch.utils._python_dispatch import TorchDispatchMode
VarlenShape = namedtuple(
@ -23,6 +24,18 @@ default_tolerances = {
}
class OpLoggingMode(TorchDispatchMode):
"""Logging mode that captures all dispatched operations"""
def __init__(self):
self.called_ops = []
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
op_name = str(func)
self.called_ops.append(op_name)
return func(*args, **(kwargs or {}))
class AttentionBlock(nn.Module):
def __init__(
self, embed_dim: int, num_heads: int, device: torch.device, dtype: torch.dtype
@ -39,12 +52,9 @@ class AttentionBlock(nn.Module):
embed_dim, embed_dim, bias=False, device=device, dtype=dtype
)
def forward_varlen(
def get_varlen_qkv(
self,
x_packed: torch.Tensor,
cu_seq: torch.Tensor,
max_len: int,
is_causal: bool = False,
):
qkv = self.qkv_proj(x_packed)
q, k, v = qkv.chunk(3, dim=-1)
@ -53,24 +63,51 @@ class AttentionBlock(nn.Module):
k = k.view(-1, self.num_heads, self.head_dim)
v = v.view(-1, self.num_heads, self.head_dim)
attn_out = varlen_attn(
q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal=is_causal
)
return q, k, v
def forward_varlen(
self,
x_packed: torch.Tensor,
cu_seq: torch.Tensor,
max_len: int,
is_causal: bool = False,
):
q, k, v = self.get_varlen_qkv(x_packed)
attn_out = varlen_attn(q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal)
attn_out = attn_out.view(-1, self.embed_dim)
return self.out_proj(attn_out)
def forward_sdpa(self, x_padded: torch.Tensor, is_causal: bool = False):
def forward_sdpa(
self,
x_padded: torch.Tensor,
seq_lengths: torch.Tensor,
dtype: torch.dtype,
is_causal: bool = False,
):
batch_size, seq_len, _ = x_padded.shape
qkv = self.qkv_proj(x_padded)
q, k, v = qkv.chunk(3, dim=-1)
mask = (
torch.arange(seq_len, device=x_padded.device)[None, :]
< seq_lengths[:, None]
)
attn_mask = mask[:, None, None, :].expand(
batch_size, self.num_heads, seq_len, seq_len
)
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
attn_out = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=is_causal
)
attn_out = (
attn_out.transpose(1, 2)
.contiguous()
@ -91,7 +128,9 @@ def create_variable_length_batch(
seq_lengths = torch.tensor(seq_lengths, device=device)
total_tokens = seq_lengths.sum().item()
x_packed = torch.randn(total_tokens, shape.embed_dim, device=device, dtype=dtype)
x_packed = torch.randn(
total_tokens, shape.embed_dim, device=device, dtype=dtype, requires_grad=True
)
cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32)
cu_seq[1:] = seq_lengths.cumsum(0)
@ -106,6 +145,7 @@ def create_variable_length_batch(
end_idx = start_idx + seq_len
x_padded[i, :seq_len] = x_packed[start_idx:end_idx]
start_idx = end_idx
x_padded = x_padded.clone().detach().requires_grad_()
return {
"seq_lengths": seq_lengths,
@ -133,7 +173,11 @@ class TestVarlenAttention(NNTestCase):
total_tokens = shape.batch_size * shape.max_seq_len
x_packed = torch.randn(
total_tokens, shape.embed_dim, device=device, dtype=dtype
total_tokens,
shape.embed_dim,
device=device,
dtype=dtype,
requires_grad=True,
)
cu_seq = torch.tensor(
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
@ -147,6 +191,128 @@ class TestVarlenAttention(NNTestCase):
self.assertEqual(output.device, torch.device(device))
self.assertEqual(output.dtype, dtype)
varlen_grad_out = torch.ones_like(output)
varlen_grad = torch.autograd.grad(
outputs=output,
inputs=x_packed,
grad_outputs=varlen_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
self.assertIsNotNone(varlen_grad)
self.assertEqual(varlen_grad.shape, x_packed.shape)
self.assertEqual(varlen_grad.dtype, x_packed.dtype)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@parametrize("dtype", [torch.bfloat16, torch.float16])
def test_custom_op_compliance(self, device, dtype):
torch.manual_seed(42)
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
attention_block = AttentionBlock(
shape.embed_dim, shape.num_heads, device, dtype
)
total_tokens = shape.batch_size * shape.max_seq_len
x_packed = torch.randn(
total_tokens,
shape.embed_dim,
device=device,
dtype=dtype,
)
cu_seq = torch.tensor(
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
)
q, k, v = attention_block.get_varlen_qkv(x_packed)
torch.library.opcheck(
torch.ops.torch_attn._varlen_attn,
(q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False),
)
out, lse, rng_state = torch.ops.torch_attn._varlen_attn(
q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False
)
grad_out = torch.randn_like(out)
# we don't support double backward
# skipping test_autograd_registration, test_aot_dispatch_dynamic, test_aot_dispatch_static
torch.library.opcheck(
torch.ops.torch_attn._varlen_attn_backward,
(
grad_out,
q,
k,
v,
out,
lse,
cu_seq,
cu_seq,
shape.max_seq_len,
shape.max_seq_len,
False,
rng_state,
),
test_utils=["test_schema", "test_faketensor"],
)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@parametrize("dtype", [torch.bfloat16, torch.float16])
def test_custom_op_registration(self, device, dtype):
torch.manual_seed(42)
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
attention_block = AttentionBlock(
shape.embed_dim, shape.num_heads, device, dtype
)
total_tokens = shape.batch_size * shape.max_seq_len
x_packed = torch.randn(
total_tokens,
shape.embed_dim,
device=device,
dtype=dtype,
requires_grad=True,
)
cu_seq = torch.tensor(
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
)
compiled_forward = torch.compile(
attention_block.forward_varlen, backend="eager", fullgraph=True
)
with OpLoggingMode() as mode:
output = compiled_forward(
x_packed, cu_seq, shape.max_seq_len, is_causal=False
)
varlen_grad_out = torch.ones_like(output)
_ = torch.autograd.grad(
outputs=output,
inputs=x_packed,
grad_outputs=varlen_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
called_ops = mode.called_ops
custom_ops_called = any(
"torch_attn._varlen_attn" in op for op in called_ops
) and any("torch_attn._varlen_attn_backward" in op for op in called_ops)
assert custom_ops_called
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@ -172,7 +338,10 @@ class TestVarlenAttention(NNTestCase):
is_causal=is_causal,
)
sdpa_output = attention_block.forward_sdpa(
variable_length_batch_data["x_padded"], is_causal=is_causal
variable_length_batch_data["x_padded"],
variable_length_batch_data["seq_lengths"],
dtype=dtype,
is_causal=is_causal,
)
tolerances = default_tolerances[dtype]
@ -186,6 +355,44 @@ class TestVarlenAttention(NNTestCase):
torch.testing.assert_close(varlen_seq, sdpa_seq, **tolerances)
start_idx = end_idx
varlen_grad_out = torch.ones_like(varlen_output)
sdpa_grad_out = torch.zeros_like(sdpa_output)
start_idx = 0
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
end_idx = start_idx + seq_len
sdpa_grad_out[i, :seq_len] = varlen_grad_out[start_idx:end_idx]
start_idx = end_idx
varlen_grad = torch.autograd.grad(
outputs=varlen_output,
inputs=variable_length_batch_data["x_packed"],
grad_outputs=varlen_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
sdpa_grad = torch.autograd.grad(
outputs=sdpa_output,
inputs=variable_length_batch_data["x_padded"],
grad_outputs=sdpa_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
start_idx = 0
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
end_idx = start_idx + seq_len
varlen_grad_seq = varlen_grad[start_idx:end_idx]
sdpa_grad_seq = sdpa_grad[i, :seq_len]
torch.testing.assert_close(varlen_grad_seq, sdpa_grad_seq, **tolerances)
start_idx = end_idx
device_types = ("cuda",)

View File

@ -445,7 +445,7 @@ use_numpy_random_stream = False
enable_cpp_guard_manager = True
# Use C++ guard manager for symbolic shapes
enable_cpp_symbolic_shape_guards = not is_fbcode()
enable_cpp_symbolic_shape_guards = False
# Enable tracing through contextlib.contextmanager
enable_trace_contextlib = True

View File

@ -2820,5 +2820,36 @@
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0280": [
{
"Gb_type": "1-arg super not implemented",
"Context": "",
"Explanation": "Dynamo failed to trace attribute `{name}` accessed via `super()` (for type `{self.typevar}` and object `{self.objvar}`) because one-argument of super() is not supported.",
"Hints": [
"Use two-argument super(type, object_or_type)."
]
}
],
"GB0281": [
{
"Gb_type": "Invalid or non-const argument in nn.Module __getitem__",
"Context": "call_method: {self} {name} {args} {kwargs}",
"Explanation": "Dynamo does not support calling method `{name}` of ``nn.Module`` {module} with a non-constant or non-(str, int) key.",
"Hints": [
"Use constant arguments of type str or int for __getitem__"
]
}
],
"GB0282": [
{
"Gb_type": "Placement with custom __getattr__ not supported",
"Context": "{value_type.__name__} with custom __getattr__",
"Explanation": "Dynamo does not support Placement types with custom __getattr__ methods",
"Hints": [
"Use Placement types without custom __getattr__ methods",
"Move the Placement usage outside the compiled region"
]
}
]
}

View File

@ -2299,10 +2299,13 @@ class GuardBuilder(GuardBuilderBase):
],
)
def FUNCTION_MATCH(self, guard: Guard) -> None:
"""things like torch.add and user defined functions"""
# don't support this in serialization because it uses unsupported ID_MATCH
return self.ID_MATCH(guard)
def UNCLASSIFIED_ID_MATCH(self, guard: Guard) -> None:
"""
Calls id_match guard but also helps with future debugging where we are
calling ID_MATCH on an object that we don't understand why. This will
show up in tlparse.
"""
self.id_match_unchecked(guard)
def CLASS_MATCH(self, guard: Guard) -> None:
"""Equals ID_MATCH on classes - better readability than directly calling ID_MATCH"""
@ -2324,14 +2327,13 @@ class GuardBuilder(GuardBuilderBase):
def CLOSURE_MATCH(self, guard: Guard) -> None:
"""matches a closure by __code__ id."""
# don't support this in serialization because it uses unsupported FUNCTION_MATCH
val = self.get(guard.name)
# Strictly only want user-defined functions
if type(val) is types.FunctionType and hasattr(val, "__code__"):
self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) # type: ignore[arg-type]
self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) # type: ignore[arg-type]
self._guard_on_attribute(guard, "__code__", GuardBuilder.CONSTANT_MATCH) # type: ignore[arg-type]
else:
self.FUNCTION_MATCH(guard)
self.UNCLASSIFIED_ID_MATCH(guard)
def BUILTIN_MATCH(self, guard: Guard) -> None:
if self.save_guards:
@ -3718,11 +3720,11 @@ class CheckFunctionManager:
"DICT_VERSION",
"NN_MODULE",
"ID_MATCH",
"FUNCTION_MATCH",
"CLASS_MATCH",
"MODULE_MATCH",
"CLOSURE_MATCH",
"WEAKREF_ALIVE",
"UNCLASSIFIED_ID_MATCH",
)
def serialize_guards(

View File

@ -629,7 +629,7 @@ class VariableBuilder:
lambda self, value: LambdaVariable(
_dataclasses_fields_lambda,
source=self.source,
**self.install_guards(GuardBuilder.FUNCTION_MATCH),
**self.install_guards(GuardBuilder.CLOSURE_MATCH),
),
),
(torch.__version__, lambda self, value: TorchVersionVariable()),
@ -927,8 +927,10 @@ class VariableBuilder:
)
elif inspect.isclass(value):
self.install_guards(GuardBuilder.CLASS_MATCH)
elif inspect.isfunction(value):
self.install_guards(GuardBuilder.CLOSURE_MATCH)
elif callable(value):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
self.install_guards(GuardBuilder.ID_MATCH)
else:
self.install_guards(GuardBuilder.TYPE_MATCH)
return NumpyVariable(value, source=self.source)
@ -945,7 +947,7 @@ class VariableBuilder:
return NumpyTypeInfoVariable(value, source=self.source)
# NB: These can't be put in type_dispatch, they have to run later
elif CollectiveFunctionRewriteVariable.can_rewrite(value):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
self.install_guards(GuardBuilder.CLOSURE_MATCH)
return CollectiveFunctionRewriteVariable.create(
self.tx,
value,
@ -1371,7 +1373,7 @@ class VariableBuilder:
elif isinstance(value, types.MethodWrapperType):
# Method-wrappers are written in C, and they are not guaranteed to
# return the same object on attribute lookup. Therefore, we cannot
# insert a FUNCTION_MATCH guard here. method-wrappers are very
# insert a ID_MATCH guard here. method-wrappers are very
# unlikely to change, so its ok to skip the guard here.
return MethodWrapperVariable(value)
elif issubclass(type(value), type) and issubclass(value, BaseException):

View File

@ -24,10 +24,12 @@ import inspect
import sys
import warnings
from contextlib import ExitStack
from typing import TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union
import torch._C
from torch._dynamo.variables.misc import GetAttrVariable
from torch._guards import Guard
from torch.fx import Proxy
from .. import graph_break_hints, variables
from ..bytecode_transformation import (
@ -41,6 +43,7 @@ from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GlobalStateSource
from ..utils import _get_error_on_graph_break, _set_error_on_graph_break
from .base import VariableTracker
from .constant import ConstantVariable
from .functions import (
NestedUserFunctionVariable,
SkipFunctionVariable,
@ -992,13 +995,82 @@ class ProfilerContextVariable(ContextWrappingVariable):
class StreamContextVariable(ContextWrappingVariable):
"""This represents torch.cuda.StreamContext"""
@staticmethod
def create(tx: "InstructionTranslator", target_value, **kwargs):
def create(
tx: "InstructionTranslator",
target_value: "StreamVariable",
**kwargs: dict[str, Any],
) -> "StreamContextVariable":
return StreamContextVariable(
target_values=[target_value],
initial_values=[
StreamContextVariable._get_current_stream(target_value.device, tx)
],
device=target_value.device,
**kwargs,
)
def __init__(
self,
target_values: list["StreamVariable"],
device: torch.device,
initial_values: Optional[list["StreamVariable"]] = None,
**kwargs: dict[str, Any],
) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.device = device
self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
# to stream, from stream is the order of the arguments
# we are entering the target, and leaving the initial stream
tx.output.create_proxy(
"call_function",
torch.ops.streams.fork.default,
self._target_stream_proxies() + self._initial_stream_proxies(),
{},
)
return ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args: tuple[Any]) -> "VariableTracker":
# to stream, from stream is the order of the arguments
# we are leaving the target, and entering the initial stream
tx.output.create_proxy(
"call_function",
torch.ops.streams.join.default,
self._initial_stream_proxies() + self._target_stream_proxies(),
{},
)
return ConstantVariable.create(None)
def _initial_stream_proxies(self) -> tuple[Proxy, Proxy]:
assert self.initial_values, "No initial stream to move from"
return StreamContextVariable._extract_stream_properties(
self.initial_values[0].as_proxy()
)
def _target_stream_proxies(self) -> tuple[Proxy, Proxy]:
return StreamContextVariable._extract_stream_properties(
self.target_values[0].as_proxy()
)
@staticmethod
def _extract_stream_properties(stream_proxy: Proxy) -> tuple[Proxy, Proxy]:
stream_index = GetAttrVariable.create_getattr_proxy(stream_proxy, "stream_id")
stream_device = GetAttrVariable.create_getattr_proxy(stream_proxy, "device")
return stream_index, stream_device
@staticmethod
def _get_current_stream(
device: torch.device, tx: "InstructionTranslator"
) -> "StreamVariable":
from .builder import wrap_fx_proxy_cls
current_stream_method = get_interface_for_device(
target_value.device
).current_stream
current_stream_method = get_interface_for_device(device).current_stream
current_stream = wrap_fx_proxy_cls(
StreamVariable,
tx,
@ -1009,50 +1081,7 @@ class StreamContextVariable(ContextWrappingVariable):
{},
),
)
return StreamContextVariable(
target_values=[target_value],
initial_values=[current_stream],
device=target_value.device,
**kwargs,
)
def __init__(self, target_values, device, initial_values=None, **kwargs) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.device = device
self.set_stream = get_interface_for_device(self.device).set_stream
self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
def enter(self, tx):
# stream generated inside the traced function
if self.target_values[0].as_proxy() is not None:
tx.output.create_proxy(
"call_function",
self.set_stream,
(self.target_values[0].as_proxy(),),
{},
)
# stream passed from outside the traced function
else:
stream = self.target_values[0].value
tx.output.create_proxy(
"call_function",
self.set_stream_id,
(stream.stream_id, stream.device_index, stream.device_type),
{},
)
self.set_stream(self.target_values[0].value)
self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value))
def exit(self, tx: "InstructionTranslator", *args):
tx.output.create_proxy(
"call_function",
self.set_stream,
(self.initial_values[0].as_proxy(),),
{},
)
self.cleanup_assert()
return current_stream
class PreserveVersionContextVariable(ContextWrappingVariable):

View File

@ -210,9 +210,16 @@ class PlacementVariable(DistributedVariable):
if name in constant_fold_functions:
try:
value_type = type(self.value)
assert (
inspect.getattr_static(value_type, "__getattr__", None) is None
), "no custom getattr allowed!"
if inspect.getattr_static(value_type, "__getattr__", None) is not None:
unimplemented_v2(
gb_type="Placement with custom __getattr__ not supported",
context=f"{value_type.__name__} with custom __getattr__",
explanation="Dynamo does not support Placement types with custom __getattr__ methods",
hints=[
"Use Placement types without custom __getattr__ methods",
"Move the Placement usage outside the compiled region",
],
)
method = inspect.getattr_static(value_type, name)
except AttributeError:
method = None

View File

@ -2001,7 +2001,7 @@ class PolyfilledFunctionVariable(VariableTracker):
@classmethod
def create_with_source(cls, value, source):
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
return cls(value, source=source)

View File

@ -103,7 +103,17 @@ class SuperVariable(VariableTracker):
codegen.extend_output(create_call_function(1, False))
def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name):
assert self.objvar, "1-arg super not implemented"
if not self.objvar:
unimplemented_v2(
gb_type="1-arg super not implemented",
context="",
explanation=f"Dynamo failed to trace attribute `{name}` accessed "
f"via `super()` (for type `{self.typevar}` and object `{self.objvar}`) "
"because one-argument of super() is not supported.",
hints=[
"Use two-argument super(type, object_or_type).",
],
)
search_type = self.typevar.as_python_constant()
# The rest of this function does two things:
@ -1032,10 +1042,10 @@ class AutogradEngineVariable(UserDefinedObjectVariable):
assert tx.one_graph or tx.error_on_graph_break, (
"queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
)
# queue_callback is a method-wrapper, no need to insert a guard.
fn_vt = VariableTracker.build(
tx,
torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback,
source=self.source,
)
return fn_vt.call_function(
tx,

View File

@ -822,9 +822,19 @@ class NNModuleVariable(VariableTracker):
)
if type(module).__getitem__ not in builtin_supported:
assert isinstance(args[0], variables.ConstantVariable), typestr(args[0])
key = args[0].as_python_constant()
assert isinstance(key, (str, int))
if not (
isinstance(args[0], variables.ConstantVariable)
and isinstance(args[0].as_python_constant(), (str, int))
):
unimplemented_v2(
gb_type="Invalid or non-const argument in nn.Module __getitem__",
context=f"call_method: {self} {name} {args} {kwargs}",
explanation="Dynamo does not support calling "
f"method `{name}` of ``nn.Module`` {module} with a non-constant or non-(str, int) key.",
hints=[
"Use constant arguments of type str or int for __getitem__"
],
)
fn = getattr(module, name).__func__
assert isinstance(fn, types.FunctionType)

View File

@ -262,7 +262,9 @@ class BaseTorchVariable(VariableTracker):
# Dont need to guard on wrappers
pass
else:
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
# Installing an ID_MATCH to preserve the old behavior. But making it
# unclassified so that we can eventually remove it.
install_guard(source.make_guard(GuardBuilder.UNCLASSIFIED_ID_MATCH))
return cls(value, source=source)
def __init__(self, value, **kwargs) -> None:

View File

@ -720,13 +720,22 @@ def check_shape(
) -> None:
backend = get_current_backend()
assert shape is not None
if config.test_configs.runtime_triton_dtype_assert and backend == "triton":
if config.test_configs.runtime_triton_shape_assert and backend == "triton":
shape_str = (
", ".join(str(d) for d in shape) if len(shape) != 1 else f"{shape[0]},"
)
buffer.writeline(f"tl.static_assert({var}.shape == ({shape_str}))")
def check_nan(buffer: IndentedBuffer, var: CSEVariableType) -> None:
backend = get_current_backend()
if backend == "triton":
msg = "NaN or Inf found"
buffer.writeline(
f"tl.device_assert(({var} == {var}) & ({var} != float('inf')) & ({var} != float('-inf')), '{msg}')"
)
class DataTypePropagation:
def __init__(self, body: LoopBody) -> None:
self.body = body
@ -2623,6 +2632,9 @@ class CSEProxy(DefaultHandler):
assert output_shape is not None
check_shape(V.kernel.compute, csevar, output_shape)
if config.runtime_triton_nan_asserts:
check_nan(V.kernel.compute, csevar)
return csevar
return pytree.tree_map(do_cse, value)

View File

@ -626,7 +626,7 @@ class ComboKernel(Kernel):
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
filename=__file__,
num_warps={self.num_warps},
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)

View File

@ -206,6 +206,9 @@ static_weight_shapes = True
# put correctness assertions in generated code
size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1"
runtime_triton_nan_asserts = (
os.environ.get("TORCHINDUCTOR_RUNTIME_TRITON_NAN_ASSERTS") == "1"
)
scalar_asserts = os.environ.get("TORCHINDUCTOR_SCALAR_ASSERTS", "1") == "1"
# Disable by default in fbcode

View File

@ -3550,24 +3550,13 @@ def user_autotune(
)
def foreach(triton_meta, filename=None, inductor_meta=None):
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
configs = []
# Naive autotuning path for num_warps
if not inductor_meta.get("autotune_pointwise", True) and not (
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
):
configs.append(triton.Config({}, num_stages=1, num_warps=8))
else:
for warps in [1, 2, 4, 8]:
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
return cached_autotune(
None,
configs,
[triton.Config({}, num_stages=1, num_warps=num_warps)],
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,

View File

@ -8,6 +8,7 @@ from contextlib import AbstractContextManager
from typing import Any, Optional, Union
import torch
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._C import _functionalization_reapply_views_tls as _reapply_views
from torch._ops import _get_dispatch_mode_pre_dispatch
@ -512,6 +513,30 @@ class FunctionalTensorMode(TorchDispatchMode):
torch.Tensor, wrap, outs_unwrapped
)
else:
# Note: [Functionalization View Replay Annotation]
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
# at the first time they are next used.
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
# to have the same annotation that the user specified on the original views. But view replay in
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
# for second_op!
#
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
# is globally set when we lazily perform view replay.
# The globally set metadata will be used to populate the fx node created for the replayed operation.
if m := torch._C._get_dispatch_mode(
torch._C._TorchDispatchModeKey.PROXY
):
for a in pytree.tree_leaves([args, kwargs]):
if not isinstance(a, FunctionalTensor):
continue
curr_node = m.tracer.tensor_tracker[
torch._from_functional_tensor(a.elem)
].proxy.node
with fx_traceback.set_current_replay_node(curr_node):
torch._sync(a)
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch

View File

@ -1,13 +1,9 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
from dataclasses import dataclass, field
from typing import cast, Optional
from typing import cast
import torch
import torch.distributed._functional_collectives as funcol
from torch.distributed._local_tensor import maybe_run_for_local_tensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._op_schema import (
OpSchema,
OpStrategy,
@ -19,8 +15,8 @@ from torch.distributed.tensor._ops.utils import (
register_op_strategy,
)
from torch.distributed.tensor.placement_types import (
MaskPartial,
Partial,
Placement,
Replicate,
Shard,
)
@ -29,190 +25,6 @@ from torch.distributed.tensor.placement_types import (
aten = torch.ops.aten
@dataclass
class MaskBuffer:
data: Optional[torch.Tensor] = None
# refcount allows shared usage of the MaskBuffer, as long as all users have the same data
refcount: int = 0
def materialize_mask(self, mask):
if self.refcount == 0:
self.data = mask
else:
assert self.data is not None
if not torch.equal(self.data, mask):
raise RuntimeError(
"MaskBuffer has been materialized with conflicting data"
)
self.refcount += 1
def release_mask(self):
if self.refcount == 0 or self.data is None:
raise RuntimeError("MaskBuffer has not been materialized")
self.refcount -= 1
if self.refcount == 0:
self.data = None
def apply_mask(self, tensor):
if self.refcount == 0 or self.data is None:
raise RuntimeError("MaskBuffer has not been materialized")
# NOTE: _MaskPartial is being used by the embedding op and the gather op.
# For gather, the mask has the same dimension as the output tensor, whereas
# the output of the embedding op has an additional dimension compare to the input,
# hence the output masking logic below having two different cases.
if tensor.ndim == self.data.ndim:
tensor[self.data] = 0.0
else:
tensor[self.data, :] = 0.0
@dataclass(frozen=True)
class _MaskPartial(Partial):
"""
A partial mask placement devised for rowwise sharded embedding op, where we need
to mask and adjust the indices to the local embedding shard, embedding masking
is a special type of the Partial placement
NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor
lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor.
"""
mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
# required fields for computing the local offset and deriving the mask
offset_shape: Optional[torch.Size] = None
offset_dim: int = 0
def __init__(
self,
reduce_op=None,
mask_buffer=None,
offset_shape=None,
offset_dim=0,
*args,
**kwargs,
):
super().__init__(reduce_op)
if mask_buffer is None:
mask_buffer = MaskBuffer()
object.__setattr__(self, "mask_buffer", mask_buffer)
object.__setattr__(self, "offset_shape", offset_shape)
object.__setattr__(self, "offset_dim", offset_dim)
@staticmethod
@maybe_run_for_local_tensor
def _mask_tensor(
tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int
) -> tuple[torch.Tensor, torch.Tensor]:
# Build the input mask and save it for the current partial placement
# this is so that the output of embedding op can reuse the same partial
# placement saved mask to perform mask + reduction
mask = (tensor < local_offset_on_dim) | (
tensor >= local_offset_on_dim + local_shard_size
)
# mask the input tensor
masked_tensor = tensor.clone() - local_offset_on_dim
masked_tensor[mask] = 0
return mask, masked_tensor
def _partition_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
my_coordinate = mesh.get_coordinate()
assert my_coordinate is not None, "my_coordinate should not be None"
# override parent logic to perform partial mask for embedding
num_chunks = mesh.size(mesh_dim)
# get local shard size and offset on the embedding_dim
assert self.offset_shape is not None, (
"offset_shape needs to be set for _MaskPartial"
)
local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset(
self.offset_shape[self.offset_dim],
num_chunks,
my_coordinate[mesh_dim],
)
mask, masked_tensor = _MaskPartial._mask_tensor(
tensor, local_offset_on_dim, local_shard_size
)
# materialize the mask buffer to be used for reduction
self.mask_buffer.materialize_mask(mask)
return masked_tensor
def _reduce_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
# by the time we need reduction, we should have already saved the mask
assert self.mask_buffer.data is not None
# apply the mask to the tensor that pending reduction
self.mask_buffer.apply_mask(tensor)
# clear the mask buffer
self.mask_buffer.release_mask()
# perform sum reduction
return funcol.all_reduce(
tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim)
)
def _reduce_shard_value(
self,
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
shard_spec: Placement,
) -> torch.Tensor:
# by the time we need reduction, we should have already saved the mask
assert self.mask_buffer.data is not None
# apply the mask to the tensor that pending reduction
self.mask_buffer.apply_mask(tensor)
# clear the mask buffer
self.mask_buffer.release_mask()
# call reduce_shard_tensor of the shard_spec.
shard_spec = cast(Shard, shard_spec)
return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim)
def __eq__(self, other: object) -> bool:
if not isinstance(other, _MaskPartial):
return False
# if either data is not None, we invalidate the sharding cache, as this indicates
# the current MaskPartial placement is still in use and should not be used for cache hit.
if self.mask_buffer.data is not None or other.mask_buffer.data is not None:
return False
return (
self.reduce_op == other.reduce_op
and self.offset_shape == other.offset_shape
and self.offset_dim == other.offset_dim
)
def __hash__(self) -> int:
return 1 + hash(
(
self.reduce_op,
self.offset_shape,
self.offset_dim,
)
)
def __repr__(self) -> str:
"""
machine readable representation of the MaskPartial placement
"""
return f"_MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})"
def __str__(self) -> str:
"""
human readable representation of the MaskPartial placement
"""
return "MaskP"
@register_op_strategy(aten.embedding.default)
def embedding_strategy(op_schema: OpSchema) -> StrategyType:
"""
@ -239,7 +51,7 @@ def embedding_strategy(op_schema: OpSchema) -> StrategyType:
single_mesh_dim_strategies.append(colwise_sharding)
# rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial
embedding_partial_placement = _MaskPartial(offset_shape=weight_shape, offset_dim=0)
embedding_partial_placement = MaskPartial(offset_shape=weight_shape, offset_dim=0)
# NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates
# from the input indices and use it for output reduction

View File

@ -0,0 +1,44 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from dataclasses import dataclass
from typing import Optional
import torch
@dataclass
class MaskBuffer:
data: Optional[torch.Tensor] = None
# refcount allows shared usage of the MaskBuffer, as long as all users have the same data
refcount: int = 0
def materialize_mask(self, mask):
if self.refcount == 0:
self.data = mask
else:
assert self.data is not None
if not torch.equal(self.data, mask):
raise RuntimeError(
"MaskBuffer has been materialized with conflicting data"
)
self.refcount += 1
def release_mask(self):
if self.refcount == 0 or self.data is None:
raise RuntimeError("MaskBuffer has not been materialized")
self.refcount -= 1
if self.refcount == 0:
self.data = None
def apply_mask(self, tensor):
if self.refcount == 0 or self.data is None:
raise RuntimeError("MaskBuffer has not been materialized")
# NOTE: MaskPartial is being used by the embedding op and the gather op.
# For gather, the mask has the same dimension as the output tensor, whereas
# the output of the embedding op has an additional dimension compare to the input,
# hence the output masking logic below having two different cases.
if tensor.ndim == self.data.ndim:
tensor[self.data] = 0.0
else:
tensor[self.data, :] = 0.0

View File

@ -17,7 +17,7 @@ from torch.distributed.tensor._op_schema import (
TupleStrategy,
)
from torch.distributed.tensor._ops._common_rules import pointwise_rule
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor._ops._embedding_ops import MaskPartial
from torch.distributed.tensor._ops.utils import (
expand_to_full_mesh_op_strategy,
generate_redistribute_costs,
@ -646,7 +646,7 @@ def gather_strategy(op_schema: OpSchema) -> StrategyType:
# this only works when the input is sharded on the gather dimension, and
# index has size 1 on the gather dimension
if dim < len(index_shape) and index_shape[dim] == 1:
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
index_partial_placement = MaskPartial(offset_shape=input_shape, offset_dim=dim)
input_sharding: PlacementList = [
index_partial_placement,
Shard(dim),

View File

@ -11,7 +11,7 @@ from torch import Tensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor._ops._embedding_ops import MaskPartial
from torch.distributed.tensor._ops._math_ops import (
_skip_dim,
Reduction,
@ -236,7 +236,7 @@ def _nll_loss_forward(
# The following code block is a distributed version of
# result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
partial_placement = MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
safe_target_partial_ = partial_placement._partition_value(
safe_target_, mesh, mesh_dim
)
@ -375,7 +375,7 @@ def _nll_loss_and_log_softmax_backward(
# The following code block is a distributed version of
# grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
partial_placement = MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
safe_target = safe_target.squeeze(channel_dim).flatten()
masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim)
# only update grad_input to -1 if not masked

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from dataclasses import dataclass, field
from typing import cast, Optional
import torch
@ -17,9 +18,10 @@ from torch.distributed.tensor._collective_utils import (
shard_dim_alltoall,
unpad_tensor,
)
from torch.distributed.tensor._ops._mask_buffer import MaskBuffer
__all__ = ["Placement", "Shard", "Replicate", "Partial"]
__all__ = ["Placement", "Shard", "Replicate", "Partial", "MaskPartial"]
# Appease TestPublicBindings.test_correct_module_names
@ -841,3 +843,149 @@ class Partial(torch._C._distributed.Partial):
# We keep the old _Partial name for a while for BC reason
_Partial = Partial
@dataclass(frozen=True)
class MaskPartial(Partial):
"""
A partial mask placement devised for rowwise sharded embedding op, where we need
to mask and adjust the indices to the local embedding shard, embedding masking
is a special type of the Partial placement
NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor
lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor.
"""
mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
# required fields for computing the local offset and deriving the mask
offset_shape: Optional[torch.Size] = None
offset_dim: int = 0
def __init__(
self,
reduce_op=None,
mask_buffer=None,
offset_shape=None,
offset_dim=0,
*args,
**kwargs,
):
super().__init__(reduce_op)
if mask_buffer is None:
mask_buffer = MaskBuffer()
object.__setattr__(self, "mask_buffer", mask_buffer)
object.__setattr__(self, "offset_shape", offset_shape)
object.__setattr__(self, "offset_dim", offset_dim)
@staticmethod
@maybe_run_for_local_tensor
def _mask_tensor(
tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int
) -> tuple[torch.Tensor, torch.Tensor]:
# Build the input mask and save it for the current partial placement
# this is so that the output of embedding op can reuse the same partial
# placement saved mask to perform mask + reduction
mask = (tensor < local_offset_on_dim) | (
tensor >= local_offset_on_dim + local_shard_size
)
# mask the input tensor
masked_tensor = tensor.clone() - local_offset_on_dim
masked_tensor[mask] = 0
return mask, masked_tensor
def _partition_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
my_coordinate = mesh.get_coordinate()
assert my_coordinate is not None, "my_coordinate should not be None"
# override parent logic to perform partial mask for embedding
num_chunks = mesh.size(mesh_dim)
# get local shard size and offset on the embedding_dim
assert self.offset_shape is not None, (
"offset_shape needs to be set for MaskPartial"
)
local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset(
self.offset_shape[self.offset_dim],
num_chunks,
my_coordinate[mesh_dim],
)
mask, masked_tensor = MaskPartial._mask_tensor(
tensor, local_offset_on_dim, local_shard_size
)
# materialize the mask buffer to be used for reduction
self.mask_buffer.materialize_mask(mask)
return masked_tensor
def _reduce_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
# by the time we need reduction, we should have already saved the mask
assert self.mask_buffer.data is not None
# apply the mask to the tensor that pending reduction
self.mask_buffer.apply_mask(tensor)
# clear the mask buffer
self.mask_buffer.release_mask()
# perform sum reduction
return funcol.all_reduce(
tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim)
)
def _reduce_shard_value(
self,
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
shard_spec: Placement,
) -> torch.Tensor:
# by the time we need reduction, we should have already saved the mask
assert self.mask_buffer.data is not None
# apply the mask to the tensor that pending reduction
self.mask_buffer.apply_mask(tensor)
# clear the mask buffer
self.mask_buffer.release_mask()
# call reduce_shard_tensor of the shard_spec.
shard_spec = cast(Shard, shard_spec)
return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim)
def __eq__(self, other: object) -> bool:
if not isinstance(other, MaskPartial):
return False
# if either data is not None, we invalidate the sharding cache, as this indicates
# the current MaskPartial placement is still in use and should not be used for cache hit.
if self.mask_buffer.data is not None or other.mask_buffer.data is not None:
return False
return (
self.reduce_op == other.reduce_op
and self.offset_shape == other.offset_shape
and self.offset_dim == other.offset_dim
)
def __hash__(self) -> int:
return 1 + hash(
(
self.reduce_op,
self.offset_shape,
self.offset_dim,
)
)
def __repr__(self) -> str:
"""
machine readable representation of the MaskPartial placement
"""
return f"MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})"
def __str__(self) -> str:
"""
human readable representation of the MaskPartial placement
"""
return "MaskP"

View File

@ -206,6 +206,21 @@ class TracerBase:
if current_meta.get("in_grad_fn", 0) > 0:
annotation_log.debug("seq_nr from current_meta")
new_seq_nr = current_meta["grad_fn_seq_nr"][-1]
# See Note [Functionalization View Replay Annotation]
# Overriding some node meta with the original node meta of the
# regenerated node.
replay_node: Node = fx_traceback.get_current_replay_node()
if replay_node is not None:
node.meta["is_functional_regenerated"] = True
if "seq_nr" in replay_node.meta:
annotation_log.debug("seq_nr from replay_node")
new_seq_nr = replay_node.meta["seq_nr"]
if "custom" in replay_node.meta:
node.meta["custom"] = replay_node.meta.get("custom")
if "stack_trace" in replay_node.meta:
node.stack_trace = replay_node.meta.get("stack_trace")
annotation_log.debug("Assigning new_seq_nr %s to %s", new_seq_nr, node.name)
node.meta["seq_nr"] = new_seq_nr

View File

@ -30,9 +30,12 @@ __all__ = [
"NodeSource",
"NodeSourceAction",
"get_graph_provenance_json",
"set_current_replay_node",
"get_current_replay_node",
]
current_meta: dict[str, Any] = {}
current_replay_node: Optional[Node] = None
should_preserve_node_meta = False
@ -400,6 +403,31 @@ def get_current_meta() -> dict[str, Any]:
return current_meta
@compatibility(is_backward_compatible=False)
@contextmanager
def set_current_replay_node(node):
"""
Set the currently replay node. If `current_replay_node` is not None,
then we're re-generating the `current_replay_node` in FunctionalTensorMode.
"""
# See [Note] annotation for more details.
global current_replay_node
saved_current_replay_node = current_replay_node
try:
current_replay_node = node
yield
finally:
current_replay_node = saved_current_replay_node
@compatibility(is_backward_compatible=False)
def get_current_replay_node():
"""
Get the currently replay node
"""
return current_replay_node
@compatibility(is_backward_compatible=False)
def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
"""

View File

@ -14,14 +14,11 @@ from torch.backends.cuda import (
SDPAParams,
)
from .varlen import varlen_attn
__all__: list[str] = [
"SDPBackend",
"sdpa_kernel",
"WARN_FOR_UNFUSED_KERNELS",
"varlen_attn",
]
# Note: [SDPA warnings]

View File

@ -34,7 +34,7 @@ from torch.fx.experimental.proxy_tensor import (
_temp_remove_pre_dispatch_torch_function_mode,
)
from torch.nn.attention._utils import _validate_sdpa_input
from torch.utils._pytree import GetAttrKey, register_pytree_node, tree_map_only
from torch.utils._pytree import GetAttrKey, tree_map_only
# Private debug flag to disable internal compilation wrapping for debugging purposes.
@ -1648,12 +1648,3 @@ def flex_attention(
return _finalize_outputs(
out, lse, max_scores, return_aux=return_aux, return_lse=return_lse
)
register_pytree_node(
BlockMask,
BlockMask._flatten,
BlockMask._unflatten,
flatten_with_keys_fn=BlockMask._flatten_with_keys,
serialized_type_name="torch.nn.attention.flex_attention.BlockMask",
)

View File

@ -7,7 +7,7 @@ that calls into the optimized Flash Attention kernels.
import logging
from functools import lru_cache
from typing import NamedTuple, Optional, Union
from typing import Any, NamedTuple, Optional, Union
import torch
@ -33,8 +33,7 @@ class AuxRequest(NamedTuple):
lse: bool = False
# import failures when I try to register as custom op
# @torch.library.custom_op("torch_nn_attention::_varlen_attn", mutates_args={})
@torch.library.custom_op("torch_attn::_varlen_attn", mutates_args={})
def _varlen_attn(
query: torch.Tensor,
key: torch.Tensor,
@ -44,7 +43,7 @@ def _varlen_attn(
max_q: int,
max_k: int,
is_causal: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Private custom op for variable-length attention.
@ -70,7 +69,7 @@ def _varlen_attn(
False, # return_debug_mask
)
# cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask)
output, softmax_lse = result[0], result[1]
output, softmax_lse, rng_state = result[0], result[1], result[6]
else:
log.info("Using Flash Attention backend for varlen_attn")
output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward(
@ -86,10 +85,13 @@ def _varlen_attn(
return_debug_mask=False,
)
return output, softmax_lse
rng_state_ = torch.zeros(
(2,), dtype=torch.uint64, device=query.device
) # hardcoded since dropout is hardcoded to 0
return output, softmax_lse, rng_state_
# @_varlen_attn.register_fake
@_varlen_attn.register_fake
def _varlen_attn_fake(
query: torch.Tensor,
key: torch.Tensor,
@ -99,7 +101,7 @@ def _varlen_attn_fake(
max_q: int,
max_k: int,
is_causal: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Fake implementation for meta tensor computation and tracing.
@ -117,7 +119,9 @@ def _varlen_attn_fake(
(num_heads, total_q), dtype=torch.float, device=query.device
)
return output, logsumexp
rng_state = torch.empty((2,), dtype=torch.uint64, device=query.device)
return output, logsumexp, rng_state
def varlen_attn(
@ -191,9 +195,145 @@ def varlen_attn(
... query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False
... )
"""
out, lse = _varlen_attn(
out, lse, _ = torch.ops.torch_attn._varlen_attn(
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal
)
if return_aux is not None and return_aux.lse:
return out, lse
return out
def _setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None:
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal = inputs
out, lse, rng_state = output
ctx.query = query
ctx.key = key
ctx.value = value
ctx.cu_seq_q = cu_seq_q
ctx.cu_seq_k = cu_seq_k
ctx.max_q = max_q
ctx.max_k = max_k
ctx.is_causal = is_causal
ctx.output = out
ctx.lse = lse
ctx.rng_state = rng_state
@torch.library.custom_op("torch_attn::_varlen_attn_backward", mutates_args={})
def _varlen_attn_backward(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
lse: torch.Tensor,
cu_seq_q: torch.Tensor,
cu_seq_k: torch.Tensor,
max_q: int,
max_k: int,
is_causal: bool,
rng_state: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
unused = torch.empty(0, device=query.device)
use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
if use_cudnn:
log.info("Using cuDNN backend for varlen_attn")
dq, dk, dv = torch.ops.aten._cudnn_attention_backward(
grad_out,
query,
key,
value,
out,
lse,
cu_seq_q,
cu_seq_k,
max_q,
max_k,
0.0,
is_causal,
rng_state,
unused,
)
else:
log.info("Using Flash Attention backend for varlen_attn")
dq, dk, dv = torch.ops.aten._flash_attention_backward(
grad_out,
query,
key,
value,
out,
lse,
cu_seq_q,
cu_seq_k,
max_q,
max_k,
0.0,
is_causal,
rng_state,
unused,
)
return dq, dk, dv
@_varlen_attn_backward.register_fake
def _varlen_attn_backward_fake(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
lse: torch.Tensor,
cu_seq_q: torch.Tensor,
cu_seq_k: torch.Tensor,
max_q: int,
max_k: int,
is_causal: bool,
rng_state: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Fake implementation for meta tensor computation and tracing.
"""
grad_query = torch.empty_like(query)
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)
return grad_query, grad_key, grad_value
def _backward(
ctx: Any, grad_out: torch.Tensor, grad_lse: torch.Tensor, grad_rng: torch.Tensor
) -> tuple[Optional[torch.Tensor], ...]:
query = ctx.query
key = ctx.key
value = ctx.value
cu_seq_q = ctx.cu_seq_q
cu_seq_k = ctx.cu_seq_k
max_q = ctx.max_q
max_k = ctx.max_k
is_causal = ctx.is_causal
out = ctx.output
lse = ctx.lse
rng_state = ctx.rng_state
# rng_state = torch.empty(2, device=query.device)
dq, dk, dv = torch.ops.torch_attn._varlen_attn_backward(
grad_out,
query,
key,
value,
out,
lse,
cu_seq_q,
cu_seq_k,
max_q,
max_k,
is_causal,
rng_state,
)
return dq, dk, dv, None, None, None, None, None, None
_varlen_attn.register_autograd(_backward, setup_context=_setup_context)

View File

@ -43,7 +43,6 @@ from torch.testing._internal.common_utils import (
_TestParametrizer,
skipIfMPS,
skipIfTorchDynamo,
skipIfXpu,
TEST_WITH_TORCHDYNAMO,
)
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
@ -2201,9 +2200,6 @@ optim_db: list[OptimizerInfo] = [
"TestOptimRenewed",
device_type="mps",
),
DecorateInfo(
skipIfXpu(msg="SparseAdam is not yet supported on the XPU stack"),
),
DecorateInfo(
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
"TestOptimRenewed",

View File

@ -447,6 +447,56 @@ def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
def ceil_div(a, b):
return (a + b - 1) // b
# NVIDIA Blackwell HW requires scales for MX/NV blocked formats to be in a 128x4 tile layout,
# with a weird 32x4x4 internal layout of that tile. If we want to take swizzled scales and use them
# for non-gemm purposes (like testing), we need to de-swizzle them, then they can be applied much
# more naturally.
def from_blocked(input, input_scales, blocksize) -> torch.Tensor:
# Matrix is in a 128x4 pattern, internally blocked as 32x4x4 nonsense.
# Output should be [input.size(0, input.size(1) // blocksize] scales
output_scales = torch.zeros(
(input.size(0), input.size(1) // blocksize),
device=input.device,
dtype=input_scales.dtype,
)
# Swizzled scales are padded to tiles of 128x4, we need to replicate how that padding
# happened for offset purposes.
# There are K//blocksize scales, padded to groups of 4.
num_col_tiles = ceil_div(ceil_div(input.size(1), blocksize), 4)
# (Very) slow reference implementation using horrifying loops.
for i in range(input.size(0)):
for j in range(input.size(1) // blocksize):
# which 128x4 tile of scaling factors am I in
scale_tile_h = i // 128
scale_tile_w = j // 4
# There are (padded) input_scales.size(1) // 4 tiles along the w dim.
# So offset is 512 * (h_tile * tiles_per_row + tile_in_row)
tile_offset = 512 * (scale_tile_h * num_col_tiles + scale_tile_w)
# indices within the tile - use nomenclature directly from cublas docs
outer = i % 128 # "outer" in cublas docs
inner = j % 4 # "inner" in cublas docs
# Note: "offset" is given in terms of bytes, in cublas docs, but our scales are e8m0,
# anyway, and so 1B == 1 value => use offset directly.
# Formula directly from cublas docs in 3.1.4.3.2
offset = tile_offset + (outer % 32) * 16 + (outer // 32) * 4 + inner
output_scales[i, j] = input_scales[offset]
return output_scales
def from_blocked_format(x_mxfp8, scales_unswizzled, blocksize=32):
# expand scales
scales = torch.repeat_interleave(scales_unswizzled, blocksize, dim=1)
# de-scale and convert
x_f32 = x_mxfp8.to(torch.float) * scales.to(torch.float)
return x_f32.to(torch.bfloat16)
def to_blocked(input_matrix) -> torch.Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.

View File

@ -950,6 +950,13 @@ def prof_meth_call(*args, **kwargs):
torch._C.ScriptFunction.__call__ = prof_func_call # type: ignore[method-assign]
torch._C.ScriptMethod.__call__ = prof_meth_call # type: ignore[method-assign]
def _get_test_report_path():
# allow users to override the test file location. We need this
# because the distributed tests run the same test file multiple
# times with different configurations.
override = os.environ.get('TEST_REPORT_SOURCE_OVERRIDE')
test_source = override if override is not None else 'python-unittest'
return os.path.join('test-reports', test_source)
def parse_cmd_line_args():
global CI_FUNCTORCH_ROOT
@ -980,7 +987,9 @@ def parse_cmd_line_args():
parser.add_argument('--repeat', type=int, default=1)
parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
parser.add_argument('--use-pytest', action='store_true')
parser.add_argument('--save-xml', type=str)
parser.add_argument('--save-xml', nargs='?', type=str,
const=_get_test_report_path(),
default=_get_test_report_path() if IS_CI else None)
parser.add_argument('--discover-tests', action='store_true')
parser.add_argument('--log-suffix', type=str, default="")
parser.add_argument('--run-parallel', type=int, default=1)
@ -1010,9 +1019,6 @@ def parse_cmd_line_args():
# infer flags based on the default settings
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
if args.save_xml is None and IS_CI:
args.xml_dir = get_report_dir(sys.argv[0], args.log_suffix, args.use_pytest)
RERUN_DISABLED_TESTS = args.rerun_disabled_tests
SLOW_TESTS_FILE = args.import_slow_tests
@ -1185,37 +1191,19 @@ def lint_test_case_extension(suite):
return succeed
def get_report_dir(test_name: str, log_suffix: Optional[str], is_pytest: bool) -> str:
"""Generates a test report directory path. Test name does not need to be
sanitized."""
# total path = test-reports/test_source+log_suffix/test_filename
# Base path
test_source = "python-unittest"
if is_pytest:
test_source = "python-pytest"
# allow users to override the test file location. We need this
# because the distributed tests run the same test file multiple
# times with different configurations.
override = os.environ.get('TEST_REPORT_SOURCE_OVERRIDE')
if override is not None:
test_source = override
# Add log suffix to if provided
if log_suffix and log_suffix != "":
test_source = test_source + log_suffix
test_report_dir = os.path.join('test-reports', test_source)
# Add test file name to path
test_filename = sanitize_test_filename(test_name)
test_report_dir = os.path.join(test_report_dir, test_filename)
os.makedirs(test_report_dir, exist_ok=True)
return test_report_dir
def get_report_path(report_dir: str, test_filename: str) -> str:
return os.path.join(report_dir, f"{sanitize_test_filename(test_filename)}-{os.urandom(8).hex()}.xml")
def get_report_path(argv=None, pytest=False):
if argv is None:
argv = UNITTEST_ARGS
test_filename = sanitize_test_filename(argv[0])
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
test_report_path = os.path.join(test_report_path, test_filename)
if pytest:
test_report_path = test_report_path.replace('python-unittest', 'python-pytest')
os.makedirs(test_report_path, exist_ok=True)
test_report_path = os.path.join(test_report_path, f"{test_filename}-{os.urandom(8).hex()}.xml")
return test_report_path
os.makedirs(test_report_path, exist_ok=True)
return test_report_path
def sanitize_pytest_xml(xml_file: str):
@ -1358,7 +1346,7 @@ def run_tests(argv=None):
pytest_args = argv + ["--use-main-module"]
test_report_path = ""
if TEST_SAVE_XML:
test_report_path = get_report_path(TEST_SAVE_XML, argv[0])
test_report_path = get_report_path(pytest=True)
print(f'Test results will be stored in {test_report_path}')
pytest_args.append(f'--junit-xml-reruns={test_report_path}')
if PYTEST_SINGLE_TEST:
@ -1402,7 +1390,7 @@ def run_tests(argv=None):
def printErrors(self) -> None:
super().printErrors()
self.printErrorList("XPASS", self.unexpectedSuccesses)
test_report_path = get_report_path(TEST_SAVE_XML, argv[0])
test_report_path = get_report_path()
verbose = '--verbose' in argv or '-v' in argv
if verbose:
print(f'Test results will be stored in {test_report_path}')

View File

@ -7,7 +7,8 @@ import inspect
import sys
import warnings
from collections.abc import Callable
from typing import Any, cast, TypeVar
from typing import Any, cast, overload, TypeVar
from typing_extensions import Self
# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
@ -158,7 +159,12 @@ class _DecoratorContextManager:
class _NoParamDecoratorContextManager(_DecoratorContextManager):
"""Allow a context manager to be used as a decorator without parentheses."""
def __new__(cls, orig_func=None):
@overload
def __new__(cls, orig_func: F) -> F: ... # type: ignore[misc]
@overload
def __new__(cls, orig_func: None = None) -> Self: ...
def __new__(cls, orig_func: F | None = None) -> Self | F: # type: ignore[misc]
if orig_func is None:
return super().__new__(cls)
return cls()(orig_func)

View File

@ -311,7 +311,11 @@ def escape(n):
def is_cuda_tensor(obj):
return isinstance(obj, torch.Tensor) and obj.is_cuda and not isinstance(obj, torch._subclasses.FakeTensor)
return (
isinstance(obj, torch.Tensor) and
obj.device.type == "cuda" and
not isinstance(obj, torch._subclasses.FakeTensor)
)
def cuda_allocation_context():
snapshot = torch.cuda.memory._snapshot()