mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] make some more graph break messages readable in English [2/N] (#147385)
This is for "for some large number Z, make sure the error messages are readable English." - beginning to audit all `unimplemented` sites and making sure that all messages are at least English-readable. Hints may not necessarily be provided. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147385 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
7a06bfdd1c
commit
3fd68e4e2f
@ -48,7 +48,7 @@ class TestCustomOperators(TestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"unsupported operator: .* you may need to `import nonexistent`",
|
||||
r"(?s)Operator does not support running with fake tensors.*you may need to `import nonexistent`",
|
||||
):
|
||||
f(x)
|
||||
|
||||
|
@ -36,7 +36,7 @@ from .bytecode_transformation import (
|
||||
create_rot_n,
|
||||
Instruction,
|
||||
)
|
||||
from .exc import IncorrectUsage, unimplemented
|
||||
from .exc import IncorrectUsage, unimplemented, unimplemented_v2
|
||||
from .source import AttrSource, ChainedSource, DictGetItemSource, Source
|
||||
from .utils import is_safe_constant, rot_n_helper
|
||||
from .variables.base import ValueMutationExisting, VariableTracker
|
||||
@ -335,7 +335,19 @@ class PyCodegen:
|
||||
try:
|
||||
self.call_reconstruct(value)
|
||||
except NotImplementedError:
|
||||
unimplemented(f"reconstruct: {value}")
|
||||
unimplemented_v2(
|
||||
gb_type="Reconstruction failure",
|
||||
context=str(value),
|
||||
explanation=f"Dynamo has no bytecode reconstruction implemented for sourceless variable {value}.",
|
||||
hints=[
|
||||
"If Dynamo attempting to trace a return statement and your code is attempting to return a variable "
|
||||
"that Dynamo cannot reconstruct, then remove it from the return statement.",
|
||||
"If this reconstruction graph break occurs while handling another graph break, then resolve the "
|
||||
"initial graph break.",
|
||||
"Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have"
|
||||
"reconstruction rules may be fundamentally unreconstructable.",
|
||||
],
|
||||
)
|
||||
if allow_cache and value in self.tempvars:
|
||||
self._output.append(create_dup_top())
|
||||
self.add_cache(value)
|
||||
|
@ -46,7 +46,7 @@ from typing import Optional, Union
|
||||
import torch
|
||||
from torch.fx.experimental.symbolic_shapes import free_symbols
|
||||
|
||||
from .exc import unimplemented
|
||||
from .exc import unimplemented_v2
|
||||
from .variables import CellVariable
|
||||
from .variables.constant import ConstantVariable
|
||||
from .variables.tensor import SymNodeVariable
|
||||
@ -192,7 +192,12 @@ class ComptimeContext:
|
||||
"""
|
||||
Manually trigger a graph break
|
||||
"""
|
||||
unimplemented(msg)
|
||||
unimplemented_v2(
|
||||
gb_type="ComptimeContext graph break",
|
||||
context=msg,
|
||||
explanation=f"Manually triggered ComptimeContext graph break with message {msg}.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
def graph(self):
|
||||
"""
|
||||
|
@ -104,7 +104,7 @@ from .exc import (
|
||||
SkipCodeRecursiveException,
|
||||
TorchRuntimeError,
|
||||
UncapturedHigherOrderOpError,
|
||||
unimplemented,
|
||||
unimplemented_v2,
|
||||
Unsupported,
|
||||
)
|
||||
from .guards import (
|
||||
@ -535,7 +535,15 @@ class ConvertFrameAssert:
|
||||
return ConvertFrameReturn()
|
||||
|
||||
if is_generator(code):
|
||||
unimplemented("generator")
|
||||
unimplemented_v2(
|
||||
gb_type="Attempt to trace generator",
|
||||
context="",
|
||||
explanation="Generators cannot be compiled directly with `torch.compile`.",
|
||||
hints=[
|
||||
"Call a generator from inside of a non-generator Python function and "
|
||||
"compile that function instead.",
|
||||
],
|
||||
)
|
||||
|
||||
if not has_tensor_in_frame(frame):
|
||||
return ConvertFrameReturn()
|
||||
@ -793,7 +801,14 @@ def _compile(
|
||||
# We now have a new "last attempt", reset the clock
|
||||
last_attempt_start_time = time.time()
|
||||
if attempt > 100:
|
||||
unimplemented("100+ RestartAnalysis() calls")
|
||||
unimplemented_v2(
|
||||
gb_type="Excessive RestartAnalysis() calls",
|
||||
context="",
|
||||
explanation="Dynamo attempted to trace the same frame 100+ times. "
|
||||
"Giving up on compiling as the compile time tradeoff is likely not "
|
||||
"worth the performance gain.",
|
||||
hints=[],
|
||||
)
|
||||
except exc.SkipFrame as e:
|
||||
if not isinstance(e, exc.TensorifyScalarRestartAnalysis):
|
||||
TensorifyState.clear()
|
||||
@ -962,7 +977,15 @@ def _compile(
|
||||
raise RecompileLimitExceeded(f"{limit_type} reached")
|
||||
else:
|
||||
# do not recursively skip frames
|
||||
unimplemented(f"{limit_type} reached")
|
||||
unimplemented_v2(
|
||||
gb_type="Dynamo cache limit exceeded",
|
||||
context=f"Limit type: {limit_type}",
|
||||
explanation="Dynamo attempted to recompile the code object too many times, "
|
||||
f"exceeding the {limit_type} cache size limit."
|
||||
"Giving up on compiling as the compile time tradeoff is likely not "
|
||||
"worth the performance gain.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
log.debug(
|
||||
"torchdynamo start compiling %s %s:%s, stack (elided %s frames):\n%s",
|
||||
|
@ -1691,7 +1691,14 @@ class GuardBuilder(GuardBuilderBase):
|
||||
assert istype(val.training, bool)
|
||||
self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH)
|
||||
else:
|
||||
exc.unimplemented(f"Guard setup for uninitialized class {type(val)}")
|
||||
exc.unimplemented_v2(
|
||||
gb_type="Attempted to guard on uninitialized nn.Module",
|
||||
context="",
|
||||
explanation="Attempted to setup an NN_MODULE guard on unitialized "
|
||||
f"nn.Module subclass `{type(val)}`. Please ensure the `nn.Module` "
|
||||
"subclass instance has called `super().__init__()`.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
def FUNCTION_MATCH(self, guard: Guard):
|
||||
"""things like torch.add and user defined functions"""
|
||||
|
@ -81,7 +81,7 @@ from .exc import (
|
||||
BackendCompilerFailed,
|
||||
exceptions_allowed_to_be_fallback,
|
||||
SkipFrame,
|
||||
unimplemented,
|
||||
unimplemented_v2,
|
||||
unimplemented_v2_with_warning,
|
||||
)
|
||||
from .graph_deduplication import apply_graph_deduplication
|
||||
@ -486,7 +486,12 @@ class OutputGraph:
|
||||
def get_backward_state_proxy(self):
|
||||
if self.backward_state_proxy is None:
|
||||
if self.export:
|
||||
unimplemented("backward_state does not support export")
|
||||
unimplemented_v2(
|
||||
gb_type="backward_state does not support export",
|
||||
context="",
|
||||
explanation="Compiled autograd doesn't work with `torch.export`.",
|
||||
hints=[],
|
||||
)
|
||||
example_value = BackwardState()
|
||||
self.backward_state_proxy = self.root_tracer.create_graph_input(
|
||||
"dynamo_backward_state",
|
||||
@ -994,7 +999,12 @@ class OutputGraph:
|
||||
log.debug("COMPILING GRAPH due to %s", reason)
|
||||
|
||||
if not all(block.can_restore() for block in tx.block_stack):
|
||||
unimplemented("compile_subgraph with block_depth != 0")
|
||||
unimplemented_v2(
|
||||
gb_type="Attempt to compile graph in a try block",
|
||||
context="",
|
||||
explanation="Dynamo cannot compile traced graphs while in a try block.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
prefix_insts: list[Instruction] = []
|
||||
if sys.version_info >= (3, 11):
|
||||
@ -1843,7 +1853,12 @@ def check_pt2_compliant_op(output_graph, kind, target, args, kwargs):
|
||||
def encountered_non_compliant_op(target, msg):
|
||||
output_graph.non_compliant_ops.add(target)
|
||||
if config.only_allow_pt2_compliant_ops:
|
||||
unimplemented(msg + " " + err_epilogue)
|
||||
unimplemented_v2(
|
||||
gb_type="Encountered non-PT2-compliant op",
|
||||
context="",
|
||||
explanation=msg + " " + err_epilogue,
|
||||
hints=[],
|
||||
)
|
||||
|
||||
if isinstance(target, torch._ops.OpOverload):
|
||||
if torch.Tag.pt2_compliant_tag in target.tags:
|
||||
@ -1881,7 +1896,12 @@ def check_pt2_compliant_op(output_graph, kind, target, args, kwargs):
|
||||
target._qualified_op_name, *args, **kwargs
|
||||
)
|
||||
except RuntimeError as e:
|
||||
unimplemented(str(e))
|
||||
unimplemented_v2(
|
||||
gb_type="Error when attempting to resolve op packet",
|
||||
context="",
|
||||
explanation=str(e),
|
||||
hints=[],
|
||||
)
|
||||
|
||||
op = getattr(target, overload)
|
||||
if torch.Tag.pt2_compliant_tag in op.tags:
|
||||
@ -1933,6 +1953,7 @@ class SubgraphTracer(fx.Tracer):
|
||||
|
||||
# SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design]
|
||||
self.parent = parent
|
||||
self.source_target = source_target
|
||||
# A dict mapping previously free variables (Proxy objects)
|
||||
# to new Proxy objects that wrap inputs to this subgraph.
|
||||
#
|
||||
@ -2115,7 +2136,13 @@ class SubgraphTracer(fx.Tracer):
|
||||
]
|
||||
elif kind == "call_module":
|
||||
if self.parent is not None:
|
||||
unimplemented("Invoking an nn.Module inside HigherOrderOperator")
|
||||
# TODO can remove once inline_inbuilt_nn_modules is always True
|
||||
unimplemented_v2(
|
||||
gb_type="Invoking an nn.Module inside a higher order operator",
|
||||
context=f"Higher order op name: {self.source_target}",
|
||||
explanation="This is not supported.",
|
||||
hints=[],
|
||||
)
|
||||
# For modules we store the class
|
||||
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
|
||||
(
|
||||
@ -2143,8 +2170,12 @@ class SubgraphTracer(fx.Tracer):
|
||||
]
|
||||
elif kind == "call_module":
|
||||
if self.parent is not None:
|
||||
unimplemented(
|
||||
"Invoking an nn.Module inside HigherOrderOperator"
|
||||
# TODO can remove once inline_inbuilt_nn_modules is always True
|
||||
unimplemented_v2(
|
||||
gb_type="Invoking an nn.Module inside a HigherOrderOperator",
|
||||
context="",
|
||||
explanation="This is not supported.",
|
||||
hints=[],
|
||||
)
|
||||
# For modules we store the class
|
||||
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
|
||||
|
@ -19,7 +19,7 @@ from .bytecode_transformation import (
|
||||
create_instruction,
|
||||
)
|
||||
from .codegen import PyCodegen
|
||||
from .exc import SideEffectsError, unimplemented
|
||||
from .exc import SideEffectsError, unimplemented_v2
|
||||
from .source import GlobalSource, LocalCellSource, LocalSource, Source
|
||||
from .utils import is_frozen_dataclass, nn_module_new, object_new
|
||||
from .variables.base import (
|
||||
@ -186,8 +186,12 @@ class SideEffects:
|
||||
"unintended variable modifications."
|
||||
)
|
||||
if not is_side_effect_safe(item.mutation_type):
|
||||
unimplemented(
|
||||
"HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)"
|
||||
# TODO plumb HOP information here
|
||||
unimplemented_v2(
|
||||
gb_type="HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)",
|
||||
context="",
|
||||
explanation="This is not supported.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
def store_attr(self, item: VariableTracker, name: str, value: VariableTracker):
|
||||
@ -202,12 +206,22 @@ class SideEffects:
|
||||
assert self.is_attribute_mutation(item)
|
||||
result = self.store_attr_mutations[item][name]
|
||||
if not deleted_ok and isinstance(result, variables.DeletedVariable):
|
||||
unimplemented("read deleted attribute")
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to read a deleted variable",
|
||||
context=f"item: {item}, name: {name}",
|
||||
explanation="",
|
||||
hints=[],
|
||||
)
|
||||
return result
|
||||
|
||||
def store_cell(self, cellvar, value):
|
||||
if cellvar.is_immutable():
|
||||
unimplemented("Dynamo currently doesn't support writing to such cell")
|
||||
unimplemented_v2(
|
||||
gb_type="Write to immutable cell",
|
||||
context=f"cellvar: {cellvar}, value: {value}",
|
||||
explanation="Dynamo doesn't support writing to immutable/sourceless cell variables.",
|
||||
hints=[],
|
||||
)
|
||||
assert isinstance(cellvar, variables.CellVariable)
|
||||
assert isinstance(value, variables.VariableTracker)
|
||||
self.store_attr(cellvar, "cell_contents", value)
|
||||
@ -218,7 +232,12 @@ class SideEffects:
|
||||
return self.load_attr(cellvar, "cell_contents", check=False)
|
||||
if cellvar.pre_existing_contents:
|
||||
return cellvar.pre_existing_contents
|
||||
unimplemented("cannot read uninitialized cell")
|
||||
unimplemented_v2(
|
||||
gb_type="Read uninitialized cell",
|
||||
context=str(cellvar),
|
||||
explanation="Attempted to read a cell variable that has not been populated yet.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
def load_global(self, gvar: VariableTracker, name: str):
|
||||
assert isinstance(gvar, variables.VariableTracker)
|
||||
@ -574,7 +593,12 @@ class SideEffects:
|
||||
var.source = LocalCellSource(var.local_name)
|
||||
elif isinstance(var.mutation_type, AttributeMutationNew):
|
||||
if isinstance(var, variables.AutogradFunctionContextVariable):
|
||||
unimplemented("AutogradFunctionContextVariable escaped")
|
||||
unimplemented_v2(
|
||||
gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region",
|
||||
context="",
|
||||
explanation="We cannot reconstruct a torch.autograd.Function's context object.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
# Reconstruct the bytecode for
|
||||
# base_cls.__new__(user_cls, *args)
|
||||
@ -723,7 +747,14 @@ class SideEffects:
|
||||
isinstance(var.maxlen, variables.ConstantVariable)
|
||||
and var.maxlen.value is None
|
||||
):
|
||||
unimplemented("side effect on existing deque with limited maxlen")
|
||||
unimplemented_v2(
|
||||
gb_type="Side effect on existing deque with limited maxlen",
|
||||
context="",
|
||||
explanation="This is not supported.",
|
||||
hints=[
|
||||
"Don't use a deque with `maxlen` specified.",
|
||||
],
|
||||
)
|
||||
|
||||
# old.extend(new), this runs last
|
||||
cg(var.source)
|
||||
|
@ -74,13 +74,7 @@ from .bytecode_transformation import (
|
||||
)
|
||||
from .code_context import code_context
|
||||
from .codegen import PyCodegen
|
||||
from .exc import (
|
||||
ArgsMismatchError,
|
||||
BackendCompilerFailed,
|
||||
unimplemented,
|
||||
unimplemented_v2,
|
||||
Unsupported,
|
||||
)
|
||||
from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented_v2, Unsupported
|
||||
from .funcname_cache import get_funcname
|
||||
from .guards import GuardBuilder, install_guard
|
||||
from .output_graph import GraphCompileReason, OutputGraph
|
||||
@ -538,9 +532,16 @@ def log_graph_break(code_options, reason="", exc_info=False, user_stack=None):
|
||||
|
||||
def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
||||
def jump_graph_break(self, inst, value, extra_msg=""):
|
||||
log_graph_break(self.code_options, reason="Data-dependent jump")
|
||||
log_graph_break(self.code_options, reason="Data-dependent branching")
|
||||
if not self.should_compile_partial_graph():
|
||||
unimplemented("should_compile_partial_graph=False")
|
||||
unimplemented_v2(
|
||||
gb_type="Should not compile partial graph (data-dependent branching)",
|
||||
context="",
|
||||
explanation="Dynamo has determined when encountering data-dependent "
|
||||
"branching (e.g. `if my_tensor.item() > 0:`) that it should not "
|
||||
"compile the partial graph.",
|
||||
hints=[],
|
||||
)
|
||||
# compile a partial subgraph prefix then jump into user code
|
||||
if self.maybe_has_backedge():
|
||||
msg = (
|
||||
@ -610,8 +611,13 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
||||
|
||||
result = torch.fx.experimental.symbolic_shapes.expect_true(sym_expr)
|
||||
if not result:
|
||||
unimplemented(
|
||||
"Assertion failed on symbolic shapes. Did you make sure eager mode succeeds?"
|
||||
unimplemented_v2(
|
||||
gb_type="Assertion failed on symbolic shapes",
|
||||
context=str(sym_expr),
|
||||
explanation="",
|
||||
hints=[
|
||||
"Did you make sure your code works without compile?",
|
||||
],
|
||||
)
|
||||
self.jump(inst)
|
||||
return
|
||||
@ -683,8 +689,12 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
||||
self.push(value)
|
||||
self.jump(inst)
|
||||
else:
|
||||
unimplemented(
|
||||
"generic_jump on UserDefined with __bool__ returning non-constant"
|
||||
unimplemented_v2(
|
||||
gb_type="Data-dependent branching with non-constant __bool__",
|
||||
context=f"method: {x}, result: {result}",
|
||||
explanation="Attempted to perform data-dependent branching on a user-defined "
|
||||
"object with a __bool__ method that did not return a constant.",
|
||||
hints=[],
|
||||
)
|
||||
# __bool__ or __len__ is non-function or not existed in the user defined object
|
||||
else:
|
||||
@ -757,7 +767,12 @@ def break_graph_if_unsupported(*, push):
|
||||
# We don't support graph break under GenericContextWrappingVariable,
|
||||
# If there is, we roll back to the checkpoint and fall back.
|
||||
excp.remove_from_stats()
|
||||
unimplemented("Graph break under GenericContextWrappingVariable")
|
||||
unimplemented_v2(
|
||||
gb_type="Graph break under GenericContextWrappingVariable",
|
||||
context="",
|
||||
explanation="Attempted to graph break in an active context manager that doesn't support graph breaking.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
if isinstance(excp, exc.UncapturedHigherOrderOpError):
|
||||
raise
|
||||
@ -866,7 +881,12 @@ class BytecodeDistpatchTableMeta(type):
|
||||
super().__init__(name, bases, dct)
|
||||
|
||||
def _missing(opname, *args):
|
||||
unimplemented(f"missing: {opname}")
|
||||
unimplemented_v2(
|
||||
gb_type="Missing bytecode handler",
|
||||
context=opname,
|
||||
explanation=f"Dynamo does not know how to handle the bytecode instruction {opname}",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
dispatch_table = {
|
||||
op: getattr(cls, opname, functools.partial(_missing, opname))
|
||||
@ -1219,9 +1239,21 @@ class InstructionTranslatorBase(
|
||||
new_name = name.replace(".", "implicit")
|
||||
self.push(self.symbolic_locals[new_name])
|
||||
except KeyError:
|
||||
unimplemented("undefined LOAD_FAST (implicit)")
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to read undefined local variable (implicit)",
|
||||
context=f"LOAD_FAST {name}",
|
||||
explanation=f"Could not find an implicit local variable with name `{name}`",
|
||||
hints=[
|
||||
"This happens in dict/list comprehensions",
|
||||
],
|
||||
)
|
||||
else:
|
||||
unimplemented("undefined LOAD_FAST")
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to read undefined local variable",
|
||||
context=f"LOAD_FAST {name}",
|
||||
explanation=f"Could not find a local variable with name `{name}`",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
# for continuation functions
|
||||
if name.startswith("___stack"):
|
||||
@ -1315,7 +1347,12 @@ class InstructionTranslatorBase(
|
||||
source, self.symbolic_globals[name]
|
||||
)
|
||||
if isinstance(value, RemovableHandleVariable):
|
||||
unimplemented("Storing handles in globals - NYI")
|
||||
unimplemented_v2(
|
||||
gb_type="Storing Tensor hook handle in globals",
|
||||
context=name,
|
||||
explanation="This is not supported.",
|
||||
hints=[],
|
||||
)
|
||||
self.output.side_effects.store_global(variable, name, value)
|
||||
|
||||
# Cache note: This cache only exists for the duration of this
|
||||
@ -1402,7 +1439,12 @@ class InstructionTranslatorBase(
|
||||
globals=self.f_globals,
|
||||
)
|
||||
except ImportError:
|
||||
unimplemented("import a module that does not exist")
|
||||
unimplemented_v2(
|
||||
gb_type="Import failure",
|
||||
context=f"module_name: {module_name}, fromlist: {fromlist}, level={level}",
|
||||
explanation="Failure when attempting to import.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
if level != 0:
|
||||
pkg = self.calc_package()
|
||||
@ -1425,7 +1467,12 @@ class InstructionTranslatorBase(
|
||||
if istype(value, (types.ModuleType, DummyModule)):
|
||||
self.push(PythonModuleVariable(value, source=source))
|
||||
else:
|
||||
unimplemented(f"IMPORT_NAME {typestr(value)}")
|
||||
unimplemented_v2(
|
||||
gb_type="Bad import result",
|
||||
context=typestr(value),
|
||||
explanation="Import result is not a Python module.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
def IMPORT_FROM(self, inst):
|
||||
self.DUP_TOP(inst)
|
||||
@ -1566,13 +1613,24 @@ class InstructionTranslatorBase(
|
||||
if observed_exception_type := exc.observed_exception_map.get(val.exc_type):
|
||||
raise observed_exception_type(f"raised exception {val}")
|
||||
raise exc.ObservedException(f"raised exception {val}")
|
||||
unimplemented(f"raise {exc}")
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to raise exception",
|
||||
context=str(exc),
|
||||
explanation="Attempted to raise a non-Exception type/value.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
def RAISE_VARARGS(self, inst):
|
||||
if inst.arg == 0:
|
||||
# duplicate the top of the stack and re-raise it
|
||||
if sys.version_info < (3, 11):
|
||||
unimplemented("re-raise")
|
||||
unimplemented_v2(
|
||||
gb_type="Re-raise with no arguments",
|
||||
context="",
|
||||
explanation="Dynamo does not support re-raising the previous exception "
|
||||
"in Python < 3.11 (empty `raise`)",
|
||||
hints=[],
|
||||
)
|
||||
assert isinstance(self.stack[-1], ExceptionVariable)
|
||||
self.stack.append(self.stack[-1])
|
||||
self._raise_exception_variable(inst)
|
||||
@ -1584,14 +1642,24 @@ class InstructionTranslatorBase(
|
||||
from_vt = self.pop()
|
||||
if isinstance(from_vt, ConstantVariable) and from_vt.value is None:
|
||||
self._raise_exception_variable(inst)
|
||||
unimplemented("raise ... from ...")
|
||||
unimplemented_v2(
|
||||
gb_type="Re-raise with 2 arguments",
|
||||
context=str(from_vt),
|
||||
explanation="Dynamo does not support `raise ... from [not-None]`",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
def CLEANUP_THROW(self, inst):
|
||||
# https://github.com/python/cpython/pull/96010
|
||||
tos = self.stack[-1]
|
||||
assert isinstance(tos, ExceptionVariable)
|
||||
if tos.exc_type is StopIteration:
|
||||
unimplemented("CLEANUP_THROW with StopIteration")
|
||||
unimplemented_v2(
|
||||
gb_type="CLEANUP_THROW with StopIteration",
|
||||
context="",
|
||||
explanation="Received StopIteration when handling generator.throw/close. This is not supported.",
|
||||
hints=[],
|
||||
)
|
||||
else:
|
||||
self.RERAISE(inst)
|
||||
|
||||
@ -1599,7 +1667,13 @@ class InstructionTranslatorBase(
|
||||
if sys.version_info >= (3, 11):
|
||||
# RERAISE is currently supported in a narrow case of `raise ... from None`
|
||||
self._raise_exception_variable(inst)
|
||||
unimplemented("RERAISE")
|
||||
unimplemented_v2(
|
||||
gb_type="RERAISE in Python < 3.11",
|
||||
context="",
|
||||
explanation="RERAISE bytecode (https://docs.python.org/3.10/library/dis.html#opcode-RERAISE) "
|
||||
"not supported in Python < 3.11. This bytecode is generated by try/with blocks.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
def WITH_EXCEPT_START(self, inst):
|
||||
if sys.version_info >= (3, 11):
|
||||
@ -1679,10 +1753,13 @@ class InstructionTranslatorBase(
|
||||
block_stack_entry = self.block_stack.pop()
|
||||
|
||||
if block_stack_entry.inst.opname != "SETUP_FINALLY":
|
||||
unimplemented(
|
||||
"exception is raised when top of the block stack "
|
||||
unimplemented_v2(
|
||||
gb_type="Exception raised with invalid exception handler",
|
||||
context="",
|
||||
explanation="Exception raised when top of the block stack "
|
||||
"is not exception handler (e.g. try .. with .. except). "
|
||||
f"Current TOS is {block_stack_entry.inst}"
|
||||
f"Current TOS is {block_stack_entry.inst}",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
# 1) pop values from the stack until it matches the stack depth
|
||||
@ -1779,14 +1856,20 @@ class InstructionTranslatorBase(
|
||||
# 2) except (NotImplemetedError, AttributeError) -> TupleVariable
|
||||
|
||||
if not isinstance(expected_exc_types, (BuiltinVariable, TupleVariable)):
|
||||
unimplemented(
|
||||
f"except has an unsupported types of objects {expected_exc_types}"
|
||||
unimplemented_v2(
|
||||
gb_type="Exception with bad expected type",
|
||||
context=str(expected_exc_types),
|
||||
explanation=f"`except ...` has unsupported type {expected_exc_types}.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
if not isinstance(exc_instance, variables.ExceptionVariable):
|
||||
unimplemented(
|
||||
f"except expects to recieve an object of exception type but received {exc_instance}"
|
||||
unimplemented_v2(
|
||||
gb_type="Caught non-Exception value",
|
||||
context=str(exc_instance),
|
||||
explanation=f"Except expects to recieve an object of Exception type but received {exc_instance}.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
if isinstance(expected_exc_types, TupleVariable):
|
||||
@ -1798,8 +1881,11 @@ class InstructionTranslatorBase(
|
||||
|
||||
for expected_type in expected_types:
|
||||
if not isinstance(expected_type, BuiltinVariable):
|
||||
unimplemented(
|
||||
f"except has an unsupported types of object {expected_type}"
|
||||
unimplemented_v2(
|
||||
gb_type="Exception with non-type expectation",
|
||||
context=str(expected_type),
|
||||
explanation=f"`except ...` expects a non-type: {expected_type}.",
|
||||
hints=[],
|
||||
)
|
||||
if isinstance(exc_instance, variables.ExceptionVariable) and issubclass(
|
||||
exc_instance.exc_type, expected_type.fn
|
||||
@ -1844,7 +1930,12 @@ class InstructionTranslatorBase(
|
||||
kwargsvars = self.pop()
|
||||
argsvars = self.pop()
|
||||
else:
|
||||
unimplemented("CALL_FUNCTION_EX")
|
||||
unimplemented_v2(
|
||||
gb_type="Variadic function call with bad flags",
|
||||
context=f"flags: {inst.argval}",
|
||||
explanation=f"Attempted to call a variadic function (CALL_FUNCTION_EX) with bad flags {inst.argval}",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
# 3.13 swapped null and callable
|
||||
@ -1879,7 +1970,12 @@ class InstructionTranslatorBase(
|
||||
# args, aot_autograd/inductor while lowering generates
|
||||
# aten.random.from, again causing syntax errors. Since this
|
||||
# usecase is uncommon, graph break.
|
||||
unimplemented("random_ op is called with from keyword")
|
||||
unimplemented_v2(
|
||||
gb_type="Tensor.random_ op called with `from` keyword",
|
||||
context="",
|
||||
explanation="This is not supported.",
|
||||
hints=[],
|
||||
)
|
||||
elif (
|
||||
fn.name == "uniform_"
|
||||
and isinstance(argsvars, TupleVariable)
|
||||
@ -1892,7 +1988,12 @@ class InstructionTranslatorBase(
|
||||
# args, aot_autograd/inductor while lowering generates
|
||||
# aten.uniform.from, again causing syntax errors. Since this
|
||||
# usecase is uncommon, graph break.
|
||||
unimplemented("uniform_ op is called with from keyword")
|
||||
unimplemented_v2(
|
||||
gb_type="Tensor.uniform_ op called with `from` keyword",
|
||||
context="",
|
||||
explanation="This is not supported.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
if not isinstance(
|
||||
argsvars, BaseListVariable
|
||||
@ -1906,7 +2007,12 @@ class InstructionTranslatorBase(
|
||||
if not isinstance(argsvars, BaseListVariable) or not isinstance(
|
||||
kwargsvars, ConstDictVariable
|
||||
):
|
||||
unimplemented(f"non-static call {typestr(argsvars)} {typestr(kwargsvars)}")
|
||||
unimplemented_v2(
|
||||
gb_type="Variadic function call with bad args/kwargs type",
|
||||
context=f"args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}",
|
||||
explanation="Expected args to be a list and kwargs to be a dict",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
# Map to a dictionary of str -> VariableTracker
|
||||
kwargsvars = kwargsvars.keys_as_python_constant()
|
||||
@ -2005,7 +2111,13 @@ class InstructionTranslatorBase(
|
||||
def store_attr_graph_break(self, inst):
|
||||
log_graph_break(self.code_options, reason="STORE_ATTR-caused graph break")
|
||||
if not self.should_compile_partial_graph():
|
||||
unimplemented("should_compile_partial_graph=False")
|
||||
unimplemented_v2(
|
||||
gb_type="Should not compile partial graph (STORE_ATTR)",
|
||||
context="",
|
||||
explanation="Dynamo has determined when encountering an unsupported "
|
||||
"STORE_ATTR instruction (i.e. `obj.attr = val`) that it should not compile the partial graph.",
|
||||
hints=[],
|
||||
)
|
||||
self.output.compile_subgraph(
|
||||
self, reason=GraphCompileReason("store_attr", [self.frame_summary()])
|
||||
)
|
||||
@ -2054,7 +2166,12 @@ class InstructionTranslatorBase(
|
||||
|
||||
def BUILD_SET(self, inst):
|
||||
if config.inject_BUILD_SET_unimplemented_TESTING_ONLY:
|
||||
unimplemented("missing: BUILD_SET")
|
||||
unimplemented_v2(
|
||||
gb_type="missing BUILD_SET handler",
|
||||
context="",
|
||||
explanation="Missing BUILD_SET bytecode handler (for testing purposes).",
|
||||
hints=[],
|
||||
)
|
||||
items = self.popn(inst.argval)
|
||||
new_set = SetVariable(items, mutation_type=ValueMutationNew())
|
||||
self.push(new_set)
|
||||
@ -2066,7 +2183,13 @@ class InstructionTranslatorBase(
|
||||
try:
|
||||
items.extend(seq.force_unpack_var_sequence(self))
|
||||
except NotImplementedError:
|
||||
unimplemented(f"BUILD_LIST_UNPACK {seq}")
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to unpack object for BUILD_LIST_UNPACK",
|
||||
context=str(seq),
|
||||
explanation=f"{seq} cannot be unpacked into a list for the BUILD_LIST_UNPACK "
|
||||
"bytecode (`[*x, *y, ...]`).",
|
||||
hints=[],
|
||||
)
|
||||
self.push(cls(items, mutation_type=ValueMutationNew()))
|
||||
|
||||
def BUILD_TUPLE_UNPACK(self, inst):
|
||||
@ -2193,9 +2316,21 @@ class InstructionTranslatorBase(
|
||||
elif seq.has_force_unpack_var_sequence(self):
|
||||
val = seq.force_unpack_var_sequence(self)
|
||||
else:
|
||||
unimplemented(f"UNPACK_SEQUENCE {seq}")
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to unpack object for UNPACK_SEQUENCE",
|
||||
context=str(seq),
|
||||
explanation=f"{seq} cannot be unpacked into a list for the UNPACK_SEQUENCE bytecode "
|
||||
"(i.e. `a, b, c = d`).",
|
||||
hints=[],
|
||||
)
|
||||
if len(val) != inst.argval:
|
||||
unimplemented("UNPACK_SEQUENCE length mismatch")
|
||||
unimplemented_v2(
|
||||
gb_type="Length mismatch when unpacking object for UNPACK_SEQUENCE",
|
||||
context=f"expected length: {inst.argval}, actual: {len(val)}",
|
||||
explanation=f"{seq} unpacked to a list for the UNPACK_SEQUENCE bytecode "
|
||||
"(i.e. `a, b, c = d`) with unexpected length.",
|
||||
hints=[],
|
||||
)
|
||||
for i in reversed(val):
|
||||
self.push(i)
|
||||
|
||||
@ -2216,7 +2351,13 @@ class InstructionTranslatorBase(
|
||||
for item in reversed(vals_prefix):
|
||||
self.push(item)
|
||||
else:
|
||||
unimplemented(f"UNPACK_EX {seq}")
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to unpack object for UNPACK_EX",
|
||||
context=str(seq),
|
||||
explanation=f"{seq} cannot be unpacked into a list for the UNPACK_EX bytecode "
|
||||
"(i.e. `a, *b, c = d`).",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
def NOP(self, inst):
|
||||
pass
|
||||
@ -2311,12 +2452,20 @@ class InstructionTranslatorBase(
|
||||
format_string_parts.append(part.format_string)
|
||||
args.extend(part.sym_args)
|
||||
if set(kwargs.keys()) & set(part.sym_kwargs.keys()):
|
||||
unimplemented(
|
||||
f"BUILD_STRING key conflict {kwargs} & {part.sym_kwargs}"
|
||||
unimplemented_v2(
|
||||
gb_type="BUILD_STRING key conflict",
|
||||
context=f"format_string_parts: {format_string_parts}, kwargs: {kwargs}, part.sym_kwargs: {part.sym_kwargs}",
|
||||
explanation="Failed to build format string due to key conflict",
|
||||
hints=[],
|
||||
)
|
||||
kwargs.update(part.sym_kwargs)
|
||||
else:
|
||||
unimplemented(f"BUILD_STRING {part}")
|
||||
unimplemented_v2(
|
||||
gb_type="BUILD_STRING type error",
|
||||
context=str(part),
|
||||
explanation="Format string part type is not correct - expected constant or format string.",
|
||||
hints=[],
|
||||
)
|
||||
self.push(
|
||||
variables.StringFormatVariable.create(
|
||||
"".join(format_string_parts), args, kwargs
|
||||
@ -2638,7 +2787,12 @@ class InstructionTranslatorBase(
|
||||
|
||||
def LOAD_FAST_CHECK(self, inst):
|
||||
if isinstance(self.symbolic_locals[inst.argval], NullVariable):
|
||||
unimplemented("LOAD_FAST_CHECK on uninitialized variable")
|
||||
unimplemented_v2(
|
||||
gb_type="LOAD_FAST_CHECK on uninitialized variable",
|
||||
context=inst.argval,
|
||||
explanation=f"Attempted to load uninitialized local variable {inst.argval}",
|
||||
hints=[],
|
||||
)
|
||||
self.LOAD_FAST(inst)
|
||||
|
||||
def LOAD_FAST_AND_CLEAR(self, inst):
|
||||
@ -2666,7 +2820,12 @@ class InstructionTranslatorBase(
|
||||
# INTRINSIC_LIST_TO_TUPLE
|
||||
self.push(TupleVariable(self.pop().force_unpack_var_sequence(self)))
|
||||
else:
|
||||
unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}")
|
||||
unimplemented_v2(
|
||||
gb_type="Missing CALL_INTRINSIC_1 handler",
|
||||
context=f"CALL_INTRINSIC_1 operand: {inst.argval}",
|
||||
explanation=f"No handler implemented for CALL_INTRINSIC_1 {inst.argval} instruction.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
def END_SEND(self, inst):
|
||||
tos = self.pop()
|
||||
@ -3052,7 +3211,12 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
# if it reaches here, it means Dynamo failed to inline a functorch function
|
||||
f"- torch.func.{name}(fn) requires the function to be inlined by dynamo"
|
||||
)
|
||||
unimplemented(msg)
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported functorch tracing attempt",
|
||||
context="",
|
||||
explanation=msg,
|
||||
hints=[],
|
||||
)
|
||||
|
||||
def get_example_value(self, source: Source):
|
||||
if isinstance(source, LocalSource):
|
||||
@ -3288,7 +3452,13 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
@staticmethod
|
||||
def check_inlineable(func):
|
||||
if func.has_self():
|
||||
unimplemented("inline with __self__")
|
||||
unimplemented_v2(
|
||||
gb_type="Inline attempt with __self__",
|
||||
context=str(func),
|
||||
explanation="Attempted to inline a function with the `__self__` attribute. "
|
||||
"Dynamo is expected to decompose method calls into function calls with a `self` argument.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
result = trace_rules.check_verbose(func, is_inlined_call=True)
|
||||
if result.skipped:
|
||||
@ -3346,7 +3516,12 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
kwargs,
|
||||
):
|
||||
if isinstance(func, SkipFunctionVariable):
|
||||
unimplemented("inline with functions in skip files")
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to inline function marked as skipped (SkipFunctionVariable)",
|
||||
context=f"Attempted to inline a SkipFunctionVariable {func}",
|
||||
explanation="Attempted to inline a function that was previously determined to be marked as intentionally skipped.",
|
||||
hints=[],
|
||||
)
|
||||
assert isinstance(
|
||||
func,
|
||||
(
|
||||
@ -3373,13 +3548,23 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
|
||||
for v in itertools.chain(sub_locals.values()):
|
||||
if not isinstance(v, VariableTracker):
|
||||
unimplemented(f"unconverted arg {v}")
|
||||
unimplemented_v2(
|
||||
gb_type="Encountered unconverted argument when attempting to inline",
|
||||
context=f"func: {func}, arg: {v}",
|
||||
explanation="An argument to an inlined function was not successfully converted to a VariableTracker.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
code: types.CodeType = func.get_code()
|
||||
if code.co_name in ("__setitem__", "__setattr__") and not (
|
||||
args and isinstance(args[0], variables.UserDefinedObjectVariable)
|
||||
):
|
||||
unimplemented(f"inline {code.co_name}")
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported __setitem__/__setattr__ inline attempt",
|
||||
context=f"code name: {code.co_name}, args: {args}",
|
||||
explanation=f"Attempted to inline {code.co_name} where first argument (self) is not a user-defined object.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
suffix = ""
|
||||
# TODO: mlazos, add support for enabling multiple artifact logs
|
||||
@ -3541,7 +3726,12 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
return False # inlining functions is all-or-nothing
|
||||
|
||||
def create_call_resume_at(self, offset):
|
||||
unimplemented("cant resume while inlining")
|
||||
unimplemented_v2(
|
||||
gb_type="Graph break in inlined function",
|
||||
context="",
|
||||
explanation="Graph breaks in an inlined call are not supported.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
def RETURN_VALUE(self, inst):
|
||||
self.symbolic_result = self.pop() # type: ignore[assignment]
|
||||
@ -3600,7 +3790,12 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
else:
|
||||
value = self.pop()
|
||||
if isinstance(value, RemovableHandleVariable):
|
||||
unimplemented("Storing handles in globals - NYI")
|
||||
unimplemented_v2(
|
||||
gb_type="Storing Tensor hook handle in globals (inline call)",
|
||||
context=inst.argval,
|
||||
explanation="This is not supported.",
|
||||
hints=[],
|
||||
)
|
||||
name = inst.argval
|
||||
_fglobals_value, fglobals_vt, _ = self.get_globals_source_and_value(name)
|
||||
self.output.side_effects.store_attr(fglobals_vt, name, value)
|
||||
@ -3658,7 +3853,12 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
|
||||
# lifted the `unimplemented("generator")` in frame conversion. This codepath handles
|
||||
# subgenerator and lines up with this line in Python 3.10
|
||||
# https://github.com/python/cpython/blob/3.10/Python/ceval.c#L2599
|
||||
unimplemented("Unreachable sub-generator code")
|
||||
unimplemented_v2(
|
||||
gb_type="Unreachable sub-generator code",
|
||||
context="",
|
||||
explanation="Should only be encountered while implementing generator support.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
try:
|
||||
val = tos.next_variable(self)
|
||||
@ -3711,6 +3911,16 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
|
||||
# lifted the `unimplemented("generator")` in frame conversion. This codepath handles
|
||||
# subgenerator and lines up with this line in Python 3.11
|
||||
# https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2597
|
||||
unimplemented("Unreachable sub-generator code")
|
||||
unimplemented_v2(
|
||||
gb_type="Unreachable sub-generator code",
|
||||
context="",
|
||||
explanation="Should only be encountered while implementing generator support.",
|
||||
hints=[],
|
||||
)
|
||||
else:
|
||||
unimplemented(f"SEND {typestr(tos)}")
|
||||
unimplemented_v2(
|
||||
gb_type="SEND with bad type",
|
||||
context=f"TOS type: {typestr(tos)}",
|
||||
explanation=f"Attempted to SEND with unsupported type {typestr(tos)}.",
|
||||
hints=[],
|
||||
)
|
||||
|
@ -1109,11 +1109,14 @@ def proxy_args_kwargs(args, kwargs):
|
||||
proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}
|
||||
return proxy_args, proxy_kwargs
|
||||
except NotImplementedError as e:
|
||||
from .exc import unimplemented
|
||||
from .exc import unimplemented_v2
|
||||
from .variables.base import typestr
|
||||
|
||||
unimplemented(
|
||||
f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}",
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to convert args/kwargs to proxy",
|
||||
context=f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}",
|
||||
explanation="Missing `as_proxy()` implementation for some arg/kwarg.",
|
||||
hints=[],
|
||||
from_exc=e,
|
||||
)
|
||||
|
||||
@ -2464,9 +2467,14 @@ def set_example_value(node, example_value):
|
||||
def _get_fake_tensor(vt):
|
||||
fake_tensor = vt.as_proxy().node.meta.get("example_value")
|
||||
if not is_fake(fake_tensor):
|
||||
from .exc import unimplemented
|
||||
from .exc import unimplemented_v2
|
||||
|
||||
unimplemented("Cannot check Tensor object identity without its fake value")
|
||||
unimplemented_v2(
|
||||
gb_type="Cannot check Tensor object identity without its fake value",
|
||||
context=str(fake_tensor),
|
||||
explanation="TensorVariable is missing a fake example_value.",
|
||||
hints=[],
|
||||
)
|
||||
return fake_tensor
|
||||
|
||||
|
||||
@ -2578,11 +2586,17 @@ def wrap_fake_exception(fn):
|
||||
try:
|
||||
return fn()
|
||||
except UnsupportedFakeTensorException as e:
|
||||
from .exc import unimplemented
|
||||
from .exc import unimplemented_v2
|
||||
|
||||
msg = f"Unsupported: {e.reason} with fake tensor propagation."
|
||||
msg = f"Encountered exception ({e.reason}) during fake tensor propagation."
|
||||
log.warning(msg)
|
||||
unimplemented(msg, from_exc=e)
|
||||
unimplemented_v2(
|
||||
gb_type="Fake tensor propagation exception",
|
||||
context=str(e.reason),
|
||||
explanation=msg,
|
||||
hints=[],
|
||||
from_exc=e,
|
||||
)
|
||||
|
||||
|
||||
def deepcopy_to_fake_tensor(obj, fake_mode):
|
||||
@ -2956,9 +2970,14 @@ def extract_fake_example_value(node, required=True):
|
||||
if "example_value" in node.meta and is_fake(node.meta["example_value"]):
|
||||
return node.meta["example_value"]
|
||||
elif required:
|
||||
from torch._dynamo.exc import unimplemented
|
||||
from torch._dynamo.exc import unimplemented_v2
|
||||
|
||||
unimplemented("`FakeTensor` example value was required but not available")
|
||||
unimplemented_v2(
|
||||
gb_type="Missing FakeTensor example value",
|
||||
context=str(node),
|
||||
explanation=f"`FakeTensor` example value was required for {node} but not available.",
|
||||
hints=[],
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -3002,7 +3021,6 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
|
||||
|
||||
from .exc import (
|
||||
TorchRuntimeError,
|
||||
unimplemented,
|
||||
unimplemented_v2,
|
||||
Unsupported,
|
||||
UserError,
|
||||
@ -3124,10 +3142,15 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
|
||||
f"module `{module}` and you may need to `import {module}`"
|
||||
f"({ctx}), otherwise "
|
||||
)
|
||||
unimplemented(
|
||||
f"unsupported operator: {cause.func} ({import_suggestion}see "
|
||||
"https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0"
|
||||
" for how to fix)"
|
||||
unimplemented_v2(
|
||||
gb_type="Operator does not support running with fake tensors",
|
||||
context=f"unsupported operator: {cause.func}",
|
||||
explanation="",
|
||||
hints=[
|
||||
f"{import_suggestion}see "
|
||||
"https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0"
|
||||
" for how to fix",
|
||||
],
|
||||
)
|
||||
elif isinstance(
|
||||
cause, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
|
||||
@ -3140,7 +3163,12 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
|
||||
elif isinstance(cause, ValueRangeError):
|
||||
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, e.args[0]) from e
|
||||
elif isinstance(cause, TypeError) and "argument" in str(cause):
|
||||
unimplemented(f"TypeError {node.target}: {cause}")
|
||||
unimplemented_v2(
|
||||
gb_type="TypeError when making fake tensor call",
|
||||
context=f"TypeError {node.target}: {cause}",
|
||||
explanation="",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
|
||||
|
||||
@ -3197,9 +3225,14 @@ def run_node(tracer, node, args, kwargs, nnmodule):
|
||||
return node.target(*args, **kwargs)
|
||||
elif op == "call_method":
|
||||
if not hasattr(args[0], node.target):
|
||||
from .exc import unimplemented
|
||||
from .exc import unimplemented_v2
|
||||
|
||||
unimplemented(make_error_message("attribute not defined"))
|
||||
unimplemented_v2(
|
||||
gb_type="Missing attribute when running call_method node",
|
||||
context="",
|
||||
explanation=make_error_message("attribute not defined"),
|
||||
hints=[],
|
||||
)
|
||||
return getattr(args[0], node.target)(*args[1:], **kwargs)
|
||||
elif op == "call_module":
|
||||
assert nnmodule is not None
|
||||
@ -3212,9 +3245,15 @@ def run_node(tracer, node, args, kwargs, nnmodule):
|
||||
|
||||
except (NotImplementedError, UnsupportedFakeTensorException) as e:
|
||||
# NB: mimic how wrap_fake_exception does it
|
||||
from .exc import unimplemented
|
||||
from .exc import unimplemented_v2
|
||||
|
||||
unimplemented(make_error_message(e), from_exc=e)
|
||||
unimplemented_v2(
|
||||
gb_type="NotImplementedError/UnsupportedFakeTensorException when running node",
|
||||
context=str(e),
|
||||
explanation=make_error_message(e),
|
||||
hints=[],
|
||||
from_exc=e,
|
||||
)
|
||||
except Unsupported:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
Reference in New Issue
Block a user