Files
pytorch/torch/_inductor/freezing_utils.py
Aaron Orenstein db4ce78d46 PEP585: More UP006 fixes (#146392)
This should be the final PR before we can enable RUFF UP006.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392
Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
2025-02-20 06:18:13 +00:00

56 lines
1.2 KiB
Python

import contextlib
import threading
from collections.abc import Generator
from typing import Any
import torch
_TLS = threading.local()
def _freezing_active() -> bool:
return getattr(_TLS, "freezing_active", False)
@contextlib.contextmanager
def enter_freezing() -> Generator[Any, None, None]:
"""
Context manager to designate when freezing is active.
"""
prev = _freezing_active()
_TLS.freezing_active = True
try:
yield
finally:
_TLS.freezing_active = prev
def record_has_frozen_params(gm: torch.fx.GraphModule) -> None:
"""
Mark the gm as having frozen params.
"""
gm._has_frozen_params = True # type: ignore[assignment]
def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
"""
Return True if the gm has frozen parameters.
"""
return getattr(gm, "_has_frozen_params", False)
def maybe_set_is_frozen_param(t: torch.Tensor) -> None:
"""
Mark the provided tensor as a frozen param if freezing is active.
"""
if _freezing_active():
t._is_frozen_param = True # type: ignore[attr-defined]
def is_frozen_param(t: torch.Tensor) -> bool:
"""
Return True if the tensor is a frozen param.
"""
return getattr(t, "_is_frozen_param", False)