[dynamo] unimplemented -> unimplemented_v2 in variables/builder.py (#151044)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151044
Approved by: https://github.com/anijain2305, https://github.com/zou3519
This commit is contained in:
William Wen
2025-04-10 14:20:29 -07:00
committed by PyTorch MergeBot
parent d6f1c72354
commit 183bca41de
3 changed files with 102 additions and 30 deletions

View File

@ -5899,7 +5899,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
with _patch_config({"allow_rnn": False}):
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"TorchDynamo purposely graph breaks on RNN, GRU, LSTMs",
"Dynamo does not support RNN, GRU, or LSTM.",
):
_ = export(mod, inp, strict=True)

View File

@ -21,6 +21,6 @@ CAUSED_BY_EARLIER_GRAPH_BREAK = [
]
INFERENCE_MODE = [
"Avoid using `tensor.is_inference()` and `torch.is_inference_mode_enabled()` in your compile code. "
"This is primarily used in conjunction with `torch.inference_mode`. Consider using `torch.no_grad` instead ",
" because `torch.no_grad` leads to same improvements as `inference_mode` when `torch.compile` is used.",
"This is primarily used in conjunction with `torch.inference_mode`. Consider using `torch.no_grad` instead "
"because `torch.no_grad` leads to same improvements as `inference_mode` when `torch.compile` is used.",
]

View File

