Compare commits

...

5 Commits

Author SHA1 Message Date
d9bd0f252b distributed/debug: add an HTTP server for debugging running jobs 2025-11-14 12:11:17 -08:00
bfddfde50c Add basic spin config and linting commands (#167226)
This PR adds a basic spin configuration to allow for linting. It is designed as a drop-in replacement for the current Makefile based solution, i.e. it sets up and updates lintrunner based on the hashes of certain configuration files.

Lintrunner is called via Uv's `uvx` command, separating its environment from the general development environment in an effort to reduce instances of competing requirements breaking environments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167226
Approved by: https://github.com/atalman, https://github.com/albanD
2025-11-14 15:35:42 +00:00
b6570615f8 [precompile] Integrate AOTI as a backend. (#167338)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167338
Approved by: https://github.com/jamesjwu
2025-11-14 15:33:11 +00:00
226850cc66 [ATen][CUDA] Add sm_121a flag for RowwiseScaledMM (#167734)
This PR add a sm_121a flag for row-wise scaled matmuls on DGX Spark.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167734
Approved by: https://github.com/eqy, https://github.com/cyyever
2025-11-14 08:44:04 +00:00
f8a2ce3b9a Fix inplace ops on Partial DTensors to preserve aliasing semantics (#164729)
Fixes #163374.

Here is the output from reproducible code:

```
W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811]
W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811] *****************************************
W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811] *****************************************
  aten::clamp_(dt: f32[][R], None, 2)
    redistribute_input(0, [P] -> [R])
      redistribute_input(t: f32[], [P] -> [R])
        _c10d_functional::all_reduce(t: f32[], sum, 0)
        _c10d_functional::wait_tensor(t: f32[])
    aten::clamp_(t: f32[], None, 2)
    aten::view(t: f32[], [])
(Replicate(),)
tensor(2., device='cuda:0')
```

The behavior is now matching what you were expecting in issue #163374:

Expected behavior (from the issue):
  1. Placement should change from Partial(sum) to Replicate()
  2. Value should be tensor(2.) instead of tensor(144.)

  Actual output from this build:
  1. (Replicate(),) - placement is correct
  2. tensor(2., device='cuda:0') - value is correct

so the inplace operation now properly redistributes the partial DTensor to replicate before performing the clamp snd maintains the correct aliasing semantics. It also produces the expected clamped value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164729
Approved by: https://github.com/ezyang
2025-11-14 07:46:35 +00:00
27 changed files with 1112 additions and 22 deletions

View File

@ -402,3 +402,6 @@ scikit-build==0.18.1
pyre-extensions==0.0.32
tabulate==0.9.0
#Description: These package are needed to build FBGEMM and torchrec on PyTorch CI
Flask==3.1.2
#Description: required for torch.distributed.debug

330
.spin/cmds.py Normal file
View File

@ -0,0 +1,330 @@
import hashlib
import subprocess
import sys
from pathlib import Path
import click
import spin
def file_digest(file, algorithm: str):
try:
return hashlib.file_digest(file, algorithm)
except AttributeError:
pass # Fallback to manual implementation below
hash = hashlib.new(algorithm)
while chunk := file.read(8192):
hash.update(chunk)
return hash
def _hash_file(file):
with open(file, "rb") as f:
hash = file_digest(f, "sha256")
return hash.hexdigest()
def _hash_files(files):
hashes = {file: _hash_file(file) for file in files}
return hashes
def _read_hashes(hash_file: Path):
if not hash_file.exists():
return {}
with hash_file.open("r") as f:
lines = f.readlines()
hashes = {}
for line in lines:
hash = line[:64]
file = line[66:].strip()
hashes[file] = hash
return hashes
def _updated_hashes(hash_file, files_to_hash):
old_hashes = _read_hashes(hash_file)
new_hashes = _hash_files(files_to_hash)
if new_hashes != old_hashes:
return new_hashes
return None
@click.command()
def regenerate_version():
"""Regenerate version.py."""
cmd = [
sys.executable,
"-m",
"tools.generate_torch_version",
"--is-debug=false",
]
spin.util.run(cmd)
TYPE_STUBS = [
(
"Pytorch type stubs",
Path(".lintbin/.pytorch-type-stubs.sha256"),
[
"aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml",
"tools/autograd/deprecated.yaml",
],
[
sys.executable,
"-m",
"tools.pyi.gen_pyi",
"--native-functions-path",
"aten/src/ATen/native/native_functions.yaml",
"--tags-path",
"aten/src/ATen/native/tags.yaml",
"--deprecated-functions-path",
"tools/autograd/deprecated.yaml",
],
),
(
"Datapipes type stubs",
None,
[],
[
sys.executable,
"torch/utils/data/datapipes/gen_pyi.py",
],
),
]
@click.command()
def regenerate_type_stubs():
"""Regenerate type stubs."""
for name, hash_file, files_to_hash, cmd in TYPE_STUBS:
if hash_file:
if hashes := _updated_hashes(hash_file, files_to_hash):
click.echo(
f"Changes detected in type stub files for {name}. Regenerating..."
)
spin.util.run(cmd)
hash_file.parent.mkdir(parents=True, exist_ok=True)
with hash_file.open("w") as f:
for file, hash in hashes.items():
f.write(f"{hash} {file}\n")
click.echo("Type stubs and hashes updated.")
else:
click.echo(f"No changes detected in type stub files for {name}.")
else:
click.echo(f"No hash file for {name}. Regenerating...")
spin.util.run(cmd)
click.echo("Type stubs regenerated.")
@click.command()
def regenerate_clangtidy_files():
"""Regenerate clang-tidy files."""
cmd = [
sys.executable,
"-m",
"tools.linter.clang_tidy.generate_build_files",
]
spin.util.run(cmd)
#: These linters are expected to need less than 3s cpu time total
VERY_FAST_LINTERS = {
"ATEN_CPU_GPU_AGNOSTIC",
"BAZEL_LINTER",
"C10_NODISCARD",
"C10_UNUSED",
"CALL_ONCE",
"CMAKE_MINIMUM_REQUIRED",
"CONTEXT_DECORATOR",
"COPYRIGHT",
"CUBINCLUDE",
"DEPLOY_DETECTION",
"ERROR_PRONE_ISINSTANCE",
"EXEC",
"HEADER_ONLY_LINTER",
"IMPORT_LINTER",
"INCLUDE",
"LINTRUNNER_VERSION",
"MERGE_CONFLICTLESS_CSV",
"META_NO_CREATE_UNBACKED",
"NEWLINE",
"NOQA",
"NO_WORKFLOWS_ON_FORK",
"ONCE_FLAG",
"PYBIND11_INCLUDE",
"PYBIND11_SPECIALIZATION",
"PYPIDEP",
"PYPROJECT",
"RAWCUDA",
"RAWCUDADEVICE",
"ROOT_LOGGING",
"TABS",
"TESTOWNERS",
"TYPEIGNORE",
"TYPENOSKIP",
"WORKFLOWSYNC",
}
#: These linters are expected to take a few seconds, but less than 10s cpu time total
FAST_LINTERS = {
"CMAKE",
"DOCSTRING_LINTER",
"GHA",
"NATIVEFUNCTIONS",
"RUFF",
"SET_LINTER",
"SHELLCHECK",
"SPACES",
}
#: These linters are expected to take more than 10s cpu time total;
#: some need more than 1 hour.
SLOW_LINTERS = {
"ACTIONLINT",
"CLANGFORMAT",
"CLANGTIDY",
"CODESPELL",
"FLAKE8",
"GB_REGISTRY",
"PYFMT",
"PYREFLY",
"TEST_DEVICE_BIAS",
"TEST_HAS_MAIN",
}
ALL_LINTERS = VERY_FAST_LINTERS | FAST_LINTERS | SLOW_LINTERS
LINTRUNNER_CACHE_INFO = (
Path(".lintbin/.lintrunner.sha256"),
[
"requirements.txt",
"pyproject.toml",
".lintrunner.toml",
],
)
LINTRUNNER_BASE_CMD = [
"uvx",
"--python",
"3.10",
"lintrunner@0.12.7",
]
@click.command()
def setup_lint():
"""Set up lintrunner with current CI version."""
cmd = LINTRUNNER_BASE_CMD + ["init"]
subprocess.run(cmd, check=True, capture_output=True, text=True)
def _check_linters():
cmd = LINTRUNNER_BASE_CMD + ["list"]
ret = spin.util.run(cmd, output=False, stderr=subprocess.PIPE)
linters = {l.strip() for l in ret.stdout.decode().strip().split("\n")[1:]}
unknown_linters = linters - ALL_LINTERS
missing_linters = ALL_LINTERS - linters
if unknown_linters:
click.secho(
f"Unknown linters found; please add them to the correct category "
f"in .spin/cmds.py: {', '.join(unknown_linters)}",
fg="yellow",
)
if missing_linters:
click.secho(
f"Missing linters found; please update the corresponding category "
f"in .spin/cmds.py: {', '.join(missing_linters)}",
fg="yellow",
)
return unknown_linters, missing_linters
@spin.util.extend_command(
setup_lint,
doc=f"""
If configuration has changed, update lintrunner.
Compares the stored old hashes of configuration files with new ones and
performs setup via setup-lint if the hashes have changed.
Hashes are stored in {LINTRUNNER_CACHE_INFO[0]}; the following files are
considered: {", ".join(LINTRUNNER_CACHE_INFO[1])}.
""",
)
@click.pass_context
def lazy_setup_lint(ctx, parent_callback, **kwargs):
if hashes := _updated_hashes(*LINTRUNNER_CACHE_INFO):
click.echo(
"Changes detected in lint configuration files. Setting up linting tools..."
)
parent_callback(**kwargs)
hash_file = LINTRUNNER_CACHE_INFO[0]
hash_file.parent.mkdir(parents=True, exist_ok=True)
with hash_file.open("w") as f:
for file, hash in hashes.items():
f.write(f"{hash} {file}\n")
click.echo("Linting tools set up and hashes updated.")
else:
click.echo("No changes detected in lint configuration files. Skipping setup.")
click.echo("Regenerating version...")
ctx.invoke(regenerate_version)
click.echo("Regenerating type stubs...")
ctx.invoke(regenerate_type_stubs)
click.echo("Done.")
_check_linters()
@click.command()
@click.option("-a", "--apply-patches", is_flag=True)
@click.pass_context
def lint(ctx, apply_patches, **kwargs):
"""Lint all files."""
ctx.invoke(lazy_setup_lint)
all_files_linters = VERY_FAST_LINTERS | FAST_LINTERS
changed_files_linters = SLOW_LINTERS
cmd = LINTRUNNER_BASE_CMD
if apply_patches:
cmd += ["--apply-patches"]
all_files_cmd = cmd + [
"--take",
",".join(all_files_linters),
"--all-files",
]
spin.util.run(all_files_cmd)
changed_files_cmd = cmd + [
"--take",
",".join(changed_files_linters),
]
spin.util.run(changed_files_cmd)
@click.command()
@click.pass_context
def fixlint(ctx, **kwargs):
"""Autofix all files."""
ctx.invoke(lint, apply_patches=True)
@click.command()
@click.option("-a", "--apply-patches", is_flag=True)
@click.pass_context
def quicklint(ctx, apply_patches, **kwargs):
"""Lint changed files."""
ctx.invoke(lazy_setup_lint)
cmd = LINTRUNNER_BASE_CMD
if apply_patches:
cmd += ["--apply-patches"]
spin.util.run(cmd)
@click.command()
@click.pass_context
def quickfix(ctx, **kwargs):
"""Autofix changed files."""
ctx.invoke(quicklint, apply_patches=True)

View File

@ -118,6 +118,11 @@ if(INTERN_BUILD_ATEN_OPS)
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
endif()
endif()
if("${_arch}" STREQUAL "121a")
if(_existing_arch_flags MATCHES ".*compute_120.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_121a,code=sm_121a")
endif()
endif()
endforeach()
list(JOIN _file_compile_flags " " _file_compile_flags)
@ -126,7 +131,7 @@ if(INTERN_BUILD_ATEN_OPS)
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
"89;90a;100a;103a;120a")
"89;90a;100a;103a;120a;121a")
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
"90a")

View File

@ -376,3 +376,19 @@ keep-runtime-typing = true
[tool.codespell]
ignore-words = "tools/linter/dictionary.txt"
[tool.spin]
package = 'torch'
[tool.spin.commands]
"Build" = [
".spin/cmds.py:lint",
".spin/cmds.py:fixlint",
".spin/cmds.py:quicklint",
".spin/cmds.py:quickfix",
]
"Regenerate" = [
".spin/cmds.py:regenerate_version",
".spin/cmds.py:regenerate_type_stubs",
".spin/cmds.py:regenerate_clangtidy_files",
]

View File

@ -14,6 +14,7 @@ lintrunner ; platform_machine != "s390x" and platform_machine != "riscv64"
networkx>=2.5.1
optree>=0.13.0
psutil
spin
sympy>=1.13.3
typing-extensions>=4.13.2
wheel

View File

@ -331,6 +331,25 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
self.assertEqual(z.placements, (Replicate(),))
self.assertEqual(z.to_local(), input)
def test_inplace_op_partial_to_replicate(self):
# test that in-place operations that require redistribution raise an error
# to preserve aliasing semantics (issue #163374)
device_mesh = self.build_device_mesh()
input_tensor = torch.tensor(64.0, device=self.device_type)
partial_dt = DTensor.from_local(
input_tensor, device_mesh, placements=(Partial(),)
)
self.assertTrue(partial_dt.placements[0].is_partial())
# Inplace ops that require placement changes (Partial -> Replicate) should error
with self.assertRaisesRegex(
RuntimeError,
"in-place operations that require placement changes are not supported",
):
partial_dt.clamp_(max=10)
if __name__ == "__main__":
run_tests()

View File

@ -0,0 +1,53 @@
# Owner(s): ["oncall: distributed"]
import os
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import torch
import torch.distributed as dist
from torch.distributed.debug import start_debug_server, stop_debug_server
from torch.testing._internal.common_utils import run_tests, TestCase
session = requests.Session()
retry_strategy = Retry(total=5, backoff_factor=0.5)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("http://", adapter)
session.mount("https://", adapter)
class TestDebug(TestCase):
def test_basics(self) -> None:
store = dist.TCPStore("localhost", 0, 1, is_master=True, wait_for_workers=False)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(store.port)
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
port = 25999
def fetch(path: str) -> str:
resp = session.get(f"http://localhost:{port}{path}")
resp.raise_for_status()
return resp.text
print("starting!")
start_debug_server(port=port)
self.assertIn("torch profiler", fetch("/"))
self.assertIn("View 0", fetch("/profile?duration=0.01"))
self.assertIn("test_basics", fetch("/stacks"))
self.assertIn("pg_status", fetch("/fr_trace"))
if torch.cuda.is_available():
self.assertIn("pg_status", fetch("/fr_trace_nccl"))
stop_debug_server()
if __name__ == "__main__":
run_tests()

View File

@ -1,9 +1,11 @@
# Owner(s): ["module: dynamo"]
import copy
import functools
import inspect
import os
import pickle
import unittest
from contextlib import contextmanager
from unittest.mock import patch
@ -13,13 +15,16 @@ import torch._inductor.config
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.aot_compile import ModelInput, SerializableCallable
from torch._dynamo.aot_compile import AOTCompiledModel, ModelInput, SerializableCallable
from torch._dynamo.exc import PackageError, Unsupported
from torch._dynamo.package import DynamoCache
from torch._dynamo.precompile_context import PrecompileContext
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.fx._graph_pickler import GraphPickler
from torch.testing._internal.common_utils import instantiate_parametrized_tests
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
TEST_CUDA,
)
MY_LAMBDA = lambda x: x + 1 # noqa: E731
@ -599,6 +604,92 @@ from user code:
actual = compiled_fn(*inputs)
self.assertEqual(expected, actual)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti(self):
with torch.device("cuda"):
from torch._dynamo.hooks import Hooks
def fn(x, y):
return x + y
def make_inputs():
return (torch.randn(3, 4), torch.randn(3, 4))
compiled_fn = torch._dynamo.aot_compile.aot_compile_fullgraph(
fn,
(make_inputs(), {}),
Hooks(),
torch._TorchCompileAOTInductorWrapper(None, None, None),
)
test_inputs = make_inputs()
expected = fn(*test_inputs)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti_module(self):
with torch.device("cuda"):
from torch._dynamo.hooks import Hooks
mod = SimpleLinearModule()
def make_inputs():
return (torch.randn(4, 3),)
compiled_mod = torch._dynamo.aot_compile.aot_compile_module(
mod,
[ModelInput(make_inputs(), {}, [])],
Hooks(),
torch._TorchCompileAOTInductorWrapper(None, None, None),
)
def get_grads(m: torch.nn.Module):
return {name: p.grad for name, p in m.named_parameters()}
original_mod = copy.deepcopy(mod)
test_inputs = make_inputs()
expected = mod(*test_inputs)
expected.sum().backward()
expected_grads = get_grads(mod)
actual = compiled_mod(*test_inputs)
self.assertEqual(expected, actual)
serialized = compiled_mod.serialize()
compiled_fn = AOTCompiledModel.deserialize(original_mod, serialized)
actual = compiled_fn(*test_inputs)
actual.sum().backward()
self.assertEqual(get_grads(original_mod), expected_grads)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti_torch_compile(self):
with torch.device("cuda"):
def fn(x, y):
return x + y
def make_inputs():
return (torch.randn(3, 4), torch.randn(3, 4))
compiled_fn = torch.compile(
fn, fullgraph=True, options={"use_aoti": True}
).aot_compile((make_inputs(), {}))
test_inputs = make_inputs()
expected = fn(*test_inputs)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(*test_inputs)
self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor")
self.assertEqual(expected, actual)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -100,7 +100,9 @@ class Logger:
def _set_static_graph(self) -> None: ...
class _WorkerServer:
def __init__(self, socket_path: str) -> None: ...
port: int
def __init__(self, host_or_file: str, port: int = ...) -> None: ...
def shutdown(self) -> None: ...
def get_debug_level(): ...
@ -206,6 +208,7 @@ class Store:
desired_value: str,
) -> bytes: ...
def delete_key(self, key: str) -> bool: ...
def multi_get(self, keys: list[str]) -> list[bytes]: ...
def num_keys(self) -> int: ...
def set_timeout(self, timeout: timedelta): ...
@overload
@ -871,3 +874,15 @@ class ProcessGroupXCCL(Backend):
def _set_process_group(pg: ProcessGroup) -> None: ...
def _current_process_group() -> ProcessGroup: ...
class _Request:
def body(self) -> bytes: ...
def get_param(self, str) -> str: ...
class _Response:
def set_content(self, content: str | bytes, content_type: str) -> None: ...
def set_status(self, status: int) -> None: ...
def _register_handler(
name: str, handler: Callable[[_Request, _Response], None]
) -> None: ...

