[BE] remove torch deploy - conditionals (#158288)

This PR is part of the work to deprecate torch::deploy in OSS. Effectively it does 3 things to get started.
1. Remove test_deploy_interaction as we no longer need to worry about this
2. Remove all torch._running_with_deploy checks and use the False path always (surfaced 1)
3. Remove `USE_DEPLOY` and switch to the default path always

Note: MyPy does fail on a bunch of things here as a bunch of older files are touched. It may be better to fix these things on a separate PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158288
Approved by: https://github.com/albanD
This commit is contained in:
PaliC
2025-07-23 10:34:43 -04:00
committed by PyTorch MergeBot
parent da94023b02
commit ab26d4fbeb
22 changed files with 371 additions and 496 deletions

View File

@ -544,62 +544,6 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
class TestCustomOp(CustomOpTestCaseBase):
test_ns = "_test_custom_op"
def test_deploy_interaction(self):
# run in a different process to avoid parallel issues when we monkeypatch torch._running_with_deploy
script = """
import torch
torch._running_with_deploy = lambda: True
# creating the library is a no-op, so you can DEF multiple times
m1 = torch.library.Library("mylib4392", "DEF") # noqa: TOR901
m2 = torch.library.Library("mylib4392", "DEF") # noqa: TOR901
m = torch.library.Library("aten", "FRAGMENT") # noqa: TOR901
# define is a no-op
m.define("foobarbaz9996(Tensor x) -> Tensor")
assert not hasattr(torch.ops.aten, "foobarbaz9996"), "m.define should have been a noop"
def sin_override(x):
raise AssertionError("m.impl should have been a noop")
# impl is a no-op
m.impl("sin", sin_override, "CompositeImplicitAutograd")
x = torch.randn(3)
y = torch.sin(x)
# should be a no-op
@torch.library.custom_op("mylib::foobar", mutates_args={})
def foobar(x: torch.Tensor) -> torch.Tensor:
return x.sin()
# should be a no-op
@foobar.register_fake
def _(x):
return torch.empty_like(x)
# should be a no-op
m2.define("foobarbaz9996(Tensor x) -> Tensor")
# should be a no-op
@torch.library.register_fake("mylib4392::foobarbaz9996")
def _(x):
return torch.empty_like(x)
"""
script = script.strip()
env = os.environ.copy()
try:
subprocess.check_output(
[sys.executable, "-c", script],
stderr=subprocess.STDOUT,
# On Windows, opening the subprocess with the default CWD makes `import torch`
# fail, so just set CWD to this script's directory
cwd=os.path.dirname(os.path.realpath(__file__)),
env=env,
)
except subprocess.CalledProcessError as e:
self.fail(msg=("Subprocess exception:\n" + e.output.decode("utf-8")))
@requires_compile
def test_functionalize_error(self):
with torch.library._scoped_library(self.test_ns, "FRAGMENT") as lib:

View File

@ -3603,8 +3603,8 @@ class TestSparseCompressedTritonKernels(TestCase):
@onlyCUDA
@dtypes(torch.half, torch.bfloat16, torch.float)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
@unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU) or torch._running_with_deploy(),
"Skipped for deploy and internal with remote GPUs")
@unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU),
"Skipped for internal with remote GPUs")
def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
from functools import partial
from torch.sparse._triton_ops import bsr_dense_mm
@ -3680,8 +3680,8 @@ class TestSparseCompressedTritonKernels(TestCase):
@onlyCUDA
@dtypes(torch.half)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU or torch._running_with_deploy(),
"Skipped for deploy and internal with remote GPUs")
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU,
"Skipped for internal with remote GPUs")
def test_triton_bsr_dense_bmm_error_messages(self, device, dtype):
from torch.sparse._triton_ops import bsr_dense_mm

View File