@ -78,7 +78,7 @@ from torch.utils.weak import TensorWeakRef
from .. import config, graph_break_hints, mutation_guard, replay_record, trace_rules
from ..device_interface import get_registered_device_interfaces
from ..exc import InternalTorchDynamoError, unimplemented, unimplemented_v2
from ..exc import InternalTorchDynamoError, unimplemented_v2
from ..guards import GuardBuilder, install_guard, make_dupe_guard
from ..pgo import (
auto_dynamic,
@ -516,7 +516,13 @@ class VariableBuilder:
# Our current infra requires the hook to be registered and removed in
# the same frame. So graph break.
# Related test - PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py -k TestAutograd.test_hooks
unimplemented("unregistered hook removable handle")
unimplemented_v2(
gb_type="Attempted to represent unregistered RemovableHandle",
context="",
explanation="Dynamo attempted to build a representation of a torch.utils.hooks.RemovableHandle, "
"which is not supported. This happens because the RemovableHandle was created in another frame.",
hints=[],
)
def wrap_jit_function(self, value):
self.install_guards(GuardBuilder.TYPE_MATCH)
@ -532,7 +538,14 @@ class VariableBuilder:
all_const = all(ConstantVariable.is_literal(k) for k in value.keys())
if not all_const:
unimplemented("mapping proxy type supports only const keys")
unimplemented_v2(
gb_type="non-const keys in mappingproxy",
context=f"non-const keys: {[k for k in value.keys() if not ConstantVariable.is_literal(k)]}",
explanation="Dynamo expects mappingproxy keys to be constants.",
hints=[
"Ensure your mappingproxy keys are constants (e.g. int, float, strings)",
],
)
def build_key_value(k, v):
key = ConstantVariable.create(k)
@ -777,7 +790,12 @@ class VariableBuilder:
keywords_source = AttrSource(self.get_source(), "keywords")
for k, v in value.keywords.items():
if not ConstantVariable.is_literal(k):
unimplemented("functools.partial with non-literal keyword")
unimplemented_v2(
gb_type="functools.partial() with non-literal keyword",
context=f"non-literal keyword: {k}",
explanation="functools.partial() expects literal/string keywords",
hints=[*graph_break_hints.USER_ERROR],
)
keywords[k] = VariableBuilder(
self.tx, DictGetItemSource(keywords_source, k)
)(v)
@ -904,8 +922,11 @@ class VariableBuilder:
return self.wrap_unspecialized_primitive(value)
elif isinstance(value, HigherOrderOperator):
if value is torch._higher_order_ops.invoke_subgraph:
unimplemented(
"Directly using invoke_subgraph is not supported. Use mark_compile_region"
unimplemented_v2(
gb_type="Attempted to wrap torch._higher_order_ops.invoke_subgraph",
context="",
explanation="Directly using invoke_subgraph is not supported. Use mark_compile_region",
hints=[],
)
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH)
return TorchHigherOrderOperatorVariable.make(value, source=self.source)
@ -1030,8 +1051,11 @@ class VariableBuilder:
# this is automatically done by evaluating the guards once but this
# will cause data-dependent error when we evaluate the outer unbacked symints.
# The test case that triggers this graph break is test_cond_unbacked_symint_closure
unimplemented(
"unbacked symint input is not supported yet. If you need this feature, please file a github issue."
unimplemented_v2(
gb_type="Attempted to wrap unbacked SymInt",
context="",
explanation="Unbacked SymInt input is not supported yet.",
hints=[*graph_break_hints.SUPPORTABLE],
)
sym_node_proxy = self.tx.output.root_tracer.create_graph_input(
@ -1404,7 +1428,14 @@ class VariableBuilder:
)
return DictKeySetVariable(items, source=self.source)
else:
unimplemented("dict_keys with non-constant keys are not supported")
unimplemented_v2(
gb_type="non-const keys in dict_keys",
context=f"non-const keys: {[k for k in value if not ConstantVariable.is_literal(k)]}",
explanation="Dynamo expects dict_keys keys to be constants.",
hints=[
"Ensure your dict_keys keys are constants (e.g. int, float, strings)",
],
)
else:
return self.wrap_user_defined(value)
@ -1589,7 +1620,12 @@ class VariableBuilder:
isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM))
and not config.allow_rnn
):
unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs")
unimplemented_v2(
gb_type="Attempted to wrap RNN, GRU, or LSTM",
context=str(value),
explanation="Dynamo does not support RNN, GRU, or LSTM.",
hints=[*graph_break_hints.SUPPORTABLE],
)
if getattr(value, "_is_fsdp_managed_module", False):
# See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule]
@ -1598,7 +1634,12 @@ class VariableBuilder:
# we can't do this assert inside FSDP constructor,
# since we don't know yet whether dynamo will be used
if not getattr(value, "_fsdp_use_orig_params", False):
unimplemented("Dynamo only supports FSDP with use_orig_params=True")
unimplemented_v2(
gb_type="FSDP with use_orig_params=False",
context="",
explanation="Dynamo only supports FSDP with use_orig_params=True",
hints=[],
)
# Note on FSDP guarding
# Eager FSDP already assumes (requires, but without enforcement)
@ -1814,7 +1855,12 @@ class VariableBuilder:
and value.is_nested
and not isinstance(value, torch.nested._internal.nested_tensor.NestedTensor)
):
unimplemented("torch.compile does not support strided NestedTensor")
unimplemented_v2(
gb_type="Attempted to wrap strided NestedTensor",
context="",
explanation="torch.compile does not support strided NestedTensor",
hints=[],
)
# TODO(pearu,sparse-team) - Add the corresponding SPARSE_TENSOR_MATCH guards
if (
@ -1825,18 +1871,24 @@ class VariableBuilder:
# A hot fix for sparse tensors + torch.compile. Support for
# export + sparsity is being added but we need to create
# SPARSE_TENSOR_GUARDS for guards to work propertly.
unimplemented("torch.compile does not support sparse Tensors")
unimplemented_v2(
gb_type="Attempted to wrap sparse Tensor",
context="",
explanation="torch.compile does not support sparse Tensors",
hints=[*graph_break_hints.SUPPORTABLE],
)
if (
safe_has_grad(value)
and safe_grad(value) is not None
and value.dtype != safe_grad(value).dtype
):
unimplemented(
"Inconsistent dtype between tensor and its gradient. "
"This can happen in FSDP and crashes meta tensor creation. "
"This is potentially a workaround. Fixing it correctly "
"requires some design around FSDP + torch.compile."
unimplemented_v2(
gb_type="dtype mismatch between tensor and its gradient",
context=f"tensor dtype: {value.dtype}; grad dtype: {safe_grad(value).dtype}",
explanation="Inconsistent dtype between tensor and its gradient. "
"This can happen in FSDP and crashes meta tensor creation.",
hints=[*graph_break_hints.SUPPORTABLE],
)
# tx.output has multiple tracers if we're introspecting HigherOrderOperator.
@ -1952,7 +2004,13 @@ class VariableBuilder:
tensor_value = clone_preserve_strides(tensor_value)
except NotImplementedError as e:
# failed to convert to tensor, graph break
unimplemented(str(e))
unimplemented_v2(
gb_type="failed to convert numpy.ndarray to Tensor",
context=str(value),
explanation="Exception encountered when attempting to convert numpy.ndarray to Tensor",
hints=[],
from_exc=e,
)
# We do this because we want the full behavior of guarding the numpy ndarray as if it were
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
@ -2318,7 +2376,12 @@ def _dataclasses_fields_lambda(obj):
if isinstance(obj, UserDefinedObjectVariable):
value = obj.value
else:
unimplemented(f"Dataclass fields handling fails for type {obj}")
unimplemented_v2(
gb_type="dataclass fields failure",
context=f"obj: {obj}; variable type: {type(obj)}",
explanation=f"Dataclass fields handling fails for {obj}. Expected it to be a user-defined object.",
hints=[],
)
items = []
for field in dataclasses.fields(value):
source = None
@ -2715,10 +2778,11 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
set_example_value(proxy.node, example_value)
return ConstantVariable.create(example_value, **options)
else:
unimplemented(
"torch.* op returned non-Tensor "
+ f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}",
case_name="unsupported_operator",
unimplemented_v2(
gb_type="torch.* op returned non-Tensor",
context=f"example_value type: {typestr(example_value)}; op: {proxy.node.op}; target: {proxy.node.target}",
explanation="torch.* ops that return a non-Tensor cannot be traced into the Dynamo FX graph output",
hints=[],
)
@ -2813,7 +2877,12 @@ def _automatic_dynamic(
if e.is_nested and not isinstance(
e, torch.nested._internal.nested_tensor.NestedTensor
):
unimplemented("torch.compile does not support strided NestedTensor")
unimplemented_v2(
gb_type="Encountered strided NestedTensor in automatic dynamic dim determination",
context="",
explanation="torch.compile does not support strided NestedTensor",
hints=[],
)
name = source.name()
dynamic_sources = get_dynamic_sources()
@ -3259,8 +3328,11 @@ class SourcelessBuilder:
)
elif isinstance(value, types.GenericAlias):
return TypingVariable(value)
unimplemented(
f"Unexpected type in sourceless builder {value_type.__module__}.{value_type.__qualname__}"
unimplemented_v2(
gb_type="Unexpected type in sourceless builder",
context=f"{value_type.__module__}.{value_type.__qualname__}",
explanation=f"SourcelessBuilder.create does not know how to wrap {value_type}",
hints=[*graph_break_hints.DYNAMO_BUG],
)
@staticmethod