Compare commits

...

12 Commits

Author SHA1 Message Date
50ea044f8a Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-18 13:37:01 -08:00
257bf8e59e Update base for Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-18 13:37:01 -08:00
be71654b78 Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-18 11:41:08 -08:00
fa0c57142a Update base for Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-18 11:41:08 -08:00
b737df7704 Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-18 11:38:22 -08:00
fb09741981 Update base for Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-18 11:38:22 -08:00
5230b7c0ac Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-17 17:44:34 -08:00
881cd1c6f4 Update base for Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-17 17:44:34 -08:00
c4f3d7d410 [MPS] remove expected failure for a test (#167922)
remove expected failure for a test for MPS backend, but lower the precision to `1e-4`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167922
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-11-17 22:58:13 +00:00
2fdd517fd9 Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-14 17:37:25 -08:00
43fe667181 Update on "[dynamo] add set_c_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-14 17:34:27 -08:00
d8ebe3543d [dynamo] add set_c_recursion_limit to fix 3.12/3.13 RecursionError problems
[ghstack-poisoned]
2025-11-14 15:55:54 -08:00
7 changed files with 183 additions and 2 deletions

View File

@ -7456,6 +7456,97 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
msg,
)
def test_dynamo_set_recursion_limit_simple(self):
# Test that torch._dynamo.set_recursion_limit calls sys.setrecursionlimit for all supported
# Python versions
old_recursion_limit = sys.getrecursionlimit()
old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit()
try:
def fn(x, n):
if n == 0:
return x
return fn(x, n - 1) + 1
sys.setrecursionlimit(100)
with self.assertRaises(RecursionError):
fn(torch.ones(3), 1000)
opt_fn = torch.compile(fn, backend="eager", dynamic=False)
torch._dynamo.set_recursion_limit(100000)
self.assertEqual(fn(torch.ones(3), 1000), opt_fn(torch.ones(3), 1000))
finally:
if old_dynamo_recursion_limit > 0:
torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit)
sys.setrecursionlimit(old_recursion_limit)
@unittest.skipIf(
sys.version_info < (3, 12) or sys.version_info >= (3, 14),
"only 3.12, 3.13 affected by c recursion limit",
)
def test_dynamo_set_recursion_limit(self):
old_recursion_limit = sys.getrecursionlimit()
old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit()
try:
def fn(x, n):
if n == 0:
return x
return fn(x, n - 1) + 1
sys.setrecursionlimit(100)
with self.assertRaises(RecursionError):
fn(torch.ones(3), 1000)
sys.setrecursionlimit(2000)
fn(torch.ones(3), 1000)
opt_fn = torch.compile(fn, backend="eager", dynamic=False)
sys.setrecursionlimit(100000)
with self.assertRaises(Exception):
opt_fn(torch.ones(3), 1000)
torch._dynamo.set_recursion_limit(100000)
self.assertEqual(fn(torch.ones(3), 1000), opt_fn(torch.ones(3), 1000))
finally:
if old_dynamo_recursion_limit > 0:
torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit)
sys.setrecursionlimit(old_recursion_limit)
@unittest.skipIf(
sys.version_info < (3, 12) or sys.version_info >= (3, 14),
"only 3.12, 3.13 affected by c recursion limit",
)
def test_dynamo_set_recursion_limit_usage(self):
old_recursion_limit = sys.getrecursionlimit()
old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit()
try:
torch._dynamo.set_recursion_limit(100)
self.assertEqual(torch._dynamo.get_recursion_limit(), 100)
with self.assertRaisesRegex(ValueError, "recursion limit"):
torch._dynamo.set_recursion_limit(0)
self.assertEqual(torch._dynamo.get_recursion_limit(), 100)
torch._dynamo.set_recursion_limit(1)
sys.setrecursionlimit(100)
@torch.compile(backend="eager", dynamic=False)
def fn(x, n):
if n == 0:
return x
return fn(x, n - 1) + 1
with self.assertRaisesRegex(RuntimeError, "new c_recursion limit"):
fn(torch.ones(3), 5)
finally:
if old_dynamo_recursion_limit > 0:
torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit)
sys.setrecursionlimit(old_recursion_limit)
@expectedFailureDynamic
def test_dynamo_default_lru_cache_behavior(self):
@torch.compile(backend="eager")

View File

@ -630,7 +630,6 @@ class TestSparse(TestSparseBase):
i[0][0] = 0
self.assertEqual(torch.empty((3, 0), dtype=dtype, device=device), self.safeToDense(x))
@expectedFailureMPS
@dtypes(torch.double, torch.cdouble)
@dtypesIfMPS(torch.float32, torch.complex64)
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error")
@ -647,7 +646,8 @@ class TestSparse(TestSparseBase):
def fn(x):
return x.to_dense(masked_grad=gradcheck.masked)
x.requires_grad_(True)
gradcheck(fn, (x,))
kwargs = {"eps": 1e-4} if device == "mps:0" else {}
gradcheck(fn, (x,), **kwargs)
i = self.index_tensor([
[0, 1, 2, 2],

View File

@ -19,6 +19,8 @@ def set_guard_complete_hook(
hook: Optional[DynamoGuardCompleteHook],
) -> Optional[DynamoGuardCompleteHook]: ...
def raise_sigtrap() -> None: ...
def set_c_recursion_limit(limit: int) -> None: ...
def get_c_recursion_limit() -> int: ...
class _CacheEntry:
def check_fn(self, *args: object, **kwargs: object) -> bool: ...

View File

@ -105,6 +105,7 @@ __all__ = [
"reset",
"run",
"error_on_graph_break",
"set_recursion_limit",
"set_stance",
"skip_frame",
"step_unsupported",
@ -181,3 +182,26 @@ def reset_code_caches() -> None:
if code:
reset_code(code)
code_context.clear()
def get_recursion_limit() -> int:
"""
Returns the internal dynamo recursion limit set by `torch._dynamo.set_recursion_limit`.
Returns -1 if no c recursion limit has been set.
"""
return torch._C._dynamo.eval_frame.get_c_recursion_limit()
def set_recursion_limit(limit: int) -> None:
"""
Sets an internal dynamo recursion limit. The limit must be >= 1.
This is possibly needed in Python 3.12-3.13 since there is a separate C recursion limit
that is not visible at the Python level. If you are getting RecursionErrors during
Dynamo compilation and `sys.setrecursionlimit()` doesn't help, this function may alleviate
the issue.
NOTE: this function will also call `sys.setrecursionlimit()`.
"""
torch._C._dynamo.eval_frame.set_c_recursion_limit(limit)

View File

@ -50,6 +50,56 @@ static py::handle _callback_from_action(
return callback;
}
// c_recursion_remaining only defined in 3.12 and 3.13
static int32_t c_recursion_limit = -1;
void set_c_recursion_limit(int32_t limit) {
if (limit < 1) {
throw std::range_error("recursion limit must be greater or equal than 1");
}
c_recursion_limit = limit;
// cannot fail
Py_SetRecursionLimit(limit); // also set the Python limit
}
int32_t get_c_recursion_limit() {
return c_recursion_limit;
}
#if IS_PYTHON_3_12_PLUS && !IS_PYTHON_3_14_PLUS
struct CRecursionLimitRAII {
PyThreadState* tstate;
int32_t old_recursion_remaining;
CRecursionLimitRAII(PyThreadState* tstate) : tstate{tstate} {
auto limit = get_c_recursion_limit();
auto& remaining = tstate->c_recursion_remaining;
this->old_recursion_remaining = remaining;
if (limit < 0) {
// no change to limit
return;
}
if (limit < remaining) {
PyErr_SetString(
PyExc_RuntimeError,
"new c_recursion limit is lower than thread's current c_recursion_remaining.");
}
remaining = limit;
}
~CRecursionLimitRAII() {
this->tstate->c_recursion_remaining = this->old_recursion_remaining;
}
};
#else
struct CRecursionLimitRAII {
CRecursionLimitRAII(PyThreadState* tstate) {}
};
#endif
// frame and callback are borrowed references.
// Returns new reference.
PyObject* dynamo__custom_eval_frame(
@ -258,6 +308,13 @@ PyObject* dynamo__custom_eval_frame(
bool apply_to_code = false;
PyObject* guarded_code = nullptr;
try {
CRecursionLimitRAII tmp(tstate); // increase C recursion limit to the given
// value during compilation
// C recursion limit failure
if (PyErr_Occurred()) {
fail();
return eval_result;
}
callback_result = dynamo_call_callback(
callback, frame, locals.get(), cache_entry, frame_state);
new_strategy =

View File

@ -19,6 +19,9 @@ PyObject* dynamo__custom_eval_frame(
PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* obj);
void skip_code_recursive(PyCodeObject* code);
void set_c_recursion_limit(int32_t limit);
int32_t get_c_recursion_limit();
#ifdef __cplusplus
} // extern "C"

View File

@ -7,6 +7,7 @@
#include <torch/csrc/dynamo/cache_entry.h>
#include <torch/csrc/dynamo/cpython_defs.h>
#include <torch/csrc/dynamo/eval_frame.h>
#include <torch/csrc/dynamo/eval_frame_cpp.h>
#include <torch/csrc/dynamo/extra_state.h>
#include <torch/csrc/dynamo/guards.h>
#include <torch/csrc/dynamo/python_compiled_autograd.h>
@ -250,6 +251,9 @@ void initDynamoBindings(PyObject* torch) {
.def_readwrite("cur_action", &FrameExecStrategy::cur_action)
.def_readwrite("recursive_action", &FrameExecStrategy::recursive_action);
m.def("set_c_recursion_limit", &set_c_recursion_limit);
m.def("get_c_recursion_limit", &get_c_recursion_limit);
m.def("_debug_get_cache_entry_list", &_debug_get_cache_entry_list);
m.def("_reset_precompile_entries", &_reset_precompile_entries);
m.def("_load_precompile_entry", &_load_precompile_entry);