Compare commits

...

3 Commits

Author SHA1 Message Date
d9bd0f252b distributed/debug: add an HTTP server for debugging running jobs 2025-11-14 12:11:17 -08:00
bfddfde50c Add basic spin config and linting commands (#167226)
This PR adds a basic spin configuration to allow for linting. It is designed as a drop-in replacement for the current Makefile based solution, i.e. it sets up and updates lintrunner based on the hashes of certain configuration files.

Lintrunner is called via Uv's `uvx` command, separating its environment from the general development environment in an effort to reduce instances of competing requirements breaking environments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167226
Approved by: https://github.com/atalman, https://github.com/albanD
2025-11-14 15:35:42 +00:00
b6570615f8 [precompile] Integrate AOTI as a backend. (#167338)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167338
Approved by: https://github.com/jamesjwu
2025-11-14 15:33:11 +00:00
24 changed files with 1069 additions and 18 deletions

View File

@ -402,3 +402,6 @@ scikit-build==0.18.1
pyre-extensions==0.0.32
tabulate==0.9.0
#Description: These package are needed to build FBGEMM and torchrec on PyTorch CI
Flask==3.1.2
#Description: required for torch.distributed.debug

330
.spin/cmds.py Normal file
View File

@ -0,0 +1,330 @@
import hashlib
import subprocess
import sys
from pathlib import Path
import click
import spin
def file_digest(file, algorithm: str):
try:
return hashlib.file_digest(file, algorithm)
except AttributeError:
pass # Fallback to manual implementation below
hash = hashlib.new(algorithm)
while chunk := file.read(8192):
hash.update(chunk)
return hash
def _hash_file(file):
with open(file, "rb") as f:
hash = file_digest(f, "sha256")
return hash.hexdigest()
def _hash_files(files):
hashes = {file: _hash_file(file) for file in files}
return hashes
def _read_hashes(hash_file: Path):
if not hash_file.exists():
return {}
with hash_file.open("r") as f:
lines = f.readlines()
hashes = {}
for line in lines:
hash = line[:64]
file = line[66:].strip()
hashes[file] = hash
return hashes
def _updated_hashes(hash_file, files_to_hash):
old_hashes = _read_hashes(hash_file)
new_hashes = _hash_files(files_to_hash)
if new_hashes != old_hashes:
return new_hashes
return None
@click.command()
def regenerate_version():
"""Regenerate version.py."""
cmd = [
sys.executable,
"-m",
"tools.generate_torch_version",
"--is-debug=false",
]
spin.util.run(cmd)
TYPE_STUBS = [
(
"Pytorch type stubs",
Path(".lintbin/.pytorch-type-stubs.sha256"),
[
"aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml",
"tools/autograd/deprecated.yaml",
],
[
sys.executable,
"-m",
"tools.pyi.gen_pyi",
"--native-functions-path",
"aten/src/ATen/native/native_functions.yaml",
"--tags-path",
"aten/src/ATen/native/tags.yaml",
"--deprecated-functions-path",
"tools/autograd/deprecated.yaml",
],
),
(
"Datapipes type stubs",
None,
[],
[
sys.executable,
"torch/utils/data/datapipes/gen_pyi.py",
],
),
]
@click.command()
def regenerate_type_stubs():
"""Regenerate type stubs."""
for name, hash_file, files_to_hash, cmd in TYPE_STUBS:
if hash_file:
if hashes := _updated_hashes(hash_file, files_to_hash):
click.echo(
f"Changes detected in type stub files for {name}. Regenerating..."
)
spin.util.run(cmd)
hash_file.parent.mkdir(parents=True, exist_ok=True)
with hash_file.open("w") as f:
for file, hash in hashes.items():
f.write(f"{hash} {file}\n")
click.echo("Type stubs and hashes updated.")
else:
click.echo(f"No changes detected in type stub files for {name}.")
else:
click.echo(f"No hash file for {name}. Regenerating...")
spin.util.run(cmd)
click.echo("Type stubs regenerated.")
@click.command()
def regenerate_clangtidy_files():
"""Regenerate clang-tidy files."""
cmd = [
sys.executable,
"-m",
"tools.linter.clang_tidy.generate_build_files",
]
spin.util.run(cmd)
#: These linters are expected to need less than 3s cpu time total
VERY_FAST_LINTERS = {
"ATEN_CPU_GPU_AGNOSTIC",
"BAZEL_LINTER",
"C10_NODISCARD",
"C10_UNUSED",
"CALL_ONCE",
"CMAKE_MINIMUM_REQUIRED",
"CONTEXT_DECORATOR",
"COPYRIGHT",
"CUBINCLUDE",
"DEPLOY_DETECTION",
"ERROR_PRONE_ISINSTANCE",
"EXEC",
"HEADER_ONLY_LINTER",
"IMPORT_LINTER",
"INCLUDE",
"LINTRUNNER_VERSION",
"MERGE_CONFLICTLESS_CSV",
"META_NO_CREATE_UNBACKED",
"NEWLINE",
"NOQA",
"NO_WORKFLOWS_ON_FORK",
"ONCE_FLAG",
"PYBIND11_INCLUDE",
"PYBIND11_SPECIALIZATION",
"PYPIDEP",
"PYPROJECT",
"RAWCUDA",
"RAWCUDADEVICE",
"ROOT_LOGGING",
"TABS",
"TESTOWNERS",
"TYPEIGNORE",
"TYPENOSKIP",
"WORKFLOWSYNC",
}
#: These linters are expected to take a few seconds, but less than 10s cpu time total
FAST_LINTERS = {
"CMAKE",
"DOCSTRING_LINTER",
"GHA",
"NATIVEFUNCTIONS",
"RUFF",
"SET_LINTER",
"SHELLCHECK",
"SPACES",
}
#: These linters are expected to take more than 10s cpu time total;
#: some need more than 1 hour.
SLOW_LINTERS = {
"ACTIONLINT",
"CLANGFORMAT",
"CLANGTIDY",
"CODESPELL",
"FLAKE8",
"GB_REGISTRY",
"PYFMT",
"PYREFLY",
"TEST_DEVICE_BIAS",
"TEST_HAS_MAIN",
}
ALL_LINTERS = VERY_FAST_LINTERS | FAST_LINTERS | SLOW_LINTERS
LINTRUNNER_CACHE_INFO = (
Path(".lintbin/.lintrunner.sha256"),
[
"requirements.txt",
"pyproject.toml",
".lintrunner.toml",
],
)
LINTRUNNER_BASE_CMD = [
"uvx",
"--python",
"3.10",
"lintrunner@0.12.7",
]
@click.command()
def setup_lint():
"""Set up lintrunner with current CI version."""
cmd = LINTRUNNER_BASE_CMD + ["init"]
subprocess.run(cmd, check=True, capture_output=True, text=True)
def _check_linters():
cmd = LINTRUNNER_BASE_CMD + ["list"]
ret = spin.util.run(cmd, output=False, stderr=subprocess.PIPE)
linters = {l.strip() for l in ret.stdout.decode().strip().split("\n")[1:]}
unknown_linters = linters - ALL_LINTERS
missing_linters = ALL_LINTERS - linters
if unknown_linters:
click.secho(
f"Unknown linters found; please add them to the correct category "
f"in .spin/cmds.py: {', '.join(unknown_linters)}",
fg="yellow",
)
if missing_linters:
click.secho(
f"Missing linters found; please update the corresponding category "
f"in .spin/cmds.py: {', '.join(missing_linters)}",
fg="yellow",
)
return unknown_linters, missing_linters
@spin.util.extend_command(
setup_lint,
doc=f"""
If configuration has changed, update lintrunner.
Compares the stored old hashes of configuration files with new ones and
performs setup via setup-lint if the hashes have changed.
Hashes are stored in {LINTRUNNER_CACHE_INFO[0]}; the following files are
considered: {", ".join(LINTRUNNER_CACHE_INFO[1])}.
""",
)
@click.pass_context
def lazy_setup_lint(ctx, parent_callback, **kwargs):
if hashes := _updated_hashes(*LINTRUNNER_CACHE_INFO):
click.echo(
"Changes detected in lint configuration files. Setting up linting tools..."
)
parent_callback(**kwargs)
hash_file = LINTRUNNER_CACHE_INFO[0]
hash_file.parent.mkdir(parents=True, exist_ok=True)
with hash_file.open("w") as f:
for file, hash in hashes.items():
f.write(f"{hash} {file}\n")
click.echo("Linting tools set up and hashes updated.")
else:
click.echo("No changes detected in lint configuration files. Skipping setup.")
click.echo("Regenerating version...")
ctx.invoke(regenerate_version)
click.echo("Regenerating type stubs...")
ctx.invoke(regenerate_type_stubs)
click.echo("Done.")
_check_linters()
@click.command()
@click.option("-a", "--apply-patches", is_flag=True)
@click.pass_context
def lint(ctx, apply_patches, **kwargs):
"""Lint all files."""
ctx.invoke(lazy_setup_lint)
all_files_linters = VERY_FAST_LINTERS | FAST_LINTERS
changed_files_linters = SLOW_LINTERS
cmd = LINTRUNNER_BASE_CMD
if apply_patches:
cmd += ["--apply-patches"]
all_files_cmd = cmd + [
"--take",
",".join(all_files_linters),
"--all-files",
]
spin.util.run(all_files_cmd)
changed_files_cmd = cmd + [
"--take",
",".join(changed_files_linters),
]
spin.util.run(changed_files_cmd)
@click.command()
@click.pass_context
def fixlint(ctx, **kwargs):
"""Autofix all files."""
ctx.invoke(lint, apply_patches=True)
@click.command()
@click.option("-a", "--apply-patches", is_flag=True)
@click.pass_context
def quicklint(ctx, apply_patches, **kwargs):
"""Lint changed files."""
ctx.invoke(lazy_setup_lint)
cmd = LINTRUNNER_BASE_CMD
if apply_patches:
cmd += ["--apply-patches"]
spin.util.run(cmd)
@click.command()
@click.pass_context
def quickfix(ctx, **kwargs):
"""Autofix changed files."""
ctx.invoke(quicklint, apply_patches=True)

View File

@ -376,3 +376,19 @@ keep-runtime-typing = true
[tool.codespell]
ignore-words = "tools/linter/dictionary.txt"
[tool.spin]
package = 'torch'
[tool.spin.commands]
"Build" = [
".spin/cmds.py:lint",
".spin/cmds.py:fixlint",
".spin/cmds.py:quicklint",
".spin/cmds.py:quickfix",
]
"Regenerate" = [
".spin/cmds.py:regenerate_version",
".spin/cmds.py:regenerate_type_stubs",
".spin/cmds.py:regenerate_clangtidy_files",
]

View File

@ -14,6 +14,7 @@ lintrunner ; platform_machine != "s390x" and platform_machine != "riscv64"
networkx>=2.5.1
optree>=0.13.0
psutil
spin
sympy>=1.13.3
typing-extensions>=4.13.2
wheel

View File

@ -0,0 +1,53 @@
# Owner(s): ["oncall: distributed"]
import os
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import torch
import torch.distributed as dist
from torch.distributed.debug import start_debug_server, stop_debug_server
from torch.testing._internal.common_utils import run_tests, TestCase
session = requests.Session()
retry_strategy = Retry(total=5, backoff_factor=0.5)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("http://", adapter)
session.mount("https://", adapter)
class TestDebug(TestCase):
def test_basics(self) -> None:
store = dist.TCPStore("localhost", 0, 1, is_master=True, wait_for_workers=False)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(store.port)
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
port = 25999
def fetch(path: str) -> str:
resp = session.get(f"http://localhost:{port}{path}")
resp.raise_for_status()
return resp.text
print("starting!")
start_debug_server(port=port)
self.assertIn("torch profiler", fetch("/"))
self.assertIn("View 0", fetch("/profile?duration=0.01"))
self.assertIn("test_basics", fetch("/stacks"))
self.assertIn("pg_status", fetch("/fr_trace"))
if torch.cuda.is_available():
self.assertIn("pg_status", fetch("/fr_trace_nccl"))
stop_debug_server()
if __name__ == "__main__":
run_tests()

View File

@ -1,9 +1,11 @@
# Owner(s): ["module: dynamo"]
import copy
import functools
import inspect
import os
import pickle
import unittest
from contextlib import contextmanager
from unittest.mock import patch
@ -13,13 +15,16 @@ import torch._inductor.config
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.aot_compile import ModelInput, SerializableCallable
from torch._dynamo.aot_compile import AOTCompiledModel, ModelInput, SerializableCallable
from torch._dynamo.exc import PackageError, Unsupported
from torch._dynamo.package import DynamoCache
from torch._dynamo.precompile_context import PrecompileContext
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.fx._graph_pickler import GraphPickler
from torch.testing._internal.common_utils import instantiate_parametrized_tests
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
TEST_CUDA,
)
MY_LAMBDA = lambda x: x + 1 # noqa: E731
@ -599,6 +604,92 @@ from user code:
actual = compiled_fn(*inputs)
self.assertEqual(expected, actual)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti(self):
with torch.device("cuda"):
from torch._dynamo.hooks import Hooks
def fn(x, y):
return x + y
def make_inputs():
return (torch.randn(3, 4), torch.randn(3, 4))
compiled_fn = torch._dynamo.aot_compile.aot_compile_fullgraph(
fn,
(make_inputs(), {}),
Hooks(),
torch._TorchCompileAOTInductorWrapper(None, None, None),
)
test_inputs = make_inputs()
expected = fn(*test_inputs)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti_module(self):
with torch.device("cuda"):
from torch._dynamo.hooks import Hooks
mod = SimpleLinearModule()
def make_inputs():
return (torch.randn(4, 3),)
compiled_mod = torch._dynamo.aot_compile.aot_compile_module(
mod,
[ModelInput(make_inputs(), {}, [])],
Hooks(),
torch._TorchCompileAOTInductorWrapper(None, None, None),
)
def get_grads(m: torch.nn.Module):
return {name: p.grad for name, p in m.named_parameters()}
original_mod = copy.deepcopy(mod)
test_inputs = make_inputs()
expected = mod(*test_inputs)
expected.sum().backward()
expected_grads = get_grads(mod)
actual = compiled_mod(*test_inputs)
self.assertEqual(expected, actual)
serialized = compiled_mod.serialize()
compiled_fn = AOTCompiledModel.deserialize(original_mod, serialized)
actual = compiled_fn(*test_inputs)
actual.sum().backward()
self.assertEqual(get_grads(original_mod), expected_grads)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti_torch_compile(self):
with torch.device("cuda"):
def fn(x, y):
return x + y
def make_inputs():
return (torch.randn(3, 4), torch.randn(3, 4))
compiled_fn = torch.compile(
fn, fullgraph=True, options={"use_aoti": True}
).aot_compile((make_inputs(), {}))
test_inputs = make_inputs()
expected = fn(*test_inputs)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(*test_inputs)
self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor")
self.assertEqual(expected, actual)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -100,7 +100,9 @@ class Logger:
def _set_static_graph(self) -> None: ...
class _WorkerServer:
def __init__(self, socket_path: str) -> None: ...
port: int
def __init__(self, host_or_file: str, port: int = ...) -> None: ...
def shutdown(self) -> None: ...
def get_debug_level(): ...
@ -206,6 +208,7 @@ class Store:
desired_value: str,
) -> bytes: ...
def delete_key(self, key: str) -> bool: ...
def multi_get(self, keys: list[str]) -> list[bytes]: ...
def num_keys(self) -> int: ...
def set_timeout(self, timeout: timedelta): ...
@overload
@ -871,3 +874,15 @@ class ProcessGroupXCCL(Backend):
def _set_process_group(pg: ProcessGroup) -> None: ...
def _current_process_group() -> ProcessGroup: ...
class _Request:
def body(self) -> bytes: ...
def get_param(self, str) -> str: ...
class _Response:
def set_content(self, content: str | bytes, content_type: str) -> None: ...
def set_status(self, status: int) -> None: ...
def _register_handler(
name: str, handler: Callable[[_Request, _Response], None]
) -> None: ...

