Introduce UserDefinedExceptionClassVariable (#146504)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146504
Approved by: https://github.com/anijain2305
This commit is contained in:
Guilherme Leobas
2025-03-11 13:33:09 +00:00
committed by PyTorch MergeBot
parent 8d08b49015
commit 4e7d264cf8
14 changed files with 179 additions and 32 deletions

View File

@ -2696,7 +2696,6 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_contextlib.py
# https://github.com/python/cpython/blob/d48cc82ed25e26b02eb97c6263d95dcaa1e9111b/Lib/test/test_contextlib.py#L70
@unittest.expectedFailure
def test_contextmanager_plain(self):
state = []
@ -2977,7 +2976,6 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
with woohoo():
raise StopIteration
@unittest.expectedFailure
def test_keywords(self):
# Ensure no keyword arguments are inhibited
@contextmanager
@ -2991,7 +2989,6 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
fn(torch.randn(2, 3))
@unittest.expectedFailure
def test_recursive(self):
depth = 0
ncols = 0

View File

@ -13,6 +13,16 @@ from torch._dynamo.bytecode_transformation import Instruction
from torch._dynamo.symbolic_convert import SpeculationLog, SpeculationLogDivergence
class CustomException(Exception):
...
class CustomExceptionWithArgs(Exception):
def __init__(self, a, b=None):
self.a = a
self.b = b
class ExceptionTests(torch._dynamo.test_case.TestCase):
def test_exception(self):
def fn(x):
@ -215,6 +225,23 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
got = opt_fn(x)
self.assertEqual(expected, got)
def test_raise_custom_exception(self):
class Exc(Exception):
...
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
try:
raise Exc
except Exc:
return t.sin()
except Exception:
return t.cos()
t = torch.randn(2)
y = fn(t)
self.assertEqual(y, t.sin())
def test_nn_module_getattr(self):
class A:
def __init__(self) -> None:
@ -468,6 +495,37 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
x = torch.randn(4)
self.assertEqual(fn(x), opt_fn(x))
def test_user_defined_exception_variable(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
z = 0
try:
raise CustomException
except ValueError:
z = 1
except CustomException:
z = 2
assert z == 2
return t.sin()
t = torch.randn(2)
fn(t)
def test_user_defined_exception_with_args(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
z = 0
try:
raise CustomExceptionWithArgs(2, b=3)
except ValueError:
z = 1
except CustomExceptionWithArgs:
z = 2
assert z == 2
t = torch.randn(2)
fn(t)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -7963,7 +7963,6 @@ torch.cuda.synchronize()
self.assertEqual(output._metadata_cache, cache)
# See https://github.com/pytorch/pytorch/issues/128649
@xfailIfTorchDynamo
@dtypes(torch.float32)
def test_composite_op_in_inference_mode(self, device, dtype):
# expect view

View File

@ -343,8 +343,9 @@ observed_exception_map = {
def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedException]:
if exc_type not in observed_exception_map:
name = getattr(exc_type, "__name__", str(exc_type))
observed_exception_map[exc_type] = type(
f"Observed{exc_type.__name__}Error", (ObservedException,), {}
f"Observed{name}Error", (ObservedException,), {}
)
return observed_exception_map[exc_type]

View File

@ -256,6 +256,7 @@ class SideEffects:
int.__getattribute__,
str.__getattribute__,
list.__getattribute__,
BaseException.__getattribute__,
)
def is_attribute_mutation(self, item):
@ -377,6 +378,8 @@ class SideEffects:
variable_cls = variables.MutableMappingVariable
elif is_frozen_dataclass(user_cls):
variable_cls = FrozenDataClassVariable
elif issubclass(user_cls, BaseException):
variable_cls = variables.UserDefinedExceptionObjectVariable
assert issubclass(variable_cls, variables.UserDefinedObjectVariable)
return variable_cls

View File

@ -160,6 +160,8 @@ from .variables.torch_function import (
from .variables.user_defined import (
RemovableHandleVariable,
UserDefinedClassVariable,
UserDefinedExceptionClassVariable,
UserDefinedExceptionObjectVariable,
UserDefinedObjectVariable,
)
@ -1660,7 +1662,9 @@ class InstructionTranslatorBase(
# 2) raise execption instance - raise NotImplemetedError("foo")
# 1) when user raises exception type
if isinstance(val, variables.BuiltinVariable):
if isinstance(
val, (variables.BuiltinVariable, UserDefinedExceptionClassVariable)
):
# Create the instance of the exception type
# https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549
val = val.call_function(self, [], {}) # type: ignore[arg-type]
@ -1677,10 +1681,9 @@ class InstructionTranslatorBase(
self.exn_vt_stack.append(val)
# 2) when user raises exception instance
if isinstance(val, variables.ExceptionVariable):
if observed_exception_type := exc.observed_exception_map.get(val.exc_type):
if self._isinstance_exception(val):
observed_exception_type = exc.get_dynamo_observed_exception(val.exc_type) # type: ignore[attr-defined]
raise observed_exception_type(f"raised exception {val}")
raise exc.ObservedException(f"raised exception {val}")
unimplemented_v2(
gb_type="Failed to raise exception",
context=str(exc),
@ -1699,7 +1702,7 @@ class InstructionTranslatorBase(
"in Python < 3.11 (empty `raise`)",
hints=[],
)
assert isinstance(self.stack[-1], ExceptionVariable)
assert self._isinstance_exception(self.stack[-1])
self.stack.append(self.stack[-1])
self._raise_exception_variable(inst)
elif inst.arg == 1:
@ -1743,6 +1746,16 @@ class InstructionTranslatorBase(
hints=[],
)
def _isinstance_exception(self, val):
return isinstance(
val,
(
variables.ExceptionVariable,
UserDefinedExceptionClassVariable,
UserDefinedExceptionObjectVariable,
),
)
def WITH_EXCEPT_START(self, inst):
if sys.version_info >= (3, 11):
# At the top of the stack are 4 values:
@ -1755,15 +1768,15 @@ class InstructionTranslatorBase(
assert len(self.stack) >= 4
fn = self.stack[-4]
val = self.stack[-1]
assert isinstance(val, variables.ExceptionVariable)
typ = BuiltinVariable(val.exc_type)
assert self._isinstance_exception(val)
typ = BuiltinVariable(val.exc_type) # type: ignore[attr-defined]
tb = ConstantVariable(None)
else:
assert len(self.stack) >= 7
fn = self.stack[-7]
val = self.stack[-4]
assert isinstance(val, variables.ExceptionVariable)
typ = BuiltinVariable(val.exc_type)
assert self._isinstance_exception(val)
typ = BuiltinVariable(val.exc_type) # type: ignore[attr-defined]
tb = ConstantVariable(None)
self.call_function(fn, [typ, val, tb], {})
@ -1910,8 +1923,7 @@ class InstructionTranslatorBase(
def POP_EXCEPT(self, inst):
if sys.version_info >= (3, 11):
val = self.pop()
assert isinstance(val, variables.ExceptionVariable)
assert self._isinstance_exception(val)
# This exception is handled and therefore we can clear the error indicator
assert len(self.exn_vt_stack)
self.exn_vt_stack.pop()
@ -1947,11 +1959,20 @@ class InstructionTranslatorBase(
# https://github.com/python/cpython/blob/3.10/Python/ceval.c#L3650-L3665
exc_instance = self.stack.pop()
# Users can check exception in 2 ways
# 1) except NotImplementedError --> BuilinVariable
# 2) except (NotImplemetedError, AttributeError) -> TupleVariable
# Users can check exception in 3 ways
# 1) except NotImplementedError --> BuiltinVariable
# 2) except CustomException --> UserDefinedExceptionClasVariable
# 3) except (NotImplemetedError, AttributeError) -> TupleVariable
if not isinstance(expected_exc_types, (BuiltinVariable, TupleVariable)):
if not isinstance(
expected_exc_types,
(
BuiltinVariable,
TupleVariable,
UserDefinedExceptionClassVariable,
UserDefinedExceptionObjectVariable,
),
):
unimplemented_v2(
gb_type="Exception with bad expected type",
context=str(expected_exc_types),
@ -1960,7 +1981,7 @@ class InstructionTranslatorBase(
)
if sys.version_info >= (3, 11):
if not isinstance(exc_instance, variables.ExceptionVariable):
if not self._isinstance_exception(exc_instance):
unimplemented_v2(
gb_type="Caught non-Exception value",
context=str(exc_instance),
@ -1976,15 +1997,23 @@ class InstructionTranslatorBase(
]
for expected_type in expected_types:
if not isinstance(expected_type, BuiltinVariable):
if not isinstance(
expected_type,
(
BuiltinVariable,
UserDefinedExceptionObjectVariable,
UserDefinedExceptionClassVariable,
),
):
unimplemented_v2(
gb_type="Exception with non-type expectation",
context=str(expected_type),
explanation=f"`except ...` expects a non-type: {expected_type}.",
hints=[*graph_break_hints.USER_ERROR],
)
if isinstance(exc_instance, variables.ExceptionVariable) and issubclass(
exc_instance.exc_type, expected_type.fn
if self._isinstance_exception(exc_instance) and issubclass(
exc_instance.exc_type, # type: ignore[attr-defined]
expected_type.fn, # type: ignore[attr-defined]
):
return True
elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass(
@ -2608,8 +2637,8 @@ class InstructionTranslatorBase(
# https://github.com/python/cpython/pull/99006
# https://github.com/python/cpython/commit/28187141cc34063ef857976ddbca87ba09a882c2
val = self.stack[-1]
assert isinstance(val, ExceptionVariable)
if val.exc_type is StopIteration:
assert self._isinstance_exception(val)
if val.exc_type is StopIteration: # type: ignore[attr-defined]
new_val = variables.BuiltinVariable(RuntimeError).call_function(
self, # type: ignore[arg-type]
[],

View File

@ -137,6 +137,8 @@ from .user_defined import (
RemovableHandleVariable,
UserDefinedClassVariable,
UserDefinedDictVariable,
UserDefinedExceptionClassVariable,
UserDefinedExceptionObjectVariable,
UserDefinedListVariable,
UserDefinedObjectVariable,
UserDefinedTupleVariable,

View File

@ -257,6 +257,7 @@ from .user_defined import (
SourcelessGraphModuleVariable,
UserDefinedClassVariable,
UserDefinedDictVariable,
UserDefinedExceptionClassVariable,
UserDefinedListVariable,
UserDefinedObjectVariable,
UserDefinedTupleVariable,
@ -1162,6 +1163,10 @@ class VariableBuilder:
# insert a FUNCTION_MATCH guard here. method-wrappers are very
# unlikely to change, so its ok to skip the guard here.
return MethodWrapperVariable(value)
elif issubclass(type(value), type) and issubclass(value, BaseException):
# match user defined exceptions
self.install_guards(GuardBuilder.ID_MATCH)
return UserDefinedExceptionClassVariable(value)
elif issubclass(type(value), type):
if value in (
torch.utils.hooks.BackwardHook,

View File

@ -21,6 +21,7 @@ These classes help Dynamo track and handle arbitrary Python objects during traci
maintaining proper semantics while enabling optimizations where possible.
"""
import builtins
import collections
import contextlib
import dataclasses
@ -99,7 +100,7 @@ if TYPE_CHECKING:
def is_standard_setattr(val):
return val in (object.__setattr__,)
return val in (object.__setattr__, BaseException.__setattr__)
def is_forbidden_context_manager(ctx):
@ -138,7 +139,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
return self.value
def __repr__(self) -> str:
return f"UserDefinedClassVariable({self.value})"
return f"{self.__class__.__name__}({self.value})"
@staticmethod
@functools.lru_cache(None)
@ -171,12 +172,18 @@ class UserDefinedClassVariable(UserDefinedVariable):
@staticmethod
@functools.lru_cache(None)
def supported_c_new_functions():
exceptions = [
getattr(builtins, name).__new__
for name in dir(builtins)
if isinstance(getattr(builtins, name), type)
and issubclass(getattr(builtins, name), BaseException)
]
return {
object.__new__,
dict.__new__,
tuple.__new__,
list.__new__,
}
}.union(exceptions)
@staticmethod
def is_supported_new_method(value):
@ -689,6 +696,16 @@ class UserDefinedClassVariable(UserDefinedVariable):
return super().const_getattr(tx, name)
class UserDefinedExceptionClassVariable(UserDefinedClassVariable):
@property
def fn(self):
return self.value
@property
def python_type(self):
return self.value
class NO_SUCH_SUBOBJ:
pass
@ -1393,6 +1410,39 @@ class SourcelessGraphModuleVariable(UserDefinedObjectVariable):
)
class UserDefinedExceptionObjectVariable(UserDefinedObjectVariable):
def __init__(self, value, **kwargs):
super().__init__(value, **kwargs)
self.exc_vt = variables.ExceptionVariable(self.value_type, ())
@property
def fn(self):
return self.value_type
def call_method(self, tx, name, args, kwargs):
if (
name == "__init__"
and (method := self._maybe_get_baseclass_method(name))
and inspect.ismethoddescriptor(method)
and len(kwargs) == 0
):
self.exc_vt.args = args
self.value.args = args
return variables.ConstantVariable(None)
return super().call_method(tx, name, args, kwargs)
@property
def __context__(self):
return self.exc_vt.__context__
def set_context(self, context: "variables.ExceptionVariable"):
return self.exc_vt.set_context(context)
@property
def exc_type(self):
return self.exc_vt.exc_type
class KeyedJaggedTensorVariable(UserDefinedObjectVariable):
@staticmethod
def is_matching_object(obj):

View File

@ -235,6 +235,7 @@ dtype_abbrs = {
torch.uint32: "u32",
torch.uint64: "u64",
torch.bits16: "b16",
torch.bits1x8: "b1x8",
}
@ -619,8 +620,10 @@ class CodeGen:
node.meta.get("tensor_meta", node.meta.get("example_value", None)),
)
# use string as annotation, to make it valid python code
if isinstance(meta_val, torch.Tensor):
if isinstance(meta_val, torch.Tensor) and meta_val.layout not in (
torch.sparse_csc,
torch.sparse_csr,
):
stride_annotation = (
f"{stringify_shape(meta_val.stride())}"
if include_stride