@ -34,20 +34,10 @@ from typing import (
)
from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs
from . import version
if TYPE_CHECKING:
from .types import Device, IntLikeType
# multipy/deploy is setting this import before importing torch, this is the most # codespell:ignore multipy
# reliable way we have to detect if we're running within deploy.
# https://github.com/pytorch/multipy/blob/d60f34ad38c371e441fe7ffdb77a3c3dda5a5d19/multipy/runtime/interpreter/interpreter_impl.cpp#L134-L137 # codespell:ignore multipy # noqa: B950
def _running_with_deploy() -> builtins.bool:
return sys.modules.get("torch._meta_registrations", None) is object
from torch._utils import (
_functionalize_sync as _sync,
_import_dotted_name,
@ -60,14 +50,9 @@ from torch._utils_internal import (
USE_GLOBAL_DEPS,
USE_RTLD_GLOBAL_WITH_LIBTORCH,
)
# TODO(torch_deploy) figure out how to freeze version.py in fbcode build
if _running_with_deploy():
__version__ = "torch-deploy-1.8"
else:
from torch.torch_version import __version__ as __version__
__all__ = [
"BoolStorage",
"BoolTensor",
@ -317,7 +302,7 @@ def _preload_cuda_deps(lib_folder: str, lib_name: str) -> None:
# See Note [Global dependencies]
def _load_global_deps() -> None:
if _running_with_deploy() or platform.system() == "Windows":
if platform.system() == "Windows":
return
# Determine the file extension based on the platform
@ -381,7 +366,7 @@ def _load_global_deps() -> None:
if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and (
_running_with_deploy() or platform.system() != "Windows"
platform.system() != "Windows"
):
# Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a
# few circumstances:
@ -2082,7 +2067,7 @@ from torch.serialization import load, save
# Shared memory manager needs to know the exact location of manager executable
def _manager_path():
if _running_with_deploy() or platform.system() == "Windows":
if platform.system() == "Windows":
return b""
path = get_file_path("torch", "bin", "torch_shm_manager")
prepare_multiprocessing_environment(get_file_path("torch"))
@ -2687,9 +2672,9 @@ from torch import fx as fx
# Register MPS specific decomps
torch.backends.mps._init()
if not _running_with_deploy():
from torch import compiler as compiler
class _TritonLibrary:
lib = torch.library.Library("triton", "DEF")
ops_table: dict[tuple[str, str], _Callable] = {}

View File

@ -49,9 +49,6 @@ Tensor = torch.Tensor
__all__ = ["trace_wrapped"]
if not torch._running_with_deploy():
# torch.library.custom_op does not work with torch.deploy/multipy # codespell:ignore
@torch.library.custom_op("flex_lib::zeros_and_scatter", mutates_args=()) # type: ignore[misc]
def zeros_and_scatter(
shape: list[int],
@ -62,6 +59,7 @@ if not torch._running_with_deploy():
grad = torch.zeros(shape, device=vals.device, dtype=vals.dtype)
return torch.ops.aten.index_put(grad, indices, vals, accumulate=True)
@zeros_and_scatter.register_fake # type: ignore[misc]
def _(
shape: list[int],
@ -70,6 +68,7 @@ if not torch._running_with_deploy():
) -> Tensor:
return vals.new_empty(shape)
@zeros_and_scatter.register_vmap # type: ignore[misc]
def _(info, indims, shape, indices, value): # type: ignore[no-untyped-def]
"""The batching rule is special in that it returns a tensor that is not batched"""

View File

@ -2409,7 +2409,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch._lowrank.svd_lowrank",
"torch._preload_cuda_deps",
"torch._register_device_module",
"torch._running_with_deploy",
"torch._utils._dummy_type",
"torch._utils._flatten_dense_tensors",
"torch._utils._unflatten_dense_tensors",

View File

@ -5,16 +5,14 @@ from torch import Tensor
from torch.autograd import Function
if not torch._running_with_deploy():
_test_lib_def = torch.library.Library("_inductor_test", "DEF")
_test_lib_def.define(
"realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag
)
_test_lib_def.define("realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag)
_test_lib_impl = torch.library.Library("_inductor_test", "IMPL")
for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"):
_test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key)
class Realize(Function):
@staticmethod
def forward(ctx: object, x: Tensor) -> Tensor:
@ -25,5 +23,6 @@ if not torch._running_with_deploy():
def backward(ctx: Any, *grad_output: Any) -> Any:
return grad_output[0]
def realize(x: Tensor) -> Tensor:
return Realize.apply(x)

View File

@ -595,10 +595,6 @@ class CustomOpDef:
self._setup_context_fn = setup_context
def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None:
if torch._running_with_deploy():
utils.warn_deploy(stacklevel=5)
return
lib = self._lib
schema_str = self._name + self._schema
cpp_schema = _C.parse_schema(schema_str)

View File

@ -2,7 +2,6 @@
import dataclasses
import inspect
import sys
import warnings
from collections.abc import Iterable, Iterator
from typing import Any, Callable, Union
@ -12,15 +11,6 @@ from torch import _C, _utils_internal
from torch._ops import OpOverload
def warn_deploy(stacklevel=3):
warnings.warn(
"Python torch.library APIs do nothing under torch::deploy (multipy). " # codespell:ignore multipy
"Please instead use C++ custom operator registration APIs.",
RuntimeWarning,
stacklevel=stacklevel,
)
@dataclasses.dataclass
class Kernel:
"""Models a (function, source location)"""

View File

@ -1478,9 +1478,6 @@ class _Ops(types.ModuleType):
Args:
path (str): A path to a shared library to load.
"""
if torch._running_with_deploy():
return
path = _utils_internal.resolve_library_path(path)
with dl_open_guard():
# Import the shared library into the process, thus running its

View File

@ -33,12 +33,6 @@ if os.environ.get("TORCH_COMPILE_STROBELIGHT", False):
# use is the FB build environment, where this source file is replaced
# by an equivalent.
if torch._running_with_deploy():
# __file__ is meaningless in the context of frozen torch used in torch deploy.
# setting empty torch_parent should allow below functions to operate without crashing,
# but it's unclear if there is a valid use case for them in the context of deploy.
torch_parent = ""
else:
if os.path.basename(os.path.dirname(__file__)) == "shared":
torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
else:

View File

@ -331,13 +331,9 @@ void initLazyBindings(PyObject* module) {
// So far this problem has only been observed internally, so we will just
// block it off there.
#if !(defined(USE_DEPLOY))
// When libtorch_python is loaded, we register the python frame getter
// otherwise, debug util simply omits python frames
GetPythonFramesFunction() = GetPythonFrames;
#endif // USE_DEPLOY
}
} // namespace torch::lazy

View File

@ -187,15 +187,6 @@ class PythonKernelHolder : public c10::OperatorKernel {
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
py::gil_scoped_acquire g;
// Jan 2024: We're slated to get rid of multipy, // codespell:ignore multipy
// so stop forcing hermetic mode unconditionally in all situations when
// you're using multipy. // codespell:ignore multipy
// Eventually just delete this entirely. (Note that you may break
// multipy anyway this way with dispatcher // codespell:ignore multipy
// registered functions that require hermetic to be off.)
#if defined(USE_DEPLOY)
EnableHermeticPyObject g2;
#endif
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
auto func =
py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter()));

View File

@ -1693,9 +1693,6 @@ class _WrappedTritonKernel:
def _register_triton_kernels():
if torch._running_with_deploy():
return
@_WrappedTritonKernel
def kernel_impl(*args, **kwargs):
from torch.sparse._triton_ops import bsr_dense_mm

View File

@ -19,13 +19,6 @@ except ImportError:
from torch.utils._pytree import tree_map_only # type: ignore[no-redef]
if torch._running_with_deploy():
def is_torchdynamo_compiling():
"""Can't import torchdynamo in torchdeploy builds currently."""
return False
else:
try:
from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
except Exception:
@ -33,7 +26,8 @@ else:
"Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly"
)
def is_torchdynamo_compiling():
def is_torchdynamo_compiling(): # type: ignore[misc]
return False
return False
@ -987,10 +981,7 @@ def _reduce_scatter_tensor_coalesced_native_meta(
]
if not torch._running_with_deploy():
# Library MUST be defined at module scope or it doesn't work
# Creating a "DEF" Library always crashes torch::deploy so we create our
# Library instances here guarded against running inside it
lib_impl = torch.library.Library("_c10d_functional", "IMPL")
lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")
lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta")
@ -1043,11 +1034,6 @@ if not torch._running_with_deploy():
legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag)
legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd")
else:
warnings.warn(
"PyTorch Distributed functional collectives do not work with torch::deploy."
)
"""
Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into

View File

@ -63,7 +63,6 @@ _META_FUNCTIONS = {
"recv_any_source_": lambda *args: create_fakework(args, return_first_arg=False),
}
if not torch._running_with_deploy():
lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901
for op, meta_func in _META_FUNCTIONS.items():
lib_impl.impl(op, meta_func, "Meta")

View File

@ -15,15 +15,6 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec
_compiled_autograd_enabled: bool = False
if torch._running_with_deploy():
def detect_compiled_autograd():
pass
def compiled_autograd_enabled():
return False
else:
def detect_compiled_autograd():
assert not torch.compiler.is_compiling(), (
@ -38,6 +29,7 @@ else:
or ca.in_compiled_autograd_region
)
def compiled_autograd_enabled():
global _compiled_autograd_enabled
return _compiled_autograd_enabled

View File

@ -140,7 +140,6 @@ def copy__functionalize(tensor, data):
torch.ops.fsdp.copy_.default(tensor_inner, data_inner)
if not torch._running_with_deploy():
torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default)

View File

@ -25,8 +25,6 @@ from torch.distributed.distributed_c10d import (
logger = logging.getLogger(__name__)
if not torch._running_with_deploy():
@torch.library.register_fake("_dtensor::shard_dim_alltoall")
def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name):
group_size = _get_group_size_by_name(group_name)
@ -40,13 +38,6 @@ if not torch._running_with_deploy():
.contiguous()
)
else:
import warnings
warnings.warn(
"PyTorch Distributed functional collectives do not work with torch::deploy."
)
def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim):
if mesh.device_type == "cpu":

View File

@ -102,9 +102,6 @@ class Library:
ns,
" is a reserved namespace. Please try creating a library with another name.",
)
if torch._running_with_deploy():
_library.utils.warn_deploy()
return
frame = traceback.extract_stack(limit=3)[0]
filename, lineno = frame.filename, frame.lineno
@ -156,9 +153,6 @@ class Library:
>>> my_lib = Library("mylib", "DEF")
>>> my_lib.define("sum(Tensor self) -> Tensor")
"""
if torch._running_with_deploy():
_library.utils.warn_deploy()
return
# This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
# AliasAnalysis type in C++
@ -191,9 +185,6 @@ class Library:
def _register_fake(self, op_name, fn, _stacklevel=1, *, allow_override=False):
r"""Registers the fake impl for an operator defined in the library."""
if torch._running_with_deploy():
_library.utils.warn_deploy()
return
source = torch._library.utils.get_source(_stacklevel + 1)
frame = sys._getframe(_stacklevel)
@ -237,9 +228,6 @@ class Library:
If it is a TorchDispatchMode, we expect fn to have the following signature:
(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
"""
if torch._running_with_deploy():
_library.utils.warn_deploy()
return
qualname = f"{self.ns}::{op_name}"
entry = torch._library.simple_registry.singleton.find(qualname)
@ -259,9 +247,6 @@ class Library:
>>> my_lib = Library("aten", "IMPL")
>>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
"""
if torch._running_with_deploy():
_library.utils.warn_deploy()
return
if dispatch_key == "":
dispatch_key = self.dispatch_key
@ -324,9 +309,6 @@ class Library:
>>> return self * (1 / other)
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
"""
if torch._running_with_deploy():
_library.utils.warn_deploy()
return
if not callable(fn):
raise TypeError(
@ -409,9 +391,6 @@ class Library:
>>> # ...
>>> my_lib.fallback(fallback_kernel, "Autocast")
"""
if torch._running_with_deploy():
_library.utils.warn_deploy()
return
if dispatch_key == "":
dispatch_key = self.dispatch_key

View File

@ -29,13 +29,7 @@ def set_module(obj, mod):
obj.__module__ = mod
if torch._running_with_deploy():
# not valid inside torch_deploy interpreter, no paths exists for frozen modules
cmake_prefix_path = None
else:
cmake_prefix_path = _osp.join(
_osp.dirname(_osp.dirname(__file__)), "share", "cmake"
)
cmake_prefix_path = _osp.join(_osp.dirname(_osp.dirname(__file__)), "share", "cmake")
def swap_tensors(t1, t2):

View File

@ -3,8 +3,6 @@ import importlib.util
from types import ModuleType
from typing import Optional
import torch
def _check_module_exists(name: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without**
@ -22,11 +20,7 @@ def _check_module_exists(name: str) -> bool:
@functools.lru_cache
def dill_available() -> bool:
return (
_check_module_exists("dill")
# dill fails to import under torchdeploy
and not torch._running_with_deploy()
)
return _check_module_exists("dill")
@functools.lru_cache

View File

@ -6,49 +6,53 @@
import datetime
import json
import locale
import os
import re
import subprocess
import sys
import os
from typing import cast as _cast
from collections import namedtuple
from typing import cast as _cast
try:
import torch
TORCH_AVAILABLE = True
except (ImportError, NameError, AttributeError, OSError):
TORCH_AVAILABLE = False
# System Environment Information
SystemEnv = namedtuple('SystemEnv', [
'torch_version',
'is_debug_build',
'cuda_compiled_version',
'gcc_version',
'clang_version',
'cmake_version',
'os',
'libc_version',
'python_version',
'python_platform',
'is_cuda_available',
'cuda_runtime_version',
'cuda_module_loading',
'nvidia_driver_version',
'nvidia_gpu_models',
'cudnn_version',
'is_xpu_available',
'pip_version', # 'pip' or 'pip3'
'pip_packages',
'conda_packages',
'hip_compiled_version',
'hip_runtime_version',
'miopen_runtime_version',
'caching_allocator_config',
'is_xnnpack_available',
'cpu_info',
])
SystemEnv = namedtuple(
"SystemEnv",
[
"torch_version",
"is_debug_build",
"cuda_compiled_version",
"gcc_version",
"clang_version",
"cmake_version",
"os",
"libc_version",
"python_version",
"python_platform",
"is_cuda_available",
"cuda_runtime_version",
"cuda_module_loading",
"nvidia_driver_version",
"nvidia_gpu_models",
"cudnn_version",
"is_xpu_available",
"pip_version", # 'pip' or 'pip3'
"pip_packages",
"conda_packages",
"hip_compiled_version",
"hip_runtime_version",
"miopen_runtime_version",
"caching_allocator_config",
"is_xnnpack_available",
"cpu_info",
],
)
COMMON_PATTERNS = [
"torch",
@ -116,12 +120,13 @@ PIP_PATTERNS = [
def run(command):
"""Return (return-code, stdout, stderr)."""
shell = True if type(command) is str else False
p = subprocess.Popen(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, shell=shell)
p = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell
)
raw_output, raw_err = p.communicate()
rc = p.returncode
if get_platform() == 'win32':
enc = 'oem'
if get_platform() == "win32":
enc = "oem"
else:
enc = locale.getpreferredencoding()
output = raw_output.decode(enc)
@ -147,18 +152,19 @@ def run_and_parse_first_match(run_lambda, command, regex):
return None
return match.group(1)
def run_and_return_first_line(run_lambda, command):
"""Run command using run_lambda and returns first line if output is not empty."""
rc, out, _ = run_lambda(command)
if rc != 0:
return None
return out.split('\n')[0]
return out.split("\n")[0]
def get_conda_packages(run_lambda, patterns=None):
if patterns is None:
patterns = CONDA_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + ONEAPI_PATTERNS
conda = os.environ.get('CONDA_EXE', 'conda')
conda = os.environ.get("CONDA_EXE", "conda")
out = run_and_read_all(run_lambda, "{} list".format(conda))
if out is None:
return out
@ -166,32 +172,40 @@ def get_conda_packages(run_lambda, patterns=None):
return "\n".join(
line
for line in out.splitlines()
if not line.startswith("#")
and any(name in line for name in patterns)
if not line.startswith("#") and any(name in line for name in patterns)
)
def get_gcc_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)')
return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)")
def get_clang_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)')
return run_and_parse_first_match(
run_lambda, "clang --version", r"clang version (.*)"
)
def get_cmake_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)')
return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)")
def get_nvidia_driver_version(run_lambda):
if get_platform() == 'darwin':
cmd = 'kextstat | grep -i cuda'
return run_and_parse_first_match(run_lambda, cmd,
r'com[.]nvidia[.]CUDA [(](.*?)[)]')
if get_platform() == "darwin":
cmd = "kextstat | grep -i cuda"
return run_and_parse_first_match(
run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]"
)
smi = get_nvidia_smi()
return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ')
return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ")
def get_gpu_info(run_lambda):
if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None):
if get_platform() == "darwin" or (
TORCH_AVAILABLE
and hasattr(torch.version, "hip")
and torch.version.hip is not None
):
if TORCH_AVAILABLE and torch.cuda.is_available():
if torch.version.hip is not None:
prop = torch.cuda.get_device_properties(0)
@ -204,42 +218,42 @@ def get_gpu_info(run_lambda):
return torch.cuda.get_device_name(None) + gcnArch
return None
smi = get_nvidia_smi()
uuid_regex = re.compile(r' \(UUID: .+?\)')
rc, out, _ = run_lambda(smi + ' -L')
uuid_regex = re.compile(r" \(UUID: .+?\)")
rc, out, _ = run_lambda(smi + " -L")
if rc != 0:
return None
# Anonymize GPUs by removing their UUID
return re.sub(uuid_regex, '', out)
return re.sub(uuid_regex, "", out)
def get_running_cuda_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)')
return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)")
def get_cudnn_version(run_lambda):
"""Return a list of libcudnn.so; it's hard to tell which one is being used."""
if get_platform() == 'win32':
system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%")
where_cmd = os.path.join(system_root, 'System32', 'where')
if get_platform() == "win32":
system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%")
where_cmd = os.path.join(system_root, "System32", "where")
cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path)
elif get_platform() == 'darwin':
elif get_platform() == "darwin":
# CUDA libraries and drivers can be found in /usr/local/cuda/. See
# https://docs.nvidia.com/cuda/archive/9.0/cuda-installation-guide-mac-os-x/index.html#installation
# https://docs.nvidia.com/deeplearning/cudnn/installation/latest/
# Use CUDNN_LIBRARY when cudnn library is installed elsewhere.
cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*'
cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*"
else:
cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev'
rc, out, _ = run_lambda(cudnn_cmd)
# find will return 1 if there are permission errors or if not found
if len(out) == 0 or (rc != 1 and rc != 0):
l = os.environ.get('CUDNN_LIBRARY')
l = os.environ.get("CUDNN_LIBRARY")
if l is not None and os.path.isfile(l):
return os.path.realpath(l)
return None
files_set = set()
for fn in out.split('\n'):
for fn in out.split("\n"):
fn = os.path.realpath(fn) # eliminate symbolic links
if os.path.isfile(fn):
files_set.add(fn)
@ -249,18 +263,20 @@ def get_cudnn_version(run_lambda):
files = sorted(files_set)
if len(files) == 1:
return files[0]
result = '\n'.join(files)
return 'Probably one of the following:\n{}'.format(result)
result = "\n".join(files)
return "Probably one of the following:\n{}".format(result)
def get_nvidia_smi():
# Note: nvidia-smi is currently available only on Windows and Linux
smi = 'nvidia-smi'
if get_platform() == 'win32':
system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files')
legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi)
new_path = os.path.join(system_root, 'System32', smi)
smi = "nvidia-smi"
if get_platform() == "win32":
system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files")
legacy_path = os.path.join(
program_files_root, "NVIDIA Corporation", "NVSMI", smi
)
new_path = os.path.join(system_root, "System32", smi)
smis = [new_path, legacy_path]
for candidate_smi in smis:
if os.path.exists(candidate_smi):
@ -411,7 +427,9 @@ def get_intel_gpu_detected(run_lambda):
if device_count == 0:
return "N/A"
devices = [f"* [{i}] {torch.xpu.get_device_properties(i)}" for i in range(device_count)]
devices = [
f"* [{i}] {torch.xpu.get_device_properties(i)}" for i in range(device_count)
]
return "\n".join(devices)
@ -490,11 +508,12 @@ def get_intel_gpu_detected(run_lambda):
# ProcessorType=3
# Revision=27142
def get_cpu_info(run_lambda):
rc, out, err = 0, '', ''
if get_platform() == 'linux':
rc, out, err = run_lambda('lscpu')
elif get_platform() == 'win32':
rc, out, err = 0, "", ""
if get_platform() == "linux":
rc, out, err = run_lambda("lscpu")
elif get_platform() == "win32":
rc, out, err = run_lambda(
'powershell.exe "gwmi -Class Win32_Processor | Select-Object -Property Name,Manufacturer,Family,\
Architecture,ProcessorType,DeviceID,CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision\
@ -514,9 +533,9 @@ def get_cpu_info(run_lambda):
lst.append(out)
lst.append(str(e))
out = "\n".join(lst)
elif get_platform() == 'darwin':
elif get_platform() == "darwin":
rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string")
cpu_info = 'None'
cpu_info = "None"
if rc == 0:
cpu_info = out
else:
@ -525,20 +544,20 @@ def get_cpu_info(run_lambda):
def get_platform():
if sys.platform.startswith('linux'):
return 'linux'
elif sys.platform.startswith('win32'):
return 'win32'
elif sys.platform.startswith('cygwin'):
return 'cygwin'
elif sys.platform.startswith('darwin'):
return 'darwin'
if sys.platform.startswith("linux"):
return "linux"
elif sys.platform.startswith("win32"):
return "win32"
elif sys.platform.startswith("cygwin"):
return "cygwin"
elif sys.platform.startswith("darwin"):
return "darwin"
else:
return sys.platform
def get_mac_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)')
return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)")
def get_windows_version(run_lambda):
@ -556,39 +575,43 @@ def get_windows_version(run_lambda):
def get_lsb_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)')
return run_and_parse_first_match(
run_lambda, "lsb_release -a", r"Description:\t(.*)"
)
def check_release_file(run_lambda):
return run_and_parse_first_match(run_lambda, 'cat /etc/*-release',
r'PRETTY_NAME="(.*)"')
return run_and_parse_first_match(
run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"'
)
def get_os(run_lambda):
from platform import machine
platform = get_platform()
if platform in ["win32", "cygwin"]:
return get_windows_version(run_lambda)
if platform == 'darwin':
if platform == "darwin":
version = get_mac_version(run_lambda)
if version is None:
return None
return 'macOS {} ({})'.format(version, machine())
return "macOS {} ({})".format(version, machine())
if platform == 'linux':
if platform == "linux":
# Ubuntu/Debian based
desc = get_lsb_version(run_lambda)
if desc is not None:
return '{} ({})'.format(desc, machine())
return "{} ({})".format(desc, machine())
# Try reading /etc/*-release
desc = check_release_file(run_lambda)
if desc is not None:
return '{} ({})'.format(desc, machine())
return "{} ({})".format(desc, machine())
return '{} ({})'.format(platform, machine())
return "{} ({})".format(platform, machine())
# Unknown platform
return platform
@ -596,14 +619,16 @@ def get_os(run_lambda):
def get_python_platform():
import platform
return platform.platform()
def get_libc_version():
import platform
if get_platform() != 'linux':
return 'N/A'
return '-'.join(platform.libc_ver())
if get_platform() != "linux":
return "N/A"
return "-".join(platform.libc_ver())
def get_pip_packages(run_lambda, patterns=None):
@ -611,35 +636,35 @@ def get_pip_packages(run_lambda, patterns=None):
if patterns is None:
patterns = PIP_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + ONEAPI_PATTERNS
pip_version = 'pip3' if sys.version_info.major == 3 else 'pip'
pip_version = "pip3" if sys.version_info.major == 3 else "pip"
os.environ['PIP_DISABLE_PIP_VERSION_CHECK'] = '1'
os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1"
# People generally have pip as `pip` or `pip3`
# But here it is invoked as `python -mpip`
out = run_and_read_all(run_lambda, [sys.executable, '-mpip', 'list', '--format=freeze'])
out = run_and_read_all(
run_lambda, [sys.executable, "-mpip", "list", "--format=freeze"]
)
if out is None:
return pip_version, out
filtered_out = '\n'.join(
line
for line in out.splitlines()
if any(name in line for name in patterns)
filtered_out = "\n".join(
line for line in out.splitlines() if any(name in line for name in patterns)
)
return pip_version, filtered_out
def get_cachingallocator_config():
ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '')
ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
if not ca_config:
ca_config = os.environ.get('PYTORCH_HIP_ALLOC_CONF', '')
ca_config = os.environ.get("PYTORCH_HIP_ALLOC_CONF", "")
return ca_config
def get_cuda_module_loading_config():
if TORCH_AVAILABLE and torch.cuda.is_available():
torch.cuda.init()
config = os.environ.get('CUDA_MODULE_LOADING', '')
config = os.environ.get("CUDA_MODULE_LOADING", "")
return config
else:
return "N/A"
@ -648,10 +673,12 @@ def get_cuda_module_loading_config():
def is_xnnpack_available():
if TORCH_AVAILABLE:
import torch.backends.xnnpack
return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
else:
return "N/A"
def get_env_info():
"""
Collects environment information to aid in debugging.
@ -678,26 +705,31 @@ def get_env_info():
cuda_version_str = torch.version.cuda
xpu_available_str = str(torch.xpu.is_available())
if torch.xpu.is_available():
xpu_available_str = f'{xpu_available_str}\n' + \
f'XPU used to build PyTorch: {torch.version.xpu}\n' + \
f'Intel GPU driver version:\n{get_intel_gpu_driver_version(run_lambda)}\n' + \
f'Intel GPU models onboard:\n{get_intel_gpu_onboard(run_lambda)}\n' + \
f'Intel GPU models detected:\n{get_intel_gpu_detected(run_lambda)}'
if not hasattr(torch.version, 'hip') or torch.version.hip is None: # cuda version
hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A'
xpu_available_str = (
f"{xpu_available_str}\n"
+ f"XPU used to build PyTorch: {torch.version.xpu}\n"
+ f"Intel GPU driver version:\n{get_intel_gpu_driver_version(run_lambda)}\n"
+ f"Intel GPU models onboard:\n{get_intel_gpu_onboard(run_lambda)}\n"
+ f"Intel GPU models detected:\n{get_intel_gpu_detected(run_lambda)}"
)
if (
not hasattr(torch.version, "hip") or torch.version.hip is None
): # cuda version
hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A"
else: # HIP version
def get_version_or_na(cfg, prefix):
_lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s]
return _lst[0] if _lst else 'N/A'
return _lst[0] if _lst else "N/A"
cfg = torch._C._show_config().split('\n')
hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime')
miopen_runtime_version = get_version_or_na(cfg, 'MIOpen')
cuda_version_str = 'N/A'
cfg = torch._C._show_config().split("\n")
hip_runtime_version = get_version_or_na(cfg, "HIP Runtime")
miopen_runtime_version = get_version_or_na(cfg, "MIOpen")
cuda_version_str = "N/A"
hip_compiled_version = torch.version.hip
else:
version_str = debug_mode_str = cuda_available_str = cuda_version_str = xpu_available_str = 'N/A'
hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A'
version_str = debug_mode_str = cuda_available_str = cuda_version_str = xpu_available_str = "N/A" # type: ignore[assignment]
hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A"
sys_version = sys.version.replace("\n", " ")
@ -706,7 +738,9 @@ def get_env_info():
return SystemEnv(
torch_version=version_str,
is_debug_build=debug_mode_str,
python_version='{} ({}-bit runtime)'.format(sys_version, sys.maxsize.bit_length() + 1),
python_version="{} ({}-bit runtime)".format(
sys_version, sys.maxsize.bit_length() + 1
),
python_platform=get_python_platform(),
is_cuda_available=cuda_available_str,
cuda_compiled_version=cuda_version_str,
@ -732,6 +766,7 @@ def get_env_info():
cpu_info=get_cpu_info(run_lambda),
)
env_info_fmt = """
PyTorch version: {torch_version}
Is debug build: {is_debug_build}
@ -767,14 +802,14 @@ Versions of relevant libraries:
def pretty_str(envinfo):
def replace_nones(dct, replacement='Could not collect'):
def replace_nones(dct, replacement="Could not collect"):
for key in dct.keys():
if dct[key] is not None:
continue
dct[key] = replacement
return dct
def replace_bools(dct, true='Yes', false='No'):
def replace_bools(dct, true="Yes", false="No"):
for key in dct.keys():
if dct[key] is True:
dct[key] = true
@ -782,42 +817,48 @@ def pretty_str(envinfo):
dct[key] = false
return dct
def prepend(text, tag='[prepend]'):
lines = text.split('\n')
def prepend(text, tag="[prepend]"):
lines = text.split("\n")
updated_lines = [tag + line for line in lines]
return '\n'.join(updated_lines)
return "\n".join(updated_lines)
def replace_if_empty(text, replacement='No relevant packages'):
def replace_if_empty(text, replacement="No relevant packages"):
if text is not None and len(text) == 0:
return replacement
return text
def maybe_start_on_next_line(string):
# If `string` is multiline, prepend a \n to it.
if string is not None and len(string.split('\n')) > 1:
return '\n{}\n'.format(string)
if string is not None and len(string.split("\n")) > 1:
return "\n{}\n".format(string)
return string
mutable_dict = envinfo._asdict()
# If nvidia_gpu_models is multiline, start on the next line
mutable_dict['nvidia_gpu_models'] = \
maybe_start_on_next_line(envinfo.nvidia_gpu_models)
mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line(
envinfo.nvidia_gpu_models
)
# If the machine doesn't have CUDA, report some fields as 'No CUDA'
dynamic_cuda_fields = [
'cuda_runtime_version',
'nvidia_gpu_models',
'nvidia_driver_version',
"cuda_runtime_version",
"nvidia_gpu_models",
"nvidia_driver_version",
]
all_cuda_fields = dynamic_cuda_fields + ['cudnn_version']
all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"]
all_dynamic_cuda_fields_missing = all(
mutable_dict[field] is None for field in dynamic_cuda_fields)
if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing:
mutable_dict[field] is None for field in dynamic_cuda_fields
)
if (
TORCH_AVAILABLE
and not torch.cuda.is_available()
and all_dynamic_cuda_fields_missing
):
for field in all_cuda_fields:
mutable_dict[field] = 'No CUDA'
mutable_dict[field] = "No CUDA"
if envinfo.cuda_compiled_version is None:
mutable_dict['cuda_compiled_version'] = 'None'
mutable_dict["cuda_compiled_version"] = "None"
# Replace True with Yes, False with No
mutable_dict = replace_bools(mutable_dict)
@ -826,18 +867,20 @@ def pretty_str(envinfo):
mutable_dict = replace_nones(mutable_dict)
# If either of these are '', replace with 'No relevant packages'
mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages'])
mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages'])
mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"])
mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"])
# Tag conda and pip packages with a prefix
# If they were previously None, they'll show up as ie '[conda] Could not collect'
if mutable_dict['pip_packages']:
mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'],
'[{}] '.format(envinfo.pip_version))
if mutable_dict['conda_packages']:
mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'],
'[conda] ')
mutable_dict['cpu_info'] = envinfo.cpu_info
if mutable_dict["pip_packages"]:
mutable_dict["pip_packages"] = prepend(
mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version)
)
if mutable_dict["conda_packages"]:
mutable_dict["conda_packages"] = prepend(
mutable_dict["conda_packages"], "[conda] "
)
mutable_dict["cpu_info"] = envinfo.cpu_info
return env_info_fmt.format(**mutable_dict)
@ -861,18 +904,29 @@ def main():
output = get_pretty_env_info()
print(output)
if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(torch.utils, '_crash_handler'):
if (
TORCH_AVAILABLE
and hasattr(torch, "utils")
and hasattr(torch.utils, "_crash_handler")
):
minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR
if sys.platform == "linux" and os.path.exists(minidump_dir):
dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)]
dumps = [
os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)
]
latest = max(dumps, key=os.path.getctime)
ctime = os.path.getctime(latest)
creation_time = datetime.datetime.fromtimestamp(ctime).strftime('%Y-%m-%d %H:%M:%S')
msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \
"if this is related to your bug please include it when you file a report ***"
creation_time = datetime.datetime.fromtimestamp(ctime).strftime(
"%Y-%m-%d %H:%M:%S"
)
msg = (
"\n*** Detected a minidump at {} created on {}, ".format(
latest, creation_time
)
+ "if this is related to your bug please include it when you file a report ***"
)
print(msg, file=sys.stderr)
if __name__ == '__main__':
if __name__ == "__main__":
main()