View File

@ -60,6 +60,7 @@ class _ExperimentalConfig:
verbose: bool = ...,
performance_events: list[str] = ...,
enable_cuda_sync_events: bool = ...,
profile_all_threads: bool = ...,
) -> None: ...
class ProfilerConfig:

View File

@ -2439,6 +2439,35 @@ class _TorchCompileInductorWrapper:
reset_cudagraph_trees()
class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
compiler_name = "aotinductor"
def __init__(self, mode, options, dynamic):
super().__init__(mode, options, dynamic)
self.apply_options({"cpp_wrapper": True})
self.apply_options({"aot_inductor.package": True})
def __call__(self, model_, inputs_):
from contextlib import nullcontext
from unittest import mock
from torch._guards import detect_fake_mode
from torch._inductor.virtualized import V
fake_mode = detect_fake_mode(inputs_)
ctx = (
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
if fake_mode
else nullcontext()
)
with (
V.set_aot_compilation(True),
ctx,
torch._inductor.config.patch("enable_autograd_for_aot", True),
):
return super().__call__(model_, inputs_)
class _TorchCompileWrapper:
def __init__(self, backend, mode, options, dynamic):
from torch._dynamo.backends.registry import lookup_backend
@ -2672,8 +2701,10 @@ def compile(
backend = bisect_backend
guard_filter_fn = None
use_aoti = False
if options and isinstance(options, dict):
guard_filter_fn = options.pop("guard_filter_fn", None)
use_aoti = options.pop("use_aoti", False)
if torch.compiler.is_exporting():
warnings.warn(
@ -2700,7 +2731,10 @@ def compile(
return export_wrapped_fn
if backend == "inductor":
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
if use_aoti:
backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
else:
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
else:
backend = _TorchCompileWrapper(backend, mode, options, dynamic)

View File

@ -53,6 +53,7 @@ class CompileArtifacts:
argdefs: Optional[tuple[Any, ...]]
source_info: "SourceInfo"
device_type: str
backend_name: str
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
def check_compatibility(self) -> None:
@ -166,7 +167,8 @@ class AOTCompiledFunction:
state = pickle.loads(data)
state["bytecode"] = SerializedCode.to_code_object(state["bytecode"])
deserializer, compiled_fn_state = state["compiled_fn"]
state["compiled_fn"] = deserializer(compiled_fn_state)
with torch._inductor.config.patch(enable_autograd_for_aot=True):
state["compiled_fn"] = deserializer(compiled_fn_state)
state["original_code"] = SerializedCode.to_code_object(state["original_code"])
artifacts = CompileArtifacts(**state)
@ -273,6 +275,7 @@ def aot_compile_fullgraph(
argdefs=fn.__defaults__,
source_info=source_info,
device_type=device_type,
backend_name=getattr(backend, "compiler_name", "unknown"),
)
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)

View File

@ -511,6 +511,7 @@ class GenericAOTAutogradResult(Generic[TForward, TBackward]):
).post_compile(
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
)
compiled_fw_func._boxed_call = True
disable_amp = torch._C._is_any_autocast_enabled()
if needs_autograd:

View File

@ -1640,7 +1640,9 @@ class _InProcessFxCompile(FxCompile):
# pyrefly: ignore [unbound-name]
(str, list, torch.fx.GraphModule),
), type(compiled_fn)
return CompiledAOTI(compiled_fn)
return CompiledAOTI(
filename=compiled_fn, device_type=graph.device_type
)
# TODO: Hoist this above V.aot_compilation
# pyrefly: ignore [unbound-name]
@ -2713,7 +2715,7 @@ def _compile_fx_main(
or torch._guards.TracingContext(fake_mode)
)
if V.aot_compilation:
if V.aot_compilation and not config.enable_autograd_for_aot:
from .utils import is_valid_aoti_model_name
is_valid_aoti_model_name()

