mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-20 02:24:54 +08:00
Revert "Remove python workaround for ContextDecorator (#167049)"
This reverts commit e20ca3bc2e6ef9935c782fe548348f81fabc5bd7. Reverted https://github.com/pytorch/pytorch/pull/167049 on behalf of https://github.com/jeanschmidt due to breaks internal tests see D87120562, @Skylion007 please thelp the author get this PR merged ([comment](https://github.com/pytorch/pytorch/pull/167049#issuecomment-3542847796))
This commit is contained in:
@ -4,7 +4,6 @@
|
||||
# ruff: noqa: F401
|
||||
|
||||
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||
from contextlib import AbstractContextManager
|
||||
from enum import Enum, IntEnum
|
||||
from pathlib import Path
|
||||
from types import EllipsisType
|
||||
@ -231,8 +230,8 @@ ${dtype_class_hints}
|
||||
class layout: ...
|
||||
|
||||
# Defined in torch/csrc/utils/disable_torch_function.cpp
|
||||
def DisableTorchFunction() -> AbstractContextManager: ...
|
||||
def DisableTorchFunctionSubclass() -> AbstractContextManager: ...
|
||||
def DisableTorchFunction(): ...
|
||||
def DisableTorchFunctionSubclass(): ...
|
||||
|
||||
# Defined in torch/csrc/utils/tensor_layouts.cpp
|
||||
strided: layout = ...
|
||||
|
||||
@ -52,7 +52,26 @@ __all__ = [
|
||||
"MemRecordsAcc",
|
||||
]
|
||||
|
||||
from contextlib import ContextDecorator
|
||||
try:
|
||||
# Available in Python >= 3.2
|
||||
from contextlib import ContextDecorator as _ContextDecorator
|
||||
except ImportError:
|
||||
import functools
|
||||
|
||||
class _ContextDecorator: # type: ignore[no-redef]
|
||||
def __enter__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, func):
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
# global python state - whether profiler is currently enabled
|
||||
@ -209,12 +228,12 @@ class profile:
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.use_device: str | None = "cuda"
|
||||
self.use_device: Optional[str] = "cuda"
|
||||
else:
|
||||
self.use_device = use_device
|
||||
# TODO Consider changing _function_events into data structure with size cap
|
||||
self._function_events: EventList | None = None
|
||||
self._old_function_events: EventList | None = None
|
||||
self._function_events: Optional[EventList] = None
|
||||
self._old_function_events: Optional[EventList] = None
|
||||
# Function event processing is done lazily
|
||||
self._needs_processing = False
|
||||
self.entered = False
|
||||
@ -229,7 +248,7 @@ class profile:
|
||||
if experimental_config is None:
|
||||
experimental_config = _ExperimentalConfig()
|
||||
self.experimental_config = experimental_config
|
||||
self.kineto_results: _ProfilerResult | None = None
|
||||
self.kineto_results: Optional[_ProfilerResult] = None
|
||||
self.profiling_start_time_ns = 0
|
||||
self.profiling_end_time_ns = 0
|
||||
self._stats = _ProfilerStats()
|
||||
@ -725,7 +744,8 @@ class profile:
|
||||
return all_function_events
|
||||
|
||||
|
||||
class record_function(ContextDecorator):
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
class record_function(_ContextDecorator):
|
||||
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
|
||||
Label will only appear if CPU activity tracing is enabled.
|
||||
|
||||
@ -764,13 +784,16 @@ class record_function(ContextDecorator):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, args: str | None = None):
|
||||
def __init__(self, name: str, args: Optional[str] = None):
|
||||
self.name: str = name
|
||||
self.args: str | None = args
|
||||
self.args: Optional[str] = args
|
||||
# Whether or not we should run record function's end callbacks when exiting.
|
||||
self.run_callbacks_on_exit: bool = True
|
||||
# TODO: TorchScript ignores standard type annotation here
|
||||
# self.record: Optional["torch.classes.profiler._RecordFunction"] = None
|
||||
self.record = torch.jit.annotate(
|
||||
Optional[torch.classes.profiler._RecordFunction],
|
||||
# pyrefly: ignore [not-a-type]
|
||||
Optional["torch.classes.profiler._RecordFunction"],
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@ -308,12 +308,6 @@ TypePtr ScriptTypeParser::parseTypeFromExprImpl(const Expr& expr) const {
|
||||
if (auto custom_class_type = getCustomClass(*name)) {
|
||||
return custom_class_type;
|
||||
}
|
||||
// Check if the type is a custom class. This is done by checking
|
||||
// if type_name starts with "torch.classes."
|
||||
if (name->find("torch.classes.") == 0) {
|
||||
auto custom_class_type = getCustomClass("__torch__." + *name);
|
||||
return custom_class_type;
|
||||
}
|
||||
|
||||
throw ErrorReport(expr) << "Unknown type name '" << *name << "'";
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user