Compare commits

...

1 Commits

Author SHA1 Message Date
e192036227 fix dynamo stack trace 2025-10-20 13:56:48 -07:00
5 changed files with 69 additions and 10 deletions

View File

@ -942,6 +942,22 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
self.assertRaises(Unsupported, fn)
def test_stack_trace_from_observed_exception(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(16, 16)
def forward(self, x):
# no attribute w on self.linear
weight = self.linear.w
return torch.nn.functional.linear(x, weight)
x = (torch.randn(4, 16, requires_grad=True),)
with self.assertRaisesRegex(Exception, "weight = self.linear.w"):
torch._dynamo.functional_export._dynamo_graph_capture_for_export(Model())(x)
instantiate_parametrized_tests(ExceptionTests)

View File

@ -163,9 +163,17 @@ class BackendCompilerFailed(ShortenTraceback):
class Unsupported(TorchDynamoException):
def __init__(self, msg: str, *, case_name: Optional[str] = None) -> None:
def __init__(
self,
msg: str,
*,
case_name: Optional[str] = None,
real_stack: None | StackSummary = None,
) -> None:
super().__init__(msg)
self.real_stack = torch._guards.TracingContext.extract_stack()
if not real_stack:
real_stack = torch._guards.TracingContext.extract_stack()
self.real_stack = real_stack
self.msg = msg
self.category: Optional[str] = None
self.add_to_stats()
@ -300,7 +308,9 @@ class PackageError(TorchDynamoException):
class ObservedException(TorchDynamoException):
# An exception observed during the tracing. This exception is used by Dynamo to handle exceptions.
pass
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.real_stack: StackSummary = torch._guards.TracingContext.extract_stack()
class ObservedUserStopIteration(ObservedException):
@ -384,14 +394,22 @@ def raise_observed_exception(
*,
args: Optional[list[Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
msg: Optional[str] = None,
) -> NoReturn:
from .variables import BuiltinVariable
# CPython here raises an exception. Since there is no python code, we have to manually setup the exception
# stack and raise the exception.
# If a message is provided but no args, use the message as the first argument
if msg is not None and (args is None or len(args) == 0):
args = [msg]
exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) # type: ignore[arg-type]
tx.exn_vt_stack.set_current_exception(exception_vt) # type: ignore[arg-type]
raise get_dynamo_observed_exception(exc_type)
raised_exc = get_dynamo_observed_exception(exc_type)
# Store the original exception arguments for better error messages
if args:
raise raised_exc(*args)
raise raised_exc
def handle_observed_exception(tx: Any) -> None:
@ -598,7 +616,10 @@ def unimplemented_v2(
if log_warning:
log.warning(msg)
if from_exc is not _NOTHING:
raise Unsupported(msg) from from_exc
past_real_stack = None
if hasattr(from_exc, "real_stack"):
past_real_stack = from_exc.real_stack
raise Unsupported(msg, real_stack=past_real_stack) from from_exc
raise Unsupported(msg)

View File

@ -2204,6 +2204,7 @@ class InstructionTranslatorBase(
*graph_break_hints.USER_ERROR,
*graph_break_hints.SUPPORTABLE,
],
from_exc=raised_exception,
)
if sys.version_info >= (3, 11):

View File

@ -103,10 +103,14 @@ def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs):
fake_kwargs = {k: convert_to_fake(v) for k, v in proxy_kwargs.items()}
try:
mod._infer_parameters(mod, fake_args, fake_kwargs)
except AttributeError:
except AttributeError as e:
# Re-raise with the original error message from the AttributeError
raise_observed_exception(
AttributeError,
tx,
msg=str(e)
if str(e)
else "AttributeError during lazy module initialization",
)
@ -363,6 +367,7 @@ class NNModuleVariable(VariableTracker):
raise_observed_exception(
AttributeError,
tx,
msg=f"'{type(base).__name__}' object has no attribute '{name}'",
)
if name == "forward":
@ -1192,7 +1197,11 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
if out is None:
out = self.getattr_helper(tx, "_buffers", name_vt)
if out is None:
raise_observed_exception(AttributeError, tx)
raise_observed_exception(
AttributeError,
tx,
msg=f"'{type(self.value).__name__}' object has no attribute '{name}'",
)
return out

View File

@ -277,7 +277,11 @@ class UserDefinedClassVariable(UserDefinedVariable):
obj = inspect.getattr_static(self.value, name)
except AttributeError:
if type(self.value) is type:
raise_observed_exception(AttributeError, tx)
raise_observed_exception(
AttributeError,
tx,
msg=f"type object '{self.value.__name__}' has no attribute '{name}'",
)
else:
# Cannot reason about classes with a custom metaclass
# See: test_functions::test_getattr_metaclass
@ -1364,7 +1368,11 @@ class UserDefinedObjectVariable(UserDefinedVariable):
if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
result = tx.output.side_effects.load_attr(self, name, deleted_ok=True)
if isinstance(result, variables.DeletedVariable):
raise_observed_exception(AttributeError, tx)
raise_observed_exception(
AttributeError,
tx,
msg=f"'{type(self.value).__name__}' object has no attribute '{name}'",
)
return result
if name == "__dict__":
@ -1636,7 +1644,11 @@ class UserDefinedObjectVariable(UserDefinedVariable):
return VariableTracker.build(tx, subobj)
# Earlier we were returning GetAttrVariable but its incorrect. In absence of attr, Python raises AttributeError.
raise_observed_exception(AttributeError, tx)
raise_observed_exception(
AttributeError,
tx,
msg=f"'{type(self.value).__name__}' object has no attribute '{name}'",
)
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str