From 3fd68e4e2f5df96b942e77b58e5d2a74814530df Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 25 Feb 2025 16:42:22 -0800 Subject: [PATCH] [dynamo] make some more graph break messages readable in English [2/N] (#147385) This is for "for some large number Z, make sure the error messages are readable English." - beginning to audit all `unimplemented` sites and making sure that all messages are at least English-readable. Hints may not necessarily be provided. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147385 Approved by: https://github.com/jansel --- test/custom_operator/test_custom_ops.py | 2 +- torch/_dynamo/codegen.py | 16 +- torch/_dynamo/comptime.py | 9 +- torch/_dynamo/convert_frame.py | 31 ++- torch/_dynamo/guards.py | 9 +- torch/_dynamo/output_graph.py | 47 +++- torch/_dynamo/side_effects.py | 47 +++- torch/_dynamo/symbolic_convert.py | 328 +++++++++++++++++++----- torch/_dynamo/utils.py | 79 ++++-- 9 files changed, 463 insertions(+), 105 deletions(-) diff --git a/test/custom_operator/test_custom_ops.py b/test/custom_operator/test_custom_ops.py index 38c7349f1390..24bc4db520a8 100644 --- a/test/custom_operator/test_custom_ops.py +++ b/test/custom_operator/test_custom_ops.py @@ -48,7 +48,7 @@ class TestCustomOperators(TestCase): with self.assertRaisesRegex( RuntimeError, - r"unsupported operator: .* you may need to `import nonexistent`", + r"(?s)Operator does not support running with fake tensors.*you may need to `import nonexistent`", ): f(x) diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index de8b24bc7ddb..84fc0ef52634 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -36,7 +36,7 @@ from .bytecode_transformation import ( create_rot_n, Instruction, ) -from .exc import IncorrectUsage, unimplemented +from .exc import IncorrectUsage, unimplemented, unimplemented_v2 from .source import AttrSource, ChainedSource, DictGetItemSource, Source from .utils import is_safe_constant, rot_n_helper from .variables.base import ValueMutationExisting, VariableTracker @@ -335,7 +335,19 @@ class PyCodegen: try: self.call_reconstruct(value) except NotImplementedError: - unimplemented(f"reconstruct: {value}") + unimplemented_v2( + gb_type="Reconstruction failure", + context=str(value), + explanation=f"Dynamo has no bytecode reconstruction implemented for sourceless variable {value}.", + hints=[ + "If Dynamo attempting to trace a return statement and your code is attempting to return a variable " + "that Dynamo cannot reconstruct, then remove it from the return statement.", + "If this reconstruction graph break occurs while handling another graph break, then resolve the " + "initial graph break.", + "Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have" + "reconstruction rules may be fundamentally unreconstructable.", + ], + ) if allow_cache and value in self.tempvars: self._output.append(create_dup_top()) self.add_cache(value) diff --git a/torch/_dynamo/comptime.py b/torch/_dynamo/comptime.py index 5ed096c68354..9e855419af1b 100644 --- a/torch/_dynamo/comptime.py +++ b/torch/_dynamo/comptime.py @@ -46,7 +46,7 @@ from typing import Optional, Union import torch from torch.fx.experimental.symbolic_shapes import free_symbols -from .exc import unimplemented +from .exc import unimplemented_v2 from .variables import CellVariable from .variables.constant import ConstantVariable from .variables.tensor import SymNodeVariable @@ -192,7 +192,12 @@ class ComptimeContext: """ Manually trigger a graph break """ - unimplemented(msg) + unimplemented_v2( + gb_type="ComptimeContext graph break", + context=msg, + explanation=f"Manually triggered ComptimeContext graph break with message {msg}.", + hints=[], + ) def graph(self): """ diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 52126e2bc10f..475ad64380d9 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -104,7 +104,7 @@ from .exc import ( SkipCodeRecursiveException, TorchRuntimeError, UncapturedHigherOrderOpError, - unimplemented, + unimplemented_v2, Unsupported, ) from .guards import ( @@ -535,7 +535,15 @@ class ConvertFrameAssert: return ConvertFrameReturn() if is_generator(code): - unimplemented("generator") + unimplemented_v2( + gb_type="Attempt to trace generator", + context="", + explanation="Generators cannot be compiled directly with `torch.compile`.", + hints=[ + "Call a generator from inside of a non-generator Python function and " + "compile that function instead.", + ], + ) if not has_tensor_in_frame(frame): return ConvertFrameReturn() @@ -793,7 +801,14 @@ def _compile( # We now have a new "last attempt", reset the clock last_attempt_start_time = time.time() if attempt > 100: - unimplemented("100+ RestartAnalysis() calls") + unimplemented_v2( + gb_type="Excessive RestartAnalysis() calls", + context="", + explanation="Dynamo attempted to trace the same frame 100+ times. " + "Giving up on compiling as the compile time tradeoff is likely not " + "worth the performance gain.", + hints=[], + ) except exc.SkipFrame as e: if not isinstance(e, exc.TensorifyScalarRestartAnalysis): TensorifyState.clear() @@ -962,7 +977,15 @@ def _compile( raise RecompileLimitExceeded(f"{limit_type} reached") else: # do not recursively skip frames - unimplemented(f"{limit_type} reached") + unimplemented_v2( + gb_type="Dynamo cache limit exceeded", + context=f"Limit type: {limit_type}", + explanation="Dynamo attempted to recompile the code object too many times, " + f"exceeding the {limit_type} cache size limit." + "Giving up on compiling as the compile time tradeoff is likely not " + "worth the performance gain.", + hints=[], + ) log.debug( "torchdynamo start compiling %s %s:%s, stack (elided %s frames):\n%s", diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 9e711048c628..30af985878f5 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1691,7 +1691,14 @@ class GuardBuilder(GuardBuilderBase): assert istype(val.training, bool) self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) else: - exc.unimplemented(f"Guard setup for uninitialized class {type(val)}") + exc.unimplemented_v2( + gb_type="Attempted to guard on uninitialized nn.Module", + context="", + explanation="Attempted to setup an NN_MODULE guard on unitialized " + f"nn.Module subclass `{type(val)}`. Please ensure the `nn.Module` " + "subclass instance has called `super().__init__()`.", + hints=[], + ) def FUNCTION_MATCH(self, guard: Guard): """things like torch.add and user defined functions""" diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 95d36ababd0d..2e5f1ef9213c 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -81,7 +81,7 @@ from .exc import ( BackendCompilerFailed, exceptions_allowed_to_be_fallback, SkipFrame, - unimplemented, + unimplemented_v2, unimplemented_v2_with_warning, ) from .graph_deduplication import apply_graph_deduplication @@ -486,7 +486,12 @@ class OutputGraph: def get_backward_state_proxy(self): if self.backward_state_proxy is None: if self.export: - unimplemented("backward_state does not support export") + unimplemented_v2( + gb_type="backward_state does not support export", + context="", + explanation="Compiled autograd doesn't work with `torch.export`.", + hints=[], + ) example_value = BackwardState() self.backward_state_proxy = self.root_tracer.create_graph_input( "dynamo_backward_state", @@ -994,7 +999,12 @@ class OutputGraph: log.debug("COMPILING GRAPH due to %s", reason) if not all(block.can_restore() for block in tx.block_stack): - unimplemented("compile_subgraph with block_depth != 0") + unimplemented_v2( + gb_type="Attempt to compile graph in a try block", + context="", + explanation="Dynamo cannot compile traced graphs while in a try block.", + hints=[], + ) prefix_insts: list[Instruction] = [] if sys.version_info >= (3, 11): @@ -1843,7 +1853,12 @@ def check_pt2_compliant_op(output_graph, kind, target, args, kwargs): def encountered_non_compliant_op(target, msg): output_graph.non_compliant_ops.add(target) if config.only_allow_pt2_compliant_ops: - unimplemented(msg + " " + err_epilogue) + unimplemented_v2( + gb_type="Encountered non-PT2-compliant op", + context="", + explanation=msg + " " + err_epilogue, + hints=[], + ) if isinstance(target, torch._ops.OpOverload): if torch.Tag.pt2_compliant_tag in target.tags: @@ -1881,7 +1896,12 @@ def check_pt2_compliant_op(output_graph, kind, target, args, kwargs): target._qualified_op_name, *args, **kwargs ) except RuntimeError as e: - unimplemented(str(e)) + unimplemented_v2( + gb_type="Error when attempting to resolve op packet", + context="", + explanation=str(e), + hints=[], + ) op = getattr(target, overload) if torch.Tag.pt2_compliant_tag in op.tags: @@ -1933,6 +1953,7 @@ class SubgraphTracer(fx.Tracer): # SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design] self.parent = parent + self.source_target = source_target # A dict mapping previously free variables (Proxy objects) # to new Proxy objects that wrap inputs to this subgraph. # @@ -2115,7 +2136,13 @@ class SubgraphTracer(fx.Tracer): ] elif kind == "call_module": if self.parent is not None: - unimplemented("Invoking an nn.Module inside HigherOrderOperator") + # TODO can remove once inline_inbuilt_nn_modules is always True + unimplemented_v2( + gb_type="Invoking an nn.Module inside a higher order operator", + context=f"Higher order op name: {self.source_target}", + explanation="This is not supported.", + hints=[], + ) # For modules we store the class rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ ( @@ -2143,8 +2170,12 @@ class SubgraphTracer(fx.Tracer): ] elif kind == "call_module": if self.parent is not None: - unimplemented( - "Invoking an nn.Module inside HigherOrderOperator" + # TODO can remove once inline_inbuilt_nn_modules is always True + unimplemented_v2( + gb_type="Invoking an nn.Module inside a HigherOrderOperator", + context="", + explanation="This is not supported.", + hints=[], ) # For modules we store the class rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index c8e74e4bffdf..a2ba0ee1e91e 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -19,7 +19,7 @@ from .bytecode_transformation import ( create_instruction, ) from .codegen import PyCodegen -from .exc import SideEffectsError, unimplemented +from .exc import SideEffectsError, unimplemented_v2 from .source import GlobalSource, LocalCellSource, LocalSource, Source from .utils import is_frozen_dataclass, nn_module_new, object_new from .variables.base import ( @@ -186,8 +186,12 @@ class SideEffects: "unintended variable modifications." ) if not is_side_effect_safe(item.mutation_type): - unimplemented( - "HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)" + # TODO plumb HOP information here + unimplemented_v2( + gb_type="HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)", + context="", + explanation="This is not supported.", + hints=[], ) def store_attr(self, item: VariableTracker, name: str, value: VariableTracker): @@ -202,12 +206,22 @@ class SideEffects: assert self.is_attribute_mutation(item) result = self.store_attr_mutations[item][name] if not deleted_ok and isinstance(result, variables.DeletedVariable): - unimplemented("read deleted attribute") + unimplemented_v2( + gb_type="Attempted to read a deleted variable", + context=f"item: {item}, name: {name}", + explanation="", + hints=[], + ) return result def store_cell(self, cellvar, value): if cellvar.is_immutable(): - unimplemented("Dynamo currently doesn't support writing to such cell") + unimplemented_v2( + gb_type="Write to immutable cell", + context=f"cellvar: {cellvar}, value: {value}", + explanation="Dynamo doesn't support writing to immutable/sourceless cell variables.", + hints=[], + ) assert isinstance(cellvar, variables.CellVariable) assert isinstance(value, variables.VariableTracker) self.store_attr(cellvar, "cell_contents", value) @@ -218,7 +232,12 @@ class SideEffects: return self.load_attr(cellvar, "cell_contents", check=False) if cellvar.pre_existing_contents: return cellvar.pre_existing_contents - unimplemented("cannot read uninitialized cell") + unimplemented_v2( + gb_type="Read uninitialized cell", + context=str(cellvar), + explanation="Attempted to read a cell variable that has not been populated yet.", + hints=[], + ) def load_global(self, gvar: VariableTracker, name: str): assert isinstance(gvar, variables.VariableTracker) @@ -574,7 +593,12 @@ class SideEffects: var.source = LocalCellSource(var.local_name) elif isinstance(var.mutation_type, AttributeMutationNew): if isinstance(var, variables.AutogradFunctionContextVariable): - unimplemented("AutogradFunctionContextVariable escaped") + unimplemented_v2( + gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region", + context="", + explanation="We cannot reconstruct a torch.autograd.Function's context object.", + hints=[], + ) # Reconstruct the bytecode for # base_cls.__new__(user_cls, *args) @@ -723,7 +747,14 @@ class SideEffects: isinstance(var.maxlen, variables.ConstantVariable) and var.maxlen.value is None ): - unimplemented("side effect on existing deque with limited maxlen") + unimplemented_v2( + gb_type="Side effect on existing deque with limited maxlen", + context="", + explanation="This is not supported.", + hints=[ + "Don't use a deque with `maxlen` specified.", + ], + ) # old.extend(new), this runs last cg(var.source) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index fdb1ee04c860..a140afa8a226 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -74,13 +74,7 @@ from .bytecode_transformation import ( ) from .code_context import code_context from .codegen import PyCodegen -from .exc import ( - ArgsMismatchError, - BackendCompilerFailed, - unimplemented, - unimplemented_v2, - Unsupported, -) +from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented_v2, Unsupported from .funcname_cache import get_funcname from .guards import GuardBuilder, install_guard from .output_graph import GraphCompileReason, OutputGraph @@ -538,9 +532,16 @@ def log_graph_break(code_options, reason="", exc_info=False, user_stack=None): def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool): def jump_graph_break(self, inst, value, extra_msg=""): - log_graph_break(self.code_options, reason="Data-dependent jump") + log_graph_break(self.code_options, reason="Data-dependent branching") if not self.should_compile_partial_graph(): - unimplemented("should_compile_partial_graph=False") + unimplemented_v2( + gb_type="Should not compile partial graph (data-dependent branching)", + context="", + explanation="Dynamo has determined when encountering data-dependent " + "branching (e.g. `if my_tensor.item() > 0:`) that it should not " + "compile the partial graph.", + hints=[], + ) # compile a partial subgraph prefix then jump into user code if self.maybe_has_backedge(): msg = ( @@ -610,8 +611,13 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool): result = torch.fx.experimental.symbolic_shapes.expect_true(sym_expr) if not result: - unimplemented( - "Assertion failed on symbolic shapes. Did you make sure eager mode succeeds?" + unimplemented_v2( + gb_type="Assertion failed on symbolic shapes", + context=str(sym_expr), + explanation="", + hints=[ + "Did you make sure your code works without compile?", + ], ) self.jump(inst) return @@ -683,8 +689,12 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool): self.push(value) self.jump(inst) else: - unimplemented( - "generic_jump on UserDefined with __bool__ returning non-constant" + unimplemented_v2( + gb_type="Data-dependent branching with non-constant __bool__", + context=f"method: {x}, result: {result}", + explanation="Attempted to perform data-dependent branching on a user-defined " + "object with a __bool__ method that did not return a constant.", + hints=[], ) # __bool__ or __len__ is non-function or not existed in the user defined object else: @@ -757,7 +767,12 @@ def break_graph_if_unsupported(*, push): # We don't support graph break under GenericContextWrappingVariable, # If there is, we roll back to the checkpoint and fall back. excp.remove_from_stats() - unimplemented("Graph break under GenericContextWrappingVariable") + unimplemented_v2( + gb_type="Graph break under GenericContextWrappingVariable", + context="", + explanation="Attempted to graph break in an active context manager that doesn't support graph breaking.", + hints=[], + ) if isinstance(excp, exc.UncapturedHigherOrderOpError): raise @@ -866,7 +881,12 @@ class BytecodeDistpatchTableMeta(type): super().__init__(name, bases, dct) def _missing(opname, *args): - unimplemented(f"missing: {opname}") + unimplemented_v2( + gb_type="Missing bytecode handler", + context=opname, + explanation=f"Dynamo does not know how to handle the bytecode instruction {opname}", + hints=[], + ) dispatch_table = { op: getattr(cls, opname, functools.partial(_missing, opname)) @@ -1219,9 +1239,21 @@ class InstructionTranslatorBase( new_name = name.replace(".", "implicit") self.push(self.symbolic_locals[new_name]) except KeyError: - unimplemented("undefined LOAD_FAST (implicit)") + unimplemented_v2( + gb_type="Attempted to read undefined local variable (implicit)", + context=f"LOAD_FAST {name}", + explanation=f"Could not find an implicit local variable with name `{name}`", + hints=[ + "This happens in dict/list comprehensions", + ], + ) else: - unimplemented("undefined LOAD_FAST") + unimplemented_v2( + gb_type="Attempted to read undefined local variable", + context=f"LOAD_FAST {name}", + explanation=f"Could not find a local variable with name `{name}`", + hints=[], + ) # for continuation functions if name.startswith("___stack"): @@ -1315,7 +1347,12 @@ class InstructionTranslatorBase( source, self.symbolic_globals[name] ) if isinstance(value, RemovableHandleVariable): - unimplemented("Storing handles in globals - NYI") + unimplemented_v2( + gb_type="Storing Tensor hook handle in globals", + context=name, + explanation="This is not supported.", + hints=[], + ) self.output.side_effects.store_global(variable, name, value) # Cache note: This cache only exists for the duration of this @@ -1402,7 +1439,12 @@ class InstructionTranslatorBase( globals=self.f_globals, ) except ImportError: - unimplemented("import a module that does not exist") + unimplemented_v2( + gb_type="Import failure", + context=f"module_name: {module_name}, fromlist: {fromlist}, level={level}", + explanation="Failure when attempting to import.", + hints=[], + ) if level != 0: pkg = self.calc_package() @@ -1425,7 +1467,12 @@ class InstructionTranslatorBase( if istype(value, (types.ModuleType, DummyModule)): self.push(PythonModuleVariable(value, source=source)) else: - unimplemented(f"IMPORT_NAME {typestr(value)}") + unimplemented_v2( + gb_type="Bad import result", + context=typestr(value), + explanation="Import result is not a Python module.", + hints=[], + ) def IMPORT_FROM(self, inst): self.DUP_TOP(inst) @@ -1566,13 +1613,24 @@ class InstructionTranslatorBase( if observed_exception_type := exc.observed_exception_map.get(val.exc_type): raise observed_exception_type(f"raised exception {val}") raise exc.ObservedException(f"raised exception {val}") - unimplemented(f"raise {exc}") + unimplemented_v2( + gb_type="Failed to raise exception", + context=str(exc), + explanation="Attempted to raise a non-Exception type/value.", + hints=[], + ) def RAISE_VARARGS(self, inst): if inst.arg == 0: # duplicate the top of the stack and re-raise it if sys.version_info < (3, 11): - unimplemented("re-raise") + unimplemented_v2( + gb_type="Re-raise with no arguments", + context="", + explanation="Dynamo does not support re-raising the previous exception " + "in Python < 3.11 (empty `raise`)", + hints=[], + ) assert isinstance(self.stack[-1], ExceptionVariable) self.stack.append(self.stack[-1]) self._raise_exception_variable(inst) @@ -1584,14 +1642,24 @@ class InstructionTranslatorBase( from_vt = self.pop() if isinstance(from_vt, ConstantVariable) and from_vt.value is None: self._raise_exception_variable(inst) - unimplemented("raise ... from ...") + unimplemented_v2( + gb_type="Re-raise with 2 arguments", + context=str(from_vt), + explanation="Dynamo does not support `raise ... from [not-None]`", + hints=[], + ) def CLEANUP_THROW(self, inst): # https://github.com/python/cpython/pull/96010 tos = self.stack[-1] assert isinstance(tos, ExceptionVariable) if tos.exc_type is StopIteration: - unimplemented("CLEANUP_THROW with StopIteration") + unimplemented_v2( + gb_type="CLEANUP_THROW with StopIteration", + context="", + explanation="Received StopIteration when handling generator.throw/close. This is not supported.", + hints=[], + ) else: self.RERAISE(inst) @@ -1599,7 +1667,13 @@ class InstructionTranslatorBase( if sys.version_info >= (3, 11): # RERAISE is currently supported in a narrow case of `raise ... from None` self._raise_exception_variable(inst) - unimplemented("RERAISE") + unimplemented_v2( + gb_type="RERAISE in Python < 3.11", + context="", + explanation="RERAISE bytecode (https://docs.python.org/3.10/library/dis.html#opcode-RERAISE) " + "not supported in Python < 3.11. This bytecode is generated by try/with blocks.", + hints=[], + ) def WITH_EXCEPT_START(self, inst): if sys.version_info >= (3, 11): @@ -1679,10 +1753,13 @@ class InstructionTranslatorBase( block_stack_entry = self.block_stack.pop() if block_stack_entry.inst.opname != "SETUP_FINALLY": - unimplemented( - "exception is raised when top of the block stack " + unimplemented_v2( + gb_type="Exception raised with invalid exception handler", + context="", + explanation="Exception raised when top of the block stack " "is not exception handler (e.g. try .. with .. except). " - f"Current TOS is {block_stack_entry.inst}" + f"Current TOS is {block_stack_entry.inst}", + hints=[], ) # 1) pop values from the stack until it matches the stack depth @@ -1779,14 +1856,20 @@ class InstructionTranslatorBase( # 2) except (NotImplemetedError, AttributeError) -> TupleVariable if not isinstance(expected_exc_types, (BuiltinVariable, TupleVariable)): - unimplemented( - f"except has an unsupported types of objects {expected_exc_types}" + unimplemented_v2( + gb_type="Exception with bad expected type", + context=str(expected_exc_types), + explanation=f"`except ...` has unsupported type {expected_exc_types}.", + hints=[], ) if sys.version_info >= (3, 11): if not isinstance(exc_instance, variables.ExceptionVariable): - unimplemented( - f"except expects to recieve an object of exception type but received {exc_instance}" + unimplemented_v2( + gb_type="Caught non-Exception value", + context=str(exc_instance), + explanation=f"Except expects to recieve an object of Exception type but received {exc_instance}.", + hints=[], ) if isinstance(expected_exc_types, TupleVariable): @@ -1798,8 +1881,11 @@ class InstructionTranslatorBase( for expected_type in expected_types: if not isinstance(expected_type, BuiltinVariable): - unimplemented( - f"except has an unsupported types of object {expected_type}" + unimplemented_v2( + gb_type="Exception with non-type expectation", + context=str(expected_type), + explanation=f"`except ...` expects a non-type: {expected_type}.", + hints=[], ) if isinstance(exc_instance, variables.ExceptionVariable) and issubclass( exc_instance.exc_type, expected_type.fn @@ -1844,7 +1930,12 @@ class InstructionTranslatorBase( kwargsvars = self.pop() argsvars = self.pop() else: - unimplemented("CALL_FUNCTION_EX") + unimplemented_v2( + gb_type="Variadic function call with bad flags", + context=f"flags: {inst.argval}", + explanation=f"Attempted to call a variadic function (CALL_FUNCTION_EX) with bad flags {inst.argval}", + hints=[], + ) if sys.version_info >= (3, 13): # 3.13 swapped null and callable @@ -1879,7 +1970,12 @@ class InstructionTranslatorBase( # args, aot_autograd/inductor while lowering generates # aten.random.from, again causing syntax errors. Since this # usecase is uncommon, graph break. - unimplemented("random_ op is called with from keyword") + unimplemented_v2( + gb_type="Tensor.random_ op called with `from` keyword", + context="", + explanation="This is not supported.", + hints=[], + ) elif ( fn.name == "uniform_" and isinstance(argsvars, TupleVariable) @@ -1892,7 +1988,12 @@ class InstructionTranslatorBase( # args, aot_autograd/inductor while lowering generates # aten.uniform.from, again causing syntax errors. Since this # usecase is uncommon, graph break. - unimplemented("uniform_ op is called with from keyword") + unimplemented_v2( + gb_type="Tensor.uniform_ op called with `from` keyword", + context="", + explanation="This is not supported.", + hints=[], + ) if not isinstance( argsvars, BaseListVariable @@ -1906,7 +2007,12 @@ class InstructionTranslatorBase( if not isinstance(argsvars, BaseListVariable) or not isinstance( kwargsvars, ConstDictVariable ): - unimplemented(f"non-static call {typestr(argsvars)} {typestr(kwargsvars)}") + unimplemented_v2( + gb_type="Variadic function call with bad args/kwargs type", + context=f"args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}", + explanation="Expected args to be a list and kwargs to be a dict", + hints=[], + ) # Map to a dictionary of str -> VariableTracker kwargsvars = kwargsvars.keys_as_python_constant() @@ -2005,7 +2111,13 @@ class InstructionTranslatorBase( def store_attr_graph_break(self, inst): log_graph_break(self.code_options, reason="STORE_ATTR-caused graph break") if not self.should_compile_partial_graph(): - unimplemented("should_compile_partial_graph=False") + unimplemented_v2( + gb_type="Should not compile partial graph (STORE_ATTR)", + context="", + explanation="Dynamo has determined when encountering an unsupported " + "STORE_ATTR instruction (i.e. `obj.attr = val`) that it should not compile the partial graph.", + hints=[], + ) self.output.compile_subgraph( self, reason=GraphCompileReason("store_attr", [self.frame_summary()]) ) @@ -2054,7 +2166,12 @@ class InstructionTranslatorBase( def BUILD_SET(self, inst): if config.inject_BUILD_SET_unimplemented_TESTING_ONLY: - unimplemented("missing: BUILD_SET") + unimplemented_v2( + gb_type="missing BUILD_SET handler", + context="", + explanation="Missing BUILD_SET bytecode handler (for testing purposes).", + hints=[], + ) items = self.popn(inst.argval) new_set = SetVariable(items, mutation_type=ValueMutationNew()) self.push(new_set) @@ -2066,7 +2183,13 @@ class InstructionTranslatorBase( try: items.extend(seq.force_unpack_var_sequence(self)) except NotImplementedError: - unimplemented(f"BUILD_LIST_UNPACK {seq}") + unimplemented_v2( + gb_type="Failed to unpack object for BUILD_LIST_UNPACK", + context=str(seq), + explanation=f"{seq} cannot be unpacked into a list for the BUILD_LIST_UNPACK " + "bytecode (`[*x, *y, ...]`).", + hints=[], + ) self.push(cls(items, mutation_type=ValueMutationNew())) def BUILD_TUPLE_UNPACK(self, inst): @@ -2193,9 +2316,21 @@ class InstructionTranslatorBase( elif seq.has_force_unpack_var_sequence(self): val = seq.force_unpack_var_sequence(self) else: - unimplemented(f"UNPACK_SEQUENCE {seq}") + unimplemented_v2( + gb_type="Failed to unpack object for UNPACK_SEQUENCE", + context=str(seq), + explanation=f"{seq} cannot be unpacked into a list for the UNPACK_SEQUENCE bytecode " + "(i.e. `a, b, c = d`).", + hints=[], + ) if len(val) != inst.argval: - unimplemented("UNPACK_SEQUENCE length mismatch") + unimplemented_v2( + gb_type="Length mismatch when unpacking object for UNPACK_SEQUENCE", + context=f"expected length: {inst.argval}, actual: {len(val)}", + explanation=f"{seq} unpacked to a list for the UNPACK_SEQUENCE bytecode " + "(i.e. `a, b, c = d`) with unexpected length.", + hints=[], + ) for i in reversed(val): self.push(i) @@ -2216,7 +2351,13 @@ class InstructionTranslatorBase( for item in reversed(vals_prefix): self.push(item) else: - unimplemented(f"UNPACK_EX {seq}") + unimplemented_v2( + gb_type="Failed to unpack object for UNPACK_EX", + context=str(seq), + explanation=f"{seq} cannot be unpacked into a list for the UNPACK_EX bytecode " + "(i.e. `a, *b, c = d`).", + hints=[], + ) def NOP(self, inst): pass @@ -2311,12 +2452,20 @@ class InstructionTranslatorBase( format_string_parts.append(part.format_string) args.extend(part.sym_args) if set(kwargs.keys()) & set(part.sym_kwargs.keys()): - unimplemented( - f"BUILD_STRING key conflict {kwargs} & {part.sym_kwargs}" + unimplemented_v2( + gb_type="BUILD_STRING key conflict", + context=f"format_string_parts: {format_string_parts}, kwargs: {kwargs}, part.sym_kwargs: {part.sym_kwargs}", + explanation="Failed to build format string due to key conflict", + hints=[], ) kwargs.update(part.sym_kwargs) else: - unimplemented(f"BUILD_STRING {part}") + unimplemented_v2( + gb_type="BUILD_STRING type error", + context=str(part), + explanation="Format string part type is not correct - expected constant or format string.", + hints=[], + ) self.push( variables.StringFormatVariable.create( "".join(format_string_parts), args, kwargs @@ -2638,7 +2787,12 @@ class InstructionTranslatorBase( def LOAD_FAST_CHECK(self, inst): if isinstance(self.symbolic_locals[inst.argval], NullVariable): - unimplemented("LOAD_FAST_CHECK on uninitialized variable") + unimplemented_v2( + gb_type="LOAD_FAST_CHECK on uninitialized variable", + context=inst.argval, + explanation=f"Attempted to load uninitialized local variable {inst.argval}", + hints=[], + ) self.LOAD_FAST(inst) def LOAD_FAST_AND_CLEAR(self, inst): @@ -2666,7 +2820,12 @@ class InstructionTranslatorBase( # INTRINSIC_LIST_TO_TUPLE self.push(TupleVariable(self.pop().force_unpack_var_sequence(self))) else: - unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}") + unimplemented_v2( + gb_type="Missing CALL_INTRINSIC_1 handler", + context=f"CALL_INTRINSIC_1 operand: {inst.argval}", + explanation=f"No handler implemented for CALL_INTRINSIC_1 {inst.argval} instruction.", + hints=[], + ) def END_SEND(self, inst): tos = self.pop() @@ -3052,7 +3211,12 @@ class InstructionTranslator(InstructionTranslatorBase): # if it reaches here, it means Dynamo failed to inline a functorch function f"- torch.func.{name}(fn) requires the function to be inlined by dynamo" ) - unimplemented(msg) + unimplemented_v2( + gb_type="Unsupported functorch tracing attempt", + context="", + explanation=msg, + hints=[], + ) def get_example_value(self, source: Source): if isinstance(source, LocalSource): @@ -3288,7 +3452,13 @@ class InliningInstructionTranslator(InstructionTranslatorBase): @staticmethod def check_inlineable(func): if func.has_self(): - unimplemented("inline with __self__") + unimplemented_v2( + gb_type="Inline attempt with __self__", + context=str(func), + explanation="Attempted to inline a function with the `__self__` attribute. " + "Dynamo is expected to decompose method calls into function calls with a `self` argument.", + hints=[], + ) result = trace_rules.check_verbose(func, is_inlined_call=True) if result.skipped: @@ -3346,7 +3516,12 @@ class InliningInstructionTranslator(InstructionTranslatorBase): kwargs, ): if isinstance(func, SkipFunctionVariable): - unimplemented("inline with functions in skip files") + unimplemented_v2( + gb_type="Attempted to inline function marked as skipped (SkipFunctionVariable)", + context=f"Attempted to inline a SkipFunctionVariable {func}", + explanation="Attempted to inline a function that was previously determined to be marked as intentionally skipped.", + hints=[], + ) assert isinstance( func, ( @@ -3373,13 +3548,23 @@ class InliningInstructionTranslator(InstructionTranslatorBase): for v in itertools.chain(sub_locals.values()): if not isinstance(v, VariableTracker): - unimplemented(f"unconverted arg {v}") + unimplemented_v2( + gb_type="Encountered unconverted argument when attempting to inline", + context=f"func: {func}, arg: {v}", + explanation="An argument to an inlined function was not successfully converted to a VariableTracker.", + hints=[], + ) code: types.CodeType = func.get_code() if code.co_name in ("__setitem__", "__setattr__") and not ( args and isinstance(args[0], variables.UserDefinedObjectVariable) ): - unimplemented(f"inline {code.co_name}") + unimplemented_v2( + gb_type="Unsupported __setitem__/__setattr__ inline attempt", + context=f"code name: {code.co_name}, args: {args}", + explanation=f"Attempted to inline {code.co_name} where first argument (self) is not a user-defined object.", + hints=[], + ) suffix = "" # TODO: mlazos, add support for enabling multiple artifact logs @@ -3541,7 +3726,12 @@ class InliningInstructionTranslator(InstructionTranslatorBase): return False # inlining functions is all-or-nothing def create_call_resume_at(self, offset): - unimplemented("cant resume while inlining") + unimplemented_v2( + gb_type="Graph break in inlined function", + context="", + explanation="Graph breaks in an inlined call are not supported.", + hints=[], + ) def RETURN_VALUE(self, inst): self.symbolic_result = self.pop() # type: ignore[assignment] @@ -3600,7 +3790,12 @@ class InliningInstructionTranslator(InstructionTranslatorBase): else: value = self.pop() if isinstance(value, RemovableHandleVariable): - unimplemented("Storing handles in globals - NYI") + unimplemented_v2( + gb_type="Storing Tensor hook handle in globals (inline call)", + context=inst.argval, + explanation="This is not supported.", + hints=[], + ) name = inst.argval _fglobals_value, fglobals_vt, _ = self.get_globals_source_and_value(name) self.output.side_effects.store_attr(fglobals_vt, name, value) @@ -3658,7 +3853,12 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): # lifted the `unimplemented("generator")` in frame conversion. This codepath handles # subgenerator and lines up with this line in Python 3.10 # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L2599 - unimplemented("Unreachable sub-generator code") + unimplemented_v2( + gb_type="Unreachable sub-generator code", + context="", + explanation="Should only be encountered while implementing generator support.", + hints=[], + ) try: val = tos.next_variable(self) @@ -3711,6 +3911,16 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): # lifted the `unimplemented("generator")` in frame conversion. This codepath handles # subgenerator and lines up with this line in Python 3.11 # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2597 - unimplemented("Unreachable sub-generator code") + unimplemented_v2( + gb_type="Unreachable sub-generator code", + context="", + explanation="Should only be encountered while implementing generator support.", + hints=[], + ) else: - unimplemented(f"SEND {typestr(tos)}") + unimplemented_v2( + gb_type="SEND with bad type", + context=f"TOS type: {typestr(tos)}", + explanation=f"Attempted to SEND with unsupported type {typestr(tos)}.", + hints=[], + ) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 6a248cb849b2..69f79a7b4ee7 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1109,11 +1109,14 @@ def proxy_args_kwargs(args, kwargs): proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} return proxy_args, proxy_kwargs except NotImplementedError as e: - from .exc import unimplemented + from .exc import unimplemented_v2 from .variables.base import typestr - unimplemented( - f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}", + unimplemented_v2( + gb_type="Failed to convert args/kwargs to proxy", + context=f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}", + explanation="Missing `as_proxy()` implementation for some arg/kwarg.", + hints=[], from_exc=e, ) @@ -2464,9 +2467,14 @@ def set_example_value(node, example_value): def _get_fake_tensor(vt): fake_tensor = vt.as_proxy().node.meta.get("example_value") if not is_fake(fake_tensor): - from .exc import unimplemented + from .exc import unimplemented_v2 - unimplemented("Cannot check Tensor object identity without its fake value") + unimplemented_v2( + gb_type="Cannot check Tensor object identity without its fake value", + context=str(fake_tensor), + explanation="TensorVariable is missing a fake example_value.", + hints=[], + ) return fake_tensor @@ -2578,11 +2586,17 @@ def wrap_fake_exception(fn): try: return fn() except UnsupportedFakeTensorException as e: - from .exc import unimplemented + from .exc import unimplemented_v2 - msg = f"Unsupported: {e.reason} with fake tensor propagation." + msg = f"Encountered exception ({e.reason}) during fake tensor propagation." log.warning(msg) - unimplemented(msg, from_exc=e) + unimplemented_v2( + gb_type="Fake tensor propagation exception", + context=str(e.reason), + explanation=msg, + hints=[], + from_exc=e, + ) def deepcopy_to_fake_tensor(obj, fake_mode): @@ -2956,9 +2970,14 @@ def extract_fake_example_value(node, required=True): if "example_value" in node.meta and is_fake(node.meta["example_value"]): return node.meta["example_value"] elif required: - from torch._dynamo.exc import unimplemented + from torch._dynamo.exc import unimplemented_v2 - unimplemented("`FakeTensor` example value was required but not available") + unimplemented_v2( + gb_type="Missing FakeTensor example value", + context=str(node), + explanation=f"`FakeTensor` example value was required for {node} but not available.", + hints=[], + ) else: return None @@ -3002,7 +3021,6 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): from .exc import ( TorchRuntimeError, - unimplemented, unimplemented_v2, Unsupported, UserError, @@ -3124,10 +3142,15 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): f"module `{module}` and you may need to `import {module}`" f"({ctx}), otherwise " ) - unimplemented( - f"unsupported operator: {cause.func} ({import_suggestion}see " - "https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0" - " for how to fix)" + unimplemented_v2( + gb_type="Operator does not support running with fake tensors", + context=f"unsupported operator: {cause.func}", + explanation="", + hints=[ + f"{import_suggestion}see " + "https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0" + " for how to fix", + ], ) elif isinstance( cause, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode @@ -3140,7 +3163,12 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): elif isinstance(cause, ValueRangeError): raise UserError(UserErrorType.CONSTRAINT_VIOLATION, e.args[0]) from e elif isinstance(cause, TypeError) and "argument" in str(cause): - unimplemented(f"TypeError {node.target}: {cause}") + unimplemented_v2( + gb_type="TypeError when making fake tensor call", + context=f"TypeError {node.target}: {cause}", + explanation="", + hints=[], + ) raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None @@ -3197,9 +3225,14 @@ def run_node(tracer, node, args, kwargs, nnmodule): return node.target(*args, **kwargs) elif op == "call_method": if not hasattr(args[0], node.target): - from .exc import unimplemented + from .exc import unimplemented_v2 - unimplemented(make_error_message("attribute not defined")) + unimplemented_v2( + gb_type="Missing attribute when running call_method node", + context="", + explanation=make_error_message("attribute not defined"), + hints=[], + ) return getattr(args[0], node.target)(*args[1:], **kwargs) elif op == "call_module": assert nnmodule is not None @@ -3212,9 +3245,15 @@ def run_node(tracer, node, args, kwargs, nnmodule): except (NotImplementedError, UnsupportedFakeTensorException) as e: # NB: mimic how wrap_fake_exception does it - from .exc import unimplemented + from .exc import unimplemented_v2 - unimplemented(make_error_message(e), from_exc=e) + unimplemented_v2( + gb_type="NotImplementedError/UnsupportedFakeTensorException when running node", + context=str(e), + explanation=make_error_message(e), + hints=[], + from_exc=e, + ) except Unsupported: raise except Exception as e: