Revert "[Dynamo] Remove ignored modes from torch function mode stack guard (#135503)"

This reverts commit c56728b643e2b7d796abd7ec45803319e1c5967d.

Reverted https://github.com/pytorch/pytorch/pull/135503 on behalf of https://github.com/albanD due to Broke tests on main ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2348886378))
This commit is contained in:
PyTorch MergeBot
2024-09-13 12:52:57 +00:00
parent 1cdf658f4a
commit dc71e7a7d4
3 changed files with 64 additions and 12 deletions

View File

@ -67,7 +67,7 @@ class GuardManager:
) -> None: ...
def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
def add_torch_function_mode_stack_guard(
self, initial_stack, verbose_code_parts: list[str]
self, initial_stack, ignored_types, verbose_code_parts: list[str]
) -> None: ...
class RootGuardManager(GuardManager):

View File

@ -2350,6 +2350,7 @@ class CheckFunctionManager:
self.guard_manager.root.add_torch_function_mode_stack_guard(
self.torch_function_mode_stack,
list(),
["___check_torch_function_mode_stack()"],
)
# Clear references to torch_function modes held in the list

View File

@ -2515,40 +2515,90 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard {
public:
TORCH_FUNCTION_MODE_STACK(
const py::list& initial_stack,
const py::list& ignored_types,
py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)), _ref_stack() {
: LeafGuard(std::move(verbose_code_parts)),
_ref_stack(),
_ignored_types() {
Py_ssize_t len = PyList_Size(initial_stack.ptr());
for (Py_ssize_t idx = 0; idx < len; idx++) {
PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref
auto type = Py_TYPE(mode);
this->_ref_stack.push_back(type);
}
len = PyList_Size(ignored_types.ptr());
for (Py_ssize_t idx = 0; idx < len; idx++) {
PyObject* type_obj =
PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref
if (PyType_Check(type_obj) == 0) {
PyErr_SetString(
PyExc_TypeError, "ignored_types should contain a list of types");
return;
}
PyTypeObject* type = (PyTypeObject*)type_obj;
this->_ignored_types.insert(type);
}
}
bool check_nopybind(PyObject* value) override {
// Ignore value arg, only used to satisfy the interface
const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len();
size_t ref_ind = 0;
const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len();
const size_t ref_stack_size = this->_ref_stack.size();
if (len != ref_stack_size) {
return false;
}
for (int64_t idx = 0; (size_t)idx < len; idx++) {
int64_t idx = 0;
while ((idx < len) && (ref_ind < ref_stack_size)) {
std::shared_ptr<c10::SafePyObject> mode =
at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
if (mode_type != _ref_stack.at(idx)) {
bool act_ignored = this->_ignored_types.count(mode_type) > 0;
bool ref_ignored =
this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0;
// skip ignored types
if (act_ignored && ref_ignored) {
idx++;
ref_ind++;
continue;
} else if (ref_ignored) {
ref_ind++;
continue;
} else if (act_ignored) {
idx++;
continue;
}
// if we already have more non-ignored modes than the ref stack
// or if the mode doesn't match at the current index, return false
else if (mode_type != _ref_stack.at(ref_ind)) {
return false;
}
ref_ind++;
idx++;
}
for (; ref_ind < ref_stack_size; ref_ind++) {
if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) {
return false;
}
}
return true;
for (; idx < len; idx++) {
std::shared_ptr<c10::SafePyObject> mode =
at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
if (!(this->_ignored_types.count(mode_type) > 0)) {
return false;
}
}
return ref_ind == ref_stack_size && idx == len;
}
private:
std::vector<PyTypeObject*> _ref_stack;
std::set<PyTypeObject*> _ignored_types;
};
class TENSOR_MATCH : public LeafGuard {
@ -3713,7 +3763,7 @@ PyObject* torch_c_dynamo_guards_init() {
LeafGuard,
std::shared_ptr<TORCH_FUNCTION_MODE_STACK>>(
py_m, "TORCH_FUNCTION_MODE_STACK")
.def(py::init<py::list, py::list>())
.def(py::init<py::list, py::list, py::list>())
.def("__call__", &TORCH_FUNCTION_MODE_STACK::check);
py::class_<DATA_PTR_MATCH, LeafGuard, std::shared_ptr<DATA_PTR_MATCH>>(
py_m, "DATA_PTR_MATCH")
@ -3950,9 +4000,10 @@ PyObject* torch_c_dynamo_guards_init() {
"add_torch_function_mode_stack_guard",
[](GuardManager& self,
const py::list& initial_stack,
const py::list& ignored_types,
py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<TORCH_FUNCTION_MODE_STACK>(
initial_stack, std::move(verbose_code_parts)));
initial_stack, ignored_types, std::move(verbose_code_parts)));
})
.def(
"add_data_ptr_guard",