From 6162e650b0e0cae720837a0e46070e3b7fc754b5 Mon Sep 17 00:00:00 2001 From: PaliC Date: Fri, 25 Jul 2025 17:46:23 -0700 Subject: [PATCH] [BE] remove torch deploy - conditionals (#158288) This PR is part of the work to deprecate torch::deploy in OSS. Effectively it does 3 things to get started. 1. Remove test_deploy_interaction as we no longer need to worry about this 2. Remove all torch._running_with_deploy checks and use the False path always (surfaced 1) 3. Remove `USE_DEPLOY` and switch to the default path always Note: MyPy does fail on a bunch of things here as a bunch of older files are touched. It may be better to fix these things on a separate PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/158288 Approved by: https://github.com/albanD --- test/test_custom_ops.py | 56 --- test/test_sparse_csr.py | 8 +- torch/__init__.py | 47 +-- .../_dynamo/_trace_wrapped_higher_order_op.py | 73 ++-- torch/_dynamo/trace_rules.py | 1 - torch/_inductor/test_operators.py | 35 +- torch/_library/custom_ops.py | 4 - torch/_library/utils.py | 10 - torch/_ops.py | 3 - torch/_utils_internal.py | 12 +- torch/csrc/lazy/python/init.cpp | 4 - torch/csrc/utils/python_dispatch.cpp | 9 - torch/cuda/__init__.py | 3 - torch/distributed/_functional_collectives.py | 128 +++---- torch/distributed/_tools/fake_collectives.py | 7 +- .../fsdp/_fully_shard/_fsdp_common.py | 36 +- .../fsdp/_fully_shard/_fsdp_param.py | 3 +- torch/distributed/tensor/_collective_utils.py | 29 +- torch/library.py | 21 - torch/utils/__init__.py | 8 +- torch/utils/_import_utils.py | 8 +- torch/utils/collect_env.py | 362 ++++++++++-------- 22 files changed, 375 insertions(+), 492 deletions(-) diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index f8d845652ee2..b713edeb7a95 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -544,62 +544,6 @@ class TestCustomOpTesting(CustomOpTestCaseBase): class TestCustomOp(CustomOpTestCaseBase): test_ns = "_test_custom_op" - def test_deploy_interaction(self): - # run in a different process to avoid parallel issues when we monkeypatch torch._running_with_deploy - script = """ -import torch -torch._running_with_deploy = lambda: True - -# creating the library is a no-op, so you can DEF multiple times -m1 = torch.library.Library("mylib4392", "DEF") # noqa: TOR901 -m2 = torch.library.Library("mylib4392", "DEF") # noqa: TOR901 - -m = torch.library.Library("aten", "FRAGMENT") # noqa: TOR901 - -# define is a no-op -m.define("foobarbaz9996(Tensor x) -> Tensor") -assert not hasattr(torch.ops.aten, "foobarbaz9996"), "m.define should have been a noop" - -def sin_override(x): - raise AssertionError("m.impl should have been a noop") - -# impl is a no-op -m.impl("sin", sin_override, "CompositeImplicitAutograd") -x = torch.randn(3) -y = torch.sin(x) - -# should be a no-op -@torch.library.custom_op("mylib::foobar", mutates_args={}) -def foobar(x: torch.Tensor) -> torch.Tensor: - return x.sin() - -# should be a no-op -@foobar.register_fake -def _(x): - return torch.empty_like(x) - -# should be a no-op -m2.define("foobarbaz9996(Tensor x) -> Tensor") - -# should be a no-op -@torch.library.register_fake("mylib4392::foobarbaz9996") -def _(x): - return torch.empty_like(x) - """ - script = script.strip() - env = os.environ.copy() - try: - subprocess.check_output( - [sys.executable, "-c", script], - stderr=subprocess.STDOUT, - # On Windows, opening the subprocess with the default CWD makes `import torch` - # fail, so just set CWD to this script's directory - cwd=os.path.dirname(os.path.realpath(__file__)), - env=env, - ) - except subprocess.CalledProcessError as e: - self.fail(msg=("Subprocess exception:\n" + e.output.decode("utf-8"))) - @requires_compile def test_functionalize_error(self): with torch.library._scoped_library(self.test_ns, "FRAGMENT") as lib: diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index cc313c586a09..8fb490e1b5bc 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -3603,8 +3603,8 @@ class TestSparseCompressedTritonKernels(TestCase): @onlyCUDA @dtypes(torch.half, torch.bfloat16, torch.float) @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float) - @unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU) or torch._running_with_deploy(), - "Skipped for deploy and internal with remote GPUs") + @unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU), + "Skipped for internal with remote GPUs") def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size): from functools import partial from torch.sparse._triton_ops import bsr_dense_mm @@ -3680,8 +3680,8 @@ class TestSparseCompressedTritonKernels(TestCase): @onlyCUDA @dtypes(torch.half) - @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU or torch._running_with_deploy(), - "Skipped for deploy and internal with remote GPUs") + @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, + "Skipped for internal with remote GPUs") def test_triton_bsr_dense_bmm_error_messages(self, device, dtype): from torch.sparse._triton_ops import bsr_dense_mm diff --git a/torch/__init__.py b/torch/__init__.py index 99cb83db84b8..34340b51d0e7 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -34,18 +34,16 @@ from typing import ( ) from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs -from . import version - if TYPE_CHECKING: from .types import Device, IntLikeType -# multipy/deploy is setting this import before importing torch, this is the most # codespell:ignore multipy -# reliable way we have to detect if we're running within deploy. -# https://github.com/pytorch/multipy/blob/d60f34ad38c371e441fe7ffdb77a3c3dda5a5d19/multipy/runtime/interpreter/interpreter_impl.cpp#L134-L137 # codespell:ignore multipy # noqa: B950 +# As a bunch of torch.packages internally still have this check +# we need to keep this. @todo: Remove tests that rely on this check as +# they are likely stale. def _running_with_deploy() -> builtins.bool: - return sys.modules.get("torch._meta_registrations", None) is object + return False from torch._utils import ( @@ -60,14 +58,9 @@ from torch._utils_internal import ( USE_GLOBAL_DEPS, USE_RTLD_GLOBAL_WITH_LIBTORCH, ) +from torch.torch_version import __version__ as __version__ -# TODO(torch_deploy) figure out how to freeze version.py in fbcode build -if _running_with_deploy(): - __version__ = "torch-deploy-1.8" -else: - from torch.torch_version import __version__ as __version__ - __all__ = [ "BoolStorage", "BoolTensor", @@ -317,7 +310,7 @@ def _preload_cuda_deps(lib_folder: str, lib_name: str) -> None: # See Note [Global dependencies] def _load_global_deps() -> None: - if _running_with_deploy() or platform.system() == "Windows": + if platform.system() == "Windows": return # Determine the file extension based on the platform @@ -381,7 +374,7 @@ def _load_global_deps() -> None: if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and ( - _running_with_deploy() or platform.system() != "Windows" + platform.system() != "Windows" ): # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a # few circumstances: @@ -2082,7 +2075,7 @@ from torch.serialization import load, save # Shared memory manager needs to know the exact location of manager executable def _manager_path(): - if _running_with_deploy() or platform.system() == "Windows": + if platform.system() == "Windows": return b"" path = get_file_path("torch", "bin", "torch_shm_manager") prepare_multiprocessing_environment(get_file_path("torch")) @@ -2687,21 +2680,21 @@ from torch import fx as fx # Register MPS specific decomps torch.backends.mps._init() -if not _running_with_deploy(): - from torch import compiler as compiler +from torch import compiler as compiler - class _TritonLibrary: - lib = torch.library.Library("triton", "DEF") - ops_table: dict[tuple[str, str], _Callable] = {} - @classmethod - def registerOp(cls, op_key, full_schema, op_impl, dispatch_key): - if (op_key, dispatch_key) not in cls.ops_table: - cls.lib.define(full_schema) - cls.lib.impl("triton::" + op_key, op_impl, dispatch_key) - cls.ops_table[(op_key, dispatch_key)] = op_impl +class _TritonLibrary: + lib = torch.library.Library("triton", "DEF") + ops_table: dict[tuple[str, str], _Callable] = {} - return cls.ops_table[(op_key, dispatch_key)] + @classmethod + def registerOp(cls, op_key, full_schema, op_impl, dispatch_key): + if (op_key, dispatch_key) not in cls.ops_table: + cls.lib.define(full_schema) + cls.lib.impl("triton::" + op_key, op_impl, dispatch_key) + cls.ops_table[(op_key, dispatch_key)] = op_impl + + return cls.ops_table[(op_key, dispatch_key)] # Deprecated attributes diff --git a/torch/_dynamo/_trace_wrapped_higher_order_op.py b/torch/_dynamo/_trace_wrapped_higher_order_op.py index 8fab0b200549..17b664fc5e0e 100644 --- a/torch/_dynamo/_trace_wrapped_higher_order_op.py +++ b/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -49,47 +49,46 @@ Tensor = torch.Tensor __all__ = ["trace_wrapped"] -if not torch._running_with_deploy(): - # torch.library.custom_op does not work with torch.deploy/multipy # codespell:ignore +@torch.library.custom_op("flex_lib::zeros_and_scatter", mutates_args=()) # type: ignore[misc] +def zeros_and_scatter( + shape: list[int], + indices: list[Tensor], + vals: Tensor, +) -> Tensor: + """Custom Op so that we can register a custom lowering for the new_output + scatter in the backwards pass""" + grad = torch.zeros(shape, device=vals.device, dtype=vals.dtype) + return torch.ops.aten.index_put(grad, indices, vals, accumulate=True) - @torch.library.custom_op("flex_lib::zeros_and_scatter", mutates_args=()) # type: ignore[misc] - def zeros_and_scatter( - shape: list[int], - indices: list[Tensor], - vals: Tensor, - ) -> Tensor: - """Custom Op so that we can register a custom lowering for the new_output + scatter in the backwards pass""" - grad = torch.zeros(shape, device=vals.device, dtype=vals.dtype) - return torch.ops.aten.index_put(grad, indices, vals, accumulate=True) - @zeros_and_scatter.register_fake # type: ignore[misc] - def _( - shape: list[int], - indices: list[Tensor], - vals: Tensor, - ) -> Tensor: - return vals.new_empty(shape) +@zeros_and_scatter.register_fake # type: ignore[misc] +def _( + shape: list[int], + indices: list[Tensor], + vals: Tensor, +) -> Tensor: + return vals.new_empty(shape) - @zeros_and_scatter.register_vmap # type: ignore[misc] - def _(info, indims, shape, indices, value): # type: ignore[no-untyped-def] - """The batching rule is special in that it returns a tensor that is not batched""" - indices_indims = indims[1] - expanded_indices = [] - for idx, idx_indim in zip(indices, indices_indims): - # The index is not a being batched, we should unsqueeze and expand to val - if idx_indim is None: - expanded_indices.append(idx.expand(value.shape)) - else: - # the index is being part of the vmap batch, it should be the same size as val - assert idx.shape == value.shape - expanded_indices.append(idx) - out = torch.ops.flex_lib.zeros_and_scatter( - shape, - expanded_indices, - value, - ) - return out, None +@zeros_and_scatter.register_vmap # type: ignore[misc] +def _(info, indims, shape, indices, value): # type: ignore[no-untyped-def] + """The batching rule is special in that it returns a tensor that is not batched""" + indices_indims = indims[1] + expanded_indices = [] + for idx, idx_indim in zip(indices, indices_indims): + # The index is not a being batched, we should unsqueeze and expand to val + if idx_indim is None: + expanded_indices.append(idx.expand(value.shape)) + else: + # the index is being part of the vmap batch, it should be the same size as val + assert idx.shape == value.shape + expanded_indices.append(idx) + + out = torch.ops.flex_lib.zeros_and_scatter( + shape, + expanded_indices, + value, + ) + return out, None class ModIndex(torch.autograd.Function): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 5a9feb29f192..3684aca12852 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2411,7 +2411,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys( "torch._lowrank.svd_lowrank", "torch._preload_cuda_deps", "torch._register_device_module", - "torch._running_with_deploy", "torch._utils._dummy_type", "torch._utils._flatten_dense_tensors", "torch._utils._unflatten_dense_tensors", diff --git a/torch/_inductor/test_operators.py b/torch/_inductor/test_operators.py index bf49f3f5d04a..d3d2705f8c78 100644 --- a/torch/_inductor/test_operators.py +++ b/torch/_inductor/test_operators.py @@ -5,25 +5,24 @@ from torch import Tensor from torch.autograd import Function -if not torch._running_with_deploy(): - _test_lib_def = torch.library.Library("_inductor_test", "DEF") - _test_lib_def.define( - "realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag - ) +_test_lib_def = torch.library.Library("_inductor_test", "DEF") +_test_lib_def.define("realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag) - _test_lib_impl = torch.library.Library("_inductor_test", "IMPL") - for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"): - _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) +_test_lib_impl = torch.library.Library("_inductor_test", "IMPL") +for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"): + _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) - class Realize(Function): - @staticmethod - def forward(ctx: object, x: Tensor) -> Tensor: - return torch.ops._inductor_test.realize(x) - @staticmethod - # types need to stay consistent with _SingleLevelFunction - def backward(ctx: Any, *grad_output: Any) -> Any: - return grad_output[0] +class Realize(Function): + @staticmethod + def forward(ctx: object, x: Tensor) -> Tensor: + return torch.ops._inductor_test.realize(x) - def realize(x: Tensor) -> Tensor: - return Realize.apply(x) + @staticmethod + # types need to stay consistent with _SingleLevelFunction + def backward(ctx: Any, *grad_output: Any) -> Any: + return grad_output[0] + + +def realize(x: Tensor) -> Tensor: + return Realize.apply(x) diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 3dfe21d45894..bd8acb2789e1 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -595,10 +595,6 @@ class CustomOpDef: self._setup_context_fn = setup_context def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None: - if torch._running_with_deploy(): - utils.warn_deploy(stacklevel=5) - return - lib = self._lib schema_str = self._name + self._schema cpp_schema = _C.parse_schema(schema_str) diff --git a/torch/_library/utils.py b/torch/_library/utils.py index 17e128bdbe0f..940318520452 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -2,7 +2,6 @@ import dataclasses import inspect import sys -import warnings from collections.abc import Iterable, Iterator from typing import Any, Callable, Union @@ -12,15 +11,6 @@ from torch import _C, _utils_internal from torch._ops import OpOverload -def warn_deploy(stacklevel=3): - warnings.warn( - "Python torch.library APIs do nothing under torch::deploy (multipy). " # codespell:ignore multipy - "Please instead use C++ custom operator registration APIs.", - RuntimeWarning, - stacklevel=stacklevel, - ) - - @dataclasses.dataclass class Kernel: """Models a (function, source location)""" diff --git a/torch/_ops.py b/torch/_ops.py index fecfebaeaa53..83a5dc0e57a5 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1478,9 +1478,6 @@ class _Ops(types.ModuleType): Args: path (str): A path to a shared library to load. """ - if torch._running_with_deploy(): - return - path = _utils_internal.resolve_library_path(path) with dl_open_guard(): # Import the shared library into the process, thus running its diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 8a3236260b9d..c3de3b3af59d 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -33,16 +33,10 @@ if os.environ.get("TORCH_COMPILE_STROBELIGHT", False): # use is the FB build environment, where this source file is replaced # by an equivalent. -if torch._running_with_deploy(): - # __file__ is meaningless in the context of frozen torch used in torch deploy. - # setting empty torch_parent should allow below functions to operate without crashing, - # but it's unclear if there is a valid use case for them in the context of deploy. - torch_parent = "" +if os.path.basename(os.path.dirname(__file__)) == "shared": + torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) else: - if os.path.basename(os.path.dirname(__file__)) == "shared": - torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - else: - torch_parent = os.path.dirname(os.path.dirname(__file__)) + torch_parent = os.path.dirname(os.path.dirname(__file__)) def get_file_path(*path_components: str) -> str: diff --git a/torch/csrc/lazy/python/init.cpp b/torch/csrc/lazy/python/init.cpp index f2b14cbfd7bb..4807aa6a4c7d 100644 --- a/torch/csrc/lazy/python/init.cpp +++ b/torch/csrc/lazy/python/init.cpp @@ -331,13 +331,9 @@ void initLazyBindings(PyObject* module) { // So far this problem has only been observed internally, so we will just // block it off there. -#if !(defined(USE_DEPLOY)) - // When libtorch_python is loaded, we register the python frame getter // otherwise, debug util simply omits python frames GetPythonFramesFunction() = GetPythonFrames; - -#endif // USE_DEPLOY } } // namespace torch::lazy diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 34fbfec49c91..b2b0e848a7e7 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -187,15 +187,6 @@ class PythonKernelHolder : public c10::OperatorKernel { auto arguments = torch::jit::pop(*stack, op.schema().arguments().size()); py::gil_scoped_acquire g; - // Jan 2024: We're slated to get rid of multipy, // codespell:ignore multipy - // so stop forcing hermetic mode unconditionally in all situations when - // you're using multipy. // codespell:ignore multipy - // Eventually just delete this entirely. (Note that you may break - // multipy anyway this way with dispatcher // codespell:ignore multipy - // registered functions that require hermetic to be off.) -#if defined(USE_DEPLOY) - EnableHermeticPyObject g2; -#endif auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); auto func = py::reinterpret_borrow(func_.ptr(getPyInterpreter())); diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 6a2d62bd424c..01bc4d73a459 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1693,9 +1693,6 @@ class _WrappedTritonKernel: def _register_triton_kernels(): - if torch._running_with_deploy(): - return - @_WrappedTritonKernel def kernel_impl(*args, **kwargs): from torch.sparse._triton_ops import bsr_dense_mm diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 46c2ac1a698f..8472f0d9dd04 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -19,22 +19,16 @@ except ImportError: from torch.utils._pytree import tree_map_only # type: ignore[no-redef] -if torch._running_with_deploy(): +try: + from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling +except Exception: + warnings.warn( + "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly" + ) - def is_torchdynamo_compiling(): - """Can't import torchdynamo in torchdeploy builds currently.""" + def is_torchdynamo_compiling(): # type: ignore[misc] + return False return False - -else: - try: - from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling - except Exception: - warnings.warn( - "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly" - ) - - def is_torchdynamo_compiling(): - return False """ @@ -985,66 +979,58 @@ def _reduce_scatter_tensor_coalesced_native_meta( ] -if not torch._running_with_deploy(): - # Library MUST be defined at module scope or it doesn't work - # Creating a "DEF" Library always crashes torch::deploy so we create our - # Library instances here guarded against running inside it - lib_impl = torch.library.Library("_c10d_functional", "IMPL") - lib_impl.impl("all_reduce", _all_reduce_meta, "Meta") - lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta") - lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") - lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") - lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") - lib_impl.impl( - "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta" - ) - lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta") - lib_impl.impl( - "all_gather_into_tensor_coalesced", - _all_gather_into_tensor_coalesced_native_meta, - "Meta", - ) - lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta") - lib_impl.impl( - "reduce_scatter_tensor_coalesced", - _reduce_scatter_tensor_coalesced_native_meta, - "Meta", - ) - lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta") - lib_impl.impl("broadcast", _broadcast_meta, "Meta") - lib_impl.impl("broadcast_", _broadcast__meta, "Meta") +# Library MUST be defined at module scope or it doesn't work +lib_impl = torch.library.Library("_c10d_functional", "IMPL") +lib_impl.impl("all_reduce", _all_reduce_meta, "Meta") +lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta") +lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") +lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") +lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") +lib_impl.impl( + "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta" +) +lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta") +lib_impl.impl( + "all_gather_into_tensor_coalesced", + _all_gather_into_tensor_coalesced_native_meta, + "Meta", +) +lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta") +lib_impl.impl( + "reduce_scatter_tensor_coalesced", + _reduce_scatter_tensor_coalesced_native_meta, + "Meta", +) +lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta") +lib_impl.impl("broadcast", _broadcast_meta, "Meta") +lib_impl.impl("broadcast_", _broadcast__meta, "Meta") - # mark these ops has side effect so that they won't be removed by DCE - torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) - torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) +# mark these ops has side effect so that they won't be removed by DCE +torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) +torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) - # Register legacy ops for backward compatibility - # TODO(yifu): remove these in functional collective beta release - legacy_lib = torch.library.Library("c10d_functional", "DEF") - legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL") - ops_defs = [ - "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor", - "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", - "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", - "wait_tensor(Tensor self) -> Tensor", - "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor", - "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]", - "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", - "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", - "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950 - ] +# Register legacy ops for backward compatibility +# TODO(yifu): remove these in functional collective beta release +legacy_lib = torch.library.Library("c10d_functional", "DEF") +legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL") +ops_defs = [ + "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor", + "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", + "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", + "wait_tensor(Tensor self) -> Tensor", + "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor", + "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]", + "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", + "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", + "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950 +] - my_module = sys.modules[__name__] - for op_def in ops_defs: - op_name = op_def[0 : op_def.index("(")] - backend_impl = getattr(fun_col_impl, f"_{op_name}") - legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag) - legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd") - -else: - warnings.warn( - "PyTorch Distributed functional collectives do not work with torch::deploy." - ) +my_module = sys.modules[__name__] +for op_def in ops_defs: + op_name = op_def[0 : op_def.index("(")] + backend_impl = getattr(fun_col_impl, f"_{op_name}") + legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag) + legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd") """ diff --git a/torch/distributed/_tools/fake_collectives.py b/torch/distributed/_tools/fake_collectives.py index f6cb23a06b67..3b201b395334 100644 --- a/torch/distributed/_tools/fake_collectives.py +++ b/torch/distributed/_tools/fake_collectives.py @@ -63,10 +63,9 @@ _META_FUNCTIONS = { "recv_any_source_": lambda *args: create_fakework(args, return_first_arg=False), } -if not torch._running_with_deploy(): - lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901 - for op, meta_func in _META_FUNCTIONS.items(): - lib_impl.impl(op, meta_func, "Meta") +lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901 +for op, meta_func in _META_FUNCTIONS.items(): + lib_impl.impl(op, meta_func, "Meta") # List of collective operation functions including functional collectives # Note: The following collectives might be deprecated soon hence not adding them diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py index fdcf32e22a33..b599f48d77d1 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py @@ -15,32 +15,24 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec _compiled_autograd_enabled: bool = False -if torch._running_with_deploy(): - def detect_compiled_autograd(): - pass +def detect_compiled_autograd(): + assert not torch.compiler.is_compiling(), ( + "`detect_compiled_autograd()` is designed to be called in eager mode" + ) + global _compiled_autograd_enabled + import torch._dynamo.compiled_autograd as ca - def compiled_autograd_enabled(): - return False + _compiled_autograd_enabled = ( + ca.compiled_autograd_enabled + or ca.compiled_autograd_enabled_force_eager + or ca.in_compiled_autograd_region + ) -else: - def detect_compiled_autograd(): - assert not torch.compiler.is_compiling(), ( - "`detect_compiled_autograd()` is designed to be called in eager mode" - ) - global _compiled_autograd_enabled - import torch._dynamo.compiled_autograd as ca - - _compiled_autograd_enabled = ( - ca.compiled_autograd_enabled - or ca.compiled_autograd_enabled_force_eager - or ca.in_compiled_autograd_region - ) - - def compiled_autograd_enabled(): - global _compiled_autograd_enabled - return _compiled_autograd_enabled +def compiled_autograd_enabled(): + global _compiled_autograd_enabled + return _compiled_autograd_enabled @dataclass diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py index 7649c32ec1c0..b7c8f4ea7c78 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py @@ -140,8 +140,7 @@ def copy__functionalize(tensor, data): torch.ops.fsdp.copy_.default(tensor_inner, data_inner) -if not torch._running_with_deploy(): - torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default) +torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default) class ShardedState(Enum): diff --git a/torch/distributed/tensor/_collective_utils.py b/torch/distributed/tensor/_collective_utils.py index 36316b2f0567..4fce6fea538a 100644 --- a/torch/distributed/tensor/_collective_utils.py +++ b/torch/distributed/tensor/_collective_utils.py @@ -25,26 +25,17 @@ from torch.distributed.distributed_c10d import ( logger = logging.getLogger(__name__) -if not torch._running_with_deploy(): +@torch.library.register_fake("_dtensor::shard_dim_alltoall") +def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): + group_size = _get_group_size_by_name(group_name) + stacked_list = [torch.empty_like(input) for _ in range(group_size)] + group = _resolve_process_group(group_name) + group_rank = get_group_rank(group, get_rank()) - @torch.library.register_fake("_dtensor::shard_dim_alltoall") - def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): - group_size = _get_group_size_by_name(group_name) - stacked_list = [torch.empty_like(input) for _ in range(group_size)] - group = _resolve_process_group(group_name) - group_rank = get_group_rank(group, get_rank()) - - return ( - torch.cat(stacked_list, dim=gather_dim) - .chunk(group_size, dim=shard_dim)[group_rank] - .contiguous() - ) - -else: - import warnings - - warnings.warn( - "PyTorch Distributed functional collectives do not work with torch::deploy." + return ( + torch.cat(stacked_list, dim=gather_dim) + .chunk(group_size, dim=shard_dim)[group_rank] + .contiguous() ) diff --git a/torch/library.py b/torch/library.py index 11f70e36c0f2..f24c3fbd4276 100644 --- a/torch/library.py +++ b/torch/library.py @@ -102,9 +102,6 @@ class Library: ns, " is a reserved namespace. Please try creating a library with another name.", ) - if torch._running_with_deploy(): - _library.utils.warn_deploy() - return frame = traceback.extract_stack(limit=3)[0] filename, lineno = frame.filename, frame.lineno @@ -156,9 +153,6 @@ class Library: >>> my_lib = Library("mylib", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor") """ - if torch._running_with_deploy(): - _library.utils.warn_deploy() - return # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid # AliasAnalysis type in C++ @@ -191,9 +185,6 @@ class Library: def _register_fake(self, op_name, fn, _stacklevel=1, *, allow_override=False): r"""Registers the fake impl for an operator defined in the library.""" - if torch._running_with_deploy(): - _library.utils.warn_deploy() - return source = torch._library.utils.get_source(_stacklevel + 1) frame = sys._getframe(_stacklevel) @@ -237,9 +228,6 @@ class Library: If it is a TorchDispatchMode, we expect fn to have the following signature: (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any """ - if torch._running_with_deploy(): - _library.utils.warn_deploy() - return qualname = f"{self.ns}::{op_name}" entry = torch._library.simple_registry.singleton.find(qualname) @@ -259,9 +247,6 @@ class Library: >>> my_lib = Library("aten", "IMPL") >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU") """ - if torch._running_with_deploy(): - _library.utils.warn_deploy() - return if dispatch_key == "": dispatch_key = self.dispatch_key @@ -324,9 +309,6 @@ class Library: >>> return self * (1 / other) >>> my_lib.impl("div.Tensor", div_cpu, "CPU") """ - if torch._running_with_deploy(): - _library.utils.warn_deploy() - return if not callable(fn): raise TypeError( @@ -409,9 +391,6 @@ class Library: >>> # ... >>> my_lib.fallback(fallback_kernel, "Autocast") """ - if torch._running_with_deploy(): - _library.utils.warn_deploy() - return if dispatch_key == "": dispatch_key = self.dispatch_key diff --git a/torch/utils/__init__.py b/torch/utils/__init__.py index 23188bba9b80..1c3ec1579006 100644 --- a/torch/utils/__init__.py +++ b/torch/utils/__init__.py @@ -29,13 +29,7 @@ def set_module(obj, mod): obj.__module__ = mod -if torch._running_with_deploy(): - # not valid inside torch_deploy interpreter, no paths exists for frozen modules - cmake_prefix_path = None -else: - cmake_prefix_path = _osp.join( - _osp.dirname(_osp.dirname(__file__)), "share", "cmake" - ) +cmake_prefix_path = _osp.join(_osp.dirname(_osp.dirname(__file__)), "share", "cmake") def swap_tensors(t1, t2): diff --git a/torch/utils/_import_utils.py b/torch/utils/_import_utils.py index dc2d7d4f0382..240f92acacb9 100644 --- a/torch/utils/_import_utils.py +++ b/torch/utils/_import_utils.py @@ -3,8 +3,6 @@ import importlib.util from types import ModuleType from typing import Optional -import torch - def _check_module_exists(name: str) -> bool: r"""Returns if a top-level module with :attr:`name` exists *without** @@ -22,11 +20,7 @@ def _check_module_exists(name: str) -> bool: @functools.lru_cache def dill_available() -> bool: - return ( - _check_module_exists("dill") - # dill fails to import under torchdeploy - and not torch._running_with_deploy() - ) + return _check_module_exists("dill") @functools.lru_cache diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index 9bb80c65076b..c6473220bc00 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -6,49 +6,53 @@ import datetime import json import locale +import os import re import subprocess import sys -import os -from typing import cast as _cast from collections import namedtuple +from typing import cast as _cast try: import torch + TORCH_AVAILABLE = True except (ImportError, NameError, AttributeError, OSError): TORCH_AVAILABLE = False # System Environment Information -SystemEnv = namedtuple('SystemEnv', [ - 'torch_version', - 'is_debug_build', - 'cuda_compiled_version', - 'gcc_version', - 'clang_version', - 'cmake_version', - 'os', - 'libc_version', - 'python_version', - 'python_platform', - 'is_cuda_available', - 'cuda_runtime_version', - 'cuda_module_loading', - 'nvidia_driver_version', - 'nvidia_gpu_models', - 'cudnn_version', - 'is_xpu_available', - 'pip_version', # 'pip' or 'pip3' - 'pip_packages', - 'conda_packages', - 'hip_compiled_version', - 'hip_runtime_version', - 'miopen_runtime_version', - 'caching_allocator_config', - 'is_xnnpack_available', - 'cpu_info', -]) +SystemEnv = namedtuple( + "SystemEnv", + [ + "torch_version", + "is_debug_build", + "cuda_compiled_version", + "gcc_version", + "clang_version", + "cmake_version", + "os", + "libc_version", + "python_version", + "python_platform", + "is_cuda_available", + "cuda_runtime_version", + "cuda_module_loading", + "nvidia_driver_version", + "nvidia_gpu_models", + "cudnn_version", + "is_xpu_available", + "pip_version", # 'pip' or 'pip3' + "pip_packages", + "conda_packages", + "hip_compiled_version", + "hip_runtime_version", + "miopen_runtime_version", + "caching_allocator_config", + "is_xnnpack_available", + "cpu_info", + ], +) COMMON_PATTERNS = [ "torch", @@ -116,12 +120,13 @@ PIP_PATTERNS = [ def run(command): """Return (return-code, stdout, stderr).""" shell = True if type(command) is str else False - p = subprocess.Popen(command, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, shell=shell) + p = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell + ) raw_output, raw_err = p.communicate() rc = p.returncode - if get_platform() == 'win32': - enc = 'oem' + if get_platform() == "win32": + enc = "oem" else: enc = locale.getpreferredencoding() output = raw_output.decode(enc) @@ -147,18 +152,19 @@ def run_and_parse_first_match(run_lambda, command, regex): return None return match.group(1) + def run_and_return_first_line(run_lambda, command): """Run command using run_lambda and returns first line if output is not empty.""" rc, out, _ = run_lambda(command) if rc != 0: return None - return out.split('\n')[0] + return out.split("\n")[0] def get_conda_packages(run_lambda, patterns=None): if patterns is None: patterns = CONDA_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + ONEAPI_PATTERNS - conda = os.environ.get('CONDA_EXE', 'conda') + conda = os.environ.get("CONDA_EXE", "conda") out = run_and_read_all(run_lambda, "{} list".format(conda)) if out is None: return out @@ -166,32 +172,40 @@ def get_conda_packages(run_lambda, patterns=None): return "\n".join( line for line in out.splitlines() - if not line.startswith("#") - and any(name in line for name in patterns) + if not line.startswith("#") and any(name in line for name in patterns) ) + def get_gcc_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') + return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)") + def get_clang_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)') + return run_and_parse_first_match( + run_lambda, "clang --version", r"clang version (.*)" + ) def get_cmake_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)') + return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)") def get_nvidia_driver_version(run_lambda): - if get_platform() == 'darwin': - cmd = 'kextstat | grep -i cuda' - return run_and_parse_first_match(run_lambda, cmd, - r'com[.]nvidia[.]CUDA [(](.*?)[)]') + if get_platform() == "darwin": + cmd = "kextstat | grep -i cuda" + return run_and_parse_first_match( + run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]" + ) smi = get_nvidia_smi() - return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ') + return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ") def get_gpu_info(run_lambda): - if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None): + if get_platform() == "darwin" or ( + TORCH_AVAILABLE + and hasattr(torch.version, "hip") + and torch.version.hip is not None + ): if TORCH_AVAILABLE and torch.cuda.is_available(): if torch.version.hip is not None: prop = torch.cuda.get_device_properties(0) @@ -204,42 +218,42 @@ def get_gpu_info(run_lambda): return torch.cuda.get_device_name(None) + gcnArch return None smi = get_nvidia_smi() - uuid_regex = re.compile(r' \(UUID: .+?\)') - rc, out, _ = run_lambda(smi + ' -L') + uuid_regex = re.compile(r" \(UUID: .+?\)") + rc, out, _ = run_lambda(smi + " -L") if rc != 0: return None # Anonymize GPUs by removing their UUID - return re.sub(uuid_regex, '', out) + return re.sub(uuid_regex, "", out) def get_running_cuda_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)') + return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)") def get_cudnn_version(run_lambda): """Return a list of libcudnn.so; it's hard to tell which one is being used.""" - if get_platform() == 'win32': - system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") - where_cmd = os.path.join(system_root, 'System32', 'where') + if get_platform() == "win32": + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%") + where_cmd = os.path.join(system_root, "System32", "where") cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) - elif get_platform() == 'darwin': + elif get_platform() == "darwin": # CUDA libraries and drivers can be found in /usr/local/cuda/. See # https://docs.nvidia.com/cuda/archive/9.0/cuda-installation-guide-mac-os-x/index.html#installation # https://docs.nvidia.com/deeplearning/cudnn/installation/latest/ # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. - cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' + cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*" else: cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' rc, out, _ = run_lambda(cudnn_cmd) # find will return 1 if there are permission errors or if not found if len(out) == 0 or (rc != 1 and rc != 0): - l = os.environ.get('CUDNN_LIBRARY') + l = os.environ.get("CUDNN_LIBRARY") if l is not None and os.path.isfile(l): return os.path.realpath(l) return None files_set = set() - for fn in out.split('\n'): + for fn in out.split("\n"): fn = os.path.realpath(fn) # eliminate symbolic links if os.path.isfile(fn): files_set.add(fn) @@ -249,18 +263,20 @@ def get_cudnn_version(run_lambda): files = sorted(files_set) if len(files) == 1: return files[0] - result = '\n'.join(files) - return 'Probably one of the following:\n{}'.format(result) + result = "\n".join(files) + return "Probably one of the following:\n{}".format(result) def get_nvidia_smi(): # Note: nvidia-smi is currently available only on Windows and Linux - smi = 'nvidia-smi' - if get_platform() == 'win32': - system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files') - legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi) - new_path = os.path.join(system_root, 'System32', smi) + smi = "nvidia-smi" + if get_platform() == "win32": + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files") + legacy_path = os.path.join( + program_files_root, "NVIDIA Corporation", "NVSMI", smi + ) + new_path = os.path.join(system_root, "System32", smi) smis = [new_path, legacy_path] for candidate_smi in smis: if os.path.exists(candidate_smi): @@ -411,7 +427,9 @@ def get_intel_gpu_detected(run_lambda): if device_count == 0: return "N/A" - devices = [f"* [{i}] {torch.xpu.get_device_properties(i)}" for i in range(device_count)] + devices = [ + f"* [{i}] {torch.xpu.get_device_properties(i)}" for i in range(device_count) + ] return "\n".join(devices) @@ -490,11 +508,12 @@ def get_intel_gpu_detected(run_lambda): # ProcessorType=3 # Revision=27142 + def get_cpu_info(run_lambda): - rc, out, err = 0, '', '' - if get_platform() == 'linux': - rc, out, err = run_lambda('lscpu') - elif get_platform() == 'win32': + rc, out, err = 0, "", "" + if get_platform() == "linux": + rc, out, err = run_lambda("lscpu") + elif get_platform() == "win32": rc, out, err = run_lambda( 'powershell.exe "gwmi -Class Win32_Processor | Select-Object -Property Name,Manufacturer,Family,\ Architecture,ProcessorType,DeviceID,CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision\ @@ -514,9 +533,9 @@ def get_cpu_info(run_lambda): lst.append(out) lst.append(str(e)) out = "\n".join(lst) - elif get_platform() == 'darwin': + elif get_platform() == "darwin": rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") - cpu_info = 'None' + cpu_info = "None" if rc == 0: cpu_info = out else: @@ -525,20 +544,20 @@ def get_cpu_info(run_lambda): def get_platform(): - if sys.platform.startswith('linux'): - return 'linux' - elif sys.platform.startswith('win32'): - return 'win32' - elif sys.platform.startswith('cygwin'): - return 'cygwin' - elif sys.platform.startswith('darwin'): - return 'darwin' + if sys.platform.startswith("linux"): + return "linux" + elif sys.platform.startswith("win32"): + return "win32" + elif sys.platform.startswith("cygwin"): + return "cygwin" + elif sys.platform.startswith("darwin"): + return "darwin" else: return sys.platform def get_mac_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)') + return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)") def get_windows_version(run_lambda): @@ -556,39 +575,43 @@ def get_windows_version(run_lambda): def get_lsb_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)') + return run_and_parse_first_match( + run_lambda, "lsb_release -a", r"Description:\t(.*)" + ) def check_release_file(run_lambda): - return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', - r'PRETTY_NAME="(.*)"') + return run_and_parse_first_match( + run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"' + ) def get_os(run_lambda): from platform import machine + platform = get_platform() if platform in ["win32", "cygwin"]: return get_windows_version(run_lambda) - if platform == 'darwin': + if platform == "darwin": version = get_mac_version(run_lambda) if version is None: return None - return 'macOS {} ({})'.format(version, machine()) + return "macOS {} ({})".format(version, machine()) - if platform == 'linux': + if platform == "linux": # Ubuntu/Debian based desc = get_lsb_version(run_lambda) if desc is not None: - return '{} ({})'.format(desc, machine()) + return "{} ({})".format(desc, machine()) # Try reading /etc/*-release desc = check_release_file(run_lambda) if desc is not None: - return '{} ({})'.format(desc, machine()) + return "{} ({})".format(desc, machine()) - return '{} ({})'.format(platform, machine()) + return "{} ({})".format(platform, machine()) # Unknown platform return platform @@ -596,14 +619,16 @@ def get_os(run_lambda): def get_python_platform(): import platform + return platform.platform() def get_libc_version(): import platform - if get_platform() != 'linux': - return 'N/A' - return '-'.join(platform.libc_ver()) + + if get_platform() != "linux": + return "N/A" + return "-".join(platform.libc_ver()) def get_pip_packages(run_lambda, patterns=None): @@ -611,35 +636,35 @@ def get_pip_packages(run_lambda, patterns=None): if patterns is None: patterns = PIP_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + ONEAPI_PATTERNS - pip_version = 'pip3' if sys.version_info.major == 3 else 'pip' + pip_version = "pip3" if sys.version_info.major == 3 else "pip" - os.environ['PIP_DISABLE_PIP_VERSION_CHECK'] = '1' + os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1" # People generally have pip as `pip` or `pip3` # But here it is invoked as `python -mpip` - out = run_and_read_all(run_lambda, [sys.executable, '-mpip', 'list', '--format=freeze']) + out = run_and_read_all( + run_lambda, [sys.executable, "-mpip", "list", "--format=freeze"] + ) if out is None: return pip_version, out - filtered_out = '\n'.join( - line - for line in out.splitlines() - if any(name in line for name in patterns) + filtered_out = "\n".join( + line for line in out.splitlines() if any(name in line for name in patterns) ) return pip_version, filtered_out def get_cachingallocator_config(): - ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '') + ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") if not ca_config: - ca_config = os.environ.get('PYTORCH_HIP_ALLOC_CONF', '') + ca_config = os.environ.get("PYTORCH_HIP_ALLOC_CONF", "") return ca_config def get_cuda_module_loading_config(): if TORCH_AVAILABLE and torch.cuda.is_available(): torch.cuda.init() - config = os.environ.get('CUDA_MODULE_LOADING', '') + config = os.environ.get("CUDA_MODULE_LOADING", "") return config else: return "N/A" @@ -648,10 +673,12 @@ def get_cuda_module_loading_config(): def is_xnnpack_available(): if TORCH_AVAILABLE: import torch.backends.xnnpack + return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] else: return "N/A" + def get_env_info(): """ Collects environment information to aid in debugging. @@ -678,26 +705,31 @@ def get_env_info(): cuda_version_str = torch.version.cuda xpu_available_str = str(torch.xpu.is_available()) if torch.xpu.is_available(): - xpu_available_str = f'{xpu_available_str}\n' + \ - f'XPU used to build PyTorch: {torch.version.xpu}\n' + \ - f'Intel GPU driver version:\n{get_intel_gpu_driver_version(run_lambda)}\n' + \ - f'Intel GPU models onboard:\n{get_intel_gpu_onboard(run_lambda)}\n' + \ - f'Intel GPU models detected:\n{get_intel_gpu_detected(run_lambda)}' - if not hasattr(torch.version, 'hip') or torch.version.hip is None: # cuda version - hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + xpu_available_str = ( + f"{xpu_available_str}\n" + + f"XPU used to build PyTorch: {torch.version.xpu}\n" + + f"Intel GPU driver version:\n{get_intel_gpu_driver_version(run_lambda)}\n" + + f"Intel GPU models onboard:\n{get_intel_gpu_onboard(run_lambda)}\n" + + f"Intel GPU models detected:\n{get_intel_gpu_detected(run_lambda)}" + ) + if ( + not hasattr(torch.version, "hip") or torch.version.hip is None + ): # cuda version + hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" else: # HIP version + def get_version_or_na(cfg, prefix): _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] - return _lst[0] if _lst else 'N/A' + return _lst[0] if _lst else "N/A" - cfg = torch._C._show_config().split('\n') - hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime') - miopen_runtime_version = get_version_or_na(cfg, 'MIOpen') - cuda_version_str = 'N/A' + cfg = torch._C._show_config().split("\n") + hip_runtime_version = get_version_or_na(cfg, "HIP Runtime") + miopen_runtime_version = get_version_or_na(cfg, "MIOpen") + cuda_version_str = "N/A" hip_compiled_version = torch.version.hip else: - version_str = debug_mode_str = cuda_available_str = cuda_version_str = xpu_available_str = 'N/A' - hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + version_str = debug_mode_str = cuda_available_str = cuda_version_str = xpu_available_str = "N/A" # type: ignore[assignment] + hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" sys_version = sys.version.replace("\n", " ") @@ -706,7 +738,9 @@ def get_env_info(): return SystemEnv( torch_version=version_str, is_debug_build=debug_mode_str, - python_version='{} ({}-bit runtime)'.format(sys_version, sys.maxsize.bit_length() + 1), + python_version="{} ({}-bit runtime)".format( + sys_version, sys.maxsize.bit_length() + 1 + ), python_platform=get_python_platform(), is_cuda_available=cuda_available_str, cuda_compiled_version=cuda_version_str, @@ -732,6 +766,7 @@ def get_env_info(): cpu_info=get_cpu_info(run_lambda), ) + env_info_fmt = """ PyTorch version: {torch_version} Is debug build: {is_debug_build} @@ -767,14 +802,14 @@ Versions of relevant libraries: def pretty_str(envinfo): - def replace_nones(dct, replacement='Could not collect'): + def replace_nones(dct, replacement="Could not collect"): for key in dct.keys(): if dct[key] is not None: continue dct[key] = replacement return dct - def replace_bools(dct, true='Yes', false='No'): + def replace_bools(dct, true="Yes", false="No"): for key in dct.keys(): if dct[key] is True: dct[key] = true @@ -782,42 +817,48 @@ def pretty_str(envinfo): dct[key] = false return dct - def prepend(text, tag='[prepend]'): - lines = text.split('\n') + def prepend(text, tag="[prepend]"): + lines = text.split("\n") updated_lines = [tag + line for line in lines] - return '\n'.join(updated_lines) + return "\n".join(updated_lines) - def replace_if_empty(text, replacement='No relevant packages'): + def replace_if_empty(text, replacement="No relevant packages"): if text is not None and len(text) == 0: return replacement return text def maybe_start_on_next_line(string): # If `string` is multiline, prepend a \n to it. - if string is not None and len(string.split('\n')) > 1: - return '\n{}\n'.format(string) + if string is not None and len(string.split("\n")) > 1: + return "\n{}\n".format(string) return string mutable_dict = envinfo._asdict() # If nvidia_gpu_models is multiline, start on the next line - mutable_dict['nvidia_gpu_models'] = \ - maybe_start_on_next_line(envinfo.nvidia_gpu_models) + mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line( + envinfo.nvidia_gpu_models + ) # If the machine doesn't have CUDA, report some fields as 'No CUDA' dynamic_cuda_fields = [ - 'cuda_runtime_version', - 'nvidia_gpu_models', - 'nvidia_driver_version', + "cuda_runtime_version", + "nvidia_gpu_models", + "nvidia_driver_version", ] - all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] + all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"] all_dynamic_cuda_fields_missing = all( - mutable_dict[field] is None for field in dynamic_cuda_fields) - if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing: + mutable_dict[field] is None for field in dynamic_cuda_fields + ) + if ( + TORCH_AVAILABLE + and not torch.cuda.is_available() + and all_dynamic_cuda_fields_missing + ): for field in all_cuda_fields: - mutable_dict[field] = 'No CUDA' + mutable_dict[field] = "No CUDA" if envinfo.cuda_compiled_version is None: - mutable_dict['cuda_compiled_version'] = 'None' + mutable_dict["cuda_compiled_version"] = "None" # Replace True with Yes, False with No mutable_dict = replace_bools(mutable_dict) @@ -826,18 +867,20 @@ def pretty_str(envinfo): mutable_dict = replace_nones(mutable_dict) # If either of these are '', replace with 'No relevant packages' - mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages']) - mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages']) + mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"]) + mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"]) # Tag conda and pip packages with a prefix # If they were previously None, they'll show up as ie '[conda] Could not collect' - if mutable_dict['pip_packages']: - mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'], - '[{}] '.format(envinfo.pip_version)) - if mutable_dict['conda_packages']: - mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'], - '[conda] ') - mutable_dict['cpu_info'] = envinfo.cpu_info + if mutable_dict["pip_packages"]: + mutable_dict["pip_packages"] = prepend( + mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version) + ) + if mutable_dict["conda_packages"]: + mutable_dict["conda_packages"] = prepend( + mutable_dict["conda_packages"], "[conda] " + ) + mutable_dict["cpu_info"] = envinfo.cpu_info return env_info_fmt.format(**mutable_dict) @@ -861,18 +904,29 @@ def main(): output = get_pretty_env_info() print(output) - if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(torch.utils, '_crash_handler'): + if ( + TORCH_AVAILABLE + and hasattr(torch, "utils") + and hasattr(torch.utils, "_crash_handler") + ): minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR if sys.platform == "linux" and os.path.exists(minidump_dir): - dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)] + dumps = [ + os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir) + ] latest = max(dumps, key=os.path.getctime) ctime = os.path.getctime(latest) - creation_time = datetime.datetime.fromtimestamp(ctime).strftime('%Y-%m-%d %H:%M:%S') - msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \ - "if this is related to your bug please include it when you file a report ***" + creation_time = datetime.datetime.fromtimestamp(ctime).strftime( + "%Y-%m-%d %H:%M:%S" + ) + msg = ( + "\n*** Detected a minidump at {} created on {}, ".format( + latest, creation_time + ) + + "if this is related to your bug please include it when you file a report ***" + ) print(msg, file=sys.stderr) - -if __name__ == '__main__': +if __name__ == "__main__": main()