[BE] remove torch deploy - conditionals (#158288)

This PR is part of the work to deprecate torch::deploy in OSS. Effectively it does 3 things to get started.
1. Remove test_deploy_interaction as we no longer need to worry about this
2. Remove all torch._running_with_deploy checks and use the False path always (surfaced 1)
3. Remove `USE_DEPLOY` and switch to the default path always

Note: MyPy does fail on a bunch of things here as a bunch of older files are touched. It may be better to fix these things on a separate PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158288
Approved by: https://github.com/albanD
This commit is contained in:
PaliC
2025-07-25 17:46:23 -07:00
committed by PyTorch MergeBot
parent 5d89634ca8
commit 6162e650b0
22 changed files with 375 additions and 492 deletions

View File

@ -544,62 +544,6 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
class TestCustomOp(CustomOpTestCaseBase): class TestCustomOp(CustomOpTestCaseBase):
test_ns = "_test_custom_op" test_ns = "_test_custom_op"
def test_deploy_interaction(self):
# run in a different process to avoid parallel issues when we monkeypatch torch._running_with_deploy
script = """
import torch
torch._running_with_deploy = lambda: True
# creating the library is a no-op, so you can DEF multiple times
m1 = torch.library.Library("mylib4392", "DEF") # noqa: TOR901
m2 = torch.library.Library("mylib4392", "DEF") # noqa: TOR901
m = torch.library.Library("aten", "FRAGMENT") # noqa: TOR901
# define is a no-op
m.define("foobarbaz9996(Tensor x) -> Tensor")
assert not hasattr(torch.ops.aten, "foobarbaz9996"), "m.define should have been a noop"
def sin_override(x):
raise AssertionError("m.impl should have been a noop")
# impl is a no-op
m.impl("sin", sin_override, "CompositeImplicitAutograd")
x = torch.randn(3)
y = torch.sin(x)
# should be a no-op
@torch.library.custom_op("mylib::foobar", mutates_args={})
def foobar(x: torch.Tensor) -> torch.Tensor:
return x.sin()
# should be a no-op
@foobar.register_fake
def _(x):
return torch.empty_like(x)
# should be a no-op
m2.define("foobarbaz9996(Tensor x) -> Tensor")
# should be a no-op
@torch.library.register_fake("mylib4392::foobarbaz9996")
def _(x):
return torch.empty_like(x)
"""
script = script.strip()
env = os.environ.copy()
try:
subprocess.check_output(
[sys.executable, "-c", script],
stderr=subprocess.STDOUT,
# On Windows, opening the subprocess with the default CWD makes `import torch`
# fail, so just set CWD to this script's directory
cwd=os.path.dirname(os.path.realpath(__file__)),
env=env,
)
except subprocess.CalledProcessError as e:
self.fail(msg=("Subprocess exception:\n" + e.output.decode("utf-8")))
@requires_compile @requires_compile
def test_functionalize_error(self): def test_functionalize_error(self):
with torch.library._scoped_library(self.test_ns, "FRAGMENT") as lib: with torch.library._scoped_library(self.test_ns, "FRAGMENT") as lib:

View File

@ -3603,8 +3603,8 @@ class TestSparseCompressedTritonKernels(TestCase):
@onlyCUDA @onlyCUDA
@dtypes(torch.half, torch.bfloat16, torch.float) @dtypes(torch.half, torch.bfloat16, torch.float)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float) @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
@unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU) or torch._running_with_deploy(), @unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU),
"Skipped for deploy and internal with remote GPUs") "Skipped for internal with remote GPUs")
def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size): def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
from functools import partial from functools import partial
from torch.sparse._triton_ops import bsr_dense_mm from torch.sparse._triton_ops import bsr_dense_mm
@ -3680,8 +3680,8 @@ class TestSparseCompressedTritonKernels(TestCase):
@onlyCUDA @onlyCUDA
@dtypes(torch.half) @dtypes(torch.half)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU or torch._running_with_deploy(), @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU,
"Skipped for deploy and internal with remote GPUs") "Skipped for internal with remote GPUs")
def test_triton_bsr_dense_bmm_error_messages(self, device, dtype): def test_triton_bsr_dense_bmm_error_messages(self, device, dtype):
from torch.sparse._triton_ops import bsr_dense_mm from torch.sparse._triton_ops import bsr_dense_mm

View File

@ -34,18 +34,16 @@ from typing import (
) )
from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs
from . import version
if TYPE_CHECKING: if TYPE_CHECKING:
from .types import Device, IntLikeType from .types import Device, IntLikeType
# multipy/deploy is setting this import before importing torch, this is the most # codespell:ignore multipy # As a bunch of torch.packages internally still have this check
# reliable way we have to detect if we're running within deploy. # we need to keep this. @todo: Remove tests that rely on this check as
# https://github.com/pytorch/multipy/blob/d60f34ad38c371e441fe7ffdb77a3c3dda5a5d19/multipy/runtime/interpreter/interpreter_impl.cpp#L134-L137 # codespell:ignore multipy # noqa: B950 # they are likely stale.
def _running_with_deploy() -> builtins.bool: def _running_with_deploy() -> builtins.bool:
return sys.modules.get("torch._meta_registrations", None) is object return False
from torch._utils import ( from torch._utils import (
@ -60,14 +58,9 @@ from torch._utils_internal import (
USE_GLOBAL_DEPS, USE_GLOBAL_DEPS,
USE_RTLD_GLOBAL_WITH_LIBTORCH, USE_RTLD_GLOBAL_WITH_LIBTORCH,
) )
from torch.torch_version import __version__ as __version__
# TODO(torch_deploy) figure out how to freeze version.py in fbcode build
if _running_with_deploy():
__version__ = "torch-deploy-1.8"
else:
from torch.torch_version import __version__ as __version__
__all__ = [ __all__ = [
"BoolStorage", "BoolStorage",
"BoolTensor", "BoolTensor",
@ -317,7 +310,7 @@ def _preload_cuda_deps(lib_folder: str, lib_name: str) -> None:
# See Note [Global dependencies] # See Note [Global dependencies]
def _load_global_deps() -> None: def _load_global_deps() -> None:
if _running_with_deploy() or platform.system() == "Windows": if platform.system() == "Windows":
return return
# Determine the file extension based on the platform # Determine the file extension based on the platform
@ -381,7 +374,7 @@ def _load_global_deps() -> None:
if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and ( if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and (
_running_with_deploy() or platform.system() != "Windows" platform.system() != "Windows"
): ):
# Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a
# few circumstances: # few circumstances:
@ -2082,7 +2075,7 @@ from torch.serialization import load, save
# Shared memory manager needs to know the exact location of manager executable # Shared memory manager needs to know the exact location of manager executable
def _manager_path(): def _manager_path():
if _running_with_deploy() or platform.system() == "Windows": if platform.system() == "Windows":
return b"" return b""
path = get_file_path("torch", "bin", "torch_shm_manager") path = get_file_path("torch", "bin", "torch_shm_manager")
prepare_multiprocessing_environment(get_file_path("torch")) prepare_multiprocessing_environment(get_file_path("torch"))
@ -2687,21 +2680,21 @@ from torch import fx as fx
# Register MPS specific decomps # Register MPS specific decomps
torch.backends.mps._init() torch.backends.mps._init()
if not _running_with_deploy(): from torch import compiler as compiler
from torch import compiler as compiler
class _TritonLibrary:
lib = torch.library.Library("triton", "DEF")
ops_table: dict[tuple[str, str], _Callable] = {}
@classmethod class _TritonLibrary:
def registerOp(cls, op_key, full_schema, op_impl, dispatch_key): lib = torch.library.Library("triton", "DEF")
if (op_key, dispatch_key) not in cls.ops_table: ops_table: dict[tuple[str, str], _Callable] = {}
cls.lib.define(full_schema)
cls.lib.impl("triton::" + op_key, op_impl, dispatch_key)
cls.ops_table[(op_key, dispatch_key)] = op_impl
return cls.ops_table[(op_key, dispatch_key)] @classmethod
def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
if (op_key, dispatch_key) not in cls.ops_table:
cls.lib.define(full_schema)
cls.lib.impl("triton::" + op_key, op_impl, dispatch_key)
cls.ops_table[(op_key, dispatch_key)] = op_impl
return cls.ops_table[(op_key, dispatch_key)]
# Deprecated attributes # Deprecated attributes

View File

@ -49,47 +49,46 @@ Tensor = torch.Tensor
__all__ = ["trace_wrapped"] __all__ = ["trace_wrapped"]
if not torch._running_with_deploy(): @torch.library.custom_op("flex_lib::zeros_and_scatter", mutates_args=()) # type: ignore[misc]
# torch.library.custom_op does not work with torch.deploy/multipy # codespell:ignore def zeros_and_scatter(
shape: list[int],
indices: list[Tensor],
vals: Tensor,
) -> Tensor:
"""Custom Op so that we can register a custom lowering for the new_output + scatter in the backwards pass"""
grad = torch.zeros(shape, device=vals.device, dtype=vals.dtype)
return torch.ops.aten.index_put(grad, indices, vals, accumulate=True)
@torch.library.custom_op("flex_lib::zeros_and_scatter", mutates_args=()) # type: ignore[misc]
def zeros_and_scatter(
shape: list[int],
indices: list[Tensor],
vals: Tensor,
) -> Tensor:
"""Custom Op so that we can register a custom lowering for the new_output + scatter in the backwards pass"""
grad = torch.zeros(shape, device=vals.device, dtype=vals.dtype)
return torch.ops.aten.index_put(grad, indices, vals, accumulate=True)
@zeros_and_scatter.register_fake # type: ignore[misc] @zeros_and_scatter.register_fake # type: ignore[misc]
def _( def _(
shape: list[int], shape: list[int],
indices: list[Tensor], indices: list[Tensor],
vals: Tensor, vals: Tensor,
) -> Tensor: ) -> Tensor:
return vals.new_empty(shape) return vals.new_empty(shape)
@zeros_and_scatter.register_vmap # type: ignore[misc]
def _(info, indims, shape, indices, value): # type: ignore[no-untyped-def]
"""The batching rule is special in that it returns a tensor that is not batched"""
indices_indims = indims[1]
expanded_indices = []
for idx, idx_indim in zip(indices, indices_indims):
# The index is not a being batched, we should unsqueeze and expand to val
if idx_indim is None:
expanded_indices.append(idx.expand(value.shape))
else:
# the index is being part of the vmap batch, it should be the same size as val
assert idx.shape == value.shape
expanded_indices.append(idx)
out = torch.ops.flex_lib.zeros_and_scatter( @zeros_and_scatter.register_vmap # type: ignore[misc]
shape, def _(info, indims, shape, indices, value): # type: ignore[no-untyped-def]
expanded_indices, """The batching rule is special in that it returns a tensor that is not batched"""
value, indices_indims = indims[1]
) expanded_indices = []
return out, None for idx, idx_indim in zip(indices, indices_indims):
# The index is not a being batched, we should unsqueeze and expand to val
if idx_indim is None:
expanded_indices.append(idx.expand(value.shape))
else:
# the index is being part of the vmap batch, it should be the same size as val
assert idx.shape == value.shape
expanded_indices.append(idx)
out = torch.ops.flex_lib.zeros_and_scatter(
shape,
expanded_indices,
value,
)
return out, None
class ModIndex(torch.autograd.Function): class ModIndex(torch.autograd.Function):

View File

@ -2411,7 +2411,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch._lowrank.svd_lowrank", "torch._lowrank.svd_lowrank",
"torch._preload_cuda_deps", "torch._preload_cuda_deps",
"torch._register_device_module", "torch._register_device_module",
"torch._running_with_deploy",
"torch._utils._dummy_type", "torch._utils._dummy_type",
"torch._utils._flatten_dense_tensors", "torch._utils._flatten_dense_tensors",
"torch._utils._unflatten_dense_tensors", "torch._utils._unflatten_dense_tensors",

View File

@ -5,25 +5,24 @@ from torch import Tensor
from torch.autograd import Function from torch.autograd import Function
if not torch._running_with_deploy(): _test_lib_def = torch.library.Library("_inductor_test", "DEF")
_test_lib_def = torch.library.Library("_inductor_test", "DEF") _test_lib_def.define("realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag)
_test_lib_def.define(
"realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag
)
_test_lib_impl = torch.library.Library("_inductor_test", "IMPL") _test_lib_impl = torch.library.Library("_inductor_test", "IMPL")
for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"): for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"):
_test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key)
class Realize(Function):
@staticmethod
def forward(ctx: object, x: Tensor) -> Tensor:
return torch.ops._inductor_test.realize(x)
@staticmethod class Realize(Function):
# types need to stay consistent with _SingleLevelFunction @staticmethod
def backward(ctx: Any, *grad_output: Any) -> Any: def forward(ctx: object, x: Tensor) -> Tensor:
return grad_output[0] return torch.ops._inductor_test.realize(x)
def realize(x: Tensor) -> Tensor: @staticmethod
return Realize.apply(x) # types need to stay consistent with _SingleLevelFunction
def backward(ctx: Any, *grad_output: Any) -> Any:
return grad_output[0]
def realize(x: Tensor) -> Tensor:
return Realize.apply(x)

View File

@ -595,10 +595,6 @@ class CustomOpDef:
self._setup_context_fn = setup_context self._setup_context_fn = setup_context
def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None: def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None:
if torch._running_with_deploy():
utils.warn_deploy(stacklevel=5)
return
lib = self._lib lib = self._lib
schema_str = self._name + self._schema schema_str = self._name + self._schema
cpp_schema = _C.parse_schema(schema_str) cpp_schema = _C.parse_schema(schema_str)

View File

@ -2,7 +2,6 @@
import dataclasses import dataclasses
import inspect import inspect
import sys import sys
import warnings
from collections.abc import Iterable, Iterator from collections.abc import Iterable, Iterator
from typing import Any, Callable, Union from typing import Any, Callable, Union
@ -12,15 +11,6 @@ from torch import _C, _utils_internal
from torch._ops import OpOverload from torch._ops import OpOverload
def warn_deploy(stacklevel=3):
warnings.warn(
"Python torch.library APIs do nothing under torch::deploy (multipy). " # codespell:ignore multipy
"Please instead use C++ custom operator registration APIs.",
RuntimeWarning,
stacklevel=stacklevel,
)
@dataclasses.dataclass @dataclasses.dataclass
class Kernel: class Kernel:
"""Models a (function, source location)""" """Models a (function, source location)"""

View File

@ -1478,9 +1478,6 @@ class _Ops(types.ModuleType):
Args: Args:
path (str): A path to a shared library to load. path (str): A path to a shared library to load.
""" """
if torch._running_with_deploy():
return
path = _utils_internal.resolve_library_path(path) path = _utils_internal.resolve_library_path(path)
with dl_open_guard(): with dl_open_guard():
# Import the shared library into the process, thus running its # Import the shared library into the process, thus running its

View File

@ -33,16 +33,10 @@ if os.environ.get("TORCH_COMPILE_STROBELIGHT", False):
# use is the FB build environment, where this source file is replaced # use is the FB build environment, where this source file is replaced
# by an equivalent. # by an equivalent.
if torch._running_with_deploy(): if os.path.basename(os.path.dirname(__file__)) == "shared":
# __file__ is meaningless in the context of frozen torch used in torch deploy. torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# setting empty torch_parent should allow below functions to operate without crashing,
# but it's unclear if there is a valid use case for them in the context of deploy.
torch_parent = ""
else: else:
if os.path.basename(os.path.dirname(__file__)) == "shared": torch_parent = os.path.dirname(os.path.dirname(__file__))
torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
else:
torch_parent = os.path.dirname(os.path.dirname(__file__))
def get_file_path(*path_components: str) -> str: def get_file_path(*path_components: str) -> str:

View File

@ -331,13 +331,9 @@ void initLazyBindings(PyObject* module) {
// So far this problem has only been observed internally, so we will just // So far this problem has only been observed internally, so we will just
// block it off there. // block it off there.
#if !(defined(USE_DEPLOY))
// When libtorch_python is loaded, we register the python frame getter // When libtorch_python is loaded, we register the python frame getter
// otherwise, debug util simply omits python frames // otherwise, debug util simply omits python frames
GetPythonFramesFunction() = GetPythonFrames; GetPythonFramesFunction() = GetPythonFrames;
#endif // USE_DEPLOY
} }
} // namespace torch::lazy } // namespace torch::lazy

View File

@ -187,15 +187,6 @@ class PythonKernelHolder : public c10::OperatorKernel {
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size()); auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
py::gil_scoped_acquire g; py::gil_scoped_acquire g;
// Jan 2024: We're slated to get rid of multipy, // codespell:ignore multipy
// so stop forcing hermetic mode unconditionally in all situations when
// you're using multipy. // codespell:ignore multipy
// Eventually just delete this entirely. (Note that you may break
// multipy anyway this way with dispatcher // codespell:ignore multipy
// registered functions that require hermetic to be off.)
#if defined(USE_DEPLOY)
EnableHermeticPyObject g2;
#endif
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
auto func = auto func =
py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter())); py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter()));

View File

@ -1693,9 +1693,6 @@ class _WrappedTritonKernel:
def _register_triton_kernels(): def _register_triton_kernels():
if torch._running_with_deploy():
return
@_WrappedTritonKernel @_WrappedTritonKernel
def kernel_impl(*args, **kwargs): def kernel_impl(*args, **kwargs):
from torch.sparse._triton_ops import bsr_dense_mm from torch.sparse._triton_ops import bsr_dense_mm

View File

@ -19,22 +19,16 @@ except ImportError:
from torch.utils._pytree import tree_map_only # type: ignore[no-redef] from torch.utils._pytree import tree_map_only # type: ignore[no-redef]
if torch._running_with_deploy(): try:
from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
except Exception:
warnings.warn(
"Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly"
)
def is_torchdynamo_compiling(): def is_torchdynamo_compiling(): # type: ignore[misc]
"""Can't import torchdynamo in torchdeploy builds currently.""" return False
return False return False
else:
try:
from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
except Exception:
warnings.warn(
"Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly"
)
def is_torchdynamo_compiling():
return False
""" """
@ -985,66 +979,58 @@ def _reduce_scatter_tensor_coalesced_native_meta(
] ]
if not torch._running_with_deploy(): # Library MUST be defined at module scope or it doesn't work
# Library MUST be defined at module scope or it doesn't work lib_impl = torch.library.Library("_c10d_functional", "IMPL")
# Creating a "DEF" Library always crashes torch::deploy so we create our lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")
# Library instances here guarded against running inside it lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta")
lib_impl = torch.library.Library("_c10d_functional", "IMPL") lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta")
lib_impl.impl("all_reduce", _all_reduce_meta, "Meta") lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta")
lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta") lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta")
lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") lib_impl.impl(
lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta"
lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") )
lib_impl.impl( lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta")
"all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta" lib_impl.impl(
) "all_gather_into_tensor_coalesced",
lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta") _all_gather_into_tensor_coalesced_native_meta,
lib_impl.impl( "Meta",
"all_gather_into_tensor_coalesced", )
_all_gather_into_tensor_coalesced_native_meta, lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta")
"Meta", lib_impl.impl(
) "reduce_scatter_tensor_coalesced",
lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta") _reduce_scatter_tensor_coalesced_native_meta,
lib_impl.impl( "Meta",
"reduce_scatter_tensor_coalesced", )
_reduce_scatter_tensor_coalesced_native_meta, lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta")
"Meta", lib_impl.impl("broadcast", _broadcast_meta, "Meta")
) lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta")
lib_impl.impl("broadcast", _broadcast_meta, "Meta")
lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
# mark these ops has side effect so that they won't be removed by DCE # mark these ops has side effect so that they won't be removed by DCE
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default)
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor)
# Register legacy ops for backward compatibility # Register legacy ops for backward compatibility
# TODO(yifu): remove these in functional collective beta release # TODO(yifu): remove these in functional collective beta release
legacy_lib = torch.library.Library("c10d_functional", "DEF") legacy_lib = torch.library.Library("c10d_functional", "DEF")
legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL") legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL")
ops_defs = [ ops_defs = [
"broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor", "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor",
"all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
"all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
"wait_tensor(Tensor self) -> Tensor", "wait_tensor(Tensor self) -> Tensor",
"all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor", "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor",
"all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]", "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]",
"reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
"reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
"all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950 "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950
] ]
my_module = sys.modules[__name__] my_module = sys.modules[__name__]
for op_def in ops_defs: for op_def in ops_defs:
op_name = op_def[0 : op_def.index("(")] op_name = op_def[0 : op_def.index("(")]
backend_impl = getattr(fun_col_impl, f"_{op_name}") backend_impl = getattr(fun_col_impl, f"_{op_name}")
legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag) legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag)
legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd") legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd")
else:
warnings.warn(
"PyTorch Distributed functional collectives do not work with torch::deploy."
)
""" """

View File

@ -63,10 +63,9 @@ _META_FUNCTIONS = {
"recv_any_source_": lambda *args: create_fakework(args, return_first_arg=False), "recv_any_source_": lambda *args: create_fakework(args, return_first_arg=False),
} }
if not torch._running_with_deploy(): lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901
lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901 for op, meta_func in _META_FUNCTIONS.items():
for op, meta_func in _META_FUNCTIONS.items(): lib_impl.impl(op, meta_func, "Meta")
lib_impl.impl(op, meta_func, "Meta")
# List of collective operation functions including functional collectives # List of collective operation functions including functional collectives
# Note: The following collectives might be deprecated soon hence not adding them # Note: The following collectives might be deprecated soon hence not adding them

View File

@ -15,32 +15,24 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec
_compiled_autograd_enabled: bool = False _compiled_autograd_enabled: bool = False
if torch._running_with_deploy():
def detect_compiled_autograd(): def detect_compiled_autograd():
pass assert not torch.compiler.is_compiling(), (
"`detect_compiled_autograd()` is designed to be called in eager mode"
)
global _compiled_autograd_enabled
import torch._dynamo.compiled_autograd as ca
def compiled_autograd_enabled(): _compiled_autograd_enabled = (
return False ca.compiled_autograd_enabled
or ca.compiled_autograd_enabled_force_eager
or ca.in_compiled_autograd_region
)
else:
def detect_compiled_autograd(): def compiled_autograd_enabled():
assert not torch.compiler.is_compiling(), ( global _compiled_autograd_enabled
"`detect_compiled_autograd()` is designed to be called in eager mode" return _compiled_autograd_enabled
)
global _compiled_autograd_enabled
import torch._dynamo.compiled_autograd as ca
_compiled_autograd_enabled = (
ca.compiled_autograd_enabled
or ca.compiled_autograd_enabled_force_eager
or ca.in_compiled_autograd_region
)
def compiled_autograd_enabled():
global _compiled_autograd_enabled
return _compiled_autograd_enabled
@dataclass @dataclass

View File

@ -140,8 +140,7 @@ def copy__functionalize(tensor, data):
torch.ops.fsdp.copy_.default(tensor_inner, data_inner) torch.ops.fsdp.copy_.default(tensor_inner, data_inner)
if not torch._running_with_deploy(): torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default)
torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default)
class ShardedState(Enum): class ShardedState(Enum):

View File

@ -25,26 +25,17 @@ from torch.distributed.distributed_c10d import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if not torch._running_with_deploy(): @torch.library.register_fake("_dtensor::shard_dim_alltoall")
def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name):
group_size = _get_group_size_by_name(group_name)
stacked_list = [torch.empty_like(input) for _ in range(group_size)]
group = _resolve_process_group(group_name)
group_rank = get_group_rank(group, get_rank())
@torch.library.register_fake("_dtensor::shard_dim_alltoall") return (
def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): torch.cat(stacked_list, dim=gather_dim)
group_size = _get_group_size_by_name(group_name) .chunk(group_size, dim=shard_dim)[group_rank]
stacked_list = [torch.empty_like(input) for _ in range(group_size)] .contiguous()
group = _resolve_process_group(group_name)
group_rank = get_group_rank(group, get_rank())
return (
torch.cat(stacked_list, dim=gather_dim)
.chunk(group_size, dim=shard_dim)[group_rank]
.contiguous()
)
else:
import warnings
warnings.warn(
"PyTorch Distributed functional collectives do not work with torch::deploy."
) )

View File

@ -102,9 +102,6 @@ class Library:
ns, ns,
" is a reserved namespace. Please try creating a library with another name.", " is a reserved namespace. Please try creating a library with another name.",
) )
if torch._running_with_deploy():
_library.utils.warn_deploy()
return
frame = traceback.extract_stack(limit=3)[0] frame = traceback.extract_stack(limit=3)[0]
filename, lineno = frame.filename, frame.lineno filename, lineno = frame.filename, frame.lineno
@ -156,9 +153,6 @@ class Library:
>>> my_lib = Library("mylib", "DEF") >>> my_lib = Library("mylib", "DEF")
>>> my_lib.define("sum(Tensor self) -> Tensor") >>> my_lib.define("sum(Tensor self) -> Tensor")
""" """
if torch._running_with_deploy():
_library.utils.warn_deploy()
return
# This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
# AliasAnalysis type in C++ # AliasAnalysis type in C++
@ -191,9 +185,6 @@ class Library:
def _register_fake(self, op_name, fn, _stacklevel=1, *, allow_override=False): def _register_fake(self, op_name, fn, _stacklevel=1, *, allow_override=False):
r"""Registers the fake impl for an operator defined in the library.""" r"""Registers the fake impl for an operator defined in the library."""
if torch._running_with_deploy():
_library.utils.warn_deploy()
return
source = torch._library.utils.get_source(_stacklevel + 1) source = torch._library.utils.get_source(_stacklevel + 1)
frame = sys._getframe(_stacklevel) frame = sys._getframe(_stacklevel)
@ -237,9 +228,6 @@ class Library:
If it is a TorchDispatchMode, we expect fn to have the following signature: If it is a TorchDispatchMode, we expect fn to have the following signature:
(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
""" """
if torch._running_with_deploy():
_library.utils.warn_deploy()
return
qualname = f"{self.ns}::{op_name}" qualname = f"{self.ns}::{op_name}"
entry = torch._library.simple_registry.singleton.find(qualname) entry = torch._library.simple_registry.singleton.find(qualname)
@ -259,9 +247,6 @@ class Library:
>>> my_lib = Library("aten", "IMPL") >>> my_lib = Library("aten", "IMPL")
>>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU") >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
""" """
if torch._running_with_deploy():
_library.utils.warn_deploy()
return
if dispatch_key == "": if dispatch_key == "":
dispatch_key = self.dispatch_key dispatch_key = self.dispatch_key
@ -324,9 +309,6 @@ class Library:
>>> return self * (1 / other) >>> return self * (1 / other)
>>> my_lib.impl("div.Tensor", div_cpu, "CPU") >>> my_lib.impl("div.Tensor", div_cpu, "CPU")
""" """
if torch._running_with_deploy():
_library.utils.warn_deploy()
return
if not callable(fn): if not callable(fn):
raise TypeError( raise TypeError(
@ -409,9 +391,6 @@ class Library:
>>> # ... >>> # ...
>>> my_lib.fallback(fallback_kernel, "Autocast") >>> my_lib.fallback(fallback_kernel, "Autocast")
""" """
if torch._running_with_deploy():
_library.utils.warn_deploy()
return
if dispatch_key == "": if dispatch_key == "":
dispatch_key = self.dispatch_key dispatch_key = self.dispatch_key

View File

@ -29,13 +29,7 @@ def set_module(obj, mod):
obj.__module__ = mod obj.__module__ = mod
if torch._running_with_deploy(): cmake_prefix_path = _osp.join(_osp.dirname(_osp.dirname(__file__)), "share", "cmake")
# not valid inside torch_deploy interpreter, no paths exists for frozen modules
cmake_prefix_path = None
else:
cmake_prefix_path = _osp.join(
_osp.dirname(_osp.dirname(__file__)), "share", "cmake"
)
def swap_tensors(t1, t2): def swap_tensors(t1, t2):

View File

@ -3,8 +3,6 @@ import importlib.util
from types import ModuleType from types import ModuleType
from typing import Optional from typing import Optional
import torch
def _check_module_exists(name: str) -> bool: def _check_module_exists(name: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without** r"""Returns if a top-level module with :attr:`name` exists *without**
@ -22,11 +20,7 @@ def _check_module_exists(name: str) -> bool:
@functools.lru_cache @functools.lru_cache
def dill_available() -> bool: def dill_available() -> bool:
return ( return _check_module_exists("dill")
_check_module_exists("dill")
# dill fails to import under torchdeploy
and not torch._running_with_deploy()
)
@functools.lru_cache @functools.lru_cache

View File

@ -6,49 +6,53 @@
import datetime import datetime
import json import json
import locale import locale
import os
import re import re
import subprocess import subprocess
import sys import sys
import os
from typing import cast as _cast
from collections import namedtuple from collections import namedtuple
from typing import cast as _cast
try: try:
import torch import torch
TORCH_AVAILABLE = True TORCH_AVAILABLE = True
except (ImportError, NameError, AttributeError, OSError): except (ImportError, NameError, AttributeError, OSError):
TORCH_AVAILABLE = False TORCH_AVAILABLE = False
# System Environment Information # System Environment Information
SystemEnv = namedtuple('SystemEnv', [ SystemEnv = namedtuple(
'torch_version', "SystemEnv",
'is_debug_build', [
'cuda_compiled_version', "torch_version",
'gcc_version', "is_debug_build",
'clang_version', "cuda_compiled_version",
'cmake_version', "gcc_version",
'os', "clang_version",
'libc_version', "cmake_version",
'python_version', "os",
'python_platform', "libc_version",
'is_cuda_available', "python_version",
'cuda_runtime_version', "python_platform",
'cuda_module_loading', "is_cuda_available",
'nvidia_driver_version', "cuda_runtime_version",
'nvidia_gpu_models', "cuda_module_loading",
'cudnn_version', "nvidia_driver_version",
'is_xpu_available', "nvidia_gpu_models",
'pip_version', # 'pip' or 'pip3' "cudnn_version",
'pip_packages', "is_xpu_available",
'conda_packages', "pip_version", # 'pip' or 'pip3'
'hip_compiled_version', "pip_packages",
'hip_runtime_version', "conda_packages",
'miopen_runtime_version', "hip_compiled_version",
'caching_allocator_config', "hip_runtime_version",
'is_xnnpack_available', "miopen_runtime_version",
'cpu_info', "caching_allocator_config",
]) "is_xnnpack_available",
"cpu_info",
],
)
COMMON_PATTERNS = [ COMMON_PATTERNS = [
"torch", "torch",
@ -116,12 +120,13 @@ PIP_PATTERNS = [
def run(command): def run(command):
"""Return (return-code, stdout, stderr).""" """Return (return-code, stdout, stderr)."""
shell = True if type(command) is str else False shell = True if type(command) is str else False
p = subprocess.Popen(command, stdout=subprocess.PIPE, p = subprocess.Popen(
stderr=subprocess.PIPE, shell=shell) command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell
)
raw_output, raw_err = p.communicate() raw_output, raw_err = p.communicate()
rc = p.returncode rc = p.returncode
if get_platform() == 'win32': if get_platform() == "win32":
enc = 'oem' enc = "oem"
else: else:
enc = locale.getpreferredencoding() enc = locale.getpreferredencoding()
output = raw_output.decode(enc) output = raw_output.decode(enc)
@ -147,18 +152,19 @@ def run_and_parse_first_match(run_lambda, command, regex):
return None return None
return match.group(1) return match.group(1)
def run_and_return_first_line(run_lambda, command): def run_and_return_first_line(run_lambda, command):
"""Run command using run_lambda and returns first line if output is not empty.""" """Run command using run_lambda and returns first line if output is not empty."""
rc, out, _ = run_lambda(command) rc, out, _ = run_lambda(command)
if rc != 0: if rc != 0:
return None return None
return out.split('\n')[0] return out.split("\n")[0]
def get_conda_packages(run_lambda, patterns=None): def get_conda_packages(run_lambda, patterns=None):
if patterns is None: if patterns is None:
patterns = CONDA_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + ONEAPI_PATTERNS patterns = CONDA_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + ONEAPI_PATTERNS
conda = os.environ.get('CONDA_EXE', 'conda') conda = os.environ.get("CONDA_EXE", "conda")
out = run_and_read_all(run_lambda, "{} list".format(conda)) out = run_and_read_all(run_lambda, "{} list".format(conda))
if out is None: if out is None:
return out return out
@ -166,32 +172,40 @@ def get_conda_packages(run_lambda, patterns=None):
return "\n".join( return "\n".join(
line line
for line in out.splitlines() for line in out.splitlines()
if not line.startswith("#") if not line.startswith("#") and any(name in line for name in patterns)
and any(name in line for name in patterns)
) )
def get_gcc_version(run_lambda): def get_gcc_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)")
def get_clang_version(run_lambda): def get_clang_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)') return run_and_parse_first_match(
run_lambda, "clang --version", r"clang version (.*)"
)
def get_cmake_version(run_lambda): def get_cmake_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)') return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)")
def get_nvidia_driver_version(run_lambda): def get_nvidia_driver_version(run_lambda):
if get_platform() == 'darwin': if get_platform() == "darwin":
cmd = 'kextstat | grep -i cuda' cmd = "kextstat | grep -i cuda"
return run_and_parse_first_match(run_lambda, cmd, return run_and_parse_first_match(
r'com[.]nvidia[.]CUDA [(](.*?)[)]') run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]"
)
smi = get_nvidia_smi() smi = get_nvidia_smi()
return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ') return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ")
def get_gpu_info(run_lambda): def get_gpu_info(run_lambda):
if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None): if get_platform() == "darwin" or (
TORCH_AVAILABLE
and hasattr(torch.version, "hip")
and torch.version.hip is not None
):
if TORCH_AVAILABLE and torch.cuda.is_available(): if TORCH_AVAILABLE and torch.cuda.is_available():
if torch.version.hip is not None: if torch.version.hip is not None:
prop = torch.cuda.get_device_properties(0) prop = torch.cuda.get_device_properties(0)
@ -204,42 +218,42 @@ def get_gpu_info(run_lambda):
return torch.cuda.get_device_name(None) + gcnArch return torch.cuda.get_device_name(None) + gcnArch
return None return None
smi = get_nvidia_smi() smi = get_nvidia_smi()
uuid_regex = re.compile(r' \(UUID: .+?\)') uuid_regex = re.compile(r" \(UUID: .+?\)")
rc, out, _ = run_lambda(smi + ' -L') rc, out, _ = run_lambda(smi + " -L")
if rc != 0: if rc != 0:
return None return None
# Anonymize GPUs by removing their UUID # Anonymize GPUs by removing their UUID
return re.sub(uuid_regex, '', out) return re.sub(uuid_regex, "", out)
def get_running_cuda_version(run_lambda): def get_running_cuda_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)') return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)")
def get_cudnn_version(run_lambda): def get_cudnn_version(run_lambda):
"""Return a list of libcudnn.so; it's hard to tell which one is being used.""" """Return a list of libcudnn.so; it's hard to tell which one is being used."""
if get_platform() == 'win32': if get_platform() == "win32":
system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%")
where_cmd = os.path.join(system_root, 'System32', 'where') where_cmd = os.path.join(system_root, "System32", "where")
cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path)
elif get_platform() == 'darwin': elif get_platform() == "darwin":
# CUDA libraries and drivers can be found in /usr/local/cuda/. See # CUDA libraries and drivers can be found in /usr/local/cuda/. See
# https://docs.nvidia.com/cuda/archive/9.0/cuda-installation-guide-mac-os-x/index.html#installation # https://docs.nvidia.com/cuda/archive/9.0/cuda-installation-guide-mac-os-x/index.html#installation
# https://docs.nvidia.com/deeplearning/cudnn/installation/latest/ # https://docs.nvidia.com/deeplearning/cudnn/installation/latest/
# Use CUDNN_LIBRARY when cudnn library is installed elsewhere. # Use CUDNN_LIBRARY when cudnn library is installed elsewhere.
cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*"
else: else:
cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev'
rc, out, _ = run_lambda(cudnn_cmd) rc, out, _ = run_lambda(cudnn_cmd)
# find will return 1 if there are permission errors or if not found # find will return 1 if there are permission errors or if not found
if len(out) == 0 or (rc != 1 and rc != 0): if len(out) == 0 or (rc != 1 and rc != 0):
l = os.environ.get('CUDNN_LIBRARY') l = os.environ.get("CUDNN_LIBRARY")
if l is not None and os.path.isfile(l): if l is not None and os.path.isfile(l):
return os.path.realpath(l) return os.path.realpath(l)
return None return None
files_set = set() files_set = set()
for fn in out.split('\n'): for fn in out.split("\n"):
fn = os.path.realpath(fn) # eliminate symbolic links fn = os.path.realpath(fn) # eliminate symbolic links
if os.path.isfile(fn): if os.path.isfile(fn):
files_set.add(fn) files_set.add(fn)
@ -249,18 +263,20 @@ def get_cudnn_version(run_lambda):
files = sorted(files_set) files = sorted(files_set)
if len(files) == 1: if len(files) == 1:
return files[0] return files[0]
result = '\n'.join(files) result = "\n".join(files)
return 'Probably one of the following:\n{}'.format(result) return "Probably one of the following:\n{}".format(result)
def get_nvidia_smi(): def get_nvidia_smi():
# Note: nvidia-smi is currently available only on Windows and Linux # Note: nvidia-smi is currently available only on Windows and Linux
smi = 'nvidia-smi' smi = "nvidia-smi"
if get_platform() == 'win32': if get_platform() == "win32":
system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files') program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files")
legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi) legacy_path = os.path.join(
new_path = os.path.join(system_root, 'System32', smi) program_files_root, "NVIDIA Corporation", "NVSMI", smi
)
new_path = os.path.join(system_root, "System32", smi)
smis = [new_path, legacy_path] smis = [new_path, legacy_path]
for candidate_smi in smis: for candidate_smi in smis:
if os.path.exists(candidate_smi): if os.path.exists(candidate_smi):
@ -411,7 +427,9 @@ def get_intel_gpu_detected(run_lambda):
if device_count == 0: if device_count == 0:
return "N/A" return "N/A"
devices = [f"* [{i}] {torch.xpu.get_device_properties(i)}" for i in range(device_count)] devices = [
f"* [{i}] {torch.xpu.get_device_properties(i)}" for i in range(device_count)
]
return "\n".join(devices) return "\n".join(devices)
@ -490,11 +508,12 @@ def get_intel_gpu_detected(run_lambda):
# ProcessorType=3 # ProcessorType=3
# Revision=27142 # Revision=27142
def get_cpu_info(run_lambda): def get_cpu_info(run_lambda):
rc, out, err = 0, '', '' rc, out, err = 0, "", ""
if get_platform() == 'linux': if get_platform() == "linux":
rc, out, err = run_lambda('lscpu') rc, out, err = run_lambda("lscpu")
elif get_platform() == 'win32': elif get_platform() == "win32":
rc, out, err = run_lambda( rc, out, err = run_lambda(
'powershell.exe "gwmi -Class Win32_Processor | Select-Object -Property Name,Manufacturer,Family,\ 'powershell.exe "gwmi -Class Win32_Processor | Select-Object -Property Name,Manufacturer,Family,\
Architecture,ProcessorType,DeviceID,CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision\ Architecture,ProcessorType,DeviceID,CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision\
@ -514,9 +533,9 @@ def get_cpu_info(run_lambda):
lst.append(out) lst.append(out)
lst.append(str(e)) lst.append(str(e))
out = "\n".join(lst) out = "\n".join(lst)
elif get_platform() == 'darwin': elif get_platform() == "darwin":
rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string")
cpu_info = 'None' cpu_info = "None"
if rc == 0: if rc == 0:
cpu_info = out cpu_info = out
else: else:
@ -525,20 +544,20 @@ def get_cpu_info(run_lambda):
def get_platform(): def get_platform():
if sys.platform.startswith('linux'): if sys.platform.startswith("linux"):
return 'linux' return "linux"
elif sys.platform.startswith('win32'): elif sys.platform.startswith("win32"):
return 'win32' return "win32"
elif sys.platform.startswith('cygwin'): elif sys.platform.startswith("cygwin"):
return 'cygwin' return "cygwin"
elif sys.platform.startswith('darwin'): elif sys.platform.startswith("darwin"):
return 'darwin' return "darwin"
else: else:
return sys.platform return sys.platform
def get_mac_version(run_lambda): def get_mac_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)') return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)")
def get_windows_version(run_lambda): def get_windows_version(run_lambda):
@ -556,39 +575,43 @@ def get_windows_version(run_lambda):
def get_lsb_version(run_lambda): def get_lsb_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)') return run_and_parse_first_match(
run_lambda, "lsb_release -a", r"Description:\t(.*)"
)
def check_release_file(run_lambda): def check_release_file(run_lambda):
return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', return run_and_parse_first_match(
r'PRETTY_NAME="(.*)"') run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"'
)
def get_os(run_lambda): def get_os(run_lambda):
from platform import machine from platform import machine
platform = get_platform() platform = get_platform()
if platform in ["win32", "cygwin"]: if platform in ["win32", "cygwin"]:
return get_windows_version(run_lambda) return get_windows_version(run_lambda)
if platform == 'darwin': if platform == "darwin":
version = get_mac_version(run_lambda) version = get_mac_version(run_lambda)
if version is None: if version is None:
return None return None
return 'macOS {} ({})'.format(version, machine()) return "macOS {} ({})".format(version, machine())
if platform == 'linux': if platform == "linux":
# Ubuntu/Debian based # Ubuntu/Debian based
desc = get_lsb_version(run_lambda) desc = get_lsb_version(run_lambda)
if desc is not None: if desc is not None:
return '{} ({})'.format(desc, machine()) return "{} ({})".format(desc, machine())
# Try reading /etc/*-release # Try reading /etc/*-release
desc = check_release_file(run_lambda) desc = check_release_file(run_lambda)
if desc is not None: if desc is not None:
return '{} ({})'.format(desc, machine()) return "{} ({})".format(desc, machine())
return '{} ({})'.format(platform, machine()) return "{} ({})".format(platform, machine())
# Unknown platform # Unknown platform
return platform return platform
@ -596,14 +619,16 @@ def get_os(run_lambda):
def get_python_platform(): def get_python_platform():
import platform import platform
return platform.platform() return platform.platform()
def get_libc_version(): def get_libc_version():
import platform import platform
if get_platform() != 'linux':
return 'N/A' if get_platform() != "linux":
return '-'.join(platform.libc_ver()) return "N/A"
return "-".join(platform.libc_ver())
def get_pip_packages(run_lambda, patterns=None): def get_pip_packages(run_lambda, patterns=None):
@ -611,35 +636,35 @@ def get_pip_packages(run_lambda, patterns=None):
if patterns is None: if patterns is None:
patterns = PIP_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + ONEAPI_PATTERNS patterns = PIP_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + ONEAPI_PATTERNS
pip_version = 'pip3' if sys.version_info.major == 3 else 'pip' pip_version = "pip3" if sys.version_info.major == 3 else "pip"
os.environ['PIP_DISABLE_PIP_VERSION_CHECK'] = '1' os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1"
# People generally have pip as `pip` or `pip3` # People generally have pip as `pip` or `pip3`
# But here it is invoked as `python -mpip` # But here it is invoked as `python -mpip`
out = run_and_read_all(run_lambda, [sys.executable, '-mpip', 'list', '--format=freeze']) out = run_and_read_all(
run_lambda, [sys.executable, "-mpip", "list", "--format=freeze"]
)
if out is None: if out is None:
return pip_version, out return pip_version, out
filtered_out = '\n'.join( filtered_out = "\n".join(
line line for line in out.splitlines() if any(name in line for name in patterns)
for line in out.splitlines()
if any(name in line for name in patterns)
) )
return pip_version, filtered_out return pip_version, filtered_out
def get_cachingallocator_config(): def get_cachingallocator_config():
ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '') ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
if not ca_config: if not ca_config:
ca_config = os.environ.get('PYTORCH_HIP_ALLOC_CONF', '') ca_config = os.environ.get("PYTORCH_HIP_ALLOC_CONF", "")
return ca_config return ca_config
def get_cuda_module_loading_config(): def get_cuda_module_loading_config():
if TORCH_AVAILABLE and torch.cuda.is_available(): if TORCH_AVAILABLE and torch.cuda.is_available():
torch.cuda.init() torch.cuda.init()
config = os.environ.get('CUDA_MODULE_LOADING', '') config = os.environ.get("CUDA_MODULE_LOADING", "")
return config return config
else: else:
return "N/A" return "N/A"
@ -648,10 +673,12 @@ def get_cuda_module_loading_config():
def is_xnnpack_available(): def is_xnnpack_available():
if TORCH_AVAILABLE: if TORCH_AVAILABLE:
import torch.backends.xnnpack import torch.backends.xnnpack
return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
else: else:
return "N/A" return "N/A"
def get_env_info(): def get_env_info():
""" """
Collects environment information to aid in debugging. Collects environment information to aid in debugging.
@ -678,26 +705,31 @@ def get_env_info():
cuda_version_str = torch.version.cuda cuda_version_str = torch.version.cuda
xpu_available_str = str(torch.xpu.is_available()) xpu_available_str = str(torch.xpu.is_available())
if torch.xpu.is_available(): if torch.xpu.is_available():
xpu_available_str = f'{xpu_available_str}\n' + \ xpu_available_str = (
f'XPU used to build PyTorch: {torch.version.xpu}\n' + \ f"{xpu_available_str}\n"
f'Intel GPU driver version:\n{get_intel_gpu_driver_version(run_lambda)}\n' + \ + f"XPU used to build PyTorch: {torch.version.xpu}\n"
f'Intel GPU models onboard:\n{get_intel_gpu_onboard(run_lambda)}\n' + \ + f"Intel GPU driver version:\n{get_intel_gpu_driver_version(run_lambda)}\n"
f'Intel GPU models detected:\n{get_intel_gpu_detected(run_lambda)}' + f"Intel GPU models onboard:\n{get_intel_gpu_onboard(run_lambda)}\n"
if not hasattr(torch.version, 'hip') or torch.version.hip is None: # cuda version + f"Intel GPU models detected:\n{get_intel_gpu_detected(run_lambda)}"
hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' )
if (
not hasattr(torch.version, "hip") or torch.version.hip is None
): # cuda version
hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A"
else: # HIP version else: # HIP version
def get_version_or_na(cfg, prefix): def get_version_or_na(cfg, prefix):
_lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s]
return _lst[0] if _lst else 'N/A' return _lst[0] if _lst else "N/A"
cfg = torch._C._show_config().split('\n') cfg = torch._C._show_config().split("\n")
hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime') hip_runtime_version = get_version_or_na(cfg, "HIP Runtime")
miopen_runtime_version = get_version_or_na(cfg, 'MIOpen') miopen_runtime_version = get_version_or_na(cfg, "MIOpen")
cuda_version_str = 'N/A' cuda_version_str = "N/A"
hip_compiled_version = torch.version.hip hip_compiled_version = torch.version.hip
else: else:
version_str = debug_mode_str = cuda_available_str = cuda_version_str = xpu_available_str = 'N/A' version_str = debug_mode_str = cuda_available_str = cuda_version_str = xpu_available_str = "N/A" # type: ignore[assignment]
hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A"
sys_version = sys.version.replace("\n", " ") sys_version = sys.version.replace("\n", " ")
@ -706,7 +738,9 @@ def get_env_info():
return SystemEnv( return SystemEnv(
torch_version=version_str, torch_version=version_str,
is_debug_build=debug_mode_str, is_debug_build=debug_mode_str,
python_version='{} ({}-bit runtime)'.format(sys_version, sys.maxsize.bit_length() + 1), python_version="{} ({}-bit runtime)".format(
sys_version, sys.maxsize.bit_length() + 1
),
python_platform=get_python_platform(), python_platform=get_python_platform(),
is_cuda_available=cuda_available_str, is_cuda_available=cuda_available_str,
cuda_compiled_version=cuda_version_str, cuda_compiled_version=cuda_version_str,
@ -732,6 +766,7 @@ def get_env_info():
cpu_info=get_cpu_info(run_lambda), cpu_info=get_cpu_info(run_lambda),
) )
env_info_fmt = """ env_info_fmt = """
PyTorch version: {torch_version} PyTorch version: {torch_version}
Is debug build: {is_debug_build} Is debug build: {is_debug_build}
@ -767,14 +802,14 @@ Versions of relevant libraries:
def pretty_str(envinfo): def pretty_str(envinfo):
def replace_nones(dct, replacement='Could not collect'): def replace_nones(dct, replacement="Could not collect"):
for key in dct.keys(): for key in dct.keys():
if dct[key] is not None: if dct[key] is not None:
continue continue
dct[key] = replacement dct[key] = replacement
return dct return dct
def replace_bools(dct, true='Yes', false='No'): def replace_bools(dct, true="Yes", false="No"):
for key in dct.keys(): for key in dct.keys():
if dct[key] is True: if dct[key] is True:
dct[key] = true dct[key] = true
@ -782,42 +817,48 @@ def pretty_str(envinfo):
dct[key] = false dct[key] = false
return dct return dct
def prepend(text, tag='[prepend]'): def prepend(text, tag="[prepend]"):
lines = text.split('\n') lines = text.split("\n")
updated_lines = [tag + line for line in lines] updated_lines = [tag + line for line in lines]
return '\n'.join(updated_lines) return "\n".join(updated_lines)
def replace_if_empty(text, replacement='No relevant packages'): def replace_if_empty(text, replacement="No relevant packages"):
if text is not None and len(text) == 0: if text is not None and len(text) == 0:
return replacement return replacement
return text return text
def maybe_start_on_next_line(string): def maybe_start_on_next_line(string):
# If `string` is multiline, prepend a \n to it. # If `string` is multiline, prepend a \n to it.
if string is not None and len(string.split('\n')) > 1: if string is not None and len(string.split("\n")) > 1:
return '\n{}\n'.format(string) return "\n{}\n".format(string)
return string return string
mutable_dict = envinfo._asdict() mutable_dict = envinfo._asdict()
# If nvidia_gpu_models is multiline, start on the next line # If nvidia_gpu_models is multiline, start on the next line
mutable_dict['nvidia_gpu_models'] = \ mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line(
maybe_start_on_next_line(envinfo.nvidia_gpu_models) envinfo.nvidia_gpu_models
)
# If the machine doesn't have CUDA, report some fields as 'No CUDA' # If the machine doesn't have CUDA, report some fields as 'No CUDA'
dynamic_cuda_fields = [ dynamic_cuda_fields = [
'cuda_runtime_version', "cuda_runtime_version",
'nvidia_gpu_models', "nvidia_gpu_models",
'nvidia_driver_version', "nvidia_driver_version",
] ]
all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"]
all_dynamic_cuda_fields_missing = all( all_dynamic_cuda_fields_missing = all(
mutable_dict[field] is None for field in dynamic_cuda_fields) mutable_dict[field] is None for field in dynamic_cuda_fields
if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing: )
if (
TORCH_AVAILABLE
and not torch.cuda.is_available()
and all_dynamic_cuda_fields_missing
):
for field in all_cuda_fields: for field in all_cuda_fields:
mutable_dict[field] = 'No CUDA' mutable_dict[field] = "No CUDA"
if envinfo.cuda_compiled_version is None: if envinfo.cuda_compiled_version is None:
mutable_dict['cuda_compiled_version'] = 'None' mutable_dict["cuda_compiled_version"] = "None"
# Replace True with Yes, False with No # Replace True with Yes, False with No
mutable_dict = replace_bools(mutable_dict) mutable_dict = replace_bools(mutable_dict)
@ -826,18 +867,20 @@ def pretty_str(envinfo):
mutable_dict = replace_nones(mutable_dict) mutable_dict = replace_nones(mutable_dict)
# If either of these are '', replace with 'No relevant packages' # If either of these are '', replace with 'No relevant packages'
mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages']) mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"])
mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages']) mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"])
# Tag conda and pip packages with a prefix # Tag conda and pip packages with a prefix
# If they were previously None, they'll show up as ie '[conda] Could not collect' # If they were previously None, they'll show up as ie '[conda] Could not collect'
if mutable_dict['pip_packages']: if mutable_dict["pip_packages"]:
mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'], mutable_dict["pip_packages"] = prepend(
'[{}] '.format(envinfo.pip_version)) mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version)
if mutable_dict['conda_packages']: )
mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'], if mutable_dict["conda_packages"]:
'[conda] ') mutable_dict["conda_packages"] = prepend(
mutable_dict['cpu_info'] = envinfo.cpu_info mutable_dict["conda_packages"], "[conda] "
)
mutable_dict["cpu_info"] = envinfo.cpu_info
return env_info_fmt.format(**mutable_dict) return env_info_fmt.format(**mutable_dict)
@ -861,18 +904,29 @@ def main():
output = get_pretty_env_info() output = get_pretty_env_info()
print(output) print(output)
if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(torch.utils, '_crash_handler'): if (
TORCH_AVAILABLE
and hasattr(torch, "utils")
and hasattr(torch.utils, "_crash_handler")
):
minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR
if sys.platform == "linux" and os.path.exists(minidump_dir): if sys.platform == "linux" and os.path.exists(minidump_dir):
dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)] dumps = [
os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)
]
latest = max(dumps, key=os.path.getctime) latest = max(dumps, key=os.path.getctime)
ctime = os.path.getctime(latest) ctime = os.path.getctime(latest)
creation_time = datetime.datetime.fromtimestamp(ctime).strftime('%Y-%m-%d %H:%M:%S') creation_time = datetime.datetime.fromtimestamp(ctime).strftime(
msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \ "%Y-%m-%d %H:%M:%S"
"if this is related to your bug please include it when you file a report ***" )
msg = (
"\n*** Detected a minidump at {} created on {}, ".format(
latest, creation_time
)
+ "if this is related to your bug please include it when you file a report ***"
)
print(msg, file=sys.stderr) print(msg, file=sys.stderr)
if __name__ == "__main__":
if __name__ == '__main__':
main() main()