View File

@ -1190,6 +1190,8 @@ autotune_lookup_table: dict[str, dict[str, Any]] = {}
file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600"))
enable_autograd_for_aot: bool = False
def get_worker_log_path() -> Optional[str]:
log_loc = None

View File

@ -773,9 +773,86 @@ class CompiledAOTI(OutputCode):
"""
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
device_type: str
current_callable: Optional[Callable[..., Any]] = None
_cached_files: dict[str, bytes] = dataclasses.field(default_factory=dict)
def __post_init__(self):
if not config.aot_inductor.link_libtorch:
return
if (
torch._inductor.cpp_builder._IS_MACOS
or torch._inductor.cpp_builder._IS_WINDOWS
):
return
if config.aot_inductor.cross_target_platform == "windows":
return
if config.aot_inductor.package_cpp_only:
return
if not config.enable_autograd_for_aot:
return
if isinstance(self.filename, list):
current_callable = next(
fn for fn in self.filename if isinstance(fn, str) and fn.endswith(".so")
)
else:
current_callable = self.filename
if isinstance(current_callable, torch.fx.GraphModule):
self.current_callable = current_callable
return
if self.device_type.startswith("cuda"):
current_callable = (
torch._C._aoti.AOTIModelContainerRunnerCuda( # type: ignore[call-arg]
current_callable,
1,
self.device_type,
"",
True,
).run # type: ignore[attr-defined]
) # type: ignore[attr-defined]
elif self.device_type == "cpu":
current_callable = (
torch._C._aoti.AOTIModelContainerRunnerCpu( # type: ignore[call-arg]
current_callable, 1
).run # type: ignore[attr-defined]
) # type: ignore[attr-defined]
else:
raise RuntimeError(f"unsupported device type {self.device_type}")
self.current_callable = current_callable
self._boxed_call = True
for file in self._cached_files:
if not os.path.exists(file):
with open(file, "wb") as f:
f.write(self._cached_files[file])
def __call__(self, inputs: Sequence[Any]) -> Any:
raise NotImplementedError("NYI")
if self.current_callable is None:
raise RuntimeError("AOTInductor compiled so is not loaded")
return self.current_callable(inputs)
def prepare_for_serialization(self) -> None:
self.current_callable = None
self._cached_files = {}
filenames: list[str] = []
if isinstance(self.filename, list):
filenames = self.filename # type: ignore[assignment]
elif isinstance(self.filename, str):
filenames = [self.filename]
for name in filenames:
with open(name, "rb") as f:
self._cached_files[name] = f.read()
def __getstate__(self):
state = self.__dict__.copy()
state["current_callable"] = None
return state
def post_compile(
self,
@ -783,10 +860,8 @@ class CompiledAOTI(OutputCode):
constants: CompiledFxGraphConstants,
graph_kwargs: _CompileFxKwargs,
) -> None:
pass
def prepare_for_serialization(self) -> None:
pass
if self.current_callable is None:
self.__post_init__()
def set_triton_bundle(self, triton_bundle: Any) -> None:
pass

View File

@ -1,5 +1,7 @@
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
#include <torch/csrc/distributed/c10d/FlightRecorder.hpp>
#include <fmt/format.h>
#include <mutex>
#include <shared_mutex>
@ -63,6 +65,14 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
res.setStatus(200);
}};
RegisterHandler frTracehandler(
"fr_trace_json",
[](const Request&, Response& res) {
auto trace = ::c10d::dump_fr_trace_json(true, true);
res.setContent(std::move(trace), "application/json");
res.setStatus(200);
});
} // namespace
void registerHandler(const std::string& name, HandlerFunc f) {

View File

@ -18,6 +18,14 @@ class TORCH_API Request {
virtual const std::string& body() const = 0;
virtual const std::multimap<std::string, std::string>& params() const = 0;
std::string getParam(const std::string& key) const {
auto it = params().find(key);
if (it != params().end()) {
return it->second;
}
return "";
}
};
// Response represents a response to the handler. This conceptually maps to an

View File

@ -152,11 +152,17 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) {
TORCH_CHECK(
server_.bind_to_port(hostOrFile, 80),
fmt::format("Error binding to {}", hostOrFile));
} else if (port == 0) {
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
port_ = server_.bind_to_any_port(hostOrFile);
TORCH_CHECK(
port_ >= 0, fmt::format("Error binding to {}:{}", hostOrFile, port));
} else {
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
TORCH_CHECK(
server_.bind_to_port(hostOrFile, port),
fmt::format("Error binding to {}:{}", hostOrFile, port));
port_ = port;
}
serverThread_ = std::thread([this]() {

View File

@ -19,9 +19,14 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
void shutdown();
int port() {
return port_;
}
private:
httplib::Server server_;
std::thread serverThread_;
int port_;
};
} // namespace c10d::control_plane

View File

@ -46,6 +46,7 @@
#include <fmt/format.h>
#include <pybind11/chrono.h>
#include <pybind11/functional.h>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
@ -4203,7 +4204,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
}),
py::arg("host_or_file"),
py::arg("port") = -1)
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown)
.def_property_readonly(
"port", &::c10d::control_plane::WorkerServer::port);
module.def(
"_get_handler",
@ -4219,6 +4222,25 @@ such as `dist.all_reduce(tensor, async_op=True)`.
Returns the handler with the specified name.
)");
module.def(
"_register_handler",
[](const std::string& name, const py::function& handler) {
::c10d::control_plane::registerHandler(
name,
[handler](
const ::c10d::control_plane::Request& req,
::c10d::control_plane::Response& res) {
py::gil_scoped_acquire acquire;
handler(std::ref(req), std::ref(res));
});
},
py::arg("name"),
py::arg("handler"),
R"(
Registers a handler by name.
)");
module.def(
"_get_handler_names",
&::c10d::control_plane::getHandlerNames,
@ -4236,12 +4258,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
// Default constructor.
.def(py::init<>())
.def("body", &::c10d::control_plane::Request::body)
.def("params", &::c10d::control_plane::Request::params);
.def("get_param", &::c10d::control_plane::Request::getParam);
py::class_<
::c10d::control_plane::Response,
std::shared_ptr<::c10d::control_plane::Response>,
PythonResponse>(
py::class_<::c10d::control_plane::Response, PythonResponse>(
module,
"_Response",
R"(

View File

@ -66,6 +66,12 @@ void initAOTIRunnerBindings(PyObject* module) {
int,
const std::string&,
const std::string&>())
.def(py::init<
const std::string&,
int,
const std::string&,
const std::string&,
const bool>())
.def(
"run",
&AOTIModelContainerRunnerCuda::run,

View File

@ -0,0 +1,59 @@
import logging
import multiprocessing
import socket
# import for registration side effect
import torch.distributed.debug._handlers # noqa: F401
from torch._C._distributed_c10d import _WorkerServer
from torch.distributed.debug._store import get_rank, tcpstore_client
__all__ = [
"start_debug_server",
"stop_debug_server",
]
logger: logging.Logger = logging.getLogger(__name__)
_WORKER_SERVER: _WorkerServer | None = None
_DEBUG_SERVER_PROC: multiprocessing.Process | None = None
def start_debug_server(port: int = 25999, worker_port: int = 0) -> None:
global _WORKER_SERVER, _DEBUG_SERVER_PROC
assert _WORKER_SERVER is None, "debug server already started"
assert _DEBUG_SERVER_PROC is None, "debug server already started"
logger.info("Starting debug server on port %d", port)
store = tcpstore_client()
_WORKER_SERVER = _WorkerServer("::", worker_port)
RANK = get_rank()
store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_WORKER_SERVER.port}")
from torch.distributed.debug._flask import main
if RANK == 0:
_DEBUG_SERVER_PROC = multiprocessing.Process(
target=main, args=(port,), daemon=True
)
_DEBUG_SERVER_PROC.start()
def stop_debug_server() -> None:
global _WORKER_SERVER, _DEBUG_SERVER_PROC
assert _DEBUG_SERVER_PROC is not None
assert _WORKER_SERVER is not None
logger.info("Stopping debug server")
_DEBUG_SERVER_PROC.terminate()
_WORKER_SERVER.shutdown()
_DEBUG_SERVER_PROC.join()
_WORKER_SERVER = None
_DEBUG_SERVER_PROC = None

View File

@ -0,0 +1,265 @@
import json
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
import requests
from flask import Flask, render_template, request
from jinja2 import DictLoader
from torch.distributed.debug._store import get_world_size, tcpstore_client
def fetch_all(
endpoint: str, args: str = ""
) -> tuple[list[str], Iterator[requests.Response]]:
store = tcpstore_client()
keys = [f"rank{r}" for r in range(get_world_size())]
addrs = store.multi_get(keys)
addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs]
with ThreadPoolExecutor(max_workers=10) as executor:
resps = executor.map(requests.post, addrs)
return addrs, resps
def format_json(blob: str):
parsed = json.loads(blob)
return json.dumps(parsed, indent=2)
templates = {
"base.html": """
<!doctype html>
<head>
<title>{% block title %}{% endblock %} - PyTorch Distributed</title>
<link rel="shortcut icon" type="image/x-icon" href="https://pytorch.org/favicon.ico?">
<style>
body {
margin: 0;
font-family:
-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,
"Helvetica Neue",Arial,"Noto Sans",sans-serif,"Apple Color Emoji",
"Segoe UI Emoji","Segoe UI Symbol","Noto Color Emoji";
font-size: 1rem;
font-weight: 400;
line-height: 1.5;
color: #212529;
text-align: left;
background-color: #fff;
}
h1, h2, h2, h4, h5, h6, .h1, .h2, .h2, .h4, .h5, .h6 {
margin-bottom: .5rem;
font-weight: 500;
line-height: 1.2;
}
nav {
background-color: rgba(0, 0, 0, 0.17);
padding: 10px;
display: flex;
align-items: center;
padding: 16px;
justify-content: flex-start;
}
nav h1 {
display: inline-block;
margin: 0;
}
nav a {
margin: 0 8px;
}
section {
max-width: 1280px;
padding: 16px;
margin: 0 auto;
}
pre {
white-space: pre-wrap;
max-width: 100%;
}
</style>
</head>
<nav>
<h1>Torch Distributed Debug Server</h1>
<a href="/">Home</a> <!--@lint-ignore-->
<a href="/stacks">Python Stack Traces</a> <!--@lint-ignore-->
<a href="/fr_trace">FlightRecorder</a> <!--@lint-ignore-->
<a href="/fr_trace_nccl">FlightRecorder NCCL</a> <!--@lint-ignore-->
<a href="/profile">torch profiler</a> <!--@lint-ignore-->
</nav>
<section class="content">
{% block header %}{% endblock %}
{% for message in get_flashed_messages() %}
<div class="flash">{{ message }}</div>
{% endfor %}
{% block content %}{% endblock %}
</section>
""",
"index.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}Index{% endblock %}</h1>
{% endblock %}
{% block content %}
Hi
{% endblock %}
""",
"raw_resp.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}{{title}}{% endblock %}</h1>
{% endblock %}
{% block content %}
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
<h2>Rank {{ i }}: {{ addr }}</h2>
{% if resp.status_code != 200 %}
<p>Failed to fetch: status={{ resp.status_code }}</p>
<pre>{{ resp.text }}</pre>
{% else %}
<pre>{{ resp.text }}</pre>
{% endif %}
{% endfor %}
{% endblock %}
""",
"json_resp.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}{{ title }}{% endblock %}</h1>
{% endblock %}
{% block content %}
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
<h2>Rank {{ i }}: {{ addr }}</h2>
{% if resp.status_code != 200 %}
<p>Failed to fetch: status={{ resp.status_code }}</p>
<pre>{{ resp.text }}</pre>
{% else %}
<pre>{{ format_json(resp.text) }}</pre>
{% endif %}
{% endfor %}
{% endblock %}
""",
"profile.html": """
{% extends "base.html" %}
{% block header %}
<h1>{% block title %}torch.profiler{% endblock %}</h1>
{% endblock %}
{% block content %}
<form action="/profile" method="get">
<label for="duration">Duration (seconds):</label>
<input type="number" id="duration" name="duration" value="{{ duration }}" min="1" max="60">
<input type="submit" value="Submit">
</form>
<script>
function stringToArrayBuffer(str) {
const encoder = new TextEncoder();
return encoder.encode(str).buffer;
}
async function openPerfetto(data) {
const ui = window.open('https://ui.perfetto.dev/#!/');
if (!ui) { alert('Popup blocked. Allow popups for this page and click again.'); return; }
// Perfetto readiness handshake: PING until we receive PONG
await new Promise((resolve, reject) => {
const onMsg = (e) => {
if (e.source === ui && e.data === 'PONG') {
window.removeEventListener('message', onMsg);
clearInterval(pinger);
resolve();
}
};
window.addEventListener('message', onMsg);
const pinger = setInterval(() => { try { ui.postMessage('PING', '*'); } catch (_e) {} }, 250);
setTimeout(() => { clearInterval(pinger); window.removeEventListener('message', onMsg); reject(); }, 20000);
}).catch(() => { alert('Perfetto UI did not respond. Try again.'); return; });
ui.postMessage({
perfetto: {
buffer: stringToArrayBuffer(JSON.stringify(data)),
title: "torch profiler",
fileName: "trace.json",
}
}, '*');
}
</script>
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
<h2>Rank {{ i }}: {{ addr }}</h2>
{% if resp.status_code != 200 %}
<p>Failed to fetch: status={{ resp.status_code }}</p>
<pre>{{ resp.text }}</pre>
{% else %}
<script>
function run{{ i }}() {
var data = {{ resp.text | safe }};
openPerfetto(data);
}
</script>
<button onclick="run{{ i }}()">View {{ i }}</button>
{% endif %}
{% endfor %}
{% endblock %}
""",
}
app = Flask(__name__)
app.jinja_loader = DictLoader(templates)
app.jinja_env.globals.update(
zip=zip,
format_json=format_json,
enumerate=enumerate,
)
@app.route("/")
def _index_handler():
return render_template("index.html")
@app.route("/stacks")
def _stacks_handler():
addrs, resps = fetch_all("dump_traceback")
return render_template("raw_resp.html", title="Stacks", addrs=addrs, resps=resps)
@app.route("/fr_trace")
def _fr_trace_handler():
addrs, resps = fetch_all("fr_trace_json")
return render_template(
"json_resp.html",
title="FlightRecorder",
addrs=addrs,
resps=resps,
)
@app.route("/fr_trace_nccl")
def _fr_trace_nccl_handler():
addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true")
return render_template(
"json_resp.html",
title="FlightRecorder NCCL",
addrs=addrs,
resps=resps,
)
@app.route("/profile")
def _profiler_handler():
duration = request.args.get("duration", default=1.0, type=float)
addrs, resps = fetch_all("torch_profile", f"duration={duration}")
return render_template("profile.html", addrs=addrs, resps=resps)
def main(port: int) -> None:
app.run(host="::", port=port)

