Compare commits

..

11 Commits

Author SHA1 Message Date
04c8da3809 distributed/debug: add an HTTP server for debugging running jobs 2025-11-12 16:48:52 -08:00
e401a56b96 [ez] Remove some dead code from test artifact related files (#166966)
Remove circle ci path since it's no longer used

Remove function that is not used
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166966
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-11-07 18:14:44 +00:00
22650c89fb [ROCm] Update skip_if_lt_x_gpu to work with MultiProcContinuous class (#167281)
- Since MultiProcContinuous class spawns one process per GPU and runs UT in each of the processes, we need to ensure we are propagating the exit code associated with skip all the way to the main worker thread that spawned all the child processes.
- This commit also updates several UTs that are meant for 4 GPUs but incorrectly calls skip_if_lt_x_gpu with 2 as an input. Examples:
    - test_replicate_with_fsdp.py
    - test_dtensor_resharding.py
    - test_state_dict.py
    - test_functional_api.py: Fix typo. multi-accelerator doesn't exit, replaced with multi-gpu
    - test_op_strategy.py: world_size was hardcoded
    - test_math_ops.py: UT written for 4 GPU, so skipping for anything less
    - test_schedule_multiproc.py: All UTs in this suite are required to run on 2+ GPUs, therefore, adding skips if less than 4 GPUs are supplied

Fixes https://github.com/pytorch/pytorch/issues/166875

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167281
Approved by: https://github.com/jeffdaily
2025-11-07 18:11:48 +00:00
c62a17a2fb [ez] Remove some unused vars in common_utils.py (#166453)
I can't find where these are used
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166453
Approved by: https://github.com/malfet
2025-11-07 18:09:40 +00:00
713e289ae7 [dynamo][pytree] support more optree functions by polyfill the underlying CXX functions directly (#167292)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167292
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167221, #167211
2025-11-07 18:09:19 +00:00
69784a0dbe [dynamo][pytree] add polyfills for optree path APIs (#167211)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167211
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167221
2025-11-07 17:53:32 +00:00
3c2409c465 Refactor recursive call of collect_temp_source (#166714)
Recursive function call creates a reference cycle: closure <- function <- cell inside closure
Capturing self (PyCodegen instance) in same closure prolongs it's life until next gc.collect() which might result in worse resource management

After the introduction of e9209e0 OOM issues has been observed. Looking for reference cycles one has been uncovered that would result in the prolonging lifetime of tensors. As the result of that OOM issues might occur. Such a dependency chain has been uncovered:
<img width="1059" height="540" alt="image" src="https://github.com/user-attachments/assets/359a8534-e7cd-491f-be40-547c2af5cbbc" />

At the end of it a reference cycle can be found that consists of a closure for function collect_temp_source, the function itself, and a cell object inside closure that would point to the function due to the recursive call.

This issue can either be resolved by removing recurrency or removing PyCodegen instance from the closure.
Another precaution that can be made is to explicitly empty f_locals dict. This way we cut the tensor from the chain leading to reference cycle.

Fixes #166721

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166714
Approved by: https://github.com/Lucaskabela, https://github.com/Skylion007, https://github.com/jeromean, https://github.com/williamwen42, https://github.com/mlazos
2025-11-07 17:52:54 +00:00
724cd32b0c [PT2 Compiler] Add flag in dynamo disable wrapper to indicate reursive disable (#165790)
Summary: After torch._dynamo.disable is applied, wrapped method does not have any flag to indicate whether it was disabled recursively or not. This flag is needed if to preserve dynamo disable methods in torch.export-ed model

Test Plan:
```
buck test mode/opt caffe2/test/dynamo:test_dynamo -- 'test_disable_recursive_flags'
````
https://www.internalfb.com/intern/testinfra/testrun/7599824674075603

Differential Revision: D84949143

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165790
Approved by: https://github.com/angelayi, https://github.com/williamwen42
2025-11-07 17:48:20 +00:00
b62935d1a5 fix alpha beta in decomp (#167317)
fix for https://github.com/pytorch/pytorch/issues/167313

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167317
Approved by: https://github.com/zou3519
ghstack dependencies: #161404
2025-11-07 17:42:13 +00:00
ccc8c117dc Codeowner/Labeler updates post-Blas-reorgs (#167130)
Summary:

Previous PRs have split out scaled/grouped Blas routines into
their own files. This updates the codeowners and labeler to reflect
those changes.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167130
Approved by: https://github.com/drisspg
2025-11-07 17:27:41 +00:00
86db4de10f [PP] PP Runtime Features for supporting Graph Based execution (#167277)
Allow overriding UNSHARD, RESHARD and REDUCE_GRAD actions.
Enable running pp backward without torch.grad.is_enabled().

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167277
Approved by: https://github.com/wconstab
2025-11-07 17:11:14 +00:00
35 changed files with 904 additions and 201 deletions

22
.github/labeler.yml vendored
View File

@ -138,7 +138,8 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/Blas.cpp
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- torch/**/*cublas*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
@ -148,7 +149,8 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/Blas.cpp
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- torch/**/*cublas*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
@ -158,20 +160,8 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/Blas.cpp
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
- third_party/fbgemm
"ciflow/mps":
- aten/src/ATen/mps/**
- aten/src/ATen/native/mps/**
- torch/_inductor/codegen/mps.py
- test/test_mps.py
- test/inductor/test_mps_basic.py
"ciflow/h100-symm-mem":
- torch/csrc/distributed/c10d/symm_mem/**
- torch/distributed/_symmetric_memory/**
- test/distributed/**/*mem*
- test/distributed/**/*mem*/**

View File

@ -210,8 +210,12 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
/test/inductor/test_flex_attention.py @drisspg
/test/inductor/test_flex_decoding.py @drisspg
# Low Precision GEMMs
# Low Precision & Grouped GEMMs
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
/aten/src/ATen/native/cuda/GroupedBlas.cpp @drisspg @slayton58
/aten/src/ATen/native/cuda/ScaledBlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
/aten/src/ATen/cuda/CUDAScaledBlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDAScaledBlas.h @drisspg @slayton58
/test/test_scaled_matmul_cuda.py @drisspg @slayton58

View File

@ -76,7 +76,7 @@ class ReplicateTest(MultiProcessTestCase):
store=dist.FileStore(self.file_name, self.world_size),
)
@skip_if_lt_x_gpu(2)
@skip_if_lt_x_gpu(4)
def test_replicate_transformer(self):
"""
This tests that replicate works on a transformer model with fully_shard and replicate layers
@ -126,7 +126,7 @@ class ReplicateTest(MultiProcessTestCase):
for parameter in layer.parameters():
self.assertEqual(parameter.placements, (Shard(dim=0),))
@skip_if_lt_x_gpu(2)
@skip_if_lt_x_gpu(4)
def test_replicate_transformer_managed_modules(self):
"""
This tests that replicate managed modules works properly. In this test we use a Transformer Module with 3 layers,
@ -178,7 +178,7 @@ class ReplicateTest(MultiProcessTestCase):
replicate_model = replicate(replicate_model)
self.assertEqual(len(_get_managed_modules((replicate_model,))), 21)
@skip_if_lt_x_gpu(2)
@skip_if_lt_x_gpu(4)
def test_replicate_tp_device_mesh(self):
"""
This tests that a user can pass in a device mesh to replicate a module
@ -206,7 +206,7 @@ class ReplicateTest(MultiProcessTestCase):
self.assertEqual(parameter.device_mesh.shape, (2,))
self.assertEqual(parameter.placements, (Replicate(),))
@skip_if_lt_x_gpu(2)
@skip_if_lt_x_gpu(4)
def test_train_replicate_fsdp(self):
"""
Tests that replicate_model has the same behavior as original model when training
@ -253,7 +253,7 @@ class ReplicateTest(MultiProcessTestCase):
self.assertEqual(replicate_loss, loss)
check_sharded_parity(self, model, replicate_model)
@skip_if_lt_x_gpu(2)
@skip_if_lt_x_gpu(4)
def test_train_parity_2d_mlp(self):
"""
Verifies when a device mesh is passed in, the model has the same behavior as the original model when training

View File

@ -299,7 +299,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(2)
@skip_if_lt_x_gpu(4)
def test_dtensor_checkpoint_with_uneven_shards(self) -> None:
"""
Saving a dtensor with uneven shards.
@ -436,6 +436,7 @@ class TestCheckpointableReshard(DTensorTestBase):
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
def test_uneven_reshard_with_checkpointable_api(self) -> None:
"""
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.
@ -498,6 +499,7 @@ class TestCheckpointableReshard(DTensorTestBase):
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
def test_uneven_reshard_with_dtensor_shards_wrapper_api(self) -> None:
"""
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.

View File

@ -886,7 +886,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self.assertEqual(cpu_model_value, meta_model_value)
@with_comms
@skip_if_lt_x_gpu(2)
@skip_if_lt_x_gpu(4)
def test_setting_meta_device_model_broadcasting_and_memory(self) -> None:
# This test verifies that we can set model state dict by a meta device model
# With the correlated changes in state_dict, meta device model should be accepted

View File

@ -39,6 +39,7 @@ from torch.nn.modules.loss import MSELoss
from torch.testing._internal.common_distributed import (
MultiProcContinuousTest,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
check_leaked_tensors,
@ -231,6 +232,7 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [_ScheduleForwardOnly])
@skip_if_lt_x_gpu(4)
def test_forward_only(self, ScheduleClass):
mod, mod_ref, x, _, _ = setup_models_and_data(self.config)
x_clone = x.clone()
@ -274,6 +276,7 @@ class ScheduleTest(MultiProcContinuousTest):
ScheduleInterleavedZeroBubble,
],
)
@skip_if_lt_x_gpu(4)
def test_eval_inference_mode(self, ScheduleClass):
num_microbatches = 4
if ScheduleClass in [
@ -351,6 +354,7 @@ class ScheduleTest(MultiProcContinuousTest):
ScheduleInterleavedZeroBubble,
],
)
@skip_if_lt_x_gpu(4)
def test_return_output(self, ScheduleClass):
num_microbatches = 4
if ScheduleClass in [
@ -406,6 +410,7 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@skip_if_lt_x_gpu(4)
def test_multi_iter(self, ScheduleClass):
mod, _, x, target, loss_fn = setup_models_and_data(self.config)
chunks = 4
@ -429,6 +434,7 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@skip_if_lt_x_gpu(4)
def test_kwargs_with_tracer(self, ScheduleClass):
mod = ModelWithKwargs(d_hid, splits=self.world_size)
mod.to(self.device)
@ -481,6 +487,7 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@skip_if_lt_x_gpu(4)
def test_grad_with_tracer(self, ScheduleClass):
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
@ -523,6 +530,7 @@ class ScheduleTest(MultiProcContinuousTest):
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@parametrize("shape_inference", [True, False])
@skip_if_lt_x_gpu(4)
def test_grad_with_manual(self, ScheduleClass, shape_inference):
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
@ -586,6 +594,7 @@ class ScheduleTest(MultiProcContinuousTest):
ScheduleInterleavedZeroBubble,
],
)
@skip_if_lt_x_gpu(4)
def test_grad_with_manual_interleaved(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
@ -650,6 +659,7 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble])
@skip_if_lt_x_gpu(4)
def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
@ -736,6 +746,7 @@ class ScheduleTest(MultiProcContinuousTest):
"schedule_class",
[ScheduleZBVZeroBubble, ScheduleDualPipeV],
)
@skip_if_lt_x_gpu(4)
def test_v_shape_schedules(self, schedule_class):
n_stages = 8
rank_stages = {0: [0, 7], 1: [1, 6], 2: [2, 5], 3: [3, 4]}
@ -780,6 +791,7 @@ class ScheduleTest(MultiProcContinuousTest):
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@skip_if_lt_x_gpu(4)
def test_custom_function_callback(self):
"""Test the custom function callback functionality with _PipelineScheduleRuntime."""
n_stages = 8
@ -979,6 +991,7 @@ class ScheduleTest(MultiProcContinuousTest):
"ScheduleClass",
[ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B],
)
@skip_if_lt_x_gpu(4)
def test_zero_bubble_with_model_kwargs(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
@ -1072,6 +1085,7 @@ class CustomSchedulesTest(MultiProcContinuousTest):
"schedule_class",
[ScheduleVShaped, ScheduleUnbalanced],
)
@skip_if_lt_x_gpu(4)
def test_non_symmetric_stage_ids(self, schedule_class):
n_stages = schedule_class.n_stages
rank_stages = schedule_class.rank_stages
@ -1121,6 +1135,7 @@ class CustomSchedulesTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleWithReorderedB])
@skip_if_lt_x_gpu(4)
def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass):
n_stages = 2
stages_per_rank = 1
@ -1181,6 +1196,7 @@ class CustomSchedulesTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleWithW])
@skip_if_lt_x_gpu(4)
def test_schedule_with_native_zero_bubble(self, ScheduleClass):
n_stages = ScheduleClass.n_stages
num_microbatches = ScheduleClass.num_microbatches

View File

@ -26,6 +26,7 @@ from torch.distributed.tensor.parallel import (
RowwiseParallel,
SequenceParallel,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
@ -764,6 +765,7 @@ class DistMathOpsTest(DTensorTestBase):
self.assertEqual(grad1_norm.device_mesh, mesh_y)
@with_comms
@skip_if_lt_x_gpu(4)
def test_foreach_add_different_mesh(self):
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(

View File

@ -577,7 +577,7 @@ class DistTensorReplicateStrategyRegistrationTest(DTensorTestBase):
self.assertEqual(
comm_mode.get_comm_counts(),
{
torch.ops.c10d_functional.all_gather_into_tensor: 4,
torch.ops.c10d_functional.all_gather_into_tensor: self.world_size,
},
)
expected_cost = [

View File

@ -0,0 +1,50 @@
# Owner(s): ["oncall: distributed"]
import os
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import torch.distributed as dist
from torch.distributed.debug import start_debug_server, stop_debug_server
from torch.testing._internal.common_utils import run_tests, TestCase
session = requests.Session()
retry_strategy = Retry(total=5, backoff_factor=0.5)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("http://", adapter)
session.mount("https://", adapter)
class TestDebug(TestCase):
def test_basics(self) -> None:
store = dist.TCPStore("localhost", 0, 1, is_master=True, wait_for_workers=False)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(store.port)
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
port = 25999
def fetch(path: str) -> str:
resp = session.get(f"http://localhost:{port}{path}")
resp.raise_for_status()
return resp.text
print("starting!")
start_debug_server(port=port)
self.assertIn("torch profiler", fetch("/"))
self.assertIn("View 0", fetch("/profile?duration=0.01"))
self.assertIn("test_basics", fetch("/stacks"))
self.assertIn("pg_status", fetch("/fr_trace"))
self.assertIn("pg_status", fetch("/fr_trace_nccl"))
stop_debug_server()
if __name__ == "__main__":
run_tests()

View File

@ -485,7 +485,7 @@ elif TEST_XPU:
def exit_if_lt_x_accelerators(x):
if torch.accelerator.is_available():
if torch.accelerator.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code)
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
def with_comms(func=None):

View File

@ -2109,6 +2109,52 @@ Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: '
with self.assertRaises(Unsupported):
outer_f2(inp)
def test_disable_recursive_flags(self):
class SimpleLinear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer0 = torch.nn.Linear(4, 4)
def forward(self, inp):
return self.layer0(torch.sigmoid(inp))
class SimpleModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer0 = SimpleLinear()
self.layer1 = torch.nn.Linear(4, 4)
def forward(self, inp):
z = self.layer0(torch.sin(inp))
return self.layer1(z)
for recursive_flag in [True, False]:
model = SimpleModel()
other_model = SimpleModel()
model.forward = torch._dynamo.disable(
model.forward,
recursive=recursive_flag,
)
self.assertEqual(
torch._dynamo.is_dynamo_disable_recursive(model.forward),
recursive_flag,
)
other_model = torch._dynamo.disable(other_model, recursive=recursive_flag)
self.assertEqual(
torch._dynamo.is_dynamo_disable_recursive(
other_model.forward
if isinstance(other_model, torch.nn.Module)
else other_model
),
recursive_flag,
)
# check the model is compilable
torch.compile(model)
torch.compile(other_model)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -422,34 +422,41 @@ from user code:
import optree
@torch.compile(backend="eager")
def fn(x):
d = {"a": 1}
optree.tree_flatten_with_path(d)
return torch.sin(x)
def post_munge(s):
s = re.sub(
r"optree\.\S*\.flatten_with_path",
"optree.<path>.flatten_with_path",
s,
)
return re.sub(
r"qualname: \S*flatten_with_path",
"qualname: <path>.flatten_with_path",
s,
def fn1(x):
tree = {"a": x, "b": (x - 1, 2 * x)}
sin, cos = optree.tree_transpose_map(
lambda t: (torch.sin(t), torch.cos(t)),
tree,
)
return sin, cos
fn(torch.randn(4))
self.assertEqual(len(counters["graph_break"]), 1)
fn1(torch.randn(4))
self.assertEqual(len(counters["graph_break"]), 0)
@torch.compile(backend="eager")
def fn2(x):
spec = optree.treespec_deque([])
return spec, x
fn2(torch.randn(4))
self.assertGreaterEqual(len(counters["graph_break"]), 1)
first_graph_break = next(iter(counters["graph_break"].keys()))
def post_munge(string):
return re.sub(
r"(optree\.|qualname: )\S*(\.make_from_collection)",
r"\1<path>\2",
string,
)
self.assertExpectedInline(
post_munge(first_graph_break),
"""\
Attempted to call function marked as skipped
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.flatten_with_path.
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.make_from_collection.
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
Developer debug context: module: optree._C, qualname: <path>.flatten_with_path, skip reason: <missing reason>
Developer debug context: module: optree._C, qualname: <path>.make_from_collection, skip reason: <missing reason>
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
)

View File

@ -1217,6 +1217,43 @@ class TestPatternMatcher(TestCase):
_, (code) = run_and_get_code(fn2, args[0], args[1], args[2])
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
def test_addmm_alpha_beta_with_pointwise(self):
# Test that addmm with alpha/beta != 1 is unfused correctly with pointwise ops
# See https://github.com/pytorch/pytorch/issues/167313
x = torch.rand(2, device=GPU_TYPE)
a = torch.rand(2, 3, device=GPU_TYPE)
b = torch.rand(3, 2, device=GPU_TYPE)
def f(x, a, b):
return torch.nn.functional.relu(torch.addmm(x, a, b, alpha=0.8, beta=0.2))
fc = torch.compile(f)
expected = f(x, a, b)
actual = fc(x, a, b)
# The compiled version should produce the same result as eager
torch.testing.assert_close(actual, expected)
# Verify that addmm is unfused (should not use extern_kernels.addmm)
# The pattern should be replaced with beta * x + alpha * (a @ b)
_, (code) = run_and_get_code(fc, x, a, b)
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
# Test with alpha=1, beta=1 (default) - should also unfuse
def f_default(x, a, b):
return torch.nn.functional.relu(torch.addmm(x, a, b))
fc_default = torch.compile(f_default)
expected_default = f_default(x, a, b)
actual_default = fc_default(x, a, b)
torch.testing.assert_close(actual_default, expected_default)
# Should unfuse and not use extern_kernels.addmm
_, (code) = run_and_get_code(fc_default, x, a, b)
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
def test_serialized_patterns_up_to_date(self):
import torch.utils._pytree as pytree
from torch._inductor.fx_passes import joint_graph

View File

@ -165,21 +165,6 @@ def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> list[dict[str,
return flattened
def get_tests_for_circleci(
workflow_run_id: int, workflow_run_attempt: int
) -> list[dict[str, Any]]:
# Parse the reports and transform them to JSON
test_cases = []
for xml_report in Path(".").glob("**/test/test-reports/**/*.xml"):
test_cases.extend(
parse_xml_report(
"testcase", xml_report, workflow_run_id, workflow_run_attempt
)
)
return test_cases
def summarize_test_cases(test_cases: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Group test cases by classname, file, and job_id. We perform the aggregation
manually instead of using the `test-suite` XML tag because xmlrunner does
@ -258,21 +243,11 @@ if __name__ == "__main__":
required=True,
help="Head repository of the workflow",
)
parser.add_argument(
"--circleci",
action="store_true",
help="If this is being run through circleci",
)
args = parser.parse_args()
print(f"Workflow id is: {args.workflow_run_id}")
if args.circleci:
test_cases = get_tests_for_circleci(
args.workflow_run_id, args.workflow_run_attempt
)
else:
test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt)
test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt)
# Flush stdout so that any errors in the upload show up last in the logs.
sys.stdout.flush()

View File

@ -100,7 +100,9 @@ class Logger:
def _set_static_graph(self) -> None: ...
class _WorkerServer:
def __init__(self, socket_path: str) -> None: ...
port: int
def __init__(self, host_or_file: str, port: int = ...) -> None: ...
def shutdown(self) -> None: ...
def get_debug_level(): ...
@ -206,6 +208,7 @@ class Store:
desired_value: str,
) -> bytes: ...
def delete_key(self, key: str) -> bool: ...
def multi_get(self, keys: list[str]) -> list[bytes]: ...
def num_keys(self) -> int: ...
def set_timeout(self, timeout: timedelta): ...
@overload
@ -871,3 +874,15 @@ class ProcessGroupXCCL(Backend):
def _set_process_group(pg: ProcessGroup) -> None: ...
def _current_process_group() -> ProcessGroup: ...
class _Request:
def body(self) -> bytes: ...
def get_param(self, str) -> str: ...
class _Response:
def set_content(self, content: str | bytes, content_type: str) -> None: ...
def set_status(self, status: int) -> None: ...
def _register_handler(
name: str, handler: Callable[[_Request, _Response], None]
) -> None: ...

View File

@ -60,6 +60,7 @@ class _ExperimentalConfig:
verbose: bool = ...,
performance_events: list[str] = ...,
enable_cuda_sync_events: bool = ...,
profile_all_threads: bool = ...,
) -> None: ...
class ProfilerConfig:

View File

@ -32,6 +32,7 @@ from .decorators import (
error_on_graph_break,
forbid_in_graph,
graph_break,
is_dynamo_disable_recursive,
mark_dynamic,
mark_static,
mark_static_address,
@ -87,6 +88,7 @@ __all__ = [
"forbid_in_graph",
"graph_break",
"is_compiling",
"is_dynamo_disable_recursive",
"list_backends",
"lookup_backend",
"mark_dynamic",

View File

@ -15,7 +15,7 @@ import dataclasses
import re
import sys
import types
from collections import Counter
from collections import Counter, deque
from collections.abc import Callable, Iterable
from typing import Any, Optional, TYPE_CHECKING, Union
@ -597,32 +597,35 @@ class PyCodegen:
graphargs = self.tx.output.graphargs
seen_sources: OrderedSet[Source] = OrderedSet()
def collect_temp_source(source: Source) -> None:
if source in seen_sources:
# This source is used at least twice, so it can be reused
self.mark_source_temp(source)
# Dont trace source further. This prevents us from marking too
# many nodes as temp sources.
return
seen_sources.add(source)
def extract_nested_sources(source: Source) -> list[Source]:
nested_sources: list[Source] = []
if isinstance(source, ChainedSource):
collect_temp_source(source.base)
nested_sources.append(source.base)
if isinstance(source, DictGetItemSource) and isinstance(
source.index, Source
):
collect_temp_source(source.index)
nested_sources.append(source.index)
return nested_sources
def collect_temp_sources(sources: deque[Source], codegen: PyCodegen) -> None:
seen_sources: OrderedSet[Source] = OrderedSet()
while sources:
current_source = sources.popleft()
if current_source in seen_sources:
# This source is used at least twice, so it can be reused
codegen.mark_source_temp(current_source)
# Dont trace source further. This prevents us from marking too
# many nodes as temp sources.
continue
seen_sources.add(current_source)
sources.extend(extract_nested_sources(current_source))
# Collect all the sources that are used more than once, so that we can
# generate tmp variables in the generated pre-graph bytecode. This
# essentially implements CSE.
for arg in graphargs:
if arg.source is not None:
collect_temp_source(arg.source)
collect_temp_sources(
deque([arg.source for arg in graphargs if arg.source is not None]), self
)
cm_var = None
if config.record_runtime_overhead:

View File

@ -828,6 +828,7 @@ def trace_frame(
raise
finally:
tracer.output.call_cleanup_hooks()
tracer.f_locals = {}
try:
run_tracer()

View File

@ -96,6 +96,7 @@ def disable(fn=None, recursive=True, *, reason=None, wrapping=True): # type: ig
nonrecursive_disable_wrapper._torchdynamo_disable = True # type: ignore[attr-defined]
nonrecursive_disable_wrapper._torchdynamo_disable_msg = reason # type: ignore[attr-defined]
nonrecursive_disable_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
nonrecursive_disable_wrapper._torchdynamo_disable_recursive = False # type: ignore[attr-defined]
# pyrefly: ignore [bad-return]
return nonrecursive_disable_wrapper
@ -1023,3 +1024,13 @@ def error_on_graph_break(
The default value of torch.compile's `error_on_graph_break` setting is False.
"""
return ErrorOnGraphBreakDecoratorContextManager(error_on_graph_break)
def is_dynamo_disable_recursive(method: Callable[[Any], Any]) -> Optional[bool]:
"""
Check if a method is marked as `dynamo_disable` recursively. It returns:
- True if disable(recursive=True)
- False if disable(recursive=False)
- None if method is not a disable decorator
"""
return getattr(method, "_torchdynamo_disable_recursive", None)

View File

@ -1155,6 +1155,8 @@ class DisableContext(_TorchDynamoContext):
# of decorators.
_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
_fn._torchdynamo_disable_recursive = True # type: ignore[attr-defined]
return _fn
def __reduce__(self) -> tuple[type[DisableContext], tuple[Any, ...]]:

View File

@ -11,6 +11,16 @@ from typing import Any, TYPE_CHECKING, TypeVar
import optree
import optree._C
import optree.utils
from optree import (
is_namedtuple,
is_namedtuple_class,
is_namedtuple_instance,
is_structseq,
is_structseq_class,
is_structseq_instance,
namedtuple_fields,
structseq_fields,
)
import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
@ -26,7 +36,26 @@ if TYPE_CHECKING:
from torch.utils._cxx_pytree import PyTree
__all__: list[str] = []
__all__ = [
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
"namedtuple_fields",
"structseq_fields",
"treespec_leaf",
"treespec_tuple",
"treespec_dict",
"tree_is_leaf",
"tree_iter",
"tree_leaves",
"tree_flatten",
"tree_flatten_with_path",
"tree_structure",
"tree_unflatten",
]
_T = TypeVar("_T")
@ -48,21 +77,20 @@ def _(*args: Any, **kwargs: Any) -> bool:
__name = ""
for __name in (
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
"namedtuple_fields",
"structseq_fields",
for __name, __func in (
("is_namedtuple", is_namedtuple),
("is_namedtuple_class", is_namedtuple_class),
("is_namedtuple_instance", is_namedtuple_instance),
("is_structseq", is_structseq),
("is_structseq_class", is_structseq_class),
("is_structseq_instance", is_structseq_instance),
("namedtuple_fields", namedtuple_fields),
("structseq_fields", structseq_fields),
):
__func = getattr(optree, __name)
globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)(
__func.__python_implementation__
)
__all__ += [__name] # noqa: PLE0604
globals()[__name] = substitute_in_graph(
__func, # type: ignore[arg-type]
can_constant_fold_through=True,
)(__func.__python_implementation__) # type: ignore[attr-defined]
del __func
del __name
@ -78,7 +106,7 @@ def tree_is_leaf(
) -> bool:
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)):
return True
if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: # type: ignore[attr-defined]
if optree.register_pytree_node.get(type(tree), namespace=namespace) is None:
return True
return False
@ -113,9 +141,6 @@ def tree_iter(
stack.extend(reversed(children))
__all__ += ["tree_iter"]
@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) # type: ignore[arg-type]
def tree_leaves(
tree: PyTree,
@ -135,9 +160,6 @@ def tree_leaves(
)
__all__ += ["tree_leaves"]
class _Asterisk(str):
__slots__ = ()
@ -168,7 +190,7 @@ class PyTreeSpec:
num_leaves: int = field(init=False)
num_children: int = field(init=False)
def __post_init__(self) -> None:
def __post_init__(self, /) -> None:
if self._type is None:
assert len(self._children) == 0
assert self._metadata is None
@ -187,7 +209,7 @@ class PyTreeSpec:
object.__setattr__(self, "num_leaves", num_leaves)
object.__setattr__(self, "num_children", num_children)
def __repr__(self) -> str:
def __repr__(self, /) -> str:
def helper(treespec: PyTreeSpec) -> str:
if treespec.is_leaf():
assert treespec.type is None
@ -221,29 +243,78 @@ class PyTreeSpec:
]
return f"PyTreeSpec({', '.join(inner)})"
def __len__(self) -> int:
def __len__(self, /) -> int:
return self.num_leaves
@property
def type(self) -> builtins.type | None:
def type(self, /) -> builtins.type | None:
return self._type
def is_leaf(self) -> bool:
def is_leaf(self, /) -> bool:
return self.num_nodes == 1 and self.num_leaves == 1
def children(self) -> list[PyTreeSpec]:
def paths(self, /) -> list[tuple[Any, ...]]:
def helper(treespec: PyTreeSpec, path_prefix: list[Any]) -> None:
if treespec.is_leaf():
paths.append(path_prefix)
return
for entry, subspec in zip(
treespec._entries,
treespec._children,
strict=True,
):
helper(subspec, path_prefix + [entry])
paths: list[list[Any]] = []
helper(self, [])
return [tuple(path) for path in paths]
def accessors(self, /) -> list[optree.PyTreeAccessor]:
def helper(
treespec: PyTreeSpec,
entry_path_prefix: list[optree.PyTreeEntry],
) -> None:
if treespec.is_leaf():
entry_paths.append(entry_path_prefix)
return
node_type = treespec.type
assert node_type is not None
handler = optree.register_pytree_node.get(
node_type, namespace=treespec.namespace
)
assert handler is not None
kind: optree.PyTreeKind = handler.kind
path_entry_type: type[optree.PyTreeEntry] = handler.path_entry_type
for entry, subspec in zip(
treespec._entries,
treespec._children,
strict=True,
):
helper(
subspec,
entry_path_prefix + [path_entry_type(entry, node_type, kind)],
)
entry_paths: list[list[optree.PyTreeEntry]] = []
helper(self, [])
return [optree.PyTreeAccessor(path) for path in entry_paths]
def children(self, /) -> list[PyTreeSpec]:
return list(self._children)
def child(self, index: int) -> PyTreeSpec:
def child(self, index: int, /) -> PyTreeSpec:
return self._children[index]
def entries(self) -> list[Any]:
def entries(self, /) -> list[Any]:
return list(self._entries)
def entry(self, index: int) -> Any:
def entry(self, index: int, /) -> Any:
return self._entries[index]
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
def flatten_up_to(self, tree: PyTree, /) -> list[PyTree]:
def helper(
treespec: PyTreeSpec,
node: PyTree,
@ -324,14 +395,14 @@ class PyTreeSpec:
f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch
)
for subtree, subspec in zip(children, treespec._children):
for subtree, subspec in zip(children, treespec._children, strict=True):
helper(subspec, subtree, subtrees)
subtrees: list[PyTree] = []
helper(self, tree, subtrees)
return subtrees
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
def unflatten(self, leaves: Iterable[Any], /) -> PyTree:
if not isinstance(leaves, (list, tuple)):
leaves = list(leaves)
if len(leaves) != self.num_leaves:
@ -408,7 +479,7 @@ def treespec_tuple(
"All children PyTreeSpecs must have the same `namespace` value "
f"as the parent; expected {namespace!r}, got: {children!r}.",
)
handler = optree.register_pytree_node.get(tuple, namespace=namespace) # type: ignore[attr-defined]
handler = optree.register_pytree_node.get(tuple, namespace=namespace)
assert handler is not None
return PyTreeSpec(
tuple(children),
@ -531,7 +602,69 @@ def tree_flatten(
return leaves, treespec
__all__ += ["tree_flatten"]
@substitute_in_graph( # type: ignore[arg-type]
optree._C.flatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def _C_flatten(
tree: PyTree,
/,
leaf_predicate: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> tuple[list[Any], PyTreeSpec]:
return tree_flatten( # type: ignore[return-value]
tree,
is_leaf=leaf_predicate,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
@substitute_in_graph( # type: ignore[arg-type]
optree.tree_flatten_with_path,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_flatten_with_path(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]:
leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
return treespec.paths(), leaves, treespec # type: ignore[return-value]
@substitute_in_graph( # type: ignore[arg-type]
optree._C.flatten_with_path,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def _C_flatten_with_path(
tree: PyTree,
/,
leaf_predicate: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]:
return tree_flatten_with_path( # type: ignore[return-value]
tree,
is_leaf=leaf_predicate,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
@substitute_in_graph( # type: ignore[arg-type]
@ -556,9 +689,6 @@ def tree_structure(
)[1]
__all__ += ["tree_structure"]
@substitute_in_graph( # type: ignore[arg-type]
optree.tree_unflatten,
# We need to disable constant folding here because we want the function to reference the
@ -574,55 +704,6 @@ def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
return treespec.unflatten(leaves)
__all__ += ["tree_unflatten"]
@substitute_in_graph(optree.tree_map, can_constant_fold_through=True) # type: ignore[arg-type]
def tree_map(
func: Callable[..., Any],
tree: PyTree,
/,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTree:
leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
return treespec.unflatten(map(func, *flat_args))
__all__ += ["tree_map"]
@substitute_in_graph(optree.tree_map_, can_constant_fold_through=True) # type: ignore[arg-type]
def tree_map_(
func: Callable[..., Any],
tree: PyTree,
/,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTree:
leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
return tree
__all__ += ["tree_map_"]
_none_registration = optree.register_pytree_node.get(type(None))
assert _none_registration is not None
@ -669,5 +750,5 @@ def dict_unflatten(
) -> dict[_KT, _VT]:
original_keys, sorted_keys = metadata
d = dict.fromkeys(original_keys)
d.update(zip(sorted_keys, values))
d.update(zip(sorted_keys, values, strict=True))
return d # type: ignore[return-value]

View File

@ -1516,17 +1516,29 @@ def should_prefer_unfused_addmm(match):
@register_graph_pattern(
CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
CallFunction(
aten.addmm,
KeywordArg("inp"),
Arg(),
Arg(),
beta=KeywordArg("beta"),
alpha=KeywordArg("alpha"),
),
# pyrefly: ignore [bad-argument-type]
pass_dict=pass_patterns[2],
extra_check=should_prefer_unfused_addmm,
)
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):
def repl(inp, x1, x2):
return x1 @ x2 + inp
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp, alpha, beta):
def repl(inp, x1, x2, alpha, beta):
mm_result = x1 @ x2
if alpha != 1:
mm_result = alpha * mm_result
if beta != 1:
inp = beta * inp
return inp + mm_result
# pyrefly: ignore [bad-argument-type]
match.replace_by_example(repl, [inp, mat1, mat2])
match.replace_by_example(repl, [inp, mat1, mat2, alpha, beta])
def is_valid_addmm_fusion(match):

View File

@ -1,5 +1,7 @@
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
#include <torch/csrc/distributed/c10d/FlightRecorder.hpp>
#include <fmt/format.h>
#include <mutex>
#include <shared_mutex>
@ -63,6 +65,14 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
res.setStatus(200);
}};
RegisterHandler frTracehandler(
"fr_trace_json",
[](const Request&, Response& res) {
auto trace = ::c10d::dump_fr_trace_json(true, true);
res.setContent(std::move(trace), "application/json");
res.setStatus(200);
});
} // namespace
void registerHandler(const std::string& name, HandlerFunc f) {

View File

@ -18,6 +18,14 @@ class TORCH_API Request {
virtual const std::string& body() const = 0;
virtual const std::multimap<std::string, std::string>& params() const = 0;
std::string getParam(const std::string& key) const {
auto it = params().find(key);
if (it != params().end()) {
return it->second;
}
return "";
}
};
// Response represents a response to the handler. This conceptually maps to an

View File

@ -152,11 +152,17 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) {
TORCH_CHECK(
server_.bind_to_port(hostOrFile, 80),
fmt::format("Error binding to {}", hostOrFile));
} else if (port == 0) {
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
port_ = server_.bind_to_any_port(hostOrFile);
TORCH_CHECK(
port_ >= 0, fmt::format("Error binding to {}:{}", hostOrFile, port));
} else {
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
TORCH_CHECK(
server_.bind_to_port(hostOrFile, port),
fmt::format("Error binding to {}:{}", hostOrFile, port));
port_ = port;
}
serverThread_ = std::thread([this]() {

View File

@ -19,9 +19,14 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
void shutdown();
int port() {
return port_;
}
private:
httplib::Server server_;
std::thread serverThread_;
int port_;
};
} // namespace c10d::control_plane

View File

@ -46,6 +46,7 @@
#include <fmt/format.h>
#include <pybind11/chrono.h>
#include <pybind11/functional.h>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
@ -4203,7 +4204,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
}),
py::arg("host_or_file"),
py::arg("port") = -1)
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown)
.def_property_readonly(
"port", &::c10d::control_plane::WorkerServer::port);
module.def(
"_get_handler",
@ -4219,6 +4222,25 @@ such as `dist.all_reduce(tensor, async_op=True)`.
Returns the handler with the specified name.
)");
module.def(
"_register_handler",
[](const std::string& name, const py::function& handler) {
::c10d::control_plane::registerHandler(
name,
[handler](
const ::c10d::control_plane::Request& req,
::c10d::control_plane::Response& res) {
py::gil_scoped_acquire acquire;
handler(std::ref(req), std::ref(res));
});
},
py::arg("name"),
py::arg("handler"),
R"(
Registers a handler by name.
)");
module.def(
"_get_handler_names",
&::c10d::control_plane::getHandlerNames,
@ -4236,12 +4258,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
// Default constructor.
.def(py::init<>())
.def("body", &::c10d::control_plane::Request::body)
.def("params", &::c10d::control_plane::Request::params);
.def("get_param", &::c10d::control_plane::Request::getParam);
py::class_<
::c10d::control_plane::Response,
std::shared_ptr<::c10d::control_plane::Response>,
PythonResponse>(
py::class_<::c10d::control_plane::Response, PythonResponse>(
module,
"_Response",
R"(

View File

@ -0,0 +1,52 @@
import multiprocessing
import socket
# import for registration side effect
import torch.distributed.debug._handlers # noqa: F401
from torch._C._distributed_c10d import _WorkerServer
from torch.distributed.debug._store import get_rank, tcpstore_client
__all__ = [
"start_debug_server",
"stop_debug_server",
]
_WORKER_SERVER: _WorkerServer | None = None
_DEBUG_SERVER_PROC: multiprocessing.Process | None = None
def start_debug_server(port: int = 25999, worker_port: int = 0) -> None:
global _WORKER_SERVER, _DEBUG_SERVER_PROC
assert _WORKER_SERVER is None, "debug server already started"
assert _DEBUG_SERVER_PROC is None, "debug server already started"
store = tcpstore_client()
_WORKER_SERVER = _WorkerServer("::", worker_port)
RANK = get_rank()
store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_WORKER_SERVER.port}")
from torch.distributed.debug._flask import main
if RANK == 0:
_DEBUG_SERVER_PROC = multiprocessing.Process(
target=main, args=(port,), daemon=True
)
_DEBUG_SERVER_PROC.start()
def stop_debug_server() -> None:
global _WORKER_SERVER, _DEBUG_SERVER_PROC
assert _DEBUG_SERVER_PROC is not None
assert _WORKER_SERVER is not None
_DEBUG_SERVER_PROC.terminate()
_WORKER_SERVER.shutdown()
_DEBUG_SERVER_PROC.join()
_WORKER_SERVER = None
_DEBUG_SERVER_PROC = None

View File

@ -0,0 +1,265 @@
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Iterator
import requests
from flask import Flask, render_template, request
from jinja2 import DictLoader
from torch.distributed.debug._store import get_world_size, tcpstore_client
def fetch_all(
endpoint: str, args: str = ""
) -> tuple[list[str], Iterator[requests.Response]]:
store = tcpstore_client()
keys = [f"rank{r}" for r in range(get_world_size())]
addrs = store.multi_get(keys)
addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs]
with ThreadPoolExecutor(max_workers=10) as executor:
resps = executor.map(requests.post, addrs)
return addrs, resps
def format_json(blob: str):
parsed = json.loads(blob)
return json.dumps(parsed, indent=2)
templates = {
"base.html": """
<!doctype html>
<head>
<title>{% block title %}{% endblock %} - PyTorch Distributed</title>
<link rel="shortcut icon" type="image/x-icon" href="https://pytorch.org/favicon.ico?">
<style>
body {
margin: 0;
font-family:
-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,
"Helvetica Neue",Arial,"Noto Sans",sans-serif,"Apple Color Emoji",
"Segoe UI Emoji","Segoe UI Symbol","Noto Color Emoji";
font-size: 1rem;
font-weight: 400;
line-height: 1.5;
color: #212529;
text-align: left;
background-color: #fff;
}
h1, h2, h2, h4, h5, h6, .h1, .h2, .h2, .h4, .h5, .h6 {
margin-bottom: .5rem;
font-weight: 500;
line-height: 1.2;
}
nav {
background-color: rgba(0, 0, 0, 0.17);
padding: 10px;
display: flex;
align-items: center;
padding: 16px;
justify-content: flex-start;
}
nav h1 {
display: inline-block;
margin: 0;
}
nav a {
margin: 0 8px;
}
section {
max-width: 1280px;
padding: 16px;
margin: 0 auto;
}
pre {
white-space: pre-wrap;
max-width: 100%;
}
</style>
</head>
<nav>
<h1>Torch Distributed Debug Server</h1>
<a href="/">Home</a>
<a href="/stacks">Python Stack Traces</a>
<a href="/fr_trace">FlightRecorder</a>
<a href="/fr_trace_nccl">FlightRecorder NCCL</a>
<a href="/profile">torch profiler</a>
</nav>
<section class="content">
{% block header %}{% endblock %}
{% for message in get_flashed_messages() %}
<div class="flash">{{ message }}</div>
{% endfor %}
{% block content %}{% endblock %}
</section>
""",
"index.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}Index{% endblock %}</h1>
{% endblock %}
{% block content %}
Hi
{% endblock %}
""",
"raw_resp.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}{{title}}{% endblock %}</h1>
{% endblock %}
{% block content %}
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
<h2>Rank {{ i }}: {{ addr }}</h2>
{% if resp.status_code != 200 %}
<p>Failed to fetch: status={{ resp.status_code }}</p>
<pre>{{ resp.text }}</pre>
{% else %}
<pre>{{ resp.text }}</pre>
{% endif %}
{% endfor %}
{% endblock %}
""",
"json_resp.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}{{ title }}{% endblock %}</h1>
{% endblock %}
{% block content %}
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
<h2>Rank {{ i }}: {{ addr }}</h2>
{% if resp.status_code != 200 %}
<p>Failed to fetch: status={{ resp.status_code }}</p>
<pre>{{ resp.text }}</pre>
{% else %}
<pre>{{ format_json(resp.text) }}</pre>
{% endif %}
{% endfor %}
{% endblock %}
""",
"profile.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}torch.profiler{% endblock %}</h1>
{% endblock %}
{% block content %}
<form action="/profile" method="get">
<label for="duration">Duration (seconds):</label>
<input type="number" id="duration" name="duration" value="{{ duration }}" min="1" max="60">
<input type="submit" value="Submit">
</form>
<script>
function stringToArrayBuffer(str) {
const encoder = new TextEncoder();
return encoder.encode(str).buffer;
}
async function openPerfetto(data) {
const ui = window.open('https://ui.perfetto.dev/#!/');
if (!ui) { alert('Popup blocked. Allow popups for this page and click again.'); return; }
// Perfetto readiness handshake: PING until we receive PONG
await new Promise((resolve, reject) => {
const onMsg = (e) => {
if (e.source === ui && e.data === 'PONG') {
window.removeEventListener('message', onMsg);
clearInterval(pinger);
resolve();
}
};
window.addEventListener('message', onMsg);
const pinger = setInterval(() => { try { ui.postMessage('PING', '*'); } catch (_e) {} }, 250);
setTimeout(() => { clearInterval(pinger); window.removeEventListener('message', onMsg); reject(); }, 20000);
}).catch(() => { alert('Perfetto UI did not respond. Try again.'); return; });
ui.postMessage({
perfetto: {
buffer: stringToArrayBuffer(JSON.stringify(data)),
title: "torch profiler",
fileName: "trace.json",
}
}, '*');
}
</script>
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
<h2>Rank {{ i }}: {{ addr }}</h2>
{% if resp.status_code != 200 %}
<p>Failed to fetch: status={{ resp.status_code }}</p>
<pre>{{ resp.text }}</pre>
{% else %}
<script>
function run{{ i }}() {
var data = {{ resp.text | safe }};
openPerfetto(data);
}
</script>
<button onclick="run{{ i }}()">View {{ i }}</button>
{% endif %}
{% endfor %}
{% endblock %}
""",
}
app = Flask(__name__)
app.jinja_loader = DictLoader(templates)
app.jinja_env.globals.update(
zip=zip,
format_json=format_json,
enumerate=enumerate,
)
@app.route("/")
def _index_handler():
return render_template("index.html")
@app.route("/stacks")
def _stacks_handler():
addrs, resps = fetch_all("dump_traceback")
return render_template("raw_resp.html", title="Stacks", addrs=addrs, resps=resps)
@app.route("/fr_trace")
def _fr_trace_handler():
addrs, resps = fetch_all("fr_trace_json")
return render_template(
"json_resp.html",
title="FlightRecorder",
addrs=addrs,
resps=resps,
)
@app.route("/fr_trace_nccl")
def _fr_trace_nccl_handler():
addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true")
return render_template(
"json_resp.html",
title="FlightRecorder NCCL",
addrs=addrs,
resps=resps,
)
@app.route("/profile")
def _profiler_handler():
duration = request.args.get("duration", default=1.0, type=float)
addrs, resps = fetch_all("torch_profile", f"duration={duration}")
return render_template("profile.html", addrs=addrs, resps=resps)
def main(port: int) -> None:
app.run(host="::", port=port)

View File

@ -0,0 +1,22 @@
import tempfile
import time
from torch._C._distributed_c10d import _register_handler, _Request, _Response
from torch.profiler import _ExperimentalConfig, profile
def _torch_profile(req: _Request, resp: _Response) -> None:
experimental_config = _ExperimentalConfig(
profile_all_threads=True,
)
duration = float(req.get_param("duration"))
with profile(record_shapes=True, experimental_config=experimental_config) as prof:
time.sleep(duration)
with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f:
prof.export_chrome_trace(f.name)
resp.set_content(open(f.name, "rb").read(), "application/json")
resp.set_status(200)
_register_handler("torch_profile", _torch_profile)

View File

@ -0,0 +1,24 @@
import os
import torch.distributed as dist
def get_rank() -> int:
return int(os.environ["RANK"])
def get_world_size() -> int:
return int(os.environ["WORLD_SIZE"])
def tcpstore_client() -> dist.Store:
MASTER_ADDR = os.environ["MASTER_ADDR"]
MASTER_PORT = int(os.environ["MASTER_PORT"])
store = dist.TCPStore(
host_name=MASTER_ADDR,
port=MASTER_PORT,
is_master=False,
)
store = dist.PrefixStore("debug_server", store)
return store

View File

@ -1485,6 +1485,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
use_full_backward: Optional[bool] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
# Init parent
super().__init__(
@ -1517,6 +1518,11 @@ class PipelineScheduleMulti(_PipelineSchedule):
# This will be set during init of derived schedules
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
# When using a custom backward function, we may or may not need autograd to be used
# for the backward pass. This flag is used to determine whether or torch.is_grad_enabled()
# check should be performed before the step function.
self._backward_requires_autograd = backward_requires_autograd
if use_full_backward is not None:
logger.warning(
"Deprecation warning: 'use_full_backward' is no longer supported. "
@ -1609,7 +1615,11 @@ class PipelineScheduleMulti(_PipelineSchedule):
losses: a list to store the losses for each microbatch.
return_outputs: whether to return the outputs from the last stage.
"""
if self._has_backward and not torch.is_grad_enabled():
if (
self._has_backward
and self._backward_requires_autograd
and not torch.is_grad_enabled()
):
raise RuntimeError(
"step() requires gradients to be enabled for backward computation; "
"it should not be used under torch.no_grad() context. "
@ -1891,7 +1901,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
Args:
computation_type: The computation type for which to register the custom function
custom_function: The function to execute when this computation type is encountered.
Must have signature: (stage: _PipelineStageBase, mb_index: int, *args, **kwargs) -> None
Must have signature: (action: _Action, ctx: _PipelineContext) -> None
"""
# Ensure that the computation type is valid
if computation_type not in (
@ -1900,10 +1910,13 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
BACKWARD_INPUT,
BACKWARD_WEIGHT,
OVERLAP_F_B,
UNSHARD,
RESHARD,
REDUCE_GRAD,
):
raise ValueError(
f"Invalid computation type {computation_type}. Only FORWARD, FULL_BACKWARD, \
BACKWARD_INPUT, BACKWARD_WEIGHT, and OVERLAP_F_B are supported."
BACKWARD_INPUT, BACKWARD_WEIGHT, OVERLAP_F_B, UNSHARD, RESHARD and REDUCE_GRAD are supported."
)
# Check if computation_type is already registered
@ -2296,6 +2309,7 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
loss_fn: Optional[Union[Callable, _Loss]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
super().__init__(
stages=stages,
@ -2303,6 +2317,7 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
loss_fn=loss_fn,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
backward_requires_autograd=backward_requires_autograd,
)
# 1. Create the pipeline_order (all ranks do this calculation)
@ -2510,6 +2525,7 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
self.pp_group_size = stages[0].group_size
super().__init__(
@ -2520,6 +2536,7 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
backward_requires_autograd=backward_requires_autograd,
)
self.n_local_stages = len(stages)
self.rank = stages[0].group_rank
@ -2622,6 +2639,7 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
# TODO: we dont support input/weight backward split with torch.compile
_check_torch_compile_compatibility(stages, self.__class__.__name__)
@ -2634,6 +2652,7 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
backward_requires_autograd=backward_requires_autograd,
)
self.n_local_stages = len(stages)
self.rank = stages[0].group_rank
@ -2819,6 +2838,7 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
# TODO: we dont support input/weight backward split with torch.compile
_check_torch_compile_compatibility(stages, self.__class__.__name__)
@ -2831,6 +2851,7 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
backward_requires_autograd=backward_requires_autograd,
)
self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
self.pp_group_size, self._num_stages, style="v"
@ -2995,6 +3016,7 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
# TODO: we dont support input/weight backward split with torch.compile
_check_torch_compile_compatibility(stages, self.__class__.__name__)
@ -3007,6 +3029,7 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
backward_requires_autograd=backward_requires_autograd,
)
self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
self.pp_group_size, self._num_stages, style="v"

View File

@ -1771,6 +1771,22 @@ class MultiProcContinuousTest(TestCase):
cls._run_test_given_id(test_id)
completion_queue.put(test_id)
except BaseException as ex: # noqa: B036
if isinstance(ex, SystemExit):
# Get exit code from the process
exit_code = getattr(ex, "code", None)
# Look up exit code in TEST_SKIPS to see if it is a valid skip
skip_entry = next(
(v for v in TEST_SKIPS.values() if v.exit_code == exit_code),
None,
)
# If we found an entry, we want to skip the test and the object back to the main process
if skip_entry:
completion_queue.put(unittest.SkipTest(skip_entry.message))
# Skip exception handling below, move to main thread for processing the skip
continue
raised_exception = True
# Send the exception and stack trace back to the dispatcher
exc_info = sys.exc_info()
@ -1892,6 +1908,8 @@ class MultiProcContinuousTest(TestCase):
# Wait for the workers to finish the test
for i, completion_queue in enumerate(self.completion_queues):
rv = completion_queue.get()
if isinstance(rv, unittest.SkipTest):
raise rv
if isinstance(rv, BaseException):
# Hit an exception, re-raise it in the main process.
logger.warning(

View File

@ -114,8 +114,6 @@ class ProfilingMode(Enum):
PROFILING = 3
# Set by parse_cmd_line_args() if called
CI_FUNCTORCH_ROOT = ""
CI_PT_ROOT = ""
CI_TEST_PREFIX = ""
DISABLED_TESTS_FILE = ""
GRAPH_EXECUTOR : Optional[ProfilingMode] = None
@ -959,8 +957,6 @@ def _get_test_report_path():
return os.path.join('test-reports', test_source)
def parse_cmd_line_args():
global CI_FUNCTORCH_ROOT
global CI_PT_ROOT
global CI_TEST_PREFIX
global DISABLED_TESTS_FILE
global GRAPH_EXECUTOR
@ -1039,10 +1035,8 @@ def parse_cmd_line_args():
set_rng_seed()
# CI Prefix path used only on CI environment
# CI Prefix path used only on CI environment
CI_TEST_PREFIX = str(Path(os.getcwd()))
CI_PT_ROOT = str(Path(os.getcwd()).parent)
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
def wait_for_process(p, timeout=None):
try: