mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-15 23:04:54 +08:00
Compare commits
5 Commits
ciflow/tru
...
d4l3k/debu
| Author | SHA1 | Date | |
|---|---|---|---|
| d9bd0f252b | |||
| bfddfde50c | |||
| b6570615f8 | |||
| 226850cc66 | |||
| f8a2ce3b9a |
@ -402,3 +402,6 @@ scikit-build==0.18.1
|
||||
pyre-extensions==0.0.32
|
||||
tabulate==0.9.0
|
||||
#Description: These package are needed to build FBGEMM and torchrec on PyTorch CI
|
||||
|
||||
Flask==3.1.2
|
||||
#Description: required for torch.distributed.debug
|
||||
|
||||
330
.spin/cmds.py
Normal file
330
.spin/cmds.py
Normal file
@ -0,0 +1,330 @@
|
||||
import hashlib
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import spin
|
||||
|
||||
|
||||
def file_digest(file, algorithm: str):
|
||||
try:
|
||||
return hashlib.file_digest(file, algorithm)
|
||||
except AttributeError:
|
||||
pass # Fallback to manual implementation below
|
||||
hash = hashlib.new(algorithm)
|
||||
while chunk := file.read(8192):
|
||||
hash.update(chunk)
|
||||
return hash
|
||||
|
||||
|
||||
def _hash_file(file):
|
||||
with open(file, "rb") as f:
|
||||
hash = file_digest(f, "sha256")
|
||||
return hash.hexdigest()
|
||||
|
||||
|
||||
def _hash_files(files):
|
||||
hashes = {file: _hash_file(file) for file in files}
|
||||
return hashes
|
||||
|
||||
|
||||
def _read_hashes(hash_file: Path):
|
||||
if not hash_file.exists():
|
||||
return {}
|
||||
with hash_file.open("r") as f:
|
||||
lines = f.readlines()
|
||||
hashes = {}
|
||||
for line in lines:
|
||||
hash = line[:64]
|
||||
file = line[66:].strip()
|
||||
hashes[file] = hash
|
||||
return hashes
|
||||
|
||||
|
||||
def _updated_hashes(hash_file, files_to_hash):
|
||||
old_hashes = _read_hashes(hash_file)
|
||||
new_hashes = _hash_files(files_to_hash)
|
||||
if new_hashes != old_hashes:
|
||||
return new_hashes
|
||||
return None
|
||||
|
||||
|
||||
@click.command()
|
||||
def regenerate_version():
|
||||
"""Regenerate version.py."""
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tools.generate_torch_version",
|
||||
"--is-debug=false",
|
||||
]
|
||||
spin.util.run(cmd)
|
||||
|
||||
|
||||
TYPE_STUBS = [
|
||||
(
|
||||
"Pytorch type stubs",
|
||||
Path(".lintbin/.pytorch-type-stubs.sha256"),
|
||||
[
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
"aten/src/ATen/native/tags.yaml",
|
||||
"tools/autograd/deprecated.yaml",
|
||||
],
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tools.pyi.gen_pyi",
|
||||
"--native-functions-path",
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
"--tags-path",
|
||||
"aten/src/ATen/native/tags.yaml",
|
||||
"--deprecated-functions-path",
|
||||
"tools/autograd/deprecated.yaml",
|
||||
],
|
||||
),
|
||||
(
|
||||
"Datapipes type stubs",
|
||||
None,
|
||||
[],
|
||||
[
|
||||
sys.executable,
|
||||
"torch/utils/data/datapipes/gen_pyi.py",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@click.command()
|
||||
def regenerate_type_stubs():
|
||||
"""Regenerate type stubs."""
|
||||
for name, hash_file, files_to_hash, cmd in TYPE_STUBS:
|
||||
if hash_file:
|
||||
if hashes := _updated_hashes(hash_file, files_to_hash):
|
||||
click.echo(
|
||||
f"Changes detected in type stub files for {name}. Regenerating..."
|
||||
)
|
||||
spin.util.run(cmd)
|
||||
hash_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with hash_file.open("w") as f:
|
||||
for file, hash in hashes.items():
|
||||
f.write(f"{hash} {file}\n")
|
||||
click.echo("Type stubs and hashes updated.")
|
||||
else:
|
||||
click.echo(f"No changes detected in type stub files for {name}.")
|
||||
else:
|
||||
click.echo(f"No hash file for {name}. Regenerating...")
|
||||
spin.util.run(cmd)
|
||||
click.echo("Type stubs regenerated.")
|
||||
|
||||
|
||||
@click.command()
|
||||
def regenerate_clangtidy_files():
|
||||
"""Regenerate clang-tidy files."""
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tools.linter.clang_tidy.generate_build_files",
|
||||
]
|
||||
spin.util.run(cmd)
|
||||
|
||||
|
||||
#: These linters are expected to need less than 3s cpu time total
|
||||
VERY_FAST_LINTERS = {
|
||||
"ATEN_CPU_GPU_AGNOSTIC",
|
||||
"BAZEL_LINTER",
|
||||
"C10_NODISCARD",
|
||||
"C10_UNUSED",
|
||||
"CALL_ONCE",
|
||||
"CMAKE_MINIMUM_REQUIRED",
|
||||
"CONTEXT_DECORATOR",
|
||||
"COPYRIGHT",
|
||||
"CUBINCLUDE",
|
||||
"DEPLOY_DETECTION",
|
||||
"ERROR_PRONE_ISINSTANCE",
|
||||
"EXEC",
|
||||
"HEADER_ONLY_LINTER",
|
||||
"IMPORT_LINTER",
|
||||
"INCLUDE",
|
||||
"LINTRUNNER_VERSION",
|
||||
"MERGE_CONFLICTLESS_CSV",
|
||||
"META_NO_CREATE_UNBACKED",
|
||||
"NEWLINE",
|
||||
"NOQA",
|
||||
"NO_WORKFLOWS_ON_FORK",
|
||||
"ONCE_FLAG",
|
||||
"PYBIND11_INCLUDE",
|
||||
"PYBIND11_SPECIALIZATION",
|
||||
"PYPIDEP",
|
||||
"PYPROJECT",
|
||||
"RAWCUDA",
|
||||
"RAWCUDADEVICE",
|
||||
"ROOT_LOGGING",
|
||||
"TABS",
|
||||
"TESTOWNERS",
|
||||
"TYPEIGNORE",
|
||||
"TYPENOSKIP",
|
||||
"WORKFLOWSYNC",
|
||||
}
|
||||
|
||||
|
||||
#: These linters are expected to take a few seconds, but less than 10s cpu time total
|
||||
FAST_LINTERS = {
|
||||
"CMAKE",
|
||||
"DOCSTRING_LINTER",
|
||||
"GHA",
|
||||
"NATIVEFUNCTIONS",
|
||||
"RUFF",
|
||||
"SET_LINTER",
|
||||
"SHELLCHECK",
|
||||
"SPACES",
|
||||
}
|
||||
|
||||
|
||||
#: These linters are expected to take more than 10s cpu time total;
|
||||
#: some need more than 1 hour.
|
||||
SLOW_LINTERS = {
|
||||
"ACTIONLINT",
|
||||
"CLANGFORMAT",
|
||||
"CLANGTIDY",
|
||||
"CODESPELL",
|
||||
"FLAKE8",
|
||||
"GB_REGISTRY",
|
||||
"PYFMT",
|
||||
"PYREFLY",
|
||||
"TEST_DEVICE_BIAS",
|
||||
"TEST_HAS_MAIN",
|
||||
}
|
||||
|
||||
|
||||
ALL_LINTERS = VERY_FAST_LINTERS | FAST_LINTERS | SLOW_LINTERS
|
||||
|
||||
|
||||
LINTRUNNER_CACHE_INFO = (
|
||||
Path(".lintbin/.lintrunner.sha256"),
|
||||
[
|
||||
"requirements.txt",
|
||||
"pyproject.toml",
|
||||
".lintrunner.toml",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
LINTRUNNER_BASE_CMD = [
|
||||
"uvx",
|
||||
"--python",
|
||||
"3.10",
|
||||
"lintrunner@0.12.7",
|
||||
]
|
||||
|
||||
|
||||
@click.command()
|
||||
def setup_lint():
|
||||
"""Set up lintrunner with current CI version."""
|
||||
cmd = LINTRUNNER_BASE_CMD + ["init"]
|
||||
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
|
||||
|
||||
def _check_linters():
|
||||
cmd = LINTRUNNER_BASE_CMD + ["list"]
|
||||
ret = spin.util.run(cmd, output=False, stderr=subprocess.PIPE)
|
||||
linters = {l.strip() for l in ret.stdout.decode().strip().split("\n")[1:]}
|
||||
unknown_linters = linters - ALL_LINTERS
|
||||
missing_linters = ALL_LINTERS - linters
|
||||
if unknown_linters:
|
||||
click.secho(
|
||||
f"Unknown linters found; please add them to the correct category "
|
||||
f"in .spin/cmds.py: {', '.join(unknown_linters)}",
|
||||
fg="yellow",
|
||||
)
|
||||
if missing_linters:
|
||||
click.secho(
|
||||
f"Missing linters found; please update the corresponding category "
|
||||
f"in .spin/cmds.py: {', '.join(missing_linters)}",
|
||||
fg="yellow",
|
||||
)
|
||||
return unknown_linters, missing_linters
|
||||
|
||||
|
||||
@spin.util.extend_command(
|
||||
setup_lint,
|
||||
doc=f"""
|
||||
If configuration has changed, update lintrunner.
|
||||
|
||||
Compares the stored old hashes of configuration files with new ones and
|
||||
performs setup via setup-lint if the hashes have changed.
|
||||
Hashes are stored in {LINTRUNNER_CACHE_INFO[0]}; the following files are
|
||||
considered: {", ".join(LINTRUNNER_CACHE_INFO[1])}.
|
||||
""",
|
||||
)
|
||||
@click.pass_context
|
||||
def lazy_setup_lint(ctx, parent_callback, **kwargs):
|
||||
if hashes := _updated_hashes(*LINTRUNNER_CACHE_INFO):
|
||||
click.echo(
|
||||
"Changes detected in lint configuration files. Setting up linting tools..."
|
||||
)
|
||||
parent_callback(**kwargs)
|
||||
hash_file = LINTRUNNER_CACHE_INFO[0]
|
||||
hash_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with hash_file.open("w") as f:
|
||||
for file, hash in hashes.items():
|
||||
f.write(f"{hash} {file}\n")
|
||||
click.echo("Linting tools set up and hashes updated.")
|
||||
else:
|
||||
click.echo("No changes detected in lint configuration files. Skipping setup.")
|
||||
click.echo("Regenerating version...")
|
||||
ctx.invoke(regenerate_version)
|
||||
click.echo("Regenerating type stubs...")
|
||||
ctx.invoke(regenerate_type_stubs)
|
||||
click.echo("Done.")
|
||||
_check_linters()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("-a", "--apply-patches", is_flag=True)
|
||||
@click.pass_context
|
||||
def lint(ctx, apply_patches, **kwargs):
|
||||
"""Lint all files."""
|
||||
ctx.invoke(lazy_setup_lint)
|
||||
all_files_linters = VERY_FAST_LINTERS | FAST_LINTERS
|
||||
changed_files_linters = SLOW_LINTERS
|
||||
cmd = LINTRUNNER_BASE_CMD
|
||||
if apply_patches:
|
||||
cmd += ["--apply-patches"]
|
||||
all_files_cmd = cmd + [
|
||||
"--take",
|
||||
",".join(all_files_linters),
|
||||
"--all-files",
|
||||
]
|
||||
spin.util.run(all_files_cmd)
|
||||
changed_files_cmd = cmd + [
|
||||
"--take",
|
||||
",".join(changed_files_linters),
|
||||
]
|
||||
spin.util.run(changed_files_cmd)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.pass_context
|
||||
def fixlint(ctx, **kwargs):
|
||||
"""Autofix all files."""
|
||||
ctx.invoke(lint, apply_patches=True)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("-a", "--apply-patches", is_flag=True)
|
||||
@click.pass_context
|
||||
def quicklint(ctx, apply_patches, **kwargs):
|
||||
"""Lint changed files."""
|
||||
ctx.invoke(lazy_setup_lint)
|
||||
cmd = LINTRUNNER_BASE_CMD
|
||||
if apply_patches:
|
||||
cmd += ["--apply-patches"]
|
||||
spin.util.run(cmd)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.pass_context
|
||||
def quickfix(ctx, **kwargs):
|
||||
"""Autofix changed files."""
|
||||
ctx.invoke(quicklint, apply_patches=True)
|
||||
@ -118,6 +118,11 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
|
||||
endif()
|
||||
endif()
|
||||
if("${_arch}" STREQUAL "121a")
|
||||
if(_existing_arch_flags MATCHES ".*compute_120.*")
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_121a,code=sm_121a")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
list(JOIN _file_compile_flags " " _file_compile_flags)
|
||||
|
||||
@ -126,7 +131,7 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
|
||||
"89;90a;100a;103a;120a")
|
||||
"89;90a;100a;103a;120a;121a")
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
|
||||
"90a")
|
||||
|
||||
@ -376,3 +376,19 @@ keep-runtime-typing = true
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words = "tools/linter/dictionary.txt"
|
||||
|
||||
[tool.spin]
|
||||
package = 'torch'
|
||||
|
||||
[tool.spin.commands]
|
||||
"Build" = [
|
||||
".spin/cmds.py:lint",
|
||||
".spin/cmds.py:fixlint",
|
||||
".spin/cmds.py:quicklint",
|
||||
".spin/cmds.py:quickfix",
|
||||
]
|
||||
"Regenerate" = [
|
||||
".spin/cmds.py:regenerate_version",
|
||||
".spin/cmds.py:regenerate_type_stubs",
|
||||
".spin/cmds.py:regenerate_clangtidy_files",
|
||||
]
|
||||
|
||||
@ -14,6 +14,7 @@ lintrunner ; platform_machine != "s390x" and platform_machine != "riscv64"
|
||||
networkx>=2.5.1
|
||||
optree>=0.13.0
|
||||
psutil
|
||||
spin
|
||||
sympy>=1.13.3
|
||||
typing-extensions>=4.13.2
|
||||
wheel
|
||||
|
||||
@ -331,6 +331,25 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
|
||||
self.assertEqual(z.placements, (Replicate(),))
|
||||
self.assertEqual(z.to_local(), input)
|
||||
|
||||
def test_inplace_op_partial_to_replicate(self):
|
||||
# test that in-place operations that require redistribution raise an error
|
||||
# to preserve aliasing semantics (issue #163374)
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
input_tensor = torch.tensor(64.0, device=self.device_type)
|
||||
partial_dt = DTensor.from_local(
|
||||
input_tensor, device_mesh, placements=(Partial(),)
|
||||
)
|
||||
|
||||
self.assertTrue(partial_dt.placements[0].is_partial())
|
||||
|
||||
# Inplace ops that require placement changes (Partial -> Replicate) should error
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"in-place operations that require placement changes are not supported",
|
||||
):
|
||||
partial_dt.clamp_(max=10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
53
test/distributed/test_debug.py
Normal file
53
test/distributed/test_debug.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.debug import start_debug_server, stop_debug_server
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
session = requests.Session()
|
||||
retry_strategy = Retry(total=5, backoff_factor=0.5)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
|
||||
class TestDebug(TestCase):
|
||||
def test_basics(self) -> None:
|
||||
store = dist.TCPStore("localhost", 0, 1, is_master=True, wait_for_workers=False)
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(store.port)
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
|
||||
port = 25999
|
||||
|
||||
def fetch(path: str) -> str:
|
||||
resp = session.get(f"http://localhost:{port}{path}")
|
||||
resp.raise_for_status()
|
||||
return resp.text
|
||||
|
||||
print("starting!")
|
||||
|
||||
start_debug_server(port=port)
|
||||
|
||||
self.assertIn("torch profiler", fetch("/"))
|
||||
self.assertIn("View 0", fetch("/profile?duration=0.01"))
|
||||
self.assertIn("test_basics", fetch("/stacks"))
|
||||
self.assertIn("pg_status", fetch("/fr_trace"))
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.assertIn("pg_status", fetch("/fr_trace_nccl"))
|
||||
|
||||
stop_debug_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -1,9 +1,11 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -13,13 +15,16 @@ import torch._inductor.config
|
||||
import torch._inductor.test_case
|
||||
import torch.onnx.operators
|
||||
import torch.utils.cpp_extension
|
||||
from torch._dynamo.aot_compile import ModelInput, SerializableCallable
|
||||
from torch._dynamo.aot_compile import AOTCompiledModel, ModelInput, SerializableCallable
|
||||
from torch._dynamo.exc import PackageError, Unsupported
|
||||
from torch._dynamo.package import DynamoCache
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
from torch.testing._internal.common_utils import instantiate_parametrized_tests
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
TEST_CUDA,
|
||||
)
|
||||
|
||||
|
||||
MY_LAMBDA = lambda x: x + 1 # noqa: E731
|
||||
@ -599,6 +604,92 @@ from user code:
|
||||
actual = compiled_fn(*inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_aot_compile_with_aoti(self):
|
||||
with torch.device("cuda"):
|
||||
from torch._dynamo.hooks import Hooks
|
||||
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(3, 4), torch.randn(3, 4))
|
||||
|
||||
compiled_fn = torch._dynamo.aot_compile.aot_compile_fullgraph(
|
||||
fn,
|
||||
(make_inputs(), {}),
|
||||
Hooks(),
|
||||
torch._TorchCompileAOTInductorWrapper(None, None, None),
|
||||
)
|
||||
|
||||
test_inputs = make_inputs()
|
||||
expected = fn(*test_inputs)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
compiled_fn.save_compiled_function(self.path())
|
||||
with open(self.path(), "rb") as f:
|
||||
compiled_fn = torch.compiler.load_compiled_function(f)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_aot_compile_with_aoti_module(self):
|
||||
with torch.device("cuda"):
|
||||
from torch._dynamo.hooks import Hooks
|
||||
|
||||
mod = SimpleLinearModule()
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(4, 3),)
|
||||
|
||||
compiled_mod = torch._dynamo.aot_compile.aot_compile_module(
|
||||
mod,
|
||||
[ModelInput(make_inputs(), {}, [])],
|
||||
Hooks(),
|
||||
torch._TorchCompileAOTInductorWrapper(None, None, None),
|
||||
)
|
||||
|
||||
def get_grads(m: torch.nn.Module):
|
||||
return {name: p.grad for name, p in m.named_parameters()}
|
||||
|
||||
original_mod = copy.deepcopy(mod)
|
||||
test_inputs = make_inputs()
|
||||
expected = mod(*test_inputs)
|
||||
expected.sum().backward()
|
||||
expected_grads = get_grads(mod)
|
||||
|
||||
actual = compiled_mod(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
serialized = compiled_mod.serialize()
|
||||
compiled_fn = AOTCompiledModel.deserialize(original_mod, serialized)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
actual.sum().backward()
|
||||
self.assertEqual(get_grads(original_mod), expected_grads)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_aot_compile_with_aoti_torch_compile(self):
|
||||
with torch.device("cuda"):
|
||||
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(3, 4), torch.randn(3, 4))
|
||||
|
||||
compiled_fn = torch.compile(
|
||||
fn, fullgraph=True, options={"use_aoti": True}
|
||||
).aot_compile((make_inputs(), {}))
|
||||
test_inputs = make_inputs()
|
||||
expected = fn(*test_inputs)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
compiled_fn.save_compiled_function(self.path())
|
||||
with open(self.path(), "rb") as f:
|
||||
compiled_fn = torch.compiler.load_compiled_function(f)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor")
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -100,7 +100,9 @@ class Logger:
|
||||
def _set_static_graph(self) -> None: ...
|
||||
|
||||
class _WorkerServer:
|
||||
def __init__(self, socket_path: str) -> None: ...
|
||||
port: int
|
||||
|
||||
def __init__(self, host_or_file: str, port: int = ...) -> None: ...
|
||||
def shutdown(self) -> None: ...
|
||||
|
||||
def get_debug_level(): ...
|
||||
@ -206,6 +208,7 @@ class Store:
|
||||
desired_value: str,
|
||||
) -> bytes: ...
|
||||
def delete_key(self, key: str) -> bool: ...
|
||||
def multi_get(self, keys: list[str]) -> list[bytes]: ...
|
||||
def num_keys(self) -> int: ...
|
||||
def set_timeout(self, timeout: timedelta): ...
|
||||
@overload
|
||||
@ -871,3 +874,15 @@ class ProcessGroupXCCL(Backend):
|
||||
|
||||
def _set_process_group(pg: ProcessGroup) -> None: ...
|
||||
def _current_process_group() -> ProcessGroup: ...
|
||||
|
||||
class _Request:
|
||||
def body(self) -> bytes: ...
|
||||
def get_param(self, str) -> str: ...
|
||||
|
||||
class _Response:
|
||||
def set_content(self, content: str | bytes, content_type: str) -> None: ...
|
||||
def set_status(self, status: int) -> None: ...
|
||||
|
||||
def _register_handler(
|
||||
name: str, handler: Callable[[_Request, _Response], None]
|
||||
) -> None: ...
|
||||
|
||||
@ -60,6 +60,7 @@ class _ExperimentalConfig:
|
||||
verbose: bool = ...,
|
||||
performance_events: list[str] = ...,
|
||||
enable_cuda_sync_events: bool = ...,
|
||||
profile_all_threads: bool = ...,
|
||||
) -> None: ...
|
||||
|
||||
class ProfilerConfig:
|
||||
|
||||
@ -2439,6 +2439,35 @@ class _TorchCompileInductorWrapper:
|
||||
reset_cudagraph_trees()
|
||||
|
||||
|
||||
class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
|
||||
compiler_name = "aotinductor"
|
||||
|
||||
def __init__(self, mode, options, dynamic):
|
||||
super().__init__(mode, options, dynamic)
|
||||
self.apply_options({"cpp_wrapper": True})
|
||||
self.apply_options({"aot_inductor.package": True})
|
||||
|
||||
def __call__(self, model_, inputs_):
|
||||
from contextlib import nullcontext
|
||||
from unittest import mock
|
||||
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
fake_mode = detect_fake_mode(inputs_)
|
||||
ctx = (
|
||||
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
|
||||
if fake_mode
|
||||
else nullcontext()
|
||||
)
|
||||
with (
|
||||
V.set_aot_compilation(True),
|
||||
ctx,
|
||||
torch._inductor.config.patch("enable_autograd_for_aot", True),
|
||||
):
|
||||
return super().__call__(model_, inputs_)
|
||||
|
||||
|
||||
class _TorchCompileWrapper:
|
||||
def __init__(self, backend, mode, options, dynamic):
|
||||
from torch._dynamo.backends.registry import lookup_backend
|
||||
@ -2672,8 +2701,10 @@ def compile(
|
||||
backend = bisect_backend
|
||||
|
||||
guard_filter_fn = None
|
||||
use_aoti = False
|
||||
if options and isinstance(options, dict):
|
||||
guard_filter_fn = options.pop("guard_filter_fn", None)
|
||||
use_aoti = options.pop("use_aoti", False)
|
||||
|
||||
if torch.compiler.is_exporting():
|
||||
warnings.warn(
|
||||
@ -2700,7 +2731,10 @@ def compile(
|
||||
return export_wrapped_fn
|
||||
|
||||
if backend == "inductor":
|
||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||
if use_aoti:
|
||||
backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
|
||||
else:
|
||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||
else:
|
||||
backend = _TorchCompileWrapper(backend, mode, options, dynamic)
|
||||
|
||||
|
||||
@ -53,6 +53,7 @@ class CompileArtifacts:
|
||||
argdefs: Optional[tuple[Any, ...]]
|
||||
source_info: "SourceInfo"
|
||||
device_type: str
|
||||
backend_name: str
|
||||
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
|
||||
|
||||
def check_compatibility(self) -> None:
|
||||
@ -166,7 +167,8 @@ class AOTCompiledFunction:
|
||||
state = pickle.loads(data)
|
||||
state["bytecode"] = SerializedCode.to_code_object(state["bytecode"])
|
||||
deserializer, compiled_fn_state = state["compiled_fn"]
|
||||
state["compiled_fn"] = deserializer(compiled_fn_state)
|
||||
with torch._inductor.config.patch(enable_autograd_for_aot=True):
|
||||
state["compiled_fn"] = deserializer(compiled_fn_state)
|
||||
state["original_code"] = SerializedCode.to_code_object(state["original_code"])
|
||||
|
||||
artifacts = CompileArtifacts(**state)
|
||||
@ -273,6 +275,7 @@ def aot_compile_fullgraph(
|
||||
argdefs=fn.__defaults__,
|
||||
source_info=source_info,
|
||||
device_type=device_type,
|
||||
backend_name=getattr(backend, "compiler_name", "unknown"),
|
||||
)
|
||||
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)
|
||||
|
||||
|
||||
@ -511,6 +511,7 @@ class GenericAOTAutogradResult(Generic[TForward, TBackward]):
|
||||
).post_compile(
|
||||
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
|
||||
)
|
||||
compiled_fw_func._boxed_call = True
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
|
||||
if needs_autograd:
|
||||
|
||||
@ -1640,7 +1640,9 @@ class _InProcessFxCompile(FxCompile):
|
||||
# pyrefly: ignore [unbound-name]
|
||||
(str, list, torch.fx.GraphModule),
|
||||
), type(compiled_fn)
|
||||
return CompiledAOTI(compiled_fn)
|
||||
return CompiledAOTI(
|
||||
filename=compiled_fn, device_type=graph.device_type
|
||||
)
|
||||
|
||||
# TODO: Hoist this above V.aot_compilation
|
||||
# pyrefly: ignore [unbound-name]
|
||||
@ -2713,7 +2715,7 @@ def _compile_fx_main(
|
||||
or torch._guards.TracingContext(fake_mode)
|
||||
)
|
||||
|
||||
if V.aot_compilation:
|
||||
if V.aot_compilation and not config.enable_autograd_for_aot:
|
||||
from .utils import is_valid_aoti_model_name
|
||||
|
||||
is_valid_aoti_model_name()
|
||||
|
||||
@ -1190,6 +1190,8 @@ autotune_lookup_table: dict[str, dict[str, Any]] = {}
|
||||
|
||||
file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600"))
|
||||
|
||||
enable_autograd_for_aot: bool = False
|
||||
|
||||
|
||||
def get_worker_log_path() -> Optional[str]:
|
||||
log_loc = None
|
||||
|
||||
@ -773,9 +773,86 @@ class CompiledAOTI(OutputCode):
|
||||
"""
|
||||
|
||||
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
|
||||
device_type: str
|
||||
current_callable: Optional[Callable[..., Any]] = None
|
||||
_cached_files: dict[str, bytes] = dataclasses.field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if not config.aot_inductor.link_libtorch:
|
||||
return
|
||||
|
||||
if (
|
||||
torch._inductor.cpp_builder._IS_MACOS
|
||||
or torch._inductor.cpp_builder._IS_WINDOWS
|
||||
):
|
||||
return
|
||||
|
||||
if config.aot_inductor.cross_target_platform == "windows":
|
||||
return
|
||||
|
||||
if config.aot_inductor.package_cpp_only:
|
||||
return
|
||||
|
||||
if not config.enable_autograd_for_aot:
|
||||
return
|
||||
|
||||
if isinstance(self.filename, list):
|
||||
current_callable = next(
|
||||
fn for fn in self.filename if isinstance(fn, str) and fn.endswith(".so")
|
||||
)
|
||||
else:
|
||||
current_callable = self.filename
|
||||
|
||||
if isinstance(current_callable, torch.fx.GraphModule):
|
||||
self.current_callable = current_callable
|
||||
return
|
||||
|
||||
if self.device_type.startswith("cuda"):
|
||||
current_callable = (
|
||||
torch._C._aoti.AOTIModelContainerRunnerCuda( # type: ignore[call-arg]
|
||||
current_callable,
|
||||
1,
|
||||
self.device_type,
|
||||
"",
|
||||
True,
|
||||
).run # type: ignore[attr-defined]
|
||||
) # type: ignore[attr-defined]
|
||||
elif self.device_type == "cpu":
|
||||
current_callable = (
|
||||
torch._C._aoti.AOTIModelContainerRunnerCpu( # type: ignore[call-arg]
|
||||
current_callable, 1
|
||||
).run # type: ignore[attr-defined]
|
||||
) # type: ignore[attr-defined]
|
||||
else:
|
||||
raise RuntimeError(f"unsupported device type {self.device_type}")
|
||||
self.current_callable = current_callable
|
||||
self._boxed_call = True
|
||||
for file in self._cached_files:
|
||||
if not os.path.exists(file):
|
||||
with open(file, "wb") as f:
|
||||
f.write(self._cached_files[file])
|
||||
|
||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
raise NotImplementedError("NYI")
|
||||
if self.current_callable is None:
|
||||
raise RuntimeError("AOTInductor compiled so is not loaded")
|
||||
return self.current_callable(inputs)
|
||||
|
||||
def prepare_for_serialization(self) -> None:
|
||||
self.current_callable = None
|
||||
self._cached_files = {}
|
||||
filenames: list[str] = []
|
||||
if isinstance(self.filename, list):
|
||||
filenames = self.filename # type: ignore[assignment]
|
||||
elif isinstance(self.filename, str):
|
||||
filenames = [self.filename]
|
||||
for name in filenames:
|
||||
with open(name, "rb") as f:
|
||||
self._cached_files[name] = f.read()
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["current_callable"] = None
|
||||
return state
|
||||
|
||||
def post_compile(
|
||||
self,
|
||||
@ -783,10 +860,8 @@ class CompiledAOTI(OutputCode):
|
||||
constants: CompiledFxGraphConstants,
|
||||
graph_kwargs: _CompileFxKwargs,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def prepare_for_serialization(self) -> None:
|
||||
pass
|
||||
if self.current_callable is None:
|
||||
self.__post_init__()
|
||||
|
||||
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
||||
pass
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/FlightRecorder.hpp>
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
@ -63,6 +65,14 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
|
||||
res.setStatus(200);
|
||||
}};
|
||||
|
||||
RegisterHandler frTracehandler(
|
||||
"fr_trace_json",
|
||||
[](const Request&, Response& res) {
|
||||
auto trace = ::c10d::dump_fr_trace_json(true, true);
|
||||
res.setContent(std::move(trace), "application/json");
|
||||
res.setStatus(200);
|
||||
});
|
||||
|
||||
} // namespace
|
||||
|
||||
void registerHandler(const std::string& name, HandlerFunc f) {
|
||||
|
||||
@ -18,6 +18,14 @@ class TORCH_API Request {
|
||||
virtual const std::string& body() const = 0;
|
||||
|
||||
virtual const std::multimap<std::string, std::string>& params() const = 0;
|
||||
|
||||
std::string getParam(const std::string& key) const {
|
||||
auto it = params().find(key);
|
||||
if (it != params().end()) {
|
||||
return it->second;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
};
|
||||
|
||||
// Response represents a response to the handler. This conceptually maps to an
|
||||
|
||||
@ -152,11 +152,17 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) {
|
||||
TORCH_CHECK(
|
||||
server_.bind_to_port(hostOrFile, 80),
|
||||
fmt::format("Error binding to {}", hostOrFile));
|
||||
} else if (port == 0) {
|
||||
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
|
||||
port_ = server_.bind_to_any_port(hostOrFile);
|
||||
TORCH_CHECK(
|
||||
port_ >= 0, fmt::format("Error binding to {}:{}", hostOrFile, port));
|
||||
} else {
|
||||
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
|
||||
TORCH_CHECK(
|
||||
server_.bind_to_port(hostOrFile, port),
|
||||
fmt::format("Error binding to {}:{}", hostOrFile, port));
|
||||
port_ = port;
|
||||
}
|
||||
|
||||
serverThread_ = std::thread([this]() {
|
||||
|
||||
@ -19,9 +19,14 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
|
||||
|
||||
void shutdown();
|
||||
|
||||
int port() {
|
||||
return port_;
|
||||
}
|
||||
|
||||
private:
|
||||
httplib::Server server_;
|
||||
std::thread serverThread_;
|
||||
int port_;
|
||||
};
|
||||
|
||||
} // namespace c10d::control_plane
|
||||
|
||||
@ -46,6 +46,7 @@
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <pybind11/chrono.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.hpp>
|
||||
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
|
||||
@ -4203,7 +4204,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
}),
|
||||
py::arg("host_or_file"),
|
||||
py::arg("port") = -1)
|
||||
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
|
||||
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown)
|
||||
.def_property_readonly(
|
||||
"port", &::c10d::control_plane::WorkerServer::port);
|
||||
|
||||
module.def(
|
||||
"_get_handler",
|
||||
@ -4219,6 +4222,25 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
Returns the handler with the specified name.
|
||||
)");
|
||||
|
||||
module.def(
|
||||
"_register_handler",
|
||||
[](const std::string& name, const py::function& handler) {
|
||||
::c10d::control_plane::registerHandler(
|
||||
name,
|
||||
[handler](
|
||||
const ::c10d::control_plane::Request& req,
|
||||
::c10d::control_plane::Response& res) {
|
||||
py::gil_scoped_acquire acquire;
|
||||
handler(std::ref(req), std::ref(res));
|
||||
});
|
||||
},
|
||||
|
||||
py::arg("name"),
|
||||
py::arg("handler"),
|
||||
R"(
|
||||
Registers a handler by name.
|
||||
)");
|
||||
|
||||
module.def(
|
||||
"_get_handler_names",
|
||||
&::c10d::control_plane::getHandlerNames,
|
||||
@ -4236,12 +4258,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
// Default constructor.
|
||||
.def(py::init<>())
|
||||
.def("body", &::c10d::control_plane::Request::body)
|
||||
.def("params", &::c10d::control_plane::Request::params);
|
||||
.def("get_param", &::c10d::control_plane::Request::getParam);
|
||||
|
||||
py::class_<
|
||||
::c10d::control_plane::Response,
|
||||
std::shared_ptr<::c10d::control_plane::Response>,
|
||||
PythonResponse>(
|
||||
py::class_<::c10d::control_plane::Response, PythonResponse>(
|
||||
module,
|
||||
"_Response",
|
||||
R"(
|
||||
|
||||
@ -66,6 +66,12 @@ void initAOTIRunnerBindings(PyObject* module) {
|
||||
int,
|
||||
const std::string&,
|
||||
const std::string&>())
|
||||
.def(py::init<
|
||||
const std::string&,
|
||||
int,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const bool>())
|
||||
.def(
|
||||
"run",
|
||||
&AOTIModelContainerRunnerCuda::run,
|
||||
|
||||
59
torch/distributed/debug/__init__.py
Normal file
59
torch/distributed/debug/__init__.py
Normal file
@ -0,0 +1,59 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import socket
|
||||
|
||||
# import for registration side effect
|
||||
import torch.distributed.debug._handlers # noqa: F401
|
||||
from torch._C._distributed_c10d import _WorkerServer
|
||||
from torch.distributed.debug._store import get_rank, tcpstore_client
|
||||
|
||||
|
||||
__all__ = [
|
||||
"start_debug_server",
|
||||
"stop_debug_server",
|
||||
]
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
_WORKER_SERVER: _WorkerServer | None = None
|
||||
_DEBUG_SERVER_PROC: multiprocessing.Process | None = None
|
||||
|
||||
|
||||
def start_debug_server(port: int = 25999, worker_port: int = 0) -> None:
|
||||
global _WORKER_SERVER, _DEBUG_SERVER_PROC
|
||||
|
||||
assert _WORKER_SERVER is None, "debug server already started"
|
||||
assert _DEBUG_SERVER_PROC is None, "debug server already started"
|
||||
|
||||
logger.info("Starting debug server on port %d", port)
|
||||
|
||||
store = tcpstore_client()
|
||||
|
||||
_WORKER_SERVER = _WorkerServer("::", worker_port)
|
||||
|
||||
RANK = get_rank()
|
||||
store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_WORKER_SERVER.port}")
|
||||
|
||||
from torch.distributed.debug._flask import main
|
||||
|
||||
if RANK == 0:
|
||||
_DEBUG_SERVER_PROC = multiprocessing.Process(
|
||||
target=main, args=(port,), daemon=True
|
||||
)
|
||||
_DEBUG_SERVER_PROC.start()
|
||||
|
||||
|
||||
def stop_debug_server() -> None:
|
||||
global _WORKER_SERVER, _DEBUG_SERVER_PROC
|
||||
|
||||
assert _DEBUG_SERVER_PROC is not None
|
||||
assert _WORKER_SERVER is not None
|
||||
|
||||
logger.info("Stopping debug server")
|
||||
|
||||
_DEBUG_SERVER_PROC.terminate()
|
||||
_WORKER_SERVER.shutdown()
|
||||
_DEBUG_SERVER_PROC.join()
|
||||
|
||||
_WORKER_SERVER = None
|
||||
_DEBUG_SERVER_PROC = None
|
||||
265
torch/distributed/debug/_flask.py
Normal file
265
torch/distributed/debug/_flask.py
Normal file
@ -0,0 +1,265 @@
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
from flask import Flask, render_template, request
|
||||
from jinja2 import DictLoader
|
||||
|
||||
from torch.distributed.debug._store import get_world_size, tcpstore_client
|
||||
|
||||
|
||||
def fetch_all(
|
||||
endpoint: str, args: str = ""
|
||||
) -> tuple[list[str], Iterator[requests.Response]]:
|
||||
store = tcpstore_client()
|
||||
keys = [f"rank{r}" for r in range(get_world_size())]
|
||||
addrs = store.multi_get(keys)
|
||||
addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
resps = executor.map(requests.post, addrs)
|
||||
|
||||
return addrs, resps
|
||||
|
||||
|
||||
def format_json(blob: str):
|
||||
parsed = json.loads(blob)
|
||||
return json.dumps(parsed, indent=2)
|
||||
|
||||
|
||||
templates = {
|
||||
"base.html": """
|
||||
<!doctype html>
|
||||
<head>
|
||||
<title>{% block title %}{% endblock %} - PyTorch Distributed</title>
|
||||
<link rel="shortcut icon" type="image/x-icon" href="https://pytorch.org/favicon.ico?">
|
||||
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
font-family:
|
||||
-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,
|
||||
"Helvetica Neue",Arial,"Noto Sans",sans-serif,"Apple Color Emoji",
|
||||
"Segoe UI Emoji","Segoe UI Symbol","Noto Color Emoji";
|
||||
font-size: 1rem;
|
||||
font-weight: 400;
|
||||
line-height: 1.5;
|
||||
color: #212529;
|
||||
text-align: left;
|
||||
background-color: #fff;
|
||||
}
|
||||
h1, h2, h2, h4, h5, h6, .h1, .h2, .h2, .h4, .h5, .h6 {
|
||||
margin-bottom: .5rem;
|
||||
font-weight: 500;
|
||||
line-height: 1.2;
|
||||
}
|
||||
nav {
|
||||
background-color: rgba(0, 0, 0, 0.17);
|
||||
padding: 10px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 16px;
|
||||
justify-content: flex-start;
|
||||
}
|
||||
nav h1 {
|
||||
display: inline-block;
|
||||
margin: 0;
|
||||
}
|
||||
nav a {
|
||||
margin: 0 8px;
|
||||
}
|
||||
section {
|
||||
max-width: 1280px;
|
||||
padding: 16px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
pre {
|
||||
white-space: pre-wrap;
|
||||
max-width: 100%;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<nav>
|
||||
<h1>Torch Distributed Debug Server</h1>
|
||||
|
||||
<a href="/">Home</a> <!--@lint-ignore-->
|
||||
<a href="/stacks">Python Stack Traces</a> <!--@lint-ignore-->
|
||||
<a href="/fr_trace">FlightRecorder</a> <!--@lint-ignore-->
|
||||
<a href="/fr_trace_nccl">FlightRecorder NCCL</a> <!--@lint-ignore-->
|
||||
<a href="/profile">torch profiler</a> <!--@lint-ignore-->
|
||||
</nav>
|
||||
|
||||
<section class="content">
|
||||
{% block header %}{% endblock %}
|
||||
{% for message in get_flashed_messages() %}
|
||||
<div class="flash">{{ message }}</div>
|
||||
{% endfor %}
|
||||
{% block content %}{% endblock %}
|
||||
</section>
|
||||
""",
|
||||
"index.html": """
|
||||
{% extends "base.html" %}
|
||||
{% block header %}
|
||||
<h1>{% block title %}Index{% endblock %}</h1>
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
Hi
|
||||
{% endblock %}
|
||||
""",
|
||||
"raw_resp.html": """
|
||||
{% extends "base.html" %}
|
||||
{% block header %}
|
||||
<h1>{% block title %}{{title}}{% endblock %}</h1>
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
|
||||
<h2>Rank {{ i }}: {{ addr }}</h2>
|
||||
{% if resp.status_code != 200 %}
|
||||
<p>Failed to fetch: status={{ resp.status_code }}</p>
|
||||
<pre>{{ resp.text }}</pre>
|
||||
{% else %}
|
||||
<pre>{{ resp.text }}</pre>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endblock %}
|
||||
""",
|
||||
"json_resp.html": """
|
||||
{% extends "base.html" %}
|
||||
{% block header %}
|
||||
<h1>{% block title %}{{ title }}{% endblock %}</h1>
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
|
||||
<h2>Rank {{ i }}: {{ addr }}</h2>
|
||||
{% if resp.status_code != 200 %}
|
||||
<p>Failed to fetch: status={{ resp.status_code }}</p>
|
||||
<pre>{{ resp.text }}</pre>
|
||||
{% else %}
|
||||
<pre>{{ format_json(resp.text) }}</pre>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endblock %}
|
||||
""",
|
||||
"profile.html": """
|
||||
{% extends "base.html" %}
|
||||
{% block header %}
|
||||
<h1>{% block title %}torch.profiler{% endblock %}</h1>
|
||||
{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<form action="/profile" method="get">
|
||||
<label for="duration">Duration (seconds):</label>
|
||||
<input type="number" id="duration" name="duration" value="{{ duration }}" min="1" max="60">
|
||||
<input type="submit" value="Submit">
|
||||
</form>
|
||||
|
||||
<script>
|
||||
function stringToArrayBuffer(str) {
|
||||
const encoder = new TextEncoder();
|
||||
return encoder.encode(str).buffer;
|
||||
}
|
||||
async function openPerfetto(data) {
|
||||
const ui = window.open('https://ui.perfetto.dev/#!/');
|
||||
if (!ui) { alert('Popup blocked. Allow popups for this page and click again.'); return; }
|
||||
|
||||
// Perfetto readiness handshake: PING until we receive PONG
|
||||
await new Promise((resolve, reject) => {
|
||||
const onMsg = (e) => {
|
||||
if (e.source === ui && e.data === 'PONG') {
|
||||
window.removeEventListener('message', onMsg);
|
||||
clearInterval(pinger);
|
||||
resolve();
|
||||
}
|
||||
};
|
||||
window.addEventListener('message', onMsg);
|
||||
const pinger = setInterval(() => { try { ui.postMessage('PING', '*'); } catch (_e) {} }, 250);
|
||||
setTimeout(() => { clearInterval(pinger); window.removeEventListener('message', onMsg); reject(); }, 20000);
|
||||
}).catch(() => { alert('Perfetto UI did not respond. Try again.'); return; });
|
||||
|
||||
ui.postMessage({
|
||||
perfetto: {
|
||||
buffer: stringToArrayBuffer(JSON.stringify(data)),
|
||||
title: "torch profiler",
|
||||
fileName: "trace.json",
|
||||
}
|
||||
}, '*');
|
||||
}
|
||||
</script>
|
||||
|
||||
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
|
||||
<h2>Rank {{ i }}: {{ addr }}</h2>
|
||||
{% if resp.status_code != 200 %}
|
||||
<p>Failed to fetch: status={{ resp.status_code }}</p>
|
||||
<pre>{{ resp.text }}</pre>
|
||||
{% else %}
|
||||
<script>
|
||||
function run{{ i }}() {
|
||||
var data = {{ resp.text | safe }};
|
||||
openPerfetto(data);
|
||||
}
|
||||
</script>
|
||||
|
||||
<button onclick="run{{ i }}()">View {{ i }}</button>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endblock %}
|
||||
""",
|
||||
}
|
||||
|
||||
app = Flask(__name__)
|
||||
app.jinja_loader = DictLoader(templates)
|
||||
app.jinja_env.globals.update(
|
||||
zip=zip,
|
||||
format_json=format_json,
|
||||
enumerate=enumerate,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def _index_handler():
|
||||
return render_template("index.html")
|
||||
|
||||
|
||||
@app.route("/stacks")
|
||||
def _stacks_handler():
|
||||
addrs, resps = fetch_all("dump_traceback")
|
||||
return render_template("raw_resp.html", title="Stacks", addrs=addrs, resps=resps)
|
||||
|
||||
|
||||
@app.route("/fr_trace")
|
||||
def _fr_trace_handler():
|
||||
addrs, resps = fetch_all("fr_trace_json")
|
||||
|
||||
return render_template(
|
||||
"json_resp.html",
|
||||
title="FlightRecorder",
|
||||
addrs=addrs,
|
||||
resps=resps,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/fr_trace_nccl")
|
||||
def _fr_trace_nccl_handler():
|
||||
addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true")
|
||||
|
||||
return render_template(
|
||||
"json_resp.html",
|
||||
title="FlightRecorder NCCL",
|
||||
addrs=addrs,
|
||||
resps=resps,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/profile")
|
||||
def _profiler_handler():
|
||||
duration = request.args.get("duration", default=1.0, type=float)
|
||||
|
||||
addrs, resps = fetch_all("torch_profile", f"duration={duration}")
|
||||
|
||||
return render_template("profile.html", addrs=addrs, resps=resps)
|
||||
|
||||
|
||||
def main(port: int) -> None:
|
||||
app.run(host="::", port=port)
|
||||
22
torch/distributed/debug/_handlers.py
Normal file
22
torch/distributed/debug/_handlers.py
Normal file
@ -0,0 +1,22 @@
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from torch._C._distributed_c10d import _register_handler, _Request, _Response
|
||||
from torch.profiler import _ExperimentalConfig, profile
|
||||
|
||||
|
||||
def _torch_profile(req: _Request, resp: _Response) -> None:
|
||||
experimental_config = _ExperimentalConfig(
|
||||
profile_all_threads=True,
|
||||
)
|
||||
duration = float(req.get_param("duration"))
|
||||
with profile(record_shapes=True, experimental_config=experimental_config) as prof:
|
||||
time.sleep(duration)
|
||||
|
||||
with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f:
|
||||
prof.export_chrome_trace(f.name)
|
||||
resp.set_content(open(f.name, "rb").read(), "application/json")
|
||||
resp.set_status(200)
|
||||
|
||||
|
||||
_register_handler("torch_profile", _torch_profile)
|
||||
24
torch/distributed/debug/_store.py
Normal file
24
torch/distributed/debug/_store.py
Normal file
@ -0,0 +1,24 @@
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
return int(os.environ["RANK"])
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
return int(os.environ["WORLD_SIZE"])
|
||||
|
||||
|
||||
def tcpstore_client() -> dist.Store:
|
||||
MASTER_ADDR = os.environ["MASTER_ADDR"]
|
||||
MASTER_PORT = int(os.environ["MASTER_PORT"])
|
||||
|
||||
store = dist.TCPStore(
|
||||
host_name=MASTER_ADDR,
|
||||
port=MASTER_PORT,
|
||||
is_master=False,
|
||||
)
|
||||
store = dist.PrefixStore("debug_server", store)
|
||||
return store
|
||||
@ -337,19 +337,34 @@ class OpDispatcher:
|
||||
if is_inplace_op:
|
||||
# inplace op should return self instead of re-wrapping
|
||||
if output_sharding.output_spec is not None:
|
||||
output_spec = output_sharding.output_spec
|
||||
assert isinstance(output_spec, DTensorSpec)
|
||||
assert isinstance(args[0], dtensor.DTensor)
|
||||
|
||||
# NOTE: aten.squeeze_.dim is an inplace op but it also may change
|
||||
# the inplace argument's tensor meta. Here we choose to special case
|
||||
# this op because as far as I know this is the only inplace op that
|
||||
# has such as behavior. We can extend this special case if necessary.
|
||||
if op_call == aten.squeeze_.dim:
|
||||
output_spec = output_sharding.output_spec
|
||||
assert isinstance(output_spec, DTensorSpec)
|
||||
assert isinstance(args[0], dtensor.DTensor)
|
||||
# update the spec to handle tensor meta changes
|
||||
args[0]._spec = output_spec
|
||||
# use return_and_correct_aliasing to match the outer and the inner
|
||||
# aliasing. See https://github.com/pytorch/pytorch/pull/158954
|
||||
return return_and_correct_aliasing(op_call, args, kwargs, args[0])
|
||||
else:
|
||||
# For all other inplace ops, check if placement changes are required
|
||||
# Inplace operations that change placement are not supported because
|
||||
# they would require redistribution, which breaks aliasing semantics.
|
||||
# If there are views into the tensor, the views would not be updated.
|
||||
if args[0]._spec.placements != output_spec.placements:
|
||||
raise RuntimeError(
|
||||
f"{op_call}: in-place operations that require placement changes "
|
||||
f"are not supported. The operation would change placement from "
|
||||
f"{args[0]._spec.placements} to {output_spec.placements}, "
|
||||
f"which requires redistribution and breaks aliasing semantics. "
|
||||
f"Please use the out-of-place version of this operation instead."
|
||||
)
|
||||
# Most inplace ops don't change tensor meta, so no spec update needed
|
||||
return args[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user