mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 21:59:56 +08:00
Compare commits
11 Commits
csl/remove
...
d4l3k/debu
| Author | SHA1 | Date | |
|---|---|---|---|
| 04c8da3809 | |||
| e401a56b96 | |||
| 22650c89fb | |||
| c62a17a2fb | |||
| 713e289ae7 | |||
| 69784a0dbe | |||
| 3c2409c465 | |||
| 724cd32b0c | |||
| b62935d1a5 | |||
| ccc8c117dc | |||
| 86db4de10f |
22
.github/labeler.yml
vendored
22
.github/labeler.yml
vendored
@ -138,7 +138,8 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- torch/**/*cublas*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
@ -148,7 +149,8 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- torch/**/*cublas*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
@ -158,20 +160,8 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
- third_party/fbgemm
|
||||
|
||||
"ciflow/mps":
|
||||
- aten/src/ATen/mps/**
|
||||
- aten/src/ATen/native/mps/**
|
||||
- torch/_inductor/codegen/mps.py
|
||||
- test/test_mps.py
|
||||
- test/inductor/test_mps_basic.py
|
||||
|
||||
"ciflow/h100-symm-mem":
|
||||
- torch/csrc/distributed/c10d/symm_mem/**
|
||||
- torch/distributed/_symmetric_memory/**
|
||||
- test/distributed/**/*mem*
|
||||
- test/distributed/**/*mem*/**
|
||||
|
||||
@ -210,8 +210,12 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
|
||||
/test/inductor/test_flex_attention.py @drisspg
|
||||
/test/inductor/test_flex_decoding.py @drisspg
|
||||
|
||||
# Low Precision GEMMs
|
||||
# Low Precision & Grouped GEMMs
|
||||
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/native/cuda/GroupedBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/native/cuda/ScaledBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDAScaledBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDAScaledBlas.h @drisspg @slayton58
|
||||
/test/test_scaled_matmul_cuda.py @drisspg @slayton58
|
||||
|
||||
@ -76,7 +76,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
store=dist.FileStore(self.file_name, self.world_size),
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_replicate_transformer(self):
|
||||
"""
|
||||
This tests that replicate works on a transformer model with fully_shard and replicate layers
|
||||
@ -126,7 +126,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
for parameter in layer.parameters():
|
||||
self.assertEqual(parameter.placements, (Shard(dim=0),))
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_replicate_transformer_managed_modules(self):
|
||||
"""
|
||||
This tests that replicate managed modules works properly. In this test we use a Transformer Module with 3 layers,
|
||||
@ -178,7 +178,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
replicate_model = replicate(replicate_model)
|
||||
self.assertEqual(len(_get_managed_modules((replicate_model,))), 21)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_replicate_tp_device_mesh(self):
|
||||
"""
|
||||
This tests that a user can pass in a device mesh to replicate a module
|
||||
@ -206,7 +206,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
self.assertEqual(parameter.device_mesh.shape, (2,))
|
||||
self.assertEqual(parameter.placements, (Replicate(),))
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_train_replicate_fsdp(self):
|
||||
"""
|
||||
Tests that replicate_model has the same behavior as original model when training
|
||||
@ -253,7 +253,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
self.assertEqual(replicate_loss, loss)
|
||||
check_sharded_parity(self, model, replicate_model)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_train_parity_2d_mlp(self):
|
||||
"""
|
||||
Verifies when a device mesh is passed in, the model has the same behavior as the original model when training
|
||||
|
||||
@ -299,7 +299,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_dtensor_checkpoint_with_uneven_shards(self) -> None:
|
||||
"""
|
||||
Saving a dtensor with uneven shards.
|
||||
@ -436,6 +436,7 @@ class TestCheckpointableReshard(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_uneven_reshard_with_checkpointable_api(self) -> None:
|
||||
"""
|
||||
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.
|
||||
@ -498,6 +499,7 @@ class TestCheckpointableReshard(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_uneven_reshard_with_dtensor_shards_wrapper_api(self) -> None:
|
||||
"""
|
||||
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.
|
||||
|
||||
@ -886,7 +886,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
self.assertEqual(cpu_model_value, meta_model_value)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_setting_meta_device_model_broadcasting_and_memory(self) -> None:
|
||||
# This test verifies that we can set model state dict by a meta device model
|
||||
# With the correlated changes in state_dict, meta device model should be accepted
|
||||
|
||||
@ -39,6 +39,7 @@ from torch.nn.modules.loss import MSELoss
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcContinuousTest,
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
check_leaked_tensors,
|
||||
@ -231,6 +232,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [_ScheduleForwardOnly])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_forward_only(self, ScheduleClass):
|
||||
mod, mod_ref, x, _, _ = setup_models_and_data(self.config)
|
||||
x_clone = x.clone()
|
||||
@ -274,6 +276,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
ScheduleInterleavedZeroBubble,
|
||||
],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_eval_inference_mode(self, ScheduleClass):
|
||||
num_microbatches = 4
|
||||
if ScheduleClass in [
|
||||
@ -351,6 +354,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
ScheduleInterleavedZeroBubble,
|
||||
],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_return_output(self, ScheduleClass):
|
||||
num_microbatches = 4
|
||||
if ScheduleClass in [
|
||||
@ -406,6 +410,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_multi_iter(self, ScheduleClass):
|
||||
mod, _, x, target, loss_fn = setup_models_and_data(self.config)
|
||||
chunks = 4
|
||||
@ -429,6 +434,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_kwargs_with_tracer(self, ScheduleClass):
|
||||
mod = ModelWithKwargs(d_hid, splits=self.world_size)
|
||||
mod.to(self.device)
|
||||
@ -481,6 +487,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_grad_with_tracer(self, ScheduleClass):
|
||||
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
|
||||
|
||||
@ -523,6 +530,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@parametrize("shape_inference", [True, False])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_grad_with_manual(self, ScheduleClass, shape_inference):
|
||||
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
|
||||
|
||||
@ -586,6 +594,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
ScheduleInterleavedZeroBubble,
|
||||
],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_grad_with_manual_interleaved(self, ScheduleClass):
|
||||
stages_per_rank = 2
|
||||
n_stages = stages_per_rank * self.world_size
|
||||
@ -650,6 +659,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass):
|
||||
stages_per_rank = 2
|
||||
n_stages = stages_per_rank * self.world_size
|
||||
@ -736,6 +746,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
"schedule_class",
|
||||
[ScheduleZBVZeroBubble, ScheduleDualPipeV],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_v_shape_schedules(self, schedule_class):
|
||||
n_stages = 8
|
||||
rank_stages = {0: [0, 7], 1: [1, 6], 2: [2, 5], 3: [3, 4]}
|
||||
@ -780,6 +791,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_custom_function_callback(self):
|
||||
"""Test the custom function callback functionality with _PipelineScheduleRuntime."""
|
||||
n_stages = 8
|
||||
@ -979,6 +991,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
"ScheduleClass",
|
||||
[ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_zero_bubble_with_model_kwargs(self, ScheduleClass):
|
||||
stages_per_rank = 2
|
||||
n_stages = stages_per_rank * self.world_size
|
||||
@ -1072,6 +1085,7 @@ class CustomSchedulesTest(MultiProcContinuousTest):
|
||||
"schedule_class",
|
||||
[ScheduleVShaped, ScheduleUnbalanced],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_non_symmetric_stage_ids(self, schedule_class):
|
||||
n_stages = schedule_class.n_stages
|
||||
rank_stages = schedule_class.rank_stages
|
||||
@ -1121,6 +1135,7 @@ class CustomSchedulesTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleWithReorderedB])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass):
|
||||
n_stages = 2
|
||||
stages_per_rank = 1
|
||||
@ -1181,6 +1196,7 @@ class CustomSchedulesTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleWithW])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_schedule_with_native_zero_bubble(self, ScheduleClass):
|
||||
n_stages = ScheduleClass.n_stages
|
||||
num_microbatches = ScheduleClass.num_microbatches
|
||||
|
||||
@ -26,6 +26,7 @@ from torch.distributed.tensor.parallel import (
|
||||
RowwiseParallel,
|
||||
SequenceParallel,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
@ -764,6 +765,7 @@ class DistMathOpsTest(DTensorTestBase):
|
||||
self.assertEqual(grad1_norm.device_mesh, mesh_y)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_foreach_add_different_mesh(self):
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(
|
||||
|
||||
@ -577,7 +577,7 @@ class DistTensorReplicateStrategyRegistrationTest(DTensorTestBase):
|
||||
self.assertEqual(
|
||||
comm_mode.get_comm_counts(),
|
||||
{
|
||||
torch.ops.c10d_functional.all_gather_into_tensor: 4,
|
||||
torch.ops.c10d_functional.all_gather_into_tensor: self.world_size,
|
||||
},
|
||||
)
|
||||
expected_cost = [
|
||||
|
||||
50
test/distributed/test_debug.py
Normal file
50
test/distributed/test_debug.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.debug import start_debug_server, stop_debug_server
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
session = requests.Session()
|
||||
retry_strategy = Retry(total=5, backoff_factor=0.5)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
|
||||
class TestDebug(TestCase):
|
||||
def test_basics(self) -> None:
|
||||
store = dist.TCPStore("localhost", 0, 1, is_master=True, wait_for_workers=False)
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(store.port)
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
|
||||
port = 25999
|
||||
|
||||
def fetch(path: str) -> str:
|
||||
resp = session.get(f"http://localhost:{port}{path}")
|
||||
resp.raise_for_status()
|
||||
return resp.text
|
||||
|
||||
print("starting!")
|
||||
|
||||
start_debug_server(port=port)
|
||||
|
||||
self.assertIn("torch profiler", fetch("/"))
|
||||
self.assertIn("View 0", fetch("/profile?duration=0.01"))
|
||||
self.assertIn("test_basics", fetch("/stacks"))
|
||||
self.assertIn("pg_status", fetch("/fr_trace"))
|
||||
self.assertIn("pg_status", fetch("/fr_trace_nccl"))
|
||||
|
||||
stop_debug_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -485,7 +485,7 @@ elif TEST_XPU:
|
||||
def exit_if_lt_x_accelerators(x):
|
||||
if torch.accelerator.is_available():
|
||||
if torch.accelerator.device_count() < x:
|
||||
sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code)
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
||||
|
||||
|
||||
def with_comms(func=None):
|
||||
|
||||
@ -2109,6 +2109,52 @@ Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: '
|
||||
with self.assertRaises(Unsupported):
|
||||
outer_f2(inp)
|
||||
|
||||
def test_disable_recursive_flags(self):
|
||||
class SimpleLinear(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layer0 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, inp):
|
||||
return self.layer0(torch.sigmoid(inp))
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layer0 = SimpleLinear()
|
||||
self.layer1 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, inp):
|
||||
z = self.layer0(torch.sin(inp))
|
||||
return self.layer1(z)
|
||||
|
||||
for recursive_flag in [True, False]:
|
||||
model = SimpleModel()
|
||||
other_model = SimpleModel()
|
||||
|
||||
model.forward = torch._dynamo.disable(
|
||||
model.forward,
|
||||
recursive=recursive_flag,
|
||||
)
|
||||
self.assertEqual(
|
||||
torch._dynamo.is_dynamo_disable_recursive(model.forward),
|
||||
recursive_flag,
|
||||
)
|
||||
|
||||
other_model = torch._dynamo.disable(other_model, recursive=recursive_flag)
|
||||
self.assertEqual(
|
||||
torch._dynamo.is_dynamo_disable_recursive(
|
||||
other_model.forward
|
||||
if isinstance(other_model, torch.nn.Module)
|
||||
else other_model
|
||||
),
|
||||
recursive_flag,
|
||||
)
|
||||
|
||||
# check the model is compilable
|
||||
torch.compile(model)
|
||||
torch.compile(other_model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -422,34 +422,41 @@ from user code:
|
||||
import optree
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
d = {"a": 1}
|
||||
optree.tree_flatten_with_path(d)
|
||||
return torch.sin(x)
|
||||
|
||||
def post_munge(s):
|
||||
s = re.sub(
|
||||
r"optree\.\S*\.flatten_with_path",
|
||||
"optree.<path>.flatten_with_path",
|
||||
s,
|
||||
)
|
||||
return re.sub(
|
||||
r"qualname: \S*flatten_with_path",
|
||||
"qualname: <path>.flatten_with_path",
|
||||
s,
|
||||
def fn1(x):
|
||||
tree = {"a": x, "b": (x - 1, 2 * x)}
|
||||
sin, cos = optree.tree_transpose_map(
|
||||
lambda t: (torch.sin(t), torch.cos(t)),
|
||||
tree,
|
||||
)
|
||||
return sin, cos
|
||||
|
||||
fn(torch.randn(4))
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
fn1(torch.randn(4))
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn2(x):
|
||||
spec = optree.treespec_deque([])
|
||||
return spec, x
|
||||
|
||||
fn2(torch.randn(4))
|
||||
self.assertGreaterEqual(len(counters["graph_break"]), 1)
|
||||
first_graph_break = next(iter(counters["graph_break"].keys()))
|
||||
|
||||
def post_munge(string):
|
||||
return re.sub(
|
||||
r"(optree\.|qualname: )\S*(\.make_from_collection)",
|
||||
r"\1<path>\2",
|
||||
string,
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
post_munge(first_graph_break),
|
||||
"""\
|
||||
Attempted to call function marked as skipped
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.flatten_with_path.
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.make_from_collection.
|
||||
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
|
||||
|
||||
Developer debug context: module: optree._C, qualname: <path>.flatten_with_path, skip reason: <missing reason>
|
||||
Developer debug context: module: optree._C, qualname: <path>.make_from_collection, skip reason: <missing reason>
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
|
||||
)
|
||||
|
||||
@ -1217,6 +1217,43 @@ class TestPatternMatcher(TestCase):
|
||||
_, (code) = run_and_get_code(fn2, args[0], args[1], args[2])
|
||||
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
|
||||
|
||||
def test_addmm_alpha_beta_with_pointwise(self):
|
||||
# Test that addmm with alpha/beta != 1 is unfused correctly with pointwise ops
|
||||
# See https://github.com/pytorch/pytorch/issues/167313
|
||||
x = torch.rand(2, device=GPU_TYPE)
|
||||
a = torch.rand(2, 3, device=GPU_TYPE)
|
||||
b = torch.rand(3, 2, device=GPU_TYPE)
|
||||
|
||||
def f(x, a, b):
|
||||
return torch.nn.functional.relu(torch.addmm(x, a, b, alpha=0.8, beta=0.2))
|
||||
|
||||
fc = torch.compile(f)
|
||||
|
||||
expected = f(x, a, b)
|
||||
actual = fc(x, a, b)
|
||||
|
||||
# The compiled version should produce the same result as eager
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
# Verify that addmm is unfused (should not use extern_kernels.addmm)
|
||||
# The pattern should be replaced with beta * x + alpha * (a @ b)
|
||||
_, (code) = run_and_get_code(fc, x, a, b)
|
||||
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
|
||||
|
||||
# Test with alpha=1, beta=1 (default) - should also unfuse
|
||||
def f_default(x, a, b):
|
||||
return torch.nn.functional.relu(torch.addmm(x, a, b))
|
||||
|
||||
fc_default = torch.compile(f_default)
|
||||
expected_default = f_default(x, a, b)
|
||||
actual_default = fc_default(x, a, b)
|
||||
|
||||
torch.testing.assert_close(actual_default, expected_default)
|
||||
|
||||
# Should unfuse and not use extern_kernels.addmm
|
||||
_, (code) = run_and_get_code(fc_default, x, a, b)
|
||||
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
|
||||
|
||||
def test_serialized_patterns_up_to_date(self):
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._inductor.fx_passes import joint_graph
|
||||
|
||||
@ -165,21 +165,6 @@ def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> list[dict[str,
|
||||
return flattened
|
||||
|
||||
|
||||
def get_tests_for_circleci(
|
||||
workflow_run_id: int, workflow_run_attempt: int
|
||||
) -> list[dict[str, Any]]:
|
||||
# Parse the reports and transform them to JSON
|
||||
test_cases = []
|
||||
for xml_report in Path(".").glob("**/test/test-reports/**/*.xml"):
|
||||
test_cases.extend(
|
||||
parse_xml_report(
|
||||
"testcase", xml_report, workflow_run_id, workflow_run_attempt
|
||||
)
|
||||
)
|
||||
|
||||
return test_cases
|
||||
|
||||
|
||||
def summarize_test_cases(test_cases: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Group test cases by classname, file, and job_id. We perform the aggregation
|
||||
manually instead of using the `test-suite` XML tag because xmlrunner does
|
||||
@ -258,21 +243,11 @@ if __name__ == "__main__":
|
||||
required=True,
|
||||
help="Head repository of the workflow",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--circleci",
|
||||
action="store_true",
|
||||
help="If this is being run through circleci",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Workflow id is: {args.workflow_run_id}")
|
||||
|
||||
if args.circleci:
|
||||
test_cases = get_tests_for_circleci(
|
||||
args.workflow_run_id, args.workflow_run_attempt
|
||||
)
|
||||
else:
|
||||
test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt)
|
||||
test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt)
|
||||
|
||||
# Flush stdout so that any errors in the upload show up last in the logs.
|
||||
sys.stdout.flush()
|
||||
|
||||
@ -100,7 +100,9 @@ class Logger:
|
||||
def _set_static_graph(self) -> None: ...
|
||||
|
||||
class _WorkerServer:
|
||||
def __init__(self, socket_path: str) -> None: ...
|
||||
port: int
|
||||
|
||||
def __init__(self, host_or_file: str, port: int = ...) -> None: ...
|
||||
def shutdown(self) -> None: ...
|
||||
|
||||
def get_debug_level(): ...
|
||||
@ -206,6 +208,7 @@ class Store:
|
||||
desired_value: str,
|
||||
) -> bytes: ...
|
||||
def delete_key(self, key: str) -> bool: ...
|
||||
def multi_get(self, keys: list[str]) -> list[bytes]: ...
|
||||
def num_keys(self) -> int: ...
|
||||
def set_timeout(self, timeout: timedelta): ...
|
||||
@overload
|
||||
@ -871,3 +874,15 @@ class ProcessGroupXCCL(Backend):
|
||||
|
||||
def _set_process_group(pg: ProcessGroup) -> None: ...
|
||||
def _current_process_group() -> ProcessGroup: ...
|
||||
|
||||
class _Request:
|
||||
def body(self) -> bytes: ...
|
||||
def get_param(self, str) -> str: ...
|
||||
|
||||
class _Response:
|
||||
def set_content(self, content: str | bytes, content_type: str) -> None: ...
|
||||
def set_status(self, status: int) -> None: ...
|
||||
|
||||
def _register_handler(
|
||||
name: str, handler: Callable[[_Request, _Response], None]
|
||||
) -> None: ...
|
||||
|
||||
@ -60,6 +60,7 @@ class _ExperimentalConfig:
|
||||
verbose: bool = ...,
|
||||
performance_events: list[str] = ...,
|
||||
enable_cuda_sync_events: bool = ...,
|
||||
profile_all_threads: bool = ...,
|
||||
) -> None: ...
|
||||
|
||||
class ProfilerConfig:
|
||||
|
||||
@ -32,6 +32,7 @@ from .decorators import (
|
||||
error_on_graph_break,
|
||||
forbid_in_graph,
|
||||
graph_break,
|
||||
is_dynamo_disable_recursive,
|
||||
mark_dynamic,
|
||||
mark_static,
|
||||
mark_static_address,
|
||||
@ -87,6 +88,7 @@ __all__ = [
|
||||
"forbid_in_graph",
|
||||
"graph_break",
|
||||
"is_compiling",
|
||||
"is_dynamo_disable_recursive",
|
||||
"list_backends",
|
||||
"lookup_backend",
|
||||
"mark_dynamic",
|
||||
|
||||
@ -15,7 +15,7 @@ import dataclasses
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
from collections import Counter
|
||||
from collections import Counter, deque
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
@ -597,32 +597,35 @@ class PyCodegen:
|
||||
|
||||
graphargs = self.tx.output.graphargs
|
||||
|
||||
seen_sources: OrderedSet[Source] = OrderedSet()
|
||||
|
||||
def collect_temp_source(source: Source) -> None:
|
||||
if source in seen_sources:
|
||||
# This source is used at least twice, so it can be reused
|
||||
self.mark_source_temp(source)
|
||||
# Dont trace source further. This prevents us from marking too
|
||||
# many nodes as temp sources.
|
||||
return
|
||||
|
||||
seen_sources.add(source)
|
||||
|
||||
def extract_nested_sources(source: Source) -> list[Source]:
|
||||
nested_sources: list[Source] = []
|
||||
if isinstance(source, ChainedSource):
|
||||
collect_temp_source(source.base)
|
||||
|
||||
nested_sources.append(source.base)
|
||||
if isinstance(source, DictGetItemSource) and isinstance(
|
||||
source.index, Source
|
||||
):
|
||||
collect_temp_source(source.index)
|
||||
nested_sources.append(source.index)
|
||||
return nested_sources
|
||||
|
||||
def collect_temp_sources(sources: deque[Source], codegen: PyCodegen) -> None:
|
||||
seen_sources: OrderedSet[Source] = OrderedSet()
|
||||
while sources:
|
||||
current_source = sources.popleft()
|
||||
if current_source in seen_sources:
|
||||
# This source is used at least twice, so it can be reused
|
||||
codegen.mark_source_temp(current_source)
|
||||
# Dont trace source further. This prevents us from marking too
|
||||
# many nodes as temp sources.
|
||||
continue
|
||||
seen_sources.add(current_source)
|
||||
sources.extend(extract_nested_sources(current_source))
|
||||
|
||||
# Collect all the sources that are used more than once, so that we can
|
||||
# generate tmp variables in the generated pre-graph bytecode. This
|
||||
# essentially implements CSE.
|
||||
for arg in graphargs:
|
||||
if arg.source is not None:
|
||||
collect_temp_source(arg.source)
|
||||
collect_temp_sources(
|
||||
deque([arg.source for arg in graphargs if arg.source is not None]), self
|
||||
)
|
||||
|
||||
cm_var = None
|
||||
if config.record_runtime_overhead:
|
||||
|
||||
@ -828,6 +828,7 @@ def trace_frame(
|
||||
raise
|
||||
finally:
|
||||
tracer.output.call_cleanup_hooks()
|
||||
tracer.f_locals = {}
|
||||
|
||||
try:
|
||||
run_tracer()
|
||||
|
||||
@ -96,6 +96,7 @@ def disable(fn=None, recursive=True, *, reason=None, wrapping=True): # type: ig
|
||||
nonrecursive_disable_wrapper._torchdynamo_disable = True # type: ignore[attr-defined]
|
||||
nonrecursive_disable_wrapper._torchdynamo_disable_msg = reason # type: ignore[attr-defined]
|
||||
nonrecursive_disable_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
|
||||
nonrecursive_disable_wrapper._torchdynamo_disable_recursive = False # type: ignore[attr-defined]
|
||||
# pyrefly: ignore [bad-return]
|
||||
return nonrecursive_disable_wrapper
|
||||
|
||||
@ -1023,3 +1024,13 @@ def error_on_graph_break(
|
||||
The default value of torch.compile's `error_on_graph_break` setting is False.
|
||||
"""
|
||||
return ErrorOnGraphBreakDecoratorContextManager(error_on_graph_break)
|
||||
|
||||
|
||||
def is_dynamo_disable_recursive(method: Callable[[Any], Any]) -> Optional[bool]:
|
||||
"""
|
||||
Check if a method is marked as `dynamo_disable` recursively. It returns:
|
||||
- True if disable(recursive=True)
|
||||
- False if disable(recursive=False)
|
||||
- None if method is not a disable decorator
|
||||
"""
|
||||
return getattr(method, "_torchdynamo_disable_recursive", None)
|
||||
|
||||
@ -1155,6 +1155,8 @@ class DisableContext(_TorchDynamoContext):
|
||||
# of decorators.
|
||||
_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
|
||||
|
||||
_fn._torchdynamo_disable_recursive = True # type: ignore[attr-defined]
|
||||
|
||||
return _fn
|
||||
|
||||
def __reduce__(self) -> tuple[type[DisableContext], tuple[Any, ...]]:
|
||||
|
||||
@ -11,6 +11,16 @@ from typing import Any, TYPE_CHECKING, TypeVar
|
||||
import optree
|
||||
import optree._C
|
||||
import optree.utils
|
||||
from optree import (
|
||||
is_namedtuple,
|
||||
is_namedtuple_class,
|
||||
is_namedtuple_instance,
|
||||
is_structseq,
|
||||
is_structseq_class,
|
||||
is_structseq_instance,
|
||||
namedtuple_fields,
|
||||
structseq_fields,
|
||||
)
|
||||
|
||||
import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
|
||||
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
|
||||
@ -26,7 +36,26 @@ if TYPE_CHECKING:
|
||||
from torch.utils._cxx_pytree import PyTree
|
||||
|
||||
|
||||
__all__: list[str] = []
|
||||
__all__ = [
|
||||
"is_namedtuple",
|
||||
"is_namedtuple_class",
|
||||
"is_namedtuple_instance",
|
||||
"is_structseq",
|
||||
"is_structseq_class",
|
||||
"is_structseq_instance",
|
||||
"namedtuple_fields",
|
||||
"structseq_fields",
|
||||
"treespec_leaf",
|
||||
"treespec_tuple",
|
||||
"treespec_dict",
|
||||
"tree_is_leaf",
|
||||
"tree_iter",
|
||||
"tree_leaves",
|
||||
"tree_flatten",
|
||||
"tree_flatten_with_path",
|
||||
"tree_structure",
|
||||
"tree_unflatten",
|
||||
]
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
@ -48,21 +77,20 @@ def _(*args: Any, **kwargs: Any) -> bool:
|
||||
|
||||
|
||||
__name = ""
|
||||
for __name in (
|
||||
"is_namedtuple",
|
||||
"is_namedtuple_class",
|
||||
"is_namedtuple_instance",
|
||||
"is_structseq",
|
||||
"is_structseq_class",
|
||||
"is_structseq_instance",
|
||||
"namedtuple_fields",
|
||||
"structseq_fields",
|
||||
for __name, __func in (
|
||||
("is_namedtuple", is_namedtuple),
|
||||
("is_namedtuple_class", is_namedtuple_class),
|
||||
("is_namedtuple_instance", is_namedtuple_instance),
|
||||
("is_structseq", is_structseq),
|
||||
("is_structseq_class", is_structseq_class),
|
||||
("is_structseq_instance", is_structseq_instance),
|
||||
("namedtuple_fields", namedtuple_fields),
|
||||
("structseq_fields", structseq_fields),
|
||||
):
|
||||
__func = getattr(optree, __name)
|
||||
globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)(
|
||||
__func.__python_implementation__
|
||||
)
|
||||
__all__ += [__name] # noqa: PLE0604
|
||||
globals()[__name] = substitute_in_graph(
|
||||
__func, # type: ignore[arg-type]
|
||||
can_constant_fold_through=True,
|
||||
)(__func.__python_implementation__) # type: ignore[attr-defined]
|
||||
del __func
|
||||
del __name
|
||||
|
||||
@ -78,7 +106,7 @@ def tree_is_leaf(
|
||||
) -> bool:
|
||||
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)):
|
||||
return True
|
||||
if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: # type: ignore[attr-defined]
|
||||
if optree.register_pytree_node.get(type(tree), namespace=namespace) is None:
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -113,9 +141,6 @@ def tree_iter(
|
||||
stack.extend(reversed(children))
|
||||
|
||||
|
||||
__all__ += ["tree_iter"]
|
||||
|
||||
|
||||
@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) # type: ignore[arg-type]
|
||||
def tree_leaves(
|
||||
tree: PyTree,
|
||||
@ -135,9 +160,6 @@ def tree_leaves(
|
||||
)
|
||||
|
||||
|
||||
__all__ += ["tree_leaves"]
|
||||
|
||||
|
||||
class _Asterisk(str):
|
||||
__slots__ = ()
|
||||
|
||||
@ -168,7 +190,7 @@ class PyTreeSpec:
|
||||
num_leaves: int = field(init=False)
|
||||
num_children: int = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
def __post_init__(self, /) -> None:
|
||||
if self._type is None:
|
||||
assert len(self._children) == 0
|
||||
assert self._metadata is None
|
||||
@ -187,7 +209,7 @@ class PyTreeSpec:
|
||||
object.__setattr__(self, "num_leaves", num_leaves)
|
||||
object.__setattr__(self, "num_children", num_children)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self, /) -> str:
|
||||
def helper(treespec: PyTreeSpec) -> str:
|
||||
if treespec.is_leaf():
|
||||
assert treespec.type is None
|
||||
@ -221,29 +243,78 @@ class PyTreeSpec:
|
||||
]
|
||||
return f"PyTreeSpec({', '.join(inner)})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
def __len__(self, /) -> int:
|
||||
return self.num_leaves
|
||||
|
||||
@property
|
||||
def type(self) -> builtins.type | None:
|
||||
def type(self, /) -> builtins.type | None:
|
||||
return self._type
|
||||
|
||||
def is_leaf(self) -> bool:
|
||||
def is_leaf(self, /) -> bool:
|
||||
return self.num_nodes == 1 and self.num_leaves == 1
|
||||
|
||||
def children(self) -> list[PyTreeSpec]:
|
||||
def paths(self, /) -> list[tuple[Any, ...]]:
|
||||
def helper(treespec: PyTreeSpec, path_prefix: list[Any]) -> None:
|
||||
if treespec.is_leaf():
|
||||
paths.append(path_prefix)
|
||||
return
|
||||
|
||||
for entry, subspec in zip(
|
||||
treespec._entries,
|
||||
treespec._children,
|
||||
strict=True,
|
||||
):
|
||||
helper(subspec, path_prefix + [entry])
|
||||
|
||||
paths: list[list[Any]] = []
|
||||
helper(self, [])
|
||||
return [tuple(path) for path in paths]
|
||||
|
||||
def accessors(self, /) -> list[optree.PyTreeAccessor]:
|
||||
def helper(
|
||||
treespec: PyTreeSpec,
|
||||
entry_path_prefix: list[optree.PyTreeEntry],
|
||||
) -> None:
|
||||
if treespec.is_leaf():
|
||||
entry_paths.append(entry_path_prefix)
|
||||
return
|
||||
|
||||
node_type = treespec.type
|
||||
assert node_type is not None
|
||||
handler = optree.register_pytree_node.get(
|
||||
node_type, namespace=treespec.namespace
|
||||
)
|
||||
assert handler is not None
|
||||
kind: optree.PyTreeKind = handler.kind
|
||||
path_entry_type: type[optree.PyTreeEntry] = handler.path_entry_type
|
||||
|
||||
for entry, subspec in zip(
|
||||
treespec._entries,
|
||||
treespec._children,
|
||||
strict=True,
|
||||
):
|
||||
helper(
|
||||
subspec,
|
||||
entry_path_prefix + [path_entry_type(entry, node_type, kind)],
|
||||
)
|
||||
|
||||
entry_paths: list[list[optree.PyTreeEntry]] = []
|
||||
helper(self, [])
|
||||
return [optree.PyTreeAccessor(path) for path in entry_paths]
|
||||
|
||||
def children(self, /) -> list[PyTreeSpec]:
|
||||
return list(self._children)
|
||||
|
||||
def child(self, index: int) -> PyTreeSpec:
|
||||
def child(self, index: int, /) -> PyTreeSpec:
|
||||
return self._children[index]
|
||||
|
||||
def entries(self) -> list[Any]:
|
||||
def entries(self, /) -> list[Any]:
|
||||
return list(self._entries)
|
||||
|
||||
def entry(self, index: int) -> Any:
|
||||
def entry(self, index: int, /) -> Any:
|
||||
return self._entries[index]
|
||||
|
||||
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
|
||||
def flatten_up_to(self, tree: PyTree, /) -> list[PyTree]:
|
||||
def helper(
|
||||
treespec: PyTreeSpec,
|
||||
node: PyTree,
|
||||
@ -324,14 +395,14 @@ class PyTreeSpec:
|
||||
f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch
|
||||
)
|
||||
|
||||
for subtree, subspec in zip(children, treespec._children):
|
||||
for subtree, subspec in zip(children, treespec._children, strict=True):
|
||||
helper(subspec, subtree, subtrees)
|
||||
|
||||
subtrees: list[PyTree] = []
|
||||
helper(self, tree, subtrees)
|
||||
return subtrees
|
||||
|
||||
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
|
||||
def unflatten(self, leaves: Iterable[Any], /) -> PyTree:
|
||||
if not isinstance(leaves, (list, tuple)):
|
||||
leaves = list(leaves)
|
||||
if len(leaves) != self.num_leaves:
|
||||
@ -408,7 +479,7 @@ def treespec_tuple(
|
||||
"All children PyTreeSpecs must have the same `namespace` value "
|
||||
f"as the parent; expected {namespace!r}, got: {children!r}.",
|
||||
)
|
||||
handler = optree.register_pytree_node.get(tuple, namespace=namespace) # type: ignore[attr-defined]
|
||||
handler = optree.register_pytree_node.get(tuple, namespace=namespace)
|
||||
assert handler is not None
|
||||
return PyTreeSpec(
|
||||
tuple(children),
|
||||
@ -531,7 +602,69 @@ def tree_flatten(
|
||||
return leaves, treespec
|
||||
|
||||
|
||||
__all__ += ["tree_flatten"]
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
optree._C.flatten,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def _C_flatten(
|
||||
tree: PyTree,
|
||||
/,
|
||||
leaf_predicate: Callable[[PyTree], bool] | None = None,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> tuple[list[Any], PyTreeSpec]:
|
||||
return tree_flatten( # type: ignore[return-value]
|
||||
tree,
|
||||
is_leaf=leaf_predicate,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
optree.tree_flatten_with_path,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def tree_flatten_with_path(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]:
|
||||
leaves, treespec = tree_flatten(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
return treespec.paths(), leaves, treespec # type: ignore[return-value]
|
||||
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
optree._C.flatten_with_path,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def _C_flatten_with_path(
|
||||
tree: PyTree,
|
||||
/,
|
||||
leaf_predicate: Callable[[PyTree], bool] | None = None,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]:
|
||||
return tree_flatten_with_path( # type: ignore[return-value]
|
||||
tree,
|
||||
is_leaf=leaf_predicate,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
@ -556,9 +689,6 @@ def tree_structure(
|
||||
)[1]
|
||||
|
||||
|
||||
__all__ += ["tree_structure"]
|
||||
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
optree.tree_unflatten,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
@ -574,55 +704,6 @@ def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
|
||||
return treespec.unflatten(leaves)
|
||||
|
||||
|
||||
__all__ += ["tree_unflatten"]
|
||||
|
||||
|
||||
@substitute_in_graph(optree.tree_map, can_constant_fold_through=True) # type: ignore[arg-type]
|
||||
def tree_map(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
/,
|
||||
*rests: PyTree,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> PyTree:
|
||||
leaves, treespec = tree_flatten(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||
return treespec.unflatten(map(func, *flat_args))
|
||||
|
||||
|
||||
__all__ += ["tree_map"]
|
||||
|
||||
|
||||
@substitute_in_graph(optree.tree_map_, can_constant_fold_through=True) # type: ignore[arg-type]
|
||||
def tree_map_(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
/,
|
||||
*rests: PyTree,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> PyTree:
|
||||
leaves, treespec = tree_flatten(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
|
||||
return tree
|
||||
|
||||
|
||||
__all__ += ["tree_map_"]
|
||||
|
||||
_none_registration = optree.register_pytree_node.get(type(None))
|
||||
assert _none_registration is not None
|
||||
|
||||
@ -669,5 +750,5 @@ def dict_unflatten(
|
||||
) -> dict[_KT, _VT]:
|
||||
original_keys, sorted_keys = metadata
|
||||
d = dict.fromkeys(original_keys)
|
||||
d.update(zip(sorted_keys, values))
|
||||
d.update(zip(sorted_keys, values, strict=True))
|
||||
return d # type: ignore[return-value]
|
||||
|
||||
@ -1516,17 +1516,29 @@ def should_prefer_unfused_addmm(match):
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
|
||||
CallFunction(
|
||||
aten.addmm,
|
||||
KeywordArg("inp"),
|
||||
Arg(),
|
||||
Arg(),
|
||||
beta=KeywordArg("beta"),
|
||||
alpha=KeywordArg("alpha"),
|
||||
),
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
pass_dict=pass_patterns[2],
|
||||
extra_check=should_prefer_unfused_addmm,
|
||||
)
|
||||
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):
|
||||
def repl(inp, x1, x2):
|
||||
return x1 @ x2 + inp
|
||||
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp, alpha, beta):
|
||||
def repl(inp, x1, x2, alpha, beta):
|
||||
mm_result = x1 @ x2
|
||||
if alpha != 1:
|
||||
mm_result = alpha * mm_result
|
||||
if beta != 1:
|
||||
inp = beta * inp
|
||||
return inp + mm_result
|
||||
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
match.replace_by_example(repl, [inp, mat1, mat2])
|
||||
match.replace_by_example(repl, [inp, mat1, mat2, alpha, beta])
|
||||
|
||||
|
||||
def is_valid_addmm_fusion(match):
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/FlightRecorder.hpp>
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
@ -63,6 +65,14 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
|
||||
res.setStatus(200);
|
||||
}};
|
||||
|
||||
RegisterHandler frTracehandler(
|
||||
"fr_trace_json",
|
||||
[](const Request&, Response& res) {
|
||||
auto trace = ::c10d::dump_fr_trace_json(true, true);
|
||||
res.setContent(std::move(trace), "application/json");
|
||||
res.setStatus(200);
|
||||
});
|
||||
|
||||
} // namespace
|
||||
|
||||
void registerHandler(const std::string& name, HandlerFunc f) {
|
||||
|
||||
@ -18,6 +18,14 @@ class TORCH_API Request {
|
||||
virtual const std::string& body() const = 0;
|
||||
|
||||
virtual const std::multimap<std::string, std::string>& params() const = 0;
|
||||
|
||||
std::string getParam(const std::string& key) const {
|
||||
auto it = params().find(key);
|
||||
if (it != params().end()) {
|
||||
return it->second;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
};
|
||||
|
||||
// Response represents a response to the handler. This conceptually maps to an
|
||||
|
||||
@ -152,11 +152,17 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) {
|
||||
TORCH_CHECK(
|
||||
server_.bind_to_port(hostOrFile, 80),
|
||||
fmt::format("Error binding to {}", hostOrFile));
|
||||
} else if (port == 0) {
|
||||
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
|
||||
port_ = server_.bind_to_any_port(hostOrFile);
|
||||
TORCH_CHECK(
|
||||
port_ >= 0, fmt::format("Error binding to {}:{}", hostOrFile, port));
|
||||
} else {
|
||||
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
|
||||
TORCH_CHECK(
|
||||
server_.bind_to_port(hostOrFile, port),
|
||||
fmt::format("Error binding to {}:{}", hostOrFile, port));
|
||||
port_ = port;
|
||||
}
|
||||
|
||||
serverThread_ = std::thread([this]() {
|
||||
|
||||
@ -19,9 +19,14 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
|
||||
|
||||
void shutdown();
|
||||
|
||||
int port() {
|
||||
return port_;
|
||||
}
|
||||
|
||||
private:
|
||||
httplib::Server server_;
|
||||
std::thread serverThread_;
|
||||
int port_;
|
||||
};
|
||||
|
||||
} // namespace c10d::control_plane
|
||||
|
||||
@ -46,6 +46,7 @@
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <pybind11/chrono.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.hpp>
|
||||
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
|
||||
@ -4203,7 +4204,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
}),
|
||||
py::arg("host_or_file"),
|
||||
py::arg("port") = -1)
|
||||
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
|
||||
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown)
|
||||
.def_property_readonly(
|
||||
"port", &::c10d::control_plane::WorkerServer::port);
|
||||
|
||||
module.def(
|
||||
"_get_handler",
|
||||
@ -4219,6 +4222,25 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
Returns the handler with the specified name.
|
||||
)");
|
||||
|
||||
module.def(
|
||||
"_register_handler",
|
||||
[](const std::string& name, const py::function& handler) {
|
||||
::c10d::control_plane::registerHandler(
|
||||
name,
|
||||
[handler](
|
||||
const ::c10d::control_plane::Request& req,
|
||||
::c10d::control_plane::Response& res) {
|
||||
py::gil_scoped_acquire acquire;
|
||||
handler(std::ref(req), std::ref(res));
|
||||
});
|
||||
},
|
||||
|
||||
py::arg("name"),
|
||||
py::arg("handler"),
|
||||
R"(
|
||||
Registers a handler by name.
|
||||
)");
|
||||
|
||||
module.def(
|
||||
"_get_handler_names",
|
||||
&::c10d::control_plane::getHandlerNames,
|
||||
@ -4236,12 +4258,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
// Default constructor.
|
||||
.def(py::init<>())
|
||||
.def("body", &::c10d::control_plane::Request::body)
|
||||
.def("params", &::c10d::control_plane::Request::params);
|
||||
.def("get_param", &::c10d::control_plane::Request::getParam);
|
||||
|
||||
py::class_<
|
||||
::c10d::control_plane::Response,
|
||||
std::shared_ptr<::c10d::control_plane::Response>,
|
||||
PythonResponse>(
|
||||
py::class_<::c10d::control_plane::Response, PythonResponse>(
|
||||
module,
|
||||
"_Response",
|
||||
R"(
|
||||
|
||||
52
torch/distributed/debug/__init__.py
Normal file
52
torch/distributed/debug/__init__.py
Normal file
@ -0,0 +1,52 @@
|
||||
import multiprocessing
|
||||
import socket
|
||||
|
||||
# import for registration side effect
|
||||
import torch.distributed.debug._handlers # noqa: F401
|
||||
from torch._C._distributed_c10d import _WorkerServer
|
||||
from torch.distributed.debug._store import get_rank, tcpstore_client
|
||||
|
||||
|
||||
__all__ = [
|
||||
"start_debug_server",
|
||||
"stop_debug_server",
|
||||
]
|
||||
|
||||
_WORKER_SERVER: _WorkerServer | None = None
|
||||
_DEBUG_SERVER_PROC: multiprocessing.Process | None = None
|
||||
|
||||
|
||||
def start_debug_server(port: int = 25999, worker_port: int = 0) -> None:
|
||||
global _WORKER_SERVER, _DEBUG_SERVER_PROC
|
||||
|
||||
assert _WORKER_SERVER is None, "debug server already started"
|
||||
assert _DEBUG_SERVER_PROC is None, "debug server already started"
|
||||
|
||||
store = tcpstore_client()
|
||||
|
||||
_WORKER_SERVER = _WorkerServer("::", worker_port)
|
||||
|
||||
RANK = get_rank()
|
||||
store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_WORKER_SERVER.port}")
|
||||
|
||||
from torch.distributed.debug._flask import main
|
||||
|
||||
if RANK == 0:
|
||||
_DEBUG_SERVER_PROC = multiprocessing.Process(
|
||||
target=main, args=(port,), daemon=True
|
||||
)
|
||||
_DEBUG_SERVER_PROC.start()
|
||||
|
||||
|
||||
def stop_debug_server() -> None:
|
||||
global _WORKER_SERVER, _DEBUG_SERVER_PROC
|
||||
|
||||
assert _DEBUG_SERVER_PROC is not None
|
||||
assert _WORKER_SERVER is not None
|
||||
|
||||
_DEBUG_SERVER_PROC.terminate()
|
||||
_WORKER_SERVER.shutdown()
|
||||
_DEBUG_SERVER_PROC.join()
|
||||
|
||||
_WORKER_SERVER = None
|
||||
_DEBUG_SERVER_PROC = None
|
||||
265
torch/distributed/debug/_flask.py
Normal file
265
torch/distributed/debug/_flask.py
Normal file
@ -0,0 +1,265 @@
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Iterator
|
||||
|
||||
import requests
|
||||
from flask import Flask, render_template, request
|
||||
from jinja2 import DictLoader
|
||||
|
||||
from torch.distributed.debug._store import get_world_size, tcpstore_client
|
||||
|
||||
|
||||
def fetch_all(
|
||||
endpoint: str, args: str = ""
|
||||
) -> tuple[list[str], Iterator[requests.Response]]:
|
||||
store = tcpstore_client()
|
||||
keys = [f"rank{r}" for r in range(get_world_size())]
|
||||
addrs = store.multi_get(keys)
|
||||
addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
resps = executor.map(requests.post, addrs)
|
||||
|
||||
return addrs, resps
|
||||
|
||||
|
||||
def format_json(blob: str):
|
||||
parsed = json.loads(blob)
|
||||
return json.dumps(parsed, indent=2)
|
||||
|
||||
|
||||
templates = {
|
||||
"base.html": """
|
||||
<!doctype html>
|
||||
<head>
|
||||
<title>{% block title %}{% endblock %} - PyTorch Distributed</title>
|
||||
<link rel="shortcut icon" type="image/x-icon" href="https://pytorch.org/favicon.ico?">
|
||||
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
font-family:
|
||||
-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,
|
||||
"Helvetica Neue",Arial,"Noto Sans",sans-serif,"Apple Color Emoji",
|
||||
"Segoe UI Emoji","Segoe UI Symbol","Noto Color Emoji";
|
||||
font-size: 1rem;
|
||||
font-weight: 400;
|
||||
line-height: 1.5;
|
||||
color: #212529;
|
||||
text-align: left;
|
||||
background-color: #fff;
|
||||
}
|
||||
h1, h2, h2, h4, h5, h6, .h1, .h2, .h2, .h4, .h5, .h6 {
|
||||
margin-bottom: .5rem;
|
||||
font-weight: 500;
|
||||
line-height: 1.2;
|
||||
}
|
||||
nav {
|
||||
background-color: rgba(0, 0, 0, 0.17);
|
||||
padding: 10px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 16px;
|
||||
justify-content: flex-start;
|
||||
}
|
||||
nav h1 {
|
||||
display: inline-block;
|
||||
margin: 0;
|
||||
}
|
||||
nav a {
|
||||
margin: 0 8px;
|
||||
}
|
||||
section {
|
||||
max-width: 1280px;
|
||||
padding: 16px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
pre {
|
||||
white-space: pre-wrap;
|
||||
max-width: 100%;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<nav>
|
||||
<h1>Torch Distributed Debug Server</h1>
|
||||
|
||||
<a href="/">Home</a>
|
||||
<a href="/stacks">Python Stack Traces</a>
|
||||
<a href="/fr_trace">FlightRecorder</a>
|
||||
<a href="/fr_trace_nccl">FlightRecorder NCCL</a>
|
||||
<a href="/profile">torch profiler</a>
|
||||
</nav>
|
||||
|
||||
<section class="content">
|
||||
{% block header %}{% endblock %}
|
||||
{% for message in get_flashed_messages() %}
|
||||
<div class="flash">{{ message }}</div>
|
||||
{% endfor %}
|
||||
{% block content %}{% endblock %}
|
||||
</section>
|
||||
""",
|
||||
"index.html": """
|
||||
{% extends "base.html" %}
|
||||
{% block header %}
|
||||
<h1>{% block title %}Index{% endblock %}</h1>
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
Hi
|
||||
{% endblock %}
|
||||
""",
|
||||
"raw_resp.html": """
|
||||
{% extends "base.html" %}
|
||||
{% block header %}
|
||||
<h1>{% block title %}{{title}}{% endblock %}</h1>
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
|
||||
<h2>Rank {{ i }}: {{ addr }}</h2>
|
||||
{% if resp.status_code != 200 %}
|
||||
<p>Failed to fetch: status={{ resp.status_code }}</p>
|
||||
<pre>{{ resp.text }}</pre>
|
||||
{% else %}
|
||||
<pre>{{ resp.text }}</pre>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endblock %}
|
||||
""",
|
||||
"json_resp.html": """
|
||||
{% extends "base.html" %}
|
||||
{% block header %}
|
||||
<h1>{% block title %}{{ title }}{% endblock %}</h1>
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
|
||||
<h2>Rank {{ i }}: {{ addr }}</h2>
|
||||
{% if resp.status_code != 200 %}
|
||||
<p>Failed to fetch: status={{ resp.status_code }}</p>
|
||||
<pre>{{ resp.text }}</pre>
|
||||
{% else %}
|
||||
<pre>{{ format_json(resp.text) }}</pre>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endblock %}
|
||||
""",
|
||||
"profile.html": """
|
||||
{% extends "base.html" %}
|
||||
{% block header %}
|
||||
<h1>{% block title %}torch.profiler{% endblock %}</h1>
|
||||
{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<form action="/profile" method="get">
|
||||
<label for="duration">Duration (seconds):</label>
|
||||
<input type="number" id="duration" name="duration" value="{{ duration }}" min="1" max="60">
|
||||
<input type="submit" value="Submit">
|
||||
</form>
|
||||
|
||||
<script>
|
||||
function stringToArrayBuffer(str) {
|
||||
const encoder = new TextEncoder();
|
||||
return encoder.encode(str).buffer;
|
||||
}
|
||||
async function openPerfetto(data) {
|
||||
const ui = window.open('https://ui.perfetto.dev/#!/');
|
||||
if (!ui) { alert('Popup blocked. Allow popups for this page and click again.'); return; }
|
||||
|
||||
// Perfetto readiness handshake: PING until we receive PONG
|
||||
await new Promise((resolve, reject) => {
|
||||
const onMsg = (e) => {
|
||||
if (e.source === ui && e.data === 'PONG') {
|
||||
window.removeEventListener('message', onMsg);
|
||||
clearInterval(pinger);
|
||||
resolve();
|
||||
}
|
||||
};
|
||||
window.addEventListener('message', onMsg);
|
||||
const pinger = setInterval(() => { try { ui.postMessage('PING', '*'); } catch (_e) {} }, 250);
|
||||
setTimeout(() => { clearInterval(pinger); window.removeEventListener('message', onMsg); reject(); }, 20000);
|
||||
}).catch(() => { alert('Perfetto UI did not respond. Try again.'); return; });
|
||||
|
||||
ui.postMessage({
|
||||
perfetto: {
|
||||
buffer: stringToArrayBuffer(JSON.stringify(data)),
|
||||
title: "torch profiler",
|
||||
fileName: "trace.json",
|
||||
}
|
||||
}, '*');
|
||||
}
|
||||
</script>
|
||||
|
||||
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
|
||||
<h2>Rank {{ i }}: {{ addr }}</h2>
|
||||
{% if resp.status_code != 200 %}
|
||||
<p>Failed to fetch: status={{ resp.status_code }}</p>
|
||||
<pre>{{ resp.text }}</pre>
|
||||
{% else %}
|
||||
<script>
|
||||
function run{{ i }}() {
|
||||
var data = {{ resp.text | safe }};
|
||||
openPerfetto(data);
|
||||
}
|
||||
</script>
|
||||
|
||||
<button onclick="run{{ i }}()">View {{ i }}</button>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endblock %}
|
||||
""",
|
||||
}
|
||||
|
||||
app = Flask(__name__)
|
||||
app.jinja_loader = DictLoader(templates)
|
||||
app.jinja_env.globals.update(
|
||||
zip=zip,
|
||||
format_json=format_json,
|
||||
enumerate=enumerate,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def _index_handler():
|
||||
return render_template("index.html")
|
||||
|
||||
|
||||
@app.route("/stacks")
|
||||
def _stacks_handler():
|
||||
addrs, resps = fetch_all("dump_traceback")
|
||||
return render_template("raw_resp.html", title="Stacks", addrs=addrs, resps=resps)
|
||||
|
||||
|
||||
@app.route("/fr_trace")
|
||||
def _fr_trace_handler():
|
||||
addrs, resps = fetch_all("fr_trace_json")
|
||||
|
||||
return render_template(
|
||||
"json_resp.html",
|
||||
title="FlightRecorder",
|
||||
addrs=addrs,
|
||||
resps=resps,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/fr_trace_nccl")
|
||||
def _fr_trace_nccl_handler():
|
||||
addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true")
|
||||
|
||||
return render_template(
|
||||
"json_resp.html",
|
||||
title="FlightRecorder NCCL",
|
||||
addrs=addrs,
|
||||
resps=resps,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/profile")
|
||||
def _profiler_handler():
|
||||
duration = request.args.get("duration", default=1.0, type=float)
|
||||
|
||||
addrs, resps = fetch_all("torch_profile", f"duration={duration}")
|
||||
|
||||
return render_template("profile.html", addrs=addrs, resps=resps)
|
||||
|
||||
|
||||
def main(port: int) -> None:
|
||||
app.run(host="::", port=port)
|
||||
22
torch/distributed/debug/_handlers.py
Normal file
22
torch/distributed/debug/_handlers.py
Normal file
@ -0,0 +1,22 @@
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from torch._C._distributed_c10d import _register_handler, _Request, _Response
|
||||
from torch.profiler import _ExperimentalConfig, profile
|
||||
|
||||
|
||||
def _torch_profile(req: _Request, resp: _Response) -> None:
|
||||
experimental_config = _ExperimentalConfig(
|
||||
profile_all_threads=True,
|
||||
)
|
||||
duration = float(req.get_param("duration"))
|
||||
with profile(record_shapes=True, experimental_config=experimental_config) as prof:
|
||||
time.sleep(duration)
|
||||
|
||||
with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f:
|
||||
prof.export_chrome_trace(f.name)
|
||||
resp.set_content(open(f.name, "rb").read(), "application/json")
|
||||
resp.set_status(200)
|
||||
|
||||
|
||||
_register_handler("torch_profile", _torch_profile)
|
||||
24
torch/distributed/debug/_store.py
Normal file
24
torch/distributed/debug/_store.py
Normal file
@ -0,0 +1,24 @@
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
return int(os.environ["RANK"])
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
return int(os.environ["WORLD_SIZE"])
|
||||
|
||||
|
||||
def tcpstore_client() -> dist.Store:
|
||||
MASTER_ADDR = os.environ["MASTER_ADDR"]
|
||||
MASTER_PORT = int(os.environ["MASTER_PORT"])
|
||||
|
||||
store = dist.TCPStore(
|
||||
host_name=MASTER_ADDR,
|
||||
port=MASTER_PORT,
|
||||
is_master=False,
|
||||
)
|
||||
store = dist.PrefixStore("debug_server", store)
|
||||
return store
|
||||
@ -1485,6 +1485,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
use_full_backward: Optional[bool] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
# Init parent
|
||||
super().__init__(
|
||||
@ -1517,6 +1518,11 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
# This will be set during init of derived schedules
|
||||
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
|
||||
|
||||
# When using a custom backward function, we may or may not need autograd to be used
|
||||
# for the backward pass. This flag is used to determine whether or torch.is_grad_enabled()
|
||||
# check should be performed before the step function.
|
||||
self._backward_requires_autograd = backward_requires_autograd
|
||||
|
||||
if use_full_backward is not None:
|
||||
logger.warning(
|
||||
"Deprecation warning: 'use_full_backward' is no longer supported. "
|
||||
@ -1609,7 +1615,11 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
losses: a list to store the losses for each microbatch.
|
||||
return_outputs: whether to return the outputs from the last stage.
|
||||
"""
|
||||
if self._has_backward and not torch.is_grad_enabled():
|
||||
if (
|
||||
self._has_backward
|
||||
and self._backward_requires_autograd
|
||||
and not torch.is_grad_enabled()
|
||||
):
|
||||
raise RuntimeError(
|
||||
"step() requires gradients to be enabled for backward computation; "
|
||||
"it should not be used under torch.no_grad() context. "
|
||||
@ -1891,7 +1901,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
Args:
|
||||
computation_type: The computation type for which to register the custom function
|
||||
custom_function: The function to execute when this computation type is encountered.
|
||||
Must have signature: (stage: _PipelineStageBase, mb_index: int, *args, **kwargs) -> None
|
||||
Must have signature: (action: _Action, ctx: _PipelineContext) -> None
|
||||
"""
|
||||
# Ensure that the computation type is valid
|
||||
if computation_type not in (
|
||||
@ -1900,10 +1910,13 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
BACKWARD_INPUT,
|
||||
BACKWARD_WEIGHT,
|
||||
OVERLAP_F_B,
|
||||
UNSHARD,
|
||||
RESHARD,
|
||||
REDUCE_GRAD,
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid computation type {computation_type}. Only FORWARD, FULL_BACKWARD, \
|
||||
BACKWARD_INPUT, BACKWARD_WEIGHT, and OVERLAP_F_B are supported."
|
||||
BACKWARD_INPUT, BACKWARD_WEIGHT, OVERLAP_F_B, UNSHARD, RESHARD and REDUCE_GRAD are supported."
|
||||
)
|
||||
|
||||
# Check if computation_type is already registered
|
||||
@ -2296,6 +2309,7 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
|
||||
loss_fn: Optional[Union[Callable, _Loss]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
stages=stages,
|
||||
@ -2303,6 +2317,7 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
|
||||
loss_fn=loss_fn,
|
||||
output_merge_spec=output_merge_spec,
|
||||
scale_grads=scale_grads,
|
||||
backward_requires_autograd=backward_requires_autograd,
|
||||
)
|
||||
|
||||
# 1. Create the pipeline_order (all ranks do this calculation)
|
||||
@ -2510,6 +2525,7 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
self.pp_group_size = stages[0].group_size
|
||||
super().__init__(
|
||||
@ -2520,6 +2536,7 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
scale_grads=scale_grads,
|
||||
backward_requires_autograd=backward_requires_autograd,
|
||||
)
|
||||
self.n_local_stages = len(stages)
|
||||
self.rank = stages[0].group_rank
|
||||
@ -2622,6 +2639,7 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
# TODO: we dont support input/weight backward split with torch.compile
|
||||
_check_torch_compile_compatibility(stages, self.__class__.__name__)
|
||||
@ -2634,6 +2652,7 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
scale_grads=scale_grads,
|
||||
backward_requires_autograd=backward_requires_autograd,
|
||||
)
|
||||
self.n_local_stages = len(stages)
|
||||
self.rank = stages[0].group_rank
|
||||
@ -2819,6 +2838,7 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
# TODO: we dont support input/weight backward split with torch.compile
|
||||
_check_torch_compile_compatibility(stages, self.__class__.__name__)
|
||||
@ -2831,6 +2851,7 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
scale_grads=scale_grads,
|
||||
backward_requires_autograd=backward_requires_autograd,
|
||||
)
|
||||
self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
|
||||
self.pp_group_size, self._num_stages, style="v"
|
||||
@ -2995,6 +3016,7 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
# TODO: we dont support input/weight backward split with torch.compile
|
||||
_check_torch_compile_compatibility(stages, self.__class__.__name__)
|
||||
@ -3007,6 +3029,7 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
scale_grads=scale_grads,
|
||||
backward_requires_autograd=backward_requires_autograd,
|
||||
)
|
||||
self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
|
||||
self.pp_group_size, self._num_stages, style="v"
|
||||
|
||||
@ -1771,6 +1771,22 @@ class MultiProcContinuousTest(TestCase):
|
||||
cls._run_test_given_id(test_id)
|
||||
completion_queue.put(test_id)
|
||||
except BaseException as ex: # noqa: B036
|
||||
if isinstance(ex, SystemExit):
|
||||
# Get exit code from the process
|
||||
exit_code = getattr(ex, "code", None)
|
||||
|
||||
# Look up exit code in TEST_SKIPS to see if it is a valid skip
|
||||
skip_entry = next(
|
||||
(v for v in TEST_SKIPS.values() if v.exit_code == exit_code),
|
||||
None,
|
||||
)
|
||||
|
||||
# If we found an entry, we want to skip the test and the object back to the main process
|
||||
if skip_entry:
|
||||
completion_queue.put(unittest.SkipTest(skip_entry.message))
|
||||
# Skip exception handling below, move to main thread for processing the skip
|
||||
continue
|
||||
|
||||
raised_exception = True
|
||||
# Send the exception and stack trace back to the dispatcher
|
||||
exc_info = sys.exc_info()
|
||||
@ -1892,6 +1908,8 @@ class MultiProcContinuousTest(TestCase):
|
||||
# Wait for the workers to finish the test
|
||||
for i, completion_queue in enumerate(self.completion_queues):
|
||||
rv = completion_queue.get()
|
||||
if isinstance(rv, unittest.SkipTest):
|
||||
raise rv
|
||||
if isinstance(rv, BaseException):
|
||||
# Hit an exception, re-raise it in the main process.
|
||||
logger.warning(
|
||||
|
||||
@ -114,8 +114,6 @@ class ProfilingMode(Enum):
|
||||
PROFILING = 3
|
||||
|
||||
# Set by parse_cmd_line_args() if called
|
||||
CI_FUNCTORCH_ROOT = ""
|
||||
CI_PT_ROOT = ""
|
||||
CI_TEST_PREFIX = ""
|
||||
DISABLED_TESTS_FILE = ""
|
||||
GRAPH_EXECUTOR : Optional[ProfilingMode] = None
|
||||
@ -959,8 +957,6 @@ def _get_test_report_path():
|
||||
return os.path.join('test-reports', test_source)
|
||||
|
||||
def parse_cmd_line_args():
|
||||
global CI_FUNCTORCH_ROOT
|
||||
global CI_PT_ROOT
|
||||
global CI_TEST_PREFIX
|
||||
global DISABLED_TESTS_FILE
|
||||
global GRAPH_EXECUTOR
|
||||
@ -1039,10 +1035,8 @@ def parse_cmd_line_args():
|
||||
|
||||
set_rng_seed()
|
||||
|
||||
# CI Prefix path used only on CI environment
|
||||
# CI Prefix path used only on CI environment
|
||||
CI_TEST_PREFIX = str(Path(os.getcwd()))
|
||||
CI_PT_ROOT = str(Path(os.getcwd()).parent)
|
||||
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
|
||||
|
||||
def wait_for_process(p, timeout=None):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user