View File

@ -60,6 +60,7 @@ class _ExperimentalConfig:
verbose: bool = ...,
performance_events: list[str] = ...,
enable_cuda_sync_events: bool = ...,
profile_all_threads: bool = ...,
) -> None: ...
class ProfilerConfig:

View File

@ -2439,6 +2439,35 @@ class _TorchCompileInductorWrapper:
reset_cudagraph_trees()
class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
compiler_name = "aotinductor"
def __init__(self, mode, options, dynamic):
super().__init__(mode, options, dynamic)
self.apply_options({"cpp_wrapper": True})
self.apply_options({"aot_inductor.package": True})
def __call__(self, model_, inputs_):
from contextlib import nullcontext
from unittest import mock
from torch._guards import detect_fake_mode
from torch._inductor.virtualized import V
fake_mode = detect_fake_mode(inputs_)
ctx = (
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
if fake_mode
else nullcontext()
)
with (
V.set_aot_compilation(True),
ctx,
torch._inductor.config.patch("enable_autograd_for_aot", True),
):
return super().__call__(model_, inputs_)
class _TorchCompileWrapper:
def __init__(self, backend, mode, options, dynamic):
from torch._dynamo.backends.registry import lookup_backend
@ -2672,8 +2701,10 @@ def compile(
backend = bisect_backend
guard_filter_fn = None
use_aoti = False
if options and isinstance(options, dict):
guard_filter_fn = options.pop("guard_filter_fn", None)
use_aoti = options.pop("use_aoti", False)
if torch.compiler.is_exporting():
warnings.warn(
@ -2700,7 +2731,10 @@ def compile(
return export_wrapped_fn
if backend == "inductor":
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
if use_aoti:
backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
else:
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
else:
backend = _TorchCompileWrapper(backend, mode, options, dynamic)

View File

@ -53,6 +53,7 @@ class CompileArtifacts:
argdefs: Optional[tuple[Any, ...]]
source_info: "SourceInfo"
device_type: str
backend_name: str
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
def check_compatibility(self) -> None:
@ -166,7 +167,8 @@ class AOTCompiledFunction:
state = pickle.loads(data)
state["bytecode"] = SerializedCode.to_code_object(state["bytecode"])
deserializer, compiled_fn_state = state["compiled_fn"]
state["compiled_fn"] = deserializer(compiled_fn_state)
with torch._inductor.config.patch(enable_autograd_for_aot=True):
state["compiled_fn"] = deserializer(compiled_fn_state)
state["original_code"] = SerializedCode.to_code_object(state["original_code"])
artifacts = CompileArtifacts(**state)
@ -273,6 +275,7 @@ def aot_compile_fullgraph(
argdefs=fn.__defaults__,
source_info=source_info,
device_type=device_type,
backend_name=getattr(backend, "compiler_name", "unknown"),
)
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)

View File

@ -511,6 +511,7 @@ class GenericAOTAutogradResult(Generic[TForward, TBackward]):
).post_compile(
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
)
compiled_fw_func._boxed_call = True
disable_amp = torch._C._is_any_autocast_enabled()
if needs_autograd:

View File

@ -1640,7 +1640,9 @@ class _InProcessFxCompile(FxCompile):
# pyrefly: ignore [unbound-name]
(str, list, torch.fx.GraphModule),
), type(compiled_fn)
return CompiledAOTI(compiled_fn)
return CompiledAOTI(
filename=compiled_fn, device_type=graph.device_type
)
# TODO: Hoist this above V.aot_compilation
# pyrefly: ignore [unbound-name]
@ -2713,7 +2715,7 @@ def _compile_fx_main(
or torch._guards.TracingContext(fake_mode)
)
if V.aot_compilation:
if V.aot_compilation and not config.enable_autograd_for_aot:
from .utils import is_valid_aoti_model_name
is_valid_aoti_model_name()

