mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Introduce UserDefinedExceptionClassVariable
(#146504)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146504 Approved by: https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
8d08b49015
commit
4e7d264cf8
@ -2696,7 +2696,6 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_contextlib.py
|
||||
# https://github.com/python/cpython/blob/d48cc82ed25e26b02eb97c6263d95dcaa1e9111b/Lib/test/test_contextlib.py#L70
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_contextmanager_plain(self):
|
||||
state = []
|
||||
|
||||
@ -2977,7 +2976,6 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
|
||||
with woohoo():
|
||||
raise StopIteration
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_keywords(self):
|
||||
# Ensure no keyword arguments are inhibited
|
||||
@contextmanager
|
||||
@ -2991,7 +2989,6 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
|
||||
|
||||
fn(torch.randn(2, 3))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_recursive(self):
|
||||
depth = 0
|
||||
ncols = 0
|
||||
|
@ -13,6 +13,16 @@ from torch._dynamo.bytecode_transformation import Instruction
|
||||
from torch._dynamo.symbolic_convert import SpeculationLog, SpeculationLogDivergence
|
||||
|
||||
|
||||
class CustomException(Exception):
|
||||
...
|
||||
|
||||
|
||||
class CustomExceptionWithArgs(Exception):
|
||||
def __init__(self, a, b=None):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
|
||||
class ExceptionTests(torch._dynamo.test_case.TestCase):
|
||||
def test_exception(self):
|
||||
def fn(x):
|
||||
@ -215,6 +225,23 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
||||
got = opt_fn(x)
|
||||
self.assertEqual(expected, got)
|
||||
|
||||
def test_raise_custom_exception(self):
|
||||
class Exc(Exception):
|
||||
...
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
try:
|
||||
raise Exc
|
||||
except Exc:
|
||||
return t.sin()
|
||||
except Exception:
|
||||
return t.cos()
|
||||
|
||||
t = torch.randn(2)
|
||||
y = fn(t)
|
||||
self.assertEqual(y, t.sin())
|
||||
|
||||
def test_nn_module_getattr(self):
|
||||
class A:
|
||||
def __init__(self) -> None:
|
||||
@ -468,6 +495,37 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
||||
x = torch.randn(4)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
def test_user_defined_exception_variable(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
z = 0
|
||||
try:
|
||||
raise CustomException
|
||||
except ValueError:
|
||||
z = 1
|
||||
except CustomException:
|
||||
z = 2
|
||||
assert z == 2
|
||||
return t.sin()
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
|
||||
def test_user_defined_exception_with_args(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
z = 0
|
||||
try:
|
||||
raise CustomExceptionWithArgs(2, b=3)
|
||||
except ValueError:
|
||||
z = 1
|
||||
except CustomExceptionWithArgs:
|
||||
z = 2
|
||||
assert z == 2
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
0
test/dynamo_expected_failures/TestNN.test_unflatten
Normal file
0
test/dynamo_expected_failures/TestNN.test_unflatten
Normal file
@ -7963,7 +7963,6 @@ torch.cuda.synchronize()
|
||||
self.assertEqual(output._metadata_cache, cache)
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/128649
|
||||
@xfailIfTorchDynamo
|
||||
@dtypes(torch.float32)
|
||||
def test_composite_op_in_inference_mode(self, device, dtype):
|
||||
# expect view
|
||||
|
@ -343,8 +343,9 @@ observed_exception_map = {
|
||||
|
||||
def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedException]:
|
||||
if exc_type not in observed_exception_map:
|
||||
name = getattr(exc_type, "__name__", str(exc_type))
|
||||
observed_exception_map[exc_type] = type(
|
||||
f"Observed{exc_type.__name__}Error", (ObservedException,), {}
|
||||
f"Observed{name}Error", (ObservedException,), {}
|
||||
)
|
||||
return observed_exception_map[exc_type]
|
||||
|
||||
|
@ -256,6 +256,7 @@ class SideEffects:
|
||||
int.__getattribute__,
|
||||
str.__getattribute__,
|
||||
list.__getattribute__,
|
||||
BaseException.__getattribute__,
|
||||
)
|
||||
|
||||
def is_attribute_mutation(self, item):
|
||||
@ -377,6 +378,8 @@ class SideEffects:
|
||||
variable_cls = variables.MutableMappingVariable
|
||||
elif is_frozen_dataclass(user_cls):
|
||||
variable_cls = FrozenDataClassVariable
|
||||
elif issubclass(user_cls, BaseException):
|
||||
variable_cls = variables.UserDefinedExceptionObjectVariable
|
||||
assert issubclass(variable_cls, variables.UserDefinedObjectVariable)
|
||||
return variable_cls
|
||||
|
||||
|
@ -160,6 +160,8 @@ from .variables.torch_function import (
|
||||
from .variables.user_defined import (
|
||||
RemovableHandleVariable,
|
||||
UserDefinedClassVariable,
|
||||
UserDefinedExceptionClassVariable,
|
||||
UserDefinedExceptionObjectVariable,
|
||||
UserDefinedObjectVariable,
|
||||
)
|
||||
|
||||
@ -1660,7 +1662,9 @@ class InstructionTranslatorBase(
|
||||
# 2) raise execption instance - raise NotImplemetedError("foo")
|
||||
|
||||
# 1) when user raises exception type
|
||||
if isinstance(val, variables.BuiltinVariable):
|
||||
if isinstance(
|
||||
val, (variables.BuiltinVariable, UserDefinedExceptionClassVariable)
|
||||
):
|
||||
# Create the instance of the exception type
|
||||
# https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549
|
||||
val = val.call_function(self, [], {}) # type: ignore[arg-type]
|
||||
@ -1677,10 +1681,9 @@ class InstructionTranslatorBase(
|
||||
self.exn_vt_stack.append(val)
|
||||
|
||||
# 2) when user raises exception instance
|
||||
if isinstance(val, variables.ExceptionVariable):
|
||||
if observed_exception_type := exc.observed_exception_map.get(val.exc_type):
|
||||
raise observed_exception_type(f"raised exception {val}")
|
||||
raise exc.ObservedException(f"raised exception {val}")
|
||||
if self._isinstance_exception(val):
|
||||
observed_exception_type = exc.get_dynamo_observed_exception(val.exc_type) # type: ignore[attr-defined]
|
||||
raise observed_exception_type(f"raised exception {val}")
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to raise exception",
|
||||
context=str(exc),
|
||||
@ -1699,7 +1702,7 @@ class InstructionTranslatorBase(
|
||||
"in Python < 3.11 (empty `raise`)",
|
||||
hints=[],
|
||||
)
|
||||
assert isinstance(self.stack[-1], ExceptionVariable)
|
||||
assert self._isinstance_exception(self.stack[-1])
|
||||
self.stack.append(self.stack[-1])
|
||||
self._raise_exception_variable(inst)
|
||||
elif inst.arg == 1:
|
||||
@ -1743,6 +1746,16 @@ class InstructionTranslatorBase(
|
||||
hints=[],
|
||||
)
|
||||
|
||||
def _isinstance_exception(self, val):
|
||||
return isinstance(
|
||||
val,
|
||||
(
|
||||
variables.ExceptionVariable,
|
||||
UserDefinedExceptionClassVariable,
|
||||
UserDefinedExceptionObjectVariable,
|
||||
),
|
||||
)
|
||||
|
||||
def WITH_EXCEPT_START(self, inst):
|
||||
if sys.version_info >= (3, 11):
|
||||
# At the top of the stack are 4 values:
|
||||
@ -1755,15 +1768,15 @@ class InstructionTranslatorBase(
|
||||
assert len(self.stack) >= 4
|
||||
fn = self.stack[-4]
|
||||
val = self.stack[-1]
|
||||
assert isinstance(val, variables.ExceptionVariable)
|
||||
typ = BuiltinVariable(val.exc_type)
|
||||
assert self._isinstance_exception(val)
|
||||
typ = BuiltinVariable(val.exc_type) # type: ignore[attr-defined]
|
||||
tb = ConstantVariable(None)
|
||||
else:
|
||||
assert len(self.stack) >= 7
|
||||
fn = self.stack[-7]
|
||||
val = self.stack[-4]
|
||||
assert isinstance(val, variables.ExceptionVariable)
|
||||
typ = BuiltinVariable(val.exc_type)
|
||||
assert self._isinstance_exception(val)
|
||||
typ = BuiltinVariable(val.exc_type) # type: ignore[attr-defined]
|
||||
tb = ConstantVariable(None)
|
||||
|
||||
self.call_function(fn, [typ, val, tb], {})
|
||||
@ -1910,8 +1923,7 @@ class InstructionTranslatorBase(
|
||||
def POP_EXCEPT(self, inst):
|
||||
if sys.version_info >= (3, 11):
|
||||
val = self.pop()
|
||||
assert isinstance(val, variables.ExceptionVariable)
|
||||
|
||||
assert self._isinstance_exception(val)
|
||||
# This exception is handled and therefore we can clear the error indicator
|
||||
assert len(self.exn_vt_stack)
|
||||
self.exn_vt_stack.pop()
|
||||
@ -1947,11 +1959,20 @@ class InstructionTranslatorBase(
|
||||
# https://github.com/python/cpython/blob/3.10/Python/ceval.c#L3650-L3665
|
||||
exc_instance = self.stack.pop()
|
||||
|
||||
# Users can check exception in 2 ways
|
||||
# 1) except NotImplementedError --> BuilinVariable
|
||||
# 2) except (NotImplemetedError, AttributeError) -> TupleVariable
|
||||
# Users can check exception in 3 ways
|
||||
# 1) except NotImplementedError --> BuiltinVariable
|
||||
# 2) except CustomException --> UserDefinedExceptionClasVariable
|
||||
# 3) except (NotImplemetedError, AttributeError) -> TupleVariable
|
||||
|
||||
if not isinstance(expected_exc_types, (BuiltinVariable, TupleVariable)):
|
||||
if not isinstance(
|
||||
expected_exc_types,
|
||||
(
|
||||
BuiltinVariable,
|
||||
TupleVariable,
|
||||
UserDefinedExceptionClassVariable,
|
||||
UserDefinedExceptionObjectVariable,
|
||||
),
|
||||
):
|
||||
unimplemented_v2(
|
||||
gb_type="Exception with bad expected type",
|
||||
context=str(expected_exc_types),
|
||||
@ -1960,7 +1981,7 @@ class InstructionTranslatorBase(
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
if not isinstance(exc_instance, variables.ExceptionVariable):
|
||||
if not self._isinstance_exception(exc_instance):
|
||||
unimplemented_v2(
|
||||
gb_type="Caught non-Exception value",
|
||||
context=str(exc_instance),
|
||||
@ -1976,15 +1997,23 @@ class InstructionTranslatorBase(
|
||||
]
|
||||
|
||||
for expected_type in expected_types:
|
||||
if not isinstance(expected_type, BuiltinVariable):
|
||||
if not isinstance(
|
||||
expected_type,
|
||||
(
|
||||
BuiltinVariable,
|
||||
UserDefinedExceptionObjectVariable,
|
||||
UserDefinedExceptionClassVariable,
|
||||
),
|
||||
):
|
||||
unimplemented_v2(
|
||||
gb_type="Exception with non-type expectation",
|
||||
context=str(expected_type),
|
||||
explanation=f"`except ...` expects a non-type: {expected_type}.",
|
||||
hints=[*graph_break_hints.USER_ERROR],
|
||||
)
|
||||
if isinstance(exc_instance, variables.ExceptionVariable) and issubclass(
|
||||
exc_instance.exc_type, expected_type.fn
|
||||
if self._isinstance_exception(exc_instance) and issubclass(
|
||||
exc_instance.exc_type, # type: ignore[attr-defined]
|
||||
expected_type.fn, # type: ignore[attr-defined]
|
||||
):
|
||||
return True
|
||||
elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass(
|
||||
@ -2608,8 +2637,8 @@ class InstructionTranslatorBase(
|
||||
# https://github.com/python/cpython/pull/99006
|
||||
# https://github.com/python/cpython/commit/28187141cc34063ef857976ddbca87ba09a882c2
|
||||
val = self.stack[-1]
|
||||
assert isinstance(val, ExceptionVariable)
|
||||
if val.exc_type is StopIteration:
|
||||
assert self._isinstance_exception(val)
|
||||
if val.exc_type is StopIteration: # type: ignore[attr-defined]
|
||||
new_val = variables.BuiltinVariable(RuntimeError).call_function(
|
||||
self, # type: ignore[arg-type]
|
||||
[],
|
||||
|
@ -137,6 +137,8 @@ from .user_defined import (
|
||||
RemovableHandleVariable,
|
||||
UserDefinedClassVariable,
|
||||
UserDefinedDictVariable,
|
||||
UserDefinedExceptionClassVariable,
|
||||
UserDefinedExceptionObjectVariable,
|
||||
UserDefinedListVariable,
|
||||
UserDefinedObjectVariable,
|
||||
UserDefinedTupleVariable,
|
||||
|
@ -257,6 +257,7 @@ from .user_defined import (
|
||||
SourcelessGraphModuleVariable,
|
||||
UserDefinedClassVariable,
|
||||
UserDefinedDictVariable,
|
||||
UserDefinedExceptionClassVariable,
|
||||
UserDefinedListVariable,
|
||||
UserDefinedObjectVariable,
|
||||
UserDefinedTupleVariable,
|
||||
@ -1162,6 +1163,10 @@ class VariableBuilder:
|
||||
# insert a FUNCTION_MATCH guard here. method-wrappers are very
|
||||
# unlikely to change, so its ok to skip the guard here.
|
||||
return MethodWrapperVariable(value)
|
||||
elif issubclass(type(value), type) and issubclass(value, BaseException):
|
||||
# match user defined exceptions
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return UserDefinedExceptionClassVariable(value)
|
||||
elif issubclass(type(value), type):
|
||||
if value in (
|
||||
torch.utils.hooks.BackwardHook,
|
||||
|
@ -21,6 +21,7 @@ These classes help Dynamo track and handle arbitrary Python objects during traci
|
||||
maintaining proper semantics while enabling optimizations where possible.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import collections
|
||||
import contextlib
|
||||
import dataclasses
|
||||
@ -99,7 +100,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def is_standard_setattr(val):
|
||||
return val in (object.__setattr__,)
|
||||
return val in (object.__setattr__, BaseException.__setattr__)
|
||||
|
||||
|
||||
def is_forbidden_context_manager(ctx):
|
||||
@ -138,7 +139,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
return self.value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"UserDefinedClassVariable({self.value})"
|
||||
return f"{self.__class__.__name__}({self.value})"
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@ -171,12 +172,18 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def supported_c_new_functions():
|
||||
exceptions = [
|
||||
getattr(builtins, name).__new__
|
||||
for name in dir(builtins)
|
||||
if isinstance(getattr(builtins, name), type)
|
||||
and issubclass(getattr(builtins, name), BaseException)
|
||||
]
|
||||
return {
|
||||
object.__new__,
|
||||
dict.__new__,
|
||||
tuple.__new__,
|
||||
list.__new__,
|
||||
}
|
||||
}.union(exceptions)
|
||||
|
||||
@staticmethod
|
||||
def is_supported_new_method(value):
|
||||
@ -689,6 +696,16 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
return super().const_getattr(tx, name)
|
||||
|
||||
|
||||
class UserDefinedExceptionClassVariable(UserDefinedClassVariable):
|
||||
@property
|
||||
def fn(self):
|
||||
return self.value
|
||||
|
||||
@property
|
||||
def python_type(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class NO_SUCH_SUBOBJ:
|
||||
pass
|
||||
|
||||
@ -1393,6 +1410,39 @@ class SourcelessGraphModuleVariable(UserDefinedObjectVariable):
|
||||
)
|
||||
|
||||
|
||||
class UserDefinedExceptionObjectVariable(UserDefinedObjectVariable):
|
||||
def __init__(self, value, **kwargs):
|
||||
super().__init__(value, **kwargs)
|
||||
self.exc_vt = variables.ExceptionVariable(self.value_type, ())
|
||||
|
||||
@property
|
||||
def fn(self):
|
||||
return self.value_type
|
||||
|
||||
def call_method(self, tx, name, args, kwargs):
|
||||
if (
|
||||
name == "__init__"
|
||||
and (method := self._maybe_get_baseclass_method(name))
|
||||
and inspect.ismethoddescriptor(method)
|
||||
and len(kwargs) == 0
|
||||
):
|
||||
self.exc_vt.args = args
|
||||
self.value.args = args
|
||||
return variables.ConstantVariable(None)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
@property
|
||||
def __context__(self):
|
||||
return self.exc_vt.__context__
|
||||
|
||||
def set_context(self, context: "variables.ExceptionVariable"):
|
||||
return self.exc_vt.set_context(context)
|
||||
|
||||
@property
|
||||
def exc_type(self):
|
||||
return self.exc_vt.exc_type
|
||||
|
||||
|
||||
class KeyedJaggedTensorVariable(UserDefinedObjectVariable):
|
||||
@staticmethod
|
||||
def is_matching_object(obj):
|
||||
|
@ -235,6 +235,7 @@ dtype_abbrs = {
|
||||
torch.uint32: "u32",
|
||||
torch.uint64: "u64",
|
||||
torch.bits16: "b16",
|
||||
torch.bits1x8: "b1x8",
|
||||
}
|
||||
|
||||
|
||||
@ -619,8 +620,10 @@ class CodeGen:
|
||||
node.meta.get("tensor_meta", node.meta.get("example_value", None)),
|
||||
)
|
||||
# use string as annotation, to make it valid python code
|
||||
|
||||
if isinstance(meta_val, torch.Tensor):
|
||||
if isinstance(meta_val, torch.Tensor) and meta_val.layout not in (
|
||||
torch.sparse_csc,
|
||||
torch.sparse_csr,
|
||||
):
|
||||
stride_annotation = (
|
||||
f"{stringify_shape(meta_val.stride())}"
|
||||
if include_stride
|
||||
|
Reference in New Issue
Block a user