View File

@ -0,0 +1,22 @@
import tempfile
import time
from torch._C._distributed_c10d import _register_handler, _Request, _Response
from torch.profiler import _ExperimentalConfig, profile
def _torch_profile(req: _Request, resp: _Response) -> None:
experimental_config = _ExperimentalConfig(
profile_all_threads=True,
)
duration = float(req.get_param("duration"))
with profile(record_shapes=True, experimental_config=experimental_config) as prof:
time.sleep(duration)
with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f:
prof.export_chrome_trace(f.name)
resp.set_content(open(f.name, "rb").read(), "application/json")
resp.set_status(200)
_register_handler("torch_profile", _torch_profile)

View File

@ -0,0 +1,24 @@
import os
import torch.distributed as dist
def get_rank() -> int:
return int(os.environ["RANK"])
def get_world_size() -> int:
return int(os.environ["WORLD_SIZE"])
def tcpstore_client() -> dist.Store:
MASTER_ADDR = os.environ["MASTER_ADDR"]
MASTER_PORT = int(os.environ["MASTER_PORT"])
store = dist.TCPStore(
host_name=MASTER_ADDR,
port=MASTER_PORT,
is_master=False,
)
store = dist.PrefixStore("debug_server", store)
return store

View File

@ -337,19 +337,34 @@ class OpDispatcher:
if is_inplace_op:
# inplace op should return self instead of re-wrapping
if output_sharding.output_spec is not None:
output_spec = output_sharding.output_spec
assert isinstance(output_spec, DTensorSpec)
assert isinstance(args[0], dtensor.DTensor)
# NOTE: aten.squeeze_.dim is an inplace op but it also may change
# the inplace argument's tensor meta. Here we choose to special case
# this op because as far as I know this is the only inplace op that
# has such as behavior. We can extend this special case if necessary.
if op_call == aten.squeeze_.dim:
output_spec = output_sharding.output_spec
assert isinstance(output_spec, DTensorSpec)
assert isinstance(args[0], dtensor.DTensor)
# update the spec to handle tensor meta changes
args[0]._spec = output_spec
# use return_and_correct_aliasing to match the outer and the inner
# aliasing. See https://github.com/pytorch/pytorch/pull/158954
return return_and_correct_aliasing(op_call, args, kwargs, args[0])
else:
# For all other inplace ops, check if placement changes are required
# Inplace operations that change placement are not supported because
# they would require redistribution, which breaks aliasing semantics.
# If there are views into the tensor, the views would not be updated.
if args[0]._spec.placements != output_spec.placements:
raise RuntimeError(
f"{op_call}: in-place operations that require placement changes "
f"are not supported. The operation would change placement from "
f"{args[0]._spec.placements} to {output_spec.placements}, "
f"which requires redistribution and breaks aliasing semantics. "
f"Please use the out-of-place version of this operation instead."
)
# Most inplace ops don't change tensor meta, so no spec update needed
return args[0]
else:
return None