Files
pytorch/torch/_utils_internal.py
PyTorch MergeBot 47956196d9 Revert "Call internal log_compilation_event if it exists (#164855)"
This reverts commit 98a081a24c22072362dc536afd39a469e28939d4.

Reverted https://github.com/pytorch/pytorch/pull/164855 on behalf of https://github.com/albanD due to We should not land this kind of code in core ([comment](https://github.com/pytorch/pytorch/pull/164855#issuecomment-3387692988))
2025-10-09 22:38:45 +00:00

380 lines
11 KiB
Python

# mypy: allow-untyped-defs
import functools
import logging
import os
import sys
import tempfile
import typing_extensions
from collections.abc import Callable
from typing import Any, Optional, TypeVar
from typing_extensions import ParamSpec
import torch
from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler
_T = TypeVar("_T")
_P = ParamSpec("_P")
log = logging.getLogger(__name__)
if os.environ.get("TORCH_COMPILE_STROBELIGHT", False):
import shutil
if not shutil.which("strobeclient"):
log.info(
"TORCH_COMPILE_STROBELIGHT is true, but seems like you are not on a FB machine."
)
else:
log.info("Strobelight profiler is enabled via environment variable")
StrobelightCompileTimeProfiler.enable()
# this arbitrary-looking assortment of functionality is provided here
# to have a central place for overridable behavior. The motivating
# use is the FB build environment, where this source file is replaced
# by an equivalent.
if os.path.basename(os.path.dirname(__file__)) == "shared":
torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
else:
torch_parent = os.path.dirname(os.path.dirname(__file__))
def get_file_path(*path_components: str) -> str:
return os.path.join(torch_parent, *path_components)
def get_file_path_2(*path_components: str) -> str:
return os.path.join(*path_components)
def get_writable_path(path: str) -> str:
if os.access(path, os.W_OK):
return path
return tempfile.mkdtemp(suffix=os.path.basename(path))
def prepare_multiprocessing_environment(path: str) -> None:
pass
def resolve_library_path(path: str) -> str:
return os.path.realpath(path)
def throw_abstract_impl_not_imported_error(opname, module, context):
if module in sys.modules:
raise NotImplementedError(
f"{opname}: We could not find the fake impl for this operator. "
)
else:
raise NotImplementedError(
f"{opname}: We could not find the fake impl for this operator. "
f"The operator specified that you may need to import the '{module}' "
f"Python module to load the fake impl. {context}"
)
# NB! This treats "skip" kwarg specially!!
def compile_time_strobelight_meta(
phase_name: str,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
def compile_time_strobelight_meta_inner(
function: Callable[_P, _T],
) -> Callable[_P, _T]:
@functools.wraps(function)
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _T:
if "skip" in kwargs and isinstance(
# pyrefly: ignore # unsupported-operation
skip := kwargs["skip"],
int,
):
# pyrefly: ignore # unbound-name
kwargs["skip"] = skip + 1
# This is not needed but we have it here to avoid having profile_compile_time
# in stack traces when profiling is not enabled.
if not StrobelightCompileTimeProfiler.enabled:
return function(*args, **kwargs)
return StrobelightCompileTimeProfiler.profile_compile_time(
function, phase_name, *args, **kwargs
)
return wrapper_function
return compile_time_strobelight_meta_inner
# Meta only, see
# https://www.internalfb.com/intern/wiki/ML_Workflow_Observability/User_Guides/Adding_instrumentation_to_your_code/
#
# This will cause an event to get logged to Scuba via the signposts API. You
# can view samples on the API at https://fburl.com/scuba/workflow_signpost/zh9wmpqs
# we log to subsystem "torch", and the category and name you provide here.
# Each of the arguments translate into a Scuba column. We're still figuring
# out local conventions in PyTorch, but category should be something like
# "dynamo" or "inductor", and name should be a specific string describing what
# kind of event happened.
#
# Killswitch is at
# https://www.internalfb.com/intern/justknobs/?name=pytorch%2Fsignpost#event
def signpost_event(category: str, name: str, parameters: dict[str, Any]):
log.info("%s %s: %r", category, name, parameters)
def add_mlhub_insight(category: str, insight: str, insight_description: str):
pass
def log_compilation_event(metrics):
log.info("%s", metrics)
def upload_graph(graph):
pass
def set_pytorch_distributed_envs_from_justknobs():
pass
def log_export_usage(**kwargs):
pass
def log_draft_export_usage(**kwargs):
pass
def log_trace_structured_event(*args, **kwargs) -> None:
pass
def log_cache_bypass(*args, **kwargs) -> None:
pass
def log_torchscript_usage(api: str, **kwargs):
_ = api
return
def check_if_torch_exportable():
return False
def export_training_ir_rollout_check() -> bool:
return True
def full_aoti_runtime_assert() -> bool:
return True
def log_torch_jit_trace_exportability(
api: str,
type_of_export: str,
export_outcome: str,
result: str,
):
_, _, _, _ = api, type_of_export, export_outcome, result
return
DISABLE_JUSTKNOBS = True
def justknobs_check(name: str, default: bool = True) -> bool:
"""
This function can be used to killswitch functionality in FB prod,
where you can toggle this value to False in JK without having to
do a code push. In OSS, we always have everything turned on all
the time, because downstream users can simply choose to not update
PyTorch. (If more fine-grained enable/disable is needed, we could
potentially have a map we lookup name in to toggle behavior. But
the point is that it's all tied to source code in OSS, since there's
no live server to query.)
This is the bare minimum functionality I needed to do some killswitches.
We have a more detailed plan at
https://docs.google.com/document/d/1Ukerh9_42SeGh89J-tGtecpHBPwGlkQ043pddkKb3PU/edit
In particular, in some circumstances it may be necessary to read in
a knob once at process start, and then use it consistently for the
rest of the process. Future functionality will codify these patterns
into a better high level API.
WARNING: Do NOT call this function at module import time, JK is not
fork safe and you will break anyone who forks the process and then
hits JK again.
"""
return default
def justknobs_getval_int(name: str) -> int:
"""
Read warning on justknobs_check
"""
return 0
def is_fb_unit_test() -> bool:
return False
@functools.cache
def max_clock_rate():
"""
unit: MHz
"""
if not torch.version.hip:
from triton.testing import nvsmi
return nvsmi(["clocks.max.sm"])[0]
else:
# Manually set max-clock speeds on ROCm until equivalent nvmsi
# functionality in triton.testing or via pyamdsmi enablement. Required
# for test_snode_runtime unit tests.
gcn_arch = str(torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0])
if "gfx94" in gcn_arch:
return 1700
elif "gfx90a" in gcn_arch:
return 1700
elif "gfx908" in gcn_arch:
return 1502
elif "gfx12" in gcn_arch:
return 1700
elif "gfx11" in gcn_arch:
return 1700
elif "gfx103" in gcn_arch:
return 1967
elif "gfx101" in gcn_arch:
return 1144
elif "gfx95" in gcn_arch:
return 1700 # TODO: placeholder, get actual value
else:
return 1100
def get_mast_job_name_version() -> Optional[tuple[str, int]]:
return None
TEST_MASTER_ADDR = "127.0.0.1"
TEST_MASTER_PORT = 29500
# USE_GLOBAL_DEPS controls whether __init__.py tries to load
# libtorch_global_deps, see Note [Global dependencies]
USE_GLOBAL_DEPS = True
# USE_RTLD_GLOBAL_WITH_LIBTORCH controls whether __init__.py tries to load
# _C.so with RTLD_GLOBAL during the call to dlopen.
USE_RTLD_GLOBAL_WITH_LIBTORCH = False
# If an op was defined in C++ and extended from Python using the
# torch.library.register_fake, returns if we require that there be a
# m.set_python_module("mylib.ops") call from C++ that associates
# the C++ op with a python module.
REQUIRES_SET_PYTHON_MODULE = False
def maybe_upload_prof_stats_to_manifold(profile_path: str) -> Optional[str]:
print("Uploading profile stats (fb-only otherwise no-op)")
return None
def log_chromium_event_internal(
event: dict[str, Any],
stack: list[str],
logger_uuid: str,
start_time_ns: int,
):
return None
def record_chromium_event_internal(
event: dict[str, Any],
):
return None
def profiler_allow_cudagraph_cupti_lazy_reinit_cuda12():
return True
def deprecated():
"""
When we deprecate a function that might still be in use, we make it internal
by adding a leading underscore. This decorator is used with a private function,
and creates a public alias without the leading underscore, but has a deprecation
warning. This tells users "THIS FUNCTION IS DEPRECATED, please use something else"
without breaking them, however, if they still really really want to use the
deprecated function without the warning, they can do so by using the internal
function name.
"""
def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
# Validate naming convention - single leading underscore, not dunder
if not (func.__name__.startswith("_")):
raise ValueError(
"@deprecate must decorate a function whose name "
"starts with a single leading underscore (e.g. '_foo') as the api should be considered internal for deprecation."
)
public_name = func.__name__[1:] # drop exactly one leading underscore
module = sys.modules[func.__module__]
# Don't clobber an existing symbol accidentally.
if hasattr(module, public_name):
raise RuntimeError(
f"Cannot create alias '{public_name}' -> symbol already exists in {module.__name__}. \
Please rename it or consult a pytorch developer on what to do"
)
warning_msg = f"{func.__name__[1:]} is DEPRECATED, please consider using an alternative API(s). "
# public deprecated alias
alias = typing_extensions.deprecated(
# pyrefly: ignore # bad-argument-type
warning_msg,
category=UserWarning,
stacklevel=1,
)(func)
alias.__name__ = public_name
# Adjust qualname if nested inside a class or another function
if "." in func.__qualname__:
alias.__qualname__ = func.__qualname__.rsplit(".", 1)[0] + "." + public_name
else:
alias.__qualname__ = public_name
setattr(module, public_name, alias)
return func
return decorator
def get_default_numa_options():
"""
When using elastic agent, if no numa options are provided, we will use these
as the default.
For external use cases, we return None, i.e. no numa binding. If you would like
to use torch's automatic numa binding capabilities, you should provide
NumaOptions to your launch config directly or use the numa binding option
available in torchrun.
Must return None or NumaOptions, but not specifying to avoid circular import.
"""
return None
def log_triton_builds(fail: Optional[str]):
pass
def find_compile_subproc_binary() -> Optional[str]:
"""
Allows overriding the binary used for subprocesses
"""
return None