Revert "Remove python workaround for ContextDecorator (#167049)"

This reverts commit e20ca3bc2e6ef9935c782fe548348f81fabc5bd7.

Reverted https://github.com/pytorch/pytorch/pull/167049 on behalf of https://github.com/jeanschmidt due to breaks internal tests see D87120562, @Skylion007 please thelp the author get this PR merged ([comment](https://github.com/pytorch/pytorch/pull/167049#issuecomment-3542847796))
This commit is contained in:
PyTorch MergeBot
2025-11-17 16:41:26 +00:00
parent 4c152a71ad
commit 39ebab1dd9
3 changed files with 34 additions and 18 deletions

View File

@ -4,7 +4,6 @@
# ruff: noqa: F401
from collections.abc import Callable, Iterable, Iterator, Sequence
from contextlib import AbstractContextManager
from enum import Enum, IntEnum
from pathlib import Path
from types import EllipsisType
@ -231,8 +230,8 @@ ${dtype_class_hints}
class layout: ...
# Defined in torch/csrc/utils/disable_torch_function.cpp
def DisableTorchFunction() -> AbstractContextManager: ...
def DisableTorchFunctionSubclass() -> AbstractContextManager: ...
def DisableTorchFunction(): ...
def DisableTorchFunctionSubclass(): ...
# Defined in torch/csrc/utils/tensor_layouts.cpp
strided: layout = ...

View File

@ -52,7 +52,26 @@ __all__ = [
"MemRecordsAcc",
]
from contextlib import ContextDecorator
try:
# Available in Python >= 3.2
from contextlib import ContextDecorator as _ContextDecorator
except ImportError:
import functools
class _ContextDecorator: # type: ignore[no-redef]
def __enter__(self):
raise NotImplementedError
def __exit__(self, exc_type, exc_val, exc_tb):
raise NotImplementedError
def __call__(self, func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
with self:
return func(*args, **kwargs)
return wrapped
# global python state - whether profiler is currently enabled
@ -209,12 +228,12 @@ class profile:
FutureWarning,
stacklevel=2,
)
self.use_device: str | None = "cuda"
self.use_device: Optional[str] = "cuda"
else:
self.use_device = use_device
# TODO Consider changing _function_events into data structure with size cap
self._function_events: EventList | None = None
self._old_function_events: EventList | None = None
self._function_events: Optional[EventList] = None
self._old_function_events: Optional[EventList] = None
# Function event processing is done lazily
self._needs_processing = False
self.entered = False
@ -229,7 +248,7 @@ class profile:
if experimental_config is None:
experimental_config = _ExperimentalConfig()
self.experimental_config = experimental_config
self.kineto_results: _ProfilerResult | None = None
self.kineto_results: Optional[_ProfilerResult] = None
self.profiling_start_time_ns = 0
self.profiling_end_time_ns = 0
self._stats = _ProfilerStats()
@ -725,7 +744,8 @@ class profile:
return all_function_events
class record_function(ContextDecorator):
# pyrefly: ignore [invalid-inheritance]
class record_function(_ContextDecorator):
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
Label will only appear if CPU activity tracing is enabled.
@ -764,13 +784,16 @@ class record_function(ContextDecorator):
"""
def __init__(self, name: str, args: str | None = None):
def __init__(self, name: str, args: Optional[str] = None):
self.name: str = name
self.args: str | None = args
self.args: Optional[str] = args
# Whether or not we should run record function's end callbacks when exiting.
self.run_callbacks_on_exit: bool = True
# TODO: TorchScript ignores standard type annotation here
# self.record: Optional["torch.classes.profiler._RecordFunction"] = None
self.record = torch.jit.annotate(
Optional[torch.classes.profiler._RecordFunction],
# pyrefly: ignore [not-a-type]
Optional["torch.classes.profiler._RecordFunction"],
None,
)

View File

@ -308,12 +308,6 @@ TypePtr ScriptTypeParser::parseTypeFromExprImpl(const Expr& expr) const {
if (auto custom_class_type = getCustomClass(*name)) {
return custom_class_type;
}
// Check if the type is a custom class. This is done by checking
// if type_name starts with "torch.classes."
if (name->find("torch.classes.") == 0) {
auto custom_class_type = getCustomClass("__torch__." + *name);
return custom_class_type;
}
throw ErrorReport(expr) << "Unknown type name '" << *name << "'";
}