mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[dynamo] unimplemented -> unimplemented_v2 for the rest of variables/misc.py (#167001)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167001 Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
bf8297afe0
commit
91b626e2ef
@ -67,7 +67,7 @@ class IgnoreLogsTests(torch._dynamo.test_case.TestCase):
|
|||||||
self.assertEqual(len(counters["graph_break"]), 0)
|
self.assertEqual(len(counters["graph_break"]), 0)
|
||||||
else:
|
else:
|
||||||
self.assertIn("moo", printed_output)
|
self.assertIn("moo", printed_output)
|
||||||
self.assertEqual(len(counters["graph_break"]), 1)
|
self.assertGreater(len(counters["graph_break"]), 0)
|
||||||
|
|
||||||
|
|
||||||
class ReorderLogsTests(torch._dynamo.test_case.TestCase):
|
class ReorderLogsTests(torch._dynamo.test_case.TestCase):
|
||||||
|
|||||||
@ -2950,5 +2950,127 @@
|
|||||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
],
|
||||||
|
"GB0289": [
|
||||||
|
{
|
||||||
|
"Gb_type": "unsupported method call on `typing` variable",
|
||||||
|
"Context": "typing variable: {self.value}, method name: {name}, args: {args}, kwargs: {kwargs}",
|
||||||
|
"Explanation": "`torch.compile` does not support method call `{name}` on `typing` variable f{self.value}.",
|
||||||
|
"Hints": [
|
||||||
|
"Avoid calling the {name} method on {self.value}.",
|
||||||
|
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"GB0290": [
|
||||||
|
{
|
||||||
|
"Gb_type": "attempted to trace numpy.* function as a method",
|
||||||
|
"Context": "numpy function: {self.value}, args: {args}, kwargs: {kwargs}",
|
||||||
|
"Explanation": "Tracing numpy.* functions as methods is not supported.",
|
||||||
|
"Hints": [
|
||||||
|
"This graph break may be difficult to debug. Please report an issue to PyTorch for assistance."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"GB0291": [
|
||||||
|
{
|
||||||
|
"Gb_type": "logging.Logger method not supported for non-export cases",
|
||||||
|
"Context": "method: {self.value}.{name}, args: {args}, kwargs: {kwargs}",
|
||||||
|
"Explanation": "logging.Logger methods are not supported for non-export cases.",
|
||||||
|
"Hints": [
|
||||||
|
"Add the logging method to `torch._dynamo.config.ignore_logger_methods."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"GB0292": [
|
||||||
|
{
|
||||||
|
"Gb_type": "constant-like method call with unsupported return type",
|
||||||
|
"Context": "{self._error_prefix}.{name}(*{args}, **{kwargs}) returned {result}",
|
||||||
|
"Explanation": "Attempted to call {self._error_prefix}.{name}, got unsupported return value {result}.",
|
||||||
|
"Hints": [
|
||||||
|
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"GB0293": [
|
||||||
|
{
|
||||||
|
"Gb_type": "attempted to trace numpy function with config.trace_numpy=False",
|
||||||
|
"Context": "numpy function: {self.value}, args: {args}, kwargs: {kwargs}",
|
||||||
|
"Explanation": "Attempted to trace numpy function {self.value} while `torch._dynamo.config.trace_numpy` was set to False.",
|
||||||
|
"Hints": [
|
||||||
|
"Set `torch._dynamo.config.trace_numpy` to True to trace numpy functions."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"GB0294": [
|
||||||
|
{
|
||||||
|
"Gb_type": "attempted to trace numpy function unsupported by PyTorch",
|
||||||
|
"Context": "numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})",
|
||||||
|
"Explanation": "Can't find numpy numpy function {self.value} in torch._numpy.",
|
||||||
|
"Hints": [
|
||||||
|
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"GB0295": [
|
||||||
|
{
|
||||||
|
"Gb_type": "cannot reconstruct NullVariable in Python < 3.11",
|
||||||
|
"Context": "",
|
||||||
|
"Explanation": "Attempted to generate PUSH_NULL instruction in Python < 3.11; where this instruction does not exist.",
|
||||||
|
"Hints": [
|
||||||
|
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"GB0296": [
|
||||||
|
{
|
||||||
|
"Gb_type": "attempted to reorder a debugging function that can't actually be reordered",
|
||||||
|
"Context": "fn: {self.value}, args: {args}, kwargs: {kwargs}",
|
||||||
|
"Explanation": "`torch.compile` can only reorder functions where the arguments are Tensors, constants, or string formatters.",
|
||||||
|
"Hints": [
|
||||||
|
"Avoid calling the logging function {self.value} with args that are not supported."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"GB0297": [
|
||||||
|
{
|
||||||
|
"Gb_type": "random.Random() with improper arguments",
|
||||||
|
"Context": "args: {args}, kwargs: {kwargs}",
|
||||||
|
"Explanation": "random.Random() with > 1 arg or with kwargs is not supported.",
|
||||||
|
"Hints": [
|
||||||
|
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"GB0298": [
|
||||||
|
{
|
||||||
|
"Gb_type": "attempted to trace torch._numpy.random function with config.use_numpy_random_stream=True",
|
||||||
|
"Context": "numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})",
|
||||||
|
"Explanation": "Attempted to trace {self.value} when `torch._dynamo.config.use_numpy_random_stream` is set to True.",
|
||||||
|
"Hints": [
|
||||||
|
"Set `torch._dynamo.config.use_numpy_random_stream` to False.",
|
||||||
|
"Avoid calling {self.value}."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"GB0299": [
|
||||||
|
{
|
||||||
|
"Gb_type": "constant-like method call with non-constant args",
|
||||||
|
"Context": "{self._error_prefix}.{name}(*{args}, **{kwargs})",
|
||||||
|
"Explanation": "Attempted to call {self._error_prefix}.{name} with non-constant args.",
|
||||||
|
"Hints": [
|
||||||
|
"Ensure that the args to the method call are constant (int, str, etc.)."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"GB0300": [
|
||||||
|
{
|
||||||
|
"Gb_type": "numpy function that produces a const collection type encountered non-const arguments",
|
||||||
|
"Context": "numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})",
|
||||||
|
"Explanation": "numpy function {self.value} that produces a const collection type (e.g. np.dtype, np.iinfo/np.finfo) received arguments that are not constant.",
|
||||||
|
"Hints": [
|
||||||
|
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||||
|
]
|
||||||
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@ -39,7 +39,7 @@ from ..bytecode_transformation import (
|
|||||||
create_instruction,
|
create_instruction,
|
||||||
)
|
)
|
||||||
from ..create_parameter_op import do_not_convert_to_tracable_parameter
|
from ..create_parameter_op import do_not_convert_to_tracable_parameter
|
||||||
from ..exc import raise_observed_exception, unimplemented, unimplemented_v2
|
from ..exc import raise_observed_exception, unimplemented_v2
|
||||||
from ..guards import GuardBuilder, install_guard
|
from ..guards import GuardBuilder, install_guard
|
||||||
from ..mutation_guard import unpatched_nn_module_init
|
from ..mutation_guard import unpatched_nn_module_init
|
||||||
from ..source import (
|
from ..source import (
|
||||||
@ -1382,7 +1382,15 @@ class TypingVariable(VariableTracker):
|
|||||||
if name == "__getitem__" and len(args) == 1:
|
if name == "__getitem__" and len(args) == 1:
|
||||||
new_typing = self.value[args[0].as_python_constant()]
|
new_typing = self.value[args[0].as_python_constant()]
|
||||||
return TypingVariable(new_typing)
|
return TypingVariable(new_typing)
|
||||||
unimplemented("unsupported method call on typing variable")
|
unimplemented_v2(
|
||||||
|
gb_type="unsupported method call on `typing` variable",
|
||||||
|
context=f"typing variable: {self.value}, method name: {name}, args: {args}, kwargs: {kwargs}",
|
||||||
|
explanation=f"`torch.compile` does not support method call `{name}` on `typing` variable f{self.value}.",
|
||||||
|
hints=[
|
||||||
|
f"Avoid calling the {name} method on {self.value}.",
|
||||||
|
*graph_break_hints.SUPPORTABLE,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
def var_getattr(self, tx: "InstructionTranslator", name: str):
|
def var_getattr(self, tx: "InstructionTranslator", name: str):
|
||||||
from .builder import SourcelessBuilder, VariableBuilder
|
from .builder import SourcelessBuilder, VariableBuilder
|
||||||
@ -1493,16 +1501,28 @@ class NumpyVariable(VariableTracker):
|
|||||||
kwargs: "dict[str, VariableTracker]",
|
kwargs: "dict[str, VariableTracker]",
|
||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
if not config.trace_numpy:
|
if not config.trace_numpy:
|
||||||
unimplemented(f"numpy.{self.value}()")
|
unimplemented_v2(
|
||||||
|
gb_type="attempted to trace numpy function with config.trace_numpy=False",
|
||||||
|
context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs}",
|
||||||
|
explanation=f"Attempted to trace numpy function {self.value} "
|
||||||
|
"while `torch._dynamo.config.trace_numpy` was set to False.",
|
||||||
|
hints=[
|
||||||
|
"Set `torch._dynamo.config.trace_numpy` to True to trace numpy functions.",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
from ..utils import numpy_to_tensor_wrapper
|
from ..utils import numpy_to_tensor_wrapper
|
||||||
from .tensor import NumpyNdarrayVariable
|
from .tensor import NumpyNdarrayVariable
|
||||||
|
|
||||||
func = get_np_to_tnp_map().get(self.value)
|
func = get_np_to_tnp_map().get(self.value)
|
||||||
if func is None:
|
if func is None:
|
||||||
unimplemented(
|
unimplemented_v2(
|
||||||
f"Can't find numpy function {self.value} in torch._numpy. "
|
gb_type="attempted to trace numpy function unsupported by PyTorch",
|
||||||
" Please file an issue to request support for this function."
|
context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})",
|
||||||
|
explanation=f"Can't find numpy numpy function {self.value} in torch._numpy.",
|
||||||
|
hints=[
|
||||||
|
*graph_break_hints.SUPPORTABLE,
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo)
|
# We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo)
|
||||||
@ -1516,20 +1536,32 @@ class NumpyVariable(VariableTracker):
|
|||||||
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except NotImplementedError:
|
except AsPythonConstantNotImplementedError:
|
||||||
unimplemented(
|
unimplemented_v2(
|
||||||
f"{self.value.__name__} with non-const args: {args} {kwargs}"
|
gb_type="numpy function that produces a const collection type encountered non-const arguments",
|
||||||
|
context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})",
|
||||||
|
explanation=f"numpy function {self.value} that produces a const collection type "
|
||||||
|
"(e.g. np.dtype, np.iinfo/np.finfo) "
|
||||||
|
"received arguments that are not constant.",
|
||||||
|
hints=[
|
||||||
|
*graph_break_hints.USER_ERROR,
|
||||||
|
],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if (
|
if (
|
||||||
func.__module__ == "torch._numpy.random"
|
func.__module__ == "torch._numpy.random"
|
||||||
and config.use_numpy_random_stream
|
and config.use_numpy_random_stream
|
||||||
):
|
):
|
||||||
msg = f"delegate '{func.__qualname__}' to NumPy itself via "
|
unimplemented_v2(
|
||||||
msg += (
|
gb_type="attempted to trace torch._numpy.random function with config.use_numpy_random_stream=True",
|
||||||
f"config.use_numpy_random_stream={config.use_numpy_random_stream}"
|
context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})",
|
||||||
|
explanation=f"Attempted to trace {self.value} when `torch._dynamo.config.use_numpy_random_stream` "
|
||||||
|
"is set to True.",
|
||||||
|
hints=[
|
||||||
|
"Set `torch._dynamo.config.use_numpy_random_stream` to False.",
|
||||||
|
f"Avoid calling {self.value}.",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
unimplemented(msg)
|
|
||||||
|
|
||||||
args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs)
|
args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs)
|
||||||
|
|
||||||
@ -1559,7 +1591,14 @@ class NumpyVariable(VariableTracker):
|
|||||||
args: "list[VariableTracker]",
|
args: "list[VariableTracker]",
|
||||||
kwargs: "dict[str, VariableTracker]",
|
kwargs: "dict[str, VariableTracker]",
|
||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
unimplemented("numpy")
|
unimplemented_v2(
|
||||||
|
gb_type="attempted to trace numpy.* function as a method",
|
||||||
|
context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs}",
|
||||||
|
explanation="Tracing numpy.* functions as methods is not supported.",
|
||||||
|
hints=[
|
||||||
|
*graph_break_hints.DIFFICULT,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
def as_python_constant(self):
|
def as_python_constant(self):
|
||||||
return self.value
|
return self.value
|
||||||
@ -1584,7 +1623,15 @@ class NullVariable(VariableTracker):
|
|||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen"):
|
||||||
if sys.version_info < (3, 11):
|
if sys.version_info < (3, 11):
|
||||||
unimplemented("cannot reconstruct NullVariable in < Python 3.11")
|
unimplemented_v2(
|
||||||
|
gb_type="cannot reconstruct NullVariable in Python < 3.11",
|
||||||
|
context="",
|
||||||
|
explanation="Attempted to generate PUSH_NULL instruction in Python < 3.11; "
|
||||||
|
"where this instruction does not exist.",
|
||||||
|
hints=[
|
||||||
|
*graph_break_hints.DYNAMO_BUG,
|
||||||
|
],
|
||||||
|
)
|
||||||
codegen.append_output(create_instruction("PUSH_NULL"))
|
codegen.append_output(create_instruction("PUSH_NULL"))
|
||||||
|
|
||||||
|
|
||||||
@ -1665,9 +1712,14 @@ class DebuggingVariable(VariableTracker):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if not self.can_reorder_logs(self.value, args, kwargs):
|
if not self.can_reorder_logs(self.value, args, kwargs):
|
||||||
unimplemented(
|
unimplemented_v2(
|
||||||
f"Reordering debugging function {self.value} "
|
gb_type="attempted to reorder a debugging function that can't actually be reordered",
|
||||||
f"with inputs {args} {kwargs} is not yet implemented."
|
context=f"fn: {self.value}, args: {args}, kwargs: {kwargs}",
|
||||||
|
explanation="`torch.compile` can only reorder functions where the arguments "
|
||||||
|
"are Tensors, constants, or string formatters.",
|
||||||
|
hints=[
|
||||||
|
f"Avoid calling the logging function {self.value} with args that are not supported.",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tx.debug_locals.append((self, list(args)))
|
tx.debug_locals.append((self, list(args)))
|
||||||
@ -1719,10 +1771,13 @@ class LoggingLoggerVariable(VariableTracker):
|
|||||||
function = getattr(method, "__func__", None)
|
function = getattr(method, "__func__", None)
|
||||||
if {method, function}.intersection(torch._dynamo.config.ignore_logger_methods):
|
if {method, function}.intersection(torch._dynamo.config.ignore_logger_methods):
|
||||||
return variables.ConstantVariable.create(None)
|
return variables.ConstantVariable.create(None)
|
||||||
unimplemented(
|
unimplemented_v2(
|
||||||
"Logger not supported for non-export cases. "
|
gb_type="logging.Logger method not supported for non-export cases",
|
||||||
"To avoid graph breaks caused by logger in compile-mode, it is recommended to"
|
context=f"method: {self.value}.{name}, args: {args}, kwargs: {kwargs}",
|
||||||
" disable logging by adding logging methods to config.ignore_logger_methods"
|
explanation="logging.Logger methods are not supported for non-export cases.",
|
||||||
|
hints=[
|
||||||
|
"Add the logging method to `torch._dynamo.config.ignore_logger_methods.",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1759,7 +1814,14 @@ class ConstantLikeVariable(VariableTracker):
|
|||||||
cargs = [x.as_python_constant() for x in args]
|
cargs = [x.as_python_constant() for x in args]
|
||||||
ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
unimplemented(f"{self._error_prefix}.{name}(*{args}, **{kwargs})")
|
unimplemented_v2(
|
||||||
|
gb_type="constant-like method call with non-constant args",
|
||||||
|
context=f"{self._error_prefix}.{name}(*{args}, **{kwargs})",
|
||||||
|
explanation=f"Attempted to call {self._error_prefix}.{name} with non-constant args.",
|
||||||
|
hints=[
|
||||||
|
"Ensure that the args to the method call are constant (int, str, etc.).",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
result = getattr(self.value, name)(*cargs, **ckwargs)
|
result = getattr(self.value, name)(*cargs, **ckwargs)
|
||||||
|
|
||||||
@ -1768,7 +1830,14 @@ class ConstantLikeVariable(VariableTracker):
|
|||||||
if isinstance(result, re.Match):
|
if isinstance(result, re.Match):
|
||||||
return ConstantRegexMatchVariable(result)
|
return ConstantRegexMatchVariable(result)
|
||||||
|
|
||||||
unimplemented(f"{self._error_prefix}.{name}() -> {result}")
|
unimplemented_v2(
|
||||||
|
gb_type="constant-like method call with unsupported return type",
|
||||||
|
context=f"{self._error_prefix}.{name}(*{args}, **{kwargs}) returned {result}",
|
||||||
|
explanation=f"Attempted to call {self._error_prefix}.{name}, got unsupported return value {result}.",
|
||||||
|
hints=[
|
||||||
|
*graph_break_hints.SUPPORTABLE,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||||
result = getattr(self.value, name)
|
result = getattr(self.value, name)
|
||||||
@ -1831,10 +1900,15 @@ class RandomClassVariable(VariableTracker):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def call_function(self, tx: "InstructionTranslator", args, kwargs):
|
def call_function(self, tx: "InstructionTranslator", args, kwargs):
|
||||||
if len(args) > 1:
|
if len(args) > 1 or kwargs:
|
||||||
unimplemented("random.Random() with > 1 arg")
|
unimplemented_v2(
|
||||||
elif kwargs:
|
gb_type="random.Random() with improper arguments",
|
||||||
unimplemented("random.Random() with kwargs")
|
context=f"args: {args}, kwargs: {kwargs}",
|
||||||
|
explanation="random.Random() with > 1 arg or with kwargs is not supported.",
|
||||||
|
hints=[
|
||||||
|
*graph_break_hints.USER_ERROR,
|
||||||
|
],
|
||||||
|
)
|
||||||
seed = variables.ConstantVariable.create(None) if len(args) == 0 else args[0]
|
seed = variables.ConstantVariable.create(None) if len(args) == 0 else args[0]
|
||||||
return RandomVariable(
|
return RandomVariable(
|
||||||
seed=seed, mutation_type=variables.base.ValueMutationNew()
|
seed=seed, mutation_type=variables.base.ValueMutationNew()
|
||||||
|
|||||||
Reference in New Issue
Block a user