Compare commits

...

13 Commits

Author SHA1 Message Date
967e7093a5 Update on "[dynamo, guards] apply functools.cached_property to Source.name"
Partial fix for https://github.com/pytorch/pytorch/issues/168118. Decreases guard build time from 25s -> 16s locally.

However, there are a lot of changes to `source.name()` callsites. This could technically be avoided by writing our own `cached_property` decorator that requires the function call.

cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-18 16:58:24 -08:00
8cc430db5d [dynamo, guards] apply functools.cached_property to Source.name
[ghstack-poisoned]
2025-11-18 16:45:09 -08:00
50ea044f8a Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-18 13:37:01 -08:00
257bf8e59e Update base for Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-18 13:37:01 -08:00
be71654b78 Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-18 11:41:08 -08:00
fa0c57142a Update base for Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-18 11:41:08 -08:00
b737df7704 Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-18 11:38:22 -08:00
fb09741981 Update base for Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-18 11:38:22 -08:00
5230b7c0ac Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-17 17:44:34 -08:00
881cd1c6f4 Update base for Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-17 17:44:34 -08:00
2fdd517fd9 Update on "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-14 17:37:25 -08:00
43fe667181 Update on "[dynamo] add set_c_recursion_limit to fix 3.12/3.13 RecursionError problems"
Fixes https://github.com/pytorch/pytorch/issues/167789

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela mlazos

[ghstack-poisoned]
2025-11-14 17:34:27 -08:00
d8ebe3543d [dynamo] add set_c_recursion_limit to fix 3.12/3.13 RecursionError problems
[ghstack-poisoned]
2025-11-14 15:55:54 -08:00
24 changed files with 380 additions and 150 deletions

View File

@ -928,8 +928,8 @@ class TypePropagationTests(torch._dynamo.test_case.TestCase):
foo_source = LocalSource("foo")
foo_x_source = AttrSource(foo_source, "x")
self.assertTrue(builder.get(foo_source.name()) is foo)
self.assertTrue(builder.get(foo_x_source.name()) is foo.x)
self.assertTrue(builder.get(foo_source.name) is foo)
self.assertTrue(builder.get(foo_x_source.name) is foo.x)
# Check types of foo.x
foo_x_mgr = builder.get_guard_manager_from_source(foo_x_source)

View File

@ -7456,6 +7456,97 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
msg,
)
def test_dynamo_set_recursion_limit_simple(self):
# Test that torch._dynamo.set_recursion_limit calls sys.setrecursionlimit for all supported
# Python versions
old_recursion_limit = sys.getrecursionlimit()
old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit()
try:
def fn(x, n):
if n == 0:
return x
return fn(x, n - 1) + 1
sys.setrecursionlimit(100)
with self.assertRaises(RecursionError):
fn(torch.ones(3), 1000)
opt_fn = torch.compile(fn, backend="eager", dynamic=False)
torch._dynamo.set_recursion_limit(100000)
self.assertEqual(fn(torch.ones(3), 1000), opt_fn(torch.ones(3), 1000))
finally:
if old_dynamo_recursion_limit > 0:
torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit)
sys.setrecursionlimit(old_recursion_limit)
@unittest.skipIf(
sys.version_info < (3, 12) or sys.version_info >= (3, 14),
"only 3.12, 3.13 affected by c recursion limit",
)
def test_dynamo_set_recursion_limit(self):
old_recursion_limit = sys.getrecursionlimit()
old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit()
try:
def fn(x, n):
if n == 0:
return x
return fn(x, n - 1) + 1
sys.setrecursionlimit(100)
with self.assertRaises(RecursionError):
fn(torch.ones(3), 1000)
sys.setrecursionlimit(2000)
fn(torch.ones(3), 1000)
opt_fn = torch.compile(fn, backend="eager", dynamic=False)
sys.setrecursionlimit(100000)
with self.assertRaises(Exception):
opt_fn(torch.ones(3), 1000)
torch._dynamo.set_recursion_limit(100000)
self.assertEqual(fn(torch.ones(3), 1000), opt_fn(torch.ones(3), 1000))
finally:
if old_dynamo_recursion_limit > 0:
torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit)
sys.setrecursionlimit(old_recursion_limit)
@unittest.skipIf(
sys.version_info < (3, 12) or sys.version_info >= (3, 14),
"only 3.12, 3.13 affected by c recursion limit",
)
def test_dynamo_set_recursion_limit_usage(self):
old_recursion_limit = sys.getrecursionlimit()
old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit()
try:
torch._dynamo.set_recursion_limit(100)
self.assertEqual(torch._dynamo.get_recursion_limit(), 100)
with self.assertRaisesRegex(ValueError, "recursion limit"):
torch._dynamo.set_recursion_limit(0)
self.assertEqual(torch._dynamo.get_recursion_limit(), 100)
torch._dynamo.set_recursion_limit(1)
sys.setrecursionlimit(100)
@torch.compile(backend="eager", dynamic=False)
def fn(x, n):
if n == 0:
return x
return fn(x, n - 1) + 1
with self.assertRaisesRegex(RuntimeError, "new c_recursion limit"):
fn(torch.ones(3), 5)
finally:
if old_dynamo_recursion_limit > 0:
torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit)
sys.setrecursionlimit(old_recursion_limit)
@expectedFailureDynamic
def test_dynamo_default_lru_cache_behavior(self):
@torch.compile(backend="eager")

View File

@ -1624,7 +1624,7 @@ class GraphModule(torch.nn.Module):
str(k): v for k, v in context.fake_mode.shape_env.var_to_val.items()
}
curr_var_to_sources = {
str(k): v[0].name()
str(k): v[0].name
for k, v in context.fake_mode.shape_env.var_to_sources.items()
}
return gm

View File

@ -19,6 +19,8 @@ def set_guard_complete_hook(
hook: Optional[DynamoGuardCompleteHook],
) -> Optional[DynamoGuardCompleteHook]: ...
def raise_sigtrap() -> None: ...
def set_c_recursion_limit(limit: int) -> None: ...
def get_c_recursion_limit() -> int: ...
class _CacheEntry:
def check_fn(self, *args: object, **kwargs: object) -> bool: ...

View File