View File

@ -1190,6 +1190,8 @@ autotune_lookup_table: dict[str, dict[str, Any]] = {}
file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600"))
enable_autograd_for_aot: bool = False
def get_worker_log_path() -> Optional[str]:
log_loc = None

View File

@ -773,9 +773,86 @@ class CompiledAOTI(OutputCode):
"""
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
device_type: str
current_callable: Optional[Callable[..., Any]] = None
_cached_files: dict[str, bytes] = dataclasses.field(default_factory=dict)
def __post_init__(self):
if not config.aot_inductor.link_libtorch:
return
if (
torch._inductor.cpp_builder._IS_MACOS
or torch._inductor.cpp_builder._IS_WINDOWS
):
return
if config.aot_inductor.cross_target_platform == "windows":
return
if config.aot_inductor.package_cpp_only:
return
if not config.enable_autograd_for_aot:
return
if isinstance(self.filename, list):
current_callable = next(
fn for fn in self.filename if isinstance(fn, str) and fn.endswith(".so")
)
else:
current_callable = self.filename
if isinstance(current_callable, torch.fx.GraphModule):
self.current_callable = current_callable
return
if self.device_type.startswith("cuda"):
current_callable = (
torch._C._aoti.AOTIModelContainerRunnerCuda( # type: ignore[call-arg]
current_callable,
1,
self.device_type,
"",
True,
).run # type: ignore[attr-defined]
) # type: ignore[attr-defined]
elif self.device_type == "cpu":
current_callable = (
torch._C._aoti.AOTIModelContainerRunnerCpu( # type: ignore[call-arg]
current_callable, 1
).run # type: ignore[attr-defined]
) # type: ignore[attr-defined]
else:
raise RuntimeError(f"unsupported device type {self.device_type}")
self.current_callable = current_callable
self._boxed_call = True
for file in self._cached_files:
if not os.path.exists(file):
with open(file, "wb") as f:
f.write(self._cached_files[file])
def __call__(self, inputs: Sequence[Any]) -> Any:
raise NotImplementedError("NYI")
if self.current_callable is None:
raise RuntimeError("AOTInductor compiled so is not loaded")
return self.current_callable(inputs)
def prepare_for_serialization(self) -> None:
self.current_callable = None
self._cached_files = {}
filenames: list[str] = []
if isinstance(self.filename, list):
filenames = self.filename # type: ignore[assignment]
elif isinstance(self.filename, str):
filenames = [self.filename]
for name in filenames:
with open(name, "rb") as f:
self._cached_files[name] = f.read()
def __getstate__(self):
state = self.__dict__.copy()
state["current_callable"] = None
return state
def post_compile(
self,
@ -783,10 +860,8 @@ class CompiledAOTI(OutputCode):
constants: CompiledFxGraphConstants,
graph_kwargs: _CompileFxKwargs,
) -> None:
pass
def prepare_for_serialization(self) -> None:
pass
if self.current_callable is None:
self.__post_init__()
def set_triton_bundle(self, triton_bundle: Any) -> None:
pass

View File

@ -1,5 +1,7 @@
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
#include <torch/csrc/distributed/c10d/FlightRecorder.hpp>
#include <fmt/format.h>
#include <mutex>
#include <shared_mutex>
@ -63,6 +65,14 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
res.setStatus(200);
}};
RegisterHandler frTracehandler(
"fr_trace_json",
[](const Request&, Response& res) {
auto trace = ::c10d::dump_fr_trace_json(true, true);
res.setContent(std::move(trace), "application/json");
res.setStatus(200);
});
} // namespace
void registerHandler(const std::string& name, HandlerFunc f) {

View File

@ -18,6 +18,14 @@ class TORCH_API Request {
virtual const std::string& body() const = 0;
virtual const std::multimap<std::string, std::string>& params() const = 0;
std::string getParam(const std::string& key) const {
auto it = params().find(key);
if (it != params().end()) {
return it->second;
}
return "";
}
};
// Response represents a response to the handler. This conceptually maps to an

View File

@ -152,11 +152,17 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) {
TORCH_CHECK(
server_.bind_to_port(hostOrFile, 80),
fmt::format("Error binding to {}", hostOrFile));
} else if (port == 0) {
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
port_ = server_.bind_to_any_port(hostOrFile);
TORCH_CHECK(
port_ >= 0, fmt::format("Error binding to {}:{}", hostOrFile, port));
} else {
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
TORCH_CHECK(
server_.bind_to_port(hostOrFile, port),
fmt::format("Error binding to {}:{}", hostOrFile, port));
port_ = port;
}
serverThread_ = std::thread([this]() {

View File

@ -19,9 +19,14 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
void shutdown();
int port() {
return port_;
}
private:
httplib::Server server_;
std::thread serverThread_;
int port_;
};
} // namespace c10d::control_plane

View File

@ -46,6 +46,7 @@
#include <fmt/format.h>
#include <pybind11/chrono.h>
#include <pybind11/functional.h>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
@ -4203,7 +4204,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
}),
py::arg("host_or_file"),
py::arg("port") = -1)
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown)
.def_property_readonly(
"port", &::c10d::control_plane::WorkerServer::port);
module.def(
"_get_handler",
@ -4219,6 +4222,25 @@ such as `dist.all_reduce(tensor, async_op=True)`.
Returns the handler with the specified name.
)");
module.def(
"_register_handler",
[](const std::string& name, const py::function& handler) {
::c10d::control_plane::registerHandler(
name,
[handler](
const ::c10d::control_plane::Request& req,
::c10d::control_plane::Response& res) {
py::gil_scoped_acquire acquire;
handler(std::ref(req), std::ref(res));
});
},
py::arg("name"),
py::arg("handler"),
R"(
Registers a handler by name.
)");
module.def(
"_get_handler_names",
&::c10d::control_plane::getHandlerNames,
@ -4236,12 +4258,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
// Default constructor.
.def(py::init<>())
.def("body", &::c10d::control_plane::Request::body)
.def("params", &::c10d::control_plane::Request::params);
.def("get_param", &::c10d::control_plane::Request::getParam);
py::class_<
::c10d::control_plane::Response,
std::shared_ptr<::c10d::control_plane::Response>,
PythonResponse>(
py::class_<::c10d::control_plane::Response, PythonResponse>(
module,
"_Response",
R"(

View File

@ -66,6 +66,12 @@ void initAOTIRunnerBindings(PyObject* module) {
int,
const std::string&,
const std::string&>())
.def(py::init<
const std::string&,
int,
const std::string&,
const std::string&,
const bool>())
.def(
"run",
&AOTIModelContainerRunnerCuda::run,

View File

@ -0,0 +1,59 @@
import logging
import multiprocessing
import socket
# import for registration side effect
import torch.distributed.debug._handlers # noqa: F401
from torch._C._distributed_c10d import _WorkerServer
from torch.distributed.debug._store import get_rank, tcpstore_client
__all__ = [
"start_debug_server",
"stop_debug_server",
]
logger: logging.Logger = logging.getLogger(__name__)
_WORKER_SERVER: _WorkerServer | None = None
_DEBUG_SERVER_PROC: multiprocessing.Process | None = None
def start_debug_server(port: int = 25999, worker_port: int = 0) -> None:
global _WORKER_SERVER, _DEBUG_SERVER_PROC
assert _WORKER_SERVER is None, "debug server already started"
assert _DEBUG_SERVER_PROC is None, "debug server already started"
logger.info("Starting debug server on port %d", port)
store = tcpstore_client()
_WORKER_SERVER = _WorkerServer("::", worker_port)
RANK = get_rank()
store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_WORKER_SERVER.port}")
from torch.distributed.debug._flask import main
if RANK == 0:
_DEBUG_SERVER_PROC = multiprocessing.Process(
target=main, args=(port,), daemon=True
)
_DEBUG_SERVER_PROC.start()
def stop_debug_server() -> None:
global _WORKER_SERVER, _DEBUG_SERVER_PROC
assert _DEBUG_SERVER_PROC is not None
assert _WORKER_SERVER is not None
logger.info("Stopping debug server")
_DEBUG_SERVER_PROC.terminate()
_WORKER_SERVER.shutdown()
_DEBUG_SERVER_PROC.join()
_WORKER_SERVER = None
_DEBUG_SERVER_PROC = None

View File

@ -0,0 +1,265 @@
import json
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
import requests
from flask import Flask, render_template, request
from jinja2 import DictLoader
from torch.distributed.debug._store import get_world_size, tcpstore_client
def fetch_all(
endpoint: str, args: str = ""
) -> tuple[list[str], Iterator[requests.Response]]:
store = tcpstore_client()
keys = [f"rank{r}" for r in range(get_world_size())]
addrs = store.multi_get(keys)
addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs]
with ThreadPoolExecutor(max_workers=10) as executor:
resps = executor.map(requests.post, addrs)
return addrs, resps
def format_json(blob: str):
parsed = json.loads(blob)
return json.dumps(parsed, indent=2)
templates = {
"base.html": """
<!doctype html>
<head>
<title>{% block title %}{% endblock %} - PyTorch Distributed</title>
<link rel="shortcut icon" type="image/x-icon" href="https://pytorch.org/favicon.ico?">
<style>
body {
margin: 0;
font-family:
-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,
"Helvetica Neue",Arial,"Noto Sans",sans-serif,"Apple Color Emoji",
"Segoe UI Emoji","Segoe UI Symbol","Noto Color Emoji";
font-size: 1rem;
font-weight: 400;
line-height: 1.5;
color: #212529;
text-align: left;
background-color: #fff;
}
h1, h2, h2, h4, h5, h6, .h1, .h2, .h2, .h4, .h5, .h6 {
margin-bottom: .5rem;
font-weight: 500;
line-height: 1.2;
}
nav {
background-color: rgba(0, 0, 0, 0.17);
padding: 10px;
display: flex;
align-items: center;
padding: 16px;
justify-content: flex-start;
}
nav h1 {
display: inline-block;
margin: 0;
}
nav a {
margin: 0 8px;
}
section {
max-width: 1280px;
padding: 16px;
margin: 0 auto;
}
pre {
white-space: pre-wrap;
max-width: 100%;
}
</style>
</head>
<nav>
<h1>Torch Distributed Debug Server</h1>
<a href="/">Home</a> <!--@lint-ignore-->
<a href="/stacks">Python Stack Traces</a> <!--@lint-ignore-->
<a href="/fr_trace">FlightRecorder</a> <!--@lint-ignore-->
<a href="/fr_trace_nccl">FlightRecorder NCCL</a> <!--@lint-ignore-->
<a href="/profile">torch profiler</a> <!--@lint-ignore-->
</nav>
<section class="content">
{% block header %}{% endblock %}
{% for message in get_flashed_messages() %}
<div class="flash">{{ message }}</div>
{% endfor %}
{% block content %}{% endblock %}
</section>
""",
"index.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}Index{% endblock %}</h1>
{% endblock %}
{% block content %}
Hi
{% endblock %}
""",
"raw_resp.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}{{title}}{% endblock %}</h1>
{% endblock %}
{% block content %}
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
<h2>Rank {{ i }}: {{ addr }}</h2>
{% if resp.status_code != 200 %}
<p>Failed to fetch: status={{ resp.status_code }}</p>
<pre>{{ resp.text }}</pre>
{% else %}
<pre>{{ resp.text }}</pre>
{% endif %}
{% endfor %}
{% endblock %}
""",
"json_resp.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}{{ title }}{% endblock %}</h1>
{% endblock %}
{% block content %}
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
<h2>Rank {{ i }}: {{ addr }}</h2>
{% if resp.status_code != 200 %}
<p>Failed to fetch: status={{ resp.status_code }}</p>
<pre>{{ resp.text }}</pre>
{% else %}
<pre>{{ format_json(resp.text) }}</pre>
{% endif %}
{% endfor %}
{% endblock %}
""",
"profile.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}torch.profiler{% endblock %}</h1>
{% endblock %}
{% block content %}
<form action="/profile" method="get">
<label for="duration">Duration (seconds):</label>
<input type="number" id="duration" name="duration" value="{{ duration }}" min="1" max="60">
<input type="submit" value="Submit">
</form>
<script>
function stringToArrayBuffer(str) {
const encoder = new TextEncoder();
return encoder.encode(str).buffer;
}
async function openPerfetto(data) {
const ui = window.open('https://ui.perfetto.dev/#!/');
if (!ui) { alert('Popup blocked. Allow popups for this page and click again.'); return; }
// Perfetto readiness handshake: PING until we receive PONG
await new Promise((resolve, reject) => {
const onMsg = (e) => {
if (e.source === ui && e.data === 'PONG') {
window.removeEventListener('message', onMsg);
clearInterval(pinger);
resolve();
}
};
window.addEventListener('message', onMsg);
const pinger = setInterval(() => { try { ui.postMessage('PING', '*'); } catch (_e) {} }, 250);
setTimeout(() => { clearInterval(pinger); window.removeEventListener('message', onMsg); reject(); }, 20000);
}).catch(() => { alert('Perfetto UI did not respond. Try again.'); return; });
ui.postMessage({
perfetto: {
buffer: stringToArrayBuffer(JSON.stringify(data)),
title: "torch profiler",
fileName: "trace.json",
}
}, '*');
}
</script>
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
<h2>Rank {{ i }}: {{ addr }}</h2>
{% if resp.status_code != 200 %}
<p>Failed to fetch: status={{ resp.status_code }}</p>
<pre>{{ resp.text }}</pre>
{% else %}
<script>
function run{{ i }}() {
var data = {{ resp.text | safe }};
openPerfetto(data);
}
</script>
<button onclick="run{{ i }}()">View {{ i }}</button>
{% endif %}
{% endfor %}
{% endblock %}
""",
}
app = Flask(__name__)
app.jinja_loader = DictLoader(templates)
app.jinja_env.globals.update(
zip=zip,
format_json=format_json,
enumerate=enumerate,
)
@app.route("/")
def _index_handler():
return render_template("index.html")
@app.route("/stacks")
def _stacks_handler():
addrs, resps = fetch_all("dump_traceback")
return render_template("raw_resp.html", title="Stacks", addrs=addrs, resps=resps)
@app.route("/fr_trace")
def _fr_trace_handler():
addrs, resps = fetch_all("fr_trace_json")
return render_template(
"json_resp.html",
title="FlightRecorder",
addrs=addrs,
resps=resps,
)
@app.route("/fr_trace_nccl")
def _fr_trace_nccl_handler():
addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true")
return render_template(
"json_resp.html",
title="FlightRecorder NCCL",
addrs=addrs,
resps=resps,
)
@app.route("/profile")
def _profiler_handler():
duration = request.args.get("duration", default=1.0, type=float)
addrs, resps = fetch_all("torch_profile", f"duration={duration}")
return render_template("profile.html", addrs=addrs, resps=resps)
def main(port: int) -> None:
app.run(host="::", port=port)

View File

@ -0,0 +1,22 @@
import tempfile
import time
from torch._C._distributed_c10d import _register_handler, _Request, _Response
from torch.profiler import _ExperimentalConfig, profile
def _torch_profile(req: _Request, resp: _Response) -> None:
experimental_config = _ExperimentalConfig(
profile_all_threads=True,
)
duration = float(req.get_param("duration"))
with profile(record_shapes=True, experimental_config=experimental_config) as prof:
time.sleep(duration)
with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f:
prof.export_chrome_trace(f.name)
resp.set_content(open(f.name, "rb").read(), "application/json")
resp.set_status(200)
_register_handler("torch_profile", _torch_profile)

View File

@ -0,0 +1,24 @@
import os
import torch.distributed as dist
def get_rank() -> int:
return int(os.environ["RANK"])
def get_world_size() -> int:
return int(os.environ["WORLD_SIZE"])
def tcpstore_client() -> dist.Store:
MASTER_ADDR = os.environ["MASTER_ADDR"]
MASTER_PORT = int(os.environ["MASTER_PORT"])
store = dist.TCPStore(
host_name=MASTER_ADDR,
port=MASTER_PORT,
is_master=False,
)
store = dist.PrefixStore("debug_server", store)
return store