mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-20 02:24:54 +08:00
Compare commits
13 Commits
optimizer_
...
ciflow/ind
| Author | SHA1 | Date | |
|---|---|---|---|
| 967e7093a5 | |||
| 8cc430db5d | |||
| 50ea044f8a | |||
| 257bf8e59e | |||
| be71654b78 | |||
| fa0c57142a | |||
| b737df7704 | |||
| fb09741981 | |||
| 5230b7c0ac | |||
| 881cd1c6f4 | |||
| 2fdd517fd9 | |||
| 43fe667181 | |||
| d8ebe3543d |
@ -928,8 +928,8 @@ class TypePropagationTests(torch._dynamo.test_case.TestCase):
|
||||
foo_source = LocalSource("foo")
|
||||
foo_x_source = AttrSource(foo_source, "x")
|
||||
|
||||
self.assertTrue(builder.get(foo_source.name()) is foo)
|
||||
self.assertTrue(builder.get(foo_x_source.name()) is foo.x)
|
||||
self.assertTrue(builder.get(foo_source.name) is foo)
|
||||
self.assertTrue(builder.get(foo_x_source.name) is foo.x)
|
||||
|
||||
# Check types of foo.x
|
||||
foo_x_mgr = builder.get_guard_manager_from_source(foo_x_source)
|
||||
|
||||
@ -7456,6 +7456,97 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
msg,
|
||||
)
|
||||
|
||||
def test_dynamo_set_recursion_limit_simple(self):
|
||||
# Test that torch._dynamo.set_recursion_limit calls sys.setrecursionlimit for all supported
|
||||
# Python versions
|
||||
old_recursion_limit = sys.getrecursionlimit()
|
||||
old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit()
|
||||
try:
|
||||
|
||||
def fn(x, n):
|
||||
if n == 0:
|
||||
return x
|
||||
return fn(x, n - 1) + 1
|
||||
|
||||
sys.setrecursionlimit(100)
|
||||
|
||||
with self.assertRaises(RecursionError):
|
||||
fn(torch.ones(3), 1000)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", dynamic=False)
|
||||
torch._dynamo.set_recursion_limit(100000)
|
||||
self.assertEqual(fn(torch.ones(3), 1000), opt_fn(torch.ones(3), 1000))
|
||||
finally:
|
||||
if old_dynamo_recursion_limit > 0:
|
||||
torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit)
|
||||
sys.setrecursionlimit(old_recursion_limit)
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 12) or sys.version_info >= (3, 14),
|
||||
"only 3.12, 3.13 affected by c recursion limit",
|
||||
)
|
||||
def test_dynamo_set_recursion_limit(self):
|
||||
old_recursion_limit = sys.getrecursionlimit()
|
||||
old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit()
|
||||
try:
|
||||
|
||||
def fn(x, n):
|
||||
if n == 0:
|
||||
return x
|
||||
return fn(x, n - 1) + 1
|
||||
|
||||
sys.setrecursionlimit(100)
|
||||
|
||||
with self.assertRaises(RecursionError):
|
||||
fn(torch.ones(3), 1000)
|
||||
|
||||
sys.setrecursionlimit(2000)
|
||||
|
||||
fn(torch.ones(3), 1000)
|
||||
opt_fn = torch.compile(fn, backend="eager", dynamic=False)
|
||||
sys.setrecursionlimit(100000)
|
||||
with self.assertRaises(Exception):
|
||||
opt_fn(torch.ones(3), 1000)
|
||||
|
||||
torch._dynamo.set_recursion_limit(100000)
|
||||
self.assertEqual(fn(torch.ones(3), 1000), opt_fn(torch.ones(3), 1000))
|
||||
finally:
|
||||
if old_dynamo_recursion_limit > 0:
|
||||
torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit)
|
||||
sys.setrecursionlimit(old_recursion_limit)
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 12) or sys.version_info >= (3, 14),
|
||||
"only 3.12, 3.13 affected by c recursion limit",
|
||||
)
|
||||
def test_dynamo_set_recursion_limit_usage(self):
|
||||
old_recursion_limit = sys.getrecursionlimit()
|
||||
old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit()
|
||||
try:
|
||||
torch._dynamo.set_recursion_limit(100)
|
||||
self.assertEqual(torch._dynamo.get_recursion_limit(), 100)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "recursion limit"):
|
||||
torch._dynamo.set_recursion_limit(0)
|
||||
|
||||
self.assertEqual(torch._dynamo.get_recursion_limit(), 100)
|
||||
|
||||
torch._dynamo.set_recursion_limit(1)
|
||||
sys.setrecursionlimit(100)
|
||||
|
||||
@torch.compile(backend="eager", dynamic=False)
|
||||
def fn(x, n):
|
||||
if n == 0:
|
||||
return x
|
||||
return fn(x, n - 1) + 1
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "new c_recursion limit"):
|
||||
fn(torch.ones(3), 5)
|
||||
finally:
|
||||
if old_dynamo_recursion_limit > 0:
|
||||
torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit)
|
||||
sys.setrecursionlimit(old_recursion_limit)
|
||||
|
||||
@expectedFailureDynamic
|
||||
def test_dynamo_default_lru_cache_behavior(self):
|
||||
@torch.compile(backend="eager")
|
||||
|
||||
@ -1624,7 +1624,7 @@ class GraphModule(torch.nn.Module):
|
||||
str(k): v for k, v in context.fake_mode.shape_env.var_to_val.items()
|
||||
}
|
||||
curr_var_to_sources = {
|
||||
str(k): v[0].name()
|
||||
str(k): v[0].name
|
||||
for k, v in context.fake_mode.shape_env.var_to_sources.items()
|
||||
}
|
||||
return gm
|
||||
|
||||
@ -19,6 +19,8 @@ def set_guard_complete_hook(
|
||||
hook: Optional[DynamoGuardCompleteHook],
|
||||
) -> Optional[DynamoGuardCompleteHook]: ...
|
||||
def raise_sigtrap() -> None: ...
|
||||
def set_c_recursion_limit(limit: int) -> None: ...
|
||||
def get_c_recursion_limit() -> int: ...
|
||||
|
||||
class _CacheEntry:
|
||||
def check_fn(self, *args: object, **kwargs: object) -> bool: ...
|
||||
|
||||
@ -105,6 +105,7 @@ __all__ = [
|
||||
"reset",
|
||||
"run",
|
||||
"error_on_graph_break",
|
||||
"set_recursion_limit",
|
||||
"set_stance",
|
||||
"skip_frame",
|
||||
"step_unsupported",
|
||||
@ -181,3 +182,26 @@ def reset_code_caches() -> None:
|
||||
if code:
|
||||
reset_code(code)
|
||||
code_context.clear()
|
||||
|
||||
|
||||
def get_recursion_limit() -> int:
|
||||
"""
|
||||
Returns the internal dynamo recursion limit set by `torch._dynamo.set_recursion_limit`.
|
||||
|
||||
Returns -1 if no c recursion limit has been set.
|
||||
"""
|
||||
return torch._C._dynamo.eval_frame.get_c_recursion_limit()
|
||||
|
||||
|
||||
def set_recursion_limit(limit: int) -> None:
|
||||
"""
|
||||
Sets an internal dynamo recursion limit. The limit must be >= 1.
|
||||
|
||||
This is possibly needed in Python 3.12-3.13 since there is a separate C recursion limit
|
||||
that is not visible at the Python level. If you are getting RecursionErrors during
|
||||
Dynamo compilation and `sys.setrecursionlimit()` doesn't help, this function may alleviate
|
||||
the issue.
|
||||
|
||||
NOTE: this function will also call `sys.setrecursionlimit()`.
|
||||
"""
|
||||
torch._C._dynamo.eval_frame.set_c_recursion_limit(limit)
|
||||
|
||||
@ -1707,13 +1707,13 @@ def check_signature_rewritable(graph: torch.fx.GraphModule) -> None:
|
||||
stack = s
|
||||
break
|
||||
if stack is None:
|
||||
msg = f"{source.name()}, a closed over free variable"
|
||||
msg = f"{source.name}, a closed over free variable"
|
||||
else:
|
||||
tb = "".join(traceback.format_list(stack))
|
||||
extra = ""
|
||||
if len(user_stacks) > 1:
|
||||
extra = f"(elided {len(user_stacks) - 1} more accesses)"
|
||||
msg = f"{source.name()}, accessed at:\n{tb}{extra}"
|
||||
msg = f"{source.name}, accessed at:\n{tb}{extra}"
|
||||
# TODO: option to print ALL of the stack traces at once
|
||||
input_errors.append(msg)
|
||||
|
||||
|
||||
@ -389,7 +389,7 @@
|
||||
{
|
||||
"Gb_type": "Encountered aliasing during higher order op tracing",
|
||||
"Context": "context",
|
||||
"Explanation": "Higher order ops do not support aliasing. Found in {source_target.name()}",
|
||||
"Explanation": "Higher order ops do not support aliasing. Found in {source_target.name}",
|
||||
"Hints": [
|
||||
"Replace `return input` with `return input.clone()` to avoid aliasing.",
|
||||
"Consider using the debug context to change user code to avoid aliasing.",
|
||||
@ -401,7 +401,7 @@
|
||||
{
|
||||
"Gb_type": "Encountered input mutation during higher order op tracing",
|
||||
"Context": "context",
|
||||
"Explanation": "Higher order ops do not support input mutation. Found in {source_target.name()}",
|
||||
"Explanation": "Higher order ops do not support input mutation. Found in {source_target.name}",
|
||||
"Hints": [
|
||||
"Consider using the debug context to change user code to avoid mutation.",
|
||||
"Please open an issue."
|
||||
@ -1469,7 +1469,7 @@
|
||||
{
|
||||
"Gb_type": "Unsupported function call (delayed)",
|
||||
"Context": "source: {self.source}",
|
||||
"Explanation": "Dynamo determined that a graph break should occur when calling `{self.source.name()}`. Reason: {self.msg}",
|
||||
"Explanation": "Dynamo determined that a graph break should occur when calling `{self.source.name}`. Reason: {self.msg}",
|
||||
"Hints": []
|
||||
}
|
||||
],
|
||||
|
||||
@ -913,7 +913,7 @@ def getitem_on_dict_manager(
|
||||
example_value: Any,
|
||||
guard_manager_enum: GuardManagerType,
|
||||
) -> GuardManager:
|
||||
base_source_name = source.base.name()
|
||||
base_source_name = source.base.name
|
||||
if isinstance(source.index, ConstDictKeySource):
|
||||
index = source.index.index
|
||||
else:
|
||||
@ -1042,9 +1042,9 @@ class GuardBuilder(GuardBuilderBase):
|
||||
self.key_order_guarded_dict_ids = set()
|
||||
assert self.check_fn_manager.output_graph is not None
|
||||
for source in self.check_fn_manager.output_graph.guard_on_key_order:
|
||||
dict_obj = self.get(source.name())
|
||||
dict_obj = self.get(source.name)
|
||||
if self.save_guards:
|
||||
self.source_get_cache[source.name()] = dict_obj
|
||||
self.source_get_cache[source.name] = dict_obj
|
||||
self.key_order_guarded_dict_ids.add(id(dict_obj))
|
||||
|
||||
# Keep track of weak references of objects with ID_MATCH guard. This
|
||||
@ -1072,7 +1072,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
)
|
||||
|
||||
# Iterate over the dicts and install a dict_getitem_manager.
|
||||
dict_source = guard.originating_source.name()
|
||||
dict_source = guard.originating_source.name
|
||||
|
||||
# Ensure that we call dict.keys and not value.keys (which can call
|
||||
# overridden keys method). In the C++ guards, we relied on PyDict_Next
|
||||
@ -1255,7 +1255,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
l1_guard_manager_enum = l2_guard_manager_enum = None
|
||||
if l2_key:
|
||||
l1_source = AttrSource(source.base, l1_key)
|
||||
l1_source_name = l1_source.name()
|
||||
l1_source_name = l1_source.name
|
||||
l1_value = mod_dict[l1_key]
|
||||
# do not guard on key order for _parameters etc unless the user code
|
||||
# actually needs the key order (e.g. calling named_parameters)
|
||||
@ -1303,7 +1303,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
return l1_mgr
|
||||
|
||||
def requires_key_order_guarding(self, source: Source) -> bool:
|
||||
source_name = source.name()
|
||||
source_name = source.name
|
||||
if source_name == "":
|
||||
return False
|
||||
obj_id = id(self.get(source_name))
|
||||
@ -1346,7 +1346,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
root_guard_manager = self.guard_manager.root
|
||||
|
||||
example_value = None
|
||||
source_name = source.name()
|
||||
source_name = source.name
|
||||
|
||||
if source_name != "" and source_name in self._cached_guard_managers:
|
||||
return self._cached_guard_managers[source_name]
|
||||
@ -1363,7 +1363,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
base_guard_manager = None
|
||||
base_guard_manager_enum = GuardManagerType.GUARD_MANAGER
|
||||
if isinstance(source, ChainedSource):
|
||||
base_source_name = source.base.name()
|
||||
base_source_name = source.base.name
|
||||
base_example_value = self.get(base_source_name)
|
||||
base_guard_manager = self.get_guard_manager_from_source(source.base)
|
||||
base_guard_manager_enum = self.get_guard_manager_type(
|
||||
@ -1747,10 +1747,10 @@ class GuardBuilder(GuardBuilderBase):
|
||||
)
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"missing guard manager builder {source} - {source.name()}"
|
||||
f"missing guard manager builder {source} - {source.name}"
|
||||
)
|
||||
|
||||
self._cached_guard_managers[source.name()] = out
|
||||
self._cached_guard_managers[source.name] = out
|
||||
return out
|
||||
|
||||
def get_guard_manager(self, guard: Guard) -> GuardManager:
|
||||
@ -1849,7 +1849,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
return
|
||||
assert isinstance(source, AttrSource), f"invalid source {guard.name}"
|
||||
base_source = source.base
|
||||
base = base_source.name()
|
||||
base = base_source.name
|
||||
attr = source.member
|
||||
|
||||
ref = self.arg_ref(base)
|
||||
@ -1871,7 +1871,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
if val:
|
||||
# Just install a getattr manager. GetAttrGuardAccessor itself
|
||||
# acts as hasattr guard.
|
||||
example_value = self.get(source.name())
|
||||
example_value = self.get(source.name)
|
||||
base_example_value = self.get(base)
|
||||
guard_manager_enum = self.get_guard_manager_type(source, example_value)
|
||||
|
||||
@ -1884,7 +1884,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
base_example_value,
|
||||
example_value,
|
||||
base,
|
||||
source.name(),
|
||||
source.name,
|
||||
guard_manager_enum,
|
||||
)
|
||||
else:
|
||||
@ -2419,7 +2419,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
self.check_fn_manager.additional_used_global_vars.add(name)
|
||||
|
||||
ref_a = self.arg_ref(guard)
|
||||
ref_b = self.arg_ref(source_b.name())
|
||||
ref_b = self.arg_ref(source_b.name)
|
||||
|
||||
if is_from_optimizer_source(
|
||||
guard.originating_source
|
||||
@ -2694,7 +2694,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
python_fallback = True
|
||||
else:
|
||||
example_value = self.get(
|
||||
source.name(),
|
||||
source.name,
|
||||
closure_vars={**SYMPY_INTERP, **_get_closure_vars()},
|
||||
)
|
||||
if isinstance(example_value, int):
|
||||
@ -3895,11 +3895,11 @@ class CheckFunctionManager:
|
||||
guard_source = source.guard_source()
|
||||
if guard_source is GuardSource.CONSTANT:
|
||||
# No need to track constants
|
||||
return source.name()
|
||||
return source.name
|
||||
assert w_builder
|
||||
r_builder = w_builder()
|
||||
assert r_builder is not None
|
||||
return r_builder.arg_ref(source.name())
|
||||
return r_builder.arg_ref(source.name)
|
||||
|
||||
builder = GuardBuilder(
|
||||
f_code,
|
||||
@ -4063,7 +4063,7 @@ class CheckFunctionManager:
|
||||
if isinstance(guard, DuplicateInputs):
|
||||
source_a = guard.input_source_a
|
||||
source_b = guard.input_source_b
|
||||
code_part = f"{source_a.name()} is {source_b.name()}"
|
||||
code_part = f"{source_a.name} is {source_b.name}"
|
||||
install_object_aliasing_guard(
|
||||
builder.get_guard_manager_from_source(source_a),
|
||||
builder.get_guard_manager_from_source(source_b),
|
||||
@ -4081,8 +4081,8 @@ class CheckFunctionManager:
|
||||
]
|
||||
code_part = (
|
||||
"""check_overlapping("""
|
||||
f"""overlapping=[{", ".join(s.name() for s in guard.overlapping_sources)}], """
|
||||
f"""non_overlapping=[{", ".join(s.name() for s in guard.non_overlapping_sources)}])"""
|
||||
f"""overlapping=[{", ".join(s.name for s in guard.overlapping_sources)}], """
|
||||
f"""non_overlapping=[{", ".join(s.name for s in guard.non_overlapping_sources)}])"""
|
||||
)
|
||||
install_storage_overlapping_guard(
|
||||
overlapping_guard_managers,
|
||||
@ -4533,7 +4533,7 @@ def make_dupe_guard(
|
||||
dupe_source
|
||||
) or is_from_flatten_script_object_source(obj_source):
|
||||
raise exc.UnsafeScriptObjectError(
|
||||
f"{obj_source.name()} is aliasing {dupe_source.name()}. This is not supported."
|
||||
f"{obj_source.name} is aliasing {dupe_source.name}. This is not supported."
|
||||
f" Please do a clone for corresponding input."
|
||||
)
|
||||
|
||||
|
||||
@ -1230,7 +1230,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
self.param_name_to_source[new_name] = new_source
|
||||
if isinstance(source, LocalSource):
|
||||
self.dynamo_flat_name_to_original_fqn[
|
||||
OutputGraph.module_key_name(new_source.name())
|
||||
OutputGraph.module_key_name(new_source.name)
|
||||
] = leaf_name
|
||||
|
||||
# annoying, but there are cases when we do not have parameters
|
||||
@ -2559,7 +2559,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
return None
|
||||
|
||||
def remove_unused(node: fx.Node) -> None:
|
||||
log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name())
|
||||
log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name)
|
||||
# I'm not really sure why you need to delete these from the
|
||||
# node since the node is going to get removed
|
||||
del node.meta["grapharg"]
|
||||
@ -2741,7 +2741,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
def add_fqn_info_for_inlined_modules(
|
||||
self, inlined_module: torch.nn.Module, source: Source
|
||||
) -> None:
|
||||
name = OutputGraph.module_key_name(source.name())
|
||||
name = OutputGraph.module_key_name(source.name)
|
||||
name = get_unique_name_wrt(
|
||||
name, self.used_inlined_inbuilt_modules_names, self.global_scope
|
||||
)
|
||||
@ -2754,7 +2754,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
self.param_name_to_source[new_name] = new_source
|
||||
if isinstance(source, LocalSource):
|
||||
self.dynamo_flat_name_to_original_fqn[
|
||||
OutputGraph.module_key_name(new_source.name())
|
||||
OutputGraph.module_key_name(new_source.name)
|
||||
] = leaf_name
|
||||
|
||||
# annoying, but there are cases when we do not have parameters
|
||||
@ -3306,7 +3306,7 @@ class SubgraphTracer(fx.Tracer):
|
||||
log.debug(
|
||||
"create_graph_input %s %s %s at debug_level %s before=%s",
|
||||
name,
|
||||
source.name() if source is not None else "(none)",
|
||||
source.name if source is not None else "(none)",
|
||||
example_value,
|
||||
self.debug_level,
|
||||
before,
|
||||
@ -3652,7 +3652,7 @@ class SubgraphTracer(fx.Tracer):
|
||||
log.debug(
|
||||
"_lift_symbols_in_symint %s from %s at debug_level %s",
|
||||
s0,
|
||||
source.name() if source is not None else "subgraph inputs",
|
||||
source.name if source is not None else "subgraph inputs",
|
||||
self.debug_level,
|
||||
)
|
||||
self.lifted_freevars[parent_proxy] = ph # type: ignore[index]
|
||||
@ -3678,7 +3678,7 @@ class SubgraphTracer(fx.Tracer):
|
||||
log.debug(
|
||||
"_lift_symbols_in_symint %s from %s at debug_level %s",
|
||||
s,
|
||||
source.name() if source is not None else "subgraph inputs",
|
||||
source.name if source is not None else "subgraph inputs",
|
||||
self.debug_level,
|
||||
)
|
||||
ph.node.meta["grapharg"] = GraphArg(
|
||||
|
||||
@ -117,7 +117,7 @@ def _get_source_debug_name(source: Optional[Source]) -> str:
|
||||
return "<unknown source>"
|
||||
else:
|
||||
try:
|
||||
return source.name()
|
||||
return source.name
|
||||
except NotImplementedError:
|
||||
return "<unknown source>"
|
||||
|
||||
@ -147,6 +147,7 @@ class LocalSource(Source):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.LOCAL
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"L[{repr(self.local_name)}]"
|
||||
|
||||
@ -162,6 +163,7 @@ class TempLocalSource(Source):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.TEMP_LOCAL
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
raise NotImplementedError(
|
||||
"Cannot create guard on TempLocalSource - this is an internal Dynamo bug. Please file an issue on GitHub."
|
||||
@ -178,6 +180,7 @@ class SyntheticLocalSource(Source):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.SYNTHETIC_LOCAL
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"SYNTHETIC_LOCAL[{self.local_name!r}]"
|
||||
|
||||
@ -194,6 +197,7 @@ class RandomValueSource(Source):
|
||||
codegen.append_output(codegen.create_load_const(self.random_call_index))
|
||||
codegen.append_output(create_binary_subscr())
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"random_value_{self.random_call_index}"
|
||||
|
||||
@ -208,6 +212,7 @@ class GlobalSource(Source):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.GLOBAL
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"G[{repr(self.global_name)}]"
|
||||
|
||||
@ -227,6 +232,7 @@ class GlobalWeakRefSource(Source):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.GLOBAL
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"G[{repr(self.global_name)}]()"
|
||||
|
||||
@ -240,8 +246,9 @@ class WeakRefCallSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}()"
|
||||
return f"{self.base.name}()"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -269,10 +276,11 @@ class AttrSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
if not self.member.isidentifier():
|
||||
return f"getattr({self.base.name()}, {self.member!r})"
|
||||
return f"{self.base.name()}.{self.member}"
|
||||
return f"getattr({self.base.name}, {self.member!r})"
|
||||
return f"{self.base.name}.{self.member}"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -295,8 +303,9 @@ class GenericAttrSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"object.__getattribute__({self.base.name()}, {self.member!r})"
|
||||
return f"object.__getattribute__({self.base.name}, {self.member!r})"
|
||||
|
||||
|
||||
# Represents obj.__dict__ where obj is a type object
|
||||
@ -309,12 +318,13 @@ class TypeDictSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
# type(ob).__dict__ can return a proxy of the dict. But in the C++
|
||||
# guard accessor, we are use type->tp_dict which is a dict. So,
|
||||
# forcefully pass a dict object to ensure that the GuardManager
|
||||
# registers that its working on a dict object.
|
||||
return f"dict({self.base.name()}.__dict__)"
|
||||
return f"dict({self.base.name}.__dict__)"
|
||||
|
||||
|
||||
# Represents obj.__mro__ where object is type object
|
||||
@ -327,8 +337,9 @@ class TypeMROSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}.__mro__"
|
||||
return f"{self.base.name}.__mro__"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -360,8 +371,9 @@ class CodeSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}.__code__"
|
||||
return f"{self.base.name}.__code__"
|
||||
|
||||
|
||||
# Represents obj.__closure__ where object is type object
|
||||
@ -374,8 +386,9 @@ class ClosureSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}.__closure__"
|
||||
return f"{self.base.name}.__closure__"
|
||||
|
||||
|
||||
# Represents tensor.grad source. It could be represented by AttrSource as well.
|
||||
@ -393,8 +406,9 @@ class GradSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}.{self.member}"
|
||||
return f"{self.base.name}.{self.member}"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -425,6 +439,7 @@ class EphemeralSource(Source):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.EPHEMERAL
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
|
||||
|
||||
@ -443,8 +458,9 @@ class SkipGuardSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return self.base.name()
|
||||
return self.base.name
|
||||
|
||||
|
||||
class TensorProperty(enum.Enum):
|
||||
@ -492,14 +508,15 @@ class TensorPropertySource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
if self.prop is TensorProperty.SIZE:
|
||||
return f"{self.base.name()}.size()[{self.idx}]"
|
||||
return f"{self.base.name}.size()[{self.idx}]"
|
||||
elif self.prop is TensorProperty.STRIDE:
|
||||
return f"{self.base.name()}.stride()[{self.idx}]"
|
||||
return f"{self.base.name}.stride()[{self.idx}]"
|
||||
elif self.prop is TensorProperty.STORAGE_OFFSET:
|
||||
assert self.idx is None
|
||||
return f"{self.base.name()}.storage_offset()"
|
||||
return f"{self.base.name}.storage_offset()"
|
||||
else:
|
||||
raise AssertionError(f"unhandled {self.prop}")
|
||||
|
||||
@ -517,8 +534,9 @@ class IndexedSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"({self.idx}, {self.base.name()})"
|
||||
return f"({self.idx}, {self.base.name})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -532,9 +550,10 @@ class NegateSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
# NB: use method call so that function stripping regexes work
|
||||
return f"{self.base.name()}.__neg__()"
|
||||
return f"{self.base.name}.__neg__()"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -548,8 +567,9 @@ class ConvertIntSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"cast_symbool_to_symint_guardless({self.base.name()})"
|
||||
return f"cast_symbool_to_symint_guardless({self.base.name})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -571,8 +591,9 @@ class DynamicScalarSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"int({self.base.name()})"
|
||||
return f"int({self.base.name})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -586,8 +607,9 @@ class FlattenScriptObjectSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}.__obj_flatten__()"
|
||||
return f"{self.base.name}.__obj_flatten__()"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -601,8 +623,9 @@ class ScriptObjectQualifiedNameSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}._type().qualified_name()"
|
||||
return f"{self.base.name}._type().qualified_name()"
|
||||
|
||||
|
||||
class AttrProxySource(ChainedSource):
|
||||
@ -612,8 +635,9 @@ class AttrProxySource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}.get_base()"
|
||||
return f"{self.base.name}.get_base()"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -631,13 +655,13 @@ class DefaultsSource(ChainedSource):
|
||||
assert isinstance(self.idx_key, str)
|
||||
object.__setattr__(self, "field", "__kwdefaults__")
|
||||
object.__setattr__(
|
||||
self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']"
|
||||
self, "_name", f"{self.base.name}.{self.field}['{self.idx_key}']"
|
||||
)
|
||||
else:
|
||||
assert isinstance(self.idx_key, int)
|
||||
object.__setattr__(self, "field", "__defaults__")
|
||||
object.__setattr__(
|
||||
self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
|
||||
self, "_name", f"{self.base.name}.{self.field}[{self.idx_key}]"
|
||||
)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
@ -649,6 +673,7 @@ class DefaultsSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@ -681,15 +706,16 @@ class GetItemSource(ChainedSource):
|
||||
slice_class, slice_args = self.index
|
||||
return slice_class(*slice_args)
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
# Index can be of following types
|
||||
# 1) index is a slice - example 1:4
|
||||
# 2) index is a constant - example string, integer
|
||||
assert not isinstance(self.index, Source)
|
||||
if self.index_is_slice:
|
||||
return f"{self.base.name()}[{self.unpack_slice()!r}]"
|
||||
return f"{self.base.name}[{self.unpack_slice()!r}]"
|
||||
else:
|
||||
return f"{self.base.name()}[{self.index!r}]"
|
||||
return f"{self.base.name}[{self.index!r}]"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -707,9 +733,10 @@ class ConstDictKeySource(ChainedSource):
|
||||
codegen.append_output(codegen.create_load_const(self.index))
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
# The list creation will be CSE'd by PyExprCSEPass
|
||||
return f"list(dict.keys({self.base.name()}))[{self.index!r}]"
|
||||
return f"list(dict.keys({self.base.name}))[{self.index!r}]"
|
||||
|
||||
def is_dict_key(self) -> bool:
|
||||
return True
|
||||
@ -735,9 +762,10 @@ class NonSerializableSetGetItemSource(ChainedSource):
|
||||
codegen.append_output(codegen.create_load_const(self.index))
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
# set ordering might not be stable
|
||||
return f"list({self.base.name()})[{self.index!r}]"
|
||||
return f"list({self.base.name})[{self.index!r}]"
|
||||
|
||||
def is_dict_key(self) -> bool:
|
||||
return False
|
||||
@ -772,11 +800,12 @@ class DictGetItemSource(ChainedSource):
|
||||
codegen.append_output(codegen.create_load_const(self.index))
|
||||
codegen.append_output(create_binary_subscr())
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
if isinstance(self.index, ConstDictKeySource):
|
||||
return f"{self.base.name()}[{self.index.name()}]"
|
||||
return f"{self.base.name}[{self.index.name}]"
|
||||
else:
|
||||
return f"{self.base.name()}[{self.index!r}]"
|
||||
return f"{self.base.name}[{self.index!r}]"
|
||||
|
||||
|
||||
# Same as DictGetItemSource but used for dict.__getitem__ calls to ensure that
|
||||
@ -817,11 +846,12 @@ class DictSubclassGetItemSource(ChainedSource):
|
||||
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
if isinstance(self.index, ConstDictKeySource):
|
||||
return f"dict.__getitem__({self.base.name()}, {self.index.name()})"
|
||||
return f"dict.__getitem__({self.base.name}, {self.index.name})"
|
||||
else:
|
||||
return f"{self.base.name()}[{self.index!r}]"
|
||||
return f"{self.base.name}[{self.index!r}]"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -852,6 +882,7 @@ class ListGetItemSource(GetItemSource):
|
||||
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
# Index can be of following types
|
||||
# 1) index is a slice - example 1:4
|
||||
@ -862,7 +893,7 @@ class ListGetItemSource(GetItemSource):
|
||||
"List[slice] is a temporary object and should not have a source"
|
||||
)
|
||||
else:
|
||||
return f"list.__getitem__({self.base.name()}, {self.index!r})"
|
||||
return f"list.__getitem__({self.base.name}, {self.index!r})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -875,8 +906,9 @@ class TupleIteratorGetItemSource(GetItemSource):
|
||||
codegen.append_output(codegen.create_load_const(self.index))
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
|
||||
return f"___tuple_iterator_getitem({self.base.name}, {self.index!r})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -888,8 +920,9 @@ class NamedTupleFieldsSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"___namedtuple_fields({self.base.name()})"
|
||||
return f"___namedtuple_fields({self.base.name})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -904,8 +937,9 @@ class DataclassFieldsSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"___dataclass_fields({self.base.name()})"
|
||||
return f"___dataclass_fields({self.base.name})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -921,8 +955,9 @@ class TypeSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"type({self.base.name()})"
|
||||
return f"type({self.base.name})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -933,8 +968,9 @@ class OptimizerSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return self.base.name()
|
||||
return self.base.name
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -945,8 +981,9 @@ class NNModuleSource(ChainedSource):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return self.base.name()
|
||||
return self.base.name
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -969,6 +1006,7 @@ class FSDPNNModuleSource(NNModuleSource):
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class GlobalStateSource(Source):
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return ""
|
||||
|
||||
@ -987,6 +1025,7 @@ class TorchSource(Source):
|
||||
|
||||
install_guard(self.make_guard(GuardBuilder.ID_MATCH))
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return "__import__('torch')"
|
||||
|
||||
@ -1007,6 +1046,7 @@ class TorchSource(Source):
|
||||
class TorchFunctionModeStackSource(Source):
|
||||
ind: int
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"___get_torch_function_mode_stack_at({self._get_index()})"
|
||||
|
||||
@ -1038,6 +1078,7 @@ class ConstantSource(Source):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.CONSTANT
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return self.source_name
|
||||
|
||||
@ -1047,8 +1088,9 @@ class ConstantSource(Source):
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NumpyTensorSource(ChainedSource):
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"___from_numpy({self.base.name()})"
|
||||
return f"___from_numpy({self.base.name})"
|
||||
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
@ -1061,8 +1103,9 @@ class NumpyTensorSource(ChainedSource):
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SubclassAttrListSource(ChainedSource):
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}.__tensor_flatten__()[0]"
|
||||
return f"{self.base.name}.__tensor_flatten__()[0]"
|
||||
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
@ -1072,8 +1115,9 @@ class SubclassAttrListSource(ChainedSource):
|
||||
# source, it is ephemeral
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FloatTensorSource(ChainedSource):
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"___as_tensor({self.base.name()})"
|
||||
return f"___as_tensor({self.base.name})"
|
||||
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
@ -1081,8 +1125,9 @@ class FloatTensorSource(ChainedSource):
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CallMethodItemSource(ChainedSource):
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}.item()"
|
||||
return f"{self.base.name}.item()"
|
||||
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
@ -1093,6 +1138,7 @@ class CallMethodItemSource(ChainedSource):
|
||||
# guard contents from the ambient ShapeEnv
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ShapeEnvSource(Source):
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return ""
|
||||
|
||||
@ -1104,6 +1150,7 @@ class ShapeEnvSource(Source):
|
||||
class CurrentStreamSource(Source):
|
||||
device: device_type
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return f"___get_current_stream(torch.device('{self.device.type}', {self.device.index}))"
|
||||
|
||||
@ -1126,6 +1173,7 @@ class CurrentStreamSource(Source):
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BackwardStateSource(Source):
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
@ -389,7 +389,7 @@ class GraphArg:
|
||||
self.example_strong_ref = None
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.source.name() == other.source.name()
|
||||
return self.source.name == other.source.name
|
||||
|
||||
|
||||
class BackwardStateGraphArg(GraphArg):
|
||||
@ -444,7 +444,7 @@ class VariableBuilder:
|
||||
super().__init__()
|
||||
self.tx = tx
|
||||
self.source = source
|
||||
self.name = source.name()
|
||||
self.name = source.name
|
||||
|
||||
def __call__(self, value):
|
||||
if value in self.tx.output.side_effects:
|
||||
@ -1645,7 +1645,7 @@ class VariableBuilder:
|
||||
elif value.dynamism.type == _DimHintType.DYNAMIC:
|
||||
log.debug(
|
||||
"%s marked %s via IntWrapper",
|
||||
self.source.name(),
|
||||
self.source.name,
|
||||
DimDynamic.DYNAMIC,
|
||||
)
|
||||
return self.wrap_symint(
|
||||
@ -1658,7 +1658,7 @@ class VariableBuilder:
|
||||
elif value.dynamism.type == _DimHintType.AUTO:
|
||||
log.debug(
|
||||
"%s marked %s via IntWrapper",
|
||||
self.source.name(),
|
||||
self.source.name,
|
||||
DimDynamic.DYNAMIC,
|
||||
)
|
||||
return self.wrap_symint(value.val, dynamism=DimDynamic.DYNAMIC)
|
||||
@ -1831,7 +1831,7 @@ class VariableBuilder:
|
||||
from ..decorators import mark_static_address
|
||||
|
||||
static_inputs_log.debug(
|
||||
"Marking static input %s, id: %s)", self.source.name(), id(value)
|
||||
"Marking static input %s, id: %s)", self.source.name, id(value)
|
||||
)
|
||||
mark_static_address(value, guard=guard)
|
||||
|
||||
@ -2003,12 +2003,12 @@ class VariableBuilder:
|
||||
def wrap_literal(self, value):
|
||||
if type(value) is int:
|
||||
# allowlist has higher precedence over specialization control.
|
||||
if is_dynamic_source(self.source.name()):
|
||||
log.debug("%s marked dynamic via source whitelist", self.source.name())
|
||||
if is_dynamic_source(self.source.name):
|
||||
log.debug("%s marked dynamic via source whitelist", self.source.name)
|
||||
return self.wrap_symint(value, dynamism=DimDynamic.DYNAMIC)
|
||||
|
||||
if is_unbacked_source(self.source.name()):
|
||||
log.debug("%s marked unbacked via source whitelist", self.source.name())
|
||||
if is_unbacked_source(self.source.name):
|
||||
log.debug("%s marked unbacked via source whitelist", self.source.name)
|
||||
return self.wrap_symint(value, dynamism=DimDynamic.SIZE_LIKE_UNBACKED)
|
||||
|
||||
if not config.specialize_int:
|
||||
@ -2034,7 +2034,7 @@ class VariableBuilder:
|
||||
|
||||
process_automatic_dynamic(
|
||||
self.tx,
|
||||
self.source.name(),
|
||||
self.source.name,
|
||||
FrameStateSizeEntry.make_scalar(value),
|
||||
is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(),
|
||||
)
|
||||
@ -2440,7 +2440,7 @@ class VariableBuilder:
|
||||
self.install_guards(GuardBuilder.CONSTANT_MATCH)
|
||||
return ConstantVariable.create(value=value, source=self.source)
|
||||
|
||||
name = self.source.name()
|
||||
name = self.source.name
|
||||
|
||||
frame_state_entry = process_automatic_dynamic(
|
||||
self.tx,
|
||||
@ -2453,7 +2453,7 @@ class VariableBuilder:
|
||||
# know if bare integers are actually going to be sizevars
|
||||
# and it is inappropriate to eagerly duck size them with
|
||||
# real sizevars
|
||||
normalized_source_name = normalize_source_name(self.source.name())
|
||||
normalized_source_name = normalize_source_name(self.source.name)
|
||||
base_source = self.source
|
||||
if isinstance(base_source, ChainedSource):
|
||||
base_source = base_source.get_base()
|
||||
@ -2539,7 +2539,7 @@ class VariableBuilder:
|
||||
|
||||
frame_state_entry = process_automatic_dynamic(
|
||||
self.tx,
|
||||
self.source.name(),
|
||||
self.source.name,
|
||||
FrameStateSizeEntry.make_scalar(value),
|
||||
is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(),
|
||||
)
|
||||
@ -3386,7 +3386,7 @@ def _automatic_dynamic(
|
||||
hints=[],
|
||||
)
|
||||
|
||||
name = source.name()
|
||||
name = source.name
|
||||
prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None)
|
||||
shape_env_to_source_to_symbol_cache = (
|
||||
prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else None
|
||||
@ -3509,7 +3509,7 @@ def _automatic_dynamic(
|
||||
# Reflect the user directive in the frame_state
|
||||
# For dynamic, apply None always
|
||||
|
||||
normalized_source_name = normalize_source_name(source.name())
|
||||
normalized_source_name = normalize_source_name(source.name)
|
||||
base_source = source
|
||||
if isinstance(base_source, ChainedSource):
|
||||
base_source = base_source.get_base()
|
||||
@ -3670,7 +3670,7 @@ def wrap_to_fake_tensor_and_record(
|
||||
|
||||
log.debug(
|
||||
"wrap_to_fake %s %s %s %s",
|
||||
source.name(),
|
||||
source.name,
|
||||
tuple(e.shape),
|
||||
symbolic_context,
|
||||
type(e),
|
||||
|
||||
@ -1080,7 +1080,7 @@ def check_aliasing_and_input_mutation(
|
||||
unimplemented(
|
||||
gb_type="Encountered input mutation during higher order op tracing",
|
||||
context=context,
|
||||
explanation=f"Higher order ops do not support input mutation. Found in {source_target.name()}",
|
||||
explanation=f"Higher order ops do not support input mutation. Found in {source_target.name}",
|
||||
hints=[
|
||||
"Consider using the debug context to change user code to avoid mutation.",
|
||||
"Please open an issue.",
|
||||
@ -1094,7 +1094,7 @@ def check_aliasing_and_input_mutation(
|
||||
unimplemented(
|
||||
gb_type="Encountered aliasing during higher order op tracing",
|
||||
context=context,
|
||||
explanation=f"Higher order ops do not support aliasing. Found in {source_target.name()}",
|
||||
explanation=f"Higher order ops do not support aliasing. Found in {source_target.name}",
|
||||
hints=[
|
||||
"Replace `return input` with `return input.clone()` to avoid aliasing.",
|
||||
"Consider using the debug context to change user code to avoid aliasing.",
|
||||
|
||||
@ -572,7 +572,7 @@ class DelayGraphBreakVariable(UnknownVariable):
|
||||
gb_type="Unsupported function call (delayed)",
|
||||
context=f"source: {self.source}",
|
||||
explanation="Dynamo determined that a graph break should occur "
|
||||
f"when calling `{self.source.name()}`. Reason: {self.msg}",
|
||||
f"when calling `{self.source.name}`. Reason: {self.msg}",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
|
||||
@ -113,7 +113,7 @@ def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs):
|
||||
|
||||
@contextmanager
|
||||
def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module):
|
||||
fully_qualified_name = source.name()
|
||||
fully_qualified_name = source.name
|
||||
# Remove redundant namings
|
||||
fully_qualified_name = re.sub(
|
||||
r"\._(?:modules|parameters|buffers)\[(['\"])([^'\"\]]+)\1\]",
|
||||
|
||||
@ -323,7 +323,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
# Note: to avoid spam logs only warn if perf hint artifact is enabled
|
||||
# (NB: artifacts are only enabled at the debug or warning level)
|
||||
if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG):
|
||||
non_static_grad_names = [src.name() for src in non_static_grads]
|
||||
non_static_grad_names = [src.name for src in non_static_grads]
|
||||
perf_hint_log.warning(
|
||||
(
|
||||
"Grad tensors %s will be copied during cudagraphs execution."
|
||||
@ -365,7 +365,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
# mark these tensors as static for cudagraphs
|
||||
mark_static_address(tensor_value, guard=True)
|
||||
source = self.tensor_to_source[tensor_value]
|
||||
self.static_tensor_names.add(tx.output.module_key_name(source.name()))
|
||||
self.static_tensor_names.add(tx.output.module_key_name(source.name))
|
||||
elif tensor_value in self.grad_to_source:
|
||||
source = self.grad_to_source[tensor_value]
|
||||
else:
|
||||
@ -374,7 +374,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
|
||||
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
|
||||
source = GlobalWeakRefSource(global_name)
|
||||
self.static_tensor_names.add(tx.output.module_key_name(source.name()))
|
||||
self.static_tensor_names.add(tx.output.module_key_name(source.name))
|
||||
|
||||
return VariableTracker.build(tx, tensor_value, source)
|
||||
|
||||
|
||||
@ -314,7 +314,7 @@ class TensorVariable(VariableTracker):
|
||||
# eval("super(L['mod'].model.model.encoder.embed_positions.forward__class__,
|
||||
# L['mod'].model.model.encoder.embed_positions)", scope)
|
||||
# Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope.
|
||||
_input_associated_real_value = eval(self.source.name(), scope)
|
||||
_input_associated_real_value = eval(self.source.name, scope)
|
||||
except Exception as exc:
|
||||
raise NotImplementedError from exc
|
||||
|
||||
@ -551,7 +551,7 @@ class TensorVariable(VariableTracker):
|
||||
# For local source, we associate the real value. We use this real value
|
||||
scope = {"L": tx.output.local_scope, "G": tx.output.global_scope}
|
||||
try:
|
||||
_input_associated_real_value = eval(self.source.name(), scope)
|
||||
_input_associated_real_value = eval(self.source.name, scope)
|
||||
except Exception as exc:
|
||||
unimplemented(
|
||||
gb_type="Error getting associated real value",
|
||||
|
||||
@ -280,7 +280,7 @@ def _create_symbolic_context_for_tensor(t, source, t_constraints, sources, mode)
|
||||
if isinstance(constraint, _RelaxedConstraint):
|
||||
continue
|
||||
symbolic_context.constraint_sizes[i] = constraint.constraint_range
|
||||
mode.shape_env.source_name_to_debug_name[src.name()] = constraint.name # type: ignore[assignment]
|
||||
mode.shape_env.source_name_to_debug_name[src.name] = constraint.name # type: ignore[assignment]
|
||||
|
||||
return symbolic_context
|
||||
|
||||
|
||||
@ -173,7 +173,7 @@ def _try_get_metadata_from_dynamo(
|
||||
assert source is None or source not in seen_sources, source
|
||||
seen_sources.add(source)
|
||||
aot_autograd_arg_pos_to_source.append(source)
|
||||
source_name = source.name() if source else str(source)
|
||||
source_name = source.name if source else str(source)
|
||||
|
||||
# input[i] in dynamo is now:
|
||||
# input[i + len(extra_params)] in AOT,
|
||||
|
||||
@ -245,7 +245,7 @@ class Guard:
|
||||
# globals (and locals, if you create a LOCAL guard) to extract the Python
|
||||
# object that we want to perform guard tests on. This evaluation
|
||||
# typically happens in GuardBuilder.eval. In these cases, name is
|
||||
# typically produced by originating_source.name() (not to be confused with
|
||||
# typically produced by originating_source.name (not to be confused with
|
||||
# GuardSource - the property source).
|
||||
#
|
||||
# Occasionally, name is not a valid Python expression; sometimes
|
||||
@ -297,7 +297,7 @@ class Guard:
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.originating_source.name()
|
||||
return self.originating_source.name
|
||||
|
||||
@property
|
||||
def source(self) -> GuardSource:
|
||||
@ -1074,6 +1074,7 @@ class Source:
|
||||
def guard_source(self) -> GuardSource:
|
||||
raise NotImplementedError
|
||||
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -870,7 +870,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
|
||||
# This function assumes that it's possible to do the conversion
|
||||
# NB: name here is used in a conventional way by Dynamo; it corresponds
|
||||
# precisely to the Source.name() of the tensor we're fakeifying and
|
||||
# precisely to the Source.name of the tensor we're fakeifying and
|
||||
# corresponds to a valid Python expression. When we construct sub-names
|
||||
# as part of this process, we will maintain this invariant! (Even though
|
||||
# other users of this may not need it this property to be upheld.)
|
||||
@ -1937,7 +1937,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
metadata_fn=lambda: {
|
||||
"describer_id": self.describer.id,
|
||||
"id": t_desc.id,
|
||||
"source": source.name(),
|
||||
"source": source.name,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -50,6 +50,56 @@ static py::handle _callback_from_action(
|
||||
return callback;
|
||||
}
|
||||
|
||||
// c_recursion_remaining only defined in 3.12 and 3.13
|
||||
|
||||
static int32_t c_recursion_limit = -1;
|
||||
|
||||
void set_c_recursion_limit(int32_t limit) {
|
||||
if (limit < 1) {
|
||||
throw std::range_error("recursion limit must be greater or equal than 1");
|
||||
}
|
||||
c_recursion_limit = limit;
|
||||
// cannot fail
|
||||
Py_SetRecursionLimit(limit); // also set the Python limit
|
||||
}
|
||||
|
||||
int32_t get_c_recursion_limit() {
|
||||
return c_recursion_limit;
|
||||
}
|
||||
|
||||
#if IS_PYTHON_3_12_PLUS && !IS_PYTHON_3_14_PLUS
|
||||
|
||||
struct CRecursionLimitRAII {
|
||||
PyThreadState* tstate;
|
||||
int32_t old_recursion_remaining;
|
||||
CRecursionLimitRAII(PyThreadState* tstate) : tstate{tstate} {
|
||||
auto limit = get_c_recursion_limit();
|
||||
auto& remaining = tstate->c_recursion_remaining;
|
||||
this->old_recursion_remaining = remaining;
|
||||
if (limit < 0) {
|
||||
// no change to limit
|
||||
return;
|
||||
}
|
||||
if (limit < remaining) {
|
||||
PyErr_SetString(
|
||||
PyExc_RuntimeError,
|
||||
"new c_recursion limit is lower than thread's current c_recursion_remaining.");
|
||||
}
|
||||
remaining = limit;
|
||||
}
|
||||
~CRecursionLimitRAII() {
|
||||
this->tstate->c_recursion_remaining = this->old_recursion_remaining;
|
||||
}
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
struct CRecursionLimitRAII {
|
||||
CRecursionLimitRAII(PyThreadState* tstate) {}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
// frame and callback are borrowed references.
|
||||
// Returns new reference.
|
||||
PyObject* dynamo__custom_eval_frame(
|
||||
@ -258,6 +308,13 @@ PyObject* dynamo__custom_eval_frame(
|
||||
bool apply_to_code = false;
|
||||
PyObject* guarded_code = nullptr;
|
||||
try {
|
||||
CRecursionLimitRAII tmp(tstate); // increase C recursion limit to the given
|
||||
// value during compilation
|
||||
// C recursion limit failure
|
||||
if (PyErr_Occurred()) {
|
||||
fail();
|
||||
return eval_result;
|
||||
}
|
||||
callback_result = dynamo_call_callback(
|
||||
callback, frame, locals.get(), cache_entry, frame_state);
|
||||
new_strategy =
|
||||
|
||||
@ -19,6 +19,9 @@ PyObject* dynamo__custom_eval_frame(
|
||||
PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* obj);
|
||||
void skip_code_recursive(PyCodeObject* code);
|
||||
|
||||
void set_c_recursion_limit(int32_t limit);
|
||||
int32_t get_c_recursion_limit();
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
} // extern "C"
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include <torch/csrc/dynamo/cache_entry.h>
|
||||
#include <torch/csrc/dynamo/cpython_defs.h>
|
||||
#include <torch/csrc/dynamo/eval_frame.h>
|
||||
#include <torch/csrc/dynamo/eval_frame_cpp.h>
|
||||
#include <torch/csrc/dynamo/extra_state.h>
|
||||
#include <torch/csrc/dynamo/guards.h>
|
||||
#include <torch/csrc/dynamo/python_compiled_autograd.h>
|
||||
@ -250,6 +251,9 @@ void initDynamoBindings(PyObject* torch) {
|
||||
.def_readwrite("cur_action", &FrameExecStrategy::cur_action)
|
||||
.def_readwrite("recursive_action", &FrameExecStrategy::recursive_action);
|
||||
|
||||
m.def("set_c_recursion_limit", &set_c_recursion_limit);
|
||||
m.def("get_c_recursion_limit", &get_c_recursion_limit);
|
||||
|
||||
m.def("_debug_get_cache_entry_list", &_debug_get_cache_entry_list);
|
||||
m.def("_reset_precompile_entries", &_reset_precompile_entries);
|
||||
m.def("_load_precompile_entry", &_load_precompile_entry);
|
||||
|
||||
@ -1918,7 +1918,7 @@ class StrictMinMaxConstraint(Constraint):
|
||||
def render(self, source: Source) -> str:
|
||||
"""Format the constrain equation"""
|
||||
# TODO: better printing for -oo and oo
|
||||
return f"{self.vr.lower} <= {source.name()} <= {self.vr.upper}"
|
||||
return f"{self.vr.lower} <= {source.name} <= {self.vr.upper}"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -1943,7 +1943,7 @@ class RelaxedUnspecConstraint(Constraint):
|
||||
"""
|
||||
|
||||
def render(self, source: Source) -> str:
|
||||
return f"RelaxedUnspecConstraint({source.name()})"
|
||||
return f"RelaxedUnspecConstraint({source.name})"
|
||||
|
||||
|
||||
# NB: None here indicates the client constraint is whatever is implicitly
|
||||
@ -2039,7 +2039,7 @@ class EqualityConstraint(Constraint):
|
||||
return self._defs[src]
|
||||
else:
|
||||
# otherwise, create a symbol representing the source
|
||||
return sympy.Symbol(src.name())
|
||||
return sympy.Symbol(src.name)
|
||||
|
||||
def is_equal(self, source1: Source, source2: Source) -> bool:
|
||||
return (
|
||||
@ -2252,11 +2252,11 @@ class TrackedFake:
|
||||
symbolic_context: Optional[SymbolicContext]
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.fake, self.source.name()))
|
||||
return hash((self.fake, self.source.name))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, TrackedFake):
|
||||
return self.fake is other.fake and self.source.name() == other.source.name()
|
||||
return self.fake is other.fake and self.source.name == other.source.name
|
||||
return False
|
||||
|
||||
|
||||
@ -2712,7 +2712,7 @@ class _ShapeGuardPrinter(abc.ABC):
|
||||
def repr_sources(src: Mapping[sympy.Symbol, list[Source]]) -> str:
|
||||
return repr(
|
||||
{
|
||||
symbol: [s.name() for s in sources]
|
||||
symbol: [s.name for s in sources]
|
||||
for symbol, sources in src.items()
|
||||
}
|
||||
)
|
||||
@ -2820,7 +2820,7 @@ class _ShapeGuardCppPrinter(_ShapeGuardPrinter, CppPrinter):
|
||||
if source in self.source_to_symbol:
|
||||
return self.source_to_symbol[source].name
|
||||
|
||||
source_name = source.name()
|
||||
source_name = source.name
|
||||
mangled_name = re.sub("[^0-9a-zA-Z_]+", "_", source_name)
|
||||
old_mangled_name = mangled_name
|
||||
count = 0
|
||||
@ -2849,7 +2849,7 @@ class _CppShapeGuardsHelper(_ShapeGuardsHelper):
|
||||
|
||||
class LoggingShapeGuardPrinter(ShapeGuardPythonPrinter):
|
||||
def __init__(self, var_to_sources: Mapping[sympy.Symbol, list[Source]]):
|
||||
super().__init__(var_to_sources, lambda n: n.name(), var_to_sources)
|
||||
super().__init__(var_to_sources, lambda n: n.name, var_to_sources)
|
||||
|
||||
|
||||
class DynamicDimConstraintPrinter(PythonPrinter):
|
||||
@ -2875,7 +2875,7 @@ class DynamicDimConstraintPrinter(PythonPrinter):
|
||||
assert self.symbol_to_source.get(expr), (
|
||||
f"Unknown symbol {expr} created by constraints solver"
|
||||
)
|
||||
return self.symbol_to_source[expr][0].name()
|
||||
return self.symbol_to_source[expr][0].name
|
||||
|
||||
|
||||
class DimConstraints:
|
||||
@ -3095,7 +3095,7 @@ class DimConstraints:
|
||||
"""Add an equality constraint"""
|
||||
if expr.is_number:
|
||||
# specialization, right here
|
||||
self._static_results.add(f"{source.name()} == {expr}")
|
||||
self._static_results.add(f"{source.name} == {expr}")
|
||||
else:
|
||||
# these will resolve to either specializations or dynamic equality constraints
|
||||
self._symbolic_equivalences.append((source, expr))
|
||||
@ -3175,7 +3175,7 @@ class DimConstraints:
|
||||
assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}"
|
||||
# because this is univariate, the solution is a specialization
|
||||
self._static_results.add(
|
||||
f"{self._dcp.symbol_to_source[s][0].name()} == {val}"
|
||||
f"{self._dcp.symbol_to_source[s][0].name} == {val}"
|
||||
)
|
||||
# add this as a substitution to simplify other constraints
|
||||
self._substitutions[s] = val # type: ignore[assignment]
|
||||
@ -3200,8 +3200,8 @@ class DimConstraints:
|
||||
base, divisor = congruence.args
|
||||
tmp_name = "_" + str(
|
||||
self._dcp.source_name_to_debug_name.get(
|
||||
self._dcp.symbol_to_source[s][0].name(),
|
||||
self._dcp.symbol_to_source[s][0].name(),
|
||||
self._dcp.symbol_to_source[s][0].name,
|
||||
self._dcp.symbol_to_source[s][0].name,
|
||||
)
|
||||
)
|
||||
tmp = sympy.Symbol(tmp_name, integer=True)
|
||||
@ -3243,7 +3243,7 @@ class DimConstraints:
|
||||
|
||||
# remaining symbolic equivalences become dynamic equality constraints
|
||||
for source, expr3 in self._symbolic_equivalences:
|
||||
self._dynamic_results.add(f"{source.name()} == {self._dcp.doprint(expr3)}")
|
||||
self._dynamic_results.add(f"{source.name} == {self._dcp.doprint(expr3)}")
|
||||
|
||||
@classmethod
|
||||
def _is_supported_congruence(cls, congruence: sympy.Expr) -> bool:
|
||||
@ -3266,7 +3266,7 @@ class DimConstraints:
|
||||
"""Returns a dictionary of the names of symbols to their specialized value"""
|
||||
|
||||
def debug_name(src: Source) -> str:
|
||||
name = src.name()
|
||||
name = src.name
|
||||
if self._dcp.source_name_to_debug_name:
|
||||
return f"{self._dcp.source_name_to_debug_name[name]} = {name}"
|
||||
else:
|
||||
@ -4011,7 +4011,7 @@ class ShapeEnv:
|
||||
check_fn: A function that takes a sympy Symbol and returns a sympy expression
|
||||
representing a constraint/specialization to be applied
|
||||
"""
|
||||
name = source.name()
|
||||
name = source.name
|
||||
sym = self.source_to_var[name]
|
||||
expr = check_fn(SymInt(SymNode(sym, self, int, None))).node._expr
|
||||
new_axioms = dict(self.get_implications(self.simplify(expr)))
|
||||
@ -4284,7 +4284,7 @@ class ShapeEnv:
|
||||
def _create_symbol_for_source(self, source: Source) -> Optional[sympy.Symbol]:
|
||||
if not self._translation_validation_enabled:
|
||||
return None
|
||||
srcname = source.name()
|
||||
srcname = source.name
|
||||
if source not in self.source_to_symbol:
|
||||
self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True)
|
||||
return self.source_to_symbol[srcname]
|
||||
@ -4874,7 +4874,7 @@ class ShapeEnv:
|
||||
if source is None:
|
||||
sloc, maybe_extra_debug = self._get_stack_summary(is_debug)
|
||||
else:
|
||||
sloc, maybe_extra_debug = source.name(), ""
|
||||
sloc, maybe_extra_debug = source.name, ""
|
||||
log.info(
|
||||
"%s %s [%s, %s] %s%s",
|
||||
prefix,
|
||||
@ -5028,7 +5028,7 @@ class ShapeEnv:
|
||||
if constraint_dim.vr.lower != val:
|
||||
raise ConstraintViolationError(
|
||||
f"Static shape constraint of {constraint_dim.vr.lower} does not match input size of {val}, "
|
||||
f"for {source.name()}"
|
||||
f"for {source.name}"
|
||||
)
|
||||
if symbolic_context:
|
||||
from torch._dynamo.source import TensorPropertySource
|
||||
@ -5041,7 +5041,7 @@ class ShapeEnv:
|
||||
constraint_dim = None
|
||||
|
||||
# see note [Tensor Fakification and Symbol Caching]
|
||||
source_name = source.name()
|
||||
source_name = source.name
|
||||
if (
|
||||
isinstance(symbolic_context, StatefulSymbolicContext)
|
||||
and id(self) not in symbolic_context.shape_env_to_source_to_symbol_cache
|
||||
@ -5115,7 +5115,7 @@ class ShapeEnv:
|
||||
# If we're not duck shaping, we always create a new symbol
|
||||
# Even if we're duck shaping, if we haven't seen this particular
|
||||
# value before, we also create a new symbol
|
||||
symbol_id = self._generate_unique_id(source.name())
|
||||
symbol_id = self._generate_unique_id(source.name)
|
||||
if type(val) is int or is_nested_int(val):
|
||||
sympy_expr = make_symbol(
|
||||
SymT.SIZE, symbol_id, positive=positive, integer=True
|
||||
@ -5219,7 +5219,7 @@ class ShapeEnv:
|
||||
"create_symbol %s = %s for %s %s %s%s%s",
|
||||
sympy_expr,
|
||||
val,
|
||||
source.name(),
|
||||
source.name,
|
||||
range_str,
|
||||
sloc,
|
||||
maybe_more_info,
|
||||
@ -5232,7 +5232,7 @@ class ShapeEnv:
|
||||
"symbol": str(sympy_expr),
|
||||
"val": repr(val),
|
||||
"vr": range_str,
|
||||
"source": source.name(),
|
||||
"source": source.name,
|
||||
"user_stack": structured.from_traceback(
|
||||
TracingContext.extract_stack()
|
||||
),
|
||||
@ -5248,7 +5248,7 @@ class ShapeEnv:
|
||||
# the same symint
|
||||
r = self.val_to_var[val]
|
||||
self.source_to_var[source_name] = r
|
||||
self.log.debug("create_symbol %s duck sized %s", r, source.name())
|
||||
self.log.debug("create_symbol %s duck sized %s", r, source.name)
|
||||
|
||||
if isinstance(r, sympy.Symbol):
|
||||
r_sources = self.var_to_sources[r]
|
||||
@ -5275,7 +5275,7 @@ class ShapeEnv:
|
||||
self.var_to_val[expr] = sympy.Integer(val)
|
||||
|
||||
def _debug_name(self, source: Source) -> str:
|
||||
src_name = source.name()
|
||||
src_name = source.name
|
||||
return self.source_name_to_debug_name.get(src_name, src_name)
|
||||
|
||||
def _render_range_for_constraint_violation(
|
||||
@ -5289,7 +5289,7 @@ class ShapeEnv:
|
||||
if upper >= default.upper:
|
||||
upper = None
|
||||
c_render = (
|
||||
f"{self._debug_name(source)} = {source.name()} in the specified range"
|
||||
f"{self._debug_name(source)} = {source.name} in the specified range"
|
||||
)
|
||||
if lower is not None and upper is not None:
|
||||
c_render += f" {lower} <= {self._debug_name(source)} <= {upper}"
|
||||
@ -5311,7 +5311,7 @@ class ShapeEnv:
|
||||
self,
|
||||
placeholders: Sequence[FakeTensor],
|
||||
sources: Sequence[Source],
|
||||
source_ref: Callable[[Source], str] = lambda n: n.name(),
|
||||
source_ref: Callable[[Source], str] = lambda n: n.name,
|
||||
*,
|
||||
guards: Optional[list[ShapeGuard]] = None,
|
||||
input_contexts: Optional[DimList[SymbolicContext]] = None,
|
||||
@ -5501,10 +5501,10 @@ class ShapeEnv:
|
||||
if equalities_inputs:
|
||||
source_index = {}
|
||||
for i, src in enumerate(sources):
|
||||
source_index[src.name()] = i
|
||||
source_index[src.name] = i
|
||||
|
||||
def get_expression(tensor_dim_src: Source) -> sympy.Expr:
|
||||
fake = placeholders[source_index[tensor_dim_src.base.name()]] # type: ignore[attr-defined]
|
||||
fake = placeholders[source_index[tensor_dim_src.base.name]] # type: ignore[attr-defined]
|
||||
assert tensor_dim_src.idx is not None # type: ignore[attr-defined]
|
||||
symint = fake.shape[tensor_dim_src.idx] # type: ignore[attr-defined]
|
||||
if isinstance(symint, torch.SymInt):
|
||||
@ -5521,16 +5521,16 @@ class ShapeEnv:
|
||||
concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2))
|
||||
if not concrete_val:
|
||||
raise ConstraintViolationError(
|
||||
f"{src1.name()} = {expr1 if isinstance(expr1, int) else expr1.xreplace(self.var_to_val)}"
|
||||
f"{src1.name} = {expr1 if isinstance(expr1, int) else expr1.xreplace(self.var_to_val)}"
|
||||
" is not equal to "
|
||||
f"{src2.name()} = {expr2 if isinstance(expr2, int) else expr2.xreplace(self.var_to_val)}"
|
||||
f"{src2.name} = {expr2 if isinstance(expr2, int) else expr2.xreplace(self.var_to_val)}"
|
||||
)
|
||||
|
||||
for srcEq, root, fn in equalities_inputs.derived_equalities:
|
||||
expr1 = get_expression(srcEq)
|
||||
# recall that root is either a phantom symbol or an input source
|
||||
if isinstance(root, sympy.Symbol):
|
||||
expr2, debug_name = root, self.var_to_sources[root][0].name()
|
||||
expr2, debug_name = root, self.var_to_sources[root][0].name
|
||||
elif isinstance(root, sympy.Integer):
|
||||
expr2, debug_name = root, str(root)
|
||||
else:
|
||||
@ -5542,7 +5542,7 @@ class ShapeEnv:
|
||||
concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2_))
|
||||
if not concrete_val:
|
||||
raise ConstraintViolationError(
|
||||
f"Expected input {srcEq.name()} to be equal to "
|
||||
f"Expected input {srcEq.name} to be equal to "
|
||||
f"{fn(sympy.Symbol(debug_name))}, "
|
||||
f"where {debug_name} = {expr2.xreplace(self.var_to_val)}, "
|
||||
f"but got {expr1.xreplace(self.var_to_val)}"
|
||||
@ -5764,7 +5764,7 @@ class ShapeEnv:
|
||||
|
||||
if not _simplified:
|
||||
for source, expr in input_guards:
|
||||
srcname = source.name()
|
||||
srcname = source.name
|
||||
if self._translation_validation_enabled:
|
||||
# Ignore sources that were not turned into SymInts.
|
||||
if srcname in self.source_to_symbol:
|
||||
@ -5827,8 +5827,8 @@ class ShapeEnv:
|
||||
)
|
||||
):
|
||||
msg = (
|
||||
f"The values of {self._debug_name(source)} = {source.name()} and "
|
||||
f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name()} "
|
||||
f"The values of {self._debug_name(source)} = {source.name} and "
|
||||
f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name} "
|
||||
"must always be equal."
|
||||
)
|
||||
record_constraint_violation(
|
||||
@ -5846,8 +5846,8 @@ class ShapeEnv:
|
||||
):
|
||||
src = symbol_to_source[symbol][0]
|
||||
msg = (
|
||||
f"The values of {self._debug_name(source)} = {source.name()} must always be related to "
|
||||
f"the values of {self._debug_name(src)} = {src.name()} by "
|
||||
f"The values of {self._debug_name(source)} = {source.name} must always be related to "
|
||||
f"the values of {self._debug_name(src)} = {src.name} by "
|
||||
f"{self._debug_name(source)} = {expr.xreplace({symbol: sympy.sympify(self._debug_name(src))})}."
|
||||
)
|
||||
record_constraint_violation(
|
||||
@ -6868,7 +6868,7 @@ class ShapeEnv:
|
||||
"symbolic_shape_specialization",
|
||||
metadata_fn=lambda: {
|
||||
"symbol": repr(a),
|
||||
"sources": [s.name() for s in self.var_to_sources.get(a, [])],
|
||||
"sources": [s.name for s in self.var_to_sources.get(a, [])],
|
||||
"value": repr(tgt),
|
||||
"reason": msg,
|
||||
"stack": structured.from_traceback(
|
||||
@ -6886,7 +6886,7 @@ class ShapeEnv:
|
||||
|
||||
if config.print_specializations:
|
||||
self.log.warning(
|
||||
"Specializing %s to %s", self.var_to_sources[a][0].name(), tgt
|
||||
"Specializing %s to %s", self.var_to_sources[a][0].name, tgt
|
||||
)
|
||||
self.log.debug("SPECIALIZATION", stack_info=True)
|
||||
log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)
|
||||
@ -7211,7 +7211,7 @@ class ShapeEnv:
|
||||
if str(s) in frame_symbols: # type: ignore[operator]
|
||||
continue
|
||||
if s in self.var_to_sources:
|
||||
frame_symbols[str(s)] = self.var_to_sources[s][0].name() # type: ignore[assignment]
|
||||
frame_symbols[str(s)] = self.var_to_sources[s][0].name # type: ignore[assignment]
|
||||
return str(x)
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user