@ -105,6 +105,7 @@ __all__ = [
"reset",
"run",
"error_on_graph_break",
"set_recursion_limit",
"set_stance",
"skip_frame",
"step_unsupported",
@ -181,3 +182,26 @@ def reset_code_caches() -> None:
if code:
reset_code(code)
code_context.clear()
def get_recursion_limit() -> int:
"""
Returns the internal dynamo recursion limit set by `torch._dynamo.set_recursion_limit`.
Returns -1 if no c recursion limit has been set.
"""
return torch._C._dynamo.eval_frame.get_c_recursion_limit()
def set_recursion_limit(limit: int) -> None:
"""
Sets an internal dynamo recursion limit. The limit must be >= 1.
This is possibly needed in Python 3.12-3.13 since there is a separate C recursion limit
that is not visible at the Python level. If you are getting RecursionErrors during
Dynamo compilation and `sys.setrecursionlimit()` doesn't help, this function may alleviate
the issue.
NOTE: this function will also call `sys.setrecursionlimit()`.
"""
torch._C._dynamo.eval_frame.set_c_recursion_limit(limit)

View File

@ -1707,13 +1707,13 @@ def check_signature_rewritable(graph: torch.fx.GraphModule) -> None:
stack = s
break
if stack is None:
msg = f"{source.name()}, a closed over free variable"
msg = f"{source.name}, a closed over free variable"
else:
tb = "".join(traceback.format_list(stack))
extra = ""
if len(user_stacks) > 1:
extra = f"(elided {len(user_stacks) - 1} more accesses)"
msg = f"{source.name()}, accessed at:\n{tb}{extra}"
msg = f"{source.name}, accessed at:\n{tb}{extra}"
# TODO: option to print ALL of the stack traces at once
input_errors.append(msg)

View File

@ -389,7 +389,7 @@
{
"Gb_type": "Encountered aliasing during higher order op tracing",
"Context": "context",
"Explanation": "Higher order ops do not support aliasing. Found in {source_target.name()}",
"Explanation": "Higher order ops do not support aliasing. Found in {source_target.name}",
"Hints": [
"Replace `return input` with `return input.clone()` to avoid aliasing.",
"Consider using the debug context to change user code to avoid aliasing.",
@ -401,7 +401,7 @@
{
"Gb_type": "Encountered input mutation during higher order op tracing",
"Context": "context",
"Explanation": "Higher order ops do not support input mutation. Found in {source_target.name()}",
"Explanation": "Higher order ops do not support input mutation. Found in {source_target.name}",
"Hints": [
"Consider using the debug context to change user code to avoid mutation.",
"Please open an issue."
@ -1469,7 +1469,7 @@
{
"Gb_type": "Unsupported function call (delayed)",
"Context": "source: {self.source}",
"Explanation": "Dynamo determined that a graph break should occur when calling `{self.source.name()}`. Reason: {self.msg}",
"Explanation": "Dynamo determined that a graph break should occur when calling `{self.source.name}`. Reason: {self.msg}",
"Hints": []
}
],

View File

@ -913,7 +913,7 @@ def getitem_on_dict_manager(
example_value: Any,
guard_manager_enum: GuardManagerType,
) -> GuardManager:
base_source_name = source.base.name()
base_source_name = source.base.name
if isinstance(source.index, ConstDictKeySource):
index = source.index.index
else:
@ -1042,9 +1042,9 @@ class GuardBuilder(GuardBuilderBase):
self.key_order_guarded_dict_ids = set()
assert self.check_fn_manager.output_graph is not None
for source in self.check_fn_manager.output_graph.guard_on_key_order:
dict_obj = self.get(source.name())
dict_obj = self.get(source.name)
if self.save_guards:
self.source_get_cache[source.name()] = dict_obj
self.source_get_cache[source.name] = dict_obj
self.key_order_guarded_dict_ids.add(id(dict_obj))
# Keep track of weak references of objects with ID_MATCH guard. This
@ -1072,7 +1072,7 @@ class GuardBuilder(GuardBuilderBase):
)
# Iterate over the dicts and install a dict_getitem_manager.
dict_source = guard.originating_source.name()
dict_source = guard.originating_source.name
# Ensure that we call dict.keys and not value.keys (which can call
# overridden keys method). In the C++ guards, we relied on PyDict_Next
@ -1255,7 +1255,7 @@ class GuardBuilder(GuardBuilderBase):
l1_guard_manager_enum = l2_guard_manager_enum = None
if l2_key:
l1_source = AttrSource(source.base, l1_key)
l1_source_name = l1_source.name()
l1_source_name = l1_source.name
l1_value = mod_dict[l1_key]
# do not guard on key order for _parameters etc unless the user code
# actually needs the key order (e.g. calling named_parameters)
@ -1303,7 +1303,7 @@ class GuardBuilder(GuardBuilderBase):
return l1_mgr
def requires_key_order_guarding(self, source: Source) -> bool:
source_name = source.name()
source_name = source.name
if source_name == "":
return False
obj_id = id(self.get(source_name))
@ -1346,7 +1346,7 @@ class GuardBuilder(GuardBuilderBase):
root_guard_manager = self.guard_manager.root
example_value = None
source_name = source.name()
source_name = source.name
if source_name != "" and source_name in self._cached_guard_managers:
return self._cached_guard_managers[source_name]
@ -1363,7 +1363,7 @@ class GuardBuilder(GuardBuilderBase):
base_guard_manager = None
base_guard_manager_enum = GuardManagerType.GUARD_MANAGER
if isinstance(source, ChainedSource):
base_source_name = source.base.name()
base_source_name = source.base.name
base_example_value = self.get(base_source_name)
base_guard_manager = self.get_guard_manager_from_source(source.base)
base_guard_manager_enum = self.get_guard_manager_type(
@ -1747,10 +1747,10 @@ class GuardBuilder(GuardBuilderBase):
)
else:
raise AssertionError(
f"missing guard manager builder {source} - {source.name()}"
f"missing guard manager builder {source} - {source.name}"
)
self._cached_guard_managers[source.name()] = out
self._cached_guard_managers[source.name] = out
return out
def get_guard_manager(self, guard: Guard) -> GuardManager:
@ -1849,7 +1849,7 @@ class GuardBuilder(GuardBuilderBase):
return
assert isinstance(source, AttrSource), f"invalid source {guard.name}"
base_source = source.base
base = base_source.name()
base = base_source.name
attr = source.member
ref = self.arg_ref(base)
@ -1871,7 +1871,7 @@ class GuardBuilder(GuardBuilderBase):
if val:
# Just install a getattr manager. GetAttrGuardAccessor itself
# acts as hasattr guard.
example_value = self.get(source.name())
example_value = self.get(source.name)
base_example_value = self.get(base)
guard_manager_enum = self.get_guard_manager_type(source, example_value)
@ -1884,7 +1884,7 @@ class GuardBuilder(GuardBuilderBase):
base_example_value,
example_value,
base,
source.name(),
source.name,
guard_manager_enum,
)
else:
@ -2419,7 +2419,7 @@ class GuardBuilder(GuardBuilderBase):
self.check_fn_manager.additional_used_global_vars.add(name)
ref_a = self.arg_ref(guard)
ref_b = self.arg_ref(source_b.name())
ref_b = self.arg_ref(source_b.name)
if is_from_optimizer_source(
guard.originating_source
@ -2694,7 +2694,7 @@ class GuardBuilder(GuardBuilderBase):
python_fallback = True
else:
example_value = self.get(
source.name(),
source.name,
closure_vars={**SYMPY_INTERP, **_get_closure_vars()},
)
if isinstance(example_value, int):
@ -3895,11 +3895,11 @@ class CheckFunctionManager:
guard_source = source.guard_source()
if guard_source is GuardSource.CONSTANT:
# No need to track constants
return source.name()
return source.name
assert w_builder
r_builder = w_builder()
assert r_builder is not None
return r_builder.arg_ref(source.name())
return r_builder.arg_ref(source.name)
builder = GuardBuilder(
f_code,
@ -4063,7 +4063,7 @@ class CheckFunctionManager:
if isinstance(guard, DuplicateInputs):
source_a = guard.input_source_a
source_b = guard.input_source_b
code_part = f"{source_a.name()} is {source_b.name()}"
code_part = f"{source_a.name} is {source_b.name}"
install_object_aliasing_guard(
builder.get_guard_manager_from_source(source_a),
builder.get_guard_manager_from_source(source_b),
@ -4081,8 +4081,8 @@ class CheckFunctionManager:
]
code_part = (
"""check_overlapping("""
f"""overlapping=[{", ".join(s.name() for s in guard.overlapping_sources)}], """
f"""non_overlapping=[{", ".join(s.name() for s in guard.non_overlapping_sources)}])"""
f"""overlapping=[{", ".join(s.name for s in guard.overlapping_sources)}], """
f"""non_overlapping=[{", ".join(s.name for s in guard.non_overlapping_sources)}])"""
)
install_storage_overlapping_guard(
overlapping_guard_managers,
@ -4533,7 +4533,7 @@ def make_dupe_guard(
dupe_source
) or is_from_flatten_script_object_source(obj_source):
raise exc.UnsafeScriptObjectError(
f"{obj_source.name()} is aliasing {dupe_source.name()}. This is not supported."
f"{obj_source.name} is aliasing {dupe_source.name}. This is not supported."
f" Please do a clone for corresponding input."
)

View File

@ -1230,7 +1230,7 @@ class OutputGraph(OutputGraphCommon):
self.param_name_to_source[new_name] = new_source
if isinstance(source, LocalSource):
self.dynamo_flat_name_to_original_fqn[
OutputGraph.module_key_name(new_source.name())
OutputGraph.module_key_name(new_source.name)
] = leaf_name
# annoying, but there are cases when we do not have parameters
@ -2559,7 +2559,7 @@ class OutputGraph(OutputGraphCommon):
return None
def remove_unused(node: fx.Node) -> None:
log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name())
log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name)
# I'm not really sure why you need to delete these from the
# node since the node is going to get removed
del node.meta["grapharg"]
@ -2741,7 +2741,7 @@ class OutputGraph(OutputGraphCommon):
def add_fqn_info_for_inlined_modules(
self, inlined_module: torch.nn.Module, source: Source
) -> None:
name = OutputGraph.module_key_name(source.name())
name = OutputGraph.module_key_name(source.name)
name = get_unique_name_wrt(
name, self.used_inlined_inbuilt_modules_names, self.global_scope
)
@ -2754,7 +2754,7 @@ class OutputGraph(OutputGraphCommon):
self.param_name_to_source[new_name] = new_source
if isinstance(source, LocalSource):
self.dynamo_flat_name_to_original_fqn[
OutputGraph.module_key_name(new_source.name())
OutputGraph.module_key_name(new_source.name)
] = leaf_name
# annoying, but there are cases when we do not have parameters
@ -3306,7 +3306,7 @@ class SubgraphTracer(fx.Tracer):
log.debug(
"create_graph_input %s %s %s at debug_level %s before=%s",
name,
source.name() if source is not None else "(none)",
source.name if source is not None else "(none)",
example_value,
self.debug_level,
before,
@ -3652,7 +3652,7 @@ class SubgraphTracer(fx.Tracer):
log.debug(
"_lift_symbols_in_symint %s from %s at debug_level %s",
s0,
source.name() if source is not None else "subgraph inputs",
source.name if source is not None else "subgraph inputs",
self.debug_level,
)
self.lifted_freevars[parent_proxy] = ph # type: ignore[index]
@ -3678,7 +3678,7 @@ class SubgraphTracer(fx.Tracer):
log.debug(
"_lift_symbols_in_symint %s from %s at debug_level %s",
s,
source.name() if source is not None else "subgraph inputs",
source.name if source is not None else "subgraph inputs",
self.debug_level,
)
ph.node.meta["grapharg"] = GraphArg(

View File

@ -117,7 +117,7 @@ def _get_source_debug_name(source: Optional[Source]) -> str:
return "<unknown source>"
else:
try:
return source.name()
return source.name
except NotImplementedError:
return "<unknown source>"
@ -147,6 +147,7 @@ class LocalSource(Source):
def guard_source(self) -> GuardSource:
return GuardSource.LOCAL
@functools.cached_property
def name(self) -> str:
return f"L[{repr(self.local_name)}]"
@ -162,6 +163,7 @@ class TempLocalSource(Source):
def guard_source(self) -> GuardSource:
return GuardSource.TEMP_LOCAL
@functools.cached_property
def name(self) -> str:
raise NotImplementedError(
"Cannot create guard on TempLocalSource - this is an internal Dynamo bug. Please file an issue on GitHub."
@ -178,6 +180,7 @@ class SyntheticLocalSource(Source):
def guard_source(self) -> GuardSource:
return GuardSource.SYNTHETIC_LOCAL
@functools.cached_property
def name(self) -> str:
return f"SYNTHETIC_LOCAL[{self.local_name!r}]"
@ -194,6 +197,7 @@ class RandomValueSource(Source):
codegen.append_output(codegen.create_load_const(self.random_call_index))
codegen.append_output(create_binary_subscr())
@functools.cached_property
def name(self) -> str:
return f"random_value_{self.random_call_index}"
@ -208,6 +212,7 @@ class GlobalSource(Source):
def guard_source(self) -> GuardSource:
return GuardSource.GLOBAL
@functools.cached_property
def name(self) -> str:
return f"G[{repr(self.global_name)}]"
@ -227,6 +232,7 @@ class GlobalWeakRefSource(Source):
def guard_source(self) -> GuardSource:
return GuardSource.GLOBAL
@functools.cached_property
def name(self) -> str:
return f"G[{repr(self.global_name)}]()"
@ -240,8 +246,9 @@ class WeakRefCallSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return f"{self.base.name()}()"
return f"{self.base.name}()"
@dataclasses.dataclass(frozen=True)
@ -269,10 +276,11 @@ class AttrSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
if not self.member.isidentifier():
return f"getattr({self.base.name()}, {self.member!r})"
return f"{self.base.name()}.{self.member}"
return f"getattr({self.base.name}, {self.member!r})"
return f"{self.base.name}.{self.member}"
@dataclasses.dataclass(frozen=True)
@ -295,8 +303,9 @@ class GenericAttrSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return f"object.__getattribute__({self.base.name()}, {self.member!r})"
return f"object.__getattribute__({self.base.name}, {self.member!r})"
# Represents obj.__dict__ where obj is a type object
@ -309,12 +318,13 @@ class TypeDictSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
# type(ob).__dict__ can return a proxy of the dict. But in the C++
# guard accessor, we are use type->tp_dict which is a dict. So,
# forcefully pass a dict object to ensure that the GuardManager
# registers that its working on a dict object.
return f"dict({self.base.name()}.__dict__)"
return f"dict({self.base.name}.__dict__)"
# Represents obj.__mro__ where object is type object
@ -327,8 +337,9 @@ class TypeMROSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return f"{self.base.name()}.__mro__"
return f"{self.base.name}.__mro__"
@dataclasses.dataclass(frozen=True)
@ -360,8 +371,9 @@ class CodeSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return f"{self.base.name()}.__code__"
return f"{self.base.name}.__code__"
# Represents obj.__closure__ where object is type object
@ -374,8 +386,9 @@ class ClosureSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return f"{self.base.name()}.__closure__"
return f"{self.base.name}.__closure__"
# Represents tensor.grad source. It could be represented by AttrSource as well.
@ -393,8 +406,9 @@ class GradSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return f"{self.base.name()}.{self.member}"
return f"{self.base.name}.{self.member}"
@dataclasses.dataclass(frozen=True)
@ -425,6 +439,7 @@ class EphemeralSource(Source):
def guard_source(self) -> GuardSource:
return GuardSource.EPHEMERAL
@functools.cached_property
def name(self) -> str:
return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
@ -443,8 +458,9 @@ class SkipGuardSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return self.base.name()
return self.base.name
class TensorProperty(enum.Enum):
@ -492,14 +508,15 @@ class TensorPropertySource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
if self.prop is TensorProperty.SIZE:
return f"{self.base.name()}.size()[{self.idx}]"
return f"{self.base.name}.size()[{self.idx}]"
elif self.prop is TensorProperty.STRIDE:
return f"{self.base.name()}.stride()[{self.idx}]"
return f"{self.base.name}.stride()[{self.idx}]"
elif self.prop is TensorProperty.STORAGE_OFFSET:
assert self.idx is None
return f"{self.base.name()}.storage_offset()"
return f"{self.base.name}.storage_offset()"
else:
raise AssertionError(f"unhandled {self.prop}")
@ -517,8 +534,9 @@ class IndexedSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return f"({self.idx}, {self.base.name()})"
return f"({self.idx}, {self.base.name})"
@dataclasses.dataclass(frozen=True)
@ -532,9 +550,10 @@ class NegateSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
# NB: use method call so that function stripping regexes work
return f"{self.base.name()}.__neg__()"
return f"{self.base.name}.__neg__()"
@dataclasses.dataclass(frozen=True)
@ -548,8 +567,9 @@ class ConvertIntSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return f"cast_symbool_to_symint_guardless({self.base.name()})"
return f"cast_symbool_to_symint_guardless({self.base.name})"
@dataclasses.dataclass(frozen=True)
@ -571,8 +591,9 @@ class DynamicScalarSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return f"int({self.base.name()})"
return f"int({self.base.name})"
@dataclasses.dataclass(frozen=True)
@ -586,8 +607,9 @@ class FlattenScriptObjectSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return f"{self.base.name()}.__obj_flatten__()"
return f"{self.base.name}.__obj_flatten__()"
@dataclasses.dataclass(frozen=True)
@ -601,8 +623,9 @@ class ScriptObjectQualifiedNameSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return f"{self.base.name()}._type().qualified_name()"
return f"{self.base.name}._type().qualified_name()"
class AttrProxySource(ChainedSource):
@ -612,8 +635,9 @@ class AttrProxySource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return f"{self.base.name()}.get_base()"
return f"{self.base.name}.get_base()"
@dataclasses.dataclass(frozen=True)
@ -631,13 +655,13 @@ class DefaultsSource(ChainedSource):
assert isinstance(self.idx_key, str)
object.__setattr__(self, "field", "__kwdefaults__")
object.__setattr__(
self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']"
self, "_name", f"{self.base.name}.{self.field}['{self.idx_key}']"
)
else:
assert isinstance(self.idx_key, int)
object.__setattr__(self, "field", "__defaults__")
object.__setattr__(
self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
self, "_name", f"{self.base.name}.{self.field}[{self.idx_key}]"
)
def reconstruct(self, codegen: "PyCodegen") -> None:
@ -649,6 +673,7 @@ class DefaultsSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return self._name
@ -681,15 +706,16 @@ class GetItemSource(ChainedSource):
slice_class, slice_args = self.index
return slice_class(*slice_args)
@functools.cached_property
def name(self) -> str:
# Index can be of following types
# 1) index is a slice - example 1:4
# 2) index is a constant - example string, integer
assert not isinstance(self.index, Source)
if self.index_is_slice:
return f"{self.base.name()}[{self.unpack_slice()!r}]"
return f"{self.base.name}[{self.unpack_slice()!r}]"
else:
return f"{self.base.name()}[{self.index!r}]"
return f"{self.base.name}[{self.index!r}]"
@dataclasses.dataclass(frozen=True)
@ -707,9 +733,10 @@ class ConstDictKeySource(ChainedSource):
codegen.append_output(codegen.create_load_const(self.index))
codegen.extend_output(create_call_function(2, False))
@functools.cached_property
def name(self) -> str:
# The list creation will be CSE'd by PyExprCSEPass
return f"list(dict.keys({self.base.name()}))[{self.index!r}]"
return f"list(dict.keys({self.base.name}))[{self.index!r}]"
def is_dict_key(self) -> bool:
return True
@ -735,9 +762,10 @@ class NonSerializableSetGetItemSource(ChainedSource):
codegen.append_output(codegen.create_load_const(self.index))
codegen.extend_output(create_call_function(2, False))
@functools.cached_property
def name(self) -> str:
# set ordering might not be stable
return f"list({self.base.name()})[{self.index!r}]"
return f"list({self.base.name})[{self.index!r}]"
def is_dict_key(self) -> bool:
return False
@ -772,11 +800,12 @@ class DictGetItemSource(ChainedSource):
codegen.append_output(codegen.create_load_const(self.index))
codegen.append_output(create_binary_subscr())
@functools.cached_property
def name(self) -> str:
if isinstance(self.index, ConstDictKeySource):
return f"{self.base.name()}[{self.index.name()}]"
return f"{self.base.name}[{self.index.name}]"
else:
return f"{self.base.name()}[{self.index!r}]"
return f"{self.base.name}[{self.index!r}]"
# Same as DictGetItemSource but used for dict.__getitem__ calls to ensure that
@ -817,11 +846,12 @@ class DictSubclassGetItemSource(ChainedSource):
codegen.extend_output(create_call_function(2, False))
@functools.cached_property
def name(self) -> str:
if isinstance(self.index, ConstDictKeySource):
return f"dict.__getitem__({self.base.name()}, {self.index.name()})"
return f"dict.__getitem__({self.base.name}, {self.index.name})"
else:
return f"{self.base.name()}[{self.index!r}]"
return f"{self.base.name}[{self.index!r}]"
@dataclasses.dataclass(frozen=True)
@ -852,6 +882,7 @@ class ListGetItemSource(GetItemSource):
codegen.extend_output(create_call_function(2, False))
@functools.cached_property
def name(self) -> str:
# Index can be of following types
# 1) index is a slice - example 1:4
@ -862,7 +893,7 @@ class ListGetItemSource(GetItemSource):
"List[slice] is a temporary object and should not have a source"
)
else:
return f"list.__getitem__({self.base.name()}, {self.index!r})"
return f"list.__getitem__({self.base.name}, {self.index!r})"
@dataclasses.dataclass(frozen=True)
@ -875,8 +906,9 @@ class TupleIteratorGetItemSource(GetItemSource):
codegen.append_output(codegen.create_load_const(self.index))
codegen.extend_output(create_call_function(2, False))
@functools.cached_property
def name(self) -> str:
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
return f"___tuple_iterator_getitem({self.base.name}, {self.index!r})"
@dataclasses.dataclass(frozen=True)
@ -888,8 +920,9 @@ class NamedTupleFieldsSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return f"___namedtuple_fields({self.base.name()})"
return f"___namedtuple_fields({self.base.name})"
@dataclasses.dataclass(frozen=True)
@ -904,8 +937,9 @@ class DataclassFieldsSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return f"___dataclass_fields({self.base.name()})"
return f"___dataclass_fields({self.base.name})"
@dataclasses.dataclass(frozen=True)
@ -921,8 +955,9 @@ class TypeSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return f"type({self.base.name()})"
return f"type({self.base.name})"
@dataclasses.dataclass(frozen=True)
@ -933,8 +968,9 @@ class OptimizerSource(ChainedSource):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@functools.cached_property
def name(self) -> str:
return self.base.name()
return self.base.name
@dataclasses.dataclass(frozen=True)
@ -945,8 +981,9 @@ class NNModuleSource(ChainedSource):
def guard_source(self) -> GuardSource:
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
@functools.cached_property
def name(self) -> str:
return self.base.name()
return self.base.name
@dataclasses.dataclass(frozen=True)
@ -969,6 +1006,7 @@ class FSDPNNModuleSource(NNModuleSource):
@dataclasses.dataclass(frozen=True)
class GlobalStateSource(Source):
@functools.cached_property
def name(self) -> str:
return ""
@ -987,6 +1025,7 @@ class TorchSource(Source):
install_guard(self.make_guard(GuardBuilder.ID_MATCH))
@functools.cached_property
def name(self) -> str:
return "__import__('torch')"
@ -1007,6 +1046,7 @@ class TorchSource(Source):
class TorchFunctionModeStackSource(Source):
ind: int
@functools.cached_property
def name(self) -> str:
return f"___get_torch_function_mode_stack_at({self._get_index()})"
@ -1038,6 +1078,7 @@ class ConstantSource(Source):
def guard_source(self) -> GuardSource:
return GuardSource.CONSTANT
@functools.cached_property
def name(self) -> str:
return self.source_name
@ -1047,8 +1088,9 @@ class ConstantSource(Source):
@dataclasses.dataclass(frozen=True)
class NumpyTensorSource(ChainedSource):
@functools.cached_property
def name(self) -> str:
return f"___from_numpy({self.base.name()})"
return f"___from_numpy({self.base.name})"
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@ -1061,8 +1103,9 @@ class NumpyTensorSource(ChainedSource):
@dataclasses.dataclass(frozen=True)
class SubclassAttrListSource(ChainedSource):
@functools.cached_property
def name(self) -> str:
return f"{self.base.name()}.__tensor_flatten__()[0]"
return f"{self.base.name}.__tensor_flatten__()[0]"
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@ -1072,8 +1115,9 @@ class SubclassAttrListSource(ChainedSource):
# source, it is ephemeral
@dataclasses.dataclass(frozen=True)
class FloatTensorSource(ChainedSource):
@functools.cached_property
def name(self) -> str:
return f"___as_tensor({self.base.name()})"
return f"___as_tensor({self.base.name})"
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@ -1081,8 +1125,9 @@ class FloatTensorSource(ChainedSource):
@dataclasses.dataclass(frozen=True)
class CallMethodItemSource(ChainedSource):
@functools.cached_property
def name(self) -> str:
return f"{self.base.name()}.item()"
return f"{self.base.name}.item()"
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@ -1093,6 +1138,7 @@ class CallMethodItemSource(ChainedSource):
# guard contents from the ambient ShapeEnv
@dataclasses.dataclass(frozen=True)
class ShapeEnvSource(Source):
@functools.cached_property
def name(self) -> str:
return ""
@ -1104,6 +1150,7 @@ class ShapeEnvSource(Source):
class CurrentStreamSource(Source):
device: device_type
@functools.cached_property
def name(self) -> str:
return f"___get_current_stream(torch.device('{self.device.type}', {self.device.index}))"
@ -1126,6 +1173,7 @@ class CurrentStreamSource(Source):
@dataclasses.dataclass(frozen=True)
class BackwardStateSource(Source):
@functools.cached_property
def name(self) -> str:
return ""

View File

@ -389,7 +389,7 @@ class GraphArg:
self.example_strong_ref = None
def __eq__(self, other):
return self.source.name() == other.source.name()
return self.source.name == other.source.name
class BackwardStateGraphArg(GraphArg):
@ -444,7 +444,7 @@ class VariableBuilder:
super().__init__()
self.tx = tx
self.source = source
self.name = source.name()
self.name = source.name
def __call__(self, value):
if value in self.tx.output.side_effects:
@ -1645,7 +1645,7 @@ class VariableBuilder:
elif value.dynamism.type == _DimHintType.DYNAMIC:
log.debug(
"%s marked %s via IntWrapper",
self.source.name(),
self.source.name,
DimDynamic.DYNAMIC,
)
return self.wrap_symint(
@ -1658,7 +1658,7 @@ class VariableBuilder:
elif value.dynamism.type == _DimHintType.AUTO:
log.debug(
"%s marked %s via IntWrapper",
self.source.name(),
self.source.name,
DimDynamic.DYNAMIC,
)
return self.wrap_symint(value.val, dynamism=DimDynamic.DYNAMIC)
@ -1831,7 +1831,7 @@ class VariableBuilder:
from ..decorators import mark_static_address
static_inputs_log.debug(
"Marking static input %s, id: %s)", self.source.name(), id(value)
"Marking static input %s, id: %s)", self.source.name, id(value)
)
mark_static_address(value, guard=guard)
@ -2003,12 +2003,12 @@ class VariableBuilder:
def wrap_literal(self, value):
if type(value) is int:
# allowlist has higher precedence over specialization control.
if is_dynamic_source(self.source.name()):
log.debug("%s marked dynamic via source whitelist", self.source.name())
if is_dynamic_source(self.source.name):
log.debug("%s marked dynamic via source whitelist", self.source.name)
return self.wrap_symint(value, dynamism=DimDynamic.DYNAMIC)
if is_unbacked_source(self.source.name()):
log.debug("%s marked unbacked via source whitelist", self.source.name())
if is_unbacked_source(self.source.name):
log.debug("%s marked unbacked via source whitelist", self.source.name)
return self.wrap_symint(value, dynamism=DimDynamic.SIZE_LIKE_UNBACKED)
if not config.specialize_int:
@ -2034,7 +2034,7 @@ class VariableBuilder:
process_automatic_dynamic(
self.tx,
self.source.name(),
self.source.name,
FrameStateSizeEntry.make_scalar(value),
is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(),
)
@ -2440,7 +2440,7 @@ class VariableBuilder:
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value, source=self.source)
name = self.source.name()
name = self.source.name
frame_state_entry = process_automatic_dynamic(
self.tx,
@ -2453,7 +2453,7 @@ class VariableBuilder:
# know if bare integers are actually going to be sizevars
# and it is inappropriate to eagerly duck size them with
# real sizevars
normalized_source_name = normalize_source_name(self.source.name())
normalized_source_name = normalize_source_name(self.source.name)
base_source = self.source
if isinstance(base_source, ChainedSource):
base_source = base_source.get_base()
@ -2539,7 +2539,7 @@ class VariableBuilder:
frame_state_entry = process_automatic_dynamic(
self.tx,
self.source.name(),
self.source.name,
FrameStateSizeEntry.make_scalar(value),
is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(),
)
@ -3386,7 +3386,7 @@ def _automatic_dynamic(
hints=[],
)
name = source.name()
name = source.name
prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None)
shape_env_to_source_to_symbol_cache = (
prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else None
@ -3509,7 +3509,7 @@ def _automatic_dynamic(
# Reflect the user directive in the frame_state
# For dynamic, apply None always
normalized_source_name = normalize_source_name(source.name())
normalized_source_name = normalize_source_name(source.name)
base_source = source
if isinstance(base_source, ChainedSource):
base_source = base_source.get_base()
@ -3670,7 +3670,7 @@ def wrap_to_fake_tensor_and_record(
log.debug(
"wrap_to_fake %s %s %s %s",
source.name(),
source.name,
tuple(e.shape),
symbolic_context,
type(e),

View File

@ -1080,7 +1080,7 @@ def check_aliasing_and_input_mutation(
unimplemented(
gb_type="Encountered input mutation during higher order op tracing",
context=context,
explanation=f"Higher order ops do not support input mutation. Found in {source_target.name()}",
explanation=f"Higher order ops do not support input mutation. Found in {source_target.name}",
hints=[
"Consider using the debug context to change user code to avoid mutation.",
"Please open an issue.",
@ -1094,7 +1094,7 @@ def check_aliasing_and_input_mutation(
unimplemented(
gb_type="Encountered aliasing during higher order op tracing",
context=context,
explanation=f"Higher order ops do not support aliasing. Found in {source_target.name()}",
explanation=f"Higher order ops do not support aliasing. Found in {source_target.name}",
hints=[
"Replace `return input` with `return input.clone()` to avoid aliasing.",
"Consider using the debug context to change user code to avoid aliasing.",

View File

@ -572,7 +572,7 @@ class DelayGraphBreakVariable(UnknownVariable):
gb_type="Unsupported function call (delayed)",
context=f"source: {self.source}",
explanation="Dynamo determined that a graph break should occur "
f"when calling `{self.source.name()}`. Reason: {self.msg}",
f"when calling `{self.source.name}`. Reason: {self.msg}",
hints=[],
)

View File

@ -113,7 +113,7 @@ def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs):
@contextmanager
def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module):
fully_qualified_name = source.name()
fully_qualified_name = source.name
# Remove redundant namings
fully_qualified_name = re.sub(
r"\._(?:modules|parameters|buffers)\[(['\"])([^'\"\]]+)\1\]",

View File

@ -323,7 +323,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
# Note: to avoid spam logs only warn if perf hint artifact is enabled
# (NB: artifacts are only enabled at the debug or warning level)
if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG):
non_static_grad_names = [src.name() for src in non_static_grads]
non_static_grad_names = [src.name for src in non_static_grads]
perf_hint_log.warning(
(
"Grad tensors %s will be copied during cudagraphs execution."
@ -365,7 +365,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value, guard=True)
source = self.tensor_to_source[tensor_value]
self.static_tensor_names.add(tx.output.module_key_name(source.name()))
self.static_tensor_names.add(tx.output.module_key_name(source.name))
elif tensor_value in self.grad_to_source:
source = self.grad_to_source[tensor_value]
else:
@ -374,7 +374,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
source = GlobalWeakRefSource(global_name)
self.static_tensor_names.add(tx.output.module_key_name(source.name()))
self.static_tensor_names.add(tx.output.module_key_name(source.name))
return VariableTracker.build(tx, tensor_value, source)

View File

@ -314,7 +314,7 @@ class TensorVariable(VariableTracker):
# eval("super(L['mod'].model.model.encoder.embed_positions.forward__class__,
# L['mod'].model.model.encoder.embed_positions)", scope)
# Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope.
_input_associated_real_value = eval(self.source.name(), scope)
_input_associated_real_value = eval(self.source.name, scope)
except Exception as exc:
raise NotImplementedError from exc
@ -551,7 +551,7 @@ class TensorVariable(VariableTracker):
# For local source, we associate the real value. We use this real value
scope = {"L": tx.output.local_scope, "G": tx.output.global_scope}
try:
_input_associated_real_value = eval(self.source.name(), scope)
_input_associated_real_value = eval(self.source.name, scope)
except Exception as exc:
unimplemented(
gb_type="Error getting associated real value",

View File

@ -280,7 +280,7 @@ def _create_symbolic_context_for_tensor(t, source, t_constraints, sources, mode)
if isinstance(constraint, _RelaxedConstraint):
continue
symbolic_context.constraint_sizes[i] = constraint.constraint_range
mode.shape_env.source_name_to_debug_name[src.name()] = constraint.name # type: ignore[assignment]
mode.shape_env.source_name_to_debug_name[src.name] = constraint.name # type: ignore[assignment]
return symbolic_context

View File

@ -173,7 +173,7 @@ def _try_get_metadata_from_dynamo(
assert source is None or source not in seen_sources, source
seen_sources.add(source)
aot_autograd_arg_pos_to_source.append(source)
source_name = source.name() if source else str(source)
source_name = source.name if source else str(source)
# input[i] in dynamo is now:
# input[i + len(extra_params)] in AOT,

View File

@ -245,7 +245,7 @@ class Guard:
# globals (and locals, if you create a LOCAL guard) to extract the Python
# object that we want to perform guard tests on. This evaluation
# typically happens in GuardBuilder.eval. In these cases, name is
# typically produced by originating_source.name() (not to be confused with
# typically produced by originating_source.name (not to be confused with
# GuardSource - the property source).
#
# Occasionally, name is not a valid Python expression; sometimes
@ -297,7 +297,7 @@ class Guard:
@property
def name(self) -> str:
return self.originating_source.name()
return self.originating_source.name
@property
def source(self) -> GuardSource:
@ -1074,6 +1074,7 @@ class Source:
def guard_source(self) -> GuardSource:
raise NotImplementedError
@functools.cached_property
def name(self) -> str:
raise NotImplementedError

View File

@ -870,7 +870,7 @@ class MetaConverter(Generic[_TensorT]):
# This function assumes that it's possible to do the conversion
# NB: name here is used in a conventional way by Dynamo; it corresponds
# precisely to the Source.name() of the tensor we're fakeifying and
# precisely to the Source.name of the tensor we're fakeifying and
# corresponds to a valid Python expression. When we construct sub-names
# as part of this process, we will maintain this invariant! (Even though
# other users of this may not need it this property to be upheld.)
@ -1937,7 +1937,7 @@ class MetaConverter(Generic[_TensorT]):
metadata_fn=lambda: {
"describer_id": self.describer.id,
"id": t_desc.id,
"source": source.name(),
"source": source.name,
},
)

View File

@ -50,6 +50,56 @@ static py::handle _callback_from_action(
return callback;
}
// c_recursion_remaining only defined in 3.12 and 3.13
static int32_t c_recursion_limit = -1;
void set_c_recursion_limit(int32_t limit) {
if (limit < 1) {
throw std::range_error("recursion limit must be greater or equal than 1");
}
c_recursion_limit = limit;
// cannot fail
Py_SetRecursionLimit(limit); // also set the Python limit
}
int32_t get_c_recursion_limit() {
return c_recursion_limit;
}
#if IS_PYTHON_3_12_PLUS && !IS_PYTHON_3_14_PLUS
struct CRecursionLimitRAII {
PyThreadState* tstate;
int32_t old_recursion_remaining;
CRecursionLimitRAII(PyThreadState* tstate) : tstate{tstate} {
auto limit = get_c_recursion_limit();
auto& remaining = tstate->c_recursion_remaining;
this->old_recursion_remaining = remaining;
if (limit < 0) {
// no change to limit
return;
}
if (limit < remaining) {
PyErr_SetString(
PyExc_RuntimeError,
"new c_recursion limit is lower than thread's current c_recursion_remaining.");
}
remaining = limit;
}
~CRecursionLimitRAII() {
this->tstate->c_recursion_remaining = this->old_recursion_remaining;
}
};
#else
struct CRecursionLimitRAII {
CRecursionLimitRAII(PyThreadState* tstate) {}
};
#endif
// frame and callback are borrowed references.
// Returns new reference.
PyObject* dynamo__custom_eval_frame(
@ -258,6 +308,13 @@ PyObject* dynamo__custom_eval_frame(
bool apply_to_code = false;
PyObject* guarded_code = nullptr;
try {
CRecursionLimitRAII tmp(tstate); // increase C recursion limit to the given
// value during compilation
// C recursion limit failure
if (PyErr_Occurred()) {
fail();
return eval_result;
}
callback_result = dynamo_call_callback(
callback, frame, locals.get(), cache_entry, frame_state);
new_strategy =

View File

@ -19,6 +19,9 @@ PyObject* dynamo__custom_eval_frame(
PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* obj);
void skip_code_recursive(PyCodeObject* code);
void set_c_recursion_limit(int32_t limit);
int32_t get_c_recursion_limit();
#ifdef __cplusplus
} // extern "C"

View File

@ -7,6 +7,7 @@
#include <torch/csrc/dynamo/cache_entry.h>
#include <torch/csrc/dynamo/cpython_defs.h>
#include <torch/csrc/dynamo/eval_frame.h>
#include <torch/csrc/dynamo/eval_frame_cpp.h>
#include <torch/csrc/dynamo/extra_state.h>
#include <torch/csrc/dynamo/guards.h>
#include <torch/csrc/dynamo/python_compiled_autograd.h>
@ -250,6 +251,9 @@ void initDynamoBindings(PyObject* torch) {
.def_readwrite("cur_action", &FrameExecStrategy::cur_action)
.def_readwrite("recursive_action", &FrameExecStrategy::recursive_action);
m.def("set_c_recursion_limit", &set_c_recursion_limit);
m.def("get_c_recursion_limit", &get_c_recursion_limit);
m.def("_debug_get_cache_entry_list", &_debug_get_cache_entry_list);
m.def("_reset_precompile_entries", &_reset_precompile_entries);
m.def("_load_precompile_entry", &_load_precompile_entry);

View File

@ -1918,7 +1918,7 @@ class StrictMinMaxConstraint(Constraint):
def render(self, source: Source) -> str:
"""Format the constrain equation"""
# TODO: better printing for -oo and oo
return f"{self.vr.lower} <= {source.name()} <= {self.vr.upper}"
return f"{self.vr.lower} <= {source.name} <= {self.vr.upper}"
@dataclass(frozen=True)
@ -1943,7 +1943,7 @@ class RelaxedUnspecConstraint(Constraint):
"""
def render(self, source: Source) -> str:
return f"RelaxedUnspecConstraint({source.name()})"
return f"RelaxedUnspecConstraint({source.name})"
# NB: None here indicates the client constraint is whatever is implicitly
@ -2039,7 +2039,7 @@ class EqualityConstraint(Constraint):
return self._defs[src]
else:
# otherwise, create a symbol representing the source
return sympy.Symbol(src.name())
return sympy.Symbol(src.name)
def is_equal(self, source1: Source, source2: Source) -> bool:
return (
@ -2252,11 +2252,11 @@ class TrackedFake:
symbolic_context: Optional[SymbolicContext]
def __hash__(self) -> int:
return hash((self.fake, self.source.name()))
return hash((self.fake, self.source.name))
def __eq__(self, other: object) -> bool:
if isinstance(other, TrackedFake):
return self.fake is other.fake and self.source.name() == other.source.name()
return self.fake is other.fake and self.source.name == other.source.name
return False
@ -2712,7 +2712,7 @@ class _ShapeGuardPrinter(abc.ABC):
def repr_sources(src: Mapping[sympy.Symbol, list[Source]]) -> str:
return repr(
{
symbol: [s.name() for s in sources]
symbol: [s.name for s in sources]
for symbol, sources in src.items()
}
)
@ -2820,7 +2820,7 @@ class _ShapeGuardCppPrinter(_ShapeGuardPrinter, CppPrinter):
if source in self.source_to_symbol:
return self.source_to_symbol[source].name
source_name = source.name()
source_name = source.name
mangled_name = re.sub("[^0-9a-zA-Z_]+", "_", source_name)
old_mangled_name = mangled_name
count = 0
@ -2849,7 +2849,7 @@ class _CppShapeGuardsHelper(_ShapeGuardsHelper):
class LoggingShapeGuardPrinter(ShapeGuardPythonPrinter):
def __init__(self, var_to_sources: Mapping[sympy.Symbol, list[Source]]):
super().__init__(var_to_sources, lambda n: n.name(), var_to_sources)
super().__init__(var_to_sources, lambda n: n.name, var_to_sources)
class DynamicDimConstraintPrinter(PythonPrinter):
@ -2875,7 +2875,7 @@ class DynamicDimConstraintPrinter(PythonPrinter):
assert self.symbol_to_source.get(expr), (
f"Unknown symbol {expr} created by constraints solver"
)
return self.symbol_to_source[expr][0].name()
return self.symbol_to_source[expr][0].name
class DimConstraints:
@ -3095,7 +3095,7 @@ class DimConstraints:
"""Add an equality constraint"""
if expr.is_number:
# specialization, right here
self._static_results.add(f"{source.name()} == {expr}")
self._static_results.add(f"{source.name} == {expr}")
else:
# these will resolve to either specializations or dynamic equality constraints
self._symbolic_equivalences.append((source, expr))
@ -3175,7 +3175,7 @@ class DimConstraints:
assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}"
# because this is univariate, the solution is a specialization
self._static_results.add(
f"{self._dcp.symbol_to_source[s][0].name()} == {val}"
f"{self._dcp.symbol_to_source[s][0].name} == {val}"
)
# add this as a substitution to simplify other constraints
self._substitutions[s] = val # type: ignore[assignment]
@ -3200,8 +3200,8 @@ class DimConstraints:
base, divisor = congruence.args
tmp_name = "_" + str(
self._dcp.source_name_to_debug_name.get(
self._dcp.symbol_to_source[s][0].name(),
self._dcp.symbol_to_source[s][0].name(),
self._dcp.symbol_to_source[s][0].name,
self._dcp.symbol_to_source[s][0].name,
)
)
tmp = sympy.Symbol(tmp_name, integer=True)
@ -3243,7 +3243,7 @@ class DimConstraints:
# remaining symbolic equivalences become dynamic equality constraints
for source, expr3 in self._symbolic_equivalences:
self._dynamic_results.add(f"{source.name()} == {self._dcp.doprint(expr3)}")
self._dynamic_results.add(f"{source.name} == {self._dcp.doprint(expr3)}")
@classmethod
def _is_supported_congruence(cls, congruence: sympy.Expr) -> bool:
@ -3266,7 +3266,7 @@ class DimConstraints:
"""Returns a dictionary of the names of symbols to their specialized value"""
def debug_name(src: Source) -> str:
name = src.name()
name = src.name
if self._dcp.source_name_to_debug_name:
return f"{self._dcp.source_name_to_debug_name[name]} = {name}"
else:
@ -4011,7 +4011,7 @@ class ShapeEnv:
check_fn: A function that takes a sympy Symbol and returns a sympy expression
representing a constraint/specialization to be applied
"""
name = source.name()
name = source.name
sym = self.source_to_var[name]
expr = check_fn(SymInt(SymNode(sym, self, int, None))).node._expr
new_axioms = dict(self.get_implications(self.simplify(expr)))
@ -4284,7 +4284,7 @@ class ShapeEnv:
def _create_symbol_for_source(self, source: Source) -> Optional[sympy.Symbol]:
if not self._translation_validation_enabled:
return None
srcname = source.name()
srcname = source.name
if source not in self.source_to_symbol:
self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True)
return self.source_to_symbol[srcname]
@ -4874,7 +4874,7 @@ class ShapeEnv:
if source is None:
sloc, maybe_extra_debug = self._get_stack_summary(is_debug)
else:
sloc, maybe_extra_debug = source.name(), ""
sloc, maybe_extra_debug = source.name, ""
log.info(
"%s %s [%s, %s] %s%s",
prefix,
@ -5028,7 +5028,7 @@ class ShapeEnv:
if constraint_dim.vr.lower != val:
raise ConstraintViolationError(
f"Static shape constraint of {constraint_dim.vr.lower} does not match input size of {val}, "
f"for {source.name()}"
f"for {source.name}"
)
if symbolic_context:
from torch._dynamo.source import TensorPropertySource
@ -5041,7 +5041,7 @@ class ShapeEnv:
constraint_dim = None
# see note [Tensor Fakification and Symbol Caching]
source_name = source.name()
source_name = source.name
if (
isinstance(symbolic_context, StatefulSymbolicContext)
and id(self) not in symbolic_context.shape_env_to_source_to_symbol_cache
@ -5115,7 +5115,7 @@ class ShapeEnv:
# If we're not duck shaping, we always create a new symbol
# Even if we're duck shaping, if we haven't seen this particular
# value before, we also create a new symbol
symbol_id = self._generate_unique_id(source.name())
symbol_id = self._generate_unique_id(source.name)
if type(val) is int or is_nested_int(val):
sympy_expr = make_symbol(
SymT.SIZE, symbol_id, positive=positive, integer=True
@ -5219,7 +5219,7 @@ class ShapeEnv:
"create_symbol %s = %s for %s %s %s%s%s",
sympy_expr,
val,
source.name(),
source.name,
range_str,
sloc,
maybe_more_info,
@ -5232,7 +5232,7 @@ class ShapeEnv:
"symbol": str(sympy_expr),
"val": repr(val),
"vr": range_str,
"source": source.name(),
"source": source.name,
"user_stack": structured.from_traceback(
TracingContext.extract_stack()
),
@ -5248,7 +5248,7 @@ class ShapeEnv:
# the same symint
r = self.val_to_var[val]
self.source_to_var[source_name] = r
self.log.debug("create_symbol %s duck sized %s", r, source.name())
self.log.debug("create_symbol %s duck sized %s", r, source.name)
if isinstance(r, sympy.Symbol):
r_sources = self.var_to_sources[r]
@ -5275,7 +5275,7 @@ class ShapeEnv:
self.var_to_val[expr] = sympy.Integer(val)
def _debug_name(self, source: Source) -> str:
src_name = source.name()
src_name = source.name
return self.source_name_to_debug_name.get(src_name, src_name)
def _render_range_for_constraint_violation(
@ -5289,7 +5289,7 @@ class ShapeEnv:
if upper >= default.upper:
upper = None
c_render = (
f"{self._debug_name(source)} = {source.name()} in the specified range"
f"{self._debug_name(source)} = {source.name} in the specified range"
)
if lower is not None and upper is not None:
c_render += f" {lower} <= {self._debug_name(source)} <= {upper}"
@ -5311,7 +5311,7 @@ class ShapeEnv:
self,
placeholders: Sequence[FakeTensor],
sources: Sequence[Source],
source_ref: Callable[[Source], str] = lambda n: n.name(),
source_ref: Callable[[Source], str] = lambda n: n.name,
*,
guards: Optional[list[ShapeGuard]] = None,
input_contexts: Optional[DimList[SymbolicContext]] = None,
@ -5501,10 +5501,10 @@ class ShapeEnv:
if equalities_inputs:
source_index = {}
for i, src in enumerate(sources):
source_index[src.name()] = i
source_index[src.name] = i
def get_expression(tensor_dim_src: Source) -> sympy.Expr:
fake = placeholders[source_index[tensor_dim_src.base.name()]] # type: ignore[attr-defined]
fake = placeholders[source_index[tensor_dim_src.base.name]] # type: ignore[attr-defined]
assert tensor_dim_src.idx is not None # type: ignore[attr-defined]
symint = fake.shape[tensor_dim_src.idx] # type: ignore[attr-defined]
if isinstance(symint, torch.SymInt):
@ -5521,16 +5521,16 @@ class ShapeEnv:
concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2))
if not concrete_val:
raise ConstraintViolationError(
f"{src1.name()} = {expr1 if isinstance(expr1, int) else expr1.xreplace(self.var_to_val)}"
f"{src1.name} = {expr1 if isinstance(expr1, int) else expr1.xreplace(self.var_to_val)}"
" is not equal to "
f"{src2.name()} = {expr2 if isinstance(expr2, int) else expr2.xreplace(self.var_to_val)}"
f"{src2.name} = {expr2 if isinstance(expr2, int) else expr2.xreplace(self.var_to_val)}"
)
for srcEq, root, fn in equalities_inputs.derived_equalities:
expr1 = get_expression(srcEq)
# recall that root is either a phantom symbol or an input source
if isinstance(root, sympy.Symbol):
expr2, debug_name = root, self.var_to_sources[root][0].name()
expr2, debug_name = root, self.var_to_sources[root][0].name
elif isinstance(root, sympy.Integer):
expr2, debug_name = root, str(root)
else:
@ -5542,7 +5542,7 @@ class ShapeEnv:
concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2_))
if not concrete_val:
raise ConstraintViolationError(
f"Expected input {srcEq.name()} to be equal to "
f"Expected input {srcEq.name} to be equal to "
f"{fn(sympy.Symbol(debug_name))}, "
f"where {debug_name} = {expr2.xreplace(self.var_to_val)}, "
f"but got {expr1.xreplace(self.var_to_val)}"
@ -5764,7 +5764,7 @@ class ShapeEnv:
if not _simplified:
for source, expr in input_guards:
srcname = source.name()
srcname = source.name
if self._translation_validation_enabled:
# Ignore sources that were not turned into SymInts.
if srcname in self.source_to_symbol:
@ -5827,8 +5827,8 @@ class ShapeEnv:
)
):
msg = (
f"The values of {self._debug_name(source)} = {source.name()} and "
f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name()} "
f"The values of {self._debug_name(source)} = {source.name} and "
f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name} "
"must always be equal."
)
record_constraint_violation(
@ -5846,8 +5846,8 @@ class ShapeEnv:
):
src = symbol_to_source[symbol][0]
msg = (
f"The values of {self._debug_name(source)} = {source.name()} must always be related to "
f"the values of {self._debug_name(src)} = {src.name()} by "
f"The values of {self._debug_name(source)} = {source.name} must always be related to "
f"the values of {self._debug_name(src)} = {src.name} by "
f"{self._debug_name(source)} = {expr.xreplace({symbol: sympy.sympify(self._debug_name(src))})}."
)
record_constraint_violation(
@ -6868,7 +6868,7 @@ class ShapeEnv:
"symbolic_shape_specialization",
metadata_fn=lambda: {
"symbol": repr(a),
"sources": [s.name() for s in self.var_to_sources.get(a, [])],
"sources": [s.name for s in self.var_to_sources.get(a, [])],
"value": repr(tgt),
"reason": msg,
"stack": structured.from_traceback(
@ -6886,7 +6886,7 @@ class ShapeEnv:
if config.print_specializations:
self.log.warning(
"Specializing %s to %s", self.var_to_sources[a][0].name(), tgt
"Specializing %s to %s", self.var_to_sources[a][0].name, tgt
)
self.log.debug("SPECIALIZATION", stack_info=True)
log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)
@ -7211,7 +7211,7 @@ class ShapeEnv:
if str(s) in frame_symbols: # type: ignore[operator]
continue
if s in self.var_to_sources:
frame_symbols[str(s)] = self.var_to_sources[s][0].name() # type: ignore[assignment]
frame_symbols[str(s)] = self.var_to_sources[s][0].name # type: ignore[assignment]
return str(x)
return None