mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[Dynamo] Remove ignored modes from torch function mode stack guard (#135503)"
This reverts commit c56728b643e2b7d796abd7ec45803319e1c5967d. Reverted https://github.com/pytorch/pytorch/pull/135503 on behalf of https://github.com/albanD due to Broke tests on main ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2348886378))
This commit is contained in:
@ -67,7 +67,7 @@ class GuardManager:
|
||||
) -> None: ...
|
||||
def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
|
||||
def add_torch_function_mode_stack_guard(
|
||||
self, initial_stack, verbose_code_parts: list[str]
|
||||
self, initial_stack, ignored_types, verbose_code_parts: list[str]
|
||||
) -> None: ...
|
||||
|
||||
class RootGuardManager(GuardManager):
|
||||
|
@ -2350,6 +2350,7 @@ class CheckFunctionManager:
|
||||
|
||||
self.guard_manager.root.add_torch_function_mode_stack_guard(
|
||||
self.torch_function_mode_stack,
|
||||
list(),
|
||||
["___check_torch_function_mode_stack()"],
|
||||
)
|
||||
# Clear references to torch_function modes held in the list
|
||||
|
@ -2515,40 +2515,90 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard {
|
||||
public:
|
||||
TORCH_FUNCTION_MODE_STACK(
|
||||
const py::list& initial_stack,
|
||||
const py::list& ignored_types,
|
||||
py::object verbose_code_parts)
|
||||
: LeafGuard(std::move(verbose_code_parts)), _ref_stack() {
|
||||
: LeafGuard(std::move(verbose_code_parts)),
|
||||
_ref_stack(),
|
||||
_ignored_types() {
|
||||
Py_ssize_t len = PyList_Size(initial_stack.ptr());
|
||||
for (Py_ssize_t idx = 0; idx < len; idx++) {
|
||||
PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref
|
||||
auto type = Py_TYPE(mode);
|
||||
this->_ref_stack.push_back(type);
|
||||
}
|
||||
|
||||
len = PyList_Size(ignored_types.ptr());
|
||||
for (Py_ssize_t idx = 0; idx < len; idx++) {
|
||||
PyObject* type_obj =
|
||||
PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref
|
||||
if (PyType_Check(type_obj) == 0) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError, "ignored_types should contain a list of types");
|
||||
return;
|
||||
}
|
||||
PyTypeObject* type = (PyTypeObject*)type_obj;
|
||||
this->_ignored_types.insert(type);
|
||||
}
|
||||
}
|
||||
|
||||
bool check_nopybind(PyObject* value) override {
|
||||
// Ignore value arg, only used to satisfy the interface
|
||||
const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len();
|
||||
size_t ref_ind = 0;
|
||||
const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len();
|
||||
const size_t ref_stack_size = this->_ref_stack.size();
|
||||
|
||||
if (len != ref_stack_size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int64_t idx = 0; (size_t)idx < len; idx++) {
|
||||
int64_t idx = 0;
|
||||
while ((idx < len) && (ref_ind < ref_stack_size)) {
|
||||
std::shared_ptr<c10::SafePyObject> mode =
|
||||
at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
|
||||
|
||||
PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
|
||||
if (mode_type != _ref_stack.at(idx)) {
|
||||
bool act_ignored = this->_ignored_types.count(mode_type) > 0;
|
||||
bool ref_ignored =
|
||||
this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0;
|
||||
// skip ignored types
|
||||
if (act_ignored && ref_ignored) {
|
||||
idx++;
|
||||
ref_ind++;
|
||||
continue;
|
||||
} else if (ref_ignored) {
|
||||
ref_ind++;
|
||||
continue;
|
||||
} else if (act_ignored) {
|
||||
idx++;
|
||||
continue;
|
||||
}
|
||||
// if we already have more non-ignored modes than the ref stack
|
||||
// or if the mode doesn't match at the current index, return false
|
||||
else if (mode_type != _ref_stack.at(ref_ind)) {
|
||||
return false;
|
||||
}
|
||||
ref_ind++;
|
||||
idx++;
|
||||
}
|
||||
|
||||
for (; ref_ind < ref_stack_size; ref_ind++) {
|
||||
if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
for (; idx < len; idx++) {
|
||||
std::shared_ptr<c10::SafePyObject> mode =
|
||||
at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
|
||||
|
||||
PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
|
||||
if (!(this->_ignored_types.count(mode_type) > 0)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return ref_ind == ref_stack_size && idx == len;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<PyTypeObject*> _ref_stack;
|
||||
std::set<PyTypeObject*> _ignored_types;
|
||||
};
|
||||
|
||||
class TENSOR_MATCH : public LeafGuard {
|
||||
@ -3713,7 +3763,7 @@ PyObject* torch_c_dynamo_guards_init() {
|
||||
LeafGuard,
|
||||
std::shared_ptr<TORCH_FUNCTION_MODE_STACK>>(
|
||||
py_m, "TORCH_FUNCTION_MODE_STACK")
|
||||
.def(py::init<py::list, py::list>())
|
||||
.def(py::init<py::list, py::list, py::list>())
|
||||
.def("__call__", &TORCH_FUNCTION_MODE_STACK::check);
|
||||
py::class_<DATA_PTR_MATCH, LeafGuard, std::shared_ptr<DATA_PTR_MATCH>>(
|
||||
py_m, "DATA_PTR_MATCH")
|
||||
@ -3950,9 +4000,10 @@ PyObject* torch_c_dynamo_guards_init() {
|
||||
"add_torch_function_mode_stack_guard",
|
||||
[](GuardManager& self,
|
||||
const py::list& initial_stack,
|
||||
const py::list& ignored_types,
|
||||
py::object verbose_code_parts) -> void {
|
||||
self.add_leaf_guard(std::make_shared<TORCH_FUNCTION_MODE_STACK>(
|
||||
initial_stack, std::move(verbose_code_parts)));
|
||||
initial_stack, ignored_types, std::move(verbose_code_parts)));
|
||||
})
|
||||
.def(
|
||||
"add_data_ptr_guard",
|
||||
|
Reference in New Issue
Block a user