Compare commits

...

14 Commits

Author SHA1 Message Date
67f237a245 distributed/debug: add an HTTP server for debugging running jobs 2025-11-07 18:35:48 -08:00
e401a56b96 [ez] Remove some dead code from test artifact related files (#166966)
Remove circle ci path since it's no longer used

Remove function that is not used
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166966
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-11-07 18:14:44 +00:00
22650c89fb [ROCm] Update skip_if_lt_x_gpu to work with MultiProcContinuous class (#167281)
- Since MultiProcContinuous class spawns one process per GPU and runs UT in each of the processes, we need to ensure we are propagating the exit code associated with skip all the way to the main worker thread that spawned all the child processes.
- This commit also updates several UTs that are meant for 4 GPUs but incorrectly calls skip_if_lt_x_gpu with 2 as an input. Examples:
    - test_replicate_with_fsdp.py
    - test_dtensor_resharding.py
    - test_state_dict.py
    - test_functional_api.py: Fix typo. multi-accelerator doesn't exit, replaced with multi-gpu
    - test_op_strategy.py: world_size was hardcoded
    - test_math_ops.py: UT written for 4 GPU, so skipping for anything less
    - test_schedule_multiproc.py: All UTs in this suite are required to run on 2+ GPUs, therefore, adding skips if less than 4 GPUs are supplied

Fixes https://github.com/pytorch/pytorch/issues/166875

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167281
Approved by: https://github.com/jeffdaily
2025-11-07 18:11:48 +00:00
c62a17a2fb [ez] Remove some unused vars in common_utils.py (#166453)
I can't find where these are used
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166453
Approved by: https://github.com/malfet
2025-11-07 18:09:40 +00:00
713e289ae7 [dynamo][pytree] support more optree functions by polyfill the underlying CXX functions directly (#167292)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167292
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167221, #167211
2025-11-07 18:09:19 +00:00
69784a0dbe [dynamo][pytree] add polyfills for optree path APIs (#167211)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167211
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167221
2025-11-07 17:53:32 +00:00
3c2409c465 Refactor recursive call of collect_temp_source (#166714)
Recursive function call creates a reference cycle: closure <- function <- cell inside closure
Capturing self (PyCodegen instance) in same closure prolongs it's life until next gc.collect() which might result in worse resource management

After the introduction of e9209e0 OOM issues has been observed. Looking for reference cycles one has been uncovered that would result in the prolonging lifetime of tensors. As the result of that OOM issues might occur. Such a dependency chain has been uncovered:
<img width="1059" height="540" alt="image" src="https://github.com/user-attachments/assets/359a8534-e7cd-491f-be40-547c2af5cbbc" />

At the end of it a reference cycle can be found that consists of a closure for function collect_temp_source, the function itself, and a cell object inside closure that would point to the function due to the recursive call.

This issue can either be resolved by removing recurrency or removing PyCodegen instance from the closure.
Another precaution that can be made is to explicitly empty f_locals dict. This way we cut the tensor from the chain leading to reference cycle.

Fixes #166721

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166714
Approved by: https://github.com/Lucaskabela, https://github.com/Skylion007, https://github.com/jeromean, https://github.com/williamwen42, https://github.com/mlazos
2025-11-07 17:52:54 +00:00
724cd32b0c [PT2 Compiler] Add flag in dynamo disable wrapper to indicate reursive disable (#165790)
Summary: After torch._dynamo.disable is applied, wrapped method does not have any flag to indicate whether it was disabled recursively or not. This flag is needed if to preserve dynamo disable methods in torch.export-ed model

Test Plan:
```
buck test mode/opt caffe2/test/dynamo:test_dynamo -- 'test_disable_recursive_flags'
````
https://www.internalfb.com/intern/testinfra/testrun/7599824674075603

Differential Revision: D84949143

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165790
Approved by: https://github.com/angelayi, https://github.com/williamwen42
2025-11-07 17:48:20 +00:00
b62935d1a5 fix alpha beta in decomp (#167317)
fix for https://github.com/pytorch/pytorch/issues/167313

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167317
Approved by: https://github.com/zou3519
ghstack dependencies: #161404
2025-11-07 17:42:13 +00:00
ccc8c117dc Codeowner/Labeler updates post-Blas-reorgs (#167130)
Summary:

Previous PRs have split out scaled/grouped Blas routines into
their own files. This updates the codeowners and labeler to reflect
those changes.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167130
Approved by: https://github.com/drisspg
2025-11-07 17:27:41 +00:00
86db4de10f [PP] PP Runtime Features for supporting Graph Based execution (#167277)
Allow overriding UNSHARD, RESHARD and REDUCE_GRAD actions.
Enable running pp backward without torch.grad.is_enabled().

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167277
Approved by: https://github.com/wconstab
2025-11-07 17:11:14 +00:00
12860892f8 Revert "[Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#167182)"
This reverts commit 77b70970f70d53de71b9703ad4c3199d714c535a.

Reverted https://github.com/pytorch/pytorch/pull/167182 on behalf of https://github.com/NikhilAPatel due to breaks local source build ([comment](https://github.com/pytorch/pytorch/pull/167182#issuecomment-3503598156))
2025-11-07 16:45:23 +00:00
694592ac1e Move enrich_profiler_metadata config import out of gm.recompile() (#167114)
Fixes T243967987

Move `enrich_profiler_metadata` from `torch._dynamo.config` to `torch.fx.experimental._config`.

We cannot import anything inside recompile(), it made some perf regress internally. We move the config so we can import it at the top of `graph_module.py` without causing any circular import.

We also cannot delete the old config right now because some internal tests rely on copies of the old `graph_module.py` cpp file in unit tests. But I think we should be able to delete the old config soon after this PR lands.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167114
Approved by: https://github.com/angelayi
2025-11-07 16:12:47 +00:00
285748e838 fix the cpp_builder error under riscv (#167071)
**fix the cpp_builder error under riscv**

`g++: error: ‘-march=native’: ISA string must begin with rv32 or rv64`

(EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779]   File "/usr/local/lib64/python3.11/site-packages/torch/_inductor/cpp_builder.py", line 1718, in build
(EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779]     run_compile_cmd(build_cmd, cwd=_build_tmp_dir)
(EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779]   File "/usr/local/lib64/python3.11/site-packages/torch/_inductor/cpp_builder.py", line 401, in run_compile_cmd
(EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779]     _run_compile_cmd(cmd_line, cwd)
(EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779]   File "/usr/local/lib64/python3.11/site-packages/torch/_inductor/cpp_builder.py", line 396, in _run_compile_cmd
(EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779]     raise exc.CppCompileError(cmd, output) from e
(EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] torch._inductor.exc.InductorError: CppCompileError: C++ compile error
(EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779]
(EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] Command:
(EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] g++ /tmp/tmpv8qz53jp/header.hpp -D TORCH_INDUCTOR_CPP_WRAPPER -D STANDALONE_TORCH_HEADER -D C10_USING_CUSTOM_GENERATED_MACROS -fPIC -O3 -DNDEBUG -fno-trapping-math -funsafe-math-optimizations -ffinite-math-only -fno-signed-zeros -fno-math-errno -fexcess-precision=fast -fno-finite-math-only -fno-unsafe-math-optimizations -ffp-contract=off -fno-tree-loop-vectorize -march=native -Wall -std=c++17 -Wno-unused-variable -Wno-unknown-pragmas -fopenmp -I/usr/include/python3.11 -I/usr/local/lib64/python3.11/site-packages/torch/include -I/usr/local/lib64/python3.11/site-packages/torch/include/torch/csrc/api/include -D_GLIBCXX_USE_CXX11_ABI=1 -E -P -o /tmp/tmpv8qz53jp/header.i
(EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779]
(EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] Output:
(EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] g++: error: ‘-march=native’: ISA string must begin with rv32 or rv64

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167071
Approved by: https://github.com/malfet
2025-11-07 16:01:30 +00:00
45 changed files with 824 additions and 1064 deletions

View File

@ -337,7 +337,7 @@ test_python() {
test_python_smoke() {
# Smoke tests for H100/B200
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
assert_git_not_dirty
}

9
.github/labeler.yml vendored
View File

@ -138,7 +138,8 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/Blas.cpp
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- torch/**/*cublas*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
@ -148,7 +149,8 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/Blas.cpp
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- torch/**/*cublas*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
@ -158,7 +160,8 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/Blas.cpp
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
- third_party/fbgemm

1
.gitignore vendored
View File

@ -127,7 +127,6 @@ torch/test/
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
torch/version.py
torch/_inductor/kernel/vendored_templates/*
minifier_launcher.py
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*

View File

@ -210,8 +210,12 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
/test/inductor/test_flex_attention.py @drisspg
/test/inductor/test_flex_decoding.py @drisspg
# Low Precision GEMMs
# Low Precision & Grouped GEMMs
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
/aten/src/ATen/native/cuda/GroupedBlas.cpp @drisspg @slayton58
/aten/src/ATen/native/cuda/ScaledBlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
/aten/src/ATen/cuda/CUDAScaledBlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDAScaledBlas.h @drisspg @slayton58
/test/test_scaled_matmul_cuda.py @drisspg @slayton58

View File

@ -630,37 +630,6 @@ def mirror_files_into_torchgen() -> None:
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
def mirror_inductor_external_kernels() -> None:
"""
Copy external kernels into Inductor so they are importable.
"""
paths = [
(
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
CWD
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
),
]
for new_path, orig_path in paths:
# Create the dirs involved in new_path if they don't exist
if not new_path.exists():
new_path.parent.mkdir(parents=True, exist_ok=True)
# Copy the files from the orig location to the new location
if orig_path.is_file():
shutil.copyfile(orig_path, new_path)
continue
if orig_path.is_dir():
if new_path.exists():
# copytree fails if the tree exists already, so remove it.
shutil.rmtree(new_path)
shutil.copytree(orig_path, new_path)
continue
raise RuntimeError(
"Check the file paths in `mirror_inductor_external_kernels()`"
)
# ATTENTION: THIS IS AI SLOP
def extract_variant_from_version(version: str) -> str:
"""Extract variant from version string, defaulting to 'cpu'."""
@ -1647,8 +1616,6 @@ def main() -> None:
if RUN_BUILD_DEPS:
build_deps()
mirror_inductor_external_kernels()
(
ext_modules,
cmdclass,
@ -1682,7 +1649,6 @@ def main() -> None:
"_inductor/codegen/aoti_runtime/*.cpp",
"_inductor/script.ld",
"_inductor/kernel/flex/templates/*.jinja",
"_inductor/kernel/templates/*.jinja",
"_export/serde/*.yaml",
"_export/serde/*.thrift",
"share/cmake/ATen/*.cmake",

View File

@ -76,7 +76,7 @@ class ReplicateTest(MultiProcessTestCase):
store=dist.FileStore(self.file_name, self.world_size),
)
@skip_if_lt_x_gpu(2)
@skip_if_lt_x_gpu(4)
def test_replicate_transformer(self):
"""
This tests that replicate works on a transformer model with fully_shard and replicate layers
@ -126,7 +126,7 @@ class ReplicateTest(MultiProcessTestCase):
for parameter in layer.parameters():
self.assertEqual(parameter.placements, (Shard(dim=0),))
@skip_if_lt_x_gpu(2)
@skip_if_lt_x_gpu(4)
def test_replicate_transformer_managed_modules(self):
"""
This tests that replicate managed modules works properly. In this test we use a Transformer Module with 3 layers,
@ -178,7 +178,7 @@ class ReplicateTest(MultiProcessTestCase):
replicate_model = replicate(replicate_model)
self.assertEqual(len(_get_managed_modules((replicate_model,))), 21)
@skip_if_lt_x_gpu(2)
@skip_if_lt_x_gpu(4)
def test_replicate_tp_device_mesh(self):
"""
This tests that a user can pass in a device mesh to replicate a module
@ -206,7 +206,7 @@ class ReplicateTest(MultiProcessTestCase):
self.assertEqual(parameter.device_mesh.shape, (2,))
self.assertEqual(parameter.placements, (Replicate(),))
@skip_if_lt_x_gpu(2)
@skip_if_lt_x_gpu(4)
def test_train_replicate_fsdp(self):
"""
Tests that replicate_model has the same behavior as original model when training
@ -253,7 +253,7 @@ class ReplicateTest(MultiProcessTestCase):
self.assertEqual(replicate_loss, loss)
check_sharded_parity(self, model, replicate_model)
@skip_if_lt_x_gpu(2)
@skip_if_lt_x_gpu(4)
def test_train_parity_2d_mlp(self):
"""
Verifies when a device mesh is passed in, the model has the same behavior as the original model when training

View File

@ -299,7 +299,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(2)
@skip_if_lt_x_gpu(4)
def test_dtensor_checkpoint_with_uneven_shards(self) -> None:
"""
Saving a dtensor with uneven shards.
@ -436,6 +436,7 @@ class TestCheckpointableReshard(DTensorTestBase):
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
def test_uneven_reshard_with_checkpointable_api(self) -> None:
"""
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.
@ -498,6 +499,7 @@ class TestCheckpointableReshard(DTensorTestBase):
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
def test_uneven_reshard_with_dtensor_shards_wrapper_api(self) -> None:
"""
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.

View File

@ -886,7 +886,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self.assertEqual(cpu_model_value, meta_model_value)
@with_comms
@skip_if_lt_x_gpu(2)
@skip_if_lt_x_gpu(4)
def test_setting_meta_device_model_broadcasting_and_memory(self) -> None:
# This test verifies that we can set model state dict by a meta device model
# With the correlated changes in state_dict, meta device model should be accepted

View File

@ -39,6 +39,7 @@ from torch.nn.modules.loss import MSELoss
from torch.testing._internal.common_distributed import (
MultiProcContinuousTest,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
check_leaked_tensors,
@ -231,6 +232,7 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [_ScheduleForwardOnly])
@skip_if_lt_x_gpu(4)
def test_forward_only(self, ScheduleClass):
mod, mod_ref, x, _, _ = setup_models_and_data(self.config)
x_clone = x.clone()
@ -274,6 +276,7 @@ class ScheduleTest(MultiProcContinuousTest):
ScheduleInterleavedZeroBubble,
],
)
@skip_if_lt_x_gpu(4)
def test_eval_inference_mode(self, ScheduleClass):
num_microbatches = 4
if ScheduleClass in [
@ -351,6 +354,7 @@ class ScheduleTest(MultiProcContinuousTest):
ScheduleInterleavedZeroBubble,
],
)
@skip_if_lt_x_gpu(4)
def test_return_output(self, ScheduleClass):
num_microbatches = 4
if ScheduleClass in [
@ -406,6 +410,7 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@skip_if_lt_x_gpu(4)
def test_multi_iter(self, ScheduleClass):
mod, _, x, target, loss_fn = setup_models_and_data(self.config)
chunks = 4
@ -429,6 +434,7 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@skip_if_lt_x_gpu(4)
def test_kwargs_with_tracer(self, ScheduleClass):
mod = ModelWithKwargs(d_hid, splits=self.world_size)
mod.to(self.device)
@ -481,6 +487,7 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@skip_if_lt_x_gpu(4)
def test_grad_with_tracer(self, ScheduleClass):
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
@ -523,6 +530,7 @@ class ScheduleTest(MultiProcContinuousTest):
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@parametrize("shape_inference", [True, False])
@skip_if_lt_x_gpu(4)
def test_grad_with_manual(self, ScheduleClass, shape_inference):
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
@ -586,6 +594,7 @@ class ScheduleTest(MultiProcContinuousTest):
ScheduleInterleavedZeroBubble,
],
)
@skip_if_lt_x_gpu(4)
def test_grad_with_manual_interleaved(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
@ -650,6 +659,7 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble])
@skip_if_lt_x_gpu(4)
def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
@ -736,6 +746,7 @@ class ScheduleTest(MultiProcContinuousTest):
"schedule_class",
[ScheduleZBVZeroBubble, ScheduleDualPipeV],
)
@skip_if_lt_x_gpu(4)
def test_v_shape_schedules(self, schedule_class):
n_stages = 8
rank_stages = {0: [0, 7], 1: [1, 6], 2: [2, 5], 3: [3, 4]}
@ -780,6 +791,7 @@ class ScheduleTest(MultiProcContinuousTest):
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@skip_if_lt_x_gpu(4)
def test_custom_function_callback(self):
"""Test the custom function callback functionality with _PipelineScheduleRuntime."""
n_stages = 8
@ -979,6 +991,7 @@ class ScheduleTest(MultiProcContinuousTest):
"ScheduleClass",
[ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B],
)
@skip_if_lt_x_gpu(4)
def test_zero_bubble_with_model_kwargs(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
@ -1072,6 +1085,7 @@ class CustomSchedulesTest(MultiProcContinuousTest):
"schedule_class",
[ScheduleVShaped, ScheduleUnbalanced],
)
@skip_if_lt_x_gpu(4)
def test_non_symmetric_stage_ids(self, schedule_class):
n_stages = schedule_class.n_stages
rank_stages = schedule_class.rank_stages
@ -1121,6 +1135,7 @@ class CustomSchedulesTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleWithReorderedB])
@skip_if_lt_x_gpu(4)
def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass):
n_stages = 2
stages_per_rank = 1
@ -1181,6 +1196,7 @@ class CustomSchedulesTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleWithW])
@skip_if_lt_x_gpu(4)
def test_schedule_with_native_zero_bubble(self, ScheduleClass):
n_stages = ScheduleClass.n_stages
num_microbatches = ScheduleClass.num_microbatches

View File

@ -26,6 +26,7 @@ from torch.distributed.tensor.parallel import (
RowwiseParallel,
SequenceParallel,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
@ -764,6 +765,7 @@ class DistMathOpsTest(DTensorTestBase):
self.assertEqual(grad1_norm.device_mesh, mesh_y)
@with_comms
@skip_if_lt_x_gpu(4)
def test_foreach_add_different_mesh(self):
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(

View File

@ -577,7 +577,7 @@ class DistTensorReplicateStrategyRegistrationTest(DTensorTestBase):
self.assertEqual(
comm_mode.get_comm_counts(),
{
torch.ops.c10d_functional.all_gather_into_tensor: 4,
torch.ops.c10d_functional.all_gather_into_tensor: self.world_size,
},
)
expected_cost = [

View File

@ -485,7 +485,7 @@ elif TEST_XPU:
def exit_if_lt_x_accelerators(x):
if torch.accelerator.is_available():
if torch.accelerator.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code)
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
def with_comms(func=None):

View File

@ -2109,6 +2109,52 @@ Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: '
with self.assertRaises(Unsupported):
outer_f2(inp)
def test_disable_recursive_flags(self):
class SimpleLinear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer0 = torch.nn.Linear(4, 4)
def forward(self, inp):
return self.layer0(torch.sigmoid(inp))
class SimpleModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer0 = SimpleLinear()
self.layer1 = torch.nn.Linear(4, 4)
def forward(self, inp):
z = self.layer0(torch.sin(inp))
return self.layer1(z)
for recursive_flag in [True, False]:
model = SimpleModel()
other_model = SimpleModel()
model.forward = torch._dynamo.disable(
model.forward,
recursive=recursive_flag,
)
self.assertEqual(
torch._dynamo.is_dynamo_disable_recursive(model.forward),
recursive_flag,
)
other_model = torch._dynamo.disable(other_model, recursive=recursive_flag)
self.assertEqual(
torch._dynamo.is_dynamo_disable_recursive(
other_model.forward
if isinstance(other_model, torch.nn.Module)
else other_model
),
recursive_flag,
)
# check the model is compilable
torch.compile(model)
torch.compile(other_model)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -422,34 +422,41 @@ from user code:
import optree
@torch.compile(backend="eager")
def fn(x):
d = {"a": 1}
optree.tree_flatten_with_path(d)
return torch.sin(x)
def post_munge(s):
s = re.sub(
r"optree\.\S*\.flatten_with_path",
"optree.<path>.flatten_with_path",
s,
)
return re.sub(
r"qualname: \S*flatten_with_path",
"qualname: <path>.flatten_with_path",
s,
def fn1(x):
tree = {"a": x, "b": (x - 1, 2 * x)}
sin, cos = optree.tree_transpose_map(
lambda t: (torch.sin(t), torch.cos(t)),
tree,
)
return sin, cos
fn(torch.randn(4))
self.assertEqual(len(counters["graph_break"]), 1)
fn1(torch.randn(4))
self.assertEqual(len(counters["graph_break"]), 0)
@torch.compile(backend="eager")
def fn2(x):
spec = optree.treespec_deque([])
return spec, x
fn2(torch.randn(4))
self.assertGreaterEqual(len(counters["graph_break"]), 1)
first_graph_break = next(iter(counters["graph_break"].keys()))
def post_munge(string):
return re.sub(
r"(optree\.|qualname: )\S*(\.make_from_collection)",
r"\1<path>\2",
string,
)
self.assertExpectedInline(
post_munge(first_graph_break),
"""\
Attempted to call function marked as skipped
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.flatten_with_path.
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.make_from_collection.
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
Developer debug context: module: optree._C, qualname: <path>.flatten_with_path, skip reason: <missing reason>
Developer debug context: module: optree._C, qualname: <path>.make_from_collection, skip reason: <missing reason>
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
)

View File

@ -1,154 +0,0 @@
# Owner(s): ["module: inductor"]
import unittest
import torch
from torch import Tensor
from torch._inductor import config
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
from torch._inductor.utils import ensure_cute_available
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
@unittest.skipIf(
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
"CuTeDSL library or Blackwell device not available",
)
@instantiate_parametrized_tests
class TestCuTeDSLGroupedGemm(InductorTestCase):
def _get_inputs(
self,
group_size: int,
M_hint: int,
K: int,
N: int,
device: str,
dtype: torch.dtype,
alignment: int = 16,
) -> tuple[Tensor, Tensor, Tensor]:
# --- Random, tile-aligned M sizes ---
M_sizes = (
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
* alignment
)
M_total = torch.sum(M_sizes).item()
# --- Construct input tensors ---
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
# --- Build offsets (no leading zero, strictly increasing) ---
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
return (A, B, offsets)
@parametrize("group_size", (2, 8))
@parametrize("M_hint", (256, 1024))
@parametrize("K", (64, 128))
@parametrize("N", (128, 256))
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
device = "cuda"
dtype = torch.bfloat16
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
def grouped_gemm_fn(A_packed, B_batched, offs):
return torch._grouped_mm(A_packed, B_batched, offs=offs)
# Eager execution
c_eager = grouped_gemm_fn(A, B, offsets)
# Test with Cute backend
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTEDSL",
"test_configs.autotune_choice_name_regex": "cutedsl",
"autotune_fallback_to_aten": False,
}
):
grouped_gemm_compiled = torch.compile(
grouped_gemm_fn, backend="inductor", dynamic=False
)
c_compiled = grouped_gemm_compiled(A, B, offsets)
self.assertEqual(c_eager.dtype, dtype)
self.assertEqual(c_compiled.dtype, dtype)
torch.testing.assert_close(c_eager, c_compiled)
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
@parametrize("layout_B", ("contiguous", "broadcasted"))
def test_grouped_gemm_assorted_layouts(
self,
layout_A: str,
layout_B: str,
):
device = "cuda"
dtype = torch.bfloat16
G, K, N = 8, 64, 128
M_sizes = [128] * G
sum_M = sum(M_sizes)
offsets = torch.tensor(
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
)
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
A = A_base
if layout_A == "offset":
# allocate bigger buffer than needed, use nonzero storage offset
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
offset = 128 # skip first 128 elements
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
elif layout_A == "padded":
# simulate row pitch > K (row_stride = K + pad)
row_pitch = K + 8
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
elif layout_A == "view":
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
A = A_storage.view(sum_M, K)
assert A._base is not None
assert A.shape == (sum_M, K)
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
if layout_B == "broadcasted":
# Broadcast B across groups (zero stride along G)
B = B[0].expand(G, K, N)
assert B.stride(0) == 0
def grouped_gemm_fn(A_packed, B_batched, offs):
return torch._grouped_mm(A_packed, B_batched, offs=offs)
# --- eager ---
c_eager = grouped_gemm_fn(A, B, offsets)
# --- compiled (CUTE backend) ---
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTEDSL",
"test_configs.autotune_choice_name_regex": "cutedsl",
"autotune_fallback_to_aten": False,
}
):
grouped_gemm_compiled = torch.compile(
grouped_gemm_fn, backend="inductor", dynamic=False
)
c_compiled = grouped_gemm_compiled(A, B, offsets)
self.assertEqual(c_eager.dtype, dtype)
self.assertEqual(c_compiled.dtype, dtype)
torch.testing.assert_close(c_eager, c_compiled)
if __name__ == "__main__":
run_tests()

View File

@ -1217,6 +1217,43 @@ class TestPatternMatcher(TestCase):
_, (code) = run_and_get_code(fn2, args[0], args[1], args[2])
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
def test_addmm_alpha_beta_with_pointwise(self):
# Test that addmm with alpha/beta != 1 is unfused correctly with pointwise ops
# See https://github.com/pytorch/pytorch/issues/167313
x = torch.rand(2, device=GPU_TYPE)
a = torch.rand(2, 3, device=GPU_TYPE)
b = torch.rand(3, 2, device=GPU_TYPE)
def f(x, a, b):
return torch.nn.functional.relu(torch.addmm(x, a, b, alpha=0.8, beta=0.2))
fc = torch.compile(f)
expected = f(x, a, b)
actual = fc(x, a, b)
# The compiled version should produce the same result as eager
torch.testing.assert_close(actual, expected)
# Verify that addmm is unfused (should not use extern_kernels.addmm)
# The pattern should be replaced with beta * x + alpha * (a @ b)
_, (code) = run_and_get_code(fc, x, a, b)
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
# Test with alpha=1, beta=1 (default) - should also unfuse
def f_default(x, a, b):
return torch.nn.functional.relu(torch.addmm(x, a, b))
fc_default = torch.compile(f_default)
expected_default = f_default(x, a, b)
actual_default = fc_default(x, a, b)
torch.testing.assert_close(actual_default, expected_default)
# Should unfuse and not use extern_kernels.addmm
_, (code) = run_and_get_code(fc_default, x, a, b)
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
def test_serialized_patterns_up_to_date(self):
import torch.utils._pytree as pytree
from torch._inductor.fx_passes import joint_graph

View File

@ -7486,7 +7486,7 @@ class TestFXMemoryProfiler(TestCase):
return fx_frames
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
def test_fx_memory_profiler_augmentation(self):
"""Test that memory snapshots are augmented with FX debug information."""

View File

@ -4251,7 +4251,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
def test_profiler_stack_trace_augmentation(self):
"""
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
@ -4307,7 +4307,7 @@ event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
def test_profiler_multiple_modules(self):
"""
Test that multiple compiled modules under the same profiler session
@ -4351,7 +4351,7 @@ event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
def test_profiler_nested_graph_modules(self):
"""
Test that nested graph modules (e.g., graph modules calling subgraphs)

View File

@ -165,21 +165,6 @@ def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> list[dict[str,
return flattened
def get_tests_for_circleci(
workflow_run_id: int, workflow_run_attempt: int
) -> list[dict[str, Any]]:
# Parse the reports and transform them to JSON
test_cases = []
for xml_report in Path(".").glob("**/test/test-reports/**/*.xml"):
test_cases.extend(
parse_xml_report(
"testcase", xml_report, workflow_run_id, workflow_run_attempt
)
)
return test_cases
def summarize_test_cases(test_cases: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Group test cases by classname, file, and job_id. We perform the aggregation
manually instead of using the `test-suite` XML tag because xmlrunner does
@ -258,21 +243,11 @@ if __name__ == "__main__":
required=True,
help="Head repository of the workflow",
)
parser.add_argument(
"--circleci",
action="store_true",
help="If this is being run through circleci",
)
args = parser.parse_args()
print(f"Workflow id is: {args.workflow_run_id}")
if args.circleci:
test_cases = get_tests_for_circleci(
args.workflow_run_id, args.workflow_run_attempt
)
else:
test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt)
test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt)
# Flush stdout so that any errors in the upload show up last in the logs.
sys.stdout.flush()

View File

@ -100,7 +100,9 @@ class Logger:
def _set_static_graph(self) -> None: ...
class _WorkerServer:
def __init__(self, socket_path: str) -> None: ...
port: int
def __init__(self, host_or_file: str, port: int = ...) -> None: ...
def shutdown(self) -> None: ...
def get_debug_level(): ...

View File

@ -32,6 +32,7 @@ from .decorators import (
error_on_graph_break,
forbid_in_graph,
graph_break,
is_dynamo_disable_recursive,
mark_dynamic,
mark_static,
mark_static_address,
@ -87,6 +88,7 @@ __all__ = [
"forbid_in_graph",
"graph_break",
"is_compiling",
"is_dynamo_disable_recursive",
"list_backends",
"lookup_backend",
"mark_dynamic",

View File

@ -15,7 +15,7 @@ import dataclasses
import re
import sys
import types
from collections import Counter
from collections import Counter, deque
from collections.abc import Callable, Iterable
from typing import Any, Optional, TYPE_CHECKING, Union
@ -597,32 +597,35 @@ class PyCodegen:
graphargs = self.tx.output.graphargs
seen_sources: OrderedSet[Source] = OrderedSet()
def collect_temp_source(source: Source) -> None:
if source in seen_sources:
# This source is used at least twice, so it can be reused
self.mark_source_temp(source)
# Dont trace source further. This prevents us from marking too
# many nodes as temp sources.
return
seen_sources.add(source)
def extract_nested_sources(source: Source) -> list[Source]:
nested_sources: list[Source] = []
if isinstance(source, ChainedSource):
collect_temp_source(source.base)
nested_sources.append(source.base)
if isinstance(source, DictGetItemSource) and isinstance(
source.index, Source
):
collect_temp_source(source.index)
nested_sources.append(source.index)
return nested_sources
def collect_temp_sources(sources: deque[Source], codegen: PyCodegen) -> None:
seen_sources: OrderedSet[Source] = OrderedSet()
while sources:
current_source = sources.popleft()
if current_source in seen_sources:
# This source is used at least twice, so it can be reused
codegen.mark_source_temp(current_source)
# Dont trace source further. This prevents us from marking too
# many nodes as temp sources.
continue
seen_sources.add(current_source)
sources.extend(extract_nested_sources(current_source))
# Collect all the sources that are used more than once, so that we can
# generate tmp variables in the generated pre-graph bytecode. This
# essentially implements CSE.
for arg in graphargs:
if arg.source is not None:
collect_temp_source(arg.source)
collect_temp_sources(
deque([arg.source for arg in graphargs if arg.source is not None]), self
)
cm_var = None
if config.record_runtime_overhead:

View File

@ -740,11 +740,8 @@ enable_aot_compile = False
# HACK: this is for testing custom ops profiling only
_custom_ops_profile: Optional[Any] = None
# Experimental: If True, graph module will register fx metadata during recompile()
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
default=False,
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
)
# Deprecated! Please use the config in torch/fx/experimental/_config instead.
enrich_profiler_metadata: bool = False
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -828,6 +828,7 @@ def trace_frame(
raise
finally:
tracer.output.call_cleanup_hooks()
tracer.f_locals = {}
try:
run_tracer()

View File

@ -96,6 +96,7 @@ def disable(fn=None, recursive=True, *, reason=None, wrapping=True): # type: ig
nonrecursive_disable_wrapper._torchdynamo_disable = True # type: ignore[attr-defined]
nonrecursive_disable_wrapper._torchdynamo_disable_msg = reason # type: ignore[attr-defined]
nonrecursive_disable_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
nonrecursive_disable_wrapper._torchdynamo_disable_recursive = False # type: ignore[attr-defined]
# pyrefly: ignore [bad-return]
return nonrecursive_disable_wrapper
@ -1023,3 +1024,13 @@ def error_on_graph_break(
The default value of torch.compile's `error_on_graph_break` setting is False.
"""
return ErrorOnGraphBreakDecoratorContextManager(error_on_graph_break)
def is_dynamo_disable_recursive(method: Callable[[Any], Any]) -> Optional[bool]:
"""
Check if a method is marked as `dynamo_disable` recursively. It returns:
- True if disable(recursive=True)
- False if disable(recursive=False)
- None if method is not a disable decorator
"""
return getattr(method, "_torchdynamo_disable_recursive", None)

View File

@ -1155,6 +1155,8 @@ class DisableContext(_TorchDynamoContext):
# of decorators.
_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
_fn._torchdynamo_disable_recursive = True # type: ignore[attr-defined]
return _fn
def __reduce__(self) -> tuple[type[DisableContext], tuple[Any, ...]]:

View File

@ -11,6 +11,16 @@ from typing import Any, TYPE_CHECKING, TypeVar
import optree
import optree._C
import optree.utils
from optree import (
is_namedtuple,
is_namedtuple_class,
is_namedtuple_instance,
is_structseq,
is_structseq_class,
is_structseq_instance,
namedtuple_fields,
structseq_fields,
)
import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
@ -26,7 +36,26 @@ if TYPE_CHECKING:
from torch.utils._cxx_pytree import PyTree
__all__: list[str] = []
__all__ = [
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
"namedtuple_fields",
"structseq_fields",
"treespec_leaf",
"treespec_tuple",
"treespec_dict",
"tree_is_leaf",
"tree_iter",
"tree_leaves",
"tree_flatten",
"tree_flatten_with_path",
"tree_structure",
"tree_unflatten",
]
_T = TypeVar("_T")
@ -48,21 +77,20 @@ def _(*args: Any, **kwargs: Any) -> bool:
__name = ""
for __name in (
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
"namedtuple_fields",
"structseq_fields",
for __name, __func in (
("is_namedtuple", is_namedtuple),
("is_namedtuple_class", is_namedtuple_class),
("is_namedtuple_instance", is_namedtuple_instance),
("is_structseq", is_structseq),
("is_structseq_class", is_structseq_class),
("is_structseq_instance", is_structseq_instance),
("namedtuple_fields", namedtuple_fields),
("structseq_fields", structseq_fields),
):
__func = getattr(optree, __name)
globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)(
__func.__python_implementation__
)
__all__ += [__name] # noqa: PLE0604
globals()[__name] = substitute_in_graph(
__func, # type: ignore[arg-type]
can_constant_fold_through=True,
)(__func.__python_implementation__) # type: ignore[attr-defined]
del __func
del __name
@ -78,7 +106,7 @@ def tree_is_leaf(
) -> bool:
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)):
return True
if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: # type: ignore[attr-defined]
if optree.register_pytree_node.get(type(tree), namespace=namespace) is None:
return True
return False
@ -113,9 +141,6 @@ def tree_iter(
stack.extend(reversed(children))
__all__ += ["tree_iter"]
@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) # type: ignore[arg-type]
def tree_leaves(
tree: PyTree,
@ -135,9 +160,6 @@ def tree_leaves(
)
__all__ += ["tree_leaves"]
class _Asterisk(str):
__slots__ = ()
@ -168,7 +190,7 @@ class PyTreeSpec:
num_leaves: int = field(init=False)
num_children: int = field(init=False)
def __post_init__(self) -> None:
def __post_init__(self, /) -> None:
if self._type is None:
assert len(self._children) == 0
assert self._metadata is None
@ -187,7 +209,7 @@ class PyTreeSpec:
object.__setattr__(self, "num_leaves", num_leaves)
object.__setattr__(self, "num_children", num_children)
def __repr__(self) -> str:
def __repr__(self, /) -> str:
def helper(treespec: PyTreeSpec) -> str:
if treespec.is_leaf():
assert treespec.type is None
@ -221,29 +243,78 @@ class PyTreeSpec:
]
return f"PyTreeSpec({', '.join(inner)})"
def __len__(self) -> int:
def __len__(self, /) -> int:
return self.num_leaves
@property
def type(self) -> builtins.type | None:
def type(self, /) -> builtins.type | None:
return self._type
def is_leaf(self) -> bool:
def is_leaf(self, /) -> bool:
return self.num_nodes == 1 and self.num_leaves == 1
def children(self) -> list[PyTreeSpec]:
def paths(self, /) -> list[tuple[Any, ...]]:
def helper(treespec: PyTreeSpec, path_prefix: list[Any]) -> None:
if treespec.is_leaf():
paths.append(path_prefix)
return
for entry, subspec in zip(
treespec._entries,
treespec._children,
strict=True,
):
helper(subspec, path_prefix + [entry])
paths: list[list[Any]] = []
helper(self, [])
return [tuple(path) for path in paths]
def accessors(self, /) -> list[optree.PyTreeAccessor]:
def helper(
treespec: PyTreeSpec,
entry_path_prefix: list[optree.PyTreeEntry],
) -> None:
if treespec.is_leaf():
entry_paths.append(entry_path_prefix)
return
node_type = treespec.type
assert node_type is not None
handler = optree.register_pytree_node.get(
node_type, namespace=treespec.namespace
)
assert handler is not None
kind: optree.PyTreeKind = handler.kind
path_entry_type: type[optree.PyTreeEntry] = handler.path_entry_type
for entry, subspec in zip(
treespec._entries,
treespec._children,
strict=True,
):
helper(
subspec,
entry_path_prefix + [path_entry_type(entry, node_type, kind)],
)
entry_paths: list[list[optree.PyTreeEntry]] = []
helper(self, [])
return [optree.PyTreeAccessor(path) for path in entry_paths]
def children(self, /) -> list[PyTreeSpec]:
return list(self._children)
def child(self, index: int) -> PyTreeSpec:
def child(self, index: int, /) -> PyTreeSpec:
return self._children[index]
def entries(self) -> list[Any]:
def entries(self, /) -> list[Any]:
return list(self._entries)
def entry(self, index: int) -> Any:
def entry(self, index: int, /) -> Any:
return self._entries[index]
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
def flatten_up_to(self, tree: PyTree, /) -> list[PyTree]:
def helper(
treespec: PyTreeSpec,
node: PyTree,
@ -324,14 +395,14 @@ class PyTreeSpec:
f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch
)
for subtree, subspec in zip(children, treespec._children):
for subtree, subspec in zip(children, treespec._children, strict=True):
helper(subspec, subtree, subtrees)
subtrees: list[PyTree] = []
helper(self, tree, subtrees)
return subtrees
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
def unflatten(self, leaves: Iterable[Any], /) -> PyTree:
if not isinstance(leaves, (list, tuple)):
leaves = list(leaves)
if len(leaves) != self.num_leaves:
@ -408,7 +479,7 @@ def treespec_tuple(
"All children PyTreeSpecs must have the same `namespace` value "
f"as the parent; expected {namespace!r}, got: {children!r}.",
)
handler = optree.register_pytree_node.get(tuple, namespace=namespace) # type: ignore[attr-defined]
handler = optree.register_pytree_node.get(tuple, namespace=namespace)
assert handler is not None
return PyTreeSpec(
tuple(children),
@ -531,7 +602,69 @@ def tree_flatten(
return leaves, treespec
__all__ += ["tree_flatten"]
@substitute_in_graph( # type: ignore[arg-type]
optree._C.flatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def _C_flatten(
tree: PyTree,
/,
leaf_predicate: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> tuple[list[Any], PyTreeSpec]:
return tree_flatten( # type: ignore[return-value]
tree,
is_leaf=leaf_predicate,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
@substitute_in_graph( # type: ignore[arg-type]
optree.tree_flatten_with_path,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_flatten_with_path(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]:
leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
return treespec.paths(), leaves, treespec # type: ignore[return-value]
@substitute_in_graph( # type: ignore[arg-type]
optree._C.flatten_with_path,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def _C_flatten_with_path(
tree: PyTree,
/,
leaf_predicate: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]:
return tree_flatten_with_path( # type: ignore[return-value]
tree,
is_leaf=leaf_predicate,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
@substitute_in_graph( # type: ignore[arg-type]
@ -556,9 +689,6 @@ def tree_structure(
)[1]
__all__ += ["tree_structure"]
@substitute_in_graph( # type: ignore[arg-type]
optree.tree_unflatten,
# We need to disable constant folding here because we want the function to reference the
@ -574,55 +704,6 @@ def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
return treespec.unflatten(leaves)
__all__ += ["tree_unflatten"]
@substitute_in_graph(optree.tree_map, can_constant_fold_through=True) # type: ignore[arg-type]
def tree_map(
func: Callable[..., Any],
tree: PyTree,
/,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTree:
leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
return treespec.unflatten(map(func, *flat_args))
__all__ += ["tree_map"]
@substitute_in_graph(optree.tree_map_, can_constant_fold_through=True) # type: ignore[arg-type]
def tree_map_(
func: Callable[..., Any],
tree: PyTree,
/,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTree:
leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
return tree
__all__ += ["tree_map_"]
_none_registration = optree.register_pytree_node.get(type(None))
assert _none_registration is not None
@ -669,5 +750,5 @@ def dict_unflatten(
) -> dict[_KT, _VT]:
original_keys, sorted_keys = metadata
d = dict.fromkeys(original_keys)
d.update(zip(sorted_keys, values))
d.update(zip(sorted_keys, values, strict=True))
return d # type: ignore[return-value]

View File

@ -550,10 +550,6 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
).upper() # type: ignore[assignment]
cutedsl_enable_autotuning: bool = (
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
)
# DEPRECATED. This setting is ignored.
autotune_fallback_to_aten = False

View File

@ -913,6 +913,10 @@ def _get_optimization_cflags(
if not config.is_fbcode():
if platform.machine() == "ppc64le":
cflags.append("mcpu=native")
elif platform.machine() == "riscv64":
cflags.append("march=rv64gc")
elif platform.machine() == "riscv32":
cflags.append("march=rv32gc")
else:
cflags.append("march=native")

View File

@ -1516,17 +1516,29 @@ def should_prefer_unfused_addmm(match):
@register_graph_pattern(
CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
CallFunction(
aten.addmm,
KeywordArg("inp"),
Arg(),
Arg(),
beta=KeywordArg("beta"),
alpha=KeywordArg("alpha"),
),
# pyrefly: ignore [bad-argument-type]
pass_dict=pass_patterns[2],
extra_check=should_prefer_unfused_addmm,
)
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):
def repl(inp, x1, x2):
return x1 @ x2 + inp
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp, alpha, beta):
def repl(inp, x1, x2, alpha, beta):
mm_result = x1 @ x2
if alpha != 1:
mm_result = alpha * mm_result
if beta != 1:
inp = beta * inp
return inp + mm_result
# pyrefly: ignore [bad-argument-type]
match.replace_by_example(repl, [inp, mat1, mat2])
match.replace_by_example(repl, [inp, mat1, mat2, alpha, beta])
def is_valid_addmm_fusion(match):

View File

@ -1,8 +1,6 @@
# mypy: allow-untyped-defs
import logging
from collections.abc import Sequence
from functools import partial
from pathlib import Path
from typing import Any
import torch
@ -14,7 +12,6 @@ from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
from .. import config
from ..codegen.wrapper import PythonWrapperCodegen
from ..ir import _IntLike, Layout, TensorBox
from ..utils import load_template
log = logging.getLogger(__name__)
@ -257,7 +254,3 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool:
return False
return True
_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates"
load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR)

View File

@ -1,13 +1,11 @@
# mypy: allow-untyped-defs
import logging
from dataclasses import asdict, dataclass
from dataclasses import dataclass
from typing import Any, Optional
import torch
from torch._dynamo.utils import counters
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
from torch._inductor.runtime.triton_compat import tl
from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs
from torch._inductor.virtualized import V
from torch.utils._triton import has_triton
@ -24,13 +22,11 @@ from ..utils import (
get_num_sms,
has_free_symbols,
use_aten_gemm_kernels,
use_blackwell_cutedsl_grouped_mm,
use_triton_template,
)
from .mm_common import (
_is_static_problem,
check_supported_striding,
load_kernel_template,
persistent_grouped_mm_grid,
)
@ -517,11 +513,6 @@ triton_scaled_grouped_mm_template = TritonTemplate(
source=triton_grouped_mm_source,
)
cutedsl_grouped_mm_template = CuteDSLTemplate(
name="grouped_gemm_cutedsl",
source=load_kernel_template("cutedsl_mm_grouped"),
)
def grouped_mm_args(
mat1: TensorBox,
@ -723,44 +714,43 @@ def _tuned_grouped_mm_common(
# Checking only for the equality of corresponding dims of
# multiplicands here, relying on meta function checks for
# everything else.
if len(m1_size) == 2:
if len(m2_size) == 2:
m, k1 = m1_size
k2, _ = m2_size
# pyrefly: ignore [missing-attribute]
g = offs.get_size()[0]
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, True
else:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, False
else:
if len(m2_size) == 2:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
g2, m, k1 = m1_size
k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, True
else:
g1, m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, False
if (
is_nonzero
and use_triton_template(layout)
and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
):
scaled = scale_a is not None
if len(m1_size) == 2:
if len(m2_size) == 2:
m, k1 = m1_size
k2, _ = m2_size
# pyrefly: ignore [missing-attribute]
g = offs.get_size()[0]
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, True
else:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, False
else:
if len(m2_size) == 2:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
g2, m, k1 = m1_size
k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, True
else:
g1, m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, False
a_is_k_major = mat_a.get_stride()[-1] == 1
b_is_k_major = mat_b.get_stride()[-2] == 1
@ -798,22 +788,6 @@ def _tuned_grouped_mm_common(
**config.kwargs,
)
if use_blackwell_cutedsl_grouped_mm(
mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result
):
for config in get_groupgemm_configs():
kwargs = dict(
ACC_DTYPE="cutlass.Float32",
)
cutedsl_grouped_mm_template.maybe_append_choice(
choices,
input_nodes=input_nodes,
layout=layout,
**kwargs,
**asdict(config),
)
input_gen_fns = {
4: lambda x: create_offsets(
x, m1_size, m2_size, offs.get_size() if offs is not None else None

View File

@ -1,333 +0,0 @@
import functools
from torch._inductor.runtime.runtime_utils import ceildiv
from cutlass.utils import TensorMapUpdateMode
{{gen_defines()}}
# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ----
from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import (
GroupedGemmKernel,
)
# Note about caching:
# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor
# maintains its own local caching system. At this stage, all compile-time
# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel
# name itself ({{kernel_name}}) are permanently baked into the file, so they
# do not need to be included in any cache key.
#
# The caching mechanism is split into two levels:
#
# 1. prep_cache
# Caches the compiled executor for build_group_ptrs_from_bases(). This
# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C,
# and can therefore be safely reused across runs with different group
# partitioning (`offs`).
#
# 2. gemm_cache
# Caches the compiled Grouped GEMM executor. Its key extends the prep
# cache key with hardware- and grid-specific parameters:
# (prep_cache_key, max_active_clusters, total_num_clusters).
# This is necessary because different `offs` tensors can change the
# per-group problem sizes and thus alter `total_num_clusters`, which in
# turn changes the grid shape and persistent scheduler configuration.
# Kernels compiled for one grid cannot be safely reused for another.
#
#
# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically,
# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead,
# despite depending only on the GPU type. We cache this function to mitigate
# redundant recompiles even when shape/stride/dtype cache misses force kernel
# regeneration. A follow-up study will investigate the root cause.
prep_cache = {}
gemm_cache = {}
@functools.lru_cache
def get_hardware_info():
hw = cutlass.utils.HardwareInfo()
sm_count = hw.get_max_active_clusters(1)
max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N)
return (sm_count, max_active_clusters)
def get_prep_cache_key(input_a, input_b, output):
"""
Returns a tuple key for caching the preprocessing kernel executor based on kernel name,
shapes, strides, and dtypes of input/output tensors.
"""
return (
tuple(input_a.shape),
tuple(input_a.stride()),
input_a.dtype,
tuple(input_b.shape),
tuple(input_b.stride()),
input_b.dtype,
tuple(output.shape),
tuple(output.stride()),
output.dtype,
)
def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters):
"""
Returns a tuple key for caching the gemm kernel executor by extending the
prep cache key with hardware- and grid-specific parameters.
"""
return (
prep_cache_key,
max_active_clusters,
total_num_clusters,
)
@cute.kernel
def build_group_ptrs_from_bases_kernel(
base_A_u64: cutlass.Int64, # device addr of input_a (bytes)
base_B_u64: cutlass.Int64, # device addr of input_b (bytes)
base_C_u64: cutlass.Int64, # device addr of Output (bytes)
offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative
K: cutlass.Constexpr,
N: cutlass.Constexpr,
sizeof_element: cutlass.Int32, # bytes
# -------- STRIDES (in ELEMENTS) --------
stride_A_m_elems: cutlass.Constexpr, # A.stride(0)
stride_A_k_elems: cutlass.Constexpr, # A.stride(1)
stride_B0_elems: cutlass.Constexpr, # B.stride(0)
stride_Bk_elems: cutlass.Constexpr, # B.stride(1)
stride_Bn_elems: cutlass.Constexpr, # B.stride(2)
stride_C_m_elems: cutlass.Constexpr, # C.stride(0)
stride_C_n_elems: cutlass.Constexpr, # C.stride(1)
# -------- OUTPUTS --------
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr)
out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1)
out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]]
):
tidx, _, _ = cute.arch.thread_idx()
g = tidx
m_beg_i32 = 0
if g > 0:
m_beg_i32 = offs[g - 1]
m_end_i32 = offs[g]
m_g_i32 = m_end_i32 - m_beg_i32
a_byte_off = (
cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element)
)
c_byte_off = (
cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element)
)
b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element)
# ---- pointers ----
out_ptrs[g, 0] = base_A_u64 + a_byte_off
out_ptrs[g, 1] = base_B_u64 + b_byte_off
out_ptrs[g, 2] = base_C_u64 + c_byte_off
# ---- (m, n, k, 1) ----
out_problem[g, 0] = m_g_i32
out_problem[g, 1] = N
out_problem[g, 2] = K
out_problem[g, 3] = cutlass.Int32(1)
# ---- strides ----
out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems)
out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems)
out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems)
out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems)
out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems)
out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems)
@cute.jit
def launch_build_group_ptrs_from_bases(
base_A_u64: cutlass.Int64,
base_B_u64: cutlass.Int64,
base_C_u64: cutlass.Int64,
offs: cute.Tensor,
G: cutlass.Constexpr,
K: cutlass.Constexpr,
N: cutlass.Constexpr,
sizeof_element: cutlass.Constexpr,
stride_A_m_elems: cutlass.Constexpr,
stride_A_k_elems: cutlass.Constexpr,
stride_B0_elems: cutlass.Constexpr,
stride_Bk_elems: cutlass.Constexpr,
stride_Bn_elems: cutlass.Constexpr,
stride_C_m_elems: cutlass.Constexpr,
stride_C_n_elems: cutlass.Constexpr,
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64
out_problem: cute.Tensor, # [G,4] cutlass.Int32
out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32
stream: cuda.CUstream,
):
build_group_ptrs_from_bases_kernel(
base_A_u64,
base_B_u64,
base_C_u64,
offs,
K,
N,
sizeof_element,
stride_A_m_elems,
stride_A_k_elems,
stride_B0_elems,
stride_Bk_elems,
stride_Bn_elems,
stride_C_m_elems,
stride_C_n_elems,
out_ptrs,
out_problem,
out_strides_abc,
).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream)
{{def_kernel("input_a", "input_b", "input_a_offs")}}
stream = cuda.CUstream(stream)
input_b = input_b.transpose(1, 2)
sumM, K = input_a.shape
G, N, Kb = input_b.shape
dev = input_a.device
base_A_u64 = int(input_a.data_ptr())
base_B_u64 = int(input_b.data_ptr())
base_C_u64 = int({{get_output()}}.data_ptr())
ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64)
probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32)
strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32)
ptrs = from_dlpack(ptrs_t)
probs = from_dlpack(probs_t)
strides = from_dlpack(strides_t)
prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}})
prep_executor = prep_cache.get(prep_cache_key)
if prep_executor is None:
sizeof_element = int(input_a.element_size())
sA_m, sA_k = map(int, input_a.stride())
sB_0, sB_n, sB_k = map(int, input_b.stride())
sC_m, sC_n = map(int, {{get_output()}}.stride())
prep_executor = cute.compile(
launch_build_group_ptrs_from_bases,
base_A_u64=base_A_u64,
base_B_u64=base_B_u64,
base_C_u64=base_C_u64,
offs=from_dlpack(input_a_offs),
G=int(G),
K=int(K),
N=int(N),
sizeof_element=sizeof_element,
stride_A_m_elems=sA_m,
stride_A_k_elems=sA_k,
stride_B0_elems=sB_0,
stride_Bk_elems=sB_k,
stride_Bn_elems=sB_n,
stride_C_m_elems=sC_m,
stride_C_n_elems=sC_n,
out_ptrs=ptrs,
out_problem=probs,
out_strides_abc=strides,
stream=stream,
)
prep_cache[prep_cache_key] = prep_executor
prep_executor(
base_A_u64=base_A_u64,
base_B_u64=base_B_u64,
base_C_u64=base_C_u64,
offs=from_dlpack(input_a_offs),
out_ptrs=ptrs,
out_problem=probs,
out_strides_abc=strides,
stream=stream,
)
# --- Tensormap workspace per SM ---
num_tensormap_buffers, max_active_clusters = get_hardware_info()
tensormap_shape = (
num_tensormap_buffers,
GroupedGemmKernel.num_tensormaps,
GroupedGemmKernel.bytes_per_tensormap // 8,
)
tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64)
tensormap_workspace = from_dlpack(tensormap_workspace_t)
# --- Total clusters ---
def compute_total_num_clusters(
problem_sizes_mnkl,
cluster_tile_shape_mn,
):
total_num_clusters = 0
for m, n, _, _ in problem_sizes_mnkl:
num_clusters_mn = tuple(
ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn)
)
total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
return total_num_clusters
# Compute cluster tile shape
def compute_cluster_tile_shape(
mma_tiler_mn,
cluster_shape_mn,
use_2cta_instrs,
):
cta_tile_shape_mn = list(mma_tiler_mn)
if use_2cta_instrs:
cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2
return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))
cluster_tile_shape_mn = compute_cluster_tile_shape(
(TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA)
)
total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn))
gemm_cache_key = get_gemm_cache_key(
prep_cache_key, max_active_clusters, total_num_clusters
)
gemm_executor = gemm_cache.get(gemm_cache_key)
if gemm_executor is None:
grouped_gemm = GroupedGemmKernel(
acc_dtype=ACC_DTYPE,
use_2cta_instrs=USE_2_CTA,
mma_tiler_mn=(TILE_M, TILE_N),
cluster_shape_mn=(CLUSTER_M, CLUSTER_N),
tensormap_update_mode=TENSORMAP_UPDATE_MODE,
)
gemm_executor = cute.compile(
grouped_gemm,
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
G,
probs,
strides,
ptrs,
total_num_clusters,
tensormap_workspace,
max_active_clusters,
stream,
)
gemm_cache[gemm_cache_key] = gemm_executor
gemm_executor(
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
probs,
strides,
ptrs,
tensormap_workspace,
stream,
)

View File

@ -1,141 +0,0 @@
from dataclasses import dataclass
from enum import auto, Enum
from itertools import product
import torch._inductor.config as config
class TensorMapUpdateMode(Enum):
"""Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency."""
SMEM = auto()
GMEM = auto()
@dataclass(frozen=True)
class CuTeGemmConfig:
TILE_M: int = 128
TILE_N: int = 192
CLUSTER_M: int = 2
CLUSTER_N: int = 1
USE_2_CTA: bool = False
TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM
def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
For information regarding valid config sets, see:
https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py
"""
# Tile_n is always the same regardless of 2cta
tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256]
# Valid clusters
clusters_no_2cta = [
(1, 1),
(1, 2),
(1, 4),
(1, 8),
(1, 16),
(2, 1),
(2, 2),
(2, 4),
(2, 8),
(4, 1),
(4, 2),
(4, 4),
(8, 1),
(8, 2),
(16, 1),
]
clusters_2cta = [
(2, 1),
(2, 2),
(2, 4),
(2, 8),
(4, 1),
(4, 2),
(4, 4),
(8, 1),
(8, 2),
(16, 1),
]
configs: list[CuTeGemmConfig] = []
for use_2cta, cluster_set, tile_m_range in [
(False, clusters_no_2cta, [64, 128]),
(True, clusters_2cta, [128, 256]),
]:
for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product(
[TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM],
tile_m_range,
tile_n_vals,
cluster_set,
):
configs.append(
CuTeGemmConfig(
tile_m,
tile_n,
cluster_m,
cluster_n,
USE_2_CTA=use_2cta,
TENSORMAP_UPDATE_MODE=tensormap_update_mode,
)
)
return configs
def get_default_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
"""
config_tuples = [
(128, 256, 2, 1, False, TensorMapUpdateMode.SMEM),
(256, 160, 2, 1, True, TensorMapUpdateMode.GMEM),
(256, 256, 2, 1, True, TensorMapUpdateMode.GMEM),
(64, 32, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
(128, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
(256, 256, 2, 2, True, TensorMapUpdateMode.GMEM),
(128, 256, 1, 2, False, TensorMapUpdateMode.GMEM),
(64, 32, 1, 1, False, TensorMapUpdateMode.SMEM),
(256, 256, 2, 1, True, TensorMapUpdateMode.SMEM),
(128, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
(256, 256, 8, 1, True, TensorMapUpdateMode.GMEM),
(64, 32, 1, 2, False, TensorMapUpdateMode.SMEM),
(256, 192, 2, 1, True, TensorMapUpdateMode.GMEM),
(256, 256, 2, 2, True, TensorMapUpdateMode.SMEM),
(128, 96, 1, 2, False, TensorMapUpdateMode.SMEM),
(64, 192, 1, 1, False, TensorMapUpdateMode.SMEM),
(64, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 192, 1, 1, False, TensorMapUpdateMode.GMEM),
(128, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 160, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
]
return [CuTeGemmConfig(*args) for args in config_tuples]
def get_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures
or unstable results. By default, autotuning is disabled and we return only
a single baseline config.
"""
if (
config.cutedsl_enable_autotuning
and config.max_autotune_gemm_search_space == "EXHAUSTIVE"
):
return get_exhaustive_groupgemm_configs()
elif config.cutedsl_enable_autotuning:
return get_default_groupgemm_configs()
else:
return [get_default_groupgemm_configs()[0]]

View File

@ -1911,84 +1911,6 @@ def use_triton_blackwell_tma_template(
return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
@functools.lru_cache(maxsize=1)
def ensure_cute_available() -> bool:
"""Check if CuTeDSL is importable; cache the result for reuse.
Call ensure_cute_available.cache_clear() after installing CuTeDSL
in the same interpreter to retry the import.
"""
try:
return importlib.util.find_spec("cutlass.cute") is not None
except ImportError:
return False
def use_blackwell_cutedsl_grouped_mm(
mat_a: Any,
mat_b: Any,
layout: Layout,
a_is_2d: bool,
b_is_2d: bool,
offs: Optional[Any],
bias: Optional[Any],
scale_result: Optional[Any],
) -> bool:
"""
Returns True if we can use the blackwell kernel for grouped mm.
Required conditions:
1. CuTeDSL backend is enabled
2. CuTeDSL is available
3. We are on a blackwell arch
4. The dtype is bf16
5. Max autotune or max autotune gemm is enabled
6. A, B, and the output are 16B aligned
7. We are not using dynamic shapes
8. A is 2d
9. B is 3d
10. Offsets are provided
11. Bias and Scale are not provided
"""
if not ensure_cute_available():
return False
if not _use_autotune_backend("CUTEDSL"):
return False
from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
if not is_gpu(layout.device.type):
return False
if not is_datacenter_blackwell_arch():
return False
layout_dtypes = [torch.bfloat16]
if not _use_template_for_gpu(layout, layout_dtypes):
return False
if not (config.max_autotune or config.max_autotune_gemm):
return False
# Checks for 16B ptr and stride alignment
if not can_use_tma(mat_a, mat_b, output_layout=layout):
return False
if any(is_dynamic(x) for x in [mat_a, mat_b]):
return False
if not a_is_2d or b_is_2d:
return False
if offs is None:
return False
if bias is not None or scale_result is not None:
return False
return True
def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
from .virtualized import V

View File

@ -1,5 +1,7 @@
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
#include <torch/csrc/distributed/c10d/FlightRecorder.hpp>
#include <fmt/format.h>
#include <mutex>
#include <shared_mutex>
@ -63,6 +65,14 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
res.setStatus(200);
}};
RegisterHandler frTracehandler(
"fr_trace_json",
[](const Request&, Response& res) {
auto trace = ::c10d::dump_fr_trace_json(true, true);
res.setContent(std::move(trace), "application/json");
res.setStatus(200);
});
} // namespace
void registerHandler(const std::string& name, HandlerFunc f) {

View File

@ -152,11 +152,17 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) {
TORCH_CHECK(
server_.bind_to_port(hostOrFile, 80),
fmt::format("Error binding to {}", hostOrFile));
} else if (port == 0) {
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
port_ = server_.bind_to_any_port(hostOrFile);
TORCH_CHECK(
port_ >= 0, fmt::format("Error binding to {}:{}", hostOrFile, port));
} else {
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
TORCH_CHECK(
server_.bind_to_port(hostOrFile, port),
fmt::format("Error binding to {}:{}", hostOrFile, port));
port_ = port;
}
serverThread_ = std::thread([this]() {

View File

@ -19,9 +19,14 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
void shutdown();
int port() {
return port_;
}
private:
httplib::Server server_;
std::thread serverThread_;
int port_;
};
} // namespace c10d::control_plane

View File

@ -46,6 +46,7 @@
#include <fmt/format.h>
#include <pybind11/chrono.h>
#include <pybind11/functional.h>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
@ -381,8 +382,9 @@ void _register_comm_hook(
::c10d::Reducer& reducer,
py::object state,
py::object comm_hook) {
reducer.register_comm_hook(std::make_unique<::c10d::PythonCommHook>(
std::move(state), std::move(comm_hook)));
reducer.register_comm_hook(
std::make_unique<::c10d::PythonCommHook>(
std::move(state), std::move(comm_hook)));
}
// Called from DDP's Python API to create a c10d C++ comm hook.
@ -882,37 +884,39 @@ This class does not support ``__members__`` property.)");
[](const ::c10d::ReduceOp& self, const py::dict& memo) {
return ::c10d::ReduceOp(self);
})
.def(py::pickle(
[](const ::c10d::ReduceOp& r) {
// __getstate__
if (r.op_ != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
return py::make_tuple(r.op_, py::none());
}
TORCH_CHECK(r.supplement_.defined(), "Invalid PREMUL_SUM ReduceOp");
const auto* preMulSupplement =
reinterpret_cast<::c10d::NCCLPreMulSumSupplement*>(
r.supplement_.get());
if (!preMulSupplement->tensor_factor.defined()) {
return py::make_tuple(r.op_, preMulSupplement->double_factor);
} else {
return py::make_tuple(r.op_, preMulSupplement->tensor_factor);
}
},
[](const py::tuple& t) {
// __setstate__
TORCH_CHECK(t.size() == 2, "Invalid state");
const auto op =
static_cast<::c10d::ReduceOp::RedOpType>(t[0].cast<uint8_t>());
if (op != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
return ::c10d::ReduceOp(op);
}
const auto preMulSupplement_factor = t[1];
if (py::isinstance<py::float_>(preMulSupplement_factor)) {
return ::c10d::makeNCCLPreMulSum(t[1].cast<double>());
} else {
return ::c10d::makeNCCLPreMulSum(t[1].cast<at::Tensor>());
}
}));
.def(
py::pickle(
[](const ::c10d::ReduceOp& r) {
// __getstate__
if (r.op_ != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
return py::make_tuple(r.op_, py::none());
}
TORCH_CHECK(
r.supplement_.defined(), "Invalid PREMUL_SUM ReduceOp");
const auto* preMulSupplement =
reinterpret_cast<::c10d::NCCLPreMulSumSupplement*>(
r.supplement_.get());
if (!preMulSupplement->tensor_factor.defined()) {
return py::make_tuple(r.op_, preMulSupplement->double_factor);
} else {
return py::make_tuple(r.op_, preMulSupplement->tensor_factor);
}
},
[](const py::tuple& t) {
// __setstate__
TORCH_CHECK(t.size() == 2, "Invalid state");
const auto op = static_cast<::c10d::ReduceOp::RedOpType>(
t[0].cast<uint8_t>());
if (op != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
return ::c10d::ReduceOp(op);
}
const auto preMulSupplement_factor = t[1];
if (py::isinstance<py::float_>(preMulSupplement_factor)) {
return ::c10d::makeNCCLPreMulSum(t[1].cast<double>());
} else {
return ::c10d::makeNCCLPreMulSum(t[1].cast<at::Tensor>());
}
}));
py::enum_<::c10d::ReduceOp::RedOpType>(reduce_op, "RedOpType")
.value("SUM", ::c10d::ReduceOp::RedOpType::SUM)
@ -3579,10 +3583,11 @@ Example::
[](std::optional<bool> includeCollectives,
std::optional<bool> includeStackTraces,
std::optional<bool> onlyActive) {
return py::bytes(::c10d::dump_xccl_trace(
includeCollectives.value_or(true),
includeStackTraces.value_or(true),
onlyActive.value_or(false)));
return py::bytes(
::c10d::dump_xccl_trace(
includeCollectives.value_or(true),
includeStackTraces.value_or(true),
onlyActive.value_or(false)));
},
py::arg("includeCollectives") = std::optional<bool>(),
py::arg("includeStackTraces") = std::optional<bool>(),
@ -4112,8 +4117,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
"_dump_nccl_trace_json",
[](std::optional<bool> includeCollectives,
std::optional<bool> onlyActive) {
return py::bytes(::c10d::dump_nccl_trace_json(
includeCollectives.value_or(true), onlyActive.value_or(false)));
return py::bytes(
::c10d::dump_nccl_trace_json(
includeCollectives.value_or(true), onlyActive.value_or(false)));
},
py::arg("includeCollectives") = std::optional<bool>(),
py::arg("onlyActive") = std::optional<bool>(),
@ -4130,10 +4136,11 @@ such as `dist.all_reduce(tensor, async_op=True)`.
[](std::optional<bool> includeCollectives,
std::optional<bool> includeStackTraces,
std::optional<bool> onlyActive) {
return py::bytes(::c10d::dump_nccl_trace(
includeCollectives.value_or(true),
includeStackTraces.value_or(true),
onlyActive.value_or(false)));
return py::bytes(
::c10d::dump_nccl_trace(
includeCollectives.value_or(true),
includeStackTraces.value_or(true),
onlyActive.value_or(false)));
},
py::arg("includeCollectives") = std::optional<bool>(),
py::arg("includeStackTraces") = std::optional<bool>(),
@ -4157,8 +4164,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
"_dump_fr_trace_json",
[](std::optional<bool> includeCollectives,
std::optional<bool> onlyActive) {
return py::bytes(::c10d::dump_fr_trace_json(
includeCollectives.value_or(true), onlyActive.value_or(false)));
return py::bytes(
::c10d::dump_fr_trace_json(
includeCollectives.value_or(true), onlyActive.value_or(false)));
},
py::arg("includeCollectives") = std::optional<bool>(),
py::arg("onlyActive") = std::optional<bool>(),
@ -4175,10 +4183,11 @@ such as `dist.all_reduce(tensor, async_op=True)`.
[](std::optional<bool> includeCollectives,
std::optional<bool> includeStackTraces,
std::optional<bool> onlyActive) {
return py::bytes(::c10d::dump_fr_trace(
includeCollectives.value_or(true),
includeStackTraces.value_or(true),
onlyActive.value_or(false)));
return py::bytes(
::c10d::dump_fr_trace(
includeCollectives.value_or(true),
includeStackTraces.value_or(true),
onlyActive.value_or(false)));
},
py::arg("includeCollectives") = std::optional<bool>(),
py::arg("includeStackTraces") = std::optional<bool>(),
@ -4203,7 +4212,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
}),
py::arg("host_or_file"),
py::arg("port") = -1)
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown)
.def_property_readonly(
"port", &::c10d::control_plane::WorkerServer::port);
module.def(
"_get_handler",
@ -4219,6 +4230,25 @@ such as `dist.all_reduce(tensor, async_op=True)`.
Returns the handler with the specified name.
)");
module.def(
"_register_handler",
[](const std::string& name, py::function handler) {
::c10d::control_plane::registerHandler(
name,
[handler](
const ::c10d::control_plane::Request& req,
::c10d::control_plane::Response& res) {
py::gil_scoped_acquire acquire;
handler(std::ref(req), std::ref(res));
});
},
py::arg("name"),
py::arg("handler"),
R"(
Registers a handler by name.
)");
module.def(
"_get_handler_names",
&::c10d::control_plane::getHandlerNames,
@ -4238,10 +4268,7 @@ such as `dist.all_reduce(tensor, async_op=True)`.
.def("body", &::c10d::control_plane::Request::body)
.def("params", &::c10d::control_plane::Request::params);
py::class_<
::c10d::control_plane::Response,
std::shared_ptr<::c10d::control_plane::Response>,
PythonResponse>(
py::class_<::c10d::control_plane::Response, PythonResponse>(
module,
"_Response",
R"(
@ -4267,9 +4294,10 @@ such as `dist.all_reduce(tensor, async_op=True)`.
} // namespace
// c10d methods on torch._C
static PyMethodDef methods[] = { // NOLINT
{"_c10d_init", c10d_init, METH_NOARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};
static PyMethodDef methods[] =
{ // NOLINT
{"_c10d_init", c10d_init, METH_NOARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};
// NOLINTNEXTLINE(misc-use-internal-linkage)
PyMethodDef* python_functions() {

239
torch/distributed/debug.py Normal file
View File

@ -0,0 +1,239 @@
import os
import socket
import multiprocessing
import requests
from concurrent.futures import ThreadPoolExecutor
import json
import time
import tempfile
from flask import Flask
import torch.distributed as dist
from torch.profiler import (
profile,
ProfilerActivity,
record_function,
_ExperimentalConfig,
)
from torch._C._distributed_c10d import _WorkerServer, _register_handler
def _torch_profile(req, resp):
experimental_config = _ExperimentalConfig(
profile_all_threads=True,
)
with profile(record_shapes=True, experimental_config=experimental_config) as prof:
time.sleep(2)
with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f:
prof.export_chrome_trace(f.name)
resp.set_content(open(f.name, "rb").read(), "application/json")
resp.set_status(200)
_register_handler("torch_profile", _torch_profile)
MASTER_ADDR = os.environ["MASTER_ADDR"]
MASTER_PORT = int(os.environ["MASTER_PORT"])
RANK = int(os.environ["RANK"])
WORLD_SIZE = int(os.environ["WORLD_SIZE"])
def _tcpstore_client() -> dist.Store:
store = dist.TCPStore(
host_name=MASTER_ADDR,
port=MASTER_PORT,
is_master=False,
)
store = dist.PrefixStore("debug_server", store)
return store
def fetch_all(endpoint: str) -> list[bytes]:
store = _tcpstore_client()
keys = [f"rank{r}" for r in range(WORLD_SIZE)]
addrs = store.multi_get(keys)
addrs = [f"{addr.decode()}/handler/{endpoint}" for addr in addrs]
with ThreadPoolExecutor(max_workers=10) as executor:
resps = executor.map(requests.post, addrs)
return addrs, resps
app = Flask(__name__)
def nav():
return """
<style>
body {
font-family: sans-serif;
}
pre {
white-space: pre-wrap;
max-width: 100%;
}
</style>
<h1>Torch Distributed Debug Server</h1>
<ul>
<li><a href="/">Home</a></li>
<li><a href="/stacks">Python Stack Traces</a></li>
<li><a href="/fr_trace">FlightRecorder</a></li>
<li><a href="/fr_trace_nccl">FlightRecorder NCCL</a></li>
<li><a href="/profile">torch profiler</a></li>
</ul>
"""
@app.route("/")
def index():
return nav()
@app.route("/stacks")
def stacks():
addrs, resps = fetch_all("dump_traceback")
def generate():
yield nav()
yield "<h2>Stacks</h2>"
for i, addr, resp in zip(range(len(addrs)), addrs, resps):
yield f"<h3>Rank {i}: {addr}</h3>"
if resp.status_code != 200:
yield f"<p>Failed to fetch: status={resp.status_code}</p>"
stack = resp.text
yield f"<pre>{stack}</pre>"
return generate()
def format_json(blob: str):
parsed = json.loads(blob)
return json.dumps(parsed, indent=2)
@app.route("/fr_trace")
def fr_trace():
addrs, resps = fetch_all("fr_trace_json")
def generate():
yield nav()
yield "<h2>FlightRecorder</h2>"
for i, addr, resp in zip(range(len(addrs)), addrs, resps):
yield f"<h3>Rank {i}: {addr}</h3>"
if resp.status_code != 200:
yield f"<p>Failed to fetch: status={resp.status_code}</p>"
stack = format_json(resp.text)
yield f"<pre>{stack}</pre>"
return generate()
@app.route("/fr_trace_nccl")
def fr_trace_nccl():
addrs, resps = fetch_all("dump_nccl_trace_json?onlyactive=true")
def generate():
yield nav()
yield "<h2>FlightRecorder NCCL</h2>"
for i, addr, resp in zip(range(len(addrs)), addrs, resps):
yield f"<h3>Rank {i}: {addr}</h3>"
if resp.status_code != 200:
yield f"<p>Failed to fetch: status={resp.status_code}</p>"
stack = format_json(resp.text)
yield f"<pre>{stack}</pre>"
return generate()
@app.route("/profile")
def profiler():
addrs, resps = fetch_all("torch_profile")
def generate():
yield nav()
yield """
<h2>torch profile</h2>
<script>
function stringToArrayBuffer(str) {
const encoder = new TextEncoder();
return encoder.encode(str).buffer;
}
async function openPerfetto(data) {
const ui = window.open('https://ui.perfetto.dev/#!/');
if (!ui) { alert('Popup blocked. Allow popups for this page and click again.'); return; }
// Perfetto readiness handshake: PING until we receive PONG
await new Promise((resolve, reject) => {
const onMsg = (e) => {
if (e.source === ui && e.data === 'PONG') {
window.removeEventListener('message', onMsg);
clearInterval(pinger);
resolve();
}
};
window.addEventListener('message', onMsg);
const pinger = setInterval(() => { try { ui.postMessage('PING', '*'); } catch (_e) {} }, 250);
setTimeout(() => { clearInterval(pinger); window.removeEventListener('message', onMsg); reject(); }, 20000);
}).catch(() => { alert('Perfetto UI did not respond. Try again.'); return; });
ui.postMessage({
perfetto: {
buffer: stringToArrayBuffer(JSON.stringify(data)),
title: "torch profiler",
fileName: "trace.json",
}
}, '*');
}
</script>
"""
for i, addr, resp in zip(range(len(addrs)), addrs, resps):
yield f"<h3>Rank {i}: {addr}</h3>"
if resp.status_code != 200:
yield f"<p>Failed to fetch: status={resp.status_code}</p>"
stack = resp.text
yield f"""
<script>
function run{i}() {{
var data = {stack};
openPerfetto(data);
}}
</script>
<button onclick="run{i}()">View {i}</button>
"""
return generate()
def _interactive_server() -> None:
app.run(host="::", port=25999)
def enable_debug_server() -> None:
global _worker_server, _p
store = _tcpstore_client()
_worker_server = _WorkerServer("::", 0)
store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_worker_server.port}")
if RANK == 0:
_p = multiprocessing.Process(
target=_interactive_server,
)
_p.start()

View File

@ -1485,6 +1485,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
use_full_backward: Optional[bool] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
# Init parent
super().__init__(
@ -1517,6 +1518,11 @@ class PipelineScheduleMulti(_PipelineSchedule):
# This will be set during init of derived schedules
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
# When using a custom backward function, we may or may not need autograd to be used
# for the backward pass. This flag is used to determine whether or torch.is_grad_enabled()
# check should be performed before the step function.
self._backward_requires_autograd = backward_requires_autograd
if use_full_backward is not None:
logger.warning(
"Deprecation warning: 'use_full_backward' is no longer supported. "
@ -1609,7 +1615,11 @@ class PipelineScheduleMulti(_PipelineSchedule):
losses: a list to store the losses for each microbatch.
return_outputs: whether to return the outputs from the last stage.
"""
if self._has_backward and not torch.is_grad_enabled():
if (
self._has_backward
and self._backward_requires_autograd
and not torch.is_grad_enabled()
):
raise RuntimeError(
"step() requires gradients to be enabled for backward computation; "
"it should not be used under torch.no_grad() context. "
@ -1891,7 +1901,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
Args:
computation_type: The computation type for which to register the custom function
custom_function: The function to execute when this computation type is encountered.
Must have signature: (stage: _PipelineStageBase, mb_index: int, *args, **kwargs) -> None
Must have signature: (action: _Action, ctx: _PipelineContext) -> None
"""
# Ensure that the computation type is valid
if computation_type not in (
@ -1900,10 +1910,13 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
BACKWARD_INPUT,
BACKWARD_WEIGHT,
OVERLAP_F_B,
UNSHARD,
RESHARD,
REDUCE_GRAD,
):
raise ValueError(
f"Invalid computation type {computation_type}. Only FORWARD, FULL_BACKWARD, \
BACKWARD_INPUT, BACKWARD_WEIGHT, and OVERLAP_F_B are supported."
BACKWARD_INPUT, BACKWARD_WEIGHT, OVERLAP_F_B, UNSHARD, RESHARD and REDUCE_GRAD are supported."
)
# Check if computation_type is already registered
@ -2296,6 +2309,7 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
loss_fn: Optional[Union[Callable, _Loss]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
super().__init__(
stages=stages,
@ -2303,6 +2317,7 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
loss_fn=loss_fn,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
backward_requires_autograd=backward_requires_autograd,
)
# 1. Create the pipeline_order (all ranks do this calculation)
@ -2510,6 +2525,7 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
self.pp_group_size = stages[0].group_size
super().__init__(
@ -2520,6 +2536,7 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
backward_requires_autograd=backward_requires_autograd,
)
self.n_local_stages = len(stages)
self.rank = stages[0].group_rank
@ -2622,6 +2639,7 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
# TODO: we dont support input/weight backward split with torch.compile
_check_torch_compile_compatibility(stages, self.__class__.__name__)
@ -2634,6 +2652,7 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
backward_requires_autograd=backward_requires_autograd,
)
self.n_local_stages = len(stages)
self.rank = stages[0].group_rank
@ -2819,6 +2838,7 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
# TODO: we dont support input/weight backward split with torch.compile
_check_torch_compile_compatibility(stages, self.__class__.__name__)
@ -2831,6 +2851,7 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
backward_requires_autograd=backward_requires_autograd,
)
self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
self.pp_group_size, self._num_stages, style="v"
@ -2995,6 +3016,7 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
# TODO: we dont support input/weight backward split with torch.compile
_check_torch_compile_compatibility(stages, self.__class__.__name__)
@ -3007,6 +3029,7 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
backward_requires_autograd=backward_requires_autograd,
)
self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
self.pp_group_size, self._num_stages, style="v"

View File

@ -2,6 +2,8 @@ import os
import sys
from typing import Optional
from torch.utils._config_module import Config, install_config_module
# [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors.
no_data_dependent_graph_break = (
@ -100,7 +102,11 @@ backed_size_oblivious = False
# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking.
skip_dtype_check_in_meta_registrations = False
from torch.utils._config_module import install_config_module
# Experimental: If True, graph module will register fx metadata during recompile()
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
default=False,
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
)
install_config_module(sys.modules[__name__])

View File

@ -20,6 +20,7 @@ from torch.nn.modules.module import _addindent
from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
from ._compatibility import compatibility
from .experimental import _config as fx_experimental_config
from .graph import (
_BoxedCodeGen,
_custom_builtins,
@ -858,14 +859,15 @@ class {module_name}(torch.nn.Module):
called after editing the contained ``graph``, otherwise the generated
code of this ``GraphModule`` will be out of date.
"""
# Do not import anything inside recompile, it might slow down the
# function and cause perf regression. Import outside of the method instead.
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
from torch._dynamo import config as dynamo_config
python_code = self._graph.python_code(
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
root_module="self",
record_func=fx_experimental_config.enrich_profiler_metadata,
)
self._code = python_code.src
self._lineno_map = python_code._lineno_map
@ -874,7 +876,7 @@ class {module_name}(torch.nn.Module):
cls = type(self)
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
if dynamo_config.enrich_profiler_metadata:
if fx_experimental_config.enrich_profiler_metadata:
# Generate metadata and register for profiler augmentation
node_metadata: dict[int, dict[str, Any]] = {}
for i, node in enumerate(self._graph.nodes):

View File

@ -1771,6 +1771,22 @@ class MultiProcContinuousTest(TestCase):
cls._run_test_given_id(test_id)
completion_queue.put(test_id)
except BaseException as ex: # noqa: B036
if isinstance(ex, SystemExit):
# Get exit code from the process
exit_code = getattr(ex, "code", None)
# Look up exit code in TEST_SKIPS to see if it is a valid skip
skip_entry = next(
(v for v in TEST_SKIPS.values() if v.exit_code == exit_code),
None,
)
# If we found an entry, we want to skip the test and the object back to the main process
if skip_entry:
completion_queue.put(unittest.SkipTest(skip_entry.message))
# Skip exception handling below, move to main thread for processing the skip
continue
raised_exception = True
# Send the exception and stack trace back to the dispatcher
exc_info = sys.exc_info()
@ -1892,6 +1908,8 @@ class MultiProcContinuousTest(TestCase):
# Wait for the workers to finish the test
for i, completion_queue in enumerate(self.completion_queues):
rv = completion_queue.get()
if isinstance(rv, unittest.SkipTest):
raise rv
if isinstance(rv, BaseException):
# Hit an exception, re-raise it in the main process.
logger.warning(

View File

@ -114,8 +114,6 @@ class ProfilingMode(Enum):
PROFILING = 3
# Set by parse_cmd_line_args() if called
CI_FUNCTORCH_ROOT = ""
CI_PT_ROOT = ""
CI_TEST_PREFIX = ""
DISABLED_TESTS_FILE = ""
GRAPH_EXECUTOR : Optional[ProfilingMode] = None
@ -959,8 +957,6 @@ def _get_test_report_path():
return os.path.join('test-reports', test_source)
def parse_cmd_line_args():
global CI_FUNCTORCH_ROOT
global CI_PT_ROOT
global CI_TEST_PREFIX
global DISABLED_TESTS_FILE
global GRAPH_EXECUTOR
@ -1039,10 +1035,8 @@ def parse_cmd_line_args():
set_rng_seed()
# CI Prefix path used only on CI environment
# CI Prefix path used only on CI environment
CI_TEST_PREFIX = str(Path(os.getcwd()))
CI_PT_ROOT = str(Path(os.getcwd()).parent)
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
def wait_for_process(p, timeout=None):
try: