mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix compilation and "import torch" issues for cpython 3.14 (#158184)
Beginning of process for 3.14 bringup. State of things from this PR: - Nothing too scary looking from the Dynamo CPython side, nothing we heavily rely on seems to be missing @williamwen42 - The existing check that makes torch.compile() nicely fail is working as expected. So all these empty functions shouldn't cause any weirdness. - The `__module__` update changes look suspicious, we should investigate what is the reason and impact of that, in particular for our public API checking @jbschlosser - Leaving the weakref.py thread safety change as a follow up to keep this a bit simpler. I vendored the whole struct in the meantime FYI @ezyang EDIT: The `__module__` change is even more cursed than I though due to changes to Union and Optional type where the `__module__` field cannot be changed anymore. See https://github.com/python/cpython/issues/132139 for details. For now, I'm just skipping the `__module__` setting for 3.14 which will trip the public API checks. Will revisit once I have a final answer on the cpython issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158184 Approved by: https://github.com/msaroufim
This commit is contained in:
@ -33,7 +33,7 @@ if sys.version_info >= (3, 11):
|
||||
TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"])
|
||||
else:
|
||||
TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"])
|
||||
if sys.version_info >= (3, 12):
|
||||
if (3, 12) <= sys.version_info < (3, 14):
|
||||
TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"])
|
||||
if sys.version_info >= (3, 13):
|
||||
TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD_NO_INTERRUPT"])
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
import sys
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -33,7 +34,9 @@ from .stubs import * # noqa: F403
|
||||
|
||||
# ensure __module__ is set correctly for public APIs
|
||||
ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase]
|
||||
ObserverOrFakeQuantize.__module__ = "torch.ao.quantization"
|
||||
if sys.version_info < (3, 14):
|
||||
ObserverOrFakeQuantize.__module__ = "torch.ao.quantization"
|
||||
|
||||
for _f in [
|
||||
compare_results,
|
||||
extract_results_from_loggers,
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
import sys
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from typing import Any, Optional, Union
|
||||
@ -568,7 +569,8 @@ def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> N
|
||||
|
||||
|
||||
QConfigAny = Optional[QConfig]
|
||||
QConfigAny.__module__ = "torch.ao.quantization.qconfig"
|
||||
if sys.version_info < (3, 14):
|
||||
QConfigAny.__module__ = "torch.ao.quantization.qconfig"
|
||||
|
||||
|
||||
def _add_module_to_qconfig_obs_ctr(
|
||||
|
@ -4,6 +4,7 @@ Utils shared by different modes of quantization (eager/graph)
|
||||
"""
|
||||
|
||||
import functools
|
||||
import sys
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from inspect import getfullargspec, signature
|
||||
@ -16,7 +17,8 @@ from torch.nn.utils.parametrize import is_parametrized
|
||||
|
||||
|
||||
NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any]
|
||||
NodePattern.__module__ = "torch.ao.quantization.utils"
|
||||
if sys.version_info < (3, 14):
|
||||
NodePattern.__module__ = "torch.ao.quantization.utils"
|
||||
|
||||
# This is the Quantizer class instance from torch/quantization/fx/quantize.py.
|
||||
# Define separately to prevent circular imports.
|
||||
@ -31,7 +33,8 @@ QuantizerCls = Any
|
||||
Pattern = Union[
|
||||
Callable, tuple[Callable, Callable], tuple[Callable, tuple[Callable, Callable]], Any
|
||||
]
|
||||
Pattern.__module__ = "torch.ao.quantization.utils"
|
||||
if sys.version_info < (3, 14):
|
||||
Pattern.__module__ = "torch.ao.quantization.utils"
|
||||
|
||||
|
||||
# TODO: maybe rename this to MatchInputNode
|
||||
|
@ -2,6 +2,20 @@
|
||||
#include <torch/csrc/dynamo/cpython_includes.h>
|
||||
#include <torch/csrc/dynamo/debug_macros.h>
|
||||
|
||||
#if IS_PYTHON_3_14_PLUS
|
||||
|
||||
const uint8_t* THP_PyOpcode_Caches = NULL;
|
||||
const int THP_PyOpcode_Caches_size = 0;
|
||||
|
||||
void
|
||||
THP_PyThreadState_PopFrame(PyThreadState *tstate, _PyInterpreterFrame * frame)
|
||||
{}
|
||||
void
|
||||
THP_PyFrame_Clear(_PyInterpreterFrame *frame)
|
||||
{}
|
||||
|
||||
#else
|
||||
|
||||
#if IS_PYTHON_3_11_PLUS
|
||||
|
||||
#define Py_BUILD_CORE
|
||||
@ -360,3 +374,5 @@ const uint8_t* THP_PyOpcode_Caches = NULL;
|
||||
const int THP_PyOpcode_Caches_size = 0;
|
||||
|
||||
#endif
|
||||
|
||||
#endif // IS_PYTHON_3_14_PLUS
|
@ -21,6 +21,14 @@
|
||||
|
||||
#if IS_PYTHON_3_11_PLUS
|
||||
#include <internal/pycore_frame.h>
|
||||
#if IS_PYTHON_3_14_PLUS
|
||||
#include <internal/pycore_interpframe_structs.h>
|
||||
#include <internal/pycore_stackref.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if IS_PYTHON_3_14_PLUS
|
||||
#include <internal/pycore_code.h>
|
||||
#endif
|
||||
|
||||
#undef Py_BUILD_CORE
|
||||
@ -30,6 +38,13 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#if IS_PYTHON_3_14_PLUS
|
||||
|
||||
#define F_CODE(x) (PyCodeObject*)PyStackRef_AsPyObjectBorrow(x->f_executable)
|
||||
#define PREV_INSTR(x) (x)->instr_ptr
|
||||
|
||||
#else
|
||||
|
||||
#if IS_PYTHON_3_13_PLUS
|
||||
#define F_CODE(x) ((PyCodeObject*)(x)->f_executable)
|
||||
#define PREV_INSTR(x) (x)->instr_ptr
|
||||
@ -38,6 +53,8 @@ extern "C" {
|
||||
#define PREV_INSTR(x) (x)->prev_instr
|
||||
#endif
|
||||
|
||||
#endif // IS_PYTHON_3_14_PLUS
|
||||
|
||||
#if IS_PYTHON_3_12_PLUS
|
||||
#define FUNC(x) ((x)->f_funcobj)
|
||||
#else
|
||||
|
@ -224,17 +224,6 @@ const char* get_frame_name(THP_EVAL_API_FRAME_OBJECT* frame) {
|
||||
return PyUnicode_AsUTF8(F_CODE(frame)->co_name);
|
||||
}
|
||||
|
||||
void clear_old_frame_if_python_312_plus(
|
||||
PyThreadState* tstate,
|
||||
THP_EVAL_API_FRAME_OBJECT* frame) {
|
||||
#if IS_PYTHON_3_12_PLUS
|
||||
|
||||
THP_PyFrame_Clear(frame);
|
||||
THP_PyThreadState_PopFrame(tstate, frame);
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
static PyObject* dynamo_eval_custom_code_impl(
|
||||
PyThreadState* tstate,
|
||||
THP_EVAL_API_FRAME_OBJECT* frame,
|
||||
@ -485,6 +474,18 @@ static PyObject* dynamo__custom_eval_frame_shim(
|
||||
|
||||
static void enable_eval_frame_shim(PyThreadState* tstate) {}
|
||||
static void enable_eval_frame_default(PyThreadState* tstate) {}
|
||||
PyObject* dynamo_eval_custom_code(
|
||||
PyThreadState* tstate,
|
||||
THP_EVAL_API_FRAME_OBJECT* frame,
|
||||
PyCodeObject* code,
|
||||
const char* trace_annotation,
|
||||
int throw_flag) {}
|
||||
THPPyInterpreterFrame* THPPyInterpreterFrame_New(
|
||||
THP_EVAL_API_FRAME_OBJECT* frame) {}
|
||||
PyObject* dynamo_eval_frame_default(
|
||||
PyThreadState* tstate,
|
||||
THP_EVAL_API_FRAME_OBJECT* frame,
|
||||
int throw_flag) {}
|
||||
|
||||
static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {NULL};
|
||||
|
||||
@ -498,6 +499,17 @@ static PyTypeObject THPPyInterpreterFrameType = {
|
||||
|
||||
#endif // !(IS_PYTHON_3_14_PLUS)
|
||||
|
||||
void clear_old_frame_if_python_312_plus(
|
||||
PyThreadState* tstate,
|
||||
THP_EVAL_API_FRAME_OBJECT* frame) {
|
||||
#if IS_PYTHON_3_12_PLUS
|
||||
|
||||
THP_PyFrame_Clear(frame);
|
||||
THP_PyThreadState_PopFrame(tstate, frame);
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
static PyObject* increment_working_threads(
|
||||
PyThreadState* tstate,
|
||||
PyObject* module) {
|
||||
|
@ -26,9 +26,13 @@ FrameLocalsMapping::FrameLocalsMapping(FrameLocalsFrameType* frame)
|
||||
PyCodeObject* co = F_CODE(frame);
|
||||
_framelocals.resize(co->co_nlocalsplus, nullptr);
|
||||
|
||||
#if IS_PYTHON_3_14_PLUS
|
||||
TORCH_CHECK(false, "Python 3.14+ not supported");
|
||||
#else
|
||||
if (!frame->stacktop) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
auto update_framelocals = [&](int i, PyObject* value) {
|
||||
_PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i);
|
||||
@ -53,11 +57,21 @@ FrameLocalsMapping::FrameLocalsMapping(FrameLocalsFrameType* frame)
|
||||
};
|
||||
|
||||
auto offset = co->co_nlocalsplus - co->co_nfreevars;
|
||||
#if IS_PYTHON_3_14_PLUS
|
||||
TORCH_CHECK(false, "Python 3.14+ not supported");
|
||||
#else
|
||||
for (int i = 0; i < offset; i++) {
|
||||
update_framelocals(i, frame->localsplus[i]);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Get references to closure variables
|
||||
#if IS_PYTHON_3_14_PLUS
|
||||
PyObject* closure;
|
||||
TORCH_CHECK(false, "Python 3.14+ not supported");
|
||||
#else
|
||||
PyObject* closure = ((PyFunctionObject*)FUNC(frame))->func_closure;
|
||||
#endif
|
||||
for (int i = 0; i < co->co_nfreevars; i++) {
|
||||
update_framelocals(offset + i, PyTuple_GET_ITEM(closure, i));
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ extern "C" {
|
||||
#define IS_PYTHON_3_12_PLUS PY_VERSION_HEX >= 0x030C0000
|
||||
#define IS_PYTHON_3_13_PLUS PY_VERSION_HEX >= 0x030D0000
|
||||
#define IS_PYTHON_3_14_PLUS PY_VERSION_HEX >= 0x030E0000
|
||||
#define IS_PYTHON_3_15_PLUS PY_VERSION_HEX >= 0x030F0000
|
||||
|
||||
static inline int PyCode_GetNCellvars(PyCodeObject* code) {
|
||||
// gh-26364 added co_ncellvars to Python 3.11.0rc1
|
||||
|
@ -100,7 +100,6 @@ ONNXProgram.__module__ = "torch.onnx"
|
||||
OnnxExporterError.__module__ = "torch.onnx"
|
||||
_OrtBackend.__module__ = "torch.onnx"
|
||||
_OrtBackendOptions.__module__ = "torch.onnx"
|
||||
_OrtExecutionProvider.__module__ = "torch.onnx"
|
||||
enable_fake_mode.__module__ = "torch.onnx"
|
||||
is_onnxrt_backend_supported.__module__ = "torch.onnx"
|
||||
|
||||
|
@ -3,8 +3,6 @@ from __future__ import annotations
|
||||
|
||||
import collections.abc as _collections_abc
|
||||
import weakref
|
||||
|
||||
from _weakrefset import _IterationGuard # type: ignore[attr-defined]
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from weakref import ref
|
||||
|
||||
@ -22,6 +20,33 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
# TODO: make weakref properly thread safe following
|
||||
# https://github.com/python/cpython/pull/125325
|
||||
class _IterationGuard:
|
||||
# This context manager registers itself in the current iterators of the
|
||||
# weak container, such as to delay all removals until the context manager
|
||||
# exits.
|
||||
# This technique should be relatively thread-safe (since sets are).
|
||||
|
||||
def __init__(self, weakcontainer):
|
||||
# Don't create cycles
|
||||
self.weakcontainer = ref(weakcontainer)
|
||||
|
||||
def __enter__(self):
|
||||
w = self.weakcontainer()
|
||||
if w is not None:
|
||||
w._iterating.add(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, e, t, b):
|
||||
w = self.weakcontainer()
|
||||
if w is not None:
|
||||
s = w._iterating
|
||||
s.remove(self)
|
||||
if not s:
|
||||
w._commit_removals()
|
||||
|
||||
|
||||
# This file defines a variant of WeakKeyDictionary that overrides the hashing
|
||||
# behavior of the key to use object identity, rather than the builtin
|
||||
# __eq__/__hash__ functions. This is useful for Tensor weak keys, as their
|
||||
|
Reference in New Issue
Block a user