Compare commits

...

2 Commits

Author SHA1 Message Date
d1bde8d84b [Dynamo] Remove partial graph printing on data-dependent graph breaks
ghstack-source-id: da213ba5363760125fee99a70f9ddd2e3bf7acab
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149686
2025-03-20 17:19:04 -07:00
ac3faf97ce [Hierarchical Compilation] Handle origin nodes without children
ghstack-source-id: 0c7481535ff01ff8513ae2f6cb806ed6b3cdb4d2
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149685
2025-03-20 17:19:00 -07:00
2 changed files with 18 additions and 22 deletions

View File

@ -179,6 +179,9 @@ class BackwardBfsArgIter:
else:
self._queue.append(arg)
def __str__(self) -> str:
return f"BackwardBfsArgIter(cur={self._cur}, queue={self._queue})"
class GraphRegionTracker:
"""
@ -315,7 +318,11 @@ def fully_expand_region_group(
region_it.add_children(node)
current_node = region_iters[0].next()
assert current_node is not None
# No children
if current_node is None:
return
# Loop incrementally adding new nodes to each region
# regions are only expanded if the node to add is valid
# for ALL regions

View File

@ -1340,18 +1340,7 @@ class InstructionTranslatorBase(
raise
except BackendCompilerFailed:
raise
except RuntimeError as e:
if hasattr(e, "msg") and "Data-dependent" in e.msg:
print(
"\n"
+ torch.fx.GraphModule(
self.output.nn_modules, self.output.graph
).print_readable(
print_output=False, include_stride=True, include_device=True
),
file=sys.stderr,
)
except RuntimeError:
raise
except Exception as e:
if self.exec_recorder:
@ -1369,9 +1358,9 @@ class InstructionTranslatorBase(
self.output.cleanup()
def push(self, val: Optional[VariableTracker]):
assert val is None or isinstance(val, VariableTracker), (
f"push expects VariableTracker, got {typestr(val)}"
)
assert val is None or isinstance(
val, VariableTracker
), f"push expects VariableTracker, got {typestr(val)}"
self.stack.append(val) # type: ignore[arg-type]
def push_many(self, vals: list[VariableTracker]):
@ -2340,9 +2329,9 @@ class InstructionTranslatorBase(
if isinstance(obj, NNModuleVariable) and not isinstance(val, ConstantVariable):
# We don't allow side effects during export on non-constant values
# https://github.com/pytorch/torchdynamo/issues/1475
assert not self.export, (
f"Mutating module attribute {inst.argval} during export."
)
assert (
not self.export
), f"Mutating module attribute {inst.argval} during export."
try:
BuiltinVariable(setattr).call_function(
@ -3364,9 +3353,9 @@ class InstructionTranslator(InstructionTranslatorBase):
self.one_graph: bool = one_graph
self.export = export
if self.export:
assert self.one_graph, (
"Export without one graph - something has gone wrong."
)
assert (
self.one_graph
), "Export without one graph - something has gone wrong."
self.symbolic_locals = {}
# Populate `symbolic_locals` with non-cell variables.