[dynamo] make some more graph break messages readable in English [2/N] (#147385)

This is for "for some large number Z, make sure the error messages are readable English." - beginning to audit all `unimplemented` sites and making sure that all messages are at least English-readable. Hints may not necessarily be provided.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147385
Approved by: https://github.com/jansel
This commit is contained in:
William Wen
2025-02-25 16:42:22 -08:00
committed by PyTorch MergeBot
parent 7a06bfdd1c
commit 3fd68e4e2f
9 changed files with 463 additions and 105 deletions

View File

@ -48,7 +48,7 @@ class TestCustomOperators(TestCase):
with self.assertRaisesRegex(
RuntimeError,
r"unsupported operator: .* you may need to `import nonexistent`",
r"(?s)Operator does not support running with fake tensors.*you may need to `import nonexistent`",
):
f(x)

View File

@ -36,7 +36,7 @@ from .bytecode_transformation import (
create_rot_n,
Instruction,
)
from .exc import IncorrectUsage, unimplemented
from .exc import IncorrectUsage, unimplemented, unimplemented_v2
from .source import AttrSource, ChainedSource, DictGetItemSource, Source
from .utils import is_safe_constant, rot_n_helper
from .variables.base import ValueMutationExisting, VariableTracker
@ -335,7 +335,19 @@ class PyCodegen:
try:
self.call_reconstruct(value)
except NotImplementedError:
unimplemented(f"reconstruct: {value}")
unimplemented_v2(
gb_type="Reconstruction failure",
context=str(value),
explanation=f"Dynamo has no bytecode reconstruction implemented for sourceless variable {value}.",
hints=[
"If Dynamo attempting to trace a return statement and your code is attempting to return a variable "
"that Dynamo cannot reconstruct, then remove it from the return statement.",
"If this reconstruction graph break occurs while handling another graph break, then resolve the "
"initial graph break.",
"Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have"
"reconstruction rules may be fundamentally unreconstructable.",
],
)
if allow_cache and value in self.tempvars:
self._output.append(create_dup_top())
self.add_cache(value)

View File

@ -46,7 +46,7 @@ from typing import Optional, Union
import torch
from torch.fx.experimental.symbolic_shapes import free_symbols
from .exc import unimplemented
from .exc import unimplemented_v2
from .variables import CellVariable
from .variables.constant import ConstantVariable
from .variables.tensor import SymNodeVariable
@ -192,7 +192,12 @@ class ComptimeContext:
"""
Manually trigger a graph break
"""
unimplemented(msg)
unimplemented_v2(
gb_type="ComptimeContext graph break",
context=msg,
explanation=f"Manually triggered ComptimeContext graph break with message {msg}.",
hints=[],
)
def graph(self):
"""

View File

@ -104,7 +104,7 @@ from .exc import (
SkipCodeRecursiveException,
TorchRuntimeError,
UncapturedHigherOrderOpError,
unimplemented,
unimplemented_v2,
Unsupported,
)
from .guards import (
@ -535,7 +535,15 @@ class ConvertFrameAssert:
return ConvertFrameReturn()
if is_generator(code):
unimplemented("generator")
unimplemented_v2(
gb_type="Attempt to trace generator",
context="",
explanation="Generators cannot be compiled directly with `torch.compile`.",
hints=[
"Call a generator from inside of a non-generator Python function and "
"compile that function instead.",
],
)
if not has_tensor_in_frame(frame):
return ConvertFrameReturn()
@ -793,7 +801,14 @@ def _compile(
# We now have a new "last attempt", reset the clock
last_attempt_start_time = time.time()
if attempt > 100:
unimplemented("100+ RestartAnalysis() calls")
unimplemented_v2(
gb_type="Excessive RestartAnalysis() calls",
context="",
explanation="Dynamo attempted to trace the same frame 100+ times. "
"Giving up on compiling as the compile time tradeoff is likely not "
"worth the performance gain.",
hints=[],
)
except exc.SkipFrame as e:
if not isinstance(e, exc.TensorifyScalarRestartAnalysis):
TensorifyState.clear()
@ -962,7 +977,15 @@ def _compile(
raise RecompileLimitExceeded(f"{limit_type} reached")
else:
# do not recursively skip frames
unimplemented(f"{limit_type} reached")
unimplemented_v2(
gb_type="Dynamo cache limit exceeded",
context=f"Limit type: {limit_type}",
explanation="Dynamo attempted to recompile the code object too many times, "
f"exceeding the {limit_type} cache size limit."
"Giving up on compiling as the compile time tradeoff is likely not "
"worth the performance gain.",
hints=[],
)
log.debug(
"torchdynamo start compiling %s %s:%s, stack (elided %s frames):\n%s",

View File

@ -1691,7 +1691,14 @@ class GuardBuilder(GuardBuilderBase):
assert istype(val.training, bool)
self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH)
else:
exc.unimplemented(f"Guard setup for uninitialized class {type(val)}")
exc.unimplemented_v2(
gb_type="Attempted to guard on uninitialized nn.Module",
context="",
explanation="Attempted to setup an NN_MODULE guard on unitialized "
f"nn.Module subclass `{type(val)}`. Please ensure the `nn.Module` "
"subclass instance has called `super().__init__()`.",
hints=[],
)
def FUNCTION_MATCH(self, guard: Guard):
"""things like torch.add and user defined functions"""

View File

@ -81,7 +81,7 @@ from .exc import (
BackendCompilerFailed,
exceptions_allowed_to_be_fallback,
SkipFrame,
unimplemented,
unimplemented_v2,
unimplemented_v2_with_warning,
)
from .graph_deduplication import apply_graph_deduplication
@ -486,7 +486,12 @@ class OutputGraph:
def get_backward_state_proxy(self):
if self.backward_state_proxy is None:
if self.export:
unimplemented("backward_state does not support export")
unimplemented_v2(
gb_type="backward_state does not support export",
context="",
explanation="Compiled autograd doesn't work with `torch.export`.",
hints=[],
)
example_value = BackwardState()
self.backward_state_proxy = self.root_tracer.create_graph_input(
"dynamo_backward_state",
@ -994,7 +999,12 @@ class OutputGraph:
log.debug("COMPILING GRAPH due to %s", reason)
if not all(block.can_restore() for block in tx.block_stack):
unimplemented("compile_subgraph with block_depth != 0")
unimplemented_v2(
gb_type="Attempt to compile graph in a try block",
context="",
explanation="Dynamo cannot compile traced graphs while in a try block.",
hints=[],
)
prefix_insts: list[Instruction] = []
if sys.version_info >= (3, 11):
@ -1843,7 +1853,12 @@ def check_pt2_compliant_op(output_graph, kind, target, args, kwargs):
def encountered_non_compliant_op(target, msg):
output_graph.non_compliant_ops.add(target)
if config.only_allow_pt2_compliant_ops:
unimplemented(msg + " " + err_epilogue)
unimplemented_v2(
gb_type="Encountered non-PT2-compliant op",
context="",
explanation=msg + " " + err_epilogue,
hints=[],
)
if isinstance(target, torch._ops.OpOverload):
if torch.Tag.pt2_compliant_tag in target.tags:
@ -1881,7 +1896,12 @@ def check_pt2_compliant_op(output_graph, kind, target, args, kwargs):
target._qualified_op_name, *args, **kwargs
)
except RuntimeError as e:
unimplemented(str(e))
unimplemented_v2(
gb_type="Error when attempting to resolve op packet",
context="",
explanation=str(e),
hints=[],
)
op = getattr(target, overload)
if torch.Tag.pt2_compliant_tag in op.tags:
@ -1933,6 +1953,7 @@ class SubgraphTracer(fx.Tracer):
# SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design]
self.parent = parent
self.source_target = source_target
# A dict mapping previously free variables (Proxy objects)
# to new Proxy objects that wrap inputs to this subgraph.
#
@ -2115,7 +2136,13 @@ class SubgraphTracer(fx.Tracer):
]
elif kind == "call_module":
if self.parent is not None:
unimplemented("Invoking an nn.Module inside HigherOrderOperator")
# TODO can remove once inline_inbuilt_nn_modules is always True
unimplemented_v2(
gb_type="Invoking an nn.Module inside a higher order operator",
context=f"Higher order op name: {self.source_target}",
explanation="This is not supported.",
hints=[],
)
# For modules we store the class
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
(
@ -2143,8 +2170,12 @@ class SubgraphTracer(fx.Tracer):
]
elif kind == "call_module":
if self.parent is not None:
unimplemented(
"Invoking an nn.Module inside HigherOrderOperator"
# TODO can remove once inline_inbuilt_nn_modules is always True
unimplemented_v2(
gb_type="Invoking an nn.Module inside a HigherOrderOperator",
context="",
explanation="This is not supported.",
hints=[],
)
# For modules we store the class
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [

View File

@ -19,7 +19,7 @@ from .bytecode_transformation import (
create_instruction,
)
from .codegen import PyCodegen
from .exc import SideEffectsError, unimplemented
from .exc import SideEffectsError, unimplemented_v2
from .source import GlobalSource, LocalCellSource, LocalSource, Source
from .utils import is_frozen_dataclass, nn_module_new, object_new
from .variables.base import (
@ -186,8 +186,12 @@ class SideEffects:
"unintended variable modifications."
)
if not is_side_effect_safe(item.mutation_type):
unimplemented(
"HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)"
# TODO plumb HOP information here
unimplemented_v2(
gb_type="HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)",
context="",
explanation="This is not supported.",
hints=[],
)
def store_attr(self, item: VariableTracker, name: str, value: VariableTracker):
@ -202,12 +206,22 @@ class SideEffects:
assert self.is_attribute_mutation(item)
result = self.store_attr_mutations[item][name]
if not deleted_ok and isinstance(result, variables.DeletedVariable):
unimplemented("read deleted attribute")
unimplemented_v2(
gb_type="Attempted to read a deleted variable",
context=f"item: {item}, name: {name}",
explanation="",
hints=[],
)
return result
def store_cell(self, cellvar, value):
if cellvar.is_immutable():
unimplemented("Dynamo currently doesn't support writing to such cell")
unimplemented_v2(
gb_type="Write to immutable cell",
context=f"cellvar: {cellvar}, value: {value}",
explanation="Dynamo doesn't support writing to immutable/sourceless cell variables.",
hints=[],
)
assert isinstance(cellvar, variables.CellVariable)
assert isinstance(value, variables.VariableTracker)
self.store_attr(cellvar, "cell_contents", value)
@ -218,7 +232,12 @@ class SideEffects:
return self.load_attr(cellvar, "cell_contents", check=False)
if cellvar.pre_existing_contents:
return cellvar.pre_existing_contents
unimplemented("cannot read uninitialized cell")
unimplemented_v2(
gb_type="Read uninitialized cell",
context=str(cellvar),
explanation="Attempted to read a cell variable that has not been populated yet.",
hints=[],
)
def load_global(self, gvar: VariableTracker, name: str):
assert isinstance(gvar, variables.VariableTracker)
@ -574,7 +593,12 @@ class SideEffects:
var.source = LocalCellSource(var.local_name)
elif isinstance(var.mutation_type, AttributeMutationNew):
if isinstance(var, variables.AutogradFunctionContextVariable):
unimplemented("AutogradFunctionContextVariable escaped")
unimplemented_v2(
gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region",
context="",
explanation="We cannot reconstruct a torch.autograd.Function's context object.",
hints=[],
)
# Reconstruct the bytecode for
# base_cls.__new__(user_cls, *args)
@ -723,7 +747,14 @@ class SideEffects:
isinstance(var.maxlen, variables.ConstantVariable)
and var.maxlen.value is None
):
unimplemented("side effect on existing deque with limited maxlen")
unimplemented_v2(
gb_type="Side effect on existing deque with limited maxlen",
context="",
explanation="This is not supported.",
hints=[
"Don't use a deque with `maxlen` specified.",
],
)
# old.extend(new), this runs last
cg(var.source)

View File

@ -74,13 +74,7 @@ from .bytecode_transformation import (
)
from .code_context import code_context
from .codegen import PyCodegen
from .exc import (
ArgsMismatchError,
BackendCompilerFailed,
unimplemented,
unimplemented_v2,
Unsupported,
)
from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented_v2, Unsupported
from .funcname_cache import get_funcname
from .guards import GuardBuilder, install_guard
from .output_graph import GraphCompileReason, OutputGraph
@ -538,9 +532,16 @@ def log_graph_break(code_options, reason="", exc_info=False, user_stack=None):
def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
def jump_graph_break(self, inst, value, extra_msg=""):
log_graph_break(self.code_options, reason="Data-dependent jump")
log_graph_break(self.code_options, reason="Data-dependent branching")
if not self.should_compile_partial_graph():
unimplemented("should_compile_partial_graph=False")
unimplemented_v2(
gb_type="Should not compile partial graph (data-dependent branching)",
context="",
explanation="Dynamo has determined when encountering data-dependent "
"branching (e.g. `if my_tensor.item() > 0:`) that it should not "
"compile the partial graph.",
hints=[],
)
# compile a partial subgraph prefix then jump into user code
if self.maybe_has_backedge():
msg = (
@ -610,8 +611,13 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
result = torch.fx.experimental.symbolic_shapes.expect_true(sym_expr)
if not result:
unimplemented(
"Assertion failed on symbolic shapes. Did you make sure eager mode succeeds?"
unimplemented_v2(
gb_type="Assertion failed on symbolic shapes",
context=str(sym_expr),
explanation="",
hints=[
"Did you make sure your code works without compile?",
],
)
self.jump(inst)
return
@ -683,8 +689,12 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
self.push(value)
self.jump(inst)
else:
unimplemented(
"generic_jump on UserDefined with __bool__ returning non-constant"
unimplemented_v2(
gb_type="Data-dependent branching with non-constant __bool__",
context=f"method: {x}, result: {result}",
explanation="Attempted to perform data-dependent branching on a user-defined "
"object with a __bool__ method that did not return a constant.",
hints=[],
)
# __bool__ or __len__ is non-function or not existed in the user defined object
else:
@ -757,7 +767,12 @@ def break_graph_if_unsupported(*, push):
# We don't support graph break under GenericContextWrappingVariable,
# If there is, we roll back to the checkpoint and fall back.
excp.remove_from_stats()
unimplemented("Graph break under GenericContextWrappingVariable")
unimplemented_v2(
gb_type="Graph break under GenericContextWrappingVariable",
context="",
explanation="Attempted to graph break in an active context manager that doesn't support graph breaking.",
hints=[],
)
if isinstance(excp, exc.UncapturedHigherOrderOpError):
raise
@ -866,7 +881,12 @@ class BytecodeDistpatchTableMeta(type):
super().__init__(name, bases, dct)
def _missing(opname, *args):
unimplemented(f"missing: {opname}")
unimplemented_v2(
gb_type="Missing bytecode handler",
context=opname,
explanation=f"Dynamo does not know how to handle the bytecode instruction {opname}",
hints=[],
)
dispatch_table = {
op: getattr(cls, opname, functools.partial(_missing, opname))
@ -1219,9 +1239,21 @@ class InstructionTranslatorBase(
new_name = name.replace(".", "implicit")
self.push(self.symbolic_locals[new_name])
except KeyError:
unimplemented("undefined LOAD_FAST (implicit)")
unimplemented_v2(
gb_type="Attempted to read undefined local variable (implicit)",
context=f"LOAD_FAST {name}",
explanation=f"Could not find an implicit local variable with name `{name}`",
hints=[
"This happens in dict/list comprehensions",
],
)
else:
unimplemented("undefined LOAD_FAST")
unimplemented_v2(
gb_type="Attempted to read undefined local variable",
context=f"LOAD_FAST {name}",
explanation=f"Could not find a local variable with name `{name}`",
hints=[],
)
# for continuation functions
if name.startswith("___stack"):
@ -1315,7 +1347,12 @@ class InstructionTranslatorBase(
source, self.symbolic_globals[name]
)
if isinstance(value, RemovableHandleVariable):
unimplemented("Storing handles in globals - NYI")
unimplemented_v2(
gb_type="Storing Tensor hook handle in globals",
context=name,
explanation="This is not supported.",
hints=[],
)
self.output.side_effects.store_global(variable, name, value)
# Cache note: This cache only exists for the duration of this
@ -1402,7 +1439,12 @@ class InstructionTranslatorBase(
globals=self.f_globals,
)
except ImportError:
unimplemented("import a module that does not exist")
unimplemented_v2(
gb_type="Import failure",
context=f"module_name: {module_name}, fromlist: {fromlist}, level={level}",
explanation="Failure when attempting to import.",
hints=[],
)
if level != 0:
pkg = self.calc_package()
@ -1425,7 +1467,12 @@ class InstructionTranslatorBase(
if istype(value, (types.ModuleType, DummyModule)):
self.push(PythonModuleVariable(value, source=source))
else:
unimplemented(f"IMPORT_NAME {typestr(value)}")
unimplemented_v2(
gb_type="Bad import result",
context=typestr(value),
explanation="Import result is not a Python module.",
hints=[],
)
def IMPORT_FROM(self, inst):
self.DUP_TOP(inst)
@ -1566,13 +1613,24 @@ class InstructionTranslatorBase(
if observed_exception_type := exc.observed_exception_map.get(val.exc_type):
raise observed_exception_type(f"raised exception {val}")
raise exc.ObservedException(f"raised exception {val}")
unimplemented(f"raise {exc}")
unimplemented_v2(
gb_type="Failed to raise exception",
context=str(exc),
explanation="Attempted to raise a non-Exception type/value.",
hints=[],
)
def RAISE_VARARGS(self, inst):
if inst.arg == 0:
# duplicate the top of the stack and re-raise it
if sys.version_info < (3, 11):
unimplemented("re-raise")
unimplemented_v2(
gb_type="Re-raise with no arguments",
context="",
explanation="Dynamo does not support re-raising the previous exception "
"in Python < 3.11 (empty `raise`)",
hints=[],
)
assert isinstance(self.stack[-1], ExceptionVariable)
self.stack.append(self.stack[-1])
self._raise_exception_variable(inst)
@ -1584,14 +1642,24 @@ class InstructionTranslatorBase(
from_vt = self.pop()
if isinstance(from_vt, ConstantVariable) and from_vt.value is None:
self._raise_exception_variable(inst)
unimplemented("raise ... from ...")
unimplemented_v2(
gb_type="Re-raise with 2 arguments",
context=str(from_vt),
explanation="Dynamo does not support `raise ... from [not-None]`",
hints=[],
)
def CLEANUP_THROW(self, inst):
# https://github.com/python/cpython/pull/96010
tos = self.stack[-1]
assert isinstance(tos, ExceptionVariable)
if tos.exc_type is StopIteration:
unimplemented("CLEANUP_THROW with StopIteration")
unimplemented_v2(
gb_type="CLEANUP_THROW with StopIteration",
context="",
explanation="Received StopIteration when handling generator.throw/close. This is not supported.",
hints=[],
)
else:
self.RERAISE(inst)
@ -1599,7 +1667,13 @@ class InstructionTranslatorBase(
if sys.version_info >= (3, 11):
# RERAISE is currently supported in a narrow case of `raise ... from None`
self._raise_exception_variable(inst)
unimplemented("RERAISE")
unimplemented_v2(
gb_type="RERAISE in Python < 3.11",
context="",
explanation="RERAISE bytecode (https://docs.python.org/3.10/library/dis.html#opcode-RERAISE) "
"not supported in Python < 3.11. This bytecode is generated by try/with blocks.",
hints=[],
)
def WITH_EXCEPT_START(self, inst):
if sys.version_info >= (3, 11):
@ -1679,10 +1753,13 @@ class InstructionTranslatorBase(
block_stack_entry = self.block_stack.pop()
if block_stack_entry.inst.opname != "SETUP_FINALLY":
unimplemented(
"exception is raised when top of the block stack "
unimplemented_v2(
gb_type="Exception raised with invalid exception handler",
context="",
explanation="Exception raised when top of the block stack "
"is not exception handler (e.g. try .. with .. except). "
f"Current TOS is {block_stack_entry.inst}"
f"Current TOS is {block_stack_entry.inst}",
hints=[],
)
# 1) pop values from the stack until it matches the stack depth
@ -1779,14 +1856,20 @@ class InstructionTranslatorBase(
# 2) except (NotImplemetedError, AttributeError) -> TupleVariable
if not isinstance(expected_exc_types, (BuiltinVariable, TupleVariable)):
unimplemented(
f"except has an unsupported types of objects {expected_exc_types}"
unimplemented_v2(
gb_type="Exception with bad expected type",
context=str(expected_exc_types),
explanation=f"`except ...` has unsupported type {expected_exc_types}.",
hints=[],
)
if sys.version_info >= (3, 11):
if not isinstance(exc_instance, variables.ExceptionVariable):
unimplemented(
f"except expects to recieve an object of exception type but received {exc_instance}"
unimplemented_v2(
gb_type="Caught non-Exception value",
context=str(exc_instance),
explanation=f"Except expects to recieve an object of Exception type but received {exc_instance}.",
hints=[],
)
if isinstance(expected_exc_types, TupleVariable):
@ -1798,8 +1881,11 @@ class InstructionTranslatorBase(
for expected_type in expected_types:
if not isinstance(expected_type, BuiltinVariable):
unimplemented(
f"except has an unsupported types of object {expected_type}"
unimplemented_v2(
gb_type="Exception with non-type expectation",
context=str(expected_type),
explanation=f"`except ...` expects a non-type: {expected_type}.",
hints=[],
)
if isinstance(exc_instance, variables.ExceptionVariable) and issubclass(
exc_instance.exc_type, expected_type.fn
@ -1844,7 +1930,12 @@ class InstructionTranslatorBase(
kwargsvars = self.pop()
argsvars = self.pop()
else:
unimplemented("CALL_FUNCTION_EX")
unimplemented_v2(
gb_type="Variadic function call with bad flags",
context=f"flags: {inst.argval}",
explanation=f"Attempted to call a variadic function (CALL_FUNCTION_EX) with bad flags {inst.argval}",
hints=[],
)
if sys.version_info >= (3, 13):
# 3.13 swapped null and callable
@ -1879,7 +1970,12 @@ class InstructionTranslatorBase(
# args, aot_autograd/inductor while lowering generates
# aten.random.from, again causing syntax errors. Since this
# usecase is uncommon, graph break.
unimplemented("random_ op is called with from keyword")
unimplemented_v2(
gb_type="Tensor.random_ op called with `from` keyword",
context="",
explanation="This is not supported.",
hints=[],
)
elif (
fn.name == "uniform_"
and isinstance(argsvars, TupleVariable)
@ -1892,7 +1988,12 @@ class InstructionTranslatorBase(
# args, aot_autograd/inductor while lowering generates
# aten.uniform.from, again causing syntax errors. Since this
# usecase is uncommon, graph break.
unimplemented("uniform_ op is called with from keyword")
unimplemented_v2(
gb_type="Tensor.uniform_ op called with `from` keyword",
context="",
explanation="This is not supported.",
hints=[],
)
if not isinstance(
argsvars, BaseListVariable
@ -1906,7 +2007,12 @@ class InstructionTranslatorBase(
if not isinstance(argsvars, BaseListVariable) or not isinstance(
kwargsvars, ConstDictVariable
):
unimplemented(f"non-static call {typestr(argsvars)} {typestr(kwargsvars)}")
unimplemented_v2(
gb_type="Variadic function call with bad args/kwargs type",
context=f"args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}",
explanation="Expected args to be a list and kwargs to be a dict",
hints=[],
)
# Map to a dictionary of str -> VariableTracker
kwargsvars = kwargsvars.keys_as_python_constant()
@ -2005,7 +2111,13 @@ class InstructionTranslatorBase(
def store_attr_graph_break(self, inst):
log_graph_break(self.code_options, reason="STORE_ATTR-caused graph break")
if not self.should_compile_partial_graph():
unimplemented("should_compile_partial_graph=False")
unimplemented_v2(
gb_type="Should not compile partial graph (STORE_ATTR)",
context="",
explanation="Dynamo has determined when encountering an unsupported "
"STORE_ATTR instruction (i.e. `obj.attr = val`) that it should not compile the partial graph.",
hints=[],
)
self.output.compile_subgraph(
self, reason=GraphCompileReason("store_attr", [self.frame_summary()])
)
@ -2054,7 +2166,12 @@ class InstructionTranslatorBase(
def BUILD_SET(self, inst):
if config.inject_BUILD_SET_unimplemented_TESTING_ONLY:
unimplemented("missing: BUILD_SET")
unimplemented_v2(
gb_type="missing BUILD_SET handler",
context="",
explanation="Missing BUILD_SET bytecode handler (for testing purposes).",
hints=[],
)
items = self.popn(inst.argval)
new_set = SetVariable(items, mutation_type=ValueMutationNew())
self.push(new_set)
@ -2066,7 +2183,13 @@ class InstructionTranslatorBase(
try:
items.extend(seq.force_unpack_var_sequence(self))
except NotImplementedError:
unimplemented(f"BUILD_LIST_UNPACK {seq}")
unimplemented_v2(
gb_type="Failed to unpack object for BUILD_LIST_UNPACK",
context=str(seq),
explanation=f"{seq} cannot be unpacked into a list for the BUILD_LIST_UNPACK "
"bytecode (`[*x, *y, ...]`).",
hints=[],
)
self.push(cls(items, mutation_type=ValueMutationNew()))
def BUILD_TUPLE_UNPACK(self, inst):
@ -2193,9 +2316,21 @@ class InstructionTranslatorBase(
elif seq.has_force_unpack_var_sequence(self):
val = seq.force_unpack_var_sequence(self)
else:
unimplemented(f"UNPACK_SEQUENCE {seq}")
unimplemented_v2(
gb_type="Failed to unpack object for UNPACK_SEQUENCE",
context=str(seq),
explanation=f"{seq} cannot be unpacked into a list for the UNPACK_SEQUENCE bytecode "
"(i.e. `a, b, c = d`).",
hints=[],
)
if len(val) != inst.argval:
unimplemented("UNPACK_SEQUENCE length mismatch")
unimplemented_v2(
gb_type="Length mismatch when unpacking object for UNPACK_SEQUENCE",
context=f"expected length: {inst.argval}, actual: {len(val)}",
explanation=f"{seq} unpacked to a list for the UNPACK_SEQUENCE bytecode "
"(i.e. `a, b, c = d`) with unexpected length.",
hints=[],
)
for i in reversed(val):
self.push(i)
@ -2216,7 +2351,13 @@ class InstructionTranslatorBase(
for item in reversed(vals_prefix):
self.push(item)
else:
unimplemented(f"UNPACK_EX {seq}")
unimplemented_v2(
gb_type="Failed to unpack object for UNPACK_EX",
context=str(seq),
explanation=f"{seq} cannot be unpacked into a list for the UNPACK_EX bytecode "
"(i.e. `a, *b, c = d`).",
hints=[],
)
def NOP(self, inst):
pass
@ -2311,12 +2452,20 @@ class InstructionTranslatorBase(
format_string_parts.append(part.format_string)
args.extend(part.sym_args)
if set(kwargs.keys()) & set(part.sym_kwargs.keys()):
unimplemented(
f"BUILD_STRING key conflict {kwargs} & {part.sym_kwargs}"
unimplemented_v2(
gb_type="BUILD_STRING key conflict",
context=f"format_string_parts: {format_string_parts}, kwargs: {kwargs}, part.sym_kwargs: {part.sym_kwargs}",
explanation="Failed to build format string due to key conflict",
hints=[],
)
kwargs.update(part.sym_kwargs)
else:
unimplemented(f"BUILD_STRING {part}")
unimplemented_v2(
gb_type="BUILD_STRING type error",
context=str(part),
explanation="Format string part type is not correct - expected constant or format string.",
hints=[],
)
self.push(
variables.StringFormatVariable.create(
"".join(format_string_parts), args, kwargs
@ -2638,7 +2787,12 @@ class InstructionTranslatorBase(
def LOAD_FAST_CHECK(self, inst):
if isinstance(self.symbolic_locals[inst.argval], NullVariable):
unimplemented("LOAD_FAST_CHECK on uninitialized variable")
unimplemented_v2(
gb_type="LOAD_FAST_CHECK on uninitialized variable",
context=inst.argval,
explanation=f"Attempted to load uninitialized local variable {inst.argval}",
hints=[],
)
self.LOAD_FAST(inst)
def LOAD_FAST_AND_CLEAR(self, inst):
@ -2666,7 +2820,12 @@ class InstructionTranslatorBase(
# INTRINSIC_LIST_TO_TUPLE
self.push(TupleVariable(self.pop().force_unpack_var_sequence(self)))
else:
unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}")
unimplemented_v2(
gb_type="Missing CALL_INTRINSIC_1 handler",
context=f"CALL_INTRINSIC_1 operand: {inst.argval}",
explanation=f"No handler implemented for CALL_INTRINSIC_1 {inst.argval} instruction.",
hints=[],
)
def END_SEND(self, inst):
tos = self.pop()
@ -3052,7 +3211,12 @@ class InstructionTranslator(InstructionTranslatorBase):
# if it reaches here, it means Dynamo failed to inline a functorch function
f"- torch.func.{name}(fn) requires the function to be inlined by dynamo"
)
unimplemented(msg)
unimplemented_v2(
gb_type="Unsupported functorch tracing attempt",
context="",
explanation=msg,
hints=[],
)
def get_example_value(self, source: Source):
if isinstance(source, LocalSource):
@ -3288,7 +3452,13 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
@staticmethod
def check_inlineable(func):
if func.has_self():
unimplemented("inline with __self__")
unimplemented_v2(
gb_type="Inline attempt with __self__",
context=str(func),
explanation="Attempted to inline a function with the `__self__` attribute. "
"Dynamo is expected to decompose method calls into function calls with a `self` argument.",
hints=[],
)
result = trace_rules.check_verbose(func, is_inlined_call=True)
if result.skipped:
@ -3346,7 +3516,12 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
kwargs,
):
if isinstance(func, SkipFunctionVariable):
unimplemented("inline with functions in skip files")
unimplemented_v2(
gb_type="Attempted to inline function marked as skipped (SkipFunctionVariable)",
context=f"Attempted to inline a SkipFunctionVariable {func}",
explanation="Attempted to inline a function that was previously determined to be marked as intentionally skipped.",
hints=[],
)
assert isinstance(
func,
(
@ -3373,13 +3548,23 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
for v in itertools.chain(sub_locals.values()):
if not isinstance(v, VariableTracker):
unimplemented(f"unconverted arg {v}")
unimplemented_v2(
gb_type="Encountered unconverted argument when attempting to inline",
context=f"func: {func}, arg: {v}",
explanation="An argument to an inlined function was not successfully converted to a VariableTracker.",
hints=[],
)
code: types.CodeType = func.get_code()
if code.co_name in ("__setitem__", "__setattr__") and not (
args and isinstance(args[0], variables.UserDefinedObjectVariable)
):
unimplemented(f"inline {code.co_name}")
unimplemented_v2(
gb_type="Unsupported __setitem__/__setattr__ inline attempt",
context=f"code name: {code.co_name}, args: {args}",
explanation=f"Attempted to inline {code.co_name} where first argument (self) is not a user-defined object.",
hints=[],
)
suffix = ""
# TODO: mlazos, add support for enabling multiple artifact logs
@ -3541,7 +3726,12 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
return False # inlining functions is all-or-nothing
def create_call_resume_at(self, offset):
unimplemented("cant resume while inlining")
unimplemented_v2(
gb_type="Graph break in inlined function",
context="",
explanation="Graph breaks in an inlined call are not supported.",
hints=[],
)
def RETURN_VALUE(self, inst):
self.symbolic_result = self.pop() # type: ignore[assignment]
@ -3600,7 +3790,12 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
else:
value = self.pop()
if isinstance(value, RemovableHandleVariable):
unimplemented("Storing handles in globals - NYI")
unimplemented_v2(
gb_type="Storing Tensor hook handle in globals (inline call)",
context=inst.argval,
explanation="This is not supported.",
hints=[],
)
name = inst.argval
_fglobals_value, fglobals_vt, _ = self.get_globals_source_and_value(name)
self.output.side_effects.store_attr(fglobals_vt, name, value)
@ -3658,7 +3853,12 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
# lifted the `unimplemented("generator")` in frame conversion. This codepath handles
# subgenerator and lines up with this line in Python 3.10
# https://github.com/python/cpython/blob/3.10/Python/ceval.c#L2599
unimplemented("Unreachable sub-generator code")
unimplemented_v2(
gb_type="Unreachable sub-generator code",
context="",
explanation="Should only be encountered while implementing generator support.",
hints=[],
)
try:
val = tos.next_variable(self)
@ -3711,6 +3911,16 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
# lifted the `unimplemented("generator")` in frame conversion. This codepath handles
# subgenerator and lines up with this line in Python 3.11
# https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2597
unimplemented("Unreachable sub-generator code")
unimplemented_v2(
gb_type="Unreachable sub-generator code",
context="",
explanation="Should only be encountered while implementing generator support.",
hints=[],
)
else:
unimplemented(f"SEND {typestr(tos)}")
unimplemented_v2(
gb_type="SEND with bad type",
context=f"TOS type: {typestr(tos)}",
explanation=f"Attempted to SEND with unsupported type {typestr(tos)}.",
hints=[],
)

View File

@ -1109,11 +1109,14 @@ def proxy_args_kwargs(args, kwargs):
proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}
return proxy_args, proxy_kwargs
except NotImplementedError as e:
from .exc import unimplemented
from .exc import unimplemented_v2
from .variables.base import typestr
unimplemented(
f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}",
unimplemented_v2(
gb_type="Failed to convert args/kwargs to proxy",
context=f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}",
explanation="Missing `as_proxy()` implementation for some arg/kwarg.",
hints=[],
from_exc=e,
)
@ -2464,9 +2467,14 @@ def set_example_value(node, example_value):
def _get_fake_tensor(vt):
fake_tensor = vt.as_proxy().node.meta.get("example_value")
if not is_fake(fake_tensor):
from .exc import unimplemented
from .exc import unimplemented_v2
unimplemented("Cannot check Tensor object identity without its fake value")
unimplemented_v2(
gb_type="Cannot check Tensor object identity without its fake value",
context=str(fake_tensor),
explanation="TensorVariable is missing a fake example_value.",
hints=[],
)
return fake_tensor
@ -2578,11 +2586,17 @@ def wrap_fake_exception(fn):
try:
return fn()
except UnsupportedFakeTensorException as e:
from .exc import unimplemented
from .exc import unimplemented_v2
msg = f"Unsupported: {e.reason} with fake tensor propagation."
msg = f"Encountered exception ({e.reason}) during fake tensor propagation."
log.warning(msg)
unimplemented(msg, from_exc=e)
unimplemented_v2(
gb_type="Fake tensor propagation exception",
context=str(e.reason),
explanation=msg,
hints=[],
from_exc=e,
)
def deepcopy_to_fake_tensor(obj, fake_mode):
@ -2956,9 +2970,14 @@ def extract_fake_example_value(node, required=True):
if "example_value" in node.meta and is_fake(node.meta["example_value"]):
return node.meta["example_value"]
elif required:
from torch._dynamo.exc import unimplemented
from torch._dynamo.exc import unimplemented_v2
unimplemented("`FakeTensor` example value was required but not available")
unimplemented_v2(
gb_type="Missing FakeTensor example value",
context=str(node),
explanation=f"`FakeTensor` example value was required for {node} but not available.",
hints=[],
)
else:
return None
@ -3002,7 +3021,6 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
from .exc import (
TorchRuntimeError,
unimplemented,
unimplemented_v2,
Unsupported,
UserError,
@ -3124,10 +3142,15 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
f"module `{module}` and you may need to `import {module}`"
f"({ctx}), otherwise "
)
unimplemented(
f"unsupported operator: {cause.func} ({import_suggestion}see "
"https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0"
" for how to fix)"
unimplemented_v2(
gb_type="Operator does not support running with fake tensors",
context=f"unsupported operator: {cause.func}",
explanation="",
hints=[
f"{import_suggestion}see "
"https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0"
" for how to fix",
],
)
elif isinstance(
cause, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
@ -3140,7 +3163,12 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
elif isinstance(cause, ValueRangeError):
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, e.args[0]) from e
elif isinstance(cause, TypeError) and "argument" in str(cause):
unimplemented(f"TypeError {node.target}: {cause}")
unimplemented_v2(
gb_type="TypeError when making fake tensor call",
context=f"TypeError {node.target}: {cause}",
explanation="",
hints=[],
)
raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
@ -3197,9 +3225,14 @@ def run_node(tracer, node, args, kwargs, nnmodule):
return node.target(*args, **kwargs)
elif op == "call_method":
if not hasattr(args[0], node.target):
from .exc import unimplemented
from .exc import unimplemented_v2
unimplemented(make_error_message("attribute not defined"))
unimplemented_v2(
gb_type="Missing attribute when running call_method node",
context="",
explanation=make_error_message("attribute not defined"),
hints=[],
)
return getattr(args[0], node.target)(*args[1:], **kwargs)
elif op == "call_module":
assert nnmodule is not None
@ -3212,9 +3245,15 @@ def run_node(tracer, node, args, kwargs, nnmodule):
except (NotImplementedError, UnsupportedFakeTensorException) as e:
# NB: mimic how wrap_fake_exception does it
from .exc import unimplemented
from .exc import unimplemented_v2
unimplemented(make_error_message(e), from_exc=e)
unimplemented_v2(
gb_type="NotImplementedError/UnsupportedFakeTensorException when running node",
context=str(e),
explanation=make_error_message(e),
hints=[],
from_exc=e,
)
except Unsupported:
raise
except Exception as e: