mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] remove torch deploy - conditionals (#158288)
This PR is part of the work to deprecate torch::deploy in OSS. Effectively it does 3 things to get started. 1. Remove test_deploy_interaction as we no longer need to worry about this 2. Remove all torch._running_with_deploy checks and use the False path always (surfaced 1) 3. Remove `USE_DEPLOY` and switch to the default path always Note: MyPy does fail on a bunch of things here as a bunch of older files are touched. It may be better to fix these things on a separate PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/158288 Approved by: https://github.com/albanD
This commit is contained in:
@ -544,62 +544,6 @@ class TestCustomOpTesting(CustomOpTestCaseBase):
|
||||
class TestCustomOp(CustomOpTestCaseBase):
|
||||
test_ns = "_test_custom_op"
|
||||
|
||||
def test_deploy_interaction(self):
|
||||
# run in a different process to avoid parallel issues when we monkeypatch torch._running_with_deploy
|
||||
script = """
|
||||
import torch
|
||||
torch._running_with_deploy = lambda: True
|
||||
|
||||
# creating the library is a no-op, so you can DEF multiple times
|
||||
m1 = torch.library.Library("mylib4392", "DEF") # noqa: TOR901
|
||||
m2 = torch.library.Library("mylib4392", "DEF") # noqa: TOR901
|
||||
|
||||
m = torch.library.Library("aten", "FRAGMENT") # noqa: TOR901
|
||||
|
||||
# define is a no-op
|
||||
m.define("foobarbaz9996(Tensor x) -> Tensor")
|
||||
assert not hasattr(torch.ops.aten, "foobarbaz9996"), "m.define should have been a noop"
|
||||
|
||||
def sin_override(x):
|
||||
raise AssertionError("m.impl should have been a noop")
|
||||
|
||||
# impl is a no-op
|
||||
m.impl("sin", sin_override, "CompositeImplicitAutograd")
|
||||
x = torch.randn(3)
|
||||
y = torch.sin(x)
|
||||
|
||||
# should be a no-op
|
||||
@torch.library.custom_op("mylib::foobar", mutates_args={})
|
||||
def foobar(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.sin()
|
||||
|
||||
# should be a no-op
|
||||
@foobar.register_fake
|
||||
def _(x):
|
||||
return torch.empty_like(x)
|
||||
|
||||
# should be a no-op
|
||||
m2.define("foobarbaz9996(Tensor x) -> Tensor")
|
||||
|
||||
# should be a no-op
|
||||
@torch.library.register_fake("mylib4392::foobarbaz9996")
|
||||
def _(x):
|
||||
return torch.empty_like(x)
|
||||
"""
|
||||
script = script.strip()
|
||||
env = os.environ.copy()
|
||||
try:
|
||||
subprocess.check_output(
|
||||
[sys.executable, "-c", script],
|
||||
stderr=subprocess.STDOUT,
|
||||
# On Windows, opening the subprocess with the default CWD makes `import torch`
|
||||
# fail, so just set CWD to this script's directory
|
||||
cwd=os.path.dirname(os.path.realpath(__file__)),
|
||||
env=env,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
self.fail(msg=("Subprocess exception:\n" + e.output.decode("utf-8")))
|
||||
|
||||
@requires_compile
|
||||
def test_functionalize_error(self):
|
||||
with torch.library._scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||
|
@ -3603,8 +3603,8 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||
@onlyCUDA
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float)
|
||||
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
|
||||
@unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU) or torch._running_with_deploy(),
|
||||
"Skipped for deploy and internal with remote GPUs")
|
||||
@unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU),
|
||||
"Skipped for internal with remote GPUs")
|
||||
def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
|
||||
from functools import partial
|
||||
from torch.sparse._triton_ops import bsr_dense_mm
|
||||
@ -3680,8 +3680,8 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.half)
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU or torch._running_with_deploy(),
|
||||
"Skipped for deploy and internal with remote GPUs")
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU,
|
||||
"Skipped for internal with remote GPUs")
|
||||
def test_triton_bsr_dense_bmm_error_messages(self, device, dtype):
|
||||
from torch.sparse._triton_ops import bsr_dense_mm
|
||||
|
||||
|
@ -34,20 +34,10 @@ from typing import (
|
||||
)
|
||||
from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs
|
||||
|
||||
from . import version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .types import Device, IntLikeType
|
||||
|
||||
|
||||
# multipy/deploy is setting this import before importing torch, this is the most # codespell:ignore multipy
|
||||
# reliable way we have to detect if we're running within deploy.
|
||||
# https://github.com/pytorch/multipy/blob/d60f34ad38c371e441fe7ffdb77a3c3dda5a5d19/multipy/runtime/interpreter/interpreter_impl.cpp#L134-L137 # codespell:ignore multipy # noqa: B950
|
||||
def _running_with_deploy() -> builtins.bool:
|
||||
return sys.modules.get("torch._meta_registrations", None) is object
|
||||
|
||||
|
||||
from torch._utils import (
|
||||
_functionalize_sync as _sync,
|
||||
_import_dotted_name,
|
||||
@ -60,14 +50,9 @@ from torch._utils_internal import (
|
||||
USE_GLOBAL_DEPS,
|
||||
USE_RTLD_GLOBAL_WITH_LIBTORCH,
|
||||
)
|
||||
from torch.torch_version import __version__ as __version__
|
||||
|
||||
|
||||
# TODO(torch_deploy) figure out how to freeze version.py in fbcode build
|
||||
if _running_with_deploy():
|
||||
__version__ = "torch-deploy-1.8"
|
||||
else:
|
||||
from torch.torch_version import __version__ as __version__
|
||||
|
||||
__all__ = [
|
||||
"BoolStorage",
|
||||
"BoolTensor",
|
||||
@ -317,7 +302,7 @@ def _preload_cuda_deps(lib_folder: str, lib_name: str) -> None:
|
||||
|
||||
# See Note [Global dependencies]
|
||||
def _load_global_deps() -> None:
|
||||
if _running_with_deploy() or platform.system() == "Windows":
|
||||
if platform.system() == "Windows":
|
||||
return
|
||||
|
||||
# Determine the file extension based on the platform
|
||||
@ -381,7 +366,7 @@ def _load_global_deps() -> None:
|
||||
|
||||
|
||||
if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and (
|
||||
_running_with_deploy() or platform.system() != "Windows"
|
||||
platform.system() != "Windows"
|
||||
):
|
||||
# Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a
|
||||
# few circumstances:
|
||||
@ -2082,7 +2067,7 @@ from torch.serialization import load, save
|
||||
|
||||
# Shared memory manager needs to know the exact location of manager executable
|
||||
def _manager_path():
|
||||
if _running_with_deploy() or platform.system() == "Windows":
|
||||
if platform.system() == "Windows":
|
||||
return b""
|
||||
path = get_file_path("torch", "bin", "torch_shm_manager")
|
||||
prepare_multiprocessing_environment(get_file_path("torch"))
|
||||
@ -2687,10 +2672,10 @@ from torch import fx as fx
|
||||
# Register MPS specific decomps
|
||||
torch.backends.mps._init()
|
||||
|
||||
if not _running_with_deploy():
|
||||
from torch import compiler as compiler
|
||||
from torch import compiler as compiler
|
||||
|
||||
class _TritonLibrary:
|
||||
|
||||
class _TritonLibrary:
|
||||
lib = torch.library.Library("triton", "DEF")
|
||||
ops_table: dict[tuple[str, str], _Callable] = {}
|
||||
|
||||
|
@ -49,29 +49,28 @@ Tensor = torch.Tensor
|
||||
__all__ = ["trace_wrapped"]
|
||||
|
||||
|
||||
if not torch._running_with_deploy():
|
||||
# torch.library.custom_op does not work with torch.deploy/multipy # codespell:ignore
|
||||
|
||||
@torch.library.custom_op("flex_lib::zeros_and_scatter", mutates_args=()) # type: ignore[misc]
|
||||
def zeros_and_scatter(
|
||||
@torch.library.custom_op("flex_lib::zeros_and_scatter", mutates_args=()) # type: ignore[misc]
|
||||
def zeros_and_scatter(
|
||||
shape: list[int],
|
||||
indices: list[Tensor],
|
||||
vals: Tensor,
|
||||
) -> Tensor:
|
||||
) -> Tensor:
|
||||
"""Custom Op so that we can register a custom lowering for the new_output + scatter in the backwards pass"""
|
||||
grad = torch.zeros(shape, device=vals.device, dtype=vals.dtype)
|
||||
return torch.ops.aten.index_put(grad, indices, vals, accumulate=True)
|
||||
|
||||
@zeros_and_scatter.register_fake # type: ignore[misc]
|
||||
def _(
|
||||
|
||||
@zeros_and_scatter.register_fake # type: ignore[misc]
|
||||
def _(
|
||||
shape: list[int],
|
||||
indices: list[Tensor],
|
||||
vals: Tensor,
|
||||
) -> Tensor:
|
||||
) -> Tensor:
|
||||
return vals.new_empty(shape)
|
||||
|
||||
@zeros_and_scatter.register_vmap # type: ignore[misc]
|
||||
def _(info, indims, shape, indices, value): # type: ignore[no-untyped-def]
|
||||
|
||||
@zeros_and_scatter.register_vmap # type: ignore[misc]
|
||||
def _(info, indims, shape, indices, value): # type: ignore[no-untyped-def]
|
||||
"""The batching rule is special in that it returns a tensor that is not batched"""
|
||||
indices_indims = indims[1]
|
||||
expanded_indices = []
|
||||
|
@ -2409,7 +2409,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._lowrank.svd_lowrank",
|
||||
"torch._preload_cuda_deps",
|
||||
"torch._register_device_module",
|
||||
"torch._running_with_deploy",
|
||||
"torch._utils._dummy_type",
|
||||
"torch._utils._flatten_dense_tensors",
|
||||
"torch._utils._unflatten_dense_tensors",
|
||||
|
@ -5,17 +5,15 @@ from torch import Tensor
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
if not torch._running_with_deploy():
|
||||
_test_lib_def = torch.library.Library("_inductor_test", "DEF")
|
||||
_test_lib_def.define(
|
||||
"realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag
|
||||
)
|
||||
_test_lib_def = torch.library.Library("_inductor_test", "DEF")
|
||||
_test_lib_def.define("realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag)
|
||||
|
||||
_test_lib_impl = torch.library.Library("_inductor_test", "IMPL")
|
||||
for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"):
|
||||
_test_lib_impl = torch.library.Library("_inductor_test", "IMPL")
|
||||
for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"):
|
||||
_test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key)
|
||||
|
||||
class Realize(Function):
|
||||
|
||||
class Realize(Function):
|
||||
@staticmethod
|
||||
def forward(ctx: object, x: Tensor) -> Tensor:
|
||||
return torch.ops._inductor_test.realize(x)
|
||||
@ -25,5 +23,6 @@ if not torch._running_with_deploy():
|
||||
def backward(ctx: Any, *grad_output: Any) -> Any:
|
||||
return grad_output[0]
|
||||
|
||||
def realize(x: Tensor) -> Tensor:
|
||||
|
||||
def realize(x: Tensor) -> Tensor:
|
||||
return Realize.apply(x)
|
||||
|
@ -595,10 +595,6 @@ class CustomOpDef:
|
||||
self._setup_context_fn = setup_context
|
||||
|
||||
def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None:
|
||||
if torch._running_with_deploy():
|
||||
utils.warn_deploy(stacklevel=5)
|
||||
return
|
||||
|
||||
lib = self._lib
|
||||
schema_str = self._name + self._schema
|
||||
cpp_schema = _C.parse_schema(schema_str)
|
||||
|
@ -2,7 +2,6 @@
|
||||
import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
@ -12,15 +11,6 @@ from torch import _C, _utils_internal
|
||||
from torch._ops import OpOverload
|
||||
|
||||
|
||||
def warn_deploy(stacklevel=3):
|
||||
warnings.warn(
|
||||
"Python torch.library APIs do nothing under torch::deploy (multipy). " # codespell:ignore multipy
|
||||
"Please instead use C++ custom operator registration APIs.",
|
||||
RuntimeWarning,
|
||||
stacklevel=stacklevel,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Kernel:
|
||||
"""Models a (function, source location)"""
|
||||
|
@ -1478,9 +1478,6 @@ class _Ops(types.ModuleType):
|
||||
Args:
|
||||
path (str): A path to a shared library to load.
|
||||
"""
|
||||
if torch._running_with_deploy():
|
||||
return
|
||||
|
||||
path = _utils_internal.resolve_library_path(path)
|
||||
with dl_open_guard():
|
||||
# Import the shared library into the process, thus running its
|
||||
|
@ -33,15 +33,9 @@ if os.environ.get("TORCH_COMPILE_STROBELIGHT", False):
|
||||
# use is the FB build environment, where this source file is replaced
|
||||
# by an equivalent.
|
||||
|
||||
if torch._running_with_deploy():
|
||||
# __file__ is meaningless in the context of frozen torch used in torch deploy.
|
||||
# setting empty torch_parent should allow below functions to operate without crashing,
|
||||
# but it's unclear if there is a valid use case for them in the context of deploy.
|
||||
torch_parent = ""
|
||||
else:
|
||||
if os.path.basename(os.path.dirname(__file__)) == "shared":
|
||||
if os.path.basename(os.path.dirname(__file__)) == "shared":
|
||||
torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
else:
|
||||
else:
|
||||
torch_parent = os.path.dirname(os.path.dirname(__file__))
|
||||
|
||||
|
||||
|
@ -331,13 +331,9 @@ void initLazyBindings(PyObject* module) {
|
||||
// So far this problem has only been observed internally, so we will just
|
||||
// block it off there.
|
||||
|
||||
#if !(defined(USE_DEPLOY))
|
||||
|
||||
// When libtorch_python is loaded, we register the python frame getter
|
||||
// otherwise, debug util simply omits python frames
|
||||
GetPythonFramesFunction() = GetPythonFrames;
|
||||
|
||||
#endif // USE_DEPLOY
|
||||
}
|
||||
|
||||
} // namespace torch::lazy
|
||||
|
@ -187,15 +187,6 @@ class PythonKernelHolder : public c10::OperatorKernel {
|
||||
|
||||
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
|
||||
py::gil_scoped_acquire g;
|
||||
// Jan 2024: We're slated to get rid of multipy, // codespell:ignore multipy
|
||||
// so stop forcing hermetic mode unconditionally in all situations when
|
||||
// you're using multipy. // codespell:ignore multipy
|
||||
// Eventually just delete this entirely. (Note that you may break
|
||||
// multipy anyway this way with dispatcher // codespell:ignore multipy
|
||||
// registered functions that require hermetic to be off.)
|
||||
#if defined(USE_DEPLOY)
|
||||
EnableHermeticPyObject g2;
|
||||
#endif
|
||||
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
||||
auto func =
|
||||
py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter()));
|
||||
|
@ -1693,9 +1693,6 @@ class _WrappedTritonKernel:
|
||||
|
||||
|
||||
def _register_triton_kernels():
|
||||
if torch._running_with_deploy():
|
||||
return
|
||||
|
||||
@_WrappedTritonKernel
|
||||
def kernel_impl(*args, **kwargs):
|
||||
from torch.sparse._triton_ops import bsr_dense_mm
|
||||
|
@ -19,21 +19,15 @@ except ImportError:
|
||||
from torch.utils._pytree import tree_map_only # type: ignore[no-redef]
|
||||
|
||||
|
||||
if torch._running_with_deploy():
|
||||
|
||||
def is_torchdynamo_compiling():
|
||||
"""Can't import torchdynamo in torchdeploy builds currently."""
|
||||
return False
|
||||
|
||||
else:
|
||||
try:
|
||||
try:
|
||||
from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
|
||||
except Exception:
|
||||
except Exception:
|
||||
warnings.warn(
|
||||
"Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly"
|
||||
)
|
||||
|
||||
def is_torchdynamo_compiling():
|
||||
def is_torchdynamo_compiling(): # type: ignore[misc]
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
@ -987,44 +981,41 @@ def _reduce_scatter_tensor_coalesced_native_meta(
|
||||
]
|
||||
|
||||
|
||||
if not torch._running_with_deploy():
|
||||
# Library MUST be defined at module scope or it doesn't work
|
||||
# Creating a "DEF" Library always crashes torch::deploy so we create our
|
||||
# Library instances here guarded against running inside it
|
||||
lib_impl = torch.library.Library("_c10d_functional", "IMPL")
|
||||
lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")
|
||||
lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta")
|
||||
lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta")
|
||||
lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta")
|
||||
lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta")
|
||||
lib_impl.impl(
|
||||
# Library MUST be defined at module scope or it doesn't work
|
||||
lib_impl = torch.library.Library("_c10d_functional", "IMPL")
|
||||
lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")
|
||||
lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta")
|
||||
lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta")
|
||||
lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta")
|
||||
lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta")
|
||||
lib_impl.impl(
|
||||
"all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta"
|
||||
)
|
||||
lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta")
|
||||
lib_impl.impl(
|
||||
)
|
||||
lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta")
|
||||
lib_impl.impl(
|
||||
"all_gather_into_tensor_coalesced",
|
||||
_all_gather_into_tensor_coalesced_native_meta,
|
||||
"Meta",
|
||||
)
|
||||
lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta")
|
||||
lib_impl.impl(
|
||||
)
|
||||
lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta")
|
||||
lib_impl.impl(
|
||||
"reduce_scatter_tensor_coalesced",
|
||||
_reduce_scatter_tensor_coalesced_native_meta,
|
||||
"Meta",
|
||||
)
|
||||
lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta")
|
||||
lib_impl.impl("broadcast", _broadcast_meta, "Meta")
|
||||
lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
|
||||
)
|
||||
lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta")
|
||||
lib_impl.impl("broadcast", _broadcast_meta, "Meta")
|
||||
lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
|
||||
|
||||
# mark these ops has side effect so that they won't be removed by DCE
|
||||
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default)
|
||||
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor)
|
||||
# mark these ops has side effect so that they won't be removed by DCE
|
||||
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default)
|
||||
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor)
|
||||
|
||||
# Register legacy ops for backward compatibility
|
||||
# TODO(yifu): remove these in functional collective beta release
|
||||
legacy_lib = torch.library.Library("c10d_functional", "DEF")
|
||||
legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL")
|
||||
ops_defs = [
|
||||
# Register legacy ops for backward compatibility
|
||||
# TODO(yifu): remove these in functional collective beta release
|
||||
legacy_lib = torch.library.Library("c10d_functional", "DEF")
|
||||
legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL")
|
||||
ops_defs = [
|
||||
"broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor",
|
||||
"all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
|
||||
"all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
|
||||
@ -1034,20 +1025,15 @@ if not torch._running_with_deploy():
|
||||
"reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
|
||||
"reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
|
||||
"all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950
|
||||
]
|
||||
]
|
||||
|
||||
my_module = sys.modules[__name__]
|
||||
for op_def in ops_defs:
|
||||
my_module = sys.modules[__name__]
|
||||
for op_def in ops_defs:
|
||||
op_name = op_def[0 : op_def.index("(")]
|
||||
backend_impl = getattr(fun_col_impl, f"_{op_name}")
|
||||
legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag)
|
||||
legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd")
|
||||
|
||||
else:
|
||||
warnings.warn(
|
||||
"PyTorch Distributed functional collectives do not work with torch::deploy."
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into
|
||||
|
@ -63,9 +63,8 @@ _META_FUNCTIONS = {
|
||||
"recv_any_source_": lambda *args: create_fakework(args, return_first_arg=False),
|
||||
}
|
||||
|
||||
if not torch._running_with_deploy():
|
||||
lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901
|
||||
for op, meta_func in _META_FUNCTIONS.items():
|
||||
lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901
|
||||
for op, meta_func in _META_FUNCTIONS.items():
|
||||
lib_impl.impl(op, meta_func, "Meta")
|
||||
|
||||
# List of collective operation functions including functional collectives
|
||||
|
@ -15,17 +15,8 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
||||
|
||||
_compiled_autograd_enabled: bool = False
|
||||
|
||||
if torch._running_with_deploy():
|
||||
|
||||
def detect_compiled_autograd():
|
||||
pass
|
||||
|
||||
def compiled_autograd_enabled():
|
||||
return False
|
||||
|
||||
else:
|
||||
|
||||
def detect_compiled_autograd():
|
||||
def detect_compiled_autograd():
|
||||
assert not torch.compiler.is_compiling(), (
|
||||
"`detect_compiled_autograd()` is designed to be called in eager mode"
|
||||
)
|
||||
@ -38,7 +29,8 @@ else:
|
||||
or ca.in_compiled_autograd_region
|
||||
)
|
||||
|
||||
def compiled_autograd_enabled():
|
||||
|
||||
def compiled_autograd_enabled():
|
||||
global _compiled_autograd_enabled
|
||||
return _compiled_autograd_enabled
|
||||
|
||||
|
@ -140,8 +140,7 @@ def copy__functionalize(tensor, data):
|
||||
torch.ops.fsdp.copy_.default(tensor_inner, data_inner)
|
||||
|
||||
|
||||
if not torch._running_with_deploy():
|
||||
torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default)
|
||||
torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default)
|
||||
|
||||
|
||||
class ShardedState(Enum):
|
||||
|
@ -25,10 +25,8 @@ from torch.distributed.distributed_c10d import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if not torch._running_with_deploy():
|
||||
|
||||
@torch.library.register_fake("_dtensor::shard_dim_alltoall")
|
||||
def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name):
|
||||
@torch.library.register_fake("_dtensor::shard_dim_alltoall")
|
||||
def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name):
|
||||
group_size = _get_group_size_by_name(group_name)
|
||||
stacked_list = [torch.empty_like(input) for _ in range(group_size)]
|
||||
group = _resolve_process_group(group_name)
|
||||
@ -40,13 +38,6 @@ if not torch._running_with_deploy():
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
else:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"PyTorch Distributed functional collectives do not work with torch::deploy."
|
||||
)
|
||||
|
||||
|
||||
def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim):
|
||||
if mesh.device_type == "cpu":
|
||||
|
@ -102,9 +102,6 @@ class Library:
|
||||
ns,
|
||||
" is a reserved namespace. Please try creating a library with another name.",
|
||||
)
|
||||
if torch._running_with_deploy():
|
||||
_library.utils.warn_deploy()
|
||||
return
|
||||
|
||||
frame = traceback.extract_stack(limit=3)[0]
|
||||
filename, lineno = frame.filename, frame.lineno
|
||||
@ -156,9 +153,6 @@ class Library:
|
||||
>>> my_lib = Library("mylib", "DEF")
|
||||
>>> my_lib.define("sum(Tensor self) -> Tensor")
|
||||
"""
|
||||
if torch._running_with_deploy():
|
||||
_library.utils.warn_deploy()
|
||||
return
|
||||
|
||||
# This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
|
||||
# AliasAnalysis type in C++
|
||||
@ -191,9 +185,6 @@ class Library:
|
||||
|
||||
def _register_fake(self, op_name, fn, _stacklevel=1, *, allow_override=False):
|
||||
r"""Registers the fake impl for an operator defined in the library."""
|
||||
if torch._running_with_deploy():
|
||||
_library.utils.warn_deploy()
|
||||
return
|
||||
|
||||
source = torch._library.utils.get_source(_stacklevel + 1)
|
||||
frame = sys._getframe(_stacklevel)
|
||||
@ -237,9 +228,6 @@ class Library:
|
||||
If it is a TorchDispatchMode, we expect fn to have the following signature:
|
||||
(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
|
||||
"""
|
||||
if torch._running_with_deploy():
|
||||
_library.utils.warn_deploy()
|
||||
return
|
||||
|
||||
qualname = f"{self.ns}::{op_name}"
|
||||
entry = torch._library.simple_registry.singleton.find(qualname)
|
||||
@ -259,9 +247,6 @@ class Library:
|
||||
>>> my_lib = Library("aten", "IMPL")
|
||||
>>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
|
||||
"""
|
||||
if torch._running_with_deploy():
|
||||
_library.utils.warn_deploy()
|
||||
return
|
||||
|
||||
if dispatch_key == "":
|
||||
dispatch_key = self.dispatch_key
|
||||
@ -324,9 +309,6 @@ class Library:
|
||||
>>> return self * (1 / other)
|
||||
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
|
||||
"""
|
||||
if torch._running_with_deploy():
|
||||
_library.utils.warn_deploy()
|
||||
return
|
||||
|
||||
if not callable(fn):
|
||||
raise TypeError(
|
||||
@ -409,9 +391,6 @@ class Library:
|
||||
>>> # ...
|
||||
>>> my_lib.fallback(fallback_kernel, "Autocast")
|
||||
"""
|
||||
if torch._running_with_deploy():
|
||||
_library.utils.warn_deploy()
|
||||
return
|
||||
|
||||
if dispatch_key == "":
|
||||
dispatch_key = self.dispatch_key
|
||||
|
@ -29,13 +29,7 @@ def set_module(obj, mod):
|
||||
obj.__module__ = mod
|
||||
|
||||
|
||||
if torch._running_with_deploy():
|
||||
# not valid inside torch_deploy interpreter, no paths exists for frozen modules
|
||||
cmake_prefix_path = None
|
||||
else:
|
||||
cmake_prefix_path = _osp.join(
|
||||
_osp.dirname(_osp.dirname(__file__)), "share", "cmake"
|
||||
)
|
||||
cmake_prefix_path = _osp.join(_osp.dirname(_osp.dirname(__file__)), "share", "cmake")
|
||||
|
||||
|
||||
def swap_tensors(t1, t2):
|
||||
|
@ -3,8 +3,6 @@ import importlib.util
|
||||
from types import ModuleType
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _check_module_exists(name: str) -> bool:
|
||||
r"""Returns if a top-level module with :attr:`name` exists *without**
|
||||
@ -22,11 +20,7 @@ def _check_module_exists(name: str) -> bool:
|
||||
|
||||
@functools.lru_cache
|
||||
def dill_available() -> bool:
|
||||
return (
|
||||
_check_module_exists("dill")
|
||||
# dill fails to import under torchdeploy
|
||||
and not torch._running_with_deploy()
|
||||
)
|
||||
return _check_module_exists("dill")
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
|
@ -6,49 +6,53 @@
|
||||
import datetime
|
||||
import json
|
||||
import locale
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
from typing import cast as _cast
|
||||
from collections import namedtuple
|
||||
from typing import cast as _cast
|
||||
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
TORCH_AVAILABLE = True
|
||||
except (ImportError, NameError, AttributeError, OSError):
|
||||
TORCH_AVAILABLE = False
|
||||
|
||||
# System Environment Information
|
||||
SystemEnv = namedtuple('SystemEnv', [
|
||||
'torch_version',
|
||||
'is_debug_build',
|
||||
'cuda_compiled_version',
|
||||
'gcc_version',
|
||||
'clang_version',
|
||||
'cmake_version',
|
||||
'os',
|
||||
'libc_version',
|
||||
'python_version',
|
||||
'python_platform',
|
||||
'is_cuda_available',
|
||||
'cuda_runtime_version',
|
||||
'cuda_module_loading',
|
||||
'nvidia_driver_version',
|
||||
'nvidia_gpu_models',
|
||||
'cudnn_version',
|
||||
'is_xpu_available',
|
||||
'pip_version', # 'pip' or 'pip3'
|
||||
'pip_packages',
|
||||
'conda_packages',
|
||||
'hip_compiled_version',
|
||||
'hip_runtime_version',
|
||||
'miopen_runtime_version',
|
||||
'caching_allocator_config',
|
||||
'is_xnnpack_available',
|
||||
'cpu_info',
|
||||
])
|
||||
SystemEnv = namedtuple(
|
||||
"SystemEnv",
|
||||
[
|
||||
"torch_version",
|
||||
"is_debug_build",
|
||||
"cuda_compiled_version",
|
||||
"gcc_version",
|
||||
"clang_version",
|
||||
"cmake_version",
|
||||
"os",
|
||||
"libc_version",
|
||||
"python_version",
|
||||
"python_platform",
|
||||
"is_cuda_available",
|
||||
"cuda_runtime_version",
|
||||
"cuda_module_loading",
|
||||
"nvidia_driver_version",
|
||||
"nvidia_gpu_models",
|
||||
"cudnn_version",
|
||||
"is_xpu_available",
|
||||
"pip_version", # 'pip' or 'pip3'
|
||||
"pip_packages",
|
||||
"conda_packages",
|
||||
"hip_compiled_version",
|
||||
"hip_runtime_version",
|
||||
"miopen_runtime_version",
|
||||
"caching_allocator_config",
|
||||
"is_xnnpack_available",
|
||||
"cpu_info",
|
||||
],
|
||||
)
|
||||
|
||||
COMMON_PATTERNS = [
|
||||
"torch",
|
||||
@ -116,12 +120,13 @@ PIP_PATTERNS = [
|
||||
def run(command):
|
||||
"""Return (return-code, stdout, stderr)."""
|
||||
shell = True if type(command) is str else False
|
||||
p = subprocess.Popen(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, shell=shell)
|
||||
p = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell
|
||||
)
|
||||
raw_output, raw_err = p.communicate()
|
||||
rc = p.returncode
|
||||
if get_platform() == 'win32':
|
||||
enc = 'oem'
|
||||
if get_platform() == "win32":
|
||||
enc = "oem"
|
||||
else:
|
||||
enc = locale.getpreferredencoding()
|
||||
output = raw_output.decode(enc)
|
||||
@ -147,18 +152,19 @@ def run_and_parse_first_match(run_lambda, command, regex):
|
||||
return None
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def run_and_return_first_line(run_lambda, command):
|
||||
"""Run command using run_lambda and returns first line if output is not empty."""
|
||||
rc, out, _ = run_lambda(command)
|
||||
if rc != 0:
|
||||
return None
|
||||
return out.split('\n')[0]
|
||||
return out.split("\n")[0]
|
||||
|
||||
|
||||
def get_conda_packages(run_lambda, patterns=None):
|
||||
if patterns is None:
|
||||
patterns = CONDA_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + ONEAPI_PATTERNS
|
||||
conda = os.environ.get('CONDA_EXE', 'conda')
|
||||
conda = os.environ.get("CONDA_EXE", "conda")
|
||||
out = run_and_read_all(run_lambda, "{} list".format(conda))
|
||||
if out is None:
|
||||
return out
|
||||
@ -166,32 +172,40 @@ def get_conda_packages(run_lambda, patterns=None):
|
||||
return "\n".join(
|
||||
line
|
||||
for line in out.splitlines()
|
||||
if not line.startswith("#")
|
||||
and any(name in line for name in patterns)
|
||||
if not line.startswith("#") and any(name in line for name in patterns)
|
||||
)
|
||||
|
||||
|
||||
def get_gcc_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)')
|
||||
return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)")
|
||||
|
||||
|
||||
def get_clang_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)')
|
||||
return run_and_parse_first_match(
|
||||
run_lambda, "clang --version", r"clang version (.*)"
|
||||
)
|
||||
|
||||
|
||||
def get_cmake_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)')
|
||||
return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)")
|
||||
|
||||
|
||||
def get_nvidia_driver_version(run_lambda):
|
||||
if get_platform() == 'darwin':
|
||||
cmd = 'kextstat | grep -i cuda'
|
||||
return run_and_parse_first_match(run_lambda, cmd,
|
||||
r'com[.]nvidia[.]CUDA [(](.*?)[)]')
|
||||
if get_platform() == "darwin":
|
||||
cmd = "kextstat | grep -i cuda"
|
||||
return run_and_parse_first_match(
|
||||
run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]"
|
||||
)
|
||||
smi = get_nvidia_smi()
|
||||
return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ')
|
||||
return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ")
|
||||
|
||||
|
||||
def get_gpu_info(run_lambda):
|
||||
if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None):
|
||||
if get_platform() == "darwin" or (
|
||||
TORCH_AVAILABLE
|
||||
and hasattr(torch.version, "hip")
|
||||
and torch.version.hip is not None
|
||||
):
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
if torch.version.hip is not None:
|
||||
prop = torch.cuda.get_device_properties(0)
|
||||
@ -204,42 +218,42 @@ def get_gpu_info(run_lambda):
|
||||
return torch.cuda.get_device_name(None) + gcnArch
|
||||
return None
|
||||
smi = get_nvidia_smi()
|
||||
uuid_regex = re.compile(r' \(UUID: .+?\)')
|
||||
rc, out, _ = run_lambda(smi + ' -L')
|
||||
uuid_regex = re.compile(r" \(UUID: .+?\)")
|
||||
rc, out, _ = run_lambda(smi + " -L")
|
||||
if rc != 0:
|
||||
return None
|
||||
# Anonymize GPUs by removing their UUID
|
||||
return re.sub(uuid_regex, '', out)
|
||||
return re.sub(uuid_regex, "", out)
|
||||
|
||||
|
||||
def get_running_cuda_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)')
|
||||
return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)")
|
||||
|
||||
|
||||
def get_cudnn_version(run_lambda):
|
||||
"""Return a list of libcudnn.so; it's hard to tell which one is being used."""
|
||||
if get_platform() == 'win32':
|
||||
system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
|
||||
cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%")
|
||||
where_cmd = os.path.join(system_root, 'System32', 'where')
|
||||
if get_platform() == "win32":
|
||||
system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
|
||||
cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%")
|
||||
where_cmd = os.path.join(system_root, "System32", "where")
|
||||
cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path)
|
||||
elif get_platform() == 'darwin':
|
||||
elif get_platform() == "darwin":
|
||||
# CUDA libraries and drivers can be found in /usr/local/cuda/. See
|
||||
# https://docs.nvidia.com/cuda/archive/9.0/cuda-installation-guide-mac-os-x/index.html#installation
|
||||
# https://docs.nvidia.com/deeplearning/cudnn/installation/latest/
|
||||
# Use CUDNN_LIBRARY when cudnn library is installed elsewhere.
|
||||
cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*'
|
||||
cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*"
|
||||
else:
|
||||
cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev'
|
||||
rc, out, _ = run_lambda(cudnn_cmd)
|
||||
# find will return 1 if there are permission errors or if not found
|
||||
if len(out) == 0 or (rc != 1 and rc != 0):
|
||||
l = os.environ.get('CUDNN_LIBRARY')
|
||||
l = os.environ.get("CUDNN_LIBRARY")
|
||||
if l is not None and os.path.isfile(l):
|
||||
return os.path.realpath(l)
|
||||
return None
|
||||
files_set = set()
|
||||
for fn in out.split('\n'):
|
||||
for fn in out.split("\n"):
|
||||
fn = os.path.realpath(fn) # eliminate symbolic links
|
||||
if os.path.isfile(fn):
|
||||
files_set.add(fn)
|
||||
@ -249,18 +263,20 @@ def get_cudnn_version(run_lambda):
|
||||
files = sorted(files_set)
|
||||
if len(files) == 1:
|
||||
return files[0]
|
||||
result = '\n'.join(files)
|
||||
return 'Probably one of the following:\n{}'.format(result)
|
||||
result = "\n".join(files)
|
||||
return "Probably one of the following:\n{}".format(result)
|
||||
|
||||
|
||||
def get_nvidia_smi():
|
||||
# Note: nvidia-smi is currently available only on Windows and Linux
|
||||
smi = 'nvidia-smi'
|
||||
if get_platform() == 'win32':
|
||||
system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
|
||||
program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files')
|
||||
legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi)
|
||||
new_path = os.path.join(system_root, 'System32', smi)
|
||||
smi = "nvidia-smi"
|
||||
if get_platform() == "win32":
|
||||
system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
|
||||
program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files")
|
||||
legacy_path = os.path.join(
|
||||
program_files_root, "NVIDIA Corporation", "NVSMI", smi
|
||||
)
|
||||
new_path = os.path.join(system_root, "System32", smi)
|
||||
smis = [new_path, legacy_path]
|
||||
for candidate_smi in smis:
|
||||
if os.path.exists(candidate_smi):
|
||||
@ -411,7 +427,9 @@ def get_intel_gpu_detected(run_lambda):
|
||||
if device_count == 0:
|
||||
return "N/A"
|
||||
|
||||
devices = [f"* [{i}] {torch.xpu.get_device_properties(i)}" for i in range(device_count)]
|
||||
devices = [
|
||||
f"* [{i}] {torch.xpu.get_device_properties(i)}" for i in range(device_count)
|
||||
]
|
||||
return "\n".join(devices)
|
||||
|
||||
|
||||
@ -490,11 +508,12 @@ def get_intel_gpu_detected(run_lambda):
|
||||
# ProcessorType=3
|
||||
# Revision=27142
|
||||
|
||||
|
||||
def get_cpu_info(run_lambda):
|
||||
rc, out, err = 0, '', ''
|
||||
if get_platform() == 'linux':
|
||||
rc, out, err = run_lambda('lscpu')
|
||||
elif get_platform() == 'win32':
|
||||
rc, out, err = 0, "", ""
|
||||
if get_platform() == "linux":
|
||||
rc, out, err = run_lambda("lscpu")
|
||||
elif get_platform() == "win32":
|
||||
rc, out, err = run_lambda(
|
||||
'powershell.exe "gwmi -Class Win32_Processor | Select-Object -Property Name,Manufacturer,Family,\
|
||||
Architecture,ProcessorType,DeviceID,CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision\
|
||||
@ -514,9 +533,9 @@ def get_cpu_info(run_lambda):
|
||||
lst.append(out)
|
||||
lst.append(str(e))
|
||||
out = "\n".join(lst)
|
||||
elif get_platform() == 'darwin':
|
||||
elif get_platform() == "darwin":
|
||||
rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string")
|
||||
cpu_info = 'None'
|
||||
cpu_info = "None"
|
||||
if rc == 0:
|
||||
cpu_info = out
|
||||
else:
|
||||
@ -525,20 +544,20 @@ def get_cpu_info(run_lambda):
|
||||
|
||||
|
||||
def get_platform():
|
||||
if sys.platform.startswith('linux'):
|
||||
return 'linux'
|
||||
elif sys.platform.startswith('win32'):
|
||||
return 'win32'
|
||||
elif sys.platform.startswith('cygwin'):
|
||||
return 'cygwin'
|
||||
elif sys.platform.startswith('darwin'):
|
||||
return 'darwin'
|
||||
if sys.platform.startswith("linux"):
|
||||
return "linux"
|
||||
elif sys.platform.startswith("win32"):
|
||||
return "win32"
|
||||
elif sys.platform.startswith("cygwin"):
|
||||
return "cygwin"
|
||||
elif sys.platform.startswith("darwin"):
|
||||
return "darwin"
|
||||
else:
|
||||
return sys.platform
|
||||
|
||||
|
||||
def get_mac_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)')
|
||||
return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)")
|
||||
|
||||
|
||||
def get_windows_version(run_lambda):
|
||||
@ -556,39 +575,43 @@ def get_windows_version(run_lambda):
|
||||
|
||||
|
||||
def get_lsb_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)')
|
||||
return run_and_parse_first_match(
|
||||
run_lambda, "lsb_release -a", r"Description:\t(.*)"
|
||||
)
|
||||
|
||||
|
||||
def check_release_file(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'cat /etc/*-release',
|
||||
r'PRETTY_NAME="(.*)"')
|
||||
return run_and_parse_first_match(
|
||||
run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"'
|
||||
)
|
||||
|
||||
|
||||
def get_os(run_lambda):
|
||||
from platform import machine
|
||||
|
||||
platform = get_platform()
|
||||
|
||||
if platform in ["win32", "cygwin"]:
|
||||
return get_windows_version(run_lambda)
|
||||
|
||||
if platform == 'darwin':
|
||||
if platform == "darwin":
|
||||
version = get_mac_version(run_lambda)
|
||||
if version is None:
|
||||
return None
|
||||
return 'macOS {} ({})'.format(version, machine())
|
||||
return "macOS {} ({})".format(version, machine())
|
||||
|
||||
if platform == 'linux':
|
||||
if platform == "linux":
|
||||
# Ubuntu/Debian based
|
||||
desc = get_lsb_version(run_lambda)
|
||||
if desc is not None:
|
||||
return '{} ({})'.format(desc, machine())
|
||||
return "{} ({})".format(desc, machine())
|
||||
|
||||
# Try reading /etc/*-release
|
||||
desc = check_release_file(run_lambda)
|
||||
if desc is not None:
|
||||
return '{} ({})'.format(desc, machine())
|
||||
return "{} ({})".format(desc, machine())
|
||||
|
||||
return '{} ({})'.format(platform, machine())
|
||||
return "{} ({})".format(platform, machine())
|
||||
|
||||
# Unknown platform
|
||||
return platform
|
||||
@ -596,14 +619,16 @@ def get_os(run_lambda):
|
||||
|
||||
def get_python_platform():
|
||||
import platform
|
||||
|
||||
return platform.platform()
|
||||
|
||||
|
||||
def get_libc_version():
|
||||
import platform
|
||||
if get_platform() != 'linux':
|
||||
return 'N/A'
|
||||
return '-'.join(platform.libc_ver())
|
||||
|
||||
if get_platform() != "linux":
|
||||
return "N/A"
|
||||
return "-".join(platform.libc_ver())
|
||||
|
||||
|
||||
def get_pip_packages(run_lambda, patterns=None):
|
||||
@ -611,35 +636,35 @@ def get_pip_packages(run_lambda, patterns=None):
|
||||
if patterns is None:
|
||||
patterns = PIP_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + ONEAPI_PATTERNS
|
||||
|
||||
pip_version = 'pip3' if sys.version_info.major == 3 else 'pip'
|
||||
pip_version = "pip3" if sys.version_info.major == 3 else "pip"
|
||||
|
||||
os.environ['PIP_DISABLE_PIP_VERSION_CHECK'] = '1'
|
||||
os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1"
|
||||
# People generally have pip as `pip` or `pip3`
|
||||
# But here it is invoked as `python -mpip`
|
||||
out = run_and_read_all(run_lambda, [sys.executable, '-mpip', 'list', '--format=freeze'])
|
||||
out = run_and_read_all(
|
||||
run_lambda, [sys.executable, "-mpip", "list", "--format=freeze"]
|
||||
)
|
||||
if out is None:
|
||||
return pip_version, out
|
||||
|
||||
filtered_out = '\n'.join(
|
||||
line
|
||||
for line in out.splitlines()
|
||||
if any(name in line for name in patterns)
|
||||
filtered_out = "\n".join(
|
||||
line for line in out.splitlines() if any(name in line for name in patterns)
|
||||
)
|
||||
|
||||
return pip_version, filtered_out
|
||||
|
||||
|
||||
def get_cachingallocator_config():
|
||||
ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '')
|
||||
ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
|
||||
if not ca_config:
|
||||
ca_config = os.environ.get('PYTORCH_HIP_ALLOC_CONF', '')
|
||||
ca_config = os.environ.get("PYTORCH_HIP_ALLOC_CONF", "")
|
||||
return ca_config
|
||||
|
||||
|
||||
def get_cuda_module_loading_config():
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
torch.cuda.init()
|
||||
config = os.environ.get('CUDA_MODULE_LOADING', '')
|
||||
config = os.environ.get("CUDA_MODULE_LOADING", "")
|
||||
return config
|
||||
else:
|
||||
return "N/A"
|
||||
@ -648,10 +673,12 @@ def get_cuda_module_loading_config():
|
||||
def is_xnnpack_available():
|
||||
if TORCH_AVAILABLE:
|
||||
import torch.backends.xnnpack
|
||||
|
||||
return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
|
||||
else:
|
||||
return "N/A"
|
||||
|
||||
|
||||
def get_env_info():
|
||||
"""
|
||||
Collects environment information to aid in debugging.
|
||||
@ -678,26 +705,31 @@ def get_env_info():
|
||||
cuda_version_str = torch.version.cuda
|
||||
xpu_available_str = str(torch.xpu.is_available())
|
||||
if torch.xpu.is_available():
|
||||
xpu_available_str = f'{xpu_available_str}\n' + \
|
||||
f'XPU used to build PyTorch: {torch.version.xpu}\n' + \
|
||||
f'Intel GPU driver version:\n{get_intel_gpu_driver_version(run_lambda)}\n' + \
|
||||
f'Intel GPU models onboard:\n{get_intel_gpu_onboard(run_lambda)}\n' + \
|
||||
f'Intel GPU models detected:\n{get_intel_gpu_detected(run_lambda)}'
|
||||
if not hasattr(torch.version, 'hip') or torch.version.hip is None: # cuda version
|
||||
hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A'
|
||||
xpu_available_str = (
|
||||
f"{xpu_available_str}\n"
|
||||
+ f"XPU used to build PyTorch: {torch.version.xpu}\n"
|
||||
+ f"Intel GPU driver version:\n{get_intel_gpu_driver_version(run_lambda)}\n"
|
||||
+ f"Intel GPU models onboard:\n{get_intel_gpu_onboard(run_lambda)}\n"
|
||||
+ f"Intel GPU models detected:\n{get_intel_gpu_detected(run_lambda)}"
|
||||
)
|
||||
if (
|
||||
not hasattr(torch.version, "hip") or torch.version.hip is None
|
||||
): # cuda version
|
||||
hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A"
|
||||
else: # HIP version
|
||||
|
||||
def get_version_or_na(cfg, prefix):
|
||||
_lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s]
|
||||
return _lst[0] if _lst else 'N/A'
|
||||
return _lst[0] if _lst else "N/A"
|
||||
|
||||
cfg = torch._C._show_config().split('\n')
|
||||
hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime')
|
||||
miopen_runtime_version = get_version_or_na(cfg, 'MIOpen')
|
||||
cuda_version_str = 'N/A'
|
||||
cfg = torch._C._show_config().split("\n")
|
||||
hip_runtime_version = get_version_or_na(cfg, "HIP Runtime")
|
||||
miopen_runtime_version = get_version_or_na(cfg, "MIOpen")
|
||||
cuda_version_str = "N/A"
|
||||
hip_compiled_version = torch.version.hip
|
||||
else:
|
||||
version_str = debug_mode_str = cuda_available_str = cuda_version_str = xpu_available_str = 'N/A'
|
||||
hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A'
|
||||
version_str = debug_mode_str = cuda_available_str = cuda_version_str = xpu_available_str = "N/A" # type: ignore[assignment]
|
||||
hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A"
|
||||
|
||||
sys_version = sys.version.replace("\n", " ")
|
||||
|
||||
@ -706,7 +738,9 @@ def get_env_info():
|
||||
return SystemEnv(
|
||||
torch_version=version_str,
|
||||
is_debug_build=debug_mode_str,
|
||||
python_version='{} ({}-bit runtime)'.format(sys_version, sys.maxsize.bit_length() + 1),
|
||||
python_version="{} ({}-bit runtime)".format(
|
||||
sys_version, sys.maxsize.bit_length() + 1
|
||||
),
|
||||
python_platform=get_python_platform(),
|
||||
is_cuda_available=cuda_available_str,
|
||||
cuda_compiled_version=cuda_version_str,
|
||||
@ -732,6 +766,7 @@ def get_env_info():
|
||||
cpu_info=get_cpu_info(run_lambda),
|
||||
)
|
||||
|
||||
|
||||
env_info_fmt = """
|
||||
PyTorch version: {torch_version}
|
||||
Is debug build: {is_debug_build}
|
||||
@ -767,14 +802,14 @@ Versions of relevant libraries:
|
||||
|
||||
|
||||
def pretty_str(envinfo):
|
||||
def replace_nones(dct, replacement='Could not collect'):
|
||||
def replace_nones(dct, replacement="Could not collect"):
|
||||
for key in dct.keys():
|
||||
if dct[key] is not None:
|
||||
continue
|
||||
dct[key] = replacement
|
||||
return dct
|
||||
|
||||
def replace_bools(dct, true='Yes', false='No'):
|
||||
def replace_bools(dct, true="Yes", false="No"):
|
||||
for key in dct.keys():
|
||||
if dct[key] is True:
|
||||
dct[key] = true
|
||||
@ -782,42 +817,48 @@ def pretty_str(envinfo):
|
||||
dct[key] = false
|
||||
return dct
|
||||
|
||||
def prepend(text, tag='[prepend]'):
|
||||
lines = text.split('\n')
|
||||
def prepend(text, tag="[prepend]"):
|
||||
lines = text.split("\n")
|
||||
updated_lines = [tag + line for line in lines]
|
||||
return '\n'.join(updated_lines)
|
||||
return "\n".join(updated_lines)
|
||||
|
||||
def replace_if_empty(text, replacement='No relevant packages'):
|
||||
def replace_if_empty(text, replacement="No relevant packages"):
|
||||
if text is not None and len(text) == 0:
|
||||
return replacement
|
||||
return text
|
||||
|
||||
def maybe_start_on_next_line(string):
|
||||
# If `string` is multiline, prepend a \n to it.
|
||||
if string is not None and len(string.split('\n')) > 1:
|
||||
return '\n{}\n'.format(string)
|
||||
if string is not None and len(string.split("\n")) > 1:
|
||||
return "\n{}\n".format(string)
|
||||
return string
|
||||
|
||||
mutable_dict = envinfo._asdict()
|
||||
|
||||
# If nvidia_gpu_models is multiline, start on the next line
|
||||
mutable_dict['nvidia_gpu_models'] = \
|
||||
maybe_start_on_next_line(envinfo.nvidia_gpu_models)
|
||||
mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line(
|
||||
envinfo.nvidia_gpu_models
|
||||
)
|
||||
|
||||
# If the machine doesn't have CUDA, report some fields as 'No CUDA'
|
||||
dynamic_cuda_fields = [
|
||||
'cuda_runtime_version',
|
||||
'nvidia_gpu_models',
|
||||
'nvidia_driver_version',
|
||||
"cuda_runtime_version",
|
||||
"nvidia_gpu_models",
|
||||
"nvidia_driver_version",
|
||||
]
|
||||
all_cuda_fields = dynamic_cuda_fields + ['cudnn_version']
|
||||
all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"]
|
||||
all_dynamic_cuda_fields_missing = all(
|
||||
mutable_dict[field] is None for field in dynamic_cuda_fields)
|
||||
if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing:
|
||||
mutable_dict[field] is None for field in dynamic_cuda_fields
|
||||
)
|
||||
if (
|
||||
TORCH_AVAILABLE
|
||||
and not torch.cuda.is_available()
|
||||
and all_dynamic_cuda_fields_missing
|
||||
):
|
||||
for field in all_cuda_fields:
|
||||
mutable_dict[field] = 'No CUDA'
|
||||
mutable_dict[field] = "No CUDA"
|
||||
if envinfo.cuda_compiled_version is None:
|
||||
mutable_dict['cuda_compiled_version'] = 'None'
|
||||
mutable_dict["cuda_compiled_version"] = "None"
|
||||
|
||||
# Replace True with Yes, False with No
|
||||
mutable_dict = replace_bools(mutable_dict)
|
||||
@ -826,18 +867,20 @@ def pretty_str(envinfo):
|
||||
mutable_dict = replace_nones(mutable_dict)
|
||||
|
||||
# If either of these are '', replace with 'No relevant packages'
|
||||
mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages'])
|
||||
mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages'])
|
||||
mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"])
|
||||
mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"])
|
||||
|
||||
# Tag conda and pip packages with a prefix
|
||||
# If they were previously None, they'll show up as ie '[conda] Could not collect'
|
||||
if mutable_dict['pip_packages']:
|
||||
mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'],
|
||||
'[{}] '.format(envinfo.pip_version))
|
||||
if mutable_dict['conda_packages']:
|
||||
mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'],
|
||||
'[conda] ')
|
||||
mutable_dict['cpu_info'] = envinfo.cpu_info
|
||||
if mutable_dict["pip_packages"]:
|
||||
mutable_dict["pip_packages"] = prepend(
|
||||
mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version)
|
||||
)
|
||||
if mutable_dict["conda_packages"]:
|
||||
mutable_dict["conda_packages"] = prepend(
|
||||
mutable_dict["conda_packages"], "[conda] "
|
||||
)
|
||||
mutable_dict["cpu_info"] = envinfo.cpu_info
|
||||
return env_info_fmt.format(**mutable_dict)
|
||||
|
||||
|
||||
@ -861,18 +904,29 @@ def main():
|
||||
output = get_pretty_env_info()
|
||||
print(output)
|
||||
|
||||
if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(torch.utils, '_crash_handler'):
|
||||
if (
|
||||
TORCH_AVAILABLE
|
||||
and hasattr(torch, "utils")
|
||||
and hasattr(torch.utils, "_crash_handler")
|
||||
):
|
||||
minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR
|
||||
if sys.platform == "linux" and os.path.exists(minidump_dir):
|
||||
dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)]
|
||||
dumps = [
|
||||
os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)
|
||||
]
|
||||
latest = max(dumps, key=os.path.getctime)
|
||||
ctime = os.path.getctime(latest)
|
||||
creation_time = datetime.datetime.fromtimestamp(ctime).strftime('%Y-%m-%d %H:%M:%S')
|
||||
msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \
|
||||
"if this is related to your bug please include it when you file a report ***"
|
||||
creation_time = datetime.datetime.fromtimestamp(ctime).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
msg = (
|
||||
"\n*** Detected a minidump at {} created on {}, ".format(
|
||||
latest, creation_time
|
||||
)
|
||||
+ "if this is related to your bug please include it when you file a report ***"
|
||||
)
|
||||
print(msg, file=sys.stderr)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user